Skip to content
88 changes: 46 additions & 42 deletions vllm_gaudi/ops/hpu_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,45 +655,49 @@ def forward_oot(
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(positions, query, key, offsets)
# from habana_frameworks.torch.hpex.kernels import (RotaryPosEmbeddingMode, apply_rotary_pos_emb)

# # Prepare cos-sin caches for long-context + LoRA with offsets for every
# # forward, since the offset information wasn't available previously
# if not hasattr(self, "sin") or self.recompute_cos_sin:
# self.prepare_cos_sin(positions, offsets, recompute_cos_sin=True)
# if hasattr(self, "scaling_factors") or hasattr(self, "scaling_factor") or self.sin is None:
# self.prepare_cos_sin(positions, offsets)
# num_tokens = positions.shape[0] * positions.shape[1]
# # HPU RoPE kernel requires hidden dimension for cos and sin to be equal
# # to query hidden dimension, so the original tensors need to be
# # expanded
# # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
# # and expansion of cos/sin tensors via concatenation
# # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
# # and expansion of cos/sin tensors via repeat_interleave
# rope_mode: RotaryPosEmbeddingMode
# rope_mode = RotaryPosEmbeddingMode.BLOCKWISE if self.is_neox_style else RotaryPosEmbeddingMode.PAIRWISE
# sin = self.sin
# cos = self.cos
# query_shape = query.shape
# key_shape = key.shape
# query = query.view(num_tokens, -1, self.head_size)
# key = key.view(num_tokens, -1, self.head_size)

# if self.head_size == self.rotary_dim:
# # Avoid unnecessary slicing and concatenation
# query = apply_rotary_pos_emb(query, cos, sin, None, 0, rope_mode)
# key = apply_rotary_pos_emb(key, cos, sin, None, 0, rope_mode)
# return query.reshape(query_shape), key.reshape(key_shape)

# query_rot = query[..., :self.rotary_dim]
# query_pass = query[..., self.rotary_dim:]
# query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
# query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

# key_rot = key[..., :self.rotary_dim]
# key_pass = key[..., self.rotary_dim:]
# key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
# key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
# return query, key
from habana_frameworks.torch.hpex.kernels import (RotaryPosEmbeddingMode, apply_rotary_pos_emb)

num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if positions.ndim == 2:
assert self.mrope_section

cos = torch.cat([m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], dim=-1)
sin = torch.cat([m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], dim=-1)
if self.is_neox_style:
cos = torch.cat((cos, cos), dim=-1).unsqueeze(-2)
sin = torch.cat((sin, sin), dim=-1).unsqueeze(-2)
else:
sin = torch.repeat_interleave(sin, 2, dim=-1, output_size=cos_sin.shape[-1]).unsqueeze(-2)
cos = torch.repeat_interleave(cos, 2, dim=-1, output_size=cos_sin.shape[-1]).unsqueeze(-2)
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
# to query hidden dimension, so the original tensors need to be
# expanded
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
# and expansion of cos/sin tensors via concatenation
# GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
# and expansion of cos/sin tensors via repeat_interleave
rope_mode: RotaryPosEmbeddingMode
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE if self.is_neox_style else RotaryPosEmbeddingMode.PAIRWISE
query_shape = query.shape
key_shape = key.shape
query = query.view(num_tokens, -1, self.head_size)
key = key.view(num_tokens, -1, self.head_size)

if self.head_size == self.rotary_dim:
# Avoid unnecessary slicing and concatenation
query = apply_rotary_pos_emb(query, cos, sin, None, 0, rope_mode)
key = apply_rotary_pos_emb(key, cos, sin, None, 0, rope_mode)
return query.reshape(query_shape), key.reshape(key_shape)

query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a comparison with existing forward_native, seems major difference is apply_rotary_pos_emb vs apply_rotary_emb_torch, may you check if we do gets perf gain with the oot impl, or we can use native ?

Copy link
Contributor Author

@attafosu attafosu Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I did some quick tests and there's some perf gain over the default:
forward_native: 11.53 tok/sec
forward_oot: 12.32 tok/sec
This is on a smaller sized image and I expect it to be more pronounced on an even bigger input (text or image)

key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key