Skip to content

Commit

Permalink
Update T5 tokenizer (adding additional tokens to tokenizer config) (#…
Browse files Browse the repository at this point in the history
…10972)

* initial commit

* restore t5_pretraining

* Apply isort and black reformatting

Signed-off-by: huvunvidia <huvunvidia@users.noreply.github.com>

---------

Signed-off-by: huvunvidia <huvunvidia@users.noreply.github.com>
Co-authored-by: Huy Vu2 <huvu@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: huvunvidia <huvunvidia@users.noreply.github.com>
  • Loading branch information
3 people authored and yashaswikarnati committed Oct 24, 2024
1 parent c457d45 commit 9b3f602
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 9 deletions.
11 changes: 10 additions & 1 deletion nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from collections import OrderedDict
from typing import Optional
from typing import List, Optional

from transformers import AutoTokenizer as AUTOTOKENIZER

Expand Down Expand Up @@ -43,6 +43,7 @@ def __init__(
sep_token: Optional[str] = None,
cls_token: Optional[str] = None,
unk_token: Optional[str] = None,
additional_special_tokens: Optional[List] = [],
use_fast: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
):
Expand All @@ -60,6 +61,7 @@ def __init__(
sep_token: token used for separating sequences
cls_token: class token. Usually equal to bos_token
unk_token: token to use for unknown tokens
additional_special_tokens: list of other tokens beside standard special tokens (bos, eos, pad, etc.). For example, sentinel tokens for T5 (<extra_id_0>, <extra_id_1>, etc.)
use_fast: whether to use fast HuggingFace tokenizer
"""
try:
Expand Down Expand Up @@ -124,10 +126,17 @@ def __init__(
elif self.tokenizer.cls_token is None and self.tokenizer.bos_token:
special_tokens_dict["cls_token"] = self.tokenizer.bos_token

# add additional special tokens (not standard special tokens such as bos, eod, sep)
if additional_special_tokens is not None:
special_tokens_dict["additional_special_tokens"] = additional_special_tokens

new_tokens_in_vocab = []
for token in [mask_token, bos_token, eos_token, pad_token, sep_token, cls_token, unk_token]:
if token is not None and token not in self.tokenizer.get_vocab():
new_tokens_in_vocab.append(token)
for token in additional_special_tokens:
if token is not None and token not in self.tokenizer.get_vocab():
new_tokens_in_vocab.append(token)

if len(new_tokens_in_vocab) > 0:
"""
Expand Down
2 changes: 0 additions & 2 deletions nemo/collections/llm/t5/data/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ def __init__(
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

self.tokenizer = tokenizer or get_nmt_tokenizer("megatron", "BertWordPieceCase")
additional_tokens = {'additional_special_tokens': [f'<extra_id_{i}>' for i in range(100)]}
self.tokenizer.add_special_tokens(additional_tokens)

self.memmap_workers = memmap_workers
self.num_workers = num_workers
Expand Down
4 changes: 0 additions & 4 deletions nemo/collections/llm/t5/data/pre_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,6 @@ def __init__(
# add additional tokens for T5 tokenizer
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

self.tokenizer = tokenizer or get_nmt_tokenizer("megatron", "BertWordPieceCase")
additional_tokens = {'additional_special_tokens': [f'<extra_id_{i}>' for i in range(100)]}
self.tokenizer.add_special_tokens(additional_tokens)

self.data_sampler = MegatronDataSampler(
seq_len=self.seq_length,
micro_batch_size=micro_batch_size,
Expand Down
9 changes: 7 additions & 2 deletions nemo/collections/nlp/modules/common/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def get_tokenizer(
To see the list of all HuggingFace pretrained models, use:
nemo_nlp.modules.common.get_huggingface_pretrained_lm_models_list()
tokenizer_model: tokenizer model file of sentencepiece
special_tokens: dict of special tokens
special_tokens: dict of special tokens.
For additional special tokens besides standard special tokens (bos, eos, pad, etc.), such as sentinel tokens for T5 (<extra_id_0>, <extra_id_1>, etc.), use key 'additional_special_tokens'
vocab_file: path to vocab file
use_fast: (only for HuggingFace AutoTokenizer) set to True to use fast HuggingFace tokenizer
bpe_dropout: (experimental) BPE dropout tries to corrupt the standard segmentation
Expand Down Expand Up @@ -224,7 +225,11 @@ def get_nmt_tokenizer(
f'Getting Megatron tokenizer for pretrained model name: {model_name}, custom vocab file: {vocab_file}, and merges file: {merges_file}'
)
return get_tokenizer(
tokenizer_name=model_name, vocab_file=vocab_file, merges_file=merges_file, chat_template=chat_template
tokenizer_name=model_name,
vocab_file=vocab_file,
merges_file=merges_file,
special_tokens=special_tokens_dict,
chat_template=chat_template,
)
elif library == 'tabular':
from nemo.collections.common.tokenizers.tabular_tokenizer import TabularTokenizer
Expand Down
3 changes: 3 additions & 0 deletions tests/collections/llm/megatron_t5_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ def get_args():

args = get_args()

special_tokens = {}
special_tokens['additional_special_tokens'] = [f'<extra_id_{i}>' for i in range(100)]
tokenizer = get_nmt_tokenizer(
"megatron",
"BertWordPieceCase",
special_tokens=special_tokens,
)

data = SquadDataModule(
Expand Down
3 changes: 3 additions & 0 deletions tests/collections/llm/megatron_t5_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,13 @@ def get_args():

args = get_args()

special_tokens = {}
special_tokens['additional_special_tokens'] = [f'<extra_id_{i}>' for i in range(100)]
tokenizer = get_nmt_tokenizer(
"megatron",
"BertWordPieceCase",
vocab_file=args.vocab_path,
special_tokens=special_tokens,
)
data = PreTrainingDataModule(
paths=args.data_path,
Expand Down

0 comments on commit 9b3f602

Please sign in to comment.