11# SPDX-License-Identifier: Apache-2.0
22"""
3- This module implements a PyNccl pipe for sending and receiving
4- Optional[torch.Tensor] between distributed ranks with advanced
3+ This module implements a PyNccl pipe for sending and receiving
4+ Optional[torch.Tensor] between distributed ranks with advanced
55 communication features.
66
77 Key Features:
@@ -59,11 +59,13 @@ def __init__(self,
5959 self .device = self ._select_device (device )
6060
6161 # build distributed connection and send/recv implementation
62+ store_timeout = self .config .get_from_extra_config ("store_timeout" , 300 )
6263 self .group = StatelessProcessGroup .create (
6364 host = self .config .kv_ip ,
6465 port = self .config .kv_port + port_offset ,
6566 rank = self .kv_rank ,
6667 world_size = self .kv_parallel_size ,
68+ store_timeout = store_timeout ,
6769 )
6870 # add a barrier to make sure the connection is initiated properly
6971 self .group .barrier ()
@@ -134,11 +136,11 @@ def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor:
134136 Create a buffer to receive the tensor based on the provided metadata.
135137
136138 Parameters:
137- - metadata: A dictionary with keys "dtype" and "shape", describing
139+ - metadata: A dictionary with keys "dtype" and "shape", describing
138140 the tensor's data type and shape.
139141
140142 Returns:
141- - buffer: A tensor of the specified type and shape, allocated on
143+ - buffer: A tensor of the specified type and shape, allocated on
142144 self.device.
143145 """
144146 return torch .empty (metadata ["shape" ],
@@ -159,18 +161,18 @@ def _recv_metadata(self) -> Metadata:
159161 Receive the metadata dictionary from the target rank.
160162
161163 Returns:
162- - metadata: A dictionary with keys "dtype" and "shape" describing
164+ - metadata: A dictionary with keys "dtype" and "shape" describing
163165 the tensor.
164166 """
165167 return self .group .recv_obj (self .target_rank_for_recv )
166168
167169 def _send_impl (self , tensor : Optional [torch .Tensor ]) -> None :
168170 """
169- The actual implementation of sending the tensor and its metadata to the
171+ The actual implementation of sending the tensor and its metadata to the
170172 target rank.
171173
172174 Parameters:
173- - tensor: The input tensor to be sent, or None if no tensor is
175+ - tensor: The input tensor to be sent, or None if no tensor is
174176 being sent.
175177 """
176178 metadata = self ._make_metadata (tensor )
@@ -181,7 +183,7 @@ def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
181183
182184 def _recv_impl (self ) -> Optional [torch .Tensor ]:
183185 """
184- The actual implementation of receiving a tensor and its metadata from
186+ The actual implementation of receiving a tensor and its metadata from
185187 the target rank.
186188
187189 Returns:
@@ -213,7 +215,7 @@ def send_tensor_wrapper(self, tensor: Optional[torch.Tensor],
213215
214216 def block_if_full (self ):
215217 """
216- Block the current thread if the buffer size is larger than the
218+ Block the current thread if the buffer size is larger than the
217219 threshold.
218220 """
219221 while self .buffer_size > self .buffer_size_thresh :
@@ -222,7 +224,7 @@ def block_if_full(self):
222224
223225 def send_tensor (self , tensor : Optional [torch .Tensor ]) -> None :
224226 """
225- Sends a tensor and its metadata to the destination rank in a
227+ Sends a tensor and its metadata to the destination rank in a
226228 non-blocking way.
227229
228230 Parameters:
0 commit comments