From bf64e23c6062d3b2c16f0842c5c8ee0d42a35f32 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 8 Jul 2024 02:58:04 +0800 Subject: [PATCH] Add FP8 Adam (#482) * update benchmark * add rank1 option to lpmm * add comma * update readme * remove unwanted file * update * add Adam fp8 * add FP8 AdamW and test * update readme * change reason to xfail, since 2.2 also have float8 * at guard for FP8 * update readme * fix guard --- benchmarks/benchmark_low_bit_adam.py | 7 +- test/prototype/test_low_bit_optim.py | 16 +++ torchao/prototype/low_bit_optim/README.md | 8 +- torchao/prototype/low_bit_optim/__init__.py | 4 +- torchao/prototype/low_bit_optim/adam.py | 20 ++++ torchao/prototype/low_bit_optim/adamw.py | 20 ++++ .../prototype/low_bit_optim/subclass_fp8.py | 106 ++++++++++++++++++ 7 files changed, 173 insertions(+), 8 deletions(-) create mode 100644 torchao/prototype/low_bit_optim/subclass_fp8.py diff --git a/benchmarks/benchmark_low_bit_adam.py b/benchmarks/benchmark_low_bit_adam.py index 716541ee9c..6517eac8fb 100644 --- a/benchmarks/benchmark_low_bit_adam.py +++ b/benchmarks/benchmark_low_bit_adam.py @@ -28,15 +28,16 @@ from torchvision.transforms import v2 from tqdm import tqdm -from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit +from torchao.prototype import low_bit_optim # lpmm doesn't have Adam, only AdamW OPTIM_MAP = dict( Adam=torch.optim.Adam, Adam8bitBnb=bnb.optim.Adam8bit, - Adam8bitAo=Adam8bit, + Adam8bitAo=low_bit_optim.Adam8bit, + AdamFp8Ao=low_bit_optim.AdamFp8, Adam4bitLpmm=partial(lpmm.optim.AdamW, weight_decay=0, fused=True), - Adam4bitAo=Adam4bit, + Adam4bitAo=low_bit_optim.Adam4bit, Adam4bitRank1Lpmm=partial(lpmm.optim.AdamW, weight_decay=0, qconfig=argparse.Namespace(scale_type="rank1")), ) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index c995d83c85..e037ef8b11 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -139,6 +139,22 @@ def test_optim_4bit_correctness(self, optim_name): for p1, p2 in zip(model1.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) + @pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3") + @parametrize("optim_name", ["AdamFp8", "AdamWFp8"]) + @parametrize("device", _DEVICES) + def test_optim_fp8_smoke(self, optim_name, device): + if device == "cuda" and torch.cuda.get_device_capability() < (8, 9): + pytest.skip("FP8 requires compute capability >= 8.9") + + model = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) + optim = getattr(low_bit_optim, optim_name)(model.parameters()) + + x = torch.randn(4, 32, device=device) + loss = model(x).sum() + loss.backward() + optim.step() + optim.zero_grad() + instantiate_parametrized_tests(TestQuantize) instantiate_parametrized_tests(TestOptim) diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md index 5c1d631d2c..db8575af7b 100644 --- a/torchao/prototype/low_bit_optim/README.md +++ b/torchao/prototype/low_bit_optim/README.md @@ -4,6 +4,7 @@ This folder implements: - 8-bit optimizers as outlined in https://arxiv.org/abs/2110.02861 - 4-bit optimizers as outlined in https://arxiv.org/abs/2309.01507 +- FP8 optimizers using the native `torch.float8_e4m3fn` dtype (experimental) The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel. @@ -18,12 +19,12 @@ model = ... optim = Adam8bit(model.parameters()) ``` -To use 4-bit Adam, replace the above with `Adam4bit`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 2048 for 8-bit optimizers, and 128 for 4-bit optimizers. +To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 2048 for 8-bit and FP8 optimizers, and 128 for 4-bit optimizers. -**Other optimizers**: AdamW is also available as `AdamW8bit` and `AdamW4bit`. Other optimizers can be added based on demand. +**Other optimizers**: AdamW is also available as `AdamW8bit`, `AdamW4bit`, and `AdamWFp8`. Other optimizers can be added based on demand. NOTE: -- The low-bit optimizers require PyTorch >= 2.3 +- The low-bit optimizers require PyTorch >= 2.3. FP8 optimizers require CUDA compute capability >= 8.9. - For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper. - **Known issue**: When learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed is slower. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed. @@ -38,6 +39,7 @@ Adam impl | max memory (GB) | time taken for 2nd epoch | accuracy PyTorch | 12.94 | 8m 18s | 91.14 bnb 8-bit | 8.31 | 6m 50s | 90.67 ao 8-bit | 8.32 | 9m 04s | 90.71 +ao FP8 E4M3 | 8.32 | 6m 38s | 91.08 lpmm 4-bit | 7.72 | 5m 59s | 89.97 ao 4-bit | 7.72 | 7m 00s | 89.94 lpmm 4-bit (*) | 7.73 | 11m 10s | 89.71 diff --git a/torchao/prototype/low_bit_optim/__init__.py b/torchao/prototype/low_bit_optim/__init__.py index ab7d8fd99b..962726b967 100644 --- a/torchao/prototype/low_bit_optim/__init__.py +++ b/torchao/prototype/low_bit_optim/__init__.py @@ -1,2 +1,2 @@ -from .adam import Adam8bit, Adam4bit -from .adamw import AdamW8bit, AdamW4bit +from .adam import Adam8bit, Adam4bit, AdamFp8 +from .adamw import AdamW8bit, AdamW4bit, AdamWFp8 diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 49223b48e9..6595711138 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -6,6 +6,7 @@ from .subclass_8bit import maybe_new_8bit_zero_buffer from .subclass_4bit import maybe_new_4bit_zero_buffer +from .subclass_fp8 import maybe_new_fp8_zero_buffer class _Adam(Optimizer): @@ -155,3 +156,22 @@ def __init__( super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size) _new_buffer = staticmethod(maybe_new_4bit_zero_buffer) + + +class AdamFp8(_Adam): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + *, + block_size=2048 + ) -> None: + super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size) + + @staticmethod + def _new_buffer(p: Tensor, signed: bool, block_size: int): + return maybe_new_fp8_zero_buffer(p, block_size) diff --git a/torchao/prototype/low_bit_optim/adamw.py b/torchao/prototype/low_bit_optim/adamw.py index 440f75620b..9397f04c3c 100644 --- a/torchao/prototype/low_bit_optim/adamw.py +++ b/torchao/prototype/low_bit_optim/adamw.py @@ -6,6 +6,7 @@ from .subclass_8bit import maybe_new_8bit_zero_buffer from .subclass_4bit import maybe_new_4bit_zero_buffer +from .subclass_fp8 import maybe_new_fp8_zero_buffer class _AdamW(Optimizer): @@ -154,3 +155,22 @@ def __init__( super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size) _new_buffer = staticmethod(maybe_new_4bit_zero_buffer) + + +class AdamWFp8(_AdamW): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + *, + block_size=2048 + ) -> None: + super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size) + + @staticmethod + def _new_buffer(p: Tensor, signed: bool, block_size: int): + return maybe_new_fp8_zero_buffer(p, block_size) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py new file mode 100644 index 0000000000..43d45b0d5b --- /dev/null +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -0,0 +1,106 @@ +import torch +from torch import Tensor +from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE + + +aten = torch.ops.aten +DTYPE = torch.float8_e4m3fn + + +def quantize_fp8(input: Tensor, block_size: int): + input = input.view(-1, block_size) + scale = input.abs().amax(-1).clip(1e-12) / torch.finfo(DTYPE).max + input = input / scale.view(-1, 1) + codes = input.to(DTYPE).view(-1) + return codes, scale + + +class OptimStateFp8(Tensor): + implements = classmethod(_implements) + tensor_attrs = ["codes", "scale"] + + @staticmethod + def __new__(cls, codes: Tensor, scale: Tensor): + return Tensor._make_wrapper_subclass( + cls, + codes.shape, + device=codes.device, + requires_grad=False, + ) + + def __init__(self, codes: Tensor, scale: Tensor): + assert codes.dtype is DTYPE + self.codes = codes + self.scale = scale + + @property + def block_size(self): + return self.codes.numel() // self.scale.numel() + + def __tensor_flatten__(self): + return self.tensor_attrs, [] + + @classmethod + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): + return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes) + + def dequantize(self, output_dtype=None): + float_data = self.codes.float() + float_data = float_data.view(-1, self.block_size) * self.scale.view(-1, 1) + + dtype = output_dtype or torch.get_default_dtype() + return float_data.view(self.codes.shape).to(dtype) + + @classmethod + def zeros(cls, shape, block_size: int = 2048, device=None): + codes = torch.zeros(shape, dtype=DTYPE, device=device) + scale = torch.zeros(codes.numel() // block_size, device=device) + return cls(codes, scale) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(block_size={self.block_size}, " + f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})" + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: + return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs) + + raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run {func}, this is not supported") + + +@OptimStateFp8.implements(aten.copy_.default) +def _(func, *args, **kwargs): + dst = args[0] + src = args[1] + + if isinstance(dst, OptimStateFp8) and isinstance(src, OptimStateFp8): + assert dst.block_size == src.block_size + dst.codes.copy_(src.codes) + dst.scale.copy_(src.scale) + + elif isinstance(dst, OptimStateFp8): + codes, scale = quantize_fp8(src, dst.block_size) + dst.codes.copy_(codes) + dst.scale.copy_(scale) + + else: + dst.copy_(src.dequantize()) + + return dst + + +@OptimStateFp8.implements(aten.lerp.Scalar) +def _(func, *args, **kwargs): + args = [x.dequantize() if isinstance(x, OptimStateFp8) else x for x in args] + return func(*args, **kwargs) + + +def maybe_new_fp8_zero_buffer(p: Tensor, block_size: int = 2048): + if p.numel() >= 4096 and p.numel() % block_size == 0: + out = OptimStateFp8.zeros(p.shape, block_size, device=p.device) + else: + out = torch.zeros_like(p) + return out