Skip to content

Conversation

@luyuzhe111
Copy link
Contributor

@luyuzhe111 luyuzhe111 commented Mar 18, 2025

In this PR, we allow the EAGLE architecture to optionally use proper RMS normalizations similar to the DeepSeek MTP module. It has two main benefits.

  1. It improves acceptance length by ~5% when pre-attention norm, the output norm, and parallel norms (highlighted in red in the diagram below, taken from the DeepSeek V3 technical report) are used during training. Through ablation studies, we found all these RMS norms contribute to the improvement.
Screenshot 2025-03-17 at 5 31 23 PM
  1. It alleviates the approximated KV cache bug currently in the vLLM EAGLE implementation as it reduces the caused performance degradation from 15% to 9% [Bug]: EAGLE / MTP Doesn't Overwrite Approximated Hidden States / KV Cache, 8%- 15% Acceptance Length Degradation #14649 . Essentially, these additional RMS norms make sure the approximated KV cache is not too off, since the hidden states, from which those KV cache are computed, are normalized. The following table show the acceptance length for Llama 3 8B on GSM8K using this example script.
Number of Speculated Tokens 1 2 3 4 5
EAGLE Repo 1.84 2.45 2.83 3.07 3.17
vLLM 1.84 2.37 2.65 2.84 2.89
Acceptance Length Drop 0% 3% 6% 7% 9%

With these normalizations, the performance degradation due to the approximated KV cache is now 0 ~ 9%, compared to the 8% ~ 15% drop without these normalizations. However, I want to emphasize that I do not intend to use this PR as a fix of #14649, but rather to show the community how to alleviate this bug temporarily while the bug is being fixed.

cc @LiuXiaoxuanPKU Would appreciate your review. Thanks!

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@luyuzhe111 luyuzhe111 changed the title [Feature] Enhance EAGLE architecture [Feature] Enhance EAGLE Architecture with Proper RMS Norms Mar 18, 2025
Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
@luyuzhe111
Copy link
Contributor Author

@WoosukKwon would appreciate your review as well!

vllm/config.py Outdated
Comment on lines 808 to 814
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp')) \
and (self.hf_text_config.kv_lora_rank is not None) or \
(hasattr(self.hf_text_config, "model_type") \
and self.hf_text_config.model_type == 'eagle' \
and self.hf_text_config.model.model_type in \
('deepseek_v2', 'deepseek_v3') \
and self.hf_text_config.kv_lora_rank is not None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we actually clean this up?

  1. It'd be nice we can have some comments here since the code here is difficult to folllow.
  2. Maybe we can have some if statements like this?
if ...:
    return True
elif ...:
    return True

return False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure! just committed a simplified version!

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a nit, otherwise LGTM

@WoosukKwon
Copy link
Collaborator

Thank you @DarkLight1337!

Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) March 26, 2025 05:20
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 26, 2025
Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took a pass. But I do have several concerns about this PR:

  1. This PR's eagle implementation is different from standard eagle.
  2. The users need to provide weights of RMS norms and also change the model config. I am concerned very limited users might use this.
    Let me know your thoughts on this @luyuzhe111, thanks!


self.add_para_norm = False
if hasattr(self.config.model,
"add_para_norm") and self.config.model.add_para_norm:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you saying add_para_norm will be added by the user in the model config file? And users need to provide the weights as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Lily @LiuXiaoxuanPKU , thanks for reviewing and raising these great questions!

  1. It indeed adds the option to load additional normalization layers, but it does not alter the default behavior. Thus, I think it should be fine?
  2. Yes, the users need to provide trained weights and corresponding model config. One immediate use case is actually DeepSeek. With these few of lines of change in this PR, one can actually load in MTP weights after some conversion. I thought this would be helpful since EAGLE will be added in V1 first and then MTP. Thus, with this PR, users can immediately begin using MTP even with only EAGLE support in V1. Finally, as mentioned in this PR, adding these norms actually improves EAGLE training (DeepSeek added these norms for good reasons). I do expect people to realize this soon, so it would be great if vLLM can anticipate and cater to the user needs in this fast-pacing field.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the user needs to specify add_para_norm (as well as skip_output_norm, and skip_prenorm) if they want to change from the original EAGLE model architecture. One can reference the conversion script linked above. Again, I want to emphasize that this is backward compatible with the original EAGLE architecture & configs.


from transformers import AutoConfig, PretrainedConfig

from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the code style's perspective, it's confusing to import deepseek config in the eagle file...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is very awkward indeed... but did not find a better solution since AutoConfig does not support DeepSeek config. Open to any suggestions.

target_archs = ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]
if any(target_arch in archs for target_arch in target_archs):
# AutoConfig does not support DeepSeek MoE models yet
model_config = DeepseekV2Config(**model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the change here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you were referring to why we need to single out DeepSeek config, it's due to the fact that AutoConfig does not support DeepSeek config. Were you referring to something else?

Thanks again for your time to review this PR!

@DarkLight1337 DarkLight1337 merged commit 781d056 into vllm-project:main Mar 26, 2025
43 of 44 checks passed
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
…ect#14990)

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
…ect#14990)

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
…ect#14990)

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…ect#14990)

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants