From c51662d37290aeec294b1564fd26028f8299d5e8 Mon Sep 17 00:00:00 2001 From: Sijia Chen Date: Tue, 15 Apr 2025 09:55:35 -0700 Subject: [PATCH 1/4] enable aiter fused moe bf16 for llama4 --- .../layers/fused_moe/rocm_aiter_fused_moe.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index ac158a7eee53..f4606b506227 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -26,6 +26,7 @@ def rocm_aiter_fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, use_fp8_w8a8: bool = False, + apply_router_weight_on_input: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, @@ -39,6 +40,14 @@ def rocm_aiter_fused_experts( from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) + if apply_router_weight_on_input: + _, topk = topk_weights.shape + assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True" + + hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) + topk_ids = topk_ids.to(torch.int32) + topk_weights = torch.ones_like(topk_weights, dtype=torch.float32) + if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8: assert w1_scale is not None assert w2_scale is not None From 5f057ca3eeee7e35c3cbfaab8b0608be3d6157a1 Mon Sep 17 00:00:00 2001 From: Sijia Chen Date: Tue, 15 Apr 2025 11:51:09 -0700 Subject: [PATCH 2/4] fix lint --- vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index f4606b506227..449e0adf1c44 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -42,7 +42,9 @@ def rocm_aiter_fused_experts( if apply_router_weight_on_input: _, topk = topk_weights.shape - assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True" + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) topk_ids = topk_ids.to(torch.int32) From 8dbbb1105a5b9b52f7a30f96cd4b72c558e9b016 Mon Sep 17 00:00:00 2001 From: Sijia Chen Date: Wed, 16 Apr 2025 11:43:55 -0700 Subject: [PATCH 3/4] add assert --- vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 449e0adf1c44..709d43823fb5 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -41,6 +41,9 @@ def rocm_aiter_fused_experts( per_token_group_quant_fp8) if apply_router_weight_on_input: + assert ( + topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" _, topk = topk_weights.shape assert ( topk == 1 From 2ee3ec73f08b66133a8e8ab1c715838705e745e3 Mon Sep 17 00:00:00 2001 From: Sijia Chen Date: Wed, 16 Apr 2025 14:46:01 -0700 Subject: [PATCH 4/4] lint --- vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 709d43823fb5..4214e8944821 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -41,9 +41,8 @@ def rocm_aiter_fused_experts( per_token_group_quant_fp8) if apply_router_weight_on_input: - assert ( - topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" _, topk = topk_weights.shape assert ( topk == 1