Skip to content

Commit 1bd32bc

Browse files
authored
[Config][Disaggregated] Add timeout configuration for the torch.store and add KVTransferConfig.kv_connector_extra_config (#14367)
Signed-off-by: Mathis Felardos <mathis@mistral.ai>
1 parent 128bf75 commit 1bd32bc

File tree

4 files changed

+22
-11
lines changed

4 files changed

+22
-11
lines changed

vllm/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2837,6 +2837,9 @@ class KVTransferConfig(BaseModel):
28372837
# The KV connector port, used to build distributed connection
28382838
kv_port: int = 14579
28392839

2840+
# any extra config that the connector may need
2841+
kv_connector_extra_config: dict[str, Any] = {}
2842+
28402843
def compute_hash(self) -> str:
28412844
"""
28422845
WARNING: Whenever a new field is added to this config,
@@ -2896,6 +2899,9 @@ def is_kv_consumer(self) -> bool:
28962899
return self.kv_connector is not None and \
28972900
self.kv_role in ["kv_consumer", "kv_both"]
28982901

2902+
def get_from_extra_config(self, key, default) -> Any:
2903+
return self.kv_connector_extra_config.get(key, default)
2904+
28992905

29002906
class CompilationLevel:
29012907
# constants for the levels of the compilation process

vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
- Distributed KV cache transmission using PyNccl pipes.
77
- Non-blocking `insert`, blocking `drop_select`.
88
- Use CPU signal pipe to avoid racing condition
9-
- Handles buffer size constraints and provide backpressure mechanism to
9+
- Handles buffer size constraints and provide backpressure mechanism to
1010
stop the prefill instance when the decode instance is slow.
1111
"""
1212
import threading

vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""
3-
This module implements a PyNccl pipe for sending and receiving
4-
Optional[torch.Tensor] between distributed ranks with advanced
3+
This module implements a PyNccl pipe for sending and receiving
4+
Optional[torch.Tensor] between distributed ranks with advanced
55
communication features.
66
77
Key Features:
@@ -59,11 +59,13 @@ def __init__(self,
5959
self.device = self._select_device(device)
6060

6161
# build distributed connection and send/recv implementation
62+
store_timeout = self.config.get_from_extra_config("store_timeout", 300)
6263
self.group = StatelessProcessGroup.create(
6364
host=self.config.kv_ip,
6465
port=self.config.kv_port + port_offset,
6566
rank=self.kv_rank,
6667
world_size=self.kv_parallel_size,
68+
store_timeout=store_timeout,
6769
)
6870
# add a barrier to make sure the connection is initiated properly
6971
self.group.barrier()
@@ -134,11 +136,11 @@ def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor:
134136
Create a buffer to receive the tensor based on the provided metadata.
135137
136138
Parameters:
137-
- metadata: A dictionary with keys "dtype" and "shape", describing
139+
- metadata: A dictionary with keys "dtype" and "shape", describing
138140
the tensor's data type and shape.
139141
140142
Returns:
141-
- buffer: A tensor of the specified type and shape, allocated on
143+
- buffer: A tensor of the specified type and shape, allocated on
142144
self.device.
143145
"""
144146
return torch.empty(metadata["shape"],
@@ -159,18 +161,18 @@ def _recv_metadata(self) -> Metadata:
159161
Receive the metadata dictionary from the target rank.
160162
161163
Returns:
162-
- metadata: A dictionary with keys "dtype" and "shape" describing
164+
- metadata: A dictionary with keys "dtype" and "shape" describing
163165
the tensor.
164166
"""
165167
return self.group.recv_obj(self.target_rank_for_recv)
166168

167169
def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
168170
"""
169-
The actual implementation of sending the tensor and its metadata to the
171+
The actual implementation of sending the tensor and its metadata to the
170172
target rank.
171173
172174
Parameters:
173-
- tensor: The input tensor to be sent, or None if no tensor is
175+
- tensor: The input tensor to be sent, or None if no tensor is
174176
being sent.
175177
"""
176178
metadata = self._make_metadata(tensor)
@@ -181,7 +183,7 @@ def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
181183

182184
def _recv_impl(self) -> Optional[torch.Tensor]:
183185
"""
184-
The actual implementation of receiving a tensor and its metadata from
186+
The actual implementation of receiving a tensor and its metadata from
185187
the target rank.
186188
187189
Returns:
@@ -213,7 +215,7 @@ def send_tensor_wrapper(self, tensor: Optional[torch.Tensor],
213215

214216
def block_if_full(self):
215217
"""
216-
Block the current thread if the buffer size is larger than the
218+
Block the current thread if the buffer size is larger than the
217219
threshold.
218220
"""
219221
while self.buffer_size > self.buffer_size_thresh:
@@ -222,7 +224,7 @@ def block_if_full(self):
222224

223225
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
224226
"""
225-
Sends a tensor and its metadata to the destination rank in a
227+
Sends a tensor and its metadata to the destination rank in a
226228
non-blocking way.
227229
228230
Parameters:

vllm/distributed/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
66
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
77
import dataclasses
8+
import datetime
89
import pickle
910
import time
1011
from collections import deque
@@ -217,6 +218,7 @@ def create(
217218
rank: int,
218219
world_size: int,
219220
data_expiration_seconds: int = 3600,
221+
store_timeout: int = 300,
220222
) -> "StatelessProcessGroup":
221223
"""A replacement for `torch.distributed.init_process_group` that does not
222224
pollute the global state.
@@ -238,6 +240,7 @@ def create(
238240
port=port,
239241
world_size=world_size,
240242
is_master=(rank == 0),
243+
timeout=datetime.timedelta(seconds=store_timeout),
241244
)
242245

243246
return StatelessProcessGroup(

0 commit comments

Comments
 (0)