diff --git a/test/test_pipelines.py b/test/test_pipelines.py index 1903f758f..c40e6d3b1 100644 --- a/test/test_pipelines.py +++ b/test/test_pipelines.py @@ -3,20 +3,6 @@ from unittest.mock import MagicMock import pytest import transformers -from transformers import ( - AudioClassificationPipeline, - AutomaticSpeechRecognitionPipeline, - DocumentQuestionAnsweringPipeline, - FeatureExtractionPipeline, - FillMaskPipeline, - ImageClassificationPipeline, - ObjectDetectionPipeline, - QuestionAnsweringPipeline, - TextClassificationPipeline, - TextGenerationPipeline, - VisualQuestionAnsweringPipeline, - ZeroShotClassificationPipeline, -) import gradio as gr from gradio.pipelines_utils import ( @@ -24,6 +10,14 @@ from gradio.pipelines_utils import ( ) +def _get_pipeline_cls(name: str): + """Resolve a pipeline class by name from transformers, returning None if it + was removed in the installed version.""" + return getattr(transformers, name, None) or getattr( + transformers.pipelines, name, None + ) + + @pytest.mark.flaky def test_interface_in_blocks(): pipe1 = transformers.pipeline(model="deepset/roberta-base-squad2") # type: ignore @@ -50,50 +44,66 @@ def test_transformers_load_from_pipeline(): class TestHandleTransformersPipelines(unittest.TestCase): + def _require(self, name: str): + """Return the pipeline class or skip the test if it was removed.""" + cls = _get_pipeline_cls(name) + if cls is None: + self.skipTest( + f"{name} not available in transformers {transformers.__version__}" + ) + return cls + def test_audio_classification_pipeline(self): - pipe = MagicMock(spec=AudioClassificationPipeline) + cls = self._require("AudioClassificationPipeline") + pipe = MagicMock(spec=cls) pipeline_info = handle_transformers_pipeline(pipe) assert pipeline_info is not None assert pipeline_info["inputs"].label == "Input" assert pipeline_info["outputs"].label == "Class" def test_automatic_speech_recognition_pipeline(self): - pipe = MagicMock(spec=AutomaticSpeechRecognitionPipeline) + cls = self._require("AutomaticSpeechRecognitionPipeline") + pipe = MagicMock(spec=cls) pipeline_info = handle_transformers_pipeline(pipe) assert pipeline_info is not None assert pipeline_info["inputs"].label == "Input" assert pipeline_info["outputs"].label == "Output" def test_object_detection_pipeline(self): - pipe = MagicMock(spec=ObjectDetectionPipeline) + cls = self._require("ObjectDetectionPipeline") + pipe = MagicMock(spec=cls) pipeline_info = handle_transformers_pipeline(pipe) assert pipeline_info is not None assert pipeline_info["inputs"].label == "Input Image" assert pipeline_info["outputs"].label == "Objects Detected" def test_feature_extraction_pipeline(self): - pipe = MagicMock(spec=FeatureExtractionPipeline) + cls = self._require("FeatureExtractionPipeline") + pipe = MagicMock(spec=cls) pipeline_info = handle_transformers_pipeline(pipe) assert pipeline_info is not None assert pipeline_info["inputs"].label == "Input" assert pipeline_info["outputs"].label == "Output" def test_fill_mask_pipeline(self): - pipe = MagicMock(spec=FillMaskPipeline) + cls = self._require("FillMaskPipeline") + pipe = MagicMock(spec=cls) pipeline_info = handle_transformers_pipeline(pipe) assert pipeline_info is not None assert pipeline_info["inputs"].label == "Input" assert pipeline_info["outputs"].label == "Classification" def test_image_classification_pipeline(self): - pipe = MagicMock(spec=ImageClassificationPipeline) + cls = self._require("ImageClassificationPipeline") + pipe = MagicMock(spec=cls) pipeline_info = handle_transformers_pipeline(pipe) assert pipeline_info is not None assert pipeline_info["inputs"].label == "Input Image" assert pipeline_info["outputs"].label == "Classification" def test_question_answering_pipeline(self): - pipe = MagicMock(spec=QuestionAnsweringPipeline) + cls = self._require("QuestionAnsweringPipeline") + pipe = MagicMock(spec=cls) pipeline_info = handle_transformers_pipeline(pipe) assert pipeline_info is not None assert pipeline_info["inputs"][0].label == "Context" @@ -102,21 +112,24 @@ class TestHandleTransformersPipelines(unittest.TestCase): assert pipeline_info["outputs"][1].label == "Score" def test_text_classification_pipeline(self): - pipe = MagicMock(spec=TextClassificationPipeline) + cls = self._require("TextClassificationPipeline") + pipe = MagicMock(spec=cls) pipeline_info = handle_transformers_pipeline(pipe) assert pipeline_info is not None assert pipeline_info["inputs"].label == "Input" assert pipeline_info["outputs"].label == "Classification" def test_text_generation_pipeline(self): - pipe = MagicMock(spec=TextGenerationPipeline) + cls = self._require("TextGenerationPipeline") + pipe = MagicMock(spec=cls) pipeline_info = handle_transformers_pipeline(pipe) assert pipeline_info is not None assert pipeline_info["inputs"].label == "Input" assert pipeline_info["outputs"].label == "Output" def test_zero_shot_classification_pipeline(self): - pipe = MagicMock(spec=ZeroShotClassificationPipeline) + cls = self._require("ZeroShotClassificationPipeline") + pipe = MagicMock(spec=cls) pipeline_info = handle_transformers_pipeline(pipe) assert pipeline_info is not None assert pipeline_info["inputs"][0].label == "Input" @@ -127,7 +140,8 @@ class TestHandleTransformersPipelines(unittest.TestCase): assert pipeline_info["outputs"].label == "Classification" def test_document_question_answering_pipeline(self): - pipe = MagicMock(spec=DocumentQuestionAnsweringPipeline) + cls = self._require("DocumentQuestionAnsweringPipeline") + pipe = MagicMock(spec=cls) pipeline_info = handle_transformers_pipeline(pipe) assert pipeline_info is not None assert pipeline_info["inputs"][0].label == "Input Document" @@ -135,7 +149,8 @@ class TestHandleTransformersPipelines(unittest.TestCase): assert pipeline_info["outputs"].label == "Label" def test_visual_question_answering_pipeline(self): - pipe = MagicMock(spec=VisualQuestionAnsweringPipeline) + cls = self._require("VisualQuestionAnsweringPipeline") + pipe = MagicMock(spec=cls) pipeline_info = handle_transformers_pipeline(pipe) assert pipeline_info is not None assert pipeline_info["inputs"][0].label == "Input Image"