Skip to content

Conversation

@CLFutureX
Copy link
Contributor

@CLFutureX CLFutureX commented Oct 14, 2025

When an already scheduled request is evicted, an element will be removed from the range [0, req_index) of the running queue.
At this point, all unscheduled request will shift forward by one position, and req_index needs to be adjusted accordingly by moving it forward one step (i.e., req_index -= 1). Failure to do so will result in the next request being skipped.
This defect was previously identified by me, but due to other factors, it has not been properly resolved to date.

Signed-off-by: CLFutureX <chenyongqyl@163.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug in the scheduler's priority-based preemption logic. When a request is preempted to make space for a new one, if the preempted request had already been evaluated in the current scheduling cycle, the index for iterating over the running queue was not correctly adjusted. This could lead to the scheduler skipping the next request in the queue. The fix introduces a decrement to the loop index in this specific scenario, which correctly compensates for the removal of the element from the list. The change is accurate and resolves the bug.

@CLFutureX
Copy link
Contributor Author

@WoosukKwon @heheda12345 PTAL

token_budget += num_scheduled_tokens[preempted_req.request_id]
req_to_new_blocks.pop(preempted_req.request_id)
num_scheduled_tokens.pop(preempted_req.request_id)
req_index -= 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this @CLFutureX and sorry for not getting to it sooner!

I think whether we decrement req_index here should depend on whether the index of the preempted request is less than req_index?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nevermind, I see that it will only be in scheduled_running_reqs iff this is the case, so I think this looks correct.

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @CLFutureX

@njhill njhill added ready ONLY add when PR is ready to merge/full CI is needed bug Something isn't working labels Oct 22, 2025
@njhill njhill merged commit 243ed7d into vllm-project:main Oct 23, 2025
46 checks passed
usberkeley pushed a commit to usberkeley/vllm that referenced this pull request Oct 23, 2025
)

Signed-off-by: CLFutureX <chenyongqyl@163.com>
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 23, 2025
)

Signed-off-by: CLFutureX <chenyongqyl@163.com>
Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
845473182 pushed a commit to raindaywhu/vllm that referenced this pull request Oct 24, 2025
…o step_forward

* 'step_forward' of https://github.com/raindaywhu/vllm: (148 commits)
  [Model] Add MoE support for NemotronH (vllm-project#25863)
  [Metrics] [KVConnector] Add connector prefix cache hit rate stats (vllm-project#26245)
  [CI] Reorganize entrypoints tests (vllm-project#27403)
  add SLA information into comparison graph for vLLM Benchmark Suite (vllm-project#25525)
  [CI/Build] Fix AMD CI: test_cpu_gpu.py (vllm-project#27388)
  [Bugfix] Fix args settings for guided decoding args (vllm-project#27375)
  [CI/Build] Fix Prithvi plugin test (vllm-project#27393)
  [Chore] Remove duplicate `has_` functions in vllm.utils (vllm-project#27372)
  [Model] Add num_cached_tokens for PoolingRequestOutput (vllm-project#27378)
  [V1][spec decode] return logprobs for spec decoding (vllm-project#26060)
  [CORE] Support Prefix Caching with Prompt Embeds (vllm-project#27219)
  [Bugfix][Core] running queue index leakage exception (vllm-project#26754)
  [Bugfix] Fix incorrect kv cache metrics in grafana.json (vllm-project#27133)
  [Bugfix] Fix SLA tuner initialization (vllm-project#27355)
  [Bugfix] Fix deepseek-ocr multi-image inference and add `merge_by_field_config=True` with tensor schema support (vllm-project#27361)
  [MLA] Bump FlashMLA (vllm-project#27354)
  [Chore] Separate out system utilities from vllm.utils (vllm-project#27201)
  [BugFix] bugfix for Flash Attention MLA with full cuda graph IMA following pr-25490 (vllm-project#27128)
  [Feature] publisher default set zmq in kv_event config (vllm-project#26915)
  [Prefix Cache] Use LoRA name for consistent KV-cache block hashing (vllm-project#27211)
  ...
kingsmad pushed a commit to kingsmad/vllm that referenced this pull request Oct 25, 2025
)

Signed-off-by: CLFutureX <chenyongqyl@163.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
)

Signed-off-by: CLFutureX <chenyongqyl@163.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
)

Signed-off-by: CLFutureX <chenyongqyl@163.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants