66
77import torch
88from compressed_tensors import CompressionFormat
9- from compressed_tensors .quantization import QuantizationStrategy
9+ from compressed_tensors .quantization import (ActivationOrdering ,
10+ QuantizationStrategy )
1011
1112import vllm .model_executor .layers .fused_moe # noqa
1213from vllm import _custom_ops as ops
@@ -30,9 +31,11 @@ class GPTQMarlinState(Enum):
3031
3132
3233__all__ = [
33- "CompressedTensorsMoEMethod" , "CompressedTensorsW8A8Fp8MoEMethod" ,
34+ "CompressedTensorsMoEMethod" ,
35+ "CompressedTensorsW8A8Fp8MoEMethod" ,
3436 "CompressedTensorsW8A8Fp8MoECutlassMethod" ,
35- "CompressedTensorsWNA16MoEMethod"
37+ "CompressedTensorsWNA16MarlinMoEMethod" ,
38+ "CompressedTensorsWNA16MoEMethod" ,
3639]
3740
3841
@@ -41,8 +44,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
4144 @staticmethod
4245 def get_moe_method (
4346 quant_config : "CompressedTensorsConfig" , # type: ignore # noqa E501
44- activation : str ,
45- expert_map : Optional [torch .Tensor ],
47+ layer : torch .nn .Module ,
4648 ) -> "CompressedTensorsMoEMethod" :
4749 # TODO: @dsikka: refactor this to use schemes as other kernels
4850 # are supported + check if the layer is being ignored.
@@ -51,9 +53,21 @@ def get_moe_method(
5153 "input_activations" )
5254
5355 if quant_config ._is_wNa16_group_channel (weight_quant , input_quant ):
54- return CompressedTensorsWNA16MoEMethod (quant_config )
56+ # Prefer to use the non-marlin kernel when:
57+ # 1. Many experts (MarlinMoE gives poor performance when >= 16)
58+ # 2. Non-FP16 dtype (MarlinMoE only supports FP16)
59+ # 3. Actorder is not group/dynamic (g_idx is unsupported)
60+ # 4. Scaled are grouped (channelwise is unsupported)
61+ if ((layer .local_num_experts >= 16
62+ or layer .params_dtype != torch .float16 ) and
63+ weight_quant .actorder not in (ActivationOrdering .GROUP ,
64+ ActivationOrdering .DYNAMIC )
65+ and weight_quant .strategy in QuantizationStrategy .GROUP ):
66+ return CompressedTensorsWNA16MoEMethod (quant_config )
67+ else :
68+ return CompressedTensorsWNA16MarlinMoEMethod (quant_config )
5569 elif (quant_config ._is_fp8_w8a8_sm90 (weight_quant , input_quant )
56- and activation == "silu" and expert_map is None ):
70+ and layer . activation == "silu" and layer . expert_map is None ):
5771 return CompressedTensorsW8A8Fp8MoECutlassMethod (quant_config )
5872 elif quant_config ._is_fp8_w8a8 (weight_quant , input_quant ):
5973 return CompressedTensorsW8A8Fp8MoEMethod (quant_config )
@@ -482,7 +496,7 @@ def apply(
482496 )
483497
484498
485- class CompressedTensorsWNA16MoEMethod (CompressedTensorsMoEMethod ):
499+ class CompressedTensorsWNA16MarlinMoEMethod (CompressedTensorsMoEMethod ):
486500
487501 def __init__ (
488502 self ,
@@ -823,3 +837,215 @@ def apply(
823837 sort_indices2 = layer .w2_g_idx_sort_indices ,
824838 num_bits = self .num_bits ,
825839 is_k_full = self .is_k_full )
840+
841+
842+ class CompressedTensorsWNA16MoEMethod (CompressedTensorsMoEMethod ):
843+
844+ def __init__ (
845+ self ,
846+ quant_config : "CompressedTensorsConfig" # type: ignore # noqa E501
847+ ):
848+ self .quant_config = quant_config
849+ # TODO: @dsikka: refactor this to use schemes as other kernels
850+ # are supported + check if the layer is being ignored.
851+ config = self .quant_config .target_scheme_map ["Linear" ].get ("weights" )
852+ self .num_bits = config .num_bits
853+ self .packed_factor = 32 // config .num_bits
854+ self .strategy = config .strategy
855+ # channelwise is not supported by this kernel
856+ assert config .strategy == "group"
857+ self .group_size = config .group_size
858+ # grouped actorder isn't supported by this kernel
859+ assert config .actorder != "group"
860+ assert config .symmetric , (
861+ "Only symmetric quantization is supported for MoE" )
862+
863+ if not (self .quant_config .quant_format
864+ == CompressionFormat .pack_quantized .value
865+ and self .num_bits in WNA16_SUPPORTED_BITS ):
866+ raise ValueError ("For Fused MoE layers, only " ,
867+ f"{ CompressionFormat .pack_quantized .value } " ,
868+ "is supported for the following bits: " ,
869+ f"{ WNA16_SUPPORTED_BITS } " )
870+
871+ def create_weights (self , layer : torch .nn .Module , num_experts : int ,
872+ hidden_size : int , intermediate_size_per_partition : int ,
873+ params_dtype : torch .dtype , ** extra_weight_attrs ):
874+
875+ # Will transpose the loaded weight along the
876+ # intermediate and hidden dim sizes. Will
877+ # shard for TP along the transposed dims
878+ extra_weight_attrs .update ({
879+ "is_transposed" : True ,
880+ "quant_method" : self .strategy
881+ })
882+ w13_weight = torch .nn .Parameter (torch .empty (
883+ num_experts ,
884+ hidden_size // self .packed_factor ,
885+ 2 * intermediate_size_per_partition ,
886+ dtype = torch .int32 ),
887+ requires_grad = False )
888+ layer .register_parameter ("w13_weight_packed" , w13_weight )
889+ set_weight_attrs (w13_weight , extra_weight_attrs )
890+
891+ w2_weight = torch .nn .Parameter (torch .empty (
892+ num_experts ,
893+ intermediate_size_per_partition // self .packed_factor ,
894+ hidden_size ,
895+ dtype = torch .int32 ),
896+ requires_grad = False )
897+ layer .register_parameter ("w2_weight_packed" , w2_weight )
898+ set_weight_attrs (w2_weight , extra_weight_attrs )
899+
900+ w2_scales_size = intermediate_size_per_partition
901+
902+ if self .strategy == "channel" :
903+ num_groups_w2 = num_groups_w13 = 1
904+ self .group_size = - 1
905+ else :
906+ num_groups_w2 = w2_scales_size // self .group_size
907+ num_groups_w13 = hidden_size // self .group_size
908+
909+ w13_scale = torch .nn .Parameter (torch .ones (
910+ num_experts ,
911+ num_groups_w13 ,
912+ 2 * intermediate_size_per_partition ,
913+ dtype = params_dtype ),
914+ requires_grad = False )
915+ layer .register_parameter ("w13_weight_scale" , w13_scale )
916+ set_weight_attrs (w13_scale , extra_weight_attrs )
917+
918+ w2_scale = torch .nn .Parameter (torch .ones (num_experts ,
919+ num_groups_w2 ,
920+ hidden_size ,
921+ dtype = params_dtype ),
922+ requires_grad = False )
923+ layer .register_parameter ("w2_weight_scale" , w2_scale )
924+ set_weight_attrs (w2_scale , extra_weight_attrs )
925+ set_weight_attrs (w2_scale , {"load_full_w2" : False })
926+
927+ w2_weight_shape = torch .nn .Parameter (torch .empty (num_experts , 2 ),
928+ requires_grad = False )
929+ layer .register_parameter ("w2_weight_shape" , w2_weight_shape )
930+ set_weight_attrs (w2_weight_shape , extra_weight_attrs )
931+ w13_weight_shape = torch .nn .Parameter (torch .empty (num_experts , 2 ),
932+ requires_grad = False )
933+
934+ layer .register_parameter ("w13_weight_shape" , w13_weight_shape )
935+ set_weight_attrs (w13_weight_shape , extra_weight_attrs )
936+
937+ w13_g_idx = torch .nn .Parameter (
938+ torch .empty (
939+ num_experts ,
940+ hidden_size ,
941+ dtype = torch .int32 ,
942+ ),
943+ requires_grad = False ,
944+ )
945+ layer .register_parameter ("w13_weight_g_idx" , w13_g_idx )
946+ set_weight_attrs (w13_g_idx , extra_weight_attrs )
947+
948+ w2_g_idx = torch .nn .Parameter (
949+ torch .empty (
950+ num_experts ,
951+ intermediate_size_per_partition ,
952+ dtype = torch .int32 ,
953+ ),
954+ requires_grad = False ,
955+ )
956+ layer .register_parameter ("w2_weight_g_idx" , w2_g_idx )
957+ set_weight_attrs (w2_g_idx , extra_weight_attrs )
958+
959+ w13_g_idx_sort_indices = torch .nn .Parameter (
960+ torch .empty (
961+ num_experts ,
962+ hidden_size ,
963+ dtype = torch .int32 ,
964+ ),
965+ requires_grad = False ,
966+ )
967+ layer .register_parameter ("w13_g_idx_sort_indices" ,
968+ w13_g_idx_sort_indices )
969+ set_weight_attrs (w13_g_idx_sort_indices , extra_weight_attrs )
970+
971+ w2_g_idx_sort_indices = torch .nn .Parameter (
972+ torch .empty (
973+ num_experts ,
974+ intermediate_size_per_partition ,
975+ dtype = torch .int32 ,
976+ ),
977+ requires_grad = False ,
978+ )
979+ layer .register_parameter ("w2_g_idx_sort_indices" ,
980+ w2_g_idx_sort_indices )
981+ set_weight_attrs (w2_g_idx_sort_indices , extra_weight_attrs )
982+
983+ layer .a13_scale = None
984+ layer .a2_scale = None
985+
986+ def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
987+ # Reconfigure packed weights and scales to match moe_wna16 format
988+ layer .w13_weight_packed = torch .nn .Parameter (
989+ layer .w13_weight_packed .transpose (1 , 2 ).contiguous ().view (
990+ torch .uint8 ),
991+ requires_grad = False )
992+ layer .w2_weight_packed = torch .nn .Parameter (
993+ layer .w2_weight_packed .transpose (1 ,
994+ 2 ).contiguous ().view (torch .uint8 ),
995+ requires_grad = False )
996+ layer .w13_weight_scale = torch .nn .Parameter (
997+ layer .w13_weight_scale .transpose (1 , 2 ).contiguous (),
998+ requires_grad = False )
999+ layer .w2_weight_scale = torch .nn .Parameter (
1000+ layer .w2_weight_scale .transpose (1 , 2 ).contiguous (),
1001+ requires_grad = False )
1002+
1003+ def apply (
1004+ self ,
1005+ layer : torch .nn .Module ,
1006+ x : torch .Tensor ,
1007+ router_logits : torch .Tensor ,
1008+ top_k : int ,
1009+ renormalize : bool ,
1010+ use_grouped_topk : bool = False ,
1011+ topk_group : Optional [int ] = None ,
1012+ num_expert_group : Optional [int ] = None ,
1013+ global_num_experts : int = - 1 ,
1014+ expert_map : Optional [torch .Tensor ] = None ,
1015+ custom_routing_function : Optional [Callable ] = None ,
1016+ scoring_func : str = "softmax" ,
1017+ e_score_correction_bias : Optional [torch .Tensor ] = None ,
1018+ apply_router_weight_on_input : bool = False ,
1019+ activation : str = "silu" ,
1020+ ) -> torch .Tensor :
1021+ from vllm .model_executor .layers .fused_moe import fused_experts
1022+ assert activation == "silu" , "Only SiLU activation is supported."
1023+ topk_weights , topk_ids = FusedMoE .select_experts (
1024+ hidden_states = x ,
1025+ router_logits = router_logits ,
1026+ use_grouped_topk = use_grouped_topk ,
1027+ top_k = top_k ,
1028+ renormalize = renormalize ,
1029+ topk_group = topk_group ,
1030+ num_expert_group = num_expert_group ,
1031+ custom_routing_function = custom_routing_function ,
1032+ scoring_func = scoring_func ,
1033+ e_score_correction_bias = e_score_correction_bias )
1034+
1035+ return fused_experts (
1036+ x ,
1037+ layer .w13_weight_packed ,
1038+ layer .w2_weight_packed ,
1039+ topk_weights = topk_weights ,
1040+ topk_ids = topk_ids ,
1041+ inplace = True ,
1042+ use_int4_w4a16 = self .num_bits == 4 ,
1043+ use_int8_w8a16 = self .num_bits == 8 ,
1044+ global_num_experts = global_num_experts ,
1045+ apply_router_weight_on_input = apply_router_weight_on_input ,
1046+ expert_map = expert_map ,
1047+ w1_scale = layer .w13_weight_scale ,
1048+ w2_scale = layer .w2_weight_scale ,
1049+ w1_zp = None ,
1050+ w2_zp = None ,
1051+ block_shape = [0 , self .group_size ])
0 commit comments