Skip to content

Commit 28bd36c

Browse files
committed
deduplicate code for get_group_qparams_symmetric
Summary: This just removes the implementation, we can have follow up PRs to remove the call all together after we have replaced all implementation with the new blockwise quant code Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
1 parent f05c215 commit 28bd36c

File tree

1 file changed

+32
-19
lines changed

1 file changed

+32
-19
lines changed

torchao/quantization/quant_primitives.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -729,25 +729,38 @@ def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float
729729
assert w.shape[-1] % groupsize == 0
730730
assert w.dim() == 2
731731

732-
to_quant = w.reshape(-1, groupsize)
733-
assert torch.isnan(to_quant).sum() == 0
734-
735-
max_val = to_quant.amax(dim=1, keepdim=True)
736-
min_val = to_quant.amin(dim=1, keepdim=True)
737-
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
738-
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
739-
740-
max_val_abs = torch.max(-min_val_neg, max_val_pos)
741-
max_int = 2 ** (n_bit - 1) - 1
742-
min_int = -(2 ** (n_bit - 1))
743-
744-
scales = max_val_abs / (float(max_int - min_int) / 2)
745-
scales = torch.max(scales, torch.full_like(scales, torch.finfo(torch.float32).eps))
746-
# TODO: make sure abs(scales) is not too small?
747-
zeros = torch.full_like(scales, 0)
748-
return scales.to(precision).reshape(w.shape[0], -1), zeros.to(precision).reshape(
749-
w.shape[0], -1
750-
)
732+
block_size = (w.shape[0], groupsize)
733+
mapping_type = MappingType.SYMMETRIC
734+
eps = torch.finfo(torch.float32).eps
735+
if TORCH_VERSION_AFTER_2_3:
736+
bit_to_dtype = {
737+
1: torch.uint1,
738+
2: torch.uint2,
739+
3: torch.uint3,
740+
4: torch.uint4,
741+
5: torch.uint5,
742+
6: torch.uint6,
743+
7: torch.uint7,
744+
8: torch.uint8,
745+
}
746+
assert n_bit in ranges, f"unsupported bit: {n_bit}"
747+
target_dtype = bit_to_dtype[n_bit]
748+
return choose_qparams_affine(w, mapping_type, block_size, target_dtype=target_dtype, eps=eps, scale_dtype=precision, zero_point_dtype=precision)
749+
else:
750+
ranges = {
751+
1: (0, 2**1-1),
752+
2: (0, 2**2-1),
753+
3: (0, 2**3-1),
754+
4: (0, 2**4-1),
755+
5: (0, 2**5-1),
756+
6: (0, 2**6-1),
757+
7: (0, 2**7-1),
758+
8: (0, 2**8-1),
759+
}
760+
assert n_bit in ranges, f"unsupported bit: {n_bit}"
761+
quant_min, quant_max = ranges[n_bit]
762+
# using uint8 to simulate uint4
763+
return choose_qparams_affine(w, mapping_type, block_size, target_dtype=torch.uint8, quant_min=0, quant_max=15, eps=eps, scale_dtype=precision, zero_point_dtype=precision)
751764

752765

753766
if TORCH_VERSION_AFTER_2_3:

0 commit comments

Comments
 (0)