@@ -34,6 +34,7 @@ class GPTQMarlinState(Enum):
3434 "CompressedTensorsMoEMethod" ,
3535 "CompressedTensorsW8A8Fp8MoEMethod" ,
3636 "CompressedTensorsW8A8Fp8MoECutlassMethod" ,
37+ "CompressedTensorsW8A8Int8MoEMethod" ,
3738 "CompressedTensorsWNA16MarlinMoEMethod" ,
3839 "CompressedTensorsWNA16MoEMethod" ,
3940]
@@ -71,6 +72,8 @@ def get_moe_method(
7172 return CompressedTensorsW8A8Fp8MoECutlassMethod (quant_config )
7273 elif quant_config ._is_fp8_w8a8 (weight_quant , input_quant ):
7374 return CompressedTensorsW8A8Fp8MoEMethod (quant_config )
75+ elif quant_config ._is_dynamic_token_w8a8 (weight_quant , input_quant ):
76+ return CompressedTensorsW8A8Int8MoEMethod (quant_config )
7477 else :
7578 raise RuntimeError (
7679 f"Unsupported FusedMoe scheme: { weight_quant } , { input_quant } " )
@@ -545,6 +548,138 @@ def apply(
545548 )
546549
547550
551+ class CompressedTensorsW8A8Int8MoEMethod (CompressedTensorsMoEMethod ):
552+
553+ def __init__ (
554+ self ,
555+ quant_config : "CompressedTensorsConfig" # type: ignore # noqa E501
556+ ):
557+ self .quant_config = quant_config
558+ self .weight_quant = self .quant_config .target_scheme_map ["Linear" ].get (
559+ "weights" )
560+ self .input_quant = self .quant_config .target_scheme_map ["Linear" ].get (
561+ "input_activations" )
562+
563+ per_channel = (
564+ self .weight_quant .strategy == QuantizationStrategy .CHANNEL
565+ and self .input_quant .strategy == QuantizationStrategy .TOKEN )
566+ if not per_channel :
567+ raise ValueError (
568+ "For INT8 Fused MoE layers, we require channelwise, "
569+ "dynamic per token quantization. Found "
570+ f"{ self .weight_quant } , { self .input_quant } " )
571+
572+ self .static_input_scales = not self .input_quant .dynamic
573+ if self .static_input_scales :
574+ raise ValueError (
575+ "For INT8 Fused MoE layers, we require channelwise, "
576+ "dynamic per token quantization. Found static input scales." )
577+
578+ def create_weights (self , layer : torch .nn .Module , num_experts : int ,
579+ hidden_size : int , intermediate_size_per_partition : int ,
580+ params_dtype : torch .dtype , ** extra_weight_attrs ):
581+
582+ params_dtype = torch .int8
583+
584+ # WEIGHTS
585+ w13_weight = torch .nn .Parameter (torch .empty (
586+ num_experts ,
587+ 2 * intermediate_size_per_partition ,
588+ hidden_size ,
589+ dtype = params_dtype ),
590+ requires_grad = False )
591+ layer .register_parameter ("w13_weight" , w13_weight )
592+ set_weight_attrs (w13_weight , extra_weight_attrs )
593+
594+ w2_weight = torch .nn .Parameter (torch .empty (
595+ num_experts ,
596+ hidden_size ,
597+ intermediate_size_per_partition ,
598+ dtype = params_dtype ),
599+ requires_grad = False )
600+ layer .register_parameter ("w2_weight" , w2_weight )
601+ set_weight_attrs (w2_weight , extra_weight_attrs )
602+
603+ # WEIGHT_SCALES
604+ assert self .weight_quant .strategy == QuantizationStrategy .CHANNEL
605+ w13_weight_scale = torch .nn .Parameter (torch .ones (
606+ num_experts ,
607+ 2 * intermediate_size_per_partition ,
608+ 1 ,
609+ dtype = torch .float32 ),
610+ requires_grad = False )
611+ layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
612+ w2_weight_scale = torch .nn .Parameter (torch .ones (num_experts ,
613+ hidden_size ,
614+ 1 ,
615+ dtype = torch .float32 ),
616+ requires_grad = False )
617+ layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
618+ # Add PER-CHANNEL quantization for FusedMoE.weight_loader.
619+ extra_weight_attrs .update (
620+ {"quant_method" : FusedMoeWeightScaleSupported .CHANNEL .value })
621+ set_weight_attrs (w13_weight_scale , extra_weight_attrs )
622+ set_weight_attrs (w2_weight_scale , extra_weight_attrs )
623+
624+ # INPUT_SCALES
625+ assert not self .static_input_scales
626+ layer .w13_input_scale = None
627+ layer .w2_input_scale = None
628+
629+ def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
630+ pass
631+
632+ def apply (
633+ self ,
634+ layer : torch .nn .Module ,
635+ x : torch .Tensor ,
636+ router_logits : torch .Tensor ,
637+ top_k : int ,
638+ renormalize : bool ,
639+ use_grouped_topk : bool = False ,
640+ topk_group : Optional [int ] = None ,
641+ num_expert_group : Optional [int ] = None ,
642+ global_num_experts : int = - 1 ,
643+ expert_map : Optional [torch .Tensor ] = None ,
644+ custom_routing_function : Optional [Callable ] = None ,
645+ scoring_func : str = "softmax" ,
646+ e_score_correction_bias : Optional [torch .Tensor ] = None ,
647+ apply_router_weight_on_input : bool = False ,
648+ activation : str = "silu" ,
649+ ) -> torch .Tensor :
650+ from vllm .model_executor .layers .fused_moe import fused_experts
651+
652+ topk_weights , topk_ids = FusedMoE .select_experts (
653+ hidden_states = x ,
654+ router_logits = router_logits ,
655+ use_grouped_topk = use_grouped_topk ,
656+ top_k = top_k ,
657+ renormalize = renormalize ,
658+ topk_group = topk_group ,
659+ num_expert_group = num_expert_group ,
660+ custom_routing_function = custom_routing_function ,
661+ scoring_func = scoring_func ,
662+ e_score_correction_bias = e_score_correction_bias )
663+
664+ return fused_experts (
665+ hidden_states = x ,
666+ w1 = layer .w13_weight ,
667+ w2 = layer .w2_weight ,
668+ topk_weights = topk_weights ,
669+ topk_ids = topk_ids ,
670+ inplace = True ,
671+ activation = activation ,
672+ apply_router_weight_on_input = apply_router_weight_on_input ,
673+ use_int8_w8a8 = True ,
674+ per_channel_quant = True ,
675+ global_num_experts = global_num_experts ,
676+ expert_map = expert_map ,
677+ w1_scale = layer .w13_weight_scale ,
678+ w2_scale = layer .w2_weight_scale ,
679+ a1_scale = layer .w13_input_scale ,
680+ a2_scale = layer .w2_input_scale )
681+
682+
548683class CompressedTensorsWNA16MarlinMoEMethod (CompressedTensorsMoEMethod ):
549684
550685 def __init__ (
0 commit comments