From aa1b82f43a3dbb3a9521fc93b35e07ebc89d5a25 Mon Sep 17 00:00:00 2001 From: Leon Seidel Date: Tue, 1 Apr 2025 12:42:52 +0200 Subject: [PATCH] Enums for XGrammar with V0 Signed-off-by: Leon Seidel --- tests/entrypoints/llm/test_guided_generate.py | 43 +++++++++++++++++++ vllm/model_executor/guided_decoding/utils.py | 4 -- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 3f275e0b2ec7..3b85ad68c057 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -3,9 +3,11 @@ import json import re import weakref +from enum import Enum import jsonschema import pytest +from pydantic import BaseModel from vllm.distributed import cleanup_dist_env_and_memory from vllm.entrypoints.llm import LLM @@ -330,3 +332,44 @@ def test_guided_json_object(llm, guided_decoding_backend: str): # Parse to verify it is valid JSON parsed_json = json.loads(generated_text) assert isinstance(parsed_json, dict) + + +class CarType(str, Enum): + sedan = "sedan" + suv = "SUV" + truck = "Truck" + coupe = "Coupe" + + +class CarDescription(BaseModel): + brand: str + model: str + car_type: CarType + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str): + json_schema = CarDescription.model_json_schema() + sampling_params = SamplingParams(temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=json_schema, + backend=guided_decoding_backend)) + outputs = llm.generate( + prompts="Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's", + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, schema=json_schema) \ No newline at end of file diff --git a/vllm/model_executor/guided_decoding/utils.py b/vllm/model_executor/guided_decoding/utils.py index 10981776e768..ba7c10252699 100644 --- a/vllm/model_executor/guided_decoding/utils.py +++ b/vllm/model_executor/guided_decoding/utils.py @@ -14,10 +14,6 @@ def check_object(obj: dict) -> bool: if "pattern" in obj: return True - # Check for enum restrictions - if "enum" in obj: - return True - # Check for numeric ranges if obj.get("type") in ("integer", "number") and any( key in obj for key in [