Skip to content

Commit

Permalink
Enable long-contexts + LoRA support for Intel Gaudi
Browse files Browse the repository at this point in the history
Signed-off-by: Sanju C Sudhakaran <scsudhakaran@habana.ai>
  • Loading branch information
SanjuCSudhakaran committed Feb 6, 2025
1 parent 1a6fcad commit 9dce58a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 8 deletions.
26 changes: 21 additions & 5 deletions vllm/lora/punica_wrapper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,18 @@ def convert_mapping(
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None

# Updating each element in `long_lora_offsets` with `lora_offset` slows
# down perf in HPU due to a series of `strided_insert` ops during lazy
# graph accumulation. Hence HPU appends `lora_offset` to a list and
# converts it to a tensor only after it is ready.
if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=device,
dtype=torch.long)
if device == torch.device("hpu"):
long_lora_offsets_list: List[int] = []
else:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=device,
dtype=torch.long)
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
Expand All @@ -104,10 +112,18 @@ def convert_mapping(
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
lora_indices[i] = lora_idx
if long_lora_context:
assert long_lora_offsets is not None
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets[i] = lora_offset
if device == torch.device("hpu"):
long_lora_offsets_list.append(lora_offset)
else:
assert long_lora_offsets is not None
long_lora_offsets[i] = lora_offset

if long_lora_context and device == torch.device("hpu"):
long_lora_offsets = torch.tensor(long_lora_offsets_list,
device=device,
dtype=torch.long)

indices_list: List[Union[List[int], torch.Tensor]] = [
index_mapping_indices,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,10 @@ def forward_hpu(
) -> Tuple[torch.Tensor, torch.Tensor]:
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode, apply_rotary_pos_emb)
positions = positions.flatten()
if offsets is not None:
offsets = offsets.view(positions.shape[0], -1)
positions = positions + offsets
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions).view(
num_tokens, 1, -1)
Expand Down
17 changes: 15 additions & 2 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,12 +639,25 @@ def load_model(self) -> None:
"Bias support in LoRA is not enabled in HPU yet."
assert not self.lora_config.fully_sharded_loras, \
"Fully sharded LoRAs is not enabled in HPU yet."
# It's necessary to distinguish between the
# max_position_embeddings of VLMs and LLMs.
if hasattr(self.model.config, "max_position_embeddings"):
max_pos_embeddings = (
self.model.config.max_position_embeddings)
else:
max_pos_embeddings = (
self.model.config.text_config.max_position_embeddings)

self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
self.vocab_size, self.lora_config, self.device,
self.vocab_size,
self.lora_config,
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules)
self.model.embedding_padding_modules,
max_position_embeddings=max_pos_embeddings,
)
self.model = self.lora_manager.create_lora_manager(self.model)

if self.model_config.quantization == 'inc':
Expand Down

0 comments on commit 9dce58a

Please sign in to comment.