Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import json
import re
import weakref
from enum import Enum

import jsonschema
import pytest
from pydantic import BaseModel

from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM
Expand Down Expand Up @@ -330,3 +332,44 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)


class CarType(str, Enum):
sedan = "sedan"
suv = "SUV"
truck = "Truck"
coupe = "Coupe"


class CarDescription(BaseModel):
brand: str
model: str
car_type: CarType


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
json_schema = CarDescription.model_json_schema()
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=json_schema,
backend=guided_decoding_backend))
outputs = llm.generate(
prompts="Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's",
sampling_params=sampling_params,
use_tqdm=True)

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt

generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema)
4 changes: 0 additions & 4 deletions vllm/model_executor/guided_decoding/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ def check_object(obj: dict) -> bool:
if "pattern" in obj:
return True

# Check for enum restrictions
if "enum" in obj:
return True

# Check for numeric ranges
if obj.get("type") in ("integer", "number") and any(
key in obj for key in [
Expand Down