From d3b520fcf7cb7684513970f06f0a8e63fda03aba Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 8 Oct 2024 15:15:42 +0100 Subject: [PATCH] fix float quant --- src/brevitas/core/quant/float.py | 13 ++++++------- tests/brevitas/core/test_float_quant.py | 3 ++- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index cd5c632c2..09dcc248a 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -68,12 +68,6 @@ def __init__( @brevitas.jit.script_method def quantize(self, x: torch.Tensor, scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - if self.float_scaling_impl is not None: - float_scaling_impl_value = self.float_scaling_impl( - self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) - else: - float_scaling_impl_value = None - scale = self.scaling_impl(x, float_scaling_impl_value) x = self.input_view_impl(x) scaled_x = x / scale internal_scale = float_internal_scale( @@ -87,7 +81,12 @@ def dequantize(self, y, scale): @brevitas.jit.script_method def forward(self, x): - scale = self.scaling_impl(x) + if self.float_scaling_impl is not None: + float_scaling_impl_value = self.float_scaling_impl( + self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) + else: + float_scaling_impl_value = None + scale = self.scaling_impl(x, float_scaling_impl_value) if self.observer_only: y = x saturating, inf_values, nan_values = self.float_clamp_impl.saturating, self.float_clamp_impl.inf_values, self.float_clamp_impl.nan_values diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 51f1c0295..6c7e26f31 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -143,7 +143,8 @@ def test_scaling_impls_called_once(inp, minifloat_format): scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl, float_clamp_impl=float_clamp) - scale = float_quant.scaling_impl(inp) + float_scaling = float_scaling_impl(exponent_bit_width, mantissa_bit_width, exponent_bias) + scale = float_quant.scaling_impl(inp, float_scaling) _ = float_quant.quantize(inp, scale) # scaling implementations should be called exaclty once on the input float_scaling_impl.assert_called_once_with(