From dbe7c997c9f10a403df091c92b61e9be6a360fc2 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 16 Apr 2025 16:06:46 +0000 Subject: [PATCH 01/11] register aiter fmoe as custom ops Co-authored-by: tjtanaa Signed-off-by: vllmellm --- .../layers/fused_moe/rocm_aiter_fused_moe.py | 341 +++++++++++++----- 1 file changed, 258 insertions(+), 83 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index ac158a7eee53..2fb9a8295d5c 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional +from typing import List, Optional, Tuple import torch import vllm.envs as envs from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op def is_rocm_aiter_moe_enabled() -> bool: @@ -18,109 +19,283 @@ def is_rocm_aiter_block_scaled_moe_enabled() -> bool: envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE -def rocm_aiter_fused_experts( - *, - hidden_states: torch.Tensor, +def rocm_aiter_ck_moe_impl(hidden_states: torch.Tensor, w1: torch.Tensor, + w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor) -> torch.Tensor: + import aiter as rocm_aiter + return rocm_aiter.ck_moe(hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids) + + +def rocm_aiter_ck_moe_fake(hidden_states: torch.Tensor, w1: torch.Tensor, + w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor) -> torch.Tensor: + return torch.empty((topk_ids.size(0), hidden_states.size(1)), + dtype=hidden_states.dtype, + device=hidden_states.device) + + +def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl( + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + hidden_states_dtype: torch.dtype, + expert_mask: torch.Tensor, + a1: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weights: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: torch.Tensor, + block_shape: List[int], + smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor: + from aiter import fmoe_fp8_blockscale_g1u1 + from aiter.fused_moe_bf16_asm import moe_sorting_ck + + assert len(block_shape) == 2 + + topk = topk_ids.shape[1] + model_dim = w1.shape[-1] + local_E = E = w1.shape[0] + if expert_mask is not None: + E = expert_mask.numel() + + ( + sorted_token_ids, + sorted_weight_buf, + sorted_expert_ids, + num_valid_ids, + out_asm, + ) = moe_sorting_ck(topk_ids, + topk_weights, + E, + model_dim, + hidden_states_dtype, + expert_mask=expert_mask) + + fmoe_fp8_blockscale_g1u1(out_asm, a1, w1, w2, sorted_token_ids, + sorted_weight_buf, sorted_expert_ids, + num_valid_ids, topk, w1_scale.view(local_E, -1), + w2_scale.view(local_E, -1), + a1_scale.t().contiguous(), *block_shape, + smooth_scale) + + return out_asm + + +def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake( topk_ids: torch.Tensor, - use_fp8_w8a8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, - expert_mask: Optional[torch.Tensor] = None, - **kwagrs # Ignore additional keyword arguments -) -> torch.Tensor: - - import aiter as rocm_aiter + topk_weights: torch.Tensor, + hidden_states_dtype: torch.dtype, + expert_mask: torch.Tensor, + a1: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: torch.Tensor, + block_shape: List[int], + smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor: + from aiter.fused_moe_bf16_asm import moe_sorting_ck + + model_dim = w1.shape[-1] + E = w1.shape[0] + _, _, _, _, out_asm = moe_sorting_ck(topk_ids, + topk_weights, + E, + model_dim, + hidden_states_dtype, + expert_mask=expert_mask) + return out_asm + + +def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + fc1_smooth_scale: Optional[torch.Tensor] = None, + fc2_smooth_scale: Optional[torch.Tensor] = None, + a16: bool = False, + activation: str = "silu") -> torch.Tensor: import aiter.fused_moe_bf16_asm as rocm_aiter_asm_fmoe + from aiter import ActivationType + + assert activation in ["silu", "gelu"], "The given activation:" \ + f" {activation}" \ + " is not supported in" \ + " AITER." + if activation == "silu": + aiter_activation = ActivationType.Silu + else: + aiter_activation = ActivationType.Gelu + + return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weight=topk_weight, + topk_ids=topk_ids, + fc1_scale=fc1_scale, + fc2_scale=fc2_scale, + fc1_smooth_scale=fc1_smooth_scale, + fc2_smooth_scale=fc2_smooth_scale, + a16=a16, + activation=aiter_activation) + + +def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + fc1_smooth_scale: Optional[torch.Tensor] = None, + fc2_smooth_scale: Optional[torch.Tensor] = None, + a16: bool = False, + activation: str = "silu") -> torch.Tensor: + return torch.empty_like(hidden_states) + + +def rocm_aiter_topk_softmax_impl(topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool) -> None: + from aiter import topk_softmax + topk_softmax(topk_weights, topk_indices, token_expert_indices, + gating_output, renormalize) + + +def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool) -> None: + pass + + +def rocm_aiter_shuffle_weight_impl(tensor: torch.Tensor) -> torch.Tensor: + from aiter.ops.shuffle import shuffle_weight + return shuffle_weight(tensor) + + +def rocm_aiter_shuffle_weight_fake(tensor: torch.Tensor) -> torch.Tensor: + return tensor + + +if current_platform.is_rocm(): + + direct_register_custom_op( + op_name="rocm_aiter_ck_moe", + op_func=rocm_aiter_ck_moe_impl, + mutates_args=[], + fake_impl=rocm_aiter_ck_moe_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_fmoe_fp8_blockscale_g1u1", + op_func=rocm_aiter_fmoe_fp8_blockscale_g1u1_impl, + mutates_args=[], + fake_impl=rocm_aiter_fmoe_fp8_blockscale_g1u1_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_asm_moe", + op_func=rocm_aiter_asm_moe_impl, + mutates_args=[], + fake_impl=rocm_aiter_asm_moe_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_topk_softmax", + op_func=rocm_aiter_topk_softmax_impl, + mutates_args=["topk_weights", "topk_indices", "token_expert_indices"], + fake_impl=rocm_aiter_topk_softmax_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op(op_name="rocm_aiter_shuffle_weight", + op_func=rocm_aiter_shuffle_weight_impl, + mutates_args=[], + fake_impl=rocm_aiter_shuffle_weight_fake, + dispatch_key=current_platform.dispatch_key, + tags=(torch.Tag.inplace_view, )) + + +def rocm_aiter_fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, + expert_mask: Optional[torch.Tensor] = None, + activation: str = "silu", + **kwargs) -> torch.Tensor: + + if is_rocm_aiter_block_scaled_moe_enabled() and use_fp8_w8a8: + + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) - from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) - - if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8: assert w1_scale is not None assert w2_scale is not None - local_E = E = w1.shape[0] - if expert_mask is not None: - E = expert_mask.numel() - - topk = topk_ids.shape[1] - model_dim = w1.shape[-1] - dtype = hidden_states.dtype # The default block sizes are 128 in AITER. if block_shape is None: block_shape = [128, 128] scale_blk_k = block_shape[1] - ( - sorted_token_ids, - sorted_weight_buf, - sorted_expert_ids, - num_valid_ids, - out_asm, - ) = rocm_aiter_asm_fmoe.moe_sorting_ck(topk_ids, - topk_weights, - E, - model_dim, - dtype, - expert_mask=expert_mask) - a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k) - rocm_aiter.fmoe_fp8_blockscale_g1u1( - out_asm, - a1, - w1, - w2, - sorted_token_ids, - sorted_weight_buf, - sorted_expert_ids, - num_valid_ids, - topk, - w1_scale.view(local_E, -1), - w2_scale.view(local_E, -1), - a1_scale.t().contiguous(), - block_shape[0], - block_shape[1], - None, - ) - return out_asm - elif use_fp8_w8a8: - return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weight=topk_weights, - topk_ids=topk_ids, - fc1_scale=w1_scale, - fc2_scale=w2_scale, - fc1_smooth_scale=None, - fc2_smooth_scale=None, - a16=False) + return torch.ops.vllm.rocm_aiter_fmoe_fp8_blockscale_g1u1( + topk_ids, topk_weights, hidden_states.dtype, expert_mask, a1, w1, + w2, w1_scale, w2_scale, a1_scale, block_shape, None) - return rocm_aiter.ck_moe(hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids) + elif use_fp8_w8a8: + return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weight=topk_weights, + topk_ids=topk_ids, + fc1_scale=w1_scale, + fc2_scale=w2_scale, + fc1_smooth_scale=None, + fc2_smooth_scale=None, + a16=False, + activation=activation) + + return torch.ops.vllm.rocm_aiter_ck_moe(hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids) def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, token_expert_indices: torch.Tensor, gating_output: torch.Tensor, - renormalize: bool) -> tuple[torch.Tensor, ...]: - import aiter as rocm_aiter - rocm_aiter.topk_softmax(topk_weights, topk_indices, token_expert_indices, - gating_output, renormalize) - + renormalize: bool) -> Tuple[torch.Tensor, ...]: + torch.ops.vllm.rocm_aiter_topk_softmax(topk_weights, topk_indices, + token_expert_indices, gating_output, + renormalize) return topk_weights, topk_indices -def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: +def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]: """ Applies shuffle_weight function from AITER to each input tensor and returns them. @@ -129,15 +304,15 @@ def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: *tensors: Variable number of torch.Tensor objects. Returns: - A tuple of shuffled tensors. + A Tuple of shuffled tensors. """ - from aiter.ops.shuffle import shuffle_weight + shuffle_weigth_func = torch.ops.vllm.rocm_aiter_shuffle_weight - return tuple(shuffle_weight(tensor) for tensor in tensors) + return tuple(shuffle_weigth_func(tensor) for tensor in tensors) def expand_weights(*tensors: torch.Tensor, - expansion_dims: list[int]) -> tuple[torch.Tensor, ...]: + expansion_dims: list[int]) -> Tuple[torch.Tensor, ...]: """ Expands the dimensions of input tensors. @@ -147,7 +322,7 @@ def expand_weights(*tensors: torch.Tensor, corresponding to each tensor. Returns: - A tuple of tensors with expanded dimensions. + A Tuple of tensors with expanded dimensions. """ assert len(tensors) == len(expansion_dims), \ From d0ac8e66f4e86ca7e3a2f4ddf76417790d1125ca Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 23 Apr 2025 04:20:35 +0000 Subject: [PATCH 02/11] revert aiter moe en check func Signed-off-by: vllmellm --- .../layers/fused_moe/rocm_aiter_fused_moe.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 89e7e8b0a103..af96e9e13fec 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -3,10 +3,17 @@ import torch +from vllm import envs from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op +def is_rocm_aiter_moe_enabled() -> bool: + return current_platform.is_rocm() \ + and envs.VLLM_ROCM_USE_AITER_MOE \ + and envs.VLLM_ROCM_USE_AITER \ + + def rocm_aiter_asm_moe_tkw1(hidden_states, w1, w2, From fe1711b2367485e900507b666f6525d8410de6b9 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 23 Apr 2025 07:00:11 +0000 Subject: [PATCH 03/11] register moe tkw1 kernel Signed-off-by: vllmellm --- .../layers/fused_moe/rocm_aiter_fused_moe.py | 87 ++++++++++++------- 1 file changed, 58 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index af96e9e13fec..0b6a3377ec5e 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +from functools import cache from typing import List, Optional, Tuple import torch @@ -8,25 +9,27 @@ from vllm.utils import direct_register_custom_op +@cache def is_rocm_aiter_moe_enabled() -> bool: return current_platform.is_rocm() \ and envs.VLLM_ROCM_USE_AITER_MOE \ - and envs.VLLM_ROCM_USE_AITER \ - - -def rocm_aiter_asm_moe_tkw1(hidden_states, - w1, - w2, - topk_weight, - topk_ids, - fc1_scale=None, - fc2_scale=None, - fc1_smooth_scale=None, - fc2_smooth_scale=None, - a16=False, - per_tensor_quant_scale=None, - expert_mask=None, - activation_str: str = "silu") -> None: + and envs.VLLM_ROCM_USE_AITER + + +def rocm_aiter_asm_moe_tkw1_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + fc1_smooth_scale: Optional[torch.Tensor] = None, + fc2_smooth_scale: Optional[torch.Tensor] = None, + a16: bool = False, + per_tensor_quant_scale: Optional[torch.Tensor] = None, + expert_mask: Optional[torch.Tensor] = None, + activation_str: str = "silu") -> torch.Tensor: from aiter import ActivationType from aiter.fused_moe_bf16_asm import asm_moe_tkw1 @@ -49,6 +52,23 @@ def rocm_aiter_asm_moe_tkw1(hidden_states, activation=activation) +def rocm_aiter_asm_moe_tkw1_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + fc1_smooth_scale: Optional[torch.Tensor] = None, + fc2_smooth_scale: Optional[torch.Tensor] = None, + a16: bool = False, + per_tensor_quant_scale: Optional[torch.Tensor] = None, + expert_mask: Optional[torch.Tensor] = None, + activation_str: str = "silu") -> torch.Tensor: + return torch.empty_like(hidden_states) + + def rocm_aiter_ck_moe_impl(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor) -> torch.Tensor: @@ -218,6 +238,14 @@ def rocm_aiter_shuffle_weight_fake(tensor: torch.Tensor) -> torch.Tensor: if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_aiter_asm_moe_tkw1", + op_func=rocm_aiter_asm_moe_tkw1_impl, + mutates_args=[], + fake_impl=rocm_aiter_asm_moe_tkw1_fake, + dispatch_key=current_platform.dispatch_key, + ) + direct_register_custom_op( op_name="rocm_aiter_ck_moe", op_func=rocm_aiter_ck_moe_impl, @@ -306,19 +334,20 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor, "Only support topk=1 when" " `apply_router_weight_on_input` is True") - return rocm_aiter_asm_moe_tkw1(hidden_states, - w1, - w2, - topk_weights, - topk_ids, - fc1_scale=w1_scale, - fc2_scale=w2_scale, - fc1_smooth_scale=None, - fc2_smooth_scale=None, - a16=False, - per_tensor_quant_scale=None, - expert_mask=expert_map, - activation_str=activation) + return torch.ops.vllm.rocm_aiter_asm_moe_tkw1( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + fc1_scale=w1_scale, + fc2_scale=w2_scale, + fc1_smooth_scale=None, + fc2_smooth_scale=None, + a16=False, + per_tensor_quant_scale=None, + expert_mask=expert_map, + activation_str=activation) elif use_fp8_w8a8: assert not apply_router_weight_on_input, ( From 7423790baffa2b2d36ed8cebfbfc828d7e6306f7 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 23 Apr 2025 07:44:34 +0000 Subject: [PATCH 04/11] make mypy happy Signed-off-by: vllmellm --- .../model_executor/layers/fused_moe/rocm_aiter_fused_moe.py | 3 ++- .../compressed_tensors/compressed_tensors_moe.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 0b6a3377ec5e..64a8b7e31e6b 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -286,7 +286,8 @@ def rocm_aiter_shuffle_weight_fake(tensor: torch.Tensor) -> torch.Tensor: tags=(torch.Tag.inplace_view, )) -def rocm_aiter_fused_experts(hidden_states: torch.Tensor, +def rocm_aiter_fused_experts(*, + hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index d74d4e9273b7..721e36af2b28 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -304,9 +304,9 @@ def apply( e_score_correction_bias=e_score_correction_bias) return self.fused_experts_func( - x, - layer.w13_weight, - layer.w2_weight, + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, From 7c15ebb90f61628a266d1f99dfaa805d1ef2800b Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 23 Apr 2025 07:45:31 +0000 Subject: [PATCH 05/11] remove comment Signed-off-by: vllmellm --- vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 64a8b7e31e6b..707df5c83a12 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -364,7 +364,6 @@ def rocm_aiter_fused_experts(*, fc2_smooth_scale=None, a16=False, activation=activation) - # Temp fix if apply_router_weight_on_input: assert (topk_weights.dim() == 2 ), "`topk_weights` should be in shape (num_tokens, topk)" From beaff1ae5aec885a1fbe6a843a5b1e6363045736 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 23 Apr 2025 09:47:21 +0000 Subject: [PATCH 06/11] fix unit tests Signed-off-by: vllmellm --- tests/model_executor/test_enabled_custom_ops.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index ac2e0f3542e7..8cfc50bad0ca 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -2,6 +2,7 @@ import pytest +import vllm from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import (GeluAndMul, @@ -100,11 +101,11 @@ def test_enabled_ops_invalid(env: str): def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) topk_func = dispatch_topk_func() - + vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.is_rocm_aiter_moe_enabled.cache_clear( + ) if current_platform.is_rocm() and int(use_rocm_aiter): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( rocm_aiter_topk_softmax) - assert topk_func == rocm_aiter_topk_softmax else: assert topk_func == vllm_topk_softmax @@ -116,11 +117,12 @@ def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) + vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.is_rocm_aiter_moe_enabled.cache_clear( + ) fused_experts_func = dispatch_fused_experts_func(inplace) if current_platform.is_rocm() and int(use_rocm_aiter): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( rocm_aiter_fused_experts) - assert fused_experts_func == rocm_aiter_fused_experts elif inplace: assert fused_experts_func == torch_vllm_inplace_fused_experts From 091128d2f3ddf66eae94a1dae95175d8c480758f Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 23 Apr 2025 09:53:58 +0000 Subject: [PATCH 07/11] make mypy happy Signed-off-by: vllmellm --- tests/model_executor/test_enabled_custom_ops.py | 9 ++++----- .../layers/fused_moe/rocm_aiter_fused_moe.py | 14 +++++++++++--- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 8cfc50bad0ca..2d9cf1d48fd5 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -2,7 +2,6 @@ import pytest -import vllm from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import (GeluAndMul, @@ -12,6 +11,8 @@ dispatch_fused_experts_func, dispatch_topk_func, torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts, vllm_topk_softmax) +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled) from vllm.model_executor.layers.layernorm import ( RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm, rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm) @@ -101,8 +102,7 @@ def test_enabled_ops_invalid(env: str): def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) topk_func = dispatch_topk_func() - vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.is_rocm_aiter_moe_enabled.cache_clear( - ) + is_rocm_aiter_moe_enabled.cache_clear() if current_platform.is_rocm() and int(use_rocm_aiter): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( rocm_aiter_topk_softmax) @@ -117,8 +117,7 @@ def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) - vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.is_rocm_aiter_moe_enabled.cache_clear( - ) + is_rocm_aiter_moe_enabled.cache_clear() fused_experts_func = dispatch_fused_experts_func(inplace) if current_platform.is_rocm() and int(use_rocm_aiter): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 707df5c83a12..345b47370ec2 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -286,8 +286,7 @@ def rocm_aiter_shuffle_weight_fake(tensor: torch.Tensor) -> torch.Tensor: tags=(torch.Tag.inplace_view, )) -def rocm_aiter_fused_experts(*, - hidden_states: torch.Tensor, +def rocm_aiter_fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, @@ -301,7 +300,16 @@ def rocm_aiter_fused_experts(*, block_shape: Optional[List[int]] = None, expert_map: Optional[torch.Tensor] = None, activation: str = "silu", - **kwargs) -> torch.Tensor: + inplace: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + global_num_experts: int = -1, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + allow_deep_gemm: bool = False) -> torch.Tensor: + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) From bb19420445a6f8e10903c473950aed0e1c0ea634 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 23 Apr 2025 10:07:58 +0000 Subject: [PATCH 08/11] revert rocm_aiter_fused_experts args Signed-off-by: vllmellm --- .../layers/fused_moe/rocm_aiter_fused_moe.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 345b47370ec2..580f04fd4849 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -291,23 +291,23 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - use_fp8_w8a8: bool = False, - per_channel_quant: bool = False, - apply_router_weight_on_input: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, - expert_map: Optional[torch.Tensor] = None, - activation: str = "silu", inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + per_channel_quant: bool = False, global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, allow_deep_gemm: bool = False) -> torch.Tensor: from vllm.model_executor.layers.quantization.utils.fp8_utils import ( From 22e9459bd6ccc0dfb10ead3c7bd86b91ad9f8c47 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 24 Apr 2025 09:44:58 +0000 Subject: [PATCH 09/11] clean up fake functions; remove unnecessary args Signed-off-by: vllmellm --- .../layers/fused_moe/rocm_aiter_fused_moe.py | 84 +++++++------------ 1 file changed, 28 insertions(+), 56 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 580f04fd4849..b53e72fc354e 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -83,9 +83,7 @@ def rocm_aiter_ck_moe_impl(hidden_states: torch.Tensor, w1: torch.Tensor, def rocm_aiter_ck_moe_fake(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor) -> torch.Tensor: - return torch.empty((topk_ids.size(0), hidden_states.size(1)), - dtype=hidden_states.dtype, - device=hidden_states.device) + return torch.empty_like(hidden_states) def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl( @@ -146,17 +144,8 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake( a1_scale: torch.Tensor, block_shape: List[int], smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor: - from aiter.fused_moe_bf16_asm import moe_sorting_ck - model_dim = w1.shape[-1] - E = w1.shape[0] - _, _, _, _, out_asm = moe_sorting_ck(topk_ids, - topk_weights, - E, - model_dim, - hidden_states_dtype, - expert_mask=expert_mask) - return out_asm + return torch.empty_like(a1, dtype=torch.bf16) def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor, @@ -227,15 +216,6 @@ def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor, pass -def rocm_aiter_shuffle_weight_impl(tensor: torch.Tensor) -> torch.Tensor: - from aiter.ops.shuffle import shuffle_weight - return shuffle_weight(tensor) - - -def rocm_aiter_shuffle_weight_fake(tensor: torch.Tensor) -> torch.Tensor: - return tensor - - if current_platform.is_rocm(): direct_register_custom_op( @@ -278,37 +258,29 @@ def rocm_aiter_shuffle_weight_fake(tensor: torch.Tensor) -> torch.Tensor: dispatch_key=current_platform.dispatch_key, ) - direct_register_custom_op(op_name="rocm_aiter_shuffle_weight", - op_func=rocm_aiter_shuffle_weight_impl, - mutates_args=[], - fake_impl=rocm_aiter_shuffle_weight_fake, - dispatch_key=current_platform.dispatch_key, - tags=(torch.Tag.inplace_view, )) - - -def rocm_aiter_fused_experts(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, - allow_deep_gemm: bool = False) -> torch.Tensor: + +def rocm_aiter_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None) -> torch.Tensor: from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) @@ -343,6 +315,7 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor, "Only support topk=1 when" " `apply_router_weight_on_input` is True") + # per_channel_quant return torch.ops.vllm.rocm_aiter_asm_moe_tkw1( hidden_states, w1, @@ -413,9 +386,8 @@ def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]: Returns: A Tuple of shuffled tensors. """ - shuffle_weigth_func = torch.ops.vllm.rocm_aiter_shuffle_weight - - return tuple(shuffle_weigth_func(tensor) for tensor in tensors) + from aiter.ops.shuffle import shuffle_weight + return tuple(shuffle_weight(tensor) for tensor in tensors) def expand_weights(*tensors: torch.Tensor, From 1fa4d071780b50f657e0b72dc6c9e8d0ceef1869 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 24 Apr 2025 15:23:29 +0000 Subject: [PATCH 10/11] add comments Signed-off-by: vllmellm --- vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index b53e72fc354e..93c894d80bcc 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -289,6 +289,7 @@ def rocm_aiter_fused_experts( topk_weights = topk_weights.to(torch.float32) topk_ids = topk_ids.to(torch.int32) + # w8a8 block-scaled if block_shape is not None and use_fp8_w8a8: assert not apply_router_weight_on_input, ( "apply_router_weight_on_input is not supported for block scaled moe" @@ -305,6 +306,7 @@ def rocm_aiter_fused_experts( topk_ids, topk_weights, hidden_states.dtype, expert_map, a1, w1, w2, w1_scale, w2_scale, a1_scale, block_shape, None) + # w8a8 per-channel quantization elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input` # This applies topk_weights on the GEMM output of the first FC layer @@ -315,7 +317,6 @@ def rocm_aiter_fused_experts( "Only support topk=1 when" " `apply_router_weight_on_input` is True") - # per_channel_quant return torch.ops.vllm.rocm_aiter_asm_moe_tkw1( hidden_states, w1, @@ -331,6 +332,7 @@ def rocm_aiter_fused_experts( expert_mask=expert_map, activation_str=activation) + # w8a8 per-tensor activation per-tensor weight elif use_fp8_w8a8: assert not apply_router_weight_on_input, ( "apply_router_weight_on_input is not supported for fp8_w8a8") @@ -357,6 +359,7 @@ def rocm_aiter_fused_experts( topk_ids = topk_ids.to(torch.int32) topk_weights = torch.ones_like(topk_weights, dtype=torch.float32) + # w16a16 fallback to rocm_aiter_ck_moe w16a16 return torch.ops.vllm.rocm_aiter_ck_moe(hidden_states=hidden_states, w1=w1, w2=w2, From d3fe808fecef9239c7e4bcf89f31c5c5311580f3 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 24 Apr 2025 15:41:32 +0000 Subject: [PATCH 11/11] revert rocm_aiter_fused_experts args Signed-off-by: vllmellm --- .../layers/fused_moe/rocm_aiter_fused_moe.py | 45 ++++++++++--------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 93c894d80bcc..acaa93f5a23e 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -259,28 +259,29 @@ def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor, ) -def rocm_aiter_fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> torch.Tensor: +def rocm_aiter_fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, + allow_deep_gemm: bool = False) -> torch.Tensor: from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8)