44
55import  torch 
66
7- from  vllm .transformers_utils .tokenizer  import  AnyTokenizer ,  MistralTokenizer 
7+ from  vllm .transformers_utils .tokenizer  import  AnyTokenizer 
88
9- LogitsProcessor  =  Union [Callable [[list [int ], torch .Tensor ], torch .Tensor ],
10-                         Callable [[list [int ], list [int ], torch .Tensor ],
11-                                  torch .Tensor ]]
9+ LogitsProcessor  =  Union [
10+     Callable [[list [int ], torch .Tensor ], torch .Tensor ],
11+     Callable [[list [int ], list [int ], torch .Tensor ], torch .Tensor ],
12+ ]
1213"""LogitsProcessor is a function that takes a list 
1314of previously generated tokens, the logits tensor 
1415for the next token and, optionally, prompt tokens as a 
@@ -29,12 +30,8 @@ def get_bad_words_logits_processors(
2930            prefix  =  " "  if  add_prefix_space  else  "" 
3031            prompt  =  prefix  +  bad_word .lstrip ()
3132
32-             if  isinstance (tokenizer , MistralTokenizer ):
33-                 # Mistral tokenizers should not add special tokens 
34-                 prompt_token_ids  =  tokenizer .encode (text = prompt )
35-             else :
36-                 prompt_token_ids  =  tokenizer .encode (text = prompt ,
37-                                                     add_special_tokens = False )
33+             prompt_token_ids  =  tokenizer .encode (text = prompt ,
34+                                                 add_special_tokens = False )
3835
3936            # If no space at the beginning 
4037            # or if prefix space produces a new word token 
0 commit comments