From abaf7dcc825bcb2bb80bd10ec6b9b6a6b407c230 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 25 Mar 2024 16:08:16 +0000 Subject: [PATCH] Feat (minifloat): cleanup minifloat impl --- src/brevitas/core/function_wrapper/clamp.py | 46 +++++++++--- src/brevitas/core/quant/float.py | 16 ++-- src/brevitas/core/scaling/__init__.py | 1 + src/brevitas/core/scaling/float_scaling.py | 32 +++++--- src/brevitas/function/ops.py | 3 +- src/brevitas/quant/experimental/float_base.py | 17 ----- .../quant/experimental/float_quant_ocp.py | 27 +++++++ src/brevitas/utils/float_quant_utils.py | 14 +++- tests/brevitas/core/test_clamp.py | 23 +++++- tests/brevitas/core/test_float_quant.py | 73 +++++++++++++------ 10 files changed, 175 insertions(+), 77 deletions(-) diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index 0bfb79374..9bc545c5c 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -13,6 +13,7 @@ import brevitas from brevitas.core.utils import StatelessBuffer from brevitas.function import tensor_clamp +from brevitas.function.ops import max_float class TensorClamp(brevitas.jit.ScriptModule): @@ -90,39 +91,62 @@ class FloatClamp(brevitas.jit.ScriptModule): def __init__( self, - max_value: float, tensor_clamp_impl: Module, + signed: bool, inf_values: Optional[Tuple[str]] = None, - saturating: bool = True) -> None: + nan_values: Optional[Tuple[str]] = None, + max_available_float: Optional[Tensor] = None, + saturating: bool = True, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None) -> None: super(FloatClamp, self).__init__() self.tensor_clamp_impl = tensor_clamp_impl - - self.max_value = StatelessBuffer(torch.tensor(max_value)) self.saturating = saturating - self.has_inf_values = bool(inf_values) + self.inf_values = inf_values + self.nan_values = nan_values + self.signed = signed + + if max_available_float: + max_available_float = torch.tensor(max_available_float, device=device, dtype=dtype) + self.max_available_float = StatelessBuffer(max_available_float) + else: + self.max_available_float = None @brevitas.jit.script_method - def forward(self, x: Tensor): + def forward( + self, + x: Tensor, + exponent_bit_width: Tensor, + mantissa_bit_width: Tensor, + exponent_bias: Tensor): inf_mask = x.isinf() - p_max_val_mask = x > self.max_value() - n_max_val_mask = -x > self.max_value() + max_value = max_float(exponent_bit_width, mantissa_bit_width, exponent_bias) + max_value = max_value if self.max_available_float is None else torch.min( + max_value, self.max_available_float()) + p_max_val_mask = x > max_value + n_max_val_mask = -x > max_value + min_float = torch.tensor(0.) if not self.signed else -max_value # first clamp everything to +- max_value, basically the saturating case - x = self.tensor_clamp_impl(x, min_val=-self.max_value(), max_val=self.max_value()) + x = self.tensor_clamp_impl(x, min_val=min_float, max_val=max_value) if not self.saturating: # if non-saturating, we need to map values greater than max_val to nan or inf - if self.has_inf_values: + if self.inf_values: # we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf x[p_max_val_mask] = torch.tensor(float('inf')) x[n_max_val_mask] = torch.tensor(float('-inf')) - else: + elif self.nan_values: # no inf values, so we need to map them to NaN full_max_val_mask = torch.logical_or(p_max_val_mask, n_max_val_mask) x[full_max_val_mask] = torch.tensor(float('nan')) # we also map the inf values to NaN in this case x[inf_mask] = torch.tensor(float('nan')) + else: + raise RuntimeError( + "Clamping is not saturaing, but neither `inf_values` nor `nan_values` is specified" + ) return x diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 11da5864b..26b12814f 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -1,7 +1,7 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from typing import Optional +from typing import Optional, Tuple import torch import torch.nn as nn @@ -46,8 +46,7 @@ def __init__( (torch.tensor(float(mantissa_bit_width), device=device, dtype=dtype))) self.exponent_bias = StatelessBuffer( torch.tensor(float(exponent_bias), device=device, dtype=dtype)) - self.fp_max_val = StatelessBuffer( - max_float(self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())) + self.fp_internal_scale_min = StatelessBuffer( 1. - self.exponent_bias() - self.mantissa_bit_width()) if float_scaling_impl is None: @@ -69,14 +68,12 @@ def internal_scale(self, x): @brevitas.jit.script_method def quantize(self, x: torch.Tensor): - scale = self.scaling_impl(x) / self.float_scaling_impl(x) + scale_impl_value = self.scaling_impl( + self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) + scale = scale_impl_value / self.float_scaling_impl(x) scaled_x = x / scale internal_scale = self.internal_scale(scaled_x) val_fp_quant = internal_scale * self.float_to_int_impl(scaled_x / internal_scale) - if self.signed: - val_fp_quant = torch.clip(val_fp_quant, -1. * self.fp_max_val(), self.fp_max_val()) - else: - val_fp_quant = torch.clip(val_fp_quant, 0., self.fp_max_val()) return val_fp_quant, scale @brevitas.jit.script_method @@ -87,7 +84,8 @@ def dequantize(self, y, scale): def forward(self, x): y, scale = self.quantize(x) # after quantizing, clamp to special cases like NaN/inf if they are set - y = self.float_clamp_impl(y) + y = self.float_clamp_impl( + y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) y = self.dequantize(y, scale) # This is to respect the current interface of proxies return y, scale, self.zero_point_impl(), self.bit_width() diff --git a/src/brevitas/core/scaling/__init__.py b/src/brevitas/core/scaling/__init__.py index 6187bd262..1be86be55 100644 --- a/src/brevitas/core/scaling/__init__.py +++ b/src/brevitas/core/scaling/__init__.py @@ -4,6 +4,7 @@ from brevitas.core.stats import SCALAR_SHAPE +from .float_scaling import FloatScaling from .int_scaling import IntScaling from .int_scaling import PowerOfTwoIntScaling from .pre_scaling import AccumulatorAwareParameterPreScaling diff --git a/src/brevitas/core/scaling/float_scaling.py b/src/brevitas/core/scaling/float_scaling.py index 89fd46362..e082589a0 100644 --- a/src/brevitas/core/scaling/float_scaling.py +++ b/src/brevitas/core/scaling/float_scaling.py @@ -1,7 +1,7 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from typing import Optional +from typing import List, Optional, Tuple import torch from torch import Tensor @@ -15,18 +15,28 @@ class FloatScaling(brevitas.jit.ScriptModule): def __init__( self, - exponent_bit_width: int, - mantissa_bit_width: int, - exponent_bias: int, + max_available_float: Optional[float] = None, + inf_values: Optional[Tuple[str]] = None, + nan_values: Optional[Tuple[str]] = None, + saturating: bool = True, device: Optional[str] = None, dtype: Optional[torch.dtype] = None): super(FloatScaling, self).__init__() - exponent_bit_width = torch.tensor(exponent_bit_width, device=device, dtype=dtype) - mantissa_bit_width = torch.tensor(mantissa_bit_width, device=device, dtype=dtype) - exponent_bias = torch.tensor(exponent_bias, device=device, dtype=dtype) - self.max_float_val = StatelessBuffer( - max_float(exponent_bit_width, mantissa_bit_width, exponent_bias)) + self.inf_values = inf_values + self.nan_values = nan_values + self.saturating = saturating + + if max_available_float: + max_available_float = torch.tensor(max_available_float, device=device, dtype=dtype) + self.max_available_float = StatelessBuffer(max_available_float) + else: + self.max_available_float = None @brevitas.jit.script_method - def forward(self, input: torch.Tensor) -> Tensor: - return self.max_float_val() + def forward( + self, exponent_bit_width: Tensor, mantissa_bit_width: Tensor, + exponent_bias: Tensor) -> Tensor: + max_value = max_float(exponent_bit_width, mantissa_bit_width, exponent_bias) + max_value = max_value if self.max_available_float is None else torch.min( + max_value, self.max_available_float()) + return max_value diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index 7bbffaec7..10717774c 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -5,12 +5,13 @@ Implementation of various core operations often performed as part of quantization. The implemented functions adheres to the restriction imposed by Pytorch 1.1.0's TorchScript compiler. """ -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import Tensor import brevitas +from brevitas.utils.float_quant_utils import get_minifloat_value @brevitas.jit.script diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index 61201578e..9a2893039 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -12,7 +12,6 @@ from brevitas.quant.solver import ActQuantSolver from brevitas.quant.solver import WeightQuantSolver from brevitas.quant.solver.common import SolveTensorQuantFloatToIntImplFromEnum -from brevitas.utils.float_quant_utils import get_max_value class FloatBase(SolveTensorQuantFloatToIntImplFromEnum): @@ -27,22 +26,6 @@ class FloatBase(SolveTensorQuantFloatToIntImplFromEnum): def exponent_bias(exponent_bit_width): return 2 ** (exponent_bit_width - 1) - 1 - @value - def max_value( - exponent_bit_width, - mantissa_bit_width, - exponent_bias, - nan_values=None, - inf_values=None, - saturating=True): - return get_max_value( - exponent_bit_width, - mantissa_bit_width, - exponent_bias, - nan_values, - inf_values, - saturating) - class FloatWeightBase(FloatBase): proxy_class = WeightQuantProxyFromInjector diff --git a/src/brevitas/quant/experimental/float_quant_ocp.py b/src/brevitas/quant/experimental/float_quant_ocp.py index 6dfda1304..f2b148482 100644 --- a/src/brevitas/quant/experimental/float_quant_ocp.py +++ b/src/brevitas/quant/experimental/float_quant_ocp.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from dependencies import value + from brevitas.quant.base import MSESymmetricScale from brevitas.quant.experimental.float_base import FloatActBase from brevitas.quant.experimental.float_base import FloatWeightBase @@ -8,17 +10,42 @@ from brevitas.quant.experimental.float_base import Fp8e5m2Mixin from brevitas.quant.experimental.float_base import ScaledFloatActBase from brevitas.quant.experimental.float_base import ScaledFloatWeightBase +from brevitas.utils.float_quant_utils import get_max_available_float class Fp8e4m3OCPMixin(Fp8e4m3Mixin): nan_values = (('111',)) inf_values = None + @value + def max_available_float( + exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values, + saturating): + return get_max_available_float( + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + nan_values, + inf_values, + saturating) + class Fp8e5m2OCPMixin(Fp8e5m2Mixin): nan_values = ('01', '11', '10') inf_values = (('00',)) + @value + def max_available_float( + exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values, + saturating): + return get_max_available_float( + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + nan_values, + inf_values, + saturating) + class Fp8e4m3OCPWeight(Fp8e4m3OCPMixin, FloatWeightBase): """ diff --git a/src/brevitas/utils/float_quant_utils.py b/src/brevitas/utils/float_quant_utils.py index 5d5c4037f..b108c37ee 100644 --- a/src/brevitas/utils/float_quant_utils.py +++ b/src/brevitas/utils/float_quant_utils.py @@ -1,5 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from typing import Tuple + +import torch def mantissa_bits_to_float(bits: str, frexp_compatible: bool = False) -> float: @@ -21,11 +24,16 @@ def get_minifloat_value(exponent: str, mantissa: str, exponent_bias: int) -> flo """ exponent_value = int(exponent, 2) mantissa_value = mantissa_bits_to_float(mantissa) - return 2 ** (exponent_value - exponent_bias) * mantissa_value + return (2 ** (exponent_value - exponent_bias)) * mantissa_value -def get_max_value( - exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values, saturating): +def get_max_available_float( + exponent_bit_width: torch.Tensor, + mantissa_bit_width: torch.Tensor, + exponent_bias: torch.Tensor, + nan_values: Tuple[str], + inf_values: Tuple[str], + saturating: bool) -> torch.Tensor: # Idea: take the smallest NaN/inf value, set max_value to the next smaller one # inf without NaN not possible if inf_values is None and nan_values is None: diff --git a/tests/brevitas/core/test_clamp.py b/tests/brevitas/core/test_clamp.py index 5ba5a0a32..335fac2f5 100644 --- a/tests/brevitas/core/test_clamp.py +++ b/tests/brevitas/core/test_clamp.py @@ -3,7 +3,9 @@ from hypothesis import given import pytest +import torch +from brevitas.function.ops import max_float from brevitas.quant.experimental.float import Fp8e4m3Weight from brevitas.quant.experimental.float import Fp8e5m2Weight from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeight @@ -20,19 +22,34 @@ 'minifloat, expected_max_val', ((format, max_val) for format, max_val in FORMAT_MAXVAL_MAP.items())) def test_max_value(minifloat, expected_max_val): - max_val = minifloat.float_clamp_impl.max_value() + max_val = max_float( + torch.tensor(minifloat.exponent_bit_width, dtype=torch.float32), + torch.tensor(minifloat.mantissa_bit_width, dtype=torch.float32), + torch.tensor(minifloat.exponent_bias, dtype=torch.float32)) + max_available_float = minifloat.float_clamp_impl.max_available_float + max_val = max_val if max_available_float is None else torch.min(max_val, max_available_float()) assert expected_max_val == max_val @given(inp=float_tensor_random_shape_st()) def test_float_clamp(inp, fp8_clamp): - max_val = fp8_clamp.float_clamp_impl.max_value() + + max_val = max_float( + torch.tensor(fp8_clamp.exponent_bit_width, dtype=torch.float32), + torch.tensor(fp8_clamp.mantissa_bit_width, dtype=torch.float32), + torch.tensor(fp8_clamp.exponent_bias, dtype=torch.float32)) + max_available_float = fp8_clamp.float_clamp_impl.max_available_float + max_val = max_val if max_available_float is None else torch.min(max_val, max_available_float()) # get values that exceed max_val over_limit_mask = inp.abs() > max_val # clamp inp - inp = fp8_clamp.float_clamp_impl(inp) + inp = fp8_clamp.float_clamp_impl( + inp, + torch.tensor(fp8_clamp.exponent_bit_width), + torch.tensor(fp8_clamp.mantissa_bit_width), + torch.tensor(fp8_clamp.exponent_bias)) if fp8_clamp.float_clamp_impl.saturating: # should be clamped to +- max val diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 1e4058fb8..1958715f6 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -11,7 +11,8 @@ from brevitas.core.function_wrapper import TensorClamp from brevitas.core.quant.float import FloatQuant from brevitas.core.scaling import ConstScaling -from brevitas.utils.float_quant_utils import get_max_value +from brevitas.core.scaling import FloatScaling +from brevitas.function.ops import max_float from tests.brevitas.hyp_helper import float_st from tests.brevitas.hyp_helper import float_tensor_random_shape_st from tests.brevitas.hyp_helper import random_minifloat_format @@ -32,12 +33,17 @@ def test_float_quant_defaults(minifloat_format): signed=signed, float_clamp_impl=None) else: - max_value = get_max_value( - exponent_bit_width, mantissa_bit_width, exponent_bias, None, None, True) # init FloatClamp - float_clamp = FloatClamp(max_value=max_value, tensor_clamp_impl=TensorClamp()) + float_clamp = FloatClamp( + tensor_clamp_impl=TensorClamp(), + signed=signed, + inf_values=None, + nan_values=None, + saturating=True) + scaling = FloatScaling(None, None, True) float_quant = FloatQuant( bit_width=bit_width, + scaling_impl=scaling, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, @@ -45,7 +51,7 @@ def test_float_quant_defaults(minifloat_format): float_clamp_impl=float_clamp) assert isinstance(float_quant.float_to_int_impl, RoundSte) assert isinstance(float_quant.float_scaling_impl, ConstScaling) - assert isinstance(float_quant.scaling_impl, ConstScaling) + assert isinstance(float_quant.scaling_impl, FloatScaling) @given(minifloat_format=random_minifloat_format()) @@ -57,6 +63,7 @@ def test_minifloat(minifloat_format): @given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format()) def test_float_to_quant_float(inp, minifloat_format): bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format + if exponent_bit_width == 0 or mantissa_bit_width == 0: with pytest.raises(RuntimeError): float_quant = FloatQuant( @@ -67,12 +74,17 @@ def test_float_to_quant_float(inp, minifloat_format): signed=signed, float_clamp_impl=None) else: - max_value = get_max_value( - exponent_bit_width, mantissa_bit_width, exponent_bias, None, None, True) # init FloatClamp - float_clamp = FloatClamp(max_value=max_value, tensor_clamp_impl=TensorClamp()) + float_clamp = FloatClamp( + tensor_clamp_impl=TensorClamp(), + signed=signed, + inf_values=None, + nan_values=None, + saturating=True) + scaling = FloatScaling(None, None, True) float_quant = FloatQuant( bit_width=bit_width, + scaling_impl=scaling, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, @@ -81,6 +93,9 @@ def test_float_to_quant_float(inp, minifloat_format): expected_out, _, _, bit_width_out = float_quant(inp) out_quant, scale = float_quant.quantize(inp) + exponent_bit_width, mantissa_bit_width, exponent_bias = torch.tensor(exponent_bit_width, dtype=torch.float), torch.tensor(mantissa_bit_width, dtype=torch.float), torch.tensor(exponent_bias, dtype=torch.float) + out_quant = float_quant.float_clamp_impl( + out_quant, exponent_bit_width, mantissa_bit_width, exponent_bias) assert bit_width_out == bit_width assert torch.equal(expected_out, out_quant * scale) @@ -89,7 +104,7 @@ def test_float_to_quant_float(inp, minifloat_format): @jit_disabled_for_mock() def test_scaling_impls_called_once(inp, minifloat_format): bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format - scaling_impl = mock.Mock(side_effect=lambda x: 1.) + scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.) float_scaling_impl = mock.Mock(side_effect=lambda x: 1.) if exponent_bit_width == 0 or mantissa_bit_width == 0: with pytest.raises(RuntimeError): @@ -103,10 +118,13 @@ def test_scaling_impls_called_once(inp, minifloat_format): float_scaling_impl=float_scaling_impl, float_clamp_impl=None) else: - max_value = get_max_value( - exponent_bit_width, mantissa_bit_width, exponent_bias, None, None, True) # init FloatClamp - float_clamp = FloatClamp(max_value=max_value, tensor_clamp_impl=TensorClamp()) + float_clamp = FloatClamp( + tensor_clamp_impl=TensorClamp(), + signed=signed, + inf_values=None, + nan_values=None, + saturating=True) float_quant = FloatQuant( bit_width=bit_width, exponent_bit_width=exponent_bit_width, @@ -116,9 +134,12 @@ def test_scaling_impls_called_once(inp, minifloat_format): scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl, float_clamp_impl=float_clamp) - output = float_quant.quantize(inp) + _ = float_quant.quantize(inp) # scaling implementations should be called exaclty once on the input - scaling_impl.assert_called_once_with(inp) + scaling_impl.assert_called_once_with( + torch.tensor(exponent_bit_width), + torch.tensor(mantissa_bit_width), + torch.tensor(exponent_bias)) float_scaling_impl.assert_called_once_with(inp) @@ -130,7 +151,7 @@ def test_scaling_impls_called_once(inp, minifloat_format): def test_inner_scale(inp, minifloat_format, scale): bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format # set scaling_impl to scale and float_scaling_impl to 1 to use the same scale as we are here - scaling_impl = mock.Mock(side_effect=lambda x: scale) + scaling_impl = mock.Mock(side_effect=lambda x, y, z: scale) float_scaling_impl = mock.Mock(side_effect=lambda x: 1.) if exponent_bit_width == 0 or mantissa_bit_width == 0: with pytest.raises(RuntimeError): @@ -144,10 +165,13 @@ def test_inner_scale(inp, minifloat_format, scale): float_scaling_impl=float_scaling_impl, float_clamp_impl=None) else: - max_value = get_max_value( - exponent_bit_width, mantissa_bit_width, exponent_bias, None, None, True) # init FloatClamp - float_clamp = FloatClamp(max_value=max_value, tensor_clamp_impl=TensorClamp()) + float_clamp = FloatClamp( + tensor_clamp_impl=TensorClamp(), + signed=signed, + inf_values=None, + nan_values=None, + saturating=True) float_quant = FloatQuant( bit_width=bit_width, exponent_bit_width=exponent_bit_width, @@ -160,15 +184,20 @@ def test_inner_scale(inp, minifloat_format, scale): # scale inp manually scaled_inp = inp / scale - + max_val = max_float( + torch.tensor(exponent_bit_width), + torch.tensor(mantissa_bit_width), + torch.tensor(exponent_bias)) + max_available_float = float_clamp.max_available_float + max_value = max_val if max_available_float is None else torch.min( + max_value, max_available_float) # call internal scale internal_scale = float_quant.internal_scale(scaled_inp) val_fp_quant = internal_scale * float_quant.float_to_int_impl(scaled_inp / internal_scale) if signed: - val_fp_quant = torch.clip( - val_fp_quant, -1. * float_quant.fp_max_val(), float_quant.fp_max_val()) + val_fp_quant = torch.clip(val_fp_quant, -1. * max_val, max_val) else: - val_fp_quant = torch.clip(val_fp_quant, 0., float_quant.fp_max_val()) + val_fp_quant = torch.clip(val_fp_quant, 0., max_val) # dequantize manually out = val_fp_quant * scale