Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[low-bit optim] Fix Adam4bit support on PyTorch 2.3 and 2.4. Update AdamFp8 torch requirement #755

Merged
merged 7 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch.testing._internal.common_fsdp import FSDPTest
from torchao.prototype import low_bit_optim
from torchao.prototype.low_bit_optim.quant_utils import quantize_8bit_with_qmap, quantize_4bit_with_qmap
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5

try:
import bitsandbytes as bnb
Expand Down Expand Up @@ -75,15 +75,15 @@ def test_quantize_4bit_with_qmap_compile(self, device):


class TestOptim(TestCase):
@pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="torch.compile() fails for PyTorch < 2.3")
@pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
@parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"])
@parametrize("dtype", [torch.float32, torch.bfloat16])
@parametrize("device", _DEVICES)
def test_optim_smoke(self, optim_name, dtype, device):
if optim_name.endswith("Fp8") and device == "cuda" and torch.cuda.get_device_capability() < (8, 9):
pytest.skip("FP8 requires compute capability >= 8.9")
if optim_name.endswith("4bit") and not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("4-bit Adam requires PyTorch > 2.4")
pytest.skip("4-bit Adam requires PyTorch >= 2.5")

# reset cache to avoid hitting cache_size_limit, since the function will re-compile for each test
torch._dynamo.reset_code_caches()
Expand All @@ -100,7 +100,7 @@ def test_optim_smoke(self, optim_name, dtype, device):

@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
@pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="torch.compile() fails for PyTorch < 2.3")
@pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"])
def test_optim_8bit_correctness(self, optim_name):
device = "cuda"
Expand Down Expand Up @@ -128,7 +128,7 @@ def test_optim_8bit_correctness(self, optim_name):

@pytest.mark.skipif(lpmm is None, reason="lpmm is not availablle")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA")
@pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="torch.compile() fails for PyTorch < 2.3")
@pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_5, reason="requires PyTorch >= 2.5")
@parametrize("optim_name", ["Adam4bit", "AdamW4bit"])
def test_optim_4bit_correctness(self, optim_name):
device = "cuda"
Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/low_bit_optim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. Y
**Other optimizers**: AdamW is also available as `AdamW8bit`, `AdamW4bit`, and `AdamWFp8`. Other optimizers can be added based on demand.

NOTE:
- The low-bit optimizers require PyTorch >= 2.3. FP8 optimizers require CUDA compute capability >= 8.9.
- 8-bit optimizers require PyTorch >= 2.3. 4-bit optimizers require PyTorch >= 2.5. FP8 optimizers require CUDA compute capability >= 8.9.
- For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper.
- The first training step is expected to be slow since the optimizer needs to be compiled.

Expand Down
Loading