-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TFBertTokenizer - support for "never_split" #23798
Comments
Hi @benzitohhh , and sorry for the delay! This is an interesting and useful idea, but we're depending on the underlying Tensorflow Text layers, specifically BertTokenizer in this case. I don't think there is a 'never_split' option here, but we could use the |
hi @Rocketknight1 Thanks for the response, and sorry so slow getting back to you also! Just to check I understand... In our case, vocabulary (token_id to token mapping) looks as below, where 5-9 inclusive are "special" tokens:
So with the input = ' [abstract] some text here. '
#out = [2, 5, 1735, 3458, 1847, 3] #### Expected tokenized ids
# 1. Replace each "special" token with a unique "unused" token
# So we need to map:
# '[abstract]' -> '[unused0]'
# '[claims]' -> '[unused1]'
# etc..
# I guess we could use some regex for this.
input__unused = '[unused0] some text here'
# 2. Do the tokenization
bert_input__unused = tokenizer(tf.constant([input__unused]))
# { 'input_ids': ... numpy=array([[ 2, 10, 1735, 3458, 1847, 3]])> etc... }
# i.e. the "10" above is the is '[unused0]' token
# 3. Replace "unused" token_ids with the correct special token_ids
# Not sure exactly how to do this with tensor operations, but I guess it's possible?
# So we need to map:
# 10 ('[unused0]') -> 5 ('[abstract]')
# 11 ('[unused1]') -> 6 ('[claims]')
# etc..
bert_input = ..
# { 'input_ids': ... numpy=array([[ 2, 5, 1735, 3458, 1847, 3]])> etc... } Will the above work? If so, that would be amazing, and totally solve our situation. Obviously, being able to add a "never_split" param would be much nicer :) Anyways let us know what is possible - thanks! |
Hi @benzitohhh, yes, that's correct! I'd also need to file a PR to expose the option in our tokenizer, but if you're interested then I can do that soon. For the issue of mapping the # `True` for `unused` tokens, `False` otherwise
condition = (input_ids >= unused_start_idx) & (input_ids <= unused_end_idx)
# Subtract offset value from all unused tokens
input_ids = tf.where(condition, input_ids - offset, input_ids) In the vocab list you linked above, an offset of |
For more complex replacements, you could also just reorder the |
@Rocketknight1 Ok this would totally work for us, and would allow us to create an end-to-end model - yay! If you could create a PR that would be super appreciated. Thanks again for all your help here, and the super clear explanations. Have a good weekend meanwhile. |
Hi @benzitohhh, the PR is now open at #24324. You can try out the PR branch with the following command:
When creating the |
@Rocketknight1 Ah amazing thanks! Will try this out first thing on Monday and let you know asap |
@Rocketknight1 Ok just tested the PR, it works perfectly. Thanks again for making this happen! |
Cool! Hopefully we can merge the PR soon in that case, so you can stop installing from the branch. |
@benzitohhh this has now been merged. You can now get it just by installing from
It will be included with the next release of transformers in a few weeks, at which point you can go back to the usual |
@Rocketknight1 amazing - thanks again |
Feature request
Often vocabularies contain special tokens that should not be split.
For example, in model "anferico/bert-for-patents", the vocabulary contains a token "[abstract]" (token_id is 5)
https://huggingface.co/anferico/bert-for-patents/raw/main/vocab.txt
The normal
BertTokenizer
supports a param "never_split" for this:So above, even though '[abstract]' has parens, it is not split.
But
TFBertTokenizer
does not have a "never_split" param, and so there is no way to prevent splits. For example:Above, notice that token_id 5 (['abstract']) is missing in the input_ids, and in fact '[abstract]' has been split into three separate tokens:
Motivation
We would like to use an end-to-end model, on TensorFlow Serving, with in-graph tokenization.
But we need to be able to include special tokens in our input, such as '[abstract]', '[claims]' etc
https://huggingface.co/anferico/bert-for-patents/raw/main/vocab.txt
If TFBertTokenizer had a "never_split" param, this would be possible.
But currently it is not, so we need to do Tokenization outside the server.
The text was updated successfully, but these errors were encountered: