|
16 | 16 | _choose_qparams_affine_tinygemm, |
17 | 17 | _fake_quantize_affine, |
18 | 18 | _fake_quantize_affine_cachemask, |
| 19 | + _maybe_expand_scale_to_tensor_shape, |
19 | 20 | choose_qparams_affine, |
20 | 21 | dequantize_affine, |
21 | 22 | quantize_affine, |
@@ -771,6 +772,32 @@ def test_fake_quantize_affine_cachemask(self): |
771 | 772 | torch.testing.assert_close(dequantized, fake_quantized) |
772 | 773 | torch.testing.assert_close(expected_mask, mask) |
773 | 774 |
|
| 775 | + def test_maybe_expand_scale_to_tensor_shape(self): |
| 776 | + # rowwise quantization: if all dimensions match except for the last one, |
| 777 | + # and the last dimension is 1, then just return the scale as is |
| 778 | + scale = torch.randn([3, 2, 1]) |
| 779 | + target_shape = torch.Size([3, 2, 8]) |
| 780 | + new_scale = _maybe_expand_scale_to_tensor_shape(scale, target_shape) |
| 781 | + self.assertIs(scale, new_scale) |
| 782 | + # other broadcastable shapes |
| 783 | + scale1 = torch.randn([3, 1, 1]) |
| 784 | + scale2 = torch.randn([1, 2, 1]) |
| 785 | + scale3 = torch.randn([1, 1, 8]) |
| 786 | + scale4 = torch.randn([1, 1, 1]) |
| 787 | + new_scale1 = _maybe_expand_scale_to_tensor_shape(scale1, target_shape) |
| 788 | + new_scale2 = _maybe_expand_scale_to_tensor_shape(scale2, target_shape) |
| 789 | + new_scale3 = _maybe_expand_scale_to_tensor_shape(scale3, target_shape) |
| 790 | + new_scale4 = _maybe_expand_scale_to_tensor_shape(scale4, target_shape) |
| 791 | + self.assertIs(scale1, new_scale1) |
| 792 | + self.assertIs(scale2, new_scale2) |
| 793 | + self.assertIs(scale3, new_scale3) |
| 794 | + self.assertIs(scale4, new_scale4) |
| 795 | + # blockwise quantization: scales are repeated to fit target_shape |
| 796 | + scale5 = torch.randn([3, 2, 2]) |
| 797 | + new_scale5 = _maybe_expand_scale_to_tensor_shape(scale5, target_shape) |
| 798 | + self.assertEqual(new_scale5.shape, torch.Size([3, 2, 8])) |
| 799 | + self.assertEqual(new_scale5.unique(dim=-1).shape, torch.Size([3, 2, 2])) |
| 800 | + |
774 | 801 |
|
775 | 802 | if __name__ == "__main__": |
776 | 803 | unittest.main() |
0 commit comments