@@ -134,7 +134,6 @@ def __init__(self,
134134 # PUT or PUT_ASYNC
135135 # tensor_id: torch.Tensor
136136 self .send_queue : deque [SendQueueItem ] = deque ()
137- self .send_request_id_to_tensor_ids : dict [str , set [str ]] = {}
138137 if self .send_type == "PUT_ASYNC" :
139138 self ._send_thread = threading .Thread (target = self .send_async ,
140139 daemon = True )
@@ -143,6 +142,7 @@ def __init__(self,
143142 # tensor_id: torch.Tensor/(addr, dtype, shape)
144143 self .recv_store : dict [str , Any ] = {}
145144 self .recv_request_id_to_tensor_ids : dict [str , set [str ]] = {}
145+ self .send_request_id_to_tensor_ids : dict [str , set [str ]] = {}
146146 self .socks : dict [str , Any ] = {} # remote_address: client socket
147147 self .comms : dict [str , Any ] = {} # remote_address: (ncclComm_t, rank)
148148
@@ -223,18 +223,26 @@ def send_tensor(
223223 # GET
224224 with self .send_store_cv :
225225 tensor_size = tensor .element_size () * tensor .numel ()
226+ if tensor_size > self .buffer_size_threshold :
227+ logger .warning (
228+ "❗[GET]tensor_id:%s, tensor_size:%d, is greater than"
229+ "buffer size threshold :%d, skip send to %s, rank:%d" ,
230+ tensor_id , tensor_size , self .buffer_size_threshold ,
231+ remote_address , self .rank )
232+ return False
226233 while (self .buffer_size + tensor_size
227234 > self .buffer_size_threshold ):
228- oldest_tenser_id = next (iter (self .send_store ))
229- oldest_tenser = self .send_store .pop (oldest_tenser_id )
230- oldest_tenser_size = oldest_tenser .element_size (
231- ) * oldest_tenser .numel ()
232- self .buffer_size -= oldest_tenser_size
233- logger .info (
235+ assert len (self .send_store ) > 0
236+ oldest_tensor_id = next (iter (self .send_store ))
237+ oldest_tensor = self .send_store .pop (oldest_tensor_id )
238+ oldest_tensor_size = oldest_tensor .element_size (
239+ ) * oldest_tensor .numel ()
240+ self .buffer_size -= oldest_tensor_size
241+ logger .debug (
234242 "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
235- " buffer_size:%d, oldest_tenser_size :%d, rank:%d" ,
243+ " buffer_size:%d, oldest_tensor_size :%d, rank:%d" ,
236244 remote_address , tensor_id , tensor_size , self .buffer_size ,
237- oldest_tenser_size , self .rank )
245+ oldest_tensor_size , self .rank )
238246
239247 self .send_store [tensor_id ] = tensor
240248 self .buffer_size += tensor_size
0 commit comments