Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,12 +1044,12 @@ def get_finished(self) -> tuple[set[str], set[str]]:
if now < expires:
break
count = self.consumer_notification_counts_by_req.pop(req_id, 0)
logger.warning(
"Releasing expired KV blocks for request %s which were "
"retrieved by %d decode worker(s) within %d seconds.", req_id,
count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT)
del self._reqs_to_send[req_id]
done_sending.add(req_id)
if self.try_remove_request(req_id, "timeout"):
done_sending.add(req_id)
logger.warning(
"Releasing expired KV blocks for request %s which were "
"retrieved by %d decode worker(s) within %d seconds.",
req_id, count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT)

return done_sending, done_recving

Expand All @@ -1074,9 +1074,14 @@ def _get_new_notifs(self) -> set[str]:
# Wait all consumers (D) to be done reading before freeing.
if self.consumer_notification_counts_by_req[req_id] == int(
tp_ratio):
notified_req_ids.add(req_id)
del self.consumer_notification_counts_by_req[req_id]
del self._reqs_to_send[req_id]
if self.try_remove_request(req_id, "consumer_complete"):
notified_req_ids.add(req_id)
else:
logger.debug(
"Request %s completed by all consumers but was"
"already removed (likely timed out)", req_id)

return notified_req_ids

def _pop_done_transfers(
Expand Down Expand Up @@ -1298,6 +1303,24 @@ def get_backend_aware_kv_block_len(self):
block_len = self.block_len
return block_len

def try_remove_request(self, req_id: str, reason: str) -> bool:
"""
Safely remove a request from pending sends.

Returns:
True if the request was removed, False if already gone.
"""
timeout_value = self._reqs_to_send.pop(req_id, None)

if timeout_value is not None:
logger.debug("Removed request %s (reason: %s, was due at: %.2f)",
req_id, reason, timeout_value)
return True
else:
logger.debug("Request %s already removed when attempting %s",
req_id, reason)
return False


@contextlib.contextmanager
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
Expand Down