diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index afeefa223..701d90e22 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -75,15 +75,16 @@ def test_quantize_4bit_with_qmap_compile(self, device): class TestOptim(TestCase): - @pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="torch.compile() fails for PyTorch < 2.3") + @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") @parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"]) @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("device", _DEVICES) def test_optim_smoke(self, optim_name, dtype, device): - if optim_name.endswith("Fp8") and device == "cuda" and torch.cuda.get_device_capability() < (8, 9): - pytest.skip("FP8 requires compute capability >= 8.9") - if optim_name.endswith("4bit") and not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("4-bit Adam requires PyTorch > 2.4") + if optim_name.endswith("Fp8") and device == "cuda": + if not TORCH_VERSION_AT_LEAST_2_4: + pytest.skip("FP8 CUDA requires PyTorch >= 2.4") + if torch.cuda.get_device_capability() < (8, 9): + pytest.skip("FP8 requires compute capability >= 8.9") # reset cache to avoid hitting cache_size_limit, since the function will re-compile for each test torch._dynamo.reset_code_caches() @@ -100,7 +101,7 @@ def test_optim_smoke(self, optim_name, dtype, device): @pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle") @pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA") - @pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="torch.compile() fails for PyTorch < 2.3") + @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") @parametrize("optim_name", ["Adam8bit", "AdamW8bit"]) def test_optim_8bit_correctness(self, optim_name): device = "cuda" @@ -126,9 +127,10 @@ def test_optim_8bit_correctness(self, optim_name): for p1, p2 in zip(model1.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) + # this will not run in CI because we can't install lpmm @pytest.mark.skipif(lpmm is None, reason="lpmm is not availablle") @pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA") - @pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="torch.compile() fails for PyTorch < 2.3") + @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") @parametrize("optim_name", ["Adam4bit", "AdamW4bit"]) def test_optim_4bit_correctness(self, optim_name): device = "cuda" diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md index 5968b2a79..b1f955e65 100644 --- a/torchao/prototype/low_bit_optim/README.md +++ b/torchao/prototype/low_bit_optim/README.md @@ -24,7 +24,8 @@ To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. Y **Other optimizers**: AdamW is also available as `AdamW8bit`, `AdamW4bit`, and `AdamWFp8`. Other optimizers can be added based on demand. NOTE: -- The low-bit optimizers require PyTorch >= 2.3. FP8 optimizers require CUDA compute capability >= 8.9. +- The low-bit optimizers require PyTorch >= 2.3 +- For FP8 optimizers on CUDA, PyTorch >= 2.4 and CUDA compute capability >= 8.9 are required. - For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper. - The first training step is expected to be slow since the optimizer needs to be compiled. diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 4b0b29534..e609696bd 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -198,7 +198,7 @@ def __init__( @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimState4bit.zeros(p.shape, signed, block_size, p.device) + return OptimState4bit.zeros(p.view(-1).shape, signed, block_size, p.device) @staticmethod def _unwrap_dtensor(p: Tensor): @@ -216,6 +216,11 @@ def step(self, closure=None): # NOTE: right now, torch.compile(param_groups_adam) will have excessive memory usage for 4-bit optim. # thus, as a workaround, we use torch.compile(single_param_adam) and call it for each param. + # NOTE: we have to create flattened optimizer states since torch.compile() will fail otherwise for + # PyTorch 2.3 and 2.4 + # calling exp_avg.view(-1) will fail torch.compile(single_param_adam) even if we implement the op + # correctly for the tensor subclass. + # unwrap DTensor since DTensor does not work well with dynamic compile # flatten p, grad, and optim state to avoid recompilation for group, lr, (beta1, beta2), weight_decay, eps in param_groups: @@ -227,9 +232,9 @@ def step(self, closure=None): self._unwrap_dtensor(p).view(-1), self._unwrap_dtensor(grad).view(-1), step, - self._unwrap_dtensor(exp_avg).view(-1), - self._unwrap_dtensor(exp_avg_sq).view(-1), - self._unwrap_dtensor(max_exp_avg_sq).view(-1) if max_exp_avg_sq is not None else None, + self._unwrap_dtensor(exp_avg), + self._unwrap_dtensor(exp_avg_sq), + self._unwrap_dtensor(max_exp_avg_sq) if max_exp_avg_sq is not None else None, lr, beta1, beta2, @@ -296,7 +301,7 @@ def __init__( @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimState4bit.zeros(p.shape, signed, block_size, p.device) + return OptimState4bit.zeros(p.view(-1).shape, signed, block_size, p.device) @staticmethod def _unwrap_dtensor(p: Tensor): @@ -314,6 +319,11 @@ def step(self, closure=None): # NOTE: right now, torch.compile(param_groups_adam) will have excessive memory usage for 4-bit optim. # thus, as a workaround, we use torch.compile(single_param_adam) and call it for each param. + # NOTE: we have to create flattened optimizer states since torch.compile() will fail otherwise for + # PyTorch 2.3 and 2.4 + # calling exp_avg.view(-1) will fail torch.compile(single_param_adam) even if we implement the op + # correctly for the tensor subclass. + # unwrap DTensor since DTensor does not work well with dynamic compile # flatten p, grad, and optim state to avoid recompilation for group, lr, (beta1, beta2), weight_decay, eps in param_groups: @@ -325,9 +335,9 @@ def step(self, closure=None): self._unwrap_dtensor(p).view(-1), self._unwrap_dtensor(grad).view(-1), step, - self._unwrap_dtensor(exp_avg).view(-1), - self._unwrap_dtensor(exp_avg_sq).view(-1), - self._unwrap_dtensor(max_exp_avg_sq).view(-1) if max_exp_avg_sq is not None else None, + self._unwrap_dtensor(exp_avg), + self._unwrap_dtensor(exp_avg_sq), + self._unwrap_dtensor(max_exp_avg_sq) if max_exp_avg_sq is not None else None, lr, beta1, beta2,