From f5cc57587a71b8590c18076f5a71b7f0a9270b79 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 27 Feb 2024 15:17:56 +0000 Subject: [PATCH] Feat: remove quant metadata from quantlayer --- src/brevitas/export/common/handler/base.py | 23 +-- .../onnx/standard/qoperator/handler/base.py | 30 ++-- .../standard/qoperator/handler/parameter.py | 26 +-- .../export/torch/qoperator/handler/act.py | 2 +- .../export/torch/qoperator/handler/base.py | 13 +- .../torch/qoperator/handler/parameter.py | 10 +- src/brevitas/graph/gpxq.py | 8 +- src/brevitas/graph/quantize.py | 4 +- src/brevitas/graph/quantize_impl.py | 4 +- src/brevitas/nn/mixin/act.py | 68 +++----- src/brevitas/nn/mixin/base.py | 61 +------ src/brevitas/nn/mixin/parameter.py | 79 ++------- src/brevitas/nn/quant_layer.py | 152 +----------------- src/brevitas/proxy/parameter_quant.py | 39 ++++- src/brevitas/proxy/runtime_quant.py | 65 +++++--- .../super_resolution/utils/evaluate.py | 2 +- tests/brevitas/export/test_torch_qop.py | 10 +- tests/brevitas/graph/test_calibration.py | 2 +- tests/brevitas/nn/test_linear.py | 2 +- tests/brevitas/nn/test_wbiol.py | 22 +-- tests/brevitas/proxy/test_act_scaling.py | 2 +- tests/brevitas/proxy/test_weight_scaling.py | 10 +- 22 files changed, 211 insertions(+), 423 deletions(-) diff --git a/src/brevitas/export/common/handler/base.py b/src/brevitas/export/common/handler/base.py index f7706f450..e346639d0 100644 --- a/src/brevitas/export/common/handler/base.py +++ b/src/brevitas/export/common/handler/base.py @@ -11,6 +11,7 @@ from brevitas.function.ops import max_int from brevitas.function.ops import min_int +from brevitas.nn.quant_layer import QuantNonLinearActLayer __all__ = ['BaseHandler', 'BitWidthHandlerMixin', 'ZeroPointHandlerMixin'] @@ -130,21 +131,25 @@ def zero_point_with_dtype(cls, signed, bit_width, zero_point): @classmethod def quant_input_zero_point(cls, module): - signed = module.is_quant_input_signed - zero_point = module.quant_input_zero_point() - bit_width = module.quant_input_bit_width() + signed = module.input_quant.signed() + zero_point = module.input_quant.zero_point() + bit_width = module.input_quant.bit_width() return cls.zero_point_with_dtype(signed, bit_width, zero_point) @classmethod def quant_weight_zero_point(cls, module): - signed = module.is_quant_weight_signed - zero_point = module.quant_weight_zero_point() - bit_width = module.quant_weight_bit_width() + signed = module.weight_quant.is_signed + zero_point = module.quant_weight().zero_point + bit_width = module.weight_quant.bit_width() return cls.zero_point_with_dtype(signed, bit_width, zero_point) @classmethod def quant_output_zero_point(cls, module): - signed = module.is_quant_output_signed - zero_point = module.quant_output_zero_point() - bit_width = module.quant_output_bit_width() + if isinstance(module, QuantNonLinearActLayer): + quant = module.act_quant + else: + quant = module.output_quant + signed = quant.signed() + zero_point = quant.zero_point() + bit_width = quant.bit_width() return cls.zero_point_with_dtype(signed, bit_width, zero_point) diff --git a/src/brevitas/export/onnx/standard/qoperator/handler/base.py b/src/brevitas/export/onnx/standard/qoperator/handler/base.py index e614d2ed5..35f713936 100644 --- a/src/brevitas/export/onnx/standard/qoperator/handler/base.py +++ b/src/brevitas/export/onnx/standard/qoperator/handler/base.py @@ -48,7 +48,7 @@ def torch_8b_dtype(cls, is_signed): @classmethod def quant_output_shape(cls, module): - cached_out = module._cached_out + cached_out = module.output_quant._cached_act if cached_out is None: raise RuntimeError("Caching of outputs is required to export") return cached_out.shape @@ -57,10 +57,10 @@ def quant_output_shape(cls, module): def output_quant_symbolic_kwargs(cls, module): if module.is_output_quant_enabled: return { - 'output_scale': module.quant_output_scale(), + 'output_scale': module.output_quant.scale(), 'output_zero_point': cls.quant_output_zero_point(module), - 'output_dtype': cls.torch_8b_dtype(module.is_quant_output_signed), - 'output_axis': cls.quant_axis(module.quant_output_scale())} + 'output_dtype': cls.torch_8b_dtype(module.output_quant.signed()), + 'output_axis': cls.quant_axis(module.output_quant.scale())} else: return None @@ -68,8 +68,8 @@ def output_quant_symbolic_kwargs(cls, module): def output_clip_symbolic_kwargs(cls, module): if module.is_output_quant_enabled: narrow = module.is_quant_output_narrow_range - signed = module.is_quant_output_signed - bit_width = module.quant_output_bit_width() + signed = module.output_quant.signed() + bit_width = module.output_quant.bit_width() return cls.int_clip_symbolic_kwargs(narrow, signed, bit_width) else: return None @@ -78,8 +78,8 @@ def output_clip_symbolic_kwargs(cls, module): def input_clip_symbolic_kwargs(cls, module): if module.is_input_quant_enabled: narrow = module.is_quant_input_narrow_range - signed = module.is_quant_input_signed - bit_width = module.quant_input_bit_width() + signed = module.input_quant.signed() + bit_width = module.input_quant.bit_width() return cls.int_clip_symbolic_kwargs(narrow, signed, bit_width) else: return None @@ -87,25 +87,25 @@ def input_clip_symbolic_kwargs(cls, module): @classmethod def output_dequant_symbolic_kwargs(cls, module): return { - 'input_scale': module.quant_output_scale(), + 'input_scale': module.output_quant.scale(), 'input_zero_point': cls.quant_output_zero_point(module), - 'input_axis': cls.quant_axis(module.quant_output_scale())} + 'input_axis': cls.quant_axis(module.output_quant.scale())} @classmethod def input_quant_symbolic_kwargs(cls, module): if module.is_input_quant_enabled: return { - 'output_scale': module.quant_input_scale(), + 'output_scale': module.input_quant.scale(), 'output_zero_point': cls.quant_input_zero_point(module), - 'output_dtype': cls.torch_8b_dtype(module.is_quant_input_signed), - 'output_axis': cls.quant_axis(module.quant_input_scale())} + 'output_dtype': cls.torch_8b_dtype(module.input_quant.signed()), + 'output_axis': cls.quant_axis(module.input_quant.scale())} else: return None @classmethod def input_dequant_symbolic_kwargs(cls, module): - if module._cached_inp is not None: - return cls.dequant_symbolic_kwargs_from_cached_io(module._cached_inp) + if module.input_quant._cached_act is not None: + return cls.dequant_symbolic_kwargs_from_cached_io(module.input_quant._cached_act) else: return None diff --git a/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py b/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py index c4c552c65..376820102 100644 --- a/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py +++ b/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py @@ -26,7 +26,7 @@ class StdQOpONNXQuantWBIOLHandler(StdQOpONNXQuantLayerHandler, ABC): @staticmethod def int_weight(module: QuantWBIOL): int_weight = module.int_weight(float_datatype=False).detach() - if module.is_quant_weight_signed: + if module.weight_quant.is_signed: return int_weight.type(torch.int8) else: return int_weight.type(torch.uint8) @@ -48,13 +48,13 @@ def validate(cls, module: QuantWBIOL, requires_quant_bias=True): assert not module.is_quant_output_narrow_range, 'Narrow output quant not supported' if module.is_input_quant_enabled: assert not module.is_quant_input_narrow_range, 'Narrow output quant not supported' - cls.validate_8b_bit_width(module.quant_weight_bit_width(), le_then=True) - cls.validate_8b_bit_width(module.quant_input_bit_width(), le_then=True) - cls.validate_8b_bit_width(module.quant_output_bit_width(), le_then=True) + cls.validate_8b_bit_width(module.weight_quant.bit_width(), le_then=True) + cls.validate_8b_bit_width(module.input_quant.bit_width(), le_then=True) + cls.validate_8b_bit_width(module.output_quant.bit_width(), le_then=True) if module.bias is not None and requires_quant_bias: assert module.is_bias_quant_enabled assert module.is_quant_bias_signed - cls.validate_32b_bit_width(module.quant_bias_bit_width(), le_then=True) + cls.validate_32b_bit_width(module.bias_quant.bit_width(), le_then=True) def prepare_for_export(self, module: Union[QuantConv1d, QuantConv2d]): self.validate(module) @@ -108,14 +108,14 @@ class StdQOpONNXQuantConvNdHandler(StdQOpONNXQuantWBIOLHandler, ABC): def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d]): conv_symbolic_kwargs = { - 'input_scale': module.quant_input_scale(), + 'input_scale': module.input_quant.scale(), 'input_zero_point': self.quant_input_zero_point(module), 'int_weight': self.int_weight(module), - 'weight_scale': to_0dim_if_scalar(module.quant_weight_scale().flatten()), + 'weight_scale': to_0dim_if_scalar(module.quant_weight().scale.flatten()), 'weight_zero_point': to_0dim_if_scalar(self.quant_weight_zero_point(module).flatten()), - 'output_scale': module.quant_output_scale(), + 'output_scale': module.output_quant.scale(), 'output_zero_point': self.quant_output_zero_point(module), - 'output_dtype': self.torch_8b_dtype(module.is_quant_output_signed), + 'output_dtype': self.torch_8b_dtype(module.output_quant.signed()), 'int_bias': self.int_bias(module), 'out_shape': self.quant_output_shape(module), 'kernel_size': list(module.kernel_size), @@ -145,14 +145,14 @@ class StdQOpONNXQuantLinearHandler(StdQOpONNXQuantWBIOLHandler): # Convert linear to conv1d to handle bias def op_symbolic_kwargs(self, module: QuantLinear): conv_symbolic_kwargs = { - 'input_scale': module.quant_input_scale(), + 'input_scale': module.input_quant.scale(), 'input_zero_point': self.quant_input_zero_point(module), 'int_weight': self.int_weight(module).view(module.out_features, module.in_features, 1), - 'weight_scale': to_0dim_if_scalar(module.quant_weight_scale().flatten()), + 'weight_scale': to_0dim_if_scalar(module.quant_weight().scale.flatten()), 'weight_zero_point': to_0dim_if_scalar(self.quant_weight_zero_point(module).flatten()), - 'output_scale': module.quant_output_scale(), + 'output_scale': module.output_quant.scale(), 'output_zero_point': self.quant_output_zero_point(module), - 'output_dtype': self.torch_8b_dtype(module.is_quant_output_signed), + 'output_dtype': self.torch_8b_dtype(module.output_quant.signed()), 'int_bias': self.int_bias(module), 'out_shape': self.quant_output_shape(module) + (1,), 'kernel_size': [1], diff --git a/src/brevitas/export/torch/qoperator/handler/act.py b/src/brevitas/export/torch/qoperator/handler/act.py index 4a0d9cf2e..d2b7f8061 100644 --- a/src/brevitas/export/torch/qoperator/handler/act.py +++ b/src/brevitas/export/torch/qoperator/handler/act.py @@ -24,7 +24,7 @@ def explicit_output_dtype(cls) -> bool: @classmethod def validate(cls, module: QuantNLAL): assert not module.is_input_quant_enabled, 'Input quantization not supported' - cls.validate_8b_bit_width(module.quant_act_bit_width(), le_then=False) + cls.validate_8b_bit_width(module.act_quant.bit_width(), le_then=False) def prepare_for_export(self, module: QuantNLAL): self.validate(module) diff --git a/src/brevitas/export/torch/qoperator/handler/base.py b/src/brevitas/export/torch/qoperator/handler/base.py index c5b569ff6..e0a9503c8 100644 --- a/src/brevitas/export/torch/qoperator/handler/base.py +++ b/src/brevitas/export/torch/qoperator/handler/base.py @@ -10,6 +10,7 @@ from brevitas.export.common.handler.base import BaseHandler from brevitas.export.common.handler.base import BitWidthHandlerMixin from brevitas.export.common.handler.base import ZeroPointHandlerMixin +from brevitas.nn.quant_layer import QuantNonLinearActLayer SCALAR_SHAPE = () @@ -55,17 +56,21 @@ def gen_quant_impl_kwargs( @classmethod def prepare_input_quant(cls, module): - scale = module.quant_input_scale() + scale = module.input_quant.scale() zero_point = cls.quant_input_zero_point(module) - signed = module.is_quant_input_signed + signed = module.input_quant.signed() quant_impl, quant_kwargs = cls.gen_quant_impl_kwargs(scale, zero_point, signed) return quant_impl, quant_kwargs @classmethod def prepare_output_quant(cls, module): - scale = module.quant_output_scale() + if isinstance(module, QuantNonLinearActLayer): + quant = module.act_quant + else: + quant = module.output_quant + scale = quant.scale() zero_point = cls.quant_output_zero_point(module) - signed = module.is_quant_output_signed + signed = quant.signed() incl_dtype = cls.explicit_output_dtype() quant_impl, quant_kwargs = cls.gen_quant_impl_kwargs(scale, zero_point, signed, incl_dtype) return quant_impl, quant_kwargs diff --git a/src/brevitas/export/torch/qoperator/handler/parameter.py b/src/brevitas/export/torch/qoperator/handler/parameter.py index fa110a84c..6314320cf 100644 --- a/src/brevitas/export/torch/qoperator/handler/parameter.py +++ b/src/brevitas/export/torch/qoperator/handler/parameter.py @@ -48,12 +48,12 @@ def prepare_bias(cls, module: QuantWBIOL): @classmethod def prepare_weight_quant(cls, module: QuantWBIOL): - cls.validate_bit_width(module.quant_weight_bit_width(), 7, le_then=True) - cls.validate_8b_bit_width(module.quant_input_bit_width(), le_then=False) - cls.validate_8b_bit_width(module.quant_output_bit_width(), le_then=False) - scale = module.quant_weight_scale() + cls.validate_bit_width(module.weight_quant.bit_width(), 7, le_then=True) + cls.validate_8b_bit_width(module.input_quant.bit_width(), le_then=False) + cls.validate_8b_bit_width(module.output_quant.bit_width(), le_then=False) + scale = module.weight_quant.scale() zero_point = cls.quant_weight_zero_point(module) - signed = module.is_quant_weight_signed + signed = module.weight_quant.is_signed weight = module.weight.detach() quant_impl, quant_kwargs = cls.gen_quant_impl_kwargs(scale, zero_point, signed) return quant_impl, (weight,), quant_kwargs diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index e9641a5a8..9f517847a 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -227,10 +227,10 @@ def process_input(self, inp): # if the quant_input is not already cached, then get # metadata from QuantWBIOL module if self.quant_input is None: - inp_scale = self.layer.quant_input_scale() - inp_zero_point = self.layer.quant_input_zero_point() - inp_bit_width = self.layer.quant_input_bit_width() - inp_signed = self.layer.is_quant_input_signed + inp_scale = self.layer.input_quant.scale() + inp_zero_point = self.layer.input_quant.zero_point() + inp_bit_width = self.layer.input_quant.bit_width() + inp_signed = self.layer.input_quant.signed() inp_training = self.layer.training # If using quantized activations, inp could be QuantTensor. In diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index 4ac324c8b..a4a786e12 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -211,10 +211,10 @@ def align_input_quant( # If it is a QuantIdentity already, simply modify tensor_quant or the scaling implementations # based on whether we need to align the sign or not if isinstance(module, qnn.QuantIdentity): - if align_sign or module.is_quant_act_signed == shared_quant_identity.is_quant_act_signed: + if align_sign or module.input_quant.signed() == shared_quant_identity.input_quant.signed(): return shared_quant_identity else: - assert not module.is_quant_act_signed and shared_quant_identity.is_quant_act_signed + assert not module.input_quant.signed() and shared_quant_identity.input_quant.signed() quant_module_class, quant_module_kwargs = quant_identity_map['unsigned'] return ( quant_module_class, diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index 4fc8e5c66..fe6c5ec12 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -82,8 +82,8 @@ def are_inputs_unsigned(model, node, is_unsigned_list, quant_act_map, unsigned_a elif isinstance(inp_module, tuple(SIGN_PRESERVING_MODULES)): are_inputs_unsigned( model, inp_node, is_unsigned_list, quant_act_map, unsigned_act_tuple) - elif hasattr(inp_module, 'is_quant_act_signed'): - is_unsigned_list.append(not inp_module.is_quant_act_signed) + elif hasattr(inp_module, 'input_quant'): + is_unsigned_list.append(not inp_module.input_quant.signed()) else: is_unsigned_list.append(False) elif inp_node.op == 'call_function': diff --git a/src/brevitas/nn/mixin/act.py b/src/brevitas/nn/mixin/act.py index d00b6f35e..e1c5f9393 100644 --- a/src/brevitas/nn/mixin/act.py +++ b/src/brevitas/nn/mixin/act.py @@ -4,13 +4,11 @@ from abc import ABCMeta from abc import abstractmethod from typing import Optional, Type, Union -from warnings import warn from torch.nn import Module from brevitas.inject import ExtendedInjector from brevitas.inject import Injector -from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector from brevitas.proxy.runtime_quant import ActQuantProxyProtocol from brevitas.quant import NoneActQuant @@ -42,22 +40,10 @@ def is_input_quant_enabled(self): def is_quant_input_narrow_range(self): # TODO make abstract once narrow range can be cached return self.input_quant.is_narrow_range - @property - @abstractmethod - def is_quant_input_signed(self): - pass - - @abstractmethod - def quant_input_scale(self): - pass - - @abstractmethod - def quant_input_zero_point(self): - pass - - @abstractmethod - def quant_input_bit_width(self): - pass + # @property + # @abstractmethod + # def is_quant_input_signed(self): + # pass class QuantOutputMixin(QuantProxyMixin): @@ -83,22 +69,10 @@ def is_output_quant_enabled(self): def is_quant_output_narrow_range(self): # TODO make abstract once narrow range can be cached return self.output_quant.is_narrow_range - @property - @abstractmethod - def is_quant_output_signed(self): - pass - - @abstractmethod - def quant_output_scale(self): - pass - - @abstractmethod - def quant_output_zero_point(self): - pass - - @abstractmethod - def quant_output_bit_width(self): - pass + # @property + # @abstractmethod + # def is_quant_output_signed(self): + # pass class QuantNonLinearActMixin(QuantProxyMixin): @@ -133,19 +107,19 @@ def is_act_quant_enabled(self): def is_quant_act_narrow_range(self): # TODO make abstract once narrow range can be cached return self.act_quant.is_narrow_range - @property - @abstractmethod - def is_quant_act_signed(self): - pass + # @property + # @abstractmethod + # def is_quant_act_signed(self): + # pass - @abstractmethod - def quant_act_scale(self): - pass + # @abstractmethod + # def quant_act_scale(self): + # pass - @abstractmethod - def quant_act_zero_point(self): - pass + # @abstractmethod + # def quant_act_zero_point(self): + # pass - @abstractmethod - def quant_act_bit_width(self): - pass + # @abstractmethod + # def quant_act_bit_width(self): + # pass diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 55a5c8150..5ce256515 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -94,67 +94,15 @@ def __init__( self.cache_inference_quant_inp = cache_inference_quant_inp self.cache_inference_quant_out = cache_inference_quant_out self.cache_quant_io_metadata_only = cache_quant_io_metadata_only - self._cached_inp = None - self._cached_out = None @property @abstractmethod def channelwise_separable(self) -> bool: pass - @property - def is_quant_input_signed(self) -> Optional[bool]: # tri-valued logic output - if self._cached_inp is not None: - return self._cached_inp.signed - else: - return None - def _set_global_is_quant_layer(self, value): config._IS_INSIDE_QUANT_LAYER = value - def quant_input_scale(self): - if self._cached_inp is not None: - return self._cached_inp.scale - else: - return None - - def quant_input_zero_point(self): - if self._cached_inp is not None: - return self._cached_inp.zero_point - else: - return None - - def quant_input_bit_width(self): - if self._cached_inp is not None: - return self._cached_inp.bit_width - else: - return None - - @property - def is_quant_output_signed(self) -> Optional[bool]: # tri-valued logic output - if self._cached_out is not None: - return self._cached_out.signed - else: - return None - - def quant_output_scale(self): - if self._cached_out is not None: - return self._cached_out.scale - else: - return None - - def quant_output_zero_point(self): - if self._cached_out is not None: - return self._cached_out.zero_point - else: - return None - - def quant_output_bit_width(self): - if self._cached_out is not None: - return self._cached_out.bit_width - else: - return None - def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: self._set_global_is_quant_layer(True) # Hack to recognize a QuantTensor that has decayed to a tuple @@ -166,7 +114,10 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe # don't cache values during export pass if not self.training and not self._export_mode and self.cache_inference_quant_inp: cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) - self._cached_inp = cached_inp + if hasattr(self, 'input_quant'): + self.input_quant._cached_act = cached_inp + if hasattr(self, 'weight_quant') and self.weight_quant_requires_quant_input: + self.weight_quant._cached_act = cached_inp if not torch._C._get_tracing_state(): if isinstance(inp, QuantTensor): inp = inp.set(value=inp.value.rename(None)) @@ -177,7 +128,9 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe def pack_output(self, quant_output: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: if not self.training and self.cache_inference_quant_out and isinstance(quant_output, QuantTensor): - self._cached_out = _CachedIO(quant_output.detach(), self.cache_quant_io_metadata_only) + if hasattr(self, 'output_quant'): + self.output_quant._cached_act = _CachedIO( + quant_output.detach(), self.cache_quant_io_metadata_only) self._set_global_is_quant_layer(False) if self.return_quant_tensor: assert isinstance(quant_output, QuantTensor) diff --git a/src/brevitas/nn/mixin/parameter.py b/src/brevitas/nn/mixin/parameter.py index a752c35ec..caa6f5e24 100644 --- a/src/brevitas/nn/mixin/parameter.py +++ b/src/brevitas/nn/mixin/parameter.py @@ -6,6 +6,8 @@ from typing import List, Optional, Tuple, Type, Union from warnings import warn +from torch import Tensor + from brevitas.inject import ExtendedInjector from brevitas.inject import Injector from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector @@ -49,9 +51,9 @@ def is_weight_quant_enabled(self): def is_quant_weight_narrow_range(self): return self.weight_quant.is_narrow_range - @property - def is_quant_weight_signed(self): - return self.weight_quant.is_signed + # @property + # def is_quant_weight_signed(self): + # return self.weight_quant.is_signed @property def weight_quant_requires_quant_input(self): @@ -85,18 +87,17 @@ def quant_weight( else: weight_slice_tuple = slice(None) if self.weight_quant_requires_quant_input: + input_bit_width = None + input_is_signed = None if self.is_weight_quant_enabled: if quant_input is None: - input_bit_width = self.quant_input_bit_width() - input_is_signed = self.is_quant_input_signed + input_bit_width = self.input_quant.bit_width() + input_is_signed = self.input_quant.signed() else: input_bit_width = quant_input.bit_width input_is_signed = quant_input.signed assert input_bit_width is not None, "Input bit-width needs to be specified." assert input_is_signed is not None, "Input sign needs to be specified." - else: - input_bit_width = None - input_is_signed = None out = self.weight_quant( weights_to_quantize[weight_slice_tuple], input_bit_width, input_is_signed) else: @@ -112,18 +113,6 @@ def quant_weight( def int_weight(self, float_datatype=False): return self.quant_weight().int(float_datatype) - def quant_weight_scale(self): - scale = self.quant_weight().scale - return scale - - def quant_weight_zero_point(self): - scale = self.quant_weight().zero_point - return scale - - def quant_weight_bit_width(self): - bit_width = self.quant_weight().bit_width - return bit_width - def register_parameter(self, name, value): super(QuantWeightMixin, self).register_parameter(name, value) if hasattr(self, 'weight_quant') and name == 'weight': @@ -148,7 +137,6 @@ def __init__( proxy_prefix='bias_', **kwargs) self.cache_inference_quant_bias = cache_inference_bias - self._cached_bias = None @property def is_bias_quant_enabled(self): @@ -175,57 +163,10 @@ def int_bias(self, float_datatype=False): def quant_bias(self): if self.bias is None: return None - scale = self.quant_bias_scale() + scale = self.bias_quant.scale() quant_bias = self.bias_quant(self.bias, scale) return quant_bias - def quant_bias_scale(self): - if self.bias is None or not self.is_bias_quant_enabled: - return None - if not self.bias_quant.requires_input_scale: - return self.bias_quant(self.bias).scale - else: - if self._cached_bias is None: - raise RuntimeError( - "No quant bias cache found, set cache_inference_quant_bias=True and run an " - "inference pass first") - if self.training: - warn("Cached quant bias scale is being used in training mode.") - return self._cached_bias.scale - - def quant_bias_zero_point(self): - if self.bias is None: - return None - - if not self.bias_quant.requires_input_scale: - bias_quant = self.bias_quant(self.bias) - if isinstance(bias_quant, QuantTensor): - return bias_quant.zero_point - else: - return None - else: - if self._cached_bias is None: - raise RuntimeError( - "No quant bias cache found, set cache_inference_quant_bias=True and run an " - "inference pass first") - if self.training: - warn("Cached quant bias zero-point is being used in training mode.") - return self._cached_bias.bit_width - - def quant_bias_bit_width(self): - if self.bias is None or not self.is_bias_quant_enabled: - return None - if not self.bias_quant.requires_input_scale: - return self.bias_quant(self.bias).bit_width - else: - if self._cached_bias is None: - raise RuntimeError( - "No quant bias cache found, set cache_inference_quant_bias=True and run an " - "inference pass first") - if self.training: - warn("Cached quant bias bit-width is being used in training mode.") - return self._cached_bias.bit_width - def register_parameter(self, name, value): super(QuantBiasMixin, self).register_parameter(name, value) if hasattr(self, 'bias_quant') and name == 'bias': diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 5cf29602f..110c1a394 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -3,12 +3,11 @@ from abc import ABCMeta from abc import abstractmethod -from typing import Callable, Optional, Type, Union +from typing import Optional, Type, Union import torch from torch import Tensor from torch.nn import Module -from torch.nn import Parameter from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import QuantTensor @@ -44,24 +43,6 @@ def channelwise_separable(self) -> bool: def requires_export_handler(self): return self.is_input_quant_enabled or self.is_act_quant_enabled - @property - def is_quant_input_signed(self) -> Optional[bool]: # tri-valued logic output - if self.is_input_quant_enabled: - return self.input_quant.is_signed - elif self._cached_inp is not None: - return self._cached_inp.signed - else: - return None - - @property - def is_quant_act_signed(self) -> Optional[bool]: # tri-valued logic output - if self.is_act_quant_enabled: - return self.act_quant.is_signed - elif self._cached_out is not None: - return self._cached_out.signed - else: - return None - @property def is_output_quant_enabled(self): return self.is_act_quant_enabled @@ -70,67 +51,6 @@ def is_output_quant_enabled(self): def is_quant_output_narrow_range(self): return self.is_quant_act_narrow_range - @property - def is_quant_output_signed(self): # overrides from QuantLayerMixin - return self.is_quant_act_signed - - def quant_input_scale(self): - if self.is_input_quant_enabled: - return self.input_quant.scale() - elif self._cached_inp is not None: - return self._cached_inp.scale - else: - return None - - def quant_act_scale(self): - if self.is_act_quant_enabled: - return self.act_quant.scale() - elif self._cached_out is not None: - return self._cached_out.scale - else: - return None - - def quant_output_scale(self): # overrides from QuantLayerMixin - return self.quant_act_scale() - - def quant_input_zero_point(self): - if self.is_input_quant_enabled: - return self.input_quant.zero_point() - elif self._cached_inp is not None: - return self._cached_inp.zero_point - else: - return None - - def quant_act_zero_point(self): - if self.is_act_quant_enabled: - return self.act_quant.zero_point() - elif self._cached_out is not None: - return self._cached_out.zero_point - else: - return None - - def quant_output_zero_point(self): # overrides from QuantLayerMixin - return self.quant_act_zero_point() - - def quant_input_bit_width(self): - if self.is_input_quant_enabled: - return self.input_quant.bit_width() - elif self._cached_inp is not None: - return self._cached_inp.bit_width - else: - return None - - def quant_act_bit_width(self): - if self.is_act_quant_enabled: - return self.act_quant.bit_width() - elif self._cached_out is not None: - return self._cached_out.bit_width - else: - return None - - def quant_output_bit_width(self): # overrides from QuantLayerMixin - return self.quant_act_bit_width() - def forward(self, input: Union[Tensor, QuantTensor]): input = self.unpack_input(input) quant_input = self.input_quant(input) @@ -178,72 +98,6 @@ def __init__( def requires_export_handler(self): return self.is_input_quant_enabled or self.is_output_quant_enabled - @property - def is_quant_input_signed(self) -> Optional[bool]: # tri-valued logic output - if self.is_input_quant_enabled: - return self.input_quant.is_signed - elif self._cached_inp is not None: - return self._cached_inp.signed - else: - return None - - @property - def is_quant_output_signed(self) -> Optional[bool]: # tri-valued logic output: - if self.is_output_quant_enabled: - return self.output_quant.is_signed - elif self._cached_out is not None: - return self._cached_out.signed - else: - return None - - def quant_input_scale(self): - if self.is_input_quant_enabled: - return self.input_quant.scale() - elif self._cached_inp is not None: - return self._cached_inp.scale - else: - return None - - def quant_output_scale(self): - if self.is_output_quant_enabled: - return self.output_quant.scale() - elif self._cached_out is not None: - return self._cached_out.scale - else: - return None - - def quant_input_zero_point(self): - if self.is_input_quant_enabled: - return self.input_quant.zero_point() - elif self._cached_inp is not None: - return self._cached_inp.zero_point - else: - return None - - def quant_output_zero_point(self): - if self.is_output_quant_enabled: - return self.output_quant.zero_point() - elif self._cached_out is not None: - return self._cached_out.zero_point - else: - return None - - def quant_input_bit_width(self): - if self.is_input_quant_enabled: - return self.input_quant.bit_width() - elif self._cached_inp is not None: - return self._cached_inp.bit_width - else: - return None - - def quant_output_bit_width(self): - if self.is_output_quant_enabled: - return self.output_quant.bit_width() - elif self._cached_out is not None: - return self._cached_out.bit_width - else: - return None - class QuantWeightBiasInputOutputLayer(QuantBiasMixin, QuantWeightMixin, QuantInputOutputLayer): __metaclass__ = ABCMeta @@ -309,6 +163,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe return out quant_input = self.input_quant(inp) + quant_weight = self.quant_weight(quant_input) compute_output_quant_tensor = isinstance(quant_input, QuantTensor) and isinstance( @@ -326,8 +181,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe quant_bias = self.bias_quant(self.bias, output_scale) if not self.training and self.cache_inference_quant_bias and isinstance(quant_bias, QuantTensor): - - self._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False) + self.bias_quant._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False) output_tensor = self.inner_forward_impl( _unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 3c8185535..f05b95f0f 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -4,14 +4,17 @@ from abc import ABCMeta from abc import abstractmethod from typing import Optional, Union +from warnings import warn import torch from torch import Tensor +import torch.nn as nn from typing_extensions import Protocol from typing_extensions import runtime_checkable from brevitas import config from brevitas.function import max_int +from brevitas.inject import BaseInjector as Injector from brevitas.quant_tensor import QuantTensor from .quant_proxy import QuantProxyFromInjector @@ -135,6 +138,11 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: class DecoupledWeightQuantWithInputProxyFromInjector(DecoupledWeightQuantProxyFromInjector): + def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: + super().__init__(quant_layer, quant_injector) + # Necessary for export + self._cached_act = None + @property def requires_quant_input(self): return True @@ -166,6 +174,10 @@ def forward(self, x: torch.Tensor, input_bit_width: torch.Tensor, class BiasQuantProxyFromInjector(ParameterQuantProxyFromInjector, BiasQuantProxyProtocol): + def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: + super().__init__(quant_layer, quant_injector) + self._cached_bias = None + @property def tracked_parameter_list(self): return [m.bias for m in self.tracked_module_list if m.bias is not None] @@ -179,21 +191,42 @@ def requires_input_scale(self) -> bool: def scale(self): if self.requires_input_scale or not self.is_quant_enabled: - return None + if self._cached_bias is None: + warn( + "No quant bias cache found, set cache_inference_quant_bias=True and run an " + "inference pass first") + return None + if self.training: + warn("Cached quant bias scale is being used in training mode.") + return self._cached_bias.scale zhs = self._zero_hw_sentinel() scale = self.__call__(self.tracked_parameter_list[0], zhs).scale return scale def zero_point(self): if not self.is_quant_enabled: - return None + if self._cached_bias is None: + warn( + "No quant bias cache found, set cache_inference_quant_bias=True and run an " + "inference pass first") + return None + if self.training: + warn("Cached quant bias scale is being used in training mode.") + return self._cached_bias.zero_point zhs = self._zero_hw_sentinel() zero_point = self.__call__(self.tracked_parameter_list[0], zhs).zero_point return zero_point def bit_width(self): if not self.is_quant_enabled: - return None + if self._cached_bias is None: + warn( + "No quant bias cache found, set cache_inference_quant_bias=True and run an " + "inference pass first") + return None + if self.training: + warn("Cached quant bias scale is being used in training mode.") + return self._cached_bias.bit_width zhs = self._zero_hw_sentinel() out = self.__call__(self.tracked_parameter_list[0], zhs) return out.bit_width diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 4b04230de..edc5ab087 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -89,6 +89,7 @@ def __init__(self, quant_layer, quant_injector): QuantProxyFromInjector.__init__(self, quant_layer, quant_injector) ActQuantProxyProtocol.__init__(self) self.is_passthrough_act = _is_passthrough_act(quant_injector) + self._cached_act = None @property def is_quant_enabled(self): @@ -118,34 +119,56 @@ def init_tensor_quant(self): self.fused_activation_quant_proxy = None def scale(self, force_eval=True): - if not self.is_quant_enabled: + if self.is_quant_enabled: + current_status = self.training + if force_eval: + self.eval() + out = self.__call__(self._zero_hw_sentinel()) + self.train(current_status) + return out.scale + elif self._cached_act is not None: + return self._cached_act.scale + elif self._cached_act is None: return None - current_status = self.training - if force_eval: - self.eval() - out = self.__call__(self._zero_hw_sentinel()) - self.train(current_status) - return out.scale def zero_point(self, force_eval=True): - if not self.is_quant_enabled: + if self.is_quant_enabled: + current_status = self.training + if force_eval: + self.eval() + out = self.__call__(self._zero_hw_sentinel()) + self.train(current_status) + return out.zero_point + elif self._cached_act is not None: + return self._cached_act.zero_point + elif self._cached_act is None: return None - current_status = self.training - if force_eval: - self.eval() - out = self.__call__(self._zero_hw_sentinel()) - self.train(current_status) - return out.zero_point def bit_width(self, force_eval=True): - if not self.is_quant_enabled: + if self.is_quant_enabled: + current_status = self.training + if force_eval: + self.eval() + out = self.__call__(self._zero_hw_sentinel()) + self.train(current_status) + return out.bit_width + elif self._cached_act is not None: + return self._cached_act.bit_width + elif self._cached_act is None: + return None + + def signed(self, force_eval=True): + if self.is_quant_enabled: + current_status = self.training + if force_eval: + self.eval() + out = self.__call__(self._zero_hw_sentinel()) + self.train(current_status) + return out.signed + elif self._cached_act is not None: + return self._cached_act.signed + elif self._cached_act is None: return None - current_status = self.training - if force_eval: - self.eval() - out = self.__call__(self._zero_hw_sentinel()) - self.train(current_status) - return out.bit_width def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: if self.fused_activation_quant_proxy is not None: diff --git a/src/brevitas_examples/super_resolution/utils/evaluate.py b/src/brevitas_examples/super_resolution/utils/evaluate.py index 2eb9e3627..881c10eea 100644 --- a/src/brevitas_examples/super_resolution/utils/evaluate.py +++ b/src/brevitas_examples/super_resolution/utils/evaluate.py @@ -56,7 +56,7 @@ def _calc_min_acc_bit_width(module: QuantWBIOL) -> Tensor: # bit-width and sign need to come from the quant tensor of the preceding layer if no io_quant input_bit_width = module.quant_input_bit_width() - input_is_signed = float(module.is_quant_input_signed) + input_is_signed = float(module.input_quant.signed()) # the tensor quantizer requires a QuantTensor with specified bit-width and sign quant_weight = module.quant_weight() diff --git a/tests/brevitas/export/test_torch_qop.py b/tests/brevitas/export/test_torch_qop.py index e01bd93cb..4003600dc 100644 --- a/tests/brevitas/export/test_torch_qop.py +++ b/tests/brevitas/export/test_torch_qop.py @@ -57,7 +57,7 @@ def forward(self, x): brevitas_out = model(inp) pytorch_qf_model = export_torch_qop(model, input_t=inp) pytorch_out = pytorch_qf_model(inp) - atol = model.conv.quant_output_scale().item() * TOLERANCE + atol = model.conv.output_quant.scale().item() * TOLERANCE assert pytorch_out.isclose(brevitas_out, rtol=0.0, atol=atol).all() @@ -94,7 +94,7 @@ def forward(self, x): brevitas_out = model(inp) pytorch_qf_model = export_torch_qop(model, input_t=inp) pytorch_out = pytorch_qf_model(inp) - atol = model.linear.quant_output_scale().item() * TOLERANCE + atol = model.linear.output_quant.scale().item() * TOLERANCE assert pytorch_out.isclose(brevitas_out, rtol=0.0, atol=atol).all() @@ -132,7 +132,7 @@ def forward(self, x): brevitas_out = model(inp) pytorch_qf_model = export_torch_qop(model, input_t=inp) pytorch_out = pytorch_qf_model(inp) - atol = model.linear.quant_output_scale().item() * TOLERANCE + atol = model.linear.output_quant.scale().item() * TOLERANCE assert pytorch_out.isclose(brevitas_out, rtol=0.0, atol=atol).all() @@ -172,7 +172,7 @@ def forward(self, x): brevitas_out = model(inp) pytorch_qf_model = export_torch_qop(model, input_t=inp) pytorch_out = pytorch_qf_model(inp) - atol = model.conv.quant_output_scale().item() * TOLERANCE + atol = model.conv.output_quant.scale().item() * TOLERANCE assert pytorch_out.isclose(brevitas_out, rtol=0.0, atol=atol).all() @@ -200,5 +200,5 @@ def forward(self, x): brevitas_out = model(inp) pytorch_qf_model = export_torch_qop(model, input_t=inp) pytorch_out = pytorch_qf_model(inp) - atol = model.act2.quant_output_scale().item() * TOLERANCE + atol = model.act2.act_quant.scale().item() * TOLERANCE assert pytorch_out.isclose(brevitas_out, rtol=0.0, atol=atol).all() diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 0b6303a8b..1775c68d6 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -59,7 +59,7 @@ def forward(self, x): model(inp) expected_scale = reference_implementation_scale_factors_po2(inp) - scale = model.act.quant_act_scale() + scale = model.act.act_quant.scale() assert torch.allclose(expected_scale, scale) diff --git a/tests/brevitas/nn/test_linear.py b/tests/brevitas/nn/test_linear.py index 8fd2d6e04..62799281b 100644 --- a/tests/brevitas/nn/test_linear.py +++ b/tests/brevitas/nn/test_linear.py @@ -36,7 +36,7 @@ def test_module_init_scale_impl_type_override(self): in_features=INPUT_FEATURES, bias=True, weight_scaling_impl_type='HE') - assert mod.quant_weight_scale() + assert mod.weight_quant.scale() class TestQuantLinearFwd: diff --git a/tests/brevitas/nn/test_wbiol.py b/tests/brevitas/nn/test_wbiol.py index 3577d25a1..d3fc4bdce 100644 --- a/tests/brevitas/nn/test_wbiol.py +++ b/tests/brevitas/nn/test_wbiol.py @@ -84,7 +84,7 @@ def test_default_wbiol_weight_quant_enabled(default_wbiol_layer: QuantWBIOL): def test_default_wbiol_weight_bit_width_enabled(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.quant_weight_bit_width() == torch.tensor(8.) + assert default_wbiol_layer.weight_quant.bit_width() == torch.tensor(8.) def test_default_wbiol_return_quant(default_wbiol_layer: QuantWBIOL): @@ -96,7 +96,7 @@ def test_default_wbiol_quant_bias_signed(default_wbiol_layer: QuantWBIOL): def test_default_wbiol_quant_weight_signed(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.is_quant_weight_signed + assert default_wbiol_layer.weight_quant.is_signed def test_default_wbiol_quant_bias_narrow_range(default_wbiol_layer: QuantWBIOL): @@ -108,11 +108,11 @@ def test_default_wbiol_quant_weight_narrow_range(default_wbiol_layer: QuantWBIOL def test_default_wbiol_quant_input_signed(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.is_quant_input_signed is None + assert default_wbiol_layer.input_quant.signed() is None def test_default_wbiol_quant_output_signed(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.is_quant_output_signed is None + assert default_wbiol_layer.output_quant.signed() is None def test_default_wbiol_quant_input_narrow_range(default_wbiol_layer: QuantWBIOL): @@ -124,31 +124,31 @@ def test_default_wbiol_quant_output_narrow_range(default_wbiol_layer: QuantWBIOL def test_default_wbiol_quant_input_zero_point(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.quant_input_zero_point() is None + assert default_wbiol_layer.input_quant.zero_point() is None def test_default_wbiol_quant_output_zero_point(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.quant_output_zero_point() is None + assert default_wbiol_layer.output_quant.zero_point() is None def test_default_wbiol_quant_weight_zero_point(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.quant_weight_zero_point() == torch.tensor(0.) + assert default_wbiol_layer.weight_quant.zero_point() == torch.tensor(0.) def test_default_wbiol_quant_bias_zero_point(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.quant_bias_zero_point() is None + assert default_wbiol_layer.bias_quant.zero_point() is None def test_default_wbiol_quant_input_scale(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.quant_input_scale() is None + assert default_wbiol_layer.input_quant.scale() is None def test_default_wbiol_quant_output_scale(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.quant_output_scale() is None + assert default_wbiol_layer.output_quant.scale() is None def test_default_wbiol_quant_bias_scale(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.quant_bias_scale() is None + assert default_wbiol_layer.bias_quant.scale() is None def test_default_wbiol_weight_quant_proxy(default_wbiol_layer: QuantWBIOL): diff --git a/tests/brevitas/proxy/test_act_scaling.py b/tests/brevitas/proxy/test_act_scaling.py index 17cee628c..3b4537610 100644 --- a/tests/brevitas/proxy/test_act_scaling.py +++ b/tests/brevitas/proxy/test_act_scaling.py @@ -41,7 +41,7 @@ def test_scaling_stats_to_parameter(self): stats_act.eval() param_act.eval() - assert (torch.allclose(stats_act.quant_act_scale(), param_act.quant_act_scale())) + assert (torch.allclose(stats_act.act_quant.scale(), param_act.act_quant.scale())) def test_scaling_parameter_grad(self): stats_act = QuantReLU( diff --git a/tests/brevitas/proxy/test_weight_scaling.py b/tests/brevitas/proxy/test_weight_scaling.py index 074ca7c61..7b9d48259 100644 --- a/tests/brevitas/proxy/test_weight_scaling.py +++ b/tests/brevitas/proxy/test_weight_scaling.py @@ -18,10 +18,10 @@ def test_parameter_from_stats_update(): weight_quant_type='binary', weight_scaling_impl_type='parameter_from_stats') l_max = linear.weight.abs().max() - old_scale = q_linear.quant_weight_scale() + old_scale = q_linear.weight_quant.scale() old_ql_max = q_linear.weight.abs().max() q_linear.load_state_dict(linear.state_dict()) - new_scale = q_linear.quant_weight_scale() + new_scale = q_linear.weight_quant.scale() new_ql_max = q_linear.weight.abs().max() assert old_scale == old_ql_max assert new_scale == l_max @@ -43,10 +43,10 @@ def test_parameter_from_stats_state_dict(): weight_quant_type='binary', weight_scaling_impl_type='parameter', weight_scaling_init=0.001) - q_linear1_old_scale = q_linear1.quant_weight_scale() + q_linear1_old_scale = q_linear1.weight_quant.scale() q_linear1.load_state_dict(q_linear2.state_dict()) - q_linear1_new_scale = q_linear1.quant_weight_scale() - q_linear2_scale = q_linear2.quant_weight_scale() + q_linear1_new_scale = q_linear1.weight_quant.scale() + q_linear2_scale = q_linear2.weight_quant.scale() assert q_linear1_old_scale != q_linear2_scale assert q_linear1_old_scale != q_linear1_new_scale assert q_linear1_new_scale == q_linear2_scale