Skip to content

Commit 3a66eb7

Browse files
alexm-redhatAlvant
authored andcommitted
[Core] Async_output_proc: Add virtual engine support (towards pipeline parallel) (vllm-project#7911)
Signed-off-by: Alvant <alvasian@yandex.ru>
1 parent 5cf2382 commit 3a66eb7

File tree

6 files changed

+122
-67
lines changed

6 files changed

+122
-67
lines changed

vllm/core/scheduler.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def __init__(
302302
cache_config: CacheConfig,
303303
lora_config: Optional[LoRAConfig],
304304
pipeline_parallel_size: int = 1,
305-
output_proc_callback_fn: Optional[Callable] = None,
305+
output_proc_callback: Optional[Callable] = None,
306306
) -> None:
307307
self.scheduler_config = scheduler_config
308308
self.cache_config = cache_config
@@ -376,8 +376,8 @@ def __init__(
376376
# iterations. I.e. since the output processing is lagged one step,
377377
# we cannot reuse the cached objects immediately when the schedule()
378378
# is called again, but only when schedule() is called the second time.
379-
self.output_proc_callback_fn = output_proc_callback_fn
380-
self.use_async_output_proc = self.output_proc_callback_fn is not None
379+
self.output_proc_callback = output_proc_callback
380+
self.use_async_output_proc = self.output_proc_callback is not None
381381
self.num_cache_iters = 2 if self.use_async_output_proc else 1
382382

383383
self.cache_id = 0
@@ -573,8 +573,8 @@ def _schedule_running(
573573
seq_group):
574574
tmp = self.running
575575
self.running = orig_running
576-
assert self.output_proc_callback_fn is not None
577-
self.output_proc_callback_fn(is_async=True)
576+
assert self.output_proc_callback is not None
577+
self.output_proc_callback()
578578
self.running = tmp
579579

580580
while not self._can_append_slots(seq_group):
@@ -1091,7 +1091,6 @@ def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
10911091
no_beam_search = seq_group.sampling_params is None or (
10921092
seq_group.sampling_params.best_of == 1
10931093
and not seq_group.sampling_params.use_beam_search)
1094-
10951094
return no_beam_search
10961095

10971096
def schedule(

vllm/engine/async_llm_engine.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -279,19 +279,26 @@ async def step_async(
279279
scheduler_outputs = cached_outputs.scheduler_outputs
280280
allow_async_output_proc = cached_outputs.allow_async_output_proc
281281

282+
ctx = self.scheduler_contexts[virtual_engine]
283+
282284
# skip the scheduler if there are any remaining steps in the seq groups.
283285
# This ensures that the scheduler is only called again when the current
284286
# batch has completed.
285287
if not self._has_remaining_steps(seq_group_metadata_list):
288+
289+
# Clear outputs on scheduler iteration start
290+
ctx.request_outputs.clear()
291+
286292
(seq_group_metadata_list, scheduler_outputs,
287293
allow_async_output_proc
288294
) = self.scheduler[virtual_engine].schedule()
289295

290296
# If current scheduler iteration has no async postprocessor,
291297
# then we need first to drain the pending async postprocessor
292298
# before moving forward
293-
if not allow_async_output_proc and len(self.output_queue) > 0:
294-
self._process_model_outputs(is_async=True)
299+
if not allow_async_output_proc and len(ctx.output_queue) > 0:
300+
self._process_model_outputs(virtual_engine=virtual_engine,
301+
is_async=True)
295302

296303
if (self.scheduler_config.is_multi_step
297304
and scheduler_outputs.num_lookahead_slots > 0):
@@ -332,8 +339,8 @@ async def step_async(
332339
last_sampled_token_ids=last_sampled_token_ids)
333340

334341
if allow_async_output_proc:
335-
execute_model_req.output_proc_callback_fn = \
336-
self._process_model_outputs
342+
execute_model_req.async_callback = self.async_callback[
343+
virtual_engine]
337344

338345
# Execute the model.
339346
output = await self.model_executor.execute_model_async(
@@ -343,9 +350,10 @@ async def step_async(
343350
if self.scheduler_config.is_multi_step:
344351
self._update_cached_scheduler_output(virtual_engine, output)
345352
else:
346-
if len(self.output_queue) > 0:
353+
if len(ctx.output_queue) > 0:
347354
assert not self.scheduler_config.is_multi_step
348-
self._process_model_outputs(is_async=True)
355+
self._process_model_outputs(virtual_engine=virtual_engine,
356+
is_async=True)
349357
output = []
350358

351359
# Finish the current step for all the sequence groups.
@@ -360,7 +368,7 @@ async def step_async(
360368
virtual_engine] = SchedulerOutputState()
361369

362370
# Cache results in engine
363-
self.output_queue.append(
371+
ctx.output_queue.append(
364372
(output, seq_group_metadata_list, scheduler_outputs))
365373

366374
if output and allow_async_output_proc:
@@ -372,7 +380,8 @@ async def step_async(
372380
scheduler_outputs.scheduled_seq_groups)
373381

374382
if not allow_async_output_proc:
375-
self._process_model_outputs(is_async=False)
383+
self._process_model_outputs(virtual_engine=virtual_engine,
384+
is_async=False)
376385

377386
# Log stats.
378387
self.do_log_stats(scheduler_outputs, output)
@@ -381,9 +390,17 @@ async def step_async(
381390
self.do_tracing(scheduler_outputs)
382391

383392
else:
384-
self.request_outputs = []
393+
ctx.request_outputs = []
394+
395+
if not self.has_unfinished_requests():
396+
# Drain async postprocessor (if exists)
397+
if len(ctx.output_queue) > 0:
398+
assert not self.scheduler_config.is_multi_step
399+
self._process_model_outputs(virtual_engine=virtual_engine,
400+
is_async=True)
401+
assert len(ctx.output_queue) == 0
385402

386-
return self.request_outputs
403+
return ctx.request_outputs
387404

388405
async def stop_remote_worker_execution_loop_async(self) -> None:
389406
"""Stop the remote worker execution loop."""

vllm/engine/llm_engine.py

Lines changed: 79 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import functools
12
import time
23
from collections import deque
34
from contextlib import contextmanager
4-
from dataclasses import dataclass
5+
from dataclasses import dataclass, field
56
from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List,
67
Mapping, Optional)
78
from typing import Sequence as GenericSequence
@@ -88,6 +89,17 @@ class SchedulerOutputState:
8889
last_output: Optional[SamplerOutput] = None
8990

9091

92+
@dataclass
93+
class SchedulerContext:
94+
output_queue: Deque[Tuple[List[SamplerOutput], List[SequenceGroupMetadata],
95+
SchedulerOutputs]] = field(
96+
default_factory=lambda: deque())
97+
98+
request_outputs: List[Union[RequestOutput,
99+
EmbeddingRequestOutput]] = field(
100+
default_factory=lambda: [])
101+
102+
91103
class LLMEngine:
92104
"""An LLM engine that receives requests and generates texts.
93105
@@ -350,9 +362,11 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
350362
Scheduler(
351363
scheduler_config, cache_config, lora_config,
352364
parallel_config.pipeline_parallel_size,
353-
self._process_model_outputs
365+
functools.partial(self._process_model_outputs,
366+
virtual_engine=v_id,
367+
is_async=True)
354368
if model_config.use_async_output_proc else None)
355-
for _ in range(parallel_config.pipeline_parallel_size)
369+
for v_id in range(parallel_config.pipeline_parallel_size)
356370
]
357371

358372
# Metric Logging.
@@ -406,12 +420,17 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
406420
for _ in range(self.parallel_config.pipeline_parallel_size)
407421
]
408422

409-
# Async output processing pointers
410-
self.output_queue: Deque[Tuple[List[SamplerOutput],
411-
List[SequenceGroupMetadata],
412-
SchedulerOutputs]] = deque()
413-
self.request_outputs: List[Union[RequestOutput,
414-
EmbeddingRequestOutput]] = []
423+
self.scheduler_contexts = [
424+
SchedulerContext()
425+
for _ in range(self.parallel_config.pipeline_parallel_size)
426+
]
427+
428+
self.async_callback = [
429+
functools.partial(self._process_model_outputs,
430+
virtual_engine=v_id,
431+
is_async=True)
432+
for v_id in range(self.parallel_config.pipeline_parallel_size)
433+
]
415434

416435
def _initialize_kv_caches(self) -> None:
417436
"""Initialize the KV cache in the worker(s).
@@ -1265,32 +1284,28 @@ def _process_sequence_group_outputs(
12651284

12661285
return
12671286

1268-
def _process_model_outputs(self,
1269-
is_async: bool,
1270-
clear_outputs: bool = True) -> None:
1287+
def _process_model_outputs(self, virtual_engine: int,
1288+
is_async: bool) -> None:
12711289
"""Apply the model output to the sequences in the scheduled seq groups.
12721290
1291+
virtual_engine: The engine id to operate on
12731292
is_async: Indicates whether this postprocessor runs in
12741293
parallel with the GPU forward pass and is processing
12751294
tokens from the previous step. If this is true, then
12761295
no tokens need to be appended since it is already done
12771296
externally (before the next schedule() call)
1278-
clear_outputs: Sometimes existing outputs need to be combined
1279-
with outputs of this call. This happens for postprocessor
1280-
draining at the final stage (like when sequences are finished)
12811297
12821298
Returns RequestOutputs that can be returned to the client.
12831299
"""
12841300
now = time.time()
12851301

1286-
if clear_outputs:
1287-
self.request_outputs.clear()
1302+
ctx: SchedulerContext = self.scheduler_contexts[virtual_engine]
12881303

1289-
if len(self.output_queue) == 0:
1304+
if len(ctx.output_queue) == 0:
12901305
return None
12911306

12921307
(outputs, seq_group_metadata_list,
1293-
scheduler_outputs) = self.output_queue.popleft()
1308+
scheduler_outputs) = ctx.output_queue.popleft()
12941309

12951310
# Sanity check
12961311
assert len(seq_group_metadata_list) == len(
@@ -1365,11 +1380,11 @@ def _process_model_outputs(self,
13651380
if (seq_group.is_finished()
13661381
if self.step_return_finished_only else True):
13671382
request_output = RequestOutputFactory.create(seq_group)
1368-
self.request_outputs.append(request_output)
1383+
ctx.request_outputs.append(request_output)
13691384

13701385
for seq_group in scheduler_outputs.ignored_seq_groups:
13711386
request_output = RequestOutputFactory.create(seq_group)
1372-
self.request_outputs.append(request_output)
1387+
ctx.request_outputs.append(request_output)
13731388

13741389
if is_async:
13751390
# Log stats.
@@ -1465,29 +1480,43 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
14651480
"Pipeline parallelism is only supported through AsyncLLMEngine "
14661481
"as performance will be severely degraded otherwise.")
14671482

1483+
# For llm_engine, there is no pipeline parallel support, so the engine
1484+
# used is always 0
1485+
virtual_engine = 0
1486+
14681487
# These are cached outputs from previous iterations. None if on first
14691488
# iteration
1470-
cached_outputs = self.cached_scheduler_outputs[0]
1489+
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
14711490
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
14721491
scheduler_outputs = cached_outputs.scheduler_outputs
14731492
allow_async_output_proc = cached_outputs.allow_async_output_proc
14741493

1494+
ctx = self.scheduler_contexts[virtual_engine]
1495+
14751496
# Skip the scheduler if there are any remaining steps in the seq groups.
14761497
# This ensures that the scheduler is only called again when the current
14771498
# batch has completed.
14781499
if not self._has_remaining_steps(seq_group_metadata_list):
1500+
1501+
# Clear outputs on scheduler iteration start
1502+
ctx.request_outputs.clear()
1503+
1504+
# Schedule iteration
14791505
(seq_group_metadata_list, scheduler_outputs,
1480-
allow_async_output_proc) = self.scheduler[0].schedule()
1506+
allow_async_output_proc
1507+
) = self.scheduler[virtual_engine].schedule()
14811508

1482-
if not allow_async_output_proc and len(self.output_queue) > 0:
1483-
self._process_model_outputs(is_async=True)
1509+
# Maybe switch from async mode to sync mode
1510+
if not allow_async_output_proc and len(ctx.output_queue) > 0:
1511+
self._process_model_outputs(virtual_engine=virtual_engine,
1512+
is_async=True)
14841513

14851514
if (self.scheduler_config.is_multi_step
14861515
and scheduler_outputs.num_lookahead_slots > 0):
14871516
# cache the scheduler outputs for the next iteration if we have
14881517
# lookahead slots
14891518
self._cache_scheduler_outputs_for_multi_step(
1490-
0, seq_group_metadata_list, scheduler_outputs,
1519+
virtual_engine, seq_group_metadata_list, scheduler_outputs,
14911520
allow_async_output_proc)
14921521

14931522
assert seq_group_metadata_list is not None
@@ -1498,14 +1527,14 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
14981527

14991528
if not scheduler_outputs.is_empty():
15001529
finished_requests_ids = self.scheduler[
1501-
0].get_and_reset_finished_requests_ids()
1530+
virtual_engine].get_and_reset_finished_requests_ids()
15021531

15031532
# Check if we have a cached last_output from the previous iteration.
15041533
# For supporting PP this is probably the best way to pass the
15051534
# sampled_token_ids, as a separate broadcast over all the PP stages
15061535
# will cause one virtual engine's microbatch to block the pipeline.
15071536
last_sampled_token_ids = \
1508-
self._get_last_sampled_token_ids(0)
1537+
self._get_last_sampled_token_ids(virtual_engine)
15091538

15101539
execute_model_req = ExecuteModelRequest(
15111540
seq_group_metadata_list=seq_group_metadata_list,
@@ -1520,20 +1549,24 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
15201549
last_sampled_token_ids=last_sampled_token_ids)
15211550

15221551
if allow_async_output_proc:
1523-
execute_model_req.output_proc_callback_fn = \
1524-
self._process_model_outputs
1552+
execute_model_req.async_callback = self.async_callback[
1553+
virtual_engine]
15251554

15261555
output = self.model_executor.execute_model(
15271556
execute_model_req=execute_model_req)
15281557

1529-
# we need to do this here so that last step's sampled_token_ids can
1558+
# We need to do this here so that last step's sampled_token_ids can
15301559
# be passed to the next iteration for PP.
15311560
if self.scheduler_config.is_multi_step:
1532-
self._update_cached_scheduler_output(0, output)
1561+
self._update_cached_scheduler_output(virtual_engine, output)
15331562
else:
1534-
if len(self.output_queue) > 0:
1563+
# Nothing scheduled => If there is pending async postprocessor,
1564+
# then finish it here.
1565+
if len(ctx.output_queue) > 0:
15351566
assert not self.scheduler_config.is_multi_step
1536-
self._process_model_outputs(is_async=True)
1567+
self._process_model_outputs(virtual_engine=virtual_engine,
1568+
is_async=True)
1569+
# No outputs in this case
15371570
output = []
15381571

15391572
# Finish the current step for all the sequence groups.
@@ -1548,7 +1581,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
15481581

15491582
# Add results to the output_queue
15501583
# (for async or non-async postprocessing)
1551-
self.output_queue.append(
1584+
ctx.output_queue.append(
15521585
(output, seq_group_metadata_list, scheduler_outputs))
15531586

15541587
if output and allow_async_output_proc:
@@ -1559,23 +1592,27 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
15591592
output[0], seq_group_metadata_list,
15601593
scheduler_outputs.scheduled_seq_groups)
15611594

1595+
# Check if need to run the usual non-async path
15621596
if not allow_async_output_proc:
1563-
self._process_model_outputs(is_async=False)
1597+
self._process_model_outputs(virtual_engine=virtual_engine,
1598+
is_async=False)
15641599

15651600
# Log stats.
15661601
self.do_log_stats(scheduler_outputs, output)
15671602

15681603
# Tracing
15691604
self.do_tracing(scheduler_outputs)
15701605
else:
1571-
self.request_outputs = []
1606+
# Multi-step case
1607+
ctx.request_outputs = []
15721608

15731609
if not self.has_unfinished_requests():
1574-
# Drain async postprocessor
1575-
if len(self.output_queue) > 0:
1610+
# Drain async postprocessor (if exists)
1611+
if len(ctx.output_queue) > 0:
15761612
assert not self.scheduler_config.is_multi_step
1577-
self._process_model_outputs(is_async=True, clear_outputs=False)
1578-
assert len(self.output_queue) == 0
1613+
self._process_model_outputs(virtual_engine=virtual_engine,
1614+
is_async=True)
1615+
assert len(ctx.output_queue) == 0
15791616

15801617
# Stop the execute model loop in parallel workers until there are
15811618
# more requests to process. This avoids waiting indefinitely in
@@ -1584,7 +1621,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
15841621
# queued control plane messages, such as add/remove lora adapters.
15851622
self.model_executor.stop_remote_worker_execution_loop()
15861623

1587-
return self.request_outputs
1624+
return ctx.request_outputs
15881625

15891626
def _has_remaining_steps(
15901627
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]

0 commit comments

Comments
 (0)