From 3000ad082d2abc92aeafd9126212289b33a46c76 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Fri, 19 Jan 2024 11:25:01 +0000 Subject: [PATCH] Add w2v2bert to pipeline (#28585) * generalize asr pipeline to fbank models * change w2v2 pipeline output * Update test_pipelines_automatic_speech_recognition.py --- .../pipelines/automatic_speech_recognition.py | 7 +++++-- ...st_pipelines_automatic_speech_recognition.py | 17 +++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 1bc68a8209aab7..2c8bf5e2ad9084 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -517,8 +517,11 @@ def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None): out["stride"] = stride else: - input_values = model_inputs.pop("input_values") - outputs = self.model(input_values=input_values, attention_mask=attention_mask) + inputs = { + self.model.main_input_name: model_inputs.pop(self.model.main_input_name), + "attention_mask": attention_mask, + } + outputs = self.model(**inputs) logits = outputs.logits if self.type == "ctc_with_lm": diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 3da55ab9da107f..1aae29e5d45bec 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -298,6 +298,23 @@ def test_torch_large(self): output = speech_recognizer(filename) self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"}) + @require_torch + @slow + def test_torch_large_with_input_features(self): + speech_recognizer = pipeline( + task="automatic-speech-recognition", + model="hf-audio/wav2vec2-bert-CV16-en", + framework="pt", + ) + waveform = np.tile(np.arange(1000, dtype=np.float32), 34) + output = speech_recognizer(waveform) + self.assertEqual(output, {"text": ""}) + + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") + filename = ds[40]["file"] + output = speech_recognizer(filename) + self.assertEqual(output, {"text": "a man said to the universe sir i exist"}) + @slow @require_torch @slow