Skip to content

Commit 7c418ca

Browse files
committed
[Bugfix] fix client socket timeout when serve multi-node model in Ray
Signed-off-by: <> Signed-off-by: yangli5t <yangli5t@users.noreply.github.com>
1 parent d2b58ca commit 7c418ca

File tree

2 files changed

+44
-70
lines changed

2 files changed

+44
-70
lines changed

vllm/executor/ray_distributed_executor.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
import os
66
from collections import defaultdict
77
from dataclasses import dataclass
8-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
8+
from itertools import islice, repeat
9+
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
10+
Union)
911

10-
import cloudpickle
1112
import msgspec
1213

1314
import vllm.envs as envs
@@ -212,17 +213,15 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
212213
num_gpus=num_gpus,
213214
scheduling_strategy=scheduling_strategy,
214215
**ray_remote_kwargs,
215-
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
216-
rpc_rank=rank)
216+
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
217217
else:
218218
worker = ray.remote(
219219
num_cpus=0,
220220
num_gpus=0,
221221
resources={current_platform.ray_device_key: num_gpus},
222222
scheduling_strategy=scheduling_strategy,
223223
**ray_remote_kwargs,
224-
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
225-
rpc_rank=rank)
224+
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
226225
worker_metadata.append(
227226
RayWorkerMetaData(worker=worker, created_rank=rank))
228227

@@ -244,7 +243,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
244243
# as the resource holder for the driver process.
245244
self.driver_dummy_worker = worker
246245
self.driver_worker = RayWorkerWrapper(
247-
vllm_config=self.vllm_config, rpc_rank=0)
246+
vllm_config=self.vllm_config)
248247
worker_metadata.pop(i)
249248
break
250249

@@ -283,11 +282,6 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
283282
for i, item in enumerate(sorted_worker_metadata):
284283
item.adjusted_rank = i + start_rank
285284
self.workers = [item.worker for item in sorted_worker_metadata]
286-
rerank_mapping = {
287-
item.created_rank: item.adjusted_rank
288-
for item in sorted_worker_metadata
289-
}
290-
self._run_workers("adjust_rank", rerank_mapping)
291285

292286
# Get the set of GPU IDs used on each node.
293287
worker_node_and_gpu_ids = []
@@ -328,10 +322,10 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
328322
" each node.")
329323

330324
# Set environment variables for the driver and workers.
331-
all_args_to_update_environment_variables = [{
325+
all_args_to_update_environment_variables = [({
332326
current_platform.device_control_env_var:
333327
",".join(map(str, node_gpus[node_id])),
334-
} for (node_id, _) in worker_node_and_gpu_ids]
328+
}, ) for (node_id, _) in worker_node_and_gpu_ids]
335329

336330
# Environment variables to copy from driver to workers
337331
env_vars_to_copy = [
@@ -347,7 +341,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
347341
# TODO: refactor platform-specific env vars
348342
for name in env_vars_to_copy:
349343
if name in os.environ:
350-
args[name] = os.environ[name]
344+
args[0][name] = os.environ[name]
351345

352346
logger.info("non_carry_over_env_vars from config: %s",
353347
self.non_carry_over_env_vars)
@@ -362,7 +356,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
362356
all_args_to_update_environment_variables)
363357

364358
self._run_workers("update_environment_variables",
365-
self._get_env_vars_to_be_updated())
359+
all_args=self._get_env_vars_to_be_updated())
366360

367361
if len(node_gpus) == 1:
368362
# in single node case, we don't need to get the IP address.
@@ -390,7 +384,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
390384
or (rank % self.parallel_config.tensor_parallel_size == 0),
391385
)
392386
all_kwargs.append(kwargs)
393-
self._run_workers("init_worker", all_kwargs)
387+
self._run_workers("init_worker", all_kwargs=all_kwargs)
394388

395389
self._run_workers("init_device")
396390
self._run_workers("load_model",
@@ -466,6 +460,8 @@ def _run_workers(
466460
method: Union[str, Callable],
467461
*args,
468462
async_run_tensor_parallel_workers_only: bool = False,
463+
all_args: Optional[List[Tuple[Any, ...]]] = None,
464+
all_kwargs: Optional[List[Dict[str, Any]]] = None,
469465
max_concurrent_workers: Optional[int] = None,
470466
**kwargs,
471467
) -> Any:
@@ -479,11 +475,6 @@ def _run_workers(
479475
rather than blocking on the results.
480476
- args/kwargs: All workers share the same args/kwargs
481477
"""
482-
if isinstance(method, str):
483-
sent_method = method
484-
else:
485-
sent_method = cloudpickle.dumps(method)
486-
del method
487478
if self.use_ray_spmd_worker:
488479
assert not async_run_tensor_parallel_workers_only, (
489480
"async_run_tensor_parallel_workers_only is not supported for "
@@ -493,13 +484,27 @@ def _run_workers(
493484
raise NotImplementedError(
494485
"max_concurrent_workers is not supported yet.")
495486

487+
count = len(self.workers) if not \
488+
async_run_tensor_parallel_workers_only \
489+
else len(self.non_driver_workers)
490+
491+
# If using SPMD worker, all workers are the same, so we should execute
492+
# the args on all workers. Otherwise, we skip the first worker's args
493+
# because those args will go to the driver worker.
494+
first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1
495+
all_worker_args = repeat(args, count) if all_args is None \
496+
else islice(all_args, first_worker_args_index, None)
497+
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
498+
else islice(all_kwargs, first_worker_args_index, None)
499+
496500
# Start the ray workers first.
497501
ray_workers = self.workers
498502
if async_run_tensor_parallel_workers_only:
499503
ray_workers = self.non_driver_workers
500504
ray_worker_outputs = [
501-
worker.execute_method.remote(sent_method, *args, **kwargs)
502-
for worker in ray_workers
505+
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
506+
for (worker, worker_args, worker_kwargs
507+
) in zip(ray_workers, all_worker_args, all_worker_kwargs)
503508
]
504509

505510
if async_run_tensor_parallel_workers_only:
@@ -511,9 +516,13 @@ def _run_workers(
511516
# so we only explicitly execute on the driver worker if using a
512517
# non-SPMD worker class.
513518
if not self.use_ray_spmd_worker:
519+
driver_args = args if all_args is None else all_args[0]
520+
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
521+
514522
# Start the driver worker after all the ray workers.
515523
driver_worker_output = [
516-
self.driver_worker.execute_method(sent_method, *args, **kwargs)
524+
self.driver_worker.execute_method(method, *driver_args,
525+
**driver_kwargs)
517526
]
518527

519528
# Get the results of the ray workers.

vllm/worker/worker_base.py

Lines changed: 10 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -495,60 +495,29 @@ class WorkerWrapperBase:
495495
def __init__(
496496
self,
497497
vllm_config: VllmConfig,
498-
rpc_rank: int = 0,
499498
) -> None:
500-
"""
501-
Initialize the worker wrapper with the given vllm_config and rpc_rank.
502-
Note: rpc_rank is the rank of the worker in the executor. In most cases,
503-
it is also the rank of the worker in the distributed group. However,
504-
when multiple executors work together, they can be different.
505-
e.g. in the case of SPMD-style offline inference with TP=2,
506-
users can launch 2 engines/executors, each with only 1 worker.
507-
All workers have rpc_rank=0, but they have different ranks in the TP
508-
group.
509-
"""
510-
self.rpc_rank = rpc_rank
499+
self.vllm_config = vllm_config
500+
trust_remote_code = vllm_config.model_config.trust_remote_code
511501
self.worker: Optional[WorkerBase] = None
512-
# do not store this `vllm_config`, `init_worker` will set the final
513-
# one. TODO: investigate if we can remove this field in
514-
# `WorkerWrapperBase`, `init_cached_hf_modules` should be
515-
# unnecessary now.
516-
if vllm_config.model_config is not None:
517-
# it can be None in tests
518-
trust_remote_code = vllm_config.model_config.trust_remote_code
519-
if trust_remote_code:
520-
# note: lazy import to avoid importing torch before initializing
521-
from vllm.utils import init_cached_hf_modules
522-
init_cached_hf_modules()
523-
524-
def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
525-
"""
526-
Adjust the rpc_rank based on the given mapping.
527-
It is only used during the initialization of the executor,
528-
to adjust the rpc_rank of workers after we create all workers.
529-
"""
530-
if self.rpc_rank in rank_mapping:
531-
self.rpc_rank = rank_mapping[self.rpc_rank]
502+
if trust_remote_code:
503+
# note: lazy import to avoid importing torch before initializing
504+
from vllm.utils import init_cached_hf_modules
505+
init_cached_hf_modules()
532506

533-
def update_environment_variables(self, envs_list: List[Dict[str,
534-
str]]) -> None:
535-
envs = envs_list[self.rpc_rank]
507+
@staticmethod
508+
def update_environment_variables(envs: Dict[str, str]) -> None:
536509
key = 'CUDA_VISIBLE_DEVICES'
537510
if key in envs and key in os.environ:
538511
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
539512
# suppress the warning in `update_environment_variables`
540513
del os.environ[key]
541514
update_environment_variables(envs)
542515

543-
def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
516+
def init_worker(self, *args, **kwargs):
544517
"""
545518
Here we inject some common logic before initializing the worker.
546519
Arguments are passed to the worker class constructor.
547520
"""
548-
kwargs = all_kwargs[self.rpc_rank]
549-
self.vllm_config = kwargs.get("vllm_config", None)
550-
assert self.vllm_config is not None, (
551-
"vllm_config is required to initialize the worker")
552521
enable_trace_function_call_for_thread(self.vllm_config)
553522

554523
from vllm.plugins import load_general_plugins
@@ -591,13 +560,9 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
591560
worker_extension_cls, worker_class, extended_calls)
592561
with set_current_vllm_config(self.vllm_config):
593562
# To make vLLM config available during worker initialization
594-
self.worker = worker_class(**kwargs)
563+
self.worker = worker_class(*args, **kwargs)
595564
assert self.worker is not None
596565

597-
def initialize_from_config(self, kv_cache_configs: List[Any]) -> None:
598-
kv_cache_config = kv_cache_configs[self.rpc_rank]
599-
self.worker.initialize_from_config(kv_cache_config) # type: ignore
600-
601566
def init_device(self):
602567
with set_current_vllm_config(self.vllm_config):
603568
# To make vLLM config available during device initialization

0 commit comments

Comments
 (0)