Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Combine async postprocessor and multi-step #7921

Merged
merged 9 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions tests/multi_step/test_correctness_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ async def completions_with_server_args(prompts: List[str], model_name: str,
@pytest.mark.parametrize("eager_mode", [False, True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.asyncio
async def test_multi_step(example_prompts, model: str, tp_size: int,
pp_size: int, eager_mode: int,
num_scheduler_steps: int, num_prompts: int):
num_scheduler_steps: int, num_prompts: int,
is_async: bool):

prompts = example_prompts
if len(prompts) < num_prompts:
Expand All @@ -62,9 +64,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
ms_server_args = DEFAULT_SERVER_ARGS + \
["--num-scheduler-steps", f"{num_scheduler_steps}"]

# Disable output proc callback as its not supported
# with multi-step right now
ms_server_args += ["--disable-async-output-proc"]
if not is_async:
ms_server_args += ["--disable-async-output-proc"]

if eager_mode:
ms_server_args.append("--enforce-eager")

Expand Down
5 changes: 1 addition & 4 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,10 +1107,7 @@ def schedule(
if not self.cache_config.enable_prefix_caching:
common_computed_block_nums = []

# TODO: Combine multi-step and async postprocessor
allow_async_output_proc: bool = (
self.use_async_output_proc
and not self.scheduler_config.is_multi_step)
allow_async_output_proc: bool = self.use_async_output_proc

# Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = []
Expand Down
65 changes: 44 additions & 21 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ async def step_async(
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc

# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)

ctx = self.scheduler_contexts[virtual_engine]

# skip the scheduler if there are any remaining steps in the seq groups.
Expand All @@ -289,17 +293,27 @@ async def step_async(
# Clear outputs on scheduler iteration start
ctx.request_outputs.clear()

# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()

# If current scheduler iteration has no async postprocessor,
# then we need first to drain the pending async postprocessor
# before moving forward
# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)

# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)

# For async + multi-step, init the queue
if use_async_and_multi_step:
assert len(ctx.output_queue) == 0
assert seq_group_metadata_list is not None
ctx.output_queue.append(
(None, seq_group_metadata_list, scheduler_outputs))

if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
Expand All @@ -311,9 +325,6 @@ async def step_async(
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None

assert not (self.scheduler_config.is_multi_step and \
allow_async_output_proc)

if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
Expand All @@ -339,8 +350,13 @@ async def step_async(
last_sampled_token_ids=last_sampled_token_ids)

if allow_async_output_proc:
execute_model_req.async_callback = self.async_callback[
virtual_engine]
async_callback = self.async_callback_multi_step[
virtual_engine] if use_async_and_multi_step \
else self.async_callback[virtual_engine]

execute_model_req.async_callback = async_callback
execute_model_req.use_async_and_multi_step = \
use_async_and_multi_step

# Execute the model.
output = await self.model_executor.execute_model_async(
Expand All @@ -350,7 +366,7 @@ async def step_async(
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
if len(ctx.output_queue) > 0:
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
Expand All @@ -362,22 +378,25 @@ async def step_async(
seq_group.finish_step()

if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps
# Clear the cache if we have finished all the steps
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()

# Cache results in engine
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
if use_async_and_multi_step:
# For async + multi-step, clear the queue
ctx.output_queue.clear()
else:
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))

if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)

if not allow_async_output_proc:
self._process_model_outputs(virtual_engine=virtual_engine,
Expand All @@ -390,7 +409,11 @@ async def step_async(
self.do_tracing(scheduler_outputs)

else:
ctx.request_outputs = []
# Multi-step case
if use_async_and_multi_step:
return []
else:
ctx.request_outputs = []

if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
Expand Down
108 changes: 82 additions & 26 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ class SchedulerOutputState:

@dataclass
class SchedulerContext:
output_queue: Deque[Tuple[List[SamplerOutput], List[SequenceGroupMetadata],
output_queue: Deque[Tuple[Optional[List[SamplerOutput]],
List[SequenceGroupMetadata],
SchedulerOutputs]] = field(
default_factory=lambda: deque())

Expand Down Expand Up @@ -432,6 +433,13 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
for v_id in range(self.parallel_config.pipeline_parallel_size)
]

self.async_callback_multi_step = [
functools.partial(self._process_model_outputs,
virtual_engine=v_id,
is_async=False)
for v_id in range(self.parallel_config.pipeline_parallel_size)
]

def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).

Expand Down Expand Up @@ -1240,8 +1248,11 @@ def _process_sequence_group_outputs(

return

def _process_model_outputs(self, virtual_engine: int,
is_async: bool) -> None:
def _process_model_outputs(self,
virtual_engine: int,
is_async: bool,
sampler_output: Optional[SamplerOutput] = None,
is_last_output: bool = False) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: docstring for sampler_output and is_last_output can be useful here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, added

"""Apply the model output to the sequences in the scheduled seq groups.

virtual_engine: The engine id to operate on
Expand All @@ -1255,13 +1266,25 @@ def _process_model_outputs(self, virtual_engine: int,
"""
now = time.time()

is_multi_step = sampler_output is not None

ctx: SchedulerContext = self.scheduler_contexts[virtual_engine]

if len(ctx.output_queue) == 0:
return None

(outputs, seq_group_metadata_list,
scheduler_outputs) = ctx.output_queue.popleft()
if is_multi_step:
# Async + multi-step case
(outputs, seq_group_metadata_list,
scheduler_outputs) = ctx.output_queue[0]
assert outputs is None
outputs = [sampler_output]
else:
# Async standard case
(outputs, seq_group_metadata_list,
scheduler_outputs) = ctx.output_queue.popleft()

assert outputs is not None

# Sanity check
assert len(seq_group_metadata_list) == len(
Expand Down Expand Up @@ -1320,15 +1343,19 @@ def _process_model_outputs(self, virtual_engine: int,
self.output_processor.process_outputs(seq_group, output,
is_async)

# Free the finished sequence groups.
# For async + multi-step, free finished seqs and create outputs
# only on the final step.
if is_multi_step and not is_last_output:
return

for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()

# Create the outputs.
for i, _ in enumerate(seq_group_metadata_list):
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]

if i in finished_before:
if not is_multi_step and i in finished_before:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why add not is_multi_step here? shouldn't double processing also be avoided for multi-step?

Copy link
Collaborator Author

@alexm-neuralmagic alexm-neuralmagic Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually not since multi-step is using is_last_output to indicate if to run this code or not. Added docstring below to explain better.

continue # Avoids double processing

seq_group = scheduled_seq_group.seq_group
Expand All @@ -1342,7 +1369,11 @@ def _process_model_outputs(self, virtual_engine: int,
request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output)

if is_async:
# For async + multi-step, do stats only on the last output.
# Otherwise, do stats if the execution is async
do_stats = is_multi_step or is_async

if do_stats:
# Log stats.
self.do_log_stats(scheduler_outputs, outputs, finished_before)

Expand Down Expand Up @@ -1437,7 +1468,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"as performance will be severely degraded otherwise.")

# For llm_engine, there is no pipeline parallel support, so the engine
# used is always 0
# used is always 0.
virtual_engine = 0

# These are cached outputs from previous iterations. None if on first
Expand All @@ -1447,6 +1478,10 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc

# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)

ctx = self.scheduler_contexts[virtual_engine]

# Skip the scheduler if there are any remaining steps in the seq groups.
Expand All @@ -1462,11 +1497,22 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()

# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)

# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)

# For async + multi-step, init the queue
if use_async_and_multi_step:
assert len(ctx.output_queue) == 0
assert seq_group_metadata_list is not None
ctx.output_queue.append(
(None, seq_group_metadata_list, scheduler_outputs))

if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
Expand All @@ -1478,9 +1524,6 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None

assert not (self.scheduler_config.is_multi_step and \
allow_async_output_proc)

if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
Expand All @@ -1505,8 +1548,13 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
last_sampled_token_ids=last_sampled_token_ids)

if allow_async_output_proc:
execute_model_req.async_callback = self.async_callback[
virtual_engine]
async_callback = self.async_callback_multi_step[
virtual_engine] if use_async_and_multi_step \
else self.async_callback[virtual_engine]

execute_model_req.async_callback = async_callback
execute_model_req.use_async_and_multi_step = \
use_async_and_multi_step

output = self.model_executor.execute_model(
execute_model_req=execute_model_req)
Expand All @@ -1518,7 +1566,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
else:
# Nothing scheduled => If there is pending async postprocessor,
# then finish it here.
if len(ctx.output_queue) > 0:
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
Expand All @@ -1535,18 +1583,23 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[0] = SchedulerOutputState()

# Add results to the output_queue
# (for async or non-async postprocessing)
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
if use_async_and_multi_step:
# For async + multi-step, clear the queue
ctx.output_queue.clear()
else:
# Add results to the output_queue
# (for async or non-async postprocessing)
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))

if output and allow_async_output_proc:
assert len(output) == 1, ("Multi step decoding does not work "
"with async output processing.")
if output and allow_async_output_proc:
assert len(output) == 1, (
"Multi step decoding does not work "
"with async output processing.")

self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)

# Check if need to run the usual non-async path
if not allow_async_output_proc:
Expand All @@ -1560,7 +1613,10 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
self.do_tracing(scheduler_outputs)
else:
# Multi-step case
ctx.request_outputs = []
if use_async_and_multi_step:
return []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Won't ctx.request_outputs = [] condition take care of both cases?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Async postprocessor may modify ctx.request_outputs at each step, so I did not want to touch it in the middle of multi-steps running.

else:
ctx.request_outputs = []

if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
Expand Down
Loading
Loading