Skip to content

Commit

Permalink
Add FP8 Adam (pytorch#482)
Browse files Browse the repository at this point in the history
* update benchmark

* add rank1 option to lpmm

* add comma

* update readme

* remove unwanted file

* update

* add Adam fp8

* add FP8 AdamW and test

* update readme

* change reason to xfail, since 2.2 also have float8

* at guard for FP8

* update readme

* fix guard
  • Loading branch information
gau-nernst authored Jul 7, 2024
1 parent cc77513 commit bf64e23
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 8 deletions.
7 changes: 4 additions & 3 deletions benchmarks/benchmark_low_bit_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@
from torchvision.transforms import v2
from tqdm import tqdm

from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
from torchao.prototype import low_bit_optim

# lpmm doesn't have Adam, only AdamW
OPTIM_MAP = dict(
Adam=torch.optim.Adam,
Adam8bitBnb=bnb.optim.Adam8bit,
Adam8bitAo=Adam8bit,
Adam8bitAo=low_bit_optim.Adam8bit,
AdamFp8Ao=low_bit_optim.AdamFp8,
Adam4bitLpmm=partial(lpmm.optim.AdamW, weight_decay=0, fused=True),
Adam4bitAo=Adam4bit,
Adam4bitAo=low_bit_optim.Adam4bit,
Adam4bitRank1Lpmm=partial(lpmm.optim.AdamW, weight_decay=0, qconfig=argparse.Namespace(scale_type="rank1")),
)

Expand Down
16 changes: 16 additions & 0 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,22 @@ def test_optim_4bit_correctness(self, optim_name):
for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)

@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
@parametrize("optim_name", ["AdamFp8", "AdamWFp8"])
@parametrize("device", _DEVICES)
def test_optim_fp8_smoke(self, optim_name, device):
if device == "cuda" and torch.cuda.get_device_capability() < (8, 9):
pytest.skip("FP8 requires compute capability >= 8.9")

model = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
optim = getattr(low_bit_optim, optim_name)(model.parameters())

x = torch.randn(4, 32, device=device)
loss = model(x).sum()
loss.backward()
optim.step()
optim.zero_grad()


instantiate_parametrized_tests(TestQuantize)
instantiate_parametrized_tests(TestOptim)
Expand Down
8 changes: 5 additions & 3 deletions torchao/prototype/low_bit_optim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ This folder implements:

- 8-bit optimizers as outlined in https://arxiv.org/abs/2110.02861
- 4-bit optimizers as outlined in https://arxiv.org/abs/2309.01507
- FP8 optimizers using the native `torch.float8_e4m3fn` dtype (experimental)

The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel.

Expand All @@ -18,12 +19,12 @@ model = ...
optim = Adam8bit(model.parameters())
```

To use 4-bit Adam, replace the above with `Adam4bit`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 2048 for 8-bit optimizers, and 128 for 4-bit optimizers.
To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 2048 for 8-bit and FP8 optimizers, and 128 for 4-bit optimizers.

**Other optimizers**: AdamW is also available as `AdamW8bit` and `AdamW4bit`. Other optimizers can be added based on demand.
**Other optimizers**: AdamW is also available as `AdamW8bit`, `AdamW4bit`, and `AdamWFp8`. Other optimizers can be added based on demand.

NOTE:
- The low-bit optimizers require PyTorch >= 2.3
- The low-bit optimizers require PyTorch >= 2.3. FP8 optimizers require CUDA compute capability >= 8.9.
- For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper.
- **Known issue**: When learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed is slower. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed.

Expand All @@ -38,6 +39,7 @@ Adam impl | max memory (GB) | time taken for 2nd epoch | accuracy
PyTorch | 12.94 | 8m 18s | 91.14
bnb 8-bit | 8.31 | 6m 50s | 90.67
ao 8-bit | 8.32 | 9m 04s | 90.71
ao FP8 E4M3 | 8.32 | 6m 38s | 91.08
lpmm 4-bit | 7.72 | 5m 59s | 89.97
ao 4-bit | 7.72 | 7m 00s | 89.94
lpmm 4-bit (*) | 7.73 | 11m 10s | 89.71
Expand Down
4 changes: 2 additions & 2 deletions torchao/prototype/low_bit_optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .adam import Adam8bit, Adam4bit
from .adamw import AdamW8bit, AdamW4bit
from .adam import Adam8bit, Adam4bit, AdamFp8
from .adamw import AdamW8bit, AdamW4bit, AdamWFp8
20 changes: 20 additions & 0 deletions torchao/prototype/low_bit_optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .subclass_8bit import maybe_new_8bit_zero_buffer
from .subclass_4bit import maybe_new_4bit_zero_buffer
from .subclass_fp8 import maybe_new_fp8_zero_buffer


class _Adam(Optimizer):
Expand Down Expand Up @@ -155,3 +156,22 @@ def __init__(
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size)

_new_buffer = staticmethod(maybe_new_4bit_zero_buffer)


class AdamFp8(_Adam):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
*,
block_size=2048
) -> None:
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size)

@staticmethod
def _new_buffer(p: Tensor, signed: bool, block_size: int):
return maybe_new_fp8_zero_buffer(p, block_size)
20 changes: 20 additions & 0 deletions torchao/prototype/low_bit_optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .subclass_8bit import maybe_new_8bit_zero_buffer
from .subclass_4bit import maybe_new_4bit_zero_buffer
from .subclass_fp8 import maybe_new_fp8_zero_buffer


class _AdamW(Optimizer):
Expand Down Expand Up @@ -154,3 +155,22 @@ def __init__(
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size)

_new_buffer = staticmethod(maybe_new_4bit_zero_buffer)


class AdamWFp8(_AdamW):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
*,
block_size=2048
) -> None:
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size)

@staticmethod
def _new_buffer(p: Tensor, signed: bool, block_size: int):
return maybe_new_fp8_zero_buffer(p, block_size)
106 changes: 106 additions & 0 deletions torchao/prototype/low_bit_optim/subclass_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import torch
from torch import Tensor
from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE


aten = torch.ops.aten
DTYPE = torch.float8_e4m3fn


def quantize_fp8(input: Tensor, block_size: int):
input = input.view(-1, block_size)
scale = input.abs().amax(-1).clip(1e-12) / torch.finfo(DTYPE).max
input = input / scale.view(-1, 1)
codes = input.to(DTYPE).view(-1)
return codes, scale


class OptimStateFp8(Tensor):
implements = classmethod(_implements)
tensor_attrs = ["codes", "scale"]

@staticmethod
def __new__(cls, codes: Tensor, scale: Tensor):
return Tensor._make_wrapper_subclass(
cls,
codes.shape,
device=codes.device,
requires_grad=False,
)

def __init__(self, codes: Tensor, scale: Tensor):
assert codes.dtype is DTYPE
self.codes = codes
self.scale = scale

@property
def block_size(self):
return self.codes.numel() // self.scale.numel()

def __tensor_flatten__(self):
return self.tensor_attrs, []

@classmethod
def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None):
return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes)

def dequantize(self, output_dtype=None):
float_data = self.codes.float()
float_data = float_data.view(-1, self.block_size) * self.scale.view(-1, 1)

dtype = output_dtype or torch.get_default_dtype()
return float_data.view(self.codes.shape).to(dtype)

@classmethod
def zeros(cls, shape, block_size: int = 2048, device=None):
codes = torch.zeros(shape, dtype=DTYPE, device=device)
scale = torch.zeros(codes.numel() // block_size, device=device)
return cls(codes, scale)

def __repr__(self):
return (
f"{self.__class__.__name__}(block_size={self.block_size}, "
f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})"
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs)

raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run {func}, this is not supported")


@OptimStateFp8.implements(aten.copy_.default)
def _(func, *args, **kwargs):
dst = args[0]
src = args[1]

if isinstance(dst, OptimStateFp8) and isinstance(src, OptimStateFp8):
assert dst.block_size == src.block_size
dst.codes.copy_(src.codes)
dst.scale.copy_(src.scale)

elif isinstance(dst, OptimStateFp8):
codes, scale = quantize_fp8(src, dst.block_size)
dst.codes.copy_(codes)
dst.scale.copy_(scale)

else:
dst.copy_(src.dequantize())

return dst


@OptimStateFp8.implements(aten.lerp.Scalar)
def _(func, *args, **kwargs):
args = [x.dequantize() if isinstance(x, OptimStateFp8) else x for x in args]
return func(*args, **kwargs)


def maybe_new_fp8_zero_buffer(p: Tensor, block_size: int = 2048):
if p.numel() >= 4096 and p.numel() % block_size == 0:
out = OptimStateFp8.zeros(p.shape, block_size, device=p.device)
else:
out = torch.zeros_like(p)
return out

0 comments on commit bf64e23

Please sign in to comment.