Skip to content

Commit

Permalink
fix group size min max tracking by adding tensor ids (#60)
Browse files Browse the repository at this point in the history
* fix group size min max tracking by adding tensor ids

* propagate change to  in base

* bug

* lint

* add back reduce_dims

* fix

* fix

* comment

---------

Co-authored-by: George Ohashi <george@neuralmagic.com>
  • Loading branch information
bfineran and horheynm authored May 22, 2024
1 parent b76acf4 commit 2c64578
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 32 deletions.
21 changes: 14 additions & 7 deletions src/compressed_tensors/quantization/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple
from typing import Any, Optional, Tuple

import torch
from compressed_tensors.quantization.quant_args import (
Expand Down Expand Up @@ -93,15 +93,18 @@ def get_qparams(
elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
columns = observed.shape[1]
scales, zero_points = [], []
for i in range(0, columns, self.quantization_args.group_size):
group_idxs = range(0, columns, self.quantization_args.group_size)
for group_id, group_idx in enumerate(group_idxs):
scale, zero_point = self.get_qparams_along_dim(
observed[:, i : (i + group_size)],
observed[:, group_idx : (group_idx + group_size)],
0,
tensor_id=group_id,
)
scales.append(scale)
zero_points.append(zero_point)
self._scale = torch.stack(scales, dim=1, out=self._scale)
self._zero_point = torch.stack(zero_points, dim=1, out=self._zero_point)

self._scale = torch.cat(scales, dim=1, out=self._scale)
self._zero_point = torch.cat(zero_points, dim=1, out=self._zero_point)

elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
# assume observed is transposed, because its the output, hence use dim 0
Expand All @@ -116,6 +119,10 @@ def get_qparams(

return self._scale, self._zero_point

def get_qparams_along_dim(self, observed, dim: int):
def get_qparams_along_dim(
self, observed, dim: int, tensor_id: Optional[Any] = None
):
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
return self.calculate_qparams(observed, reduce_dims=reduce_dims)
return self.calculate_qparams(
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
)
4 changes: 3 additions & 1 deletion src/compressed_tensors/quantization/observers/memoryless.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple
from typing import Any, Optional, Tuple

import torch
from compressed_tensors.quantization.observers.base import Observer
Expand All @@ -33,12 +33,14 @@ class MemorylessObserver(Observer):
def calculate_qparams(
self,
observed: Tensor,
tensor_id: Optional[Any] = None,
reduce_dims: Optional[Tuple[int]] = None,
) -> Tuple[FloatTensor, IntTensor]:
"""
Returns the min and max values of observed tensor
:param observed: observed tensor to calculate quantization parameters for
:param tensor_id: optional id for tensor; not used for memoryless
:param reduce_dims: optional tuple of dimensions to reduce along,
returned scale and zero point will be shaped (1,) along the
reduced dimensions
Expand Down
42 changes: 31 additions & 11 deletions src/compressed_tensors/quantization/observers/min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple
from typing import Any, Optional, Tuple

import torch
from compressed_tensors.quantization.observers.base import Observer
Expand All @@ -36,14 +36,15 @@ def __init__(
):
super().__init__(quantization_args=quantization_args)

self.min_val = None
self.max_val = None
self.min_val = {}
self.max_val = {}
self.averaging_constant = averaging_constant

def calculate_qparams(
self,
observed: Tensor,
reduce_dims: Optional[Tuple[int]] = None,
tensor_id: Optional[Any] = None,
) -> Tuple[FloatTensor, IntTensor]:
"""
Updates the observed min and max using a moving average smoothed by the
Expand All @@ -53,24 +54,43 @@ def calculate_qparams(
:param reduce_dims: optional tuple of dimensions to reduce along,
returned scale and zero point will be shaped (1,) along the
reduced dimensions
:param tensor_id: Optional id if different ranges of observed tensors are
passed, useful for sharding tensors by group_size
:return: tuple of scale and zero point derived from the observed tensor
"""
tensor_id = tensor_id or "default"

if not reduce_dims:
min_val, max_val = torch.aminmax(observed)
else:
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)

if self.min_val is None and self.max_val is None:
self.min_val = min_val
self.max_val = max_val
running_min_val = self.min_val.get(tensor_id, None)
running_max_val = self.max_val.get(tensor_id, None)

if running_min_val is None or running_max_val is None:
updated_min_val = min_val
updated_max_val = max_val
else:
self.min_val = self.min_val + self.averaging_constant * (
min_val - self.min_val
updated_min_val = running_min_val + self.averaging_constant * (
min_val - running_min_val
)
self.max_val = self.max_val + self.averaging_constant * (
max_val - self.max_val
updated_max_val = running_max_val + self.averaging_constant * (
max_val - running_max_val
)

return calculate_qparams(self.min_val, self.max_val, self.quantization_args)
self.min_val[tensor_id] = updated_min_val
self.max_val[tensor_id] = updated_max_val

return calculate_qparams(
updated_min_val, updated_max_val, self.quantization_args
)

def get_qparams_along_dim(
self, observed, dim: int, tensor_id: Optional[Any] = None
):
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
return self.calculate_qparams(
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
)
17 changes: 8 additions & 9 deletions src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,32 +37,31 @@ class QuantizationScheme(BaseModel):
weights: Optional[QuantizationArgs] = None
input_activations: Optional[QuantizationArgs] = None
output_activations: Optional[QuantizationArgs] = None

@classmethod
def default_scheme(
cls,
targets: Optional[List[str]] = None,
):

if targets is None:
# default to quantizing all Linear layers
targets = ["Linear"]

# default to 8 bit integer symmetric quantization
# for weights
weights = QuantizationArgs(num_bits=8, symmetric=True)

# default to 8 bit integer asymmetric quantization
input_activations = QuantizationArgs(num_bits=8, symmetric=True)

# Do not quantize the output activations
# by default
output_activations = None

return cls(
targets=targets,
weights=weights,
input_activations=input_activations,
output_activations=output_activations,)


output_activations=output_activations,
)
2 changes: 0 additions & 2 deletions tests/test_quantization/test_configs/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,10 @@ def test_group(input_symmetry, weight_symmetry, model_shape, group_size):
assert list(model.weight_scale.shape) == [
model_shape[1],
int(model_shape[0] / group_size),
1,
]
assert list(model.weight_zero_point.shape) == [
model_shape[1],
int(model_shape[0] / group_size),
1,
]


Expand Down
4 changes: 2 additions & 2 deletions tests/test_quantization/test_observers/test_min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def test_min_max_observer_value_update():
curr_min = 1
for i, tensor in enumerate(tensors):
observer(tensor)
curr_max = max(observer.max_val, curr_max)
curr_min = min(observer.min_val, curr_max)
curr_max = max(observer.max_val.get("default"), curr_max)
curr_min = min(observer.min_val.get("default"), curr_max)

if i < 2:
assert curr_max == 1
Expand Down

0 comments on commit 2c64578

Please sign in to comment.