Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WhisperTokenizer decode is offsetting timestamps incorrectly #34472

Closed
4 tasks
wallrothm opened this issue Oct 28, 2024 · 2 comments · Fixed by #34537
Closed
4 tasks

WhisperTokenizer decode is offsetting timestamps incorrectly #34472

wallrothm opened this issue Oct 28, 2024 · 2 comments · Fixed by #34537
Labels

Comments

@wallrothm
Copy link

wallrothm commented Oct 28, 2024

System Info

  • transformers version: 4.47.0.dev0
  • Platform: Linux-5.15.0-1073-azure-x86_64-with-glibc2.35
  • Python version: 3.11.0rc1
  • Huggingface_hub version: 0.23.4
  • Safetensors version: 0.4.2
  • Accelerate version: 0.31.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.1+cu121 (True)
  • Tensorflow version (GPU?): 2.16.1 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: no
  • Using GPU in script?: yes
  • GPU type: Tesla T4

Who can help?

@ylacombe
@eustlb
@sanchit-gandhi

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Decoding the output from Whisper using the WhisperTokenizer is seemingly offsetting the timestamps incorrectly in consecutive chunks, which for long audios leads to timestamp accuracy significantly degrading over time.

I have not found any open bug report on this matter. The issue #31942 and the PR intended to fix it #32131 are related, and hence I've added @sanchit-gandhi to this issue as well.

From my understanding, the above mentioned PR solves it under the assumption that the predicted timestamps at all time spans the entire previous chunk and thus incrementing the timestamp in consecutive chunks based on the cur_max_timestamp would solve it. However, cur_max_timestamp is not generally correctly offsetting the timestamps. The example described in #32131 (comment) does generate the correct output, but unfortunately slightly altering the silence leads to incorrect timestamps.

The following snippets should reproduce the issue (simply increasing silence from 15s to 16s):

from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, AutoProcessor
import numpy as np

# load model + processor
processor = AutoProcessor.from_pretrained("openai/whisper-small.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small.en")

# load dataset
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]["array"]
sampling_rate = dataset[0]["audio"]["sampling_rate"]

sample = [*sample[:15 * sampling_rate], *np.zeros(16 * sampling_rate).tolist(), *sample[15 * sampling_rate:]]
sample = np.array(sample)

# pre-process
inputs = processor(
    sample,
    sampling_rate=16_000,
    padding="longest",
    truncation=False,
    return_attention_mask=True,
    return_tensors="pt",
)

# inference
output = model.generate(**inputs, return_timestamps=True, return_segments=True)

# pass token ids to processor's decode method
result = processor.batch_decode(output["sequences"], skip_special_tokens=True, output_offsets=True)

# format output offsets for readability
print("\n".join([f"{chunk['timestamp'][0]:.2f} -> {chunk['timestamp'][1]:.2f} : {chunk['text']}" for chunk in result[0]["offsets"]]))

which results in the following output:

0.00 -> 6.38 :  Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.
6.38 -> 11.32 :  Nor is Mr. Quilter's manner less interesting than his matter.
11.32 -> 15.00 :  He tells us that at this festive season of the year,
15.00 -> 21.76 :  With Christmas and roast beef looming before us, similes drawn from eating and its results
21.76 -> 24.80 :  occur most readily to the mind.
24.80 -> 30.38 :  He has grave doubts whether Sir Frederick Layton's work is really Greek after all and
30.38 -> 34.00 :  can discover in it but little of rocky Ithaca.
34.00 -> 41.28 :  Lenell's pictures are a sort of up-guards-and-atom paintings, and Mason's exquisite ittles
41.28 -> 49.12 :  are as national as a jingo poem. Mr. Burkett fosters landscape's smile at one much in
49.12 -> 55.76 :  the same way that Mr. Karker used to flash his teeth. And Mr. John Collier gives his
55.76 -> 62.16 :  sitter a cheerful slap on the back before he says, like a shampoo or in a Turkish bath,
62.16 -> 63.16 :  Next Man

while inspecting the output["segments"] gives the following segment timestamps:

0.00 -> 6.38
6.38 -> 11.32
11.32 -> 15.00
30.00 -> 36.76
36.76 -> 39.80
39.80 -> 45.38
45.38 -> 49.00
49.00 -> 56.28
56.28 -> 64.12
64.12 -> 70.76
70.76 -> 77.16
77.16 -> 78.16

which in turn are close to the output https://github.com/openai/whisper generates and sort of explains how the fourth segment becomes 15.00 -> 21.76 instead of the expected 30.00 -> 36.76.

Expected behavior

I would expect that the WhisperTokenizer can correctly decode handle offsets in timestamps and not have timestamps become misaligned to the corresponding chunk.

@wallrothm wallrothm added the bug label Oct 28, 2024
@eustlb
Copy link
Contributor

eustlb commented Oct 29, 2024

Hey @wallrothm,

Thanks a lot for raising this issue!

There's indeed a problem with WhisperTokenizer: I confirmed by running my forked version of the original Whisper implem (see this #34111 for more info) on input features built as you mentioned above, then saved the generated tokens and passed them through result = processor.batch_decode(output["sequences"], skip_special_tokens=True, output_offsets=True). I see the same issue you described with 15.00 -> 21.76 instead of the expected 30.00 -> 36.76. Let me open a PR to solve it.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants