- 
                Notifications
    
You must be signed in to change notification settings  - Fork 533
 
[Feat] Adapted mtp function to Qwen3-next #3918
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| 
           👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge: 
 If CI fails, you can run linting and testing checks locally according Contributing and Testing.  | 
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adapts the Multi-Token Prediction (MTP) speculative decoding function for the Qwen3-next model. The changes involve registering the new MTP model variant, customizing the Qwen3-next model implementation for MTP, and adjusting the model runner logic. My review has identified two main issues. First, in mtp_proposer.py, a hardcoded layer index is used to access attention metadata, which is brittle and should be replaced with a dynamic lookup. Second, and more critically, the logic for determining the attention state in model_runner_v1.py has been inverted, which is a breaking change for other speculative decoding methods. I have provided suggestions to address both issues to improve maintainability and correctness.
| 
           This pull request has conflicts, please resolve those before we can evaluate the pull request.  | 
    
845019f    to
    678d684      
    Compare
  
    334e475    to
    810d097      
    Compare
  
    0a49317    to
    f2664cb      
    Compare
  
    | self.model = DeepSeekMTP( | ||
| vllm_config=self.vllm_config).to(target_device) | ||
| 
               | 
          ||
| architecture = self.vllm_config.model_config.architecture | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to globally maintain a dictionary mapping from model architectures to specific module classes for better scalability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good advice! But i must merge mtp for qwen3-next quickly, so i have no time to test other model architectures fully.
Also, i think a function is also a good choice. I think i will think and modify it in next pr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, as i commented, the lazy import is crucial.
A global map will cause a patch error!
So, it must be a function here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can take vllm's model registration as an example:
https://github.com/vllm-project/vllm/blob/32257297dd4dcb996a0fb4641c2018289d20396b/vllm/model_executor/models/registry.py#L671
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corresponding dict used by them is here:
https://github.com/vllm-project/vllm/blob/32257297dd4dcb996a0fb4641c2018289d20396b/vllm/model_executor/models/registry.py#L57
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, now i added two maps in vllm_ascend/spec_decode/mtp_proposer.py.
| attn_metadata = attn_metadata['model.layers.0.self_attn.attn'] | ||
| architecture = self.vllm_config.model_config.architecture | ||
| if architecture == "Qwen3NextForCausalLM": | ||
| attn_metadata = attn_metadata['model.layers.3.self_attn.attn'] | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto here. Don't use hard code. Instead, the global dict can map from model architectures to both their corresponding module classes and attention layer names.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right!
But I think a function is a better choice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, now i added two maps in vllm_ascend/spec_decode/mtp_proposer.py.
| with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct", | ||
| tensor_parallel_size=4, | ||
| max_model_len=4096, | ||
| gpu_memory_utilization=0.8, | ||
| distributed_executor_backend="mp", | ||
| speculative_config={ | ||
| "method": "qwen3_next_mtp", | ||
| "num_speculative_tokens": 1 | ||
| }, | ||
| enforce_eager=True) as vllm_model: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have we tested graph mode yet? That is to say, do not enforce_eager.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, i removed enforce_eager=True and it works fine.
| @support_torch_compile | ||
| class CustomQwen3NextMTP(Qwen3NextMTP, SupportsPP): | ||
| packed_modules_mapping = { | ||
| "qkv_proj": [ | ||
| "q_proj", | ||
| "k_proj", | ||
| "v_proj", | ||
| ], | ||
| "gate_up_proj": ["up_proj", "down_proj"] | ||
| } | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see you've add support_torch_compile for both classes, assuming you copy them from vLLM, then all the more reason to test graph mode now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, i removed enforce_eager=True in test_qwen3_next.py and it works fine.
| from tests.e2e.conftest import VllmRunner | ||
| 
               | 
          ||
| 
               | 
          ||
| def test_models_distributed_Qwen3_NEXT_MTP_TP4(): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I deleted test_qwen3_next_mtp.py and move tests into test_qwen3_next.py.
| "num_speculative_tokens": 1 | ||
| }, | ||
| enforce_eager=True) as vllm_model: | ||
| vllm_model.generate_greedy(example_prompts, max_tokens) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we test the accept rate in mtp scenario? like what's done in deepseek mtp test case: https://github.com/vllm-project/vllm-ascend/blob/main/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py#L74-L86
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added a third test test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY to do it.
Signed-off-by: drslark <slarksblood@qq.com>
| 
           This pull request has conflicts, please resolve those before we can evaluate the pull request.  | 
    
What this PR does / why we need it?
Adapts mtp function to Qwen3-next.
Does this PR introduce any user-facing change?
N/A
How was this patch tested?
Run below codes.
outputs:
Qwen3-next and Qwen3-next-mtp have same results now.