4747if TYPE_CHECKING :
4848 from vllm .attention .backends .abstract import AttentionMetadata
4949 from vllm .v1 .core .kv_cache_manager import KVCacheBlocks
50+ from vllm .v1 .outputs import KVConnectorOutput
5051 from vllm .v1 .request import Request
5152
5253Transfer = tuple [int , float ] # (xfer_handle, start_time)
@@ -117,9 +118,6 @@ class NixlConnectorMetadata(KVConnectorMetadata):
117118 def __init__ (self ):
118119 self .reqs_to_recv : dict [ReqId , ReqMeta ] = {}
119120 self .reqs_to_save : dict [ReqId , ReqMeta ] = {}
120- self .reqs_to_send : dict [ReqId , float ] = {}
121- self .reqs_in_batch : set [ReqId ] = set ()
122- self .reqs_not_processed : set [ReqId ] = set ()
123121
124122 def add_new_req (
125123 self ,
@@ -210,6 +208,13 @@ def build_connector_meta(
210208 assert self .connector_scheduler is not None
211209 return self .connector_scheduler .build_connector_meta (scheduler_output )
212210
211+ def update_connector_output (
212+ self ,
213+ connector_output : "KVConnectorOutput" ,
214+ ):
215+ assert self .connector_scheduler is not None
216+ return self .connector_scheduler .update_connector_output (connector_output )
217+
213218 def request_finished (
214219 self ,
215220 request : "Request" ,
@@ -278,6 +283,99 @@ def shutdown(self):
278283 self .connector_worker .shutdown ()
279284
280285
286+ class ReqsNeedSendTracker :
287+ @dataclass
288+ class RequestTimer :
289+ """Timer for requests that need to be sent for remote decode."""
290+
291+ expiry_time : float
292+ """Expiry time to avoid stranded KV blocks that are never fetched."""
293+ consumer_count : int
294+ """Consumer notification count - with heterogeneous TP, P must wait
295+ for all assigned D TP workers to finish reading before safely freeing
296+ the blocks."""
297+
298+ def __init__ (self ):
299+ self ._reqs_need_send : dict [ReqId , ReqsNeedSendTracker .RequestTimer ] = {}
300+ self ._timeout = envs .VLLM_NIXL_ABORT_REQUEST_TIMEOUT
301+
302+ def start_timer (self , req_id ):
303+ self ._reqs_need_send [req_id ] = self .RequestTimer (
304+ time .monotonic () + self ._timeout , 0
305+ )
306+
307+ def delete_timer (self , req_id ):
308+ if req_id not in self ._reqs_need_send :
309+ return
310+ logger .debug ("Deleting KV transfer timeout for request %s" , req_id )
311+ del self ._reqs_need_send [req_id ]
312+
313+ def _process_finished_notifs (self , finished_notifs : set [str ]) -> set [str ]:
314+ """Process notifications from D and track consumer completion.
315+
316+ The notification strings are in format "req_id:tp_ratio".
317+
318+ Return request IDs that have completed sending to all consumers, to be
319+ used by the scheduler via KVConnectorOutput.finished_sending.
320+ """
321+ finished_sending : set [str ] = set ()
322+ for notif in finished_notifs or ():
323+ try :
324+ req_id , tp_ratio = notif .rsplit (":" , 1 )
325+ except (ValueError , TypeError ) as e :
326+ raise ValueError (f"Invalid notification: { notif } " ) from e
327+
328+ # Sent notifications received after we already timed out
329+ if req_id not in self ._reqs_need_send :
330+ logger .debug (
331+ "Already finished or expired KV transfer for request %s" , req_id
332+ )
333+ continue
334+
335+ # Wait all consumers (D) to be done reading before freeing.
336+ request_timer = self ._reqs_need_send [req_id ]
337+ request_timer .consumer_count += 1
338+ if request_timer .consumer_count < int (tp_ratio ):
339+ continue
340+
341+ logger .debug (
342+ "KV transfer finished for request %s after retrieval by %d "
343+ "decode worker(s)." ,
344+ req_id ,
345+ request_timer .consumer_count ,
346+ )
347+ del self ._reqs_need_send [req_id ]
348+ finished_sending .add (req_id )
349+
350+ return finished_sending
351+
352+ def _abort_expired_requests (self , finished_sending : set [str ]) -> set [str ]:
353+ """Abort requests that have passed their expiry timeout.
354+
355+ Adds aborted requests to KVConnectorOutput.finished_sending.
356+ """
357+ now = time .monotonic ()
358+ while self ._reqs_need_send :
359+ req_id , request_timer = next (iter (self ._reqs_need_send .items ()))
360+ # Insertion-ordered dict; oldest first so we can exit early.
361+ if now < request_timer .expiry_time :
362+ break
363+ logger .warning (
364+ "Releasing expired KV blocks for request %s which were "
365+ "retrieved by %d decode worker(s) within %d seconds." ,
366+ req_id ,
367+ request_timer .consumer_count ,
368+ self ._timeout ,
369+ )
370+ del self ._reqs_need_send [req_id ]
371+ finished_sending .add (req_id )
372+ return finished_sending
373+
374+ def reqs_finished_sending (self , finished_notifs : set [str ]) -> set [str ]:
375+ finished_sending = self ._process_finished_notifs (finished_notifs )
376+ return self ._abort_expired_requests (finished_sending )
377+
378+
281379class NixlConnectorScheduler :
282380 """Implementation of Scheduler side methods"""
283381
@@ -299,12 +397,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
299397 # the scheduler. Used to make metadata passed to Worker.
300398 self ._reqs_need_recv : dict [ReqId , tuple [Request , list [int ]]] = {}
301399 self ._reqs_need_save : dict [ReqId , tuple [Request , list [int ]]] = {}
302- # Reqs to send and their expiration time
303- self ._reqs_need_send : dict [ReqId , float ] = {}
304- self ._reqs_in_batch : set [ReqId ] = set ()
305- # Reqs to remove from processed set because they're not to send after
306- # remote prefill or aborted.
307- self ._reqs_not_processed : set [ReqId ] = set ()
400+
401+ self ._reqs_need_send = ReqsNeedSendTracker ()
308402
309403 def get_num_new_matched_tokens (
310404 self , request : "Request" , num_computed_tokens : int
@@ -355,8 +449,6 @@ def update_state_after_alloc(
355449 if not params :
356450 return
357451
358- if params .get ("do_remote_decode" ):
359- self ._reqs_in_batch .add (request .request_id )
360452 if self .use_host_buffer and params .get ("do_remote_decode" ):
361453 # NOTE: when accelerator is not directly supported by Nixl,
362454 # prefilled blocks need to be saved to host memory before transfer.
@@ -428,19 +520,20 @@ def build_connector_meta(
428520 save_to_host = True ,
429521 )
430522
431- meta .reqs_to_send = self ._reqs_need_send
432- meta .reqs_in_batch = self ._reqs_in_batch
433- meta .reqs_not_processed = self ._reqs_not_processed
434-
435523 # Clear the list once workers start the transfers
436524 self ._reqs_need_recv .clear ()
437525 self ._reqs_need_save .clear ()
438- self ._reqs_in_batch = set ()
439- self ._reqs_not_processed = set ()
440- self ._reqs_need_send = {}
441526
442527 return meta
443528
529+ def update_connector_output (
530+ self ,
531+ connector_output : "KVConnectorOutput" ,
532+ ):
533+ connector_output .finished_sending = self ._reqs_need_send .reqs_finished_sending (
534+ connector_output .finished_sending
535+ )
536+
444537 def request_finished (
445538 self ,
446539 request : "Request" ,
@@ -474,10 +567,10 @@ def request_finished(
474567
475568 if not params .get ("do_remote_decode" ):
476569 return False , None
570+
477571 if request .status != RequestStatus .FINISHED_LENGTH_CAPPED :
478- # Also include the case of a P/D Prefill request with immediate
479- # block free (eg abort). Stop tracking this request.
480- self ._reqs_not_processed .add (request .request_id )
572+ # Request aborted after we delayed freeing the blocks?
573+ self ._reqs_need_send .delete_timer (request .request_id )
481574 return False , None
482575
483576 # TODO: check whether block_ids actually ever be 0. If not we could
@@ -486,9 +579,7 @@ def request_finished(
486579
487580 if delay_free_blocks :
488581 # Prefill request on remote. It will be read from D upon completion
489- self ._reqs_need_send [request .request_id ] = (
490- time .perf_counter () + envs .VLLM_NIXL_ABORT_REQUEST_TIMEOUT
491- )
582+ self ._reqs_need_send .start_timer (request .request_id )
492583
493584 return delay_free_blocks , dict (
494585 do_remote_prefill = True ,
@@ -609,10 +700,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
609700 # [req_id -> list[handle]]
610701 self ._recving_metadata : dict [ReqId , ReqMeta ] = {}
611702 self ._recving_transfers = defaultdict [ReqId , list [Transfer ]](list )
612- # Track the expiration time of requests that are waiting to be sent.
613- self ._reqs_to_send : dict [ReqId , float ] = {}
614- # Set of requests that have been part of a batch, regardless of status.
615- self ._reqs_to_process : set [ReqId ] = set ()
616703
617704 # Background thread for handling new handshake requests.
618705 self ._nixl_handshake_listener_t : Optional [threading .Thread ] = None
@@ -654,9 +741,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
654741 logger .debug ("Detected kv cache layout %s" , self .kv_cache_layout )
655742
656743 self ._tp_size : dict [EngineId , int ] = {self .engine_id : self .world_size }
657- # With heterogeneous TP, P must wait for all assigned D TP workers to
658- # finish reading before safely freeing the blocks.
659- self .consumer_notification_counts_by_req = defaultdict [ReqId , int ](int )
660744 self .xfer_stats = NixlKVConnectorStats ()
661745
662746 @staticmethod
@@ -1220,25 +1304,6 @@ def get_finished(self) -> tuple[set[str], set[str]]:
12201304 assert meta , f"{ req_id } not found in recving_metadata list"
12211305 self .sync_recved_kv_to_device (req_id , meta )
12221306
1223- # Handle timeout to avoid stranding blocks on remote.
1224- now = time .perf_counter ()
1225- while self ._reqs_to_send :
1226- req_id , expires = next (iter (self ._reqs_to_send .items ()))
1227- # Sorted dict, oldest requests are put first so we can exit early.
1228- if now < expires :
1229- break
1230- count = self .consumer_notification_counts_by_req .pop (req_id , 0 )
1231- logger .warning (
1232- "Releasing expired KV blocks for request %s which were "
1233- "retrieved by %d decode worker(s) within %d seconds." ,
1234- req_id ,
1235- count ,
1236- envs .VLLM_NIXL_ABORT_REQUEST_TIMEOUT ,
1237- )
1238- self ._reqs_to_process .remove (req_id )
1239- del self ._reqs_to_send [req_id ]
1240- done_sending .add (req_id )
1241-
12421307 return done_sending , done_recving
12431308
12441309 def _get_new_notifs (self ) -> set [str ]:
@@ -1250,26 +1315,8 @@ def _get_new_notifs(self) -> set[str]:
12501315 notified_req_ids : set [str ] = set ()
12511316 for notifs in self .nixl_wrapper .get_new_notifs ().values ():
12521317 for notif in notifs :
1253- req_id , tp_ratio = notif .decode ("utf-8" ).rsplit (":" , 1 )
1254- if (
1255- req_id not in self ._reqs_to_send
1256- and req_id not in self ._reqs_to_process
1257- ):
1258- logger .error (
1259- "Potentially invalid KV blocks for "
1260- "unrecognized request %s were retrieved by "
1261- "a decode worker. They may have expired." ,
1262- req_id ,
1263- )
1264- continue
1265-
1266- self .consumer_notification_counts_by_req [req_id ] += 1
1267- # Wait all consumers (D) to be done reading before freeing.
1268- if self .consumer_notification_counts_by_req [req_id ] == int (tp_ratio ):
1269- notified_req_ids .add (req_id )
1270- del self .consumer_notification_counts_by_req [req_id ]
1271- self ._reqs_to_process .remove (req_id )
1272- self ._reqs_to_send .pop (req_id , None )
1318+ # Note - this is in req_id:tp_ratio format
1319+ notified_req_ids .add (notif .decode ("utf-8" ))
12731320 return notified_req_ids
12741321
12751322 def _pop_done_transfers (
@@ -1333,24 +1380,6 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
13331380 while not self ._ready_requests .empty ():
13341381 self ._read_blocks_for_req (* self ._ready_requests .get_nowait ())
13351382
1336- # Keep around the requests that have been part of a batch. This is
1337- # needed because async scheduling pushes the misalignment between the
1338- # moment in which requests expiration is set (P side) and the moment in
1339- # which blocks are read from D. As P can now more easily lag behind D
1340- # while processing the next batch, we make sure to only set an
1341- # expiration for requests that have not been read from D yet.
1342- for req_id in metadata .reqs_in_batch :
1343- self ._reqs_to_process .add (req_id )
1344-
1345- # Remove all requests that are not to be processed (eg aborted).
1346- for req_id in metadata .reqs_not_processed :
1347- self ._reqs_to_process .discard (req_id )
1348-
1349- # Add to requests that are waiting to be read and track expiration.
1350- for req_id , expiration_time in metadata .reqs_to_send .items ():
1351- if req_id in self ._reqs_to_process :
1352- self ._reqs_to_send [req_id ] = expiration_time
1353-
13541383 def _read_blocks_for_req (self , req_id : str , meta : ReqMeta ):
13551384 logger .debug (
13561385 "Remote agent %s available, calling _read_blocks for req %s" ,
0 commit comments