diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index dffc17c61..48dca5248 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -53,4 +53,4 @@ jobs: - uses: actions/checkout@v3 - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH - run: pip install .["dev"] - - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda' + - run: pytest --durations=0 -vv -k 'not (test_transcribe or test_decode) or test_transcribe[tiny] or test_transcribe[tiny.en] or test_decode[tiny] or test_decode[tiny.en]' -m 'not requires_cuda' diff --git a/tests/test_decode.py b/tests/test_decode.py new file mode 100644 index 000000000..b0c0a6ea6 --- /dev/null +++ b/tests/test_decode.py @@ -0,0 +1,52 @@ +import os + +import pytest +import torch + +import whisper + + +@pytest.mark.parametrize("model_name", whisper.available_models()) +def test_decode(model_name: str): + # Regression test: batch_size and beam_size should work together + beam_size = 2 + batch_size = 2 + + device = "cuda" if torch.cuda.is_available() else "cpu" + model = whisper.load_model(model_name).to(device) + audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") + + language = "en" if model_name.endswith(".en") else None + + options = whisper.DecodingOptions(language=language, beam_size=beam_size) + + audio = whisper.load_audio(audio_path) + audio = whisper.pad_or_trim(audio) + mel = whisper.log_mel_spectrogram(audio).to(device) + + # Create a small batch + batch_mel = mel.unsqueeze(0).repeat(batch_size, 1, 1) + + results = model.decode(batch_mel, options) + + # Since both examples are the same, results should be identical + assert len(results) == batch_size + assert results[0].text == results[1].text + + decoded_text = results[0].text.lower() + assert "my fellow americans" in decoded_text + assert "your country" in decoded_text + assert "do for you" in decoded_text + + timing_checked = False + if hasattr(results[0], "segments"): + for segment in results[0].segments: + for timing in segment["words"]: + assert timing["start"] < timing["end"] + if timing["word"].strip(" ,") == "Americans": + assert timing["start"] <= 1.8 + assert timing["end"] >= 1.8 + timing_checked = True + + if hasattr(results[0], "segments"): + assert timing_checked diff --git a/whisper/decoding.py b/whisper/decoding.py index 49485d009..a516918f7 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -731,6 +731,9 @@ def run(self, mel: Tensor) -> List[DecodingResult]: ] # repeat text tensors by the group size, for beam search or best-of-n sampling + audio_features = audio_features.repeat_interleave(self.n_group, dim=0).to( + audio_features.device + ) tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device) # call the main sampling loop