@@ -79,15 +79,16 @@ class ReqMeta:
7979class NixlConnectorMetadata (KVConnectorMetadata ):
8080
8181 def __init__ (self ):
82- self .requests : dict [ReqId , ReqMeta ] = {}
82+ self .reqs_to_recv : dict [ReqId , ReqMeta ] = {}
83+ self .reqs_to_send : dict [ReqId , float ] = {}
8384
8485 def add_new_req (
8586 self ,
8687 request_id : ReqId ,
8788 local_block_ids : list [int ],
8889 kv_transfer_params : dict [str , Any ],
8990 ):
90- self .requests [request_id ] = ReqMeta (
91+ self .reqs_to_recv [request_id ] = ReqMeta (
9192 local_block_ids = local_block_ids ,
9293 remote_block_ids = kv_transfer_params ["remote_block_ids" ],
9394 remote_engine_id = kv_transfer_params ["remote_engine_id" ],
@@ -194,10 +195,12 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
194195 vllm_config .parallel_config .tensor_parallel_size )
195196 logger .info ("Initializing NIXL Scheduler %s" , engine_id )
196197
197- # Requests that need to start recv.
198+ # Requests that need to start recv/send .
198199 # New requests are added by update_state_after_alloc in
199200 # the scheduler. Used to make metadata passed to Worker.
200201 self ._reqs_need_recv : dict [ReqId , tuple [Request , list [int ]]] = {}
202+ # Reqs to send and their expiration time
203+ self ._reqs_need_send : dict [ReqId , float ] = {}
201204
202205 def get_num_new_matched_tokens (
203206 self , request : "Request" ,
@@ -284,6 +287,9 @@ def build_connector_meta(
284287 # Clear the list once workers start the transfers
285288 self ._reqs_need_recv .clear ()
286289
290+ meta .reqs_to_send = self ._reqs_need_send
291+ self ._reqs_need_send = {}
292+
287293 return meta
288294
289295 def request_finished (
@@ -325,6 +331,11 @@ def request_finished(
325331 # If prompt < block_size, no xfer so free blocks immediately.
326332 delay_free_blocks = len (computed_block_ids ) > 0
327333
334+ if delay_free_blocks :
335+ # Prefill request on remote. It will be read from D upon completion
336+ self ._reqs_need_send [request .request_id ] = time .perf_counter (
337+ ) + envs .VLLM_NIXL_ABORT_REQUEST_TIMEOUT
338+
328339 return delay_free_blocks , dict (
329340 do_remote_prefill = True ,
330341 do_remote_decode = False ,
@@ -394,6 +405,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
394405 # In progress transfers.
395406 # [req_id -> list[handle]]
396407 self ._recving_transfers = defaultdict [ReqId , list [Transfer ]](list )
408+ # Track the expiration time of requests that are waiting to be sent.
409+ self ._reqs_to_send : dict [ReqId , float ] = {}
397410
398411 # Complete transfer tracker. Used by the rank 0 to track finished
399412 # transactions on ranks 1 to N-1.
@@ -826,6 +839,16 @@ def get_finished(self) -> tuple[set[str], set[str]]:
826839 "and %s requests done recving" , self .tp_rank ,
827840 len (done_sending ), len (done_recving ))
828841
842+ # Handle timeout to avoid stranding blocks on remote.
843+ now = time .perf_counter ()
844+ while self ._reqs_to_send :
845+ req_id , expires = next (iter (self ._reqs_to_send .items ()))
846+ # Sorted dict, oldest requests are put first so we can exit early.
847+ if now < expires :
848+ break
849+ del self ._reqs_to_send [req_id ]
850+ done_sending .add (req_id )
851+
829852 if self .world_size == 1 :
830853 return done_sending , done_recving
831854
@@ -857,7 +880,7 @@ def get_finished(self) -> tuple[set[str], set[str]]:
857880
858881 all_done_sending : set [str ] = set ()
859882 for req_id in list (self ._done_sending_count .keys ()):
860- if self ._done_sending_count [req_id ] = = self .world_size :
883+ if self ._done_sending_count [req_id ] > = self .world_size :
861884 del self ._done_sending_count [req_id ]
862885 all_done_sending .add (req_id )
863886
@@ -887,6 +910,7 @@ def _get_new_notifs(self) -> set[str]:
887910 tp_ratio ):
888911 notified_req_ids .add (req_id )
889912 del self .consumer_notification_counts_by_req [req_id ]
913+ del self ._reqs_to_send [req_id ]
890914 return notified_req_ids
891915
892916 def _pop_done_transfers (
@@ -921,7 +945,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
921945 Start loading by triggering non-blocking nixl_xfer.
922946 We check for these trnxs to complete in each step().
923947 """
924- for req_id , meta in metadata .requests .items ():
948+ for req_id , meta in metadata .reqs_to_recv .items ():
925949 remote_engine_id = meta .remote_engine_id
926950 logger .debug (
927951 "start_load_kv for request %s from remote engine %s. "
@@ -943,6 +967,9 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
943967 while not self ._ready_requests .empty ():
944968 self ._read_blocks_for_req (* self ._ready_requests .get_nowait ())
945969
970+ # Add to requests that are waiting to be read and track expiration.
971+ self ._reqs_to_send .update (metadata .reqs_to_send )
972+
946973 def _read_blocks_for_req (self , req_id : str , meta : ReqMeta ):
947974 logger .debug (
948975 "Remote agent %s available, calling _read_blocks for req %s" ,
0 commit comments