22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33import contextlib
44import copy
5+ import enum
56import logging
67import math
78import os
@@ -113,13 +114,17 @@ class ReqMeta:
113114 tp_size : int
114115
115116
117+ class ReqState (enum .Enum ):
118+ SCHEDULED = 1
119+ FINISHED = 2
120+ ABORTED = 3
121+
122+
116123class NixlConnectorMetadata (KVConnectorMetadata ):
117124 def __init__ (self ):
118125 self .reqs_to_recv : dict [ReqId , ReqMeta ] = {}
119126 self .reqs_to_save : dict [ReqId , ReqMeta ] = {}
120- self .reqs_to_send : dict [ReqId , float ] = {}
121- self .reqs_in_batch : set [ReqId ] = set ()
122- self .reqs_not_processed : set [ReqId ] = set ()
127+ self .reqs_to_send : dict [ReqId , ReqState ] = {}
123128
124129 def add_new_req (
125130 self ,
@@ -299,12 +304,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
299304 # the scheduler. Used to make metadata passed to Worker.
300305 self ._reqs_need_recv : dict [ReqId , tuple [Request , list [int ]]] = {}
301306 self ._reqs_need_save : dict [ReqId , tuple [Request , list [int ]]] = {}
302- # Reqs to send and their expiration time
303- self ._reqs_need_send : dict [ReqId , float ] = {}
304- self ._reqs_in_batch : set [ReqId ] = set ()
305- # Reqs to remove from processed set because they're not to send after
306- # remote prefill or aborted.
307- self ._reqs_not_processed : set [ReqId ] = set ()
307+ # Reqs to send state updates
308+ self ._reqs_need_send : dict [ReqId , ReqState ] = {}
308309
309310 def get_num_new_matched_tokens (
310311 self , request : "Request" , num_computed_tokens : int
@@ -356,7 +357,7 @@ def update_state_after_alloc(
356357 return
357358
358359 if params .get ("do_remote_decode" ):
359- self ._reqs_in_batch . add ( request .request_id )
360+ self ._reqs_need_send [ request .request_id ] = ReqState . SCHEDULED
360361 if self .use_host_buffer and params .get ("do_remote_decode" ):
361362 # NOTE: when accelerator is not directly supported by Nixl,
362363 # prefilled blocks need to be saved to host memory before transfer.
@@ -429,14 +430,10 @@ def build_connector_meta(
429430 )
430431
431432 meta .reqs_to_send = self ._reqs_need_send
432- meta .reqs_in_batch = self ._reqs_in_batch
433- meta .reqs_not_processed = self ._reqs_not_processed
434433
435434 # Clear the list once workers start the transfers
436435 self ._reqs_need_recv .clear ()
437436 self ._reqs_need_save .clear ()
438- self ._reqs_in_batch = set ()
439- self ._reqs_not_processed = set ()
440437 self ._reqs_need_send = {}
441438
442439 return meta
@@ -477,7 +474,7 @@ def request_finished(
477474 if request .status != RequestStatus .FINISHED_LENGTH_CAPPED :
478475 # Also include the case of a P/D Prefill request with immediate
479476 # block free (eg abort). Stop tracking this request.
480- self ._reqs_not_processed . add ( request .request_id )
477+ self ._reqs_need_send [ request .request_id ] = ReqState . ABORTED
481478 return False , None
482479
483480 # TODO: check whether block_ids actually ever be 0. If not we could
@@ -486,9 +483,7 @@ def request_finished(
486483
487484 if delay_free_blocks :
488485 # Prefill request on remote. It will be read from D upon completion
489- self ._reqs_need_send [request .request_id ] = (
490- time .perf_counter () + envs .VLLM_NIXL_ABORT_REQUEST_TIMEOUT
491- )
486+ self ._reqs_need_send [request .request_id ] = ReqState .FINISHED
492487
493488 return delay_free_blocks , dict (
494489 do_remote_prefill = True ,
@@ -1221,7 +1216,7 @@ def get_finished(self) -> tuple[set[str], set[str]]:
12211216 self .sync_recved_kv_to_device (req_id , meta )
12221217
12231218 # Handle timeout to avoid stranding blocks on remote.
1224- now = time .perf_counter ()
1219+ now = time .monotonic ()
12251220 while self ._reqs_to_send :
12261221 req_id , expires = next (iter (self ._reqs_to_send .items ()))
12271222 # Sorted dict, oldest requests are put first so we can exit early.
@@ -1339,17 +1334,19 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
13391334 # which blocks are read from D. As P can now more easily lag behind D
13401335 # while processing the next batch, we make sure to only set an
13411336 # expiration for requests that have not been read from D yet.
1342- for req_id in metadata .reqs_in_batch :
1343- self ._reqs_to_process .add (req_id )
1344-
1345- # Remove all requests that are not to be processed (eg aborted).
1346- for req_id in metadata .reqs_not_processed :
1347- self ._reqs_to_process .discard (req_id )
1348-
1349- # Add to requests that are waiting to be read and track expiration.
1350- for req_id , expiration_time in metadata .reqs_to_send .items ():
1351- if req_id in self ._reqs_to_process :
1352- self ._reqs_to_send [req_id ] = expiration_time
1337+ for req_id , req_state in metadata .reqs_to_send .items ():
1338+ if req_state == ReqState .SCHEDULED :
1339+ self ._reqs_to_process .add (req_id )
1340+ elif req_state == ReqState .ABORTED :
1341+ # Remove all requests that are not to be processed (eg aborted).
1342+ self ._reqs_to_process .discard (req_id )
1343+ # We should never get an abort after setting an expiry timer
1344+ assert req_id not in self ._reqs_to_send
1345+ elif req_state == ReqState .FINISHED and req_id in self ._reqs_to_process :
1346+ # Add to requests that are waiting to be read and track expiration.
1347+ self ._reqs_to_send [req_id ] = (
1348+ time .monotonic () + envs .VLLM_NIXL_ABORT_REQUEST_TIMEOUT
1349+ )
13531350
13541351 def _read_blocks_for_req (self , req_id : str , meta : ReqMeta ):
13551352 logger .debug (
0 commit comments