Skip to content

Commit

Permalink
Feat: functionalize QuantTensor
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 28, 2024
1 parent f5cc575 commit 0176483
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 72 deletions.
1 change: 0 additions & 1 deletion src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from brevitas.nn import QuantHardTanh
from brevitas.nn import QuantLinear
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
from brevitas.nn.utils import compute_channel_view_shape
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from brevitas.common import ExportMixin
from brevitas.inject import ExtendedInjector
from brevitas.inject import Injector
from brevitas.nn.utils import compute_channel_view_shape
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.torch_utils import compute_channel_view_shape

from .utils import filter_kwargs

Expand Down
20 changes: 12 additions & 8 deletions src/brevitas/nn/quant_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,15 @@ def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Option
else:
return self._conv_forward(x, quant_weight, quant_bias)

def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
group_size = self.in_channels // self.groups
kernel_size = self.kernel_size[0] * self.kernel_size[1]
max_uint_output = max_uint_input * max_kernel_val * kernel_size * group_size
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width
# def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor):
# max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
# max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
# group_size = self.in_channels // self.groups
# kernel_size = self.kernel_size[0] * self.kernel_size[1]
# max_uint_output = max_uint_input * max_kernel_val * kernel_size * group_size
# max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
# return max_output_bit_width
# def max_acc_bit_width(self, input, weight):
# input_cls = type(input)
# max_output_bit_width = input_cls.max_acc_bit_width(input, weight, self.in_features)
# return max_output_bit_width
18 changes: 16 additions & 2 deletions src/brevitas/nn/quant_convtranspose.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,14 @@ def compute_output_padding(self, inp, output_size):
def conv_transpose1d_zeros_pad(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor], output_padding):
out = conv_transpose1d(
x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
x,
weight,
bias,
stride=self.stride,
padding=self.padding,
output_padding=output_padding,
groups=self.groups,
dilation=self.dilation)
return out

def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]):
Expand Down Expand Up @@ -197,7 +204,14 @@ def compute_output_padding(self, inp, output_size):
def conv_transpose2d_zeros_pad(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor], output_padding):
out = conv_transpose2d(
x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
x,
weight,
bias,
stride=self.stride,
padding=self.padding,
output_padding=output_padding,
groups=self.groups,
dilation=self.dilation)
return out

def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]):
Expand Down
61 changes: 11 additions & 50 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
from torch import Tensor
from torch.nn import Module

from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.torch_utils import compute_channel_view_shape

from .mixin import *
from .mixin.base import _CachedIO
from .utils import compute_channel_view_shape
from .utils import merge_bn
from .utils import rename_state_dict_by_prefix

Expand Down Expand Up @@ -56,7 +55,7 @@ def forward(self, input: Union[Tensor, QuantTensor]):
quant_input = self.input_quant(input)
# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(_unpack_quant_tensor(quant_input))
out = self.export_handler(quant_input)
self._set_global_is_quant_layer(False)
return out
out = self.act_quant(quant_input)
Expand Down Expand Up @@ -130,7 +129,8 @@ def max_acc_bit_width(self, input_bit_width: Tensor, quant_weight_bit_width: Ten

def quant_output_scale_impl(
self, inp: Tensor, quant_input_scale: Tensor, quant_weight_scale: Tensor):
output_scale_shape = compute_channel_view_shape(inp, channel_dim=1)
channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1
output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim)
output_scale = quant_weight_scale.view(output_scale_shape)
output_scale = output_scale * quant_input_scale.view(output_scale_shape)
return output_scale
Expand All @@ -149,16 +149,12 @@ def merge_bn_in(self, bn):
merge_bn(self, bn, output_channel_dim=self.output_channel_dim)

def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
output_scale = None
output_bit_width = None
output_zero_point = None
output_signed = None

inp = self.unpack_input(inp)

# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(_unpack_quant_tensor(inp))
out = self.export_handler(inp)
self._set_global_is_quant_layer(False)
return out

Expand All @@ -172,60 +168,25 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
self.is_output_quant_enabled) and self.return_quant_tensor:
raise RuntimeError("QuantLayer is not correctly configured")

output_scale = None
if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor):
output_bit_width = self.max_acc_bit_width(quant_input.bit_width, quant_weight.bit_width)
output_scale = self.quant_output_scale_impl(inp, quant_input.scale, quant_weight.scale)
output_signed = quant_input.signed or quant_weight.signed

if self.bias is not None:
quant_bias = self.bias_quant(self.bias, output_scale)
if not self.training and self.cache_inference_quant_bias and isinstance(quant_bias,
QuantTensor):
self.bias_quant._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False)
output_tensor = self.inner_forward_impl(
_unpack_quant_tensor(quant_input),
_unpack_quant_tensor(quant_weight),
_unpack_quant_tensor(quant_bias))

if output_scale is not None:
if (isinstance(quant_bias, QuantTensor) and
quant_bias.scale.data_ptr() != output_scale.data_ptr()) or not isinstance(
quant_bias, QuantTensor):
channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1
output_scale_broadcast_shape = compute_channel_view_shape(
inp, channel_dim=channel_dim)
output_zero_point = -_unpack_quant_tensor(quant_bias).view(
output_scale_broadcast_shape) / output_scale

if output_bit_width is not None and isinstance(quant_bias, QuantTensor):
output_bit_width = torch.where(
quant_bias.bit_width > output_bit_width, quant_bias.bit_width, output_bit_width)
output_bit_width = output_bit_width + 1
self._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False)
output_tensor = self.inner_forward_impl(quant_input, quant_weight, quant_bias)

else:
output_tensor = self.inner_forward_impl(
_unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), None)
output_tensor = self.inner_forward_impl(quant_input, quant_weight, None)

if not self.is_output_quant_enabled and self.return_quant_tensor:
if compute_output_quant_tensor:
if (quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any():
raise RuntimeError(
"Computing zero point of output accumulator not supported yet.")
elif output_zero_point is None:
output_zero_point = quant_input.zero_point

elif output_zero_point is None:
output_zero_point = torch.zeros(1).type_as(output_tensor)

if compute_output_quant_tensor:
quant_output = QuantTensor(
output_tensor,
scale=output_scale,
zero_point=output_zero_point,
bit_width=output_bit_width,
signed=output_signed,
training=self.training)
else:
quant_output = output_tensor

quant_output = self.output_quant(quant_output)
quant_output = self.output_quant(output_tensor)
return self.pack_output(quant_output)
12 changes: 3 additions & 9 deletions src/brevitas/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,9 @@
# SPDX-License-Identifier: BSD-3-Clause

import torch
from torch import Tensor
from torch.nn import Parameter

from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector


def compute_channel_view_shape(tensor: Tensor, channel_dim: int):
broadcast_shape = [1] * len(tensor.size())
broadcast_shape[channel_dim] = -1
return tuple(broadcast_shape)
from brevitas.utils.torch_utils import compute_channel_view_shape


def mul_add_from_bn(bn_mean, bn_var, bn_eps, bn_weight, bn_bias):
Expand All @@ -23,6 +15,8 @@ def mul_add_from_bn(bn_mean, bn_var, bn_eps, bn_weight, bn_bias):


def merge_bn(layer, bn, output_channel_dim=0):
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
out = mul_add_from_bn(
bn_mean=bn.running_mean,
bn_var=bn.running_var,
Expand Down
18 changes: 17 additions & 1 deletion src/brevitas/quant_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if (func not in QUANT_TENSOR_FN_HANDLER or
not all(issubclass(t, QuantTensor) for t in types)):
not any(issubclass(t, QuantTensor) for t in types)):
args = _unpack_quant_tensor(args)
kwargs = _unpack_quant_tensor(kwargs)
return func(*args, **kwargs)
Expand Down Expand Up @@ -443,3 +443,19 @@ def __abs__(self):

def __pos__(self):
return self

@classmethod
def max_acc_bit_width(cls, *args):

def _max_int_or_tensor(args):
if isinstance(args, QuantTensor):
return max_int(bit_width=args.bit_width, signed=False, narrow_range=False)
else:
return args

args = map(_max_int_or_tensor, args)
res = 1
for arg in args:
res *= arg
res = ceil_ste(torch.log2(res))
return res
Loading

0 comments on commit 0176483

Please sign in to comment.