Skip to content

Commit

Permalink
HACK: change 8da4w to 2-bit weight only quant
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
andrewor14 committed May 21, 2024
1 parent b5ebc94 commit 5e497ff
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 20 deletions.
2 changes: 1 addition & 1 deletion test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _set_ptq_weight(
Set the weight to the quantized version of the given fp32 weights,
for making linear outputs comparable with QAT.
"""
n_bit = 4
n_bit = 2
(qmin, qmax) = self._get_qmin_qmax(n_bit)
(s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size)
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
Expand Down
7 changes: 4 additions & 3 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,14 +1010,14 @@ def linear_forward_8da4w(
groupsize,
precision,
):
x = per_token_dynamic_quant(x)
#x = per_token_dynamic_quant(x)
# TODO: verify and remove following reshape code
# origin_x_size = x.size()
# x = x.reshape(-1, origin_x_size[-1])

# TODO: better API
# weight_int8 = torch.ops.quantized_decomposed.unpack_int4_to_int8(weight_int4packed)
n_bit = 4
n_bit = 2
quant_min = -(2 ** (n_bit - 1))
quant_max = 2 ** (n_bit - 1) - 1
w_dq = torch.ops.quantized_decomposed.dequantize_per_channel_group(
Expand Down Expand Up @@ -1196,6 +1196,7 @@ def __init__(
self.padding_allowed: bool = padding_allowed
self.precision: torch.dtype = precision
self.scales_precision: torch.dtype = scales_precision
print("PTQ is running 2-bit weight only quant!")

@torch.no_grad()
def _create_quantized_state_dict(
Expand Down Expand Up @@ -1236,7 +1237,7 @@ def _create_quantized_state_dict(
zeros,
) = group_quantize_tensor_symmetric(
weight.to(self.precision),
4, # n_bit
2, # n_bit
self.groupsize,
self.scales_precision,
)
Expand Down
32 changes: 17 additions & 15 deletions torchao/quantization/prototype/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _convert_qat_linear_8da4w(
setattr(module, name, quantized_linear)

# Load weights and qparams into quantized linear
n_bit = 4
n_bit = 2
(qmin, qmax) = child._get_qmin_qmax(n_bit)
(s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize)
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
Expand Down Expand Up @@ -159,6 +159,7 @@ def __init__(
# TODO: make this configurable?
self.zero_points_precision = torch.int32
self._fake_quant_enabled = True
print("QAT: I'm runining 2-bit weight only quant!")

def enable_fake_quant(self, enabled: bool = True):
self._fake_quant_enabled = enabled
Expand All @@ -168,25 +169,26 @@ def disable_fake_quant(self):

def forward(self, x: torch.Tensor) -> torch.Tensor:
# activations: int8 dynamic asymmetric quant
if self._fake_quant_enabled:
(act_scales, act_zp) = _choose_qparams_per_token_asymmetric(
x, self.scales_precision, self.zero_points_precision,
)
(act_qmin, act_qmax) = self._get_qmin_qmax(8)
x_fq = fake_quantize_per_token(
x, act_scales, act_zp, act_qmin, act_qmax,
)
else:
x_fq = x

# weights: int4 grouped per channel symmetric quant
#if self._fake_quant_enabled:
# (act_scales, act_zp) = _choose_qparams_per_token_asymmetric(
# x, self.scales_precision, self.zero_points_precision,
# )
# (act_qmin, act_qmax) = self._get_qmin_qmax(8)
# x_fq = fake_quantize_per_token(
# x, act_scales, act_zp, act_qmin, act_qmax,
# )
#else:
# x_fq = x
x_fq = x

# weights: int2 grouped per channel symmetric quant
if self._fake_quant_enabled:
(weight_scales, weight_zp) = get_group_qparams_symmetric(
self.weight, 4, self.groupsize, self.scales_precision,
self.weight, 2, self.groupsize, self.scales_precision,
)
# TODO: pass zp dtype to `get_group_qparams_symmetric` instead
weight_zp = weight_zp.to(self.zero_points_precision)
(weight_qmin, weight_qmax) = self._get_qmin_qmax(4)
(weight_qmin, weight_qmax) = self._get_qmin_qmax(2)
w_fq = fake_quantize_per_channel_group(
self.weight,
weight_scales,
Expand Down
1 change: 0 additions & 1 deletion torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,6 @@ def group_quantize_tensor_symmetric(
precision=torch.float32,
):
scales, zeros = get_group_qparams_symmetric(w, n_bit, group_size, precision)
n_bit = 4
max_int = 2 ** (n_bit - 1) - 1
min_int = -(2 ** (n_bit - 1))
# TODO: currently we don't know how to express torch.int4, we'll
Expand Down

0 comments on commit 5e497ff

Please sign in to comment.