From 1de59d6eed8d895dd96d3e74179c34b57be7e38a Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Wed, 29 Nov 2023 15:28:14 -0800 Subject: [PATCH] Fix whisper tokenizer saving --- .../models/whisper/whisper_preprocessor.py | 109 ++---------------- keras_nlp/models/whisper/whisper_presets.py | 18 --- keras_nlp/models/whisper/whisper_tokenizer.py | 13 +++ 3 files changed, 23 insertions(+), 117 deletions(-) diff --git a/keras_nlp/models/whisper/whisper_preprocessor.py b/keras_nlp/models/whisper/whisper_preprocessor.py index 2f8673c52f..abcff0d770 100644 --- a/keras_nlp/models/whisper/whisper_preprocessor.py +++ b/keras_nlp/models/whisper/whisper_preprocessor.py @@ -30,7 +30,6 @@ ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight from keras_nlp.utils.python_utils import classproperty -from keras_nlp.utils.python_utils import format_docstring @keras_nlp_export("keras_nlp.models.WhisperPreprocessor") @@ -49,9 +48,11 @@ class WhisperPreprocessor(Preprocessor): directly to a Whisper model. Args: - audio_feature_extractor: A `keras_nlp.models.WhisperAudioFeatureExtractor` - instance. tokenizer: A `keras_nlp.models.WhisperTokenizer` instance. + audio_feature_extractor: A + `keras_nlp.models.WhisperAudioFeatureExtractor` instance or `None`. + If `None` a feature extractor with default parameters will be + created. decoder_sequence_length: The length of the packed decoder inputs. language: string, language token. Should only be passed if your tokenizer is multilingual. @@ -73,7 +74,9 @@ class WhisperPreprocessor(Preprocessor): Directly calling the layer on data. ```python - preprocessor = keras_nlp.models.WhisperPreprocessor.from_preset("whisper_tiny_en") + preprocessor = keras_nlp.models.WhisperPreprocessor.from_preset( + "whisper_tiny_en", + ) # Preprocess unbatched inputs. input_data = { @@ -153,8 +156,8 @@ class WhisperPreprocessor(Preprocessor): def __init__( self, - audio_feature_extractor, tokenizer, + audio_feature_extractor=None, decoder_sequence_length=448, language=None, task=None, @@ -162,6 +165,8 @@ def __init__( **kwargs, ): super().__init__(**kwargs) + if audio_feature_extractor is None: + audio_feature_extractor = WhisperAudioFeatureExtractor() self.audio_feature_extractor = audio_feature_extractor self.tokenizer = tokenizer self.decoder_sequence_length = decoder_sequence_length @@ -313,97 +318,3 @@ def tokenizer_cls(cls): @classproperty def presets(cls): return copy.deepcopy(backbone_presets) - - @classmethod - def from_preset( - cls, - preset, - language=None, - task=None, - no_timestamps=True, - **kwargs, - ): - """Instantiate `WhisperPreprocessor` from preset architecture. - - Args: - preset: string. Must be one of "{{preset_names}}". - language: string, language token (eg., `"<|en|>"`). Should only be - passed if your tokenizer is multilingual. - task: string, task name. One of `"transcribe"`, `"translate"`. - Should only be passed if your tokenizer is multilingual. - no_timestamps: bool. If True, `"<|no_timestamps|>"` will be added as - a special token to your input. - - Examples: - ```python - # Load a preprocessor layer from a preset. - preprocessor = keras_nlp.models.WhisperPreprocessor.from_preset( - "{{example_preset_name}}", - ) - ``` - """ - # Override base class's `from_preset` to handle audio feature extractor - # , `decoder_sequence_length` and special tokens. - if not cls.presets: - raise NotImplementedError( - "No presets have been created for this class." - ) - if preset not in cls.presets: - raise ValueError( - "`preset` must be one of " - f"""{", ".join(cls.presets)}. Received: {preset}.""" - ) - - audio_feature_extractor = cls.audio_feature_extractor_cls.from_preset( - preset - ) - tokenizer = cls.tokenizer_cls.from_preset(preset) - - metadata = cls.presets[preset] - # For task model presets, the backbone config is nested. - if "backbone" in metadata["config"]: - backbone_config = metadata["config"]["backbone"]["config"] - else: - backbone_config = metadata["config"] - - # Use model's `max_decoder_sequence_length` if `decoder_sequence_length` - # is unspecified; otherwise check that `decoder_sequence_length` is not - # too long. - decoder_sequence_length = kwargs.pop("decoder_sequence_length", None) - max_decoder_sequence_length = backbone_config[ - "max_decoder_sequence_length" - ] - - def check_sequence_length(sequence_length, max_sequence_length, name): - if sequence_length is not None: - if sequence_length > max_sequence_length: - raise ValueError( - f"`{name}` cannot be longer than `{preset}` " - f"preset's `max_{name}` of {max_sequence_length}. " - f"Received: {sequence_length}." - ) - return sequence_length - else: - return max_sequence_length - - decoder_sequence_length = check_sequence_length( - decoder_sequence_length, - max_decoder_sequence_length, - "decoder_sequence_length", - ) - - return cls( - audio_feature_extractor=audio_feature_extractor, - tokenizer=tokenizer, - decoder_sequence_length=decoder_sequence_length, - language=language, - task=task, - no_timestamps=no_timestamps, - **kwargs, - ) - - -format_docstring( - example_preset_name=next(iter(backbone_presets), ""), - preset_names='", "'.join(backbone_presets), -)(WhisperPreprocessor.from_preset.__func__) diff --git a/keras_nlp/models/whisper/whisper_presets.py b/keras_nlp/models/whisper/whisper_presets.py index e8c0d075a4..8ec5a7353d 100644 --- a/keras_nlp/models/whisper/whisper_presets.py +++ b/keras_nlp/models/whisper/whisper_presets.py @@ -27,14 +27,6 @@ "<|transcribe|>": 50357, } -AUDIO_FEATURE_EXTRACTOR_CONFIG = { - "num_mels": 80, - "num_fft_bins": 400, - "stride": 160, - "sampling_rate": 16000, - "max_audio_length": 30, -} - LANGUAGE_TOKENS = { "<|af|>": 50327, "<|am|>": 50334, @@ -161,7 +153,6 @@ "max_encoder_sequence_length": 3000, "max_decoder_sequence_length": 448, }, - "audio_feature_extractor_config": AUDIO_FEATURE_EXTRACTOR_CONFIG, "preprocessor_config": { "special_tokens": ENGLISH_SPECIAL_TOKENS, "language_tokens": None, @@ -195,7 +186,6 @@ "max_encoder_sequence_length": 3000, "max_decoder_sequence_length": 448, }, - "audio_feature_extractor_config": AUDIO_FEATURE_EXTRACTOR_CONFIG, "preprocessor_config": { "special_tokens": ENGLISH_SPECIAL_TOKENS, "language_tokens": None, @@ -229,7 +219,6 @@ "max_encoder_sequence_length": 3000, "max_decoder_sequence_length": 448, }, - "audio_feature_extractor_config": AUDIO_FEATURE_EXTRACTOR_CONFIG, "preprocessor_config": { "special_tokens": ENGLISH_SPECIAL_TOKENS, "language_tokens": None, @@ -263,7 +252,6 @@ "max_encoder_sequence_length": 3000, "max_decoder_sequence_length": 448, }, - "audio_feature_extractor_config": AUDIO_FEATURE_EXTRACTOR_CONFIG, "preprocessor_config": { "special_tokens": ENGLISH_SPECIAL_TOKENS, "language_tokens": None, @@ -297,7 +285,6 @@ "max_encoder_sequence_length": 3000, "max_decoder_sequence_length": 448, }, - "audio_feature_extractor_config": AUDIO_FEATURE_EXTRACTOR_CONFIG, "preprocessor_config": { "special_tokens": MULTILINGUAL_SPECIAL_TOKENS, "language_tokens": LANGUAGE_TOKENS, @@ -331,7 +318,6 @@ "max_encoder_sequence_length": 3000, "max_decoder_sequence_length": 448, }, - "audio_feature_extractor_config": AUDIO_FEATURE_EXTRACTOR_CONFIG, "preprocessor_config": { "special_tokens": MULTILINGUAL_SPECIAL_TOKENS, "language_tokens": LANGUAGE_TOKENS, @@ -365,7 +351,6 @@ "max_encoder_sequence_length": 3000, "max_decoder_sequence_length": 448, }, - "audio_feature_extractor_config": AUDIO_FEATURE_EXTRACTOR_CONFIG, "preprocessor_config": { "special_tokens": MULTILINGUAL_SPECIAL_TOKENS, "language_tokens": LANGUAGE_TOKENS, @@ -399,7 +384,6 @@ "max_encoder_sequence_length": 3000, "max_decoder_sequence_length": 448, }, - "audio_feature_extractor_config": AUDIO_FEATURE_EXTRACTOR_CONFIG, "preprocessor_config": { "special_tokens": MULTILINGUAL_SPECIAL_TOKENS, "language_tokens": LANGUAGE_TOKENS, @@ -433,7 +417,6 @@ "max_encoder_sequence_length": 3000, "max_decoder_sequence_length": 448, }, - "audio_feature_extractor_config": AUDIO_FEATURE_EXTRACTOR_CONFIG, "preprocessor_config": { "special_tokens": MULTILINGUAL_SPECIAL_TOKENS, "language_tokens": LANGUAGE_TOKENS, @@ -468,7 +451,6 @@ "max_encoder_sequence_length": 3000, "max_decoder_sequence_length": 448, }, - "audio_feature_extractor_config": AUDIO_FEATURE_EXTRACTOR_CONFIG, "preprocessor_config": { "special_tokens": MULTILINGUAL_SPECIAL_TOKENS, "language_tokens": LANGUAGE_TOKENS, diff --git a/keras_nlp/models/whisper/whisper_tokenizer.py b/keras_nlp/models/whisper/whisper_tokenizer.py index 10998ffed8..996a501d18 100644 --- a/keras_nlp/models/whisper/whisper_tokenizer.py +++ b/keras_nlp/models/whisper/whisper_tokenizer.py @@ -111,9 +111,20 @@ def __init__( **kwargs, ) + def save_assets(self, dir_path): + # TODO: whisper is currently mutating it's vocabulary before passing + # it to the super class, so we need to restore the unmutated vocabulary + # before saving our assets. We should find a more robust (and memory + # efficient) way to do this. + vocabulary = self.vocabulary + self.vocabulary = self._initial_vocabulary + super().save_assets(dir_path) + self.vocabulary = vocabulary + def set_vocabulary_and_merges(self, vocabulary, merges): if vocabulary is not None: vocabulary = _load_dict(vocabulary) + self._initial_vocabulary = dict(vocabulary) if self.language_tokens is not None: # Multilingual tokenizer. @@ -133,6 +144,8 @@ def set_vocabulary_and_merges(self, vocabulary, merges): self.transcribe_token, ]: vocabulary[token] = self.special_tokens[token] + else: + self._initial_vocabulary = None super().set_vocabulary_and_merges(vocabulary, merges)