Skip to content

Conversation

@tlrmchlsmth
Copy link
Member

@tlrmchlsmth tlrmchlsmth commented Oct 15, 2025

This fixes a race condition where the KV transfer times out before the request is marked finished and adds a unit test.

Here's a log from the failure:

(EngineCore_DP3 pid=654) DEBUG 10-15 12:29:26 [distributed/.../v1/nixl_connector.py:322] NIXLConnector update_state_after_alloc: num_external_tokens=0, kv_transfer_params={'do_remote_decode': True, 'do_remote_prefill': False, 'remote_block_ids': None, 'remote_engine_id': None, 'remote_host': None, 'remote_port': None}
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710] EngineCore encountered a fatal error.
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710] Traceback (most recent call last):
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710]   File "/app/vllm/vllm/v1/engine/core.py", line 701, in run_engine_core
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710]     engine_core.run_busy_loop()
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710]   File "/app/vllm/vllm/v1/engine/core.py", line 1045, in run_busy_loop
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710]     executed = self._process_engine_step()
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710]   File "/app/vllm/vllm/v1/engine/core.py", line 754, in _process_engine_step
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710]     outputs, model_executed = self.step_fn()
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710]                               ^^^^^^^^^^^^^^
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710]   File "/app/vllm/vllm/v1/engine/core.py", line 349, in step_with_batch_queue
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710]     engine_core_outputs = self.scheduler.update_from_output(
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710]   File "/app/vllm/vllm/v1/core/sched/scheduler.py", line 990, in update_from_output
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710]     self._update_from_kv_xfer_finished(
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710]   File "/app/vllm/vllm/v1/core/sched/scheduler.py", line 1296, in _update_from_kv_xfer_finished
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710]     self._free_blocks(self.requests[req_id])
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710]   File "/app/vllm/vllm/v1/core/sched/scheduler.py", line 1163, in _free_blocks
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710]     assert request.is_finished()
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710]            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP7 pid=658) ERROR 10-15 12:29:26 [v1/engine/core.py:710] AssertionError

… finishes

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a critical race condition in the scheduler. The race condition occurred when a KV transfer timed out, causing a worker to report a request as finished_sending before the scheduler considered the request to be in a finished state. This led to an AssertionError and a crash when attempting to free blocks for a still-running request. The fix introduces robust checks to verify a request's status before attempting to free its blocks, replacing a dangerous assertion with graceful handling and a warning log. A targeted test case has been added to simulate this specific race condition, ensuring the fix is effective and preventing future regressions. The changes are well-implemented and I see no issues.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Comment on lines 1366 to +1385
for req_id in kv_connector_output.finished_sending or ():
logger.debug("Finished sending KV transfer for request %s", req_id)
assert req_id in self.requests
self._free_blocks(self.requests[req_id])
request = self.requests.get(req_id)
if request is None:
logger.warning(
"Got finished sending KV transfer for request %s, "
"but the request is already freed.",
req_id,
)
elif not request.is_finished():
logger.warning(
"Got finished sending KV transfer for request %s, "
"but the request is not finished (status=%s). "
"This may indicate the request was aborted or the KV "
"transfer timed out before the request completed.",
req_id,
request.status,
)
else:
self._free_blocks(request)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Track premature KV transfers to avoid leaked blocks

When a worker reports finished_sending before the scheduler marks the request as finished, the new branch only logs a warning and returns. No state is recorded that the transfer has already completed. Once the request does eventually finish, _free_request() still calls _connector_finished(), which typically returns delay_free_blocks=True for remote decode, so _free_blocks() is never invoked unless another finished_sending arrives. Since the worker already emitted its only finished_sending event during the timeout, the request is left in self.requests and its KV blocks remain allocated indefinitely, leaking cache space and preventing the scheduler from recycling memory. The handler needs to persist that the transfer already finished (or immediately free when the request later finishes) rather than merely warn.

Useful? React with 👍 / 👎.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good question to wonder what will happen if ConnectorScheduler.request_finished() gets called after the request was already reported as finished_sending

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Copy link
Member

@markmc markmc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My sense is there's a bug to be fixed here, not papered over

the KV transfer times out before the request is marked finished

That shouldn't happen (AIUI at least). If that is what's happening, I think we should figure out why

but not actually in a finished state on the scheduler side.
This can happen when:
1. Worker-side NIXL connector times out waiting for decode workers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this timeout should only be started once the request has finished

The timeout is started at the delay_free_blocks spot in ConnectorScheduler.request_finished()

This can happen when:
1. Worker-side NIXL connector times out waiting for decode workers
2. Worker reports request in finished_sending to prevent stranding blocks
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requests are added to finished_sending by returning them from ConnectorWorker.get_finished() in two cases:

  1. We've received the required number of xfer notifications
  2. The timeout expired

Did you mean one of these? Or some other way that a request is added to finished_sending?

(e.g. are you thinking of some connector other than NIXL, or ...?)

This can happen when:
1. Worker-side NIXL connector times out waiting for decode workers
2. Worker reports request in finished_sending to prevent stranding blocks
3. Scheduler-side request hasn't reached a finished state yet
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The decode side is somehow getting these block IDs before the prefill side has finished?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just thinking about this possibility - on the prefill side we get notification from decode that the blocks for this request has been transferred. But the request is still not finished on the prefill side ...

Assumption: it impossible for the decode side to be notifying prefill about a request before prefill returns kv_transfer_params, which happens here:

    def _free_request(self, request: Request) -> dict[str, Any] | None:
	assert request.is_finished()

        delay_free_blocks, kv_xfer_params = self._connector_finished(request)
        ...
        if not delay_free_blocks:
            self._free_blocks(request)

        return kv_xfer_params

If the assumption above is correct, then the scenario looks impossible - the request must be finished before prefill returns kv_transfer_params?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another theory - the request is preempted after it finished and is waiting for KV blocks to fetch? Doesn't seem possible - we only choose requests to preempt from Scheduler.running

"Got finished sending KV transfer for request %s, "
"but the request is already freed.",
req_id,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was deliberately removed by #25067 - it's a bug if it happens, the warning is not actionable by users, at best it should be a "fix this bug!" debug statement

"Got finished sending KV transfer for request %s, "
"but the request is not finished (status=%s). "
"This may indicate the request was aborted or the KV "
"transfer timed out before the request completed.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, not actionable by a user - it's either an expected scenario that we can safely ignore (with no logging) or something that's a bug if it happens

Comment on lines 1366 to +1385
for req_id in kv_connector_output.finished_sending or ():
logger.debug("Finished sending KV transfer for request %s", req_id)
assert req_id in self.requests
self._free_blocks(self.requests[req_id])
request = self.requests.get(req_id)
if request is None:
logger.warning(
"Got finished sending KV transfer for request %s, "
"but the request is already freed.",
req_id,
)
elif not request.is_finished():
logger.warning(
"Got finished sending KV transfer for request %s, "
"but the request is not finished (status=%s). "
"This may indicate the request was aborted or the KV "
"transfer timed out before the request completed.",
req_id,
request.status,
)
else:
self._free_blocks(request)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good question to wonder what will happen if ConnectorScheduler.request_finished() gets called after the request was already reported as finished_sending

# reports this request as finished_sending, even though the request
# is still RUNNING on the scheduler side.
# This simulates the timeout scenario in NIXL connector.
kv_connector_output = KVConnectorOutput(finished_sending={request.request_id})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a similar situation to the abort-after-finished race condition in #25067 - I wrote a unit test to artificially simulate it, but it took quite a bit of digging to understand how it could happen and figure out how we wanted to handle it

@markmc
Copy link
Member

markmc commented Oct 16, 2025

Here's a log from the failure:

Was this preceded by Releasing expired KV blocks for request ?

@njhill
Copy link
Member

njhill commented Oct 16, 2025

Thanks @markmc I had a similar conclusion that the error in question should not be possible and we should get to the bottom of and fix the root cause rather than adding these new checks.

My only theory as to how it could have happened is if the prefiller received more than one request with the same id.

@markmc
Copy link
Member

markmc commented Oct 17, 2025

Ok, I think I can reproduce your repeated-request-ID theory @njhill with some effort

Here's the server:

CUDA_VISIBLE_DEVICES=0 VLLM_NIXL_SIDE_CHANNEL_PORT=5559 VLLM_NIXL_ABORT_REQUEST_TIMEOUT=30 VLLM_LOGGING_LEVEL=DEBUG UCX_LOG_LEVEL=debug  vllm serve meta-llama/Llama-3.2-1B-Instruct --port 8100 --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' --enable-log-requests --distributed-executor-backend mp | tee prefill-$(date +%s).log

And I spin up lots of long-random-context requests with the same request ID, hit Ctrl-C, repeat until crash:

$ while true; do bash req.sh; done

See req.sh here

At some point, I seem to have gotten lucky and (running this PR) got:

16082:(EngineCore_DP0 pid=2122813) WARNING 10-17 10:43:57 [v1/core/sched/scheduler.py:1379] Got finished sending KV transfer for request cmpl-1760710278-0, but the request is not finished (status=RUNNING). This may indicate the request was aborted or the KV transfer timed out before the request completed.
18251:(EngineCore_DP0 pid=2122813) WARNING 10-17 10:52:58 [v1/core/sched/scheduler.py:1379] Got finished sending KV transfer for request cmpl-1760710278-0, but the request is not finished (status=RUNNING). This may indicate the request was aborted or the KV transfer timed out before the request completed.

but mostly I'm seeing this crash


(Worker pid=2127426) ERROR 10-17 11:28:10 [v1/executor/multiproc_executor.py:700]   File "/home/markmc/vllm-project/vllm/vllm/v1/worker/gpu_model_runner.py", line 2420, in execute_model
(Worker pid=2127426) ERROR 10-17 11:28:10 [v1/executor/multiproc_executor.py:700]     return self.kv_connector_no_forward(
(Worker pid=2127426) ERROR 10-17 11:28:10 [v1/executor/multiproc_executor.py:700]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker pid=2127426) ERROR 10-17 11:28:10 [v1/executor/multiproc_executor.py:700]   File "/home/markmc/vllm-project/vllm/vllm/v1/worker/kv_connector_model_runner_mixin.py", line 79, in kv_connector_no_forward
(Worker pid=2127426) ERROR 10-17 11:28:10 [v1/executor/multiproc_executor.py:700]     with (
(Worker pid=2127426) ERROR 10-17 11:28:10 [v1/executor/multiproc_executor.py:700]   File "/usr/lib64/python3.12/contextlib.py", line 137, in __enter__
(Worker pid=2127426) ERROR 10-17 11:28:10 [v1/executor/multiproc_executor.py:700]     return next(self.gen)
(Worker pid=2127426) ERROR 10-17 11:28:10 [v1/executor/multiproc_executor.py:700]            ^^^^^^^^^^^^^^
(Worker pid=2127426) ERROR 10-17 11:28:10 [v1/executor/multiproc_executor.py:700]   File "/home/markmc/vllm-project/vllm/vllm/v1/worker/kv_connector_model_runner_mixin.py", line 123, in _get_kv_connector_output
(Worker pid=2127426) ERROR 10-17 11:28:10 [v1/executor/multiproc_executor.py:700]     kv_connector.start_load_kv(get_forward_context())
(Worker pid=2127426) ERROR 10-17 11:28:10 [v1/executor/multiproc_executor.py:700]   File "/home/markmc/vllm-project/vllm/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py", line 261, in start_load_kv
(Worker pid=2127426) ERROR 10-17 11:28:10 [v1/executor/multiproc_executor.py:700]     self.connector_worker.start_load_kv(self._connector_metadata)
(Worker pid=2127426) ERROR 10-17 11:28:10 [v1/executor/multiproc_executor.py:700]   File "/home/markmc/vllm-project/vllm/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py", line 1469, in start_load_kv
(Worker pid=2127426) ERROR 10-17 11:28:10 [v1/executor/multiproc_executor.py:700]     assert req_id not in self._reqs_to_send
(Worker pid=2127426) ERROR 10-17 11:28:10 [v1/executor/multiproc_executor.py:700]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker pid=2127426) ERROR 10-17 11:28:10 [v1/executor/multiproc_executor.py:700] AssertionError

@markmc
Copy link
Member

markmc commented Oct 17, 2025

At some point, I seem to have gotten lucky and (running this PR) got:

16082:(EngineCore_DP0 pid=2122813) WARNING 10-17 10:43:57 [v1/core/sched/scheduler.py:1379] Got finished sending KV transfer for request cmpl-1760710278-0, but the request is not finished (status=RUNNING). This may indicate the request was aborted or the KV transfer timed out before the request completed.

To answer my own question, these were preceded by:

WARNING 10-17 10:43:57 [distributed/.../v1/nixl_connector.py:1330] Releasing expired KV blocks for request cmpl-1760710278-0 which were retrieved by 0 decode worker(s) within 30 seconds.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants