diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 9d46ad1bb8a698..1d353fba0cfeef 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -95,8 +95,6 @@ AutoModelForTableQuestionAnswering, AutoModelForTokenClassification, ) - from ..models.speech_to_text.modeling_speech_to_text import Speech2TextForConditionalGeneration - from ..models.wav2vec2.modeling_wav2vec2 import Wav2Vec2ForCTC if TYPE_CHECKING: from ..modeling_tf_utils import TFPreTrainedModel from ..modeling_utils import PreTrainedModel @@ -112,8 +110,10 @@ SUPPORTED_TASKS = { "automatic-speech-recognition": { "impl": AutomaticSpeechRecognitionPipeline, - "tf": None, - "pt": "config" if is_torch_available() else None, + "tf": (), + # Only load from `config.architectures`, AutoModelForCTC and AutoModelForConditionalGeneration + # do not exist yet. + "pt": () if is_torch_available() else (), "default": {"model": {"pt": "facebook/wav2vec2-base-960h"}}, }, "feature-extraction": { diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 5065c56ca29d76..19d9840fc7fad0 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -104,11 +104,11 @@ def infer_framework_load_model( classes = [] for architecture in config.architectures: transformers_module = importlib.import_module("transformers") - if look_tf: + if look_pt: _class = getattr(transformers_module, architecture, None) if _class is not None: classes.append(_class) - if look_pt: + if look_tf: _class = getattr(transformers_module, f"TF{architecture}", None) if _class is not None: classes.append(_class) diff --git a/tests/test_pipelines_automatic_speech_recognition.py b/tests/test_pipelines_automatic_speech_recognition.py index fba845c7e71f4f..eebdc477b0dd18 100644 --- a/tests/test_pipelines_automatic_speech_recognition.py +++ b/tests/test_pipelines_automatic_speech_recognition.py @@ -16,13 +16,14 @@ from transformers import AutoFeatureExtractor, AutoTokenizer, Speech2TextForConditionalGeneration, Wav2Vec2ForCTC from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline -from transformers.testing_utils import require_datasets, require_torch, require_torchaudio, slow +from transformers.testing_utils import is_pipeline_test, require_datasets, require_torch, require_torchaudio, slow # We can't use this mixin because it assumes TF support. # from .test_pipelines_common import CustomInputPipelineCommonMixin +@is_pipeline_test class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): @require_torch @slow