Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 28, 2024
1 parent 596c1af commit 16ac867
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 15 deletions.
8 changes: 4 additions & 4 deletions src/brevitas/export/onnx/standard/qoperator/handler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def input_clip_symbolic_kwargs(cls, module):
if module.is_input_quant_enabled:
narrow = module.is_quant_input_narrow_range
signed = module.input_quant.signed()
bit_width = module.quant_input_bit_width()
bit_width = module.input_quant.bit_width()
return cls.int_clip_symbolic_kwargs(narrow, signed, bit_width)
else:
return None
Expand All @@ -95,10 +95,10 @@ def output_dequant_symbolic_kwargs(cls, module):
def input_quant_symbolic_kwargs(cls, module):
if module.is_input_quant_enabled:
return {
'output_scale': module.quant_input_scale(),
'output_zero_point': cls.input_quant.zero_point(module),
'output_scale': module.input_quant.scale(),
'output_zero_point': cls.quant_input_zero_point(module),
'output_dtype': cls.torch_8b_dtype(module.input_quant.signed()),
'output_axis': cls.quant_axis(module.quant_input_scale())}
'output_axis': cls.quant_axis(module.input_quant.scale())}
else:
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ class StdQOpONNXQuantConvNdHandler(StdQOpONNXQuantWBIOLHandler, ABC):
def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d]):
conv_symbolic_kwargs = {
'input_scale': module.quant_input_scale(),
'input_zero_point': self.input_quant.zero_point(module),
'input_zero_point': self.quant_input_zero_point(module),
'int_weight': self.int_weight(module),
'weight_scale': to_0dim_if_scalar(module.quant_weight_scale().flatten()),
'weight_zero_point': to_0dim_if_scalar(self.weight_quant.zero_point(module).flatten()),
'weight_zero_point': to_0dim_if_scalar(self.quant_weight_zero_point(module).flatten()),
'output_scale': module.output_quant.scale(),
'output_zero_point': self.quant_output_zero_point(module),
'output_dtype': self.torch_8b_dtype(module.output_quant.signed()),
Expand Down Expand Up @@ -146,10 +146,10 @@ class StdQOpONNXQuantLinearHandler(StdQOpONNXQuantWBIOLHandler):
def op_symbolic_kwargs(self, module: QuantLinear):
conv_symbolic_kwargs = {
'input_scale': module.quant_input_scale(),
'input_zero_point': self.input_quant.zero_point(module),
'input_zero_point': self.quant_input_zero_point(module),
'int_weight': self.int_weight(module).view(module.out_features, module.in_features, 1),
'weight_scale': to_0dim_if_scalar(module.quant_weight_scale().flatten()),
'weight_zero_point': to_0dim_if_scalar(self.weight_quant.zero_point(module).flatten()),
'weight_zero_point': to_0dim_if_scalar(self.quant_weight_zero_point(module).flatten()),
'output_scale': module.output_quant.scale(),
'output_zero_point': self.quant_output_zero_point(module),
'output_dtype': self.torch_8b_dtype(module.output_quant.signed()),
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/export/torch/qoperator/handler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def gen_quant_impl_kwargs(
@classmethod
def prepare_input_quant(cls, module):
scale = module.input_quant.scale()
zero_point = cls.input_quant.zero_point(module)
zero_point = cls.quant_input_zero_point(module)
signed = module.input_quant.signed()
quant_impl, quant_kwargs = cls.gen_quant_impl_kwargs(scale, zero_point, signed)
return quant_impl, quant_kwargs
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/export/torch/qoperator/handler/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def prepare_weight_quant(cls, module: QuantWBIOL):
cls.validate_8b_bit_width(module.input_quant.bit_width(), le_then=False)
cls.validate_8b_bit_width(module.output_quant.bit_width(), le_then=False)
scale = module.quant_weight_scale()
zero_point = cls.weight_quant.zero_point(module)
zero_point = cls.quant_weight_zero_point(module)
signed = module.weight_quant.is_signed
weight = module.weight.detach()
quant_impl, quant_kwargs = cls.gen_quant_impl_kwargs(scale, zero_point, signed)
Expand Down
8 changes: 4 additions & 4 deletions src/brevitas/nn/mixin/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ def quant_weight(
input_is_signed = None
if self.weight_quant_requires_quant_input:
if self.is_weight_quant_enabled:
if quant_input is None and self.weight_quant._cached_act is not None:
input_bit_width = self.weight_quant._cached_act.bit_width
input_is_signed = self.weight_quant._cached_act.signed
elif quant_input is not None:
if quant_input is None:
input_bit_width = self.input_quant.bit_width()
input_is_signed = self.input_quant.signed()
else:
input_bit_width = quant_input.bit_width
input_is_signed = quant_input.signed
assert input_bit_width is not None, "Input bit-width needs to be specified."
Expand Down
2 changes: 1 addition & 1 deletion tests/brevitas/nn/test_wbiol.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_default_wbiol_quant_bias_signed(default_wbiol_layer: QuantWBIOL):


def test_default_wbiol_quant_weight_signed(default_wbiol_layer: QuantWBIOL):
assert default_wbiol_layer.quant_weight.is_signed
assert default_wbiol_layer.weight_quant.is_signed


def test_default_wbiol_quant_bias_narrow_range(default_wbiol_layer: QuantWBIOL):
Expand Down

0 comments on commit 16ac867

Please sign in to comment.