Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def prepare_communication_buffer_for_model(self,
if module.__class__.__name__ == "FusedMoE"
]
for module in moe_modules:
module.quant_method.init_prepare_finalize()
module.quant_method.init_prepare_finalize(module)

def dispatch(
self, hidden_states: torch.Tensor,
Expand Down
6 changes: 6 additions & 0 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,12 @@ def make(
if quant_dtype is None and isinstance(quant_config, Fp8Config):
quant_dtype = torch.float8_e4m3fn

from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Config)
if (quant_dtype is None and isinstance(quant_config, Mxfp4Config)
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8):
quant_dtype = "mxfp8"

from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4Config)
if quant_dtype is None and isinstance(quant_config,
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def maybe_make_prepare_finalize(

# Note: init_prepare_finalize should only be called by
# prepare_communication_buffer_for_model.
def init_prepare_finalize(self):
def init_prepare_finalize(self, layer: torch.nn.Module):
assert self.moe is not None
prepare_finalize = self.maybe_make_prepare_finalize(self.moe)

Expand All @@ -211,7 +211,7 @@ def init_prepare_finalize(self):
assert self.fused_experts is None, \
f"Attempt to override experts for {id(self)}!"
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize, self.moe)
experts = self.select_gemm_impl(prepare_finalize, self.moe, layer)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
Expand All @@ -221,6 +221,7 @@ def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
# based on the all2all implementation, select the appropriate
# gemm implementation
Expand Down Expand Up @@ -273,6 +274,7 @@ def select_gemm_impl(
prepare_finalize: FusedMoEPrepareAndFinalize,
# TODO(bnell): Remove. Every layer should have an moe config object.
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
Expand Down
197 changes: 197 additions & 0 deletions vllm/model_executor/layers/fused_moe/trtllm_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional

import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP)
from vllm.utils import next_power_of_2


class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):

def __init__(
self,
moe: FusedMoEConfig,
gemm1_alpha,
gemm1_beta,
gemm1_clamp_limit,
w13_bias,
w2_bias,
max_capture_size,
Comment on lines +19 to +24
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing types

):
super().__init__(moe.quant_config)
self.moe = moe
self.gemm1_alpha = gemm1_alpha
self.gemm1_beta = gemm1_beta
self.gemm1_clamp_limit = gemm1_clamp_limit
self.w13_bias = w13_bias
self.w2_bias = w2_bias
self.max_capture_size = max_capture_size

@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)

def supports_chunking(self) -> bool:
return True

def supports_expert_map(self) -> bool:
return True

def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()

def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# The workspaces for this implementation are managed by flashinfer.
# TODO(varun) : workspace1 is could be used as the output tensor. This
# is error-prone. Allow the `workspace_shapes` to return None workspaces
workspace1 = (M, K)
workspace2 = (0, 0)
output = (M, K)
return (workspace1, workspace2, output, a.dtype)

def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int,
local_num_experts: int):
# Number of tokens in the input tensor.
num_tokens = x.shape[0]
# Factor to account for the imbalance of the experts.
# factor equals to the
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
# 1.0 means perfect expert distribution.
# > 1.0 means some experts have more tokens than the perfect
# distribution.
# < 1.0 does not make sense.
imbalance_factor = 1.3
# Calculate the number of tokens per expert assuming perfect
# distribution.
num_tokens_per_expert = (num_tokens * top_k) // local_num_experts
# Apply the imbalance factor.
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the
# kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)

return tile_tokens_dim

def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
topk = topk_ids.size(-1)
local_num_experts = w1.size(0)
intermediate_size = w2.size(1)
local_expert_offset = self.moe.ep_rank * local_num_experts

x_quant = hidden_states
x_scale = a1q_scale
if x_scale is not None:
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
*x_quant.shape[:-1], -1)

packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16).view(torch.int16)

assert w1_scale is not None
assert w2_scale is not None
kwargs = {
"topk_ids":
packed_tensor,
"routing_bias":
None,
"hidden_states":
x_quant,
"hidden_states_scale":
x_scale,
"gemm1_weights":
w1,
"gemm1_weights_scale":
w1_scale,
"gemm1_bias":
self.w13_bias,
"gemm1_alpha":
self.gemm1_alpha,
"gemm1_beta":
self.gemm1_beta,
"gemm1_clamp_limit":
self.gemm1_clamp_limit,
"gemm2_weights":
w2,
"gemm2_weights_scale":
w2_scale,
"gemm2_bias":
self.w2_bias,
"output1_scale_scalar":
None,
"output1_scale_gate_scalar":
None,
"output2_scale_scalar":
None,
"num_experts":
global_num_experts,
"top_k":
topk,
"n_group":
None,
"topk_group":
None,
"intermediate_size":
intermediate_size,
"local_expert_offset":
local_expert_offset,
"local_num_experts":
local_num_experts,
"routed_scaling_factor":
None,
"tile_tokens_dim":
self._get_tile_tokens_dim(x_quant, topk, local_num_experts),
"routing_method_type":
1,
"do_finalize":
True,
"output":
output,
"tune_max_num_tokens":
self.max_capture_size,
}

from flashinfer import trtllm_fp4_block_scale_routed_moe
trtllm_fp4_block_scale_routed_moe(**kwargs)
return output
16 changes: 16 additions & 0 deletions vllm/model_executor/layers/fused_moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
per_token_group_quant_int8, per_token_quant_int8)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
quant_dequant_mxfp4)
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_quantize)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv
Expand Down Expand Up @@ -177,6 +179,18 @@ def _mxfp4_quantize(
return A, None


def _mxfp8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert A_scale is None
assert not per_act_token_quant
assert block_shape is None
return mxfp8_quantize(A)


def moe_kernel_quantize_input(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
Expand All @@ -195,6 +209,8 @@ def moe_kernel_quantize_input(
is_sf_swizzled_layout=is_fp4_scale_swizzled)
elif quant_dtype == "mxfp4":
return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == "mxfp8":
return _mxfp8_quantize(A, A_scale, per_act_token_quant, block_shape)
else:
return A, A_scale

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import enum
from enum import Enum
from typing import Callable, Optional
from typing import Any, Callable, Optional

import torch
from compressed_tensors import CompressionFormat
Expand Down Expand Up @@ -292,6 +292,7 @@ def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: Any,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: all these layer args should be torch.nn.Module also.

) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return the appropriate GEMM experts implementation."""
experts = select_nvfp4_gemm_impl(
Expand Down Expand Up @@ -688,11 +689,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device=device,
dtype=torch.int64)

def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: Any) -> FusedMoEPermuteExpertsUnpermute:
# cutlass path
if self.use_cutlass:
from vllm.model_executor.layers.fused_moe import (
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,7 @@ def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: Any,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts)
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: Any,
) -> mk.FusedMoEPermuteExpertsUnpermute:
experts = select_cutlass_fp8_gemm_impl(
moe,
Expand Down Expand Up @@ -1034,6 +1035,7 @@ def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: Any,
) -> mk.FusedMoEPermuteExpertsUnpermute:
experts = select_nvfp4_gemm_impl(
moe,
Expand Down
Loading