Skip to content

Commit 935ac1a

Browse files
committed
Skip expanding scales for rowwise fp8 quantize
**Summary:** #2253 added a step in `quantize_affine_float8` to expand the scales for blockwise quantization. The purpose of this step is to make the scales always broadcastable with the input tensor. However, this is unnecessary for rowwise quantization, which already has broadcastable shapes, e.g. ``` scale = [32, 1] input = [32, 16] ``` Today, we will `repeat_interleave` the above scales to pad the scale tensor until it reaches `[32, 16]`, which adds non-trivial memory and latency overhead. This commit adds a fast path to skip this expanding step if we detect rowwise quantization. **Test Plan:** ``` python test/quantization/test_quant_primitives.py -k test_maybe_expand_scale_to_tensor_shape ``` Also compared fine-tuning Qwen3-1.7B with fp8-fp8 QAT using batch size 32 on a single H100 GPU: - Before: 25.34 GB peak memory, 3047.25 tok/s - After: 22.53 GB peak memory, 3358.49 tok/s - This PR uses 11.1% less memory and is 10.2% faster
1 parent 4872c4f commit 935ac1a

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
_choose_qparams_affine_tinygemm,
1717
_fake_quantize_affine,
1818
_fake_quantize_affine_cachemask,
19+
_maybe_expand_scale_to_tensor_shape,
1920
choose_qparams_affine,
2021
dequantize_affine,
2122
quantize_affine,
@@ -771,6 +772,20 @@ def test_fake_quantize_affine_cachemask(self):
771772
torch.testing.assert_close(dequantized, fake_quantized)
772773
torch.testing.assert_close(expected_mask, mask)
773774

775+
def test_maybe_expand_scale_to_tensor_shape(self):
776+
# rowwise quantization: if all dimensions match except for the last one,
777+
# and the last dimension is 1, then just return the scale as is
778+
scale = torch.randn([3, 2, 1])
779+
target_shape = torch.Size([3, 2, 8])
780+
new_scale = _maybe_expand_scale_to_tensor_shape(scale, target_shape)
781+
self.assertIs(scale, new_scale)
782+
# blockwise quantization: scales are repeated to fit target_shape
783+
scale = torch.randn([3, 2, 2])
784+
target_shape = torch.Size([3, 2, 8])
785+
new_scale = _maybe_expand_scale_to_tensor_shape(scale, target_shape)
786+
self.assertEqual(new_scale.shape, torch.Size([3, 2, 8]))
787+
self.assertEqual(new_scale.unique(dim=-1).shape, torch.Size([3, 2, 2]))
788+
774789

775790
if __name__ == "__main__":
776791
unittest.main()

torchao/quantization/quant_primitives.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2221,11 +2221,12 @@ def _choose_scale_float8(
22212221
return scale.to(dtype=torch.float32)
22222222

22232223

2224-
def _expand_scale_to_tensor_shape(
2224+
def _maybe_expand_scale_to_tensor_shape(
22252225
scale: torch.Tensor, target_shape: torch.Size
22262226
) -> torch.Tensor:
22272227
"""
22282228
Expand a scale tensor to match the target tensor shape for block-wise quantization.
2229+
If this is rowwise quantization, however, just return the scale as is.
22292230
22302231
Args:
22312232
scale (torch.Tensor): Scale tensor with shape corresponding to block structure
@@ -2242,6 +2243,10 @@ def _expand_scale_to_tensor_shape(
22422243
# Scalar scale - can broadcast naturally
22432244
return scale
22442245

2246+
# For rowwise quantization, just return the scale as is
2247+
if scale.shape[:-1] == target_shape[:-1] and scale.shape[-1] == 1:
2248+
return scale
2249+
22452250
# Calculate block sizes from shape difference
22462251
if len(scale.shape) != len(target_shape):
22472252
raise ValueError(
@@ -2283,7 +2288,7 @@ def _quantize_affine_float8(
22832288
tensor_fp32 = tensor.to(torch.float32)
22842289

22852290
# Expand scale to match tensor dimensions for block-wise quantization
2286-
scale_expanded = _expand_scale_to_tensor_shape(scale, tensor.shape)
2291+
scale_expanded = _maybe_expand_scale_to_tensor_shape(scale, tensor.shape)
22872292

22882293
tensor_scaled = tensor_fp32 / scale_expanded
22892294
max_value = torch.finfo(float8_dtype).max
@@ -2306,7 +2311,7 @@ def _dequantize_affine_float8(
23062311
fp8_tensor = tensor.to(torch.float32)
23072312

23082313
# Expand scale to match tensor dimensions for block-wise quantization
2309-
scale_expanded = _expand_scale_to_tensor_shape(scale, tensor.shape)
2314+
scale_expanded = _maybe_expand_scale_to_tensor_shape(scale, tensor.shape)
23102315

23112316
hp_tensor = fp8_tensor * scale_expanded
23122317
return hp_tensor.to(output_dtype)

0 commit comments

Comments
 (0)