Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add_tokens does not preserve spacing #28384

Closed
2 of 4 tasks
denizyuret opened this issue Jan 8, 2024 · 7 comments
Closed
2 of 4 tasks

add_tokens does not preserve spacing #28384

denizyuret opened this issue Jan 8, 2024 · 7 comments

Comments

@denizyuret
Copy link

System Info

  • transformers version: 4.35.2
  • Platform: Linux-4.18.0-348.el8.x86_64-x86_64-with-glibc2.28
  • Python version: 3.11.6
  • Huggingface_hub version: 0.20.1
  • Safetensors version: 0.3.3
  • Accelerate version: 0.22.0

Who can help?

@ArthurZucker and @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

>>> from transformers import AutoTokenizer
>>> mtok = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-v0.1')
>>> str = 'Arthur met Deniz today.'
>>> mtok.decode(mtok.encode(str))
'<s> Arthur met Deniz today.'
>>> mtok.add_tokens(['Deniz'])
1
>>> mtok.decode(mtok.encode(str))
'<s> Arthur metDeniz today.'

Expected behavior

Spaces shoud be preserved when a text is encoded/decoded.

@ArthurZucker
Copy link
Collaborator

Few things can come into play. By default the token will be normalized, this means that the normalizer will be adding a prefix space to that token. When decoding, that space is removed. You should add the token using tokenizer.add_tokens(AddedToken("Deniz", normalized = False, special=False))

@denizyuret
Copy link
Author

denizyuret commented Jan 8, 2024

This time there is an extra space that appears after the token:

>>> from transformers import AddedToken, AutoTokenizer
>>> mtok = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-v0.1')
>>> mtok.add_tokens(AddedToken('Deniz', normalized=False, special=False))
1
>>> str = 'Arthur met Deniz today.'
>>> mtok.decode(mtok.encode(str))
'<s> Arthur met Deniz  today.'

Here are the individual tokens if helps:

>>> mtok.convert_ids_to_tokens(mtok.encode(str))
['<s>', '▁Arthur', '▁met', '▁', 'Deniz', '▁', '▁today', '.']

@ArthurZucker
Copy link
Collaborator

That is also expected, you should either use a slow tokenizer (use_fast = False) or follow the fix that is here: #26678. It's a known issue of the normalizer.

@denizyuret
Copy link
Author

use_fast=False has an interesting behavior:

>>> mtok = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-v0.1', use_fast=False)
>>> mtok.add_tokens([AddedToken('Deniz', normalized=False, special=False)])
>>> mtok.decode(mtok.encode('Arthur met Deniz today.'))
'<s>Arthur met  Deniz today.'
>>> mtok.convert_ids_to_tokens(mtok.encode('Arthur met Deniz today.'))
['<s>', '▁Arthur', '▁met', '▁', 'Deniz', '▁', '▁today', '.']

i.e. decode puts extra space before the added token, however convert_ids_to_tokens shows an extra space after the added token.

I tried adding ' Deniz' or '▁Deniz' thinking it would better match the behavior of regular tokens like '▁Arthur' but no success.

So far I haven't been able to find a combination of options that will preserve spacing. Will check out your fix next.

@ArthurZucker
Copy link
Collaborator

This is the combination that works:

In [44]: mtok = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-v0.1', use_fast=False, legacy = False)

In [45]: mtok.tokenize('Arthur met Deniz today.')
Out[45]: ['▁Arthur', '▁met', '▁Den', 'iz', '▁today', '.']

In [46]: mtok.add_tokens([AddedToken('Deniz', normalized=False, special=False)])
Out[46]: 1

In [47]: mtok.tokenize('Arthur met Deniz today.')
Out[47]: ['▁Arthur', '▁met', '▁', 'Deniz', '▁today', '.']
```�

@denizyuret
Copy link
Author

mtok.decode still throws in some extra spaces, but this works for model training etc. Thanks!

Copy link

github-actions bot commented Feb 7, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants