Skip to content

Commit 1356ab0

Browse files
committed
rebase over multi-step and fix bugs
1 parent 1b2e046 commit 1356ab0

File tree

2 files changed

+40
-16
lines changed

2 files changed

+40
-16
lines changed

vllm/core/scheduler.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,12 +1103,16 @@ def schedule(
11031103
if not self.cache_config.enable_prefix_caching:
11041104
common_computed_block_nums = []
11051105

1106+
# TODO: Combine multi-step and async postprocessor
1107+
allow_output_proc_callback: bool = (
1108+
self.use_output_proc_callback
1109+
and not self.scheduler_config.is_multi_step)
1110+
11061111
# Create list of scheduled request ids
11071112
scheduled_ids: List[Tuple[ScheduledSequenceGroup,
11081113
SequenceGroupMetadata]] = []
11091114
# Create input data structures.
11101115
seq_group_metadata_list: List[SequenceGroupMetadata] = []
1111-
allow_output_proc_callback: bool = False
11121116
for i, scheduled_seq_group in enumerate(
11131117
scheduler_outputs.scheduled_seq_groups):
11141118
seq_group = scheduled_seq_group.seq_group
@@ -1209,10 +1213,9 @@ def schedule(
12091213
)
12101214
seq_group_metadata_list.append(seq_group_metadata)
12111215

1212-
if self.use_output_proc_callback:
1213-
allow_output_proc_callback = (
1214-
allow_output_proc_callback
1215-
and self._allow_output_proc_callback(seq_group))
1216+
if allow_output_proc_callback:
1217+
allow_output_proc_callback = self._allow_output_proc_callback(
1218+
seq_group)
12161219

12171220
scheduled_ids.append((scheduled_seq_group, seq_group_metadata))
12181221
# Now that the batch has been created, we can assume all blocks in the

vllm/engine/async_llm_engine.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import vllm.envs as envs
1313
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
1414
ParallelConfig, SchedulerConfig)
15-
from vllm.core.scheduler import SchedulerOutputs
15+
from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs
1616
from vllm.engine.arg_utils import AsyncEngineArgs
1717
from vllm.engine.async_timeout import asyncio_timeout
1818
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
@@ -258,6 +258,9 @@ class SchedulerOutputState:
258258
last_output: Optional[SamplerOutput] = None
259259
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
260260
scheduler_outputs: Optional[SchedulerOutputs] = None
261+
scheduled_ids: Optional[List[Tuple[ScheduledSequenceGroup,
262+
SequenceGroupMetadata]]] = None
263+
allow_output_proc_callback: bool = False
261264

262265

263266
class _AsyncLLMEngine(LLMEngine):
@@ -288,22 +291,27 @@ async def step_async(
288291
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
289292
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
290293
scheduler_outputs = cached_outputs.scheduler_outputs
294+
scheduled_ids = cached_outputs.scheduled_ids
295+
allow_output_proc_callback = cached_outputs.allow_output_proc_callback
291296
# skip the scheduler if there are any remaining steps in the seq groups.
292297
# This ensures that the scheduler is only called again when the current
293298
# batch has completed.
294299
if not self._has_remaining_steps(seq_group_metadata_list):
295-
(seq_group_metadata_list, scheduler_outputs, scheduled_ids, allow_output_proc_callback) = self.scheduler[
296-
virtual_engine].schedule()
300+
(seq_group_metadata_list, scheduler_outputs, scheduled_ids,
301+
allow_output_proc_callback
302+
) = self.scheduler[virtual_engine].schedule()
297303

298304
if (self.scheduler_config.is_multi_step
299305
and scheduler_outputs.num_lookahead_slots > 0):
300306
# cache the scheduler outputs for the next iteration if we have
301307
# lookahead slots
302308
self._cache_scheduler_outputs_for_multi_step(
303-
virtual_engine, seq_group_metadata_list, scheduler_outputs)
309+
virtual_engine, seq_group_metadata_list, scheduler_outputs,
310+
scheduled_ids, allow_output_proc_callback)
304311

305312
assert seq_group_metadata_list is not None
306313
assert scheduler_outputs is not None
314+
assert scheduled_ids is not None
307315

308316
if not scheduler_outputs.is_empty():
309317
finished_requests_ids = self.scheduler[
@@ -328,6 +336,10 @@ async def step_async(
328336
# We use ExecuteModelRequest to pass the last sampled_token_ids
329337
# to each of the non-last PP stages for in-place prepare_input.
330338
last_sampled_token_ids=last_sampled_token_ids)
339+
340+
if allow_output_proc_callback:
341+
execute_model_req.callback_fn = self._process_model_outputs
342+
331343
# Execute the model.
332344
output = await self.model_executor.execute_model_async(
333345
execute_model_req)
@@ -350,17 +362,18 @@ async def step_async(
350362
if self.scheduler_config.is_multi_step:
351363
self.cached_scheduler_outputs[
352364
virtual_engine] = SchedulerOutputState()
353-
365+
354366
# Cache results in engine
355367
self.output_queue.append(
356-
(output, scheduled_ids, scheduler_outputs.ignored_seq_groups))
368+
(output, scheduled_ids, scheduler_outputs.ignored_seq_groups))
357369

358370
if (len(output) > 0) and allow_output_proc_callback:
359371
assert len(
360372
output
361373
) == 1, "Multi step decoding does not work with output processor callback" # noqa: E501
362-
self._advance_to_next_step(output[0], seq_group_metadata_list,
363-
scheduler_outputs.scheduled_seq_groups)
374+
self._advance_to_next_step(
375+
output[0], seq_group_metadata_list,
376+
scheduler_outputs.scheduled_seq_groups)
364377

365378
if not allow_output_proc_callback:
366379
self._process_model_outputs(is_async=False)
@@ -398,12 +411,20 @@ def _has_remaining_steps(
398411
def _cache_scheduler_outputs_for_multi_step(
399412
self, virtual_engine: int,
400413
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
401-
scheduler_outputs: SchedulerOutputs) -> None:
414+
scheduler_outputs: SchedulerOutputs,
415+
scheduled_ids: Optional[List[Tuple[ScheduledSequenceGroup,
416+
SequenceGroupMetadata]]],
417+
allow_output_proc_callback: bool) -> None:
418+
v = virtual_engine
402419
self.cached_scheduler_outputs[
403420
virtual_engine].seq_group_metadata_list = seq_group_metadata_list
404-
self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \
421+
self.cached_scheduler_outputs[v].scheduler_outputs = \
405422
scheduler_outputs
406-
self.cached_scheduler_outputs[virtual_engine].last_output = None
423+
self.cached_scheduler_outputs[v].scheduled_ids = \
424+
scheduled_ids
425+
self.cached_scheduler_outputs[v].allow_output_proc_callback = \
426+
allow_output_proc_callback
427+
self.cached_scheduler_outputs[v].last_output = None
407428

408429
def _get_last_sampled_token_ids(
409430
self, virtual_engine: int) -> Optional[torch.Tensor]:

0 commit comments

Comments
 (0)