Skip to content

[Experimental] Float8 support in AQT #671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
12d0ac2
updates for float
jainapurva Aug 13, 2024
9173ee1
AAdded optional param
jainapurva Aug 14, 2024
32bf45f
updates for compatible type
jainapurva Aug 14, 2024
ddc1bc0
updates to support torch=2.2
jainapurva Aug 14, 2024
b3e4e79
updates to quant api
jainapurva Aug 19, 2024
e778bca
updates
jainapurva Aug 19, 2024
d4b057f
Float8 updates
jainapurva Aug 20, 2024
9d86df3
Merge branch 'main' into experimental_float8_aqt
jainapurva Aug 20, 2024
04c471e
todos
jainapurva Aug 20, 2024
d86d798
version check
jainapurva Aug 21, 2024
1e5dfd9
Updates for removing float8 API
jainapurva Aug 21, 2024
e25a8e9
Add float8wo to hf_eval
jainapurva Aug 22, 2024
29316fa
remove fpx layout
jainapurva Aug 22, 2024
1fb4a2b
revert changes to hf_eval
jainapurva Aug 22, 2024
ceb3275
Revert changes to main
jainapurva Aug 22, 2024
81eb91f
Fp8 upgrades
jainapurva Aug 23, 2024
c017186
Test for float8
jainapurva Aug 23, 2024
ca9fdbf
Updates
jainapurva Aug 23, 2024
fce977f
Merge branch 'main' into experimental_float8_aqt
jainapurva Aug 23, 2024
bfb8b3b
Remove from_float_float8
jainapurva Aug 23, 2024
146c328
Update optional[int]
jainapurva Aug 23, 2024
9b0c2aa
remove from_float8
jainapurva Aug 23, 2024
3c300eb
Review fixes
jainapurva Aug 24, 2024
efd480b
Review fixes
jainapurva Aug 24, 2024
0abdcc1
typos
jainapurva Aug 24, 2024
5fc0dd8
typos
jainapurva Aug 24, 2024
526a282
Added constraints
jainapurva Aug 24, 2024
83b7356
Added constraints
jainapurva Aug 24, 2024
3ac9f72
Remove eps check
jainapurva Aug 27, 2024
8bd3849
Seperate floatx from_float
jainapurva Aug 27, 2024
02a89b3
Typing fixes
jainapurva Aug 27, 2024
8e19598
init fixes
jainapurva Aug 27, 2024
56a4eb5
Updates for fixes
jainapurva Aug 27, 2024
33ae19a
Review fixes
jainapurva Aug 27, 2024
482f537
Revert a doc string change
jainapurva Aug 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)
from torchao.quantization.quant_api import (
float8_weight_only
)
import torch
import unittest
import tempfile
from torchao.utils import (
TORCH_VERSION_AFTER_2_5,
)


class TestAffineQuantizedFloat(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_tensor_core_layout_transpose(self):
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
t = l.weight
shape = t.shape
apply_float8_weight_only_quant = float8_weight_only()
ql = apply_float8_weight_only_quant(l)
aqt = ql.weight
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)

# transpose shape test
for _ in range(10):
t = t.t()
aqt = aqt.t()
shape = t.shape
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_weights_only(self):
for apply_quant in [float8_weight_only()]:
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(l)
with tempfile.NamedTemporaryFile() as f:
torch.save(ql.state_dict(), f)
f.seek(0)
# `weights_only=True` is enabled for torch 2.5+
if TORCH_VERSION_AFTER_2_5:
_ = torch.load(f, weights_only=True)
else:
_ = torch.load(f, weights_only=False)


if __name__ == "__main__":
run_tests()
2 changes: 2 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
AffineQuantizedTensor,
to_affine_quantized,
to_affine_quantized_static,
to_affine_quantized_float8,
LayoutType,
PlainLayoutType,
SemiSparseLayoutType,
Expand All @@ -18,6 +19,7 @@
"AffineQuantizedTensor",
"to_affine_quantized",
"to_affine_quantized_static",
"to_affine_quantized_float8",
"LayoutType",
"PlainLayoutType",
"SemiSparseLayoutType",
Expand Down
141 changes: 136 additions & 5 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from typing import Dict, Callable, Any, Tuple, Optional
from typing import Dict, Callable, Any, Tuple, Optional, Union
from collections import defaultdict
import functools
import math
Expand All @@ -25,6 +25,7 @@
_get_layout_tensor_constructor,
LayoutType,
PlainLayoutType,
FpxLayoutType,
is_device,
)
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
Expand Down Expand Up @@ -116,8 +117,8 @@ def __new__(
layout_tensor: AQTLayout,
block_size: Tuple[int, ...],
shape: torch.Size,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
quant_min: Optional[Union[int, float]] = None,
quant_max: Optional[Union[int, float]] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
dtype=None,
strides=None,
Expand All @@ -138,8 +139,8 @@ def __init__(
layout_tensor: AQTLayout,
block_size: Tuple[int, ...],
shape: torch.Size,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
quant_min: Optional[Union[int, float]] = None,
quant_max: Optional[Union[int, float]] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
dtype=None,
strides=None,
Expand Down Expand Up @@ -269,6 +270,42 @@ def from_float_static(
dtype=input_float.dtype,
)

@classmethod
def from_float_float8(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 can you help with a design on how to have from_float, from_float_static, etc extend to this use case? Ideally we shouldn't special case a set of dtypes (float8) to have their own function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A combined function is better, will be refactoring it after testing float8

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah sure, I think we could have the two following final state:

  • have separate from_float_fpx and from_from_intx since they have a bit different arg list
  • if we manage to generalize the arg list enough so it is reasonable merge the two then we can merge as well, I will discuss with Apurva and Driss about the args but at the first glance maybe preserve_zero is always going to be true and zero_point_domain may not apply here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense.

one thought, the from_float name will become more confusing if both the source and the target can also be various floating point bitwidths. To clarify this in torchao.float8, I went with the high_precision|hp and low_precision|lp naming scheme

Copy link
Contributor

@cpuhrsch cpuhrsch Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also why I like the idea of extending to and having our own factory functions that we can pass dtype enums. For example

"to_nf4",

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo yeah makes sense, we can rename from_float to from_high_precision as well. as @cpuhrsch mentioned, this is not user facing API, we'll have to factory functions for various dtypes as the user facing API:

to_affine_quantized = AffineQuantizedTensor.from_float
to_affine_quantized_static = AffineQuantizedTensor.from_float_static

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jainapurva as discussed from the meeting, let's merge this into from_float and add some guards on arguments

cls,
input_float: torch.Tensor,
mapping_type: MappingType,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[float] = None,
quant_max: Optional[float] = None,
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
layout_type: LayoutType = FpxLayoutType(),
):
original_shape = input_float.shape
input_float = layout_type.pre_process(input_float)

scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
float8_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)

float8_data = layout_type.post_process(float8_data)

layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
layout_tensor = layout_tensor_ctr(float8_data, float8_data, None, layout_type)
return cls(
layout_tensor,
block_size,
original_shape,
quant_min,
quant_max,
zero_point_domain,
dtype=input_float.dtype
)

@property
def layout_type(self) -> LayoutType:
return self.layout_tensor.layout_type
Expand Down Expand Up @@ -663,6 +700,99 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def get_layout_type(self) -> LayoutType:
return self.layout_type


@register_layout_cls(FpxLayoutType)
class FpxAQTLayout(AQTLayout):

def __new__(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
layout_type: LayoutType,
):
kwargs = {}
kwargs["device"] = int_data.device
kwargs["layout"] = (
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
)
kwargs["dtype"] = int_data.dtype
kwargs["requires_grad"] = False
shape = int_data.shape
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
int_data: torch.Tensor,
scale: torch.Tensor,
layout_type: LayoutType,
):
self.int_data = int_data
self.scale = scale
self.layout_type = layout_type

def __tensor_flatten__(self):
return ["int_data", "scale"], [self.layout_type]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
int_data, scale = tensor_data_dict["int_data"], tensor_data_dict["scale"]
layout_type, = tensor_attributes
return cls(int_data, scale, layout_type)

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]:
return self.int_data.get_plain(), self.scale

@classmethod
def from_plain(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
layout_type: LayoutType,
):
assert isinstance(layout_type, FpxLayoutType)
return cls(int_data, scale, layout_type)

def __repr__(self):
int_data, scale = self.get_plain()
layout_type = self.get_layout_type()
return f"{self.__class__.__name__}(int_data={int_data}, scale={scale}, layout_type={layout_type})"

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.int_data),
fn(self.scale),
self.layout_type,
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs

if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

if func is aten.t.default:
tensor = args[0]
new = tensor.__class__(
tensor.int_data.view(tensor.shape[::-1]), tensor.scale, tensor.layout_type
)
return return_and_correct_aliasing(func, args, kwargs, new)

raise NotImplementedError(
f"FpxAQTLayout dispatch: attempting to run {func}, this is not supported"
)

__torch_function__ = torch._C._disabled_torch_function_impl

def get_layout_type(self) -> LayoutType:
return self.layout_type


#####################################################
# torch functional and aten operator implementation #
#####################################################
Expand Down Expand Up @@ -989,6 +1119,7 @@ def _(func, types, args, kwargs):

to_affine_quantized = AffineQuantizedTensor.from_float
to_affine_quantized_static = AffineQuantizedTensor.from_float_static
to_affine_quantized_float8 = AffineQuantizedTensor.from_float_float8

if TORCH_VERSION_AT_LEAST_2_5:
# Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True`
Expand Down
5 changes: 5 additions & 0 deletions torchao/dtypes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ def __repr__(self):
def extra_repr(self) -> str:
return ""

@dataclass(frozen=True)
class FpxLayoutType(LayoutType):
pass


"""
Plain LayoutType, the most basic LayoutType, also has no extra metadata, will typically be the default
"""
Expand Down
1 change: 0 additions & 1 deletion torchao/float8/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

# Needed to load Float8Tensor with weights_only = True
from torch.serialization import add_safe_globals

add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole, LinearMMConfig])

__all__ = [
Expand Down
19 changes: 19 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
"int8_dynamic_activation_int8_semi_sparse_weight",
"int4_weight_only",
"int8_weight_only",
"float8_weight_only",
]

from .GPTQ import (
Expand Down Expand Up @@ -488,6 +489,24 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
"""
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())

def float8_weight_only():
"""
Applies float8 weight-only symmetric per-channel quantization to linear layers.
"""
def apply_float8wo_quant(weight):
# avoid circular dep
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you want to import to_affine_quantized_floatx here, I'm also refactoring this file to change the import to the file to avoid circular dep as well

from torchao.dtypes import to_affine_quantized_float8

mapping_type = MappingType.SYMMETRIC
target_dtype = torch.float8_e4m3fn
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.float32
block_size = (1, weight.shape[1])
return to_affine_quantized_float8(input_float=weight, mapping_type=mapping_type, block_size=block_size, target_dtype=target_dtype,
eps=eps, zero_point_dtype=zero_point_dtype)

return _get_linear_subclass_inserter(apply_float8wo_quant)


def uintx_weight_only(bit_width, group_size=64, pack_dim=-1):
"""
Expand Down
Loading
Loading