Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental] Float8 support in AQT #671

Merged
merged 35 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
12d0ac2
updates for float
jainapurva Aug 13, 2024
9173ee1
AAdded optional param
jainapurva Aug 14, 2024
32bf45f
updates for compatible type
jainapurva Aug 14, 2024
ddc1bc0
updates to support torch=2.2
jainapurva Aug 14, 2024
b3e4e79
updates to quant api
jainapurva Aug 19, 2024
e778bca
updates
jainapurva Aug 19, 2024
d4b057f
Float8 updates
jainapurva Aug 20, 2024
9d86df3
Merge branch 'main' into experimental_float8_aqt
jainapurva Aug 20, 2024
04c471e
todos
jainapurva Aug 20, 2024
d86d798
version check
jainapurva Aug 21, 2024
1e5dfd9
Updates for removing float8 API
jainapurva Aug 21, 2024
e25a8e9
Add float8wo to hf_eval
jainapurva Aug 22, 2024
29316fa
remove fpx layout
jainapurva Aug 22, 2024
1fb4a2b
revert changes to hf_eval
jainapurva Aug 22, 2024
ceb3275
Revert changes to main
jainapurva Aug 22, 2024
81eb91f
Fp8 upgrades
jainapurva Aug 23, 2024
c017186
Test for float8
jainapurva Aug 23, 2024
ca9fdbf
Updates
jainapurva Aug 23, 2024
fce977f
Merge branch 'main' into experimental_float8_aqt
jainapurva Aug 23, 2024
bfb8b3b
Remove from_float_float8
jainapurva Aug 23, 2024
146c328
Update optional[int]
jainapurva Aug 23, 2024
9b0c2aa
remove from_float8
jainapurva Aug 23, 2024
3c300eb
Review fixes
jainapurva Aug 24, 2024
efd480b
Review fixes
jainapurva Aug 24, 2024
0abdcc1
typos
jainapurva Aug 24, 2024
5fc0dd8
typos
jainapurva Aug 24, 2024
526a282
Added constraints
jainapurva Aug 24, 2024
83b7356
Added constraints
jainapurva Aug 24, 2024
3ac9f72
Remove eps check
jainapurva Aug 27, 2024
8bd3849
Seperate floatx from_float
jainapurva Aug 27, 2024
02a89b3
Typing fixes
jainapurva Aug 27, 2024
8e19598
init fixes
jainapurva Aug 27, 2024
56a4eb5
Updates for fixes
jainapurva Aug 27, 2024
33ae19a
Review fixes
jainapurva Aug 27, 2024
482f537
Revert a doc string change
jainapurva Aug 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_dynamic_activation_int8_semi_sparse_weight,
float8_weight_only,
)
from torchao.dtypes import (
to_affine_quantized,
Expand All @@ -18,6 +19,7 @@
import unittest
import tempfile


class TestAffineQuantized(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_tensor_core_layout_transpose(self):
Expand All @@ -40,7 +42,8 @@ def test_tensor_core_layout_transpose(self):

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_weights_only(self):
for apply_quant in [int4_weight_only(group_size=32), int8_weight_only(), int8_dynamic_activation_int4_weight(), int8_dynamic_activation_int8_weight(), int8_dynamic_activation_int8_semi_sparse_weight()]:
for apply_quant in [int4_weight_only(group_size=32), int8_weight_only(), int8_dynamic_activation_int4_weight(),
int8_dynamic_activation_int8_weight(), int8_dynamic_activation_int8_semi_sparse_weight(), float8_weight_only()]:
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(l)
with tempfile.NamedTemporaryFile() as f:
Expand Down Expand Up @@ -69,6 +72,5 @@ def test_to_device(self):
ql.cuda()



if __name__ == "__main__":
run_tests()
2 changes: 2 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
AffineQuantizedTensor,
to_affine_quantized,
to_affine_quantized_static,
to_affine_quantized_floatx,
LayoutType,
PlainLayoutType,
SemiSparseLayoutType,
Expand All @@ -18,6 +19,7 @@
"AffineQuantizedTensor",
"to_affine_quantized",
"to_affine_quantized_static",
"to_affine_quantized_floatx",
"LayoutType",
"PlainLayoutType",
"SemiSparseLayoutType",
Expand Down
32 changes: 30 additions & 2 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
MappingType,
int_scaled_matmul,
quantize_affine_hqq,
FP8_TYPES,
)
from torchao.quantization.utils import (
pack_tinygemm_scales_and_zeros,
Expand All @@ -36,7 +37,6 @@

aten = torch.ops.aten


###############################
# Base Layout Tensor Subclass #
###############################
Expand Down Expand Up @@ -91,7 +91,7 @@ class AffineQuantizedTensor(TorchAOBaseTensor):
shape (torch.Size): the shape for the Tensor
quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
Expand Down Expand Up @@ -260,6 +260,33 @@ def from_float_static(
dtype=input_float.dtype,
)

@classmethod
def from_float_to_floatx(
cls,
input_float: torch.Tensor,
block_size: Tuple[int, ...],
target_dtype: torch.dtype = torch.float8_e4m3fn,
layout_type: LayoutType = PlainLayoutType(),
):
if target_dtype in FP8_TYPES:
cls.from_float(
input_float=input_float,
mapping_type=MappingType.SYMMETRIC,
block_size=block_size,
target_dtype=target_dtype,
quant_min=math.ceil(torch.finfo(target_dtype).min),
quant_max=math.ceil(torch.finfo(target_dtype).max),
eps=torch.finfo(torch.float32).eps,
scale_dtype=None,
zero_point_dtype=None,
preserve_zero=True,
zero_point_domain=ZeroPointDomain.INT,
layout_type=PlainLayoutType(),
use_hqq=False,
)
else:
raise NotImplementedError(f"Unsupported dtype {target_dtype} for from_float_to_floatx")

@property
def layout_type(self) -> LayoutType:
return self.layout_tensor.layout_type
Expand Down Expand Up @@ -974,6 +1001,7 @@ def _(func, types, args, kwargs):

to_affine_quantized = AffineQuantizedTensor.from_float
to_affine_quantized_static = AffineQuantizedTensor.from_float_static
to_affine_quantized_floatx = AffineQuantizedTensor.from_float_to_floatx

if TORCH_VERSION_AT_LEAST_2_5:
# Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True`
Expand Down
17 changes: 15 additions & 2 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
TensorCoreTiledLayoutType,
PlainLayoutType,
AffineQuantizedTensor,
SemiSparseLayoutType
SemiSparseLayoutType,
to_affine_quantized_floatx
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
Expand Down Expand Up @@ -57,7 +58,6 @@
import logging
from .autoquant import autoquant, AutoQuantizableLinearWeight


__all__ = [
"swap_conv2d_1x1_to_linear",
"Quantizer",
Expand All @@ -72,6 +72,7 @@
"int8_dynamic_activation_int8_semi_sparse_weight",
"int4_weight_only",
"int8_weight_only",
"float8_weight_only",
]

from .GPTQ import (
Expand Down Expand Up @@ -488,6 +489,18 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
"""
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())

def float8_weight_only(target_dtype: torch.dtype = torch.float8_e4m3fn):
"""
Applies float8 weight-only symmetric per-channel quantization to linear layers.
"""
from torchao.dtypes import to_affine_quantized_floatx
def apply_float8wo_quant(weight):
# avoid circular dep
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you want to import to_affine_quantized_floatx here, I'm also refactoring this file to change the import to the file to avoid circular dep as well

block_size = (1, weight.shape[1])
return to_affine_quantized_floatx(input_float=weight, block_size=block_size, target_dtype=target_dtype)

return _get_linear_subclass_inserter(apply_float8wo_quant)


def uintx_weight_only(dtype, group_size=64, pack_dim=-1):
"""
Expand Down
15 changes: 12 additions & 3 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import List, Optional, Tuple, Dict, Callable, Union
import torch, math


from torchao.kernel.intmm import int_scaled_matmul
from torchao.kernel.intmm import safe_int_mm
from torchao.utils import (
Expand Down Expand Up @@ -58,6 +57,13 @@ class ZeroPointDomain(Enum):
if TORCH_VERSION_AT_LEAST_2_5:
torch.serialization.add_safe_globals([MappingType, ZeroPointDomain])

FP8_TYPES = {
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
}

"""
Map from dtype to the bound value of integers
TODO: maybe can replace this with call to torch.iinfo
Expand Down Expand Up @@ -95,9 +101,12 @@ def _get_and_check_qmin_qmax(dtype, quant_min, quant_max):
verify that they are within the range of possible quant_min/quant_max
for dtype
"""
if dtype not in _DTYPE_TO_QVALUE_BOUNDS:
if dtype in FP8_TYPES:
quant_min_lower_bound, quant_max_upper_bound = torch.finfo(dtype).min, torch.finfo(dtype).max
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
elif dtype not in _DTYPE_TO_QVALUE_BOUNDS:
raise ValueError(f"Unsupported dtype: {dtype}")
quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype]
else:
quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype]
if quant_min is None:
quant_min = quant_min_lower_bound
if quant_max is None:
Expand Down
Loading