From ade8febfdd5f96868b9ad9debcabf68119cf9e76 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 16 Jul 2024 22:09:00 +0800 Subject: [PATCH] Add FSDP2 support for low-bit optimizers (#484) --- test/prototype/test_low_bit_optim.py | 157 +++++++++++++----- torchao/prototype/low_bit_optim/README.md | 6 +- torchao/prototype/low_bit_optim/adam.py | 141 ++++++++++++---- torchao/prototype/low_bit_optim/adamw.py | 141 ++++++++++++---- .../prototype/low_bit_optim/quant_utils.py | 112 +++++++++++++ .../prototype/low_bit_optim/subclass_4bit.py | 69 ++------ .../prototype/low_bit_optim/subclass_8bit.py | 115 +------------ .../prototype/low_bit_optim/subclass_fp8.py | 13 +- 8 files changed, 467 insertions(+), 287 deletions(-) create mode 100644 torchao/prototype/low_bit_optim/quant_utils.py diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index e037ef8b11..94cfe34096 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -1,5 +1,4 @@ import copy -from functools import partial import pytest import torch @@ -10,9 +9,11 @@ parametrize, run_tests, ) +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_fsdp import FSDPTest from torchao.prototype import low_bit_optim -from torchao.prototype.low_bit_optim import subclass_8bit, subclass_4bit -from torchao.utils import TORCH_VERSION_AFTER_2_3 +from torchao.prototype.low_bit_optim.quant_utils import quantize_8bit_with_qmap, quantize_4bit_with_qmap +from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 try: import bitsandbytes as bnb @@ -31,52 +32,69 @@ class TestQuantize(TestCase): @parametrize("device", _DEVICES) def test_quantize_8bit_with_qmap_correctness(self, device): - x = torch.randn(32, 1024, device=device) - qmap = torch.tensor(subclass_8bit.QMAP_SIGNED, device=device) + x = torch.rand(32, 1024, device=device) + qmap = torch.rand(256, device=device).sort().values - actual_codes, actual_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256, implementation=1) - expected_codes, expected_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256, implementation=0) + actual = (x.unsqueeze(-1) - qmap).abs().argmin(-1).to(torch.uint8) + expected = quantize_8bit_with_qmap(x, qmap) - torch.testing.assert_close(actual_codes, expected_codes) - torch.testing.assert_close(actual_scale, expected_scale) + torch.testing.assert_close(actual, expected) @parametrize("device", _DEVICES) def test_quantize_8bit_with_qmap_compile(self, device): - x = torch.randn(32, 1024, device=device) - qmap = torch.tensor(subclass_8bit.QMAP_SIGNED, device=device) + x = torch.rand(32, 1024, device=device) + qmap = torch.rand(256, device=device).sort().values - compiled_f = torch.compile(subclass_8bit.quantize_8bit_with_qmap, fullgraph=True) - actual_codes, actual_scale = compiled_f(x, qmap, 256) - expected_codes, expected_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256) + compiled_f = torch.compile(quantize_8bit_with_qmap, fullgraph=True) + actual = compiled_f(x, qmap) + expected = quantize_8bit_with_qmap(x, qmap) - torch.testing.assert_close(actual_codes, expected_codes) - torch.testing.assert_close(actual_scale, expected_scale) + torch.testing.assert_close(actual, expected) @parametrize("device", _DEVICES) def test_quantize_4bit_with_qmap_correctness(self, device): - x = torch.randn(32, 1024, device=device) - qmap = torch.tensor(subclass_4bit.QMAP_SIGNED, device=device) + x = torch.rand(32, 1024, device=device) + qmap = torch.rand(16, device=device).sort().values - actual_codes, actual_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256, implementation=1) - expected_codes, expected_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256, implementation=0) + actual = (x.unsqueeze(-1) - qmap).abs().argmin(-1).to(torch.uint8) + expected = quantize_4bit_with_qmap(x, qmap) - torch.testing.assert_close(actual_codes, expected_codes) - torch.testing.assert_close(actual_scale, expected_scale) + torch.testing.assert_close(actual, expected) @parametrize("device", _DEVICES) def test_quantize_4bit_with_qmap_compile(self, device): - x = torch.randn(32, 1024, device=device) - qmap = torch.tensor(subclass_4bit.QMAP_SIGNED, device=device) + x = torch.rand(32, 1024, device=device) + qmap = torch.rand(16, device=device).sort().values - compiled_f = torch.compile(subclass_4bit.quantize_4bit_with_qmap, fullgraph=True) - actual_codes, actual_scale = compiled_f(x, qmap, 256) - expected_codes, expected_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256) + compiled_f = torch.compile(quantize_4bit_with_qmap, fullgraph=True) + actual = compiled_f(x, qmap) + expected = quantize_4bit_with_qmap(x, qmap) - torch.testing.assert_close(actual_codes, expected_codes) - torch.testing.assert_close(actual_scale, expected_scale) + torch.testing.assert_close(actual, expected) class TestOptim(TestCase): + @pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3") + @parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"]) + @parametrize("dtype", [torch.float32, torch.bfloat16]) + @parametrize("device", _DEVICES) + def test_optim_smoke(self, optim_name, dtype, device): + if optim_name.endswith("Fp8") and device == "cuda" and torch.cuda.get_device_capability() < (8, 9): + pytest.skip("FP8 requires compute capability >= 8.9") + + # reset cache to avoid hitting cache_size_limit, since the function will re-compile for each test + torch._dynamo.reset_code_caches() + + model = nn.Sequential(nn.Linear(32, 256), nn.ReLU(), nn.Linear(256, 32)) + model.to(device=device, dtype=dtype) + optim = getattr(low_bit_optim, optim_name)(model.parameters()) + + x = torch.randn(4, 32, device=device, dtype=dtype) + loss = model(x).sum() + loss.backward() + optim.step() + optim.zero_grad() + @pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle") @pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA") @pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3") @@ -139,21 +157,74 @@ def test_optim_4bit_correctness(self, optim_name): for p1, p2 in zip(model1.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) - @pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3") - @parametrize("optim_name", ["AdamFp8", "AdamWFp8"]) - @parametrize("device", _DEVICES) - def test_optim_fp8_smoke(self, optim_name, device): - if device == "cuda" and torch.cuda.get_device_capability() < (8, 9): - pytest.skip("FP8 requires compute capability >= 8.9") - - model = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) - optim = getattr(low_bit_optim, optim_name)(model.parameters()) - x = torch.randn(4, 32, device=device) - loss = model(x).sum() - loss.backward() - optim.step() - optim.zero_grad() +class TestFSDP2(FSDPTest): + @property + def world_size(self) -> int: + return 2 + + @pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="torch >= 2.4 required") + @skip_if_lt_x_gpu(2) + def test_fsdp2(self): + optim_classes = [low_bit_optim.Adam8bit, low_bit_optim.Adam4bit] + if torch.cuda.get_device_capability() >= (8, 9): + optim_classes.append(low_bit_optim.AdamFp8) + + self.run_subtests( + {"optim_cls": optim_classes}, + self._test_fsdp2, + ) + + def _test_fsdp2(self, optim_cls): + from torch.distributed._composable.fsdp import fully_shard + from torch.testing._internal.distributed._tensor.common_dtensor import ( + ModelArgs, + Transformer, + TransformerBlock, + ) + + # seems like cache_size_limit is shared between FSDP processes? + torch._dynamo.config.cache_size_limit = 8 * self.world_size + + batch_size = 3 + vocab_size = 1024 + seq_len = 64 + model_args = ModelArgs( + n_layers=3, + n_heads=4, + dim=1024, + vocab_size=vocab_size, + max_seq_len=seq_len, + dropout_p=0, + ) + torch.manual_seed(42) + with torch.device("cuda"): + base_model = Transformer(model_args) + base_optim = optim_cls(base_model.parameters(), lr=1e-2) + + fsdp_model = copy.deepcopy(base_model) + for m in fsdp_model.modules(): + if isinstance(m, TransformerBlock): + fully_shard(m) + fully_shard(fsdp_model) + fsdp_optim = optim_cls(fsdp_model.parameters(), lr=1e-2) + + torch.manual_seed(42 + self.rank + 1) + for iter_idx in range(5): + inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") + fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) + fsdp_loss = fsdp_model(inp).mean() + fsdp_loss.backward() + fsdp_optim.step() + + base_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) + base_loss = base_model(inp).mean() + base_loss.backward() + for param in base_model.parameters(): + if param.grad is not None: + torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG) + base_optim.step() + self.assertEqual(fsdp_loss, base_loss) instantiate_parametrized_tests(TestQuantize) diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md index db8575af7b..6412074140 100644 --- a/torchao/prototype/low_bit_optim/README.md +++ b/torchao/prototype/low_bit_optim/README.md @@ -38,10 +38,10 @@ Adam impl | max memory (GB) | time taken for 2nd epoch | accuracy ---------------|-----------------|--------------------------|---------- PyTorch | 12.94 | 8m 18s | 91.14 bnb 8-bit | 8.31 | 6m 50s | 90.67 -ao 8-bit | 8.32 | 9m 04s | 90.71 -ao FP8 E4M3 | 8.32 | 6m 38s | 91.08 +ao 8-bit | 8.31 | 6m 44s | 90.63 +ao FP8 E4M3 | 8.32 | 6m 35s | 90.98 lpmm 4-bit | 7.72 | 5m 59s | 89.97 -ao 4-bit | 7.72 | 7m 00s | 89.94 +ao 4-bit | 7.72 | 7m 13s | 90.05 lpmm 4-bit (*) | 7.73 | 11m 10s | 89.71 (*) means rank-1 normalization is used for 2nd optimizer state. Refer to [paper](https://arxiv.org/abs/2309.01507) for more details. diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 6595711138..b3b7eeb6f3 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -3,10 +3,11 @@ import torch from torch import Tensor from torch.optim import Optimizer +from torch.distributed._tensor import DTensor -from .subclass_8bit import maybe_new_8bit_zero_buffer -from .subclass_4bit import maybe_new_4bit_zero_buffer -from .subclass_fp8 import maybe_new_fp8_zero_buffer +from .subclass_8bit import OptimState8bit +from .subclass_4bit import OptimState4bit +from .subclass_fp8 import OptimStateFp8 class _Adam(Optimizer): @@ -28,18 +29,34 @@ def __setstate__(self, state): for group in self.param_groups: group.setdefault("amsgrad", False) + # bring your own function to create zero-filled subclass @staticmethod - def _new_buffer(p: Tensor, signed: bool, block_size: int): + def _subclass_zeros(p: Tensor, signed: bool, block_size: int): raise NotImplementedError - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() + # follow bitsandbytes, only quantize tensors >= 4096 values + # also wrap subclass in DTensor when needed + def _new_buffer(self, p: Tensor, signed: bool): + if p.numel() >= 4096 and p.numel() % self.block_size == 0: + if isinstance(p, DTensor): + out = torch.empty_like(p) + out._local_tensor = self._subclass_zeros( + out._local_tensor, + signed, + self.block_size, + ) + else: + out = self._subclass_zeros(p, signed, self.block_size) + else: + out = torch.zeros_like(p) + return out + + def _prepare_param_groups(self): + param_groups = [] for group in self.param_groups: + _group = [] + for p in group["params"]: if p.grad is None: continue @@ -51,42 +68,56 @@ def step(self, closure=None): state = self.state[p] # State initialization - # state is flattened so that torch.compile won't recompile for tensors with different ndim if len(state) == 0: state["step"] = torch.tensor(0.0, device=p.device) - state["exp_avg"] = self._new_buffer(p.view(-1), True, self.block_size) - state["exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size) + state["exp_avg"] = self._new_buffer(p, True) + state["exp_avg_sq"] = self._new_buffer(p, False) if group["amsgrad"]: - state["max_exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size) + state["max_exp_avg_sq"] = self._new_buffer(p, False) state["step"] += 1 - # must explicitly convert lr to Tensor since torch.compile() will treat it as a constant - # if it is a python float. practically, only lr is changed during training. - # NOTE: if lr is change at every step, moving lr to CUDA will be a bottleneck. + # must explicitly convert lr to Tensor since torch.compile() will treat Python float as constant. + # practically, only lr is changed during training. + # NOTE: if lr is changed at every step, moving lr to CUDA can slow down training 3-4%. if not isinstance(group["lr"], Tensor): group["lr"] = torch.tensor(group["lr"], device=p.device) - # flatten p and grad so that torch.compile won't recompile for tensors with different ndim - single_param_adam( - p.view(-1), - grad.view(-1), + p_grad_state = ( + p, + grad, state["step"], state["exp_avg"], state["exp_avg_sq"], state.get("max_exp_avg_sq", None), - group["lr"], - group["betas"][0], - group["betas"][1], - group["weight_decay"], - group["eps"], ) + _group.append(p_grad_state) + + param_groups.append((_group, group["lr"], group["betas"], group["weight_decay"], group["eps"])) + + return param_groups + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + param_groups = self._prepare_param_groups() + + # static compile optim step for all params in a single graph + torch.compile(param_groups_adam, fullgraph=True)(param_groups) return loss +def param_groups_adam(param_groups): + for group, lr, (beta1, beta2), weight_decay, eps in param_groups: + for p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq in group: + single_param_adam(p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq, lr, beta1, beta2, weight_decay, eps) + + # this will work with any optim state tensor subclass that implements aten.lerp.Scalar and aten.copy_.default -@torch.compile(fullgraph=True, dynamic=True) def single_param_adam( p: Tensor, grad: Tensor, @@ -134,11 +165,13 @@ def __init__( weight_decay=0, amsgrad=False, *, - block_size=2048 + block_size=2048, ) -> None: super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size) - _new_buffer = staticmethod(maybe_new_8bit_zero_buffer) + @staticmethod + def _subclass_zeros(p: Tensor, signed: bool, block_size: int): + return OptimState8bit.zeros(p.shape, signed, block_size, p.device) class Adam4bit(_Adam): @@ -155,7 +188,49 @@ def __init__( ) -> None: super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size) - _new_buffer = staticmethod(maybe_new_4bit_zero_buffer) + @staticmethod + def _subclass_zeros(p: Tensor, signed: bool, block_size: int): + return OptimState4bit.zeros(p.shape, signed, block_size, p.device) + + @staticmethod + def _unwrap_dtensor(p: Tensor): + return p._local_tensor if isinstance(p, DTensor) else p + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + param_groups = self._prepare_param_groups() + + # NOTE: right now, torch.compile(param_groups_adam) will have excessive memory usage for 4-bit optim. + # thus, as a workaround, we use torch.compile(single_param_adam) and call it for each param. + + # unwrap DTensor since DTensor does not work well with dynamic compile + # flatten p, grad, and optim state to avoid recompilation + for group, lr, (beta1, beta2), weight_decay, eps in param_groups: + for p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq in group: + # DTensor._local_tensor has .requires_grad = False + # to avoid recompilation, set p.requires_grad = False and restore it after optim step + p.requires_grad_(False) + torch.compile(single_param_adam, fullgraph=True, dynamic=True)( + self._unwrap_dtensor(p).view(-1), + self._unwrap_dtensor(grad).view(-1), + step, + self._unwrap_dtensor(exp_avg).view(-1), + self._unwrap_dtensor(exp_avg_sq).view(-1), + self._unwrap_dtensor(max_exp_avg_sq).view(-1) if max_exp_avg_sq is not None else None, + lr, + beta1, + beta2, + weight_decay, + eps, + ) + p.requires_grad_(True) + + return loss class AdamFp8(_Adam): @@ -168,10 +243,10 @@ def __init__( weight_decay=0, amsgrad=False, *, - block_size=2048 + block_size=2048, ) -> None: super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size) @staticmethod - def _new_buffer(p: Tensor, signed: bool, block_size: int): - return maybe_new_fp8_zero_buffer(p, block_size) + def _subclass_zeros(p: Tensor, signed: bool, block_size: int): + return OptimStateFp8.zeros(p.shape, block_size, p.device) diff --git a/torchao/prototype/low_bit_optim/adamw.py b/torchao/prototype/low_bit_optim/adamw.py index 9397f04c3c..ad60caa435 100644 --- a/torchao/prototype/low_bit_optim/adamw.py +++ b/torchao/prototype/low_bit_optim/adamw.py @@ -3,10 +3,11 @@ import torch from torch import Tensor from torch.optim import Optimizer +from torch.distributed._tensor import DTensor -from .subclass_8bit import maybe_new_8bit_zero_buffer -from .subclass_4bit import maybe_new_4bit_zero_buffer -from .subclass_fp8 import maybe_new_fp8_zero_buffer +from .subclass_8bit import OptimState8bit +from .subclass_4bit import OptimState4bit +from .subclass_fp8 import OptimStateFp8 class _AdamW(Optimizer): @@ -28,18 +29,34 @@ def __setstate__(self, state): for group in self.param_groups: group.setdefault("amsgrad", False) + # bring your own function to create zero-filled subclass @staticmethod - def _new_buffer(p: Tensor, signed: bool, block_size: int): + def _subclass_zeros(p: Tensor, signed: bool, block_size: int): raise NotImplementedError - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() + # follow bitsandbytes, only quantize tensors >= 4096 values + # also wrap subclass in DTensor when needed + def _new_buffer(self, p: Tensor, signed: bool): + if p.numel() >= 4096 and p.numel() % self.block_size == 0: + if isinstance(p, DTensor): + out = torch.empty_like(p) + out._local_tensor = self._subclass_zeros( + out._local_tensor, + signed, + self.block_size, + ) + else: + out = self._subclass_zeros(p, signed, self.block_size) + else: + out = torch.zeros_like(p) + return out + + def _prepare_param_groups(self): + param_groups = [] for group in self.param_groups: + _group = [] + for p in group["params"]: if p.grad is None: continue @@ -51,42 +68,56 @@ def step(self, closure=None): state = self.state[p] # State initialization - # state is flattened so that torch.compile won't recompile for tensors with different ndim if len(state) == 0: state["step"] = torch.tensor(0.0, device=p.device) - state["exp_avg"] = self._new_buffer(p.view(-1), True, self.block_size) - state["exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size) + state["exp_avg"] = self._new_buffer(p, True) + state["exp_avg_sq"] = self._new_buffer(p, False) if group["amsgrad"]: - state["max_exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size) + state["max_exp_avg_sq"] = self._new_buffer(p, False) state["step"] += 1 - # must explicitly convert lr to Tensor since torch.compile() will treat it as a constant - # if it is a python float. practically, only lr is changed during training. - # NOTE: if lr is change at every step, moving lr to CUDA will be a bottleneck. + # must explicitly convert lr to Tensor since torch.compile() will treat Python float as constant. + # practically, only lr is changed during training. + # NOTE: if lr is changed at every step, moving lr to CUDA can slow down training 3-4%. if not isinstance(group["lr"], Tensor): group["lr"] = torch.tensor(group["lr"], device=p.device) - # flatten p and grad so that torch.compile won't recompile for tensors with different ndim - single_param_adamw( - p.view(-1), - grad.view(-1), + p_grad_state = ( + p, + grad, state["step"], state["exp_avg"], state["exp_avg_sq"], state.get("max_exp_avg_sq", None), - group["lr"], - group["betas"][0], - group["betas"][1], - group["weight_decay"], - group["eps"], ) + _group.append(p_grad_state) + + param_groups.append((_group, group["lr"], group["betas"], group["weight_decay"], group["eps"])) + + return param_groups + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + param_groups = self._prepare_param_groups() + + # static compile optim step for all params in a single graph + torch.compile(param_groups_adamw, fullgraph=True)(param_groups) return loss +def param_groups_adamw(param_groups): + for group, lr, (beta1, beta2), weight_decay, eps in param_groups: + for p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq in group: + single_param_adamw(p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq, lr, beta1, beta2, weight_decay, eps) + + # this will work with any optim state tensor subclass that implements aten.lerp.Scalar and aten.copy_.default -@torch.compile(fullgraph=True, dynamic=True) def single_param_adamw( p: Tensor, grad: Tensor, @@ -133,11 +164,13 @@ def __init__( weight_decay=1e-2, amsgrad=False, *, - block_size=2048 + block_size=2048, ) -> None: super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size) - _new_buffer = staticmethod(maybe_new_8bit_zero_buffer) + @staticmethod + def _subclass_zeros(p: Tensor, signed: bool, block_size: int): + return OptimState8bit.zeros(p.shape, signed, block_size, p.device) class AdamW4bit(_AdamW): @@ -154,7 +187,49 @@ def __init__( ) -> None: super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size) - _new_buffer = staticmethod(maybe_new_4bit_zero_buffer) + @staticmethod + def _subclass_zeros(p: Tensor, signed: bool, block_size: int): + return OptimState4bit.zeros(p.shape, signed, block_size, p.device) + + @staticmethod + def _unwrap_dtensor(p: Tensor): + return p._local_tensor if isinstance(p, DTensor) else p + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + param_groups = self._prepare_param_groups() + + # NOTE: right now, torch.compile(param_groups_adam) will have excessive memory usage for 4-bit optim. + # thus, as a workaround, we use torch.compile(single_param_adam) and call it for each param. + + # unwrap DTensor since DTensor does not work well with dynamic compile + # flatten p, grad, and optim state to avoid recompilation + for group, lr, (beta1, beta2), weight_decay, eps in param_groups: + for p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq in group: + # DTensor._local_tensor has .requires_grad = False + # to avoid recompilation, set p.requires_grad = False and restore it after optim step + p.requires_grad_(False) + torch.compile(single_param_adamw, fullgraph=True, dynamic=True)( + self._unwrap_dtensor(p).view(-1), + self._unwrap_dtensor(grad).view(-1), + step, + self._unwrap_dtensor(exp_avg).view(-1), + self._unwrap_dtensor(exp_avg_sq).view(-1), + self._unwrap_dtensor(max_exp_avg_sq).view(-1) if max_exp_avg_sq is not None else None, + lr, + beta1, + beta2, + weight_decay, + eps, + ) + p.requires_grad_(True) + + return loss class AdamWFp8(_AdamW): @@ -167,10 +242,10 @@ def __init__( weight_decay=1e-2, amsgrad=False, *, - block_size=2048 + block_size=2048, ) -> None: super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size) @staticmethod - def _new_buffer(p: Tensor, signed: bool, block_size: int): - return maybe_new_fp8_zero_buffer(p, block_size) + def _subclass_zeros(p: Tensor, signed: bool, block_size: int): + return OptimStateFp8.zeros(p.shape, block_size, p.device) diff --git a/torchao/prototype/low_bit_optim/quant_utils.py b/torchao/prototype/low_bit_optim/quant_utils.py new file mode 100644 index 0000000000..0dc262ed40 --- /dev/null +++ b/torchao/prototype/low_bit_optim/quant_utils.py @@ -0,0 +1,112 @@ +import torch +from torch import Tensor + + +# https://github.com/TimDettmers/bitsandbytes/blob/dada530149212d64d4b69534716202659ef37ec8/bitsandbytes/functional.py#L339-L391 +# NOTE: zero padding is removed so this function can work with 4-bit qmap +def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): + """ + Creates the dynamic quantiztion map. + + The dynamic data type is made up of a dynamic exponent and + fraction. As the exponent increase from 0 to -7 the number + of bits available for the fraction shrinks. + + This is a generalization of the dynamic type where a certain + number of the bits and be reserved for the linear quantization + region (the fraction). n determines the maximum number of + exponent bits. + + For more details see + (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561] + """ + + data = [] + # these are additional items that come from the case + # where all the exponent bits are zero and no + # indicator bit is present + non_sign_bits = total_bits - (1 if signed else 1) + additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 + for i in range(max_exponent_bits): + fraction_items = int( + 2 ** (i + non_sign_bits - max_exponent_bits) + 1 + if signed + else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1, + ) + boundaries = torch.linspace(0.1, 1, fraction_items) + means = (boundaries[:-1] + boundaries[1:]) / 2.0 + data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() + if signed: + data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() + + if additional_items > 0: + boundaries = torch.linspace(0.1, 1, additional_items + 1) + means = (boundaries[:-1] + boundaries[1:]) / 2.0 + data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() + if signed: + data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() + + data.append(0) + data.append(1.0) + + assert len(data) == 2**total_bits + + data.sort() + return data + + +def scale_tensor(input: Tensor, block_size: int): + """Scale tensor so that max(abs(input)) = 1""" + shape = input.shape + + # section 2.1 from https://arxiv.org/abs/2110.02861 + input = input.view(-1, block_size) + scale = input.abs().amax(-1).clip(1e-12) + input = input / scale.view(-1, 1) + return input.view(shape), scale + + +def quantize_8bit_with_qmap(input: Tensor, qmap: Tensor): + # GPU-friendly binary search + # https://blog.demofox.org/2017/06/20/simd-gpu-friendly-branchless-binary-search/ + codes = torch.where(input >= qmap[128], 128, 0) + codes += torch.where(input >= qmap[codes + 64], 64, 0) + codes += torch.where(input >= qmap[codes + 32], 32, 0) + codes += torch.where(input >= qmap[codes + 16], 16, 0) + codes += torch.where(input >= qmap[codes + 8], 8, 0) + codes += torch.where(input >= qmap[codes + 4], 4, 0) + codes += torch.where(input >= qmap[codes + 2], 2, 0) + codes += torch.where(input >= qmap[codes + 1], 1, 0) + + # rounding + codes_up = (codes + 1).clip(max=255) + val_down = qmap[codes] + val_up = qmap[codes_up] + residual = input - val_down + codes = torch.where(residual >= (val_up - val_down) * 0.5, codes_up, codes) + + return codes.to(torch.uint8) + + +def quantize_4bit_with_qmap(input: Tensor, qmap: Tensor): + # GPU-friendly binary search + # https://blog.demofox.org/2017/06/20/simd-gpu-friendly-branchless-binary-search/ + codes = torch.where(input >= qmap[8], 8, 0) + codes += torch.where(input >= qmap[codes + 4], 4, 0) + codes += torch.where(input >= qmap[codes + 2], 2, 0) + codes += torch.where(input >= qmap[codes + 1], 1, 0) + + # rounding + codes_up = (codes + 1).clip(max=15) + val_down = qmap[codes] + val_up = qmap[codes_up] + residual = input - val_down + codes = torch.where(residual >= (val_up - val_down) * 0.5, codes_up, codes) + + return codes.to(torch.uint8) + + +def dequant_with_qmap(codes: Tensor, qmap: Tensor, scale: Tensor): + # torch.compile() cannot use uint8 as index + out = qmap[codes.int()].view(scale.shape[0], -1) * scale.view(-1, 1) + return out.view(codes.shape) diff --git a/torchao/prototype/low_bit_optim/subclass_4bit.py b/torchao/prototype/low_bit_optim/subclass_4bit.py index 2b3608ce1d..9550b3d51c 100644 --- a/torchao/prototype/low_bit_optim/subclass_4bit.py +++ b/torchao/prototype/low_bit_optim/subclass_4bit.py @@ -4,7 +4,7 @@ from torch import Tensor from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE -from .subclass_8bit import create_dynamic_map +from .quant_utils import create_dynamic_map, scale_tensor, quantize_4bit_with_qmap, dequant_with_qmap aten = torch.ops.aten @@ -12,47 +12,11 @@ # https://github.com/thu-ml/low-bit-optimizers/blob/e3e2854728e498c2a606e3fdb88daa27ae94f9a6/lpmm/configs/2nd_moment_group_128.yml # NOTE: power-1 is linear +# TODO: since QMAP_UNSIGNED is linear, perhaps doing affine quantize is faster? QMAP_SIGNED = create_dynamic_map(True, 3, 4) QMAP_UNSIGNED = torch.linspace(0, 1, 17)[1:].tolist() # no zero -def quantize_4bit_with_qmap(input: Tensor, qmap: Tensor, block_size: int, implementation: int = 1): - # section 2.1 from https://arxiv.org/abs/2110.02861 - input = input.view(-1, block_size) - scale = input.abs().amax(-1).clip(1e-12) - input = input / scale.view(-1, 1) - - # reference implementation. equation 4 from https://arxiv.org/abs/2110.02861 - if implementation == 0: - codes = (qmap.view(1, -1) - input.view(-1, 1)).abs().argmin(-1) - codes = codes.to(torch.uint8) - - # GPU-friendly binary search - # https://blog.demofox.org/2017/06/20/simd-gpu-friendly-branchless-binary-search/ - elif implementation == 1: - input = input.view(-1) - codes = torch.where(input >= qmap[8], 8, 0) - codes += torch.where(input >= qmap[codes + 4], 4, 0) - codes += torch.where(input >= qmap[codes + 2], 2, 0) - codes += torch.where(input >= qmap[codes + 1], 1, 0) - - # rounding - codes_up = (codes + 1).clip(max=15) - val_down = qmap[codes] - val_up = qmap[codes_up] - residual = input - val_down - codes = torch.where(residual >= (val_up - val_down) * 0.5, codes_up, codes) - - codes = codes.to(torch.uint8) - - else: - raise ValueError(f"Unsupported implementation={implementation}") - - # packing - codes = (codes[::2] << 4) | codes[1::2] - return codes, scale - - class OptimState4bit(Tensor): implements = classmethod(_implements) tensor_attrs = ["codes", "scale", "qmap"] @@ -87,13 +51,8 @@ def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=No return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes) def dequantize(self, output_dtype=None): - # unpack - codes = torch.stack([self.codes >> 4, self.codes & 0b1111], dim=-1) - - # torch.compile() cannot use uint8 as index - float_data = self.qmap[codes.int()] - float_data = float_data.view(-1, self.block_size) * self.scale.view(-1, 1) - + codes = torch.stack([self.codes >> 4, self.codes & 0b1111], dim=-1) # unpack + float_data = dequant_with_qmap(codes, self.qmap, self.scale) dtype = output_dtype or torch.get_default_dtype() return float_data.view(self._shape).to(dtype) @@ -137,8 +96,9 @@ def _(func, *args, **kwargs): # qmap should be the same, don't need to copy elif isinstance(dst, OptimState4bit): - codes, scale = quantize_4bit_with_qmap(src, dst.qmap, dst.block_size) - dst.codes.copy_(codes) + scaled_src, scale = scale_tensor(src.view(-1), dst.block_size) + codes = quantize_4bit_with_qmap(scaled_src, dst.qmap) + dst.codes.copy_((codes[::2] << 4) | codes[1::2]) # packing dst.scale.copy_(scale) else: @@ -153,12 +113,9 @@ def _(func, *args, **kwargs): return func(*args, **kwargs) -# https://github.com/thu-ml/low-bit-optimizers/blob/e3e2854728e498c2a606e3fdb88daa27ae94f9a6/lpmm/config.py#L37 -# only apply quantization for tensor with more than 4096 values -# TODO: also skip 1D tensor? e.g. biases and norm scales -def maybe_new_4bit_zero_buffer(p: Tensor, signed: bool = True, block_size: int = 128): - if p.numel() >= 4096 and p.numel() % block_size == 0: - out = OptimState4bit.zeros(p.shape, signed, block_size, device=p.device) - else: - out = torch.zeros_like(p) - return out +@OptimState4bit.implements(aten.view.default) +def _(func, *args, **kwargs): + x, shape = args + if len(shape) > 1 or shape[0] != -1: + raise ValueError(f"{x.__class__.__name__} only supports .view() with shape=[-1]") + return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),)) diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index 44a3d593cf..5b16f6363f 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -2,106 +2,15 @@ from torch import Tensor from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE +from .quant_utils import create_dynamic_map, scale_tensor, quantize_8bit_with_qmap, dequant_with_qmap -aten = torch.ops.aten - - -# https://github.com/TimDettmers/bitsandbytes/blob/dada530149212d64d4b69534716202659ef37ec8/bitsandbytes/functional.py#L339-L391 -# NOTE: zero padding is removed so this function can work with 4-bit qmap -def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): - """ - Creates the dynamic quantiztion map. - - The dynamic data type is made up of a dynamic exponent and - fraction. As the exponent increase from 0 to -7 the number - of bits available for the fraction shrinks. - - This is a generalization of the dynamic type where a certain - number of the bits and be reserved for the linear quantization - region (the fraction). n determines the maximum number of - exponent bits. - - For more details see - (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561] - """ - - data = [] - # these are additional items that come from the case - # where all the exponent bits are zero and no - # indicator bit is present - non_sign_bits = total_bits - (1 if signed else 1) - additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 - for i in range(max_exponent_bits): - fraction_items = int( - 2 ** (i + non_sign_bits - max_exponent_bits) + 1 - if signed - else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1, - ) - boundaries = torch.linspace(0.1, 1, fraction_items) - means = (boundaries[:-1] + boundaries[1:]) / 2.0 - data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() - if signed: - data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() - - if additional_items > 0: - boundaries = torch.linspace(0.1, 1, additional_items + 1) - means = (boundaries[:-1] + boundaries[1:]) / 2.0 - data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() - if signed: - data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() - - data.append(0) - data.append(1.0) - - assert len(data) == 2**total_bits - - data.sort() - return data +aten = torch.ops.aten QMAP_SIGNED = create_dynamic_map(signed=True) QMAP_UNSIGNED = create_dynamic_map(signed=False) -def quantize_8bit_with_qmap(input: Tensor, qmap: Tensor, block_size: int, implementation: int = 1): - # section 2.1 from https://arxiv.org/abs/2110.02861 - input = input.view(-1, block_size) - scale = input.abs().amax(-1).clip(1e-12) - input = input / scale.view(-1, 1) - - # reference implementation. equation 4 from https://arxiv.org/abs/2110.02861 - if implementation == 0: - codes = (qmap.view(1, -1) - input.view(-1, 1)).abs().argmin(-1) - codes = codes.to(torch.uint8) - - # GPU-friendly binary search - # https://blog.demofox.org/2017/06/20/simd-gpu-friendly-branchless-binary-search/ - elif implementation == 1: - input = input.view(-1) - codes = torch.where(input >= qmap[128], 128, 0) - codes += torch.where(input >= qmap[codes + 64], 64, 0) - codes += torch.where(input >= qmap[codes + 32], 32, 0) - codes += torch.where(input >= qmap[codes + 16], 16, 0) - codes += torch.where(input >= qmap[codes + 8], 8, 0) - codes += torch.where(input >= qmap[codes + 4], 4, 0) - codes += torch.where(input >= qmap[codes + 2], 2, 0) - codes += torch.where(input >= qmap[codes + 1], 1, 0) - - # rounding - codes_up = (codes + 1).clip(max=255) - val_down = qmap[codes] - val_up = qmap[codes_up] - residual = input - val_down - codes = torch.where(residual >= (val_up - val_down) * 0.5, codes_up, codes) - - codes = codes.to(torch.uint8) - - else: - raise ValueError(f"Unsupported implementation={implementation}") - - return codes, scale - - # dynamic tree quantization # https://arxiv.org/pdf/1511.04561 # https://arxiv.org/abs/2110.02861 @@ -137,12 +46,8 @@ def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=No return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes) def dequantize(self, output_dtype=None): - # torch.compile() cannot use uint8 as index - float_data = self.qmap[self.codes.int()] - float_data = float_data.view(-1, self.block_size) * self.scale.view(-1, 1) - dtype = output_dtype or torch.get_default_dtype() - return float_data.view(self.codes.shape).to(dtype) + return dequant_with_qmap(self.codes, self.qmap, self.scale).to(dtype) @classmethod def zeros(cls, shape, signed: bool = True, block_size: int = 2048, device=None): @@ -177,7 +82,8 @@ def _(func, *args, **kwargs): # qmap should be the same, don't need to copy elif isinstance(dst, OptimState8bit): - codes, scale = quantize_8bit_with_qmap(src, dst.qmap, dst.block_size) + scaled_src, scale = scale_tensor(src, dst.block_size) + codes = quantize_8bit_with_qmap(scaled_src, dst.qmap) dst.codes.copy_(codes) dst.scale.copy_(scale) @@ -191,14 +97,3 @@ def _(func, *args, **kwargs): def _(func, *args, **kwargs): args = [x.dequantize() if isinstance(x, OptimState8bit) else x for x in args] return func(*args, **kwargs) - - -# follow bitsandbytes -# only apply quantization for tensor with more than 4096 values -# TODO: also skip 1D tensor? e.g. biases and norm scales -def maybe_new_8bit_zero_buffer(p: Tensor, signed: bool = True, block_size: int = 2048): - if p.numel() >= 4096 and p.numel() % block_size == 0: - out = OptimState8bit.zeros(p.shape, signed, block_size, device=p.device) - else: - out = torch.zeros_like(p) - return out diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 43d45b0d5b..e3116e20f8 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -8,13 +8,16 @@ def quantize_fp8(input: Tensor, block_size: int): + shape = input.shape input = input.view(-1, block_size) scale = input.abs().amax(-1).clip(1e-12) / torch.finfo(DTYPE).max input = input / scale.view(-1, 1) codes = input.to(DTYPE).view(-1) - return codes, scale + return codes.view(shape), scale +# NOTE: FP8 sign bit is redundant for unsigned optim state. +# we may investigate how to use it to increase range/precision for unsigned optim state. class OptimStateFp8(Tensor): implements = classmethod(_implements) tensor_attrs = ["codes", "scale"] @@ -96,11 +99,3 @@ def _(func, *args, **kwargs): def _(func, *args, **kwargs): args = [x.dequantize() if isinstance(x, OptimStateFp8) else x for x in args] return func(*args, **kwargs) - - -def maybe_new_fp8_zero_buffer(p: Tensor, block_size: int = 2048): - if p.numel() >= 4096 and p.numel() % block_size == 0: - out = OptimStateFp8.zeros(p.shape, block_size, device=p.device) - else: - out = torch.zeros_like(p) - return out