Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/entrypoints/openai/test_chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ def test_get_gen_prompt(model, template, add_generation_prompt,

# Call the function and get the result
result = apply_hf_chat_template(
model_config,
tokenizer,
tokenizer=tokenizer,
conversation=mock_request.messages,
chat_template=mock_request.chat_template or template_content,
model_config=model_config,
tools=None,
add_generation_prompt=mock_request.add_generation_prompt,
continue_final_message=mock_request.continue_final_message,
Expand Down
10 changes: 5 additions & 5 deletions tests/entrypoints/test_chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,10 +793,10 @@ def get_conversation(is_hf: bool):
)

vllm_result = apply_hf_chat_template(
model_config,
tokenizer,
tokenizer=tokenizer,
conversation=conversation,
chat_template=None,
model_config=model_config,
tools=None,
add_generation_prompt=True,
)
Expand Down Expand Up @@ -903,11 +903,11 @@ def test_resolve_content_format_hf_defined(model, expected_format):
print(_try_extract_ast(chat_template))

resolved_format = resolve_chat_template_content_format(
model_config,
None, # Test detecting the tokenizer's chat_template
None,
"auto",
tokenizer,
model_config=model_config,
)

assert resolved_format == expected_format
Expand Down Expand Up @@ -962,11 +962,11 @@ def test_resolve_content_format_fallbacks(model, expected_format):
print(_try_extract_ast(chat_template))

resolved_format = resolve_chat_template_content_format(
model_config,
None, # Test detecting the tokenizer's chat_template
None,
"auto",
tokenizer,
model_config=model_config,
)

assert resolved_format == expected_format
Expand Down Expand Up @@ -1021,11 +1021,11 @@ def test_resolve_content_format_examples(template_path, expected_format):
print(_try_extract_ast(chat_template))

resolved_format = resolve_chat_template_content_format(
model_config,
chat_template,
None,
"auto",
dummy_tokenizer,
model_config=model_config,
)

assert resolved_format == expected_format
35 changes: 27 additions & 8 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
# yapf: enable
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid
from vllm.utils import deprecate_kwargs, random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -329,11 +329,17 @@ def resolve_mistral_chat_template(
"so it will be ignored.")
return None

@deprecate_kwargs(
"trust_remote_code",
additional_message="Please use `model_config.trust_remote_code` instead.",
)
def resolve_hf_chat_template(
model_config: ModelConfig,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
*,
model_config: ModelConfig,
trsut_remote_code: Optional[bool] = None,
) -> Optional[str]:
# 1st priority: The given chat template
if chat_template is not None:
Expand Down Expand Up @@ -379,18 +385,19 @@ def resolve_hf_chat_template(


def _resolve_chat_template_content_format(
model_config: ModelConfig,
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
*,
model_config: ModelConfig,
) -> _ChatTemplateContentFormat:
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
hf_chat_template = resolve_hf_chat_template(
model_config,
tokenizer,
chat_template=chat_template,
tools=tools,
model_config=model_config,
)
else:
hf_chat_template = None
Expand Down Expand Up @@ -428,19 +435,25 @@ def _log_chat_template_content_format(
)


@deprecate_kwargs(
"trust_remote_code",
additional_message="Please use `model_config.trust_remote_code` instead.",
)
def resolve_chat_template_content_format(
model_config: ModelConfig,
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
*,
model_config: ModelConfig,
trust_remote_code: Optional[bool] = None,
) -> _ChatTemplateContentFormat:
detected_format = _resolve_chat_template_content_format(
model_config,
chat_template,
tools,
given_format,
tokenizer,
model_config=model_config,
)

_log_chat_template_content_format(
Expand Down Expand Up @@ -1191,21 +1204,27 @@ def parse_chat_messages_futures(
return conversation, mm_tracker.all_mm_data()


@deprecate_kwargs(
"trust_remote_code",
additional_message="Please use `model_config.trust_remote_code` instead.",
)
def apply_hf_chat_template(
model_config: ModelConfig,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: list[ConversationMessage],
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
*,
model_config: ModelConfig,
tokenize: bool = False, # Different from HF's default
# Deprecated, explicitly capture here so it doesn't slit into kwargs.
trust_remote_code: Optional[bool] = None,
**kwargs: Any,
) -> str:
hf_chat_template = resolve_hf_chat_template(
model_config,
tokenizer,
chat_template=chat_template,
tools=tools,
model_config=model_config,
)

if hf_chat_template is None:
Expand Down
6 changes: 3 additions & 3 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,11 +731,11 @@ def chat(
tokenizer = self.get_tokenizer(lora_request)
model_config = self.llm_engine.get_model_config()
resolved_content_format = resolve_chat_template_content_format(
model_config,
chat_template,
tools,
chat_template_content_format,
tokenizer,
model_config=model_config,
)

_chat_template_kwargs: dict[str, Any] = dict(
Expand Down Expand Up @@ -767,9 +767,9 @@ def chat(
)
else:
prompt_str = apply_hf_chat_template(
model_config,
tokenizer,
tokenizer=tokenizer,
conversation=conversation,
model_config=model_config,
**_chat_template_kwargs,
)
# Special tokens are already included in chat templates so
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,10 +971,10 @@ async def init_app_state(
chat_template=resolved_chat_template)
else:
hf_chat_template = resolve_hf_chat_template(
vllm_config.model_config,
tokenizer,
tokenizer=tokenizer,
chat_template=None,
tools=None,
model_config=vllm_config.model_config,
)

if hf_chat_template != resolved_chat_template:
Expand Down
6 changes: 3 additions & 3 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,11 +670,11 @@ async def _preprocess_chat(
model_config = self.model_config

resolved_content_format = resolve_chat_template_content_format(
model_config,
chat_template,
tool_dicts,
chat_template_content_format,
tokenizer,
model_config=model_config,
)
conversation, mm_data_future = parse_chat_messages_futures(
messages,
Expand All @@ -701,9 +701,9 @@ async def _preprocess_chat(
)
else:
request_prompt = apply_hf_chat_template(
model_config,
tokenizer,
tokenizer=tokenizer,
conversation=conversation,
model_config=model_config,
**_chat_template_kwargs,
)

Expand Down