From 9dce58a28015e8c01b9fa3955d48cd9ac3a7a074 Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Sun, 5 Jan 2025 18:14:48 +0200 Subject: [PATCH] Enable long-contexts + LoRA support for Intel Gaudi Signed-off-by: Sanju C Sudhakaran --- vllm/lora/punica_wrapper/utils.py | 26 +++++++++++++++---- .../model_executor/layers/rotary_embedding.py | 3 ++- vllm/worker/hpu_model_runner.py | 17 ++++++++++-- 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index dbc2d27c597f2..b1759a489223b 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -88,10 +88,18 @@ def convert_mapping( embedding_indices = index_mapping_indices.copy() lora_indices = index_mapping_indices.copy() long_lora_offsets: Optional[torch.Tensor] = None + + # Updating each element in `long_lora_offsets` with `lora_offset` slows + # down perf in HPU due to a series of `strided_insert` ops during lazy + # graph accumulation. Hence HPU appends `lora_offset` to a list and + # converts it to a tensor only after it is ready. if long_lora_context: - long_lora_offsets = torch.zeros(len(index_mapping_indices), - device=device, - dtype=torch.long) + if device == torch.device("hpu"): + long_lora_offsets_list: List[int] = [] + else: + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device=device, + dtype=torch.long) prompt_mapping: List[int] = [ lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping @@ -104,10 +112,18 @@ def convert_mapping( embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 lora_indices[i] = lora_idx if long_lora_context: - assert long_lora_offsets is not None lora_offset: int = long_lora_context.offsets_by_lora_id.get( index_mapping_indices[i], 0) - long_lora_offsets[i] = lora_offset + if device == torch.device("hpu"): + long_lora_offsets_list.append(lora_offset) + else: + assert long_lora_offsets is not None + long_lora_offsets[i] = lora_offset + + if long_lora_context and device == torch.device("hpu"): + long_lora_offsets = torch.tensor(long_lora_offsets_list, + device=device, + dtype=torch.long) indices_list: List[Union[List[int], torch.Tensor]] = [ index_mapping_indices, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index b3b9b0e876057..2fffcb7b6e6e1 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -206,9 +206,10 @@ def forward_hpu( ) -> Tuple[torch.Tensor, torch.Tensor]: from habana_frameworks.torch.hpex.kernels import ( RotaryPosEmbeddingMode, apply_rotary_pos_emb) - positions = positions.flatten() if offsets is not None: + offsets = offsets.view(positions.shape[0], -1) positions = positions + offsets + positions = positions.flatten() num_tokens = positions.shape[0] cos_sin = self.cos_sin_cache.index_select(0, positions).view( num_tokens, 1, -1) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index b846d4387ba58..774049a5281ee 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -639,12 +639,25 @@ def load_model(self) -> None: "Bias support in LoRA is not enabled in HPU yet." assert not self.lora_config.fully_sharded_loras, \ "Fully sharded LoRAs is not enabled in HPU yet." + # It's necessary to distinguish between the + # max_position_embeddings of VLMs and LLMs. + if hasattr(self.model.config, "max_position_embeddings"): + max_pos_embeddings = ( + self.model.config.max_position_embeddings) + else: + max_pos_embeddings = ( + self.model.config.text_config.max_position_embeddings) + self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, - self.vocab_size, self.lora_config, self.device, + self.vocab_size, + self.lora_config, + self.device, self.model.embedding_modules, - self.model.embedding_padding_modules) + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) self.model = self.lora_manager.create_lora_manager(self.model) if self.model_config.quantization == 'inc':