Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix whisper tokenizer saving #1334

Merged
merged 1 commit into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 10 additions & 99 deletions keras_nlp/models/whisper/whisper_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
)
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.python_utils import format_docstring


@keras_nlp_export("keras_nlp.models.WhisperPreprocessor")
Expand All @@ -49,9 +48,11 @@ class WhisperPreprocessor(Preprocessor):
directly to a Whisper model.

Args:
audio_feature_extractor: A `keras_nlp.models.WhisperAudioFeatureExtractor`
instance.
tokenizer: A `keras_nlp.models.WhisperTokenizer` instance.
audio_feature_extractor: A
`keras_nlp.models.WhisperAudioFeatureExtractor` instance or `None`.
If `None` a feature extractor with default parameters will be
created.
decoder_sequence_length: The length of the packed decoder inputs.
language: string, language token. Should only be passed if your
tokenizer is multilingual.
Expand All @@ -73,7 +74,9 @@ class WhisperPreprocessor(Preprocessor):

Directly calling the layer on data.
```python
preprocessor = keras_nlp.models.WhisperPreprocessor.from_preset("whisper_tiny_en")
preprocessor = keras_nlp.models.WhisperPreprocessor.from_preset(
"whisper_tiny_en",
)

# Preprocess unbatched inputs.
input_data = {
Expand Down Expand Up @@ -153,15 +156,17 @@ class WhisperPreprocessor(Preprocessor):

def __init__(
self,
audio_feature_extractor,
tokenizer,
audio_feature_extractor=None,
decoder_sequence_length=448,
language=None,
task=None,
no_timestamps=True,
**kwargs,
):
super().__init__(**kwargs)
if audio_feature_extractor is None:
audio_feature_extractor = WhisperAudioFeatureExtractor()
self.audio_feature_extractor = audio_feature_extractor
self.tokenizer = tokenizer
self.decoder_sequence_length = decoder_sequence_length
Expand Down Expand Up @@ -313,97 +318,3 @@ def tokenizer_cls(cls):
@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)

@classmethod
def from_preset(
cls,
preset,
language=None,
task=None,
no_timestamps=True,
**kwargs,
):
"""Instantiate `WhisperPreprocessor` from preset architecture.

Args:
preset: string. Must be one of "{{preset_names}}".
language: string, language token (eg., `"<|en|>"`). Should only be
passed if your tokenizer is multilingual.
task: string, task name. One of `"transcribe"`, `"translate"`.
Should only be passed if your tokenizer is multilingual.
no_timestamps: bool. If True, `"<|no_timestamps|>"` will be added as
a special token to your input.

Examples:
```python
# Load a preprocessor layer from a preset.
preprocessor = keras_nlp.models.WhisperPreprocessor.from_preset(
"{{example_preset_name}}",
)
```
"""
# Override base class's `from_preset` to handle audio feature extractor
# , `decoder_sequence_length` and special tokens.
if not cls.presets:
raise NotImplementedError(
"No presets have been created for this class."
)
if preset not in cls.presets:
raise ValueError(
"`preset` must be one of "
f"""{", ".join(cls.presets)}. Received: {preset}."""
)

audio_feature_extractor = cls.audio_feature_extractor_cls.from_preset(
preset
)
tokenizer = cls.tokenizer_cls.from_preset(preset)

metadata = cls.presets[preset]
# For task model presets, the backbone config is nested.
if "backbone" in metadata["config"]:
backbone_config = metadata["config"]["backbone"]["config"]
else:
backbone_config = metadata["config"]

# Use model's `max_decoder_sequence_length` if `decoder_sequence_length`
# is unspecified; otherwise check that `decoder_sequence_length` is not
# too long.
decoder_sequence_length = kwargs.pop("decoder_sequence_length", None)
max_decoder_sequence_length = backbone_config[
"max_decoder_sequence_length"
]

def check_sequence_length(sequence_length, max_sequence_length, name):
if sequence_length is not None:
if sequence_length > max_sequence_length:
raise ValueError(
f"`{name}` cannot be longer than `{preset}` "
f"preset's `max_{name}` of {max_sequence_length}. "
f"Received: {sequence_length}."
)
return sequence_length
else:
return max_sequence_length

decoder_sequence_length = check_sequence_length(
decoder_sequence_length,
max_decoder_sequence_length,
"decoder_sequence_length",
)

return cls(
audio_feature_extractor=audio_feature_extractor,
tokenizer=tokenizer,
decoder_sequence_length=decoder_sequence_length,
language=language,
task=task,
no_timestamps=no_timestamps,
**kwargs,
)


format_docstring(
example_preset_name=next(iter(backbone_presets), ""),
preset_names='", "'.join(backbone_presets),
)(WhisperPreprocessor.from_preset.__func__)
18 changes: 0 additions & 18 deletions keras_nlp/models/whisper/whisper_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,6 @@
"<|transcribe|>": 50357,
}

AUDIO_FEATURE_EXTRACTOR_CONFIG = {
"num_mels": 80,
"num_fft_bins": 400,
"stride": 160,
"sampling_rate": 16000,
"max_audio_length": 30,
}

LANGUAGE_TOKENS = {
"<|af|>": 50327,
"<|am|>": 50334,
Expand Down Expand Up @@ -161,7 +153,6 @@
"max_encoder_sequence_length": 3000,
"max_decoder_sequence_length": 448,
},
"audio_feature_extractor_config": AUDIO_FEATURE_EXTRACTOR_CONFIG,
"preprocessor_config": {
"special_tokens": ENGLISH_SPECIAL_TOKENS,
"language_tokens": None,
Expand Down Expand Up @@ -195,7 +186,6 @@
"max_encoder_sequence_length": 3000,
"max_decoder_sequence_length": 448,
},
"audio_feature_extractor_config": AUDIO_FEATURE_EXTRACTOR_CONFIG,
"preprocessor_config": {
"special_tokens": ENGLISH_SPECIAL_TOKENS,
"language_tokens": None,
Expand Down Expand Up @@ -229,7 +219,6 @@
"max_encoder_sequence_length": 3000,
"max_decoder_sequence_length": 448,
},
"audio_feature_extractor_config": AUDIO_FEATURE_EXTRACTOR_CONFIG,
"preprocessor_config": {
"special_tokens": ENGLISH_SPECIAL_TOKENS,
"language_tokens": None,
Expand Down Expand Up @@ -263,7 +252,6 @@
"max_encoder_sequence_length": 3000,
"max_decoder_sequence_length": 448,
},
"audio_feature_extractor_config": AUDIO_FEATURE_EXTRACTOR_CONFIG,
"preprocessor_config": {
"special_tokens": ENGLISH_SPECIAL_TOKENS,
"language_tokens": None,
Expand Down Expand Up @@ -297,7 +285,6 @@
"max_encoder_sequence_length": 3000,
"max_decoder_sequence_length": 448,
},
"audio_feature_extractor_config": AUDIO_FEATURE_EXTRACTOR_CONFIG,
"preprocessor_config": {
"special_tokens": MULTILINGUAL_SPECIAL_TOKENS,
"language_tokens": LANGUAGE_TOKENS,
Expand Down Expand Up @@ -331,7 +318,6 @@
"max_encoder_sequence_length": 3000,
"max_decoder_sequence_length": 448,
},
"audio_feature_extractor_config": AUDIO_FEATURE_EXTRACTOR_CONFIG,
"preprocessor_config": {
"special_tokens": MULTILINGUAL_SPECIAL_TOKENS,
"language_tokens": LANGUAGE_TOKENS,
Expand Down Expand Up @@ -365,7 +351,6 @@
"max_encoder_sequence_length": 3000,
"max_decoder_sequence_length": 448,
},
"audio_feature_extractor_config": AUDIO_FEATURE_EXTRACTOR_CONFIG,
"preprocessor_config": {
"special_tokens": MULTILINGUAL_SPECIAL_TOKENS,
"language_tokens": LANGUAGE_TOKENS,
Expand Down Expand Up @@ -399,7 +384,6 @@
"max_encoder_sequence_length": 3000,
"max_decoder_sequence_length": 448,
},
"audio_feature_extractor_config": AUDIO_FEATURE_EXTRACTOR_CONFIG,
"preprocessor_config": {
"special_tokens": MULTILINGUAL_SPECIAL_TOKENS,
"language_tokens": LANGUAGE_TOKENS,
Expand Down Expand Up @@ -433,7 +417,6 @@
"max_encoder_sequence_length": 3000,
"max_decoder_sequence_length": 448,
},
"audio_feature_extractor_config": AUDIO_FEATURE_EXTRACTOR_CONFIG,
"preprocessor_config": {
"special_tokens": MULTILINGUAL_SPECIAL_TOKENS,
"language_tokens": LANGUAGE_TOKENS,
Expand Down Expand Up @@ -468,7 +451,6 @@
"max_encoder_sequence_length": 3000,
"max_decoder_sequence_length": 448,
},
"audio_feature_extractor_config": AUDIO_FEATURE_EXTRACTOR_CONFIG,
"preprocessor_config": {
"special_tokens": MULTILINGUAL_SPECIAL_TOKENS,
"language_tokens": LANGUAGE_TOKENS,
Expand Down
13 changes: 13 additions & 0 deletions keras_nlp/models/whisper/whisper_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,20 @@ def __init__(
**kwargs,
)

def save_assets(self, dir_path):
# TODO: whisper is currently mutating it's vocabulary before passing
# it to the super class, so we need to restore the unmutated vocabulary
# before saving our assets. We should find a more robust (and memory
# efficient) way to do this.
vocabulary = self.vocabulary
self.vocabulary = self._initial_vocabulary
super().save_assets(dir_path)
self.vocabulary = vocabulary

def set_vocabulary_and_merges(self, vocabulary, merges):
if vocabulary is not None:
vocabulary = _load_dict(vocabulary)
self._initial_vocabulary = dict(vocabulary)

if self.language_tokens is not None:
# Multilingual tokenizer.
Expand All @@ -133,6 +144,8 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
self.transcribe_token,
]:
vocabulary[token] = self.special_tokens[token]
else:
self._initial_vocabulary = None

super().set_vocabulary_and_merges(vocabulary, merges)

Expand Down