Skip to content

Commit

Permalink
[whisper] alternative fix for long-form timestamps (huggingface#32131)
Browse files Browse the repository at this point in the history
* [whisper] alternative fix for long-form timestamps

* update test
  • Loading branch information
sanchit-gandhi authored and BernardZach committed Dec 5, 2024
1 parent 8de9659 commit 45cecbb
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 4 deletions.
13 changes: 11 additions & 2 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,11 +587,20 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1)

last_slice = np.where(timestamp_tokens)[0][0]
cur_max_timestamp = 0
prev_segments_len = 0
for current_slice in consecutive:
sliced_tokens = token_ids[last_slice:current_slice]
if len(sliced_tokens) > 1:
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin

if start_timestamp_position < cur_max_timestamp:
# next segment has started
prev_segments_len += cur_max_timestamp

cur_max_timestamp = end_timestamp_position

# strip timestamp tokens from the text output
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
text = self._decode(sliced_tokens)
Expand All @@ -600,8 +609,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
{
"text": text,
"timestamp": (
start_timestamp_position * time_precision,
end_timestamp_position * time_precision,
(start_timestamp_position + prev_segments_len) * time_precision,
(end_timestamp_position + prev_segments_len) * time_precision,
),
}
)
Expand Down
13 changes: 11 additions & 2 deletions src/transformers/models/whisper/tokenization_whisper_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,20 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1)

last_slice = np.where(timestamp_tokens)[0][0]
cur_max_timestamp = 0
prev_segments_len = 0
for current_slice in consecutive:
sliced_tokens = token_ids[last_slice:current_slice]
if len(sliced_tokens) > 1:
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin

if start_timestamp_position < cur_max_timestamp:
# next segment has started
prev_segments_len += cur_max_timestamp

cur_max_timestamp = end_timestamp_position

# strip timestamp tokens from the text output
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
text = self._decode(sliced_tokens)
Expand All @@ -242,8 +251,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
{
"text": text,
"timestamp": (
start_timestamp_position * time_precision,
end_timestamp_position * time_precision,
(start_timestamp_position + prev_segments_len) * time_precision,
(end_timestamp_position + prev_segments_len) * time_precision,
),
}
)
Expand Down
59 changes: 59 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2099,6 +2099,65 @@ def test_tiny_timestamp_generation(self):
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)

@slow
def test_tiny_longform_timestamps_generation(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)

dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]

input_features = processor(
sample["array"], return_tensors="pt", truncation=False, sampling_rate=sample["sampling_rate"]
)
input_features = input_features.to(torch_device)

generated_ids = model.generate(**input_features, return_timestamps=True, return_segments=True)

EXPECTED_TRANSCRIPT = [
{
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
"timestamp": (0.0, 6.5600000000000005),
},
{
"text": " Nor is Mr. Quilter's manner less interesting than his matter.",
"timestamp": (6.5600000000000005, 11.24),
},
{
"text": " He tells us that at this festive season of the year, with Christmas and roast beef looming",
"timestamp": (11.24, 16.88),
},
{
"text": " before us, similarly drawn from eating and its results occur most readily to the mind.",
"timestamp": (16.88, 23.76),
},
{
"text": " He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and",
"timestamp": (23.76, 29.44),
},
{"text": " can discover in it but little of rocky ithaka.", "timestamp": (29.44, 33.72)},
{
"text": " Lennils, pictures, are a sort of upguards and atom paintings, and Mason's exquisite itals",
"timestamp": (33.72, 40.32),
},
{"text": " are as national as a jingo poem.", "timestamp": (40.32, 44.72)},
{
"text": " Mr. Birkut Foster's landscapes smile at one much in the same way that Mr. Carker used",
"timestamp": (44.72, 50.4),
},
{"text": " to flash his teeth.", "timestamp": (50.4, 52.96)},
{
"text": " And Mr. John Collier gives his sitter a cheerful slap on the back before he says, like",
"timestamp": (52.96, 58.68),
},
{"text": " a shampoo and a Turkish bath next man.", "timestamp": (58.68, 61.96)},
]

transcript = processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript[0]["offsets"], EXPECTED_TRANSCRIPT)

@slow
def test_large_timestamp_generation(self):
set_seed(0)
Expand Down

0 comments on commit 45cecbb

Please sign in to comment.