Skip to content

Commit 3838107

Browse files
committed
add guided decoding
Signed-off-by: wangli <wangli858794774@gmail.com>
1 parent 96d6fa7 commit 3838107

File tree

2 files changed

+170
-0
lines changed

2 files changed

+170
-0
lines changed

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ pytest >= 6.0
44
pytest-asyncio
55
lm-eval
66
ray
7+
types-jsonschema
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm/tests/entrypoints/llm/test_guided_generate.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
import gc
20+
import json
21+
import os
22+
import re
23+
import weakref
24+
25+
import jsonschema
26+
import pytest
27+
import torch
28+
from vllm.entrypoints.llm import LLM
29+
from vllm.outputs import RequestOutput
30+
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
31+
32+
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
33+
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
34+
GUIDED_DECODING_BACKENDS = [
35+
"outlines",
36+
"lm-format-enforcer",
37+
"xgrammar:disable-any-whitespace",
38+
]
39+
40+
41+
@pytest.fixture(scope="module")
42+
def sample_regex():
43+
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
44+
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
45+
46+
47+
@pytest.fixture(scope="module")
48+
def sample_json_schema():
49+
return {
50+
"type": "object",
51+
"properties": {
52+
"name": {
53+
"type": "string"
54+
},
55+
"age": {
56+
"type": "integer"
57+
},
58+
"skills": {
59+
"type": "array",
60+
"items": {
61+
"type": "string",
62+
"maxLength": 10
63+
},
64+
"minItems": 3
65+
},
66+
"work_history": {
67+
"type": "array",
68+
"items": {
69+
"type": "object",
70+
"properties": {
71+
"company": {
72+
"type": "string"
73+
},
74+
"duration": {
75+
"type": "number"
76+
},
77+
"position": {
78+
"type": "string"
79+
}
80+
},
81+
"required": ["company", "position"]
82+
}
83+
}
84+
},
85+
"required": ["name", "age", "skills", "work_history"]
86+
}
87+
88+
89+
def clean_up():
90+
gc.collect()
91+
torch.npu.empty_cache()
92+
93+
94+
@pytest.fixture(scope="module")
95+
def llm():
96+
# pytest caches the fixture so we use weakref.proxy to
97+
# enable garbage collection
98+
llm = LLM(model=MODEL_NAME, max_model_len=1024, seed=0)
99+
with llm.deprecate_legacy_api():
100+
yield weakref.proxy(llm)
101+
del llm
102+
clean_up()
103+
104+
105+
# TODO: Add v1 fully tested
106+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
107+
reason="v1 does not support guided decoding")
108+
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
109+
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str):
110+
sampling_params = SamplingParams(temperature=0.8,
111+
top_p=0.95,
112+
guided_decoding=GuidedDecodingParams(
113+
regex=sample_regex,
114+
backend=guided_decoding_backend))
115+
print(f"Using backend: {guided_decoding_backend}")
116+
outputs = llm.generate(prompts=[
117+
f"Give an example IPv4 address with this regex: {sample_regex}"
118+
] * 2,
119+
sampling_params=sampling_params,
120+
use_tqdm=True)
121+
122+
assert outputs is not None
123+
for output in outputs:
124+
assert output is not None
125+
assert isinstance(output, RequestOutput)
126+
prompt = output.prompt
127+
generated_text = output.outputs[0].text
128+
print(generated_text)
129+
assert generated_text is not None
130+
assert re.fullmatch(sample_regex, generated_text) is not None
131+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
132+
133+
134+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
135+
reason="v1 does not support guided decoding")
136+
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
137+
def test_guided_json_completion(sample_json_schema, llm,
138+
guided_decoding_backend: str):
139+
if guided_decoding_backend == "xgrammar:disable-any-whitespace":
140+
# xgrammar does not support json schema, will fall back to outlines, skip it
141+
pytest.skip(
142+
f"{guided_decoding_backend} does not support json schema validation"
143+
)
144+
145+
sampling_params = SamplingParams(temperature=1.0,
146+
max_tokens=1000,
147+
guided_decoding=GuidedDecodingParams(
148+
json=sample_json_schema,
149+
backend=guided_decoding_backend))
150+
print(f"Using backend: {guided_decoding_backend}")
151+
outputs = llm.generate(prompts=[
152+
f"Give an example JSON for an employee profile "
153+
f"that fits this schema: {sample_json_schema}"
154+
] * 2,
155+
sampling_params=sampling_params,
156+
use_tqdm=True)
157+
158+
assert outputs is not None
159+
160+
for output in outputs:
161+
assert output is not None
162+
assert isinstance(output, RequestOutput)
163+
prompt = output.prompt
164+
165+
generated_text = output.outputs[0].text
166+
assert generated_text is not None
167+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
168+
output_json = json.loads(generated_text)
169+
jsonschema.validate(instance=output_json, schema=sample_json_schema)

0 commit comments

Comments
 (0)