Skip to content

Expose zero_point_domain as arguments #1401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
int8_dynamic_activation_int8_weight,
int8_weight_only,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
is_sm_at_least_89,
)


def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"):
def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda", int4_zp_int: bool = False):
base_functions = [
int8_weight_only(),
int8_dynamic_activation_int4_weight(),
Expand All @@ -36,6 +36,10 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cu
base_functions.append(
int4_weight_only(group_size=32, layout=Int4CPULayout())
)
if int4_zp_int:
base_functions.append(
int4_weight_only(group_size=32, layout=Int4CPULayout(), zero_point_domain=ZeroPointDomain.INT)
)
else:
base_functions.append(int4_weight_only(group_size=32))

Expand Down Expand Up @@ -71,7 +75,7 @@ def test_tensor_core_layout_transpose(self):
self.assertEqual(aqt_shape, shape)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True, "cuda", True))
def test_weights_only(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(linear)
Expand Down
149 changes: 93 additions & 56 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def check_idempotent(self, fn, *args, **kwargs):


# Legacy tinygemm ops
def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16, zero_point_domain=ZeroPointDomain.FLOAT):
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
assert groupsize > 1
Expand All @@ -70,11 +70,18 @@ def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat1
max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
quant_min = 0
quant_max = max_int
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
return scales.to(dtype=dtype).reshape(w.shape[0], -1), zeros.to(
dtype=dtype
).reshape(w.shape[0], -1)
if (zero_point_domain == ZeroPointDomain.FLOAT):
zeros = min_val + scales * (2 ** (n_bit - 1))
zeros = zeros.to(dtype=dtype).reshape(w.shape[0], -1)
else:
zeros = quant_min - torch.round(min_val / scales)
zeros = torch.clamp(zeros, quant_min, quant_max)
zeros = zeros.to(dtype=dtype).reshape(w.shape[0], -1)
scales = scales.to(dtype=dtype).reshape(w.shape[0], -1)
return scales, zeros


def _groupwise_affine_quantize_tensor_from_qparams(
Expand All @@ -83,8 +90,10 @@ def _groupwise_affine_quantize_tensor_from_qparams(
zeros,
n_bit=4,
groupsize=128,
zero_point_domain=ZeroPointDomain.FLOAT
):
assert groupsize > 1
assert n_bit == 4
# needed for GPTQ single column quantize
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
groupsize = w.shape[-1]
Expand All @@ -97,17 +106,28 @@ def _groupwise_affine_quantize_tensor_from_qparams(

scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
min_val = zeros - scales * (2 ** (n_bit - 1))
max_int = 2**n_bit - 1
min_int = 0
w_int4x8 = (
to_quant.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)
if zero_point_domain == ZeroPointDomain.FLOAT:
min_val = zeros - scales * (2 ** (n_bit - 1))
w_int4x8 = (
to_quant.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)
else:
w_int4x8 = (
to_quant.div(scales)
.round()
.add(zeros)
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)

if TORCH_VERSION_AT_LEAST_2_5:
if not (is_device(w.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6):
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
Expand All @@ -121,6 +141,7 @@ def _groupwise_affine_dequantize_tensor_from_qparams(
zeros,
n_bit=4,
groupsize=128,
zero_point_domain=ZeroPointDomain.FLOAT
):
assert groupsize > 1
# needed for GPTQ single column dequantize
Expand All @@ -133,12 +154,19 @@ def _groupwise_affine_dequantize_tensor_from_qparams(
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)

w_dq = (
w_int4x8_grouped.sub(2 ** (n_bit - 1))
.mul(scales)
.add(zeros)
.reshape_as(w_int4x8)
)
if zero_point_domain == ZeroPointDomain.FLOAT:
w_dq = (
w_int4x8_grouped.sub(2 ** (n_bit - 1))
.mul(scales)
.add(zeros)
.reshape_as(w_int4x8)
)
else:
w_dq = (
w_int4x8_grouped.sub(zeros)
.mul(scales)
.reshape_as(w_int4x8)
)
return w_dq


Expand Down Expand Up @@ -650,10 +678,8 @@ def test_not_preserve_zero_not_supported(self):
def test_get_groupwise_affine_qparams(self):
input = torch.randn(10, 256)
n_bit = 4
scale_ref, zero_point_ref = _get_groupwise_affine_qparams(
input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16
)

zero_point_domains = [ZeroPointDomain.FLOAT, ZeroPointDomain.INT]
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (1, 128)
Expand All @@ -662,19 +688,24 @@ def test_get_groupwise_affine_qparams(self):
eps = 1e-6
scale_dtype = torch.bfloat16
zero_point_dtype = torch.bfloat16
scale, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=False,
zero_point_domain=ZeroPointDomain.FLOAT,
)
for zero_point_domain in zero_point_domains:
scale_ref, zero_point_ref = _get_groupwise_affine_qparams(
input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16,
zero_point_domain=zero_point_domain
)
scale, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=zero_point_domain == ZeroPointDomain.INT,
zero_point_domain=zero_point_domain,
)

self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zero_point_ref))
Expand All @@ -686,14 +717,15 @@ def test_groupwise_affine_quantize_tensor_from_qparams(self):
n_bit = 4
groupsize = 128

w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize
)
w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize
)
for zero_point_domain in [ZeroPointDomain.FLOAT, ZeroPointDomain.INT]:
w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize, zero_point_domain
)
w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize, zero_point_domain
)

self.assertTrue(torch.equal(w_int4x8, w_int4x8_ref))
self.assertTrue(torch.equal(w_int4x8, w_int4x8_ref))

def test_groupwise_affine_dequantize_tensor_from_qparams(self):
input = torch.randint(0, 15, (10, 256), dtype=torch.int32)
Expand All @@ -702,20 +734,25 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
n_bit = 4
groupsize = 128

if TORCH_VERSION_AT_LEAST_2_5:
input_tmp = input
if not (is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6):
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
input_tmp, scales, zeros, n_bit, groupsize
)
else:
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize
for zero_point_domain in [ZeroPointDomain.FLOAT, ZeroPointDomain.INT]:
if zero_point_domain == ZeroPointDomain.INT:
zeros = torch.randint(0, 15, (10, 2), dtype=torch.int32)
if TORCH_VERSION_AT_LEAST_2_5:
input_tmp = input
if not (is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6):
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain
)
else:
if zero_point_domain == ZeroPointDomain.INT:
continue
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize
)
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize, zero_point_domain
)
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize
)

self.assertTrue(torch.equal(w_bf16, w_bf16_ref))

Expand Down
29 changes: 24 additions & 5 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
PlainLayout,
SemiSparseLayout,
TensorCoreTiledLayout,
Int4CPULayout,
UintxLayout,
to_affine_quantized_floatx,
to_affine_quantized_floatx_static,
Expand Down Expand Up @@ -110,6 +111,18 @@
"Int8DynActInt4WeightGPTQQuantizer",
]

# update according to the support matrix
LAYOUT_TO_ZERO_POINT_DOMAIN = {
TensorCoreTiledLayout: [ZeroPointDomain.FLOAT],
MarlinSparseLayout: [ZeroPointDomain.INT],
Int4CPULayout: [ZeroPointDomain.FLOAT]
}

LAYOUT_TO_PRESERVE_ZEROS = {
TensorCoreTiledLayout: False,
MarlinSparseLayout: True,
Int4CPULayout: False
}

######
# TO BE DEPRECATED START
Expand Down Expand Up @@ -630,7 +643,8 @@ def int8_dynamic_activation_int4_weight(


def int4_weight_only(
group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False
group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False,
zero_point_domain=None
):
"""
Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please update docs for zero_point_domain before landing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update in 51a4505

Expand Down Expand Up @@ -665,17 +679,22 @@ def apply_int4_weight_only_quant(weight):
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT

nonlocal zero_point_domain
assert type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys(), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}"
if zero_point_domain is None:
# the first value is the default one
zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0]
preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)]
else:
assert zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)], f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}"

# Sparse Marlin only supports symmetric quantization.
# NOTE: If we start having lots of layouts that require different configurations,
# we should consider moving this logic somewhere else.
if isinstance(layout, MarlinSparseLayout):
mapping_type = MappingType.SYMMETRIC
preserve_zero = True
zero_point_domain = ZeroPointDomain.INT
assert (
group_size == 128 or group_size == weight.shape[-1]
), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}"
Expand Down
6 changes: 4 additions & 2 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def groupwise_affine_quantize_tensor_from_qparams(
zeros,
n_bit=4,
groupsize=128,
zero_point_domain=ZeroPointDomain.FLOAT
):
assert groupsize > 1
# needed for GPTQ single column quantize
Expand All @@ -400,7 +401,7 @@ def groupwise_affine_quantize_tensor_from_qparams(
output_dtype,
quant_min,
quant_max,
zero_point_domain=ZeroPointDomain.FLOAT,
zero_point_domain=zero_point_domain
)
if TORCH_VERSION_AT_LEAST_2_5 and w.shape[-1] > 1:
if not (is_device(int_data.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6):
Expand All @@ -414,6 +415,7 @@ def groupwise_affine_dequantize_tensor_from_qparams(
zeros,
n_bit=4,
groupsize=128,
zero_point_domain=ZeroPointDomain.FLOAT
):
assert groupsize > 1
assert w_int4x8.dim() == 2
Expand Down Expand Up @@ -452,7 +454,7 @@ def groupwise_affine_dequantize_tensor_from_qparams(
input_dtype,
quant_min,
quant_max,
zero_point_domain=ZeroPointDomain.FLOAT,
zero_point_domain=zero_point_domain,
output_dtype=scales.dtype,
)

Expand Down
Loading