Skip to content

Commit

Permalink
[Bugfix]: Make chat content text allow type content (vllm-project#9358)
Browse files Browse the repository at this point in the history
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Signed-off-by: Erkin Sagiroglu <erkin@infra-aipipeline-1-at1-prox-prod-a.ipa.corp.telnyx.com>
  • Loading branch information
vrdn-23 authored and Erkin Sagiroglu committed Oct 26, 2024
1 parent a383e7b commit 27abf44
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 12 deletions.
17 changes: 17 additions & 0 deletions docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,23 @@ vllm serve <model> --chat-template ./path-to-chat-template.jinja
vLLM community provides a set of chat templates for popular models. You can find them in the examples
directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)

With the inclusion of multi-modal chat APIs, the OpenAI spec now accepts chat messages in a new format which specifies
both a `type` and a `text` field. An example is provided below:
```python
completion = client.chat.completions.create(
model="NousResearch/Meta-Llama-3-8B-Instruct",
messages=[
{"role": "user", "content": [{"type": "text", "text": "Classify this sentiment: vLLM is wonderful!"}]}
]
)
```
Most chat templates for LLMs expect the `content` to be a `string` but there are some newer models like
`meta-llama/Llama-Guard-3-1B` that expect the content to be parsed with the new OpenAI spec. In order to choose which
format the content needs to be parsed in by vLLM, please use the `--chat-template-text-format` argument to specify
between `string` or `openai`. The default value is `string` and vLLM internally converts both spec formats to match
this, unless explicitly specified.


## Command line arguments for the server

```{argparse}
Expand Down
1 change: 1 addition & 0 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class MockModelConfig:
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
chat_template_text_format = "string"
max_model_len = 100
tokenizer_revision = None
multimodal_config = MultiModalConfig()
Expand Down
48 changes: 47 additions & 1 deletion tests/entrypoints/test_chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"


@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def phi3v_model_config():
return ModelConfig(PHI3V_MODEL_ID,
task="generate",
Expand All @@ -26,6 +26,7 @@ def phi3v_model_config():
trust_remote_code=True,
dtype="bfloat16",
seed=0,
chat_template_text_format="string",
limit_mm_per_prompt={
"image": 2,
})
Expand Down Expand Up @@ -330,6 +331,51 @@ def test_parse_chat_messages_multiple_images_across_messages(
_assert_mm_data_is_image_input(mm_data, 2)


def test_parse_chat_messages_context_text_format(
phi3v_model_config,
phi3v_tokenizer,
):
phi3v_model_config.chat_template_text_format = "openai"
conversation, mm_data = parse_chat_messages(
[{
"role": "user",
"content": [{
"type": "text",
"text": "What's in this text?"
}]
}, {
"role": "assistant",
"content": "Some stuff."
}, {
"role": "user",
"content": "What about this one?"
}], phi3v_model_config, phi3v_tokenizer)

assert conversation == [
{
"role": "user",
"content": [{
"type": "text",
"text": "What's in this text?"
}]
},
{
"role": "assistant",
"content": [{
"type": "text",
"text": "Some stuff."
}]
},
{
"role": "user",
"content": [{
"type": "text",
"text": "What about this one?"
}]
},
]


def test_parse_chat_messages_rejects_too_many_images_in_one_message(
phi3v_model_config,
phi3v_tokenizer,
Expand Down
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def __init__(self,
use_async_output_proc: bool = True,
override_neuron_config: Optional[Dict[str, Any]] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
chat_template_text_format: str = "string",
mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None:
self.model = model
self.tokenizer = tokenizer
Expand Down Expand Up @@ -176,6 +177,7 @@ def __init__(self,
self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc
self.chat_template_text_format = chat_template_text_format
self.mm_processor_kwargs = mm_processor_kwargs

# Set enforce_eager to False if the value is unset.
Expand Down
10 changes: 10 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class EngineArgs:
task: TaskOption = "auto"
skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto'
chat_template_text_format: str = 'string'
trust_remote_code: bool = False
download_dir: Optional[str] = None
load_format: str = 'auto'
Expand Down Expand Up @@ -250,6 +251,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.')
parser.add_argument(
'--chat-template-text-format',
type=str,
default=EngineArgs.chat_template_text_format,
choices=['string', 'openai'],
help='The format to render text content within a chat template. '
'"string" will keep the content field as a string whereas '
'"openai" will parse content in the current OpenAI format.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='Trust remote code from huggingface.')
Expand Down Expand Up @@ -858,6 +867,7 @@ def create_model_config(self) -> ModelConfig:
# We know this is not None because we set it in __post_init__
tokenizer=cast(str, self.tokenizer),
tokenizer_mode=self.tokenizer_mode,
chat_template_text_format=self.chat_template_text_format,
trust_remote_code=self.trust_remote_code,
dtype=self.dtype,
seed=self.seed,
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def __init__(
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, "
"mm_processor_kwargs=%s)",
"chat_template_text_format=%s, mm_processor_kwargs=%s)",
VLLM_VERSION,
model_config.model,
speculative_config,
Expand Down Expand Up @@ -289,6 +289,7 @@ def __init__(
cache_config.enable_prefix_caching,
model_config.use_async_output_proc,
use_cached_outputs,
model_config.chat_template_text_format,
model_config.mm_processor_kwargs,
)
# TODO(woosuk): Print more configs in debug mode.
Expand Down
31 changes: 23 additions & 8 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class ConversationMessage(TypedDict, total=False):
role: Required[str]
"""The role of the message's author."""

content: Optional[str]
content: Union[Optional[str], List[Dict[str, str]]]
"""The contents of the message"""

tool_call_id: Optional[str]
Expand Down Expand Up @@ -431,7 +431,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> Tuple[str, str]:
"""
Parses a given multi modal content part based on its type.
Parses a given multi-modal content part based on its type.
Args:
part: A dict containing the content part, with a potential 'type' field.
Expand Down Expand Up @@ -485,21 +485,26 @@ def _parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionContentPartParam],
mm_tracker: BaseMultiModalItemTracker,
chat_template_text_format: str,
) -> List[ConversationMessage]:
content: List[Union[str, Dict[str, str]]] = []

mm_parser = mm_tracker.create_parser()
keep_multimodal_content = \
wrap_dicts = \
mm_tracker._model_config.hf_config.model_type in \
MODEL_KEEP_MULTI_MODAL_CONTENT
MODEL_KEEP_MULTI_MODAL_CONTENT or \
(chat_template_text_format == "openai")

for part in parts:
parse_res = _parse_chat_message_content_part(
part, mm_parser, wrap_dicts=keep_multimodal_content)
part,
mm_parser,
wrap_dicts=wrap_dicts,
)
if parse_res:
content.append(parse_res)

if keep_multimodal_content:
if wrap_dicts:
# Parsing wraps images and texts as interleaved dictionaries
return [ConversationMessage(role=role,
content=content)] # type: ignore
Expand Down Expand Up @@ -560,6 +565,7 @@ def _parse_chat_message_content_part(
def _parse_chat_message_content(
message: ChatCompletionMessageParam,
mm_tracker: BaseMultiModalItemTracker,
chat_template_text_format: str,
) -> List[ConversationMessage]:
role = message["role"]
content = message.get("content")
Expand All @@ -575,6 +581,7 @@ def _parse_chat_message_content(
role,
content, # type: ignore
mm_tracker,
chat_template_text_format,
)

for result_msg in result:
Expand Down Expand Up @@ -618,7 +625,11 @@ def parse_chat_messages(
mm_tracker = MultiModalItemTracker(model_config, tokenizer)

for msg in messages:
sub_messages = _parse_chat_message_content(msg, mm_tracker)
sub_messages = _parse_chat_message_content(
msg,
mm_tracker,
model_config.chat_template_text_format,
)

conversation.extend(sub_messages)

Expand All @@ -636,7 +647,11 @@ def parse_chat_messages_futures(
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)

for msg in messages:
sub_messages = _parse_chat_message_content(msg, mm_tracker)
sub_messages = _parse_chat_message_content(
msg,
mm_tracker,
model_config.chat_template_text_format,
)

conversation.extend(sub_messages)

Expand Down
7 changes: 5 additions & 2 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ async def chat_completion_stream_generator(
# Send response to echo the input portion of the
# last message
if request.echo or request.continue_final_message:
last_msg_content: str = ""
last_msg_content: Union[str, List[Dict[str, str]]] = ""
if conversation and "content" in conversation[
-1] and conversation[-1].get("role") == role:
last_msg_content = conversation[-1]["content"] or ""
Expand Down Expand Up @@ -724,10 +724,13 @@ async def chat_completion_full_generator(
choices.append(choice_data)

if request.echo or request.continue_final_message:
last_msg_content = ""
last_msg_content: Union[str, List[Dict[str, str]]] = ""
if conversation and "content" in conversation[-1] and conversation[
-1].get("role") == role:
last_msg_content = conversation[-1]["content"] or ""
if isinstance(last_msg_content, list):
last_msg_content = "\n".join(msg['text']
for msg in last_msg_content)

for choice in choices:
full_message = last_msg_content + (choice.message.content
Expand Down

0 comments on commit 27abf44

Please sign in to comment.