diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 31295582c1b1..b3b14f5acbd6 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -418,10 +418,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: shuffle_weights) if self.rocm_aiter_moe_enabled: - # use 2stage ck moe layout - shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data, - layer.w2_weight.data, - layout=(32, 32)) + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data) layer.w13_weight.data = shuffled_w13 layer.w2_weight.data = shuffled_w2 diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index a92081862bfa..10b61fcda176 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +from enum import IntEnum from functools import cache from typing import Optional @@ -9,6 +10,28 @@ from vllm.utils import direct_register_custom_op +class QuantMethod(IntEnum): + # This allows interfacing with AITER QuantType Enum + # without importing the QuantType from AITER globally. + + # Note that these quantization methods are + # supported in AITER package. However, + # not all are used in this module. + + NO = 0 # a16w16 + PER_TENSOR = 1 # w8a8 (pre_Tensor) + PER_TOKEN = 2 # w8a8/w8a4 (per_Token) + BLOCK_1X128 = 3 # block quantized w8a8 (per_1x128) + BLOCK_128x128 = 4 # block quantized w8a8 (per_128x128) + + +class ActivationMethod(IntEnum): + # This allows interfacing with AITER ActivationType enum + # without importing the ActivationType enum from AITER globally. + SILU = 0 + GELU = 1 + + @cache def is_rocm_aiter_moe_enabled() -> bool: return current_platform.is_rocm() \ @@ -29,13 +52,12 @@ def rocm_aiter_asm_moe_tkw1_impl( a16: bool = False, per_tensor_quant_scale: Optional[torch.Tensor] = None, expert_mask: Optional[torch.Tensor] = None, - activation_str: str = "silu") -> torch.Tensor: + activation_method: int = ActivationMethod.SILU.value) -> torch.Tensor: from aiter import ActivationType from aiter.fused_moe_bf16_asm import asm_moe_tkw1 - activation = \ - ActivationType.Gelu if activation_str == "gelu" else ActivationType.Silu + activation = ActivationType(activation_method) return asm_moe_tkw1(hidden_states, w1, @@ -65,163 +87,7 @@ def rocm_aiter_asm_moe_tkw1_fake( a16: bool = False, per_tensor_quant_scale: Optional[torch.Tensor] = None, expert_mask: Optional[torch.Tensor] = None, - activation_str: str = "silu") -> torch.Tensor: - return torch.empty_like(hidden_states) - - -def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl( - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - hidden_states_dtype: torch.dtype, - expert_mask: torch.Tensor, - a1: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - a1_scale: torch.Tensor, - block_shape: list[int], - smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor: - from aiter import fmoe_fp8_blockscale_g1u1 - from aiter.fused_moe_bf16_asm import moe_sorting_ck - - topk = topk_ids.shape[1] - model_dim = w1.shape[-1] - local_E = E = w1.shape[0] - if expert_mask is not None: - E = expert_mask.numel() - - ( - sorted_token_ids, - sorted_weight_buf, - sorted_expert_ids, - num_valid_ids, - out_asm, - ) = moe_sorting_ck(topk_ids, - topk_weights, - E, - model_dim, - hidden_states_dtype, - expert_mask=expert_mask) - - fmoe_fp8_blockscale_g1u1(out_asm, a1, w1, w2, sorted_token_ids, - sorted_weight_buf, sorted_expert_ids, - num_valid_ids, topk, - a1_scale.t().contiguous(), - w1_scale.view(local_E, -1), - w2_scale.view(local_E, - -1), *block_shape, smooth_scale) - - return out_asm - - -def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake( - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - hidden_states_dtype: torch.dtype, - expert_mask: torch.Tensor, - a1: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - a1_scale: torch.Tensor, - block_shape: list[int], - smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor: - - return torch.empty_like(a1, dtype=hidden_states_dtype) - - -def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - fc1_scale: Optional[torch.Tensor] = None, - fc2_scale: Optional[torch.Tensor] = None, - fc1_smooth_scale: Optional[torch.Tensor] = None, - fc2_smooth_scale: Optional[torch.Tensor] = None, - a16: bool = False, - activation: str = "silu") -> torch.Tensor: - import aiter.fused_moe_bf16_asm as rocm_aiter_asm_fmoe - from aiter import ActivationType - - assert activation in ["silu", "gelu"], "The given activation:" \ - f" {activation}" \ - " is not supported in" \ - " AITER." - if activation == "silu": - aiter_activation = ActivationType.Silu - else: - aiter_activation = ActivationType.Gelu - - return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weight=topk_weights, - topk_ids=topk_ids, - fc1_scale=fc1_scale, - fc2_scale=fc2_scale, - fc1_smooth_scale=fc1_smooth_scale, - fc2_smooth_scale=fc2_smooth_scale, - a16=a16, - activation=aiter_activation) - - -def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - fc1_scale: Optional[torch.Tensor] = None, - fc2_scale: Optional[torch.Tensor] = None, - fc1_smooth_scale: Optional[torch.Tensor] = None, - fc2_smooth_scale: Optional[torch.Tensor] = None, - a16: bool = False, - activation: str = "silu") -> torch.Tensor: - return torch.empty_like(hidden_states) - - -def rocm_aiter_ck_moe_2stages_impl( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - fc1_scale: Optional[torch.Tensor] = None, - fc2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_size: Optional[list[int]] = None, - expert_mask: Optional[torch.Tensor] = None, -) -> torch.Tensor: - from aiter.fused_moe_bf16_asm import ck_moe_2stages - return ck_moe_2stages(a1=hidden_states, - w1=w1, - w2=w2, - topk_weight=topk_weights, - topk_ids=topk_ids, - fc1_scale=fc1_scale, - fc2_scale=fc2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_size=block_size, - expert_mask=expert_mask) - - -def rocm_aiter_ck_moe_2stages_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - fc1_scale: Optional[torch.Tensor] = None, - fc2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_size: Optional[list[int]] = None, - expert_mask: Optional[torch.Tensor] = None, -) -> torch.Tensor: + activation_method: int = ActivationMethod.SILU.value) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -274,6 +140,50 @@ def rocm_aiter_biased_grouped_topk_fake( pass +def rocm_aiter_fused_moe_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + expert_mask: Optional[torch.Tensor] = None, + activation_method: int = ActivationMethod.SILU.value, + quant_method: int = QuantMethod.NO.value, + doweight_stage1: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + from aiter import ActivationType, QuantType + from aiter.fused_moe import fused_moe + + activation = ActivationType(activation_method) + quant_type = QuantType(quant_method) + + return fused_moe(hidden_states, w1, w2, topk_weight, topk_ids, expert_mask, + activation, quant_type, doweight_stage1, w1_scale, + w2_scale, a1_scale, a2_scale) + + +def rocm_aiter_fused_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + expert_mask: Optional[torch.Tensor] = None, + activation_method: int = ActivationMethod.SILU.value, + quant_method: int = QuantMethod.NO.value, + doweight_stage1: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + if current_platform.is_rocm(): direct_register_custom_op( @@ -285,26 +195,10 @@ def rocm_aiter_biased_grouped_topk_fake( ) direct_register_custom_op( - op_name="rocm_aiter_fmoe_fp8_blockscale_g1u1", - op_func=rocm_aiter_fmoe_fp8_blockscale_g1u1_impl, - mutates_args=[], - fake_impl=rocm_aiter_fmoe_fp8_blockscale_g1u1_fake, - dispatch_key=current_platform.dispatch_key, - ) - - direct_register_custom_op( - op_name="rocm_aiter_asm_moe", - op_func=rocm_aiter_asm_moe_impl, - mutates_args=[], - fake_impl=rocm_aiter_asm_moe_fake, - dispatch_key=current_platform.dispatch_key, - ) - - direct_register_custom_op( - op_name="rocm_aiter_ck_moe_2stages", - op_func=rocm_aiter_ck_moe_2stages_impl, + op_name="rocm_aiter_fused_moe", + op_func=rocm_aiter_fused_moe_impl, mutates_args=[], - fake_impl=rocm_aiter_ck_moe_2stages_fake, + fake_impl=rocm_aiter_fused_moe_fake, dispatch_key=current_platform.dispatch_key, ) @@ -373,32 +267,14 @@ def rocm_aiter_fused_experts( a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[list[int]] = None) -> torch.Tensor: - from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) - + activation_method = (ActivationMethod.SILU + if activation == "silu" else ActivationMethod.GELU) # All AITER Fused MoE kernels are expecting the following datatypes topk_weights = topk_weights.to(torch.float32) topk_ids = topk_ids.to(torch.int32) - # w8a8 block-scaled - if block_shape is not None and use_fp8_w8a8: - assert not apply_router_weight_on_input, ( - "apply_router_weight_on_input is not supported for block scaled moe" - ) - assert w1_scale is not None - assert w2_scale is not None - - # The default block sizes are 128 in AITER. - block_shape = [128, 128] if block_shape is None else block_shape - - a1, a1_scale = per_token_group_quant_fp8(hidden_states, block_shape[1]) - - return torch.ops.vllm.rocm_aiter_fmoe_fp8_blockscale_g1u1( - topk_ids, topk_weights, hidden_states.dtype, None, a1, w1, w2, - w1_scale, w2_scale, a1_scale, block_shape, None) - # w8a8 per-channel quantization - elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: + if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input` # This applies topk_weights on the GEMM output of the first FC layer # rather than the second FC. @@ -421,60 +297,44 @@ def rocm_aiter_fused_experts( a16=False, per_tensor_quant_scale=None, expert_mask=None, - activation_str=activation) - - # w8a8 per-tensor activation per-tensor weight - elif use_fp8_w8a8: - assert not apply_router_weight_on_input, ( - "apply_router_weight_on_input is not supported for fp8_w8a8") - - # - faster static per-tensor-activation static per-tensor-weight - # fp8 quantization w8a8 - if a1_scale is not None and a2_scale is not None: - return torch.ops.vllm.rocm_aiter_ck_moe_2stages( - hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - fc1_scale=w1_scale, - fc2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale) - - # - fallback static per-tensor-activation static per-tensor-weight - # fp8 quantization w8a8 - # - dynamic per-tensor activation static per-tensor-weight - # fp8 quantization w8a8 - return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - fc1_scale=w1_scale, - fc2_scale=w2_scale, - fc1_smooth_scale=None, - fc2_smooth_scale=None, - a16=False, - activation=activation) - if apply_router_weight_on_input: - assert (topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" - _, topk = topk_weights.shape - assert ( - topk == 1 - ), "Only support topk=1 when `apply_router_weight_on_input` is True" - - hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) - topk_ids = topk_ids.to(torch.int32) - topk_weights = torch.ones_like(topk_weights, dtype=torch.float32) + activation_method=activation_method) - return torch.ops.vllm.rocm_aiter_ck_moe_2stages( - hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids) + else: + quant_method = QuantMethod.NO.value + + # w8a8 block-scaled + if block_shape is not None and use_fp8_w8a8: + assert not apply_router_weight_on_input, ( + "apply_router_weight_on_input is\ + not supported for block scaled moe") + assert w1_scale is not None + assert w2_scale is not None + quant_method = QuantMethod.BLOCK_128x128.value + elif use_fp8_w8a8: + # Currently only per tensor quantization method is enabled. + quant_method = QuantMethod.PER_TENSOR.value + + if apply_router_weight_on_input: + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + + return torch.ops.vllm.rocm_aiter_fused_moe( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + quant_method=quant_method, + activation_method=activation_method, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + doweight_stage1=apply_router_weight_on_input) def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, @@ -488,14 +348,21 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, return topk_weights, topk_indices -def shuffle_weights(*tensors: torch.Tensor, - layout: tuple[int, int]) -> tuple[torch.Tensor, ...]: +def shuffle_weights( + *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16) +) -> tuple[torch.Tensor, ...]: """ Applies shuffle_weight function from AITER to each input tensor and returns them. + + Rearranges (shuffles) the input tensor/s + into a specified block layout for optimized computation. Args: - *tensors: Variable number of torch.Tensor objects. + *tensors: Variable number of torch.Tensor objects. + layout: A pair of integers specifying the + block sizes used to divide the tensors during shuffling. + Default is (16, 16). Returns: A Tuple of shuffled tensors. @@ -503,25 +370,3 @@ def shuffle_weights(*tensors: torch.Tensor, from aiter.ops.shuffle import shuffle_weight return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) - - -def expand_weights(*tensors: torch.Tensor, - expansion_dims: list[int]) -> tuple[torch.Tensor, ...]: - """ - Expands the dimensions of input tensors. - - Args: - *tensors: A variable number of torch.Tensor objects. - expansion_dims: A list of expansion dimensions - corresponding to each tensor. - - Returns: - A Tuple of tensors with expanded dimensions. - """ - - assert len(tensors) == len(expansion_dims), \ - "Number of tensors must match the number of expansion dimensions." - - return tuple( - tensor.unsqueeze(-1).unsqueeze(-1).expand((-1, dim, -1)) - for tensor, dim in zip(tensors, expansion_dims)) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index fa0067c44802..9241ceeb4db2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -286,9 +286,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: rocm_aiter_fused_experts, shuffle_weights) # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data, - layer.w2_weight.data, - layout=(16, 16)) + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data) layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 652bf76673c5..fb5cf961b970 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -596,7 +596,7 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, def process_weights_after_loading(self, layer: Module) -> None: # Lazy import to avoid importing triton too early. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights) + is_rocm_aiter_moe_enabled, shuffle_weights) self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() @@ -628,9 +628,7 @@ def process_weights_after_loading(self, layer: Module) -> None: if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight.data, - layer.w2_weight.data, - layout=(16, 16)) + layer.w13_weight.data, layer.w2_weight.data) layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) @@ -676,20 +674,8 @@ def process_weights_after_loading(self, layer: Module) -> None: requires_grad=False) if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. - w13_scales, w2_scales = expand_weights( - layer.w13_weight_scale.data, - layer.w2_weight_scale.data, - expansion_dims=[ - layer.w13_weight.shape[1], layer.w2_weight.shape[1] - ]) - layer.w13_weight_scale = torch.nn.Parameter( - w13_scales.contiguous(), requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter( - w2_scales.contiguous(), requires_grad=False) - - shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight, - layer.w2_weight, - layout=(16, 16)) + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight, layer.w2_weight) layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) @@ -761,20 +747,8 @@ def process_weights_after_loading(self, layer: Module) -> None: start += shard_size if self.rocm_aiter_moe_enabled: - # reshaping weights is required for aiter moe kernel. - expansion_dims = [ - layer.w13_weight.shape[1], layer.w2_weight.shape[1] - ] - max_w13_scales, w2_scales = expand_weights( - max_w13_scales, - layer.w2_weight_scale.data, - expansion_dims=expansion_dims) - layer.w2_weight_scale = torch.nn.Parameter( - w2_scales.contiguous(), requires_grad=False) - - shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight, - layer.w2_weight, - layout=(32, 32)) + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight, layer.w2_weight) layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)