Skip to content

Commit 6ee0b57

Browse files
Skipped matmuls where no loras are needed
Signed-off-by: Akshat Tripathi <akshat@krai.ai>
1 parent 2aacb34 commit 6ee0b57

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

vllm/lora/ops/xla_ops/pallas.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,20 @@ def _():
2020

2121
for i in range(max_num_loras):
2222
mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32)
23+
valid = False
2324
for j in range(bT):
25+
valid |= idx_ref[j + bT * t] == i
26+
2427
@pl.when(idx_ref[j + bT * t] == i)
2528
def _():
2629
mask_ref[j, :] = jnp.ones((bL, ), dtype=jnp.float32)
2730

28-
acc_ref[...] += jax.lax.dot_general(
29-
inp_ref[...],
30-
lora_ref[i, ...], (((1, ), (1, )), ((), ())),
31-
preferred_element_type=jnp.float32) * mask_ref[...]
31+
@pl.when(valid)
32+
def _():
33+
acc_ref[...] += jax.lax.dot_general(
34+
inp_ref[...],
35+
lora_ref[i, ...], (((1, ), (1, )), ((), ())),
36+
preferred_element_type=jnp.float32) * mask_ref[...]
3237

3338
@pl.when(pl.program_id(2) == pl.num_programs(2) - 1)
3439
def _():

0 commit comments

Comments
 (0)