Skip to content

Commit 868c546

Browse files
authored
Support W8A8 INT8 MoE for compressed-tensors (#16745)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent 99404f5 commit 868c546

File tree

2 files changed

+136
-1
lines changed

2 files changed

+136
-1
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class GPTQMarlinState(Enum):
3434
"CompressedTensorsMoEMethod",
3535
"CompressedTensorsW8A8Fp8MoEMethod",
3636
"CompressedTensorsW8A8Fp8MoECutlassMethod",
37+
"CompressedTensorsW8A8Int8MoEMethod",
3738
"CompressedTensorsWNA16MarlinMoEMethod",
3839
"CompressedTensorsWNA16MoEMethod",
3940
]
@@ -71,6 +72,8 @@ def get_moe_method(
7172
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
7273
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
7374
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
75+
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
76+
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
7477
else:
7578
raise RuntimeError(
7679
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
@@ -545,6 +548,138 @@ def apply(
545548
)
546549

547550

551+
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
552+
553+
def __init__(
554+
self,
555+
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
556+
):
557+
self.quant_config = quant_config
558+
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
559+
"weights")
560+
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
561+
"input_activations")
562+
563+
per_channel = (
564+
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
565+
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
566+
if not per_channel:
567+
raise ValueError(
568+
"For INT8 Fused MoE layers, we require channelwise, "
569+
"dynamic per token quantization. Found "
570+
f"{self.weight_quant}, {self.input_quant}")
571+
572+
self.static_input_scales = not self.input_quant.dynamic
573+
if self.static_input_scales:
574+
raise ValueError(
575+
"For INT8 Fused MoE layers, we require channelwise, "
576+
"dynamic per token quantization. Found static input scales.")
577+
578+
def create_weights(self, layer: torch.nn.Module, num_experts: int,
579+
hidden_size: int, intermediate_size_per_partition: int,
580+
params_dtype: torch.dtype, **extra_weight_attrs):
581+
582+
params_dtype = torch.int8
583+
584+
# WEIGHTS
585+
w13_weight = torch.nn.Parameter(torch.empty(
586+
num_experts,
587+
2 * intermediate_size_per_partition,
588+
hidden_size,
589+
dtype=params_dtype),
590+
requires_grad=False)
591+
layer.register_parameter("w13_weight", w13_weight)
592+
set_weight_attrs(w13_weight, extra_weight_attrs)
593+
594+
w2_weight = torch.nn.Parameter(torch.empty(
595+
num_experts,
596+
hidden_size,
597+
intermediate_size_per_partition,
598+
dtype=params_dtype),
599+
requires_grad=False)
600+
layer.register_parameter("w2_weight", w2_weight)
601+
set_weight_attrs(w2_weight, extra_weight_attrs)
602+
603+
# WEIGHT_SCALES
604+
assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL
605+
w13_weight_scale = torch.nn.Parameter(torch.ones(
606+
num_experts,
607+
2 * intermediate_size_per_partition,
608+
1,
609+
dtype=torch.float32),
610+
requires_grad=False)
611+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
612+
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
613+
hidden_size,
614+
1,
615+
dtype=torch.float32),
616+
requires_grad=False)
617+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
618+
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
619+
extra_weight_attrs.update(
620+
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
621+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
622+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
623+
624+
# INPUT_SCALES
625+
assert not self.static_input_scales
626+
layer.w13_input_scale = None
627+
layer.w2_input_scale = None
628+
629+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
630+
pass
631+
632+
def apply(
633+
self,
634+
layer: torch.nn.Module,
635+
x: torch.Tensor,
636+
router_logits: torch.Tensor,
637+
top_k: int,
638+
renormalize: bool,
639+
use_grouped_topk: bool = False,
640+
topk_group: Optional[int] = None,
641+
num_expert_group: Optional[int] = None,
642+
global_num_experts: int = -1,
643+
expert_map: Optional[torch.Tensor] = None,
644+
custom_routing_function: Optional[Callable] = None,
645+
scoring_func: str = "softmax",
646+
e_score_correction_bias: Optional[torch.Tensor] = None,
647+
apply_router_weight_on_input: bool = False,
648+
activation: str = "silu",
649+
) -> torch.Tensor:
650+
from vllm.model_executor.layers.fused_moe import fused_experts
651+
652+
topk_weights, topk_ids = FusedMoE.select_experts(
653+
hidden_states=x,
654+
router_logits=router_logits,
655+
use_grouped_topk=use_grouped_topk,
656+
top_k=top_k,
657+
renormalize=renormalize,
658+
topk_group=topk_group,
659+
num_expert_group=num_expert_group,
660+
custom_routing_function=custom_routing_function,
661+
scoring_func=scoring_func,
662+
e_score_correction_bias=e_score_correction_bias)
663+
664+
return fused_experts(
665+
hidden_states=x,
666+
w1=layer.w13_weight,
667+
w2=layer.w2_weight,
668+
topk_weights=topk_weights,
669+
topk_ids=topk_ids,
670+
inplace=True,
671+
activation=activation,
672+
apply_router_weight_on_input=apply_router_weight_on_input,
673+
use_int8_w8a8=True,
674+
per_channel_quant=True,
675+
global_num_experts=global_num_experts,
676+
expert_map=expert_map,
677+
w1_scale=layer.w13_weight_scale,
678+
w2_scale=layer.w2_weight_scale,
679+
a1_scale=layer.w13_input_scale,
680+
a2_scale=layer.w2_input_scale)
681+
682+
548683
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
549684

550685
def __init__(

vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def apply_weights(self,
111111
# * dynamic, i_s is None and x_s computed from x.
112112
# * static, i_s is scalar and x_s is i_s.
113113
symmetric = azp_adj is None
114-
x_q, x_s, x_zp = ops.scaled_int8_quant(x,
114+
x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(),
115115
i_s,
116116
i_zp,
117117
symmetric=symmetric)

0 commit comments

Comments
 (0)