From 27abf442dbd126d7f593e31ccfef27105cb5a49a Mon Sep 17 00:00:00 2001 From: Vinay R Damodaran Date: Thu, 24 Oct 2024 01:05:49 -0400 Subject: [PATCH] [Bugfix]: Make chat content text allow type content (#9358) Signed-off-by: Vinay Damodaran Signed-off-by: Erkin Sagiroglu --- .../serving/openai_compatible_server.md | 17 +++++++ tests/entrypoints/openai/test_serving_chat.py | 1 + tests/entrypoints/test_chat_utils.py | 48 ++++++++++++++++++- vllm/config.py | 2 + vllm/engine/arg_utils.py | 10 ++++ vllm/engine/llm_engine.py | 3 +- vllm/entrypoints/chat_utils.py | 31 ++++++++---- vllm/entrypoints/openai/serving_chat.py | 7 ++- 8 files changed, 107 insertions(+), 12 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index cc8e539a8a6d3..413c87ab28755 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -103,6 +103,23 @@ vllm serve --chat-template ./path-to-chat-template.jinja vLLM community provides a set of chat templates for popular models. You can find them in the examples directory [here](https://github.com/vllm-project/vllm/tree/main/examples/) +With the inclusion of multi-modal chat APIs, the OpenAI spec now accepts chat messages in a new format which specifies +both a `type` and a `text` field. An example is provided below: +```python +completion = client.chat.completions.create( + model="NousResearch/Meta-Llama-3-8B-Instruct", + messages=[ + {"role": "user", "content": [{"type": "text", "text": "Classify this sentiment: vLLM is wonderful!"}]} + ] +) +``` +Most chat templates for LLMs expect the `content` to be a `string` but there are some newer models like +`meta-llama/Llama-Guard-3-1B` that expect the content to be parsed with the new OpenAI spec. In order to choose which +format the content needs to be parsed in by vLLM, please use the `--chat-template-text-format` argument to specify +between `string` or `openai`. The default value is `string` and vLLM internally converts both spec formats to match +this, unless explicitly specified. + + ## Command line arguments for the server ```{argparse} diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index d9342fad9f018..e969d33775d86 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -26,6 +26,7 @@ class MockModelConfig: tokenizer = MODEL_NAME trust_remote_code = False tokenizer_mode = "auto" + chat_template_text_format = "string" max_model_len = 100 tokenizer_revision = None multimodal_config = MultiModalConfig() diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index f64743e065fc8..5fa466f8f041f 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -17,7 +17,7 @@ MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") def phi3v_model_config(): return ModelConfig(PHI3V_MODEL_ID, task="generate", @@ -26,6 +26,7 @@ def phi3v_model_config(): trust_remote_code=True, dtype="bfloat16", seed=0, + chat_template_text_format="string", limit_mm_per_prompt={ "image": 2, }) @@ -330,6 +331,51 @@ def test_parse_chat_messages_multiple_images_across_messages( _assert_mm_data_is_image_input(mm_data, 2) +def test_parse_chat_messages_context_text_format( + phi3v_model_config, + phi3v_tokenizer, +): + phi3v_model_config.chat_template_text_format = "openai" + conversation, mm_data = parse_chat_messages( + [{ + "role": "user", + "content": [{ + "type": "text", + "text": "What's in this text?" + }] + }, { + "role": "assistant", + "content": "Some stuff." + }, { + "role": "user", + "content": "What about this one?" + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [ + { + "role": "user", + "content": [{ + "type": "text", + "text": "What's in this text?" + }] + }, + { + "role": "assistant", + "content": [{ + "type": "text", + "text": "Some stuff." + }] + }, + { + "role": "user", + "content": [{ + "type": "text", + "text": "What about this one?" + }] + }, + ] + + def test_parse_chat_messages_rejects_too_many_images_in_one_message( phi3v_model_config, phi3v_tokenizer, diff --git a/vllm/config.py b/vllm/config.py index c569789c650ab..25f841231dedd 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -142,6 +142,7 @@ def __init__(self, use_async_output_proc: bool = True, override_neuron_config: Optional[Dict[str, Any]] = None, config_format: ConfigFormat = ConfigFormat.AUTO, + chat_template_text_format: str = "string", mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None: self.model = model self.tokenizer = tokenizer @@ -176,6 +177,7 @@ def __init__(self, self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.use_async_output_proc = use_async_output_proc + self.chat_template_text_format = chat_template_text_format self.mm_processor_kwargs = mm_processor_kwargs # Set enforce_eager to False if the value is unset. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a5cfaf3977a4f..c49f475b9ee61 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -89,6 +89,7 @@ class EngineArgs: task: TaskOption = "auto" skip_tokenizer_init: bool = False tokenizer_mode: str = 'auto' + chat_template_text_format: str = 'string' trust_remote_code: bool = False download_dir: Optional[str] = None load_format: str = 'auto' @@ -250,6 +251,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'fast tokenizer if available.\n* "slow" will ' 'always use the slow tokenizer. \n* ' '"mistral" will always use the `mistral_common` tokenizer.') + parser.add_argument( + '--chat-template-text-format', + type=str, + default=EngineArgs.chat_template_text_format, + choices=['string', 'openai'], + help='The format to render text content within a chat template. ' + '"string" will keep the content field as a string whereas ' + '"openai" will parse content in the current OpenAI format.') parser.add_argument('--trust-remote-code', action='store_true', help='Trust remote code from huggingface.') @@ -858,6 +867,7 @@ def create_model_config(self) -> ModelConfig: # We know this is not None because we set it in __post_init__ tokenizer=cast(str, self.tokenizer), tokenizer_mode=self.tokenizer_mode, + chat_template_text_format=self.chat_template_text_format, trust_remote_code=self.trust_remote_code, dtype=self.dtype, seed=self.seed, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 167efa51e3e2f..0d73ed7c8e7ab 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -254,7 +254,7 @@ def __init__( "num_scheduler_steps=%d, chunked_prefill_enabled=%s " "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " "use_async_output_proc=%s, use_cached_outputs=%s, " - "mm_processor_kwargs=%s)", + "chat_template_text_format=%s, mm_processor_kwargs=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -289,6 +289,7 @@ def __init__( cache_config.enable_prefix_caching, model_config.use_async_output_proc, use_cached_outputs, + model_config.chat_template_text_format, model_config.mm_processor_kwargs, ) # TODO(woosuk): Print more configs in debug mode. diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index faa493d518a7c..fef6a91414db6 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -121,7 +121,7 @@ class ConversationMessage(TypedDict, total=False): role: Required[str] """The role of the message's author.""" - content: Optional[str] + content: Union[Optional[str], List[Dict[str, str]]] """The contents of the message""" tool_call_id: Optional[str] @@ -431,7 +431,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], def _parse_chat_message_content_mm_part( part: ChatCompletionContentPartParam) -> Tuple[str, str]: """ - Parses a given multi modal content part based on its type. + Parses a given multi-modal content part based on its type. Args: part: A dict containing the content part, with a potential 'type' field. @@ -485,21 +485,26 @@ def _parse_chat_message_content_parts( role: str, parts: Iterable[ChatCompletionContentPartParam], mm_tracker: BaseMultiModalItemTracker, + chat_template_text_format: str, ) -> List[ConversationMessage]: content: List[Union[str, Dict[str, str]]] = [] mm_parser = mm_tracker.create_parser() - keep_multimodal_content = \ + wrap_dicts = \ mm_tracker._model_config.hf_config.model_type in \ - MODEL_KEEP_MULTI_MODAL_CONTENT + MODEL_KEEP_MULTI_MODAL_CONTENT or \ + (chat_template_text_format == "openai") for part in parts: parse_res = _parse_chat_message_content_part( - part, mm_parser, wrap_dicts=keep_multimodal_content) + part, + mm_parser, + wrap_dicts=wrap_dicts, + ) if parse_res: content.append(parse_res) - if keep_multimodal_content: + if wrap_dicts: # Parsing wraps images and texts as interleaved dictionaries return [ConversationMessage(role=role, content=content)] # type: ignore @@ -560,6 +565,7 @@ def _parse_chat_message_content_part( def _parse_chat_message_content( message: ChatCompletionMessageParam, mm_tracker: BaseMultiModalItemTracker, + chat_template_text_format: str, ) -> List[ConversationMessage]: role = message["role"] content = message.get("content") @@ -575,6 +581,7 @@ def _parse_chat_message_content( role, content, # type: ignore mm_tracker, + chat_template_text_format, ) for result_msg in result: @@ -618,7 +625,11 @@ def parse_chat_messages( mm_tracker = MultiModalItemTracker(model_config, tokenizer) for msg in messages: - sub_messages = _parse_chat_message_content(msg, mm_tracker) + sub_messages = _parse_chat_message_content( + msg, + mm_tracker, + model_config.chat_template_text_format, + ) conversation.extend(sub_messages) @@ -636,7 +647,11 @@ def parse_chat_messages_futures( mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) for msg in messages: - sub_messages = _parse_chat_message_content(msg, mm_tracker) + sub_messages = _parse_chat_message_content( + msg, + mm_tracker, + model_config.chat_template_text_format, + ) conversation.extend(sub_messages) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index b9b240b64850e..cd2883a3b323b 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -384,7 +384,7 @@ async def chat_completion_stream_generator( # Send response to echo the input portion of the # last message if request.echo or request.continue_final_message: - last_msg_content: str = "" + last_msg_content: Union[str, List[Dict[str, str]]] = "" if conversation and "content" in conversation[ -1] and conversation[-1].get("role") == role: last_msg_content = conversation[-1]["content"] or "" @@ -724,10 +724,13 @@ async def chat_completion_full_generator( choices.append(choice_data) if request.echo or request.continue_final_message: - last_msg_content = "" + last_msg_content: Union[str, List[Dict[str, str]]] = "" if conversation and "content" in conversation[-1] and conversation[ -1].get("role") == role: last_msg_content = conversation[-1]["content"] or "" + if isinstance(last_msg_content, list): + last_msg_content = "\n".join(msg['text'] + for msg in last_msg_content) for choice in choices: full_message = last_msg_content + (choice.message.content