diff --git a/megatron/arguments.py b/megatron/arguments.py index 631d4b12e8..be3a79e9e2 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -1273,6 +1273,7 @@ def _add_data_args(parser): 'GPT2BPETokenizer', 'SentencePieceTokenizer', 'GPTSentencePieceTokenizer', + 'Llama2Tokenizer', 'HFTokenizer', 'NullTokenizer'], help='What type of tokenizer to use.') diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py index 43c251bab1..643bc4f7cf 100644 --- a/megatron/tokenizer/tokenizer.py +++ b/megatron/tokenizer/tokenizer.py @@ -35,6 +35,9 @@ def build_tokenizer(args): elif args.tokenizer_type == 'GPTSentencePieceTokenizer': assert args.tokenizer_model is not None tokenizer = _GPTSentencePieceTokenizer(args.tokenizer_model) + elif args.tokenizer_type == 'Llama2Tokenizer': + assert args.tokenizer_model is not None + tokenizer = _Llama2Tokenizer(args.tokenizer_model) elif args.tokenizer_type == 'NullTokenizer': assert args.vocab_size is not None tokenizer = _NullTokenizer(args.vocab_size) @@ -465,6 +468,7 @@ def mask(self): def additional_special_tokens_ids(self): return [self.vocab[k] for k in self._t5_tokens] + class _GPTSentencePieceTokenizer(_SentencePieceTokenizer): """SentencePieceTokenizer-Megatron wrapper""" @@ -504,6 +508,57 @@ def eod(self): def additional_special_tokens_ids(self): return None + +class _Llama2Tokenizer(_SentencePieceTokenizer): + """SentencePieceTokenizer-Megatron wrapper""" + + def __init__(self, model_file,): + super().__init__(model_file, vocab_extra_ids=0) + + def _initalize(self, vocab_extra_ids): + self._populate_vocab() + + # BOS / EOS token IDs + self.n_words: int = self.tokenizer.vocab_size() + self.bos_id: int = self.tokenizer.bos_id() + self.eos_id: int = self.tokenizer.eos_id() + self.pad_id: int = self.tokenizer.pad_id() + assert self.tokenizer.vocab_size() == self.tokenizer.get_piece_size() + + def tokenize(self, s: str, bos=True, eos=False): + '''Default args for text completion, not chat/dialog.''' + assert type(s) is str + t = self.tokenizer.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def detokenize(self, ids): + return self.tokenizer.decode_ids(ids) + + @property + def cls(self): + return -1 + + @property + def sep(self): + return -1 + + @property + def mask(self): + return -1 + + @property + def eod(self): + return self.eos_id + + @property + def additional_special_tokens_ids(self): + return None + + class _NullTokenizer: def __init__(self, vocab_size): vocab_size = int(vocab_size)