-
Notifications
You must be signed in to change notification settings - Fork 310
[Experimental] Float8 support in AQT #671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
12d0ac2
9173ee1
32bf45f
ddc1bc0
b3e4e79
e778bca
d4b057f
9d86df3
04c471e
d86d798
1e5dfd9
e25a8e9
29316fa
1fb4a2b
ceb3275
81eb91f
c017186
ca9fdbf
fce977f
bfb8b3b
146c328
9b0c2aa
3c300eb
efd480b
0abdcc1
5fc0dd8
526a282
83b7356
3ac9f72
8bd3849
02a89b3
8e19598
56a4eb5
33ae19a
482f537
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from torch.testing._internal.common_utils import ( | ||
TestCase, | ||
run_tests, | ||
) | ||
from torchao.quantization.quant_api import ( | ||
float8_weight_only | ||
) | ||
import torch | ||
import unittest | ||
import tempfile | ||
from torchao.utils import ( | ||
TORCH_VERSION_AFTER_2_5, | ||
) | ||
|
||
|
||
class TestAffineQuantizedFloat(TestCase): | ||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
def test_tensor_core_layout_transpose(self): | ||
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") | ||
t = l.weight | ||
shape = t.shape | ||
apply_float8_weight_only_quant = float8_weight_only() | ||
ql = apply_float8_weight_only_quant(l) | ||
aqt = ql.weight | ||
aqt_shape = aqt.shape | ||
self.assertEqual(aqt_shape, shape) | ||
|
||
# transpose shape test | ||
for _ in range(10): | ||
t = t.t() | ||
aqt = aqt.t() | ||
shape = t.shape | ||
aqt_shape = aqt.shape | ||
self.assertEqual(aqt_shape, shape) | ||
|
||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
def test_weights_only(self): | ||
for apply_quant in [float8_weight_only()]: | ||
jainapurva marked this conversation as resolved.
Show resolved
Hide resolved
|
||
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") | ||
ql = apply_quant(l) | ||
with tempfile.NamedTemporaryFile() as f: | ||
torch.save(ql.state_dict(), f) | ||
f.seek(0) | ||
# `weights_only=True` is enabled for torch 2.5+ | ||
if TORCH_VERSION_AFTER_2_5: | ||
_ = torch.load(f, weights_only=True) | ||
else: | ||
_ = torch.load(f, weights_only=False) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -1,5 +1,5 @@ | ||||||||
import torch | ||||||||
from typing import Dict, Callable, Any, Tuple, Optional | ||||||||
from typing import Dict, Callable, Any, Tuple, Optional, Union | ||||||||
from collections import defaultdict | ||||||||
import functools | ||||||||
import math | ||||||||
|
@@ -25,6 +25,7 @@ | |||||||
_get_layout_tensor_constructor, | ||||||||
LayoutType, | ||||||||
PlainLayoutType, | ||||||||
FpxLayoutType, | ||||||||
is_device, | ||||||||
) | ||||||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass | ||||||||
|
@@ -116,8 +117,8 @@ def __new__( | |||||||
layout_tensor: AQTLayout, | ||||||||
block_size: Tuple[int, ...], | ||||||||
shape: torch.Size, | ||||||||
quant_min: Optional[int] = None, | ||||||||
quant_max: Optional[int] = None, | ||||||||
quant_min: Optional[Union[int, float]] = None, | ||||||||
quant_max: Optional[Union[int, float]] = None, | ||||||||
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, | ||||||||
dtype=None, | ||||||||
strides=None, | ||||||||
|
@@ -138,8 +139,8 @@ def __init__( | |||||||
layout_tensor: AQTLayout, | ||||||||
block_size: Tuple[int, ...], | ||||||||
shape: torch.Size, | ||||||||
quant_min: Optional[int] = None, | ||||||||
quant_max: Optional[int] = None, | ||||||||
quant_min: Optional[Union[int, float]] = None, | ||||||||
quant_max: Optional[Union[int, float]] = None, | ||||||||
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, | ||||||||
dtype=None, | ||||||||
strides=None, | ||||||||
|
@@ -269,6 +270,42 @@ def from_float_static( | |||||||
dtype=input_float.dtype, | ||||||||
) | ||||||||
|
||||||||
@classmethod | ||||||||
def from_float_float8( | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jerryzh168 can you help with a design on how to have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A combined function is better, will be refactoring it after testing float8 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah sure, I think we could have the two following final state:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. makes sense. one thought, the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is also why I like the idea of extending Line 16 in 5c0e060
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vkuzo yeah makes sense, we can rename ao/torchao/dtypes/affine_quantized_tensor.py Lines 990 to 991 in 5c0e060
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jainapurva as discussed from the meeting, let's merge this into |
||||||||
cls, | ||||||||
input_float: torch.Tensor, | ||||||||
mapping_type: MappingType, | ||||||||
block_size: Tuple[int, ...], | ||||||||
target_dtype: torch.dtype, | ||||||||
quant_min: Optional[float] = None, | ||||||||
quant_max: Optional[float] = None, | ||||||||
eps: Optional[float] = None, | ||||||||
scale_dtype: Optional[torch.dtype] = None, | ||||||||
zero_point_dtype: Optional[torch.dtype] = None, | ||||||||
preserve_zero: bool = True, | ||||||||
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, | ||||||||
layout_type: LayoutType = FpxLayoutType(), | ||||||||
): | ||||||||
original_shape = input_float.shape | ||||||||
input_float = layout_type.pre_process(input_float) | ||||||||
|
||||||||
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) | ||||||||
float8_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) | ||||||||
|
||||||||
float8_data = layout_type.post_process(float8_data) | ||||||||
|
||||||||
layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) | ||||||||
layout_tensor = layout_tensor_ctr(float8_data, float8_data, None, layout_type) | ||||||||
return cls( | ||||||||
layout_tensor, | ||||||||
block_size, | ||||||||
original_shape, | ||||||||
quant_min, | ||||||||
quant_max, | ||||||||
zero_point_domain, | ||||||||
dtype=input_float.dtype | ||||||||
) | ||||||||
|
||||||||
@property | ||||||||
def layout_type(self) -> LayoutType: | ||||||||
return self.layout_tensor.layout_type | ||||||||
|
@@ -663,6 +700,99 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |||||||
def get_layout_type(self) -> LayoutType: | ||||||||
return self.layout_type | ||||||||
|
||||||||
|
||||||||
@register_layout_cls(FpxLayoutType) | ||||||||
class FpxAQTLayout(AQTLayout): | ||||||||
jainapurva marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
|
||||||||
def __new__( | ||||||||
cls, | ||||||||
int_data: torch.Tensor, | ||||||||
scale: torch.Tensor, | ||||||||
layout_type: LayoutType, | ||||||||
): | ||||||||
kwargs = {} | ||||||||
kwargs["device"] = int_data.device | ||||||||
kwargs["layout"] = ( | ||||||||
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout | ||||||||
) | ||||||||
kwargs["dtype"] = int_data.dtype | ||||||||
kwargs["requires_grad"] = False | ||||||||
shape = int_data.shape | ||||||||
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] | ||||||||
|
||||||||
def __init__( | ||||||||
self, | ||||||||
int_data: torch.Tensor, | ||||||||
scale: torch.Tensor, | ||||||||
layout_type: LayoutType, | ||||||||
): | ||||||||
self.int_data = int_data | ||||||||
jainapurva marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
self.scale = scale | ||||||||
self.layout_type = layout_type | ||||||||
|
||||||||
def __tensor_flatten__(self): | ||||||||
return ["int_data", "scale"], [self.layout_type] | ||||||||
|
||||||||
@classmethod | ||||||||
def __tensor_unflatten__( | ||||||||
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride | ||||||||
): | ||||||||
int_data, scale = tensor_data_dict["int_data"], tensor_data_dict["scale"] | ||||||||
layout_type, = tensor_attributes | ||||||||
return cls(int_data, scale, layout_type) | ||||||||
|
||||||||
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: | ||||||||
return self.int_data.get_plain(), self.scale | ||||||||
|
||||||||
@classmethod | ||||||||
def from_plain( | ||||||||
cls, | ||||||||
int_data: torch.Tensor, | ||||||||
scale: torch.Tensor, | ||||||||
zero_point: Optional[torch.Tensor], | ||||||||
layout_type: LayoutType, | ||||||||
): | ||||||||
assert isinstance(layout_type, FpxLayoutType) | ||||||||
return cls(int_data, scale, layout_type) | ||||||||
|
||||||||
def __repr__(self): | ||||||||
int_data, scale = self.get_plain() | ||||||||
layout_type = self.get_layout_type() | ||||||||
return f"{self.__class__.__name__}(int_data={int_data}, scale={scale}, layout_type={layout_type})" | ||||||||
|
||||||||
def _apply_fn_to_data(self, fn): | ||||||||
return self.__class__( | ||||||||
fn(self.int_data), | ||||||||
fn(self.scale), | ||||||||
self.layout_type, | ||||||||
) | ||||||||
|
||||||||
@classmethod | ||||||||
def __torch_dispatch__(cls, func, types, args, kwargs): | ||||||||
kwargs = {} if kwargs is None else kwargs | ||||||||
|
||||||||
if func is aten.detach.default: | ||||||||
return return_and_correct_aliasing( | ||||||||
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) | ||||||||
) | ||||||||
|
||||||||
if func is aten.t.default: | ||||||||
tensor = args[0] | ||||||||
new = tensor.__class__( | ||||||||
tensor.int_data.view(tensor.shape[::-1]), tensor.scale, tensor.layout_type | ||||||||
) | ||||||||
return return_and_correct_aliasing(func, args, kwargs, new) | ||||||||
|
||||||||
raise NotImplementedError( | ||||||||
f"FpxAQTLayout dispatch: attempting to run {func}, this is not supported" | ||||||||
) | ||||||||
|
||||||||
__torch_function__ = torch._C._disabled_torch_function_impl | ||||||||
|
||||||||
def get_layout_type(self) -> LayoutType: | ||||||||
return self.layout_type | ||||||||
|
||||||||
|
||||||||
##################################################### | ||||||||
# torch functional and aten operator implementation # | ||||||||
##################################################### | ||||||||
|
@@ -989,6 +1119,7 @@ def _(func, types, args, kwargs): | |||||||
|
||||||||
to_affine_quantized = AffineQuantizedTensor.from_float | ||||||||
to_affine_quantized_static = AffineQuantizedTensor.from_float_static | ||||||||
to_affine_quantized_float8 = AffineQuantizedTensor.from_float_float8 | ||||||||
|
||||||||
if TORCH_VERSION_AT_LEAST_2_5: | ||||||||
# Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True` | ||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -72,6 +72,7 @@ | |
"int8_dynamic_activation_int8_semi_sparse_weight", | ||
"int4_weight_only", | ||
"int8_weight_only", | ||
"float8_weight_only", | ||
] | ||
|
||
from .GPTQ import ( | ||
|
@@ -488,6 +489,24 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): | |
""" | ||
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()) | ||
|
||
def float8_weight_only(): | ||
""" | ||
Applies float8 weight-only symmetric per-channel quantization to linear layers. | ||
""" | ||
def apply_float8wo_quant(weight): | ||
# avoid circular dep | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe you want to import |
||
from torchao.dtypes import to_affine_quantized_float8 | ||
|
||
mapping_type = MappingType.SYMMETRIC | ||
target_dtype = torch.float8_e4m3fn | ||
eps = torch.finfo(torch.float32).eps | ||
zero_point_dtype = torch.float32 | ||
block_size = (1, weight.shape[1]) | ||
return to_affine_quantized_float8(input_float=weight, mapping_type=mapping_type, block_size=block_size, target_dtype=target_dtype, | ||
eps=eps, zero_point_dtype=zero_point_dtype) | ||
|
||
return _get_linear_subclass_inserter(apply_float8wo_quant) | ||
|
||
|
||
def uintx_weight_only(bit_width, group_size=64, pack_dim=-1): | ||
""" | ||
|
Uh oh!
There was an error while loading. Please reload this page.