Skip to content

Commit b94548a

Browse files
attafosuxuechendi
andauthored
Bug fix: hpu mrope (#167)
- HPU Mrope implementation had a bug which was exposed by vllm-project/vllm#24444 - Initial workaround was to use the default implementation: #162 - This PR fixes the bug in the HPU mrope --------- Signed-off-by: attafosu <thomas.atta-fosu@intel.com> Co-authored-by: Chendi.Xue <chendi.xue@intel.com>
1 parent b6a1c7c commit b94548a

File tree

1 file changed

+46
-42
lines changed

1 file changed

+46
-42
lines changed

vllm_gaudi/ops/hpu_rotary_embedding.py

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -655,45 +655,49 @@ def forward_oot(
655655
key: torch.Tensor,
656656
offsets: Optional[torch.Tensor] = None,
657657
) -> tuple[torch.Tensor, torch.Tensor]:
658-
return self.forward_native(positions, query, key, offsets)
659-
# from habana_frameworks.torch.hpex.kernels import (RotaryPosEmbeddingMode, apply_rotary_pos_emb)
660-
661-
# # Prepare cos-sin caches for long-context + LoRA with offsets for every
662-
# # forward, since the offset information wasn't available previously
663-
# if not hasattr(self, "sin") or self.recompute_cos_sin:
664-
# self.prepare_cos_sin(positions, offsets, recompute_cos_sin=True)
665-
# if hasattr(self, "scaling_factors") or hasattr(self, "scaling_factor") or self.sin is None:
666-
# self.prepare_cos_sin(positions, offsets)
667-
# num_tokens = positions.shape[0] * positions.shape[1]
668-
# # HPU RoPE kernel requires hidden dimension for cos and sin to be equal
669-
# # to query hidden dimension, so the original tensors need to be
670-
# # expanded
671-
# # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
672-
# # and expansion of cos/sin tensors via concatenation
673-
# # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
674-
# # and expansion of cos/sin tensors via repeat_interleave
675-
# rope_mode: RotaryPosEmbeddingMode
676-
# rope_mode = RotaryPosEmbeddingMode.BLOCKWISE if self.is_neox_style else RotaryPosEmbeddingMode.PAIRWISE
677-
# sin = self.sin
678-
# cos = self.cos
679-
# query_shape = query.shape
680-
# key_shape = key.shape
681-
# query = query.view(num_tokens, -1, self.head_size)
682-
# key = key.view(num_tokens, -1, self.head_size)
683-
684-
# if self.head_size == self.rotary_dim:
685-
# # Avoid unnecessary slicing and concatenation
686-
# query = apply_rotary_pos_emb(query, cos, sin, None, 0, rope_mode)
687-
# key = apply_rotary_pos_emb(key, cos, sin, None, 0, rope_mode)
688-
# return query.reshape(query_shape), key.reshape(key_shape)
689-
690-
# query_rot = query[..., :self.rotary_dim]
691-
# query_pass = query[..., self.rotary_dim:]
692-
# query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
693-
# query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
694-
695-
# key_rot = key[..., :self.rotary_dim]
696-
# key_pass = key[..., self.rotary_dim:]
697-
# key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
698-
# key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
699-
# return query, key
658+
from habana_frameworks.torch.hpex.kernels import (RotaryPosEmbeddingMode, apply_rotary_pos_emb)
659+
660+
num_tokens = positions.shape[-1]
661+
cos_sin = self.cos_sin_cache[positions]
662+
cos, sin = cos_sin.chunk(2, dim=-1)
663+
if positions.ndim == 2:
664+
assert self.mrope_section
665+
666+
cos = torch.cat([m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], dim=-1)
667+
sin = torch.cat([m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], dim=-1)
668+
if self.is_neox_style:
669+
cos = torch.cat((cos, cos), dim=-1).unsqueeze(-2)
670+
sin = torch.cat((sin, sin), dim=-1).unsqueeze(-2)
671+
else:
672+
sin = torch.repeat_interleave(sin, 2, dim=-1, output_size=cos_sin.shape[-1]).unsqueeze(-2)
673+
cos = torch.repeat_interleave(cos, 2, dim=-1, output_size=cos_sin.shape[-1]).unsqueeze(-2)
674+
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
675+
# to query hidden dimension, so the original tensors need to be
676+
# expanded
677+
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
678+
# and expansion of cos/sin tensors via concatenation
679+
# GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
680+
# and expansion of cos/sin tensors via repeat_interleave
681+
rope_mode: RotaryPosEmbeddingMode
682+
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE if self.is_neox_style else RotaryPosEmbeddingMode.PAIRWISE
683+
query_shape = query.shape
684+
key_shape = key.shape
685+
query = query.view(num_tokens, -1, self.head_size)
686+
key = key.view(num_tokens, -1, self.head_size)
687+
688+
if self.head_size == self.rotary_dim:
689+
# Avoid unnecessary slicing and concatenation
690+
query = apply_rotary_pos_emb(query, cos, sin, None, 0, rope_mode)
691+
key = apply_rotary_pos_emb(key, cos, sin, None, 0, rope_mode)
692+
return query.reshape(query_shape), key.reshape(key_shape)
693+
694+
query_rot = query[..., :self.rotary_dim]
695+
query_pass = query[..., self.rotary_dim:]
696+
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
697+
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
698+
699+
key_rot = key[..., :self.rotary_dim]
700+
key_pass = key[..., self.rotary_dim:]
701+
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
702+
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
703+
return query, key

0 commit comments

Comments
 (0)