Skip to content

Commit 8f1b7b7

Browse files
committed
Fix test
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent 4fc3209 commit 8f1b7b7

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

tests/lora/test_layers.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ def populate_loras(
164164
weight=layer_weights,
165165
generate_embeddings_tensor=generate_embeddings_tensor,
166166
)
167-
sublora.lora_b = sublora.lora_b[:, (sublora_len *
168-
i):(sublora_len * (i + 1))]
167+
sublora.lora_b = sublora.lora_b[(sublora_len *
168+
i):(sublora_len * (i + 1)), :]
169169
sublora.optimize()
170170
subloras.append(sublora)
171171

@@ -304,9 +304,9 @@ def create_random_embedding_layer():
304304
result = embedding(input_)
305305
after_a = F.embedding(
306306
input_,
307-
lora.lora_a,
307+
lora.lora_a.T,
308308
)
309-
result += (after_a @ lora.lora_b)
309+
result += (after_a @ lora.lora_b.T)
310310
expected_results.append(result)
311311
expected_result = torch.cat(expected_results)
312312

@@ -445,9 +445,9 @@ def create_random_embedding_layer():
445445
result = expanded_embedding(input_)
446446
after_a = F.embedding(
447447
original_input_,
448-
lora.lora_a,
448+
lora.lora_a.T,
449449
)
450-
result += (after_a @ lora.lora_b)
450+
result += (after_a @ lora.lora_b.T)
451451
expected_results.append(result)
452452
expected_result = torch.cat(expected_results)
453453

@@ -575,7 +575,7 @@ def _pretest():
575575
lm_head=linear,
576576
embedding_bias=None)
577577
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
578-
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
578+
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
579579
expected_results.append(result)
580580
expected_result = torch.cat(expected_results)
581581
logits_processor.org_vocab_size = vocab_size
@@ -692,9 +692,10 @@ def create_random_linear_replicated_layer():
692692

693693
expected_results: list[torch.Tensor] = []
694694
for input_, lora_id in zip(inputs, prompt_mapping):
695+
695696
lora = lora_dict[lora_id]
696697
result = linear(input_)[0]
697-
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
698+
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
698699
expected_results.append(result)
699700
expected_result = torch.cat(expected_results)
700701

@@ -817,7 +818,7 @@ def create_random_linear_parallel_layer():
817818
for input_, lora_id in zip(inputs, prompt_mapping):
818819
lora = lora_dict[lora_id]
819820
result = linear(input_)[0]
820-
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
821+
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
821822
expected_results.append(result)
822823
expected_result = torch.cat(expected_results)
823824

@@ -965,9 +966,10 @@ class FakeConfig:
965966
result = linear(input_)[0]
966967
subloras = sublora_dict[lora_id]
967968
for i, sublora in enumerate(subloras):
968-
result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] *
969-
(i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b *
970-
sublora.scaling)
969+
result[:, sublora.lora_b.shape[0] * i:sublora.lora_b.shape[0] *
970+
(i + 1)] += (
971+
input_ @ sublora.lora_a.T @ sublora.lora_b.T *
972+
sublora.scaling)
971973
expected_results.append(result)
972974
expected_result = torch.cat(expected_results)
973975

tests/lora/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ def init_random_lora(
3636
module_name,
3737
rank=rank,
3838
lora_alpha=1,
39-
lora_a=torch.rand([weight.shape[1], rank],
39+
lora_a=torch.rand([rank, weight.shape[1]],
4040
dtype=weight.dtype,
4141
device=self._device),
42-
lora_b=torch.rand([rank, weight.shape[0]],
42+
lora_b=torch.rand([weight.shape[0], rank],
4343
dtype=weight.dtype,
4444
device=self._device),
4545
)
@@ -67,8 +67,8 @@ def init_lora(
6767
module_name,
6868
rank=rank,
6969
lora_alpha=1,
70-
lora_a=torch.rand([input_dim, rank], device="cuda"),
71-
lora_b=torch.rand([rank, output_dim], device="cuda"),
70+
lora_a=torch.rand([rank, input_dim], device="cuda"),
71+
lora_b=torch.rand([output_dim, input_dim], device="cuda"),
7272
embeddings_tensor=embeddings_tensor,
7373
)
7474
self.set_module_lora(module_name, lora)

0 commit comments

Comments
 (0)