Skip to content

Commit a82f3fe

Browse files
Added fused lora transpose [experimental]
Signed-off-by: Akshat Tripathi <akshat@krai.ai>
1 parent 27ad793 commit a82f3fe

File tree

5 files changed

+207
-44
lines changed

5 files changed

+207
-44
lines changed

tests/lora/tpu/test_pallas_kernels.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,10 @@ def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed):
7474
T, D, L, N, seed, dtype)
7575

7676
# Run bgmv
77-
output = torch.ops.xla.bgmv(inputs, loras, idxs)
77+
if op_type == "shrink":
78+
output = torch.ops.xla.bgmv(inputs, loras, idxs)
79+
else:
80+
output = torch.ops.xla.bgmv_expand(inputs, loras.transpose(2, 3), idxs)
7881

7982
# Make sure we have no NaNs
8083
assert not torch.any(torch.isnan(output))

vllm/lora/layers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,6 +1049,9 @@ def create_lora_weights(
10491049
dtype=lora_config.lora_dtype,
10501050
device=self.device,
10511051
)
1052+
1053+
self.lora_b_stacked = torch.transpose(self.lora_b_stacked, 2, 3)
1054+
10521055
self.embeddings_tensors = torch.full(
10531056
(max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
10541057
fill_value=float("-inf"),
@@ -1081,8 +1084,8 @@ def set_lora(
10811084
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
10821085
lora_a.T, non_blocking=True)
10831086
self.lora_b_stacked[index,
1084-
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
1085-
lora_b.T, non_blocking=True)
1087+
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
1088+
lora_b, non_blocking=True)
10861089
if embeddings_tensor is not None:
10871090
self.embeddings_tensors[
10881091
index,

vllm/lora/ops/xla_ops/lora_ops.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,23 @@ def bgmv_expand(inputs: torch.Tensor,
1111
lora_b_weights: torch.Tensor,
1212
output_tensor: torch.Tensor,
1313
lora_indices_tensor: torch.Tensor,
14-
add_inputs: bool = True):
15-
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
14+
add_inputs: bool = True,
15+
fused_transpose: bool = False):
16+
17+
if fused_transpose:
18+
outputs = torch.ops.xla.bgmv_pre_transpose(inputs, lora_b_weights,
19+
lora_indices_tensor)
20+
else:
21+
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights,
22+
lora_indices_tensor)
1623

1724
limit = output_tensor.shape[0]
1825
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
1926
limit = 1
2027

2128
if output_tensor.shape[1] > outputs.shape[1]:
22-
outputs = F.pad(
23-
outputs,
24-
(0, output_tensor.shape[1] - outputs.shape[1], 0, 0)
25-
)
29+
outputs = F.pad(outputs,
30+
(0, output_tensor.shape[1] - outputs.shape[1], 0, 0))
2631

2732
if add_inputs:
2833
return output_tensor + outputs[:limit, :output_tensor.shape[1]]
@@ -49,10 +54,8 @@ def bgmv_expand_slice(inputs: torch.Tensor,
4954
add_inputs: bool = True):
5055
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
5156

52-
outputs = F.pad(
53-
outputs,
54-
(slice_offset, output_tensor.shape[1] - (slice_offset + slice_size), 0, 0)
55-
)
57+
outputs = F.pad(outputs, (slice_offset, output_tensor.shape[1] -
58+
(slice_offset + slice_size), 0, 0))
5659

5760
if add_inputs:
5861
return output_tensor + outputs

vllm/lora/ops/xla_ops/pallas.py

Lines changed: 180 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import functools
3+
import math
4+
from typing import List
35

46
import jax
57
import jax.numpy as jnp
@@ -10,14 +12,16 @@
1012
from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard,
1113
make_kernel_from_pallas)
1214

13-
def _bgmv_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, inp_ref, lora_ref, out_ref,
14-
acc_ref, mask_ref):
15+
16+
def _bgmv_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, inp_ref,
17+
lora_ref, out_ref, acc_ref, mask_ref):
18+
1519
@pl.when(pl.program_id(2) == 0)
1620
def _():
1721
acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32)
1822

1923
t = pl.program_id(0)
20-
24+
2125
ones = jnp.ones((bL, ), dtype=jnp.float32)
2226

2327
for i in range(max_num_loras):
@@ -42,16 +46,16 @@ def _():
4246
out_ref[...] = acc_ref[...].astype(out_ref.dtype)
4347

4448

45-
@functools.partial(jax.jit, static_argnames=["TOKEN_BLOCK", "LORA_BLOCK", "DIM_BLOCK"])
49+
@functools.partial(jax.jit,
50+
static_argnames=["TOKEN_BLOCK", "LORA_BLOCK", "DIM_BLOCK"])
4651
def _bgmv(
47-
idxs: jax.Array, # (T, ) int32
48-
inputs: jax.Array, # (T, D) model dtype
49-
loras: jax.Array, # (N, L, D) model dtype
50-
*,
51-
TOKEN_BLOCK: int,
52-
LORA_BLOCK: int,
53-
DIM_BLOCK: int
54-
) -> jax.Array: # (T, L) model dtype
52+
idxs: jax.Array, # (T, ) int32
53+
inputs: jax.Array, # (T, D) model dtype
54+
loras: jax.Array, # (N, L, D) model dtype
55+
*,
56+
TOKEN_BLOCK: int,
57+
LORA_BLOCK: int,
58+
DIM_BLOCK: int) -> jax.Array: # (T, L) model dtype
5559
T, D = inputs.shape
5660
N, L, _ = loras.shape
5761

@@ -60,8 +64,7 @@ def _bgmv(
6064
out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype),
6165
grid_spec=pltpu.PrefetchScalarGridSpec(
6266
num_scalar_prefetch=1,
63-
grid=(T // TOKEN_BLOCK, L // LORA_BLOCK,
64-
D // DIM_BLOCK),
67+
grid=(T // TOKEN_BLOCK, L // LORA_BLOCK, D // DIM_BLOCK),
6568
in_specs=[
6669
pl.BlockSpec((TOKEN_BLOCK, DIM_BLOCK),
6770
lambda i, j, k, block_idx: (i, k)),
@@ -103,22 +106,18 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
103106
jax_import_guard()
104107

105108
TOKEN_BLOCK = get_bounded_value(16, next_multiple_of(T, 16), 128)
106-
if is_expand: # Expand
109+
if is_expand: # Expand
107110
LORA_BLOCK = min(1024, next_multiple_of(L, 256))
108111
DIM_BLOCK = 256
109-
else: # Shrink
112+
else: # Shrink
110113
LORA_BLOCK = 256
111114
DIM_BLOCK = min(1024, next_multiple_of(D, 256))
112115

113116
kernel = make_kernel_from_pallas(
114-
functools.partial(
115-
_bgmv,
116-
TOKEN_BLOCK=TOKEN_BLOCK,
117-
LORA_BLOCK=LORA_BLOCK,
118-
DIM_BLOCK=DIM_BLOCK
119-
),
120-
bgmv_shape_function
121-
)
117+
functools.partial(_bgmv,
118+
TOKEN_BLOCK=TOKEN_BLOCK,
119+
LORA_BLOCK=LORA_BLOCK,
120+
DIM_BLOCK=DIM_BLOCK), bgmv_shape_function)
122121

123122
# Pad the loras' rank if it's too low. This is to allow it to fit in a TPU
124123
# register. This has to happen in pytorch, doing it in Jax will lead to NaNs
@@ -157,10 +156,163 @@ def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor,
157156
return torch.empty((T, L), device=inputs.device)
158157

159158

159+
# This kernel is similar to the one above but it assumes that the LoRA adapters
160+
# have been pre-transposed. This lets us skip the data copies involved in
161+
# transposing.
162+
# We only need this for the expand op since the LoRA dimensions in the shrink op
163+
# are small enough that the TPU can gather them without a data copy.
164+
def _bgmv_pre_transpose_kernel(bT: int, bL: int, max_num_loras: int, idx_ref,
165+
inp_ref, lora_ref, out_ref, acc_ref, mask_ref):
166+
167+
@pl.when(pl.program_id(2) == 0)
168+
def _():
169+
acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32)
170+
171+
t = pl.program_id(0)
172+
173+
ones = jnp.ones((bL, ), dtype=jnp.float32)
174+
175+
for i in range(max_num_loras):
176+
mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32)
177+
valid = False
178+
for j in range(bT):
179+
valid |= idx_ref[j + bT * t] == i
180+
181+
@pl.when(idx_ref[j + bT * t] == i)
182+
def _():
183+
mask_ref.at[j, :].set(ones)
184+
185+
@pl.when(valid)
186+
def _():
187+
acc_ref[...] += jax.lax.dot(
188+
inp_ref[...],
189+
lora_ref[i, ...],
190+
preferred_element_type=jnp.float32) * mask_ref[...]
191+
192+
@pl.when(pl.program_id(2) == pl.num_programs(2) - 1)
193+
def _():
194+
out_ref[...] = acc_ref[...].astype(out_ref.dtype)
195+
196+
197+
@functools.partial(jax.jit,
198+
static_argnames=["TOKEN_BLOCK", "LORA_BLOCK", "DIM_BLOCK"])
199+
def _bgmv_pre_transpose(
200+
idxs: jax.Array, # (T, ) int32
201+
inputs: jax.Array, # (T, D) model dtype
202+
loras: jax.Array, # (N, L, D) model dtype
203+
*,
204+
TOKEN_BLOCK: int,
205+
LORA_BLOCK: int,
206+
DIM_BLOCK: int) -> jax.Array: # (T, L) model dtype
207+
T, D = inputs.shape
208+
N, _, L = loras.shape
209+
210+
return pl.pallas_call(
211+
kernel=functools.partial(_bgmv_pre_transpose_kernel, TOKEN_BLOCK,
212+
LORA_BLOCK, N),
213+
out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype),
214+
grid_spec=pltpu.PrefetchScalarGridSpec(
215+
num_scalar_prefetch=1,
216+
grid=(T // TOKEN_BLOCK, L // LORA_BLOCK, D // DIM_BLOCK),
217+
in_specs=[
218+
pl.BlockSpec((TOKEN_BLOCK, DIM_BLOCK),
219+
lambda i, j, k, block_idx: (i, k)),
220+
pl.BlockSpec((N, DIM_BLOCK, LORA_BLOCK),
221+
lambda i, j, k, block_idx: (0, k, j)),
222+
],
223+
out_specs=pl.BlockSpec((TOKEN_BLOCK, LORA_BLOCK),
224+
lambda i, j, k, block_idx: (i, j)),
225+
scratch_shapes=[
226+
pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32),
227+
pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32)
228+
]),
229+
compiler_params=pltpu.TPUCompilerParams(
230+
dimension_semantics=("parallel", "parallel", "arbitrary")),
231+
name="bgmv_pre_transpose")(idxs, inputs, loras)
232+
233+
234+
def bgmv_pre_transpose_shape_function(idxs, inputs, loras):
235+
T, _ = inputs.shape
236+
_, _, L = loras.shape
237+
238+
return [((T, L), inputs.dtype)]
239+
240+
241+
XLA_LIB.define(
242+
"bgmv_pre_transpose(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", )
243+
244+
245+
@impl(XLA_LIB, "bgmv_pre_transpose", "XLA")
246+
def bgmv_pre_transpose_xla(inputs: torch.Tensor, loras: torch.Tensor,
247+
idxs: torch.IntTensor):
248+
inputs = inputs.to(dtype=loras.dtype)
249+
250+
if len(loras.shape) == 4:
251+
loras = loras.squeeze(axis=1)
252+
253+
T, _ = inputs.shape
254+
_, D, L = loras.shape
255+
256+
jax_import_guard()
257+
258+
TOKEN_BLOCK = get_bounded_value(16, next_multiple_of(T, 16), 128)
259+
LORA_BLOCK = min(1024, next_multiple_of(L, 256))
260+
DIM_BLOCK = 256
261+
262+
kernel = make_kernel_from_pallas(
263+
functools.partial(_bgmv_pre_transpose,
264+
TOKEN_BLOCK=TOKEN_BLOCK,
265+
LORA_BLOCK=LORA_BLOCK,
266+
DIM_BLOCK=DIM_BLOCK),
267+
bgmv_pre_transpose_shape_function)
268+
269+
# Pad the loras' rank if it's too low. This is to allow it to fit in a TPU
270+
# register. This has to happen in pytorch, doing it in Jax will lead to NaNs
271+
pad_L = 0
272+
if LORA_BLOCK > L or L % LORA_BLOCK != 0:
273+
pad_L = next_multiple_of(L, LORA_BLOCK) - L
274+
275+
pad_D = 0
276+
if DIM_BLOCK > D or D % DIM_BLOCK != 0:
277+
pad_D = next_multiple_of(D, DIM_BLOCK) - D
278+
279+
pad_T = 0
280+
if TOKEN_BLOCK > T or T % TOKEN_BLOCK != 0:
281+
pad_T = next_multiple_of(T, TOKEN_BLOCK) - T
282+
283+
if pad_D != 0 or pad_L != 0:
284+
loras = torch.nn.functional.pad(loras, (0, pad_L, 0, pad_D, 0, 0))
285+
if pad_D != 0 or pad_T != 0:
286+
inputs = torch.nn.functional.pad(inputs, (0, pad_D, 0, pad_T))
287+
if pad_T != T:
288+
idxs = torch.nn.functional.pad(idxs, ((0, pad_T)))
289+
290+
return kernel(idxs, inputs, loras)[:T, :L]
291+
292+
293+
@impl(XLA_LIB, "bgmv_pre_transpose", "CompositeExplicitAutograd")
294+
def bgmv_pre_transpose_non_xla(inputs: torch.Tensor, loras: torch.Tensor,
295+
idxs: torch.IntTensor):
296+
T, _ = inputs.shape
297+
298+
if len(loras.shape) == 4:
299+
loras = loras.squeeze(axis=1)
300+
301+
_, _, L = loras.shape
302+
303+
return torch.empty((T, L), device=inputs.device)
304+
305+
306+
def largest_divisor(n: int, divs: List[int]) -> int:
307+
for div in sorted(divs, reverse=True):
308+
if n % div == 0:
309+
return div
310+
return max(divs)
311+
312+
160313
def next_multiple_of(n: int, mult: int) -> int:
161-
if n % mult == 0:
162-
return n
163-
return (n // mult + 1) * mult
314+
return math.ceil(n / mult) * mult
315+
164316

165317
def get_bounded_value(_min: int, val: int, _max: int) -> int:
166-
return min(max(_min, val), _max)
318+
return min(max(_min, val), _max)

vllm/lora/punica_wrapper/punica_tpu.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,11 +277,12 @@ def add_lora_logits(self,
277277
y_org = y
278278
y = y.view(-1, y.shape[-1])
279279
x = x.view(-1, x.shape[-1])
280-
r = lora_b_stacked.size(-1)
280+
281+
rank = lora_b_stacked.size(-1)
281282
if buffer is None:
282283
# We set the buffer to be float32 by default, consistent with the
283284
# triton op
284-
buffer = torch.zeros((x.size(0), r),
285+
buffer = torch.zeros((x.size(0), rank),
285286
dtype=torch.float32,
286287
device=x.device)
287288

@@ -291,7 +292,8 @@ def add_lora_logits(self,
291292
lora_b_stacked,
292293
y,
293294
self.sampler_indices,
294-
add_inputs=True)
295+
add_inputs=True,
296+
fused_transpose=True)
295297
return y.view_as(y_org)
296298

297299
def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:

0 commit comments

Comments
 (0)