1212import vllm .envs as envs
1313from vllm .config import (DecodingConfig , EngineConfig , LoRAConfig , ModelConfig ,
1414 ParallelConfig , SchedulerConfig )
15- from vllm .core .scheduler import SchedulerOutputs
15+ from vllm .core .scheduler import ScheduledSequenceGroup , SchedulerOutputs
1616from vllm .engine .arg_utils import AsyncEngineArgs
1717from vllm .engine .async_timeout import asyncio_timeout
1818from vllm .engine .llm_engine import (DecoderPromptComponents , LLMEngine ,
@@ -258,6 +258,9 @@ class SchedulerOutputState:
258258 last_output : Optional [SamplerOutput ] = None
259259 seq_group_metadata_list : Optional [List [SequenceGroupMetadata ]] = None
260260 scheduler_outputs : Optional [SchedulerOutputs ] = None
261+ scheduled_ids : Optional [List [Tuple [ScheduledSequenceGroup ,
262+ SequenceGroupMetadata ]]] = None
263+ allow_output_proc_callback : bool = False
261264
262265
263266class _AsyncLLMEngine (LLMEngine ):
@@ -288,22 +291,27 @@ async def step_async(
288291 cached_outputs = self .cached_scheduler_outputs [virtual_engine ]
289292 seq_group_metadata_list = cached_outputs .seq_group_metadata_list
290293 scheduler_outputs = cached_outputs .scheduler_outputs
294+ scheduled_ids = cached_outputs .scheduled_ids
295+ allow_output_proc_callback = cached_outputs .allow_output_proc_callback
291296 # skip the scheduler if there are any remaining steps in the seq groups.
292297 # This ensures that the scheduler is only called again when the current
293298 # batch has completed.
294299 if not self ._has_remaining_steps (seq_group_metadata_list ):
295- (seq_group_metadata_list , scheduler_outputs , scheduled_ids , allow_output_proc_callback ) = self .scheduler [
296- virtual_engine ].schedule ()
300+ (seq_group_metadata_list , scheduler_outputs , scheduled_ids ,
301+ allow_output_proc_callback
302+ ) = self .scheduler [virtual_engine ].schedule ()
297303
298304 if (self .scheduler_config .is_multi_step
299305 and scheduler_outputs .num_lookahead_slots > 0 ):
300306 # cache the scheduler outputs for the next iteration if we have
301307 # lookahead slots
302308 self ._cache_scheduler_outputs_for_multi_step (
303- virtual_engine , seq_group_metadata_list , scheduler_outputs )
309+ virtual_engine , seq_group_metadata_list , scheduler_outputs ,
310+ scheduled_ids , allow_output_proc_callback )
304311
305312 assert seq_group_metadata_list is not None
306313 assert scheduler_outputs is not None
314+ assert scheduled_ids is not None
307315
308316 if not scheduler_outputs .is_empty ():
309317 finished_requests_ids = self .scheduler [
@@ -328,6 +336,10 @@ async def step_async(
328336 # We use ExecuteModelRequest to pass the last sampled_token_ids
329337 # to each of the non-last PP stages for in-place prepare_input.
330338 last_sampled_token_ids = last_sampled_token_ids )
339+
340+ if allow_output_proc_callback :
341+ execute_model_req .callback_fn = self ._process_model_outputs
342+
331343 # Execute the model.
332344 output = await self .model_executor .execute_model_async (
333345 execute_model_req )
@@ -350,17 +362,18 @@ async def step_async(
350362 if self .scheduler_config .is_multi_step :
351363 self .cached_scheduler_outputs [
352364 virtual_engine ] = SchedulerOutputState ()
353-
365+
354366 # Cache results in engine
355367 self .output_queue .append (
356- (output , scheduled_ids , scheduler_outputs .ignored_seq_groups ))
368+ (output , scheduled_ids , scheduler_outputs .ignored_seq_groups ))
357369
358370 if (len (output ) > 0 ) and allow_output_proc_callback :
359371 assert len (
360372 output
361373 ) == 1 , "Multi step decoding does not work with output processor callback" # noqa: E501
362- self ._advance_to_next_step (output [0 ], seq_group_metadata_list ,
363- scheduler_outputs .scheduled_seq_groups )
374+ self ._advance_to_next_step (
375+ output [0 ], seq_group_metadata_list ,
376+ scheduler_outputs .scheduled_seq_groups )
364377
365378 if not allow_output_proc_callback :
366379 self ._process_model_outputs (is_async = False )
@@ -398,12 +411,20 @@ def _has_remaining_steps(
398411 def _cache_scheduler_outputs_for_multi_step (
399412 self , virtual_engine : int ,
400413 seq_group_metadata_list : Optional [List [SequenceGroupMetadata ]],
401- scheduler_outputs : SchedulerOutputs ) -> None :
414+ scheduler_outputs : SchedulerOutputs ,
415+ scheduled_ids : Optional [List [Tuple [ScheduledSequenceGroup ,
416+ SequenceGroupMetadata ]]],
417+ allow_output_proc_callback : bool ) -> None :
418+ v = virtual_engine
402419 self .cached_scheduler_outputs [
403420 virtual_engine ].seq_group_metadata_list = seq_group_metadata_list
404- self .cached_scheduler_outputs [virtual_engine ].scheduler_outputs = \
421+ self .cached_scheduler_outputs [v ].scheduler_outputs = \
405422 scheduler_outputs
406- self .cached_scheduler_outputs [virtual_engine ].last_output = None
423+ self .cached_scheduler_outputs [v ].scheduled_ids = \
424+ scheduled_ids
425+ self .cached_scheduler_outputs [v ].allow_output_proc_callback = \
426+ allow_output_proc_callback
427+ self .cached_scheduler_outputs [v ].last_output = None
407428
408429 def _get_last_sampled_token_ids (
409430 self , virtual_engine : int ) -> Optional [torch .Tensor ]:
0 commit comments