Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
_choose_qparams_affine_tinygemm,
_fake_quantize_affine,
_fake_quantize_affine_cachemask,
_maybe_expand_scale_to_tensor_shape,
choose_qparams_affine,
dequantize_affine,
quantize_affine,
Expand Down Expand Up @@ -771,6 +772,32 @@ def test_fake_quantize_affine_cachemask(self):
torch.testing.assert_close(dequantized, fake_quantized)
torch.testing.assert_close(expected_mask, mask)

def test_maybe_expand_scale_to_tensor_shape(self):
# rowwise quantization: if all dimensions match except for the last one,
# and the last dimension is 1, then just return the scale as is
scale = torch.randn([3, 2, 1])
target_shape = torch.Size([3, 2, 8])
new_scale = _maybe_expand_scale_to_tensor_shape(scale, target_shape)
self.assertIs(scale, new_scale)
# other broadcastable shapes
scale1 = torch.randn([3, 1, 1])
scale2 = torch.randn([1, 2, 1])
scale3 = torch.randn([1, 1, 8])
scale4 = torch.randn([1, 1, 1])
new_scale1 = _maybe_expand_scale_to_tensor_shape(scale1, target_shape)
new_scale2 = _maybe_expand_scale_to_tensor_shape(scale2, target_shape)
new_scale3 = _maybe_expand_scale_to_tensor_shape(scale3, target_shape)
new_scale4 = _maybe_expand_scale_to_tensor_shape(scale4, target_shape)
self.assertIs(scale1, new_scale1)
self.assertIs(scale2, new_scale2)
self.assertIs(scale3, new_scale3)
self.assertIs(scale4, new_scale4)
# blockwise quantization: scales are repeated to fit target_shape
scale5 = torch.randn([3, 2, 2])
new_scale5 = _maybe_expand_scale_to_tensor_shape(scale5, target_shape)
self.assertEqual(new_scale5.shape, torch.Size([3, 2, 8]))
self.assertEqual(new_scale5.unique(dim=-1).shape, torch.Size([3, 2, 2]))


if __name__ == "__main__":
unittest.main()
12 changes: 9 additions & 3 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -2221,11 +2221,12 @@ def _choose_scale_float8(
return scale.to(dtype=torch.float32)


def _expand_scale_to_tensor_shape(
def _maybe_expand_scale_to_tensor_shape(
scale: torch.Tensor, target_shape: torch.Size
) -> torch.Tensor:
"""
Expand a scale tensor to match the target tensor shape for block-wise quantization.
If this is rowwise quantization, however, just return the scale as is.

Args:
scale (torch.Tensor): Scale tensor with shape corresponding to block structure
Expand All @@ -2242,6 +2243,11 @@ def _expand_scale_to_tensor_shape(
# Scalar scale - can broadcast naturally
return scale

# If the scale can be broadcast as is, then we don't need to expand it
# E.g. for rowwise quantization, scale = [256, 1] and target_shape = [256, 512]
if all(a == b or a == 1 for a, b in zip(scale.shape, target_shape)):
return scale

# Calculate block sizes from shape difference
if len(scale.shape) != len(target_shape):
raise ValueError(
Expand Down Expand Up @@ -2283,7 +2289,7 @@ def _quantize_affine_float8(
tensor_fp32 = tensor.to(torch.float32)

# Expand scale to match tensor dimensions for block-wise quantization
scale_expanded = _expand_scale_to_tensor_shape(scale, tensor.shape)
scale_expanded = _maybe_expand_scale_to_tensor_shape(scale, tensor.shape)

tensor_scaled = tensor_fp32 / scale_expanded
max_value = torch.finfo(float8_dtype).max
Expand All @@ -2306,7 +2312,7 @@ def _dequantize_affine_float8(
fp8_tensor = tensor.to(torch.float32)

# Expand scale to match tensor dimensions for block-wise quantization
scale_expanded = _expand_scale_to_tensor_shape(scale, tensor.shape)
scale_expanded = _maybe_expand_scale_to_tensor_shape(scale, tensor.shape)

hp_tensor = fp8_tensor * scale_expanded
return hp_tensor.to(output_dtype)
Expand Down
Loading