Skip to content

Commit 469d4ec

Browse files
committed
Add "auto" mode, enable test coverage
Signed-off-by: Russell Bryant <rbryant@redhat.com>
1 parent 122da1c commit 469d4ec

File tree

4 files changed

+127
-79
lines changed

4 files changed

+127
-79
lines changed

tests/v1/entrypoints/llm/test_struct_output_generate.py

Lines changed: 92 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from vllm.outputs import RequestOutput
1414
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
1515

16-
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"]
16+
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar", "guidance"]
1717
MODELS_TO_TEST = [
1818
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
1919
]
@@ -30,12 +30,13 @@ def test_guided_json_completion(
3030
model_name: str,
3131
):
3232
monkeypatch.setenv("VLLM_USE_V1", "1")
33-
llm = LLM(model=model_name, max_model_len=1024)
34-
sampling_params = SamplingParams(temperature=1.0,
35-
max_tokens=1000,
36-
guided_decoding=GuidedDecodingParams(
37-
json=sample_json_schema,
38-
backend=guided_decoding_backend))
33+
llm = LLM(model=model_name,
34+
max_model_len=1024,
35+
guided_decoding_backend=guided_decoding_backend)
36+
sampling_params = SamplingParams(
37+
temperature=1.0,
38+
max_tokens=1000,
39+
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
3940
outputs = llm.generate(prompts=[
4041
f"Give an example JSON for an employee profile "
4142
f"that fits this schema: {sample_json_schema}"
@@ -67,13 +68,14 @@ def test_guided_json_object(
6768
model_name: str,
6869
):
6970
monkeypatch.setenv("VLLM_USE_V1", "1")
70-
llm = LLM(model=model_name, max_model_len=1024)
71-
sampling_params = SamplingParams(temperature=1.0,
72-
max_tokens=100,
73-
n=2,
74-
guided_decoding=GuidedDecodingParams(
75-
json_object=True,
76-
backend=guided_decoding_backend))
71+
llm = LLM(model=model_name,
72+
max_model_len=1024,
73+
guided_decoding_backend=guided_decoding_backend)
74+
sampling_params = SamplingParams(
75+
temperature=1.0,
76+
max_tokens=100,
77+
n=2,
78+
guided_decoding=GuidedDecodingParams(json_object=True))
7779

7880
outputs = llm.generate(
7981
prompts=("Generate a JSON object with curly braces for a person with "
@@ -98,7 +100,7 @@ def test_guided_json_object(
98100

99101
@pytest.mark.skip_global_cleanup
100102
@pytest.mark.parametrize("guided_decoding_backend",
101-
GUIDED_DECODING_BACKENDS_V1)
103+
GUIDED_DECODING_BACKENDS_V1 + ["auto"])
102104
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
103105
def test_guided_json_unsupported_schema(
104106
monkeypatch: pytest.MonkeyPatch,
@@ -107,21 +109,43 @@ def test_guided_json_unsupported_schema(
107109
model_name: str,
108110
):
109111
monkeypatch.setenv("VLLM_USE_V1", "1")
110-
llm = LLM(model=model_name, max_model_len=1024)
111-
sampling_params = SamplingParams(temperature=1.0,
112-
max_tokens=1000,
113-
guided_decoding=GuidedDecodingParams(
114-
json=unsupported_json_schema,
115-
backend=guided_decoding_backend))
116-
with pytest.raises(ValueError,
117-
match="The provided JSON schema contains features "
118-
"not supported by xgrammar."):
119-
llm.generate(prompts=[
120-
f"Give an example JSON for an employee profile "
121-
f"that fits this schema: {unsupported_json_schema}"
122-
] * 2,
123-
sampling_params=sampling_params,
124-
use_tqdm=True)
112+
llm = LLM(model=model_name,
113+
max_model_len=1024,
114+
guided_decoding_backend=guided_decoding_backend)
115+
sampling_params = SamplingParams(
116+
temperature=1.0,
117+
max_tokens=1000,
118+
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
119+
if guided_decoding_backend == "xgrammar":
120+
with pytest.raises(ValueError,
121+
match="The provided JSON schema contains features "
122+
"not supported by xgrammar."):
123+
llm.generate(prompts=[
124+
f"Give an example JSON for an employee profile "
125+
f"that fits this schema: {unsupported_json_schema}"
126+
] * 2,
127+
sampling_params=sampling_params,
128+
use_tqdm=True)
129+
else:
130+
# This should work for both "guidelines" and "auto".
131+
132+
outputs = llm.generate(
133+
prompts=("Give an example JSON object for a grade "
134+
"that fits this schema: "
135+
f"{unsupported_json_schema}"),
136+
sampling_params=sampling_params,
137+
use_tqdm=True)
138+
assert outputs is not None
139+
for output in outputs:
140+
assert output is not None
141+
assert isinstance(output, RequestOutput)
142+
generated_text = output.outputs[0].text
143+
assert generated_text is not None
144+
print(generated_text)
145+
146+
# Parse to verify it is valid JSON
147+
parsed_json = json.loads(generated_text)
148+
assert isinstance(parsed_json, dict)
125149

126150

127151
@pytest.mark.skip_global_cleanup
@@ -135,13 +159,14 @@ def test_guided_grammar_ebnf(
135159
model_name: str,
136160
):
137161
monkeypatch.setenv("VLLM_USE_V1", "1")
138-
llm = LLM(model=model_name, max_model_len=1024)
139-
sampling_params = SamplingParams(temperature=0.8,
140-
top_p=0.95,
141-
max_tokens=1000,
142-
guided_decoding=GuidedDecodingParams(
143-
grammar=sample_sql_ebnf,
144-
backend=guided_decoding_backend))
162+
llm = LLM(model=model_name,
163+
max_model_len=1024,
164+
guided_decoding_backend=guided_decoding_backend)
165+
sampling_params = SamplingParams(
166+
temperature=0.8,
167+
top_p=0.95,
168+
max_tokens=1000,
169+
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
145170
outputs = llm.generate(
146171
prompts=("Generate a sql statement that selects col_1 from "
147172
"table_1 where it is equal to 1"),
@@ -178,13 +203,14 @@ def test_guided_grammar_lark(
178203
model_name: str,
179204
):
180205
monkeypatch.setenv("VLLM_USE_V1", "1")
181-
llm = LLM(model=model_name, max_model_len=1024)
182-
sampling_params = SamplingParams(temperature=0.8,
183-
top_p=0.95,
184-
max_tokens=1000,
185-
guided_decoding=GuidedDecodingParams(
186-
grammar=sample_sql_lark,
187-
backend=guided_decoding_backend))
206+
llm = LLM(model=model_name,
207+
max_model_len=1024,
208+
guided_decoding_backend=guided_decoding_backend)
209+
sampling_params = SamplingParams(
210+
temperature=0.8,
211+
top_p=0.95,
212+
max_tokens=1000,
213+
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
188214
outputs = llm.generate(
189215
prompts=("Generate a sql statement that selects col_1 from "
190216
"table_1 where it is equal to 1"),
@@ -225,13 +251,14 @@ def test_guided_grammar_ebnf_invalid(
225251
model_name: str,
226252
):
227253
monkeypatch.setenv("VLLM_USE_V1", "1")
228-
llm = LLM(model=model_name, max_model_len=1024)
229-
sampling_params = SamplingParams(temperature=0.8,
230-
top_p=0.95,
231-
max_tokens=1000,
232-
guided_decoding=GuidedDecodingParams(
233-
grammar="not a grammar",
234-
backend=guided_decoding_backend))
254+
llm = LLM(model=model_name,
255+
max_model_len=1024,
256+
guided_decoding_backend=guided_decoding_backend)
257+
sampling_params = SamplingParams(
258+
temperature=0.8,
259+
top_p=0.95,
260+
max_tokens=1000,
261+
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
235262
with pytest.raises(ValueError,
236263
match="Failed to convert the grammar "
237264
"from Lark to EBNF."):
@@ -254,12 +281,13 @@ def test_guided_regex(
254281
model_name: str,
255282
):
256283
monkeypatch.setenv("VLLM_USE_V1", "1")
257-
llm = LLM(model=model_name, max_model_len=1024)
258-
sampling_params = SamplingParams(temperature=0.8,
259-
top_p=0.95,
260-
guided_decoding=GuidedDecodingParams(
261-
regex=sample_regex,
262-
backend=guided_decoding_backend))
284+
llm = LLM(model=model_name,
285+
max_model_len=1024,
286+
guided_decoding_backend=guided_decoding_backend)
287+
sampling_params = SamplingParams(
288+
temperature=0.8,
289+
top_p=0.95,
290+
guided_decoding=GuidedDecodingParams(regex=sample_regex))
263291
outputs = llm.generate(
264292
prompts=[
265293
f"Give an example IPv4 address with this regex: {sample_regex}"
@@ -291,12 +319,13 @@ def test_guided_choice_completion(
291319
model_name: str,
292320
):
293321
monkeypatch.setenv("VLLM_USE_V1", "1")
294-
llm = LLM(model=model_name, max_model_len=1024)
295-
sampling_params = SamplingParams(temperature=0.8,
296-
top_p=0.95,
297-
guided_decoding=GuidedDecodingParams(
298-
choice=sample_guided_choice,
299-
backend=guided_decoding_backend))
322+
llm = LLM(model=model_name,
323+
max_model_len=1024,
324+
guided_decoding_backend=guided_decoding_backend)
325+
sampling_params = SamplingParams(
326+
temperature=0.8,
327+
top_p=0.95,
328+
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
300329
outputs = llm.generate(
301330
prompts="The best language for type-safe systems programming is ",
302331
sampling_params=sampling_params,

vllm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2788,7 +2788,7 @@ def __post_init__(self):
27882788
v0_valid_guided_backends = [
27892789
'outlines', 'lm-format-enforcer', 'xgrammar'
27902790
]
2791-
v1_valid_guided_backends = ['xgrammar', 'guidance']
2791+
v1_valid_guided_backends = ['xgrammar', 'guidance', 'auto']
27922792

27932793
backend = GuidedDecodingParams(
27942794
backend=self.guided_decoding_backend).backend_name

vllm/engine/arg_utils.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -382,16 +382,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
382382
default='xgrammar',
383383
help='Which engine will be used for guided decoding'
384384
' (JSON schema / regex etc) by default. Currently support '
385-
'https://github.com/outlines-dev/outlines, '
386-
'https://github.com/mlc-ai/xgrammar, and '
387-
'https://github.com/noamgat/lm-format-enforcer.'
388-
' Can be overridden per request via guided_decoding_backend'
389-
' parameter.\n'
390-
'Backend-specific options can be supplied in a comma-separated '
391-
'list following a colon after the backend name. Valid backends and '
392-
'all available options are: [xgrammar:no-fallback, '
393-
'xgrammar:disable-any-whitespace, '
394-
'outlines:no-fallback, lm-format-enforcer:no-fallback]')
385+
'https://github.com/mlc-ai/xgrammar and '
386+
'https://github.com/guidance-ai/llguidance.'
387+
'Valid backend values are "xgrammar", "guidance", and "auto". '
388+
'With "auto", we will make opinionated choices based on request'
389+
'contents and what the backend libraries currently support, so '
390+
'the behavior is subject to change in each release. '
391+
'The default is xgrammar.')
395392
parser.add_argument(
396393
'--logits-processor-pattern',
397394
type=nullable_str,
@@ -1461,8 +1458,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
14611458
recommend_to_remove=False)
14621459
return False
14631460

1464-
# Only support Xgrammar for guided decoding so far.
1465-
SUPPORTED_GUIDED_DECODING = ["xgrammar", "xgrammar:nofallback"]
1461+
# Support xgrammar, guidance, or an opinionated automatic mode.
1462+
SUPPORTED_GUIDED_DECODING = ["xgrammar", "guidance", "auto"]
14661463
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
14671464
_raise_or_fallback(feature_name="--guided-decoding-backend",
14681465
recommend_to_remove=False)

vllm/v1/engine/processor.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,16 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
121121
if not params.guided_decoding or not self.decoding_config:
122122
return
123123

124-
supported_backends = ["xgrammar", "guidance"]
124+
# Platform validation
125+
if vllm.platforms.current_platform.is_tpu():
126+
raise ValueError("Structured output is not supported on TPU.")
127+
128+
# Backend validation
129+
# - ensure backend is supported in v1
130+
# - if a backend was included in the request, ensure it matches
131+
# the backend configured for the engine. We don't support changing
132+
# the backend per request in V1.
133+
supported_backends = ["auto", "xgrammar", "guidance"]
125134
engine_level_backend = self.decoding_config.guided_decoding_backend
126135
if engine_level_backend not in supported_backends:
127136
raise ValueError(f"Only {supported_backends} structured output is "
@@ -135,11 +144,24 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
135144
else:
136145
params.guided_decoding.backend = engine_level_backend
137146

138-
if vllm.platforms.current_platform.is_tpu():
139-
raise ValueError("Structured output is not supported on TPU.")
147+
# Request content validation
140148

141149
if engine_level_backend == "xgrammar":
150+
# xgrammar with no fallback
142151
validate_structured_output_request_xgrammar(params)
152+
params.guided_decoding.backend = "xgrammar"
153+
elif engine_level_backend == "auto":
154+
# "auto" is an opt-in to opinionated behavior where we try to
155+
# choose a backend based on request contents. This is not the
156+
# default as it is less predictable and subject to change
157+
# between releases as feature support changes.
158+
try:
159+
validate_structured_output_request_xgrammar(params)
160+
params.guided_decoding.backend = "xgrammar"
161+
except ValueError:
162+
# The request includes some jsonschema feature(s) that
163+
# are not supported in xgrammar. Fall back to guidance.
164+
params.guided_decoding.backend = "guidance"
143165

144166
def process_inputs(
145167
self,

0 commit comments

Comments
 (0)