Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 29 additions & 25 deletions tests/full_tests/structured_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,39 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copied from vllm/examples/offline_inference/
"""
This file demonstrates the example usage of guided decoding
to generate structured outputs using vLLM. It shows how to apply
different guided decoding techniques such as Choice, Regex, JSON schema,
and Grammar to produce structured and formatted results
based on specific prompts.
This file demonstrates the example usage of structured outputs
in vLLM. It shows how to apply different constraints such as choice,
regex, json schema, and grammar to produce structured and formatted
results based on specific prompts.
"""

from enum import Enum

from pydantic import BaseModel

from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams
from vllm.sampling_params import StructuredOutputsParams

# Guided decoding by Choice (list of possible options)
guided_decoding_params_choice = GuidedDecodingParams(choice=["Positive", "Negative"])
sampling_params_choice = SamplingParams(guided_decoding=guided_decoding_params_choice)
MAX_TOKENS = 1024

# Structured outputs by Choice (list of possible options)
structured_outputs_params_choice = StructuredOutputsParams(choice=["Positive", "Negative"])
sampling_params_choice = SamplingParams(structured_outputs=structured_outputs_params_choice)
prompt_choice = "Classify this sentiment: vLLM is wonderful!"

# Guided decoding by Regex
guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n")
sampling_params_regex = SamplingParams(guided_decoding=guided_decoding_params_regex, stop=["\n"])
# Structured outputs by Regex
structured_outputs_params_regex = StructuredOutputsParams(regex=r"\w+@\w+\.com\n")
sampling_params_regex = SamplingParams(
structured_outputs=structured_outputs_params_regex,
stop=["\n"],
max_tokens=MAX_TOKENS,
)
prompt_regex = ("Generate an email address for Alan Turing, who works in Enigma."
"End in .com and new line. Example result:"
"alan.turing@enigma.com\n")


# Guided decoding by JSON using Pydantic schema
# Structured outputs by JSON using Pydantic schema
class CarType(str, Enum):
sedan = "sedan"
suv = "SUV"
Expand All @@ -44,12 +49,12 @@ class CarDescription(BaseModel):


json_schema = CarDescription.model_json_schema()
guided_decoding_params_json = GuidedDecodingParams(json=json_schema)
sampling_params_json = SamplingParams(guided_decoding=guided_decoding_params_json)
prompt_json = ("Generate a JSON with the brand, model and car_type of"
structured_outputs_params_json = StructuredOutputsParams(json=json_schema)
sampling_params_json = SamplingParams(structured_outputs=structured_outputs_params_json, max_tokens=MAX_TOKENS)
prompt_json = ("Generate a JSON with the brand, model and car_type of "
"the most iconic car from the 90's")

# Guided decoding by Grammar
# Structured outputs by Grammar
simplified_sql_grammar = """
root ::= select_statement
select_statement ::= "SELECT " column " from " table " where " condition
Expand All @@ -58,10 +63,9 @@ class CarDescription(BaseModel):
condition ::= column "= " number
number ::= "1 " | "2 "
"""
guided_decoding_params_grammar = GuidedDecodingParams(grammar=simplified_sql_grammar)
sampling_params_grammar = SamplingParams(guided_decoding=guided_decoding_params_grammar)
prompt_grammar = ("Generate an SQL query to show the "
"'username' and 'email'from the 'users' table.")
structured_outputs_params_grammar = StructuredOutputsParams(grammar=simplified_sql_grammar)
sampling_params_grammar = SamplingParams(structured_outputs=structured_outputs_params_grammar, max_tokens=MAX_TOKENS)
prompt_grammar = ("Generate an SQL query to show the 'username' and 'email' from the 'users' table.")


def format_output(title: str, output: str):
Expand All @@ -77,16 +81,16 @@ def main():
llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=256)

choice_output = generate_output(prompt_choice, sampling_params_choice, llm)
format_output("Guided decoding by Choice", choice_output)
format_output("Structured outputs by Choice", choice_output)

regex_output = generate_output(prompt_regex, sampling_params_regex, llm)
format_output("Guided decoding by Regex", regex_output)
format_output("Structured outputs by Regex", regex_output)

json_output = generate_output(prompt_json, sampling_params_json, llm)
format_output("Guided decoding by JSON", json_output)
format_output("Structured outputs by JSON", json_output)

grammar_output = generate_output(prompt_grammar, sampling_params_grammar, llm)
format_output("Guided decoding by Grammar", grammar_output)
format_output("Structured outputs by Grammar", grammar_output)


if __name__ == "__main__":
Expand Down