|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 |
|
3 | | -from typing import Optional, Tuple, Union |
| 3 | +from typing import TYPE_CHECKING, List, Optional, Tuple, Union |
4 | 4 |
|
5 | 5 | import torch |
6 | 6 |
|
7 | 7 | from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink |
| 8 | +from vllm.lora.punica_wrapper.utils import convert_mapping |
| 9 | + |
| 10 | +if TYPE_CHECKING: |
| 11 | + # avoid circuit import |
| 12 | + from vllm.lora.layers import LoRAMapping |
| 13 | + from vllm.lora.models import LongContextLoRAContext |
8 | 14 |
|
9 | 15 | from .punica_base import PunicaWrapperBase |
10 | 16 |
|
@@ -284,6 +290,52 @@ def add_lora_logits(self, |
284 | 290 | self.sampler_indices, |
285 | 291 | add_inputs=True) |
286 | 292 | return y.view_as(y_org) |
| 293 | + |
| 294 | + # This performs the same tensor ops as the base method, except it does them |
| 295 | + # on the CPU then transfers the results to the TPU |
| 296 | + def _update_base_metadata( |
| 297 | + self, |
| 298 | + mapping: "LoRAMapping", |
| 299 | + lora_index_to_id: List[Optional[int]], |
| 300 | + max_loras: int, |
| 301 | + vocab_size: int, |
| 302 | + extra_vocab_size: int, |
| 303 | + long_lora_context: Optional["LongContextLoRAContext"] = None, |
| 304 | + ): |
| 305 | + # Pad the prompt mapping to avoid running into recompiles on the TPU |
| 306 | + pad_len = len(mapping.index_mapping) - len(mapping.prompt_mapping) |
| 307 | + padding = [-1] * pad_len |
| 308 | + mapping.prompt_mapping = tuple(list(mapping.prompt_mapping) + padding) |
| 309 | + |
| 310 | + ( |
| 311 | + base_indices, |
| 312 | + sampler_indices, |
| 313 | + sampler_indices_padded, |
| 314 | + embeddings_indices, |
| 315 | + long_lora_offsets_tensor, |
| 316 | + indices_len, |
| 317 | + ) = convert_mapping( |
| 318 | + mapping, |
| 319 | + lora_index_to_id, |
| 320 | + max_loras, |
| 321 | + vocab_size, |
| 322 | + extra_vocab_size, |
| 323 | + "cpu", |
| 324 | + long_lora_context, |
| 325 | + ) |
| 326 | + self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices.to(self.device)) |
| 327 | + self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices.to(self.device)) |
| 328 | + self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( |
| 329 | + sampler_indices_padded.to(self.device)) |
| 330 | + self._embeddings_indices[:embeddings_indices. |
| 331 | + shape[0], :embeddings_indices.shape[1]].copy_( |
| 332 | + embeddings_indices.to(self.device)) |
| 333 | + if long_lora_offsets_tensor is not None: |
| 334 | + self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( |
| 335 | + long_lora_offsets_tensor.to(self.device)) |
| 336 | + else: |
| 337 | + self._long_lora_indices.zero_() |
| 338 | + self.indices_len[:] = indices_len |
287 | 339 |
|
288 | 340 | def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: |
289 | 341 | self.batch_size = 1 |
|
0 commit comments