diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 3c3da41c3abf3..eeb6eaa2165bc 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -34,10 +34,14 @@ steps: mirror_hardwares: [amd] commands: - pytest -v -s distributed/test_pynccl_library.py - - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - label: Distributed Tests (Multiple Groups) working_dir: "/vllm-workspace/tests" diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index d06e407c73b98..5178bc5dae566 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -24,6 +24,7 @@ MODELS = [ "meta-llama/Llama-2-7b-hf", ] +DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" @@ -40,19 +41,21 @@ def test_models( dtype: str, max_tokens: int, ) -> None: - enforce_eager = False + distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) + backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) - if backend_by_env_var == "FLASHINFER": - enforce_eager = True + enforce_eager = backend_by_env_var == "FLASHINFER" hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, - dtype=dtype, - tensor_parallel_size=2, - enforce_eager=enforce_eager) + vllm_model = vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=2, + enforce_eager=enforce_eager, + distributed_executor_backend=distributed_executor_backend) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py index 209d03084c3e5..9bc2553a7ffed 100644 --- a/tests/distributed/test_chunked_prefill_distributed.py +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -21,6 +21,7 @@ MODELS = [ "meta-llama/Llama-2-7b-hf", ] +DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -38,6 +39,8 @@ def test_models( max_tokens: int, chunked_prefill_token_size: int, ) -> None: + distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) + # Add a chunked prefill config. max_num_seqs = min(chunked_prefill_token_size, 256) assert chunked_prefill_token_size != -1 @@ -55,6 +58,7 @@ def test_models( max_num_seqs=max_num_seqs, enable_chunked_prefill=enable_chunked_prefill, max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, ) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py index ba47581cb4422..53d49a8dbc813 100644 --- a/tests/lora/test_mixtral.py +++ b/tests/lora/test_mixtral.py @@ -40,8 +40,7 @@ def test_mixtral_lora(mixtral_lora_files, tp_size): enable_lora=True, max_num_seqs=16, max_loras=4, - tensor_parallel_size=tp_size, - worker_use_ray=True) + tensor_parallel_size=tp_size) expected_lora_output = [ "give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501 diff --git a/vllm/config.py b/vllm/config.py index 9098fdf336374..d457710ba64e8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -565,9 +565,7 @@ class ParallelConfig: Args: pipeline_parallel_size: Number of pipeline parallel groups. tensor_parallel_size: Number of tensor parallel groups. - worker_use_ray: Whether to use Ray for model workers. Will be set to - True if either pipeline_parallel_size or tensor_parallel_size is - greater than 1. + worker_use_ray: Deprecated, use distributed_executor_backend instead. max_parallel_loading_workers: Maximum number of multiple batches when load model sequentially. To avoid RAM OOM when using tensor parallel and large models. @@ -577,22 +575,27 @@ class ParallelConfig: If None, will use synchronous tokenization. ray_workers_use_nsight: Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. + distributed_executor_backend: Backend to use for distributed model + workers, either "ray" or "mp" (multiprocessing). If either + pipeline_parallel_size or tensor_parallel_size is greater than 1, + will default to "ray" if Ray is installed or "mp" otherwise. """ def __init__( self, pipeline_parallel_size: int, tensor_parallel_size: int, - worker_use_ray: bool, + worker_use_ray: Optional[bool] = None, max_parallel_loading_workers: Optional[int] = None, disable_custom_all_reduce: bool = False, tokenizer_pool_config: Optional[TokenizerPoolConfig] = None, ray_workers_use_nsight: bool = False, placement_group: Optional["PlacementGroup"] = None, + distributed_executor_backend: Optional[str] = None, ) -> None: self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size - self.worker_use_ray = worker_use_ray + self.distributed_executor_backend = distributed_executor_backend self.max_parallel_loading_workers = max_parallel_loading_workers self.disable_custom_all_reduce = disable_custom_all_reduce self.tokenizer_pool_config = tokenizer_pool_config @@ -600,14 +603,29 @@ def __init__( self.placement_group = placement_group self.world_size = pipeline_parallel_size * self.tensor_parallel_size - if self.world_size > 1: - self.worker_use_ray = True + if worker_use_ray: + if self.distributed_executor_backend is None: + self.distributed_executor_backend = "ray" + elif self.distributed_executor_backend != "ray": + raise ValueError(f"worker-use-ray can't be used with " + f"distributed executor backend " + f"'{self.distributed_executor_backend}'.") + + if self.distributed_executor_backend is None and self.world_size > 1: + from vllm.executor import ray_utils + ray_found = ray_utils.ray is not None + self.distributed_executor_backend = "ray" if ray_found else "mp" + self._verify_args() def _verify_args(self) -> None: if self.pipeline_parallel_size > 1: raise NotImplementedError( "Pipeline parallelism is not supported yet.") + if self.distributed_executor_backend not in ("ray", "mp", None): + raise ValueError( + "Unrecognized distributed executor backend. Supported values " + "are 'ray' or 'mp'.") if not self.disable_custom_all_reduce and self.world_size > 1: if is_hip(): self.disable_custom_all_reduce = True @@ -619,7 +637,8 @@ def _verify_args(self) -> None: logger.info( "Disabled the custom all-reduce kernel because it is not " "supported with pipeline parallelism.") - if self.ray_workers_use_nsight and not self.worker_use_ray: + if self.ray_workers_use_nsight and ( + not self.distributed_executor_backend == "ray"): raise ValueError("Unable to use nsight profiling unless workers " "run with Ray.") @@ -931,7 +950,8 @@ def create_draft_parallel_config( pipeline_parallel_size=target_parallel_config. pipeline_parallel_size, tensor_parallel_size=target_parallel_config.tensor_parallel_size, - worker_use_ray=target_parallel_config.worker_use_ray, + distributed_executor_backend=target_parallel_config. + distributed_executor_backend, max_parallel_loading_workers=target_parallel_config. max_parallel_loading_workers, disable_custom_all_reduce=target_parallel_config. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a64b0308966e5..a8e914d668769 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -36,6 +36,7 @@ class EngineArgs: seed: int = 0 max_model_len: Optional[int] = None worker_use_ray: bool = False + distributed_executor_backend: Optional[str] = None pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 max_parallel_loading_workers: Optional[int] = None @@ -225,10 +226,17 @@ def add_cli_args( ' Can be overridden per request via guided_decoding_backend' ' parameter.') # Parallel arguments - parser.add_argument('--worker-use-ray', - action='store_true', - help='Use Ray for distributed serving, will be ' - 'automatically set when using more than 1 GPU.') + parser.add_argument( + '--distributed-executor-backend', + choices=['ray', 'mp'], + default=EngineArgs.distributed_executor_backend, + help='Backend to use for distributed serving. When more than 1 GPU ' + 'is used, will be automatically set to "ray" if installed ' + 'or "mp" (multiprocessing) otherwise.') + parser.add_argument( + '--worker-use-ray', + action='store_true', + help='Deprecated, use --distributed-executor-backend=ray.') parser.add_argument('--pipeline-parallel-size', '-pp', type=int, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index a31f10b7748d3..8a37bac02823a 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -348,27 +348,31 @@ def from_engine_args( """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. engine_config = engine_args.create_engine_config() + distributed_executor_backend = ( + engine_config.parallel_config.distributed_executor_backend) if engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutorAsync executor_class = NeuronExecutorAsync elif engine_config.device_config.device_type == "cpu": - assert not engine_config.parallel_config.worker_use_ray, ( - "Ray is not supported with the CPU backend.") + assert distributed_executor_backend is None, ( + "Distributed execution is not supported with the CPU backend.") from vllm.executor.cpu_executor import CPUExecutorAsync executor_class = CPUExecutorAsync - elif engine_config.parallel_config.worker_use_ray: + elif distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync executor_class = RayGPUExecutorAsync + elif distributed_executor_backend == "mp": + from vllm.executor.multiproc_gpu_executor import ( + MultiprocessingGPUExecutorAsync) + executor_class = MultiprocessingGPUExecutorAsync else: - assert engine_config.parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") from vllm.executor.gpu_executor import GPUExecutorAsync executor_class = GPUExecutorAsync # Create the async LLM engine. engine = cls( - engine_config.parallel_config.worker_use_ray, + distributed_executor_backend == "ray", engine_args.engine_use_ray, **engine_config.to_dict(), executor_class=executor_class, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index b8c7a16ea5bcd..4e84a19198021 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -277,6 +277,8 @@ def from_engine_args( """Creates an LLM engine from the engine arguments.""" # Create the engine configs. engine_config = engine_args.create_engine_config() + distributed_executor_backend = ( + engine_config.parallel_config.distributed_executor_backend) # Initialize the cluster and specify the executor class. if engine_config.device_config.device_type == "neuron": @@ -285,13 +287,15 @@ def from_engine_args( elif engine_config.device_config.device_type == "cpu": from vllm.executor.cpu_executor import CPUExecutor executor_class = CPUExecutor - elif engine_config.parallel_config.worker_use_ray: + elif distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutor executor_class = RayGPUExecutor + elif distributed_executor_backend == "mp": + from vllm.executor.multiproc_gpu_executor import ( + MultiprocessingGPUExecutor) + executor_class = MultiprocessingGPUExecutor else: - assert engine_config.parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") from vllm.executor.gpu_executor import GPUExecutor executor_class = GPUExecutor diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py new file mode 100644 index 0000000000000..2a7b99c9dcbe1 --- /dev/null +++ b/vllm/executor/multiproc_gpu_executor.py @@ -0,0 +1,140 @@ +import asyncio +import os +from functools import partial +from typing import Any, Dict, Optional, Tuple + +from vllm.executor.distributed_gpu_executor import ( # yapf: disable + DistributedGPUExecutor, DistributedGPUExecutorAsync) +from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, + ResultHandler, WorkerMonitor) +from vllm.logger import init_logger +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + get_vllm_instance_id, make_async) + +logger = init_logger(__name__) + + +class MultiprocessingGPUExecutor(DistributedGPUExecutor): + """Python multiprocessing-based multi-GPU executor""" + + def _init_executor(self) -> None: + assert ( + not self.speculative_config + ), "Speculative decoding not yet supported for MultiProcGPU backend." + + # Create the parallel GPU workers. + world_size = self.parallel_config.tensor_parallel_size + + # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers + if "CUDA_VISIBLE_DEVICES" not in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = (",".join( + map(str, range(world_size)))) + + # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers + os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() + + from torch.cuda import device_count + assert world_size <= device_count(), ( + "please set tensor_parallel_size to less than max local gpu count") + + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + + if world_size == 1: + self.workers = [] + else: + result_handler = ResultHandler() + self.workers = [ + ProcessWorkerWrapper( + result_handler, + partial( + self._create_worker, + rank=rank, + local_rank=rank, + distributed_init_method=distributed_init_method, + )) for rank in range(1, world_size) + ] + + self.worker_monitor = WorkerMonitor(self.workers, result_handler) + result_handler.start() + self.worker_monitor.start() + + self.driver_worker = self._create_worker( + distributed_init_method=distributed_init_method) + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + def shutdown(self): + if (worker_monitor := getattr(self, "worker_monitor", + None)) is not None: + worker_monitor.close() + + def _run_workers( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + # Start the workers first. + worker_outputs = [ + worker.execute_method(method, *args, **kwargs) + for worker in self.workers + ] + + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + # Start the driver worker after all the ray workers. + driver_worker_method = getattr(self.driver_worker, method) + driver_worker_output = driver_worker_method(*driver_args, + **driver_kwargs) + + # Get the results of the workers. + return [driver_worker_output + ] + [output.get() for output in worker_outputs] + + def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + if not self.worker_monitor.is_alive(): + raise RuntimeError("Worker processes are not running") + + +class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor, + DistributedGPUExecutorAsync): + + async def _run_workers_async( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + driver_executor = make_async(getattr(self.driver_worker, method)) + + # Run all the workers asynchronously. + coros = [driver_executor(*driver_args, **driver_kwargs)] + [ + worker.execute_method_async(method, *args, **kwargs) + for worker in self.workers + ] + + return await asyncio.gather(*coros) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index afc1c886722e6..9cb03ec8c3f5a 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -31,7 +31,7 @@ def _init_executor(self) -> None: assert (not self.speculative_config ), "Speculative decoding not yet supported for RayGPU backend." - assert self.parallel_config.worker_use_ray + assert self.parallel_config.distributed_executor_backend == "ray" placement_group = self.parallel_config.placement_group # Disable Ray usage stats collection. @@ -264,7 +264,7 @@ def _compiled_ray_dag(self): f"required, but found {current_version}") from ray.dag import InputNode, MultiOutputNode - assert self.parallel_config.worker_use_ray + assert self.parallel_config.distributed_executor_backend == "ray" # Right now, compiled DAG requires at least 1 arg. We send # a dummy value for now. It will be fixed soon. diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 9db3ae2ff8298..4704f5f1b1a10 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -44,7 +44,7 @@ def execute_model_compiled_dag_remote(self, ignored): except ImportError as e: logger.warning( - "Failed to import Ray with %r. For distributed inference, " + "Failed to import Ray with %r. For multi-node inference, " "please install Ray with `pip install ray`.", e) ray = None # type: ignore RayWorkerWrapper = None # type: ignore @@ -67,7 +67,7 @@ def initialize_ray_cluster( """ if ray is None: raise ImportError( - "Ray is not installed. Please install Ray to use distributed " + "Ray is not installed. Please install Ray to use multi-node " "serving.") # Connect to a ray cluster.