Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0
from openai import OpenAI

# This example demonstrates the `structural_tag` response format.
# It can be used to specify a structured output format that occurs between
# specific tags in the response. This example shows how it could be used
# to enforce the format of a tool call response, but it could be used for
# any structured output within a subset of the response.


def main():
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="-",
)

messages = [{
"role":
"user",
"content":
"""
You have access to the following function to retrieve the weather in a city:

{
"name": "get_weather",
"parameters": {
"city": {
"param_type": "string",
"description": "The city to get the weather for",
"required": True
}
}
}

If a you choose to call a function ONLY reply in the following format:
<{start_tag}={function_name}>{parameters}{end_tag}
where

start_tag => `<function`
parameters => a JSON dict with the function argument name as key and function
argument value as value.
end_tag => `</function>`

Here is an example,
<function=example_function_name>{"example_name": "example_value"}</function>

Reminder:
- Function calls MUST follow the specified format
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- Always add your sources when using search results to answer the user query

You are a helpful assistant.

Given the previous instructions, what is the weather in New York City, Boston,
and San Francisco?
"""
}]

response = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=messages,
response_format={
"type":
"structural_tag",
"structures": [{
"begin": "<function=get_weather>",
"schema": {
"type": "object",
"properties": {
"city": {
"type": "string"
}
}
},
"end": "</function>"
}],
"triggers": ["<function="]
})
print(response)


if __name__ == "__main__":
main()
101 changes: 101 additions & 0 deletions tests/v1/entrypoints/llm/test_struct_output_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def test_structured_output(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=json_schema))

outputs = llm.generate(
prompts="Generate a description of a frog using 50 characters.",
sampling_params=sampling_params,
Expand All @@ -368,6 +369,106 @@ def test_structured_output(
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema)

#
# Test 11: Generate structured output using structural_tag format
#
structural_tag_config = {
"type":
"structural_tag",
"structures": [{
"begin": "<function=get_weather>",
"schema": {
"type": "object",
"properties": {
"city": {
"type": "string"
}
}
},
"end": "</function>"
}],
"triggers": ["<function="]
}

sampling_params = SamplingParams(
temperature=0.0,
max_tokens=100,
guided_decoding=GuidedDecodingParams(
structural_tag=json.dumps(structural_tag_config)))

prompt = """
You have access to the following function to retrieve the weather in a city:

{
"name": "get_weather",
"parameters": {
"city": {
"param_type": "string",
"description": "The city to get the weather for",
"required": True
}
}
}

If a you choose to call a function ONLY reply in the following format:
<{start_tag}={function_name}>{parameters}{end_tag}
where

start_tag => `<function`
parameters => a JSON dict with the function argument name
as key and function argument value as value.
end_tag => `</function>`

Here is an example,
<function=example_function_name>{"example_name": "example_value"}</function>

Reminder:
- Function calls MUST follow the specified format
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- Always add your sources when using search results to answer the user query

You are a helpful assistant.

Given the previous instructions, what is the weather in New York City?
"""

# Change this once other backends support structural_tag
if guided_decoding_backend.startswith("xgrammar"):
outputs = llm.generate(prompts=prompt,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
else:
outputs = []

for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
generated_text = output.outputs[0].text
assert generated_text is not None

# Search for function call pattern in the response
function_call_pattern = r'<function=get_weather>(.*?)</function>'
matches = re.findall(function_call_pattern, generated_text)

if not matches:
print(f"Warning: No function calls found in response: "
f"{generated_text!r}")
continue

# Take the first function call if multiple are found
json_str = matches[0]
try:
json_content = json.loads(json_str)
assert "city" in json_content
assert isinstance(json_content["city"], str)
print(f"Found valid function call: {generated_text!r}")
except (json.JSONDecodeError, AssertionError) as e:
pytest.fail("Invalid function call format: "
f"{generated_text!r}\nError: {str(e)}")


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("model_name, tokenizer_mode",
Expand Down
4 changes: 3 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,7 +1395,9 @@ def _add_guided_params(
grammar=guided_options.guided_grammar,
json_object=guided_options.guided_json_object,
backend=guided_options.guided_decoding_backend,
whitespace_pattern=guided_options.guided_whitespace_pattern)
whitespace_pattern=guided_options.guided_whitespace_pattern,
structural_tag=guided_options.structural_tag,
)
return params

def _run_engine(
Expand Down
46 changes: 39 additions & 7 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import json
import re
import time
from argparse import Namespace
Expand Down Expand Up @@ -139,12 +140,30 @@ class JsonSchemaResponseFormat(OpenAIBaseModel):
strict: Optional[bool] = None


class StructuralTag(OpenAIBaseModel):
begin: str
# schema is the field, but that causes conflicts with pydantic so
# instead use structural_tag_schema with an alias
structural_tag_schema: Optional[dict[str, Any]] = Field(default=None,
alias="schema")
end: str


class StructuralTagResponseFormat(OpenAIBaseModel):
type: Literal["structural_tag"]
structures: list[StructuralTag]
triggers: list[str]


class ResponseFormat(OpenAIBaseModel):
# type must be "json_schema", "json_object" or "text"
# type must be "json_schema", "json_object", or "text"
type: Literal["text", "json_object", "json_schema"]
json_schema: Optional[JsonSchemaResponseFormat] = None


AnyResponseFormat = Union[ResponseFormat, StructuralTagResponseFormat]


class StreamOptions(OpenAIBaseModel):
include_usage: Optional[bool] = True
continuous_usage_stats: Optional[bool] = False
Expand Down Expand Up @@ -227,7 +246,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
max_completion_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None
response_format: Optional[AnyResponseFormat] = None
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: Optional[Union[str, list[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
Expand Down Expand Up @@ -340,6 +359,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=(
"If specified, the output will follow the context free grammar."),
)
structural_tag: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the structural tag schema."),
)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
Expand Down Expand Up @@ -476,6 +500,12 @@ def to_sampling_params(
json_schema = self.response_format.json_schema
assert json_schema is not None
self.guided_json = json_schema.json_schema
elif self.response_format.type == "structural_tag":
structural_tag = self.response_format
assert structural_tag is not None and isinstance(
structural_tag, StructuralTagResponseFormat)
s_tag_obj = structural_tag.model_dump(by_alias=True)
self.structural_tag = json.dumps(s_tag_obj)
Comment on lines +507 to +508
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
s_tag_obj = structural_tag.model_dump(by_alias=True)
self.structural_tag = json.dumps(s_tag_obj)
self.structural_tag = structural_tag.model_dump_json(by_alias=True)

Then you don't have to use json here :)


guided_decoding = GuidedDecodingParams.from_optional(
json=self._get_guided_json_from_tool() or self.guided_json,
Expand All @@ -485,6 +515,7 @@ def to_sampling_params(
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern,
structural_tag=self.structural_tag,
)

return SamplingParams.from_optional(
Expand Down Expand Up @@ -742,12 +773,13 @@ class CompletionRequest(OpenAIBaseModel):
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."),
)
response_format: Optional[ResponseFormat] = Field(
response_format: Optional[AnyResponseFormat] = Field(
default=None,
description=
("Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'}, {'type': 'json_schema'} or "
"{'type': 'text' } is supported."),
description=(
"Similar to chat completion, this parameter specifies the format "
"of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
Expand Down
11 changes: 6 additions & 5 deletions vllm/model_executor/guided_decoding/guided_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ class GuidedDecodingRequest:
guided_decoding_backend: Optional[str] = None
guided_whitespace_pattern: Optional[str] = None
guided_json_object: Optional[bool] = None
structural_tag: Optional[str] = None

def __post_init__(self):
"""Validate that some fields are mutually exclusive."""
guide_count = sum([
self.guided_json is not None, self.guided_regex is not None,
self.guided_choice is not None, self.guided_grammar is not None,
self.guided_json_object is not None
])
guide_count = sum(x is not None
for x in (self.guided_json, self.guided_regex,
self.guided_choice, self.guided_grammar,
self.guided_json_object,
self.structural_tag))
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding but multiple are "
Expand Down
7 changes: 5 additions & 2 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class GuidedDecodingParams:
"""These are other options that can be set"""
backend: Optional[str] = None
whitespace_pattern: Optional[str] = None
structural_tag: Optional[str] = None

@staticmethod
def from_optional(
Expand All @@ -48,9 +49,10 @@ def from_optional(
json_object: Optional[bool] = None,
backend: Optional[str] = None,
whitespace_pattern: Optional[str] = None,
structural_tag: Optional[str] = None,
) -> Optional["GuidedDecodingParams"]:
if all(arg is None
for arg in (json, regex, choice, grammar, json_object)):
if all(arg is None for arg in (json, regex, choice, grammar,
json_object, structural_tag)):
return None
# Extract json schemas from pydantic models
if isinstance(json, (BaseModel, type(BaseModel))):
Expand All @@ -63,6 +65,7 @@ def from_optional(
json_object=json_object,
backend=backend,
whitespace_pattern=whitespace_pattern,
structural_tag=structural_tag,
)

@property
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/structured_output/backend_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ def serialize_guidance_grammar(
tp = "grammar"
elif request_type == StructuredOutputOptions.CHOICE:
tp = "choice"
elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
raise ValueError("Structural tag is not supported "
"for guidance backend yet")
else:
logger.error("Validation should have already occurred. "
"Please file an issue.")
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/structured_output/backend_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class StructuredOutputOptions(enum.Enum):
REGEX = enum.auto()
GRAMMAR = enum.auto()
CHOICE = enum.auto()
STRUCTURAL_TAG = enum.auto()


StructuredOutputKey = tuple[StructuredOutputOptions, str]
Expand Down
Loading