Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/executor/mp_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _init_executor(self) -> None:
# Set up signal handlers to shutdown the executor cleanly
# sometimes gc does not work well

self.driver_worker = WorkerWrapperBase(self.vllm_config, 0)
self.driver_worker = WorkerWrapperBase(self.vllm_config)

all_kwargs = []
distributed_init_method = get_distributed_init_method(
Expand Down
59 changes: 34 additions & 25 deletions vllm/executor/ray_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from itertools import islice, repeat
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union)

import cloudpickle
import msgspec

import vllm.envs as envs
Expand Down Expand Up @@ -212,17 +213,15 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
rpc_rank=rank)
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
else:
worker = ray.remote(
num_cpus=0,
num_gpus=0,
resources={current_platform.ray_device_key: num_gpus},
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
rpc_rank=rank)
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
worker_metadata.append(
RayWorkerMetaData(worker=worker, created_rank=rank))

Expand All @@ -244,7 +243,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
vllm_config=self.vllm_config, rpc_rank=0)
vllm_config=self.vllm_config)
worker_metadata.pop(i)
break

Expand Down Expand Up @@ -283,11 +282,6 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
for i, item in enumerate(sorted_worker_metadata):
item.adjusted_rank = i + start_rank
self.workers = [item.worker for item in sorted_worker_metadata]
rerank_mapping = {
item.created_rank: item.adjusted_rank
for item in sorted_worker_metadata
}
self._run_workers("adjust_rank", rerank_mapping)

# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = []
Expand Down Expand Up @@ -328,10 +322,10 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
" each node.")

# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [{
all_args_to_update_environment_variables = [({
current_platform.device_control_env_var:
",".join(map(str, node_gpus[node_id])),
} for (node_id, _) in worker_node_and_gpu_ids]
}, ) for (node_id, _) in worker_node_and_gpu_ids]

# Environment variables to copy from driver to workers
env_vars_to_copy = [
Expand All @@ -347,7 +341,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
# TODO: refactor platform-specific env vars
for name in env_vars_to_copy:
if name in os.environ:
args[name] = os.environ[name]
args[0][name] = os.environ[name]

logger.info("non_carry_over_env_vars from config: %s",
self.non_carry_over_env_vars)
Expand All @@ -362,7 +356,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
all_args_to_update_environment_variables)

self._run_workers("update_environment_variables",
self._get_env_vars_to_be_updated())
all_args=self._get_env_vars_to_be_updated())

if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.
Expand Down Expand Up @@ -390,7 +384,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
or (rank % self.parallel_config.tensor_parallel_size == 0),
)
all_kwargs.append(kwargs)
self._run_workers("init_worker", all_kwargs)
self._run_workers("init_worker", all_kwargs=all_kwargs)

self._run_workers("init_device")
self._run_workers("load_model",
Expand Down Expand Up @@ -466,6 +460,8 @@ def _run_workers(
method: Union[str, Callable],
*args,
async_run_tensor_parallel_workers_only: bool = False,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
Expand All @@ -479,11 +475,6 @@ def _run_workers(
rather than blocking on the results.
- args/kwargs: All workers share the same args/kwargs
"""
if isinstance(method, str):
sent_method = method
else:
sent_method = cloudpickle.dumps(method)
del method
if self.use_ray_spmd_worker:
assert not async_run_tensor_parallel_workers_only, (
"async_run_tensor_parallel_workers_only is not supported for "
Expand All @@ -493,13 +484,27 @@ def _run_workers(
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")

count = len(self.workers) if not \
async_run_tensor_parallel_workers_only \
else len(self.non_driver_workers)

# If using SPMD worker, all workers are the same, so we should execute
# the args on all workers. Otherwise, we skip the first worker's args
# because those args will go to the driver worker.
first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, first_worker_args_index, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, first_worker_args_index, None)

# Start the ray workers first.
ray_workers = self.workers
if async_run_tensor_parallel_workers_only:
ray_workers = self.non_driver_workers
ray_worker_outputs = [
worker.execute_method.remote(sent_method, *args, **kwargs)
for worker in ray_workers
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(ray_workers, all_worker_args, all_worker_kwargs)
]

if async_run_tensor_parallel_workers_only:
Expand All @@ -511,9 +516,13 @@ def _run_workers(
# so we only explicitly execute on the driver worker if using a
# non-SPMD worker class.
if not self.use_ray_spmd_worker:
driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]

# Start the driver worker after all the ray workers.
driver_worker_output = [
self.driver_worker.execute_method(sent_method, *args, **kwargs)
self.driver_worker.execute_method(method, *driver_args,
**driver_kwargs)
]

# Get the results of the ray workers.
Expand Down
55 changes: 10 additions & 45 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,60 +495,29 @@ class WorkerWrapperBase:
def __init__(
self,
vllm_config: VllmConfig,
rpc_rank: int = 0,
) -> None:
"""
Initialize the worker wrapper with the given vllm_config and rpc_rank.
Note: rpc_rank is the rank of the worker in the executor. In most cases,
it is also the rank of the worker in the distributed group. However,
when multiple executors work together, they can be different.
e.g. in the case of SPMD-style offline inference with TP=2,
users can launch 2 engines/executors, each with only 1 worker.
All workers have rpc_rank=0, but they have different ranks in the TP
group.
"""
self.rpc_rank = rpc_rank
self.vllm_config = vllm_config
trust_remote_code = vllm_config.model_config.trust_remote_code
self.worker: Optional[WorkerBase] = None
# do not store this `vllm_config`, `init_worker` will set the final
# one. TODO: investigate if we can remove this field in
# `WorkerWrapperBase`, `init_cached_hf_modules` should be
# unnecessary now.
if vllm_config.model_config is not None:
# it can be None in tests
trust_remote_code = vllm_config.model_config.trust_remote_code
if trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()

def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
"""
Adjust the rpc_rank based on the given mapping.
It is only used during the initialization of the executor,
to adjust the rpc_rank of workers after we create all workers.
"""
if self.rpc_rank in rank_mapping:
self.rpc_rank = rank_mapping[self.rpc_rank]
if trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()

def update_environment_variables(self, envs_list: List[Dict[str,
str]]) -> None:
envs = envs_list[self.rpc_rank]
@staticmethod
def update_environment_variables(envs: Dict[str, str]) -> None:
key = 'CUDA_VISIBLE_DEVICES'
if key in envs and key in os.environ:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
# suppress the warning in `update_environment_variables`
del os.environ[key]
update_environment_variables(envs)

def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
def init_worker(self, *args, **kwargs):
"""
Here we inject some common logic before initializing the worker.
Arguments are passed to the worker class constructor.
"""
kwargs = all_kwargs[self.rpc_rank]
self.vllm_config = kwargs.get("vllm_config", None)
assert self.vllm_config is not None, (
"vllm_config is required to initialize the worker")
enable_trace_function_call_for_thread(self.vllm_config)

from vllm.plugins import load_general_plugins
Expand Down Expand Up @@ -591,13 +560,9 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
worker_extension_cls, worker_class, extended_calls)
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during worker initialization
self.worker = worker_class(**kwargs)
self.worker = worker_class(*args, **kwargs)
assert self.worker is not None

def initialize_from_config(self, kv_cache_configs: List[Any]) -> None:
kv_cache_config = kv_cache_configs[self.rpc_rank]
self.worker.initialize_from_config(kv_cache_config) # type: ignore

def init_device(self):
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during device initialization
Expand Down
Loading