Skip to content

Commit

Permalink
fix float quant
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 8, 2024
1 parent b3117dd commit d3b520f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
13 changes: 6 additions & 7 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,6 @@ def __init__(

@brevitas.jit.script_method
def quantize(self, x: torch.Tensor, scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.float_scaling_impl is not None:
float_scaling_impl_value = self.float_scaling_impl(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
else:
float_scaling_impl_value = None
scale = self.scaling_impl(x, float_scaling_impl_value)
x = self.input_view_impl(x)
scaled_x = x / scale
internal_scale = float_internal_scale(
Expand All @@ -87,7 +81,12 @@ def dequantize(self, y, scale):

@brevitas.jit.script_method
def forward(self, x):
scale = self.scaling_impl(x)
if self.float_scaling_impl is not None:
float_scaling_impl_value = self.float_scaling_impl(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
else:
float_scaling_impl_value = None
scale = self.scaling_impl(x, float_scaling_impl_value)
if self.observer_only:
y = x
saturating, inf_values, nan_values = self.float_clamp_impl.saturating, self.float_clamp_impl.inf_values, self.float_clamp_impl.nan_values
Expand Down
3 changes: 2 additions & 1 deletion tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def test_scaling_impls_called_once(inp, minifloat_format):
scaling_impl=scaling_impl,
float_scaling_impl=float_scaling_impl,
float_clamp_impl=float_clamp)
scale = float_quant.scaling_impl(inp)
float_scaling = float_scaling_impl(exponent_bit_width, mantissa_bit_width, exponent_bias)
scale = float_quant.scaling_impl(inp, float_scaling)
_ = float_quant.quantize(inp, scale)
# scaling implementations should be called exaclty once on the input
float_scaling_impl.assert_called_once_with(
Expand Down

0 comments on commit d3b520f

Please sign in to comment.