Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper Tokenizer] Make decoding faster after adding timestamps #26299

Merged
merged 1 commit into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def __init__(

# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>")

self.language = language
super().__init__(
Expand Down Expand Up @@ -560,10 +561,12 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
# strip timestamp tokens from the text output
sliced_tokens = self._preprocess_token_ids(sliced_tokens, decode_with_timestamps=False)
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
text = self._decode(sliced_tokens)
text = self._filter_timestamp_ids(text)
offsets.append(
{
"text": self._decode(sliced_tokens),
"text": text,
"timestamp": (
start_timestamp_position * time_precision,
end_timestamp_position * time_precision,
Expand All @@ -585,9 +588,7 @@ def timestamp_ids(self, time_precision=0.02):
"""
return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)])

def _preprocess_token_ids(
self, token_ids, skip_special_tokens: bool = False, decode_with_timestamps: bool = False, time_precision=0.02
):
def _preprocess_token_ids(self, token_ids, skip_special_tokens: bool = False):
"""
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.

Expand All @@ -597,24 +598,17 @@ def _preprocess_token_ids(
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
removed.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
filtered out from the token ids.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
if skip_special_tokens:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)

if not decode_with_timestamps:
# filter timestamp tokens if they are contained in the vocab
timestamp_ids = self.timestamp_ids(time_precision=time_precision)
token_ids = [token for token in token_ids if token not in timestamp_ids]

return token_ids

def _filter_timestamp_ids(self, token_ids):
return re.sub(self.timestamp_pat, "", token_ids)

def decode(
self,
token_ids,
Expand Down Expand Up @@ -644,6 +638,8 @@ def decode(
output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text.
Returns:
Expand All @@ -652,8 +648,6 @@ def decode(
filtered_ids = self._preprocess_token_ids(
token_ids,
skip_special_tokens=skip_special_tokens,
decode_with_timestamps=decode_with_timestamps,
time_precision=time_precision,
)

text = super().decode(
Expand All @@ -668,6 +662,9 @@ def decode(
text = self._decode_with_timestamps(
filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
)
else:
text = self._filter_timestamp_ids(text)

# retrieve offsets
if output_offsets:
offsets = self._compute_offsets(token_ids, time_precision=time_precision)
Expand Down
33 changes: 16 additions & 17 deletions src/transformers/models/whisper/tokenization_whisper_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Tokenization classes for Whisper."""
import json
import os
import re
from functools import lru_cache
from typing import List, Optional, Tuple

Expand Down Expand Up @@ -190,6 +191,7 @@ def __init__(
self.english_spelling_normalizer = None

self.add_prefix_space = add_prefix_space
self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>")

self.language = language
self.task = task
Expand Down Expand Up @@ -269,10 +271,12 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
# strip timestamp tokens from the text output
sliced_tokens = self._preprocess_token_ids(sliced_tokens, decode_with_timestamps=False)
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
text = self._decode(sliced_tokens)
text = self._filter_timestamp_ids(text)
offsets.append(
{
"text": self._decode(sliced_tokens),
"text": text,
"timestamp": (
start_timestamp_position * time_precision,
end_timestamp_position * time_precision,
Expand All @@ -296,9 +300,7 @@ def timestamp_ids(self, time_precision=0.02):
return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)])

# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._preprocess_token_ids
def _preprocess_token_ids(
self, token_ids, skip_special_tokens: bool = False, decode_with_timestamps: bool = False, time_precision=0.02
):
def _preprocess_token_ids(self, token_ids, skip_special_tokens: bool = False):
"""
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.

Expand All @@ -308,24 +310,18 @@ def _preprocess_token_ids(
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
removed.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
filtered out from the token ids.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
if skip_special_tokens:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)

if not decode_with_timestamps:
# filter timestamp tokens if they are contained in the vocab
timestamp_ids = self.timestamp_ids(time_precision=time_precision)
token_ids = [token for token in token_ids if token not in timestamp_ids]

return token_ids

# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._filter_timestamp_ids
def _filter_timestamp_ids(self, token_ids):
return re.sub(self.timestamp_pat, "", token_ids)

# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.decode
def decode(
self,
Expand Down Expand Up @@ -356,6 +352,8 @@ def decode(
output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text.
Returns:
Expand All @@ -364,8 +362,6 @@ def decode(
filtered_ids = self._preprocess_token_ids(
token_ids,
skip_special_tokens=skip_special_tokens,
decode_with_timestamps=decode_with_timestamps,
time_precision=time_precision,
)

text = super().decode(
Expand All @@ -380,6 +376,9 @@ def decode(
text = self._decode_with_timestamps(
filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
)
else:
text = self._filter_timestamp_ids(text)

# retrieve offsets
if output_offsets:
offsets = self._compute_offsets(token_ids, time_precision=time_precision)
Expand Down