Skip to content

Commit 643622b

Browse files
Akshat-Tripathiyaochengjixihajunjdefreitas02Jorge de Freitas
authored
[Hardware][TPU][V1] Multi-LoRA Optimisations for the V1 TPU backend (#15655)
Signed-off-by: Akshat Tripathi <akshat@krai.ai> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: xihajun <junfan@krai.ai> Signed-off-by: Jorge de Freitas <jorge.de-freitas22@imperial.ac.uk> Signed-off-by: Jorge de Freitas <jorge@krai.ai> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: xihajun <junfan@krai.ai> Co-authored-by: Jorge de Freitas <jorge.de-freitas22@imperial.ac.uk> Co-authored-by: Jorge de Freitas <jorge@krai.ai>
1 parent a09c7ca commit 643622b

File tree

9 files changed

+325
-334
lines changed

9 files changed

+325
-334
lines changed

.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,8 @@ run_and_track_test 11 "test_struct_output_generate.py" \
122122
"python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py"
123123
run_and_track_test 12 "test_moe_pallas.py" \
124124
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py"
125-
126-
# Disable the TPU LoRA tests until the feature is activated
127-
# run_and_track_test 13 "test_lora (directory)" \
128-
# "python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/"
125+
run_and_track_test 13 "test_lora.py" \
126+
"VLLM_XLA_CHECK_RECOMPILATION=0 python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/test_lora.py"
129127
130128
# After all tests have been attempted, exit with the overall status.
131129
if [ "$overall_script_exit_code" -ne 0 ]; then

tests/tpu/lora/test_pallas_kernels.py

Lines changed: 0 additions & 73 deletions
This file was deleted.

vllm/lora/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def from_local_checkpoint(
200200
weights_mapper: Optional[WeightsMapper] = None,
201201
tensorizer_config_dict: Optional[dict] = None) -> "LoRAModel":
202202
"""Create a LoRAModel from a local checkpoint.
203-
203+
204204
Args:
205205
lora_dir: The local path that has lora data.
206206
expected_lora_modules: Name of modules that are expected to be
@@ -620,7 +620,7 @@ def _match_target_modules(self, module_name: str):
620620
def _filter_unsupported_mm_module(self, module_name: str) -> bool:
621621
"""
622622
Regarding multimodal models, vLLM currently only supports adding LoRA to
623-
language model. LoRA for other modules, such as the vision tower, will
623+
language model. LoRA for other modules, such as the vision tower, will
624624
be filtered out.
625625
"""
626626
if self.supports_mm:

vllm/lora/ops/xla_ops/lora_ops.py

Lines changed: 89 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,99 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import jax
4+
import jax.numpy as jnp
35
import torch
6+
import torch.nn.functional as F
7+
import torch_xla.core.xla_builder as xb
8+
from torch.library import impl
9+
from torch_xla.experimental.custom_kernel import XLA_LIB, jax_import_guard
410

5-
# Required to register the custom ops
6-
import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import
711

12+
@jax.jit
13+
def bgmv_jax(inputs, loras, idxs):
14+
return jnp.einsum(
15+
"td,tX,Xld->tl",
16+
inputs,
17+
jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
18+
loras,
19+
)
820

9-
def bgmv_expand(inputs: torch.Tensor,
10-
lora_b_weights: torch.Tensor,
11-
output_tensor: torch.Tensor,
12-
lora_indices_tensor: torch.Tensor,
13-
add_inputs: bool = True):
21+
22+
XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor")
23+
24+
25+
@impl(XLA_LIB, "bgmv", "XLA")
26+
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
27+
if len(loras.shape) == 4:
28+
loras = loras.squeeze(axis=1)
29+
30+
jax_import_guard()
31+
return xb.call_jax(bgmv_jax, (inputs, loras, idxs))
32+
33+
34+
@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
35+
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor,
36+
idxs: torch.IntTensor):
37+
T, _ = inputs.shape
38+
if len(loras.shape) == 4:
39+
loras = loras.squeeze(axis=1)
40+
_, L, _ = loras.shape
41+
42+
return torch.empty((T, L), device=inputs.device)
43+
44+
45+
def bgmv_expand(
46+
inputs: torch.Tensor,
47+
lora_b_weights: torch.Tensor,
48+
output_tensor: torch.Tensor,
49+
lora_indices_tensor: torch.Tensor,
50+
add_inputs: bool = True,
51+
):
1452
"""
1553
Args:
1654
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
17-
18-
lora_b_weights (torch.Tensor): LoRA weights of shape
55+
56+
lora_b_weights (torch.Tensor): LoRA weights of shape
1957
[num_loras, lora_rank, hidden_size].
20-
21-
output_tensor (torch.Tensor): output tensor of shape
58+
59+
output_tensor (torch.Tensor): output tensor of shape
2260
[num_tokens, hidden_size * num_slices].
23-
24-
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
61+
62+
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
2563
indicating which LoRA matrix to use for each token.
26-
add_inputs (bool): Whether or not to add the input tensor to the output
64+
add_inputs (bool): Whether or not to add the input tensor to the output
2765
tensor.
2866
"""
2967

3068
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
31-
n_tokens = outputs.size(0)
3269

3370
limit = output_tensor.shape[0]
3471
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
3572
limit = 1
3673

37-
outputs = torch.cat(
38-
(outputs,
39-
torch.zeros((n_tokens, output_tensor.shape[1] - outputs.shape[1]),
40-
device=outputs.device)),
41-
dim=1)
74+
if output_tensor.shape[1] > outputs.shape[1]:
75+
outputs = F.pad(outputs,
76+
(0, output_tensor.shape[1] - outputs.shape[1], 0, 0))
4277

4378
if add_inputs:
44-
return output_tensor + outputs[:limit, :]
79+
return output_tensor + outputs[:limit, :output_tensor.shape[1]]
4580
else:
46-
return outputs[:limit, :]
81+
return outputs[:limit, :output_tensor.shape[1]]
4782

4883

49-
def bgmv_shrink(inputs: torch.Tensor,
50-
lora_b_weights: torch.Tensor,
51-
output_tensor: torch.Tensor,
52-
lora_indices_tensor: torch.Tensor,
53-
scaling: float = 1.0):
84+
def bgmv_shrink(
85+
inputs: torch.Tensor,
86+
lora_b_weights: torch.Tensor,
87+
lora_indices_tensor: torch.Tensor,
88+
scaling: float = 1.0,
89+
):
5490
"""
5591
Args:
5692
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
57-
lora_b_weights (torch.Tensor): LoRA weights of shape
93+
lora_b_weights (torch.Tensor): LoRA weights of shape
5894
[num_loras, lora_rank, hidden_size].
5995
output_tensor (torch.Tensor): (Unused) output tensor (placeholder).
60-
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
96+
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
6197
indicating which LoRA matrix to use for each token.
6298
scaling (float, optional): Scalar multiplier applied to the output.
6399
"""
@@ -66,39 +102,41 @@ def bgmv_shrink(inputs: torch.Tensor,
66102
lora_indices_tensor)
67103

68104

69-
def bgmv_expand_slice(inputs: torch.Tensor,
70-
lora_b_weights: torch.Tensor,
71-
output_tensor: torch.Tensor,
72-
lora_indices_tensor: torch.Tensor,
73-
slice_offset: int,
74-
slice_size: int,
75-
add_inputs: bool = True):
105+
def bgmv_expand_slice(
106+
inputs: torch.Tensor,
107+
lora_b_weights: torch.Tensor,
108+
output_tensor: torch.Tensor,
109+
lora_indices_tensor: torch.Tensor,
110+
slice_offset: int,
111+
slice_size: int,
112+
add_inputs: bool = True,
113+
):
76114
"""
77115
Args:
78116
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
79-
80-
lora_b_weights (torch.Tensor): LoRA weights of shape
117+
118+
lora_b_weights (torch.Tensor): LoRA weights of shape
81119
[num_loras, lora_rank, hidden_size].
82-
83-
output_tensor (torch.Tensor): output tensor of shape
120+
121+
output_tensor (torch.Tensor): output tensor of shape
84122
[num_tokens, hidden_size * num_slices].
85-
86-
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
123+
124+
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
87125
indicating which LoRA matrix to use for each token.
88-
add_inputs (bool): Whether or not to add the input tensor to the output
126+
add_inputs (bool): Whether or not to add the input tensor to the output
89127
tensor.
90128
"""
91129
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
92-
n_tokens = outputs.size(0)
93130

94-
outputs = torch.cat((
95-
torch.zeros((n_tokens, slice_offset), device=outputs.device),
131+
outputs = F.pad(
96132
outputs,
97-
torch.zeros(
98-
(n_tokens, output_tensor.shape[1] - (slice_offset + slice_size)),
99-
device=outputs.device),
100-
),
101-
dim=1)
133+
(
134+
slice_offset,
135+
output_tensor.shape[1] - (slice_offset + slice_size),
136+
0,
137+
0,
138+
),
139+
)
102140

103141
if add_inputs:
104142
return output_tensor + outputs

0 commit comments

Comments
 (0)