-
-
Couldn't load subscription status.
- Fork 10.8k
[ROCM] enable aiter fused moe kernel for llama4 bf16 checkpoints #16674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,7 @@ def rocm_aiter_fused_experts( | |
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| use_fp8_w8a8: bool = False, | ||
| apply_router_weight_on_input: bool = False, | ||
| w1_scale: Optional[torch.Tensor] = None, | ||
| w2_scale: Optional[torch.Tensor] = None, | ||
| block_shape: Optional[List[int]] = None, | ||
|
|
@@ -39,6 +40,18 @@ def rocm_aiter_fused_experts( | |
| from vllm.model_executor.layers.quantization.utils.fp8_utils import ( | ||
| per_token_group_quant_fp8) | ||
|
|
||
| if apply_router_weight_on_input: | ||
| assert (topk_weights.dim() == 2 | ||
| ), "`topk_weights` should be in shape (num_tokens, topk)" | ||
| _, topk = topk_weights.shape | ||
| assert ( | ||
| topk == 1 | ||
| ), "Only support topk=1 when `apply_router_weight_on_input` is True" | ||
|
|
||
| hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) | ||
| topk_ids = topk_ids.to(torch.int32) | ||
| topk_weights = torch.ones_like(topk_weights, dtype=torch.float32) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does AITER require fp32 weight? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, the passed-in |
||
|
|
||
| if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8: | ||
| assert w1_scale is not None | ||
| assert w2_scale is not None | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we assert topk_weights.dim() == 2?