Skip to content

Commit 9894745

Browse files
committed
NIXL: re-work send timeout tracking on prefill side
In a prefill instance, we need to free KV blocks that have not been fetched after a timeout. See #20139. In #26012, we're trying to deal with corner cases involved with doing this request timeout tracking on the worker side. This PR proposes moving all of this to the scheduler side, hopefully making the logic simpler. Note the expiry timer is switched back to monotonic time because the timestamp is no longer sent across process boundaries. Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent fc67969 commit 9894745

File tree

2 files changed

+117
-155
lines changed

2 files changed

+117
-155
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
from vllm.sampling_params import SamplingParams
4141
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
4242
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
43-
from vllm.v1.request import RequestStatus
4443

4544
from .utils import create_request, create_scheduler, create_vllm_config
4645

@@ -1078,69 +1077,3 @@ def test_shutdown_cleans_up_resources(dist_init):
10781077
assert mock_dereg.call_count == 2
10791078
mock_dereg.assert_any_call("desc1")
10801079
mock_dereg.assert_any_call("desc2")
1081-
1082-
1083-
@patch(
1084-
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
1085-
FakeNixlWrapper,
1086-
)
1087-
def test_aborted_request_removed_from_worker_in_batch(dist_init):
1088-
"""
1089-
Create and schedule a request so that P adds it to in-batch tracking via
1090-
the real scheduler, then simulate an abort (request not in next scheduler
1091-
iteration) and verify the worker no longer tracks it as in-batch.
1092-
"""
1093-
vllm_config = create_vllm_config()
1094-
1095-
scheduler = create_scheduler(vllm_config)
1096-
# KVConnector Worker in P
1097-
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
1098-
connector.connector_worker = FakeNixlConnectorWorker(
1099-
vllm_config, connector.engine_id, hand_shake_latency=0
1100-
)
1101-
1102-
# Create a request that triggers do_remote_decode so that
1103-
# the scheduler adds it to reqs_in_batch
1104-
req = create_request(request_id=1, do_remote_decode=True, max_tokens=1)
1105-
scheduler.add_request(req)
1106-
1107-
# First scheduling pass - examinate build_connector_meta output
1108-
sched_out = scheduler.schedule()
1109-
kv_meta = sched_out.kv_connector_metadata
1110-
assert kv_meta is not None
1111-
assert isinstance(kv_meta, NixlConnectorMetadata)
1112-
assert req.request_id in kv_meta.reqs_in_batch
1113-
1114-
#### Model Runner start ####
1115-
# Bind scheduler-produced metadata and start worker processing.
1116-
connector.bind_connector_metadata(kv_meta)
1117-
1118-
dummy_ctx = ForwardContext(
1119-
no_compile_layers={},
1120-
attn_metadata={},
1121-
virtual_engine=0,
1122-
)
1123-
connector.start_load_kv(dummy_ctx)
1124-
1125-
# Ensure it was tracked by the worker
1126-
assert req.request_id in connector.connector_worker._reqs_to_process
1127-
1128-
#### Model Runner end ####
1129-
1130-
# Abort request - request_finished call in connector scheduler
1131-
scheduler.finish_requests(req.request_id, RequestStatus.FINISHED_ABORTED)
1132-
# Second scheduling pass - build metadata with aborted request
1133-
sched_out2 = scheduler.schedule()
1134-
kv_meta2 = sched_out2.kv_connector_metadata
1135-
assert kv_meta2 is not None
1136-
assert isinstance(kv_meta2, NixlConnectorMetadata)
1137-
assert req.request_id not in kv_meta2.reqs_in_batch
1138-
1139-
# Bind empty/abort metadata and run worker step
1140-
#### Model Runner start ####
1141-
connector.bind_connector_metadata(kv_meta2)
1142-
connector.start_load_kv(dummy_ctx)
1143-
1144-
# After abort, the worker should not keep tracking it as "in-batch"
1145-
assert req.request_id not in connector.connector_worker._reqs_to_process
1146-
#### Model Runner end ####

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 117 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
if TYPE_CHECKING:
4848
from vllm.attention.backends.abstract import AttentionMetadata
4949
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
50+
from vllm.v1.outputs import KVConnectorOutput
5051
from vllm.v1.request import Request
5152

5253
Transfer = tuple[int, float] # (xfer_handle, start_time)
@@ -117,9 +118,6 @@ class NixlConnectorMetadata(KVConnectorMetadata):
117118
def __init__(self):
118119
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
119120
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
120-
self.reqs_to_send: dict[ReqId, float] = {}
121-
self.reqs_in_batch: set[ReqId] = set()
122-
self.reqs_not_processed: set[ReqId] = set()
123121

124122
def add_new_req(
125123
self,
@@ -210,6 +208,13 @@ def build_connector_meta(
210208
assert self.connector_scheduler is not None
211209
return self.connector_scheduler.build_connector_meta(scheduler_output)
212210

211+
def update_connector_output(
212+
self,
213+
connector_output: "KVConnectorOutput",
214+
):
215+
assert self.connector_scheduler is not None
216+
return self.connector_scheduler.update_connector_output(connector_output)
217+
213218
def request_finished(
214219
self,
215220
request: "Request",
@@ -278,6 +283,99 @@ def shutdown(self):
278283
self.connector_worker.shutdown()
279284

280285

286+
class ReqsNeedSendTracker:
287+
@dataclass
288+
class RequestTimer:
289+
"""Timer for requests that need to be sent for remote decode."""
290+
291+
expiry_time: float
292+
"""Expiry time to avoid stranded KV blocks that are never fetched."""
293+
consumer_count: int
294+
"""Consumer notification count - with heterogeneous TP, P must wait
295+
for all assigned D TP workers to finish reading before safely freeing
296+
the blocks."""
297+
298+
def __init__(self):
299+
self._reqs_need_send: dict[ReqId, ReqsNeedSendTracker.RequestTimer] = {}
300+
self._timeout = envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
301+
302+
def start_timer(self, req_id):
303+
self._reqs_need_send[req_id] = self.RequestTimer(
304+
time.monotonic() + self._timeout, 0
305+
)
306+
307+
def delete_timer(self, req_id):
308+
if req_id not in self._reqs_need_send:
309+
return
310+
logger.debug("Deleting KV transfer timeout for request %s", req_id)
311+
del self._reqs_need_send[req_id]
312+
313+
def _process_finished_notifs(self, finished_notifs: set[str]) -> set[str]:
314+
"""Process notifications from D and track consumer completion.
315+
316+
The notification strings are in format "req_id:tp_ratio".
317+
318+
Return request IDs that have completed sending to all consumers, to be
319+
used by the scheduler via KVConnectorOutput.finished_sending.
320+
"""
321+
finished_sending: set[str] = set()
322+
for notif in finished_notifs or ():
323+
try:
324+
req_id, tp_ratio = notif.rsplit(":", 1)
325+
except (ValueError, TypeError) as e:
326+
raise ValueError(f"Invalid notification: {notif}") from e
327+
328+
# Sent notifications received after we already timed out
329+
if req_id not in self._reqs_need_send:
330+
logger.debug(
331+
"Already finished or expired KV transfer for request %s", req_id
332+
)
333+
continue
334+
335+
# Wait all consumers (D) to be done reading before freeing.
336+
request_timer = self._reqs_need_send[req_id]
337+
request_timer.consumer_count += 1
338+
if request_timer.consumer_count < int(tp_ratio):
339+
continue
340+
341+
logger.debug(
342+
"KV transfer finished for request %s after retrieval by %d "
343+
"decode worker(s).",
344+
req_id,
345+
request_timer.consumer_count,
346+
)
347+
del self._reqs_need_send[req_id]
348+
finished_sending.add(req_id)
349+
350+
return finished_sending
351+
352+
def _abort_expired_requests(self, finished_sending: set[str]) -> set[str]:
353+
"""Abort requests that have passed their expiry timeout.
354+
355+
Adds aborted requests to KVConnectorOutput.finished_sending.
356+
"""
357+
now = time.monotonic()
358+
while self._reqs_need_send:
359+
req_id, request_timer = next(iter(self._reqs_need_send.items()))
360+
# Insertion-ordered dict; oldest first so we can exit early.
361+
if now < request_timer.expiry_time:
362+
break
363+
logger.warning(
364+
"Releasing expired KV blocks for request %s which were "
365+
"retrieved by %d decode worker(s) within %d seconds.",
366+
req_id,
367+
request_timer.consumer_count,
368+
self._timeout,
369+
)
370+
del self._reqs_need_send[req_id]
371+
finished_sending.add(req_id)
372+
return finished_sending
373+
374+
def reqs_finished_sending(self, finished_notifs: set[str]) -> set[str]:
375+
finished_sending = self._process_finished_notifs(finished_notifs)
376+
return self._abort_expired_requests(finished_sending)
377+
378+
281379
class NixlConnectorScheduler:
282380
"""Implementation of Scheduler side methods"""
283381

@@ -299,12 +397,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
299397
# the scheduler. Used to make metadata passed to Worker.
300398
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
301399
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
302-
# Reqs to send and their expiration time
303-
self._reqs_need_send: dict[ReqId, float] = {}
304-
self._reqs_in_batch: set[ReqId] = set()
305-
# Reqs to remove from processed set because they're not to send after
306-
# remote prefill or aborted.
307-
self._reqs_not_processed: set[ReqId] = set()
400+
401+
self._reqs_need_send = ReqsNeedSendTracker()
308402

309403
def get_num_new_matched_tokens(
310404
self, request: "Request", num_computed_tokens: int
@@ -355,8 +449,6 @@ def update_state_after_alloc(
355449
if not params:
356450
return
357451

358-
if params.get("do_remote_decode"):
359-
self._reqs_in_batch.add(request.request_id)
360452
if self.use_host_buffer and params.get("do_remote_decode"):
361453
# NOTE: when accelerator is not directly supported by Nixl,
362454
# prefilled blocks need to be saved to host memory before transfer.
@@ -428,19 +520,20 @@ def build_connector_meta(
428520
save_to_host=True,
429521
)
430522

431-
meta.reqs_to_send = self._reqs_need_send
432-
meta.reqs_in_batch = self._reqs_in_batch
433-
meta.reqs_not_processed = self._reqs_not_processed
434-
435523
# Clear the list once workers start the transfers
436524
self._reqs_need_recv.clear()
437525
self._reqs_need_save.clear()
438-
self._reqs_in_batch = set()
439-
self._reqs_not_processed = set()
440-
self._reqs_need_send = {}
441526

442527
return meta
443528

529+
def update_connector_output(
530+
self,
531+
connector_output: "KVConnectorOutput",
532+
):
533+
connector_output.finished_sending = self._reqs_need_send.reqs_finished_sending(
534+
connector_output.finished_sending
535+
)
536+
444537
def request_finished(
445538
self,
446539
request: "Request",
@@ -474,10 +567,10 @@ def request_finished(
474567

475568
if not params.get("do_remote_decode"):
476569
return False, None
570+
477571
if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
478-
# Also include the case of a P/D Prefill request with immediate
479-
# block free (eg abort). Stop tracking this request.
480-
self._reqs_not_processed.add(request.request_id)
572+
# Request aborted after we delayed freeing the blocks?
573+
self._reqs_need_send.delete_timer(request.request_id)
481574
return False, None
482575

483576
# TODO: check whether block_ids actually ever be 0. If not we could
@@ -486,9 +579,7 @@ def request_finished(
486579

487580
if delay_free_blocks:
488581
# Prefill request on remote. It will be read from D upon completion
489-
self._reqs_need_send[request.request_id] = (
490-
time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
491-
)
582+
self._reqs_need_send.start_timer(request.request_id)
492583

493584
return delay_free_blocks, dict(
494585
do_remote_prefill=True,
@@ -609,10 +700,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
609700
# [req_id -> list[handle]]
610701
self._recving_metadata: dict[ReqId, ReqMeta] = {}
611702
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
612-
# Track the expiration time of requests that are waiting to be sent.
613-
self._reqs_to_send: dict[ReqId, float] = {}
614-
# Set of requests that have been part of a batch, regardless of status.
615-
self._reqs_to_process: set[ReqId] = set()
616703

617704
# Background thread for handling new handshake requests.
618705
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
@@ -654,9 +741,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
654741
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
655742

656743
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
657-
# With heterogeneous TP, P must wait for all assigned D TP workers to
658-
# finish reading before safely freeing the blocks.
659-
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
660744
self.xfer_stats = NixlKVConnectorStats()
661745

662746
@staticmethod
@@ -1220,25 +1304,6 @@ def get_finished(self) -> tuple[set[str], set[str]]:
12201304
assert meta, f"{req_id} not found in recving_metadata list"
12211305
self.sync_recved_kv_to_device(req_id, meta)
12221306

1223-
# Handle timeout to avoid stranding blocks on remote.
1224-
now = time.perf_counter()
1225-
while self._reqs_to_send:
1226-
req_id, expires = next(iter(self._reqs_to_send.items()))
1227-
# Sorted dict, oldest requests are put first so we can exit early.
1228-
if now < expires:
1229-
break
1230-
count = self.consumer_notification_counts_by_req.pop(req_id, 0)
1231-
logger.warning(
1232-
"Releasing expired KV blocks for request %s which were "
1233-
"retrieved by %d decode worker(s) within %d seconds.",
1234-
req_id,
1235-
count,
1236-
envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT,
1237-
)
1238-
self._reqs_to_process.remove(req_id)
1239-
del self._reqs_to_send[req_id]
1240-
done_sending.add(req_id)
1241-
12421307
return done_sending, done_recving
12431308

12441309
def _get_new_notifs(self) -> set[str]:
@@ -1250,26 +1315,8 @@ def _get_new_notifs(self) -> set[str]:
12501315
notified_req_ids: set[str] = set()
12511316
for notifs in self.nixl_wrapper.get_new_notifs().values():
12521317
for notif in notifs:
1253-
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
1254-
if (
1255-
req_id not in self._reqs_to_send
1256-
and req_id not in self._reqs_to_process
1257-
):
1258-
logger.error(
1259-
"Potentially invalid KV blocks for "
1260-
"unrecognized request %s were retrieved by "
1261-
"a decode worker. They may have expired.",
1262-
req_id,
1263-
)
1264-
continue
1265-
1266-
self.consumer_notification_counts_by_req[req_id] += 1
1267-
# Wait all consumers (D) to be done reading before freeing.
1268-
if self.consumer_notification_counts_by_req[req_id] == int(tp_ratio):
1269-
notified_req_ids.add(req_id)
1270-
del self.consumer_notification_counts_by_req[req_id]
1271-
self._reqs_to_process.remove(req_id)
1272-
self._reqs_to_send.pop(req_id, None)
1318+
# Note - this is in req_id:tp_ratio format
1319+
notified_req_ids.add(notif.decode("utf-8"))
12731320
return notified_req_ids
12741321

12751322
def _pop_done_transfers(
@@ -1333,24 +1380,6 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
13331380
while not self._ready_requests.empty():
13341381
self._read_blocks_for_req(*self._ready_requests.get_nowait())
13351382

1336-
# Keep around the requests that have been part of a batch. This is
1337-
# needed because async scheduling pushes the misalignment between the
1338-
# moment in which requests expiration is set (P side) and the moment in
1339-
# which blocks are read from D. As P can now more easily lag behind D
1340-
# while processing the next batch, we make sure to only set an
1341-
# expiration for requests that have not been read from D yet.
1342-
for req_id in metadata.reqs_in_batch:
1343-
self._reqs_to_process.add(req_id)
1344-
1345-
# Remove all requests that are not to be processed (eg aborted).
1346-
for req_id in metadata.reqs_not_processed:
1347-
self._reqs_to_process.discard(req_id)
1348-
1349-
# Add to requests that are waiting to be read and track expiration.
1350-
for req_id, expiration_time in metadata.reqs_to_send.items():
1351-
if req_id in self._reqs_to_process:
1352-
self._reqs_to_send[req_id] = expiration_time
1353-
13541383
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
13551384
logger.debug(
13561385
"Remote agent %s available, calling _read_blocks for req %s",

0 commit comments

Comments
 (0)