Skip to content

Commit

Permalink
Skip QLoRA unit tests broken by latest ao nightly (#568)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers authored Mar 23, 2024
1 parent 3e6f9c1 commit 81d93bb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tests/torchtune/models/test_lora_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def test_lora_llama2_state_dict_parity(
assert not unexpected
assert all(["lora" in key for key in missing])

@pytest.mark.skip(reason="broken by ao nightly")
def test_lora_linear_quantize_base(self):
model = self.get_lora_llama2(
lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"],
Expand All @@ -195,6 +196,7 @@ def test_lora_linear_quantize_base(self):
if isinstance(module, LoRALinear):
assert module._quantize_base

@pytest.mark.skip(reason="broken by ao nightly")
def test_qlora_llama2_parity(self, inputs):
with utils.set_default_dtype(torch.bfloat16):
model_ref = self.get_lora_llama2(
Expand Down Expand Up @@ -222,6 +224,7 @@ def test_qlora_llama2_parity(self, inputs):
output = qlora(inputs)
torch.testing.assert_close(ref_output, output)

@pytest.mark.skip(reason="broken by ao nightly")
def test_qlora_llama2_state_dict(self):
with utils.set_default_dtype(torch.bfloat16):
model_ref = self.get_lora_llama2(
Expand Down Expand Up @@ -255,6 +258,7 @@ def test_qlora_llama2_state_dict(self):
for v in qlora_sd.values():
assert v.dtype == torch.bfloat16

@pytest.mark.skip(reason="broken by ao nightly")
def test_qlora_llama2_merged_state_dict(self):
with utils.set_default_dtype(torch.bfloat16):
qlora = self.get_lora_llama2(
Expand Down
1 change: 1 addition & 0 deletions tests/torchtune/modules/peft/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def test_forward(self, inputs, lora_linear, out_dim) -> None:
assert actual.shape == (BSZ, SEQ_LEN, out_dim)
torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6)

@pytest.mark.skip(reason="broken by ao nightly")
def test_lora_weight_nf4_when_quantized(self, qlora_linear):
assert isinstance(qlora_linear.weight, NF4Tensor)

Expand Down

0 comments on commit 81d93bb

Please sign in to comment.