-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[whisper] alternative fix for long-form timestamps #32131
[whisper] alternative fix for long-form timestamps #32131
Conversation
@@ -589,36 +587,33 @@ def _compute_offsets(self, token_ids, time_precision=0.02, longform_timestamps=N | |||
consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1) | |||
|
|||
last_slice = np.where(timestamp_tokens)[0][0] | |||
for i, current_slice in enumerate(consecutive): | |||
cur_max_timestamp = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With no change to the overall processor/tokenizer design, we can fix the original timestamp issue by keeping track of the last timestamp predicted
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@sanchit-gandhi thanks for working on this, this is indeed a better solution that in PR #32003! Could you make sure that the slow test Thank you! |
2da3c02
to
1244dbb
Compare
The test does indeed pass. I've updated it to use a different dataset that makes it a bit easier to see the effect of the correct timestamps (a 1 mins sample of LibriSpeech, rather than 10 concatenated samples of the same utterance). |
1244dbb
to
7f8938f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @sanchit-gandhi, thanks for this PR and sorry for not having catch the issue with the previous PR!
While I understand most the logic of the PR, I'm still struggling with an edge case that I might have misunderstood:
The way I see it, Whisper processes chunks of 30s of audios, but sometimes no speech is detected on the last seconds of a chunk.
For example, in the tiny test you added, the first 29.44
seconds contain the following speech:
<|0.00|> Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.<|6.56|><|6.56|> Nor is Mr. Quilter's manner less interesting than his matter.<|11.24|><|11.24|> He tells us that at this festive season of the year, with Christmas and roast beef looming<|16.88|><|16.88|> before us, similarly drawn from eating and its results occur most readily to the mind.<|23.76|><|23.76|> He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and<|29.44|>
.
The next 30 seconds contains something that begins with :
<|0.00|> can discover in it but little of rocky ithaka.<|4.28|><|4.28|> Lennils, pictures, are a sort of upguards and atom paintings, and Mason's exquisite itals<|10.88|><|10.88|> are as national as a jingo poem.<|15.28|>
So there's a small bit of audio in between the two chunks (30 - 29.44 = 0.56
), that doesn't seem to be taken into account when computing prev_segments_len
.
For this particular example, it doesn't matter that much but let's say we have a 45s audio sample with speech in the first 15s and in the last 15s, and no speech in-between.
In that case, prev_segments_len
will be corresponding to 15s
so the last 15s will be detected starting at 00:15 instead of 00:30, is that correct? If so, this is not what we want, right ?
Let me know what you think!
Here's a script for creating such an example and running inference:from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, AutoProcessor
import numpy as np
# load model + processor
processor = AutoProcessor.from_pretrained("openai/whisper-small.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small.en")
# load dataset
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]["array"]
sampling_rate = dataset[0]["audio"]["sampling_rate"]
# Contrived audio sample
# 1. First 15-seconds: standard speech
# 2. Next 15-seconds: silence (zeroes)
# 3. Last 15-seconds: speech
sample = [*sample[:15 * sampling_rate], *np.zeros(15 * sampling_rate).tolist(), *sample[15 * sampling_rate:]]
sample = np.array(sample)
# pre-process
inputs = processor(
sample,
sampling_rate=16_000,
padding="longest",
truncation=False,
return_attention_mask=True,
return_tensors="pt",
)
# inference
output = model.generate(**inputs, return_timestamps=True, return_segments=True)
# pass token ids to processor's decode method
result = processor.batch_decode(output["sequences"], skip_special_tokens=True, output_offsets=True)
print(result) What we get for this Transformers PR (formatted in VTT style):
What we get for original Whisper:
=> the start/end timings vary slightly between implementations (as a result of numerical precision), but both handle the long period of silence in the same way: they lump it into the preceding segment, giving an end timestamp that encompasses all of the silence:
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation @sanchit-gandhi, LGTM!
cc @ArthurZucker for review!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks both 🤗
* [whisper] alternative fix for long-form timestamps * update test
* [whisper] alternative fix for long-form timestamps * update test
* [whisper] alternative fix for long-form timestamps * update test
* [whisper] alternative fix for long-form timestamps * update test
What does this PR do?
Fixes #31942 and supersedes #32003 by implementing new logic in the tokenizer method
_compute_offsets
.The advantages of this PR over #32003 are: