diff --git a/vllm/executor/mp_distributed_executor.py b/vllm/executor/mp_distributed_executor.py index d1f8c36fbbec..39d592e933bf 100644 --- a/vllm/executor/mp_distributed_executor.py +++ b/vllm/executor/mp_distributed_executor.py @@ -103,7 +103,7 @@ def _init_executor(self) -> None: # Set up signal handlers to shutdown the executor cleanly # sometimes gc does not work well - self.driver_worker = WorkerWrapperBase(self.vllm_config, 0) + self.driver_worker = WorkerWrapperBase(self.vllm_config) all_kwargs = [] distributed_init_method = get_distributed_init_method( diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 9b0b98731e03..4c4460bbbfa5 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -5,9 +5,10 @@ import os from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from itertools import islice, repeat +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Union) -import cloudpickle import msgspec import vllm.envs as envs @@ -212,8 +213,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote(vllm_config=self.vllm_config, - rpc_rank=rank) + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) else: worker = ray.remote( num_cpus=0, @@ -221,8 +221,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", resources={current_platform.ray_device_key: num_gpus}, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote(vllm_config=self.vllm_config, - rpc_rank=rank) + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) worker_metadata.append( RayWorkerMetaData(worker=worker, created_rank=rank)) @@ -244,7 +243,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # as the resource holder for the driver process. self.driver_dummy_worker = worker self.driver_worker = RayWorkerWrapper( - vllm_config=self.vllm_config, rpc_rank=0) + vllm_config=self.vllm_config) worker_metadata.pop(i) break @@ -283,11 +282,6 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): for i, item in enumerate(sorted_worker_metadata): item.adjusted_rank = i + start_rank self.workers = [item.worker for item in sorted_worker_metadata] - rerank_mapping = { - item.created_rank: item.adjusted_rank - for item in sorted_worker_metadata - } - self._run_workers("adjust_rank", rerank_mapping) # Get the set of GPU IDs used on each node. worker_node_and_gpu_ids = [] @@ -328,10 +322,10 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): " each node.") # Set environment variables for the driver and workers. - all_args_to_update_environment_variables = [{ + all_args_to_update_environment_variables = [({ current_platform.device_control_env_var: ",".join(map(str, node_gpus[node_id])), - } for (node_id, _) in worker_node_and_gpu_ids] + }, ) for (node_id, _) in worker_node_and_gpu_ids] # Environment variables to copy from driver to workers env_vars_to_copy = [ @@ -347,7 +341,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): # TODO: refactor platform-specific env vars for name in env_vars_to_copy: if name in os.environ: - args[name] = os.environ[name] + args[0][name] = os.environ[name] logger.info("non_carry_over_env_vars from config: %s", self.non_carry_over_env_vars) @@ -362,7 +356,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): all_args_to_update_environment_variables) self._run_workers("update_environment_variables", - self._get_env_vars_to_be_updated()) + all_args=self._get_env_vars_to_be_updated()) if len(node_gpus) == 1: # in single node case, we don't need to get the IP address. @@ -390,7 +384,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): or (rank % self.parallel_config.tensor_parallel_size == 0), ) all_kwargs.append(kwargs) - self._run_workers("init_worker", all_kwargs) + self._run_workers("init_worker", all_kwargs=all_kwargs) self._run_workers("init_device") self._run_workers("load_model", @@ -466,6 +460,8 @@ def _run_workers( method: Union[str, Callable], *args, async_run_tensor_parallel_workers_only: bool = False, + all_args: Optional[List[Tuple[Any, ...]]] = None, + all_kwargs: Optional[List[Dict[str, Any]]] = None, max_concurrent_workers: Optional[int] = None, **kwargs, ) -> Any: @@ -479,11 +475,6 @@ def _run_workers( rather than blocking on the results. - args/kwargs: All workers share the same args/kwargs """ - if isinstance(method, str): - sent_method = method - else: - sent_method = cloudpickle.dumps(method) - del method if self.use_ray_spmd_worker: assert not async_run_tensor_parallel_workers_only, ( "async_run_tensor_parallel_workers_only is not supported for " @@ -493,13 +484,27 @@ def _run_workers( raise NotImplementedError( "max_concurrent_workers is not supported yet.") + count = len(self.workers) if not \ + async_run_tensor_parallel_workers_only \ + else len(self.non_driver_workers) + + # If using SPMD worker, all workers are the same, so we should execute + # the args on all workers. Otherwise, we skip the first worker's args + # because those args will go to the driver worker. + first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1 + all_worker_args = repeat(args, count) if all_args is None \ + else islice(all_args, first_worker_args_index, None) + all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ + else islice(all_kwargs, first_worker_args_index, None) + # Start the ray workers first. ray_workers = self.workers if async_run_tensor_parallel_workers_only: ray_workers = self.non_driver_workers ray_worker_outputs = [ - worker.execute_method.remote(sent_method, *args, **kwargs) - for worker in ray_workers + worker.execute_method.remote(method, *worker_args, **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(ray_workers, all_worker_args, all_worker_kwargs) ] if async_run_tensor_parallel_workers_only: @@ -511,9 +516,13 @@ def _run_workers( # so we only explicitly execute on the driver worker if using a # non-SPMD worker class. if not self.use_ray_spmd_worker: + driver_args = args if all_args is None else all_args[0] + driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + # Start the driver worker after all the ray workers. driver_worker_output = [ - self.driver_worker.execute_method(sent_method, *args, **kwargs) + self.driver_worker.execute_method(method, *driver_args, + **driver_kwargs) ] # Get the results of the ray workers. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index e5662e69343c..0497bc0866e4 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -495,44 +495,17 @@ class WorkerWrapperBase: def __init__( self, vllm_config: VllmConfig, - rpc_rank: int = 0, ) -> None: - """ - Initialize the worker wrapper with the given vllm_config and rpc_rank. - Note: rpc_rank is the rank of the worker in the executor. In most cases, - it is also the rank of the worker in the distributed group. However, - when multiple executors work together, they can be different. - e.g. in the case of SPMD-style offline inference with TP=2, - users can launch 2 engines/executors, each with only 1 worker. - All workers have rpc_rank=0, but they have different ranks in the TP - group. - """ - self.rpc_rank = rpc_rank + self.vllm_config = vllm_config + trust_remote_code = vllm_config.model_config.trust_remote_code self.worker: Optional[WorkerBase] = None - # do not store this `vllm_config`, `init_worker` will set the final - # one. TODO: investigate if we can remove this field in - # `WorkerWrapperBase`, `init_cached_hf_modules` should be - # unnecessary now. - if vllm_config.model_config is not None: - # it can be None in tests - trust_remote_code = vllm_config.model_config.trust_remote_code - if trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() - - def adjust_rank(self, rank_mapping: Dict[int, int]) -> None: - """ - Adjust the rpc_rank based on the given mapping. - It is only used during the initialization of the executor, - to adjust the rpc_rank of workers after we create all workers. - """ - if self.rpc_rank in rank_mapping: - self.rpc_rank = rank_mapping[self.rpc_rank] + if trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() - def update_environment_variables(self, envs_list: List[Dict[str, - str]]) -> None: - envs = envs_list[self.rpc_rank] + @staticmethod + def update_environment_variables(envs: Dict[str, str]) -> None: key = 'CUDA_VISIBLE_DEVICES' if key in envs and key in os.environ: # overwriting CUDA_VISIBLE_DEVICES is desired behavior @@ -540,15 +513,11 @@ def update_environment_variables(self, envs_list: List[Dict[str, del os.environ[key] update_environment_variables(envs) - def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: + def init_worker(self, *args, **kwargs): """ Here we inject some common logic before initializing the worker. Arguments are passed to the worker class constructor. """ - kwargs = all_kwargs[self.rpc_rank] - self.vllm_config = kwargs.get("vllm_config", None) - assert self.vllm_config is not None, ( - "vllm_config is required to initialize the worker") enable_trace_function_call_for_thread(self.vllm_config) from vllm.plugins import load_general_plugins @@ -591,13 +560,9 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: worker_extension_cls, worker_class, extended_calls) with set_current_vllm_config(self.vllm_config): # To make vLLM config available during worker initialization - self.worker = worker_class(**kwargs) + self.worker = worker_class(*args, **kwargs) assert self.worker is not None - def initialize_from_config(self, kv_cache_configs: List[Any]) -> None: - kv_cache_config = kv_cache_configs[self.rpc_rank] - self.worker.initialize_from_config(kv_cache_config) # type: ignore - def init_device(self): with set_current_vllm_config(self.vllm_config): # To make vLLM config available during device initialization