File tree Expand file tree Collapse file tree 3 files changed +14
-3
lines changed
vllm/model_executor/layers Expand file tree Collapse file tree 3 files changed +14
-3
lines changed Original file line number Diff line number Diff line change @@ -492,6 +492,8 @@ def rocm_aiter_fused_experts(
492492 assert quant_config .w1_scale is not None
493493 assert quant_config .w2_scale is not None
494494 quant_method = QuantMethod .BLOCK_128x128 .value
495+ elif quant_config .use_fp8_w8a8 and quant_config .per_out_ch_quant :
496+ quant_method = QuantMethod .PER_TOKEN .value
495497 elif quant_config .use_fp8_w8a8 :
496498 # Currently only per tensor quantization method is enabled.
497499 quant_method = QuantMethod .PER_TENSOR .value
Original file line number Diff line number Diff line change @@ -342,7 +342,8 @@ def get_fused_moe_quant_config(
342342 w2_scale = layer .w2_weight_scale ,
343343 a1_scale = layer .w13_input_scale ,
344344 a2_scale = layer .w2_input_scale ,
345- per_act_token_quant = self .weight_qscheme == "per_channel" ,
345+ per_act_token_quant = self .input_qscheme == "per_channel" ,
346+ per_out_ch_quant = self .weight_qscheme == "per_channel" ,
346347 )
347348
348349 def apply (
Original file line number Diff line number Diff line change @@ -464,8 +464,16 @@ def apply(
464464 else :
465465 qinput , x_scale = input_2d , input_scale
466466
467- per_tensor_weights = weight_scale .numel () == 1
468- per_tensor_activations = x_scale .numel () == 1
467+ # Must have dim() conditions
468+ # In per-token quant scenario, when the number of token is 1,
469+ # the scale will only have 1 elements.
470+ # Without checking the dim(),
471+ # we cannot distingushes between per-tensor and per-token quant.
472+ # Example:
473+ # When the number of token is 1, per-token scale is [[1]]
474+ # When per-tensor scale is [1] or ().
475+ per_tensor_weights = (weight_scale .numel () == 1 ) and weight_scale .dim () < 2
476+ per_tensor_activations = (x_scale .numel () == 1 ) and x_scale .dim () < 2
469477
470478 # TODO(luka) do this dispatch during init (after ScaledMM refactor)
471479 w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm (
You can’t perform that action at this time.
0 commit comments