Skip to content

Commit

Permalink
Feat: remove quant metadata from quantlayer
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Mar 5, 2024
1 parent 2fdcb42 commit a36d65f
Show file tree
Hide file tree
Showing 27 changed files with 283 additions and 569 deletions.
6 changes: 3 additions & 3 deletions docs/tutorials/quant_tensor_quant_conv2d_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@
}
],
"source": [
"print(f'Is weight quant enabled: {default_quant_conv.is_weight_quant_enabled}')\n",
"print(f'Is bias quant enabled: {default_quant_conv.is_bias_quant_enabled}')\n",
"print(f'Is weight quant enabled: {default_quant_conv.weight_quant.is_quant_enabled}')\n",
"print(f'Is bias quant enabled: {default_quant_conv.bias_quant.is_quant_enabled}')\n",
"print(f'Is input quant enabled: {default_quant_conv.is_input_quant_enabled}')\n",
"print(f'Is output quant enabled: {default_quant_conv.is_output_quant_enabled}')"
"print(f'Is output quant enabled: {default_quant_conv.output_quant.is_quant_enabled}')"
]
},
{
Expand Down
42 changes: 17 additions & 25 deletions notebooks/01_quant_tensor_quant_conv2d_overview.ipynb

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions notebooks/03_anatomy_of_a_quantizer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Indeed we can verify that `quant_weight_scale()` is equal to `weight.abs().max()`:"
"Indeed we can verify that `weight_quant.scale()` is equal to `weight.abs().max()`:"
]
},
{
Expand All @@ -792,7 +792,7 @@
}
],
"source": [
"assert_with_message((param_from_max_quant_conv.quant_weight_scale() == param_from_max_quant_conv.weight.abs().max()).item())"
"assert_with_message((param_from_max_quant_conv.weight_quant.scale() == param_from_max_quant_conv.weight.abs().max()).item())"
]
},
{
Expand Down Expand Up @@ -1024,7 +1024,7 @@
}
],
"source": [
"assert_with_message((quant_conv1.quant_weight_scale() == quant_conv2.quant_weight_scale()).item())"
"assert_with_message((quant_conv1.weight_quant.scale() == quant_conv2.weight_quant.scale()).item())"
]
},
{
Expand Down Expand Up @@ -1059,9 +1059,9 @@
" return module.weight.abs().mean()\n",
" \n",
"quant_conv1 = QuantConv2d(3, 2, (3, 3), weight_quant=SharedParamFromMeanWeightQuantizer)\n",
"old_quant_conv1_scale = quant_conv1.quant_weight_scale()\n",
"old_quant_conv1_scale = quant_conv1.weight_quant.scale()\n",
"quant_conv2 = QuantConv2d(3, 2, (3, 3), weight_quant=quant_conv1.weight_quant)\n",
"new_quant_conv1_scale = quant_conv1.quant_weight_scale()\n",
"new_quant_conv1_scale = quant_conv1.weight_quant.scale()\n",
"\n",
"assert_with_message(not (old_quant_conv1_scale == new_quant_conv1_scale).item())"
]
Expand All @@ -1080,7 +1080,7 @@
}
],
"source": [
"assert_with_message((new_quant_conv1_scale == quant_conv2.quant_weight_scale()).item())"
"assert_with_message((new_quant_conv1_scale == quant_conv2.weight_quant.scale()).item())"
]
},
{
Expand Down Expand Up @@ -1134,7 +1134,7 @@
"quant_conv_w_init = QuantConv2d(3, 2, (3, 3), weight_quant=ParamFromMaxWeightQuantizer)\n",
"torch.nn.init.uniform_(quant_conv_w_init.weight)\n",
"\n",
"assert_with_message(not (quant_conv_w_init.weight.abs().max() == quant_conv_w_init.quant_weight_scale()).item())"
"assert_with_message(not (quant_conv_w_init.weight.abs().max() == quant_conv_w_init.weight_quant.scale()).item())"
]
},
{
Expand All @@ -1160,7 +1160,7 @@
"source": [
"quant_conv_w_init.weight_quant.init_tensor_quant()\n",
"\n",
"assert_with_message((quant_conv_w_init.weight.abs().max() == quant_conv_w_init.quant_weight_scale()).item())"
"assert_with_message((quant_conv_w_init.weight.abs().max() == quant_conv_w_init.weight_quant.scale()).item())"
]
},
{
Expand Down
14 changes: 7 additions & 7 deletions notebooks/Brevitas_TVMCon2021.ipynb

Large diffs are not rendered by default.

23 changes: 14 additions & 9 deletions src/brevitas/export/common/handler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from brevitas.function.ops import max_int
from brevitas.function.ops import min_int
from brevitas.nn.quant_layer import QuantNonLinearActLayer

__all__ = ['BaseHandler', 'BitWidthHandlerMixin', 'ZeroPointHandlerMixin']

Expand Down Expand Up @@ -130,21 +131,25 @@ def zero_point_with_dtype(cls, signed, bit_width, zero_point):

@classmethod
def quant_input_zero_point(cls, module):
signed = module.is_quant_input_signed
zero_point = module.quant_input_zero_point()
bit_width = module.quant_input_bit_width()
signed = module.input_quant.signed()
zero_point = module.input_quant.zero_point()
bit_width = module.input_quant.bit_width()
return cls.zero_point_with_dtype(signed, bit_width, zero_point)

@classmethod
def quant_weight_zero_point(cls, module):
signed = module.is_quant_weight_signed
zero_point = module.quant_weight_zero_point()
bit_width = module.quant_weight_bit_width()
signed = module.weight_quant.is_signed
zero_point = module.quant_weight().zero_point
bit_width = module.weight_quant.bit_width()
return cls.zero_point_with_dtype(signed, bit_width, zero_point)

@classmethod
def quant_output_zero_point(cls, module):
signed = module.is_quant_output_signed
zero_point = module.quant_output_zero_point()
bit_width = module.quant_output_bit_width()
if isinstance(module, QuantNonLinearActLayer):
quant = module.act_quant
else:
quant = module.output_quant
signed = quant.signed()
zero_point = quant.zero_point()
bit_width = quant.bit_width()
return cls.zero_point_with_dtype(signed, bit_width, zero_point)
14 changes: 7 additions & 7 deletions src/brevitas/export/onnx/standard/qoperator/handler/act.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ class StdQOpONNXQuantNLALHandler(StdQOpONNXQuantLayerHandler, ABC):

@classmethod
def validate(cls, module: QuantNLAL):
if cls.input_quant_supported and module.is_input_quant_enabled:
assert not module.is_quant_input_narrow_range, "Narrow range quant not supported."
elif not cls.input_quant_supported and module.is_input_quant_enabled:
if cls.input_quant_supported and module.input_quant.is_quant_enabled:
assert not module.input_quant.is_narrow_range, "Narrow range quant not supported."
elif not cls.input_quant_supported and module.input_quant.is_quant_enabled:
raise RuntimeError("Input quant not supported.")
if module.is_act_quant_enabled:
assert not module.is_quant_act_narrow_range, "Narrow range quant not supported."
input_bit_width = module.quant_input_bit_width()
act_bit_width = module.quant_act_bit_width()
if module.act_quant.is_quant_enabled:
assert not module.act_quant.is_narrow_range, "Narrow range quant not supported."
input_bit_width = module.input_quant.bit_width()
act_bit_width = module.act_quant.bit_width()
if input_bit_width is not None:
cls.validate_8b_bit_width(input_bit_width, le_then=True)
if act_bit_width is not None:
Expand Down
47 changes: 26 additions & 21 deletions src/brevitas/export/onnx/standard/qoperator/handler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from brevitas.export.onnx.standard.function import DequantizeLinearFn
from brevitas.export.onnx.standard.function import IntClipFn
from brevitas.export.onnx.standard.function import QuantizeLinearFn
from brevitas.nn.quant_layer import QuantNonLinearActLayer


class StdQOpONNXQuantLayerHandler(ONNXBaseHandler,
Expand Down Expand Up @@ -48,64 +49,68 @@ def torch_8b_dtype(cls, is_signed):

@classmethod
def quant_output_shape(cls, module):
cached_out = module._cached_out
cached_out = module.output_quant._cached_act
if cached_out is None:
raise RuntimeError("Caching of outputs is required to export")
return cached_out.shape

@classmethod
def output_quant_symbolic_kwargs(cls, module):
if module.is_output_quant_enabled:
quant_proxy = module.act_quant if isinstance(
module, QuantNonLinearActLayer) else module.output_quant
if quant_proxy.is_quant_enabled:
return {
'output_scale': module.quant_output_scale(),
'output_scale': quant_proxy.scale(),
'output_zero_point': cls.quant_output_zero_point(module),
'output_dtype': cls.torch_8b_dtype(module.is_quant_output_signed),
'output_axis': cls.quant_axis(module.quant_output_scale())}
'output_dtype': cls.torch_8b_dtype(quant_proxy.signed()),
'output_axis': cls.quant_axis(quant_proxy.scale())}
else:
return None

@classmethod
def output_clip_symbolic_kwargs(cls, module):
if module.is_output_quant_enabled:
narrow = module.is_quant_output_narrow_range
signed = module.is_quant_output_signed
bit_width = module.quant_output_bit_width()
quant_proxy = module.act_quant if isinstance(
module, QuantNonLinearActLayer) else module.output_quant
if quant_proxy.is_quant_enabled:
narrow = quant_proxy.is_narrow_range
signed = quant_proxy.signed()
bit_width = quant_proxy.bit_width()
return cls.int_clip_symbolic_kwargs(narrow, signed, bit_width)
else:
return None

@classmethod
def input_clip_symbolic_kwargs(cls, module):
if module.is_input_quant_enabled:
narrow = module.is_quant_input_narrow_range
signed = module.is_quant_input_signed
bit_width = module.quant_input_bit_width()
if module.input_quant.is_quant_enabled:
narrow = module.input_quant.is_narrow_range
signed = module.input_quant.signed()
bit_width = module.input_quant.bit_width()
return cls.int_clip_symbolic_kwargs(narrow, signed, bit_width)
else:
return None

@classmethod
def output_dequant_symbolic_kwargs(cls, module):
return {
'input_scale': module.quant_output_scale(),
'input_scale': module.output_quant.scale(),
'input_zero_point': cls.quant_output_zero_point(module),
'input_axis': cls.quant_axis(module.quant_output_scale())}
'input_axis': cls.quant_axis(module.output_quant.scale())}

@classmethod
def input_quant_symbolic_kwargs(cls, module):
if module.is_input_quant_enabled:
if module.input_quant.is_quant_enabled:
return {
'output_scale': module.quant_input_scale(),
'output_scale': module.input_quant.scale(),
'output_zero_point': cls.quant_input_zero_point(module),
'output_dtype': cls.torch_8b_dtype(module.is_quant_input_signed),
'output_axis': cls.quant_axis(module.quant_input_scale())}
'output_dtype': cls.torch_8b_dtype(module.input_quant.signed()),
'output_axis': cls.quant_axis(module.input_quant.scale())}
else:
return None

@classmethod
def input_dequant_symbolic_kwargs(cls, module):
if module._cached_inp is not None:
return cls.dequant_symbolic_kwargs_from_cached_io(module._cached_inp)
if module.input_quant._cached_act is not None:
return cls.dequant_symbolic_kwargs_from_cached_io(module.input_quant._cached_act)
else:
return None

Expand Down
40 changes: 20 additions & 20 deletions src/brevitas/export/onnx/standard/qoperator/handler/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class StdQOpONNXQuantWBIOLHandler(StdQOpONNXQuantLayerHandler, ABC):
@staticmethod
def int_weight(module: QuantWBIOL):
int_weight = module.int_weight(float_datatype=False).detach()
if module.is_quant_weight_signed:
if module.weight_quant.is_signed:
return int_weight.type(torch.int8)
else:
return int_weight.type(torch.uint8)
Expand All @@ -41,20 +41,20 @@ def int_bias(module: QuantWBIOL):

@classmethod
def validate(cls, module: QuantWBIOL, requires_quant_bias=True):
assert module.is_weight_quant_enabled, 'Weight quant required'
assert module.is_output_quant_enabled, 'Output quant required'
assert module.weight_quant.is_quant_enabled, 'Weight quant required'
assert module.output_quant.is_quant_enabled, 'Output quant required'
# Handling narrow_range is across the network is difficult do to the fact that
# it's not part of QuantTensor, and so it can't be cached
assert not module.is_quant_output_narrow_range, 'Narrow output quant not supported'
if module.is_input_quant_enabled:
assert not module.is_quant_input_narrow_range, 'Narrow output quant not supported'
cls.validate_8b_bit_width(module.quant_weight_bit_width(), le_then=True)
cls.validate_8b_bit_width(module.quant_input_bit_width(), le_then=True)
cls.validate_8b_bit_width(module.quant_output_bit_width(), le_then=True)
assert not module.output_quant.is_narrow_range, 'Narrow output quant not supported'
if module.input_quant.is_quant_enabled:
assert not module.input_quant.is_narrow_range, 'Narrow output quant not supported'
cls.validate_8b_bit_width(module.weight_quant.bit_width(), le_then=True)
cls.validate_8b_bit_width(module.input_quant.bit_width(), le_then=True)
cls.validate_8b_bit_width(module.output_quant.bit_width(), le_then=True)
if module.bias is not None and requires_quant_bias:
assert module.is_bias_quant_enabled
assert module.is_quant_bias_signed
cls.validate_32b_bit_width(module.quant_bias_bit_width(), le_then=True)
assert module.bias_quant.is_quant_enabled
assert module.bias_quant.is_signed
cls.validate_32b_bit_width(module.bias_quant.bit_width(), le_then=True)

def prepare_for_export(self, module: Union[QuantConv1d, QuantConv2d]):
self.validate(module)
Expand Down Expand Up @@ -108,14 +108,14 @@ class StdQOpONNXQuantConvNdHandler(StdQOpONNXQuantWBIOLHandler, ABC):

def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d]):
conv_symbolic_kwargs = {
'input_scale': module.quant_input_scale(),
'input_scale': module.input_quant.scale(),
'input_zero_point': self.quant_input_zero_point(module),
'int_weight': self.int_weight(module),
'weight_scale': to_0dim_if_scalar(module.quant_weight_scale().flatten()),
'weight_scale': to_0dim_if_scalar(module.quant_weight().scale.flatten()),
'weight_zero_point': to_0dim_if_scalar(self.quant_weight_zero_point(module).flatten()),
'output_scale': module.quant_output_scale(),
'output_scale': module.output_quant.scale(),
'output_zero_point': self.quant_output_zero_point(module),
'output_dtype': self.torch_8b_dtype(module.is_quant_output_signed),
'output_dtype': self.torch_8b_dtype(module.output_quant.signed()),
'int_bias': self.int_bias(module),
'out_shape': self.quant_output_shape(module),
'kernel_size': list(module.kernel_size),
Expand Down Expand Up @@ -145,14 +145,14 @@ class StdQOpONNXQuantLinearHandler(StdQOpONNXQuantWBIOLHandler):
# Convert linear to conv1d to handle bias
def op_symbolic_kwargs(self, module: QuantLinear):
conv_symbolic_kwargs = {
'input_scale': module.quant_input_scale(),
'input_scale': module.input_quant.scale(),
'input_zero_point': self.quant_input_zero_point(module),
'int_weight': self.int_weight(module).view(module.out_features, module.in_features, 1),
'weight_scale': to_0dim_if_scalar(module.quant_weight_scale().flatten()),
'weight_scale': to_0dim_if_scalar(module.quant_weight().scale.flatten()),
'weight_zero_point': to_0dim_if_scalar(self.quant_weight_zero_point(module).flatten()),
'output_scale': module.quant_output_scale(),
'output_scale': module.output_quant.scale(),
'output_zero_point': self.quant_output_zero_point(module),
'output_dtype': self.torch_8b_dtype(module.is_quant_output_signed),
'output_dtype': self.torch_8b_dtype(module.output_quant.signed()),
'int_bias': self.int_bias(module),
'out_shape': self.quant_output_shape(module) + (1,),
'kernel_size': [1],
Expand Down
6 changes: 3 additions & 3 deletions src/brevitas/export/torch/qoperator/handler/act.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ def explicit_output_dtype(cls) -> bool:

@classmethod
def validate(cls, module: QuantNLAL):
assert not module.is_input_quant_enabled, 'Input quantization not supported'
cls.validate_8b_bit_width(module.quant_act_bit_width(), le_then=False)
assert not module.input_quant.is_quant_enabled, 'Input quantization not supported'
cls.validate_8b_bit_width(module.act_quant.bit_width(), le_then=False)

def prepare_for_export(self, module: QuantNLAL):
self.validate(module)
self.qf_impl, self.qf_kwargs = self.prepare_qf(module)
if module.is_act_quant_enabled:
if module.act_quant.is_quant_enabled:
self.output_quant_impl, self.output_quant_kwargs = self.prepare_output_quant(module)
self.return_quant_tensor = module.return_quant_tensor

Expand Down
13 changes: 9 additions & 4 deletions src/brevitas/export/torch/qoperator/handler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from brevitas.export.common.handler.base import BaseHandler
from brevitas.export.common.handler.base import BitWidthHandlerMixin
from brevitas.export.common.handler.base import ZeroPointHandlerMixin
from brevitas.nn.quant_layer import QuantNonLinearActLayer

SCALAR_SHAPE = ()

Expand Down Expand Up @@ -55,17 +56,21 @@ def gen_quant_impl_kwargs(

@classmethod
def prepare_input_quant(cls, module):
scale = module.quant_input_scale()
scale = module.input_quant.scale()
zero_point = cls.quant_input_zero_point(module)
signed = module.is_quant_input_signed
signed = module.input_quant.signed()
quant_impl, quant_kwargs = cls.gen_quant_impl_kwargs(scale, zero_point, signed)
return quant_impl, quant_kwargs

@classmethod
def prepare_output_quant(cls, module):
scale = module.quant_output_scale()
if isinstance(module, QuantNonLinearActLayer):
quant = module.act_quant
else:
quant = module.output_quant
scale = quant.scale()
zero_point = cls.quant_output_zero_point(module)
signed = module.is_quant_output_signed
signed = quant.signed()
incl_dtype = cls.explicit_output_dtype()
quant_impl, quant_kwargs = cls.gen_quant_impl_kwargs(scale, zero_point, signed, incl_dtype)
return quant_impl, quant_kwargs
Loading

0 comments on commit a36d65f

Please sign in to comment.