Skip to content

Conversation

@luccafong
Copy link
Collaborator

@luccafong luccafong commented Feb 4, 2025

Implement DeepSeek MTP: #12181 to support DeepSeek MTP layers for next n prediction.

Online Serving
Add --num-speculative-tokens 1 for DeepSeek V3/R1:

python -m vllm.entrypoints.openai.api_server --disable-log-requests --gpu-memory-utilization 0.8  --max-model-len 65536 --max-num-seqs 128 --seed 0 --tensor-parallel-size 8 --model deepseek-ai/DeepSeek-R1 ---trust-remote-code --num-speculative-tokens 1

Offline Inference
Set num_speculative_tokens = 1

llm = LLM(
    model="deepseek-ai/DeepSeek-R1",
    tensor_parallel_size=8,
    max_model_len=8192, # If you have enough memory with your hardware, you can ignore this
    num_speculative_tokens=1, # only 1 is supported for now
   draft_tensor_parallel_size=8, # optional, by default it will be the same as tensor_parallel_size.
)

Note: This implementation validates on MTP k=1 models only.

Benchmark Results

The acceptance rate is 81% ~ 82.3% on R1 k=1.
The speedup depends on the QPS, with 1.63x speedup for QPS=1 and certain improvement with QPS<8 as shown in below table.

Results on various QPS

Draft TP=1

QPS Baseline TPOT k=1 TPOT Speedup
1 55.47 33.99 1.63x
2 57.58 48.8 1.18x
4 64.29 51.02 1.26x
6 122.93 108.15 1.14x
8 120.18 119.14 1.0x

Draft TP=8

QPS Baseline TPOT k=1,TP=8 TPOT Speedup
1 55.47 32.64 1.69x
2 57.58 43.6 1.32x
4 64.29 52.62 1.22x
6 122.93 129.5 < 1.0
8 120.18 139.49 < 1.0

Results on various Concurrency
Draft TP=8

MAX_CONCURRENCY Baseline TPOT k=1 TPOT Speedup
1 23.13 17.24 1.34x
2 28.10 17.07 1.64x
4 27.57 21.48 1.28x
8 38.57 34.62 1.11x
16 50.24 40.89 1.22x
32 70.88 56.63 1.25x

@github-actions
Copy link

github-actions bot commented Feb 4, 2025

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM. It's pretty clean so no concerns.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know how long does it take to run all tests in this file?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: where did we truncate the input_ids?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for 1st stage: position 0 is masked for MTP, but it only applies to k=1, I need to change the mask to the [position<=k-1],
for 2+ stage, previous tokens in last stage is marked pre-computed, this is a bit complicated for k>1 on different layers, need to look into.

in short, the current change works for k=1 (which deepseek v3 model set), but need more changes for k>1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: what's the shape of input_ids here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the incremental length, here its [B*1]

Copy link

@Neo9061 Neo9061 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any way to put a MD file instructing examples on how to use the MTP for SD?

Especially,

  1. The num_nextn_predict_layers is 1, can we specify speculation length more than 1? and what are requirements on formatting the draft model artifacts?
  2. Is this code compatible with multi-node inference? assume so as the draft is loaded in single GPU?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The num_nextn_predict_layers in DeepSeek V3 has only 1. Will that mean you will reuse the MTP head if I specify MAX_SEC_TOKENS more than 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a test file on dummy model. num_speculative_tokens should be <= num_nextn_predict_layers, the transformer blocks are different in different steps. I am adding some assertion for this case when user pass higher number.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to just re-use the MTP to predict tokens whose k > 1? as essentially they are the same right?

You can print some warning that this is not expected though.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't the mtp_start_layer_idx be num_hidden_layers -1?

num_hidden_layers is 61 in DeepSeek config. The index of last layer is 60.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://huggingface.co/deepseek-ai/DeepSeek-V3/raw/main/model.safetensors.index.json the last layer is 61 which is the mtp layer.

I see, thanks for clarifying!

@Neo9061
Copy link

Neo9061 commented Feb 5, 2025

@luccafong Sorry have to ask those questions as I hope to use your implementation.

  1. Have you tested it e2e with VLLM's multi-node distributed inference setting? asking as I can only deploy the model in multi-node settings.
  2. If I want to re-use the MTP head to do speculation length k > 1, what is the hacking implementation you would recommend to just make it work? as k=1 is too limited in my application.

@benchislett
Copy link
Collaborator

@luccafong I have been working on a similar implementation locally, and have faced a few challenges that I'm not sure are addressed here. Have you validated the acceptance rate for k=1 for real weights?

I believe that the final RMSNorm in the DeepSeekV3 main model is not necessary for speculative decoding since the hnorm already normalizes the previous hidden weights received from the main model. It's unclear to me how it is classified in the DeepSeek-V3 technical report, but I think that the norm might be included in the output head and therefore not normalized as input to the MTP module. Anecdotally, I observe a small increase in acceptance rate with this change.

Also, I have noticed the acceptance rate becomes very low (<50%) when I enable the recently added MLA attention. Have you noticed this also? I am not sure what could cause this, maybe it is a bug fixed in recent commits to vLLM. I would like to know if this is an issue for your implementation.

@luccafong
Copy link
Collaborator Author

@luccafong I have been working on a similar implementation locally, and have faced a few challenges that I'm not sure are addressed here. Have you validated the acceptance rate for k=1 for real weights?

I believe that the final RMSNorm in the DeepSeekV3 main model is not necessary for speculative decoding since the hnorm already normalizes the previous hidden weights received from the main model. It's unclear to me how it is classified in the DeepSeek-V3 technical report, but I think that the norm might be included in the output head and therefore not normalized as input to the MTP module. Anecdotally, I observe a small increase in acceptance rate with this change.

Also, I have noticed the acceptance rate becomes very low (<50%) when I enable the recently added MLA attention. Have you noticed this also? I am not sure what could cause this, maybe it is a bug fixed in recent commits to vLLM. I would like to know if this is an issue for your implementation.

The accept rate is around 56% during my testing, MLA attention could lead to different branch,
https://github.com/luccafong/vllm/blob/ds_mtp/vllm/spec_decode/multi_step_worker.py#L98 I fixed in a later commit.

regarding the norm, thanks for pointing out, let me try adjusting to see if there is an improvement.

@luccafong
Copy link
Collaborator Author

luccafong commented Feb 6, 2025

@luccafong Sorry have to ask those questions as I hope to use your implementation.

  1. Have you tested it e2e with VLLM's multi-node distributed inference setting? asking as I can only deploy the model in multi-node settings.
  2. If I want to re-use the MTP head to do speculation length k > 1, what is the hacking implementation you would recommend to just make it work? as k=1 is too limited in my application.

1.Not tested with multi node settings; 2. We can reuse if you do some model processing, e.g. duplicate the weights to different layers, the hacky changes will not be proper since for n predict layers >1, and we do a k > n predict layers, it is difficult to decide which layer to forward multiple times.
Note for now as commented in the other thread, some changes are needed for K>1, I am working in progress, let me update with you if it works.

@mergify
Copy link

mergify bot commented Feb 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @luccafong.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 6, 2025
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please excuse my multiple questions.

inputs_embeds[positions <= spec_step_index] = 0 is for pre-filling stage for each MTP head correct? as during draft model (MTP head) decoding stage, the inputs_embeds is a single hidden vector.

That is what I saw in EAGLE workflow. It firstly enters the code from here with num_steps being 1 for prefilling (that is where the mask is effective). Then the num_steps becomes to be speculation length k and inputs_embed for each forward pass is a single embed vector.

But I didn't see your logic is modified in

for step in range(num_steps):
to introduce spec_step_index, where do you introduce it then?

@yangchou19
Copy link

@luccafong Hi thanks for your great work! I ran deepseek-r1 in 2 x 8 H100 ray clusters, but encountered CUDA error: invalid device ordinal, could you help take a look this issue? The ray status is normal, thanks

Run script:

vllm serve  deepseek-ai/DeepSeek-R1\
        --host 0.0.0.0 \
        --port 8081 \
        --tensor-parallel-size 16 \
        --pipeline-parallel-size 1 \
        --gpu-memory-utilization 0.8 \
        --max-num-seqs 16 \
        --max-model-len 32768 \
        --served-model-name deepseek_r1 \
        --device cuda \
        --quantization fp8 \
        --trust-remote-code \
        --num-speculative-tokens 1

Error message:

ERROR 02-19 05:03:47 engine.py:389]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^M
ERROR 02-19 05:03:47 engine.py:389]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^M
ERROR 02-19 05:03:47 engine.py:389]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 582, in execute_method^M
ERROR 02-19 05:03:47 engine.py:389]     raise e^M
ERROR 02-19 05:03:47 engine.py:389]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 573, in execute_method^M
ERROR 02-19 05:03:47 engine.py:389]     return run_method(target, method, args, kwargs)^M
ERROR 02-19 05:03:47 engine.py:389]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^M
ERROR 02-19 05:03:47 engine.py:389]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2196, in run_method^M
ERROR 02-19 05:03:47 engine.py:389]     return func(*args, **kwargs)^M
ERROR 02-19 05:03:47 engine.py:389]            ^^^^^^^^^^^^^^^^^^^^^^M
ERROR 02-19 05:03:47 engine.py:389]   File "/usr/local/lib/python3.12/dist-packages/vllm/spec_decode/spec_decode_worker.py", line 368, in init_device^M
ERROR 02-19 05:03:47 engine.py:389]     self.spec_decode_sampler.init_tensors(self.rank,^M
ERROR 02-19 05:03:47 engine.py:389]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/spec_decode_base_sampler.py", line 56, in init_tensors^M
ERROR 02-19 05:03:47 engine.py:389]     self.num_accepted_tokens = torch.tensor(0,^M
ERROR 02-19 05:03:47 engine.py:389]                                ^^^^^^^^^^^^^^^^M
ERROR 02-19 05:03:47 engine.py:389] RuntimeError: CUDA error: invalid device ordinal^M
ERROR 02-19 05:03:47 engine.py:389] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.^M
ERROR 02-19 05:03:47 engine.py:389] For debugging consider passing CUDA_LAUNCH_BLOCKING=1^M
ERROR 02-19 05:03:47 engine.py:389] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.^M

I encountered an “out of memory” error while running DeepSeek-R1 on 2 machines with 8 *H20 GPUs each. Has anyone successfully run DeepSeek-R1 on H20 GPUs?

Here is the command I used:

 vllm serve  deepseek-ai/DeepSeek-R1\
        --host 0.0.0.0 \
        --port 8081 \
        --tensor-parallel-size 16 \
        --pipeline-parallel-size 1 \
        --gpu-memory-utilization 0.99 \
        --max-model-len 131072 \
        --served-model-name deepseek_r1 \
        --quantization fp8 \
        --trust-remote-code \
        --num-speculative-tokens 1

@benchislett
Copy link
Collaborator

@yangchou19 Your gpu memory utilization is set too high. As I understand, the weights for speculative decoding are not currently accounted-for in the memory profiler so there must be additional memory leftover from the vllm allocation for it to live in. As a workaround, try decreasing the gpu memory utilization progressively (try 0.95, 0.9, 0.88, 0.85) until it succeeds. If you want to use speculative decoding, you may need to decrease max-model-len to ensure there is enough memory available for a large KV cache.

@hxt365
Copy link

hxt365 commented Feb 27, 2025

Is this applicable for R1 Distill models? I got this error for deepseek-ai/DeepSeek-R1-Distill-Qwen-32B: ValueError: num_speculative_tokens was provided without speculative_model

@mgoin
Copy link
Member

mgoin commented Feb 27, 2025

@hxt365 No the Distill models do not have MTP modules like DeepSeek V3/R1

@KiroSummer
Copy link

Great work on the analysis! I wanted to clarify one point regarding the baseline vs. MTP performance comparison. Given that the speculative decoding worker implementation doesn't support asynchronous output processing or multi-step scheduling, could you confirm whether these two optimizations were utilized when calculating the TPOT metric for the baseline model?

@JoeyYoung
Copy link

Hi all, I try to use speculative decoding with tp=8 and pp=2 on the 2 x 8H20 testbed, with following command:

vllm serve /vllm-workspace/DeepSeek-R1/ \
        --host 0.0.0.0 \
        --port 8081 \
        --tensor-parallel-size 8 \
        --pipeline-parallel-size 2 \
        --gpu-memory-utilization 0.8 \
        --max-num-seqs 16 \
        --max-model-len 32768 \
        --served-model-name deepseek_r1 \
        --device cuda \
        --quantization fp8 \
        --trust-remote-code \
        --num-speculative-tokens 1

But it reports the error:

INFO 02-28 00:11:01 [config.py:334] Overriding HF config with <function SpeculativeConfig.hf_config_override at 0x7f1bcd2a7880>
Traceback (most recent call last):
  File "/root/anaconda3/envs/vllm/bin/vllm", line 8, in <module>
    sys.exit(main())
  File "/vllm-workspace/vllm/vllm/entrypoints/cli/main.py", line 73, in main
    args.dispatch_function(args)
  File "/vllm-workspace/vllm/vllm/entrypoints/cli/serve.py", line 34, in cmd
    uvloop.run(run_server(args))
  File "/root/anaconda3/envs/vllm/lib/python3.10/site-packages/uvloop/__init__.py", line 82, in run
    return loop.run_until_complete(wrapper())
  File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
  File "/root/anaconda3/envs/vllm/lib/python3.10/site-packages/uvloop/__init__.py", line 61, in wrapper
    return await main
  File "/vllm-workspace/vllm/vllm/entrypoints/openai/api_server.py", line 946, in run_server
    async with build_async_engine_client(args) as engine_client:
  File "/root/anaconda3/envs/vllm/lib/python3.10/contextlib.py", line 199, in __aenter__
    return await anext(self.gen)
  File "/vllm-workspace/vllm/vllm/entrypoints/openai/api_server.py", line 138, in build_async_engine_client
    async with build_async_engine_client_from_engine_args(
  File "/root/anaconda3/envs/vllm/lib/python3.10/contextlib.py", line 199, in __aenter__
    return await anext(self.gen)
  File "/vllm-workspace/vllm/vllm/entrypoints/openai/api_server.py", line 162, in build_async_engine_client_from_engine_args
    engine_client = AsyncLLMEngine.from_engine_args(
  File "/vllm-workspace/vllm/vllm/engine/async_llm_engine.py", line 639, in from_engine_args
    engine_config = engine_args.create_engine_config(usage_context)
  File "/vllm-workspace/vllm/vllm/engine/arg_utils.py", line 1237, in create_engine_config
    speculative_config = SpeculativeConfig.maybe_create_spec_config(
  File "/vllm-workspace/vllm/vllm/config.py", line 2025, in maybe_create_spec_config
    return SpeculativeConfig(
  File "/vllm-workspace/vllm/vllm/config.py", line 2199, in __init__
    self._verify_args()
  File "/vllm-workspace/vllm/vllm/config.py", line 2207, in _verify_args
    self.draft_model_config.verify_with_parallel_config(
  File "/vllm-workspace/vllm/vllm/config.py", line 762, in verify_with_parallel_config
    raise NotImplementedError(
NotImplementedError: Pipeline parallelism is not supported for this model. Supported models implement the `SupportsPP` interface.

Is pipeline parallelism not supported for the draft model?

Akshat-Tripathi pushed a commit to krai/vllm that referenced this pull request Mar 3, 2025
…12755)

Signed-off-by: Lu Fang <fanglu@fb.com>
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
@BoyuanS
Copy link

BoyuanS commented Mar 6, 2025

11

When I started deepseek-r1-awq int4 and set --num-speculative-tokens 1, generated contents are all empty after . Any idea why this should happen? The model works well when num-speculative-tokens is not set

previous_hidden_states = self.hnorm(previous_hidden_states)

hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
Copy link

@Pokemons386 Pokemons386 Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In 4tokens prefill case, the main model forward token[0:3] and get HS[0:3] & token4.Then MTP forward token[0:4] (token0 masked), how token[0:4] and HS[0:3] be torch.cat()?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See prepare_prefill_hidden_states, which rotates the hidden states such that they match up with the tokens.

@benchislett
Copy link
Collaborator

@BoyuanS It is likely that your AWQ quantization did not include weights for the MTP head, in which case this will not work.

lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
…12755)

Signed-off-by: Lu Fang <fanglu@fb.com>
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
czhu15 added a commit to HabanaAI/vllm-fork that referenced this pull request Apr 21, 2025
Mainly refer to:
[Model][Speculative Decoding] DeepSeek MTP spec decode (vllm-project#12755)
Enable MTP for HPU from deepseek_r1_upstream branch
czhu15 added a commit to HabanaAI/vllm-fork that referenced this pull request Apr 22, 2025
[Model][Speculative Decoding] DeepSeek MTP spec decode (vllm-project#12755)

Signed-off-by: Lu Fang <fanglu@fb.com>
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
czhu15 added a commit to HabanaAI/vllm-fork that referenced this pull request Apr 23, 2025
Mainly referred to:

401d1f2
    Author: Chendi Xue <chendi.xue@intel.com>
And
    [Model][Speculative Decoding] DeepSeek MTP spec decode (vllm-project#12755)

    Signed-off-by: Lu Fang <fanglu@fb.com>
    Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>

---------
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
…12755)

Signed-off-by: Lu Fang <fanglu@fb.com>
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
@parambole
Copy link

Hey @luccafong @mgoin @LiuXiaoxuanPKU I am currently working on integrating Deepseek's Multi-Token Prediction into Maxtext.

Question:

As part of this PR, has the team been able to load the open MTP Deepseek V3 weights and analyze the implementation during pre-training and fine-tuning? I am curious to know about the observed behavior?

@mgoin
Copy link
Member

mgoin commented Jul 9, 2025

I have no experience with training, sorry!

@parambole
Copy link

parambole commented Jul 10, 2025

I have no experience with training, sorry!

Hey @mgoin - thanks for responding. With respect to inference how is the performance of the predicted token with loading the deepseek MTP published weights ?

Specifically I am observing: deepseek-ai/DeepSeek-V3#928

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build force-merge ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding

Projects

None yet

Development

Successfully merging this pull request may close these issues.