Skip to content

Commit e33ee23

Browse files
authored
[Bugfix] [AITER] [ROCm] Fix Quark MoE Quant Config and AITER Fused MoE quant type logic (vllm-project#27029)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
1 parent b10c64c commit e33ee23

File tree

3 files changed

+14
-3
lines changed

3 files changed

+14
-3
lines changed

vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,8 @@ def rocm_aiter_fused_experts(
492492
assert quant_config.w1_scale is not None
493493
assert quant_config.w2_scale is not None
494494
quant_method = QuantMethod.BLOCK_128x128.value
495+
elif quant_config.use_fp8_w8a8 and quant_config.per_out_ch_quant:
496+
quant_method = QuantMethod.PER_TOKEN.value
495497
elif quant_config.use_fp8_w8a8:
496498
# Currently only per tensor quantization method is enabled.
497499
quant_method = QuantMethod.PER_TENSOR.value

vllm/model_executor/layers/quantization/quark/quark_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,8 @@ def get_fused_moe_quant_config(
342342
w2_scale=layer.w2_weight_scale,
343343
a1_scale=layer.w13_input_scale,
344344
a2_scale=layer.w2_input_scale,
345-
per_act_token_quant=self.weight_qscheme == "per_channel",
345+
per_act_token_quant=self.input_qscheme == "per_channel",
346+
per_out_ch_quant=self.weight_qscheme == "per_channel",
346347
)
347348

348349
def apply(

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,8 +464,16 @@ def apply(
464464
else:
465465
qinput, x_scale = input_2d, input_scale
466466

467-
per_tensor_weights = weight_scale.numel() == 1
468-
per_tensor_activations = x_scale.numel() == 1
467+
# Must have dim() conditions
468+
# In per-token quant scenario, when the number of token is 1,
469+
# the scale will only have 1 elements.
470+
# Without checking the dim(),
471+
# we cannot distingushes between per-tensor and per-token quant.
472+
# Example:
473+
# When the number of token is 1, per-token scale is [[1]]
474+
# When per-tensor scale is [1] or ().
475+
per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2
476+
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
469477

470478
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
471479
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(

0 commit comments

Comments
 (0)