Skip to content

Commit

Permalink
Adding support for pipeline("automatic-speech-recognition"). (#11525)
Browse files Browse the repository at this point in the history
* Adding support for `pipeline("automatic-speech-recognition")`.

- Ugly `"config"` choice for AutoModel. It would be great to have the
possibility to have something like `AutoModelFor` that would implement
the same logic (Load the config, check Architectures and load the first
one)

* Remove `model_id` was not needed in the end.

* Rebased !

* Remove old code.

* Rename `nlp`.
  • Loading branch information
Narsil authored Jul 7, 2021
1 parent 7d321b7 commit ebc69af
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 10 deletions.
8 changes: 8 additions & 0 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@
"ner": "token-classification",
}
SUPPORTED_TASKS = {
"automatic-speech-recognition": {
"impl": AutomaticSpeechRecognitionPipeline,
"tf": (),
# Only load from `config.architectures`, AutoModelForCTC and AutoModelForConditionalGeneration
# do not exist yet.
"pt": () if is_torch_available() else (),
"default": {"model": {"pt": "facebook/wav2vec2-base-960h"}},
},
"feature-extraction": {
"impl": FeatureExtractionPipeline,
"tf": (TFAutoModel,) if is_tf_available() else (),
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ def infer_framework_load_model(
classes = []
for architecture in config.architectures:
transformers_module = importlib.import_module("transformers")
if look_tf:
if look_pt:
_class = getattr(transformers_module, architecture, None)
if _class is not None:
classes.append(_class)
if look_pt:
if look_tf:
_class = getattr(transformers_module, f"TF{architecture}", None)
if _class is not None:
classes.append(_class)
Expand Down
53 changes: 45 additions & 8 deletions tests/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,57 @@
import unittest

from transformers import AutoFeatureExtractor, AutoTokenizer, Speech2TextForConditionalGeneration, Wav2Vec2ForCTC
from transformers.pipelines import AutomaticSpeechRecognitionPipeline
from transformers.testing_utils import require_datasets, require_torch, require_torchaudio, slow
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
from transformers.testing_utils import is_pipeline_test, require_datasets, require_torch, require_torchaudio, slow


# We can't use this mixin because it assumes TF support.
# from .test_pipelines_common import CustomInputPipelineCommonMixin


@is_pipeline_test
class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
# pipeline_task = "automatic-speech-recognition"
# small_models = ["facebook/s2t-small-mustc-en-fr-st"] # Models tested without the @slow decorator
# large_models = [
# "facebook/wav2vec2-base-960h",
# "facebook/s2t-small-mustc-en-fr-st",
# ] # Models tested with the @slow decorator
@require_torch
@slow
def test_pt_defaults(self):
pipeline("automatic-speech-recognition", framework="pt")

@require_torch
def test_torch_small(self):
import numpy as np

speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="facebook/s2t-small-mustc-en-fr-st",
tokenizer="facebook/s2t-small-mustc-en-fr-st",
framework="pt",
)
waveform = np.zeros((34000,))
output = speech_recognizer(waveform)
self.assertEqual(output, {"text": "C'est ce que j'ai fait à ce moment-là."})

@require_datasets
@require_torch
@slow
def test_torch_large(self):
import numpy as np

speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="facebook/wav2vec2-base-960h",
tokenizer="facebook/wav2vec2-base-960h",
framework="pt",
)
waveform = np.zeros((34000,))
output = speech_recognizer(waveform)
self.assertEqual(output, {"text": ""})

from datasets import load_dataset

ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
filename = ds[0]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})

@slow
@require_torch
Expand Down

0 comments on commit ebc69af

Please sign in to comment.