Skip to content

Commit

Permalink
[Bugfix] Add missing attributes in mistral tokenizer (vllm-project#8364)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored and Jeffwan committed Sep 19, 2024
1 parent 8fa53f1 commit 6d25ecc
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 32 deletions.
7 changes: 5 additions & 2 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,11 +519,14 @@ def apply_hf_chat_template(
def apply_mistral_chat_template(
tokenizer: MistralTokenizer,
messages: List[ChatCompletionMessageParam],
chat_template: Optional[str],
chat_template: Optional[str] = None,
**kwargs: Any,
) -> List[int]:
if chat_template is not None:
logger.warning(
"'chat_template' cannot be overridden for mistral tokenizer.")

return tokenizer.apply_chat_template(
messages=messages,
chat_template=chat_template,
**kwargs,
)
88 changes: 58 additions & 30 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,25 @@ class MistralTokenizer:
def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
self.mistral = tokenizer
self.instruct = tokenizer.instruct_tokenizer
self.tokenizer = tokenizer.instruct_tokenizer.tokenizer

self.vocab_size = len(self.tokenizer.vocab())

assert isinstance(self.tokenizer,
(Tekkenizer, SentencePieceTokenizer)), type(
self.tokenizer)

if (is_tekken := isinstance(self.tokenizer, Tekkenizer)):
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
if isinstance(tokenizer_, Tekkenizer):
# Make sure special tokens will not raise
self.tokenizer.special_token_policy = SpecialTokenPolicy.IGNORE

self._is_tekken = is_tekken
tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE

self._vocab = {
token: idx
for idx, token in enumerate(tokenizer_.vocab())
}
elif isinstance(tokenizer_, SentencePieceTokenizer):
self._vocab = {
token: idx
for idx, token in enumerate(tokenizer_.vocab())
}
else:
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")

# the following attributes are set to fit VLLM's design
self.is_fast = True
self.chat_template = True
self.all_special_ids: List[Any] = []
self.all_special_tokens: List[Any] = []
self.all_special_tokens_extended: List[Any] = []
self.tokenizer = tokenizer_

@classmethod
def from_pretrained(cls,
Expand Down Expand Up @@ -102,6 +101,38 @@ def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
revision=revision)
return tokenizer_file

# the following attributes are set to fit VLLM's design
@property
def all_special_tokens_extended(self) -> List[str]:
return []

@property
def all_special_tokens(self) -> List[str]:
return []

@property
def all_special_ids(self) -> List[int]:
return []

@property
def bos_token_id(self) -> int:
return self.tokenizer.bos_id

@property
def eos_token_id(self) -> int:
return self.tokenizer.eos_id

@property
def is_fast(self) -> bool:
return True

@property
def vocab_size(self) -> int:
return len(self._vocab)

def __len__(self) -> int:
return self.vocab_size

def __call__(
self,
prompt: str,
Expand All @@ -117,9 +148,12 @@ def __call__(

return Encoding(input_ids=input_ids)

def get_added_vocab(self) -> List[str]:
def get_vocab(self) -> Dict[str, int]:
return self._vocab

def get_added_vocab(self) -> Dict[str, int]:
# Mistral tokenizers have no added vocabulary
return []
return {}

def encode(self, prompt: str) -> List[int]:
# `encode` should only be used for prompt completion
Expand All @@ -141,7 +175,7 @@ def apply_chat_template(self,
return encoded.tokens

def convert_tokens_to_string(self, tokens: List[str]) -> str:
if self._is_tekken:
if isinstance(self.tokenizer, Tekkenizer):
return "".join(tokens)
else:
return self.tokenizer.decode(tokens) # type: ignore[arg-type]
Expand All @@ -151,14 +185,11 @@ def decode(self, ids: Union[List[int], int]) -> str:
ids = [ids]
return self.tokenizer.decode(ids)

@property
def eos_token_id(self):
return self.tokenizer.eos_id

def convert_ids_to_tokens(
self,
ids: List[int],
skip_special_tokens: Optional[bool] = True) -> List[str]:
self,
ids: List[int],
skip_special_tokens: bool = True,
) -> List[str]:
# TODO(Patrick) - potentially allow special tokens to not be skipped
assert (
skip_special_tokens
Expand All @@ -170,6 +201,3 @@ def convert_ids_to_tokens(

tokens = [self.tokenizer.id_to_piece(id) for id in ids]
return tokens

def __len__(self):
return self.vocab_size

0 comments on commit 6d25ecc

Please sign in to comment.