diff --git a/src/brevitas/export/onnx/standard/qoperator/handler/base.py b/src/brevitas/export/onnx/standard/qoperator/handler/base.py index ce9336d30..f8d5e1e7a 100644 --- a/src/brevitas/export/onnx/standard/qoperator/handler/base.py +++ b/src/brevitas/export/onnx/standard/qoperator/handler/base.py @@ -79,7 +79,7 @@ def input_clip_symbolic_kwargs(cls, module): if module.is_input_quant_enabled: narrow = module.is_quant_input_narrow_range signed = module.input_quant.signed() - bit_width = module.quant_input_bit_width() + bit_width = module.input_quant.bit_width() return cls.int_clip_symbolic_kwargs(narrow, signed, bit_width) else: return None @@ -95,10 +95,10 @@ def output_dequant_symbolic_kwargs(cls, module): def input_quant_symbolic_kwargs(cls, module): if module.is_input_quant_enabled: return { - 'output_scale': module.quant_input_scale(), - 'output_zero_point': cls.input_quant.zero_point(module), + 'output_scale': module.input_quant.scale(), + 'output_zero_point': cls.quant_input_zero_point(module), 'output_dtype': cls.torch_8b_dtype(module.input_quant.signed()), - 'output_axis': cls.quant_axis(module.quant_input_scale())} + 'output_axis': cls.quant_axis(module.input_quant.scale())} else: return None diff --git a/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py b/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py index 696f2c30b..a892b63b3 100644 --- a/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py +++ b/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py @@ -109,10 +109,10 @@ class StdQOpONNXQuantConvNdHandler(StdQOpONNXQuantWBIOLHandler, ABC): def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d]): conv_symbolic_kwargs = { 'input_scale': module.quant_input_scale(), - 'input_zero_point': self.input_quant.zero_point(module), + 'input_zero_point': self.quant_input_zero_point(module), 'int_weight': self.int_weight(module), 'weight_scale': to_0dim_if_scalar(module.quant_weight_scale().flatten()), - 'weight_zero_point': to_0dim_if_scalar(self.weight_quant.zero_point(module).flatten()), + 'weight_zero_point': to_0dim_if_scalar(self.quant_weight_zero_point(module).flatten()), 'output_scale': module.output_quant.scale(), 'output_zero_point': self.quant_output_zero_point(module), 'output_dtype': self.torch_8b_dtype(module.output_quant.signed()), @@ -146,10 +146,10 @@ class StdQOpONNXQuantLinearHandler(StdQOpONNXQuantWBIOLHandler): def op_symbolic_kwargs(self, module: QuantLinear): conv_symbolic_kwargs = { 'input_scale': module.quant_input_scale(), - 'input_zero_point': self.input_quant.zero_point(module), + 'input_zero_point': self.quant_input_zero_point(module), 'int_weight': self.int_weight(module).view(module.out_features, module.in_features, 1), 'weight_scale': to_0dim_if_scalar(module.quant_weight_scale().flatten()), - 'weight_zero_point': to_0dim_if_scalar(self.weight_quant.zero_point(module).flatten()), + 'weight_zero_point': to_0dim_if_scalar(self.quant_weight_zero_point(module).flatten()), 'output_scale': module.output_quant.scale(), 'output_zero_point': self.quant_output_zero_point(module), 'output_dtype': self.torch_8b_dtype(module.output_quant.signed()), diff --git a/src/brevitas/export/torch/qoperator/handler/base.py b/src/brevitas/export/torch/qoperator/handler/base.py index 82a72cf81..42900495f 100644 --- a/src/brevitas/export/torch/qoperator/handler/base.py +++ b/src/brevitas/export/torch/qoperator/handler/base.py @@ -56,7 +56,7 @@ def gen_quant_impl_kwargs( @classmethod def prepare_input_quant(cls, module): scale = module.input_quant.scale() - zero_point = cls.input_quant.zero_point(module) + zero_point = cls.quant_input_zero_point(module) signed = module.input_quant.signed() quant_impl, quant_kwargs = cls.gen_quant_impl_kwargs(scale, zero_point, signed) return quant_impl, quant_kwargs diff --git a/src/brevitas/export/torch/qoperator/handler/parameter.py b/src/brevitas/export/torch/qoperator/handler/parameter.py index d45752de0..9c3cc8011 100644 --- a/src/brevitas/export/torch/qoperator/handler/parameter.py +++ b/src/brevitas/export/torch/qoperator/handler/parameter.py @@ -52,7 +52,7 @@ def prepare_weight_quant(cls, module: QuantWBIOL): cls.validate_8b_bit_width(module.input_quant.bit_width(), le_then=False) cls.validate_8b_bit_width(module.output_quant.bit_width(), le_then=False) scale = module.quant_weight_scale() - zero_point = cls.weight_quant.zero_point(module) + zero_point = cls.quant_weight_zero_point(module) signed = module.weight_quant.is_signed weight = module.weight.detach() quant_impl, quant_kwargs = cls.gen_quant_impl_kwargs(scale, zero_point, signed) diff --git a/src/brevitas/nn/mixin/parameter.py b/src/brevitas/nn/mixin/parameter.py index 80cd4e726..50511723a 100644 --- a/src/brevitas/nn/mixin/parameter.py +++ b/src/brevitas/nn/mixin/parameter.py @@ -91,10 +91,10 @@ def quant_weight( input_is_signed = None if self.weight_quant_requires_quant_input: if self.is_weight_quant_enabled: - if quant_input is None and self.weight_quant._cached_act is not None: - input_bit_width = self.weight_quant._cached_act.bit_width - input_is_signed = self.weight_quant._cached_act.signed - elif quant_input is not None: + if quant_input is None: + input_bit_width = self.input_quant.bit_width() + input_is_signed = self.input_quant.signed() + else: input_bit_width = quant_input.bit_width input_is_signed = quant_input.signed assert input_bit_width is not None, "Input bit-width needs to be specified." diff --git a/tests/brevitas/nn/test_wbiol.py b/tests/brevitas/nn/test_wbiol.py index e55879aef..d3fc4bdce 100644 --- a/tests/brevitas/nn/test_wbiol.py +++ b/tests/brevitas/nn/test_wbiol.py @@ -96,7 +96,7 @@ def test_default_wbiol_quant_bias_signed(default_wbiol_layer: QuantWBIOL): def test_default_wbiol_quant_weight_signed(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.quant_weight.is_signed + assert default_wbiol_layer.weight_quant.is_signed def test_default_wbiol_quant_bias_narrow_range(default_wbiol_layer: QuantWBIOL):