11# SPDX-License-Identifier: Apache-2.0
22import functools
3+ import math
4+ from typing import List
35
46import jax
57import jax .numpy as jnp
1012from torch_xla .experimental .custom_kernel import (XLA_LIB , jax_import_guard ,
1113 make_kernel_from_pallas )
1214
13- def _bgmv_kernel (bT : int , bL : int , max_num_loras : int , idx_ref , inp_ref , lora_ref , out_ref ,
14- acc_ref , mask_ref ):
15+
16+ def _bgmv_kernel (bT : int , bL : int , max_num_loras : int , idx_ref , inp_ref ,
17+ lora_ref , out_ref , acc_ref , mask_ref ):
18+
1519 @pl .when (pl .program_id (2 ) == 0 )
1620 def _ ():
1721 acc_ref [...] = jnp .zeros_like (acc_ref [...], dtype = jnp .float32 )
1822
1923 t = pl .program_id (0 )
20-
24+
2125 ones = jnp .ones ((bL , ), dtype = jnp .float32 )
2226
2327 for i in range (max_num_loras ):
@@ -42,16 +46,16 @@ def _():
4246 out_ref [...] = acc_ref [...].astype (out_ref .dtype )
4347
4448
45- @functools .partial (jax .jit , static_argnames = ["TOKEN_BLOCK" , "LORA_BLOCK" , "DIM_BLOCK" ])
49+ @functools .partial (jax .jit ,
50+ static_argnames = ["TOKEN_BLOCK" , "LORA_BLOCK" , "DIM_BLOCK" ])
4651def _bgmv (
47- idxs : jax .Array , # (T, ) int32
48- inputs : jax .Array , # (T, D) model dtype
49- loras : jax .Array , # (N, L, D) model dtype
50- * ,
51- TOKEN_BLOCK : int ,
52- LORA_BLOCK : int ,
53- DIM_BLOCK : int
54- ) -> jax .Array : # (T, L) model dtype
52+ idxs : jax .Array , # (T, ) int32
53+ inputs : jax .Array , # (T, D) model dtype
54+ loras : jax .Array , # (N, L, D) model dtype
55+ * ,
56+ TOKEN_BLOCK : int ,
57+ LORA_BLOCK : int ,
58+ DIM_BLOCK : int ) -> jax .Array : # (T, L) model dtype
5559 T , D = inputs .shape
5660 N , L , _ = loras .shape
5761
@@ -60,8 +64,7 @@ def _bgmv(
6064 out_shape = jax .ShapeDtypeStruct ((T , L ), dtype = inputs .dtype ),
6165 grid_spec = pltpu .PrefetchScalarGridSpec (
6266 num_scalar_prefetch = 1 ,
63- grid = (T // TOKEN_BLOCK , L // LORA_BLOCK ,
64- D // DIM_BLOCK ),
67+ grid = (T // TOKEN_BLOCK , L // LORA_BLOCK , D // DIM_BLOCK ),
6568 in_specs = [
6669 pl .BlockSpec ((TOKEN_BLOCK , DIM_BLOCK ),
6770 lambda i , j , k , block_idx : (i , k )),
@@ -103,22 +106,18 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
103106 jax_import_guard ()
104107
105108 TOKEN_BLOCK = get_bounded_value (16 , next_multiple_of (T , 16 ), 128 )
106- if is_expand : # Expand
109+ if is_expand : # Expand
107110 LORA_BLOCK = min (1024 , next_multiple_of (L , 256 ))
108111 DIM_BLOCK = 256
109- else : # Shrink
112+ else : # Shrink
110113 LORA_BLOCK = 256
111114 DIM_BLOCK = min (1024 , next_multiple_of (D , 256 ))
112115
113116 kernel = make_kernel_from_pallas (
114- functools .partial (
115- _bgmv ,
116- TOKEN_BLOCK = TOKEN_BLOCK ,
117- LORA_BLOCK = LORA_BLOCK ,
118- DIM_BLOCK = DIM_BLOCK
119- ),
120- bgmv_shape_function
121- )
117+ functools .partial (_bgmv ,
118+ TOKEN_BLOCK = TOKEN_BLOCK ,
119+ LORA_BLOCK = LORA_BLOCK ,
120+ DIM_BLOCK = DIM_BLOCK ), bgmv_shape_function )
122121
123122 # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU
124123 # register. This has to happen in pytorch, doing it in Jax will lead to NaNs
@@ -157,10 +156,163 @@ def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor,
157156 return torch .empty ((T , L ), device = inputs .device )
158157
159158
159+ # This kernel is similar to the one above but it assumes that the LoRA adapters
160+ # have been pre-transposed. This lets us skip the data copies involved in
161+ # transposing.
162+ # We only need this for the expand op since the LoRA dimensions in the shrink op
163+ # are small enough that the TPU can gather them without a data copy.
164+ def _bgmv_pre_transpose_kernel (bT : int , bL : int , max_num_loras : int , idx_ref ,
165+ inp_ref , lora_ref , out_ref , acc_ref , mask_ref ):
166+
167+ @pl .when (pl .program_id (2 ) == 0 )
168+ def _ ():
169+ acc_ref [...] = jnp .zeros_like (acc_ref [...], dtype = jnp .float32 )
170+
171+ t = pl .program_id (0 )
172+
173+ ones = jnp .ones ((bL , ), dtype = jnp .float32 )
174+
175+ for i in range (max_num_loras ):
176+ mask_ref [...] = jnp .zeros_like (mask_ref [...], dtype = jnp .float32 )
177+ valid = False
178+ for j in range (bT ):
179+ valid |= idx_ref [j + bT * t ] == i
180+
181+ @pl .when (idx_ref [j + bT * t ] == i )
182+ def _ ():
183+ mask_ref .at [j , :].set (ones )
184+
185+ @pl .when (valid )
186+ def _ ():
187+ acc_ref [...] += jax .lax .dot (
188+ inp_ref [...],
189+ lora_ref [i , ...],
190+ preferred_element_type = jnp .float32 ) * mask_ref [...]
191+
192+ @pl .when (pl .program_id (2 ) == pl .num_programs (2 ) - 1 )
193+ def _ ():
194+ out_ref [...] = acc_ref [...].astype (out_ref .dtype )
195+
196+
197+ @functools .partial (jax .jit ,
198+ static_argnames = ["TOKEN_BLOCK" , "LORA_BLOCK" , "DIM_BLOCK" ])
199+ def _bgmv_pre_transpose (
200+ idxs : jax .Array , # (T, ) int32
201+ inputs : jax .Array , # (T, D) model dtype
202+ loras : jax .Array , # (N, L, D) model dtype
203+ * ,
204+ TOKEN_BLOCK : int ,
205+ LORA_BLOCK : int ,
206+ DIM_BLOCK : int ) -> jax .Array : # (T, L) model dtype
207+ T , D = inputs .shape
208+ N , _ , L = loras .shape
209+
210+ return pl .pallas_call (
211+ kernel = functools .partial (_bgmv_pre_transpose_kernel , TOKEN_BLOCK ,
212+ LORA_BLOCK , N ),
213+ out_shape = jax .ShapeDtypeStruct ((T , L ), dtype = inputs .dtype ),
214+ grid_spec = pltpu .PrefetchScalarGridSpec (
215+ num_scalar_prefetch = 1 ,
216+ grid = (T // TOKEN_BLOCK , L // LORA_BLOCK , D // DIM_BLOCK ),
217+ in_specs = [
218+ pl .BlockSpec ((TOKEN_BLOCK , DIM_BLOCK ),
219+ lambda i , j , k , block_idx : (i , k )),
220+ pl .BlockSpec ((N , DIM_BLOCK , LORA_BLOCK ),
221+ lambda i , j , k , block_idx : (0 , k , j )),
222+ ],
223+ out_specs = pl .BlockSpec ((TOKEN_BLOCK , LORA_BLOCK ),
224+ lambda i , j , k , block_idx : (i , j )),
225+ scratch_shapes = [
226+ pltpu .VMEM ((TOKEN_BLOCK , LORA_BLOCK ), jnp .float32 ),
227+ pltpu .VMEM ((TOKEN_BLOCK , LORA_BLOCK ), jnp .float32 )
228+ ]),
229+ compiler_params = pltpu .TPUCompilerParams (
230+ dimension_semantics = ("parallel" , "parallel" , "arbitrary" )),
231+ name = "bgmv_pre_transpose" )(idxs , inputs , loras )
232+
233+
234+ def bgmv_pre_transpose_shape_function (idxs , inputs , loras ):
235+ T , _ = inputs .shape
236+ _ , _ , L = loras .shape
237+
238+ return [((T , L ), inputs .dtype )]
239+
240+
241+ XLA_LIB .define (
242+ "bgmv_pre_transpose(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor" , )
243+
244+
245+ @impl (XLA_LIB , "bgmv_pre_transpose" , "XLA" )
246+ def bgmv_pre_transpose_xla (inputs : torch .Tensor , loras : torch .Tensor ,
247+ idxs : torch .IntTensor ):
248+ inputs = inputs .to (dtype = loras .dtype )
249+
250+ if len (loras .shape ) == 4 :
251+ loras = loras .squeeze (axis = 1 )
252+
253+ T , _ = inputs .shape
254+ _ , D , L = loras .shape
255+
256+ jax_import_guard ()
257+
258+ TOKEN_BLOCK = get_bounded_value (16 , next_multiple_of (T , 16 ), 128 )
259+ LORA_BLOCK = min (1024 , next_multiple_of (L , 256 ))
260+ DIM_BLOCK = 256
261+
262+ kernel = make_kernel_from_pallas (
263+ functools .partial (_bgmv_pre_transpose ,
264+ TOKEN_BLOCK = TOKEN_BLOCK ,
265+ LORA_BLOCK = LORA_BLOCK ,
266+ DIM_BLOCK = DIM_BLOCK ),
267+ bgmv_pre_transpose_shape_function )
268+
269+ # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU
270+ # register. This has to happen in pytorch, doing it in Jax will lead to NaNs
271+ pad_L = 0
272+ if LORA_BLOCK > L or L % LORA_BLOCK != 0 :
273+ pad_L = next_multiple_of (L , LORA_BLOCK ) - L
274+
275+ pad_D = 0
276+ if DIM_BLOCK > D or D % DIM_BLOCK != 0 :
277+ pad_D = next_multiple_of (D , DIM_BLOCK ) - D
278+
279+ pad_T = 0
280+ if TOKEN_BLOCK > T or T % TOKEN_BLOCK != 0 :
281+ pad_T = next_multiple_of (T , TOKEN_BLOCK ) - T
282+
283+ if pad_D != 0 or pad_L != 0 :
284+ loras = torch .nn .functional .pad (loras , (0 , pad_L , 0 , pad_D , 0 , 0 ))
285+ if pad_D != 0 or pad_T != 0 :
286+ inputs = torch .nn .functional .pad (inputs , (0 , pad_D , 0 , pad_T ))
287+ if pad_T != T :
288+ idxs = torch .nn .functional .pad (idxs , ((0 , pad_T )))
289+
290+ return kernel (idxs , inputs , loras )[:T , :L ]
291+
292+
293+ @impl (XLA_LIB , "bgmv_pre_transpose" , "CompositeExplicitAutograd" )
294+ def bgmv_pre_transpose_non_xla (inputs : torch .Tensor , loras : torch .Tensor ,
295+ idxs : torch .IntTensor ):
296+ T , _ = inputs .shape
297+
298+ if len (loras .shape ) == 4 :
299+ loras = loras .squeeze (axis = 1 )
300+
301+ _ , _ , L = loras .shape
302+
303+ return torch .empty ((T , L ), device = inputs .device )
304+
305+
306+ def largest_divisor (n : int , divs : List [int ]) -> int :
307+ for div in sorted (divs , reverse = True ):
308+ if n % div == 0 :
309+ return div
310+ return max (divs )
311+
312+
160313def next_multiple_of (n : int , mult : int ) -> int :
161- if n % mult == 0 :
162- return n
163- return (n // mult + 1 ) * mult
314+ return math .ceil (n / mult ) * mult
315+
164316
165317def get_bounded_value (_min : int , val : int , _max : int ) -> int :
166- return min (max (_min , val ), _max )
318+ return min (max (_min , val ), _max )
0 commit comments