Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2820,6 +2820,9 @@ class KVTransferConfig(BaseModel):
# The KV connector port, used to build distributed connection
kv_port: int = 14579

# any extra config that the connector may need
kv_connector_extra_config: dict[str, Any] = {}

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down Expand Up @@ -2879,6 +2882,9 @@ def is_kv_consumer(self) -> bool:
return self.kv_connector is not None and \
self.kv_role in ["kv_consumer", "kv_both"]

def get_from_extra_config(self, key, default) -> Any:
return self.kv_connector_extra_config.get(key, default)


class CompilationLevel:
# constants for the levels of the compilation process
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
- Distributed KV cache transmission using PyNccl pipes.
- Non-blocking `insert`, blocking `drop_select`.
- Use CPU signal pipe to avoid racing condition
- Handles buffer size constraints and provide backpressure mechanism to
- Handles buffer size constraints and provide backpressure mechanism to
stop the prefill instance when the decode instance is slow.
"""
import threading
Expand Down
22 changes: 12 additions & 10 deletions vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
"""
This module implements a PyNccl pipe for sending and receiving
Optional[torch.Tensor] between distributed ranks with advanced
This module implements a PyNccl pipe for sending and receiving
Optional[torch.Tensor] between distributed ranks with advanced
communication features.

Key Features:
Expand Down Expand Up @@ -59,11 +59,13 @@ def __init__(self,
self.device = self._select_device(device)

# build distributed connection and send/recv implementation
store_timeout = self.config.get_from_extra_config("store_timeout", 300)
self.group = StatelessProcessGroup.create(
host=self.config.kv_ip,
port=self.config.kv_port + port_offset,
rank=self.kv_rank,
world_size=self.kv_parallel_size,
store_timeout=store_timeout,
)
# add a barrier to make sure the connection is initiated properly
self.group.barrier()
Expand Down Expand Up @@ -134,11 +136,11 @@ def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor:
Create a buffer to receive the tensor based on the provided metadata.

Parameters:
- metadata: A dictionary with keys "dtype" and "shape", describing
- metadata: A dictionary with keys "dtype" and "shape", describing
the tensor's data type and shape.

Returns:
- buffer: A tensor of the specified type and shape, allocated on
- buffer: A tensor of the specified type and shape, allocated on
self.device.
"""
return torch.empty(metadata["shape"],
Expand All @@ -159,18 +161,18 @@ def _recv_metadata(self) -> Metadata:
Receive the metadata dictionary from the target rank.

Returns:
- metadata: A dictionary with keys "dtype" and "shape" describing
- metadata: A dictionary with keys "dtype" and "shape" describing
the tensor.
"""
return self.group.recv_obj(self.target_rank_for_recv)

def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
"""
The actual implementation of sending the tensor and its metadata to the
The actual implementation of sending the tensor and its metadata to the
target rank.

Parameters:
- tensor: The input tensor to be sent, or None if no tensor is
- tensor: The input tensor to be sent, or None if no tensor is
being sent.
"""
metadata = self._make_metadata(tensor)
Expand All @@ -181,7 +183,7 @@ def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:

def _recv_impl(self) -> Optional[torch.Tensor]:
"""
The actual implementation of receiving a tensor and its metadata from
The actual implementation of receiving a tensor and its metadata from
the target rank.

Returns:
Expand Down Expand Up @@ -213,7 +215,7 @@ def send_tensor_wrapper(self, tensor: Optional[torch.Tensor],

def block_if_full(self):
"""
Block the current thread if the buffer size is larger than the
Block the current thread if the buffer size is larger than the
threshold.
"""
while self.buffer_size > self.buffer_size_thresh:
Expand All @@ -222,7 +224,7 @@ def block_if_full(self):

def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
"""
Sends a tensor and its metadata to the destination rank in a
Sends a tensor and its metadata to the destination rank in a
non-blocking way.

Parameters:
Expand Down
3 changes: 3 additions & 0 deletions vllm/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import datetime
import pickle
import time
from collections import deque
Expand Down Expand Up @@ -217,6 +218,7 @@ def create(
rank: int,
world_size: int,
data_expiration_seconds: int = 3600,
store_timeout: int = 300,
) -> "StatelessProcessGroup":
"""A replacement for `torch.distributed.init_process_group` that does not
pollute the global state.
Expand All @@ -238,6 +240,7 @@ def create(
port=port,
world_size=world_size,
is_master=(rank == 0),
timeout=datetime.timedelta(seconds=store_timeout),
)

return StatelessProcessGroup(
Expand Down