Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Observers] group size + channel wise + per token #32

Merged
merged 25 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ def quantize(
q_max: torch.Tensor,
) -> torch.Tensor:
return torch.clamp(
torch.round(
x / scale + zero_point,
),
torch.round(x / scale + zero_point),
q_min,
q_max,
)
Expand All @@ -60,9 +58,20 @@ def fake_quantize(
bit_range = 2**args.num_bits
max_q = torch.tensor(bit_range / 2 - 1, device=x.device)
min_q = torch.tensor(-bit_range / 2, device=x.device)
Q = torch.zeros_like(x)
Q = quantize(x, scale, zero_point, min_q, max_q)
return dequantize(Q, scale, zero_point)
# Q = torch.zeros_like(x)
DQ = torch.zeros_like(x)
num_groups = len(scale)
horheynm marked this conversation as resolved.
Show resolved Hide resolved
group_size = int(x.shape[1] / num_groups)
for i in range(num_groups):
horheynm marked this conversation as resolved.
Show resolved Hide resolved
sc = scale[i]
zp = zero_point[i]

idx = i * group_size
Q = quantize(x[:, idx : (idx + group_size)], sc, zp, min_q, max_q)
DQ[:, idx : (idx + group_size)] = dequantize(Q, sc, zp)
breakpoint()
# Q = quantize(x, scale, zero_point, min_q, max_q)
return DQ


def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
Expand Down
33 changes: 31 additions & 2 deletions src/compressed_tensors/quantization/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import Optional, Tuple

import torch
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.registry.registry import RegistryMixin
from torch import FloatTensor, IntTensor, Tensor
Expand Down Expand Up @@ -64,6 +65,34 @@ def get_qparams(
:return: tuple of scale and zero point based on last observed value
"""
if observed is not None:
# re-calcualte scale and zero point, update the stored value
self._scale, self._zero_point = self.calculate_qparams(observed)
group_size = self.quantization_args.group_size
if group_size is None:

# re-calcualte scale and zero point, update the stored value
Satrat marked this conversation as resolved.
Show resolved Hide resolved
self._scale, self._zero_point = self.calculate_qparams(observed)
if hasattr(self, "inc"):
self.inc()
horheynm marked this conversation as resolved.
Show resolved Hide resolved

elif group_size > 0: # quantize by groups
columns = observed.shape[1]
scales, zero_points = [], []
for i in range(0, columns, self.quantization_args.group_size):
scale, zero_point = self.calculate_qparams(
observed[:, i : (i + group_size)]
)
scales.append(scale)
zero_points.append(zero_point)

if hasattr(self, "inc"):
self.inc()

self._scale = torch.cat(scales)
self._zero_point = torch.cat(zero_points)

elif group_size < 0: # channel-wise quantization
# TODO: Import channel wise logic here

if hasattr(self, "inc"):
self.inc()

return self._scale, self._zero_point
2 changes: 1 addition & 1 deletion src/compressed_tensors/quantization/observers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def calculate_qparams(
if quantization_args.symmetric:
symmetric_range = 2 * max(min_vals.abs(), max_vals.abs())
scales = symmetric_range / bit_range
zero_points = torch.tensor(0).to(torch.int8)
zero_points = torch.tensor([0]).to(torch.int8)
else:
# non-symmetric
observed_range = max_vals - min_vals
Expand Down
10 changes: 8 additions & 2 deletions src/compressed_tensors/quantization/observers/min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,12 @@ def __init__(self, quantization_args: QuantizationArgs):
self.max_val = -float("inf")
self.counter = 0

def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
def calculate_qparams(
self,
observed: Tensor,
) -> Tuple[FloatTensor, IntTensor]:
"""

:param observed: observed tensor to calculate quantization parameters for
:return: tuple of scale and zero point derived from the observed tensor
"""
Expand All @@ -59,5 +63,7 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
min_val = torch.min(self.min_val, torch.zeros_like(self.min_val))
max_val = torch.max(self.max_val, torch.zeros_like(self.max_val))

self.counter += 1
return calculate_qparams(min_val, max_val, self.quantization_args)

def inc(self):
self.counter += 1
Loading