Skip to content

Commit 50390a3

Browse files
NickLucchexuebwang-amd
authored andcommitted
[Bugfix] Fix _reqs_to_process leak on abort (vllm-project#26012)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent b2e52e5 commit 50390a3

File tree

2 files changed

+82
-2
lines changed

2 files changed

+82
-2
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from vllm.sampling_params import SamplingParams
3434
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
3535
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
36+
from vllm.v1.request import RequestStatus
3637

3738
from .utils import create_request, create_scheduler, create_vllm_config
3839

@@ -1023,3 +1024,68 @@ def test_shutdown_cleans_up_resources(dist_init):
10231024
assert mock_dereg.call_count == 2
10241025
mock_dereg.assert_any_call("desc1")
10251026
mock_dereg.assert_any_call("desc2")
1027+
1028+
1029+
@patch(
1030+
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
1031+
FakeNixlWrapper)
1032+
def test_aborted_request_removed_from_worker_in_batch(dist_init):
1033+
"""
1034+
Create and schedule a request so that P adds it to in-batch tracking via
1035+
the real scheduler, then simulate an abort (request not in next scheduler
1036+
iteration) and verify the worker no longer tracks it as in-batch.
1037+
"""
1038+
vllm_config = create_vllm_config()
1039+
1040+
scheduler = create_scheduler(vllm_config)
1041+
# KVConnector Worker in P
1042+
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
1043+
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
1044+
connector.engine_id,
1045+
hand_shake_latency=0)
1046+
1047+
# Create a request that triggers do_remote_decode so that
1048+
# the scheduler adds it to reqs_in_batch
1049+
req = create_request(request_id=1, do_remote_decode=True, max_tokens=1)
1050+
scheduler.add_request(req)
1051+
1052+
# First scheduling pass - examinate build_connector_meta output
1053+
sched_out = scheduler.schedule()
1054+
kv_meta = sched_out.kv_connector_metadata
1055+
assert kv_meta is not None
1056+
assert isinstance(kv_meta, NixlConnectorMetadata)
1057+
assert req.request_id in kv_meta.reqs_in_batch
1058+
1059+
#### Model Runner start ####
1060+
# Bind scheduler-produced metadata and start worker processing.
1061+
connector.bind_connector_metadata(kv_meta)
1062+
1063+
dummy_ctx = ForwardContext(
1064+
no_compile_layers={},
1065+
attn_metadata={},
1066+
virtual_engine=0,
1067+
)
1068+
connector.start_load_kv(dummy_ctx)
1069+
1070+
# Ensure it was tracked by the worker
1071+
assert req.request_id in connector.connector_worker._reqs_to_process
1072+
1073+
#### Model Runner end ####
1074+
1075+
# Abort request - request_finished call in connector scheduler
1076+
scheduler.finish_requests(req.request_id, RequestStatus.FINISHED_ABORTED)
1077+
# Second scheduling pass - build metadata with aborted request
1078+
sched_out2 = scheduler.schedule()
1079+
kv_meta2 = sched_out2.kv_connector_metadata
1080+
assert kv_meta2 is not None
1081+
assert isinstance(kv_meta2, NixlConnectorMetadata)
1082+
assert req.request_id not in kv_meta2.reqs_in_batch
1083+
1084+
# Bind empty/abort metadata and run worker step
1085+
#### Model Runner start ####
1086+
connector.bind_connector_metadata(kv_meta2)
1087+
connector.start_load_kv(dummy_ctx)
1088+
1089+
# After abort, the worker should not keep tracking it as "in-batch"
1090+
assert req.request_id not in connector.connector_worker._reqs_to_process
1091+
#### Model Runner end ####

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def __init__(self):
113113
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
114114
self.reqs_to_send: dict[ReqId, float] = {}
115115
self.reqs_in_batch: set[ReqId] = set()
116+
self.reqs_not_processed: set[ReqId] = set()
116117

117118
def add_new_req(
118119
self,
@@ -287,6 +288,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
287288
# Reqs to send and their expiration time
288289
self._reqs_need_send: dict[ReqId, float] = {}
289290
self._reqs_in_batch: set[ReqId] = set()
291+
# Reqs to remove from processed set because they're not to send after
292+
# remote prefill or aborted.
293+
self._reqs_not_processed: set[ReqId] = set()
290294

291295
def get_num_new_matched_tokens(
292296
self, request: "Request",
@@ -401,11 +405,13 @@ def build_connector_meta(
401405

402406
meta.reqs_to_send = self._reqs_need_send
403407
meta.reqs_in_batch = self._reqs_in_batch
408+
meta.reqs_not_processed = self._reqs_not_processed
404409

405410
# Clear the list once workers start the transfers
406411
self._reqs_need_recv.clear()
407412
self._reqs_need_save.clear()
408413
self._reqs_in_batch = set()
414+
self._reqs_not_processed = set()
409415
self._reqs_need_send = {}
410416

411417
return meta
@@ -439,8 +445,12 @@ def request_finished(
439445
params["do_remote_prefill"] = False
440446
return False, None
441447

442-
if (not params.get("do_remote_decode")
443-
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
448+
if not params.get("do_remote_decode"):
449+
return False, None
450+
if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
451+
# Also include the case of a P/D Prefill request with immediate
452+
# block free (eg abort). Stop tracking this request.
453+
self._reqs_not_processed.add(request.request_id)
444454
return False, None
445455

446456
# TODO: check whether block_ids actually ever be 0. If not we could
@@ -1234,6 +1244,10 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
12341244
for req_id in metadata.reqs_in_batch:
12351245
self._reqs_to_process.add(req_id)
12361246

1247+
# Remove all requests that are not to be processed (eg aborted).
1248+
for req_id in metadata.reqs_not_processed:
1249+
self._reqs_to_process.discard(req_id)
1250+
12371251
# Add to requests that are waiting to be read and track expiration.
12381252
for req_id, expiration_time in metadata.reqs_to_send.items():
12391253
if req_id in self._reqs_to_process:

0 commit comments

Comments
 (0)