Skip to content

Commit eb804a0

Browse files
Added the LoRA Laning optimisation + tests + explanation
Signed-off-by: Akshat Tripathi <akshat@krai.ai>
1 parent dc8b940 commit eb804a0

File tree

4 files changed

+243
-47
lines changed

4 files changed

+243
-47
lines changed

tests/lora/tpu/test_pallas_kernels.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ def generate_test_data(T, D, L, N, seed, dtype=torch.float32):
3030
D: Input dim
3131
L: LoRA Dim
3232
N: N LoRAs
33-
33+
3434
Outputs:
3535
inputs: torch.Tensor - shape (T, D)
3636
loras: torch.Tensor - shape (N, 1, L, D)
3737
idxs: torch.Tensor - shape (T, ) - all values must be in [0, N)
38-
38+
3939
ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T
4040
"""
4141
torch.manual_seed(seed)
@@ -84,3 +84,28 @@ def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed):
8484

8585
# Compare with reference output
8686
assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2)
87+
88+
# Parameterize tests with various shapes and dtypes
89+
@pytest.mark.parametrize("T", N_TOKENS)
90+
@pytest.mark.parametrize("D", HIDDEN_SIZES)
91+
@pytest.mark.parametrize("L", RANKS)
92+
@pytest.mark.parametrize("N", NUM_LORA)
93+
@pytest.mark.parametrize("dtype", DTYPES)
94+
@pytest.mark.parametrize("seed", [0])
95+
def test_lora_laning_correctness(T, D, L, N, dtype, seed):
96+
inputs, loras_a, idxs, _ = generate_test_data(T, D, L, N, seed, dtype)
97+
_, loras_b, _, _ = generate_test_data(T, L, D, N, seed, dtype)
98+
99+
r1 = ref_bgmv(inputs, loras_a, idxs)
100+
r2 = ref_bgmv(r1, loras_b, idxs)
101+
102+
o1 = torch.ops.xla.bgmv_shrink(inputs, loras_a, idxs)
103+
o2 = torch.ops.xla.bgmv_expand(
104+
o1,
105+
loras_b.transpose(2, 3),
106+
idxs,
107+
True
108+
)
109+
110+
# Compare with reference output
111+
assert torch.allclose(o2, r2, rtol=1e-2, atol=1e-2)

vllm/lora/ops/xla_ops/lora_ops.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@ def bgmv_expand(inputs: torch.Tensor,
1111
lora_b_weights: torch.Tensor,
1212
output_tensor: torch.Tensor,
1313
lora_indices_tensor: torch.Tensor,
14-
add_inputs: bool = True):
14+
add_inputs: bool = True,
15+
*,
16+
enable_laning: bool = False):
1517

1618
outputs = torch.ops.xla.bgmv_expand(inputs, lora_b_weights.transpose(2, 3),
17-
lora_indices_tensor)
19+
lora_indices_tensor, enable_laning)
1820

1921
limit = output_tensor.shape[0]
2022
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
@@ -46,9 +48,11 @@ def bgmv_expand_slice(inputs: torch.Tensor,
4648
lora_indices_tensor: torch.Tensor,
4749
slice_offset: int,
4850
slice_size: int,
49-
add_inputs: bool = True):
51+
add_inputs: bool = True,
52+
*,
53+
enable_laning: bool = False):
5054
outputs = torch.ops.xla.bgmv_expand(inputs, lora_b_weights.transpose(2, 3),
51-
lora_indices_tensor)
55+
lora_indices_tensor, enable_laning)
5256

5357
outputs = F.pad(outputs, (slice_offset, output_tensor.shape[1] -
5458
(slice_offset + slice_size), 0, 0))

0 commit comments

Comments
 (0)