Skip to content

Commit

Permalink
[Core] Optimize Async + Multi-step (vllm-project#8050)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-neuralmagic authored Sep 3, 2024
1 parent 95a178f commit 6d646d0
Show file tree
Hide file tree
Showing 8 changed files with 326 additions and 248 deletions.
4 changes: 2 additions & 2 deletions tests/multi_step/test_correctness_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ async def test_multi_step(
model,
server_args + distributed_args,
num_logprobs,
max_wait_seconds=3 * 240)
max_wait_seconds=5 * 240)
test_completions = await completions_with_server_args(
prompts,
model,
ms_server_args + distributed_args,
num_logprobs,
max_wait_seconds=3 * 240)
max_wait_seconds=5 * 240)

# Assert multi-step scheduling produces identical tokens
# to single-step scheduling.
Expand Down
109 changes: 54 additions & 55 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,40 +280,27 @@ async def step_async(
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc

# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)

ctx = self.scheduler_contexts[virtual_engine]

# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()

# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):

# Clear outputs on scheduler iteration start
ctx.request_outputs.clear()

# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()

# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs

# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)

# For async + multi-step, init the queue
if use_async_and_multi_step:
assert len(ctx.output_queue) == 0
assert seq_group_metadata_list is not None
ctx.output_queue.append(
(None, seq_group_metadata_list, scheduler_outputs))
self._process_model_outputs(ctx=ctx)

if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
Expand Down Expand Up @@ -351,26 +338,20 @@ async def step_async(
last_sampled_token_ids=last_sampled_token_ids)

if allow_async_output_proc:
async_callback = self.async_callback_multi_step[
virtual_engine] if use_async_and_multi_step \
else self.async_callback[virtual_engine]

execute_model_req.async_callback = async_callback
execute_model_req.use_async_and_multi_step = \
use_async_and_multi_step
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]

# Execute the model.
output = await self.model_executor.execute_model_async(
execute_model_req)

# we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
output = []

# Finish the current step for all the sequence groups.
Expand All @@ -384,24 +365,22 @@ async def step_async(
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()

if use_async_and_multi_step:
# For async + multi-step, clear the queue
ctx.output_queue.clear()
else:
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
is_async = allow_async_output_proc
is_last_step = True
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step))

if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if output and allow_async_output_proc:
assert len(
output
) == 1, "Async postprocessor expects only a single output set"
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)

if not allow_async_output_proc:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=False)
self._process_model_outputs(ctx=ctx)

# Log stats.
self.do_log_stats(scheduler_outputs, output)
Expand All @@ -411,17 +390,12 @@ async def step_async(

else:
# Multi-step case
if use_async_and_multi_step:
return []
else:
ctx.request_outputs = []
return ctx.request_outputs

if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
self._process_model_outputs(ctx=ctx)
assert len(ctx.output_queue) == 0

return ctx.request_outputs
Expand Down Expand Up @@ -640,6 +614,17 @@ def __init__(self,
self.log_requests = log_requests
self.engine = self._init_engine(*args, **kwargs)

# This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed,
# especially for multi-step.
#
# TODO: Currently, disabled for engine_use_ray, ask
# Cody/Will/Woosuk about this case.
self.use_process_request_outputs_callback = not self.engine_use_ray
if self.use_process_request_outputs_callback:
self.engine.process_request_outputs_callback = \
self.process_request_outputs

if self.engine_use_ray:
print_warning_once(
"DEPRECATED. `--engine-use-ray` is deprecated and will "
Expand Down Expand Up @@ -883,13 +868,27 @@ async def engine_step(self, virtual_engine: int) -> bool:
request_outputs = await self.engine.step_async(virtual_engine)

# Put the outputs into the corresponding streams.
finished = True
# If used as a callback, then already invoked inside
# LLMEngine's _process_model_outputs
if not self.use_process_request_outputs_callback:
all_finished = self.process_request_outputs(request_outputs)
else:
# For callback case, we only need to detect when all
# requests are finished
all_finished = all(request_output.finished
for request_output in request_outputs)

return not all_finished

def process_request_outputs(self, request_outputs) -> bool:
# Put the outputs into the corresponding streams.
all_finished = True
for request_output in request_outputs:
self._request_tracker.process_request_output(
request_output, verbose=self.log_requests)
finished = finished and request_output.finished
all_finished = all_finished and request_output.finished

return not finished
return all_finished

async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray:
Expand Down
Loading

0 comments on commit 6d646d0

Please sign in to comment.