Skip to content

Commit

Permalink
Support for whisper
Browse files Browse the repository at this point in the history
  • Loading branch information
abuelnasr0 committed Apr 2, 2024
1 parent 085c233 commit c231372
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
13 changes: 7 additions & 6 deletions keras_nlp/models/whisper/whisper_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class WhisperTokenizer(BytePairTokenizer):
language_tokens: string or dict, maps language tokens to integer IDs. If
not None, the tokenizer will be assumed to be a multilingual
tokenizer.
special_tokens_in_strings: bool. A bool to indicate if the tokenizer
should expect special tokens in input strings that should be
tokenized and mapped correctly to their ids. Defaults to False.
"""

def __init__(
Expand All @@ -53,6 +57,7 @@ def __init__(
merges=None,
special_tokens=None,
language_tokens=None,
special_tokens_in_strings=False,
**kwargs,
):
special_tokens = _load_dict(special_tokens)
Expand Down Expand Up @@ -105,6 +110,7 @@ def __init__(
vocabulary=vocabulary,
merges=merges,
unsplittable_tokens=unsplittable_tokens,
unsplittable_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)

Expand Down Expand Up @@ -148,12 +154,7 @@ def set_vocabulary_and_merges(self, vocabulary, merges):

def get_config(self):
config = super().get_config()

# In the constructor, we pass the list of special tokens to the
# `unsplittable_tokens` arg of the superclass' constructor. Hence, we
# delete it from the config here.
del config["unsplittable_tokens"]

del config["unsplittable_tokens"] # Not configurable; set in __init__.
config.update(
{
"special_tokens": self.special_tokens,
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/whisper/whisper_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def setUp(self):
"merges": self.merges,
"special_tokens": self.special_tokens,
"language_tokens": self.language_tokens,
"special_tokens_in_strings": True,
}
self.input_data = [
" airplane at airport<|endoftext|>",
Expand Down

0 comments on commit c231372

Please sign in to comment.