Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Combine async postprocessor and multi-step #7921

Merged
merged 9 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions tests/multi_step/test_correctness_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ async def completions_with_server_args(prompts: List[str], model_name: str,
@pytest.mark.parametrize("eager_mode", [False, True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.asyncio
async def test_multi_step(example_prompts, model: str, tp_size: int,
pp_size: int, eager_mode: int,
num_scheduler_steps: int, num_prompts: int):
num_scheduler_steps: int, num_prompts: int,
is_async: bool):

prompts = example_prompts
if len(prompts) < num_prompts:
Expand All @@ -62,9 +64,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
ms_server_args = DEFAULT_SERVER_ARGS + \
["--num-scheduler-steps", f"{num_scheduler_steps}"]

# Disable output proc callback as its not supported
# with multi-step right now
ms_server_args += ["--disable-async-output-proc"]
if not is_async:
ms_server_args += ["--disable-async-output-proc"]

if eager_mode:
ms_server_args.append("--enforce-eager")

Expand Down
5 changes: 1 addition & 4 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,10 +1107,7 @@ def schedule(
if not self.cache_config.enable_prefix_caching:
common_computed_block_nums = []

# TODO: Combine multi-step and async postprocessor
allow_async_output_proc: bool = (
self.use_async_output_proc
and not self.scheduler_config.is_multi_step)
allow_async_output_proc: bool = self.use_async_output_proc

# Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = []
Expand Down
65 changes: 44 additions & 21 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ async def step_async(
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc

# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)

ctx = self.scheduler_contexts[virtual_engine]

# skip the scheduler if there are any remaining steps in the seq groups.
Expand All @@ -289,17 +293,27 @@ async def step_async(
# Clear outputs on scheduler iteration start
ctx.request_outputs.clear()

# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()

# If current scheduler iteration has no async postprocessor,
# then we need first to drain the pending async postprocessor
# before moving forward
# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)

# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)

# For async + multi-step, init the queue
if use_async_and_multi_step:
assert len(ctx.output_queue) == 0
assert seq_group_metadata_list is not None
ctx.output_queue.append(
(None, seq_group_metadata_list, scheduler_outputs))

if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
Expand All @@ -311,9 +325,6 @@ async def step_async(
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None

assert not (self.scheduler_config.is_multi_step and \
allow_async_output_proc)

if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
Expand All @@ -339,8 +350,13 @@ async def step_async(
last_sampled_token_ids=last_sampled_token_ids)

if allow_async_output_proc:
execute_model_req.async_callback = self.async_callback[
virtual_engine]
async_callback = self.async_callback_multi_step[
virtual_engine] if use_async_and_multi_step \
else self.async_callback[virtual_engine]

execute_model_req.async_callback = async_callback
execute_model_req.use_async_and_multi_step = \
use_async_and_multi_step

# Execute the model.
output = await self.model_executor.execute_model_async(
Expand All @@ -350,7 +366,7 @@ async def step_async(
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
if len(ctx.output_queue) > 0:
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
Expand All @@ -362,22 +378,25 @@ async def step_async(
seq_group.finish_step()

if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps
# Clear the cache if we have finished all the steps
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()

# Cache results in engine
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
if use_async_and_multi_step:
# For async + multi-step, clear the queue
ctx.output_queue.clear()
else:
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))

if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)

if not allow_async_output_proc:
self._process_model_outputs(virtual_engine=virtual_engine,
Expand All @@ -390,7 +409,11 @@ async def step_async(
self.do_tracing(scheduler_outputs)

else:
ctx.request_outputs = []
# Multi-step case
if use_async_and_multi_step:
return []
else:
ctx.request_outputs = []

if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
Expand Down
Loading
Loading