diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 8f690fc9b..2a93f1226 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -13,7 +13,6 @@ from brevitas.nn import QuantHardTanh from brevitas.nn import QuantLinear from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL -from brevitas.nn.utils import compute_channel_view_shape from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 5ce256515..ad6f24118 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -17,9 +17,9 @@ from brevitas.common import ExportMixin from brevitas.inject import ExtendedInjector from brevitas.inject import Injector -from brevitas.nn.utils import compute_channel_view_shape from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import QuantTensor +from brevitas.utils.torch_utils import compute_channel_view_shape from .utils import filter_kwargs diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index 5af432af7..09da76528 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -203,11 +203,15 @@ def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Option else: return self._conv_forward(x, quant_weight, quant_bias) - def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor): - max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) - max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) - group_size = self.in_channels // self.groups - kernel_size = self.kernel_size[0] * self.kernel_size[1] - max_uint_output = max_uint_input * max_kernel_val * kernel_size * group_size - max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) - return max_output_bit_width + # def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor): + # max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) + # max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) + # group_size = self.in_channels // self.groups + # kernel_size = self.kernel_size[0] * self.kernel_size[1] + # max_uint_output = max_uint_input * max_kernel_val * kernel_size * group_size + # max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) + # return max_output_bit_width + # def max_acc_bit_width(self, input, weight): + # input_cls = type(input) + # max_output_bit_width = input_cls.max_acc_bit_width(input, weight, self.in_features) + # return max_output_bit_width diff --git a/src/brevitas/nn/quant_convtranspose.py b/src/brevitas/nn/quant_convtranspose.py index 64fbe8eb6..ba04c0161 100644 --- a/src/brevitas/nn/quant_convtranspose.py +++ b/src/brevitas/nn/quant_convtranspose.py @@ -100,7 +100,14 @@ def compute_output_padding(self, inp, output_size): def conv_transpose1d_zeros_pad( self, x: Tensor, weight: Tensor, bias: Optional[Tensor], output_padding): out = conv_transpose1d( - x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + x, + weight, + bias, + stride=self.stride, + padding=self.padding, + output_padding=output_padding, + groups=self.groups, + dilation=self.dilation) return out def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]): @@ -197,7 +204,14 @@ def compute_output_padding(self, inp, output_size): def conv_transpose2d_zeros_pad( self, x: Tensor, weight: Tensor, bias: Optional[Tensor], output_padding): out = conv_transpose2d( - x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + x, + weight, + bias, + stride=self.stride, + padding=self.padding, + output_padding=output_padding, + groups=self.groups, + dilation=self.dilation) return out def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]): diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 110c1a394..5d9b707c1 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -9,12 +9,11 @@ from torch import Tensor from torch.nn import Module -from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import QuantTensor +from brevitas.utils.torch_utils import compute_channel_view_shape from .mixin import * from .mixin.base import _CachedIO -from .utils import compute_channel_view_shape from .utils import merge_bn from .utils import rename_state_dict_by_prefix @@ -56,7 +55,7 @@ def forward(self, input: Union[Tensor, QuantTensor]): quant_input = self.input_quant(input) # shortcut execution through the export impl during export if self.export_mode: - out = self.export_handler(_unpack_quant_tensor(quant_input)) + out = self.export_handler(quant_input) self._set_global_is_quant_layer(False) return out out = self.act_quant(quant_input) @@ -130,7 +129,8 @@ def max_acc_bit_width(self, input_bit_width: Tensor, quant_weight_bit_width: Ten def quant_output_scale_impl( self, inp: Tensor, quant_input_scale: Tensor, quant_weight_scale: Tensor): - output_scale_shape = compute_channel_view_shape(inp, channel_dim=1) + channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1 + output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim) output_scale = quant_weight_scale.view(output_scale_shape) output_scale = output_scale * quant_input_scale.view(output_scale_shape) return output_scale @@ -149,16 +149,12 @@ def merge_bn_in(self, bn): merge_bn(self, bn, output_channel_dim=self.output_channel_dim) def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: - output_scale = None - output_bit_width = None - output_zero_point = None - output_signed = None inp = self.unpack_input(inp) # shortcut execution through the export impl during export if self.export_mode: - out = self.export_handler(_unpack_quant_tensor(inp)) + out = self.export_handler(inp) self._set_global_is_quant_layer(False) return out @@ -172,60 +168,25 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe self.is_output_quant_enabled) and self.return_quant_tensor: raise RuntimeError("QuantLayer is not correctly configured") + output_scale = None if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): - output_bit_width = self.max_acc_bit_width(quant_input.bit_width, quant_weight.bit_width) output_scale = self.quant_output_scale_impl(inp, quant_input.scale, quant_weight.scale) - output_signed = quant_input.signed or quant_weight.signed if self.bias is not None: quant_bias = self.bias_quant(self.bias, output_scale) if not self.training and self.cache_inference_quant_bias and isinstance(quant_bias, QuantTensor): - self.bias_quant._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False) - output_tensor = self.inner_forward_impl( - _unpack_quant_tensor(quant_input), - _unpack_quant_tensor(quant_weight), - _unpack_quant_tensor(quant_bias)) - - if output_scale is not None: - if (isinstance(quant_bias, QuantTensor) and - quant_bias.scale.data_ptr() != output_scale.data_ptr()) or not isinstance( - quant_bias, QuantTensor): - channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1 - output_scale_broadcast_shape = compute_channel_view_shape( - inp, channel_dim=channel_dim) - output_zero_point = -_unpack_quant_tensor(quant_bias).view( - output_scale_broadcast_shape) / output_scale - - if output_bit_width is not None and isinstance(quant_bias, QuantTensor): - output_bit_width = torch.where( - quant_bias.bit_width > output_bit_width, quant_bias.bit_width, output_bit_width) - output_bit_width = output_bit_width + 1 + self._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False) + output_tensor = self.inner_forward_impl(quant_input, quant_weight, quant_bias) + else: - output_tensor = self.inner_forward_impl( - _unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), None) + output_tensor = self.inner_forward_impl(quant_input, quant_weight, None) if not self.is_output_quant_enabled and self.return_quant_tensor: if compute_output_quant_tensor: if (quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any(): raise RuntimeError( "Computing zero point of output accumulator not supported yet.") - elif output_zero_point is None: - output_zero_point = quant_input.zero_point - - elif output_zero_point is None: - output_zero_point = torch.zeros(1).type_as(output_tensor) - - if compute_output_quant_tensor: - quant_output = QuantTensor( - output_tensor, - scale=output_scale, - zero_point=output_zero_point, - bit_width=output_bit_width, - signed=output_signed, - training=self.training) - else: - quant_output = output_tensor - quant_output = self.output_quant(quant_output) + quant_output = self.output_quant(output_tensor) return self.pack_output(quant_output) diff --git a/src/brevitas/nn/utils.py b/src/brevitas/nn/utils.py index ed5e87302..fccfbc4de 100644 --- a/src/brevitas/nn/utils.py +++ b/src/brevitas/nn/utils.py @@ -2,17 +2,9 @@ # SPDX-License-Identifier: BSD-3-Clause import torch -from torch import Tensor from torch.nn import Parameter -from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector -from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector - - -def compute_channel_view_shape(tensor: Tensor, channel_dim: int): - broadcast_shape = [1] * len(tensor.size()) - broadcast_shape[channel_dim] = -1 - return tuple(broadcast_shape) +from brevitas.utils.torch_utils import compute_channel_view_shape def mul_add_from_bn(bn_mean, bn_var, bn_eps, bn_weight, bn_bias): @@ -23,6 +15,8 @@ def mul_add_from_bn(bn_mean, bn_var, bn_eps, bn_weight, bn_bias): def merge_bn(layer, bn, output_channel_dim=0): + from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector + from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector out = mul_add_from_bn( bn_mean=bn.running_mean, bn_var=bn.running_var, diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index c66690c50..201f33109 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -77,7 +77,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} if (func not in QUANT_TENSOR_FN_HANDLER or - not all(issubclass(t, QuantTensor) for t in types)): + not any(issubclass(t, QuantTensor) for t in types)): args = _unpack_quant_tensor(args) kwargs = _unpack_quant_tensor(kwargs) return func(*args, **kwargs) @@ -443,3 +443,19 @@ def __abs__(self): def __pos__(self): return self + + @classmethod + def max_acc_bit_width(cls, *args): + + def _max_int_or_tensor(args): + if isinstance(args, QuantTensor): + return max_int(bit_width=args.bit_width, signed=False, narrow_range=False) + else: + return args + + args = map(_max_int_or_tensor, args) + res = 1 + for arg in args: + res *= arg + res = ceil_ste(torch.log2(res)) + return res diff --git a/src/brevitas/quant_tensor/torch_handler.py b/src/brevitas/quant_tensor/torch_handler.py index 1b6e43a37..48a028787 100644 --- a/src/brevitas/quant_tensor/torch_handler.py +++ b/src/brevitas/quant_tensor/torch_handler.py @@ -2,11 +2,15 @@ # SPDX-License-Identifier: BSD-3-Clause import functools +import math import torch import torch.nn.functional as F import brevitas +from brevitas.function.ops import max_int +from brevitas.function.ops_ste import ceil_ste +from brevitas.utils.torch_utils import compute_channel_view_shape QUANT_TENSOR_FN_HANDLER = {} @@ -156,3 +160,175 @@ def pixel_shuffle_handler(*args, **kwargs): @implements(F.pixel_unshuffle) def pixel_unshuffle_handler(*args, **kwargs): return quant_invariant_handler(F.pixel_unshuffle, *args, **kwargs) + + +@implements(F.conv2d) +def conv2d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv2d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements(F.conv1d) +def conv1d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv1d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements(F.conv_transpose2d) +def conv_transpose2d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv_transpose2d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements(F.conv_transpose1d) +def conv_transpose1d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv_transpose1d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements(F.linear) +def linear_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.linear, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +def quant_layer(cls, quant_input, quant_weight, bias, *args, **kwargs): + from brevitas.quant_tensor import _unpack_quant_tensor + from brevitas.quant_tensor import QuantTensor + + output_scale = None + output_bit_width = None + output_zero_point = None + output_signed = None + max_acc_bit_width = IMPLS[cls] + + compute_output_quant_tensor = isinstance(quant_input, QuantTensor) and isinstance( + quant_weight, QuantTensor) + + if bias is None: + output = cls( + _unpack_quant_tensor(quant_input), + _unpack_quant_tensor(quant_weight), + None, + *args, + **kwargs) + else: + output = cls( + _unpack_quant_tensor(quant_input), + _unpack_quant_tensor(quant_weight), + _unpack_quant_tensor(bias), + *args, + **kwargs) + + if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): + output_bit_width = max_acc_bit_width( + quant_input.bit_width, + quant_weight.bit_width, + quant_weight.value.shape, + *args, + **kwargs) + output_scale = quant_output_scale_impl( + cls, quant_input.value, quant_input.scale, quant_weight.scale) + output_signed = quant_input.signed or quant_weight.signed + output_training = quant_input.training or quant_weight.training + + if bias is not None: + if output_scale is not None: + if (isinstance(bias, QuantTensor) and + not torch.allclose(bias.scale, output_scale)) or not isinstance(bias, + QuantTensor): + channel_dim = -1 if isinstance(cls, torch.nn.Linear) else 1 + output_scale_broadcast_shape = compute_channel_view_shape( + quant_input, channel_dim=channel_dim) + output_zero_point = -_unpack_quant_tensor(bias).view( + output_scale_broadcast_shape) / output_scale + if output_bit_width is not None and isinstance(bias, QuantTensor): + output_bit_width = torch.where( + bias.bit_width > output_bit_width, bias.bit_width, output_bit_width) + output_bit_width = output_bit_width + 1 + + if compute_output_quant_tensor: + if output_zero_point is None: + output_zero_point = torch.zeros(1).type_as(output) + return create_quant_tensor( + output, + output_scale, + output_bit_width, + output_zero_point, + output_signed, + output_training) + else: + return output + + +def create_quant_tensor(tensor, scale, bit_width, zero_point, signed, training): + from brevitas.quant_tensor import QuantTensor + return QuantTensor( + tensor, + scale=scale, + zero_point=zero_point, + bit_width=bit_width, + signed=signed, + training=training) + + +def quant_output_scale_impl(cls, inp, quant_input_scale, quant_weight_scale): + channel_dim = -1 if cls == F.linear else 1 + output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim) + output_scale = quant_weight_scale.view(output_scale_shape) + output_scale = output_scale * quant_input_scale.view(output_scale_shape) + return output_scale + + +def max_acc_bit_width_convnd(input_bit_width, weight_bit_width, weight_shape, *args, **kwargs): + max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) + max_kernel_val = max_int(bit_width=weight_bit_width, signed=False, narrow_range=False) + in_channel = weight_shape[1] + kernel_size = math.prod(weight_shape[2:]) + max_uint_output = max_uint_input * max_kernel_val * kernel_size * in_channel + max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) + return max_output_bit_width + + +def max_acc_bit_width_linear(input_bit_width, weight_bit_width, weight_shape, *args, **kwargs): + max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) + max_kernel_val = max_int(bit_width=weight_bit_width, signed=False, narrow_range=False) + in_channel = weight_shape[1] + max_uint_output = max_uint_input * max_kernel_val * in_channel + max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) + return max_output_bit_width + + +def max_acc_bit_width_convtranspose1d( + input_bit_width, weight_bit_width, weight_shape, *args, **kwargs): + stride = kwargs['stride'] + max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) + max_kernel_val = max_int(bit_width=weight_bit_width, signed=False, narrow_range=False) + out_channel = weight_shape[1] + kernel_shape = weight_shape[2:] + overlapping_sums = max(round(kernel_shape[0] / stride[0]), 1) + max_uint_output = max_uint_input * max_kernel_val * overlapping_sums * out_channel + max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) + return max_output_bit_width + + +def max_acc_bit_width_convtranspose2d( + input_bit_width, weight_bit_width, weight_shape, *args, **kwargs): + stride = kwargs['stride'] + max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) + max_kernel_val = max_int(bit_width=weight_bit_width, signed=False, narrow_range=False) + out_channel = weight_shape[1] + kernel_shape = weight_shape[2:] + overlapping_sums = max(round(kernel_shape[0] / stride[0]), 1) + overlapping_sums += max(round(kernel_shape[1] / stride[1]), 1) + max_uint_output = max_uint_input * max_kernel_val * overlapping_sums * out_channel + max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) + return max_output_bit_width + + +IMPLS = { + F.conv2d: max_acc_bit_width_convnd, + F.conv1d: max_acc_bit_width_convnd, + F.linear: max_acc_bit_width_linear, + F.conv_transpose2d: max_acc_bit_width_convtranspose2d, + F.conv_transpose1d: max_acc_bit_width_convtranspose1d} diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index ec7d6fac4..9392c001d 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -80,3 +80,9 @@ def kthvalue( if x.dtype != dtype: x = x.type(dtype) return (x, indices) + + +def compute_channel_view_shape(tensor: torch.Tensor, channel_dim: int): + broadcast_shape = [1] * len(tensor.size()) + broadcast_shape[channel_dim] = -1 + return tuple(broadcast_shape)