Skip to content

Conversation

@tomasruizt
Copy link
Contributor

@tomasruizt tomasruizt commented Sep 5, 2025

Purpose

Enabling draft models for speculative decoding (SD).
E.g. Qwen3-1.7B as draft model and Qwen3-32B as target model.
This type of SD requires no special trained heads (like EAGLE, or Medusa).

Example usage:

vllm serve \
    --model=Qwen/Qwen3-4B \
    --speculative-config '{"model": "Qwen/Qwen3-0.6B", "method": "draft_model", "num_speculative_tokens": 3, "max-model-len": 2000, "disable_padded_drafter_batch": true}' \
    --max-model-len 2000

Get a generation:

curl -X POST http://localhost:8000/v1/completions \
    -H "Content-Type: application/json" \
    -d '{"prompt": "Capital of France", "max_tokens": 16}'

Status

  • The acceptance rates in greedy decoding are great for Qwen3 models (see corresponding section).
  • Using SD with Qwen3 has higher throughput (TPOT) than not using SD.

Acceptance Length

As suggested by @ekagra-ranjan, I benchmarked acceptance length (AL) with the command below:

VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py \
    --model-dir Qwen/Qwen3-32B \
    --draft-model Qwen/Qwen3-1.7B \
    --method draft_model \
    --num_spec_tokens 3 \
    --dataset-name hf \
    --dataset-path philschmid/mt-bench \
    --num_prompts 100 \
    --temp 1.0 \
    --gpu-memory-utilization 0.9

The AL values within the Qwen3 family seem good, both with temperatures of 0.0 (greedy) and 1.0.
As a sanity check, I benchmarked LLama-3.2-1B as both target and draft, which had almost perfect AL (3.97/4), suggesting its working as intended.
I have not run the default model meta-llama/Llama-3.1-8B-Instruct, because I didn't find a good draft model for it, but feel free to suggest one and I can run the benchmarks.

Temperature t=0:

Target Draft K Temperature AL
Qwen3-32B Qwen3-1.7B 3 0.0 2.79
Qwen3-32B Qwen3-1.7B 4 0.0 3.12
Llama-3.2-1B Llama-3.2-1B 3 0.0 3.97

Temperature t=1.0:

Target Draft K Temperature AL
Qwen3-32B Qwen3-1.7B 3 1.0 2.61
Qwen3-32B Qwen3-1.7B 4 1.0 2.85
Llama-3.2-1B Llama-3.2-1B 3 1.0 2.82

Using t=1.0, the AL metric degrades. However, spec-decode with probabilities is not yet implemented, needed for lossless rejection sampling. This is being addressed atm: #20459. After that PR is merged, the AL for non-greedy spec-decode should improve.

All scripts and logs used for the benchmarks can be found in this Google Drive.

Online Throughput Metrics

I measured online throughput metrics using the commands below. Hardware was an RTX PRO 6000 96GB. After making sure the draft model also uses CUDA graph, SD has higher throughput than not using SD. See tables below.

VLLM_USE_V1=1 vllm serve Qwen/Qwen3-32B \
  --max-model-len 20000 \
  --disable-uvicorn-access-log

# or 

VLLM_USE_V1=1 vllm serve Qwen/Qwen3-32B \
  --speculative_config '{"method": "draft_model", "model": "Qwen/Qwen3-1.7B", "num_speculative_tokens": 3, "max_model_len": 20000, "disable_padded_drafter_batch": true}' \
  --max-model-len 20000 \
  --disable-uvicorn-access-log

nohup vllm bench serve \
  --model Qwen/Qwen3-32B \
  --dataset-name hf \
  --dataset-path philschmid/mt-bench \
  --num-prompts 80 \
  --max-concurrency (100|1) \
  --temperature 0.0 \
  --top-p 1.0 2>&1 > results/qwen3-32b-t0-run1.out &
  • The tables show shorter runtime and higher throughput for SD (both in batch size 1 and 100).
  • Using SD the TPOT is 50% shorter (better) in batch size 1, and 26% to 33% shorter in batch size 100. The reason is the higher throughput of the draft model.
  • Using SD the TTFT and ITL are higher (worse), because tokens are produced in batches by the spec-decoding. Nevertheless, total runtimes are shorter overall when using SD.

The metrics (lower is better) are:

  • TTFT: Time-to-first-token
  • TPOT: Time-per-output-token
  • ITL: Inter-token-latency
Batch Size = 1 For Temperature = 0.0:
Target Draft Runtime TTFT TPOT ITL
Qwen3-32B - 943s 69.14ms 46.06ms 45.88ms
Qwen3-32B Qwen3-1.7B 466s 76.77ms 22.62ms 62.44ms

Using SD runtimes and TPOT are shorter by ~50%.

Batch Size = 100 For Temperature = 0.0:
Target Draft Runtime TTFT TPOT ITL
Qwen3-32B - 16.88s 262.59ms 65.09ms 64.84ms
Qwen3-32B Qwen3-1.7B 13.04s 284.45ms 43.70ms 121.48ms

For Temperature = 1.0:

Target Draft Runtime TTFT TPOT ITL
Qwen3-32B - 16.83s 230.04ms 64.95ms 64.70ms
Qwen3-32B Qwen3-1.7B 14.84s 272.51ms 48.00ms 122.27ms

This scenario with batch size 100 is a more realistic inference case.
Using SD runtimes and TPOT are shorter.

Profiling

This section was removed, since using CUDA graphs on the draft model significantly improved its speed.

Profiling script I used the command below to profile the generation process and identify that the draft model was running too slow before.
export VLLM_USE_V1=1
export VLLM_TORCH_PROFILER_DIR=./profiles/
export CUDA_LAUNCH_BLOCKING=1

vllm bench throughput \
    --model=Qwen/Qwen3-32B \
    --speculative-config '{"model": "Qwen/Qwen3-1.7B", "method": "draft_model", "num_speculative_tokens": 3, "max_model_len": 2048}' \
    --dataset-name=hf \
    --dataset-path=likaixin/InstructCoder \
    --max-num-seqs=100 \
    --num-prompts=10 \
    --input-len=1000 \
    --output-len=10 \
    --max-model-len=2048 \
    --gpu-memory-utilization=0.95 \
    --profile

Note: The command uses the --profile flag, which I introduce in this PR: #24575

Test Plan

The added unit test check the correctness metrics. To run it:

cd tests/v1/e2e/
pytest test_spec_decode.py -k test_draft_model_correctness
EAGLE testing

I tested that the EAGLE implementation stays unaffected the command below

VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py \
    --model-dir meta-llama/Llama-3.1-8B-Instruct \
    --eagle-dir yuhuili/EAGLE3-LLaMA3.1-Instruct-8B \
    --method eagle3 \
    --num_spec_tokens 7 \
    --dataset-name hf \
    --dataset-path philschmid/mt-bench \
    --num_prompts 80 \
    --temp 0.0 \
    --gpu-memory-utilization 0.9

The results are in line with previous measurements like #17504 (comment)

total_num_output_tokens: 16990
num_drafts: 4816
num_draft_tokens: 33712
num_accepted_tokens: 12208
mean acceptance length: 3.53
--------------------------------------------------
acceptance at token 0: 0.74
acceptance at token 1: 0.54
acceptance at token 2: 0.41
acceptance at token 3: 0.31
acceptance at token 4: 0.23
acceptance at token 5: 0.17
acceptance at token 6: 0.13

Follow-up Optimizations

  • Include the tokens in next_token_ids together with target_token_ids in the first forward pass of the draft model. This reduces the number of forward passes needed in each drafting phase by one, speeding up drafting.

Qwen3 Metrics

I compare Qwen3-32B against Qwen3-32B with Qwen3-1.7B as drafter model, on a single H100 GPU (TP=1).
The benchmarks show (left to right)

  • An increase in Total Token Throughput with 1 to 100 concurrent requests (so larger batch sizes).
  • A decrease Benchmark Duration
  • Increases in Total Token Throughput of almost 2x
image image image

Broken down we find:

  • A steep improvement (reduction) of Time Per Output Token (TPOT)
  • Higher Inter-Token-Latency (ITL), because tokens are generated and verified in batches
  • Equivalent Time-To-First-Token (TTFT) except in high batch sizes.
image

Llama3 Multi-GPU Metrics

I benchmarked meta-llama/Meta-Llama-3-70B, with meta-llama/Llama-3.2-1B as a draft model on 4 x H100 GPUs (TP=4). The goal was to measure the effect of tensor parallelism on throughput, since the small draft model is also running on the same level of tensor parallelism as the target model. I found an acceleration in this setup only up to bsz=32, and a slowdown for bsz=64. This setup shows that using TP > 1 degrades the acceleration of draft_model. Presumably, because of the inter-GPU communication overhead for the draft model. In a follow-up PR this should be optimized.

serving-script.sh
benchmark-script.sh

image image image

The TPOT metric improves in general, showing similar results to Qwen3.

image

EAGLE3 Metrics (Reference)

Below are benchmark values using method="eagle3", with target model meta-llama/Llama-3.3-70B-Instruct and draft model yuhuili/EAGLE3-LLaMA3.3-Instruct-70B on 4 x H100 GPUs (TP=4). Eagle achieves over 2x acceleration in batch_size=1. In contrast to draft_model, the eagle drafter is running with tensor_parallelism = 1.

serve-script.sh
bench-script.sh

image image image image

@github-actions
Copy link

github-actions bot commented Sep 5, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added performance Performance-related issues speculative-decoding v1 labels Sep 5, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for speculative decoding using a draft model. The changes are comprehensive, touching configuration, model loading, scheduling, and the core speculative decoding logic. New tests and benchmark modifications are also included to validate and measure the new feature. The overall implementation appears solid. However, I've identified a critical issue in a refactoring of the bind_kv_cache utility function, which removes an important safety check and could lead to incorrect behavior for certain model architectures.

@ekagra-ranjan
Copy link
Contributor

ekagra-ranjan commented Sep 5, 2025

@tomasruizt - Thank you for the PR!

  1. Can you also report Acceptance Length (AL) and K used? It is more informative than AR. For e.g., with Eagle with K=3 we get AL ~2.29 on MTBench so we can expect 2.29x speedup assuming 0 draft overhead. AL gives a more holistic overview of the speedup for a given K than AR.
  2. Can you run the metrics on MTBench instead of one off sample for less noisy metric? This would also allow us to compare the AL and TPOT improvement from other SD methods: [Benchmark][V1][Spec Decode][EAGLE] Tracking benchmark for V1 EAGLE #17812. You can find the cmd to run the offline inference to find AL and online inference to find TPOT using MTBench in vLLM here: [V1][Spec Decode][Feature] Spec decode with probs #20459. You will need to update offline inference script to use separate draft model.

@ekagra-ranjan
Copy link
Contributor

ekagra-ranjan commented Sep 5, 2025

As a main model Qwen3-1.7B runs a decoding forward pass in 9-10ms, while when its used as a draft model, each forward pass takes 19-22ms.

What is the TP you are using for Qwen3-32B? By default, draft model TP is equal to target model TP. Since Qwen3-1.7B is a small model, running it on high TP might be incurring nccl communication cost. Try setting draft TP to 1.

@tomasruizt
Copy link
Contributor Author

What is the TP you are using for Qwen3-32B? By default, draft model TP is equal to target model TP. Since Qwen3-1.7B is a small model, running it on high TP might be incurring nccl communication cost. Try setting draft TP to 1.

I ran the benchmarks with TP=1 and num_draft_tokens=3. So we can rule out TP communication issues.

@tomasruizt tomasruizt requested a review from 22quinn as a code owner September 6, 2025 08:34
@mergify mergify bot added the documentation Improvements or additions to documentation label Sep 6, 2025
@mergify
Copy link

mergify bot commented Sep 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tomasruizt.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 8, 2025
@tomasruizt tomasruizt force-pushed the feature/spec-decode-draft-model branch from 7de2ae1 to 2e0fb65 Compare September 8, 2025 09:17
@tomasruizt
Copy link
Contributor Author

@ggg-s setting different TP sizes for target and draft model will raise an error in the latest commit. If you don't specify the TP size for the draft model, it will by default be equal to the target model TP.

@ggg-s
Copy link

ggg-s commented Oct 15, 2025

@tomasruizt yep. I pulled the latest commit and locally removed to test heterogenous TP. With set in , the service starts without errors, but the drafter still gets instantiated on every TP rank — logs show each worker printing “Loading drafter model…”, and GPU memory usage is identical across GPUs.self._raise_if_draft_tp_mismatch()"draft_tensor_parallel_size": 1speculative-config

So even though the check is bypassed, the behavior remains “replicated-per-rank” rather than a single TP=1 drafter. If there’s a flag or code path to pin the drafter to rank0 (and broadcast results), I’m happy to try it.

@tomasruizt
Copy link
Contributor Author

@ggg-s that's expected. As I said:

If you don't specify the TP size for the draft model, it will by default be equal to the target model TP.

The only way to have TP = 1 on the drafter is to have TP = 1 on the target model, atm.

@ggg-s
Copy link

ggg-s commented Oct 15, 2025

@tomasruizt Got it, thank you for your explanation!

@QingNagi
Copy link

I encountered a new issue involving tensor dimension mismatches that occur under high concurrency conditions. What causes this problem?
My command:
vllm serve Qwen3/Qwen3-32B-FP8
--served-model-name "qwen3"
--gpu-memory-utilization 0.9
--tensor-parallel-size 1
--trust-remote-code
--max-model-len 10000
--speculative-config '{"method": "draft_model", "model": "Qwen3/Qwen3-0.6B", "num_speculative_tokens": 5, "max-model-len": 10000, "disable_padded_drafter_batch": true}'
--override-generation-config '{"temperature" : 0}' &

vllm bench serve --model qwen3fd
--dataset-name spec_bench
--tokenizer Qwen3/Qwen3-32B-FP8
--percentile-metrics ttft,tpot,itl,e2el
--dataset-path ./question.jsonl
--num-prompts -1
--max-concurrency 64

Error message:
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] EngineCore encountered a fatal error.
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] Traceback (most recent call last):
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] File "/opt/vllm_env/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 701, in run_engine_core
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] engine_core.run_busy_loop()
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] File "/opt/vllm_env/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 728, in run_busy_loop
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] self._process_engine_step()
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] File "/opt/vllm_env/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 754, in _process_engine_step
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] outputs, model_executed = self.step_fn()
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] File "/opt/vllm_env/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 284, in step
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] model_output = self.execute_model_with_error_logging(
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] File "/opt/vllm_env/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 270, in execute_model_with_error_logging
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] raise err
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] File "/opt/vllm_env/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 261, in execute_model_with_error_logging
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] return model_fn(scheduler_output)
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] File "/opt/vllm_env/lib/python3.10/site-packages/vllm/v1/executor/abstract.py", line 103, in execute_model
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] output = self.collective_rpc("execute_model",
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] File "/opt/vllm_env/lib/python3.10/site-packages/vllm/executor/uniproc_executor.py", line 83, in collective_rpc
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] return [run_method(self.driver_worker, method, args, kwargs)]
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] File "/opt/vllm_env/lib/python3.10/site-packages/vllm/utils/init.py", line 3122, in run_method
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] return func(*args, **kwargs)
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] File "/opt/vllm_env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] return func(*args, **kwargs)
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] File "/opt/vllm_env/lib/python3.10/site-packages/vllm/v1/worker/gpu_worker.py", line 447, in execute_model
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] output = self.model_runner.execute_model(scheduler_output,
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] File "/opt/vllm_env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] return func(*args, **kwargs)
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] File "/opt/vllm_env/lib/python3.10/site-packages/vllm/v1/worker/gpu_model_runner.py", line 2432, in execute_model
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] propose_draft_token_ids(valid_sampled_token_ids)
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] File "/opt/vllm_env/lib/python3.10/site-packages/vllm/v1/worker/gpu_model_runner.py", line 2379, in propose_draft_token_ids
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] self._draft_token_ids = self.propose_draft_token_ids(
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] File "/opt/vllm_env/lib/python3.10/site-packages/vllm/v1/worker/gpu_model_runner.py", line 2590, in propose_draft_token_ids
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] draft_token_ids = self.drafter.propose(
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] File "/opt/vllm_env/lib/python3.10/site-packages/vllm/v1/spec_decode/draft_model.py", line 71, in propose
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] draft_token_ids = super().propose(
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] File "/opt/vllm_env/lib/python3.10/site-packages/vllm/v1/spec_decode/eagle.py", line 216, in propose
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] self.set_input_ids_first_pass(
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] File "/opt/vllm_env/lib/python3.10/site-packages/vllm/v1/spec_decode/draft_model.py", line 127, in set_input_ids_first_pass
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] self.input_ids[:num_tokens] = target_token_ids
(EngineCore_DP0 pid=139546) ERROR 10-14 05:36:09 [core.py:710] RuntimeError: The expanded size of the tensor (2048) must match the existing size (2062) at non-singleton dimension 0. Target sizes: [2048]. Tensor sizes: [2062]

@tomasruizt
Copy link
Contributor Author

I think I know precisely the line causing it. But to reproduce I would need the questions.jsonl file. Would it be possible to share it? @QingNagi

@QingNagi
Copy link

I think I know precisely the line causing it. But to reproduce I would need the questions.jsonl file. Would it be possible to share it? @QingNagi

Sure! Download the dataset using: wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl

@tomasruizt
Copy link
Contributor Author

@QingNagi can you share what commit you are testing on?

@QingNagi
Copy link

@QingNagi can you share what commit you are testing on?

vllm==0.11.0

@tomasruizt
Copy link
Contributor Author

@QingNagi I meant which commit of this branch?

@QingNagi
Copy link

37f013e this commit.

@QingNagi
Copy link

@QingNagi I meant which commit of this branch?

When I add the --request-rate parameter, the problem doesn't occur, which is strange.

@tomasruizt
Copy link
Contributor Author

tomasruizt commented Oct 16, 2025

@QingNagi Are you on the vLLM Slack? We can chat over there

Edit: I'm not able to reproduce the issue. Can you try with the latest commit of this branch? It would help with reproduction if the error happens with a smaller model. e.g. target=Qwen3-1.7B-FP8, draft=Qwen3-0.6B, also if the error happens with --enforce-eager, and if you let me what GPU you are running on.

@QingNagi
Copy link

Ok, I will try it on the latest commit.

Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
@ggg-s
Copy link

ggg-s commented Oct 16, 2025

@tomasruizt I was wondering if this PR (#26937) could help improve the decoding speed of the draft model?

@benchislett
Copy link
Collaborator

@ggg-s currently the EAGLE/draft model is forced to be piecewise. We can optimize this aspect as a follow-up feature.

Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
@ggg-s
Copy link

ggg-s commented Oct 17, 2025

Got it!

@ggg-s
Copy link

ggg-s commented Oct 18, 2025

“ I measured online throughput metrics using the commands below. Hardware was an RTX PRO 6000 96GB. After making sure the draft model also uses CUDA graph, SD has higher throughput than not using SD. See tables below.”

@tomasruizt Does using CUDA graph here refer to using the FULL CUDA graph for SD? If so, could you please let me know which parameters need to be set to enable SD to use it? Thank you!

@tomasruizt
Copy link
Contributor Author

@ggg-s Full cuda graphs are not yet supported. That comment referred to piecewise CUDA graphs. They are used by default unless you pass --enforce-eager.

@ggg-s
Copy link

ggg-s commented Oct 20, 2025

@tomasruizt Can I contact you on Slack?

@tomasruizt
Copy link
Contributor Author

@tomasruizt
Copy link
Contributor Author

@LiuXiaoxuanPKU I ran throughput benchmarks on a multi-GPU (TP=4) setup with Llama-3-70B, and added the results to the bottom of the PR description. I found speedups up to batch_size=32. AFAIK you had a somewhat different results. If you can share those details, I can attempt to reproduce. Thanks for your review :)

@benchislett
Copy link
Collaborator

@tomasruizt do you have any comparison to EAGLE3? There's a head for Llama 3.3 70B: https://huggingface.co/yuhuili/EAGLE3-LLaMA3.3-Instruct-70B

@tomasruizt
Copy link
Contributor Author

tomasruizt commented Oct 20, 2025

@benchislett I haven't run these benchmarks for EAGLE3, but I could compute that tomorrow for comparison.

@tomasruizt
Copy link
Contributor Author

@benchislett I added the EAGLE3 benchmark for reference. It shows very good acceleration, strongest with small batch sizes. It's faster than the draft_model contributed here, but using draft_model doesn't require training a separate model, like EAGLE. Furthermore, there are follow-up optimizations that should substantially improve draft_model performance:

  • (1) Making the draft model run on TP=1, regardless of the TP of the target model. The Qwen3 benchmarks on TP=1 in the PR description show that we might reach 2x acceleration this way.
  • (2) the draft model should run with full CUDA graphs during pure decode steps (just like the target model). Profiling suggest that this could speed up draft_model decodes by ~20-30%.

I tried both for this PR, but found them non-trivial, so I stepped back. I will be away in November, so I think we should try to merge this PR before November.

@mergify
Copy link

mergify bot commented Oct 24, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tomasruizt.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models multi-modality Related to multi-modality (#4194) needs-rebase performance Performance-related issues qwen Related to Qwen models rocm Related to AMD ROCm speculative-decoding structured-output tool-calling v1

Projects

Status: No status
Status: No status
Status: In progress

Development

Successfully merging this pull request may close these issues.

10 participants