Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect Whisper long-form decoding timestamps #31942

Open
2 of 4 tasks
Robinysh opened this issue Jul 12, 2024 · 9 comments · Fixed by #32131
Open
2 of 4 tasks

Incorrect Whisper long-form decoding timestamps #31942

Robinysh opened this issue Jul 12, 2024 · 9 comments · Fixed by #32131
Labels

Comments

@Robinysh
Copy link

System Info

  • transformers version: 4.42.4
  • Platform: Linux-6.8.9-arch1-2-x86_64-with-glibc2.39
  • Python version: 3.11.9
  • Huggingface_hub version: 0.23.4
  • Safetensors version: 0.4.3
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: NVIDIA GeForce RTX 3090

Who can help?

@Narsil @sanchit-gandhi

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import numpy as np
import json
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
from datasets import load_dataset

device = "cuda"
torch_dtype = torch.bfloat16
# model_id = "openai/whisper-large-v3"
model_id = "distil-whisper/distil-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=False, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=None,
    torch_dtype=torch_dtype,
    device=device,
)

dataset = load_dataset(
    "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
sample = dataset[0]["audio"]
sample = np.concatenate([sample["array"]] * 10)

results = pipe(
    sample,
    return_timestamps=True,
    generate_kwargs={
        "language": "english",
    },
)
print(json.dumps(results, indent=4))

Output

{
    "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel. Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel. Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome His Gospel. Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel. Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel. Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel. Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel. Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel. Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel. Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel.",
    "chunks": [
        {
            "timestamp": [
                0.0,
                6.5
            ],
            "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel."
        },
        {
            "timestamp": [
                6.5,
                12.5
            ],
            "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel."
        },
        {
            "timestamp": [
                12.5,
                18.24
            ],
            "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome His Gospel."
        },
        {
            "timestamp": [
                18.24,
                24.0
            ],
            "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel."
        },
        {
            "timestamp": [
                24.0,
                29.84
            ],
            "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel."
        },
        {
            "timestamp": [
                0.0,
                4.7
            ],
            "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel."
        },
        {
            "timestamp": [
                5.76,
                10.54
            ],
            "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel."
        },
        {
            "timestamp": [
                11.6,
                16.4
            ],
            "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel."
        },
        {
            "timestamp": [
                17.46,
                22.26
            ],
            "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel."
        },
        {
            "timestamp": [
                23.12,
                28.12
            ],
            "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel."
        }
    ]
}

Expected behavior

Currently the timestamp resets to zero after 30s of audio. I expect the timestamps to increase monotonically.

@Robinysh
Copy link
Author

Additionally, the bug does not happen when I add chunk_length_s=30 to pipe(), i.e.

results = pipe(
    sample,
    chunk_length_s=30,
    return_timestamps=True,
    generate_kwargs={
        "language": "english",
    },
)

However this workaround is not applicable to me because I also would like to supply the compression_ratio_threshold argument to generate_kwargs, and that is not supported with short-form transcription.

@amyeroberts
Copy link
Collaborator

cc @kamilakesbi

@kamilakesbi
Copy link
Contributor

Thanks for opening this issue @Robinysh!

This is indeed a problem, I'll open an issue to solve it!

Note that we're working on unifying short form and long form generation with PR #30984. Once merged you should be able to use compression_ratio_threshold with short-form transcription :)

@kamilakesbi
Copy link
Contributor

kamilakesbi commented Jul 16, 2024

Hi @Robinysh,

You could use this workaround before we properly integrate the solution in Transformers:

import numpy as np
import json
from transformers import AutoProcessor, WhisperForConditionalGeneration
import torch
from datasets import load_dataset

device = "cuda"
torch_dtype = torch.bfloat16

processor = AutoProcessor.from_pretrained("distil-whisper/distil-large-v3")
model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v3", torch_dtype=torch.float16)
model = model.to("cuda")

dataset = load_dataset(
    "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
sample = dataset[0]["audio"]
sample = np.concatenate([sample["array"]] * 10)

inputs = processor(sample, return_tensors="pt", truncation=False, sampling_rate=16_000)
inputs = inputs.to("cuda", torch.float16)

output = model.generate(**inputs, return_timestamps=True, return_segments = True)

result = processor.batch_decode(output['sequences'], skip_special_tokens=True, output_offsets = True)

for i in range(len(result[0]['offsets'])):
   result[0]['offsets'][i]['timestamp'] = (output['segments'][0][i]['start'].item(), output['segments'][0][i]['end'].item())

print(json.dumps(result, indent=4))

Explanation:

When performing long form generation with Whisper, the right utterance level timestamps are returned as output to generate when we specify return_segments = True and return_timestamps=True.

The problem arise at the decoding level: batch_decode currently doesn't support the output to long form generation (segments). When specifying output_offsets, it indeed outputs the wrong timestamps you previously obtained.

One simple solution is to replace the obtained timestamps with the ones stored in output[segments], which I do with these lines:

for i in range(len(result[0]['offsets'])):
   result[0]['offsets'][i]['timestamp'] = (output['segments'][0][i]['start'].item(), output['segments'][0][i]['end'].item())

cc @sanchit-gandhi @ylacombe ( We should integrate this properly in batch_decode and also handle it in the automatic speech recognition pipeline, I'll open a PR for that :) )

@ialvata
Copy link

ialvata commented Jan 14, 2025

Any estimated time for a solution to this issue?

@Rocketknight1
Copy link
Member

cc @eustlb - I'm not sure if this is fixed already?

@eustlb
Copy link
Contributor

eustlb commented Jan 15, 2025

Related to #34210 and not fixed yet for the pipeline.
Tackling this issue this week! 🤗

In the meantime, please run:

import numpy as np
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
from datasets import load_dataset, Audio

device = "cuda"
torch_dtype = torch.bfloat16
model_id = "distil-whisper/distil-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=False, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)

dataset = load_dataset(
    "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
sample = dataset[0]["audio"]
sample = np.concatenate([sample["array"]] * 10)

input_features = processor(
    sample, return_tensors="pt", truncation=False, sampling_rate=16000
).input_features
input_features = input_features.to(device, torch_dtype)

generated_ids = model.generate(input_features, return_timestamps=True, return_segments=True)

transcript = processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True)
for el in transcript[0]["offsets"]:
    print(el)

@eustlb
Copy link
Contributor

eustlb commented Jan 17, 2025

Fixed in #35750 that will be merged ASAP! Thanks a lot for raising this issue, and thanks a lot for your patience 🤗

@SomaRe
Copy link

SomaRe commented Jan 22, 2025

@eustlb Looks like its not yet merged! Came across the issue today, will probably try suggested alternate solutions for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants