Skip to content

Commit ec261b0

Browse files
[XPU] IPEX-optimized Punica Wrapper on XPU (#21703)
Signed-off-by: chzhang <chaojun.zhang@intel.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent 04fe61a commit ec261b0

File tree

4 files changed

+321
-1
lines changed

4 files changed

+321
-1
lines changed

vllm/lora/ops/ipex_ops/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from vllm.lora.ops.ipex_ops.lora_ops import (bgmv_expand, bgmv_expand_slice,
5+
bgmv_shrink)
6+
7+
__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"]

vllm/lora/ops/ipex_ops/lora_ops.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import torch
5+
6+
from vllm.logger import init_logger
7+
8+
logger = init_logger(__name__)
9+
10+
try:
11+
import intel_extension_for_pytorch as ipex
12+
except ImportError as e:
13+
raise e
14+
15+
16+
def bgmv_shrink(inputs: torch.Tensor,
17+
lora_a_weights: torch.Tensor,
18+
output_tensor: torch.Tensor,
19+
lora_indices_tensor: torch.Tensor,
20+
scaling: float = 1.0) -> None:
21+
22+
ipex.llm.functional.bgmv_shrink(inputs, lora_a_weights, output_tensor,
23+
lora_indices_tensor, scaling)
24+
25+
26+
def bgmv_expand(inputs: torch.Tensor,
27+
lora_b_weights: torch.Tensor,
28+
output_tensor: torch.Tensor,
29+
lora_indices_tensor: torch.Tensor,
30+
add_inputs: bool = True) -> None:
31+
ipex.llm.functional.bgmv_expand(inputs, lora_b_weights, output_tensor,
32+
lora_indices_tensor, add_inputs)
33+
34+
35+
def bgmv_expand_slice(inputs: torch.Tensor,
36+
lora_b_weights: torch.Tensor,
37+
output_tensor: torch.Tensor,
38+
lora_indices_tensor: torch.Tensor,
39+
slice_offset: int,
40+
slice_size: int,
41+
add_inputs: bool = True) -> None:
42+
ipex.llm.functional.bgmv_expand_slice(inputs, lora_b_weights,
43+
output_tensor, lora_indices_tensor,
44+
slice_offset, slice_size, add_inputs)
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Based on:
5+
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
6+
Punica: Multi-Tenant LoRA Serving.
7+
https://arxiv.org/abs/2310.18547
8+
"""
9+
10+
from typing import Optional, Union, final
11+
12+
import torch
13+
14+
from vllm.lora.layers import LoRAMapping
15+
from vllm.lora.ops.ipex_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
16+
17+
from .punica_base import PunicaWrapperBase
18+
19+
20+
@final
21+
class PunicaWrapperXPU(PunicaWrapperBase):
22+
"""
23+
PunicaWrapperXPU is designed to manage and provide metadata for the punica
24+
kernel. The main function is to maintain the state information for
25+
Multi-LoRA, and to provide the interface for the punica ipex kernel.
26+
"""
27+
28+
def __init__(self, max_num_batched_tokens: int, max_batches: int,
29+
device: Union[torch.device, str], **kwargs):
30+
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
31+
device)
32+
torch._dynamo.mark_dynamic(self._token_lora_indices, 0)
33+
torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
34+
torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
35+
36+
def update_metadata(self, mapping: LoRAMapping,
37+
lora_index_to_id: list[Optional[int]], max_loras: int,
38+
vocab_size: int, extra_vocab_size: int, **kwargs):
39+
40+
self.is_prefill = mapping.is_prefill
41+
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
42+
vocab_size, extra_vocab_size)
43+
44+
def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
45+
return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
46+
47+
def _apply_shrink(
48+
self,
49+
y: torch.Tensor,
50+
x: torch.Tensor,
51+
w_t_all: torch.Tensor,
52+
scale: float,
53+
):
54+
bgmv_shrink(x, w_t_all, y, self._get_token_lora_indices(x), scale)
55+
56+
def _apply_expand(
57+
self,
58+
y: torch.Tensor,
59+
x: torch.Tensor,
60+
w_t_all: torch.Tensor,
61+
y_offset: int,
62+
y_slice_size: int,
63+
add_inputs: bool,
64+
):
65+
token_lora_indices = self._get_token_lora_indices(x)
66+
bgmv_expand_slice(x, w_t_all, y, token_lora_indices, y_offset,
67+
y_slice_size, add_inputs)
68+
69+
def add_shrink(self, y: torch.Tensor, x: torch.Tensor,
70+
lora_a_stacked: tuple[torch.Tensor,
71+
...], scale: float, **kwargs):
72+
"""
73+
Performs GEMM for multiple slices of lora_a.
74+
75+
Semantics:
76+
for i in range(len(lora_a_stacked)):
77+
y[i] += (x @ lora_a_stacked[i]) * scale
78+
79+
Args:
80+
y (torch.Tensor): Output tensors
81+
x (torch.Tensor): Input tensor
82+
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
83+
scale (float): Scaling factor for the operation
84+
"""
85+
86+
x = x.view(-1, x.shape[-1])
87+
for slice_idx in range(len(lora_a_stacked)):
88+
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
89+
scale)
90+
91+
def add_expand(self,
92+
y: torch.Tensor,
93+
x: torch.Tensor,
94+
lora_b_stacked: tuple[torch.Tensor, ...],
95+
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
96+
output_slices: tuple[int, ...],
97+
offset_start: int = 0,
98+
add_inputs=True,
99+
**kwargs) -> None:
100+
"""
101+
Performs GEMM and bias addition for multiple slices of lora_b.
102+
103+
Semantics:
104+
for i in range(len(lora_b_stacked)):
105+
slice = output_slices[i]
106+
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
107+
lora_bias_stacked[i]
108+
offset += slice
109+
110+
Args:
111+
y (torch.Tensor): Output tensor.
112+
x (torch.Tensor): Input tensors
113+
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
114+
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
115+
bias's weight
116+
output_slices (tuple[int, ...]): Every slice's size
117+
add_inputs (bool): Defaults to True.
118+
"""
119+
y_org = y
120+
y = y.view(-1, y.shape[-1])
121+
if lora_bias_stacked is not None:
122+
token_lora_indices = self._get_token_lora_indices(y)
123+
self._apply_bias(token_lora_indices, y, output_slices,
124+
lora_bias_stacked)
125+
126+
assert x.ndim == 3
127+
assert x.size(0) == len(output_slices)
128+
129+
# TODO fuse these kernels
130+
for slice_idx in range(len(lora_b_stacked)):
131+
self._apply_expand(
132+
y,
133+
x[slice_idx],
134+
lora_b_stacked[slice_idx],
135+
offset_start,
136+
output_slices[slice_idx],
137+
add_inputs=add_inputs,
138+
)
139+
offset_start += output_slices[slice_idx]
140+
y.view_as(y_org)
141+
142+
def add_lora_embedding(self,
143+
y: torch.Tensor,
144+
x: torch.Tensor,
145+
lora_b_stacked: torch.Tensor,
146+
add_inputs: bool = True,
147+
**kwargs) -> None:
148+
"""
149+
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
150+
151+
Semantics:
152+
y += x @ lora_b_stacked
153+
154+
Args:
155+
y (torch.Tensor): Output tensor.
156+
x (torch.Tensor): Input tensor.
157+
lora_b_stacked (torch.Tensor): lora_b's weights.
158+
add_inputs (bool): Default to True.
159+
"""
160+
token_lora_indices = self._get_token_lora_indices(x)
161+
bgmv_expand(x, lora_b_stacked, y, token_lora_indices, add_inputs)
162+
163+
def add_lora_linear(self,
164+
y: torch.Tensor,
165+
x: torch.Tensor,
166+
lora_a_stacked: tuple[torch.Tensor, ...],
167+
lora_b_stacked: tuple[torch.Tensor, ...],
168+
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
169+
scale: float,
170+
output_slices: tuple[int, ...],
171+
*,
172+
buffer: Optional[torch.Tensor] = None,
173+
**kwargs) -> None:
174+
"""
175+
Applicable to linear-related lora.
176+
177+
Semantics:
178+
for i in range(len(lora_a_stacked)):
179+
y[i] += (
180+
x[i].unsqueeze(0)
181+
@ lora_a_stacked[indices[i], layer_idx, :, :]
182+
@ lora_b_stacked[indices[i], layer_idx, :, :]
183+
* scale
184+
).squeeze(0)+lora_bias_stacked[i]
185+
186+
Args:
187+
y (torch.Tensor): Output tensor. Will be changed in-place.
188+
x (torch.Tensor): Input tensor
189+
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
190+
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
191+
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
192+
scale (float): Scaling factor.
193+
output_slices (tuple[int, ...]): Every slice's size.
194+
buffer (Optional[torch.Tensor]): Defaults to None.
195+
"""
196+
197+
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
198+
if lora_bias_stacked is not None:
199+
assert len(lora_bias_stacked) == len(output_slices)
200+
token_lora_indices = self._get_token_lora_indices(y)
201+
y = self._apply_bias(token_lora_indices, y, output_slices,
202+
lora_bias_stacked)
203+
204+
if buffer is None:
205+
r = lora_b_stacked[0].size(-1)
206+
# We set the buffer to be float32 by default, refer to:
207+
# https://github.com/triton-lang/triton/issues/1387
208+
buffer = torch.zeros( # type: ignore
209+
(len(output_slices), x.size(0), r),
210+
dtype=torch.float32,
211+
device=x.device,
212+
)
213+
self.add_shrink(
214+
buffer, # type: ignore
215+
x,
216+
lora_a_stacked,
217+
scale,
218+
**kwargs)
219+
self.add_expand(
220+
y,
221+
buffer, # type: ignore
222+
lora_b_stacked,
223+
None,
224+
output_slices,
225+
add_inputs=True,
226+
**kwargs)
227+
228+
def add_lora_logits(self,
229+
y: torch.Tensor,
230+
x: torch.Tensor,
231+
lora_a_stacked: torch.Tensor,
232+
lora_b_stacked: torch.Tensor,
233+
scale,
234+
*,
235+
buffer: Optional[torch.Tensor] = None,
236+
**kwargs) -> None:
237+
"""
238+
Applies lora specifically for LogitsProcessorWithLoRA.
239+
240+
Semantics:
241+
buffer = (x @ lora_a_stacked) * scale
242+
y += buffer @ lora_b_stacked
243+
244+
Args:
245+
y (torch.Tensor): Output tensor.
246+
x (torch.Tensor): Input tensor.
247+
lora_a_stacked (torch.Tensor): lora_a's weights.
248+
lora_b_stacked (torch.Tensor): lora_b's weights.
249+
scale (float): Scaling factor.
250+
buffer (Optional[torch.Tensor]): Default to None.
251+
"""
252+
y_org = y
253+
y = y.view(-1, y.shape[-1])
254+
x = x.view(-1, x.shape[-1])
255+
r = lora_b_stacked.size(-1)
256+
if buffer is None:
257+
# We set the buffer to be float32 by default, refer to:
258+
# https://github.com/triton-lang/triton/issues/1387
259+
buffer = torch.zeros((x.size(0), r),
260+
dtype=torch.float32,
261+
device=x.device)
262+
263+
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
264+
bgmv_expand(buffer,
265+
lora_b_stacked,
266+
y,
267+
self.sampler_indices,
268+
add_inputs=True)
269+
return y.view_as(y_org)

vllm/platforms/xpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def get_device_name(cls, device_id: int = 0) -> str:
6767

6868
@classmethod
6969
def get_punica_wrapper(cls) -> str:
70-
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
70+
return "vllm.lora.punica_wrapper.punica_xpu.PunicaWrapperXPU"
7171

7272
@classmethod
7373
def get_device_total_memory(cls, device_id: int = 0) -> int:

0 commit comments

Comments
 (0)