@@ -323,13 +323,18 @@ def create_weights(
323323 params_dtype : torch .dtype ,
324324 ** extra_weight_attrs ,
325325 ):
326- # Currently assuming is_k_full is always True
327- # (input size per partition is the same as full input size)
328- # Supports only sym for now (no zp)
326+ intermediate_size_full = extra_weight_attrs .pop (
327+ "intermediate_size_full" )
328+
329+ self .is_k_full = (not self .quant_config .desc_act ) or (
330+ intermediate_size_per_partition == intermediate_size_full )
331+
329332 if self .quant_config .group_size != - 1 :
330333 scales_size13 = hidden_size // self .quant_config .group_size
331- scales_size2 = (intermediate_size_per_partition //
332- self .quant_config .group_size )
334+ w2_scales_size = (intermediate_size_full
335+ if self .quant_config .desc_act else
336+ intermediate_size_per_partition )
337+ scales_size2 = (w2_scales_size // self .quant_config .group_size )
333338 strategy = FusedMoeWeightScaleSupported .GROUP .value
334339 else :
335340 scales_size13 = 1
@@ -385,6 +390,9 @@ def create_weights(
385390 )
386391 layer .register_parameter ("w2_scales" , w2_scales )
387392 set_weight_attrs (w2_scales , extra_weight_attrs )
393+ # dont shard the w2 scales when running act order
394+ set_weight_attrs (w2_scales ,
395+ {"load_full_w2" : self .quant_config .desc_act })
388396 # up_proj scales
389397 w13_qzeros = torch .nn .Parameter (
390398 torch .empty (num_experts ,
@@ -406,6 +414,9 @@ def create_weights(
406414 )
407415 layer .register_parameter ("w2_qzeros" , w2_qzeros )
408416 set_weight_attrs (w2_qzeros , extra_weight_attrs )
417+ # dont shard the w2 scales when running act order
418+ set_weight_attrs (w2_qzeros ,
419+ {"load_full_w2" : self .quant_config .desc_act })
409420 w13_g_idx = torch .nn .Parameter (
410421 torch .empty (
411422 num_experts ,
@@ -575,4 +586,4 @@ def apply(
575586 sort_indices1 = layer .w13_g_idx_sort_indices ,
576587 sort_indices2 = layer .w2_g_idx_sort_indices ,
577588 num_bits = self .quant_config .quant_type .size_bits ,
578- ).to (orig_dtype )
589+ is_k_full = self . is_k_full ).to (orig_dtype )
0 commit comments