Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 89 additions & 133 deletions tests/v1/entrypoints/llm/test_struct_output_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,46 @@
]


class CarType(str, Enum):
sedan = "sedan"
suv = "SUV"
truck = "Truck"
coupe = "Coupe"


class CarDescription(BaseModel):
brand: str
model: str
car_type: CarType


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_json_completion(
def test_structured_output(
monkeypatch: pytest.MonkeyPatch,
sample_json_schema: dict[str, Any],
unsupported_json_schema: dict[str, Any],
sample_sql_ebnf: str,
sample_sql_lark: str,
sample_regex: str,
sample_guided_choice: str,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")

# Use a single LLM instance for several scenarios to
# speed up the test suite.
llm = LLM(model=model_name,
enforce_eager=True,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)

#
# Test 1: Generate JSON output based on a provided schema
#
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
Expand All @@ -63,20 +89,9 @@ def test_guided_json_completion(
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_json_object(
monkeypatch: pytest.MonkeyPatch,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
#
# Test 2: Generate JSON object without a schema
#
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=100,
Expand Down Expand Up @@ -111,21 +126,9 @@ def test_guided_json_object(
allowed_types = (dict, list)
assert isinstance(parsed_json, allowed_types)


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1 + ["auto"])
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_json_unsupported_schema(
monkeypatch: pytest.MonkeyPatch,
unsupported_json_schema: dict[str, Any],
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
#
# Test 3: test a jsonschema incompatible with xgrammar
#
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
Expand All @@ -141,8 +144,6 @@ def test_guided_json_unsupported_schema(
sampling_params=sampling_params,
use_tqdm=True)
else:
# This should work for both "guidance" and "auto".

outputs = llm.generate(
prompts=("Give an example JSON object for a grade "
"that fits this schema: "
Expand All @@ -161,21 +162,9 @@ def test_guided_json_unsupported_schema(
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_grammar_ebnf(
monkeypatch: pytest.MonkeyPatch,
sample_sql_ebnf: str,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
#
# Test 4: Generate SQL statement using EBNF grammar
#
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
Expand Down Expand Up @@ -205,21 +194,9 @@ def test_guided_grammar_ebnf(

print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_grammar_lark(
monkeypatch: pytest.MonkeyPatch,
sample_sql_lark: str,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
#
# Test 5: Generate SQL statement using Lark grammar
#
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
Expand Down Expand Up @@ -254,20 +231,9 @@ def test_guided_grammar_lark(

print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_grammar_ebnf_invalid(
monkeypatch: pytest.MonkeyPatch,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
#
# Test 6: Test invalid grammar input
#
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
Expand All @@ -281,21 +247,9 @@ def test_guided_grammar_ebnf_invalid(
use_tqdm=True,
)


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_regex(
monkeypatch: pytest.MonkeyPatch,
sample_regex: str,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
#
# Test 7: Generate text based on a regex pattern
#
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
Expand All @@ -319,21 +273,9 @@ def test_guided_regex(
assert re.fullmatch(sample_regex, generated_text) is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_choice_completion(
monkeypatch: pytest.MonkeyPatch,
sample_guided_choice: str,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
#
# Test 8: Generate text based on a choices
#
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
Expand All @@ -353,33 +295,9 @@ def test_guided_choice_completion(
assert generated_text in sample_guided_choice
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


class CarType(str, Enum):
sedan = "sedan"
suv = "SUV"
truck = "Truck"
coupe = "Coupe"


class CarDescription(BaseModel):
brand: str
model: str
car_type: CarType


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_json_completion_with_enum(
monkeypatch: pytest.MonkeyPatch,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
#
# Test 9: Generate structured output using a Pydantic model with an enum
#
json_schema = CarDescription.model_json_schema()
sampling_params = SamplingParams(
temperature=1.0,
Expand All @@ -403,3 +321,41 @@ def test_guided_json_completion_with_enum(
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema)


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_structured_output_auto_mode(
monkeypatch: pytest.MonkeyPatch,
unsupported_json_schema: dict[str, Any],
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")

llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend="auto")

sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))

# This would fail with the default of "xgrammar", but in "auto"
# we will handle fallback automatically.
outputs = llm.generate(prompts=("Give an example JSON object for a grade "
"that fits this schema: "
f"{unsupported_json_schema}"),
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
generated_text = output.outputs[0].text
assert generated_text is not None
print(generated_text)

# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)