Skip to content

Commit

Permalink
Rebased !
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jun 14, 2021
1 parent 8107249 commit f2ff3a7
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
8 changes: 4 additions & 4 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
)
from ..models.speech_to_text.modeling_speech_to_text import Speech2TextForConditionalGeneration
from ..models.wav2vec2.modeling_wav2vec2 import Wav2Vec2ForCTC
if TYPE_CHECKING:
from ..modeling_tf_utils import TFPreTrainedModel
from ..modeling_utils import PreTrainedModel
Expand All @@ -112,8 +110,10 @@
SUPPORTED_TASKS = {
"automatic-speech-recognition": {
"impl": AutomaticSpeechRecognitionPipeline,
"tf": None,
"pt": "config" if is_torch_available() else None,
"tf": (),
# Only load from `config.architectures`, AutoModelForCTC and AutoModelForConditionalGeneration
# do not exist yet.
"pt": () if is_torch_available() else (),
"default": {"model": {"pt": "facebook/wav2vec2-base-960h"}},
},
"feature-extraction": {
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ def infer_framework_load_model(
classes = []
for architecture in config.architectures:
transformers_module = importlib.import_module("transformers")
if look_tf:
if look_pt:
_class = getattr(transformers_module, architecture, None)
if _class is not None:
classes.append(_class)
if look_pt:
if look_tf:
_class = getattr(transformers_module, f"TF{architecture}", None)
if _class is not None:
classes.append(_class)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

from transformers import AutoFeatureExtractor, AutoTokenizer, Speech2TextForConditionalGeneration, Wav2Vec2ForCTC
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
from transformers.testing_utils import require_datasets, require_torch, require_torchaudio, slow
from transformers.testing_utils import is_pipeline_test, require_datasets, require_torch, require_torchaudio, slow


# We can't use this mixin because it assumes TF support.
# from .test_pipelines_common import CustomInputPipelineCommonMixin


@is_pipeline_test
class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
@require_torch
@slow
Expand Down

0 comments on commit f2ff3a7

Please sign in to comment.