Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fast sentencepiece tokenziation to T5. #683

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions seqio/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,12 +1265,14 @@ def sentencepiece_vocab(
sentencepiece_model_pb2.NormalizerSpec
] = None,
reverse_extra_ids: bool = True,
use_fast_tokenizer: bool = False,
):
return vocabularies.SentencePieceVocabulary(
os.path.join(TEST_DATA_DIR, "sentencepiece", "sentencepiece.model"),
extra_ids=extra_ids,
normalizer_spec_overrides=normalizer_spec_overrides,
reverse_extra_ids=reverse_extra_ids,
use_fast_tokenizer=use_fast_tokenizer,
)


Expand Down
6 changes: 6 additions & 0 deletions seqio/vocabularies.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def __init__(
sentencepiece_model_pb2.NormalizerSpec
] = None,
reverse_extra_ids: bool = True,
use_fast_tokenizer: bool = False,
):
"""Create a SentencePieceVocabulary.

Expand All @@ -300,11 +301,14 @@ def __init__(
reverse_extra_ids: if True, extra_ids are numbered in descending order, so
the first extra_id has the highest number. This is done for
compatibility with span_corruption mask generation in T5.
use_fast_tokenizer: use the tf_text fastsentencepiecetokenizer
implementation which runs much faster.
"""
self._sentencepiece_model_file = sentencepiece_model_file
self._normalizer_spec_overrides = normalizer_spec_overrides
self._reverse_extra_ids = reverse_extra_ids
self._model: Optional[SentencePieceVocabulary._ModelContext] = None
self._use_fast_tokenizer = use_fast_tokenizer

super().__init__(extra_ids=extra_ids)

Expand Down Expand Up @@ -436,6 +440,8 @@ def tokenizer(self) -> sentencepiece_processor.SentencePieceProcessor:
@property
def tf_tokenizer(self):
"""Instantiate and return a TF tokenizer."""
if self._use_fast_tokenizer:
return tf_text.FastSentencepieceTokenizer(model=self.sp_model)
return tf_text.SentencepieceTokenizer(model=self.sp_model)

@property
Expand Down
14 changes: 14 additions & 0 deletions seqio/vocabularies_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,20 @@ def test_extra_ids(self):
test_tokens, tuple(vocab.encode_tf(test_string).numpy())
)

def test_fast_tokenizer(self):
vocab = test_utils.sentencepiece_vocab(
extra_ids=10, use_fast_tokenizer=True)
self.assertEqual(36, vocab.vocab_size)
self.assertEqual("v", vocab.decode([25]))
test_string = "<extra_id_0> <extra_id_1> v <extra_id_9>"
test_tokens = (35, 34, 3, 25, 26)
self.assertEqual(test_string, vocab.decode(test_tokens))
self.assertEqual(test_string, _decode_tf(vocab, test_tokens))
self.assertSequenceEqual(test_tokens, vocab.encode(test_string))
self.assertSequenceEqual(
test_tokens, tuple(vocab.encode_tf(test_string).numpy())
)

def test_force_repeated_whitespace_preservation(self):
test_string = "a a a a" # string with repeated whitespaces

Expand Down
Loading