Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing kwargs through to TFBertTokenizer #24324

Merged
merged 1 commit into from
Jun 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/transformers/models/bert/tokenization_bert_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ class TFBertTokenizer(tf.keras.layers.Layer):
return_attention_mask (`bool`, *optional*, defaults to `True`):
Whether to return the attention_mask.
use_fast_bert_tokenizer (`bool`, *optional*, defaults to `True`):
If set to false will use standard TF Text BertTokenizer, making it servable by TF Serving.
If True, will use the FastBertTokenizer class from Tensorflow Text. If False, will use the BertTokenizer
class instead. BertTokenizer supports some additional options, but is slower and cannot be exported to
TFLite.
"""

def __init__(
Expand All @@ -65,11 +67,12 @@ def __init__(
return_token_type_ids: bool = True,
return_attention_mask: bool = True,
use_fast_bert_tokenizer: bool = True,
**tokenizer_kwargs,
):
super().__init__()
if use_fast_bert_tokenizer:
self.tf_tokenizer = FastBertTokenizer(
vocab_list, token_out_type=tf.int64, lower_case_nfd_strip_accents=do_lower_case
vocab_list, token_out_type=tf.int64, lower_case_nfd_strip_accents=do_lower_case, **tokenizer_kwargs
)
else:
lookup_table = tf.lookup.StaticVocabularyTable(
Expand All @@ -81,7 +84,9 @@ def __init__(
),
num_oov_buckets=1,
)
self.tf_tokenizer = BertTokenizerLayer(lookup_table, token_out_type=tf.int64, lower_case=do_lower_case)
self.tf_tokenizer = BertTokenizerLayer(
lookup_table, token_out_type=tf.int64, lower_case=do_lower_case, **tokenizer_kwargs
)

self.vocab_list = vocab_list
self.do_lower_case = do_lower_case
Expand Down