diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index ec72905a0d3e..3dadfa595ef1 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -178,6 +178,9 @@ def inject_kv_into_layer( # Load the KV for each request each layer for request in metadata.requests: + request_id = request.request_id + ip, port = self.parse_request_id(request_id, False) + remote_address = ip + ":" + str(port + self._rank) for layer_name in forward_context.no_compile_layers: layer = forward_context.no_compile_layers[layer_name] @@ -191,7 +194,7 @@ def inject_kv_into_layer( layer = kv_cache[forward_context.virtual_engine] kv_cache = self.p2p_nccl_engine.recv_tensor( - request.request_id + "#" + layer_name) + request.request_id + "#" + layer_name, remote_address) if kv_cache is None: logger.warning("🚧kv_cache is None, %s", request.request_id) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py index fa7cc66ab654..959bf0277a3f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -134,7 +134,6 @@ def __init__(self, # PUT or PUT_ASYNC # tensor_id: torch.Tensor self.send_queue: deque[SendQueueItem] = deque() - self.send_request_id_to_tensor_ids: dict[str, set[str]] = {} if self.send_type == "PUT_ASYNC": self._send_thread = threading.Thread(target=self.send_async, daemon=True) @@ -143,6 +142,7 @@ def __init__(self, # tensor_id: torch.Tensor/(addr, dtype, shape) self.recv_store: dict[str, Any] = {} self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {} + self.send_request_id_to_tensor_ids: dict[str, set[str]] = {} self.socks: dict[str, Any] = {} # remote_address: client socket self.comms: dict[str, Any] = {} # remote_address: (ncclComm_t, rank) @@ -223,18 +223,26 @@ def send_tensor( # GET with self.send_store_cv: tensor_size = tensor.element_size() * tensor.numel() + if tensor_size > self.buffer_size_threshold: + logger.warning( + "❗[GET]tensor_id:%s, tensor_size:%d, is greater than" + "buffer size threshold :%d, skip send to %s, rank:%d", + tensor_id, tensor_size, self.buffer_size_threshold, + remote_address, self.rank) + return False while (self.buffer_size + tensor_size > self.buffer_size_threshold): - oldest_tenser_id = next(iter(self.send_store)) - oldest_tenser = self.send_store.pop(oldest_tenser_id) - oldest_tenser_size = oldest_tenser.element_size( - ) * oldest_tenser.numel() - self.buffer_size -= oldest_tenser_size - logger.info( + assert len(self.send_store) > 0 + oldest_tensor_id = next(iter(self.send_store)) + oldest_tensor = self.send_store.pop(oldest_tensor_id) + oldest_tensor_size = oldest_tensor.element_size( + ) * oldest_tensor.numel() + self.buffer_size -= oldest_tensor_size + logger.debug( "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," - " buffer_size:%d, oldest_tenser_size:%d, rank:%d", + " buffer_size:%d, oldest_tensor_size:%d, rank:%d", remote_address, tensor_id, tensor_size, self.buffer_size, - oldest_tenser_size, self.rank) + oldest_tensor_size, self.rank) self.send_store[tensor_id] = tensor self.buffer_size += tensor_size