@@ -907,20 +907,18 @@ def process_weights_after_loading(self, layer: Module) -> None:
907907 torch .uint8 ), epilogue_tile_m ).reshape (
908908 weight_scale .shape ).view (torch .float8_e4m3fn ))
909909
910- layer .weight_scale_swizzled = Parameter (weight_scale ,
911- requires_grad = False )
910+ layer .weight_scale = Parameter (weight_scale , requires_grad = False )
912911 layer .weight = Parameter (weight , requires_grad = False )
913912 else :
914913 swizzled_weight_scale = swizzle_blockscale (layer .weight_scale )
915- layer .weight_scale_swizzled = Parameter (swizzled_weight_scale ,
916- requires_grad = False )
914+ layer .weight_scale = Parameter (swizzled_weight_scale ,
915+ requires_grad = False )
917916 layer .weight = Parameter (layer .weight .data , requires_grad = False )
918917
919918 if self .backend == "marlin" :
920919 prepare_fp4_layer_for_marlin (layer )
921920 del layer .alpha
922921 del layer .input_scale
923- del layer .weight_scale_swizzled
924922
925923 def apply (
926924 self ,
@@ -951,14 +949,14 @@ def apply(
951949 assert (x_fp4 .dtype == torch .uint8 )
952950 assert (layer .weight .dtype == torch .uint8 )
953951 assert (x_blockscale .dtype == torch .float8_e4m3fn )
954- assert (layer .weight_scale_swizzled .dtype == torch .float8_e4m3fn )
952+ assert (layer .weight_scale .dtype == torch .float8_e4m3fn )
955953 assert (layer .alpha .dtype == torch .float32 )
956954
957955 mm_args = (
958956 x_fp4 ,
959957 layer .weight ,
960958 x_blockscale ,
961- layer .weight_scale_swizzled ,
959+ layer .weight_scale ,
962960 layer .alpha ,
963961 output_dtype ,
964962 )
@@ -1320,16 +1318,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
13201318 "Weight Blockscale must be represented as FP8-E4M3" )
13211319 w13_blockscale_swizzled = swizzle_blockscale (
13221320 layer .w13_weight_scale )
1323- layer .w13_blockscale_swizzled = Parameter (w13_blockscale_swizzled ,
1324- requires_grad = False )
1321+ layer .w13_weight_scale = Parameter (w13_blockscale_swizzled ,
1322+ requires_grad = False )
13251323
13261324 assert (layer .w2_weight_scale .shape [2 ] % 16 == 0 ), (
13271325 "Expected weight_scale.dim(1) to be divisible by 16" )
13281326 assert (layer .w2_weight_scale .dtype == torch .float8_e4m3fn ), (
13291327 "Weight Blockscale must be represented as FP8-E4M3" )
13301328 w2_blockscale_swizzled = swizzle_blockscale (layer .w2_weight_scale )
1331- layer .w2_blockscale_swizzled = Parameter (w2_blockscale_swizzled ,
1332- requires_grad = False )
1329+ layer .w2_weight_scale = Parameter (w2_blockscale_swizzled ,
1330+ requires_grad = False )
13331331 layer .w2_weight = Parameter (layer .w2_weight .data ,
13341332 requires_grad = False )
13351333
@@ -1339,8 +1337,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
13391337 del layer .g2_alphas
13401338 del layer .w13_input_scale_quant
13411339 del layer .w2_input_scale_quant
1342- del layer .w13_blockscale_swizzled
1343- del layer .w2_blockscale_swizzled
13441340
13451341 def apply (
13461342 self ,
@@ -1474,8 +1470,8 @@ def apply(
14741470 activation = activation ,
14751471 global_num_experts = global_num_experts ,
14761472 expert_map = expert_map ,
1477- w1_scale = layer .w13_blockscale_swizzled ,
1478- w2_scale = layer .w2_blockscale_swizzled ,
1473+ w1_scale = layer .w13_weight_scale ,
1474+ w2_scale = layer .w2_weight_scale ,
14791475 apply_router_weight_on_input = apply_router_weight_on_input ,
14801476 )
14811477 elif (self .allow_flashinfer
@@ -1489,8 +1485,8 @@ def apply(
14891485 w2 = layer .w2_weight ,
14901486 topk_weights = topk_weights ,
14911487 topk_ids = topk_ids ,
1492- w1_scale = layer .w13_blockscale_swizzled ,
1493- w2_scale = layer .w2_blockscale_swizzled ,
1488+ w1_scale = layer .w13_weight_scale ,
1489+ w2_scale = layer .w2_weight_scale ,
14941490 g1_alphas = layer .g1_alphas ,
14951491 g2_alphas = layer .g2_alphas ,
14961492 a1_gscale = layer .w13_input_scale_quant ,
@@ -1510,8 +1506,8 @@ def apply(
15101506 a = x ,
15111507 w1_fp4 = layer .w13_weight ,
15121508 w2_fp4 = layer .w2_weight ,
1513- w1_blockscale = layer .w13_blockscale_swizzled ,
1514- w2_blockscale = layer .w2_blockscale_swizzled ,
1509+ w1_blockscale = layer .w13_weight_scale ,
1510+ w2_blockscale = layer .w2_weight_scale ,
15151511 g1_alphas = layer .g1_alphas ,
15161512 g2_alphas = layer .g2_alphas ,
15171513 a1_gscale = layer .w13_input_scale_quant ,
0 commit comments