Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 29 additions & 21 deletions examples/offline_inference/structured_outputs.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file demonstrates the example usage of guided decoding
to generate structured outputs using vLLM. It shows how to apply
different guided decoding techniques such as Choice, Regex, JSON schema,
and Grammar to produce structured and formatted results
This file demonstrates the example usage to generate structured
outputs using vLLM. It shows how to apply different techniques
such as Choice, Regex, JSON schema, and Grammar to produce
structured and formatted results
based on specific prompts.
"""

Expand All @@ -13,17 +13,21 @@
from pydantic import BaseModel

from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams
from vllm.sampling_params import StructuredOuputsParams

# Guided decoding by Choice (list of possible options)
guided_decoding_params_choice = GuidedDecodingParams(choice=["Positive", "Negative"])
sampling_params_choice = SamplingParams(guided_decoding=guided_decoding_params_choice)
# Structured outputs by Choice (list of possible options)
structured_outputs_params_choice = StructuredOuputsParams(
choice=["Positive", "Negative"]
)
sampling_params_choice = SamplingParams(
structured_outputs=structured_outputs_params_choice
)
prompt_choice = "Classify this sentiment: vLLM is wonderful!"

# Guided decoding by Regex
guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n")
# Structured outputs by Regex
structured_outputs_params_regex = StructuredOuputsParams(regex=r"\w+@\w+\.com\n")
sampling_params_regex = SamplingParams(
guided_decoding=guided_decoding_params_regex, stop=["\n"]
structured_outputs=structured_outputs_params_regex, stop=["\n"]
)
prompt_regex = (
"Generate an email address for Alan Turing, who works in Enigma."
Expand All @@ -32,7 +36,7 @@
)


# Guided decoding by JSON using Pydantic schema
# Structured outputs by JSON using Pydantic schema
class CarType(str, Enum):
sedan = "sedan"
suv = "SUV"
Expand All @@ -47,14 +51,14 @@ class CarDescription(BaseModel):


json_schema = CarDescription.model_json_schema()
guided_decoding_params_json = GuidedDecodingParams(json=json_schema)
sampling_params_json = SamplingParams(guided_decoding=guided_decoding_params_json)
structured_outputs_params_json = StructuredOuputsParams(json=json_schema)
sampling_params_json = SamplingParams(structured_outputs=structured_outputs_params_json)
prompt_json = (
"Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's"
)

# Guided decoding by Grammar
# Structured outputs by Grammar
simplified_sql_grammar = """
root ::= select_statement
select_statement ::= "SELECT " column " from " table " where " condition
Expand All @@ -63,8 +67,12 @@ class CarDescription(BaseModel):
condition ::= column "= " number
number ::= "1 " | "2 "
"""
guided_decoding_params_grammar = GuidedDecodingParams(grammar=simplified_sql_grammar)
sampling_params_grammar = SamplingParams(guided_decoding=guided_decoding_params_grammar)
structured_outputs_params_grammar = StructuredOuputsParams(
grammar=simplified_sql_grammar
)
sampling_params_grammar = SamplingParams(
structured_outputs=structured_outputs_params_grammar
)
prompt_grammar = (
"Generate an SQL query to show the 'username' and 'email'from the 'users' table."
)
Expand All @@ -83,16 +91,16 @@ def main():
llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100)

choice_output = generate_output(prompt_choice, sampling_params_choice, llm)
format_output("Guided decoding by Choice", choice_output)
format_output("Structured outputs by Choice", choice_output)

regex_output = generate_output(prompt_regex, sampling_params_regex, llm)
format_output("Guided decoding by Regex", regex_output)
format_output("Structured outputs by Regex", regex_output)

json_output = generate_output(prompt_json, sampling_params_json, llm)
format_output("Guided decoding by JSON", json_output)
format_output("Structured outputs by JSON", json_output)

grammar_output = generate_output(prompt_grammar, sampling_params_grammar, llm)
format_output("Guided decoding by Grammar", grammar_output)
format_output("Structured outputs by Grammar", grammar_output)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
without any specific flags:

```bash
VLLM_USE_V1=0 vllm serve unsloth/Llama-3.2-1B-Instruct \
--guided-decoding-backend outlines
vllm serve unsloth/Llama-3.2-1B-Instruct \
--structured-output-config '{"backend": "xgrammar"}'
```

This example demonstrates how to generate chat completions
Expand Down
1 change: 0 additions & 1 deletion tests/async_engine/test_async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ async def test_new_requests_event():
engine = MockAsyncLLMEngine()
assert engine.get_model_config() is not None
assert engine.get_tokenizer() is not None
assert engine.get_decoding_config() is not None


def start_engine():
Expand Down
2 changes: 1 addition & 1 deletion tests/entrypoints/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def sample_enum_json_schema():


@pytest.fixture
def sample_guided_choice():
def sample_choice():
return [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript",
"Ruby", "Swift", "Kotlin"
Expand Down
36 changes: 18 additions & 18 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.sampling_params import SamplingParams, StructuredOuputsParams

MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"

Expand Down Expand Up @@ -49,7 +49,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
structured_outputs=StructuredOuputsParams(
regex=sample_regex,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_guided_json_completion(sample_json_schema, llm,
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
structured_outputs=StructuredOuputsParams(
json=sample_json_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm,
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
structured_outputs=StructuredOuputsParams(
json=sample_complex_json_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
Expand Down Expand Up @@ -150,7 +150,7 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
structured_outputs=StructuredOuputsParams(
json=sample_definition_json_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
Expand Down Expand Up @@ -185,7 +185,7 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm,
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
structured_outputs=StructuredOuputsParams(
json=sample_enum_json_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
Expand Down Expand Up @@ -224,14 +224,14 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
ALL_DECODING_BACKENDS)
def test_guided_choice_completion(sample_guided_choice, llm,
def test_guided_choice_completion(sample_choice, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
choice=sample_guided_choice,
structured_outputs=StructuredOuputsParams(
choice=sample_choice,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(
Expand All @@ -247,7 +247,7 @@ def test_guided_choice_completion(sample_guided_choice, llm,
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert generated_text in sample_guided_choice
assert generated_text in sample_choice
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


Expand All @@ -261,7 +261,7 @@ def test_guided_grammar(sample_sql_statements, llm,
temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
structured_outputs=StructuredOuputsParams(
grammar=sample_sql_statements,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
Expand Down Expand Up @@ -310,7 +310,7 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(regex=sample_regex))
structured_outputs=StructuredOuputsParams(regex=sample_regex))

with pytest.raises(ValueError, match="Cannot set both"):
llm.generate(prompts="This should fail",
Expand All @@ -333,7 +333,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
}
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
structured_outputs=StructuredOuputsParams(
json=unsupported_json,
backend="xgrammar",
disable_fallback=True))
Expand All @@ -356,7 +356,7 @@ def test_guided_json_object(llm, guided_decoding_backend: str,
temperature=1.0,
max_tokens=100,
n=2,
guided_decoding=GuidedDecodingParams(
structured_outputs=StructuredOuputsParams(
json_object=True,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
Expand Down Expand Up @@ -409,7 +409,7 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
structured_outputs=StructuredOuputsParams(
json=json_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
Expand Down Expand Up @@ -460,7 +460,7 @@ def test_guided_number_range_json_completion(llm, guided_decoding_backend: str,
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
structured_outputs=StructuredOuputsParams(
json=sample_output_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace),
Expand Down Expand Up @@ -516,14 +516,14 @@ def test_guidance_no_additional_properties(llm):
"<|im_end|>\n<|im_start|>assistant\n")

def generate_with_backend(backend, disable_additional_properties):
guided_params = GuidedDecodingParams(
guided_params = StructuredOuputsParams(
json=schema,
backend=backend,
disable_any_whitespace=True,
disable_additional_properties=disable_additional_properties)
sampling_params = SamplingParams(temperature=0,
max_tokens=256,
guided_decoding=guided_params)
structured_outputs=guided_params)

outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
assert outputs is not None
Expand Down
15 changes: 7 additions & 8 deletions tests/entrypoints/openai/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,7 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,


@pytest.mark.asyncio
async def test_guided_choice_chat(client: openai.AsyncOpenAI,
sample_guided_choice):
async def test_guided_choice_chat(client: openai.AsyncOpenAI, sample_choice):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -503,9 +502,9 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
messages=messages,
max_completion_tokens=10,
temperature=0.7,
extra_body=dict(guided_choice=sample_guided_choice))
extra_body=dict(guided_choice=sample_choice))
choice1 = chat_completion.choices[0].message.content
assert choice1 in sample_guided_choice
assert choice1 in sample_choice

messages.append({"role": "assistant", "content": choice1})
messages.append({
Expand All @@ -517,9 +516,9 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
messages=messages,
max_completion_tokens=10,
temperature=0.7,
extra_body=dict(guided_choice=sample_guided_choice))
extra_body=dict(guided_choice=sample_choice))
choice2 = chat_completion.choices[0].message.content
assert choice2 in sample_guided_choice
assert choice2 in sample_choice
assert choice1 != choice2


Expand Down Expand Up @@ -624,7 +623,7 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI):

@pytest.mark.asyncio
async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
sample_guided_choice):
sample_choice):

messages = [{
"role": "system",
Expand All @@ -641,7 +640,7 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
max_completion_tokens=10,
logprobs=True,
top_logprobs=5,
extra_body=dict(guided_choice=sample_guided_choice))
extra_body=dict(guided_choice=sample_choice))

assert chat_completion.choices[0].logprobs is not None
assert chat_completion.choices[0].logprobs.content is not None
Expand Down
6 changes: 3 additions & 3 deletions tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,20 +686,20 @@ async def test_guided_regex_completion(client: openai.AsyncOpenAI,
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_choice_completion(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_guided_choice):
sample_choice):
completion = await client.completions.create(
model=MODEL_NAME,
prompt="The best language for type-safe systems programming is ",
n=2,
temperature=1.0,
max_tokens=10,
extra_body=dict(guided_choice=sample_guided_choice,
extra_body=dict(guided_choice=sample_choice,
guided_decoding_backend=guided_decoding_backend))

assert completion.id is not None
assert len(completion.choices) == 2
for i in range(2):
assert completion.choices[i].text in sample_guided_choice
assert completion.choices[i].text in sample_choice


@pytest.mark.asyncio
Expand Down
Loading
Loading