diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 66efe3ed3298..d96f0183bc67 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -134,7 +134,9 @@ steps: - tests/compile/test_basic_correctness - examples/offline_inference/rlhf.py - examples/offline_inference/rlhf_colocate.py + - tests/examples/offline_inference/data_parallel.py commands: + - VLLM_USE_V1=1 python3 ../examples/offline_inference/data_parallel.py - pytest -v -s distributed/test_utils.py - pytest -v -s compile/test_basic_correctness.py - pytest -v -s distributed/test_pynccl.py diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py new file mode 100644 index 000000000000..a9544c8cf8a8 --- /dev/null +++ b/examples/offline_inference/data_parallel.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# usage: VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py +# we need to have a launcher to create multiple data parallel +# ranks. And each rank will create a vLLM instance to process its own prompts. +import os + +from vllm import LLM, SamplingParams +from vllm.utils import get_open_port + + +def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank): + os.environ["VLLM_DP_RANK"] = str(dp_rank) + os.environ["VLLM_DP_SIZE"] = str(dp_size) + os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip + os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port) + # set devices for each dp_rank + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( + str(i) for i in range(dp_rank * GPUs_per_dp_rank, (dp_rank + 1) * + GPUs_per_dp_rank)) + + # Sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # with DP, each rank should process different prompts. + # usually all the DP ranks process a full dataset, + # and each rank processes a different part of the dataset. + promts_per_rank = len(prompts) // dp_size + start = dp_rank * promts_per_rank + end = start + promts_per_rank + prompts = prompts[start:end] + if len(prompts) == 0: + # if any rank has no prompts to process, + # we need to set a placeholder prompt + prompts = ["Placeholder"] + print(f"DP rank {dp_rank} needs to process {len(prompts)} prompts") + + # Create a sampling params object. + # since we are doing data parallel, every rank can have different + # sampling params. here we set different max_tokens for different + # ranks for demonstration. + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=16 * (dp_rank + 1)) + + # Create an LLM. + llm = LLM(model="facebook/opt-125m", tensor_parallel_size=2, enforce_eager=True) + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print( + f"DP rank {dp_rank}, Prompt: {prompt!r}, " + f"Generated text: {generated_text!r}") + + +if __name__ == "__main__": + from multiprocessing import Process + dp_size = 2 + GPUs_per_dp_rank = 2 + dp_master_ip = "127.0.0.1" + dp_master_port = get_open_port() + procs = [] + for i in range(dp_size): + proc = Process(target=main, + args=(dp_size, i, dp_master_ip, dp_master_port, + GPUs_per_dp_rank)) + proc.start() + procs.append(proc) + for proc in procs: + proc.join() diff --git a/vllm/config.py b/vllm/config.py index d6e197fe988a..27abdaf9a828 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -16,6 +16,7 @@ import torch from pydantic import BaseModel, Field, PrivateAttr +from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig import vllm.envs as envs @@ -1290,6 +1291,11 @@ class ParallelConfig: pipeline_parallel_size: int = 1 # Number of pipeline parallel groups. tensor_parallel_size: int = 1 # Number of tensor parallel groups. + data_parallel_size: int = 1 # Number of data parallel groups. + data_parallel_rank: int = 0 # Rank of the data parallel group. + # IP of the data parallel master. + data_parallel_master_ip: str = "127.0.0.1" + data_parallel_master_port: int = 29500 # Port of the data parallel master. # Maximum number of multiple batches # when load model sequentially. To avoid RAM OOM when using tensor @@ -1323,10 +1329,55 @@ class ParallelConfig: worker_cls: str = "auto" sd_worker_cls: str = "auto" + # world_size is TPxPP, it affects the number of workers we create. world_size: int = field(init=False) + # world_size_across_dp is TPxPPxDP, it is the size of the world + # including data parallelism. + world_size_across_dp: int = field(init=False) rank: int = 0 + def get_next_dp_init_port(self) -> int: + """ + We might need to initialize process groups in multiple + processes that is related to data parallelism, + e.g. both in the worker and in the engine, which + can live in different processes. To avoid port conflicts, we + increment the port number each time we need to initialize a + new process group related to data parallelism. + """ + answer = self.data_parallel_master_port + self.data_parallel_master_port += 1 + return answer + + def stateless_init_dp_group(self) -> "ProcessGroup": + from vllm.distributed.utils import ( + stateless_init_torch_distributed_process_group) + + # use gloo since the engine process might not have cuda device + dp_group = stateless_init_torch_distributed_process_group( + self.data_parallel_master_ip, + self.get_next_dp_init_port(), + self.data_parallel_rank, + self.data_parallel_size, + backend="gloo") + + return dp_group + + @staticmethod + def has_unfinished_dp(dp_group: "ProcessGroup", + has_unfinished: bool) -> bool: + tensor = torch.tensor([has_unfinished], + dtype=torch.int32, + device="cpu") + # dp rank 0: has_unfinished_seqs=True + # dp rank 1: has_unfinished_seqs=False + # aggregated: has_unfinished_seqs=True + # so this is an OR operation, i.e. MAX in integers + torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group) + aggregated_has_unfinished = bool(tensor.item()) + return aggregated_has_unfinished + def compute_hash(self): """ Provide a hash that uniquely identifies all the configs @@ -1344,6 +1395,12 @@ def __post_init__(self) -> None: self.world_size = self.pipeline_parallel_size * \ self.tensor_parallel_size + self.data_parallel_size = envs.VLLM_DP_SIZE + self.data_parallel_rank = envs.VLLM_DP_RANK + self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP + self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT + self.world_size_across_dp = self.world_size * self.data_parallel_size + ray_only_devices = ["tpu"] from vllm.platforms import current_platform if (current_platform.device_type in ray_only_devices diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index f806f8b39ef9..07c9ff506092 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -16,8 +16,8 @@ def __init__(self, device_group: Optional[ProcessGroup] = None, unique_name: str = ""): super().__init__(cpu_group, device, device_group, unique_name) - if "pp" in unique_name: - # pipeline parallel does not need custom allreduce + if "tp" not in unique_name: + # only tp uses custom allreduce use_custom_allreduce = False else: from vllm.distributed.parallel_state import ( diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index a2614ed5d0bd..90f7f2d0f982 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -87,6 +87,7 @@ def __init__(self, return rank = dist.get_rank(group=self.group) + self.rank = rank world_size = dist.get_world_size(group=self.group) if world_size == 1: # No need to initialize custom allreduce for single GPU case. @@ -201,8 +202,10 @@ def create_shared_buffer( @staticmethod def free_shared_buffer(pointers: List[int], - group: Optional[ProcessGroup] = None) -> None: - rank = dist.get_rank(group=group) + group: Optional[ProcessGroup] = None, + rank: Optional[int] = None) -> None: + if rank is None: + rank = dist.get_rank(group=group) lib = CudaRTLibrary() lib.cudaFree(ctypes.c_void_p(pointers[rank])) @@ -298,8 +301,8 @@ def close(self): if not self.disabled and self._ptr: ops.dispose(self._ptr) self._ptr = 0 - self.free_shared_buffer(self.meta_ptrs) - self.free_shared_buffer(self.buffer_ptrs) + self.free_shared_buffer(self.meta_ptrs, rank=self.rank) + self.free_shared_buffer(self.buffer_ptrs, rank=self.rank) def __del__(self): self.close() diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 781f870a756c..83484cd73550 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -750,6 +750,13 @@ def get_tp_group() -> GroupCoordinator: _PP: Optional[GroupCoordinator] = None +_DP: Optional[GroupCoordinator] = None + + +def get_dp_group() -> GroupCoordinator: + assert _DP is not None, ("data parallel group is not initialized") + return _DP + def get_pp_group() -> GroupCoordinator: assert _PP is not None, ( @@ -811,6 +818,21 @@ def init_distributed_environment( "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", world_size, rank, local_rank, distributed_init_method, backend) + from vllm.config import get_current_vllm_config + config = get_current_vllm_config() + if config is not None and config.parallel_config.data_parallel_size > 1: + parallel_config = config.parallel_config + # adjust to take into account data parallelism + # offset the rank by the data parallel rank + rank = parallel_config.data_parallel_rank * world_size + rank + # adjust the world size to take into account data parallelism + world_size = parallel_config.world_size_across_dp + ip = parallel_config.data_parallel_master_ip + port = parallel_config.get_next_dp_init_port() + distributed_init_method = f"tcp://{ip}:{port}" # noqa + logger.info( + "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP", + world_size, rank, distributed_init_method) if not torch.distributed.is_initialized(): assert distributed_init_method is not None, ( "distributed_init_method must be provided when initializing " @@ -870,20 +892,28 @@ def initialize_model_parallel( # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() backend = backend or torch.distributed.get_backend( get_world_group().device_group) + data_parallel_size = 1 + from vllm.config import get_current_vllm_config + config = get_current_vllm_config() + if config is not None: + data_parallel_size = config.parallel_config.data_parallel_size + + # the layout order is: DP x PP x TP + # to get group_ranks for each dimension, transpose that dimension to the + # last dimension, then reshape to 2D, then unbind the last dimension + all_ranks = torch.arange(world_size).reshape( + data_parallel_size, pipeline_model_parallel_size, + tensor_model_parallel_size) # noqa + # Build the tensor model-parallel groups. - num_tensor_model_parallel_groups: int = (world_size // - tensor_model_parallel_size) global _TP assert _TP is None, ("tensor model parallel group is already initialized") - group_ranks = [] - for i in range(num_tensor_model_parallel_groups): - ranks = list( - range(i * tensor_model_parallel_size, - (i + 1) * tensor_model_parallel_size)) - group_ranks.append(ranks) + group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] # message queue broadcaster is only used in tensor model parallel group _TP = init_model_parallel_group(group_ranks, @@ -893,20 +923,33 @@ def initialize_model_parallel( group_name="tp") # Build the pipeline model-parallel groups. - num_pipeline_model_parallel_groups: int = (world_size // - pipeline_model_parallel_size) global _PP assert _PP is None, ( "pipeline model parallel group is already initialized") - group_ranks = [] - for i in range(num_pipeline_model_parallel_groups): - ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) - group_ranks.append(ranks) + group_ranks = all_ranks.transpose(1, 2).reshape( + -1, pipeline_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="pp") + global _DP + assert _DP is None, ("data parallel group is already initialized") + group_ranks = all_ranks.transpose(0, + 2).reshape(-1, + data_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + _DP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="dp") + + logger.info( + "rank %s in world size %s is assigned as " + "DP rank %s, PP rank %s, TP rank %s", rank, world_size, + _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group) + def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: """ @@ -1011,6 +1054,11 @@ def destroy_model_parallel(): _PP.destroy() _PP = None + global _DP + if _DP: + _DP.destroy() + _DP = None + def destroy_distributed_environment(): global _WORLD diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 84f8c0a8e51c..79f9a84b476f 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -11,7 +11,11 @@ from typing import Any, Deque, Dict, Optional, Sequence, Tuple import torch -from torch.distributed import TCPStore +from torch.distributed import ProcessGroup, TCPStore +from torch.distributed.distributed_c10d import (Backend, PrefixStore, + _get_default_timeout, + is_nccl_available) +from torch.distributed.rendezvous import rendezvous import vllm.envs as envs from vllm.logger import init_logger @@ -227,3 +231,88 @@ def create( world_size=world_size, store=store, data_expiration_seconds=data_expiration_seconds) + + +def stateless_init_torch_distributed_process_group( + host: str, port: int, rank: int, world_size: int, + backend: str) -> ProcessGroup: + """ + A replacement for `torch.distributed.init_process_group` that does not + pollute the global state. The created ProcessGroup object can be used for + some operations such as `allreduce`, because it does not depend on the + global rank. However, some operations such as `broadcast` cannot be used + because it depends on the global rank. + + # TODO: ask for help from PyTorch team if we need the `broadcast` operation. + + This function is useful when we are not sure about the total number of + processes in the process group. For example, we may have process + 1, 2, ..., 8 who want to communicate, and process 9 might be the same + process as process 1, or it might be a different process; process 10 + might be the same process as process 5, or it might be a different process. + In this case, how can we reliably form a communication channel within + process 9 and 10, without affecting the communication channel within + process 1, 2, ..., 8? + + One possible solution is to figure out if process 9 and 10 are the same + as process 1 and 5 beforehand, and then form a communication channel + based on the information, adjusting the ranks and world_size etc. However, + figuring out the information is not always easy, and it will interfere + with the main communication channel. + + Our solution is to always form a communication channel with process 1, 2, + ..., 8, and then use this function to form another communication channel + with process 9 and 10. This way, regardless of whether process 9 and 10 + are the same as process 1 and 5, the main communication channel is + always formed with process 1, 2, ..., 8, and the additional communication + channel is formed with process 9 and 10. + """ + init_method = f"tcp://{host}:{port}" + backend = Backend(backend) # it is basically string + timeout = _get_default_timeout(backend) + + store, rank, world_size = next( + rendezvous(init_method, rank, world_size, timeout=timeout)) + store.set_timeout(timeout) + + group_rank = rank + group_size = world_size + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + prefix_store = PrefixStore(init_method, store) + + pg_options = ProcessGroup.Options(backend=backend, timeout=timeout) + + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + group_size, + pg_options, + ) + + if backend == "gloo": + from torch.distributed.distributed_c10d import ProcessGroupGloo + backend_class = ProcessGroupGloo(prefix_store, + group_rank, + group_size, + timeout=timeout) + backend_type = ProcessGroup.BackendType.GLOO + device = torch.device("cpu") + elif backend == "nccl": + assert is_nccl_available() + from torch.distributed.distributed_c10d import ProcessGroupNCCL + + backend_options = ProcessGroupNCCL.Options() + backend_options._timeout = timeout + + backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, + backend_options) + backend_type = ProcessGroup.BackendType.NCCL + device = torch.device("cuda") + + backend_class._set_sequence_number_for_group() + + pg._register_backend(device, backend_type, backend_class) + + return pg diff --git a/vllm/envs.py b/vllm/envs.py index 45547416314f..1eb9b9f1bbf5 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -90,6 +90,10 @@ VLLM_RAY_BUNDLE_INDICES: str = "" VLLM_CUDART_SO_PATH: Optional[str] = None VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True + VLLM_DP_RANK: int = 0 + VLLM_DP_SIZE: int = 1 + VLLM_DP_MASTER_IP: str = "" + VLLM_DP_MASTER_PORT: int = 0 def get_default_cache_root(): @@ -593,6 +597,22 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH": lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in ("1", "true"), + + # Rank of the process in the data parallel setting + "VLLM_DP_RANK": + lambda: int(os.getenv("VLLM_DP_RANK", "0")), + + # World size of the data parallel setting + "VLLM_DP_SIZE": + lambda: int(os.getenv("VLLM_DP_SIZE", "1")), + + # IP address of the master node in the data parallel setting + "VLLM_DP_MASTER_IP": + lambda: os.getenv("VLLM_DP_MASTER_IP", "127.0.0.1"), + + # Port of the master node in the data parallel setting + "VLLM_DP_MASTER_PORT": + lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")), } # end-env-vars-definition diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 10de8bc593ab..b91816af1b6d 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -4,9 +4,10 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional import torch +import torch.distributed as dist import vllm.envs as envs from vllm.config import VllmConfig @@ -32,6 +33,8 @@ class ForwardContext: attn_metadata: "AttentionMetadata" # set dynamically for each forward pass # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass + num_tokens_across_dp: Optional[ + List[int]] = None # set dynamically for each forward pass _forward_context: Optional[ForwardContext] = None @@ -48,7 +51,8 @@ def get_forward_context() -> ForwardContext: @contextmanager def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig, - virtual_engine: int = 0): + virtual_engine: int = 0, + num_tokens: int = 0): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -57,12 +61,36 @@ def set_forward_context(attn_metadata: Any, need_to_track_batchsize = track_batchsize and attn_metadata is not None if need_to_track_batchsize: forward_start_time = time.perf_counter() + num_tokens_across_dp = None + if vllm_config.parallel_config.data_parallel_size > 1: + dp_size = vllm_config.parallel_config.data_parallel_size + dp_rank = vllm_config.parallel_config.data_parallel_rank + if attn_metadata is not None: + if hasattr(attn_metadata, "num_prefill_tokens"): + # for v0 attention backends + batchsize = attn_metadata.num_prefill_tokens + \ + attn_metadata.num_decode_tokens + else: + # for v1 attention backends + batchsize = attn_metadata.num_input_tokens + else: + batchsize = num_tokens + num_tokens_across_dp = [0] * dp_size + num_tokens_across_dp[dp_rank] = batchsize + num_tokens_tensor = torch.tensor(num_tokens_across_dp, + device="cpu", + dtype=torch.int32) + from vllm.distributed.parallel_state import get_dp_group + dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) + num_tokens_across_dp = num_tokens_tensor.tolist() + global _forward_context prev_context = _forward_context _forward_context = ForwardContext( attn_layers=vllm_config.compilation_config.static_forward_context, virtual_engine=virtual_engine, - attn_metadata=attn_metadata) + attn_metadata=attn_metadata, + num_tokens_across_dp=num_tokens_across_dp) try: yield finally: diff --git a/vllm/utils.py b/vllm/utils.py index 4d3f90c95a7d..640260411267 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -501,6 +501,24 @@ def get_open_zmq_ipc_path() -> str: def get_open_port() -> int: + """ + Get an open port for the vLLM process to listen on. + An edge case to handle, is when we run data parallel, + we need to avoid ports that are potentially used by + the data parallel master process. + Right now we reserve 10 ports for the data parallel master + process. Currently it uses 2 ports. + """ + if "VLLM_DP_MASTER_PORT" in os.environ: + dp_port = envs.VLLM_DP_MASTER_PORT + while True: + port = _get_open_port() + if port >= dp_port and port < dp_port + 10: + continue + return port + return _get_open_port() + +def _get_open_port() -> int: port = envs.VLLM_PORT if port is not None: while True: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 03825d6ea430..981d23237e2a 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -219,6 +219,9 @@ def sleep(self, level: int = 1): def wake_up(self): self.model_executor.wake_up() + def execute_dummy_batch(self): + self.model_executor.collective_rpc("execute_dummy_batch") + def add_lora(self, lora_request: LoRARequest) -> None: self.model_executor.add_lora(lora_request) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 43ba7583c662..e898a872c62b 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -87,6 +87,12 @@ def sleep(self, level: int = 1) -> None: def wake_up(self) -> None: raise NotImplementedError + def execute_dummy_batch(self) -> None: + raise NotImplementedError + + async def execute_dummy_batch_async(self) -> None: + raise NotImplementedError + def abort_requests(self, request_ids: List[str]) -> None: raise NotImplementedError @@ -156,6 +162,9 @@ def sleep(self, level: int = 1) -> None: def wake_up(self) -> None: self.engine_core.wake_up() + def execute_dummy_batch(self) -> None: + self.engine_core.execute_dummy_batch() + def add_lora(self, lora_request: LoRARequest) -> None: self.engine_core.add_lora(lora_request) @@ -331,6 +340,8 @@ def sleep(self, level: int = 1) -> None: def wake_up(self) -> None: self._call_utility("wake_up") + def execute_dummy_batch(self) -> None: + self._call_utility("execute_dummy_batch") class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" @@ -414,5 +425,8 @@ async def sleep_async(self, level: int = 1) -> None: async def wake_up_async(self) -> None: await self._call_utility_async("wake_up") + async def execute_dummy_batch_async(self) -> None: + await self._call_utility_async("execute_dummy_batch") + async def add_lora_async(self, lora_request: LoRARequest) -> None: await self._call_utility_async("add_lora", lora_request) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 6b7de4deed39..04c7ee109e0b 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -4,7 +4,7 @@ from typing_extensions import TypeVar -from vllm.config import VllmConfig +from vllm.config import ParallelConfig, VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING @@ -47,6 +47,13 @@ def __init__( self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config + # important: init dp group before init the engine_core + self.parallel_config = vllm_config.parallel_config + self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa + self.should_execute_dummy_batch = False + if self.dp_enabled: + self.dp_group = self.parallel_config.stateless_init_dp_group() + # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( model_config=vllm_config.model_config, @@ -106,7 +113,17 @@ def get_num_unfinished_requests(self) -> int: return self.output_processor.get_num_unfinished_requests() def has_unfinished_requests(self) -> bool: - return self.output_processor.has_unfinished_requests() + has_unfinished = self.output_processor.has_unfinished_requests() + if not self.dp_enabled: + return has_unfinished + return self.has_unfinished_requests_dp(has_unfinished) + + def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool: + aggregated_has_unfinished = ParallelConfig.has_unfinished_dp( + self.dp_group, has_unfinished) + if not has_unfinished and aggregated_has_unfinished: + self.should_execute_dummy_batch = True + return aggregated_has_unfinished @classmethod def validate_outputs(cls, outputs, output_type): @@ -145,6 +162,11 @@ def add_request( def step(self) -> List[RequestOutput]: + if self.should_execute_dummy_batch: + self.should_execute_dummy_batch = False + self.engine_core.execute_dummy_batch() + return [] + # 1) Get EngineCoreOutput from the EngineCore. outputs = self.engine_core.get_output() diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index e3f07172d8cd..14492f273ed3 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -239,7 +239,7 @@ def __init__( ready_socket.send_string(WorkerProc.READY_STR) ready_socket.send(payload) - self.worker.init_device() + wrapper.init_device() self.worker.load_model() @staticmethod diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 31fe095a91bc..e215cbae6f02 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1164,7 +1164,7 @@ def _dummy_run( for k, v in self.intermediate_tensors.items() }) - with set_forward_context(None, self.vllm_config): + with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): hidden_states = model( input_ids=input_ids, positions=positions, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 10154a752393..ece0fa555342 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -235,6 +235,9 @@ def profile(self, is_start: bool = True): else: self.profiler.stop() + def execute_dummy_batch(self) -> None: + self.model_runner._dummy_run(1) + def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 190429074d56..44c26ed350a8 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -567,6 +567,11 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: self.worker = worker_class(**kwargs) assert self.worker is not None + def init_device(self): + with set_current_vllm_config(self.vllm_config): + # To make vLLM config available during device initialization + self.worker.init_device() # type: ignore + def execute_method(self, method: Union[str, bytes], *args, **kwargs): try: target = self if self.worker is None else self.worker