Skip to content

Commit

Permalink
Attempt fix for quant scaled bias
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Mar 28, 2024
1 parent e372ad4 commit dbdb7e3
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 46 deletions.
6 changes: 0 additions & 6 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,5 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
else:
output_tensor = self.inner_forward_impl(quant_input, quant_weight, None)

if not self.output_quant.is_quant_enabled and self.return_quant_tensor:
if compute_output_quant_tensor:
if (quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any():
raise RuntimeError(
"Computing zero point of output accumulator not supported yet.")

quant_output = self.output_quant(output_tensor)
return self.pack_output(quant_output)
23 changes: 11 additions & 12 deletions src/brevitas/nn/quant_scale_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,17 @@ def biased_mul(input, weight, bias):
out += bias
return out

# TODO: when implementing new types of QuantTensor, this should be revised
if isinstance(input, QuantTensor):
from brevitas.quant_tensor.torch_handler import quant_layer

output_tensor = quant_layer(
biased_mul,
input,
quant_weight,
bias=quant_bias,
external_acc_bit_width_fn=self.max_acc_bit_width)
else:
output_tensor = biased_mul(input, quant_weight, quant_bias)
# # TODO: when implementing new types of QuantTensor, this should be revised
# if isinstance(input, QuantTensor):
# from brevitas.quant_tensor.torch_handler import quant_layer

# output_tensor = quant_layer(
# biased_mul,
# input,
# quant_weight,
# bias=quant_bias)
# else:
output_tensor = biased_mul(input, quant_weight, quant_bias)

return output_tensor

Expand Down
12 changes: 6 additions & 6 deletions src/brevitas/quant_tensor/base_quant_tensor.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import NamedTuple, Optional
from typing import NamedTuple

from torch import Tensor


class QuantTensorBase(NamedTuple):
value: Tensor
scale: Optional[Tensor]
zero_point: Optional[Tensor]
bit_width: Optional[Tensor]
signed_t: Optional[Tensor]
training_t: Optional[Tensor]
scale: Tensor
zero_point: Tensor
bit_width: Tensor
signed_t: Tensor
training_t: Tensor


def _unpack_quant_tensor(input_data):
Expand Down
31 changes: 9 additions & 22 deletions src/brevitas/quant_tensor/int_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,15 @@ def __add__(self, other):
bit_width=output_bit_width,
signed=output_signed,
training=output_training)
elif isinstance(other, QuantTensor):
output = self.value + other.value
elif self.value.shape == other.shape:
output = QuantTensor(
value=self.value + other,
scale=output_scale,
zero_point=self.zero_point - other / self.scale,
bit_width=output_bit_width,
signed=output_signed,
training=output_training)

else:
output = self.value + other
return output
Expand Down Expand Up @@ -362,8 +369,6 @@ def __mul__(self, other):
bit_width=output_bit_width,
signed=output_signed,
training=output_training)
elif isinstance(other, QuantTensor):
output = self.value * other.value
else:
output = self.value * other
return output
Expand Down Expand Up @@ -393,8 +398,6 @@ def __truediv__(self, other):
bit_width=output_bit_width,
signed=output_signed,
training=output_training)
elif isinstance(other, QuantTensor):
output = self.value / other.value
else:
output = self.value / other
return output
Expand All @@ -416,19 +419,3 @@ def __abs__(self):

def __pos__(self):
return self

@classmethod
def max_acc_bit_width(cls, *args):

def _max_int_or_tensor(args):
if isinstance(args, QuantTensor):
return max_int(bit_width=args.bit_width, signed=False, narrow_range=False)
else:
return args

args = map(_max_int_or_tensor, args)
res = 1
for arg in args:
res *= arg
res = ceil_ste(torch.log2(res))
return res
7 changes: 7 additions & 0 deletions src/brevitas/quant_tensor/torch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import functools
import math
import warnings

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -309,9 +310,15 @@ def quant_layer(
bias.bit_width > output_bit_width, bias.bit_width, output_bit_width)
output_bit_width = output_bit_width + 1

if compute_output_quant_tensor and (quant_input.zero_point != 0.0).any() or (
quant_weight.zero_point != 0.0).any():
warnings.warn("Computing zero point of output accumulator not supported yet.")
compute_output_quant_tensor = False

if compute_output_quant_tensor:
if output_zero_point is None:
output_zero_point = torch.zeros(1).type_as(output)

return create_quant_tensor(
output,
output_scale,
Expand Down

0 comments on commit dbdb7e3

Please sign in to comment.