1313from vllm .outputs import RequestOutput
1414from vllm .sampling_params import GuidedDecodingParams , SamplingParams
1515
16- GUIDED_DECODING_BACKENDS_V1 = ["xgrammar" ]
16+ GUIDED_DECODING_BACKENDS_V1 = ["xgrammar" , "guidance" ]
1717MODELS_TO_TEST = [
1818 "Qwen/Qwen2.5-1.5B-Instruct" , "mistralai/Ministral-8B-Instruct-2410"
1919]
@@ -30,12 +30,13 @@ def test_guided_json_completion(
3030 model_name : str ,
3131):
3232 monkeypatch .setenv ("VLLM_USE_V1" , "1" )
33- llm = LLM (model = model_name , max_model_len = 1024 )
34- sampling_params = SamplingParams (temperature = 1.0 ,
35- max_tokens = 1000 ,
36- guided_decoding = GuidedDecodingParams (
37- json = sample_json_schema ,
38- backend = guided_decoding_backend ))
33+ llm = LLM (model = model_name ,
34+ max_model_len = 1024 ,
35+ guided_decoding_backend = guided_decoding_backend )
36+ sampling_params = SamplingParams (
37+ temperature = 1.0 ,
38+ max_tokens = 1000 ,
39+ guided_decoding = GuidedDecodingParams (json = sample_json_schema ))
3940 outputs = llm .generate (prompts = [
4041 f"Give an example JSON for an employee profile "
4142 f"that fits this schema: { sample_json_schema } "
@@ -67,13 +68,14 @@ def test_guided_json_object(
6768 model_name : str ,
6869):
6970 monkeypatch .setenv ("VLLM_USE_V1" , "1" )
70- llm = LLM (model = model_name , max_model_len = 1024 )
71- sampling_params = SamplingParams (temperature = 1.0 ,
72- max_tokens = 100 ,
73- n = 2 ,
74- guided_decoding = GuidedDecodingParams (
75- json_object = True ,
76- backend = guided_decoding_backend ))
71+ llm = LLM (model = model_name ,
72+ max_model_len = 1024 ,
73+ guided_decoding_backend = guided_decoding_backend )
74+ sampling_params = SamplingParams (
75+ temperature = 1.0 ,
76+ max_tokens = 100 ,
77+ n = 2 ,
78+ guided_decoding = GuidedDecodingParams (json_object = True ))
7779
7880 outputs = llm .generate (
7981 prompts = ("Generate a JSON object with curly braces for a person with "
@@ -98,7 +100,7 @@ def test_guided_json_object(
98100
99101@pytest .mark .skip_global_cleanup
100102@pytest .mark .parametrize ("guided_decoding_backend" ,
101- GUIDED_DECODING_BACKENDS_V1 )
103+ GUIDED_DECODING_BACKENDS_V1 + [ "auto" ] )
102104@pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
103105def test_guided_json_unsupported_schema (
104106 monkeypatch : pytest .MonkeyPatch ,
@@ -107,21 +109,43 @@ def test_guided_json_unsupported_schema(
107109 model_name : str ,
108110):
109111 monkeypatch .setenv ("VLLM_USE_V1" , "1" )
110- llm = LLM (model = model_name , max_model_len = 1024 )
111- sampling_params = SamplingParams (temperature = 1.0 ,
112- max_tokens = 1000 ,
113- guided_decoding = GuidedDecodingParams (
114- json = unsupported_json_schema ,
115- backend = guided_decoding_backend ))
116- with pytest .raises (ValueError ,
117- match = "The provided JSON schema contains features "
118- "not supported by xgrammar." ):
119- llm .generate (prompts = [
120- f"Give an example JSON for an employee profile "
121- f"that fits this schema: { unsupported_json_schema } "
122- ] * 2 ,
123- sampling_params = sampling_params ,
124- use_tqdm = True )
112+ llm = LLM (model = model_name ,
113+ max_model_len = 1024 ,
114+ guided_decoding_backend = guided_decoding_backend )
115+ sampling_params = SamplingParams (
116+ temperature = 1.0 ,
117+ max_tokens = 1000 ,
118+ guided_decoding = GuidedDecodingParams (json = unsupported_json_schema ))
119+ if guided_decoding_backend == "xgrammar" :
120+ with pytest .raises (ValueError ,
121+ match = "The provided JSON schema contains features "
122+ "not supported by xgrammar." ):
123+ llm .generate (prompts = [
124+ f"Give an example JSON for an employee profile "
125+ f"that fits this schema: { unsupported_json_schema } "
126+ ] * 2 ,
127+ sampling_params = sampling_params ,
128+ use_tqdm = True )
129+ else :
130+ # This should work for both "guidelines" and "auto".
131+
132+ outputs = llm .generate (
133+ prompts = ("Give an example JSON object for a grade "
134+ "that fits this schema: "
135+ f"{ unsupported_json_schema } " ),
136+ sampling_params = sampling_params ,
137+ use_tqdm = True )
138+ assert outputs is not None
139+ for output in outputs :
140+ assert output is not None
141+ assert isinstance (output , RequestOutput )
142+ generated_text = output .outputs [0 ].text
143+ assert generated_text is not None
144+ print (generated_text )
145+
146+ # Parse to verify it is valid JSON
147+ parsed_json = json .loads (generated_text )
148+ assert isinstance (parsed_json , dict )
125149
126150
127151@pytest .mark .skip_global_cleanup
@@ -135,13 +159,14 @@ def test_guided_grammar_ebnf(
135159 model_name : str ,
136160):
137161 monkeypatch .setenv ("VLLM_USE_V1" , "1" )
138- llm = LLM (model = model_name , max_model_len = 1024 )
139- sampling_params = SamplingParams (temperature = 0.8 ,
140- top_p = 0.95 ,
141- max_tokens = 1000 ,
142- guided_decoding = GuidedDecodingParams (
143- grammar = sample_sql_ebnf ,
144- backend = guided_decoding_backend ))
162+ llm = LLM (model = model_name ,
163+ max_model_len = 1024 ,
164+ guided_decoding_backend = guided_decoding_backend )
165+ sampling_params = SamplingParams (
166+ temperature = 0.8 ,
167+ top_p = 0.95 ,
168+ max_tokens = 1000 ,
169+ guided_decoding = GuidedDecodingParams (grammar = sample_sql_ebnf ))
145170 outputs = llm .generate (
146171 prompts = ("Generate a sql statement that selects col_1 from "
147172 "table_1 where it is equal to 1" ),
@@ -178,13 +203,14 @@ def test_guided_grammar_lark(
178203 model_name : str ,
179204):
180205 monkeypatch .setenv ("VLLM_USE_V1" , "1" )
181- llm = LLM (model = model_name , max_model_len = 1024 )
182- sampling_params = SamplingParams (temperature = 0.8 ,
183- top_p = 0.95 ,
184- max_tokens = 1000 ,
185- guided_decoding = GuidedDecodingParams (
186- grammar = sample_sql_lark ,
187- backend = guided_decoding_backend ))
206+ llm = LLM (model = model_name ,
207+ max_model_len = 1024 ,
208+ guided_decoding_backend = guided_decoding_backend )
209+ sampling_params = SamplingParams (
210+ temperature = 0.8 ,
211+ top_p = 0.95 ,
212+ max_tokens = 1000 ,
213+ guided_decoding = GuidedDecodingParams (grammar = sample_sql_lark ))
188214 outputs = llm .generate (
189215 prompts = ("Generate a sql statement that selects col_1 from "
190216 "table_1 where it is equal to 1" ),
@@ -225,13 +251,14 @@ def test_guided_grammar_ebnf_invalid(
225251 model_name : str ,
226252):
227253 monkeypatch .setenv ("VLLM_USE_V1" , "1" )
228- llm = LLM (model = model_name , max_model_len = 1024 )
229- sampling_params = SamplingParams (temperature = 0.8 ,
230- top_p = 0.95 ,
231- max_tokens = 1000 ,
232- guided_decoding = GuidedDecodingParams (
233- grammar = "not a grammar" ,
234- backend = guided_decoding_backend ))
254+ llm = LLM (model = model_name ,
255+ max_model_len = 1024 ,
256+ guided_decoding_backend = guided_decoding_backend )
257+ sampling_params = SamplingParams (
258+ temperature = 0.8 ,
259+ top_p = 0.95 ,
260+ max_tokens = 1000 ,
261+ guided_decoding = GuidedDecodingParams (grammar = "not a grammar" ))
235262 with pytest .raises (ValueError ,
236263 match = "Failed to convert the grammar "
237264 "from Lark to EBNF." ):
@@ -254,12 +281,13 @@ def test_guided_regex(
254281 model_name : str ,
255282):
256283 monkeypatch .setenv ("VLLM_USE_V1" , "1" )
257- llm = LLM (model = model_name , max_model_len = 1024 )
258- sampling_params = SamplingParams (temperature = 0.8 ,
259- top_p = 0.95 ,
260- guided_decoding = GuidedDecodingParams (
261- regex = sample_regex ,
262- backend = guided_decoding_backend ))
284+ llm = LLM (model = model_name ,
285+ max_model_len = 1024 ,
286+ guided_decoding_backend = guided_decoding_backend )
287+ sampling_params = SamplingParams (
288+ temperature = 0.8 ,
289+ top_p = 0.95 ,
290+ guided_decoding = GuidedDecodingParams (regex = sample_regex ))
263291 outputs = llm .generate (
264292 prompts = [
265293 f"Give an example IPv4 address with this regex: { sample_regex } "
@@ -291,12 +319,13 @@ def test_guided_choice_completion(
291319 model_name : str ,
292320):
293321 monkeypatch .setenv ("VLLM_USE_V1" , "1" )
294- llm = LLM (model = model_name , max_model_len = 1024 )
295- sampling_params = SamplingParams (temperature = 0.8 ,
296- top_p = 0.95 ,
297- guided_decoding = GuidedDecodingParams (
298- choice = sample_guided_choice ,
299- backend = guided_decoding_backend ))
322+ llm = LLM (model = model_name ,
323+ max_model_len = 1024 ,
324+ guided_decoding_backend = guided_decoding_backend )
325+ sampling_params = SamplingParams (
326+ temperature = 0.8 ,
327+ top_p = 0.95 ,
328+ guided_decoding = GuidedDecodingParams (choice = sample_guided_choice ))
300329 outputs = llm .generate (
301330 prompts = "The best language for type-safe systems programming is " ,
302331 sampling_params = sampling_params ,
0 commit comments