Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix beam search with batch processing in Whisper decoding #2197

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

zuazo
Copy link

@zuazo zuazo commented Jun 1, 2024

When using the Whisper model's decode() method with both beam_size and batch_size greater than 1, a dimension mismatch error occurs. This issue was apparently introduced accidentally in PR #1483.

Example to Reproduce the Error

import torch
import whisper

device = "cuda" if torch.cuda.is_available() else "cpu"
model = whisper.load_model("tiny").to(device)

# load audio and pad/trim it to fit 30 seconds
audio = whisper.load_audio("tests/jfk.flac")
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio).to(model.device)

# Create a small batch with 2 examples
batch_mel = mel.unsqueeze(0).repeat(2, 1, 1)

# decode the audio with beam size > 1
options = whisper.DecodingOptions(beam_size=5)
results = model.decode(batch_mel, options)

# print the recognized text
for result in results:
    print(result.text)

This outputs the following error:

Traceback (most recent call last):
  File "decode_example.py", line 17, in <module>
    results = model.decode(batch_mel, options)
  File "~/.anaconda3/envs/whisper-dev/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "~/src/whisper/whisper/decoding.py", line 824, in decode
    result = DecodingTask(model, options).run(mel)
  File "~/.anaconda3/envs/whisper-dev/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "~/src/whisper/whisper/decoding.py", line 737, in run
    tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
  File "~/src/whisper/whisper/decoding.py", line 687, in _main_loop
    logits = self.inference.logits(tokens, audio_features)
  File "~/src/whisper/whisper/decoding.py", line 163, in logits
    return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
  File "~/.anaconda3/envs/whisper-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/src/whisper/whisper/model.py", line 211, in forward
    x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
  File "~/.anaconda3/envs/whisper-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/src/whisper/whisper/model.py", line 138, in forward
    x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
  File "~/.anaconda3/envs/whisper-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/src/whisper/whisper/model.py", line 90, in forward
    wv, qk = self.qkv_attention(q, k, v, mask)
  File "~/src/whisper/whisper/model.py", line 102, in qkv_attention
    qk = q @ k
RuntimeError: The size of tensor a (10) must match the size of tensor b (2) at non-singleton dimension 0

Solution

  • Ensure that audio features are correctly duplicated across beams for each batch item.
  • Add a test for decode() that includes a regression test for this.
  • Update .github/workflows/test.yml to run the new test for decode().

* It ensures that audio features are correctly duplicated across beams for each batch item.
* Added a test for `decode()` that includes a regression test for this.
* Update *.github/workflows/test.yml* to run the new test for `decode()` in tiny.
* This issue was introduced in PR openai#1483.
Copy link

@nevisende nevisende left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is solved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants