Skip to content

Commit 09491e6

Browse files
committed
fix failing unit test
1 parent b197368 commit 09491e6

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

tests/torchtune/modules/peft/test_lora.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -50,23 +50,27 @@ def inputs(self, in_dim) -> torch.Tensor:
5050

5151
@pytest.fixture
5252
def lora_linear(self, in_dim, out_dim) -> LoRALinear:
53-
lora_linear = LoRALinear(
54-
in_dim=in_dim,
55-
out_dim=out_dim,
56-
rank=RANK,
57-
alpha=ALPHA,
58-
use_bias=True,
59-
)
60-
fixed_init_model(lora_linear)
61-
return lora_linear
53+
def create_lora_linear(use_bias, dtype, in_dim=in_dim, out_dim=out_dim):
54+
with training.set_default_dtype(dtype):
55+
lora_linear = LoRALinear(
56+
in_dim=in_dim,
57+
out_dim=out_dim,
58+
rank=RANK,
59+
alpha=ALPHA,
60+
use_bias=use_bias,
61+
)
62+
fixed_init_model(lora_linear)
63+
return lora_linear
64+
65+
return create_lora_linear
6266

6367
@pytest.fixture
6468
def qlora_linear(self):
65-
def create_qlora_linear(use_bias, dtype):
69+
def create_qlora_linear(use_bias, dtype, in_dim=512, out_dim=512):
6670
with training.set_default_dtype(dtype):
6771
qlora_linear = LoRALinear(
68-
in_dim=512,
69-
out_dim=512,
72+
in_dim=in_dim,
73+
out_dim=out_dim,
7074
rank=RANK,
7175
alpha=ALPHA,
7276
use_bias=use_bias,
@@ -95,6 +99,7 @@ def set_dummy_weights_for_merge(self, lora_module):
9599
lora_module.lora_b.weight[32, 1] = 12
96100

97101
def test_forward(self, inputs, lora_linear, out_dim) -> None:
102+
lora_linear = lora_linear(use_bias=True, dtype=torch.float32)
98103
expected = torch.tensor(EXPECTED_VAL)
99104
actual = lora_linear(inputs)
100105
assert actual.shape == (BSZ, SEQ_LEN, out_dim)
@@ -115,18 +120,13 @@ def test_lora_weight_nf4_when_quantized(self, use_bias, qlora_linear):
115120
"use_bias, dtype",
116121
[(False, torch.bfloat16), (True, torch.float32), (False, torch.float32)],
117122
)
118-
def test_qlora_parity(self, use_bias, dtype, qlora_linear):
119-
qlora_linear = qlora_linear(use_bias=use_bias, dtype=dtype)
120-
with training.set_default_dtype(dtype):
121-
lora_linear = LoRALinear(
122-
in_dim=512,
123-
out_dim=512,
124-
rank=RANK,
125-
alpha=ALPHA,
126-
use_bias=use_bias,
127-
quantize_base=False,
128-
)
129-
fixed_init_model(lora_linear, dtype=torch.bfloat16)
123+
def test_qlora_parity(self, use_bias, dtype, qlora_linear, lora_linear):
124+
qlora_linear = qlora_linear(
125+
use_bias=use_bias, dtype=dtype, in_dim=512, out_dim=512
126+
)
127+
lora_linear = lora_linear(
128+
use_bias=use_bias, dtype=dtype, in_dim=512, out_dim=512
129+
)
130130

131131
# set weight of lora_linear to unquantized weight of qlora_linear and check
132132
# parity.

0 commit comments

Comments
 (0)