Skip to content

Commit da63274

Browse files
authored
[Bugfix][NIXL] Fix Async Scheduler timeout issue (#25808)
Signed-off-by: NickLucche <nlucches@redhat.com>
1 parent c216119 commit da63274

File tree

1 file changed

+28
-3
lines changed

1 file changed

+28
-3
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def __init__(self):
105105
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
106106
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
107107
self.reqs_to_send: dict[ReqId, float] = {}
108+
self.reqs_in_batch: set[ReqId] = set()
108109

109110
def add_new_req(
110111
self,
@@ -278,6 +279,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
278279
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
279280
# Reqs to send and their expiration time
280281
self._reqs_need_send: dict[ReqId, float] = {}
282+
self._reqs_in_batch: set[ReqId] = set()
281283

282284
def get_num_new_matched_tokens(
283285
self, request: "Request",
@@ -324,6 +326,9 @@ def update_state_after_alloc(self, request: "Request",
324326

325327
if not params:
326328
return
329+
330+
if params.get("do_remote_decode"):
331+
self._reqs_in_batch.add(request.request_id)
327332
if self.use_host_buffer and params.get("do_remote_decode"):
328333
# NOTE: when accelerator is not directly supported by Nixl,
329334
# prefilled blocks need to be saved to host memory before transfer.
@@ -373,6 +378,8 @@ def build_connector_meta(
373378
request_id=req_id,
374379
local_block_ids=block_ids,
375380
kv_transfer_params=req.kv_transfer_params,
381+
load_remote_cache=True,
382+
save_to_host=False,
376383
)
377384

378385
for req_id, (req, block_ids) in self._reqs_need_save.items():
@@ -386,10 +393,12 @@ def build_connector_meta(
386393
)
387394

388395
meta.reqs_to_send = self._reqs_need_send
396+
meta.reqs_in_batch = self._reqs_in_batch
389397

390398
# Clear the list once workers start the transfers
391399
self._reqs_need_recv.clear()
392400
self._reqs_need_save.clear()
401+
self._reqs_in_batch = set()
393402
self._reqs_need_send = {}
394403

395404
return meta
@@ -546,6 +555,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
546555
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
547556
# Track the expiration time of requests that are waiting to be sent.
548557
self._reqs_to_send: dict[ReqId, float] = {}
558+
# Set of requests that have been part of a batch, regardless of status.
559+
self._reqs_to_process: set[ReqId] = set()
549560

550561
# Background thread for handling new handshake requests.
551562
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
@@ -1082,6 +1093,7 @@ def get_finished(self) -> tuple[set[str], set[str]]:
10821093
"Releasing expired KV blocks for request %s which were "
10831094
"retrieved by %d decode worker(s) within %d seconds.", req_id,
10841095
count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT)
1096+
self._reqs_to_process.remove(req_id)
10851097
del self._reqs_to_send[req_id]
10861098
done_sending.add(req_id)
10871099

@@ -1097,7 +1109,8 @@ def _get_new_notifs(self) -> set[str]:
10971109
for notifs in self.nixl_wrapper.get_new_notifs().values():
10981110
for notif in notifs:
10991111
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
1100-
if req_id not in self._reqs_to_send:
1112+
if (req_id not in self._reqs_to_send
1113+
and req_id not in self._reqs_to_process):
11011114
logger.error(
11021115
"Potentially invalid KV blocks for "
11031116
"unrecognized request %s were retrieved by "
@@ -1110,7 +1123,8 @@ def _get_new_notifs(self) -> set[str]:
11101123
tp_ratio):
11111124
notified_req_ids.add(req_id)
11121125
del self.consumer_notification_counts_by_req[req_id]
1113-
del self._reqs_to_send[req_id]
1126+
self._reqs_to_process.remove(req_id)
1127+
self._reqs_to_send.pop(req_id, None)
11141128
return notified_req_ids
11151129

11161130
def _pop_done_transfers(
@@ -1171,8 +1185,19 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
11711185
while not self._ready_requests.empty():
11721186
self._read_blocks_for_req(*self._ready_requests.get_nowait())
11731187

1188+
# Keep around the requests that have been part of a batch. This is
1189+
# needed because async scheduling pushes the misalignment between the
1190+
# moment in which requests expiration is set (P side) and the moment in
1191+
# which blocks are read from D. As P can now more easily lag behind D
1192+
# while processing the next batch, we make sure to only set an
1193+
# expiration for requests that have not been read from D yet.
1194+
for req_id in metadata.reqs_in_batch:
1195+
self._reqs_to_process.add(req_id)
1196+
11741197
# Add to requests that are waiting to be read and track expiration.
1175-
self._reqs_to_send.update(metadata.reqs_to_send)
1198+
for req_id, expiration_time in metadata.reqs_to_send.items():
1199+
if req_id in self._reqs_to_process:
1200+
self._reqs_to_send[req_id] = expiration_time
11761201

11771202
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
11781203
logger.debug(

0 commit comments

Comments
 (0)