Skip to content

Commit 2aacb34

Browse files
Improved lora output masking
Signed-off-by: Akshat Tripathi <akshat@krai.ai>
1 parent 7418b5a commit 2aacb34

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

vllm/lora/ops/xla_ops/pallas.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,24 @@
1010
from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard,
1111
make_kernel_from_pallas)
1212

13-
def _bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref,
13+
def _bgmv_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, inp_ref, lora_ref, out_ref,
1414
acc_ref, mask_ref):
1515
@pl.when(pl.program_id(2) == 0)
1616
def _():
1717
acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32)
1818

1919
t = pl.program_id(0)
20-
21-
for i in range(bT):
22-
idx = idx_ref[i + bT * t]
20+
21+
for i in range(max_num_loras):
2322
mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32)
24-
mask_ref[i, :] = jnp.ones((bL, ), dtype=jnp.float32)
23+
for j in range(bT):
24+
@pl.when(idx_ref[j + bT * t] == i)
25+
def _():
26+
mask_ref[j, :] = jnp.ones((bL, ), dtype=jnp.float32)
2527

2628
acc_ref[...] += jax.lax.dot_general(
2729
inp_ref[...],
28-
lora_ref[idx, ...], (((1, ), (1, )), ((), ())),
30+
lora_ref[i, ...], (((1, ), (1, )), ((), ())),
2931
preferred_element_type=jnp.float32) * mask_ref[...]
3032

3133
@pl.when(pl.program_id(2) == pl.num_programs(2) - 1)
@@ -47,7 +49,7 @@ def _bgmv(
4749
N, L, _ = loras.shape
4850

4951
return pl.pallas_call(
50-
kernel=functools.partial(_bgmv_kernel, TOKEN_BLOCK_SIZE, LORA_RANK_BLOCK_SIZE),
52+
kernel=functools.partial(_bgmv_kernel, TOKEN_BLOCK_SIZE, LORA_RANK_BLOCK_SIZE, N),
5153
out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype),
5254
grid_spec=pltpu.PrefetchScalarGridSpec(
5355
num_scalar_prefetch=1,

0 commit comments

Comments
 (0)