1010from torch_xla .experimental .custom_kernel import (XLA_LIB , jax_import_guard ,
1111 make_kernel_from_pallas )
1212
13- def _bgmv_kernel (bT : int , bL : int , idx_ref , inp_ref , lora_ref , out_ref ,
13+ def _bgmv_kernel (bT : int , bL : int , max_num_loras : int , idx_ref , inp_ref , lora_ref , out_ref ,
1414 acc_ref , mask_ref ):
1515 @pl .when (pl .program_id (2 ) == 0 )
1616 def _ ():
1717 acc_ref [...] = jnp .zeros_like (acc_ref [...], dtype = jnp .float32 )
1818
1919 t = pl .program_id (0 )
20-
21- for i in range (bT ):
22- idx = idx_ref [i + bT * t ]
20+
21+ for i in range (max_num_loras ):
2322 mask_ref [...] = jnp .zeros_like (mask_ref [...], dtype = jnp .float32 )
24- mask_ref [i , :] = jnp .ones ((bL , ), dtype = jnp .float32 )
23+ for j in range (bT ):
24+ @pl .when (idx_ref [j + bT * t ] == i )
25+ def _ ():
26+ mask_ref [j , :] = jnp .ones ((bL , ), dtype = jnp .float32 )
2527
2628 acc_ref [...] += jax .lax .dot_general (
2729 inp_ref [...],
28- lora_ref [idx , ...], (((1 , ), (1 , )), ((), ())),
30+ lora_ref [i , ...], (((1 , ), (1 , )), ((), ())),
2931 preferred_element_type = jnp .float32 ) * mask_ref [...]
3032
3133 @pl .when (pl .program_id (2 ) == pl .num_programs (2 ) - 1 )
@@ -47,7 +49,7 @@ def _bgmv(
4749 N , L , _ = loras .shape
4850
4951 return pl .pallas_call (
50- kernel = functools .partial (_bgmv_kernel , TOKEN_BLOCK_SIZE , LORA_RANK_BLOCK_SIZE ),
52+ kernel = functools .partial (_bgmv_kernel , TOKEN_BLOCK_SIZE , LORA_RANK_BLOCK_SIZE , N ),
5153 out_shape = jax .ShapeDtypeStruct ((T , L ), dtype = inputs .dtype ),
5254 grid_spec = pltpu .PrefetchScalarGridSpec (
5355 num_scalar_prefetch = 1 ,
0 commit comments