Skip to content

Commit 3d4cb8d

Browse files
committed
Dequantize fp8 rowwise in backward
1 parent 19500bf commit 3d4cb8d

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,20 @@ def _float8_linear_impl(
286286
f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
287287
)
288288

289+
# TODO: make this better
290+
# During the backward pass, we transpose the weight tensor,
291+
# so if the weight tensor was originally rowwise quantized,
292+
# now it becomes colwise. In this case, simply dequantize
293+
# the tensor and do a bf16 matmul
294+
is_backward = (
295+
weight_tensor.block_size[0] == weight_tensor.shape[0] and
296+
weight_tensor.block_size[1] == 1
297+
)
298+
if is_backward:
299+
return torch.nn.functional.linear(
300+
input_tensor, weight_tensor.dequantize(), bias,
301+
)
302+
289303
act_quant_kwargs = weight_tensor.act_quant_kwargs
290304
# quantizing activation, if `act_quant_kwargs` is specified
291305
if act_quant_kwargs is not None:
@@ -321,8 +335,7 @@ def _float8_linear_impl(
321335
wq = weight_tensor.qdata
322336
x_scale = input_tensor.scale
323337
w_scale = weight_tensor.scale
324-
# TODO: fix this?
325-
if True: # _is_rowwise_scaled(weight_tensor):
338+
if _is_rowwise_scaled(weight_tensor):
326339
assert _is_rowwise_scaled(input_tensor), (
327340
"Input tensor must be rowwise block size"
328341
)

0 commit comments

Comments
 (0)