Skip to content

Commit fe2f183

Browse files
njhilllulmer
authored andcommitted
[BugFix][V1] Fix parallel sampling finishing/aborts (vllm-project#14512)
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
1 parent c61a56b commit fe2f183

File tree

7 files changed

+137
-113
lines changed

7 files changed

+137
-113
lines changed

tests/v1/engine/test_async_llm.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ async def generate(engine: AsyncLLM,
4646
prompt: PromptType,
4747
output_kind: RequestOutputKind,
4848
max_tokens: int,
49+
n: int = 1,
4950
prompt_logprobs: Optional[int] = None) -> tuple[int, str]:
5051
# Ensure generate doesn't complete too fast for cancellation test.
5152
await asyncio.sleep(0.2)
@@ -54,13 +55,15 @@ async def generate(engine: AsyncLLM,
5455
sampling_params = SamplingParams(max_tokens=max_tokens,
5556
ignore_eos=True,
5657
output_kind=output_kind,
57-
temperature=0,
58+
temperature=0.5,
59+
seed=33,
60+
n=n,
5861
prompt_logprobs=prompt_logprobs)
5962
async for out in engine.generate(request_id=request_id,
6063
prompt=prompt,
6164
sampling_params=sampling_params):
6265

63-
num_tokens = len(out.outputs[0].token_ids)
66+
num_tokens = sum(len(output.token_ids) for output in out.outputs)
6467
if output_kind == RequestOutputKind.DELTA:
6568
count += num_tokens
6669
else:
@@ -136,17 +139,22 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind,
136139

137140
NUM_REQUESTS = 100
138141
NUM_EXPECTED_TOKENS = 100
142+
NUM_EXPECTED_TOKENS_LONG = 50000
139143
REQUEST_IDS_TO_ABORT = range(1, 100, 10)
144+
PARALLEL_SAMPLE_REQ_IDS = range(1, 100, 15)
140145

141146
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
142147

143148
# Create concurrent requests.
144149
tasks: list[asyncio.Task] = []
145-
for request_id in request_ids:
150+
for idx, request_id in enumerate(request_ids):
151+
max_tokens = NUM_EXPECTED_TOKENS_LONG if (
152+
idx in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS
153+
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
146154
tasks.append(
147155
asyncio.create_task(
148156
generate(engine, request_id, prompt, output_kind,
149-
NUM_EXPECTED_TOKENS)))
157+
max_tokens, n)))
150158

151159
# API server cancels requests when they disconnect.
152160
for idx in REQUEST_IDS_TO_ABORT:
@@ -162,10 +170,13 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind,
162170
else:
163171
# Otherwise, make sure the request was not impacted.
164172
num_generated_tokens, request_id = await task
165-
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
173+
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
174+
expected_tokens = NUM_EXPECTED_TOKENS * n
175+
assert num_generated_tokens == expected_tokens, (
166176
f"{request_id} generated {num_generated_tokens} but "
167-
f"expected {NUM_EXPECTED_TOKENS}")
177+
f"expected {expected_tokens}")
168178

179+
# Make sure all aborted requests were really aborted.
169180
assert not engine.output_processor.has_unfinished_requests()
170181

171182
# Confirm we can do another generation.
@@ -176,3 +187,36 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind,
176187
num_generated_tokens, request_id = await task
177188
assert num_generated_tokens == NUM_EXPECTED_TOKENS
178189
assert not engine.output_processor.has_unfinished_requests()
190+
191+
192+
@pytest.mark.parametrize("n", [1, 3])
193+
@pytest.mark.parametrize("engine_args_and_prompt",
194+
[(TEXT_ENGINE_ARGS, TEXT_PROMPT),
195+
(VISION_ENGINE_ARGS, VISION_PROMPT)])
196+
@pytest.mark.asyncio
197+
async def test_finished_flag(monkeypatch, n: int,
198+
engine_args_and_prompt: tuple[AsyncEngineArgs,
199+
PromptType]):
200+
201+
with monkeypatch.context() as m, ExitStack() as after:
202+
m.setenv("VLLM_USE_V1", "1")
203+
engine_args, prompt = engine_args_and_prompt
204+
205+
engine = AsyncLLM.from_engine_args(engine_args)
206+
after.callback(engine.shutdown)
207+
208+
sampling_params = SamplingParams(max_tokens=100,
209+
output_kind=RequestOutputKind.DELTA,
210+
temperature=1.0,
211+
seed=33,
212+
n=n)
213+
outputs = [
214+
out
215+
async for out in engine.generate(request_id="request-33",
216+
prompt=prompt,
217+
sampling_params=sampling_params)
218+
]
219+
220+
# Assert only the last output has the finished flag set
221+
assert all(not out.finished for out in outputs[:-1])
222+
assert outputs[-1].finished

tests/v1/entrypoints/openai/test_completion.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -263,22 +263,24 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
263263

264264
prompt = "What is an LLM?"
265265
n = 3
266-
max_tokens = 5
266+
max_tokens = 50 # we want some to finish earlier than others
267267

268268
# High temperature to maximize chance of unique completions.
269269
completion = await client.completions.create(model=model_name,
270270
prompt=prompt,
271271
max_tokens=max_tokens,
272272
n=n,
273-
temperature=0.95,
273+
temperature=1.0,
274274
stream=False,
275+
logprobs=0,
275276
seed=42)
276277

277278
# Assert `n` completions
278279
num_completions = len(completion.choices)
279280
assert num_completions == n, (
280281
f"Num completions {num_completions} but expected {n}.")
281282
completion_repeats: dict[str, int] = {}
283+
output_token_lengths = set()
282284
for idx, choice in enumerate(completion.choices):
283285
# Assert correct completion index & some finish reason.
284286
assert choice.index == idx, (
@@ -287,6 +289,9 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
287289
"None finish_reason is invalid.")
288290
text = choice.text
289291
completion_repeats[text] = completion_repeats.get(text, 0) + 1
292+
output_token_lengths.add(len(choice.logprobs.tokens))
293+
# Assert subrequests finished at different times
294+
assert len(output_token_lengths) > 1
290295
# Assert `n` unique completions
291296
num_unique = len(completion_repeats)
292297
if num_unique != n:
@@ -312,16 +317,16 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
312317

313318
prompt = "What is an LLM?"
314319
n = 3
315-
max_tokens = 5
320+
max_tokens = 50 # we want some to finish earlier than others
316321

317322
stream = await client.completions.create(model=model_name,
318323
prompt=prompt,
319324
max_tokens=max_tokens,
320325
n=n,
321-
temperature=0.95,
326+
temperature=1.0,
322327
stream=True,
323328
seed=42)
324-
chunks: list[list[str]] = [[] for i in range(n)]
329+
chunks: list[list[str]] = [[] for _ in range(n)]
325330
finish_reason_count = 0
326331
async for chunk in stream:
327332
index = chunk.choices[0].index
@@ -333,14 +338,18 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
333338
assert finish_reason_count == n, (
334339
f"Expected {n} completions with valid indices and finish_reason.")
335340
completion_repeats: dict[str, int] = {}
341+
chunk_lengths = set()
336342
for chunk in chunks:
337343
chunk_len = len(chunk)
338344
# Assert correct number of completion tokens
339-
assert chunk_len == max_tokens, (
345+
chunk_lengths.add(chunk_len)
346+
assert chunk_len <= max_tokens, (
340347
f"max_tokens={max_tokens} but chunk len is {chunk_len}.")
341348
text = "".join(chunk)
342349
completion_repeats[text] = completion_repeats.get(text, 0) + 1
343350
print(text)
351+
# Assert subrequests finished at different times
352+
assert len(chunk_lengths) > 1
344353
# Assert `n` unique completions
345354
num_unique = len(completion_repeats)
346355
if num_unique != n:

vllm/outputs.py

Lines changed: 18 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -134,57 +134,29 @@ def __init__(
134134
self.encoder_prompt_token_ids = encoder_prompt_token_ids
135135
self.num_cached_tokens = num_cached_tokens
136136

137-
@classmethod
138-
def new(
139-
cls,
140-
request_id: str,
141-
prompt: Optional[str],
142-
prompt_token_ids: Optional[list[int]],
143-
text: str,
144-
token_ids: list[int],
145-
logprobs: Optional[SampleLogprobs],
146-
prompt_logprobs: Optional[PromptLogprobs],
147-
cumulative_logprob: Optional[float],
148-
finished: bool = False,
149-
) -> "RequestOutput":
150-
"""Initialize a new RequestOutput object."""
151-
152-
# TODO: Support `n` > 1.
153-
completion_output = CompletionOutput(
154-
index=0,
155-
text=text,
156-
token_ids=token_ids,
157-
cumulative_logprob=cumulative_logprob,
158-
logprobs=logprobs)
159-
160-
return RequestOutput(
161-
request_id=request_id,
162-
prompt=prompt,
163-
prompt_token_ids=prompt_token_ids,
164-
prompt_logprobs=prompt_logprobs,
165-
outputs=[completion_output],
166-
finished=finished,
167-
)
168-
169137
def add(self, next_output: "RequestOutput") -> None:
170138
"""Merge subsequent RequestOutput into this one"""
171139

172-
self.prompt = next_output.prompt
173-
self.prompt_token_ids = next_output.prompt_token_ids
174-
self.prompt_logprobs = next_output.prompt_logprobs
175140
self.finished |= next_output.finished
176141

177-
#TODO assuming n == 1 for now
178-
completion = self.outputs[0]
179-
next_completion = next_output.outputs[0]
180-
completion.text += next_completion.text
181-
if not isinstance(completion.token_ids, MutableSequence):
182-
completion.token_ids = list(completion.token_ids)
183-
completion.token_ids.extend(next_completion.token_ids)
184-
if next_completion.logprobs:
185-
assert completion.logprobs is not None
186-
completion.logprobs.extend(next_completion.logprobs)
187-
completion.cumulative_logprob = next_completion.cumulative_logprob
142+
for next_completion in next_output.outputs:
143+
for completion in self.outputs:
144+
if completion.index == next_completion.index:
145+
# Merge outputs with same index
146+
completion.text += next_completion.text
147+
if not isinstance(completion.token_ids, MutableSequence):
148+
completion.token_ids = list(completion.token_ids)
149+
completion.token_ids.extend(next_completion.token_ids)
150+
if next_completion.logprobs:
151+
assert completion.logprobs is not None
152+
completion.logprobs.extend(next_completion.logprobs)
153+
completion.cumulative_logprob = (
154+
next_completion.cumulative_logprob)
155+
completion.finish_reason = next_completion.finish_reason
156+
completion.stop_reason = next_completion.stop_reason
157+
break
158+
else:
159+
self.outputs.append(next_completion)
188160

189161
@classmethod
190162
def from_seq_group(

vllm/v1/engine/async_llm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,8 @@ async def _run_output_handler(self):
298298
async def abort(self, request_id: str) -> None:
299299
"""Abort RequestId in OutputProcessor and EngineCore."""
300300

301-
request_ids = [request_id]
301+
request_ids = self.output_processor.abort_requests((request_id, ))
302302
await self.engine_core.abort_requests_async(request_ids)
303-
self.output_processor.abort_requests(request_ids)
304303

305304
if self.log_requests:
306305
logger.info("Aborted request %s.", request_id)

vllm/v1/engine/llm_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ def validate_outputs(cls, outputs, output_type):
137137
def abort_request(self, request_ids: list[str]) -> None:
138138
"""Remove request_ids from EngineCore and Detokenizer."""
139139

140+
request_ids = self.output_processor.abort_requests(request_ids)
140141
self.engine_core.abort_requests(request_ids)
141-
self.output_processor.abort_requests(request_ids)
142142

143143
def add_request(
144144
self,

vllm/v1/engine/output_processor.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import asyncio
4+
from collections.abc import Iterable
45
from dataclasses import dataclass
56
from typing import Optional, Union
67

@@ -102,33 +103,32 @@ def make_request_output(
102103
) -> Optional[RequestOutput]:
103104

104105
finished = finish_reason is not None
105-
output_kind = self.output_kind
106-
final_only = output_kind == RequestOutputKind.FINAL_ONLY
106+
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
107107

108108
# In follow up, we will switch to invariant where EngineCore
109109
# does not stream partial prefills.
110110
if not finished and (self.is_prefilling or final_only):
111111
# Only the final output is required in FINAL_ONLY mode.
112112
return None
113113

114-
def new_request_output(request_id: str) -> RequestOutput:
115-
return self._new_request_output(request_id, finished)
116-
117114
completion_output = self._new_completion_output(
118115
new_token_ids, finish_reason, stop_reason)
119116

120-
if self.parent_req is not None:
121-
return self.parent_req.make_request_output(final_only,
122-
completion_output,
123-
new_request_output)
117+
request_id = self.request_id
118+
if self.parent_req is None:
119+
outputs = [completion_output]
120+
else:
121+
request_id, outputs, finished = self.parent_req.get_outputs(
122+
request_id, completion_output)
123+
if not outputs:
124+
return None
124125

125-
request_output = new_request_output(self.request_id)
126-
request_output.outputs.append(completion_output)
127-
return request_output
126+
return self._new_request_output(request_id, outputs, finished)
128127

129128
def _new_request_output(
130129
self,
131130
request_id: str,
131+
outputs: list[CompletionOutput],
132132
finished: bool,
133133
) -> RequestOutput:
134134

@@ -143,7 +143,7 @@ def _new_request_output(
143143
prompt=self.prompt,
144144
prompt_token_ids=self.prompt_token_ids,
145145
prompt_logprobs=prompt_logprobs,
146-
outputs=[],
146+
outputs=outputs,
147147
finished=finished,
148148
)
149149

@@ -188,6 +188,7 @@ def __init__(
188188
self.log_stats = log_stats
189189
self.tokenizer = tokenizer
190190
self.request_states: dict[str, RequestState] = {}
191+
self.parent_requests: dict[str, ParentRequest] = {}
191192
self.lora_states = LoRARequestStates()
192193

193194
def get_num_unfinished_requests(self):
@@ -198,14 +199,20 @@ def has_unfinished_requests(self) -> bool:
198199

199200
def abort_requests(
200201
self,
201-
request_ids: list[str],
202-
) -> None:
202+
request_ids: Iterable[str],
203+
) -> list[str]:
204+
request_ids_to_abort = []
203205
for request_id in request_ids:
204206
req_state = self.request_states.pop(request_id, None)
205207
if req_state is not None:
206208
self.lora_states.abort_request(req_state)
207-
if req_state.parent_req is not None:
208-
req_state.parent_req.finish_child_request(request_id)
209+
request_ids_to_abort.append(request_id)
210+
else:
211+
parent = self.parent_requests.pop(request_id, None)
212+
if parent and parent.child_requests:
213+
self.abort_requests(parent.child_requests)
214+
request_ids_to_abort.extend(parent.child_requests)
215+
return request_ids_to_abort
209216

210217
def add_request(
211218
self,
@@ -227,6 +234,8 @@ def add_request(
227234
log_stats=self.log_stats)
228235
self.request_states[request_id] = req_state
229236
self.lora_states.add_request(req_state)
237+
if parent_req:
238+
self.parent_requests[parent_req.request_id] = parent_req
230239

231240
def process_outputs(
232241
self,
@@ -314,12 +323,14 @@ def process_outputs(
314323
# Free completed requests.
315324
if finish_reason is not None:
316325
self.request_states.pop(req_id)
326+
# Remove parent request if applicable.
327+
parent_req = req_state.parent_req
328+
if parent_req and not parent_req.child_requests:
329+
self.parent_requests.pop(parent_req.request_id, None)
317330
if not engine_core_output.finished:
318331
# If req not finished in EngineCore, but Detokenizer
319332
# detected stop string, abort needed in EngineCore.
320333
reqs_to_abort.append(req_id)
321-
if req_state.parent_req is not None:
322-
req_state.parent_req.finish_child_request(req_id)
323334

324335
# Track per-request stats
325336
self._update_stats_from_finished(req_state, finish_reason,

0 commit comments

Comments
 (0)