Skip to content

Conversation

@LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Sep 15, 2025

Purpose

  • Add support for having prefill requests inside ubatches; this is done by allowing a request to be split across ubatches with the second batch effectively chunked-prefill continuing on from the the first (since it always runs after)
  • Add support for the DeepEP high-throughput kernels with SM control for DeepGEMM and DeepEP (shout-out to @yewentao256 for the HT support)
  • General cleanups

Test Plan

lm_eval

Test Result

export VLLM_ALL2ALL_BACKEND=deepep_high_throughput

lm_eval --model local-completions --model_args "base_url=http://0.0.0.0:8000/v1/completions,model=deepseek-ai/DeepSeek-R1,num_concurrent=256" --tasks gsm8k
....
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9522|±  |0.0059|
|     |       |strict-match    |     5|exact_match|↑  |0.9500|±  |0.0060|

export VLLM_ALL2ALL_BACKEND=deepep_low_latency

lm_eval --model local-completions --model_args "base_url=http://0.0.0.0:8000/v1/completions,model=deepseek-ai/DeepSeek-R1,num_concurrent=256" --tasks gsm8k
....
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9522|±  |0.0059|
|     |       |strict-match    |     5|exact_match|↑  |0.9507|±  |0.0060|

HT Overlap Trace (2x8xH100)
Screenshot 2025-09-20 at 4 43 31 PM


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

SageMoore and others added 30 commits June 2, 2025 19:04
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
…pping asserts

Signed-off-by: Sage Moore <sage@neuralmagic.com>
…sult in an empty second ubatch

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
LucasWilkinson and others added 2 commits September 22, 2025 11:14
…e.py

Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@mergify
Copy link

mergify bot commented Sep 22, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 22, 2025
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
@mergify mergify bot removed the needs-rebase label Sep 23, 2025
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) September 23, 2025 02:37
@mergify
Copy link

mergify bot commented Sep 23, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 23, 2025
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
@mergify mergify bot removed the needs-rebase label Sep 23, 2025
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@tlrmchlsmth tlrmchlsmth merged commit cc1dc7e into vllm-project:main Sep 23, 2025
51 checks passed
@voipmonitor
Copy link

@LucasWilkinson this commit introduces weired behaviour - the first http request with larger context is working normally but the subsequent requests are signifficantly slower. I have verified that it is this commit: cc1dc7e which should be this PR. a903669 is working normally

CUDA_VISIBLE_DEVICES=4,5 NCCL_DEBUG=INFO vllm serve /mnt/gpt-oss-120b --async-scheduling --tensor-parallel-size 2 --port 4999 --host 0.0.0.0 --config  GPT-OSS_Blackwell.yaml --served-model-name default --gpu-memory-utilization 0.8 --max-num-seqs 1024 --max-num-batched-tokens 32000

cat GPT-OSS_Blackwell.yaml
compilation-config: '{"pass_config":{"enable_fi_allreduce_fusion":false,"enable_noop":true},"custom_ops":["+rms_norm"],"cudagraph_mode":"FULL_AND_PIECEWISE"}'
async-scheduling: true
cuda-graph-sizes: 2048
max-num-batched-tokens: 32000

@jmkuebler jmkuebler mentioned this pull request Sep 24, 2025
5 tasks
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
… and Prefill support (vllm-project#24845)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
yewentao256 added a commit that referenced this pull request Oct 3, 2025
… and Prefill support (#24845)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>

if not should_ubatch:
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens += num_pad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing this doesn't make the padding happen.

gjc0824 pushed a commit to gjc0824/vllm that referenced this pull request Oct 10, 2025
… and Prefill support (vllm-project#24845)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: gaojc <1055866782@qq.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
… and Prefill support (vllm-project#24845)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
… and Prefill support (vllm-project#24845)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
… and Prefill support (vllm-project#24845)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
… and Prefill support (vllm-project#24845)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants