Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (a2q+): improving accumulator-aware weight quantization #797

Merged
merged 16 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/brevitas/core/scaling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .int_scaling import IntScaling
from .int_scaling import PowerOfTwoIntScaling
from .pre_scaling import AccumulatorAwareParameterPreScaling
from .pre_scaling import AccumulatorAwareZeroCenterParameterPreScaling
from .pre_scaling import ParameterPreScalingWeightNorm
from .runtime import RuntimeStatsScaling
from .runtime import StatsFromParameterScaling
Expand Down
120 changes: 102 additions & 18 deletions src/brevitas/core/scaling/pre_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
from brevitas.core.restrict_val import _RestrictClampValue
from brevitas.core.stats import SCALAR_SHAPE
from brevitas.core.stats.stats_wrapper import _Stats
from brevitas.core.zero_point import PreZeroCenterZeroPoint
from brevitas.function import abs_binary_sign_grad
from brevitas.function import get_upper_bound_on_l1_norm

__all__ = ["ParameterPreScalingWeightNorm", "AccumulatorAwareParameterPreScaling"]
__all__ = [
"ParameterPreScalingWeightNorm",
"AccumulatorAwareParameterPreScaling",
"AccumulatorAwareZeroCenterParameterPreScaling"]


class ParameterPreScalingWeightNorm(brevitas.jit.ScriptModule):
Expand Down Expand Up @@ -113,7 +117,7 @@ def _load_from_state_dict(
class AccumulatorAwareParameterPreScaling(ParameterPreScalingWeightNorm):
"""
ScriptModule implementation of learned pre-clipping scaling factor to support
accumulator-aware quantizaion (A2Q) as proposed in `A2Q: Accumulator-Aware Quantization
accumulator-aware quantization (A2Q) as proposed in `A2Q: Accumulator-Aware Quantization
with Guaranteed Overflow Avoidance` by I. Colbert, A. Pappalardo, and J. Petri-Koenig.
The module parameterizes the pre-clipping scaling factor (i.e., `pre_scale`) of the
Expand Down Expand Up @@ -150,38 +154,118 @@ class AccumulatorAwareParameterPreScaling(ParameterPreScalingWeightNorm):
"""

def __init__(
self,
scaling_impl: Module,
normalize_stats_impl: Module,
accumulator_bit_width_impl: Module,
scaling_stats_input_view_shape_impl: Module,
tracked_parameter_list: List[torch.nn.Parameter],
pre_scaling_shape: Optional[Tuple[int, ...]] = None,
restrict_pre_scaling_impl: Optional[Module] = None,
pre_scaling_min_val: Optional[float] = None,
) -> None:
self,
scaling_impl: Module,
normalize_stats_impl: Module,
accumulator_bit_width_impl: Module,
scaling_stats_input_view_shape_impl: Module,
tracked_parameter_list: List[torch.nn.Parameter],
pre_scaling_shape: Optional[Tuple[int, ...]] = None,
restrict_pre_scaling_impl: Optional[Module] = None,
pre_scaling_min_val: Optional[float] = None) -> None:
super().__init__(
scaling_impl,
normalize_stats_impl,
scaling_stats_input_view_shape_impl,
tracked_parameter_list,
pre_scaling_shape,
restrict_pre_scaling_impl,
pre_scaling_min_val,
)
pre_scaling_min_val)
self.accumulator_bit_width = accumulator_bit_width_impl

@brevitas.jit.script_method
def forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Takes weights as input and returns the pre-clipping scaling factor"""
def calc_max_l1_norm(self, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
accumulator_bit_width = self.accumulator_bit_width()
upper_bound = get_upper_bound_on_l1_norm(
accumulator_bit_width, input_bit_width, input_is_signed)
return upper_bound

@brevitas.jit.script_method
def inner_forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool):
weights = self.stats_input_view_shape_impl(weights)
d_w = self.stats(weights) # denominator for weight normalization
s = self.scaling_impl(weights) # s
g = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) # g
T = get_upper_bound_on_l1_norm(
self.accumulator_bit_width(), input_bit_width, input_is_signed) # T / s
T = self.calc_max_l1_norm(input_bit_width, input_is_signed) # T / s
g = torch.clamp_max(g / s, T)
value = d_w / g # calculating final pre-clipping scaling factor
# re-apply clamp_min_ste from restrict_scaling_impl to the specified pre_scaling_min_val
value = self.restrict_clamp_scaling.clamp_min_ste(value)
return value

@brevitas.jit.script_method
def forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Takes weights, input bit-width, and input sign as input and returns the pre-clipping
scaling factor per-channel, which is $s \cdot \Vert v \Vert_1 / g$"""
value = self.inner_forward(weights, input_bit_width, input_is_signed)
return value


class AccumulatorAwareZeroCenterParameterPreScaling(AccumulatorAwareParameterPreScaling):
"""
ScriptModule implementation of learned pre-clipping scaling factor to support
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
A2Q+ as proposed in `A2Q+: Improving Accumulator-Aware Weight Quantization`.
The module implements the zero-centering constraint as a pre-clipping zero-point
(i.e., `PreZeroCenterZeroPoint`) to relax the l1-norm constraint.
Args:
scaling_impl (Module): post-clipping scaling factor.
pre_zero_point_impl (Module): pre-clipping zero-point.
normalize_stats_impl (Module): calculate statistics for normalizing weight parameter.
accumulator_bit_width_impl (Module): module that returns the accumulator bit-width.
scaling_stats_input_view_shape_impl (Module): transforming scaling to a new shape.
tracked_parameter_list (List[torch.nn.Parameter]): list of tracked weight parameters
for tensor quantizer.
pre_scaling_shape (Tuple[int]): shape of pre-clipping scaling factor. Default: None.
restrict_pre_scaling_impl (Module): restrict pre_scaling_init according to some
criteria. Default: None.
pre_scaling_min_val (float): force a lower-bound on scaling_init. Default: None.
Returns:
Tensor: scaling factor wrapped in a float torch.Tensor.
"""

def __init__(
self,
scaling_impl: Module,
pre_zero_point_impl: Module,
normalize_stats_impl: Module,
accumulator_bit_width_impl: Module,
scaling_stats_input_view_shape_impl: Module,
tracked_parameter_list: List[Parameter],
pre_scaling_shape: Optional[Tuple[int, ...]] = None,
restrict_pre_scaling_impl: Optional[Module] = None,
pre_scaling_min_val: Optional[float] = None) -> None:
super().__init__(
scaling_impl,
normalize_stats_impl,
accumulator_bit_width_impl,
scaling_stats_input_view_shape_impl,
tracked_parameter_list,
pre_scaling_shape,
restrict_pre_scaling_impl,
pre_scaling_min_val)
assert isinstance(
pre_zero_point_impl, PreZeroCenterZeroPoint
), "Error: A2Q+ requires a pre-clipping zero-centering zero-point."
self.pre_zero_point = pre_zero_point_impl

@brevitas.jit.script_method
def calc_max_l1_norm(self, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
""" """
assert input_bit_width is not None, "A2Q+ relies on input bit-width."
max_accumulator_bit_width = self.accumulator_bit_width() # P
max_accumulator_mag = pow(2.0, max_accumulator_bit_width) - 2.0 # 2^P - 2
max_input_mag = pow(2.0, input_bit_width) - 1.0 # 2^N - 1
return max_accumulator_mag / max_input_mag

@brevitas.jit.script_method
def forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Takes weights, input bit-width, and input sign as input and returns the pre-clipping
scaling factor per-channel, which is $s \cdot \Vert v - \mu_v \Vert_1 / g$"""
# NOTE: A2Q+ requires zero-centering the floating-point weights, which means that the
# calculation of the l1-norm needs to be done over the zero-centered weights.
z = self.pre_zero_point.get_zero_center(weights)
value = self.inner_forward(weights + z, input_bit_width, input_is_signed)
return value
33 changes: 32 additions & 1 deletion src/brevitas/core/zero_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
'StatsFromParameterZeroPoint',
'ParameterFromRuntimeZeroPoint',
'ParameterZeroPoint',
'ParameterFromStatsFromParameterZeroPoint']
'ParameterFromStatsFromParameterZeroPoint',
'PreZeroCenterZeroPoint']


class ZeroZeroPoint(brevitas.jit.ScriptModule):
Expand Down Expand Up @@ -294,3 +295,33 @@ def _load_from_state_dict(
self.init_done = True
if config.IGNORE_MISSING_KEYS and value_key in missing_keys:
missing_keys.remove(value_key)


class PreZeroCenterZeroPoint(brevitas.jit.ScriptModule):
"""Experimental ScriptModule implementation of a pre-scaling zero-point that zero-centers
the incoming tensors. This is intended to be used with `DecoupledIntQuant`."""

def __init__(
self,
stats_reduce_dim: int,
pre_zero_point_stats_input_view_shape_impl: Module,
pre_zero_point_shape: Optional[Tuple[int, ...]] = None) -> None:
super(PreZeroCenterZeroPoint, self).__init__()
self.stats_reduce_dim = stats_reduce_dim
self.stats_output_shape = pre_zero_point_shape
self.stats_input_view_shape_impl = pre_zero_point_stats_input_view_shape_impl

@brevitas.jit.script_method
def get_zero_center(self, x: Tensor) -> Tensor:
x = self.stats_input_view_shape_impl(x)
u = torch.mean(x, dim=self.stats_reduce_dim, keepdim=True)
z = -u.view(self.stats_output_shape)
return z

@brevitas.jit.script_method
def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor:
# NOTE: `DecoupledIntQuant` adds the `pre_zero_point` value to the scaled tensor,
# so this needs to return the negative of the scaled average value to perform
# pre-zero centering before rounding and clipping
z = self.get_zero_center(x) / scale # need to scale the norm by s
return z
7 changes: 4 additions & 3 deletions src/brevitas/function/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,10 @@ def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_b

def get_upper_bound_on_l1_norm(
accumulator_bit_width: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Calculate the upper bound on the l1-norm of the weights using the derivations from
`Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance`
by I.Colbert, A.Pappalardo, and J.Petri-Koenig."""
"""Calculate the upper bound on the l1-norm of the weights needed to guarantee overflow avoidance
for a given accumulator bit width and input representation using the derivations from
`A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance` by I.Colbert,
A.Pappalardo, and J.Petri-Koenig. Note that this assumes integer quantization."""
assert input_bit_width is not None, "A2Q relies on input bit-width."
assert input_is_signed is not None, "A2Q relies on input sign."
assert accumulator_bit_width is not None, "A2Q relies on accumulator bit-width."
Expand Down
51 changes: 0 additions & 51 deletions src/brevitas/nn/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Optional

import torch
from torch import Tensor
from torch.nn import Parameter

from brevitas.function.ops_ste import ceil_ste
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector

Expand Down Expand Up @@ -79,51 +76,3 @@ def check_tensors_same_ptr(tensor_list):
else:
return False
return all(p == pointers[0] for p in pointers)


def calculate_min_accumulator_bit_width(
input_bit_width: Tensor,
input_is_signed: bool,
weight_max_l1_norm: Optional[Tensor] = None,
weight_bit_width: Optional[Tensor] = None,
n_elements: Optional[Tensor] = None,
min_val: Optional[float] = 1e-10):
"""Using the closed-form bounds on accumulator bit-width as derived in `Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow
Avoidance` by I. Colbert, A. Pappalardo, and J. Petri-Koenig. This function returns the minimum accumulator bit-width that can be used without risk of
overflow. It supports both the data-type bound as well as the weight-level bound.
Args:
input_bit_width (Tensor): the bit-width of the inputs to the layer.
input_is_signed (bool): calculate statistics for normalizing weight parameter.
weight_max_l1_norm (Tensor): the maximum per-channel l1-norm of the weights.
weight_bit_width (Tensor): the bit-width of the weights to the layer.
n_elements (Tensor): the number of elements in the dot product.
min_val (float): the minimum value used for the l1-norm, used to avoid log2(0). Default: 1e-8.
Example (data-type bound):
>> acc_bit_width = calculate_min_accumulator_bit_width(input_bit_width, input_is_signed, weight_bit_width, n_elements)
Example (weight-level bound):
>> acc_bit_width = calculate_min_accumulator_bit_width(input_bit_width, input_is_signed, weight_max_l1_norm)
"""
input_is_signed = float(input_is_signed)
# if the l1-norm of the weights is specified, then use the weight-level bound
if weight_max_l1_norm is not None:
assert isinstance(weight_max_l1_norm, (float, Tensor)), "The l1-norm of the weights needs to be a float or a torch.Tensor instance."
if isinstance(weight_max_l1_norm, Tensor):
assert weight_max_l1_norm.numel() == 1, "The minimum accumulator bit-width calculation currently only supports scalars."
weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, min_val)
input_is_signed = float(input_is_signed)
alpha = torch.log2(weight_max_l1_norm) + input_bit_width - input_is_signed
# else use the data-type bound
else:
assert isinstance(weight_bit_width, (float, Tensor)), "If weight_max_l1_norm is un-specified, weight_bit_width needs to be specified."
assert isinstance(n_elements, (float, Tensor)), "If weight_max_l1_norm is un-specified, n_elements needs to be specified."
if isinstance(n_elements, Tensor):
assert n_elements.numel() == 1, "The minimum accumulator bit-width calculation currently only supports scalars."
assert n_elements > 0, "There needs to be at least one element considered in this evaluation."
alpha = torch.log2(n_elements) + input_bit_width + weight_bit_width - input_is_signed - 1.
phi = lambda x: torch.log2(1. + pow(2., -x))
min_bit_width = alpha + phi(alpha) + 1.
min_bit_width = ceil_ste(min_bit_width)
return min_bit_width # returns the minimum accumulator that can be used without risk of overflow
17 changes: 17 additions & 0 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from brevitas.core.restrict_val import FloatRestrictValue
from brevitas.core.restrict_val import LogFloatRestrictValue
from brevitas.core.scaling import AccumulatorAwareParameterPreScaling
from brevitas.core.scaling import AccumulatorAwareZeroCenterParameterPreScaling
from brevitas.core.scaling import IntScaling
from brevitas.core.scaling import ParameterFromStatsFromParameterScaling
from brevitas.core.scaling import ParameterPreScalingWeightNorm
Expand All @@ -38,6 +39,7 @@
from brevitas.core.utils import SingleArgStatelessBuffer
from brevitas.core.zero_point import ParameterFromRuntimeZeroPoint
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
from brevitas.core.zero_point import PreZeroCenterZeroPoint
from brevitas.core.zero_point import StatsFromParameterZeroPoint
from brevitas.core.zero_point import ZeroZeroPoint
from brevitas.inject import ExtendedInjector
Expand Down Expand Up @@ -76,6 +78,7 @@
'BatchQuantStatsScaling1d',
'BatchQuantStatsScaling2d',
'AccumulatorAwareWeightQuant',
'AccumulatorAwareZeroCenterWeightQuant',
'MSESymmetricScale',
'MSEAsymmetricScale',
'MSEWeightZeroPoint',
Expand Down Expand Up @@ -400,6 +403,20 @@ def accumulator_bit_width_impl(accumulator_bit_width):
float_to_int_impl = RoundToZeroSte # required to ensure no upwards rounding violates constraints


class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant):
"""Experimental zero-centered accumulator-aware weight quantized based on:
`A2Q+: Improving Accumulator-Aware Weight Quantization`.
When compared to A2Q, A2Q+ changes the following:
(1) added zero-centering constraint on the weights (i.e., `PreZeroCenterZeroPoint`)
(2) a more relaxed l1-norm bound that is derived in the referenced paper
"""
pre_scaling_impl = AccumulatorAwareZeroCenterParameterPreScaling
pre_zero_point_impl = PreZeroCenterZeroPoint
pre_zero_point_shape = this.scaling_shape # TODO: decouple zero_point from scaling
pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl


class MSESubInjectorBase(ExtendedInjector):

@value
Expand Down
Loading
Loading