Skip to content

Commit

Permalink
align tokenizer interface
Browse files Browse the repository at this point in the history
Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
  • Loading branch information
youngkent committed Jan 29, 2025
1 parent b6c37e3 commit ce4d0af
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 14 deletions.
4 changes: 1 addition & 3 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,9 +1105,7 @@ def ensure_str(prompt: SingletonPrompt):
parsed_prompts = []

for q, t in input_pairs:
prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
prompt_inputs = tokenizer(q, text_pair=t, **tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
Expand Down
3 changes: 1 addition & 2 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,7 @@ async def _preprocess_chat(
_chat_template_kwargs.update(chat_template_kwargs or {})

request_prompt: Union[str, List[int]]
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
if is_mistral_tokenizer:
if isinstance(tokenizer, MistralTokenizer):
request_prompt = apply_mistral_chat_template(
tokenizer,
messages=messages,
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def create_score(

tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
prompt_inputs = await tokenize_async(text=q,
prompt_inputs = await tokenize_async(q,
text_pair=t,
**tokenization_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion vllm/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_bad_words_logits_processors(

if isinstance(tokenizer, MistralTokenizer):
# Mistral tokenizers should not add special tokens
prompt_token_ids = tokenizer.encode(prompt=prompt)
prompt_token_ids = tokenizer.encode(text=prompt)
else:
prompt_token_ids = tokenizer.encode(text=prompt,
add_special_tokens=False)
Expand Down
2 changes: 2 additions & 0 deletions vllm/transformers_utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def get_tokenizer(
'encoding and decoding.',
FutureWarning,
stacklevel=2)

tokenizer: AnyTokenizer
if tokenizer_mode == "mistral":
tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
revision=revision)
Expand Down
14 changes: 11 additions & 3 deletions vllm/transformers_utils/tokenizer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def vocab_size(self) -> int:
def max_token_id(self) -> int:
raise NotImplementedError()

@property
@abstractmethod
def sep_token(self) -> int:
raise NotImplementedError()

@abstractmethod
def __len__(self) -> int:
raise NotImplementedError()
Expand All @@ -56,6 +61,7 @@ def __len__(self) -> int:
def __call__(
self,
prompt: Union[str, List[str], List[int]],
text_pair: Optional[str] = None,
add_special_tokens: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
Expand All @@ -73,14 +79,16 @@ def get_added_vocab(self) -> Dict[str, int]:
@abstractmethod
def encode_one(
self,
prompt: str,
text: str,
truncation: bool = False,
max_length: Optional[int] = None,
) -> List[int]:
raise NotImplementedError()

@abstractmethod
def encode(self, prompt: str) -> List[int]:
def encode(self,
text: str,
add_special_tokens: Optional[bool] = None) -> List[int]:
raise NotImplementedError()

@abstractmethod
Expand Down Expand Up @@ -114,7 +122,7 @@ class TokenizerRegistry:
REGISTRY: Dict[str, Tuple[str, str]] = {}

@staticmethod
def register(name: str, module: str, class_name: str) -> TokenizerBase:
def register(name: str, module: str, class_name: str) -> None:
TokenizerRegistry.REGISTRY[name] = (module, class_name)

@staticmethod
Expand Down
15 changes: 11 additions & 4 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,17 @@ def vocab_size(self) -> int:
def max_token_id(self) -> int:
return self._max_token_id

@property
def sep_token(self) -> int:
raise NotImplementedError()

def __len__(self) -> int:
return self.vocab_size

def __call__(
self,
prompt: Union[str, List[str], List[int]],
text_pair: Optional[str] = None,
add_special_tokens: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
Expand Down Expand Up @@ -257,22 +262,24 @@ def get_added_vocab(self) -> Dict[str, int]:

def encode_one(
self,
prompt: str,
text: str,
truncation: bool = False,
max_length: Optional[int] = None,
) -> List[int]:
# Mistral Tokenizers should not add special tokens
input_ids = self.encode(prompt)
input_ids = self.encode(text)

if truncation:
input_ids = input_ids[:max_length]
return input_ids

def encode(self, prompt: str) -> List[int]:
def encode(self,
text: str,
add_special_tokens: Optional[bool] = None) -> List[int]:
# `encode` should only be used for prompt completion
# it should never be used for chat_completion.
# For chat completion use `apply_chat_template`
return self.tokenizer.encode(prompt, bos=True, eos=False)
return self.tokenizer.encode(text, bos=True, eos=False)

def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"],
Expand Down

0 comments on commit ce4d0af

Please sign in to comment.