Skip to content

Commit

Permalink
Feat: remove quant metadata from quantlayer
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 28, 2024
1 parent 3be48cb commit f5cc575
Show file tree
Hide file tree
Showing 22 changed files with 211 additions and 423 deletions.
23 changes: 14 additions & 9 deletions src/brevitas/export/common/handler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from brevitas.function.ops import max_int
from brevitas.function.ops import min_int
from brevitas.nn.quant_layer import QuantNonLinearActLayer

__all__ = ['BaseHandler', 'BitWidthHandlerMixin', 'ZeroPointHandlerMixin']

Expand Down Expand Up @@ -130,21 +131,25 @@ def zero_point_with_dtype(cls, signed, bit_width, zero_point):

@classmethod
def quant_input_zero_point(cls, module):
signed = module.is_quant_input_signed
zero_point = module.quant_input_zero_point()
bit_width = module.quant_input_bit_width()
signed = module.input_quant.signed()
zero_point = module.input_quant.zero_point()
bit_width = module.input_quant.bit_width()
return cls.zero_point_with_dtype(signed, bit_width, zero_point)

@classmethod
def quant_weight_zero_point(cls, module):
signed = module.is_quant_weight_signed
zero_point = module.quant_weight_zero_point()
bit_width = module.quant_weight_bit_width()
signed = module.weight_quant.is_signed
zero_point = module.quant_weight().zero_point
bit_width = module.weight_quant.bit_width()
return cls.zero_point_with_dtype(signed, bit_width, zero_point)

@classmethod
def quant_output_zero_point(cls, module):
signed = module.is_quant_output_signed
zero_point = module.quant_output_zero_point()
bit_width = module.quant_output_bit_width()
if isinstance(module, QuantNonLinearActLayer):
quant = module.act_quant
else:
quant = module.output_quant
signed = quant.signed()
zero_point = quant.zero_point()
bit_width = quant.bit_width()
return cls.zero_point_with_dtype(signed, bit_width, zero_point)
30 changes: 15 additions & 15 deletions src/brevitas/export/onnx/standard/qoperator/handler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def torch_8b_dtype(cls, is_signed):

@classmethod
def quant_output_shape(cls, module):
cached_out = module._cached_out
cached_out = module.output_quant._cached_act
if cached_out is None:
raise RuntimeError("Caching of outputs is required to export")
return cached_out.shape
Expand All @@ -57,19 +57,19 @@ def quant_output_shape(cls, module):
def output_quant_symbolic_kwargs(cls, module):
if module.is_output_quant_enabled:
return {
'output_scale': module.quant_output_scale(),
'output_scale': module.output_quant.scale(),
'output_zero_point': cls.quant_output_zero_point(module),
'output_dtype': cls.torch_8b_dtype(module.is_quant_output_signed),
'output_axis': cls.quant_axis(module.quant_output_scale())}
'output_dtype': cls.torch_8b_dtype(module.output_quant.signed()),
'output_axis': cls.quant_axis(module.output_quant.scale())}
else:
return None

@classmethod
def output_clip_symbolic_kwargs(cls, module):
if module.is_output_quant_enabled:
narrow = module.is_quant_output_narrow_range
signed = module.is_quant_output_signed
bit_width = module.quant_output_bit_width()
signed = module.output_quant.signed()
bit_width = module.output_quant.bit_width()
return cls.int_clip_symbolic_kwargs(narrow, signed, bit_width)
else:
return None
Expand All @@ -78,34 +78,34 @@ def output_clip_symbolic_kwargs(cls, module):
def input_clip_symbolic_kwargs(cls, module):
if module.is_input_quant_enabled:
narrow = module.is_quant_input_narrow_range
signed = module.is_quant_input_signed
bit_width = module.quant_input_bit_width()
signed = module.input_quant.signed()
bit_width = module.input_quant.bit_width()
return cls.int_clip_symbolic_kwargs(narrow, signed, bit_width)
else:
return None

@classmethod
def output_dequant_symbolic_kwargs(cls, module):
return {
'input_scale': module.quant_output_scale(),
'input_scale': module.output_quant.scale(),
'input_zero_point': cls.quant_output_zero_point(module),
'input_axis': cls.quant_axis(module.quant_output_scale())}
'input_axis': cls.quant_axis(module.output_quant.scale())}

@classmethod
def input_quant_symbolic_kwargs(cls, module):
if module.is_input_quant_enabled:
return {
'output_scale': module.quant_input_scale(),
'output_scale': module.input_quant.scale(),
'output_zero_point': cls.quant_input_zero_point(module),
'output_dtype': cls.torch_8b_dtype(module.is_quant_input_signed),
'output_axis': cls.quant_axis(module.quant_input_scale())}
'output_dtype': cls.torch_8b_dtype(module.input_quant.signed()),
'output_axis': cls.quant_axis(module.input_quant.scale())}
else:
return None

@classmethod
def input_dequant_symbolic_kwargs(cls, module):
if module._cached_inp is not None:
return cls.dequant_symbolic_kwargs_from_cached_io(module._cached_inp)
if module.input_quant._cached_act is not None:
return cls.dequant_symbolic_kwargs_from_cached_io(module.input_quant._cached_act)
else:
return None

Expand Down
26 changes: 13 additions & 13 deletions src/brevitas/export/onnx/standard/qoperator/handler/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class StdQOpONNXQuantWBIOLHandler(StdQOpONNXQuantLayerHandler, ABC):
@staticmethod
def int_weight(module: QuantWBIOL):
int_weight = module.int_weight(float_datatype=False).detach()
if module.is_quant_weight_signed:
if module.weight_quant.is_signed:
return int_weight.type(torch.int8)
else:
return int_weight.type(torch.uint8)
Expand All @@ -48,13 +48,13 @@ def validate(cls, module: QuantWBIOL, requires_quant_bias=True):
assert not module.is_quant_output_narrow_range, 'Narrow output quant not supported'
if module.is_input_quant_enabled:
assert not module.is_quant_input_narrow_range, 'Narrow output quant not supported'
cls.validate_8b_bit_width(module.quant_weight_bit_width(), le_then=True)
cls.validate_8b_bit_width(module.quant_input_bit_width(), le_then=True)
cls.validate_8b_bit_width(module.quant_output_bit_width(), le_then=True)
cls.validate_8b_bit_width(module.weight_quant.bit_width(), le_then=True)
cls.validate_8b_bit_width(module.input_quant.bit_width(), le_then=True)
cls.validate_8b_bit_width(module.output_quant.bit_width(), le_then=True)
if module.bias is not None and requires_quant_bias:
assert module.is_bias_quant_enabled
assert module.is_quant_bias_signed
cls.validate_32b_bit_width(module.quant_bias_bit_width(), le_then=True)
cls.validate_32b_bit_width(module.bias_quant.bit_width(), le_then=True)

def prepare_for_export(self, module: Union[QuantConv1d, QuantConv2d]):
self.validate(module)
Expand Down Expand Up @@ -108,14 +108,14 @@ class StdQOpONNXQuantConvNdHandler(StdQOpONNXQuantWBIOLHandler, ABC):

def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d]):
conv_symbolic_kwargs = {
'input_scale': module.quant_input_scale(),
'input_scale': module.input_quant.scale(),
'input_zero_point': self.quant_input_zero_point(module),
'int_weight': self.int_weight(module),
'weight_scale': to_0dim_if_scalar(module.quant_weight_scale().flatten()),
'weight_scale': to_0dim_if_scalar(module.quant_weight().scale.flatten()),
'weight_zero_point': to_0dim_if_scalar(self.quant_weight_zero_point(module).flatten()),
'output_scale': module.quant_output_scale(),
'output_scale': module.output_quant.scale(),
'output_zero_point': self.quant_output_zero_point(module),
'output_dtype': self.torch_8b_dtype(module.is_quant_output_signed),
'output_dtype': self.torch_8b_dtype(module.output_quant.signed()),
'int_bias': self.int_bias(module),
'out_shape': self.quant_output_shape(module),
'kernel_size': list(module.kernel_size),
Expand Down Expand Up @@ -145,14 +145,14 @@ class StdQOpONNXQuantLinearHandler(StdQOpONNXQuantWBIOLHandler):
# Convert linear to conv1d to handle bias
def op_symbolic_kwargs(self, module: QuantLinear):
conv_symbolic_kwargs = {
'input_scale': module.quant_input_scale(),
'input_scale': module.input_quant.scale(),
'input_zero_point': self.quant_input_zero_point(module),
'int_weight': self.int_weight(module).view(module.out_features, module.in_features, 1),
'weight_scale': to_0dim_if_scalar(module.quant_weight_scale().flatten()),
'weight_scale': to_0dim_if_scalar(module.quant_weight().scale.flatten()),
'weight_zero_point': to_0dim_if_scalar(self.quant_weight_zero_point(module).flatten()),
'output_scale': module.quant_output_scale(),
'output_scale': module.output_quant.scale(),
'output_zero_point': self.quant_output_zero_point(module),
'output_dtype': self.torch_8b_dtype(module.is_quant_output_signed),
'output_dtype': self.torch_8b_dtype(module.output_quant.signed()),
'int_bias': self.int_bias(module),
'out_shape': self.quant_output_shape(module) + (1,),
'kernel_size': [1],
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/export/torch/qoperator/handler/act.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def explicit_output_dtype(cls) -> bool:
@classmethod
def validate(cls, module: QuantNLAL):
assert not module.is_input_quant_enabled, 'Input quantization not supported'
cls.validate_8b_bit_width(module.quant_act_bit_width(), le_then=False)
cls.validate_8b_bit_width(module.act_quant.bit_width(), le_then=False)

def prepare_for_export(self, module: QuantNLAL):
self.validate(module)
Expand Down
13 changes: 9 additions & 4 deletions src/brevitas/export/torch/qoperator/handler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from brevitas.export.common.handler.base import BaseHandler
from brevitas.export.common.handler.base import BitWidthHandlerMixin
from brevitas.export.common.handler.base import ZeroPointHandlerMixin
from brevitas.nn.quant_layer import QuantNonLinearActLayer

SCALAR_SHAPE = ()

Expand Down Expand Up @@ -55,17 +56,21 @@ def gen_quant_impl_kwargs(

@classmethod
def prepare_input_quant(cls, module):
scale = module.quant_input_scale()
scale = module.input_quant.scale()
zero_point = cls.quant_input_zero_point(module)
signed = module.is_quant_input_signed
signed = module.input_quant.signed()
quant_impl, quant_kwargs = cls.gen_quant_impl_kwargs(scale, zero_point, signed)
return quant_impl, quant_kwargs

@classmethod
def prepare_output_quant(cls, module):
scale = module.quant_output_scale()
if isinstance(module, QuantNonLinearActLayer):
quant = module.act_quant
else:
quant = module.output_quant
scale = quant.scale()
zero_point = cls.quant_output_zero_point(module)
signed = module.is_quant_output_signed
signed = quant.signed()
incl_dtype = cls.explicit_output_dtype()
quant_impl, quant_kwargs = cls.gen_quant_impl_kwargs(scale, zero_point, signed, incl_dtype)
return quant_impl, quant_kwargs
10 changes: 5 additions & 5 deletions src/brevitas/export/torch/qoperator/handler/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ def prepare_bias(cls, module: QuantWBIOL):

@classmethod
def prepare_weight_quant(cls, module: QuantWBIOL):
cls.validate_bit_width(module.quant_weight_bit_width(), 7, le_then=True)
cls.validate_8b_bit_width(module.quant_input_bit_width(), le_then=False)
cls.validate_8b_bit_width(module.quant_output_bit_width(), le_then=False)
scale = module.quant_weight_scale()
cls.validate_bit_width(module.weight_quant.bit_width(), 7, le_then=True)
cls.validate_8b_bit_width(module.input_quant.bit_width(), le_then=False)
cls.validate_8b_bit_width(module.output_quant.bit_width(), le_then=False)
scale = module.weight_quant.scale()
zero_point = cls.quant_weight_zero_point(module)
signed = module.is_quant_weight_signed
signed = module.weight_quant.is_signed
weight = module.weight.detach()
quant_impl, quant_kwargs = cls.gen_quant_impl_kwargs(scale, zero_point, signed)
return quant_impl, (weight,), quant_kwargs
Expand Down
8 changes: 4 additions & 4 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,10 @@ def process_input(self, inp):
# if the quant_input is not already cached, then get
# metadata from QuantWBIOL module
if self.quant_input is None:
inp_scale = self.layer.quant_input_scale()
inp_zero_point = self.layer.quant_input_zero_point()
inp_bit_width = self.layer.quant_input_bit_width()
inp_signed = self.layer.is_quant_input_signed
inp_scale = self.layer.input_quant.scale()
inp_zero_point = self.layer.input_quant.zero_point()
inp_bit_width = self.layer.input_quant.bit_width()
inp_signed = self.layer.input_quant.signed()
inp_training = self.layer.training

# If using quantized activations, inp could be QuantTensor. In
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,10 @@ def align_input_quant(
# If it is a QuantIdentity already, simply modify tensor_quant or the scaling implementations
# based on whether we need to align the sign or not
if isinstance(module, qnn.QuantIdentity):
if align_sign or module.is_quant_act_signed == shared_quant_identity.is_quant_act_signed:
if align_sign or module.input_quant.signed() == shared_quant_identity.input_quant.signed():
return shared_quant_identity
else:
assert not module.is_quant_act_signed and shared_quant_identity.is_quant_act_signed
assert not module.input_quant.signed() and shared_quant_identity.input_quant.signed()
quant_module_class, quant_module_kwargs = quant_identity_map['unsigned']
return (
quant_module_class,
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/graph/quantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def are_inputs_unsigned(model, node, is_unsigned_list, quant_act_map, unsigned_a
elif isinstance(inp_module, tuple(SIGN_PRESERVING_MODULES)):
are_inputs_unsigned(
model, inp_node, is_unsigned_list, quant_act_map, unsigned_act_tuple)
elif hasattr(inp_module, 'is_quant_act_signed'):
is_unsigned_list.append(not inp_module.is_quant_act_signed)
elif hasattr(inp_module, 'input_quant'):
is_unsigned_list.append(not inp_module.input_quant.signed())
else:
is_unsigned_list.append(False)
elif inp_node.op == 'call_function':
Expand Down
68 changes: 21 additions & 47 deletions src/brevitas/nn/mixin/act.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
from abc import ABCMeta
from abc import abstractmethod
from typing import Optional, Type, Union
from warnings import warn

from torch.nn import Module

from brevitas.inject import ExtendedInjector
from brevitas.inject import Injector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
from brevitas.proxy.runtime_quant import ActQuantProxyProtocol
from brevitas.quant import NoneActQuant

Expand Down Expand Up @@ -42,22 +40,10 @@ def is_input_quant_enabled(self):
def is_quant_input_narrow_range(self): # TODO make abstract once narrow range can be cached
return self.input_quant.is_narrow_range

@property
@abstractmethod
def is_quant_input_signed(self):
pass

@abstractmethod
def quant_input_scale(self):
pass

@abstractmethod
def quant_input_zero_point(self):
pass

@abstractmethod
def quant_input_bit_width(self):
pass
# @property
# @abstractmethod
# def is_quant_input_signed(self):
# pass


class QuantOutputMixin(QuantProxyMixin):
Expand All @@ -83,22 +69,10 @@ def is_output_quant_enabled(self):
def is_quant_output_narrow_range(self): # TODO make abstract once narrow range can be cached
return self.output_quant.is_narrow_range

@property
@abstractmethod
def is_quant_output_signed(self):
pass

@abstractmethod
def quant_output_scale(self):
pass

@abstractmethod
def quant_output_zero_point(self):
pass

@abstractmethod
def quant_output_bit_width(self):
pass
# @property
# @abstractmethod
# def is_quant_output_signed(self):
# pass


class QuantNonLinearActMixin(QuantProxyMixin):
Expand Down Expand Up @@ -133,19 +107,19 @@ def is_act_quant_enabled(self):
def is_quant_act_narrow_range(self): # TODO make abstract once narrow range can be cached
return self.act_quant.is_narrow_range

@property
@abstractmethod
def is_quant_act_signed(self):
pass
# @property
# @abstractmethod
# def is_quant_act_signed(self):
# pass

@abstractmethod
def quant_act_scale(self):
pass
# @abstractmethod
# def quant_act_scale(self):
# pass

@abstractmethod
def quant_act_zero_point(self):
pass
# @abstractmethod
# def quant_act_zero_point(self):
# pass

@abstractmethod
def quant_act_bit_width(self):
pass
# @abstractmethod
# def quant_act_bit_width(self):
# pass
Loading

0 comments on commit f5cc575

Please sign in to comment.