Skip to content

Commit bdf1210

Browse files
committed
Align _choose_qparams_affine with _choose_scale_float8 behavior
Changes keepdim default from False to True in _choose_qparams_affine to match _choose_scale_float8 behavior. This ensures scale/zero_point maintain the same rank as input tensor, making downstream handling more consistent. Fixes #3324
1 parent aa21b80 commit bdf1210

File tree

3 files changed

+36
-16
lines changed

3 files changed

+36
-16
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -302,11 +302,15 @@ def test_choose_qparams_token_asym(self):
302302
input, dtype
303303
)
304304
)
305+
# With keepdim=True, scale and zero_point now keep dimensions
306+
# Match reference shapes for comparison
305307
scale_ref = scale_ref.squeeze()
306308
zp_ref = zp_ref.squeeze()
309+
scale_squeezed = scale.squeeze()
310+
zp_squeezed = zero_point.squeeze()
307311

308-
torch.testing.assert_close(scale, scale_ref, atol=10e-3, rtol=10e-3)
309-
self.assertTrue(torch.equal(zero_point, zp_ref))
312+
torch.testing.assert_close(scale_squeezed, scale_ref, atol=10e-3, rtol=10e-3)
313+
self.assertTrue(torch.equal(zp_squeezed, zp_ref))
310314

311315
@unittest.skipIf(is_fbcode(), "broken in fbcode")
312316
def test_choose_qparams_tensor_asym(self):
@@ -324,11 +328,14 @@ def test_choose_qparams_tensor_asym(self):
324328
scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams(
325329
input, quant_min, quant_max, eps, dtype
326330
)
331+
# With keepdim=True, scale and zero_point now keep dimensions
327332
scale_ref = scale_ref.squeeze()
328333
zp_ref = zp_ref.squeeze()
334+
scale_squeezed = scale.squeeze()
335+
zp_squeezed = zero_point.squeeze()
329336

330-
self.assertTrue(torch.equal(scale, scale_ref))
331-
self.assertTrue(torch.equal(zero_point, zp_ref))
337+
self.assertTrue(torch.equal(scale_squeezed, scale_ref))
338+
self.assertTrue(torch.equal(zp_squeezed, zp_ref))
332339

333340
@unittest.skipIf(is_fbcode(), "broken in fbcode")
334341
def test_choose_qparams_tensor_sym(self):
@@ -346,11 +353,14 @@ def test_choose_qparams_tensor_sym(self):
346353
scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_symmetric(
347354
input, quant_min, quant_max, eps, dtype
348355
)
356+
# With keepdim=True, scale and zero_point now keep dimensions
349357
scale_ref = scale_ref.squeeze()
350358
zp_ref = zp_ref.squeeze()
359+
scale_squeezed = scale.squeeze()
360+
zp_squeezed = zero_point.squeeze()
351361

352-
self.assertTrue(torch.equal(scale, scale_ref))
353-
self.assertTrue(torch.equal(zero_point, zp_ref))
362+
self.assertTrue(torch.equal(scale_squeezed, scale_ref))
363+
self.assertTrue(torch.equal(zp_squeezed, zp_ref))
354364

355365
def test_quantize_activation_per_token_abs_max(self):
356366
input = torch.randn(10, 10)

torchao/quantization/quant_primitives.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,7 +1217,7 @@ def choose_qparams_affine(
12171217
eps: Optional[float] = None,
12181218
scale_dtype: Optional[torch.dtype] = None,
12191219
zero_point_dtype: Optional[torch.dtype] = torch.int32,
1220-
keepdim: bool = False,
1220+
keepdim: bool = True,
12211221
) -> Tuple[torch.Tensor, torch.Tensor]:
12221222
"""
12231223
Args:
@@ -1231,6 +1231,7 @@ def choose_qparams_affine(
12311231
eps (Optional[float]): minimum scale, if not provided, default to eps of input.dtype
12321232
scale_dtype (torch.dtype): dtype for scale Tensor
12331233
zero_point_dtype (torch.dtype): dtype for zero_point Tensor, defaults to torch.int32
1234+
keepdim (bool): whether to keep dimensions with size 1 in output (aligned with _choose_scale_float8)
12341235
Now removed params:
12351236
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, defaults to Integer or None
12361237
preserve_zero (bool): whether to preserve zero in the quantized Tensor, defaults to True
@@ -1523,7 +1524,7 @@ def _choose_qparams_affine(
15231524
eps: Optional[float] = None,
15241525
scale_dtype: Optional[torch.dtype] = None,
15251526
zero_point_dtype: Optional[torch.dtype] = None,
1526-
keepdim: bool = False,
1527+
keepdim: bool = True,
15271528
) -> Tuple[torch.Tensor, torch.Tensor]:
15281529
"""op definition that has compatible signatures with custom op library
15291530
@@ -1532,6 +1533,10 @@ def _choose_qparams_affine(
15321533
2. find min_val/max_val based on the dimension for reduction
15331534
3. calculate quantization parameters based on min_val/max_val based on args like `preserve_zero`
15341535
and `zero_point_domain`
1536+
1537+
Note:
1538+
keepdim defaults to True to align with _choose_scale_float8 behavior. This ensures
1539+
scale/zero_point maintain the same rank as input, making it easier to handle downstream.
15351540
"""
15361541
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
15371542
assert mapping_type in [
@@ -1548,6 +1553,8 @@ def _choose_qparams_affine(
15481553
assert len(block_size) == input.dim(), (
15491554
f"Got input dim:{input.dim()}, block_size: {block_size}"
15501555
)
1556+
# Save original input size before reshaping for later use
1557+
original_input_size = input.size()
15511558
shape_for_reduction, reduction_dims = _get_reduction_params(
15521559
block_size, input.size()
15531560
)
@@ -1591,6 +1598,15 @@ def _choose_qparams_affine(
15911598
if zero_point_dtype is None:
15921599
zero_point_dtype = torch.int32
15931600

1601+
# Reshape scale and zero_point to match expected output shape
1602+
# This aligns with _choose_scale_float8 behavior
1603+
if keepdim:
1604+
output_shape = [
1605+
original_input_size[i] // block_size[i] for i in range(len(block_size))
1606+
]
1607+
scale = scale.reshape(output_shape)
1608+
zero_point = zero_point.reshape(output_shape)
1609+
15941610
return scale.to(dtype=scale_dtype, device=input.device), zero_point.to(
15951611
dtype=zero_point_dtype
15961612
)

torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -244,14 +244,8 @@ def from_hp(
244244
f"Unsupported IntxChooseQParamsAlgorithm: {intx_choose_qparams_algorithm}"
245245
)
246246

247-
# Reshape scale and zero_point to be compatible with block_size
248-
# This is asserted in IntxUnpackedToInt8Tensor's __init__
249-
n_blocks = []
250-
for i in range(len(block_size)):
251-
assert qdata.shape[i] % block_size[i] == 0
252-
n_blocks.append(qdata.shape[i] // block_size[i])
253-
scale = scale.reshape(*n_blocks)
254-
zero_point = zero_point.reshape(*n_blocks)
247+
# Note: scale and zero_point already have the correct shape from choose_qparams_affine
248+
# which now uses keepdim=True and reshapes to match block_size expectations
255249

256250
return IntxUnpackedToInt8Tensor(
257251
qdata=qdata,

0 commit comments

Comments
 (0)