diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 6bdfa0fae4a2..00fa47575b6a 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -4,10 +4,12 @@ import json import re +from enum import Enum from typing import Any import jsonschema import pytest +from pydantic import BaseModel from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput @@ -390,3 +392,54 @@ def test_guided_choice_completion( assert generated_text is not None assert generated_text in sample_guided_choice print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +class CarType(str, Enum): + sedan = "sedan" + suv = "SUV" + truck = "Truck" + coupe = "Coupe" + + +class CarDescription(BaseModel): + brand: str + model: str + car_type: CarType + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +@pytest.mark.parametrize("model_name", MODELS_TO_TEST) +def test_guided_json_completion_with_enum( + monkeypatch: pytest.MonkeyPatch, + guided_decoding_backend: str, + model_name: str, +): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=model_name, + max_model_len=1024, + guided_decoding_backend=guided_decoding_backend) + json_schema = CarDescription.model_json_schema() + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams(json=json_schema)) + outputs = llm.generate( + prompts="Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's", + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, schema=json_schema) diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index 694e46f763f0..a771256ef29f 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -26,10 +26,6 @@ def check_object(obj: dict[str, Any]) -> bool: if "pattern" in obj: return True - # Check for enum restrictions - if "enum" in obj: - return True - # Check for numeric ranges if obj.get("type") in ("integer", "number") and any( key in obj