nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
1diff --git a/test/test_pipelines.py b/test/test_pipelines.py
2index 1903f758f..c40e6d3b1 100644
3--- a/test/test_pipelines.py
4+++ b/test/test_pipelines.py
5@@ -3,20 +3,6 @@ from unittest.mock import MagicMock
6
7 import pytest
8 import transformers
9-from transformers import (
10- AudioClassificationPipeline,
11- AutomaticSpeechRecognitionPipeline,
12- DocumentQuestionAnsweringPipeline,
13- FeatureExtractionPipeline,
14- FillMaskPipeline,
15- ImageClassificationPipeline,
16- ObjectDetectionPipeline,
17- QuestionAnsweringPipeline,
18- TextClassificationPipeline,
19- TextGenerationPipeline,
20- VisualQuestionAnsweringPipeline,
21- ZeroShotClassificationPipeline,
22-)
23
24 import gradio as gr
25 from gradio.pipelines_utils import (
26@@ -24,6 +10,14 @@ from gradio.pipelines_utils import (
27 )
28
29
30+def _get_pipeline_cls(name: str):
31+ """Resolve a pipeline class by name from transformers, returning None if it
32+ was removed in the installed version."""
33+ return getattr(transformers, name, None) or getattr(
34+ transformers.pipelines, name, None
35+ )
36+
37+
38 @pytest.mark.flaky
39 def test_interface_in_blocks():
40 pipe1 = transformers.pipeline(model="deepset/roberta-base-squad2") # type: ignore
41@@ -50,50 +44,66 @@ def test_transformers_load_from_pipeline():
42
43
44 class TestHandleTransformersPipelines(unittest.TestCase):
45+ def _require(self, name: str):
46+ """Return the pipeline class or skip the test if it was removed."""
47+ cls = _get_pipeline_cls(name)
48+ if cls is None:
49+ self.skipTest(
50+ f"{name} not available in transformers {transformers.__version__}"
51+ )
52+ return cls
53+
54 def test_audio_classification_pipeline(self):
55- pipe = MagicMock(spec=AudioClassificationPipeline)
56+ cls = self._require("AudioClassificationPipeline")
57+ pipe = MagicMock(spec=cls)
58 pipeline_info = handle_transformers_pipeline(pipe)
59 assert pipeline_info is not None
60 assert pipeline_info["inputs"].label == "Input"
61 assert pipeline_info["outputs"].label == "Class"
62
63 def test_automatic_speech_recognition_pipeline(self):
64- pipe = MagicMock(spec=AutomaticSpeechRecognitionPipeline)
65+ cls = self._require("AutomaticSpeechRecognitionPipeline")
66+ pipe = MagicMock(spec=cls)
67 pipeline_info = handle_transformers_pipeline(pipe)
68 assert pipeline_info is not None
69 assert pipeline_info["inputs"].label == "Input"
70 assert pipeline_info["outputs"].label == "Output"
71
72 def test_object_detection_pipeline(self):
73- pipe = MagicMock(spec=ObjectDetectionPipeline)
74+ cls = self._require("ObjectDetectionPipeline")
75+ pipe = MagicMock(spec=cls)
76 pipeline_info = handle_transformers_pipeline(pipe)
77 assert pipeline_info is not None
78 assert pipeline_info["inputs"].label == "Input Image"
79 assert pipeline_info["outputs"].label == "Objects Detected"
80
81 def test_feature_extraction_pipeline(self):
82- pipe = MagicMock(spec=FeatureExtractionPipeline)
83+ cls = self._require("FeatureExtractionPipeline")
84+ pipe = MagicMock(spec=cls)
85 pipeline_info = handle_transformers_pipeline(pipe)
86 assert pipeline_info is not None
87 assert pipeline_info["inputs"].label == "Input"
88 assert pipeline_info["outputs"].label == "Output"
89
90 def test_fill_mask_pipeline(self):
91- pipe = MagicMock(spec=FillMaskPipeline)
92+ cls = self._require("FillMaskPipeline")
93+ pipe = MagicMock(spec=cls)
94 pipeline_info = handle_transformers_pipeline(pipe)
95 assert pipeline_info is not None
96 assert pipeline_info["inputs"].label == "Input"
97 assert pipeline_info["outputs"].label == "Classification"
98
99 def test_image_classification_pipeline(self):
100- pipe = MagicMock(spec=ImageClassificationPipeline)
101+ cls = self._require("ImageClassificationPipeline")
102+ pipe = MagicMock(spec=cls)
103 pipeline_info = handle_transformers_pipeline(pipe)
104 assert pipeline_info is not None
105 assert pipeline_info["inputs"].label == "Input Image"
106 assert pipeline_info["outputs"].label == "Classification"
107
108 def test_question_answering_pipeline(self):
109- pipe = MagicMock(spec=QuestionAnsweringPipeline)
110+ cls = self._require("QuestionAnsweringPipeline")
111+ pipe = MagicMock(spec=cls)
112 pipeline_info = handle_transformers_pipeline(pipe)
113 assert pipeline_info is not None
114 assert pipeline_info["inputs"][0].label == "Context"
115@@ -102,21 +112,24 @@ class TestHandleTransformersPipelines(unittest.TestCase):
116 assert pipeline_info["outputs"][1].label == "Score"
117
118 def test_text_classification_pipeline(self):
119- pipe = MagicMock(spec=TextClassificationPipeline)
120+ cls = self._require("TextClassificationPipeline")
121+ pipe = MagicMock(spec=cls)
122 pipeline_info = handle_transformers_pipeline(pipe)
123 assert pipeline_info is not None
124 assert pipeline_info["inputs"].label == "Input"
125 assert pipeline_info["outputs"].label == "Classification"
126
127 def test_text_generation_pipeline(self):
128- pipe = MagicMock(spec=TextGenerationPipeline)
129+ cls = self._require("TextGenerationPipeline")
130+ pipe = MagicMock(spec=cls)
131 pipeline_info = handle_transformers_pipeline(pipe)
132 assert pipeline_info is not None
133 assert pipeline_info["inputs"].label == "Input"
134 assert pipeline_info["outputs"].label == "Output"
135
136 def test_zero_shot_classification_pipeline(self):
137- pipe = MagicMock(spec=ZeroShotClassificationPipeline)
138+ cls = self._require("ZeroShotClassificationPipeline")
139+ pipe = MagicMock(spec=cls)
140 pipeline_info = handle_transformers_pipeline(pipe)
141 assert pipeline_info is not None
142 assert pipeline_info["inputs"][0].label == "Input"
143@@ -127,7 +140,8 @@ class TestHandleTransformersPipelines(unittest.TestCase):
144 assert pipeline_info["outputs"].label == "Classification"
145
146 def test_document_question_answering_pipeline(self):
147- pipe = MagicMock(spec=DocumentQuestionAnsweringPipeline)
148+ cls = self._require("DocumentQuestionAnsweringPipeline")
149+ pipe = MagicMock(spec=cls)
150 pipeline_info = handle_transformers_pipeline(pipe)
151 assert pipeline_info is not None
152 assert pipeline_info["inputs"][0].label == "Input Document"
153@@ -135,7 +149,8 @@ class TestHandleTransformersPipelines(unittest.TestCase):
154 assert pipeline_info["outputs"].label == "Label"
155
156 def test_visual_question_answering_pipeline(self):
157- pipe = MagicMock(spec=VisualQuestionAnsweringPipeline)
158+ cls = self._require("VisualQuestionAnsweringPipeline")
159+ pipe = MagicMock(spec=cls)
160 pipeline_info = handle_transformers_pipeline(pipe)
161 assert pipeline_info is not None
162 assert pipeline_info["inputs"][0].label == "Input Image"