@@ -105,6 +105,7 @@ def __init__(self):
105105 self .reqs_to_recv : dict [ReqId , ReqMeta ] = {}
106106 self .reqs_to_save : dict [ReqId , ReqMeta ] = {}
107107 self .reqs_to_send : dict [ReqId , float ] = {}
108+ self .reqs_in_batch : set [ReqId ] = set ()
108109
109110 def add_new_req (
110111 self ,
@@ -278,6 +279,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
278279 self ._reqs_need_save : dict [ReqId , tuple [Request , list [int ]]] = {}
279280 # Reqs to send and their expiration time
280281 self ._reqs_need_send : dict [ReqId , float ] = {}
282+ self ._reqs_in_batch : set [ReqId ] = set ()
281283
282284 def get_num_new_matched_tokens (
283285 self , request : "Request" ,
@@ -324,6 +326,9 @@ def update_state_after_alloc(self, request: "Request",
324326
325327 if not params :
326328 return
329+
330+ if params .get ("do_remote_decode" ):
331+ self ._reqs_in_batch .add (request .request_id )
327332 if self .use_host_buffer and params .get ("do_remote_decode" ):
328333 # NOTE: when accelerator is not directly supported by Nixl,
329334 # prefilled blocks need to be saved to host memory before transfer.
@@ -373,6 +378,8 @@ def build_connector_meta(
373378 request_id = req_id ,
374379 local_block_ids = block_ids ,
375380 kv_transfer_params = req .kv_transfer_params ,
381+ load_remote_cache = True ,
382+ save_to_host = False ,
376383 )
377384
378385 for req_id , (req , block_ids ) in self ._reqs_need_save .items ():
@@ -386,10 +393,12 @@ def build_connector_meta(
386393 )
387394
388395 meta .reqs_to_send = self ._reqs_need_send
396+ meta .reqs_in_batch = self ._reqs_in_batch
389397
390398 # Clear the list once workers start the transfers
391399 self ._reqs_need_recv .clear ()
392400 self ._reqs_need_save .clear ()
401+ self ._reqs_in_batch = set ()
393402 self ._reqs_need_send = {}
394403
395404 return meta
@@ -546,6 +555,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
546555 self ._recving_transfers = defaultdict [ReqId , list [Transfer ]](list )
547556 # Track the expiration time of requests that are waiting to be sent.
548557 self ._reqs_to_send : dict [ReqId , float ] = {}
558+ # Set of requests that have been part of a batch, regardless of status.
559+ self ._reqs_to_process : set [ReqId ] = set ()
549560
550561 # Background thread for handling new handshake requests.
551562 self ._nixl_handshake_listener_t : Optional [threading .Thread ] = None
@@ -1082,6 +1093,7 @@ def get_finished(self) -> tuple[set[str], set[str]]:
10821093 "Releasing expired KV blocks for request %s which were "
10831094 "retrieved by %d decode worker(s) within %d seconds." , req_id ,
10841095 count , envs .VLLM_NIXL_ABORT_REQUEST_TIMEOUT )
1096+ self ._reqs_to_process .remove (req_id )
10851097 del self ._reqs_to_send [req_id ]
10861098 done_sending .add (req_id )
10871099
@@ -1097,7 +1109,8 @@ def _get_new_notifs(self) -> set[str]:
10971109 for notifs in self .nixl_wrapper .get_new_notifs ().values ():
10981110 for notif in notifs :
10991111 req_id , tp_ratio = notif .decode ("utf-8" ).rsplit (":" , 1 )
1100- if req_id not in self ._reqs_to_send :
1112+ if (req_id not in self ._reqs_to_send
1113+ and req_id not in self ._reqs_to_process ):
11011114 logger .error (
11021115 "Potentially invalid KV blocks for "
11031116 "unrecognized request %s were retrieved by "
@@ -1110,7 +1123,8 @@ def _get_new_notifs(self) -> set[str]:
11101123 tp_ratio ):
11111124 notified_req_ids .add (req_id )
11121125 del self .consumer_notification_counts_by_req [req_id ]
1113- del self ._reqs_to_send [req_id ]
1126+ self ._reqs_to_process .remove (req_id )
1127+ self ._reqs_to_send .pop (req_id , None )
11141128 return notified_req_ids
11151129
11161130 def _pop_done_transfers (
@@ -1171,8 +1185,19 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
11711185 while not self ._ready_requests .empty ():
11721186 self ._read_blocks_for_req (* self ._ready_requests .get_nowait ())
11731187
1188+ # Keep around the requests that have been part of a batch. This is
1189+ # needed because async scheduling pushes the misalignment between the
1190+ # moment in which requests expiration is set (P side) and the moment in
1191+ # which blocks are read from D. As P can now more easily lag behind D
1192+ # while processing the next batch, we make sure to only set an
1193+ # expiration for requests that have not been read from D yet.
1194+ for req_id in metadata .reqs_in_batch :
1195+ self ._reqs_to_process .add (req_id )
1196+
11741197 # Add to requests that are waiting to be read and track expiration.
1175- self ._reqs_to_send .update (metadata .reqs_to_send )
1198+ for req_id , expiration_time in metadata .reqs_to_send .items ():
1199+ if req_id in self ._reqs_to_process :
1200+ self ._reqs_to_send [req_id ] = expiration_time
11761201
11771202 def _read_blocks_for_req (self , req_id : str , meta : ReqMeta ):
11781203 logger .debug (
0 commit comments