Skip to content

Commit e31d0f8

Browse files
chaunceyjianglk-chen
authored andcommitted
[Bugfix][v1] xgrammar structured output supports Enum. (vllm-project#15594)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
1 parent ae9e713 commit e31d0f8

File tree

2 files changed

+53
-4
lines changed

2 files changed

+53
-4
lines changed

tests/v1/entrypoints/llm/test_struct_output_generate.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
import json
66
import re
7+
from enum import Enum
78
from typing import Any
89

910
import jsonschema
1011
import pytest
12+
from pydantic import BaseModel
1113

1214
from vllm.entrypoints.llm import LLM
1315
from vllm.outputs import RequestOutput
@@ -390,3 +392,54 @@ def test_guided_choice_completion(
390392
assert generated_text is not None
391393
assert generated_text in sample_guided_choice
392394
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
395+
396+
397+
class CarType(str, Enum):
398+
sedan = "sedan"
399+
suv = "SUV"
400+
truck = "Truck"
401+
coupe = "Coupe"
402+
403+
404+
class CarDescription(BaseModel):
405+
brand: str
406+
model: str
407+
car_type: CarType
408+
409+
410+
@pytest.mark.skip_global_cleanup
411+
@pytest.mark.parametrize("guided_decoding_backend",
412+
GUIDED_DECODING_BACKENDS_V1)
413+
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
414+
def test_guided_json_completion_with_enum(
415+
monkeypatch: pytest.MonkeyPatch,
416+
guided_decoding_backend: str,
417+
model_name: str,
418+
):
419+
monkeypatch.setenv("VLLM_USE_V1", "1")
420+
llm = LLM(model=model_name,
421+
max_model_len=1024,
422+
guided_decoding_backend=guided_decoding_backend)
423+
json_schema = CarDescription.model_json_schema()
424+
sampling_params = SamplingParams(
425+
temperature=1.0,
426+
max_tokens=1000,
427+
guided_decoding=GuidedDecodingParams(json=json_schema))
428+
outputs = llm.generate(
429+
prompts="Generate a JSON with the brand, model and car_type of"
430+
"the most iconic car from the 90's",
431+
sampling_params=sampling_params,
432+
use_tqdm=True)
433+
434+
assert outputs is not None
435+
436+
for output in outputs:
437+
assert output is not None
438+
assert isinstance(output, RequestOutput)
439+
prompt = output.prompt
440+
441+
generated_text = output.outputs[0].text
442+
assert generated_text is not None
443+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
444+
output_json = json.loads(generated_text)
445+
jsonschema.validate(instance=output_json, schema=json_schema)

vllm/v1/structured_output/utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,6 @@ def check_object(obj: dict[str, Any]) -> bool:
2626
if "pattern" in obj:
2727
return True
2828

29-
# Check for enum restrictions
30-
if "enum" in obj:
31-
return True
32-
3329
# Check for numeric ranges
3430
if obj.get("type") in ("integer", "number") and any(
3531
key in obj

0 commit comments

Comments
 (0)