From 82af104ebc2e00fa7b1389681a19077fe722af78 Mon Sep 17 00:00:00 2001 From: Qiong Zhou Huang Date: Tue, 6 May 2025 23:37:56 +0000 Subject: [PATCH 1/2] Remove special tokens for Mistral when using bad words Signed-off-by: Qiong Zhou Huang --- vllm/sampling_params.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 66a77681be9a..affc5c64b941 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -13,7 +13,6 @@ from vllm.logger import init_logger from vllm.logits_process import LogitsProcessor from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer logger = init_logger(__name__) @@ -491,13 +490,8 @@ def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None: for add_prefix_space in [False, True]: prefix = " " if add_prefix_space else "" prompt = prefix + bad_word.lstrip() - - if isinstance(tokenizer, MistralTokenizer): - # Mistral tokenizers should not add special tokens - prompt_token_ids = tokenizer.encode(text=prompt) - else: - prompt_token_ids = tokenizer.encode( - text=prompt, add_special_tokens=False) + prompt_token_ids = tokenizer.encode(text=prompt, + add_special_tokens=False) # If no space at the beginning # or if prefix space produces a new word token From 601f6ec2eb19e786f0852d429e4a82896d051b0e Mon Sep 17 00:00:00 2001 From: Qiong Zhou Huang Date: Wed, 7 May 2025 16:16:35 +0000 Subject: [PATCH 2/2] Set add_special_tokens=True in for Mistral in logits_process as well Signed-off-by: Qiong Zhou Huang --- vllm/logits_process.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/vllm/logits_process.py b/vllm/logits_process.py index e3faf20029ec..29a73656bf65 100644 --- a/vllm/logits_process.py +++ b/vllm/logits_process.py @@ -4,11 +4,12 @@ import torch -from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer -LogitsProcessor = Union[Callable[[list[int], torch.Tensor], torch.Tensor], - Callable[[list[int], list[int], torch.Tensor], - torch.Tensor]] +LogitsProcessor = Union[ + Callable[[list[int], torch.Tensor], torch.Tensor], + Callable[[list[int], list[int], torch.Tensor], torch.Tensor], +] """LogitsProcessor is a function that takes a list of previously generated tokens, the logits tensor for the next token and, optionally, prompt tokens as a @@ -29,12 +30,8 @@ def get_bad_words_logits_processors( prefix = " " if add_prefix_space else "" prompt = prefix + bad_word.lstrip() - if isinstance(tokenizer, MistralTokenizer): - # Mistral tokenizers should not add special tokens - prompt_token_ids = tokenizer.encode(text=prompt) - else: - prompt_token_ids = tokenizer.encode(text=prompt, - add_special_tokens=False) + prompt_token_ids = tokenizer.encode(text=prompt, + add_special_tokens=False) # If no space at the beginning # or if prefix space produces a new word token