Skip to content

Commit 3d6eb7b

Browse files
gaocegegeshreyankg
authored andcommitted
[v0][structured output] Support reasoning output (vllm-project#12955)
Signed-off-by: Ce Gao <cegao@tensorchord.ai>
1 parent e3f067a commit 3d6eb7b

File tree

16 files changed

+400
-76
lines changed

16 files changed

+400
-76
lines changed

docs/source/features/reasoning_outputs.md

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,13 @@ Streaming chat completions are also supported for reasoning models. The `reasoni
7676
}
7777
```
7878

79-
Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests.
79+
Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests. You could checkout the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py).
80+
81+
## Limitations
82+
83+
- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
84+
- It is not compatible with [`tool_calling`](#tool_calling).
85+
- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning.
8086

8187
## How to support a new reasoning model
8288

@@ -137,15 +143,36 @@ class ExampleParser(ReasoningParser):
137143
"""
138144
```
139145

140-
After defining the reasoning parser, you can use it by specifying the `--reasoning-parser` flag when making a request to the chat completion endpoint.
146+
Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in `vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py`.
147+
148+
```python
149+
@dataclass
150+
class DeepSeekReasoner(Reasoner):
151+
"""
152+
Reasoner for DeepSeek R series models.
153+
"""
154+
start_token_id: int
155+
end_token_id: int
156+
157+
start_token: str = "<think>"
158+
end_token: str = "</think>"
159+
160+
@classmethod
161+
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
162+
return cls(start_token_id=tokenizer.encode(
163+
"<think>", add_special_tokens=False)[0],
164+
end_token_id=tokenizer.encode("</think>",
165+
add_special_tokens=False)[0])
166+
167+
def is_reasoning_end(self, input_ids: list[int]) -> bool:
168+
return self.end_token_id in input_ids
169+
```
170+
171+
The structured output engine like xgrammar will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case.
172+
173+
Finally, you can enable reasoning for the model by using the `--enable-reasoning` and `--reasoning-parser` flags.
141174

142175
```bash
143176
vllm serve <model_tag> \
144177
--enable-reasoning --reasoning-parser example
145178
```
146-
147-
## Limitations
148-
149-
- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
150-
- It is not compatible with the [`structured_outputs`](#structured_outputs) and [`tool_calling`](#tool_calling) features.
151-
- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning.
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
An example shows how to generate structured outputs from reasoning models
4+
like DeepSeekR1. The thinking process will not be guided by the JSON
5+
schema provided by the user. Only the final output will be structured.
6+
7+
To run this example, you need to start the vLLM server with the reasoning
8+
parser:
9+
10+
```bash
11+
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
12+
--enable-reasoning --reasoning-parser deepseek_r1
13+
```
14+
15+
This example demonstrates how to generate chat completions from reasoning models
16+
using the OpenAI Python client library.
17+
"""
18+
19+
from enum import Enum
20+
21+
from openai import OpenAI
22+
from pydantic import BaseModel
23+
24+
# Modify OpenAI's API key and API base to use vLLM's API server.
25+
openai_api_key = "EMPTY"
26+
openai_api_base = "http://localhost:8000/v1"
27+
28+
client = OpenAI(
29+
api_key=openai_api_key,
30+
base_url=openai_api_base,
31+
)
32+
33+
models = client.models.list()
34+
model = models.data[0].id
35+
36+
37+
# Guided decoding by JSON using Pydantic schema
38+
class CarType(str, Enum):
39+
sedan = "sedan"
40+
suv = "SUV"
41+
truck = "Truck"
42+
coupe = "Coupe"
43+
44+
45+
class CarDescription(BaseModel):
46+
brand: str
47+
model: str
48+
car_type: CarType
49+
50+
51+
json_schema = CarDescription.model_json_schema()
52+
53+
prompt = ("Generate a JSON with the brand, model and car_type of"
54+
"the most iconic car from the 90's, think in 100 tokens")
55+
completion = client.chat.completions.create(
56+
model=model,
57+
messages=[{
58+
"role": "user",
59+
"content": prompt,
60+
}],
61+
extra_body={"guided_json": json_schema},
62+
)
63+
print("content", completion.choices[0].message.content)
64+
print("reasoning_content: ", completion.choices[0].message.reasoning_content)

tests/model_executor/test_guided_processors.py

Lines changed: 102 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,41 @@
1616

1717
MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
1818
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
19+
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"]
20+
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
1921

2022

21-
def test_guided_logits_processors(sample_regex, sample_json_schema):
23+
# Initialize the tokenizer for the model here to avoid repeated loading
24+
@pytest.fixture(scope="module")
25+
def zephyr_7B_tokenzer():
26+
return AutoTokenizer.from_pretrained(MODEL_NAME)
27+
28+
29+
@pytest.fixture(scope="module")
30+
def deepseek_r1_qwen_tokenizer():
31+
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
32+
33+
34+
def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex,
35+
sample_json_schema):
2236
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
23-
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
24-
regex_LP = RegexLogitsProcessor(sample_regex, tokenizer)
37+
regex_LP = RegexLogitsProcessor(sample_regex,
38+
zephyr_7B_tokenzer,
39+
reasoner=None)
2540
json_LP = JSONLogitsProcessor(sample_json_schema,
26-
tokenizer,
27-
whitespace_pattern=None)
41+
zephyr_7B_tokenzer,
42+
whitespace_pattern=None,
43+
reasoner=None)
2844

29-
token_ids = tokenizer.encode(
45+
token_ids = zephyr_7B_tokenzer.encode(
3046
f"Give an example IPv4 address with this regex: {sample_regex}")
3147
tensor = torch.rand(32000)
3248
original_tensor = torch.clone(tensor)
3349
regex_LP(token_ids, tensor)
3450
assert tensor.shape == original_tensor.shape
3551
assert not torch.allclose(tensor, original_tensor)
3652

37-
token_ids = tokenizer.encode(
53+
token_ids = zephyr_7B_tokenzer.encode(
3854
f"Give an employee profile that fits this schema: {sample_json_schema}"
3955
)
4056
tensor = torch.rand(32000)
@@ -49,7 +65,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
4965
@pytest.mark.parametrize("is_local", [True, False])
5066
async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
5167
sample_regex,
52-
sample_json_schema):
68+
sample_json_schema,
69+
zephyr_7B_tokenzer):
5370

5471
config = ModelConfig(
5572
MODEL_NAME,
@@ -60,29 +77,100 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
6077
seed=0,
6178
dtype="bfloat16",
6279
)
63-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
64-
token_ids = tokenizer.encode(
80+
token_ids = zephyr_7B_tokenzer.encode(
6581
f"Give an example IPv4 address with this regex: {sample_regex}")
6682
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
6783

6884
regex_lp = get_local_guided_decoding_logits_processor(
69-
regex_request, tokenizer, config) if is_local else \
85+
regex_request, zephyr_7B_tokenzer, config) if is_local else \
7086
await get_guided_decoding_logits_processor(
71-
regex_request, tokenizer, config)
87+
regex_request, zephyr_7B_tokenzer, config)
7288
assert regex_lp is not None
7389
tensor = torch.rand(32000)
7490
original_tensor = torch.clone(tensor)
7591
tensor = regex_lp(token_ids, tensor)
7692
assert tensor.shape == original_tensor.shape
7793
assert not torch.allclose(tensor, original_tensor)
7894

79-
token_ids = tokenizer.encode(
95+
token_ids = zephyr_7B_tokenzer.encode(
8096
f"Give an employee profile that fits this schema: {sample_json_schema}"
8197
)
8298
json_request = GuidedDecodingParams(json=sample_json_schema,
8399
backend=backend)
84100
json_lp = await get_guided_decoding_logits_processor(
85-
json_request, tokenizer, config)
101+
json_request, zephyr_7B_tokenzer, config)
102+
assert json_lp is not None
103+
tensor = torch.rand(32000)
104+
original_tensor = torch.clone(tensor)
105+
tensor = json_lp(token_ids, tensor)
106+
assert tensor.shape == original_tensor.shape
107+
assert not torch.allclose(tensor, original_tensor)
108+
109+
110+
@pytest.mark.asyncio
111+
@pytest.mark.parametrize("backend",
112+
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT)
113+
@pytest.mark.parametrize("is_local", [True, False])
114+
@pytest.mark.parametrize("reasoning_backend", ["deepseek_r1"])
115+
async def test_guided_logits_processor_with_reasoning(
116+
backend: str, is_local: bool, reasoning_backend: str, sample_regex,
117+
sample_json_schema, deepseek_r1_qwen_tokenizer):
118+
119+
config = ModelConfig(
120+
REASONING_MODEL_NAME,
121+
task="generate",
122+
tokenizer=REASONING_MODEL_NAME,
123+
tokenizer_mode="auto",
124+
trust_remote_code=False,
125+
seed=0,
126+
dtype="bfloat16",
127+
)
128+
token_ids = deepseek_r1_qwen_tokenizer.encode(
129+
f"Give an example IPv4 address with this regex: {sample_regex}."
130+
"<think>here is the thinking process")
131+
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
132+
133+
regex_lp = get_local_guided_decoding_logits_processor(regex_request,
134+
deepseek_r1_qwen_tokenizer, config,
135+
reasoning_backend) if is_local else \
136+
await get_guided_decoding_logits_processor(
137+
regex_request, deepseek_r1_qwen_tokenizer, config,
138+
reasoning_backend)
139+
assert regex_lp is not None
140+
tensor = torch.rand(32000)
141+
original_tensor = torch.clone(tensor)
142+
tensor = regex_lp(token_ids, tensor)
143+
assert tensor.shape == original_tensor.shape
144+
assert torch.allclose(tensor, original_tensor)
145+
146+
token_ids = deepseek_r1_qwen_tokenizer.encode(
147+
f"Give an employee profile that fits this schema: {sample_json_schema}."
148+
"<think>here is the thinking process")
149+
json_request = GuidedDecodingParams(json=sample_json_schema,
150+
backend=backend)
151+
json_lp = get_local_guided_decoding_logits_processor(
152+
json_request, deepseek_r1_qwen_tokenizer, config,
153+
reasoning_backend) if is_local else \
154+
await get_guided_decoding_logits_processor(
155+
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
156+
assert json_lp is not None
157+
tensor = torch.rand(32000)
158+
original_tensor = torch.clone(tensor)
159+
tensor = json_lp(token_ids, tensor)
160+
assert tensor.shape == original_tensor.shape
161+
assert torch.allclose(tensor, original_tensor)
162+
163+
# Thinking is over, so the tensor should change.
164+
token_ids = deepseek_r1_qwen_tokenizer.encode(
165+
f"Give an employee profile that fits this schema: {sample_json_schema}."
166+
"<think>here is the thinking process</think> Then")
167+
json_request = GuidedDecodingParams(json=sample_json_schema,
168+
backend=backend)
169+
json_lp = get_local_guided_decoding_logits_processor(
170+
json_request, deepseek_r1_qwen_tokenizer, config,
171+
reasoning_backend) if is_local else \
172+
await get_guided_decoding_logits_processor(
173+
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
86174
assert json_lp is not None
87175
tensor = torch.rand(32000)
88176
original_tensor = torch.clone(tensor)

vllm/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2717,6 +2717,8 @@ class DecodingConfig:
27172717
# 'outlines' / 'lm-format-enforcer' / 'xgrammar'
27182718
guided_decoding_backend: str = 'xgrammar'
27192719

2720+
reasoning_backend: Optional[str] = None
2721+
27202722
def compute_hash(self) -> str:
27212723
"""
27222724
WARNING: Whenever a new field is added to this config,

vllm/engine/arg_utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,8 @@ class EngineArgs:
214214
calculate_kv_scales: Optional[bool] = None
215215

216216
additional_config: Optional[Dict[str, Any]] = None
217+
enable_reasoning: Optional[bool] = None
218+
reasoning_parser: Optional[str] = None
217219

218220
def __post_init__(self):
219221
if not self.tokenizer:
@@ -1060,6 +1062,25 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
10601062
"Different platforms may support different configs. Make sure the "
10611063
"configs are valid for the platform you are using. The input format"
10621064
" is like '{\"config_key\":\"config_value\"}'")
1065+
1066+
parser.add_argument(
1067+
"--enable-reasoning",
1068+
action="store_true",
1069+
default=False,
1070+
help="Whether to enable reasoning_content for the model. "
1071+
"If enabled, the model will be able to generate reasoning content."
1072+
)
1073+
1074+
parser.add_argument(
1075+
"--reasoning-parser",
1076+
type=str,
1077+
choices=["deepseek_r1"],
1078+
default=None,
1079+
help=
1080+
"Select the reasoning parser depending on the model that you're "
1081+
"using. This is used to parse the reasoning content into OpenAI "
1082+
"API format. Required for ``--enable-reasoning``.")
1083+
10631084
return parser
10641085

10651086
@classmethod
@@ -1333,7 +1354,10 @@ def create_engine_config(self,
13331354
if self.enable_prompt_adapter else None
13341355

13351356
decoding_config = DecodingConfig(
1336-
guided_decoding_backend=self.guided_decoding_backend)
1357+
guided_decoding_backend=self.guided_decoding_backend,
1358+
reasoning_backend=self.reasoning_parser
1359+
if self.enable_reasoning else None,
1360+
)
13371361

13381362
show_hidden_metrics = False
13391363
if self.show_hidden_metrics_for_version is not None:

vllm/engine/async_llm_engine.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,7 @@ async def add_request_async(
509509
tokenizer=await self.get_tokenizer_async(lora_request),
510510
default_guided_backend=self.decoding_config.
511511
guided_decoding_backend,
512+
reasoning_backend=self.decoding_config.reasoning_backend,
512513
model_config=self.model_config)
513514

514515
self._add_processed_request(
@@ -530,7 +531,7 @@ async def check_health_async(self) -> None:
530531

531532
async def build_guided_decoding_logits_processor_async(
532533
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
533-
default_guided_backend: str,
534+
default_guided_backend: str, reasoning_backend: Optional[str],
534535
model_config: ModelConfig) -> SamplingParams:
535536
"""Constructs logits processors based on the guided_decoding,
536537
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
@@ -545,14 +546,18 @@ async def build_guided_decoding_logits_processor_async(
545546
sampling_params = copy.copy(sampling_params)
546547
guided_decoding = sampling_params.guided_decoding
547548

548-
logger.debug("Building guided decoding logits processor. "
549-
"Params: %s", guided_decoding)
549+
logger.info(
550+
"Building guided decoding logits processor. "
551+
"guided_decoding: %s%s", guided_decoding,
552+
f", reasoning_backend: {reasoning_backend}"
553+
if reasoning_backend is not None else "")
550554

551555
guided_decoding.backend = guided_decoding.backend or default_guided_backend
552556

553557
processor = await get_guided_decoding_logits_processor(
554558
guided_params=guided_decoding,
555559
tokenizer=tokenizer,
560+
reasoning_backend=reasoning_backend,
556561
model_config=model_config)
557562

558563
if processor:

0 commit comments

Comments
 (0)