From b116311768e6770f9fe40c5727f83620d4b478c1 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Tue, 17 Sep 2024 17:59:16 +0200 Subject: [PATCH] Fix tests in ASR pipeline --- ..._pipelines_automatic_speech_recognition.py | 74 +++++++++---------- 1 file changed, 35 insertions(+), 39 deletions(-) diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index abb07d831ad003..842933d2b76c94 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -295,8 +295,8 @@ def test_torch_large(self): self.assertEqual(output, {"text": ""}) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") - filename = ds[40]["file"] - output = speech_recognizer(filename) + audio = ds[40]["audio"] + output = speech_recognizer(audio) self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"}) @require_torch @@ -312,8 +312,8 @@ def test_torch_large_with_input_features(self): self.assertEqual(output, {"text": ""}) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") - filename = ds[40]["file"] - output = speech_recognizer(filename) + audio = ds[40]["audio"] + output = speech_recognizer(audio) self.assertEqual(output, {"text": "a man said to the universe sir i exist"}) @slow @@ -542,11 +542,11 @@ def test_torch_whisper(self): framework="pt", ) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") - filename = ds[40]["file"] - output = speech_recognizer(filename) + audio = ds[40]["audio"] + output = speech_recognizer(audio) self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."}) - output = speech_recognizer([filename], chunk_length_s=5, batch_size=4) + output = speech_recognizer([ds[40]["audio"]], chunk_length_s=5, batch_size=4) self.assertEqual(output, [{"text": " A man said to the universe, Sir, I exist."}]) @require_torch @@ -1014,8 +1014,8 @@ def test_torch_speech_encoder_decoder(self): ) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") - filename = ds[40]["file"] - output = speech_recognizer(filename) + audio = ds[40]["audio"] + output = speech_recognizer(audio) self.assertEqual(output, {"text": 'Ein Mann sagte zum Universum : " Sir, ich existiert! "'}) @slow @@ -1032,13 +1032,11 @@ def test_simple_wav2vec2(self): self.assertEqual(output, {"text": ""}) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") - filename = ds[40]["file"] - output = asr(filename) + audio = ds[40]["audio"] + output = asr(audio) self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"}) - filename = ds[40]["file"] - with open(filename, "rb") as f: - data = f.read() + data = Audio().encode_example(ds[40]["audio"])["bytes"] output = asr(data) self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"}) @@ -1058,13 +1056,11 @@ def test_simple_s2t(self): self.assertEqual(output, {"text": "(Applausi)"}) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") - filename = ds[40]["file"] - output = asr(filename) + audio = ds[40]["audio"] + output = asr(audio) self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."}) - filename = ds[40]["file"] - with open(filename, "rb") as f: - data = f.read() + data = Audio().encode_example(ds[40]["audio"])["bytes"] output = asr(data) self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."}) @@ -1078,13 +1074,13 @@ def test_simple_whisper_asr(self): framework="pt", ) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - filename = ds[0]["file"] - output = speech_recognizer(filename) + audio = ds[0]["audio"] + output = speech_recognizer(audio) self.assertEqual( output, {"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."}, ) - output = speech_recognizer(filename, return_timestamps=True) + output = speech_recognizer(ds[0]["audio"], return_timestamps=True) self.assertEqual( output, { @@ -1100,7 +1096,7 @@ def test_simple_whisper_asr(self): }, ) speech_recognizer.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]] - output = speech_recognizer(filename, return_timestamps="word") + output = speech_recognizer(ds[0]["audio"], return_timestamps="word") # fmt: off self.assertEqual( output, @@ -1135,7 +1131,7 @@ def test_simple_whisper_asr(self): "^Whisper cannot return `char` timestamps, only word level or segment level timestamps. " "Use `return_timestamps='word'` or `return_timestamps=True` respectively.$", ): - _ = speech_recognizer(filename, return_timestamps="char") + _ = speech_recognizer(audio, return_timestamps="char") @slow @require_torch @@ -1147,8 +1143,8 @@ def test_simple_whisper_translation(self): framework="pt", ) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") - filename = ds[40]["file"] - output = speech_recognizer(filename) + audio = ds[40]["audio"] + output = speech_recognizer(audio) self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."}) model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large") @@ -1158,7 +1154,7 @@ def test_simple_whisper_translation(self): speech_recognizer_2 = AutomaticSpeechRecognitionPipeline( model=model, tokenizer=tokenizer, feature_extractor=feature_extractor ) - output_2 = speech_recognizer_2(filename) + output_2 = speech_recognizer_2(ds[0]["audio"]) self.assertEqual(output, output_2) # either use generate_kwargs or set the model's generation_config @@ -1170,7 +1166,7 @@ def test_simple_whisper_translation(self): feature_extractor=feature_extractor, generate_kwargs={"task": "transcribe", "language": "<|it|>"}, ) - output_3 = speech_translator(filename) + output_3 = speech_translator(ds[0]["audio"]) self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."}) @slow @@ -1182,10 +1178,10 @@ def test_whisper_language(self): framework="pt", ) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - filename = ds[0]["file"] + audio = ds[0]["audio"] # 1. English-only model compatible with no language argument - output = speech_recognizer(filename) + output = speech_recognizer(audio) self.assertEqual( output, {"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."}, @@ -1197,7 +1193,7 @@ def test_whisper_language(self): "Cannot specify `task` or `language` for an English-only model. If the model is intended to be multilingual, " "pass `is_multilingual=True` to generate, or update the generation config.", ): - _ = speech_recognizer(filename, generate_kwargs={"language": "en"}) + _ = speech_recognizer(ds[0]["audio"], generate_kwargs={"language": "en"}) # 3. Multilingual model accepts language argument speech_recognizer = pipeline( @@ -1205,7 +1201,7 @@ def test_whisper_language(self): model="openai/whisper-tiny", framework="pt", ) - output = speech_recognizer(filename, generate_kwargs={"language": "en"}) + output = speech_recognizer(ds[0]["audio"], generate_kwargs={"language": "en"}) self.assertEqual( output, {"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."}, @@ -1315,8 +1311,8 @@ def test_xls_r_to_en(self): ) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") - filename = ds[40]["file"] - output = speech_recognizer(filename) + audio = ds[40]["audio"] + output = speech_recognizer(audio) self.assertEqual(output, {"text": "A man said to the universe: “Sir, I exist."}) @slow @@ -1331,8 +1327,8 @@ def test_xls_r_from_en(self): ) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") - filename = ds[40]["file"] - output = speech_recognizer(filename) + audio = ds[40]["audio"] + output = speech_recognizer(audio) self.assertEqual(output, {"text": "Ein Mann sagte zu dem Universum, Sir, ich bin da."}) @slow @@ -1348,9 +1344,8 @@ def test_speech_to_text_leveraged(self): ) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") - filename = ds[40]["file"] - - output = speech_recognizer(filename) + audio = ds[40]["audio"] + output = speech_recognizer(audio) self.assertEqual(output, {"text": "a man said to the universe sir i exist"}) @slow @@ -1561,6 +1556,7 @@ def test_whisper_longform(self): feature_extractor=processor.feature_extractor, max_new_tokens=128, device=torch_device, + return_timestamps=True, # to allow longform generation ) ds = load_dataset("distil-whisper/meanwhile", "default")["test"]