From 27b1d162b25be3a4ef29d7afb3c02ca7de9ec1f4 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Tue, 5 Nov 2024 19:10:28 -0800 Subject: [PATCH] Fix for weights-only load (#1228) stack-info: PR: https://github.com/pytorch/ao/pull/1228, branch: drisspg/stack/19 --- test/prototype/test_low_bit_optim.py | 5 +++-- torchao/prototype/low_bit_optim/adam.py | 12 +++++++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index c87f482dbb..880fb1b109 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -345,8 +345,9 @@ def test_optim_bf16_stochastic_round_correctness(self): optim2.step() optim2.zero_grad() - torch.testing.assert_close(loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}") - + torch.testing.assert_close( + loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}" + ) @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") @parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"]) diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 0895841473..053c9cb943 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -13,7 +13,17 @@ class _AdamBase(Optimizer): def __init__( - self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size, bf16_stochastic_round, is_adamw, exclude_low_bit_optim_params=None + self, + params, + lr, + betas, + eps, + weight_decay, + amsgrad, + *, + block_size, + bf16_stochastic_round, + is_adamw, ) -> None: if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr))