Skip to content

Commit aa1b82f

Browse files
leon-seidelrussellb
authored andcommitted
Enums for XGrammar with V0
Signed-off-by: Leon Seidel <leon.seidel@fau.de>
1 parent 090c856 commit aa1b82f

File tree

2 files changed

+43
-4
lines changed

2 files changed

+43
-4
lines changed

tests/entrypoints/llm/test_guided_generate.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import json
44
import re
55
import weakref
6+
from enum import Enum
67

78
import jsonschema
89
import pytest
10+
from pydantic import BaseModel
911

1012
from vllm.distributed import cleanup_dist_env_and_memory
1113
from vllm.entrypoints.llm import LLM
@@ -330,3 +332,44 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
330332
# Parse to verify it is valid JSON
331333
parsed_json = json.loads(generated_text)
332334
assert isinstance(parsed_json, dict)
335+
336+
337+
class CarType(str, Enum):
338+
sedan = "sedan"
339+
suv = "SUV"
340+
truck = "Truck"
341+
coupe = "Coupe"
342+
343+
344+
class CarDescription(BaseModel):
345+
brand: str
346+
model: str
347+
car_type: CarType
348+
349+
350+
@pytest.mark.skip_global_cleanup
351+
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
352+
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
353+
json_schema = CarDescription.model_json_schema()
354+
sampling_params = SamplingParams(temperature=1.0,
355+
max_tokens=1000,
356+
guided_decoding=GuidedDecodingParams(
357+
json=json_schema,
358+
backend=guided_decoding_backend))
359+
outputs = llm.generate(
360+
prompts="Generate a JSON with the brand, model and car_type of"
361+
"the most iconic car from the 90's",
362+
sampling_params=sampling_params,
363+
use_tqdm=True)
364+
365+
assert outputs is not None
366+
for output in outputs:
367+
assert output is not None
368+
assert isinstance(output, RequestOutput)
369+
prompt = output.prompt
370+
371+
generated_text = output.outputs[0].text
372+
assert generated_text is not None
373+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
374+
output_json = json.loads(generated_text)
375+
jsonschema.validate(instance=output_json, schema=json_schema)

vllm/model_executor/guided_decoding/utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@ def check_object(obj: dict) -> bool:
1414
if "pattern" in obj:
1515
return True
1616

17-
# Check for enum restrictions
18-
if "enum" in obj:
19-
return True
20-
2117
# Check for numeric ranges
2218
if obj.get("type") in ("integer", "number") and any(
2319
key in obj for key in [

0 commit comments

Comments
 (0)