@@ -30,9 +30,11 @@ class GPTQMarlinState(Enum):
3030
3131
3232__all__ = [
33- "CompressedTensorsMoEMethod" , "CompressedTensorsW8A8Fp8MoEMethod" ,
33+ "CompressedTensorsMoEMethod" ,
34+ "CompressedTensorsW8A8Fp8MoEMethod" ,
3435 "CompressedTensorsW8A8Fp8MoECutlassMethod" ,
35- "CompressedTensorsWNA16MoEMethod"
36+ "CompressedTensorsWNA16MarlinMoEMethod" ,
37+ "CompressedTensorsWNA16MoEMethod" ,
3638]
3739
3840
@@ -41,8 +43,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
4143 @staticmethod
4244 def get_moe_method (
4345 quant_config : "CompressedTensorsConfig" , # type: ignore # noqa E501
44- activation : str ,
45- expert_map : Optional [torch .Tensor ],
46+ layer : torch .nn .Module ,
4647 ) -> "CompressedTensorsMoEMethod" :
4748 # TODO: @dsikka: refactor this to use schemes as other kernels
4849 # are supported + check if the layer is being ignored.
@@ -51,9 +52,20 @@ def get_moe_method(
5152 "input_activations" )
5253
5354 if quant_config ._is_wNa16_group_channel (weight_quant , input_quant ):
54- return CompressedTensorsWNA16MoEMethod (quant_config )
55+ # Prefer to use the non-marlin kernel when:
56+ # 1. Many experts (MarlinMoE gives poor performance when >= 16)
57+ # 2. Non-FP16 dtype (MarlinMoE only supports FP16)
58+ # 3. Actorder is not dynamic (g_idx is unsupported)
59+ # 4. Scaled are grouped (channelwise is unsupported)
60+ if ((layer .local_num_experts >= 16
61+ or layer .params_dtype != torch .float16 )
62+ and weight_quant .actorder != "group"
63+ and weight_quant .strategy == "group" ):
64+ return CompressedTensorsWNA16MoEMethod (quant_config )
65+ else :
66+ return CompressedTensorsWNA16MarlinMoEMethod (quant_config )
5567 elif (quant_config ._is_fp8_w8a8_sm90 (weight_quant , input_quant )
56- and activation == "silu" and expert_map is None ):
68+ and layer . activation == "silu" and layer . expert_map is None ):
5769 return CompressedTensorsW8A8Fp8MoECutlassMethod (quant_config )
5870 elif quant_config ._is_fp8_w8a8 (weight_quant , input_quant ):
5971 return CompressedTensorsW8A8Fp8MoEMethod (quant_config )
@@ -482,7 +494,7 @@ def apply(
482494 )
483495
484496
485- class CompressedTensorsWNA16MoEMethod (CompressedTensorsMoEMethod ):
497+ class CompressedTensorsWNA16MarlinMoEMethod (CompressedTensorsMoEMethod ):
486498
487499 def __init__ (
488500 self ,
@@ -823,3 +835,215 @@ def apply(
823835 sort_indices2 = layer .w2_g_idx_sort_indices ,
824836 num_bits = self .num_bits ,
825837 is_k_full = self .is_k_full )
838+
839+
840+ class CompressedTensorsWNA16MoEMethod (CompressedTensorsMoEMethod ):
841+
842+ def __init__ (
843+ self ,
844+ quant_config : "CompressedTensorsConfig" # type: ignore # noqa E501
845+ ):
846+ self .quant_config = quant_config
847+ # TODO: @dsikka: refactor this to use schemes as other kernels
848+ # are supported + check if the layer is being ignored.
849+ config = self .quant_config .target_scheme_map ["Linear" ].get ("weights" )
850+ self .num_bits = config .num_bits
851+ self .packed_factor = 32 // config .num_bits
852+ self .strategy = config .strategy
853+ # channelwise is not supported by this kernel
854+ assert config .strategy == "group"
855+ self .group_size = config .group_size
856+ # grouped actorder isn't supported by this kernel
857+ assert config .actorder != "group"
858+ assert config .symmetric , (
859+ "Only symmetric quantization is supported for MoE" )
860+
861+ if not (self .quant_config .quant_format
862+ == CompressionFormat .pack_quantized .value
863+ and self .num_bits in WNA16_SUPPORTED_BITS ):
864+ raise ValueError ("For Fused MoE layers, only " ,
865+ f"{ CompressionFormat .pack_quantized .value } " ,
866+ "is supported for the following bits: " ,
867+ f"{ WNA16_SUPPORTED_BITS } " )
868+
869+ def create_weights (self , layer : torch .nn .Module , num_experts : int ,
870+ hidden_size : int , intermediate_size_per_partition : int ,
871+ params_dtype : torch .dtype , ** extra_weight_attrs ):
872+
873+ # Will transpose the loaded weight along the
874+ # intermediate and hidden dim sizes. Will
875+ # shard for TP along the transposed dims
876+ extra_weight_attrs .update ({
877+ "is_transposed" : True ,
878+ "quant_method" : self .strategy
879+ })
880+ w13_weight = torch .nn .Parameter (torch .empty (
881+ num_experts ,
882+ hidden_size // self .packed_factor ,
883+ 2 * intermediate_size_per_partition ,
884+ dtype = torch .int32 ),
885+ requires_grad = False )
886+ layer .register_parameter ("w13_weight_packed" , w13_weight )
887+ set_weight_attrs (w13_weight , extra_weight_attrs )
888+
889+ w2_weight = torch .nn .Parameter (torch .empty (
890+ num_experts ,
891+ intermediate_size_per_partition // self .packed_factor ,
892+ hidden_size ,
893+ dtype = torch .int32 ),
894+ requires_grad = False )
895+ layer .register_parameter ("w2_weight_packed" , w2_weight )
896+ set_weight_attrs (w2_weight , extra_weight_attrs )
897+
898+ w2_scales_size = intermediate_size_per_partition
899+
900+ if self .strategy == "channel" :
901+ num_groups_w2 = num_groups_w13 = 1
902+ self .group_size = - 1
903+ else :
904+ num_groups_w2 = w2_scales_size // self .group_size
905+ num_groups_w13 = hidden_size // self .group_size
906+
907+ w13_scale = torch .nn .Parameter (torch .ones (
908+ num_experts ,
909+ num_groups_w13 ,
910+ 2 * intermediate_size_per_partition ,
911+ dtype = params_dtype ),
912+ requires_grad = False )
913+ layer .register_parameter ("w13_weight_scale" , w13_scale )
914+ set_weight_attrs (w13_scale , extra_weight_attrs )
915+
916+ w2_scale = torch .nn .Parameter (torch .ones (num_experts ,
917+ num_groups_w2 ,
918+ hidden_size ,
919+ dtype = params_dtype ),
920+ requires_grad = False )
921+ layer .register_parameter ("w2_weight_scale" , w2_scale )
922+ set_weight_attrs (w2_scale , extra_weight_attrs )
923+ set_weight_attrs (w2_scale , {"load_full_w2" : False })
924+
925+ w2_weight_shape = torch .nn .Parameter (torch .empty (num_experts , 2 ),
926+ requires_grad = False )
927+ layer .register_parameter ("w2_weight_shape" , w2_weight_shape )
928+ set_weight_attrs (w2_weight_shape , extra_weight_attrs )
929+ w13_weight_shape = torch .nn .Parameter (torch .empty (num_experts , 2 ),
930+ requires_grad = False )
931+
932+ layer .register_parameter ("w13_weight_shape" , w13_weight_shape )
933+ set_weight_attrs (w13_weight_shape , extra_weight_attrs )
934+
935+ w13_g_idx = torch .nn .Parameter (
936+ torch .empty (
937+ num_experts ,
938+ hidden_size ,
939+ dtype = torch .int32 ,
940+ ),
941+ requires_grad = False ,
942+ )
943+ layer .register_parameter ("w13_weight_g_idx" , w13_g_idx )
944+ set_weight_attrs (w13_g_idx , extra_weight_attrs )
945+
946+ w2_g_idx = torch .nn .Parameter (
947+ torch .empty (
948+ num_experts ,
949+ intermediate_size_per_partition ,
950+ dtype = torch .int32 ,
951+ ),
952+ requires_grad = False ,
953+ )
954+ layer .register_parameter ("w2_weight_g_idx" , w2_g_idx )
955+ set_weight_attrs (w2_g_idx , extra_weight_attrs )
956+
957+ w13_g_idx_sort_indices = torch .nn .Parameter (
958+ torch .empty (
959+ num_experts ,
960+ hidden_size ,
961+ dtype = torch .int32 ,
962+ ),
963+ requires_grad = False ,
964+ )
965+ layer .register_parameter ("w13_g_idx_sort_indices" ,
966+ w13_g_idx_sort_indices )
967+ set_weight_attrs (w13_g_idx_sort_indices , extra_weight_attrs )
968+
969+ w2_g_idx_sort_indices = torch .nn .Parameter (
970+ torch .empty (
971+ num_experts ,
972+ intermediate_size_per_partition ,
973+ dtype = torch .int32 ,
974+ ),
975+ requires_grad = False ,
976+ )
977+ layer .register_parameter ("w2_g_idx_sort_indices" ,
978+ w2_g_idx_sort_indices )
979+ set_weight_attrs (w2_g_idx_sort_indices , extra_weight_attrs )
980+
981+ layer .a13_scale = None
982+ layer .a2_scale = None
983+
984+ def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
985+ # Reconfigure packed weights and scales to match moe_wna16 format
986+ layer .w13_weight_packed = torch .nn .Parameter (
987+ layer .w13_weight_packed .transpose (1 , 2 ).contiguous ().view (
988+ torch .uint8 ),
989+ requires_grad = False )
990+ layer .w2_weight_packed = torch .nn .Parameter (
991+ layer .w2_weight_packed .transpose (1 ,
992+ 2 ).contiguous ().view (torch .uint8 ),
993+ requires_grad = False )
994+ layer .w13_weight_scale = torch .nn .Parameter (
995+ layer .w13_weight_scale .transpose (1 , 2 ).contiguous (),
996+ requires_grad = False )
997+ layer .w2_weight_scale = torch .nn .Parameter (
998+ layer .w2_weight_scale .transpose (1 , 2 ).contiguous (),
999+ requires_grad = False )
1000+
1001+ def apply (
1002+ self ,
1003+ layer : torch .nn .Module ,
1004+ x : torch .Tensor ,
1005+ router_logits : torch .Tensor ,
1006+ top_k : int ,
1007+ renormalize : bool ,
1008+ use_grouped_topk : bool = False ,
1009+ topk_group : Optional [int ] = None ,
1010+ num_expert_group : Optional [int ] = None ,
1011+ global_num_experts : int = - 1 ,
1012+ expert_map : Optional [torch .Tensor ] = None ,
1013+ custom_routing_function : Optional [Callable ] = None ,
1014+ scoring_func : str = "softmax" ,
1015+ e_score_correction_bias : Optional [torch .Tensor ] = None ,
1016+ apply_router_weight_on_input : bool = False ,
1017+ activation : str = "silu" ,
1018+ ) -> torch .Tensor :
1019+ from vllm .model_executor .layers .fused_moe import fused_experts
1020+ assert activation == "silu" , "Only SiLU activation is supported."
1021+ topk_weights , topk_ids = FusedMoE .select_experts (
1022+ hidden_states = x ,
1023+ router_logits = router_logits ,
1024+ use_grouped_topk = use_grouped_topk ,
1025+ top_k = top_k ,
1026+ renormalize = renormalize ,
1027+ topk_group = topk_group ,
1028+ num_expert_group = num_expert_group ,
1029+ custom_routing_function = custom_routing_function ,
1030+ scoring_func = scoring_func ,
1031+ e_score_correction_bias = e_score_correction_bias )
1032+
1033+ return fused_experts (
1034+ x ,
1035+ layer .w13_weight_packed ,
1036+ layer .w2_weight_packed ,
1037+ topk_weights = topk_weights ,
1038+ topk_ids = topk_ids ,
1039+ inplace = True ,
1040+ use_int4_w4a16 = self .num_bits == 4 ,
1041+ use_int8_w8a16 = self .num_bits == 8 ,
1042+ global_num_experts = global_num_experts ,
1043+ apply_router_weight_on_input = apply_router_weight_on_input ,
1044+ expert_map = expert_map ,
1045+ w1_scale = layer .w13_weight_scale ,
1046+ w2_scale = layer .w2_weight_scale ,
1047+ w1_zp = None ,
1048+ w2_zp = None ,
1049+ block_shape = [0 , self .group_size ])
0 commit comments