Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[low-bit optim] Fix Adam4bit support on PyTorch 2.3 and 2.4. Update AdamFp8 torch requirement #755

Merged
merged 7 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,16 @@ def test_quantize_4bit_with_qmap_compile(self, device):


class TestOptim(TestCase):
@pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="torch.compile() fails for PyTorch < 2.3")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
@parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"])
@parametrize("dtype", [torch.float32, torch.bfloat16])
@parametrize("device", _DEVICES)
def test_optim_smoke(self, optim_name, dtype, device):
if optim_name.endswith("Fp8") and device == "cuda" and torch.cuda.get_device_capability() < (8, 9):
pytest.skip("FP8 requires compute capability >= 8.9")
if optim_name.endswith("4bit") and not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("4-bit Adam requires PyTorch > 2.4")
if optim_name.endswith("Fp8") and device == "cuda":
if not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("FP8 CUDA requires PyTorch >= 2.4")
if torch.cuda.get_device_capability() < (8, 9):
pytest.skip("FP8 requires compute capability >= 8.9")

# reset cache to avoid hitting cache_size_limit, since the function will re-compile for each test
torch._dynamo.reset_code_caches()
Expand All @@ -100,7 +101,7 @@ def test_optim_smoke(self, optim_name, dtype, device):

@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
@pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="torch.compile() fails for PyTorch < 2.3")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"])
def test_optim_8bit_correctness(self, optim_name):
device = "cuda"
Expand All @@ -126,9 +127,10 @@ def test_optim_8bit_correctness(self, optim_name):
for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)

# this will not run in CI because we can't install lpmm
@pytest.mark.skipif(lpmm is None, reason="lpmm is not availablle")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA")
@pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="torch.compile() fails for PyTorch < 2.3")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
@parametrize("optim_name", ["Adam4bit", "AdamW4bit"])
def test_optim_4bit_correctness(self, optim_name):
device = "cuda"
Expand Down
3 changes: 2 additions & 1 deletion torchao/prototype/low_bit_optim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. Y
**Other optimizers**: AdamW is also available as `AdamW8bit`, `AdamW4bit`, and `AdamWFp8`. Other optimizers can be added based on demand.

NOTE:
- The low-bit optimizers require PyTorch >= 2.3. FP8 optimizers require CUDA compute capability >= 8.9.
- The low-bit optimizers require PyTorch >= 2.3
- For FP8 optimizers on CUDA, PyTorch >= 2.4 and CUDA compute capability >= 8.9 are required.
- For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper.
- The first training step is expected to be slow since the optimizer needs to be compiled.

Expand Down
26 changes: 18 additions & 8 deletions torchao/prototype/low_bit_optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def __init__(

@staticmethod
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
return OptimState4bit.zeros(p.shape, signed, block_size, p.device)
return OptimState4bit.zeros(p.view(-1).shape, signed, block_size, p.device)

@staticmethod
def _unwrap_dtensor(p: Tensor):
Expand All @@ -216,6 +216,11 @@ def step(self, closure=None):
# NOTE: right now, torch.compile(param_groups_adam) will have excessive memory usage for 4-bit optim.
# thus, as a workaround, we use torch.compile(single_param_adam) and call it for each param.

# NOTE: we have to create flattened optimizer states since torch.compile() will fail otherwise for
# PyTorch 2.3 and 2.4
# calling exp_avg.view(-1) will fail torch.compile(single_param_adam) even if we implement the op
# correctly for the tensor subclass.

# unwrap DTensor since DTensor does not work well with dynamic compile
# flatten p, grad, and optim state to avoid recompilation
for group, lr, (beta1, beta2), weight_decay, eps in param_groups:
Expand All @@ -227,9 +232,9 @@ def step(self, closure=None):
self._unwrap_dtensor(p).view(-1),
self._unwrap_dtensor(grad).view(-1),
step,
self._unwrap_dtensor(exp_avg).view(-1),
self._unwrap_dtensor(exp_avg_sq).view(-1),
self._unwrap_dtensor(max_exp_avg_sq).view(-1) if max_exp_avg_sq is not None else None,
self._unwrap_dtensor(exp_avg),
self._unwrap_dtensor(exp_avg_sq),
self._unwrap_dtensor(max_exp_avg_sq) if max_exp_avg_sq is not None else None,
lr,
beta1,
beta2,
Expand Down Expand Up @@ -296,7 +301,7 @@ def __init__(

@staticmethod
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
return OptimState4bit.zeros(p.shape, signed, block_size, p.device)
return OptimState4bit.zeros(p.view(-1).shape, signed, block_size, p.device)

@staticmethod
def _unwrap_dtensor(p: Tensor):
Expand All @@ -314,6 +319,11 @@ def step(self, closure=None):
# NOTE: right now, torch.compile(param_groups_adam) will have excessive memory usage for 4-bit optim.
# thus, as a workaround, we use torch.compile(single_param_adam) and call it for each param.

# NOTE: we have to create flattened optimizer states since torch.compile() will fail otherwise for
# PyTorch 2.3 and 2.4
# calling exp_avg.view(-1) will fail torch.compile(single_param_adam) even if we implement the op
# correctly for the tensor subclass.

# unwrap DTensor since DTensor does not work well with dynamic compile
# flatten p, grad, and optim state to avoid recompilation
for group, lr, (beta1, beta2), weight_decay, eps in param_groups:
Expand All @@ -325,9 +335,9 @@ def step(self, closure=None):
self._unwrap_dtensor(p).view(-1),
self._unwrap_dtensor(grad).view(-1),
step,
self._unwrap_dtensor(exp_avg).view(-1),
self._unwrap_dtensor(exp_avg_sq).view(-1),
self._unwrap_dtensor(max_exp_avg_sq).view(-1) if max_exp_avg_sq is not None else None,
self._unwrap_dtensor(exp_avg),
self._unwrap_dtensor(exp_avg_sq),
self._unwrap_dtensor(max_exp_avg_sq) if max_exp_avg_sq is not None else None,
lr,
beta1,
beta2,
Expand Down
Loading