diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 8547532b78..139116def2 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -240,6 +240,16 @@ def test_choose_qparams_tensor_asym_eps(self): eps = torch.finfo(torch.float32).eps self.assertEqual(scale, eps) + @unittest.skipIf(not torch.cuda.is_available(), "skipping when cuda is not available") + def test_get_group_qparams_symmetric_memory(self): + """Check the memory usage of the op""" + weight = torch.randn(1024, 1024).to(device="cuda") + original_mem_use = torch.cuda.memory_allocated() + n_bit = 4 + groupsize = 128 + (scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize) + after_choose_qparams_mem_use = torch.cuda.memory_allocated() + self.assertTrue(after_choose_qparams_mem_use < 1.2 * original_mem_use) if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 42f32a9c78..4d6a7666aa 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -260,6 +260,8 @@ def choose_qparams_affine( Tuple of scales and zero_points Tensor with requested dtype """ quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) + assert mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC], f"Unsupported mapping type: {mapping_type}" + if scale_dtype is None: scale_dtype = input.dtype if zero_point_dtype is None: @@ -269,23 +271,20 @@ def choose_qparams_affine( shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) input = input.view(shape_for_reduction) - if mapping_type == MappingType.SYMMETRIC: - amax = torch.amax(torch.abs(input), dim=reduction_dims, keepdim=False) - scale = amax / (float(quant_max - quant_min) / 2) - zero_point = torch.ones_like(scale) - zero_point *= int((quant_min + quant_max + 1) / 2) - elif mapping_type == MappingType.ASYMMETRIC: - min_val = torch.amin(input, dim=reduction_dims, keepdim=False) - max_val = torch.amax(input, dim=reduction_dims, keepdim=False) + min_val = torch.amin(input, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input, dim=reduction_dims, keepdim=False) - min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) - max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + if mapping_type == MappingType.SYMMETRIC: + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + zero_point = torch.full_like(scale, int((quant_min + quant_max + 1) / 2)) + else: scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) zero_point = quant_min - torch.round(min_val_neg / scale) zero_point = torch.clamp(zero_point, quant_min, quant_max) - else: - raise RuntimeError(f"Unsupported mapping type: {mapping_type}") if eps is None: eps = torch.finfo(input.dtype).eps