diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index f3d265e14a..bed8421671 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -16,6 +16,7 @@ _choose_qparams_affine_tinygemm, _fake_quantize_affine, _fake_quantize_affine_cachemask, + _maybe_expand_scale_to_tensor_shape, choose_qparams_affine, dequantize_affine, quantize_affine, @@ -771,6 +772,32 @@ def test_fake_quantize_affine_cachemask(self): torch.testing.assert_close(dequantized, fake_quantized) torch.testing.assert_close(expected_mask, mask) + def test_maybe_expand_scale_to_tensor_shape(self): + # rowwise quantization: if all dimensions match except for the last one, + # and the last dimension is 1, then just return the scale as is + scale = torch.randn([3, 2, 1]) + target_shape = torch.Size([3, 2, 8]) + new_scale = _maybe_expand_scale_to_tensor_shape(scale, target_shape) + self.assertIs(scale, new_scale) + # other broadcastable shapes + scale1 = torch.randn([3, 1, 1]) + scale2 = torch.randn([1, 2, 1]) + scale3 = torch.randn([1, 1, 8]) + scale4 = torch.randn([1, 1, 1]) + new_scale1 = _maybe_expand_scale_to_tensor_shape(scale1, target_shape) + new_scale2 = _maybe_expand_scale_to_tensor_shape(scale2, target_shape) + new_scale3 = _maybe_expand_scale_to_tensor_shape(scale3, target_shape) + new_scale4 = _maybe_expand_scale_to_tensor_shape(scale4, target_shape) + self.assertIs(scale1, new_scale1) + self.assertIs(scale2, new_scale2) + self.assertIs(scale3, new_scale3) + self.assertIs(scale4, new_scale4) + # blockwise quantization: scales are repeated to fit target_shape + scale5 = torch.randn([3, 2, 2]) + new_scale5 = _maybe_expand_scale_to_tensor_shape(scale5, target_shape) + self.assertEqual(new_scale5.shape, torch.Size([3, 2, 8])) + self.assertEqual(new_scale5.unique(dim=-1).shape, torch.Size([3, 2, 2])) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index c118e0b4ce..d9b5efd7f2 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -2221,11 +2221,12 @@ def _choose_scale_float8( return scale.to(dtype=torch.float32) -def _expand_scale_to_tensor_shape( +def _maybe_expand_scale_to_tensor_shape( scale: torch.Tensor, target_shape: torch.Size ) -> torch.Tensor: """ Expand a scale tensor to match the target tensor shape for block-wise quantization. + If this is rowwise quantization, however, just return the scale as is. Args: scale (torch.Tensor): Scale tensor with shape corresponding to block structure @@ -2242,6 +2243,11 @@ def _expand_scale_to_tensor_shape( # Scalar scale - can broadcast naturally return scale + # If the scale can be broadcast as is, then we don't need to expand it + # E.g. for rowwise quantization, scale = [256, 1] and target_shape = [256, 512] + if all(a == b or a == 1 for a, b in zip(scale.shape, target_shape)): + return scale + # Calculate block sizes from shape difference if len(scale.shape) != len(target_shape): raise ValueError( @@ -2283,7 +2289,7 @@ def _quantize_affine_float8( tensor_fp32 = tensor.to(torch.float32) # Expand scale to match tensor dimensions for block-wise quantization - scale_expanded = _expand_scale_to_tensor_shape(scale, tensor.shape) + scale_expanded = _maybe_expand_scale_to_tensor_shape(scale, tensor.shape) tensor_scaled = tensor_fp32 / scale_expanded max_value = torch.finfo(float8_dtype).max @@ -2306,7 +2312,7 @@ def _dequantize_affine_float8( fp8_tensor = tensor.to(torch.float32) # Expand scale to match tensor dimensions for block-wise quantization - scale_expanded = _expand_scale_to_tensor_shape(scale, tensor.shape) + scale_expanded = _maybe_expand_scale_to_tensor_shape(scale, tensor.shape) hp_tensor = fp8_tensor * scale_expanded return hp_tensor.to(output_dtype)