Skip to content

Commit 115929e

Browse files
qionghuang6mawong-amd
authored andcommitted
[Bugfix] Fix bad words for Mistral models (vllm-project#17753)
Signed-off-by: Qiong Zhou Huang <qiong@phonic.co>
1 parent e24ade4 commit 115929e

File tree

2 files changed

+9
-18
lines changed

2 files changed

+9
-18
lines changed

vllm/logits_process.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44

55
import torch
66

7-
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
7+
from vllm.transformers_utils.tokenizer import AnyTokenizer
88

9-
LogitsProcessor = Union[Callable[[list[int], torch.Tensor], torch.Tensor],
10-
Callable[[list[int], list[int], torch.Tensor],
11-
torch.Tensor]]
9+
LogitsProcessor = Union[
10+
Callable[[list[int], torch.Tensor], torch.Tensor],
11+
Callable[[list[int], list[int], torch.Tensor], torch.Tensor],
12+
]
1213
"""LogitsProcessor is a function that takes a list
1314
of previously generated tokens, the logits tensor
1415
for the next token and, optionally, prompt tokens as a
@@ -29,12 +30,8 @@ def get_bad_words_logits_processors(
2930
prefix = " " if add_prefix_space else ""
3031
prompt = prefix + bad_word.lstrip()
3132

32-
if isinstance(tokenizer, MistralTokenizer):
33-
# Mistral tokenizers should not add special tokens
34-
prompt_token_ids = tokenizer.encode(text=prompt)
35-
else:
36-
prompt_token_ids = tokenizer.encode(text=prompt,
37-
add_special_tokens=False)
33+
prompt_token_ids = tokenizer.encode(text=prompt,
34+
add_special_tokens=False)
3835

3936
# If no space at the beginning
4037
# or if prefix space produces a new word token

vllm/sampling_params.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from vllm.logger import init_logger
1414
from vllm.logits_process import LogitsProcessor
1515
from vllm.transformers_utils.tokenizer import AnyTokenizer
16-
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
1716

1817
logger = init_logger(__name__)
1918

@@ -491,13 +490,8 @@ def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
491490
for add_prefix_space in [False, True]:
492491
prefix = " " if add_prefix_space else ""
493492
prompt = prefix + bad_word.lstrip()
494-
495-
if isinstance(tokenizer, MistralTokenizer):
496-
# Mistral tokenizers should not add special tokens
497-
prompt_token_ids = tokenizer.encode(text=prompt)
498-
else:
499-
prompt_token_ids = tokenizer.encode(
500-
text=prompt, add_special_tokens=False)
493+
prompt_token_ids = tokenizer.encode(text=prompt,
494+
add_special_tokens=False)
501495

502496
# If no space at the beginning
503497
# or if prefix space produces a new word token

0 commit comments

Comments
 (0)