33import json
44import re
55import weakref
6+ from enum import Enum
67
78import jsonschema
89import pytest
10+ from pydantic import BaseModel
911
1012from vllm .distributed import cleanup_dist_env_and_memory
1113from vllm .entrypoints .llm import LLM
@@ -330,3 +332,44 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
330332 # Parse to verify it is valid JSON
331333 parsed_json = json .loads (generated_text )
332334 assert isinstance (parsed_json , dict )
335+
336+
337+ class CarType (str , Enum ):
338+ sedan = "sedan"
339+ suv = "SUV"
340+ truck = "Truck"
341+ coupe = "Coupe"
342+
343+
344+ class CarDescription (BaseModel ):
345+ brand : str
346+ model : str
347+ car_type : CarType
348+
349+
350+ @pytest .mark .skip_global_cleanup
351+ @pytest .mark .parametrize ("guided_decoding_backend" , GUIDED_DECODING_BACKENDS )
352+ def test_guided_json_completion_with_enum (llm , guided_decoding_backend : str ):
353+ json_schema = CarDescription .model_json_schema ()
354+ sampling_params = SamplingParams (temperature = 1.0 ,
355+ max_tokens = 1000 ,
356+ guided_decoding = GuidedDecodingParams (
357+ json = json_schema ,
358+ backend = guided_decoding_backend ))
359+ outputs = llm .generate (
360+ prompts = "Generate a JSON with the brand, model and car_type of"
361+ "the most iconic car from the 90's" ,
362+ sampling_params = sampling_params ,
363+ use_tqdm = True )
364+
365+ assert outputs is not None
366+ for output in outputs :
367+ assert output is not None
368+ assert isinstance (output , RequestOutput )
369+ prompt = output .prompt
370+
371+ generated_text = output .outputs [0 ].text
372+ assert generated_text is not None
373+ print (f"Prompt: { prompt !r} , Generated text: { generated_text !r} " )
374+ output_json = json .loads (generated_text )
375+ jsonschema .validate (instance = output_json , schema = json_schema )
0 commit comments