Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding FP8 weight export #907

Merged
merged 46 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
5b168c5
placeholder version
costigt-dev Apr 11, 2024
d2b7d2d
checkpoint commit
costigt-dev Apr 12, 2024
e10e630
first working flow end to end
costigt-dev Apr 12, 2024
84e70f7
formatting
costigt-dev Apr 12, 2024
ef4c737
changes to tests
costigt-dev Apr 12, 2024
4aa4b21
added version check for test
costigt-dev Apr 12, 2024
3b05883
using existing functionality over homespun
costigt-dev Apr 12, 2024
cad5802
corrected mistake in copying and restored FloatClipMixin
costigt-dev Apr 12, 2024
4848248
fixed mistake
costigt-dev Apr 12, 2024
5188aa6
first pass activation fp8 export
costigt-dev Apr 16, 2024
29cb952
beginnings of activation fp8 export and change name of QCDQCastFloatW…
costigt-dev Apr 16, 2024
9bf9240
more changes to make naming scheme more consistent
costigt-dev Apr 16, 2024
f9406f1
added FloatFusedActivationQuantProxy
costigt-dev Apr 16, 2024
991ddb7
replaced zero_point workaround with placeholder implementation of fp8…
costigt-dev Apr 17, 2024
520db85
removed verbose flag
costigt-dev Apr 17, 2024
2bb2895
created context manager for fp8 workaround
costigt-dev Apr 17, 2024
8ffce48
added check that objects being compared are tensors in the fp8 workar…
costigt-dev Apr 17, 2024
7edf5bd
General equal implementation
Giuseppe5 May 14, 2024
bbd5362
fallback to fp32 if fp8
Giuseppe5 May 14, 2024
4bc126d
Fix for PT < 2.1
Giuseppe5 May 14, 2024
a55dcd0
Remove non existent destroy
Giuseppe5 May 14, 2024
cd6cad6
Merge branch 'dev' into feat/export_fp8
Giuseppe5 May 23, 2024
fabc8ae
Remove import
Giuseppe5 May 23, 2024
74b65a9
Fixed imports
Giuseppe5 May 23, 2024
cf1ea02
Fixed imports
Giuseppe5 May 23, 2024
cda7f1f
Fix export
Giuseppe5 May 23, 2024
8349391
more testing
Giuseppe5 May 23, 2024
11387d3
Fix
Giuseppe5 May 24, 2024
592ccd3
Fix
Giuseppe5 May 24, 2024
1fc5642
fix
Giuseppe5 May 25, 2024
58f46bc
Fix minifloat check
Giuseppe5 May 25, 2024
bd657b8
Last fix
Giuseppe5 May 25, 2024
630a3e3
Fix minifloat
Giuseppe5 May 27, 2024
38a37fb
Review
Giuseppe5 May 28, 2024
76b3193
Review 2
Giuseppe5 May 28, 2024
529470f
Merge branch 'dev' into feat/export_fp8
Giuseppe5 May 28, 2024
f2f8969
fix
Giuseppe5 May 28, 2024
44579f8
Typo
Giuseppe5 May 28, 2024
038cba9
fix tests
Giuseppe5 May 28, 2024
198c5af
Typo
Giuseppe5 May 28, 2024
c3d7d3c
fix
Giuseppe5 May 28, 2024
fef531d
last fix
Giuseppe5 May 28, 2024
6431882
Fix JIT
Giuseppe5 May 29, 2024
4b78543
Fix import
Giuseppe5 May 29, 2024
d762c99
Last fix
Giuseppe5 May 29, 2024
ac5e58c
correct skip
Giuseppe5 May 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/brevitas/export/common/handler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import ABC
from abc import abstractmethod
import math
from warnings import warn

import torch
from torch import Tensor
Expand All @@ -12,7 +13,8 @@
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int

__all__ = ['BaseHandler', 'BitWidthHandlerMixin', 'ZeroPointHandlerMixin']
__all__ = [
'BaseHandler', 'BitWidthHandlerMixin', 'ZeroPointHandlerMixin', 'FloatZeroPointHandlerMixin']


class BaseHandler(Module, ABC):
Expand All @@ -38,6 +40,13 @@ def quant_axis(cls, scale):
return None


class FloatClipMixin(ABC):

@classmethod
def clip_symbolic_kwargs(cls, narrow, signed, exponent_bit_width, mantissa_bit_width):
return None


class ClipMixin(ABC):

@classmethod
Expand Down Expand Up @@ -112,6 +121,18 @@ def validate_neg_scalar_int_exponent(cls, scale: Tensor):
return -cls.validate_scalar_int_exponent(scale)


class FloatZeroPointHandlerMixin(ABC):

@classmethod
def zero_point_with_dtype(cls, signed, exponent_bit_width, mantissa_bit_width, zero_point):
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
if exponent_bit_width == 4 and mantissa_bit_width == 3:
return zero_point.type(torch.float8_e4m3fn)
elif exponent_bit_width == 5 and mantissa_bit_width == 2:
return zero_point.type(torch.float8_e5m2)
else:
return zero_point.type(torch.float32)


class ZeroPointHandlerMixin(ABC):

@classmethod
Expand Down
346 changes: 342 additions & 4 deletions src/brevitas/export/common/handler/qcdq.py

Large diffs are not rendered by default.

41 changes: 40 additions & 1 deletion src/brevitas/export/onnx/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,43 @@
from ..manager import ExportContext


# workaround for fp8 not having many operators implemented
class Fp8Workaround():
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self):
self.lib = None

def __enter__(self):
if torch_version >= version.parse('2.1.0'):
self.lib = torch.library.Library("aten", "IMPL")

def equal_cpu(self, other):
if (isinstance(self, Tensor) and
self.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)) or (
isinstance(other, Tensor) and
other.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)):
self = self.to(torch.float32)
other = other.to(torch.float32)
return torch.equal(self, other)
else:
res = True
if not isinstance(self, Tensor):
self = torch.tensor(self)
if not isinstance(other, Tensor):
other = torch.tensor(other)
if self.dim() > 0:
for x, y in zip(self.flatten(), other.flatten()):
res &= x == y
else:
res = self.item() == other.item()
return torch.tensor([res])

self.lib.impl("equal", equal_cpu, "CPU")

def __exit__(self, exc_type, exc_value, exc_traceback):
self.lib = None


class ONNXBaseManager(BaseManager, ABC):

model_transforms = []
Expand Down Expand Up @@ -127,7 +164,9 @@ def export_onnx(
else:
model_bytes = BytesIO()
export_target = model_bytes
torch.onnx.export(module, args, export_target, **onnx_export_kwargs)

with Fp8Workaround():
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
torch.onnx.export(module, args, export_target, **onnx_export_kwargs)

# restore the model to previous properties
module.apply(lambda m: _restore_act_caching_mode(m))
Expand Down
37 changes: 37 additions & 0 deletions src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC
from warnings import warn

import torch

Expand All @@ -10,6 +11,9 @@
from brevitas.export.common.handler.qcdq import DQCastMixin
from brevitas.export.common.handler.qcdq import DynamicQDQCastActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import DynamicQMixin
from brevitas.export.common.handler.qcdq import FloatQCDQCastActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import FloatQCDQCastWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import FloatQMixin
from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import \
Expand Down Expand Up @@ -47,12 +51,33 @@ def validate(self, module):
assert module.bit_width() > 1., 'Binary quant not supported'


class StdFloatDQCastONNXMixin(StdDQCastONNXMixin, ABC):

def validate(self, module):
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
pass


class StdFloatCDQCastONNXMixin(CDQCastMixin, StdFloatDQCastONNXMixin, ABC):

def clip_fn(self, x, min_val, max_val):
return IntClipFn.apply(x, min_val, max_val)
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved


class StdCDQCastONNXMixin(CDQCastMixin, StdDQCastONNXMixin, ABC):

def clip_fn(self, x, min_val, max_val):
return IntClipFn.apply(x, min_val, max_val)


class StdFloatQCDQCastONNXMixin(FloatQMixin, StdFloatCDQCastONNXMixin, ABC):

def validate(self, module):
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
pass

def quantize_fn(self, x, scale, zero_point, dtype, axis):
return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis)


class StdQCDQCastONNXMixin(QMixin, StdCDQCastONNXMixin, ABC):

@classmethod
Expand Down Expand Up @@ -112,6 +137,12 @@ def quantize_fn(self, x, dtype):
return DynamicQuantizeLinearFn.apply(x, dtype)


class StdFloatQCDQCastONNXWeightQuantProxyHandler(StdFloatQCDQCastONNXMixin,
FloatQCDQCastWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
_export_q_node = False


class StdQCDQCastONNXWeightQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQCastWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
Expand All @@ -130,6 +161,12 @@ class StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler(
_export_q_node = False


class StdFloatQCDQCastONNXActQuantProxyHandler(StdFloatQCDQCastONNXMixin,
FloatQCDQCastActQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQCastONNXActQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQCastActQuantProxyHandlerMixin,
ONNXBaseHandler):
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas/export/onnx/standard/qcdq/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from ..manager import StdONNXBaseManager
from .handler import StdCDQCastONNXBiasQuantProxyHandler
from .handler import StdDynamicQDQCastONNXActQuantProxyHandler
from .handler import StdFloatQCDQCastONNXActQuantProxyHandler
from .handler import StdFloatQCDQCastONNXWeightQuantProxyHandler
from .handler import StdQCDQCastONNXActQuantProxyHandler
from .handler import StdQCDQCastONNXDecoupledWeightQuantProxyHandler
from .handler import StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler
Expand All @@ -36,8 +38,10 @@ class StdQCDQONNXManager(StdONNXBaseManager):

handlers = [
StdQCDQCastONNXWeightQuantProxyHandler,
StdFloatQCDQCastONNXWeightQuantProxyHandler,
StdCDQCastONNXBiasQuantProxyHandler,
StdQCDQCastONNXActQuantProxyHandler,
StdFloatQCDQCastONNXActQuantProxyHandler,
StdQCDQCastONNXDecoupledWeightQuantProxyHandler,
StdDynamicQDQCastONNXActQuantProxyHandler,
StdQCDQCastONNXTruncQuantProxyHandler,
Expand Down
2 changes: 2 additions & 0 deletions src/brevitas/proxy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from .float_parameter_quant import WeightFloatQuantProxyFromInjector
from .float_runtime_quant import ActFloatQuantProxyFromInjector
from .parameter_quant import BiasQuantProxyFromInjector
from .parameter_quant import BiasQuantProxyFromInjectorBase
from .parameter_quant import DecoupledWeightQuantProxyFromInjector
Expand Down
56 changes: 21 additions & 35 deletions src/brevitas/proxy/float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,28 @@
class ActFloatQuantProxyFromInjector(ActQuantProxyFromInjectorBase):

def scale(self, force_eval=True):
if self.is_quant_enabled:
current_status = self.training
if force_eval:
self.eval()
out = self.__call__(self._zero_hw_sentinel())
self.train(current_status)
return out.scale
elif self._cached_act is not None:
return self._cached_act.scale
elif self._cached_act is None:
return None
return self.retrieve_attribute('scale', force_eval)

def zero_point(self, force_eval=True):
if self.is_quant_enabled:
current_status = self.training
if force_eval:
self.eval()
out = self.__call__(self._zero_hw_sentinel())
self.train(current_status)
return out.zero_point
elif self._cached_act is not None:
return self._cached_act.zero_point
elif self._cached_act is None:
return None
return self.retrieve_attribute('zero_point', force_eval)

def bit_width(self, force_eval=True):
if self.is_quant_enabled:
current_status = self.training
if force_eval:
self.eval()
out = self.__call__(self._zero_hw_sentinel())
self.train(current_status)
return out.bit_width
elif self._cached_act is not None:
return self._cached_act.bit_width
elif self._cached_act is None:
return None
def exponent_bit_width(self, force_eval=True):
return self.retrieve_attribute('exponent_bit_width', force_eval)

def mantissa_bit_width(self, force_eval=True):
return self.retrieve_attribute('mantissa_bit_width', force_eval)

def exponent_bias(self, force_eval=True):
return self.retrieve_attribute('exponent_bias', force_eval)

def saturating(self, force_eval=True):
return self.retrieve_attribute('saturating', force_eval)

def inf_values(self, force_eval=True):
return self.retrieve_attribute('inf_values', force_eval)

def nan_values(self, force_eval=True):
return self.retrieve_attribute('nan_values', force_eval)

def forward(self, x: Union[Tensor, FloatQuantTensor]) -> Union[Tensor, FloatQuantTensor]:
out = x
Expand All @@ -68,7 +53,8 @@ def forward(self, x: Union[Tensor, FloatQuantTensor]) -> Union[Tensor, FloatQuan
y = self.fused_activation_quant_proxy(y)
# If y is an empty FloatQuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor
if isinstance(y, tuple) and not any(map(lambda f: f is None, y)):
# We exclude the last two values (inf_values and nan_values)
if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])):
out = FloatQuantTensor(*y, signed=self.is_signed, training=self.training)
elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant
if isinstance(y, tuple):
Expand Down
54 changes: 21 additions & 33 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
'ActQuantProxyProtocol',
'AccQuantProxyProtocol',
'ActQuantProxyFromInjector',
'FloatActQuantProxyFromInjector',
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
'TruncQuantProxyFromInjector',
'ClampQuantProxyFromInjector']

Expand Down Expand Up @@ -95,6 +96,23 @@ def __init__(self, quant_layer, quant_injector):
self.cache_inference_quant_act = False
self.cache_quant_io_metadata_only = True

def internal_forward(self, force_eval):
current_status = self.training
if force_eval:
self.eval()
out = self.__call__(self._zero_hw_sentinel())
self.train(current_status)
return out

def retrieve_attribute(self, attribute, force_eval):
if self.is_quant_enabled:
out = self.internal_forward(force_eval)
return getattr(out, attribute)
elif self._cached_act is not None:
return getattr(self._cached_act, attribute)
elif self._cached_act is None:
return None

@property
def is_quant_enabled(self):
return self._is_quant_enabled and not self.disable_quant
Expand Down Expand Up @@ -132,43 +150,13 @@ def init_tensor_quant(self):
class ActQuantProxyFromInjector(ActQuantProxyFromInjectorBase):

def scale(self, force_eval=True):
if self.is_quant_enabled:
current_status = self.training
if force_eval:
self.eval()
out = self.__call__(self._zero_hw_sentinel())
self.train(current_status)
return out.scale
elif self._cached_act is not None:
return self._cached_act.scale
elif self._cached_act is None:
return None
return self.retrieve_attribute('scale', force_eval)

def zero_point(self, force_eval=True):
if self.is_quant_enabled:
current_status = self.training
if force_eval:
self.eval()
out = self.__call__(self._zero_hw_sentinel())
self.train(current_status)
return out.zero_point
elif self._cached_act is not None:
return self._cached_act.zero_point
elif self._cached_act is None:
return None
return self.retrieve_attribute('zero_point', force_eval)

def bit_width(self, force_eval=True):
if self.is_quant_enabled:
current_status = self.training
if force_eval:
self.eval()
out = self.__call__(self._zero_hw_sentinel())
self.train(current_status)
return out.bit_width
elif self._cached_act is not None:
return self._cached_act.bit_width
elif self._cached_act is None:
return None
return self.retrieve_attribute('bit_width', force_eval)

def forward(self, x: Union[Tensor, IntQuantTensor]) -> Union[Tensor, IntQuantTensor]:
out = x
Expand Down
6 changes: 2 additions & 4 deletions src/brevitas/quant/experimental/float_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from brevitas.core.scaling.float_scaling import FloatScaling
from brevitas.inject import ExtendedInjector
from brevitas.inject import value
from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector
from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
from brevitas.proxy import ActFloatQuantProxyFromInjector
from brevitas.proxy import WeightFloatQuantProxyFromInjector
from brevitas.quant.solver import ActQuantSolver
from brevitas.quant.solver import WeightQuantSolver
from brevitas.quant.solver.common import SolveTensorQuantFloatToIntImplFromEnum
Expand Down
Loading
Loading