Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TFBertTokenizer - support for "never_split" #23798

Closed
benzitohhh opened this issue May 26, 2023 · 12 comments · Fixed by #24324
Closed

TFBertTokenizer - support for "never_split" #23798

benzitohhh opened this issue May 26, 2023 · 12 comments · Fixed by #24324

Comments

@benzitohhh
Copy link

benzitohhh commented May 26, 2023

Feature request

Often vocabularies contain special tokens that should not be split.

For example, in model "anferico/bert-for-patents", the vocabulary contains a token "[abstract]" (token_id is 5)
https://huggingface.co/anferico/bert-for-patents/raw/main/vocab.txt

The normal BertTokenizer supports a param "never_split" for this:

from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('anferico/bert-for-patents', never_split=['[abstract]'])
tokenizer.tokenize('[abstract] some text here')
# ['[abstract]', 'some', 'text', 'here']

So above, even though '[abstract]' has parens, it is not split.

But TFBertTokenizer does not have a "never_split" param, and so there is no way to prevent splits. For example:

from transformers import TFBertTokenizer
tokenizer = TFBertTokenizer.from_pretrained('anferico/bert-for-patents')
tokenizer(tf.constant(['[abstract] some text here']))
# {'input_ids': <tf.Tensor: shape=(1, 8), dtype=int64, numpy=array([[   2, 1036, 9726, 1038, 1735, 3458, 1847,    3]])>, 'attention_mask': <tf.Tensor: shape=(1, 8), dtype=int64, numpy=array([[1, 1, 1, 1, 1, 1, 1, 1]])>, 'token_type_ids': <tf.Tensor: shape=(1, 8), dtype=int64, numpy=array([[0, 0, 0, 0, 0, 0, 0, 0]])>}

Above, notice that token_id 5 (['abstract']) is missing in the input_ids, and in fact '[abstract]' has been split into three separate tokens:

  • '[' - 1036
  • 'abstract' - 9726
  • ']' - 1038

Motivation

We would like to use an end-to-end model, on TensorFlow Serving, with in-graph tokenization.

But we need to be able to include special tokens in our input, such as '[abstract]', '[claims]' etc
https://huggingface.co/anferico/bert-for-patents/raw/main/vocab.txt

If TFBertTokenizer had a "never_split" param, this would be possible.

But currently it is not, so we need to do Tokenization outside the server.

@benzitohhh benzitohhh changed the title TFBertTokenizer - support for "never_split", or a way of inserting special tokens. TFBertTokenizer - support for "never_split" May 26, 2023
@sgugger
Copy link
Collaborator

sgugger commented May 26, 2023

cc @Rocketknight1

@Rocketknight1
Copy link
Member

Rocketknight1 commented Jun 1, 2023

Hi @benzitohhh , and sorry for the delay! This is an interesting and useful idea, but we're depending on the underlying Tensorflow Text layers, specifically BertTokenizer in this case.

I don't think there is a 'never_split' option here, but we could use the preserve_unused_token argument. This would mean that tokens of the form [unused0], [unused1], etc. would never be split, so you could use those as a control token like [abstract]. Would this work for your use case? If it's useful to you it's probably useful to other people, and we can add it to the TFBertTokenizer layer in a PR.

@benzitohhh
Copy link
Author

benzitohhh commented Jun 15, 2023

hi @Rocketknight1 Thanks for the response, and sorry so slow getting back to you also!

Just to check I understand...

In our case, vocabulary (token_id to token mapping) looks as below, where 5-9 inclusive are "special" tokens:
https://huggingface.co/anferico/bert-for-patents/raw/main/vocab.txt

0: [PAD]
1: [UNK]
2: [CLS]
3: [SEP]
4: [MASK]
5: [abstract]
6: [claim]
7: [summary]
8: [invention]
9: [cpc]
10: [unused0]
11: [unused1]
12: [unused2]
13: [unused3]
14: [unused4]
15: [unused5]
etc...

So with the preserve_unused_token approach, I guess we'd need to do something like:

input = '         [abstract]   some      text      here.       '
#out  = [2,       5,           1735,     3458,     1847,      3]  #### Expected tokenized ids

# 1. Replace each "special" token with a unique "unused" token
# So we need to map:
#    '[abstract]' -> '[unused0]'
#    '[claims]'   -> '[unused1]'
# etc..
# I guess we could use some regex for this.
input__unused = '[unused0] some text here'

# 2. Do the tokenization
bert_input__unused = tokenizer(tf.constant([input__unused]))
# { 'input_ids': ... numpy=array([[   2, 10, 1735, 3458, 1847,  3]])> etc... }
# i.e. the "10" above is the is '[unused0]' token

# 3. Replace "unused" token_ids with the correct special token_ids
# Not sure exactly how to do this with tensor operations, but I guess it's possible?
# So we need to map:
#    10 ('[unused0]') -> 5 ('[abstract]')
#    11 ('[unused1]') -> 6 ('[claims]')
# etc..
bert_input = ..
# { 'input_ids': ... numpy=array([[   2, 5, 1735, 3458, 1847,  3]])> etc... } 

Will the above work?

If so, that would be amazing, and totally solve our situation.

Obviously, being able to add a "never_split" param would be much nicer :)

Anyways let us know what is possible - thanks!

@Rocketknight1
Copy link
Member

Hi @benzitohhh, yes, that's correct! I'd also need to file a PR to expose the option in our tokenizer, but if you're interested then I can do that soon.

For the issue of mapping the unused token ids to the correct special token IDs, I suggest using a block of unused token IDs in the same order as your special token IDs. Then all you would need to do is:

# `True` for `unused` tokens, `False` otherwise
condition = (input_ids >= unused_start_idx) & (input_ids <= unused_end_idx)
# Subtract offset value from all unused tokens
input_ids = tf.where(condition, input_ids - offset, input_ids)

In the vocab list you linked above, an offset of 5 would map [unused0] -> [abstract] and so on.

@Rocketknight1
Copy link
Member

For more complex replacements, you could also just reorder the vocab_list for the TFBertTokenizer so it generates the indices you want!

@benzitohhh
Copy link
Author

@Rocketknight1 Ok this would totally work for us, and would allow us to create an end-to-end model - yay!

If you could create a PR that would be super appreciated.

Thanks again for all your help here, and the super clear explanations. Have a good weekend meanwhile.

@Rocketknight1
Copy link
Member

Rocketknight1 commented Jun 16, 2023

Hi @benzitohhh, the PR is now open at #24324. You can try out the PR branch with the following command:

pip install git+https://github.com/huggingface/transformers.git@allow_tf_tokenizer_kwargs

When creating the TFBertTokenizer, add the arguments use_fast_bert_tokenizer=False and preserve_unused_token=True. Also, note that only the slower TF tokenizer layer supports the preserve_unused_token argument, but only the fast layer can be exported to TFLite. This means that this solution won't work for you if you want to export to TFLite!

@benzitohhh
Copy link
Author

@Rocketknight1 Ah amazing thanks! Will try this out first thing on Monday and let you know asap

@benzitohhh
Copy link
Author

@Rocketknight1 Ok just tested the PR, it works perfectly.

Thanks again for making this happen!

@Rocketknight1
Copy link
Member

Cool! Hopefully we can merge the PR soon in that case, so you can stop installing from the branch.

@Rocketknight1
Copy link
Member

@benzitohhh this has now been merged. You can now get it just by installing from main with

pip install git+https://github.com/huggingface/transformers.git

It will be included with the next release of transformers in a few weeks, at which point you can go back to the usual pip install transformers

@benzitohhh
Copy link
Author

@Rocketknight1 amazing - thanks again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants