1+ import functools
12import time
23from collections import deque
34from contextlib import contextmanager
4- from dataclasses import dataclass
5+ from dataclasses import dataclass , field
56from typing import (TYPE_CHECKING , Any , ClassVar , Deque , Dict , Iterable , List ,
67 Mapping , Optional )
78from typing import Sequence as GenericSequence
@@ -88,6 +89,17 @@ class SchedulerOutputState:
8889 last_output : Optional [SamplerOutput ] = None
8990
9091
92+ @dataclass
93+ class SchedulerContext :
94+ output_queue : Deque [Tuple [List [SamplerOutput ], List [SequenceGroupMetadata ],
95+ SchedulerOutputs ]] = field (
96+ default_factory = lambda : deque ())
97+
98+ request_outputs : List [Union [RequestOutput ,
99+ EmbeddingRequestOutput ]] = field (
100+ default_factory = lambda : [])
101+
102+
91103class LLMEngine :
92104 """An LLM engine that receives requests and generates texts.
93105
@@ -350,9 +362,11 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
350362 Scheduler (
351363 scheduler_config , cache_config , lora_config ,
352364 parallel_config .pipeline_parallel_size ,
353- self ._process_model_outputs
365+ functools .partial (self ._process_model_outputs ,
366+ virtual_engine = v_id ,
367+ is_async = True )
354368 if model_config .use_async_output_proc else None )
355- for _ in range (parallel_config .pipeline_parallel_size )
369+ for v_id in range (parallel_config .pipeline_parallel_size )
356370 ]
357371
358372 # Metric Logging.
@@ -406,12 +420,17 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
406420 for _ in range (self .parallel_config .pipeline_parallel_size )
407421 ]
408422
409- # Async output processing pointers
410- self .output_queue : Deque [Tuple [List [SamplerOutput ],
411- List [SequenceGroupMetadata ],
412- SchedulerOutputs ]] = deque ()
413- self .request_outputs : List [Union [RequestOutput ,
414- EmbeddingRequestOutput ]] = []
423+ self .scheduler_contexts = [
424+ SchedulerContext ()
425+ for _ in range (self .parallel_config .pipeline_parallel_size )
426+ ]
427+
428+ self .async_callback = [
429+ functools .partial (self ._process_model_outputs ,
430+ virtual_engine = v_id ,
431+ is_async = True )
432+ for v_id in range (self .parallel_config .pipeline_parallel_size )
433+ ]
415434
416435 def _initialize_kv_caches (self ) -> None :
417436 """Initialize the KV cache in the worker(s).
@@ -1265,32 +1284,28 @@ def _process_sequence_group_outputs(
12651284
12661285 return
12671286
1268- def _process_model_outputs (self ,
1269- is_async : bool ,
1270- clear_outputs : bool = True ) -> None :
1287+ def _process_model_outputs (self , virtual_engine : int ,
1288+ is_async : bool ) -> None :
12711289 """Apply the model output to the sequences in the scheduled seq groups.
12721290
1291+ virtual_engine: The engine id to operate on
12731292 is_async: Indicates whether this postprocessor runs in
12741293 parallel with the GPU forward pass and is processing
12751294 tokens from the previous step. If this is true, then
12761295 no tokens need to be appended since it is already done
12771296 externally (before the next schedule() call)
1278- clear_outputs: Sometimes existing outputs need to be combined
1279- with outputs of this call. This happens for postprocessor
1280- draining at the final stage (like when sequences are finished)
12811297
12821298 Returns RequestOutputs that can be returned to the client.
12831299 """
12841300 now = time .time ()
12851301
1286- if clear_outputs :
1287- self .request_outputs .clear ()
1302+ ctx : SchedulerContext = self .scheduler_contexts [virtual_engine ]
12881303
1289- if len (self .output_queue ) == 0 :
1304+ if len (ctx .output_queue ) == 0 :
12901305 return None
12911306
12921307 (outputs , seq_group_metadata_list ,
1293- scheduler_outputs ) = self .output_queue .popleft ()
1308+ scheduler_outputs ) = ctx .output_queue .popleft ()
12941309
12951310 # Sanity check
12961311 assert len (seq_group_metadata_list ) == len (
@@ -1365,11 +1380,11 @@ def _process_model_outputs(self,
13651380 if (seq_group .is_finished ()
13661381 if self .step_return_finished_only else True ):
13671382 request_output = RequestOutputFactory .create (seq_group )
1368- self .request_outputs .append (request_output )
1383+ ctx .request_outputs .append (request_output )
13691384
13701385 for seq_group in scheduler_outputs .ignored_seq_groups :
13711386 request_output = RequestOutputFactory .create (seq_group )
1372- self .request_outputs .append (request_output )
1387+ ctx .request_outputs .append (request_output )
13731388
13741389 if is_async :
13751390 # Log stats.
@@ -1465,29 +1480,43 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
14651480 "Pipeline parallelism is only supported through AsyncLLMEngine "
14661481 "as performance will be severely degraded otherwise." )
14671482
1483+ # For llm_engine, there is no pipeline parallel support, so the engine
1484+ # used is always 0
1485+ virtual_engine = 0
1486+
14681487 # These are cached outputs from previous iterations. None if on first
14691488 # iteration
1470- cached_outputs = self .cached_scheduler_outputs [0 ]
1489+ cached_outputs = self .cached_scheduler_outputs [virtual_engine ]
14711490 seq_group_metadata_list = cached_outputs .seq_group_metadata_list
14721491 scheduler_outputs = cached_outputs .scheduler_outputs
14731492 allow_async_output_proc = cached_outputs .allow_async_output_proc
14741493
1494+ ctx = self .scheduler_contexts [virtual_engine ]
1495+
14751496 # Skip the scheduler if there are any remaining steps in the seq groups.
14761497 # This ensures that the scheduler is only called again when the current
14771498 # batch has completed.
14781499 if not self ._has_remaining_steps (seq_group_metadata_list ):
1500+
1501+ # Clear outputs on scheduler iteration start
1502+ ctx .request_outputs .clear ()
1503+
1504+ # Schedule iteration
14791505 (seq_group_metadata_list , scheduler_outputs ,
1480- allow_async_output_proc ) = self .scheduler [0 ].schedule ()
1506+ allow_async_output_proc
1507+ ) = self .scheduler [virtual_engine ].schedule ()
14811508
1482- if not allow_async_output_proc and len (self .output_queue ) > 0 :
1483- self ._process_model_outputs (is_async = True )
1509+ # Maybe switch from async mode to sync mode
1510+ if not allow_async_output_proc and len (ctx .output_queue ) > 0 :
1511+ self ._process_model_outputs (virtual_engine = virtual_engine ,
1512+ is_async = True )
14841513
14851514 if (self .scheduler_config .is_multi_step
14861515 and scheduler_outputs .num_lookahead_slots > 0 ):
14871516 # cache the scheduler outputs for the next iteration if we have
14881517 # lookahead slots
14891518 self ._cache_scheduler_outputs_for_multi_step (
1490- 0 , seq_group_metadata_list , scheduler_outputs ,
1519+ virtual_engine , seq_group_metadata_list , scheduler_outputs ,
14911520 allow_async_output_proc )
14921521
14931522 assert seq_group_metadata_list is not None
@@ -1498,14 +1527,14 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
14981527
14991528 if not scheduler_outputs .is_empty ():
15001529 finished_requests_ids = self .scheduler [
1501- 0 ].get_and_reset_finished_requests_ids ()
1530+ virtual_engine ].get_and_reset_finished_requests_ids ()
15021531
15031532 # Check if we have a cached last_output from the previous iteration.
15041533 # For supporting PP this is probably the best way to pass the
15051534 # sampled_token_ids, as a separate broadcast over all the PP stages
15061535 # will cause one virtual engine's microbatch to block the pipeline.
15071536 last_sampled_token_ids = \
1508- self ._get_last_sampled_token_ids (0 )
1537+ self ._get_last_sampled_token_ids (virtual_engine )
15091538
15101539 execute_model_req = ExecuteModelRequest (
15111540 seq_group_metadata_list = seq_group_metadata_list ,
@@ -1520,20 +1549,24 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
15201549 last_sampled_token_ids = last_sampled_token_ids )
15211550
15221551 if allow_async_output_proc :
1523- execute_model_req .output_proc_callback_fn = \
1524- self . _process_model_outputs
1552+ execute_model_req .async_callback = self . async_callback [
1553+ virtual_engine ]
15251554
15261555 output = self .model_executor .execute_model (
15271556 execute_model_req = execute_model_req )
15281557
1529- # we need to do this here so that last step's sampled_token_ids can
1558+ # We need to do this here so that last step's sampled_token_ids can
15301559 # be passed to the next iteration for PP.
15311560 if self .scheduler_config .is_multi_step :
1532- self ._update_cached_scheduler_output (0 , output )
1561+ self ._update_cached_scheduler_output (virtual_engine , output )
15331562 else :
1534- if len (self .output_queue ) > 0 :
1563+ # Nothing scheduled => If there is pending async postprocessor,
1564+ # then finish it here.
1565+ if len (ctx .output_queue ) > 0 :
15351566 assert not self .scheduler_config .is_multi_step
1536- self ._process_model_outputs (is_async = True )
1567+ self ._process_model_outputs (virtual_engine = virtual_engine ,
1568+ is_async = True )
1569+ # No outputs in this case
15371570 output = []
15381571
15391572 # Finish the current step for all the sequence groups.
@@ -1548,7 +1581,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
15481581
15491582 # Add results to the output_queue
15501583 # (for async or non-async postprocessing)
1551- self .output_queue .append (
1584+ ctx .output_queue .append (
15521585 (output , seq_group_metadata_list , scheduler_outputs ))
15531586
15541587 if output and allow_async_output_proc :
@@ -1559,23 +1592,27 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
15591592 output [0 ], seq_group_metadata_list ,
15601593 scheduler_outputs .scheduled_seq_groups )
15611594
1595+ # Check if need to run the usual non-async path
15621596 if not allow_async_output_proc :
1563- self ._process_model_outputs (is_async = False )
1597+ self ._process_model_outputs (virtual_engine = virtual_engine ,
1598+ is_async = False )
15641599
15651600 # Log stats.
15661601 self .do_log_stats (scheduler_outputs , output )
15671602
15681603 # Tracing
15691604 self .do_tracing (scheduler_outputs )
15701605 else :
1571- self .request_outputs = []
1606+ # Multi-step case
1607+ ctx .request_outputs = []
15721608
15731609 if not self .has_unfinished_requests ():
1574- # Drain async postprocessor
1575- if len (self .output_queue ) > 0 :
1610+ # Drain async postprocessor (if exists)
1611+ if len (ctx .output_queue ) > 0 :
15761612 assert not self .scheduler_config .is_multi_step
1577- self ._process_model_outputs (is_async = True , clear_outputs = False )
1578- assert len (self .output_queue ) == 0
1613+ self ._process_model_outputs (virtual_engine = virtual_engine ,
1614+ is_async = True )
1615+ assert len (ctx .output_queue ) == 0
15791616
15801617 # Stop the execute model loop in parallel workers until there are
15811618 # more requests to process. This avoids waiting indefinitely in
@@ -1584,7 +1621,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
15841621 # queued control plane messages, such as add/remove lora adapters.
15851622 self .model_executor .stop_remote_worker_execution_loop ()
15861623
1587- return self .request_outputs
1624+ return ctx .request_outputs
15881625
15891626 def _has_remaining_steps (
15901627 self , seq_group_metadata_list : Optional [List [SequenceGroupMetadata ]]
0 commit comments