Skip to content

Reduce memory usage for symmetric choose_qparams #210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,16 @@ def test_choose_qparams_tensor_asym_eps(self):
eps = torch.finfo(torch.float32).eps
self.assertEqual(scale, eps)

@unittest.skipIf(not torch.cuda.is_available(), "skipping when cuda is not available")
def test_get_group_qparams_symmetric_memory(self):
"""Check the memory usage of the op"""
weight = torch.randn(1024, 1024).to(device="cuda")
original_mem_use = torch.cuda.memory_allocated()
n_bit = 4
groupsize = 128
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize)
after_choose_qparams_mem_use = torch.cuda.memory_allocated()
self.assertTrue(after_choose_qparams_mem_use < 1.2 * original_mem_use)

if __name__ == "__main__":
unittest.main()
23 changes: 11 additions & 12 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,8 @@ def choose_qparams_affine(
Tuple of scales and zero_points Tensor with requested dtype
"""
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
assert mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC], f"Unsupported mapping type: {mapping_type}"

if scale_dtype is None:
scale_dtype = input.dtype
if zero_point_dtype is None:
Expand All @@ -269,23 +271,20 @@ def choose_qparams_affine(
shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
input = input.view(shape_for_reduction)

if mapping_type == MappingType.SYMMETRIC:
amax = torch.amax(torch.abs(input), dim=reduction_dims, keepdim=False)
scale = amax / (float(quant_max - quant_min) / 2)
zero_point = torch.ones_like(scale)
zero_point *= int((quant_min + quant_max + 1) / 2)
elif mapping_type == MappingType.ASYMMETRIC:
min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)
min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)

min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))

if mapping_type == MappingType.SYMMETRIC:
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
zero_point = torch.full_like(scale, int((quant_min + quant_max + 1) / 2))
else:
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
else:
raise RuntimeError(f"Unsupported mapping type: {mapping_type}")

if eps is None:
eps = torch.finfo(input.dtype).eps
Expand Down
Loading