Skip to content

Commit 784c231

Browse files
authored
[NIXL] Ignore abort on already-finished request (vllm-project#25067)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent 606b00e commit 784c231

File tree

4 files changed

+64
-16
lines changed

4 files changed

+64
-16
lines changed

tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def test_basic_lifecycle():
4343
# STEP (1): Prefill.
4444
# (1a): schedule()
4545
scheduler_output = scheduler.schedule()
46+
assert len(scheduler.requests) == 1
4647
assert len(scheduler.running) == 1
4748
assert len(scheduler_output.scheduled_new_reqs) == 1
4849

@@ -67,6 +68,7 @@ def test_basic_lifecycle():
6768
assert len(scheduler.waiting) == 0
6869

6970
# ... but blocks should not be freed.
71+
assert len(scheduler.requests) == 1
7072
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
7173
0
7274
].req_to_blocks[request_id]
@@ -76,6 +78,7 @@ def test_basic_lifecycle():
7678
# STEP (2): Send Finished to PB.
7779
# (2a): schedule() - pass finished request to PB.
7880
scheduler_output = scheduler.schedule()
81+
assert len(scheduler.requests) == 1
7982
assert len(scheduler.running) == 0
8083
assert len(scheduler_output.finished_req_ids) == 1
8184
assert request_id in scheduler_output.finished_req_ids
@@ -92,6 +95,7 @@ def test_basic_lifecycle():
9295
# STEP (3): Finished sending.
9396
# (3a): schedule() - pass finished request to PB.
9497
scheduler_output = scheduler.schedule()
98+
assert len(scheduler.requests) == 1
9599
assert len(scheduler.running) == 0
96100
assert len(scheduler_output.finished_req_ids) == 0
97101
assert len(scheduler_output.scheduled_new_reqs) == 0
@@ -133,6 +137,7 @@ def test_short_prompt_lifecycle():
133137
# STEP (1): Prefill.
134138
# (1a): schedule()
135139
scheduler_output = scheduler.schedule()
140+
assert len(scheduler.requests) == 1
136141
assert len(scheduler.running) == 1
137142
assert len(scheduler_output.scheduled_new_reqs) == 1
138143

@@ -178,7 +183,7 @@ def test_prefix_cache_lifecycle():
178183
reqs=[request_normal], use_eos=True
179184
)
180185
scheduler.update_from_output(scheduler_output, model_runner_output)
181-
scheduler.schedule()
186+
scheduler_output = scheduler.schedule()
182187
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
183188

184189
#####################
@@ -213,3 +218,45 @@ def test_prefix_cache_lifecycle():
213218
)
214219
scheduler.update_from_output(scheduler_output, model_runner_output)
215220
assert_scheduler_empty(scheduler)
221+
222+
223+
def test_abort_during_kv_transfer():
224+
"""Test aborting request does not release blocks for remote decode."""
225+
226+
vllm_config = create_vllm_config()
227+
scheduler = create_scheduler(vllm_config)
228+
229+
# Prime the KVCache.
230+
BLOCK_SIZE = vllm_config.cache_config.block_size
231+
NUM_EXTERNAL_FULL_BLOCKS = 2
232+
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
233+
234+
request = create_request(
235+
request_id=1,
236+
block_size=BLOCK_SIZE,
237+
num_tokens=NUM_TOKENS,
238+
do_remote_decode=True,
239+
)
240+
241+
scheduler.add_request(request)
242+
scheduler_output = scheduler.schedule()
243+
model_runner_output = create_model_runner_output(reqs=[request])
244+
scheduler.update_from_output(scheduler_output, model_runner_output)
245+
scheduler_output = scheduler.schedule()
246+
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
247+
248+
# Request removed from PB but blocks should not be freed.
249+
assert len(scheduler.requests) == 1
250+
251+
# Abort the request, and check the blocks are still not freed
252+
scheduler.finish_requests([request.request_id], RequestStatus.FINISHED_ABORTED)
253+
assert len(scheduler.requests) == 1
254+
255+
# Simulate a finished sending notification
256+
scheduler_output = scheduler.schedule()
257+
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
258+
model_runner_output.kv_connector_output = KVConnectorOutput(
259+
finished_sending=[request.request_id]
260+
)
261+
scheduler.update_from_output(scheduler_output, model_runner_output)
262+
assert_scheduler_empty(scheduler)

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
temporary buffer alloc by the CacheManager.
1515
update_connector_output() - update KVConnector state after
1616
output is received from worker-side connectors.
17-
request_finished() - called when a request is finished, with
18-
the computed kv cache blocks for the request.
19-
Returns whether KV cache should be freed now or will be
20-
freed asynchronously and optionally returns KV transfer
21-
params.
17+
request_finished() - called once when a request is finished,
18+
with the computed kv cache blocks for the request.
19+
Returns whether KV cache should be freed now or if the
20+
connector now assumes responsibility for freeing the
21+
the blocks asynchronously. Also optionally returns KV
22+
transfer params.
2223
take_events() - returns new KV events that were collected
2324
by the connector since the last call.
2425
@@ -362,7 +363,11 @@ def request_finished(
362363
block_ids: list[int],
363364
) -> tuple[bool, Optional[dict[str, Any]]]:
364365
"""
365-
Called when a request has finished, before its blocks are freed.
366+
Called exactly once when a request has finished, before its blocks are
367+
freed.
368+
369+
The connector may assumes responsibility for freeing the the blocks
370+
asynchronously by returning True.
366371
367372
Returns:
368373
True if the request is being saved/sent asynchronously and blocks

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,8 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
13451345
# Remove all requests that are not to be processed (eg aborted).
13461346
for req_id in metadata.reqs_not_processed:
13471347
self._reqs_to_process.discard(req_id)
1348+
# We should never get an abort after setting an expiry timer
1349+
assert req_id not in self._reqs_to_send
13481350

13491351
# Add to requests that are waiting to be read and track expiration.
13501352
for req_id, expiration_time in metadata.reqs_to_send.items():

vllm/v1/core/sched/scheduler.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,7 +1187,7 @@ def finish_requests(
11871187
# First pass: collect requests to remove from queues
11881188
for req_id in request_ids:
11891189
request = self.requests.get(req_id)
1190-
if request is None:
1190+
if request is None or request.is_finished():
11911191
# Invalid request ID.
11921192
continue
11931193

@@ -1365,14 +1365,8 @@ def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput):
13651365
self.finished_recving_kv_req_ids.add(req_id)
13661366
for req_id in kv_connector_output.finished_sending or ():
13671367
logger.debug("Finished sending KV transfer for request %s", req_id)
1368-
if req_id not in self.requests:
1369-
logger.warning(
1370-
"Got finished sending KV transfer for request %s,"
1371-
"but the request is already freed.",
1372-
req_id,
1373-
)
1374-
else:
1375-
self._free_blocks(self.requests[req_id])
1368+
assert req_id in self.requests
1369+
self._free_blocks(self.requests[req_id])
13761370

13771371
def _update_requests_with_invalid_blocks(
13781372
self, requests: Iterable[Request], invalid_block_ids: set[int]

0 commit comments

Comments
 (0)