Skip to content

Commit 31e8138

Browse files
committed
add guided decoding
Signed-off-by: wangli <wangli858794774@gmail.com>
1 parent 14d9a64 commit 31e8138

File tree

7 files changed

+144
-1
lines changed

7 files changed

+144
-1
lines changed

fusion_result.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
null

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
modelscope
33
pytest >= 6.0
44
pytest-asyncio
5+
types-jsonschema

tests/__init__.py

Whitespace-only changes.

tests/entrypoints/__init__.py

Whitespace-only changes.

tests/entrypoints/conftest.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Inspired from https://github.com/vllm-project/vllm/blob/main/tests/entrypoints/llm/test_guided_generate.py
2+
import pytest
3+
4+
5+
@pytest.fixture
6+
def sample_regex():
7+
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
8+
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
9+
10+
11+
@pytest.fixture
12+
def sample_json_schema():
13+
return {
14+
"type": "object",
15+
"properties": {
16+
"name": {
17+
"type": "string"
18+
},
19+
"age": {
20+
"type": "integer"
21+
},
22+
"skills": {
23+
"type": "array",
24+
"items": {
25+
"type": "string",
26+
"maxLength": 10
27+
},
28+
"minItems": 3
29+
},
30+
"work_history": {
31+
"type": "array",
32+
"items": {
33+
"type": "object",
34+
"properties": {
35+
"company": {
36+
"type": "string"
37+
},
38+
"duration": {
39+
"type": "number"
40+
},
41+
"position": {
42+
"type": "string"
43+
}
44+
},
45+
"required": ["company", "position"]
46+
}
47+
}
48+
},
49+
"required": ["name", "age", "skills", "work_history"]
50+
}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Inspired from https://github.com/vllm-project/vllm/blob/main/tests/entrypoints/llm/test_guided_generate.py
2+
import gc
3+
import json
4+
import os
5+
import re
6+
import weakref
7+
8+
import jsonschema
9+
import pytest
10+
import torch
11+
from vllm.entrypoints.llm import LLM
12+
from vllm.outputs import RequestOutput
13+
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
14+
15+
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
16+
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
17+
GUIDED_DECODING_BACKENDS = [
18+
"outlines",
19+
"lm-format-enforcer",
20+
"xgrammar",
21+
]
22+
23+
24+
def clean_up():
25+
gc.collect()
26+
torch.npu.empty_cache()
27+
28+
29+
@pytest.fixture(scope="module")
30+
def llm():
31+
# pytest caches the fixture so we use weakref.proxy to
32+
# enable garbage collection
33+
llm = LLM(model=MODEL_NAME, max_model_len=1024, seed=0)
34+
with llm.deprecate_legacy_api():
35+
yield weakref.proxy(llm)
36+
del llm
37+
clean_up()
38+
39+
40+
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
41+
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str):
42+
sampling_params = SamplingParams(temperature=0.8,
43+
top_p=0.95,
44+
guided_decoding=GuidedDecodingParams(
45+
regex=sample_regex,
46+
backend=guided_decoding_backend))
47+
outputs = llm.generate(prompts=[
48+
f"Give an example IPv4 address with this regex: {sample_regex}"
49+
] * 2,
50+
sampling_params=sampling_params,
51+
use_tqdm=True)
52+
53+
assert outputs is not None
54+
for output in outputs:
55+
assert output is not None
56+
assert isinstance(output, RequestOutput)
57+
prompt = output.prompt
58+
generated_text = output.outputs[0].text
59+
print(generated_text)
60+
assert generated_text is not None
61+
assert re.fullmatch(sample_regex, generated_text) is not None
62+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
63+
64+
65+
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
66+
def test_guided_json_completion(sample_json_schema, llm,
67+
guided_decoding_backend: str):
68+
sampling_params = SamplingParams(temperature=1.0,
69+
max_tokens=1000,
70+
guided_decoding=GuidedDecodingParams(
71+
json=sample_json_schema,
72+
backend=guided_decoding_backend))
73+
outputs = llm.generate(prompts=[
74+
f"Give an example JSON for an employee profile "
75+
f"that fits this schema: {sample_json_schema}"
76+
] * 2,
77+
sampling_params=sampling_params,
78+
use_tqdm=True)
79+
80+
assert outputs is not None
81+
82+
for output in outputs:
83+
assert output is not None
84+
assert isinstance(output, RequestOutput)
85+
prompt = output.prompt
86+
87+
generated_text = output.outputs[0].text
88+
assert generated_text is not None
89+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
90+
output_json = json.loads(generated_text)
91+
jsonschema.validate(instance=output_json, schema=sample_json_schema)

tests/test_offline_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424

2525
import pytest
2626
import vllm # noqa: F401
27-
from conftest import VllmRunner
2827

2928
import vllm_ascend # noqa: F401
29+
from tests.conftest import VllmRunner
3030

3131
MODELS = [
3232
"Qwen/Qwen2.5-0.5B-Instruct",

0 commit comments

Comments
 (0)