Skip to content

Commit d384771

Browse files
committed
[NIXL] Refactor scheduler->worker request state synchronization
Use a SCHEDULED/FINISHED/ENUM rather than in_batch, to_send, and not_processed. Also move the expiry timestamp calculation to the worker side so we're not sending timestamps across process boundaries. Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent 335b28f commit d384771

File tree

2 files changed

+32
-32
lines changed

2 files changed

+32
-32
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
NixlConnectorMetadata,
3131
NixlConnectorWorker,
3232
NixlKVConnectorStats,
33+
ReqState,
3334
)
3435
from vllm.distributed.kv_transfer.kv_transfer_state import (
3536
ensure_kv_transfer_shutdown,
@@ -1109,7 +1110,8 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init):
11091110
kv_meta = sched_out.kv_connector_metadata
11101111
assert kv_meta is not None
11111112
assert isinstance(kv_meta, NixlConnectorMetadata)
1112-
assert req.request_id in kv_meta.reqs_in_batch
1113+
assert req.request_id in kv_meta.reqs_to_send
1114+
assert kv_meta.reqs_to_send[req.request_id] == ReqState.SCHEDULED
11131115

11141116
#### Model Runner start ####
11151117
# Bind scheduler-produced metadata and start worker processing.
@@ -1134,7 +1136,8 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init):
11341136
kv_meta2 = sched_out2.kv_connector_metadata
11351137
assert kv_meta2 is not None
11361138
assert isinstance(kv_meta2, NixlConnectorMetadata)
1137-
assert req.request_id not in kv_meta2.reqs_in_batch
1139+
assert req.request_id in kv_meta2.reqs_to_send
1140+
assert kv_meta2.reqs_to_send[req.request_id] == ReqState.ABORTED
11381141

11391142
# Bind empty/abort metadata and run worker step
11401143
#### Model Runner start ####

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import contextlib
44
import copy
5+
import enum
56
import logging
67
import math
78
import os
@@ -113,13 +114,17 @@ class ReqMeta:
113114
tp_size: int
114115

115116

117+
class ReqState(enum.Enum):
118+
SCHEDULED = 1
119+
FINISHED = 2
120+
ABORTED = 3
121+
122+
116123
class NixlConnectorMetadata(KVConnectorMetadata):
117124
def __init__(self):
118125
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
119126
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
120-
self.reqs_to_send: dict[ReqId, float] = {}
121-
self.reqs_in_batch: set[ReqId] = set()
122-
self.reqs_not_processed: set[ReqId] = set()
127+
self.reqs_to_send: dict[ReqId, ReqState] = {}
123128

124129
def add_new_req(
125130
self,
@@ -299,12 +304,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
299304
# the scheduler. Used to make metadata passed to Worker.
300305
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
301306
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
302-
# Reqs to send and their expiration time
303-
self._reqs_need_send: dict[ReqId, float] = {}
304-
self._reqs_in_batch: set[ReqId] = set()
305-
# Reqs to remove from processed set because they're not to send after
306-
# remote prefill or aborted.
307-
self._reqs_not_processed: set[ReqId] = set()
307+
# Reqs to send state updates
308+
self._reqs_need_send: dict[ReqId, ReqState] = {}
308309

309310
def get_num_new_matched_tokens(
310311
self, request: "Request", num_computed_tokens: int
@@ -356,7 +357,7 @@ def update_state_after_alloc(
356357
return
357358

358359
if params.get("do_remote_decode"):
359-
self._reqs_in_batch.add(request.request_id)
360+
self._reqs_need_send[request.request_id] = ReqState.SCHEDULED
360361
if self.use_host_buffer and params.get("do_remote_decode"):
361362
# NOTE: when accelerator is not directly supported by Nixl,
362363
# prefilled blocks need to be saved to host memory before transfer.
@@ -429,14 +430,10 @@ def build_connector_meta(
429430
)
430431

431432
meta.reqs_to_send = self._reqs_need_send
432-
meta.reqs_in_batch = self._reqs_in_batch
433-
meta.reqs_not_processed = self._reqs_not_processed
434433

435434
# Clear the list once workers start the transfers
436435
self._reqs_need_recv.clear()
437436
self._reqs_need_save.clear()
438-
self._reqs_in_batch = set()
439-
self._reqs_not_processed = set()
440437
self._reqs_need_send = {}
441438

442439
return meta
@@ -477,7 +474,7 @@ def request_finished(
477474
if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
478475
# Also include the case of a P/D Prefill request with immediate
479476
# block free (eg abort). Stop tracking this request.
480-
self._reqs_not_processed.add(request.request_id)
477+
self._reqs_need_send[request.request_id] = ReqState.ABORTED
481478
return False, None
482479

483480
# TODO: check whether block_ids actually ever be 0. If not we could
@@ -486,9 +483,7 @@ def request_finished(
486483

487484
if delay_free_blocks:
488485
# Prefill request on remote. It will be read from D upon completion
489-
self._reqs_need_send[request.request_id] = (
490-
time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
491-
)
486+
self._reqs_need_send[request.request_id] = ReqState.FINISHED
492487

493488
return delay_free_blocks, dict(
494489
do_remote_prefill=True,
@@ -1221,7 +1216,7 @@ def get_finished(self) -> tuple[set[str], set[str]]:
12211216
self.sync_recved_kv_to_device(req_id, meta)
12221217

12231218
# Handle timeout to avoid stranding blocks on remote.
1224-
now = time.perf_counter()
1219+
now = time.monotonic()
12251220
while self._reqs_to_send:
12261221
req_id, expires = next(iter(self._reqs_to_send.items()))
12271222
# Sorted dict, oldest requests are put first so we can exit early.
@@ -1339,17 +1334,19 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
13391334
# which blocks are read from D. As P can now more easily lag behind D
13401335
# while processing the next batch, we make sure to only set an
13411336
# expiration for requests that have not been read from D yet.
1342-
for req_id in metadata.reqs_in_batch:
1343-
self._reqs_to_process.add(req_id)
1344-
1345-
# Remove all requests that are not to be processed (eg aborted).
1346-
for req_id in metadata.reqs_not_processed:
1347-
self._reqs_to_process.discard(req_id)
1348-
1349-
# Add to requests that are waiting to be read and track expiration.
1350-
for req_id, expiration_time in metadata.reqs_to_send.items():
1351-
if req_id in self._reqs_to_process:
1352-
self._reqs_to_send[req_id] = expiration_time
1337+
for req_id, req_state in metadata.reqs_to_send.items():
1338+
if req_state == ReqState.SCHEDULED:
1339+
self._reqs_to_process.add(req_id)
1340+
elif req_state == ReqState.ABORTED:
1341+
# Remove all requests that are not to be processed (eg aborted).
1342+
self._reqs_to_process.discard(req_id)
1343+
# We should never get an abort after setting an expiry timer
1344+
assert req_id not in self._reqs_to_send
1345+
elif req_state == ReqState.FINISHED and req_id in self._reqs_to_process:
1346+
# Add to requests that are waiting to be read and track expiration.
1347+
self._reqs_to_send[req_id] = (
1348+
time.monotonic() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
1349+
)
13531350

13541351
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
13551352
logger.debug(

0 commit comments

Comments
 (0)