Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix beam search with batch processing in Whisper decoding #2197

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ jobs:
- uses: actions/checkout@v3
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
- run: pip install .["dev"]
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
- run: pytest --durations=0 -vv -k 'not (test_transcribe or test_decode) or test_transcribe[tiny] or test_transcribe[tiny.en] or test_decode[tiny] or test_decode[tiny.en]' -m 'not requires_cuda'
52 changes: 52 additions & 0 deletions tests/test_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os

import pytest
import torch

import whisper


@pytest.mark.parametrize("model_name", whisper.available_models())
def test_decode(model_name: str):
# Regression test: batch_size and beam_size should work together
beam_size = 2
batch_size = 2

device = "cuda" if torch.cuda.is_available() else "cpu"
model = whisper.load_model(model_name).to(device)
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")

language = "en" if model_name.endswith(".en") else None

options = whisper.DecodingOptions(language=language, beam_size=beam_size)

audio = whisper.load_audio(audio_path)
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio).to(device)

# Create a small batch
batch_mel = mel.unsqueeze(0).repeat(batch_size, 1, 1)

results = model.decode(batch_mel, options)

# Since both examples are the same, results should be identical
assert len(results) == batch_size
assert results[0].text == results[1].text

decoded_text = results[0].text.lower()
assert "my fellow americans" in decoded_text
assert "your country" in decoded_text
assert "do for you" in decoded_text

timing_checked = False
if hasattr(results[0], "segments"):
for segment in results[0].segments:
for timing in segment["words"]:
assert timing["start"] < timing["end"]
if timing["word"].strip(" ,") == "Americans":
assert timing["start"] <= 1.8
assert timing["end"] >= 1.8
timing_checked = True

if hasattr(results[0], "segments"):
assert timing_checked
3 changes: 3 additions & 0 deletions whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,9 @@ def run(self, mel: Tensor) -> List[DecodingResult]:
]

# repeat text tensors by the group size, for beam search or best-of-n sampling
audio_features = audio_features.repeat_interleave(self.n_group, dim=0).to(
audio_features.device
)
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)

# call the main sampling loop
Expand Down