diff --git a/tests/full_tests/structured_outputs.py b/tests/full_tests/structured_outputs.py index d4977b2e..e696a70b 100644 --- a/tests/full_tests/structured_outputs.py +++ b/tests/full_tests/structured_outputs.py @@ -2,11 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copied from vllm/examples/offline_inference/ """ -This file demonstrates the example usage of guided decoding -to generate structured outputs using vLLM. It shows how to apply -different guided decoding techniques such as Choice, Regex, JSON schema, -and Grammar to produce structured and formatted results -based on specific prompts. +This file demonstrates the example usage of structured outputs +in vLLM. It shows how to apply different constraints such as choice, +regex, json schema, and grammar to produce structured and formatted +results based on specific prompts. """ from enum import Enum @@ -14,22 +13,28 @@ from pydantic import BaseModel from vllm import LLM, SamplingParams -from vllm.sampling_params import GuidedDecodingParams +from vllm.sampling_params import StructuredOutputsParams -# Guided decoding by Choice (list of possible options) -guided_decoding_params_choice = GuidedDecodingParams(choice=["Positive", "Negative"]) -sampling_params_choice = SamplingParams(guided_decoding=guided_decoding_params_choice) +MAX_TOKENS = 1024 + +# Structured outputs by Choice (list of possible options) +structured_outputs_params_choice = StructuredOutputsParams(choice=["Positive", "Negative"]) +sampling_params_choice = SamplingParams(structured_outputs=structured_outputs_params_choice) prompt_choice = "Classify this sentiment: vLLM is wonderful!" -# Guided decoding by Regex -guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n") -sampling_params_regex = SamplingParams(guided_decoding=guided_decoding_params_regex, stop=["\n"]) +# Structured outputs by Regex +structured_outputs_params_regex = StructuredOutputsParams(regex=r"\w+@\w+\.com\n") +sampling_params_regex = SamplingParams( + structured_outputs=structured_outputs_params_regex, + stop=["\n"], + max_tokens=MAX_TOKENS, +) prompt_regex = ("Generate an email address for Alan Turing, who works in Enigma." "End in .com and new line. Example result:" "alan.turing@enigma.com\n") -# Guided decoding by JSON using Pydantic schema +# Structured outputs by JSON using Pydantic schema class CarType(str, Enum): sedan = "sedan" suv = "SUV" @@ -44,12 +49,12 @@ class CarDescription(BaseModel): json_schema = CarDescription.model_json_schema() -guided_decoding_params_json = GuidedDecodingParams(json=json_schema) -sampling_params_json = SamplingParams(guided_decoding=guided_decoding_params_json) -prompt_json = ("Generate a JSON with the brand, model and car_type of" +structured_outputs_params_json = StructuredOutputsParams(json=json_schema) +sampling_params_json = SamplingParams(structured_outputs=structured_outputs_params_json, max_tokens=MAX_TOKENS) +prompt_json = ("Generate a JSON with the brand, model and car_type of " "the most iconic car from the 90's") -# Guided decoding by Grammar +# Structured outputs by Grammar simplified_sql_grammar = """ root ::= select_statement select_statement ::= "SELECT " column " from " table " where " condition @@ -58,10 +63,9 @@ class CarDescription(BaseModel): condition ::= column "= " number number ::= "1 " | "2 " """ -guided_decoding_params_grammar = GuidedDecodingParams(grammar=simplified_sql_grammar) -sampling_params_grammar = SamplingParams(guided_decoding=guided_decoding_params_grammar) -prompt_grammar = ("Generate an SQL query to show the " - "'username' and 'email'from the 'users' table.") +structured_outputs_params_grammar = StructuredOutputsParams(grammar=simplified_sql_grammar) +sampling_params_grammar = SamplingParams(structured_outputs=structured_outputs_params_grammar, max_tokens=MAX_TOKENS) +prompt_grammar = ("Generate an SQL query to show the 'username' and 'email' from the 'users' table.") def format_output(title: str, output: str): @@ -77,16 +81,16 @@ def main(): llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=256) choice_output = generate_output(prompt_choice, sampling_params_choice, llm) - format_output("Guided decoding by Choice", choice_output) + format_output("Structured outputs by Choice", choice_output) regex_output = generate_output(prompt_regex, sampling_params_regex, llm) - format_output("Guided decoding by Regex", regex_output) + format_output("Structured outputs by Regex", regex_output) json_output = generate_output(prompt_json, sampling_params_json, llm) - format_output("Guided decoding by JSON", json_output) + format_output("Structured outputs by JSON", json_output) grammar_output = generate_output(prompt_grammar, sampling_params_grammar, llm) - format_output("Guided decoding by Grammar", grammar_output) + format_output("Structured outputs by Grammar", grammar_output) if __name__ == "__main__":