@@ -518,7 +518,7 @@ def __init__(
518518 assert self .weight_quant .strategy == QuantizationStrategy .BLOCK
519519 logger .debug ("WQ = %s" , str (self .weight_quant ))
520520 self .weight_block_size = self .weight_quant .block_structure
521- # TODO: self.weight_quant.dynamic
521+ assert self .weight_quant .dynamic is not None
522522 else :
523523 self .weight_block_size = None
524524 self .block_quant = self .weight_block_size is not None
@@ -550,15 +550,9 @@ def __init__(
550550 or self .is_fp8_w8a8_sm100 )
551551 self .disable_expert_map = False
552552
553- def create_weights (
554- self ,
555- layer : torch .nn .Module ,
556- num_experts : int ,
557- hidden_size : int ,
558- intermediate_size_per_partition : int ,
559- params_dtype : torch .dtype ,
560- ** extra_weight_attrs ,
561- ):
553+ def create_weights (self , layer : torch .nn .Module , num_experts : int ,
554+ hidden_size : int , intermediate_size_per_partition : int ,
555+ params_dtype : torch .dtype , ** extra_weight_attrs ):
562556
563557 layer .intermediate_size_per_partition = intermediate_size_per_partition
564558 layer .hidden_size = hidden_size
@@ -668,6 +662,8 @@ def create_weights(
668662 {"quant_method" : FusedMoeWeightScaleSupported .BLOCK .value })
669663 set_weight_attrs (w13_weight_scale , extra_weight_attrs )
670664 set_weight_attrs (w2_weight_scale , extra_weight_attrs )
665+ layer .register_parameter ("w13_weight_scale_inv" , w13_weight_scale )
666+ layer .register_parameter ("w2_weight_scale_inv" , w2_weight_scale )
671667
672668 # INPUT_SCALES
673669 if self .static_input_scales :
@@ -690,7 +686,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
690686 # Fp8 moe kernels require a single activation scale.
691687 # We take the max of all the scales in case they differ.
692688 if self .static_input_scales :
693- # TODO(bnell): Is this assert right?
694689 assert self .input_quant .strategy == QuantizationStrategy .TENSOR
695690 if (layer .w13_input_scale is None or layer .w2_input_scale is None ):
696691 raise ValueError (
@@ -793,7 +788,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
793788 device = device ,
794789 dtype = torch .int64 )
795790
796- # XXXXXXXXXXXXX
797791 if is_deep_gemm_e8m0_used () and self .block_quant :
798792 assert layer .weight_block_size is not None
799793 # Re-quantise the expert weights so their scales are UE8M0.
@@ -810,11 +804,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
810804 )
811805
812806 # Ensure column-major TMA alignment expected by DeepGEMM.
813- if _is_col_major (layer .w13_weight_scale ):
814- layer .w13_weight_scale = get_col_major_tma_aligned_tensor (
807+ if _is_col_major (layer .w13_weight_scale_inv ):
808+ layer .w13_weight_scale_inv = get_col_major_tma_aligned_tensor (
815809 layer .w13_weight_scale )
816- if _is_col_major (layer .w2_weight_scale ):
817- layer .w2_weight_scale = get_col_major_tma_aligned_tensor (
810+ if _is_col_major (layer .w2_weight_scale_inv ):
811+ layer .w2_weight_scale_inv = get_col_major_tma_aligned_tensor (
818812 layer .w2_weight_scale )
819813
820814 def maybe_make_prepare_finalize (
0 commit comments