Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Core] Eliminate parallel worker per-step task scheduling overhead (v…
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored and robertgshaw2-neuralmagic committed Jul 14, 2024
1 parent d9e9332 commit 9f3dac0
Show file tree
Hide file tree
Showing 12 changed files with 348 additions and 209 deletions.
10 changes: 9 additions & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,14 @@ async def step_async(
# Log stats.
self.do_log_stats(scheduler_outputs, output)

if not request_outputs:
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
await self.model_executor.stop_remote_worker_execution_loop_async()

return request_outputs

async def encode_request_async(
Expand Down Expand Up @@ -687,7 +695,7 @@ async def encode(
multi_modal_data: Multi modal data per request.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
Details:
Expand Down
8 changes: 8 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,14 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
# Log stats.
self.do_log_stats(scheduler_outputs, output)

if not request_outputs:
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
self.model_executor.stop_remote_worker_execution_loop()

return request_outputs

def do_log_stats(
Expand Down
123 changes: 97 additions & 26 deletions vllm/executor/distributed_gpu_executor.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
import asyncio
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union

from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SamplerOutput

logger = init_logger(__name__)


class DistributedGPUExecutor(GPUExecutor):
"""Abstract superclass of multi-GPU executor implementations."""

def __init__(self, *args, **kwargs):
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
# Updated by implementations that require additional args to be passed
# to the _run_workers execute_model call
self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}

super().__init__(*args, **kwargs)

def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
Expand Down Expand Up @@ -52,13 +63,28 @@ def initialize_cache(self, num_gpu_blocks: int,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)

def execute_model(self, *args, **kwargs) -> List[SamplerOutput]:
all_outputs = self._run_workers("execute_model",
driver_args=args,
driver_kwargs=kwargs)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
async_run_remote_workers_only=True,
**self.extra_execute_model_run_workers_kwargs)

# Only the driver worker returns the sampling results.
return all_outputs[0]
return self._driver_execute_model(execute_model_req)

def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
return

self._driver_execute_model()
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
self._wait_for_tasks_completion(parallel_worker_tasks)

def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
Expand Down Expand Up @@ -88,39 +114,84 @@ def save_sharded_state(
pattern=pattern,
max_size=max_size)

@abstractmethod
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
raise NotImplementedError

@abstractmethod
def _run_workers(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
async_run_remote_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
"""Runs the given method on all workers.
Args:
async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than
blocking on the results.
"""
raise NotImplementedError

@abstractmethod
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
raise NotImplementedError


class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):

async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
# Start model execution loop running in the parallel workers
self.parallel_worker_tasks = asyncio.create_task(
self._start_worker_execution_loop())

# Only the driver worker returns the sampling results.
return await self._driver_execute_model_async(execute_model_req)

async def stop_remote_worker_execution_loop_async(self) -> None:
if self.parallel_worker_tasks is None:
return

await self._driver_execute_model_async()
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
await parallel_worker_tasks

@abstractmethod
async def _run_workers_async(
async def _driver_execute_model_async(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
raise NotImplementedError
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Execute the model asynchronously in the driver worker.
async def execute_model_async(self, *args,
**kwargs) -> List[SamplerOutput]:
all_outputs = await self._run_workers_async("execute_model",
driver_args=args,
driver_kwargs=kwargs)
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
raise NotImplementedError

# Only the driver worker returns the sampling results.
return all_outputs[0]
@abstractmethod
async def _start_worker_execution_loop(self):
"""Run execution loop on all workers. It guarantees all workers run
the loop or None of them is running the loop. Loop can be stopped by
`stop_remote_worker_execution_loop`.
The API is idempotent (guarantee only 1 loop run at any moment)."""
raise NotImplementedError
8 changes: 8 additions & 0 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def execute_model(
"""Executes at least one model step on the given sequences."""
raise NotImplementedError

def stop_remote_worker_execution_loop(self) -> None:
"""Releases parallel workers from model loop."""
return

@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError
Expand Down Expand Up @@ -109,6 +113,10 @@ async def execute_model_async(
"""Executes one model step on the given sequences."""
raise NotImplementedError

async def stop_remote_worker_execution_loop_async(self) -> None:
"""Releases parallel workers from model loop."""
return

async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
Expand Down
73 changes: 44 additions & 29 deletions vllm/executor/multiproc_gpu_executor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import asyncio
import os
from functools import partial
from typing import Any, Dict, Optional, Tuple
from typing import Any, List, Optional

from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)

Expand Down Expand Up @@ -71,16 +72,34 @@ def shutdown(self):
None)) is not None:
worker_monitor.close()

def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return self.driver_worker.execute_model(
execute_model_req=execute_model_req)

def _run_workers(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
async_run_remote_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
"""Runs the given method on all workers.
Args:
async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than
blocking on the results.
"""

if max_concurrent_workers:
raise NotImplementedError(
Expand All @@ -92,15 +111,12 @@ def _run_workers(
for worker in self.workers
]

if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
if async_run_remote_workers_only:
# Just return futures
return worker_outputs

# Start the driver worker after all the ray workers.
driver_worker_method = getattr(self.driver_worker, method)
driver_worker_output = driver_worker_method(*driver_args,
**driver_kwargs)
driver_worker_output = driver_worker_method(*args, **kwargs)

# Get the results of the workers.
return [driver_worker_output
Expand All @@ -111,30 +127,29 @@ def check_health(self) -> None:
if not self.worker_monitor.is_alive():
raise RuntimeError("Worker processes are not running")

def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
for result in parallel_worker_tasks:
result.get()


class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
DistributedGPUExecutorAsync):

async def _run_workers_async(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_model = make_async(self.driver_worker.execute_model)

driver_executor = make_async(getattr(self.driver_worker, method))
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
return await self.driver_exec_model(execute_model_req)

# Run all the workers asynchronously.
coros = [driver_executor(*driver_args, **driver_kwargs)] + [
worker.execute_method_async(method, *args, **kwargs)
async def _start_worker_execution_loop(self):
coros = [
worker.execute_method_async("start_worker_execution_loop")
for worker in self.workers
]

return await asyncio.gather(*coros)
Loading

0 comments on commit 9f3dac0

Please sign in to comment.