Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 27, 2024
1 parent ed4aa71 commit 4cea083
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,10 @@ def test_optim_8bit_correctness(self, optim_name):
for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)

# this will not run in CI because we can't install lpmm
@pytest.mark.skipif(lpmm is None, reason="lpmm is not availablle")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
@parametrize("optim_name", ["Adam4bit", "AdamW4bit"])
def test_optim_4bit_correctness(self, optim_name):
device = "cuda"
Expand Down

0 comments on commit 4cea083

Please sign in to comment.