Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 42 additions & 41 deletions vllm_gaudi/ops/hpu_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,44 +655,45 @@ def forward_oot(
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
from habana_frameworks.torch.hpex.kernels import (RotaryPosEmbeddingMode, apply_rotary_pos_emb)

# Prepare cos-sin caches for long-context + LoRA with offsets for every
# forward, since the offset information wasn't available previously
if not hasattr(self, "sin") or self.recompute_cos_sin:
self.prepare_cos_sin(positions, offsets, recompute_cos_sin=True)
if hasattr(self, "scaling_factors") or hasattr(self, "scaling_factor") or self.sin is None:
self.prepare_cos_sin(positions, offsets)
num_tokens = positions.shape[0] * positions.shape[1]
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
# to query hidden dimension, so the original tensors need to be
# expanded
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
# and expansion of cos/sin tensors via concatenation
# GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
# and expansion of cos/sin tensors via repeat_interleave
rope_mode: RotaryPosEmbeddingMode
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE if self.is_neox_style else RotaryPosEmbeddingMode.PAIRWISE
sin = self.sin
cos = self.cos
query_shape = query.shape
key_shape = key.shape
query = query.view(num_tokens, -1, self.head_size)
key = key.view(num_tokens, -1, self.head_size)

if self.head_size == self.rotary_dim:
# Avoid unnecessary slicing and concatenation
query = apply_rotary_pos_emb(query, cos, sin, None, 0, rope_mode)
key = apply_rotary_pos_emb(key, cos, sin, None, 0, rope_mode)
return query.reshape(query_shape), key.reshape(key_shape)

query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
return self.forward_native(positions, query, key, offsets)
# from habana_frameworks.torch.hpex.kernels import (RotaryPosEmbeddingMode, apply_rotary_pos_emb)

# # Prepare cos-sin caches for long-context + LoRA with offsets for every
# # forward, since the offset information wasn't available previously
# if not hasattr(self, "sin") or self.recompute_cos_sin:
# self.prepare_cos_sin(positions, offsets, recompute_cos_sin=True)
# if hasattr(self, "scaling_factors") or hasattr(self, "scaling_factor") or self.sin is None:
# self.prepare_cos_sin(positions, offsets)
# num_tokens = positions.shape[0] * positions.shape[1]
# # HPU RoPE kernel requires hidden dimension for cos and sin to be equal
# # to query hidden dimension, so the original tensors need to be
# # expanded
# # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
# # and expansion of cos/sin tensors via concatenation
# # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
# # and expansion of cos/sin tensors via repeat_interleave
# rope_mode: RotaryPosEmbeddingMode
# rope_mode = RotaryPosEmbeddingMode.BLOCKWISE if self.is_neox_style else RotaryPosEmbeddingMode.PAIRWISE
# sin = self.sin
# cos = self.cos
# query_shape = query.shape
# key_shape = key.shape
# query = query.view(num_tokens, -1, self.head_size)
# key = key.view(num_tokens, -1, self.head_size)

# if self.head_size == self.rotary_dim:
# # Avoid unnecessary slicing and concatenation
# query = apply_rotary_pos_emb(query, cos, sin, None, 0, rope_mode)
# key = apply_rotary_pos_emb(key, cos, sin, None, 0, rope_mode)
# return query.reshape(query_shape), key.reshape(key_shape)

# query_rot = query[..., :self.rotary_dim]
# query_pass = query[..., self.rotary_dim:]
# query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
# query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

# key_rot = key[..., :self.rotary_dim]
# key_pass = key[..., self.rotary_dim:]
# key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
# key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
# return query, key