Skip to content

Commit

Permalink
Feat (a2q+): improving accumulator-aware weight quantization (#797)
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert authored Jan 25, 2024
1 parent 0ebbfca commit 56056ba
Show file tree
Hide file tree
Showing 12 changed files with 257 additions and 99 deletions.
1 change: 1 addition & 0 deletions src/brevitas/core/scaling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .int_scaling import IntScaling
from .int_scaling import PowerOfTwoIntScaling
from .pre_scaling import AccumulatorAwareParameterPreScaling
from .pre_scaling import AccumulatorAwareZeroCenterParameterPreScaling
from .pre_scaling import ParameterPreScalingWeightNorm
from .runtime import RuntimeStatsScaling
from .runtime import StatsFromParameterScaling
Expand Down
120 changes: 102 additions & 18 deletions src/brevitas/core/scaling/pre_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
from brevitas.core.restrict_val import _RestrictClampValue
from brevitas.core.stats import SCALAR_SHAPE
from brevitas.core.stats.stats_wrapper import _Stats
from brevitas.core.zero_point import PreZeroCenterZeroPoint
from brevitas.function import abs_binary_sign_grad
from brevitas.function import get_upper_bound_on_l1_norm

__all__ = ["ParameterPreScalingWeightNorm", "AccumulatorAwareParameterPreScaling"]
__all__ = [
"ParameterPreScalingWeightNorm",
"AccumulatorAwareParameterPreScaling",
"AccumulatorAwareZeroCenterParameterPreScaling"]


class ParameterPreScalingWeightNorm(brevitas.jit.ScriptModule):
Expand Down Expand Up @@ -113,7 +117,7 @@ def _load_from_state_dict(
class AccumulatorAwareParameterPreScaling(ParameterPreScalingWeightNorm):
"""
ScriptModule implementation of learned pre-clipping scaling factor to support
accumulator-aware quantizaion (A2Q) as proposed in `A2Q: Accumulator-Aware Quantization
accumulator-aware quantization (A2Q) as proposed in `A2Q: Accumulator-Aware Quantization
with Guaranteed Overflow Avoidance` by I. Colbert, A. Pappalardo, and J. Petri-Koenig.
The module parameterizes the pre-clipping scaling factor (i.e., `pre_scale`) of the
Expand Down Expand Up @@ -150,38 +154,118 @@ class AccumulatorAwareParameterPreScaling(ParameterPreScalingWeightNorm):
"""

def __init__(
self,
scaling_impl: Module,
normalize_stats_impl: Module,
accumulator_bit_width_impl: Module,
scaling_stats_input_view_shape_impl: Module,
tracked_parameter_list: List[torch.nn.Parameter],
pre_scaling_shape: Optional[Tuple[int, ...]] = None,
restrict_pre_scaling_impl: Optional[Module] = None,
pre_scaling_min_val: Optional[float] = None,
) -> None:
self,
scaling_impl: Module,
normalize_stats_impl: Module,
accumulator_bit_width_impl: Module,
scaling_stats_input_view_shape_impl: Module,
tracked_parameter_list: List[torch.nn.Parameter],
pre_scaling_shape: Optional[Tuple[int, ...]] = None,
restrict_pre_scaling_impl: Optional[Module] = None,
pre_scaling_min_val: Optional[float] = None) -> None:
super().__init__(
scaling_impl,
normalize_stats_impl,
scaling_stats_input_view_shape_impl,
tracked_parameter_list,
pre_scaling_shape,
restrict_pre_scaling_impl,
pre_scaling_min_val,
)
pre_scaling_min_val)
self.accumulator_bit_width = accumulator_bit_width_impl

@brevitas.jit.script_method
def forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Takes weights as input and returns the pre-clipping scaling factor"""
def calc_max_l1_norm(self, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
accumulator_bit_width = self.accumulator_bit_width()
upper_bound = get_upper_bound_on_l1_norm(
accumulator_bit_width, input_bit_width, input_is_signed)
return upper_bound

@brevitas.jit.script_method
def inner_forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool):
weights = self.stats_input_view_shape_impl(weights)
d_w = self.stats(weights) # denominator for weight normalization
s = self.scaling_impl(weights) # s
g = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) # g
T = get_upper_bound_on_l1_norm(
self.accumulator_bit_width(), input_bit_width, input_is_signed) # T / s
T = self.calc_max_l1_norm(input_bit_width, input_is_signed) # T / s
g = torch.clamp_max(g / s, T)
value = d_w / g # calculating final pre-clipping scaling factor
# re-apply clamp_min_ste from restrict_scaling_impl to the specified pre_scaling_min_val
value = self.restrict_clamp_scaling.clamp_min_ste(value)
return value

@brevitas.jit.script_method
def forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Takes weights, input bit-width, and input sign as input and returns the pre-clipping
scaling factor per-channel, which is $s \cdot \Vert v \Vert_1 / g$"""
value = self.inner_forward(weights, input_bit_width, input_is_signed)
return value


class AccumulatorAwareZeroCenterParameterPreScaling(AccumulatorAwareParameterPreScaling):
"""
ScriptModule implementation of learned pre-clipping scaling factor to support
A2Q+ as proposed in `A2Q+: Improving Accumulator-Aware Weight Quantization`.
The module implements the zero-centering constraint as a pre-clipping zero-point
(i.e., `PreZeroCenterZeroPoint`) to relax the l1-norm constraint.
Args:
scaling_impl (Module): post-clipping scaling factor.
pre_zero_point_impl (Module): pre-clipping zero-point.
normalize_stats_impl (Module): calculate statistics for normalizing weight parameter.
accumulator_bit_width_impl (Module): module that returns the accumulator bit-width.
scaling_stats_input_view_shape_impl (Module): transforming scaling to a new shape.
tracked_parameter_list (List[torch.nn.Parameter]): list of tracked weight parameters
for tensor quantizer.
pre_scaling_shape (Tuple[int]): shape of pre-clipping scaling factor. Default: None.
restrict_pre_scaling_impl (Module): restrict pre_scaling_init according to some
criteria. Default: None.
pre_scaling_min_val (float): force a lower-bound on scaling_init. Default: None.
Returns:
Tensor: scaling factor wrapped in a float torch.Tensor.
"""

def __init__(
self,
scaling_impl: Module,
pre_zero_point_impl: Module,
normalize_stats_impl: Module,
accumulator_bit_width_impl: Module,
scaling_stats_input_view_shape_impl: Module,
tracked_parameter_list: List[Parameter],
pre_scaling_shape: Optional[Tuple[int, ...]] = None,
restrict_pre_scaling_impl: Optional[Module] = None,
pre_scaling_min_val: Optional[float] = None) -> None:
super().__init__(
scaling_impl,
normalize_stats_impl,
accumulator_bit_width_impl,
scaling_stats_input_view_shape_impl,
tracked_parameter_list,
pre_scaling_shape,
restrict_pre_scaling_impl,
pre_scaling_min_val)
assert isinstance(
pre_zero_point_impl, PreZeroCenterZeroPoint
), "Error: A2Q+ requires a pre-clipping zero-centering zero-point."
self.pre_zero_point = pre_zero_point_impl

@brevitas.jit.script_method
def calc_max_l1_norm(self, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
""" """
assert input_bit_width is not None, "A2Q+ relies on input bit-width."
max_accumulator_bit_width = self.accumulator_bit_width() # P
max_accumulator_mag = pow(2.0, max_accumulator_bit_width) - 2.0 # 2^P - 2
max_input_mag = pow(2.0, input_bit_width) - 1.0 # 2^N - 1
return max_accumulator_mag / max_input_mag

@brevitas.jit.script_method
def forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Takes weights, input bit-width, and input sign as input and returns the pre-clipping
scaling factor per-channel, which is $s \cdot \Vert v - \mu_v \Vert_1 / g$"""
# NOTE: A2Q+ requires zero-centering the floating-point weights, which means that the
# calculation of the l1-norm needs to be done over the zero-centered weights.
z = self.pre_zero_point.get_zero_center(weights)
value = self.inner_forward(weights + z, input_bit_width, input_is_signed)
return value
33 changes: 32 additions & 1 deletion src/brevitas/core/zero_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
'StatsFromParameterZeroPoint',
'ParameterFromRuntimeZeroPoint',
'ParameterZeroPoint',
'ParameterFromStatsFromParameterZeroPoint']
'ParameterFromStatsFromParameterZeroPoint',
'PreZeroCenterZeroPoint']


class ZeroZeroPoint(brevitas.jit.ScriptModule):
Expand Down Expand Up @@ -294,3 +295,33 @@ def _load_from_state_dict(
self.init_done = True
if config.IGNORE_MISSING_KEYS and value_key in missing_keys:
missing_keys.remove(value_key)


class PreZeroCenterZeroPoint(brevitas.jit.ScriptModule):
"""Experimental ScriptModule implementation of a pre-scaling zero-point that zero-centers
the incoming tensors. This is intended to be used with `DecoupledIntQuant`."""

def __init__(
self,
stats_reduce_dim: int,
pre_zero_point_stats_input_view_shape_impl: Module,
pre_zero_point_shape: Optional[Tuple[int, ...]] = None) -> None:
super(PreZeroCenterZeroPoint, self).__init__()
self.stats_reduce_dim = stats_reduce_dim
self.stats_output_shape = pre_zero_point_shape
self.stats_input_view_shape_impl = pre_zero_point_stats_input_view_shape_impl

@brevitas.jit.script_method
def get_zero_center(self, x: Tensor) -> Tensor:
x = self.stats_input_view_shape_impl(x)
u = torch.mean(x, dim=self.stats_reduce_dim, keepdim=True)
z = -u.view(self.stats_output_shape)
return z

@brevitas.jit.script_method
def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor:
# NOTE: `DecoupledIntQuant` adds the `pre_zero_point` value to the scaled tensor,
# so this needs to return the negative of the scaled average value to perform
# pre-zero centering before rounding and clipping
z = self.get_zero_center(x) / scale # need to scale the norm by s
return z
7 changes: 4 additions & 3 deletions src/brevitas/function/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,10 @@ def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_b

def get_upper_bound_on_l1_norm(
accumulator_bit_width: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Calculate the upper bound on the l1-norm of the weights using the derivations from
`Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance`
by I.Colbert, A.Pappalardo, and J.Petri-Koenig."""
"""Calculate the upper bound on the l1-norm of the weights needed to guarantee overflow avoidance
for a given accumulator bit width and input representation using the derivations from
`A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance` by I.Colbert,
A.Pappalardo, and J.Petri-Koenig. Note that this assumes integer quantization."""
assert input_bit_width is not None, "A2Q relies on input bit-width."
assert input_is_signed is not None, "A2Q relies on input sign."
assert accumulator_bit_width is not None, "A2Q relies on accumulator bit-width."
Expand Down
51 changes: 0 additions & 51 deletions src/brevitas/nn/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Optional

import torch
from torch import Tensor
from torch.nn import Parameter

from brevitas.function.ops_ste import ceil_ste
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector

Expand Down Expand Up @@ -79,51 +76,3 @@ def check_tensors_same_ptr(tensor_list):
else:
return False
return all(p == pointers[0] for p in pointers)


def calculate_min_accumulator_bit_width(
input_bit_width: Tensor,
input_is_signed: bool,
weight_max_l1_norm: Optional[Tensor] = None,
weight_bit_width: Optional[Tensor] = None,
n_elements: Optional[Tensor] = None,
min_val: Optional[float] = 1e-10):
"""Using the closed-form bounds on accumulator bit-width as derived in `Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow
Avoidance` by I. Colbert, A. Pappalardo, and J. Petri-Koenig. This function returns the minimum accumulator bit-width that can be used without risk of
overflow. It supports both the data-type bound as well as the weight-level bound.
Args:
input_bit_width (Tensor): the bit-width of the inputs to the layer.
input_is_signed (bool): calculate statistics for normalizing weight parameter.
weight_max_l1_norm (Tensor): the maximum per-channel l1-norm of the weights.
weight_bit_width (Tensor): the bit-width of the weights to the layer.
n_elements (Tensor): the number of elements in the dot product.
min_val (float): the minimum value used for the l1-norm, used to avoid log2(0). Default: 1e-8.
Example (data-type bound):
>> acc_bit_width = calculate_min_accumulator_bit_width(input_bit_width, input_is_signed, weight_bit_width, n_elements)
Example (weight-level bound):
>> acc_bit_width = calculate_min_accumulator_bit_width(input_bit_width, input_is_signed, weight_max_l1_norm)
"""
input_is_signed = float(input_is_signed)
# if the l1-norm of the weights is specified, then use the weight-level bound
if weight_max_l1_norm is not None:
assert isinstance(weight_max_l1_norm, (float, Tensor)), "The l1-norm of the weights needs to be a float or a torch.Tensor instance."
if isinstance(weight_max_l1_norm, Tensor):
assert weight_max_l1_norm.numel() == 1, "The minimum accumulator bit-width calculation currently only supports scalars."
weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, min_val)
input_is_signed = float(input_is_signed)
alpha = torch.log2(weight_max_l1_norm) + input_bit_width - input_is_signed
# else use the data-type bound
else:
assert isinstance(weight_bit_width, (float, Tensor)), "If weight_max_l1_norm is un-specified, weight_bit_width needs to be specified."
assert isinstance(n_elements, (float, Tensor)), "If weight_max_l1_norm is un-specified, n_elements needs to be specified."
if isinstance(n_elements, Tensor):
assert n_elements.numel() == 1, "The minimum accumulator bit-width calculation currently only supports scalars."
assert n_elements > 0, "There needs to be at least one element considered in this evaluation."
alpha = torch.log2(n_elements) + input_bit_width + weight_bit_width - input_is_signed - 1.
phi = lambda x: torch.log2(1. + pow(2., -x))
min_bit_width = alpha + phi(alpha) + 1.
min_bit_width = ceil_ste(min_bit_width)
return min_bit_width # returns the minimum accumulator that can be used without risk of overflow
17 changes: 17 additions & 0 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from brevitas.core.restrict_val import FloatRestrictValue
from brevitas.core.restrict_val import LogFloatRestrictValue
from brevitas.core.scaling import AccumulatorAwareParameterPreScaling
from brevitas.core.scaling import AccumulatorAwareZeroCenterParameterPreScaling
from brevitas.core.scaling import IntScaling
from brevitas.core.scaling import ParameterFromStatsFromParameterScaling
from brevitas.core.scaling import ParameterPreScalingWeightNorm
Expand All @@ -38,6 +39,7 @@
from brevitas.core.utils import SingleArgStatelessBuffer
from brevitas.core.zero_point import ParameterFromRuntimeZeroPoint
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
from brevitas.core.zero_point import PreZeroCenterZeroPoint
from brevitas.core.zero_point import StatsFromParameterZeroPoint
from brevitas.core.zero_point import ZeroZeroPoint
from brevitas.inject import ExtendedInjector
Expand Down Expand Up @@ -76,6 +78,7 @@
'BatchQuantStatsScaling1d',
'BatchQuantStatsScaling2d',
'AccumulatorAwareWeightQuant',
'AccumulatorAwareZeroCenterWeightQuant',
'MSESymmetricScale',
'MSEAsymmetricScale',
'MSEWeightZeroPoint',
Expand Down Expand Up @@ -400,6 +403,20 @@ def accumulator_bit_width_impl(accumulator_bit_width):
float_to_int_impl = RoundToZeroSte # required to ensure no upwards rounding violates constraints


class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant):
"""Experimental zero-centered accumulator-aware weight quantized based on:
`A2Q+: Improving Accumulator-Aware Weight Quantization`.
When compared to A2Q, A2Q+ changes the following:
(1) added zero-centering constraint on the weights (i.e., `PreZeroCenterZeroPoint`)
(2) a more relaxed l1-norm bound that is derived in the referenced paper
"""
pre_scaling_impl = AccumulatorAwareZeroCenterParameterPreScaling
pre_zero_point_impl = PreZeroCenterZeroPoint
pre_zero_point_shape = this.scaling_shape # TODO: decouple zero_point from scaling
pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl


class MSESubInjectorBase(ExtendedInjector):

@value
Expand Down
Loading

0 comments on commit 56056ba

Please sign in to comment.