Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Tiktoken tokenizer for Nemotron-Mistral 12B #9797

Merged
merged 153 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
153 commits
Select commit Hold shift + click to select a range
118ead1
Adding context- & expert-parallism to MegatronStrategy (#9525)
marcromeyn Jun 24, 2024
28eaa1b
Add CICD test for Stable Diffusion (#9464)
michal2409 Jun 24, 2024
d27d00f
Akoumparouli/nemo ux mixtral (#9446)
akoumpa Jun 24, 2024
d339062
update mcoreddp call (#9345)
akoumpa Jun 25, 2024
0c0752b
[NeMo-UX] Llama and Gemma (#9528)
cuichenx Jun 25, 2024
c5590d7
[NeMo-UX] minor logging bug fixes (#9529)
ashors1 Jun 25, 2024
01c8389
mcore distOpt restore fix (#9421)
akoumpa Jun 25, 2024
9f76e93
Custom Tiktoken tokenizer.
ertkonuk Jun 26, 2024
990a371
Fixed the tokenizer decoding on special tokens.
ertkonuk Jun 29, 2024
51e5743
Apply isort and black reformatting
ertkonuk Jul 18, 2024
84a6952
Added token_to_id() method.
ertkonuk Jul 19, 2024
996fdd1
Update neva conversion script from and to HF (#9296)
yaoyu-33 Jun 25, 2024
2c5bcd4
vLLM Export Support (#9381)
apanteleev Jun 25, 2024
b9cecab
PL: Delete precision if using plugin. TODO switch to MegatronTrainerB…
akoumpa Jun 25, 2024
1d9fd4d
Add page context fmha (#9526)
meatybobby Jun 25, 2024
d82018c
extend get_gpt_layer_modelopt_spec to support MoE (#9532)
akoumpa Jun 26, 2024
2d7c4f2
fix mock data generation for legacy dataset (#9530)
dimapihtar Jun 26, 2024
88f632d
[Nemo-UX] IO fixes (#9512)
marcromeyn Jun 26, 2024
21fea92
Test C++ runtime on demand in nemo_export.py to avoid possible OOMs (…
janekl Jun 26, 2024
57d6465
Fix lhotse tests for v1.24.2 (#9546)
pzelasko Jun 26, 2024
11fabac
gpu_unitTests_notOptional (#9551)
pablo-garay Jun 27, 2024
fe86da4
add reset learning rate functionality (#9372)
dimapihtar Jun 27, 2024
023fa71
Add Python AIStore SDK to container and bump min Lhotse version (#9537)
pzelasko Jun 27, 2024
1806cff
Adding 'use_dynamo' option for export to use onnx.dynamo_export() ins…
borisfom Jun 27, 2024
9dc51ef
[NeMo-UX] Fix tokenizer IO (#9555)
marcromeyn Jun 27, 2024
7f5cc82
[NeMo UX] Move mistral_7b.py to mistral.py (#9545)
akoumpa Jun 27, 2024
d7ac5e0
Use closed-formula to round by multiple (#9307)
akoumpa Jun 27, 2024
6535e17
ci: Do not attempt to send slack on fork (#9556)
ko3n1g Jun 27, 2024
146dcdc
Fix nemo export test (#9547)
oyilmaz-nvidia Jun 27, 2024
6161348
Fix SDXL incorrect name in docs (#9534)
suiyoubi Jun 27, 2024
da711d7
GPU unit tests: Mark flaky tests to be fixed (#9559)
pablo-garay Jun 27, 2024
825ab7e
Bump PTL version (#9557)
athitten Jun 28, 2024
8e43b3e
[Resiliency] Straggler detection (#9473)
jbieniusiewi Jun 28, 2024
cb049cc
switch to torch_dist as default dist checkpointing backend (#9541)
ashors1 Jun 28, 2024
bb5132f
[NeMo-UX] Checkpointing bug fixes (#9562)
ashors1 Jun 28, 2024
7182633
Add tps and pps params to the export script (#9558)
oyilmaz-nvidia Jun 28, 2024
e79908f
Consolidate gpt continue training script into pretraining script (#9413)
yaoyu-33 Jun 28, 2024
411e88c
Add support to change Multi task model prompt (#9542)
titu1994 Jun 28, 2024
094d5a2
Add Multimodal Exporter (#9256)
meatybobby Jun 28, 2024
b2cc3d9
Enable encoder adapters for Canary and MultiTaskAED models (#9409)
titu1994 Jun 29, 2024
e856c6a
pass option through (#9570)
maanug-nv Jul 1, 2024
e95f3c6
PTQ refinements (#9574)
janekl Jul 1, 2024
dcfd711
Audio model collection (#9263)
anteju Jul 1, 2024
144ed66
[NeMo-UX] Fix Trainer serialization (#9571)
marcromeyn Jul 1, 2024
7e998ae
Update click version requirement (#9580)
thomasdhc Jul 1, 2024
b97152d
[Fault tolerance] Heartbeat detection (#9352)
maanug-nv Jul 1, 2024
786ef6c
Add ModelOpt QAT example for Llama2 SFT model (#9326)
kevalmorabia97 Jul 1, 2024
6cba41e
Set TE flag in legacy -> mcore conversion script (#9585)
cuichenx Jul 1, 2024
4630e4f
[Nemo-UX] Add fabric-API for manual forward-pass (#9577)
marcromeyn Jul 2, 2024
c5a8ad2
[Nemo-UX] Add SDK-factories to llm-collection (#9589)
marcromeyn Jul 2, 2024
db6c8f1
Multimodal projection layer adapter fix for PP>1 (#9445)
paul-gibbons Jul 2, 2024
28129f8
Add offline quantization script for QLoRA deployment (#9455)
cuichenx Jul 2, 2024
1fc59b5
qlora support more models (#9488)
cuichenx Jul 2, 2024
131e8b3
[NeMo-UX] Some improvements to NeMoLogger (#9591)
marcromeyn Jul 2, 2024
d4d4841
Set n_gpu to None in nemo export (#9593)
oyilmaz-nvidia Jul 2, 2024
0499992
Inflight nemo model export support (#9527)
JimmyZhang12 Jul 3, 2024
896897f
vLLM Export Improvements (#9596)
apanteleev Jul 3, 2024
b8ec574
Set finalize_model_grads_func in on_fit_start instead to make sure it…
marcromeyn Jul 3, 2024
6fc68d6
Set no_sync_func & grad_sync_fucn (#9601)
akoumpa Jul 3, 2024
1a0edc1
small nemo logger bug fix (#9607)
ashors1 Jul 3, 2024
2371ed7
fix the dict format returned by scheduler method (#9609)
sararb Jul 3, 2024
1d4ddf2
[NeMo-UX] Dataloading enhancements and bug fixes (#9595)
ashors1 Jul 4, 2024
38564e4
Fix serialization of AutoResume (#9616)
sararb Jul 4, 2024
5b0730d
Chat template support for megatron_gpt_eval.py (#9354)
akoumpa Jul 4, 2024
07520fe
Jsonl support (#9611)
adityavavre Jul 4, 2024
2cab60a
[NeMo-UX] Add PEFT (#9490)
cuichenx Jul 5, 2024
b2e043b
Akoumparouli/mistral import instruct chat template fix (#9567)
akoumpa Jul 5, 2024
0c2e1f8
Remove .cuda calls, use device isntead (#9602)
akoumpa Jul 5, 2024
20282f5
fix converter defautl args (#9565)
akoumpa Jul 5, 2024
46bd64d
mixtral export (#9603)
akoumpa Jul 5, 2024
86b5434
fix: remove non_blocking from PTL's .cuda call (#9618)
akoumpa Jul 5, 2024
60204db
Alit/mamba tmp (#9612)
JRD971000 Jul 5, 2024
06949f8
TitaNet Batch Verify Speaker (#9337)
monica-sekoyan Jul 5, 2024
0e0a29d
Enable MCore checkpointing optimizations (#9505)
mikolajblaz Jul 5, 2024
544b8e8
Change mixtral moe key name for trt-llm (#9620)
oyilmaz-nvidia Jul 5, 2024
82c529f
fix ckpt load bug (#9621)
dimapihtar Jul 6, 2024
e79f049
NeVA Minor Fixes (#9608)
yaoyu-33 Jul 6, 2024
91be2cf
fix pretrianing data sizes and weights (#9627)
cuichenx Jul 6, 2024
b20d668
Alit/mamba (#9575)
JRD971000 Jul 6, 2024
faf89d2
[NeMo-UX] async checkpointing support (#9466)
ashors1 Jul 8, 2024
26a63e4
Fix the arguments of forward_for_export function in msdd_models (#9624)
tango4j Jul 8, 2024
be64d15
Change default parallel_save to False (#9632)
mikolajblaz Jul 8, 2024
55511e5
Unwrap ckpt_io for model opt (async save) (#9622)
mikolajblaz Jul 8, 2024
1d00f68
MCore T5 support for NeMo - Training (#9432)
huvunvidia Jul 8, 2024
fc8980e
[Nemo-UX] Expose transformer_layer_spec inside GPTConfig (#9592)
marcromeyn Jul 8, 2024
a9d6499
Update NeMo Clip to Use MCore Modules (#9594)
yaoyu-33 Jul 8, 2024
dc359cd
Add REST API to deploy module (#9539)
athitten Jul 8, 2024
1b8136f
Mistral + Mixtral Support for NeVa (#9459)
paul-gibbons Jul 8, 2024
1344ebf
ci: Timeout per step, not job (#9635)
ko3n1g Jul 8, 2024
8b433a5
Adding support for mcore generate (#9566)
shanmugamr1992 Jul 8, 2024
5bd2679
Improve error messaging during trt-llm export (#9638)
oyilmaz-nvidia Jul 8, 2024
0bbb2e2
Nemotron export - fixing megatron_export.py (#9625)
borisfom Jul 8, 2024
227647e
support lora when kv_channel != hidden_size / num_heads (#9636)
suiyoubi Jul 8, 2024
e195637
[Nemo CICD] Docker temp files auto-cleanup (#9642)
pablo-garay Jul 9, 2024
8bf1d0b
Update Dockerfile.ci (#9651)
huvunvidia Jul 9, 2024
5f402e8
SDXL improvements (and support for Draft+) [DRAFT PR] (#9543)
rohitrango Jul 9, 2024
8478599
Triton deployment improvements for in-framework models (#9600)
jukim-nv Jul 9, 2024
879b560
Use FP8 in GPT TP2 test (#9451)
jbaczek Jul 9, 2024
8e55110
enables default data step in megatron parallel to operate on a wider …
jomitchellnv Jul 10, 2024
2409503
Revert "enables default data step in megatron parallel to operate on …
marcromeyn Jul 10, 2024
651fb03
Contrastive Reranker/Reward model (#9171)
arendu Jul 10, 2024
836020a
unpin transformers version (#9606)
dimapihtar Jul 10, 2024
c9ee483
Added CPU offloading docs (#9479)
sanandaraj5597 Jul 10, 2024
ab4d89e
Update llama-3 PEFT notebook to download model from NGC (#9667)
shashank3959 Jul 10, 2024
9282795
fix pipeline parallel dtype bug (#9637) (#9661)
github-actions[bot] Jul 10, 2024
a3e6c2a
LITA integration (#9578)
Slyne Jul 11, 2024
2f68b46
Parametrize FPS group (#9648) (#9669)
github-actions[bot] Jul 11, 2024
45333e8
Huvu/mcore t5 (#9677)
huvunvidia Jul 11, 2024
4656a19
chore: Version bump NeMo (#9631)
ko3n1g Jul 11, 2024
abea8b5
add a bit more for timeout (#9702)
pablo-garay Jul 11, 2024
470cf45
Alit/mamba (#9696)
JRD971000 Jul 11, 2024
9531b94
NeMo performance feature documentation (#9482)
erhoo82 Jul 11, 2024
408b893
[TTS] Add fullband mel codec checkpoints (#9704)
rlangman Jul 11, 2024
c7b2ead
Adding support for mcore T5 Eval - SFT - PEFT (#9679)
huvunvidia Jul 12, 2024
3662b61
Allows non-strict load with distributed checkpoints (#9613) (#9715)
github-actions[bot] Jul 12, 2024
d0b648d
refactor: Uniform BRANCH for notebooks (#9710)
ko3n1g Jul 12, 2024
dc06818
fix legacy ds padding bug (#9716)
dimapihtar Jul 15, 2024
9287114
enables default data step in megatron parallel to operate on a wider …
jomitchellnv Jul 15, 2024
7c10575
[NeMo-UX] Fix when optimizers are setup for PEFT (#9619) (#9647)
github-actions[bot] Jul 15, 2024
e907667
refactor: README (#9712)
ko3n1g Jul 15, 2024
1b77c94
Remove mask if use fusion mask (#9723)
hsiehjackson Jul 15, 2024
879f4d9
[NeMo-UX] Fix imports so local configuration of runs works again (#96…
github-actions[bot] Jul 15, 2024
b84b97e
add contianer (#9731)
JRD971000 Jul 15, 2024
7e586fe
update pretrained model text (#9724) (#9745)
github-actions[bot] Jul 15, 2024
bfcc5ae
[Nemo-UX] Including all trainable-params in a PEFT-checkpoint (#9650)…
github-actions[bot] Jul 15, 2024
05db815
[NeMo-UX] Make TE and Apex dependencies optional (#9732)
ashors1 Jul 15, 2024
8d1b19a
[NeMo-UX] Minor bug fix when TE/Apex not installed (#9749)
ashors1 Jul 16, 2024
872554b
make 'load_directly_on_device' configurable (#9657) (#9674)
github-actions[bot] Jul 16, 2024
c4b76b5
TorchAudio installation workaround for incorrect `PYTORCH_VERSION` en…
github-actions[bot] Jul 16, 2024
5629117
Create __init__.py (#9755)
stevehuang52 Jul 16, 2024
56a4e8c
Canary Adapters tutorial (#9670)
titu1994 Jul 16, 2024
6971720
match nemo 1's default behavior for drop_last and pad_samples_to_glob…
github-actions[bot] Jul 17, 2024
f0f2f01
ci: Bump MCore tag (#9744)
ko3n1g Jul 17, 2024
32e8889
Fix the serialization of partial functions in nemo 2.0 (#9668)
sararb Jul 17, 2024
670da1d
ci: Add PAT to create-pullrequest action (#9769)
ko3n1g Jul 17, 2024
d95d3f6
Speeds up copying of necessary artifact files with SaveRestoreConnect…
terrykong Jul 17, 2024
865e2bd
ci: Remove ko3n1g from reviewers (#9773)
ko3n1g Jul 17, 2024
06cfacb
bump mcore commit in Dockerfile (#9766)
ashors1 Jul 17, 2024
65550c4
Yuya/add checkpoints section (#9329)
yaoyu-33 Jul 17, 2024
84fa8a3
Release automation (#9687)
ko3n1g Jul 18, 2024
b7e91c3
Rename speech dockerfile appropriately (#9778)
pablo-garay Jul 18, 2024
516ca1f
Add option to convert PyTriton response to OpenAI format (#9726)
athitten Jul 18, 2024
3b6a770
ci: Fix changelog-config (#9788)
ko3n1g Jul 18, 2024
b6daddd
Support configurable extra fields for LazyNeMoTarredIterator (#9548)
pzelasko Jul 19, 2024
d3bcefe
upper bound huggingface-hub version to 0.24.0 (exc.) (#9799)
akoumpa Jul 19, 2024
2c910ea
CodeQL fixes
akoumpa Jul 19, 2024
e51c8f0
import guard
akoumpa Jul 19, 2024
6a4f78f
add tiktoken to requirements
akoumpa Jul 19, 2024
f31a63e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2024
575c1b6
Apply isort and black reformatting
pre-commit-ci[bot] Jul 19, 2024
ad35944
Apply isort and black reformatting
ertkonuk Jul 19, 2024
ef1c7dd
Merge branch 'main' into tkonuk/tiktoken
akoumpa Jul 19, 2024
643fe07
Apply isort and black reformatting
akoumpa Jul 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nemo/collections/common/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.common.tokenizers.regex_tokenizer import RegExTokenizer
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer
from nemo.collections.common.tokenizers.tiktoken_tokenizer import TiktokenTokenizer
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.collections.common.tokenizers.word_tokenizer import WordTokenizer

Expand Down
200 changes: 200 additions & 0 deletions nemo/collections/common/tokenizers/tiktoken_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import json
import os
from pathlib import Path
from typing import Dict, List, Optional

try:
import tiktoken
except ImportError:
Dismissed Show dismissed Hide dismissed
pass

from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec

__all__ = ['TiktokenTokenizer']


def reload_mergeable_ranks(
path: str,
max_vocab: Optional[int] = None,
) -> Dict[bytes, int]:
"""
Reload the tokenizer JSON file and convert it to Tiktoken format.
"""
assert path.endswith(".json")

# reload vocab
with open(path, "r") as f:
vocab = json.load(f)
assert isinstance(vocab, list)
print(f"Vocab size: {len(vocab)}")
if max_vocab is not None:
vocab = vocab[:max_vocab]
print(f"Cutting vocab to first {len(vocab)} tokens.")

# build ranks
ranks: Dict[bytes, int] = {}
for i, x in enumerate(vocab):
assert x.keys() == {"rank", "token_bytes", "token_str"}
assert x["rank"] == i
merge = base64.b64decode(x["token_bytes"])
assert i >= 256 or merge == bytes([i])
ranks[merge] = x["rank"]

# sanity check
assert len(ranks) == len(vocab)
assert set(ranks.values()) == set(range(len(ranks)))

return ranks


PATTERN_TIKTOKEN = "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
DEFAULT_TIKTOKEN_MAX_VOCAB = 2**17 # 131072
SPECIAL_TOKENS = ["<unk>", "<s>", "</s>"]
SPECIAL_TOKEN_TEMPLATE = "<SPECIAL_{id}>"


class TiktokenTokenizer(TokenizerSpec):
"""
TiktokenTokenizer https://github.com/openai/tiktoken.

Args:
model_path: path to tokenizer vocabulary
num_special_tokens: number of special tokens to generate
special_tokens: template for user-defined special tokens
pattern: Regex pattern to split the text
"""

def __init__(
self,
vocab_file: str,
pattern: str = PATTERN_TIKTOKEN,
vocab_size: int = DEFAULT_TIKTOKEN_MAX_VOCAB, # 131072
num_special_tokens: int = 1000,
special_tokens: Optional[List[str]] = None,
):
if not vocab_file or not os.path.exists(vocab_file):
raise ValueError(f"vocab_file: {vocab_file} is invalid")

if special_tokens is None:
special_tokens = SPECIAL_TOKENS.copy()

assert len(special_tokens) == len(set(special_tokens)), f"Special tokens should be unique: {special_tokens}"
assert len(special_tokens) <= num_special_tokens < vocab_size
assert set(SPECIAL_TOKENS) <= set(special_tokens), f"Custom special tokens should include {SPECIAL_TOKENS}"

self._unk_id = special_tokens.index("<unk>")
self._bos_id = special_tokens.index("<s>")
self._eos_id = special_tokens.index("</s>")

self._vocab_size = vocab_size
print(f'{self._vocab_size = }')
self.num_special_tokens = num_special_tokens
special_filler = [SPECIAL_TOKEN_TEMPLATE.format(id=i) for i in range(len(special_tokens), num_special_tokens)]
if special_filler:
print(f"Adding special tokens {special_filler[0]}, ..., {special_filler[-1]}")
self.special_tokens = special_tokens + special_filler
assert len(set(self.special_tokens)) == len(self.special_tokens) == num_special_tokens, self.special_tokens
self.inner_vocab_size = vocab_size - num_special_tokens

# reload vocab
self.token2id = reload_mergeable_ranks(vocab_file, max_vocab=self.inner_vocab_size)
self.id2token = {v: k for k, v in self.token2id.items()}
assert set(range(self.inner_vocab_size)) == set(self.id2token.keys())

self.shifted_id2token = {i: tok for i, tok in enumerate(self.special_tokens)}
for key, value in self.id2token.items():
self.shifted_id2token[key + self.num_special_tokens] = value

self.tokenizer = tiktoken.Encoding(
name=Path(vocab_file).parent.name,
pat_str=pattern,
mergeable_ranks=self.token2id,
special_tokens={}, # special tokens are handled manually
)

def text_to_tokens(self, text: str):
token_ids = self.tokenizer.encode(text)
return [self.tokenizer.decode_single_token_bytes(token) for token in token_ids]

def tokens_to_text(self, tokens: List[int]):
token_ids = [self.tokenizer.encode_single_token(tokens) for tokens in tokens]
return self.tokenizer.decode(token_ids)

def token_to_id(self, token):
return self.tokenizer.encode_single_token(token)

def tokens_to_ids(self, tokens):
return [self.tokenizer.encode_single_token(token) for token in tokens]

def ids_to_tokens(self, token_ids):
tokens = []
for token_id in token_ids:
if token_id < self.num_special_tokens:
tokens.append(self.special_tokens[token_id])
else:
token_id -= self.num_special_tokens
token_bytes = self.tokenizer.decode_single_token_bytes(token_id)
tokens.append(token_bytes.decode('utf-8', errors='replace'))
return tokens

def text_to_ids(self, text: str):
tokens = self.tokenizer.encode(text)
tokens = [t + self.num_special_tokens for t in tokens]
return tokens

def ids_to_text(self, tokens: List[int]):
# Filter out special tokens and adjust the remaining tokens
adjusted_tokens = [
t - self.num_special_tokens
for t in tokens
if t not in {self.bos, self.eos} and t >= self.num_special_tokens
]

# Decode only if there are tokens left after filtering
if adjusted_tokens:
return self.tokenizer.decode(adjusted_tokens)
else:
return "" # Return an empty string if all tokens were filtered out

@property
def bos_id(self):
return self._bos_id

@property
def eos_id(self):
return self._eos_id

@property
def unk_id(self):
return self._unk_id

@property
def vocab(self):
return self.token2id

@property
def decoder(self):
return self.shifted_id2token

@property
def encoder(self):
return self.vocab

@property
def vocab_size(self) -> int:
return self._vocab_size
5 changes: 5 additions & 0 deletions nemo/collections/nlp/modules/common/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.common.tokenizers.regex_tokenizer import RegExTokenizer
from nemo.collections.common.tokenizers.tabular_tokenizer import TabularTokenizer
from nemo.collections.common.tokenizers.tiktoken_tokenizer import TiktokenTokenizer
from nemo.collections.common.tokenizers.word_tokenizer import WordTokenizer
from nemo.collections.nlp.modules.common.huggingface.huggingface_utils import get_huggingface_pretrained_lm_models_list
from nemo.collections.nlp.modules.common.lm_utils import get_pretrained_lm_models_list
Expand Down Expand Up @@ -122,6 +123,8 @@ def get_tokenizer(
legacy=True,
chat_template=chat_template,
)
elif tokenizer_name == 'tiktoken':
return nemo.collections.common.tokenizers.tiktoken_tokenizer.TiktokenTokenizer(vocab_file=vocab_file)
elif tokenizer_name == 'word':
return WordTokenizer(vocab_file=vocab_file, **special_tokens_dict)
elif tokenizer_name == 'char':
Expand Down Expand Up @@ -221,6 +224,8 @@ def get_nmt_tokenizer(
)
elif library == 'tabular':
return TabularTokenizer(vocab_file, delimiter=delimiter)
elif library == 'tiktoken':
return TiktokenTokenizer(vocab_file=vocab_file)
else:
raise NotImplementedError(
'Currently we only support "huggingface", "sentencepiece", "megatron", and "byte-level" tokenizer'
Expand Down
1 change: 0 additions & 1 deletion nemo/export/multimodal/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def init_tokenizer(self, llm_engine_dir):

self.tokenizer = AutoTokenizer.from_pretrained(os.path.join(llm_engine_dir, 'huggingface_tokenizer'))
self.tokenizer.pad_token = self.tokenizer.eos_token

if self.model_type == 'vita':
self.tokenizer.im_start_id = self.tokenizer.convert_tokens_to_ids("<extra_id_4>")
self.tokenizer.im_end_id = self.tokenizer.convert_tokens_to_ids("<extra_id_5>")
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements_nlp.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ rouge_score
sacrebleu # manually install sacrebleu[ja] for Japanese support; MeCab is unsupported in Python 3.11+
sentence_transformers
tensorstore<0.1.46
tiktoken==0.7.0
zarr
Loading