Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Add missing attributes in mistral tokenizer #8364

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,11 +513,14 @@ def apply_hf_chat_template(
def apply_mistral_chat_template(
tokenizer: MistralTokenizer,
messages: List[ChatCompletionMessageParam],
chat_template: Optional[str],
chat_template: Optional[str] = None,
**kwargs: Any,
) -> List[int]:
if chat_template is not None:
logger.warning(
"'chat_template' cannot be overridden for mistral tokenizer.")

return tokenizer.apply_chat_template(
messages=messages,
chat_template=chat_template,
**kwargs,
)
88 changes: 58 additions & 30 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,25 @@ class MistralTokenizer:
def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
self.mistral = tokenizer
self.instruct = tokenizer.instruct_tokenizer
self.tokenizer = tokenizer.instruct_tokenizer.tokenizer

self.vocab_size = len(self.tokenizer.vocab())

assert isinstance(self.tokenizer,
(Tekkenizer, SentencePieceTokenizer)), type(
self.tokenizer)

if (is_tekken := isinstance(self.tokenizer, Tekkenizer)):
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
if isinstance(tokenizer_, Tekkenizer):
# Make sure special tokens will not raise
self.tokenizer.special_token_policy = SpecialTokenPolicy.IGNORE

self._is_tekken = is_tekken
tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE

self._vocab = {
token: idx
for idx, token in enumerate(tokenizer_.vocab())
}
elif isinstance(tokenizer_, SentencePieceTokenizer):
self._vocab = {
token: idx
for idx, token in enumerate(tokenizer_.vocab())
}
else:
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")

# the following attributes are set to fit VLLM's design
self.is_fast = True
self.chat_template = True
self.all_special_ids: List[Any] = []
self.all_special_tokens: List[Any] = []
self.all_special_tokens_extended: List[Any] = []
self.tokenizer = tokenizer_

@classmethod
def from_pretrained(cls,
Expand Down Expand Up @@ -102,6 +101,38 @@ def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
revision=revision)
return tokenizer_file

# the following attributes are set to fit VLLM's design
@property
def all_special_tokens_extended(self) -> List[str]:
return []

@property
def all_special_tokens(self) -> List[str]:
return []

@property
def all_special_ids(self) -> List[int]:
return []

@property
def bos_token_id(self) -> int:
return self.tokenizer.bos_id
Comment on lines +117 to +119
Copy link
Member Author

@DarkLight1337 DarkLight1337 Sep 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While not reported in the linked issue, I added this attribute as the tokenizer would otherwise throw an error when it's required by LLMEngine._get_bos_token_id.


@property
def eos_token_id(self) -> int:
return self.tokenizer.eos_id

@property
def is_fast(self) -> bool:
return True

@property
def vocab_size(self) -> int:
return len(self._vocab)

def __len__(self) -> int:
return self.vocab_size

def __call__(
self,
prompt: str,
Expand All @@ -117,9 +148,12 @@ def __call__(

return Encoding(input_ids=input_ids)

def get_added_vocab(self) -> List[str]:
def get_vocab(self) -> Dict[str, int]:
return self._vocab

def get_added_vocab(self) -> Dict[str, int]:
# Mistral tokenizers have no added vocabulary
return []
return {}

def encode(self, prompt: str) -> List[int]:
# `encode` should only be used for prompt completion
Expand All @@ -141,7 +175,7 @@ def apply_chat_template(self,
return encoded.tokens

def convert_tokens_to_string(self, tokens: List[str]) -> str:
if self._is_tekken:
if isinstance(self.tokenizer, Tekkenizer):
return "".join(tokens)
else:
return self.tokenizer.decode(tokens) # type: ignore[arg-type]
Expand All @@ -151,14 +185,11 @@ def decode(self, ids: Union[List[int], int]) -> str:
ids = [ids]
return self.tokenizer.decode(ids)

@property
def eos_token_id(self):
return self.tokenizer.eos_id

def convert_ids_to_tokens(
self,
ids: List[int],
skip_special_tokens: Optional[bool] = True) -> List[str]:
self,
ids: List[int],
skip_special_tokens: bool = True,
) -> List[str]:
# TODO(Patrick) - potentially allow special tokens to not be skipped
assert (
skip_special_tokens
Expand All @@ -170,6 +201,3 @@ def convert_ids_to_tokens(

tokens = [self.tokenizer.id_to_piece(id) for id in ids]
return tokens

def __len__(self):
return self.vocab_size
Loading