Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions tests/v1/test_async_llm_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,22 @@ async def generate(engine: AsyncLLM,


@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
"output_kind",
[
RequestOutputKind.DELTA,
RequestOutputKind.FINAL_ONLY,
],
)
@pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"])
@pytest.mark.asyncio
async def test_load(output_kind: RequestOutputKind):
async def test_load(output_kind: RequestOutputKind,
data_parallel_backend: str):

with ExitStack() as after:

prompt = "This is a test of data parallel"

engine_args.data_parallel_backend = data_parallel_backend
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)

Expand All @@ -82,7 +90,6 @@ async def test_load(output_kind: RequestOutputKind):
asyncio.create_task(
generate(engine, request_id, prompt, output_kind,
NUM_EXPECTED_TOKENS)))

# Confirm that we got all the EXPECTED tokens from the requests.
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION)
Expand Down
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1693,6 +1693,8 @@ class ParallelConfig:
"""Port for data parallel messaging."""
data_parallel_master_port: int = 29500
"""Port of the data parallel master."""
data_parallel_backend: str = "mp"
"""Backend to use for data parallel, either "mp" or "ray"."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
max_parallel_loading_workers: Optional[int] = None
Expand Down
9 changes: 9 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ class EngineArgs:
data_parallel_size_local: Optional[int] = None
data_parallel_address: Optional[str] = None
data_parallel_rpc_port: Optional[int] = None
data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
max_parallel_loading_workers: Optional[
int] = ParallelConfig.max_parallel_loading_workers
Expand Down Expand Up @@ -618,6 +619,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=int,
help='Port for data parallel RPC '
'communication.')
parallel_group.add_argument('--data-parallel-backend',
'-dpb',
type=str,
help='Backend for data parallel, either '
'"mp" or "ray".')
parallel_group.add_argument(
"--enable-expert-parallel",
**parallel_kwargs["enable_expert_parallel"])
Expand Down Expand Up @@ -1058,13 +1064,16 @@ def create_engine_config(
self.data_parallel_rpc_port
is not None) else ParallelConfig.data_parallel_rpc_port

data_parallel_backend = self.data_parallel_backend

parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
data_parallel_size=self.data_parallel_size,
data_parallel_size_local=data_parallel_size_local,
data_parallel_master_ip=data_parallel_address,
data_parallel_rpc_port=data_parallel_rpc_port,
data_parallel_backend=data_parallel_backend,
enable_expert_parallel=self.enable_expert_parallel,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,
Expand Down
15 changes: 11 additions & 4 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, cdiv
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient
from vllm.v1.engine.core_client import (AsyncMPClient, DPAsyncMPClient,
RayDPClient)
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.output_processor import (OutputProcessor,
RequestOutputCollector)
Expand Down Expand Up @@ -111,9 +112,15 @@ def __init__(
log_stats=self.log_stats)

# EngineCore (starts the engine in background process).
core_client_class = AsyncMPClient if (
vllm_config.parallel_config.data_parallel_size
== 1) else DPAsyncMPClient
core_client_class: Union[type[RayDPClient], type[DPAsyncMPClient],
type[AsyncMPClient]]
if vllm_config.parallel_config.data_parallel_size > 1:
if vllm_config.parallel_config.data_parallel_backend == "ray":
core_client_class = RayDPClient
else:
core_client_class = DPAsyncMPClient
else:
core_client_class = AsyncMPClient

self.engine_core = core_client_class(
vllm_config=vllm_config,
Expand Down
114 changes: 112 additions & 2 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.engine.exceptions import EngineDeadError
from vllm.v1.executor.abstract import Executor
from vllm.v1.ray_dp import CoreEngineActorManager
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
from vllm.v1.utils import CoreEngineProcManager

Expand Down Expand Up @@ -67,7 +68,11 @@ def make_client(

if multiprocess_mode and asyncio_mode:
if vllm_config.parallel_config.data_parallel_size > 1:
return DPAsyncMPClient(vllm_config, executor_class, log_stats)
if vllm_config.parallel_config.data_parallel_backend == "ray":
return RayDPClient(vllm_config, executor_class, log_stats)
else:
return DPAsyncMPClient(vllm_config, executor_class,
log_stats)

return AsyncMPClient(vllm_config, executor_class, log_stats)

Expand Down Expand Up @@ -289,7 +294,8 @@ class BackgroundResources:
circular reference back to the client object."""

ctx: Union[zmq.Context]
local_engine_manager: Optional[CoreEngineProcManager] = None
local_engine_manager: Optional[Union[CoreEngineProcManager,
CoreEngineActorManager]] = None
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
output_queue_task: Optional[asyncio.Task] = None
Expand Down Expand Up @@ -463,6 +469,9 @@ def _wait_for_engine_startup(self, output_address: str,
poller.register(sync_input_socket, zmq.POLLIN)
proc_manager = self.resources.local_engine_manager
if proc_manager is not None:
assert isinstance(proc_manager, CoreEngineProcManager), (
"_wait_for_engine_startup should only be called "
"with CoreEngineProcManager")
for sentinel in proc_manager.sentinels():
poller.register(sentinel, zmq.POLLIN)
while any(conn_pending) or any(start_pending):
Expand Down Expand Up @@ -1018,3 +1027,104 @@ async def _abort_requests(self, request_ids: list[str],
if not self.resources.engine_dead:
await self._send_input(EngineCoreRequestType.ABORT, request_ids,
engine)


class RayDPClient(DPAsyncMPClient):
"""
Ray-based client for multi-proc, multi-engine (data parallel)
EngineCore.
"""

def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
):
self.current_wave = 0
self.engines_running = False
self.reqs_in_flight: dict[str, CoreEngine] = {}

self.vllm_config = vllm_config
# Serialization setup.
self.encoder = MsgpackEncoder()
self.decoder = MsgpackDecoder(EngineCoreOutputs)

# ZMQ setup.
sync_ctx = zmq.Context(io_threads=2)
self.ctx = zmq.asyncio.Context(sync_ctx)

# This will ensure resources created so far are closed
# when the client is garbage collected, even if an
# exception is raised mid-construction.
self.resources = BackgroundResources(ctx=sync_ctx)
self._finalizer = weakref.finalize(self, self.resources)
success = False
try:
parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local
start_index = parallel_config.data_parallel_rank
local_start_index = parallel_config.data_parallel_rank_local

# SPMD mode is where there is an LLM instance per DP rank and
# one core engine per LLM, see
# examples/offline_inference/data_parallel.py.
spmd_mode = local_start_index is not None
if spmd_mode:
assert local_engine_count == 1
self.core_engines = [
CoreEngine(index=local_start_index, local=True)
]
else:
assert start_index == 0
local_start_index = 0
self.core_engines = [
CoreEngine(index=i, local=(i < local_engine_count))
for i in range(parallel_config.data_parallel_size)
]

input_address, output_address = self._get_zmq_addresses(
parallel_config, spmd_mode)

# Create input and output sockets.
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx, input_address, zmq.ROUTER, bind=True)
self.resources.output_socket = make_zmq_socket(
self.ctx, output_address, zmq.constants.PULL)

# Start all engines.
self.resources.local_engine_manager = CoreEngineActorManager(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=log_stats,
input_address=input_address,
output_address=output_address,
local_engine_count=local_engine_count,
start_index=start_index,
local_start_index=local_start_index)

self.core_engine = self.core_engines[0]

self.utility_results: dict[int, AnyFuture] = {}

# Request objects which may contain pytorch-allocated tensors
# that we need to keep references to until zmq is done with the
# underlying data.
self.pending_messages = deque[tuple[zmq.MessageTracker, Any]]()
self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs,
Exception]]()

success = True
finally:
if not success:
self._finalizer()

try:
# If we are running in an asyncio event loop, start the queue task.
# Otherwise, it will be started lazily. If it is not started here,
# we could miss EXECUTOR_FAILED messages from engine core if they
# occur prior to any requests being sent.
asyncio.get_running_loop()
self._ensure_output_queue_task()
except RuntimeError:
pass
Loading