|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +""" |
| 4 | +Based on: |
| 5 | +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). |
| 6 | +Punica: Multi-Tenant LoRA Serving. |
| 7 | +https://arxiv.org/abs/2310.18547 |
| 8 | +""" |
| 9 | + |
| 10 | +from typing import Optional, Union, final |
| 11 | + |
| 12 | +import torch |
| 13 | + |
| 14 | +from vllm.lora.layers import LoRAMapping |
| 15 | +from vllm.lora.ops.ipex_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink |
| 16 | + |
| 17 | +from .punica_base import PunicaWrapperBase |
| 18 | + |
| 19 | + |
| 20 | +@final |
| 21 | +class PunicaWrapperXPU(PunicaWrapperBase): |
| 22 | + """ |
| 23 | + PunicaWrapperXPU is designed to manage and provide metadata for the punica |
| 24 | + kernel. The main function is to maintain the state information for |
| 25 | + Multi-LoRA, and to provide the interface for the punica ipex kernel. |
| 26 | + """ |
| 27 | + |
| 28 | + def __init__(self, max_num_batched_tokens: int, max_batches: int, |
| 29 | + device: Union[torch.device, str], **kwargs): |
| 30 | + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, |
| 31 | + device) |
| 32 | + torch._dynamo.mark_dynamic(self._token_lora_indices, 0) |
| 33 | + torch._dynamo.mark_dynamic(self._embeddings_indices, 1) |
| 34 | + torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0) |
| 35 | + |
| 36 | + def update_metadata(self, mapping: LoRAMapping, |
| 37 | + lora_index_to_id: list[Optional[int]], max_loras: int, |
| 38 | + vocab_size: int, extra_vocab_size: int, **kwargs): |
| 39 | + |
| 40 | + self.is_prefill = mapping.is_prefill |
| 41 | + self._update_base_metadata(mapping, lora_index_to_id, max_loras, |
| 42 | + vocab_size, extra_vocab_size) |
| 43 | + |
| 44 | + def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor: |
| 45 | + return torch.narrow(self._token_lora_indices, 0, 0, x.size(0)) |
| 46 | + |
| 47 | + def _apply_shrink( |
| 48 | + self, |
| 49 | + y: torch.Tensor, |
| 50 | + x: torch.Tensor, |
| 51 | + w_t_all: torch.Tensor, |
| 52 | + scale: float, |
| 53 | + ): |
| 54 | + bgmv_shrink(x, w_t_all, y, self._get_token_lora_indices(x), scale) |
| 55 | + |
| 56 | + def _apply_expand( |
| 57 | + self, |
| 58 | + y: torch.Tensor, |
| 59 | + x: torch.Tensor, |
| 60 | + w_t_all: torch.Tensor, |
| 61 | + y_offset: int, |
| 62 | + y_slice_size: int, |
| 63 | + add_inputs: bool, |
| 64 | + ): |
| 65 | + token_lora_indices = self._get_token_lora_indices(x) |
| 66 | + bgmv_expand_slice(x, w_t_all, y, token_lora_indices, y_offset, |
| 67 | + y_slice_size, add_inputs) |
| 68 | + |
| 69 | + def add_shrink(self, y: torch.Tensor, x: torch.Tensor, |
| 70 | + lora_a_stacked: tuple[torch.Tensor, |
| 71 | + ...], scale: float, **kwargs): |
| 72 | + """ |
| 73 | + Performs GEMM for multiple slices of lora_a. |
| 74 | + |
| 75 | + Semantics: |
| 76 | + for i in range(len(lora_a_stacked)): |
| 77 | + y[i] += (x @ lora_a_stacked[i]) * scale |
| 78 | + |
| 79 | + Args: |
| 80 | + y (torch.Tensor): Output tensors |
| 81 | + x (torch.Tensor): Input tensor |
| 82 | + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights |
| 83 | + scale (float): Scaling factor for the operation |
| 84 | + """ |
| 85 | + |
| 86 | + x = x.view(-1, x.shape[-1]) |
| 87 | + for slice_idx in range(len(lora_a_stacked)): |
| 88 | + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], |
| 89 | + scale) |
| 90 | + |
| 91 | + def add_expand(self, |
| 92 | + y: torch.Tensor, |
| 93 | + x: torch.Tensor, |
| 94 | + lora_b_stacked: tuple[torch.Tensor, ...], |
| 95 | + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], |
| 96 | + output_slices: tuple[int, ...], |
| 97 | + offset_start: int = 0, |
| 98 | + add_inputs=True, |
| 99 | + **kwargs) -> None: |
| 100 | + """ |
| 101 | + Performs GEMM and bias addition for multiple slices of lora_b. |
| 102 | + |
| 103 | + Semantics: |
| 104 | + for i in range(len(lora_b_stacked)): |
| 105 | + slice = output_slices[i] |
| 106 | + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + |
| 107 | + lora_bias_stacked[i] |
| 108 | + offset += slice |
| 109 | + |
| 110 | + Args: |
| 111 | + y (torch.Tensor): Output tensor. |
| 112 | + x (torch.Tensor): Input tensors |
| 113 | + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight |
| 114 | + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): |
| 115 | + bias's weight |
| 116 | + output_slices (tuple[int, ...]): Every slice's size |
| 117 | + add_inputs (bool): Defaults to True. |
| 118 | + """ |
| 119 | + y_org = y |
| 120 | + y = y.view(-1, y.shape[-1]) |
| 121 | + if lora_bias_stacked is not None: |
| 122 | + token_lora_indices = self._get_token_lora_indices(y) |
| 123 | + self._apply_bias(token_lora_indices, y, output_slices, |
| 124 | + lora_bias_stacked) |
| 125 | + |
| 126 | + assert x.ndim == 3 |
| 127 | + assert x.size(0) == len(output_slices) |
| 128 | + |
| 129 | + # TODO fuse these kernels |
| 130 | + for slice_idx in range(len(lora_b_stacked)): |
| 131 | + self._apply_expand( |
| 132 | + y, |
| 133 | + x[slice_idx], |
| 134 | + lora_b_stacked[slice_idx], |
| 135 | + offset_start, |
| 136 | + output_slices[slice_idx], |
| 137 | + add_inputs=add_inputs, |
| 138 | + ) |
| 139 | + offset_start += output_slices[slice_idx] |
| 140 | + y.view_as(y_org) |
| 141 | + |
| 142 | + def add_lora_embedding(self, |
| 143 | + y: torch.Tensor, |
| 144 | + x: torch.Tensor, |
| 145 | + lora_b_stacked: torch.Tensor, |
| 146 | + add_inputs: bool = True, |
| 147 | + **kwargs) -> None: |
| 148 | + """ |
| 149 | + Applies lora specifically for VocabParallelEmbeddingWithLoRA. |
| 150 | +
|
| 151 | + Semantics: |
| 152 | + y += x @ lora_b_stacked |
| 153 | +
|
| 154 | + Args: |
| 155 | + y (torch.Tensor): Output tensor. |
| 156 | + x (torch.Tensor): Input tensor. |
| 157 | + lora_b_stacked (torch.Tensor): lora_b's weights. |
| 158 | + add_inputs (bool): Default to True. |
| 159 | + """ |
| 160 | + token_lora_indices = self._get_token_lora_indices(x) |
| 161 | + bgmv_expand(x, lora_b_stacked, y, token_lora_indices, add_inputs) |
| 162 | + |
| 163 | + def add_lora_linear(self, |
| 164 | + y: torch.Tensor, |
| 165 | + x: torch.Tensor, |
| 166 | + lora_a_stacked: tuple[torch.Tensor, ...], |
| 167 | + lora_b_stacked: tuple[torch.Tensor, ...], |
| 168 | + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], |
| 169 | + scale: float, |
| 170 | + output_slices: tuple[int, ...], |
| 171 | + *, |
| 172 | + buffer: Optional[torch.Tensor] = None, |
| 173 | + **kwargs) -> None: |
| 174 | + """ |
| 175 | + Applicable to linear-related lora. |
| 176 | +
|
| 177 | + Semantics: |
| 178 | + for i in range(len(lora_a_stacked)): |
| 179 | + y[i] += ( |
| 180 | + x[i].unsqueeze(0) |
| 181 | + @ lora_a_stacked[indices[i], layer_idx, :, :] |
| 182 | + @ lora_b_stacked[indices[i], layer_idx, :, :] |
| 183 | + * scale |
| 184 | + ).squeeze(0)+lora_bias_stacked[i] |
| 185 | +
|
| 186 | + Args: |
| 187 | + y (torch.Tensor): Output tensor. Will be changed in-place. |
| 188 | + x (torch.Tensor): Input tensor |
| 189 | + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. |
| 190 | + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. |
| 191 | + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. |
| 192 | + scale (float): Scaling factor. |
| 193 | + output_slices (tuple[int, ...]): Every slice's size. |
| 194 | + buffer (Optional[torch.Tensor]): Defaults to None. |
| 195 | + """ |
| 196 | + |
| 197 | + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) |
| 198 | + if lora_bias_stacked is not None: |
| 199 | + assert len(lora_bias_stacked) == len(output_slices) |
| 200 | + token_lora_indices = self._get_token_lora_indices(y) |
| 201 | + y = self._apply_bias(token_lora_indices, y, output_slices, |
| 202 | + lora_bias_stacked) |
| 203 | + |
| 204 | + if buffer is None: |
| 205 | + r = lora_b_stacked[0].size(-1) |
| 206 | + # We set the buffer to be float32 by default, refer to: |
| 207 | + # https://github.com/triton-lang/triton/issues/1387 |
| 208 | + buffer = torch.zeros( # type: ignore |
| 209 | + (len(output_slices), x.size(0), r), |
| 210 | + dtype=torch.float32, |
| 211 | + device=x.device, |
| 212 | + ) |
| 213 | + self.add_shrink( |
| 214 | + buffer, # type: ignore |
| 215 | + x, |
| 216 | + lora_a_stacked, |
| 217 | + scale, |
| 218 | + **kwargs) |
| 219 | + self.add_expand( |
| 220 | + y, |
| 221 | + buffer, # type: ignore |
| 222 | + lora_b_stacked, |
| 223 | + None, |
| 224 | + output_slices, |
| 225 | + add_inputs=True, |
| 226 | + **kwargs) |
| 227 | + |
| 228 | + def add_lora_logits(self, |
| 229 | + y: torch.Tensor, |
| 230 | + x: torch.Tensor, |
| 231 | + lora_a_stacked: torch.Tensor, |
| 232 | + lora_b_stacked: torch.Tensor, |
| 233 | + scale, |
| 234 | + *, |
| 235 | + buffer: Optional[torch.Tensor] = None, |
| 236 | + **kwargs) -> None: |
| 237 | + """ |
| 238 | + Applies lora specifically for LogitsProcessorWithLoRA. |
| 239 | + |
| 240 | + Semantics: |
| 241 | + buffer = (x @ lora_a_stacked) * scale |
| 242 | + y += buffer @ lora_b_stacked |
| 243 | +
|
| 244 | + Args: |
| 245 | + y (torch.Tensor): Output tensor. |
| 246 | + x (torch.Tensor): Input tensor. |
| 247 | + lora_a_stacked (torch.Tensor): lora_a's weights. |
| 248 | + lora_b_stacked (torch.Tensor): lora_b's weights. |
| 249 | + scale (float): Scaling factor. |
| 250 | + buffer (Optional[torch.Tensor]): Default to None. |
| 251 | + """ |
| 252 | + y_org = y |
| 253 | + y = y.view(-1, y.shape[-1]) |
| 254 | + x = x.view(-1, x.shape[-1]) |
| 255 | + r = lora_b_stacked.size(-1) |
| 256 | + if buffer is None: |
| 257 | + # We set the buffer to be float32 by default, refer to: |
| 258 | + # https://github.com/triton-lang/triton/issues/1387 |
| 259 | + buffer = torch.zeros((x.size(0), r), |
| 260 | + dtype=torch.float32, |
| 261 | + device=x.device) |
| 262 | + |
| 263 | + bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) |
| 264 | + bgmv_expand(buffer, |
| 265 | + lora_b_stacked, |
| 266 | + y, |
| 267 | + self.sampler_indices, |
| 268 | + add_inputs=True) |
| 269 | + return y.view_as(y_org) |
0 commit comments