Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: functionalize QuantTensor #878

Merged
merged 3 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from brevitas.nn import QuantHardTanh
from brevitas.nn import QuantLinear
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
from brevitas.nn.utils import compute_channel_view_shape
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from brevitas.common import ExportMixin
from brevitas.inject import ExtendedInjector
from brevitas.inject import Injector
from brevitas.nn.utils import compute_channel_view_shape
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.torch_utils import compute_channel_view_shape

from .utils import filter_kwargs

Expand Down Expand Up @@ -86,7 +86,7 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
def pack_output(self, quant_output: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
self._set_global_is_quant_layer(False)
if self.return_quant_tensor:
assert isinstance(quant_output, QuantTensor)
assert isinstance(quant_output, QuantTensor), 'QuantLayer is not correctly configured, check if warnings were raised'
return quant_output
else:
return _unpack_quant_tensor(quant_output)
Expand Down
44 changes: 9 additions & 35 deletions src/brevitas/nn/quant_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,38 +47,19 @@ def channelwise_separable(self) -> bool:
def requires_export_handler(self):
return True

@property
def _avg_scaling(self):
if isinstance(self.kernel_size, tuple):
return self.kernel_size[0] * self.kernel_size[1]
else:
return self.kernel_size * self.kernel_size

def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)

if self.export_mode:
return self.export_handler(_unpack_quant_tensor(x))

if isinstance(x, QuantTensor):
x = x.set(value=super(TruncAvgPool2d, self).forward(x.value))
if self.is_trunc_quant_enabled:
# remove avg scaling
rescaled_value = x.value * self._avg_scaling
x = x.set(value=rescaled_value)
x = x.set(bit_width=self.max_acc_bit_width(x.bit_width))
x = self.trunc_quant(x)
if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled:
y = AvgPool2d.forward(self, x)
y = self.trunc_quant(y)
else:
assert not self.is_trunc_quant_enabled
x = super(TruncAvgPool2d, self).forward(x)

return self.pack_output(x)
y = AvgPool2d.forward(self, _unpack_quant_tensor(x))

def max_acc_bit_width(self, input_bit_width):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_uint_output = max_uint_input * self._avg_scaling
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width
return self.pack_output(y)


class TruncAdaptiveAvgPool2d(TruncMixin, QuantLayerMixin, AdaptiveAvgPool2d):
Expand Down Expand Up @@ -130,18 +111,11 @@ def forward(self, input: Union[Tensor, QuantTensor]):
self._set_global_is_quant_layer(False)
return out

if isinstance(x, QuantTensor):
y = x.set(value=super(TruncAdaptiveAvgPool2d, self).forward(x.value))
k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:])
if self.is_trunc_quant_enabled:
reduce_size = reduce(mul, k_size, 1)
rescaled_value = y.value * reduce_size # remove avg scaling
y = y.set(value=rescaled_value)
y = y.set(bit_width=self.max_acc_bit_width(y.bit_width, reduce_size))
y = self.trunc_quant(y)
if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled:
y = AdaptiveAvgPool2d.forward(self, x)
y = self.trunc_quant(y)
else:
assert not self.is_trunc_quant_enabled
y = super(TruncAdaptiveAvgPool2d, self).forward(x)
y = AdaptiveAvgPool2d.forward(self, _unpack_quant_tensor(x))

return self.pack_output(y)

Expand Down
8 changes: 0 additions & 8 deletions src/brevitas/nn/quant_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,6 @@ def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Option
else:
return self._conv_forward(x, quant_weight, quant_bias)

def max_acc_bit_width(self, input_bit_width, weight_bit_width):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
group_size = self.in_channels // self.groups
max_uint_output = max_uint_input * max_kernel_val * self.kernel_size[0] * group_size
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width


class QuantConv2d(QuantWBIOL, Conv2d):

Expand Down
27 changes: 24 additions & 3 deletions src/brevitas/nn/quant_convtranspose.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,14 @@ def compute_output_padding(self, inp, output_size):
def conv_transpose1d_zeros_pad(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor], output_padding):
out = conv_transpose1d(
x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
x,
weight,
bias,
stride=self.stride,
padding=self.padding,
output_padding=output_padding,
groups=self.groups,
dilation=self.dilation)
return out

def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]):
Expand Down Expand Up @@ -200,7 +207,14 @@ def compute_output_padding(self, inp, output_size):
def conv_transpose2d_zeros_pad(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor], output_padding):
out = conv_transpose2d(
x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
x,
weight,
bias,
stride=self.stride,
padding=self.padding,
output_padding=output_padding,
groups=self.groups,
dilation=self.dilation)
return out

def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]):
Expand Down Expand Up @@ -298,7 +312,14 @@ def compute_output_padding(self, inp, output_size):
def conv_transpose3d_zeros_pad(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor], output_padding):
out = conv_transpose3d(
x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
x,
weight,
bias,
stride=self.stride,
padding=self.padding,
output_padding=output_padding,
groups=self.groups,
dilation=self.dilation)
return out

def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]):
Expand Down
65 changes: 9 additions & 56 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
from torch import Tensor
from torch.nn import Module

from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.torch_utils import compute_channel_view_shape

from .mixin import *
from .utils import compute_channel_view_shape
from .utils import merge_bn
from .utils import rename_state_dict_by_prefix

Expand Down Expand Up @@ -47,7 +46,7 @@ def forward(self, input: Union[Tensor, QuantTensor]):
quant_input = self.input_quant(input)
# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(_unpack_quant_tensor(quant_input))
out = self.export_handler(quant_input)
self._set_global_is_quant_layer(False)
return out
out = self.act_quant(quant_input)
Expand Down Expand Up @@ -121,7 +120,8 @@ def max_acc_bit_width(self, input_bit_width: Tensor, quant_weight_bit_width: Ten

def quant_output_scale_impl(
self, inp: Tensor, quant_input_scale: Tensor, quant_weight_scale: Tensor):
output_scale_shape = compute_channel_view_shape(inp, channel_dim=1)
channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1
output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim)
output_scale = quant_weight_scale.view(output_scale_shape)
output_scale = output_scale * quant_input_scale.view(output_scale_shape)
return output_scale
Expand All @@ -140,16 +140,12 @@ def merge_bn_in(self, bn):
merge_bn(self, bn, output_channel_dim=self.output_channel_dim)

def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
output_scale = None
output_bit_width = None
output_zero_point = None
output_signed = None

inp = self.unpack_input(inp)

# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(_unpack_quant_tensor(inp))
out = self.export_handler(inp)
self._set_global_is_quant_layer(False)
return out

Expand All @@ -163,58 +159,15 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
self.output_quant.is_quant_enabled) and self.return_quant_tensor:
raise RuntimeError("QuantLayer is not correctly configured")

output_scale = None
if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor):
output_bit_width = self.max_acc_bit_width(quant_input.bit_width, quant_weight.bit_width)
output_scale = self.quant_output_scale_impl(inp, quant_input.scale, quant_weight.scale)
output_signed = quant_input.signed or quant_weight.signed

if self.bias is not None:
quant_bias = self.bias_quant(self.bias, output_scale)

output_tensor = self.inner_forward_impl(
_unpack_quant_tensor(quant_input),
_unpack_quant_tensor(quant_weight),
_unpack_quant_tensor(quant_bias))

if output_scale is not None:
if (isinstance(quant_bias, QuantTensor) and
quant_bias.scale.data_ptr() != output_scale.data_ptr()) or not isinstance(
quant_bias, QuantTensor):
channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1
output_scale_broadcast_shape = compute_channel_view_shape(
inp, channel_dim=channel_dim)
output_zero_point = -_unpack_quant_tensor(quant_bias).view(
output_scale_broadcast_shape) / output_scale

if output_bit_width is not None and isinstance(quant_bias, QuantTensor):
output_bit_width = torch.where(
quant_bias.bit_width > output_bit_width, quant_bias.bit_width, output_bit_width)
output_bit_width = output_bit_width + 1
else:
output_tensor = self.inner_forward_impl(
_unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), None)

if not self.output_quant.is_quant_enabled and self.return_quant_tensor:
if compute_output_quant_tensor:
if (quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any():
raise RuntimeError(
"Computing zero point of output accumulator not supported yet.")
elif output_zero_point is None:
output_zero_point = quant_input.zero_point

elif output_zero_point is None:
output_zero_point = torch.zeros(1).type_as(output_tensor)

if compute_output_quant_tensor:
quant_output = QuantTensor(
output_tensor,
scale=output_scale,
zero_point=output_zero_point,
bit_width=output_bit_width,
signed=output_signed,
training=self.training)
else:
quant_output = output_tensor
quant_bias = None
output_tensor = self.inner_forward_impl(quant_input, quant_weight, quant_bias)

quant_output = self.output_quant(quant_output)
quant_output = self.output_quant(output_tensor)
return self.pack_output(quant_output)
7 changes: 0 additions & 7 deletions src/brevitas/nn/quant_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,3 @@ def quant_output_scale_impl(
quant_weight_scale = quant_weight_scale.view(weight_broadcast_shape)
quant_output_scale = linear(quant_input_scale, quant_weight_scale)
return quant_output_scale

def max_acc_bit_width(self, input_bit_width, weight_bit_width):
max_input_val = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_fc_val = self.weight_quant.max_uint_value(weight_bit_width)
max_output_val = max_input_val * max_fc_val * self.in_features
output_bit_width = ceil_ste(torch.log2(max_output_val))
return output_bit_width
23 changes: 14 additions & 9 deletions src/brevitas/nn/quant_scale_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,20 @@ def channelwise_separable(self) -> bool:
def forward(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
return self.forward_impl(inp)

def inner_forward_impl(self, input: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]):
def inner_forward_impl(
self,
input: Union[Tensor, QuantTensor],
quant_weight: Union[Tensor, QuantTensor],
quant_bias: Optional[Union[Tensor, QuantTensor]]):
quant_weight = quant_weight.view(self.runtime_shape)
quant_bias = quant_bias.view(self.runtime_shape)
output_tensor = input * quant_weight + quant_bias
return output_tensor

def max_acc_bit_width(self, input_bit_width, weight_bit_width):
max_input_val = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_weight_val = self.weight_quant.max_uint_value(weight_bit_width)
max_output_val = max_input_val * max_weight_val
output_bit_width = ceil_ste(torch.log2(max_output_val))
return output_bit_width
def biased_mul(input, weight, bias):
out = input * weight
if bias is not None:
out += bias
return out

output_tensor = biased_mul(input, quant_weight, quant_bias)

return output_tensor
12 changes: 3 additions & 9 deletions src/brevitas/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,9 @@
# SPDX-License-Identifier: BSD-3-Clause

import torch
from torch import Tensor
from torch.nn import Parameter

from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector


def compute_channel_view_shape(tensor: Tensor, channel_dim: int):
broadcast_shape = [1] * len(tensor.size())
broadcast_shape[channel_dim] = -1
return tuple(broadcast_shape)
from brevitas.utils.torch_utils import compute_channel_view_shape


def mul_add_from_bn(bn_mean, bn_var, bn_eps, bn_weight, bn_bias):
Expand All @@ -23,6 +15,8 @@ def mul_add_from_bn(bn_mean, bn_var, bn_eps, bn_weight, bn_bias):


def merge_bn(layer, bn, output_channel_dim=0):
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
out = mul_add_from_bn(
bn_mean=bn.running_mean,
bn_var=bn.running_var,
Expand Down
Loading
Loading