Skip to content

Commit 16f053c

Browse files
gaocegegechaunceyjiang
authored andcommitted
[Frontend] Skip stop in reasoning content (vllm-project#14550)
Signed-off-by: Ce Gao <cegao@tensorchord.ai> Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
1 parent 076c9dc commit 16f053c

File tree

3 files changed

+256
-4
lines changed

3 files changed

+256
-4
lines changed

tests/engine/test_stop_checker.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
from transformers import AutoTokenizer
6+
7+
from vllm.engine.output_processor.stop_checker import StopChecker
8+
from vllm.reasoning import ReasoningParser
9+
from vllm.sampling_params import SamplingParams
10+
from vllm.sequence import Sequence, SequenceStatus
11+
12+
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
13+
14+
15+
class MockReasoningParser(ReasoningParser):
16+
"""Mock reasoning parser for testing purposes."""
17+
18+
def __init__(self,
19+
tokenizer: AutoTokenizer,
20+
reasoning_active: bool = False):
21+
super().__init__(tokenizer)
22+
self.reasoning_active = reasoning_active
23+
24+
def is_reasoning_end(self, input_ids: list[int]) -> bool:
25+
return not self.reasoning_active
26+
27+
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
28+
return input_ids
29+
30+
31+
class MockSequence(Sequence):
32+
"""Mock sequence for testing purposes."""
33+
34+
def __init__(self, token_ids, output_text="test_output", eos_token_id=0):
35+
self.token_ids = token_ids
36+
self.output_text = output_text
37+
self.eos_token_id = eos_token_id
38+
self.status = SequenceStatus.RUNNING
39+
self.stop_reason = None
40+
41+
def get_token_ids(self):
42+
return self.token_ids
43+
44+
def get_last_token_id(self):
45+
return self.token_ids[-1] if self.token_ids else None
46+
47+
def get_len(self):
48+
return len(self.token_ids)
49+
50+
def get_output_len(self):
51+
return len(self.token_ids) - 1 # Simulating prompt + outputs
52+
53+
54+
@pytest.fixture
55+
def deepseek_r1_qwen_tokenizer():
56+
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
57+
58+
59+
@pytest.fixture
60+
def stop_checker():
61+
return StopChecker(max_model_len=10,
62+
get_tokenizer_for_seq=deepseek_r1_qwen_tokenizer)
63+
64+
65+
@pytest.fixture
66+
def stop_checker_with_reasoner():
67+
reasoner = MockReasoningParser(deepseek_r1_qwen_tokenizer)
68+
return StopChecker(max_model_len=10,
69+
get_tokenizer_for_seq=deepseek_r1_qwen_tokenizer,
70+
reasoner=reasoner)
71+
72+
73+
def test_eos_token_stopping(stop_checker):
74+
"""Test sequence stopping when EOS token is encountered."""
75+
seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0)
76+
sampling_params = SamplingParams()
77+
78+
stop_checker.maybe_stop_sequence(seq,
79+
new_char_count=1,
80+
sampling_params=sampling_params)
81+
82+
assert seq.status == SequenceStatus.FINISHED_STOPPED
83+
84+
85+
def test_ignore_eos(stop_checker):
86+
"""Test sequence continuing when EOS token is ignored."""
87+
seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0)
88+
sampling_params = SamplingParams(ignore_eos=True)
89+
90+
stop_checker.maybe_stop_sequence(seq,
91+
new_char_count=1,
92+
sampling_params=sampling_params)
93+
94+
assert seq.status == SequenceStatus.RUNNING
95+
96+
97+
def test_min_tokens(stop_checker):
98+
"""Test min_tokens prevents early stopping."""
99+
seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0)
100+
sampling_params = SamplingParams(min_tokens=3)
101+
102+
stop_checker.maybe_stop_sequence(seq,
103+
new_char_count=1,
104+
sampling_params=sampling_params)
105+
106+
assert seq.status == SequenceStatus.RUNNING
107+
108+
109+
def test_stop_token_ids(stop_checker):
110+
"""Test sequence stopping with custom stop token IDs."""
111+
seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0)
112+
sampling_params = SamplingParams(stop_token_ids=[3])
113+
114+
stop_checker.maybe_stop_sequence(seq,
115+
new_char_count=1,
116+
sampling_params=sampling_params)
117+
118+
assert seq.status == SequenceStatus.FINISHED_STOPPED
119+
assert seq.stop_reason == 3
120+
121+
122+
def test_stop_strings(stop_checker):
123+
"""Test sequence stopping with stop strings."""
124+
seq = MockSequence(token_ids=[1, 2, 3],
125+
output_text="test output with STOP",
126+
eos_token_id=0)
127+
sampling_params = SamplingParams(stop=["STOP"])
128+
129+
stop_checker.maybe_stop_sequence(seq,
130+
new_char_count=1,
131+
sampling_params=sampling_params)
132+
133+
assert seq.status == SequenceStatus.FINISHED_STOPPED
134+
assert seq.stop_reason == "STOP"
135+
assert "STOP" not in seq.output_text # Default behavior removes stop string
136+
137+
138+
def test_include_stop_str_in_output(stop_checker):
139+
"""Test keeping stop strings in output."""
140+
seq = MockSequence(token_ids=[1, 2, 3],
141+
output_text="test output with STOP",
142+
eos_token_id=0)
143+
sampling_params = SamplingParams(stop=["STOP"],
144+
include_stop_str_in_output=True)
145+
146+
stop_checker.maybe_stop_sequence(seq,
147+
new_char_count=5,
148+
sampling_params=sampling_params)
149+
150+
assert seq.status == SequenceStatus.FINISHED_STOPPED
151+
assert "STOP" in seq.output_text
152+
153+
154+
def test_max_tokens(stop_checker):
155+
"""Test sequence stopping at max_tokens."""
156+
seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0)
157+
sampling_params = SamplingParams(max_tokens=2)
158+
159+
stop_checker.maybe_stop_sequence(seq,
160+
new_char_count=1,
161+
sampling_params=sampling_params)
162+
163+
assert seq.status == SequenceStatus.FINISHED_LENGTH_CAPPED
164+
165+
166+
def test_max_model_len(stop_checker):
167+
"""Test sequence stopping at max_model_len."""
168+
seq = MockSequence(token_ids=list(range(11)),
169+
eos_token_id=0) # 11 tokens, max is 10
170+
sampling_params = SamplingParams()
171+
172+
stop_checker.maybe_stop_sequence(seq,
173+
new_char_count=1,
174+
sampling_params=sampling_params)
175+
176+
assert seq.status == SequenceStatus.FINISHED_LENGTH_CAPPED
177+
178+
179+
def test_reasoning_skip_stops(stop_checker_with_reasoner):
180+
"""Test that stop tokens and strings are ignored during reasoning."""
181+
# Set reasoning_active to True to simulate being in reasoning mode
182+
stop_checker_with_reasoner.reasoner.reasoning_active = True
183+
184+
# Test with stop token
185+
seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0)
186+
sampling_params = SamplingParams(stop_token_ids=[3])
187+
188+
stop_checker_with_reasoner.maybe_stop_sequence(
189+
seq, new_char_count=1, sampling_params=sampling_params)
190+
assert seq.status == SequenceStatus.RUNNING
191+
192+
# Test with stop string
193+
seq = MockSequence(token_ids=[1, 2, 3], output_text="test STOP")
194+
sampling_params = SamplingParams(stop=["STOP"])
195+
196+
stop_checker_with_reasoner.maybe_stop_sequence(
197+
seq, new_char_count=4, sampling_params=sampling_params)
198+
assert seq.status == SequenceStatus.RUNNING
199+
200+
# But EOS token still stops the sequence
201+
seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0)
202+
sampling_params = SamplingParams()
203+
204+
stop_checker_with_reasoner.maybe_stop_sequence(
205+
seq, new_char_count=1, sampling_params=sampling_params)
206+
assert seq.status == SequenceStatus.FINISHED_STOPPED
207+
208+
209+
def test_reasoning_end_enables_stops(stop_checker_with_reasoner):
210+
"""Test that stop tokens work after reasoning ends."""
211+
# Set reasoning_active to False to simulate being out of reasoning mode
212+
stop_checker_with_reasoner.reasoner.reasoning_active = False
213+
214+
# Test with stop token
215+
seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0)
216+
sampling_params = SamplingParams(stop_token_ids=[3])
217+
218+
stop_checker_with_reasoner.maybe_stop_sequence(
219+
seq, new_char_count=1, sampling_params=sampling_params)
220+
assert seq.status == SequenceStatus.FINISHED_STOPPED
221+
222+
# Test with stop string
223+
seq = MockSequence(token_ids=[1, 2, 3], output_text="test STOP")
224+
sampling_params = SamplingParams(stop=["STOP"])
225+
226+
stop_checker_with_reasoner.maybe_stop_sequence(
227+
seq, new_char_count=4, sampling_params=sampling_params)
228+
assert seq.status == SequenceStatus.FINISHED_STOPPED

vllm/engine/llm_engine.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from vllm.multimodal.processing import EncDecMultiModalProcessor
4141
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
4242
RequestOutputFactory)
43+
from vllm.reasoning import ReasoningParser, ReasoningParserManager
4344
from vllm.sampling_params import RequestOutputKind, SamplingParams
4445
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
4546
Sequence, SequenceGroup, SequenceGroupBase,
@@ -372,6 +373,14 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
372373
"vllm.llm_engine",
373374
self.observability_config.otlp_traces_endpoint)
374375

376+
# Initialize reasoning parser if reasoning backend is set.
377+
if self.decoding_config.reasoning_backend and \
378+
self.tokenizer:
379+
reasoner_class = ReasoningParserManager.get_reasoning_parser(
380+
self.decoding_config.reasoning_backend)
381+
self.reasoner: ReasoningParser = reasoner_class(
382+
self.tokenizer.get_lora_tokenizer())
383+
375384
# Create sequence output processor, e.g. for beam search or
376385
# speculative decoding.
377386
self.output_processor = (
@@ -381,8 +390,12 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
381390
self.scheduler,
382391
self.seq_counter,
383392
get_tokenizer_for_seq,
384-
stop_checker=StopChecker(self.scheduler_config.max_model_len,
385-
get_tokenizer_for_seq),
393+
stop_checker=StopChecker(
394+
self.scheduler_config.max_model_len,
395+
get_tokenizer_for_seq,
396+
self.reasoner if self.decoding_config.reasoning_backend
397+
and self.tokenizer else None,
398+
),
386399
))
387400

388401
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}

vllm/engine/output_processor/stop_checker.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Callable, List, Optional, Tuple
55

66
from vllm.lora.request import LoRARequest
7+
from vllm.reasoning import ReasoningParser
78
from vllm.sampling_params import SamplingParams
89
from vllm.sequence import Sequence, SequenceStatus
910
from vllm.transformers_utils.tokenizer import AnyTokenizer
@@ -16,11 +17,16 @@ class StopChecker:
1617
emitted, or if we have exceeded the max model len.
1718
"""
1819

19-
def __init__(self, max_model_len: int,
20-
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]):
20+
def __init__(
21+
self,
22+
max_model_len: int,
23+
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
24+
reasoner: Optional[ReasoningParser] = None,
25+
):
2126
# Do not use it directly, but use `self._get_max_model_len`.
2227
self._max_model_len = max_model_len
2328
self.get_tokenizer_for_seq = get_tokenizer_for_seq
29+
self.reasoner = reasoner
2430

2531
def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
2632
if lora_req and lora_req.long_lora_max_len:
@@ -57,6 +63,11 @@ def maybe_stop_sequence(
5763
seq.status = SequenceStatus.FINISHED_STOPPED
5864
return
5965

66+
# Skip stop string/token checks if in reasoning content generation
67+
if self.reasoner is not None and \
68+
not self.reasoner.is_reasoning_end(seq.get_token_ids()):
69+
return
70+
6071
# Check if a stop token was encountered.
6172
# This assumes a single token produced per step.
6273
last_token_id = seq.get_last_token_id()

0 commit comments

Comments
 (0)