diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 90fd8f8bf0..87c12c60d5 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -279,5 +279,24 @@ def test_get_group_qparams_symmetric_memory(self): after_choose_qparams_mem_use = torch.cuda.memory_allocated() self.assertTrue(after_choose_qparams_mem_use < 1.2 * original_mem_use) + def test_raises(self): + """Make sure some errors are raised when user requested an unsupported type of quantization + """ + input = torch.randn(10, 10) + mapping_type = MappingType.ASYMMETRIC + dtype = torch.int8 + block_size = (10, 10) + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype) + + + # make sure we can't quantize int32 tensors: + with self.assertRaisesRegex(AssertionError, "Unsupported input dtype:"): + _ = quantize_affine(input.to(torch.int32), block_size, scale, zero_point, dtype) + + # block_size and scale/zero_point shape mismatch + block_size = (1, 1) + with self.assertRaisesRegex(RuntimeError, "is invalid for input of size 1"): + _ = quantize_affine(input, block_size, scale, zero_point, dtype) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index f59144becd..b435d5a893 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -145,9 +145,9 @@ def quantize_affine( ): """ Args: - input (torch.Tensor): original float32 or bfloat16 Tensor + input (torch.Tensor): original float32, float16 or bfloat16 Tensor block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam - e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization scale (float): quantization parameter for affine quantization zero_point (int): quantization parameter for affine quantization output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor @@ -171,6 +171,8 @@ def quantize_affine( quantized tensor with requested dtype """ # TODO: validations + # TODO: validate scale/zero_point dimensions are compatible with block_size + assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported input dtype: {input.dtype}" quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max) shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) original_shape = input.shape @@ -198,7 +200,7 @@ def dequantize_affine( quant_min: Optional[int] = None, quant_max: Optional[int] = None, *, - output_dtype: Optional[torch.dtype] = None, + output_dtype: torch.dtype = torch.float32, ): """ Args: @@ -210,13 +212,15 @@ def dequantize_affine( dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor quant_min (Optional[int]): minimum quantized value for input Tensor quant_max (Optional[int]): maximum quantized value for input Tensor - output_dtype (torch.dtype?): optional dtype for output Tensor, default is fp32 + output_dtype (torch.dtype): dtype for output Tensor, default is fp32 Output: dequantized Tensor, with requested dtype or fp32 """ # TODO: validations + # TODO: validate scale/zero_point dimensions are compatible with block_size assert input.dtype == input_dtype + assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}" quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max) shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) @@ -229,9 +233,10 @@ def dequantize_affine( if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) - dequant = input.to(output_dtype) + dequant = input.to(torch.int32) if zero_point is not None: - dequant -= zero_point + dequant -= zero_point.to(torch.int32) + dequant = dequant.to(output_dtype) dequant *= scale dequant = dequant.view(original_shape) return dequant.to(output_dtype)