11# SPDX-License-Identifier: Apache-2.0
22
3+ import jax
4+ import jax .numpy as jnp
35import torch
6+ import torch .nn .functional as F
7+ import torch_xla .core .xla_builder as xb
8+ from torch .library import impl
9+ from torch_xla .experimental .custom_kernel import XLA_LIB , jax_import_guard
410
5- # Required to register the custom ops
6- import vllm .lora .ops .xla_ops .pallas # noqa # pylint: disable=unused-import
711
12+ @jax .jit
13+ def bgmv_jax (inputs , loras , idxs ):
14+ return jnp .einsum (
15+ "td,tX,Xld->tl" ,
16+ inputs ,
17+ jax .nn .one_hot (idxs , loras .shape [0 ], dtype = inputs .dtype ),
18+ loras ,
19+ )
820
9- def bgmv_expand (inputs : torch .Tensor ,
10- lora_b_weights : torch .Tensor ,
11- output_tensor : torch .Tensor ,
12- lora_indices_tensor : torch .Tensor ,
13- add_inputs : bool = True ):
21+
22+ XLA_LIB .define ("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor" )
23+
24+
25+ @impl (XLA_LIB , "bgmv" , "XLA" )
26+ def bgmv_xla (inputs : torch .Tensor , loras : torch .Tensor , idxs : torch .IntTensor ):
27+ if len (loras .shape ) == 4 :
28+ loras = loras .squeeze (axis = 1 )
29+
30+ jax_import_guard ()
31+ return xb .call_jax (bgmv_jax , (inputs , loras , idxs ))
32+
33+
34+ @impl (XLA_LIB , "bgmv" , "CompositeExplicitAutograd" )
35+ def bgmv_non_xla (inputs : torch .Tensor , loras : torch .Tensor ,
36+ idxs : torch .IntTensor ):
37+ T , _ = inputs .shape
38+ if len (loras .shape ) == 4 :
39+ loras = loras .squeeze (axis = 1 )
40+ _ , L , _ = loras .shape
41+
42+ return torch .empty ((T , L ), device = inputs .device )
43+
44+
45+ def bgmv_expand (
46+ inputs : torch .Tensor ,
47+ lora_b_weights : torch .Tensor ,
48+ output_tensor : torch .Tensor ,
49+ lora_indices_tensor : torch .Tensor ,
50+ add_inputs : bool = True ,
51+ ):
1452 """
1553 Args:
1654 inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
17-
18- lora_b_weights (torch.Tensor): LoRA weights of shape
55+
56+ lora_b_weights (torch.Tensor): LoRA weights of shape
1957 [num_loras, lora_rank, hidden_size].
20-
21- output_tensor (torch.Tensor): output tensor of shape
58+
59+ output_tensor (torch.Tensor): output tensor of shape
2260 [num_tokens, hidden_size * num_slices].
23-
24- lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
61+
62+ lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
2563 indicating which LoRA matrix to use for each token.
26- add_inputs (bool): Whether or not to add the input tensor to the output
64+ add_inputs (bool): Whether or not to add the input tensor to the output
2765 tensor.
2866 """
2967
3068 outputs = torch .ops .xla .bgmv (inputs , lora_b_weights , lora_indices_tensor )
31- n_tokens = outputs .size (0 )
3269
3370 limit = output_tensor .shape [0 ]
3471 if outputs .shape [0 ] == 1 and output_tensor .shape [0 ] != 1 :
3572 limit = 1
3673
37- outputs = torch .cat (
38- (outputs ,
39- torch .zeros ((n_tokens , output_tensor .shape [1 ] - outputs .shape [1 ]),
40- device = outputs .device )),
41- dim = 1 )
74+ if output_tensor .shape [1 ] > outputs .shape [1 ]:
75+ outputs = F .pad (outputs ,
76+ (0 , output_tensor .shape [1 ] - outputs .shape [1 ], 0 , 0 ))
4277
4378 if add_inputs :
44- return output_tensor + outputs [:limit , :]
79+ return output_tensor + outputs [:limit , :output_tensor . shape [ 1 ] ]
4580 else :
46- return outputs [:limit , :]
81+ return outputs [:limit , :output_tensor . shape [ 1 ] ]
4782
4883
49- def bgmv_shrink (inputs : torch .Tensor ,
50- lora_b_weights : torch .Tensor ,
51- output_tensor : torch .Tensor ,
52- lora_indices_tensor : torch .Tensor ,
53- scaling : float = 1.0 ):
84+ def bgmv_shrink (
85+ inputs : torch .Tensor ,
86+ lora_b_weights : torch .Tensor ,
87+ lora_indices_tensor : torch .Tensor ,
88+ scaling : float = 1.0 ,
89+ ):
5490 """
5591 Args:
5692 inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
57- lora_b_weights (torch.Tensor): LoRA weights of shape
93+ lora_b_weights (torch.Tensor): LoRA weights of shape
5894 [num_loras, lora_rank, hidden_size].
5995 output_tensor (torch.Tensor): (Unused) output tensor (placeholder).
60- lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
96+ lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
6197 indicating which LoRA matrix to use for each token.
6298 scaling (float, optional): Scalar multiplier applied to the output.
6399 """
@@ -66,39 +102,41 @@ def bgmv_shrink(inputs: torch.Tensor,
66102 lora_indices_tensor )
67103
68104
69- def bgmv_expand_slice (inputs : torch .Tensor ,
70- lora_b_weights : torch .Tensor ,
71- output_tensor : torch .Tensor ,
72- lora_indices_tensor : torch .Tensor ,
73- slice_offset : int ,
74- slice_size : int ,
75- add_inputs : bool = True ):
105+ def bgmv_expand_slice (
106+ inputs : torch .Tensor ,
107+ lora_b_weights : torch .Tensor ,
108+ output_tensor : torch .Tensor ,
109+ lora_indices_tensor : torch .Tensor ,
110+ slice_offset : int ,
111+ slice_size : int ,
112+ add_inputs : bool = True ,
113+ ):
76114 """
77115 Args:
78116 inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
79-
80- lora_b_weights (torch.Tensor): LoRA weights of shape
117+
118+ lora_b_weights (torch.Tensor): LoRA weights of shape
81119 [num_loras, lora_rank, hidden_size].
82-
83- output_tensor (torch.Tensor): output tensor of shape
120+
121+ output_tensor (torch.Tensor): output tensor of shape
84122 [num_tokens, hidden_size * num_slices].
85-
86- lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
123+
124+ lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
87125 indicating which LoRA matrix to use for each token.
88- add_inputs (bool): Whether or not to add the input tensor to the output
126+ add_inputs (bool): Whether or not to add the input tensor to the output
89127 tensor.
90128 """
91129 outputs = torch .ops .xla .bgmv (inputs , lora_b_weights , lora_indices_tensor )
92- n_tokens = outputs .size (0 )
93130
94- outputs = torch .cat ((
95- torch .zeros ((n_tokens , slice_offset ), device = outputs .device ),
131+ outputs = F .pad (
96132 outputs ,
97- torch .zeros (
98- (n_tokens , output_tensor .shape [1 ] - (slice_offset + slice_size )),
99- device = outputs .device ),
100- ),
101- dim = 1 )
133+ (
134+ slice_offset ,
135+ output_tensor .shape [1 ] - (slice_offset + slice_size ),
136+ 0 ,
137+ 0 ,
138+ ),
139+ )
102140
103141 if add_inputs :
104142 return output_tensor + outputs
0 commit comments