Skip to content

Commit e9d39da

Browse files
amirkl94tlrmchlsmthmgoin
authored andcommitted
Bugfix: Cutlass FP8 FusedMoE bad scaling factors (vllm-project#27255)
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 0b6cd58 commit e9d39da

File tree

4 files changed

+40
-14
lines changed

4 files changed

+40
-14
lines changed

tests/kernels/moe/test_flashinfer.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
import torch
77

88
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
9-
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
9+
from vllm.model_executor.layers.fused_moe.config import (
10+
FusedMoEQuantConfig,
11+
fp8_w8a8_moe_quant_config,
12+
)
1013
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
1114
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
1215
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
@@ -22,10 +25,10 @@
2225
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
2326

2427
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
25-
100
28+
90
2629
):
2730
pytest.skip(
28-
"Requires flashinfer_cutlass_fused_moe and nvfp4 support",
31+
"Supported for sm >= 90",
2932
allow_module_level=True,
3033
)
3134

@@ -131,6 +134,8 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
131134
topk: int,
132135
monkeypatch,
133136
):
137+
if not current_platform.has_device_capability(100):
138+
pytest.skip("Test is only supported for sm >= 100")
134139
current_platform.seed_everything(7)
135140
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
136141
with set_current_vllm_config(vllm_config):
@@ -184,9 +189,6 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
184189
torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)
185190

186191

187-
@pytest.mark.skip(
188-
"Requires flashinfer version that contains https://github.com/flashinfer-ai/flashinfer/pull/1472"
189-
)
190192
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
191193
@pytest.mark.parametrize("e", NUM_EXPERTS)
192194
@pytest.mark.parametrize("topk", TOP_KS)
@@ -216,9 +218,13 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
216218

217219
quant_config = fp8_w8a8_moe_quant_config(
218220
w1_scale=td.w13_weight_scale,
221+
g1_alphas=(td.w13_weight_scale * td.a1_scale).squeeze(),
219222
w2_scale=td.w2_weight_scale,
223+
g2_alphas=(td.w2_weight_scale * td.a2_scale).squeeze(),
220224
a1_scale=td.a1_scale,
225+
a1_gscale=td.a1_scale,
221226
a2_scale=td.a2_scale,
227+
a2_gscale=1.0 / td.a2_scale,
222228
per_act_token_quant=False,
223229
)
224230

@@ -238,6 +244,12 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
238244

239245
td.layer.dp_size = 1
240246

247+
def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
248+
return quant_config
249+
250+
td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
251+
td.layer.quant_method = td.layer
252+
241253
flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8(
242254
td.hidden_states,
243255
td.layer,

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,16 +463,24 @@ def fp8_w8a8_moe_quant_config(
463463
per_act_token_quant: bool = False,
464464
per_out_ch_quant: bool = False,
465465
block_shape: list[int] | None = None,
466+
a1_gscale: torch.Tensor | None = None,
467+
a2_gscale: torch.Tensor | None = None,
468+
g1_alphas: torch.Tensor | None = None,
469+
g2_alphas: torch.Tensor | None = None,
466470
) -> FusedMoEQuantConfig:
467471
"""
468472
Construct a quant config for fp8 activations and fp8 weights.
469473
"""
470474
return FusedMoEQuantConfig.make(
471475
torch.float8_e4m3fn,
472476
w1_scale=w1_scale,
477+
g1_alphas=g1_alphas,
473478
w2_scale=w2_scale,
479+
g2_alphas=g2_alphas,
474480
a1_scale=a1_scale,
481+
a1_gscale=a1_gscale,
475482
a2_scale=a2_scale,
483+
a2_gscale=a2_gscale,
476484
per_act_token_quant=per_act_token_quant,
477485
per_out_ch_quant=per_out_ch_quant,
478486
block_shape=block_shape,

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def prepare(
170170
self._apply_router_weight_on_input(
171171
a1, topk_weights, topk_ids, apply_router_weight_on_input
172172
)
173-
if not self.use_dp:
173+
if not self.use_dp and quant_config.quant_dtype == "nvfp4":
174174
return a1, None, None, topk_ids, topk_weights
175175

176176
a1q, a1q_scale = moe_kernel_quantize_input(
@@ -181,11 +181,13 @@ def prepare(
181181
quant_config.block_shape,
182182
is_fp4_scale_swizzled=not self.use_dp,
183183
)
184-
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
185-
[topk_weights, topk_ids, a1q, a1q_scale],
186-
dim=0,
187-
sizes=get_local_sizes(),
188-
)
184+
185+
if self.use_dp:
186+
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
187+
[topk_weights, topk_ids, a1q, a1q_scale],
188+
dim=0,
189+
sizes=get_local_sizes(),
190+
)
189191
if quant_config.quant_dtype == "nvfp4":
190192
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
191193

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -567,9 +567,13 @@ def get_fused_moe_quant_config(
567567

568568
return fp8_w8a8_moe_quant_config(
569569
w1_scale=layer.w13_weight_scale,
570+
g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(),
570571
w2_scale=layer.w2_weight_scale,
572+
g2_alphas=(layer.w2_weight_scale * layer.w2_input_scale).squeeze(),
571573
a1_scale=layer.w13_input_scale,
574+
a1_gscale=layer.w13_input_scale,
572575
a2_scale=layer.w2_input_scale,
576+
a2_gscale=1.0 / layer.w2_input_scale,
573577
per_act_token_quant=False,
574578
)
575579

@@ -1138,8 +1142,8 @@ def __init__(
11381142
moe: FusedMoEConfig,
11391143
layer: torch.nn.Module,
11401144
) -> None:
1141-
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
1142-
detect_nvfp4_moe_support,
1145+
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
1146+
detect_nvfp4_moe_support, # noqa: E501
11431147
)
11441148

11451149
super().__init__(moe)

0 commit comments

Comments
 (0)