Skip to content

Commit bf881f6

Browse files
committed
Add bfloat16 cast in scales
1 parent d8c8cd7 commit bf881f6

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
147147
w3_weight_scale = layer.w13_weight_scale[:, intermediate_size:]
148148

149149
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
150-
w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
150+
w2_weight_scale = t2j(layer.w2_weight_scale.to(torch.bfloat16), use_dlpack=False)
151151
w1_weight = t2j(w1_weight, use_dlpack=False)
152-
w1_weight_scale = t2j(w1_weight_scale, use_dlpack=False)
152+
w1_weight_scale = t2j(w1_weight_scale.to(torch.bfloat16), use_dlpack=False)
153153
w3_weight = t2j(w3_weight, use_dlpack=False)
154-
w3_weight_scale = t2j(w3_weight_scale, use_dlpack=False)
154+
w3_weight_scale = t2j(w3_weight_scale.to(torch.bfloat16), use_dlpack=False)
155155

156156
if layer.use_ep:
157157
format = Format(

0 commit comments

Comments
 (0)