diff --git a/vllm/config.py b/vllm/config.py index 3f1bff498129..a175890fe087 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2820,6 +2820,9 @@ class KVTransferConfig(BaseModel): # The KV connector port, used to build distributed connection kv_port: int = 14579 + # any extra config that the connector may need + kv_connector_extra_config: dict[str, Any] = {} + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -2879,6 +2882,9 @@ def is_kv_consumer(self) -> bool: return self.kv_connector is not None and \ self.kv_role in ["kv_consumer", "kv_both"] + def get_from_extra_config(self, key, default) -> Any: + return self.kv_connector_extra_config.get(key, default) + class CompilationLevel: # constants for the levels of the compilation process diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index 3462f7de020e..10bbfe1ddd8a 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -6,7 +6,7 @@ - Distributed KV cache transmission using PyNccl pipes. - Non-blocking `insert`, blocking `drop_select`. - Use CPU signal pipe to avoid racing condition - - Handles buffer size constraints and provide backpressure mechanism to + - Handles buffer size constraints and provide backpressure mechanism to stop the prefill instance when the decode instance is slow. """ import threading diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py index 7aa53d07a9ef..e8bf607eb899 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """ - This module implements a PyNccl pipe for sending and receiving - Optional[torch.Tensor] between distributed ranks with advanced + This module implements a PyNccl pipe for sending and receiving + Optional[torch.Tensor] between distributed ranks with advanced communication features. Key Features: @@ -59,11 +59,13 @@ def __init__(self, self.device = self._select_device(device) # build distributed connection and send/recv implementation + store_timeout = self.config.get_from_extra_config("store_timeout", 300) self.group = StatelessProcessGroup.create( host=self.config.kv_ip, port=self.config.kv_port + port_offset, rank=self.kv_rank, world_size=self.kv_parallel_size, + store_timeout=store_timeout, ) # add a barrier to make sure the connection is initiated properly self.group.barrier() @@ -134,11 +136,11 @@ def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor: Create a buffer to receive the tensor based on the provided metadata. Parameters: - - metadata: A dictionary with keys "dtype" and "shape", describing + - metadata: A dictionary with keys "dtype" and "shape", describing the tensor's data type and shape. Returns: - - buffer: A tensor of the specified type and shape, allocated on + - buffer: A tensor of the specified type and shape, allocated on self.device. """ return torch.empty(metadata["shape"], @@ -159,18 +161,18 @@ def _recv_metadata(self) -> Metadata: Receive the metadata dictionary from the target rank. Returns: - - metadata: A dictionary with keys "dtype" and "shape" describing + - metadata: A dictionary with keys "dtype" and "shape" describing the tensor. """ return self.group.recv_obj(self.target_rank_for_recv) def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: """ - The actual implementation of sending the tensor and its metadata to the + The actual implementation of sending the tensor and its metadata to the target rank. Parameters: - - tensor: The input tensor to be sent, or None if no tensor is + - tensor: The input tensor to be sent, or None if no tensor is being sent. """ metadata = self._make_metadata(tensor) @@ -181,7 +183,7 @@ def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: def _recv_impl(self) -> Optional[torch.Tensor]: """ - The actual implementation of receiving a tensor and its metadata from + The actual implementation of receiving a tensor and its metadata from the target rank. Returns: @@ -213,7 +215,7 @@ def send_tensor_wrapper(self, tensor: Optional[torch.Tensor], def block_if_full(self): """ - Block the current thread if the buffer size is larger than the + Block the current thread if the buffer size is larger than the threshold. """ while self.buffer_size > self.buffer_size_thresh: @@ -222,7 +224,7 @@ def block_if_full(self): def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: """ - Sends a tensor and its metadata to the destination rank in a + Sends a tensor and its metadata to the destination rank in a non-blocking way. Parameters: diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index d6fca4f0221b..25202062e975 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -5,6 +5,7 @@ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import dataclasses +import datetime import pickle import time from collections import deque @@ -217,6 +218,7 @@ def create( rank: int, world_size: int, data_expiration_seconds: int = 3600, + store_timeout: int = 300, ) -> "StatelessProcessGroup": """A replacement for `torch.distributed.init_process_group` that does not pollute the global state. @@ -238,6 +240,7 @@ def create( port=port, world_size=world_size, is_master=(rank == 0), + timeout=datetime.timedelta(seconds=store_timeout), ) return StatelessProcessGroup(