Skip to content

Commit 5638e7d

Browse files
Removed some recompilations when updating LoRA metadata
Signed-off-by: Akshat Tripathi <akshat@krai.ai>
1 parent df69c52 commit 5638e7d

File tree

1 file changed

+53
-1
lines changed

1 file changed

+53
-1
lines changed

vllm/lora/punica_wrapper/punica_tpu.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import Optional, Tuple, Union
3+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
44

55
import torch
66

77
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
8+
from vllm.lora.punica_wrapper.utils import convert_mapping
9+
10+
if TYPE_CHECKING:
11+
# avoid circuit import
12+
from vllm.lora.layers import LoRAMapping
13+
from vllm.lora.models import LongContextLoRAContext
814

915
from .punica_base import PunicaWrapperBase
1016

@@ -284,6 +290,52 @@ def add_lora_logits(self,
284290
self.sampler_indices,
285291
add_inputs=True)
286292
return y.view_as(y_org)
293+
294+
# This performs the same tensor ops as the base method, except it does them
295+
# on the CPU then transfers the results to the TPU
296+
def _update_base_metadata(
297+
self,
298+
mapping: "LoRAMapping",
299+
lora_index_to_id: List[Optional[int]],
300+
max_loras: int,
301+
vocab_size: int,
302+
extra_vocab_size: int,
303+
long_lora_context: Optional["LongContextLoRAContext"] = None,
304+
):
305+
# Pad the prompt mapping to avoid running into recompiles on the TPU
306+
pad_len = len(mapping.index_mapping) - len(mapping.prompt_mapping)
307+
padding = [-1] * pad_len
308+
mapping.prompt_mapping = tuple(list(mapping.prompt_mapping) + padding)
309+
310+
(
311+
base_indices,
312+
sampler_indices,
313+
sampler_indices_padded,
314+
embeddings_indices,
315+
long_lora_offsets_tensor,
316+
indices_len,
317+
) = convert_mapping(
318+
mapping,
319+
lora_index_to_id,
320+
max_loras,
321+
vocab_size,
322+
extra_vocab_size,
323+
"cpu",
324+
long_lora_context,
325+
)
326+
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices.to(self.device))
327+
self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices.to(self.device))
328+
self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
329+
sampler_indices_padded.to(self.device))
330+
self._embeddings_indices[:embeddings_indices.
331+
shape[0], :embeddings_indices.shape[1]].copy_(
332+
embeddings_indices.to(self.device))
333+
if long_lora_offsets_tensor is not None:
334+
self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
335+
long_lora_offsets_tensor.to(self.device))
336+
else:
337+
self._long_lora_indices.zero_()
338+
self.indices_len[:] = indices_len
287339

288340
def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
289341
self.batch_size = 1

0 commit comments

Comments
 (0)