Skip to content

Commit f6dae94

Browse files
committed
[NIXL] Fix KeyError on abort-after-finished
We have observed a rare scenario with AsyncLLM where a client disconnect triggers an abort request after the request has finished, but before AsyncLLM has processed the request output. See vllm-project#26012, vllm-project#25067, vllm-project#25844, and llm-d/llm-d#187. Without the fix, the unit test fails with: ``` logger.warning( "Releasing expired KV blocks for request %s which were " "retrieved by %d decode worker(s) within %d seconds.", req_id, count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT, ) > self._reqs_to_process.remove(req_id) E KeyError: '0' vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:1238: KeyError ``` Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent e1098ce commit f6dae94

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,93 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
877877
assert "0" not in req_to_blocks
878878

879879

880+
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
881+
@patch(
882+
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
883+
FakeNixlWrapper,
884+
)
885+
def test_abort_after_finish_on_prefiller(monkeypatch, distributed_executor_backend):
886+
"""
887+
Simulate a rare scenario with AsyncLLM where a client disconnect
888+
triggers an abort request after the request has finished, but before
889+
AsyncLLM has processed the request output
890+
"""
891+
model_name = "Qwen/Qwen3-0.6B"
892+
kv_transfer_config = KVTransferConfig(
893+
kv_connector="NixlConnector",
894+
kv_role="kv_both",
895+
)
896+
llm_kwargs = {
897+
"model": model_name,
898+
"enforce_eager": True,
899+
"gpu_memory_utilization": 0.5,
900+
"kv_transfer_config": kv_transfer_config,
901+
"distributed_executor_backend": distributed_executor_backend,
902+
}
903+
904+
timeout = 6
905+
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
906+
monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout))
907+
908+
# Build runtime_env only if we're using Ray
909+
if distributed_executor_backend == "ray":
910+
with _make_fake_nixl_pkg() as working_dir:
911+
runtime_env = {
912+
"working_dir": working_dir, # ship fake nixl package
913+
"env_vars": {
914+
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
915+
# TODO: for ray to carry over, remove once we set
916+
"NIXL_TELEMETRY_ENABLE": "1",
917+
},
918+
}
919+
ray.init(runtime_env=runtime_env)
920+
921+
_run_abort_after_finish_test(llm_kwargs, timeout)
922+
else:
923+
_run_abort_after_finish_test(llm_kwargs, timeout)
924+
925+
926+
def _run_abort_after_finish_test(llm_kwargs: dict, timeout: int):
927+
"""Helper function to run the abort after finish test."""
928+
llm = LLM(**llm_kwargs)
929+
remote_prefill_opts = {
930+
"do_remote_decode": True,
931+
"do_remote_prefill": False,
932+
"remote_engine_id": None,
933+
"remote_block_ids": None,
934+
"remote_host": None,
935+
"remote_port": None,
936+
}
937+
# Simulate sidecar request
938+
sampling_params = SamplingParams(
939+
temperature=0.0,
940+
max_tokens=1,
941+
extra_args={"kv_transfer_params": remote_prefill_opts},
942+
)
943+
scheduler = llm.llm_engine.engine_core.engine_core.scheduler
944+
req_to_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
945+
0
946+
].req_to_blocks
947+
948+
padding = "Just making this request a little longer so that we're sure "
949+
"we're not hitting the small-request lower bound beneath which we don't "
950+
"actually trigger the whole kv transfer, but rather just recompute the "
951+
"blocks on D."
952+
_ = llm.generate([f"What is the capital of Japan? {padding}"], sampling_params)
953+
954+
# Request finished but not freed
955+
assert "0" in scheduler.finished_req_ids and "0" in req_to_blocks
956+
957+
# Request aborted and freed
958+
llm.llm_engine.engine_core.abort_requests(["0"])
959+
assert "0" not in req_to_blocks
960+
961+
# Wait for timeout and trigger another scheduler loop
962+
time.sleep(timeout)
963+
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
964+
# Timeout logic hasn't crashed!
965+
966+
880967
def test_register_kv_caches(dist_init):
881968
"""
882969
Test that register_kv_caches() properly calls nixl_wrapper methods with

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,6 +1344,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
13441344

13451345
# Remove all requests that are not to be processed (eg aborted).
13461346
for req_id in metadata.reqs_not_processed:
1347+
self._reqs_to_send.pop(req_id, None)
13471348
self._reqs_to_process.discard(req_id)
13481349

13491350
# Add to requests that are waiting to be read and track expiration.

0 commit comments

Comments
 (0)