Skip to content

Commit

Permalink
HACK: now it's 3-bits
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
andrewor14 committed May 21, 2024
1 parent afd86bd commit b50609c
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _set_ptq_weight(
Set the weight to the quantized version of the given fp32 weights,
for making linear outputs comparable with QAT.
"""
n_bit = 2
n_bit = 3
(qmin, qmax) = self._get_qmin_qmax(n_bit)
(s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size)
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
Expand Down
6 changes: 3 additions & 3 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,7 @@ def linear_forward_8da4w(

# TODO: better API
# weight_int8 = torch.ops.quantized_decomposed.unpack_int4_to_int8(weight_int4packed)
n_bit = 2
n_bit = 3
quant_min = -(2 ** (n_bit - 1))
quant_max = 2 ** (n_bit - 1) - 1
w_dq = torch.ops.quantized_decomposed.dequantize_per_channel_group(
Expand Down Expand Up @@ -1196,7 +1196,7 @@ def __init__(
self.padding_allowed: bool = padding_allowed
self.precision: torch.dtype = precision
self.scales_precision: torch.dtype = scales_precision
print("PTQ is running 2-bit weight only quant!")
print("PTQ is running 3-bit weight only quant!")

@torch.no_grad()
def _create_quantized_state_dict(
Expand Down Expand Up @@ -1237,7 +1237,7 @@ def _create_quantized_state_dict(
zeros,
) = group_quantize_tensor_symmetric(
weight.to(self.precision),
2, # n_bit
3, # n_bit
self.groupsize,
self.scales_precision,
)
Expand Down
10 changes: 5 additions & 5 deletions torchao/quantization/prototype/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _convert_qat_linear_8da4w(
setattr(module, name, quantized_linear)

# Load weights and qparams into quantized linear
n_bit = 2
n_bit = 3
(qmin, qmax) = child._get_qmin_qmax(n_bit)
(s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize)
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
Expand Down Expand Up @@ -159,7 +159,7 @@ def __init__(
# TODO: make this configurable?
self.zero_points_precision = torch.int32
self._fake_quant_enabled = True
print("QAT: I'm runining 2-bit weight only quant!")
print("QAT: I'm runining 3-bit weight only quant!")

def enable_fake_quant(self, enabled: bool = True):
self._fake_quant_enabled = enabled
Expand All @@ -181,14 +181,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# x_fq = x
x_fq = x

# weights: int2 grouped per channel symmetric quant
# weights: int3 grouped per channel symmetric quant
if self._fake_quant_enabled:
(weight_scales, weight_zp) = get_group_qparams_symmetric(
self.weight, 2, self.groupsize, self.scales_precision,
self.weight, 3, self.groupsize, self.scales_precision,
)
# TODO: pass zp dtype to `get_group_qparams_symmetric` instead
weight_zp = weight_zp.to(self.zero_points_precision)
(weight_qmin, weight_qmax) = self._get_qmin_qmax(2)
(weight_qmin, weight_qmax) = self._get_qmin_qmax(3)
w_fq = fake_quantize_per_channel_group(
self.weight,
weight_scales,
Expand Down

0 comments on commit b50609c

Please sign in to comment.