Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][Intel-Gaudi] Enable long-contexts + LoRA support for Intel Gaudi #12812

Merged
merged 2 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion vllm/lora/punica_wrapper/punica_hpu.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Tuple, Union, final
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final

import torch
from vllm_hpu_extension.ops import (dispatch_bgmv_embedding,
dispatch_bgmv_linear)

from .punica_base import PunicaWrapperBase
from .utils import convert_mapping

if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
from vllm.lora.models import LongContextLoRAContext


@final
Expand All @@ -19,6 +25,55 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int,
PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens,
max_batches, device)

def _update_base_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
):
(
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
long_lora_offsets_tensor,
indices_len,
) = convert_mapping(mapping, lora_index_to_id, max_loras, vocab_size,
extra_vocab_size, self.device, None)
# Updating each element in `long_lora_offsets` with `lora_offset` slows
# down perf in HPU due to a series of `strided_insert` ops during lazy
# graph accumulation. Hence HPU appends `lora_offset` to a list and
# converts it to a tensor only after it is ready.
if long_lora_context:
index_mapping_indices: List[int] = list(
mapping.index_mapping).copy()
long_lora_offsets: List[int] = []
for i in range(len(index_mapping_indices)):
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets.append(lora_offset)
long_lora_offsets_tensor = torch.tensor(long_lora_offsets,
device=self.device,
dtype=torch.long)
indices_len[-1] = long_lora_offsets_tensor.shape[-1]

self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
sampler_indices_padded)
self._embeddings_indices[:embeddings_indices.
shape[0], :embeddings_indices.shape[1]].copy_(
embeddings_indices)
if long_lora_offsets_tensor is not None:
self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
long_lora_offsets_tensor)
else:
self._long_lora_indices.zero_()
self.indices_len[:] = indices_len

def add_lora_embedding(self,
y: torch.Tensor,
x: torch.Tensor,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,10 @@ def forward_hpu(
) -> Tuple[torch.Tensor, torch.Tensor]:
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode, apply_rotary_pos_emb)
positions = positions.flatten()
if offsets is not None:
offsets = offsets.view(positions.shape[0], -1)
positions = positions + offsets
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions).view(
num_tokens, 1, -1)
Expand Down
17 changes: 15 additions & 2 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,12 +639,25 @@ def load_model(self) -> None:
"Bias support in LoRA is not enabled in HPU yet."
assert not self.lora_config.fully_sharded_loras, \
"Fully sharded LoRAs is not enabled in HPU yet."
# It's necessary to distinguish between the
# max_position_embeddings of VLMs and LLMs.
if hasattr(self.model.config, "max_position_embeddings"):
max_pos_embeddings = (
self.model.config.max_position_embeddings)
else:
max_pos_embeddings = (
self.model.config.text_config.max_position_embeddings)

self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
self.vocab_size, self.lora_config, self.device,
self.vocab_size,
self.lora_config,
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules)
self.model.embedding_padding_modules,
max_position_embeddings=max_pos_embeddings,
)
self.model = self.lora_manager.create_lora_manager(self.model)

if self.model_config.quantization == 'inc':
Expand Down