Skip to content

Commit

Permalink
Fix tests in ASR pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
ylacombe committed Sep 17, 2024
1 parent ac5a055 commit b116311
Showing 1 changed file with 35 additions and 39 deletions.
74 changes: 35 additions & 39 deletions tests/pipelines/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ def test_torch_large(self):
self.assertEqual(output, {"text": ""})

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})

@require_torch
Expand All @@ -312,8 +312,8 @@ def test_torch_large_with_input_features(self):
self.assertEqual(output, {"text": ""})

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})

@slow
Expand Down Expand Up @@ -542,11 +542,11 @@ def test_torch_whisper(self):
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."})

output = speech_recognizer([filename], chunk_length_s=5, batch_size=4)
output = speech_recognizer([ds[40]["audio"]], chunk_length_s=5, batch_size=4)
self.assertEqual(output, [{"text": " A man said to the universe, Sir, I exist."}])

@require_torch
Expand Down Expand Up @@ -1014,8 +1014,8 @@ def test_torch_speech_encoder_decoder(self):
)

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": 'Ein Mann sagte zum Universum : " Sir, ich existiert! "'})

@slow
Expand All @@ -1032,13 +1032,11 @@ def test_simple_wav2vec2(self):
self.assertEqual(output, {"text": ""})

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = asr(filename)
audio = ds[40]["audio"]
output = asr(audio)
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})

filename = ds[40]["file"]
with open(filename, "rb") as f:
data = f.read()
data = Audio().encode_example(ds[40]["audio"])["bytes"]
output = asr(data)
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})

Expand All @@ -1058,13 +1056,11 @@ def test_simple_s2t(self):
self.assertEqual(output, {"text": "(Applausi)"})

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = asr(filename)
audio = ds[40]["audio"]
output = asr(audio)
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})

filename = ds[40]["file"]
with open(filename, "rb") as f:
data = f.read()
data = Audio().encode_example(ds[40]["audio"])["bytes"]
output = asr(data)
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})

Expand All @@ -1078,13 +1074,13 @@ def test_simple_whisper_asr(self):
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
filename = ds[0]["file"]
output = speech_recognizer(filename)
audio = ds[0]["audio"]
output = speech_recognizer(audio)
self.assertEqual(
output,
{"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."},
)
output = speech_recognizer(filename, return_timestamps=True)
output = speech_recognizer(ds[0]["audio"], return_timestamps=True)
self.assertEqual(
output,
{
Expand All @@ -1100,7 +1096,7 @@ def test_simple_whisper_asr(self):
},
)
speech_recognizer.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
output = speech_recognizer(filename, return_timestamps="word")
output = speech_recognizer(ds[0]["audio"], return_timestamps="word")
# fmt: off
self.assertEqual(
output,
Expand Down Expand Up @@ -1135,7 +1131,7 @@ def test_simple_whisper_asr(self):
"^Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
"Use `return_timestamps='word'` or `return_timestamps=True` respectively.$",
):
_ = speech_recognizer(filename, return_timestamps="char")
_ = speech_recognizer(audio, return_timestamps="char")

@slow
@require_torch
Expand All @@ -1147,8 +1143,8 @@ def test_simple_whisper_translation(self):
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."})

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
Expand All @@ -1158,7 +1154,7 @@ def test_simple_whisper_translation(self):
speech_recognizer_2 = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
)
output_2 = speech_recognizer_2(filename)
output_2 = speech_recognizer_2(ds[0]["audio"])
self.assertEqual(output, output_2)

# either use generate_kwargs or set the model's generation_config
Expand All @@ -1170,7 +1166,7 @@ def test_simple_whisper_translation(self):
feature_extractor=feature_extractor,
generate_kwargs={"task": "transcribe", "language": "<|it|>"},
)
output_3 = speech_translator(filename)
output_3 = speech_translator(ds[0]["audio"])
self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."})

@slow
Expand All @@ -1182,10 +1178,10 @@ def test_whisper_language(self):
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
filename = ds[0]["file"]
audio = ds[0]["audio"]

# 1. English-only model compatible with no language argument
output = speech_recognizer(filename)
output = speech_recognizer(audio)
self.assertEqual(
output,
{"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."},
Expand All @@ -1197,15 +1193,15 @@ def test_whisper_language(self):
"Cannot specify `task` or `language` for an English-only model. If the model is intended to be multilingual, "
"pass `is_multilingual=True` to generate, or update the generation config.",
):
_ = speech_recognizer(filename, generate_kwargs={"language": "en"})
_ = speech_recognizer(ds[0]["audio"], generate_kwargs={"language": "en"})

# 3. Multilingual model accepts language argument
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny",
framework="pt",
)
output = speech_recognizer(filename, generate_kwargs={"language": "en"})
output = speech_recognizer(ds[0]["audio"], generate_kwargs={"language": "en"})
self.assertEqual(
output,
{"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."},
Expand Down Expand Up @@ -1315,8 +1311,8 @@ def test_xls_r_to_en(self):
)

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": "A man said to the universe: “Sir, I exist."})

@slow
Expand All @@ -1331,8 +1327,8 @@ def test_xls_r_from_en(self):
)

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": "Ein Mann sagte zu dem Universum, Sir, ich bin da."})

@slow
Expand All @@ -1348,9 +1344,8 @@ def test_speech_to_text_leveraged(self):
)

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]

output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})

@slow
Expand Down Expand Up @@ -1561,6 +1556,7 @@ def test_whisper_longform(self):
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
device=torch_device,
return_timestamps=True, # to allow longform generation
)

ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
Expand Down

0 comments on commit b116311

Please sign in to comment.