Skip to content

Commit d52358c

Browse files
authored
[Perf] Remove duplicated NVFP4 blockscales to save memory (#23379)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent 6ace2f7 commit d52358c

File tree

3 files changed

+30
-35
lines changed

3 files changed

+30
-35
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -246,13 +246,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
246246
return
247247

248248
# swizzle weight scales
249-
layer.w13_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale(
249+
layer.w13_weight_scale = torch.nn.Parameter(swizzle_blockscale(
250250
layer.w13_weight_scale),
251-
requires_grad=False)
251+
requires_grad=False)
252252

253-
layer.w2_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale(
253+
layer.w2_weight_scale = torch.nn.Parameter(swizzle_blockscale(
254254
layer.w2_weight_scale),
255-
requires_grad=False)
255+
requires_grad=False)
256256

257257
# w13
258258
w13_input_global_scale = layer.w13_input_global_scale.max(
@@ -383,8 +383,8 @@ def apply(
383383
activation=activation,
384384
global_num_experts=global_num_experts,
385385
expert_map=expert_map,
386-
w1_scale=layer.w13_blockscale_swizzled,
387-
w2_scale=layer.w2_blockscale_swizzled,
386+
w1_scale=layer.w13_weight_scale,
387+
w2_scale=layer.w2_weight_scale,
388388
apply_router_weight_on_input=apply_router_weight_on_input,
389389
)
390390

@@ -406,8 +406,8 @@ def apply(
406406
activation=activation,
407407
global_num_experts=global_num_experts,
408408
expert_map=expert_map,
409-
w1_scale=layer.w13_blockscale_swizzled,
410-
w2_scale=layer.w2_blockscale_swizzled,
409+
w1_scale=layer.w13_weight_scale,
410+
w2_scale=layer.w2_weight_scale,
411411
g1_alphas=layer.g1_alphas,
412412
g2_alphas=layer.g2_alphas,
413413
a1_gscale=layer.w13_input_scale_quant,
@@ -427,8 +427,8 @@ def apply(
427427
a=x,
428428
w1_fp4=layer.w13_weight,
429429
w2_fp4=layer.w2_weight,
430-
w1_blockscale=layer.w13_blockscale_swizzled,
431-
w2_blockscale=layer.w2_blockscale_swizzled,
430+
w1_blockscale=layer.w13_weight_scale,
431+
w2_blockscale=layer.w2_weight_scale,
432432
g1_alphas=layer.g1_alphas,
433433
g2_alphas=layer.g2_alphas,
434434
a1_gscale=layer.w13_input_scale_quant,

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,12 @@ def process_weights_after_loading(self, layer) -> None:
112112
torch.uint8), epilogue_tile_m).reshape(
113113
weight_scale.shape).view(torch.float8_e4m3fn))
114114

115-
layer.weight_scale_swizzled = Parameter(weight_scale,
116-
requires_grad=False)
115+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
117116
layer.weight_packed = Parameter(weight, requires_grad=False)
118117
else:
119118
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
120-
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
121-
requires_grad=False)
119+
layer.weight_scale = Parameter(swizzled_weight_scale,
120+
requires_grad=False)
122121
layer.weight_packed = Parameter(layer.weight_packed.data,
123122
requires_grad=False)
124123

@@ -136,7 +135,7 @@ def apply_weights(self,
136135
x=x,
137136
input_global_scale=layer.input_global_scale,
138137
weight=layer.weight_packed,
139-
weight_scale_swizzled=layer.weight_scale_swizzled,
138+
weight_scale_swizzled=layer.weight_scale,
140139
weight_global_scale=layer.weight_global_scale)
141140
if bias is not None:
142141
out = out + bias
@@ -149,7 +148,7 @@ def apply_weights(self,
149148
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
150149

151150
mm_args = (x_fp4, layer.weight_packed, x_blockscale,
152-
layer.weight_scale_swizzled, layer.alpha, output_dtype)
151+
layer.weight_scale, layer.alpha, output_dtype)
153152
if self.backend == "flashinfer-trtllm":
154153
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
155154
elif self.backend == "flashinfer-cutlass":

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -907,20 +907,18 @@ def process_weights_after_loading(self, layer: Module) -> None:
907907
torch.uint8), epilogue_tile_m).reshape(
908908
weight_scale.shape).view(torch.float8_e4m3fn))
909909

910-
layer.weight_scale_swizzled = Parameter(weight_scale,
911-
requires_grad=False)
910+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
912911
layer.weight = Parameter(weight, requires_grad=False)
913912
else:
914913
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
915-
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
916-
requires_grad=False)
914+
layer.weight_scale = Parameter(swizzled_weight_scale,
915+
requires_grad=False)
917916
layer.weight = Parameter(layer.weight.data, requires_grad=False)
918917

919918
if self.backend == "marlin":
920919
prepare_fp4_layer_for_marlin(layer)
921920
del layer.alpha
922921
del layer.input_scale
923-
del layer.weight_scale_swizzled
924922

925923
def apply(
926924
self,
@@ -951,14 +949,14 @@ def apply(
951949
assert (x_fp4.dtype == torch.uint8)
952950
assert (layer.weight.dtype == torch.uint8)
953951
assert (x_blockscale.dtype == torch.float8_e4m3fn)
954-
assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn)
952+
assert (layer.weight_scale.dtype == torch.float8_e4m3fn)
955953
assert (layer.alpha.dtype == torch.float32)
956954

957955
mm_args = (
958956
x_fp4,
959957
layer.weight,
960958
x_blockscale,
961-
layer.weight_scale_swizzled,
959+
layer.weight_scale,
962960
layer.alpha,
963961
output_dtype,
964962
)
@@ -1320,16 +1318,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
13201318
"Weight Blockscale must be represented as FP8-E4M3")
13211319
w13_blockscale_swizzled = swizzle_blockscale(
13221320
layer.w13_weight_scale)
1323-
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
1324-
requires_grad=False)
1321+
layer.w13_weight_scale = Parameter(w13_blockscale_swizzled,
1322+
requires_grad=False)
13251323

13261324
assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
13271325
"Expected weight_scale.dim(1) to be divisible by 16")
13281326
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
13291327
"Weight Blockscale must be represented as FP8-E4M3")
13301328
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
1331-
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
1332-
requires_grad=False)
1329+
layer.w2_weight_scale = Parameter(w2_blockscale_swizzled,
1330+
requires_grad=False)
13331331
layer.w2_weight = Parameter(layer.w2_weight.data,
13341332
requires_grad=False)
13351333

@@ -1339,8 +1337,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
13391337
del layer.g2_alphas
13401338
del layer.w13_input_scale_quant
13411339
del layer.w2_input_scale_quant
1342-
del layer.w13_blockscale_swizzled
1343-
del layer.w2_blockscale_swizzled
13441340

13451341
def apply(
13461342
self,
@@ -1474,8 +1470,8 @@ def apply(
14741470
activation=activation,
14751471
global_num_experts=global_num_experts,
14761472
expert_map=expert_map,
1477-
w1_scale=layer.w13_blockscale_swizzled,
1478-
w2_scale=layer.w2_blockscale_swizzled,
1473+
w1_scale=layer.w13_weight_scale,
1474+
w2_scale=layer.w2_weight_scale,
14791475
apply_router_weight_on_input=apply_router_weight_on_input,
14801476
)
14811477
elif (self.allow_flashinfer
@@ -1489,8 +1485,8 @@ def apply(
14891485
w2=layer.w2_weight,
14901486
topk_weights=topk_weights,
14911487
topk_ids=topk_ids,
1492-
w1_scale=layer.w13_blockscale_swizzled,
1493-
w2_scale=layer.w2_blockscale_swizzled,
1488+
w1_scale=layer.w13_weight_scale,
1489+
w2_scale=layer.w2_weight_scale,
14941490
g1_alphas=layer.g1_alphas,
14951491
g2_alphas=layer.g2_alphas,
14961492
a1_gscale=layer.w13_input_scale_quant,
@@ -1510,8 +1506,8 @@ def apply(
15101506
a=x,
15111507
w1_fp4=layer.w13_weight,
15121508
w2_fp4=layer.w2_weight,
1513-
w1_blockscale=layer.w13_blockscale_swizzled,
1514-
w2_blockscale=layer.w2_blockscale_swizzled,
1509+
w1_blockscale=layer.w13_weight_scale,
1510+
w2_blockscale=layer.w2_weight_scale,
15151511
g1_alphas=layer.g1_alphas,
15161512
g2_alphas=layer.g2_alphas,
15171513
a1_gscale=layer.w13_input_scale_quant,

0 commit comments

Comments
 (0)