@@ -655,45 +655,49 @@ def forward_oot(
655655 key : torch .Tensor ,
656656 offsets : Optional [torch .Tensor ] = None ,
657657 ) -> tuple [torch .Tensor , torch .Tensor ]:
658- return self .forward_native (positions , query , key , offsets )
659- # from habana_frameworks.torch.hpex.kernels import (RotaryPosEmbeddingMode, apply_rotary_pos_emb)
660-
661- # # Prepare cos-sin caches for long-context + LoRA with offsets for every
662- # # forward, since the offset information wasn't available previously
663- # if not hasattr(self, "sin") or self.recompute_cos_sin:
664- # self.prepare_cos_sin(positions, offsets, recompute_cos_sin=True)
665- # if hasattr(self, "scaling_factors") or hasattr(self, "scaling_factor") or self.sin is None:
666- # self.prepare_cos_sin(positions, offsets)
667- # num_tokens = positions.shape[0] * positions.shape[1]
668- # # HPU RoPE kernel requires hidden dimension for cos and sin to be equal
669- # # to query hidden dimension, so the original tensors need to be
670- # # expanded
671- # # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
672- # # and expansion of cos/sin tensors via concatenation
673- # # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
674- # # and expansion of cos/sin tensors via repeat_interleave
675- # rope_mode: RotaryPosEmbeddingMode
676- # rope_mode = RotaryPosEmbeddingMode.BLOCKWISE if self.is_neox_style else RotaryPosEmbeddingMode.PAIRWISE
677- # sin = self.sin
678- # cos = self.cos
679- # query_shape = query.shape
680- # key_shape = key.shape
681- # query = query.view(num_tokens, -1, self.head_size)
682- # key = key.view(num_tokens, -1, self.head_size)
683-
684- # if self.head_size == self.rotary_dim:
685- # # Avoid unnecessary slicing and concatenation
686- # query = apply_rotary_pos_emb(query, cos, sin, None, 0, rope_mode)
687- # key = apply_rotary_pos_emb(key, cos, sin, None, 0, rope_mode)
688- # return query.reshape(query_shape), key.reshape(key_shape)
689-
690- # query_rot = query[..., :self.rotary_dim]
691- # query_pass = query[..., self.rotary_dim:]
692- # query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
693- # query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
694-
695- # key_rot = key[..., :self.rotary_dim]
696- # key_pass = key[..., self.rotary_dim:]
697- # key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
698- # key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
699- # return query, key
658+ from habana_frameworks .torch .hpex .kernels import (RotaryPosEmbeddingMode , apply_rotary_pos_emb )
659+
660+ num_tokens = positions .shape [- 1 ]
661+ cos_sin = self .cos_sin_cache [positions ]
662+ cos , sin = cos_sin .chunk (2 , dim = - 1 )
663+ if positions .ndim == 2 :
664+ assert self .mrope_section
665+
666+ cos = torch .cat ([m [i ] for i , m in enumerate (cos .split (self .mrope_section , dim = - 1 ))], dim = - 1 )
667+ sin = torch .cat ([m [i ] for i , m in enumerate (sin .split (self .mrope_section , dim = - 1 ))], dim = - 1 )
668+ if self .is_neox_style :
669+ cos = torch .cat ((cos , cos ), dim = - 1 ).unsqueeze (- 2 )
670+ sin = torch .cat ((sin , sin ), dim = - 1 ).unsqueeze (- 2 )
671+ else :
672+ sin = torch .repeat_interleave (sin , 2 , dim = - 1 , output_size = cos_sin .shape [- 1 ]).unsqueeze (- 2 )
673+ cos = torch .repeat_interleave (cos , 2 , dim = - 1 , output_size = cos_sin .shape [- 1 ]).unsqueeze (- 2 )
674+ # HPU RoPE kernel requires hidden dimension for cos and sin to be equal
675+ # to query hidden dimension, so the original tensors need to be
676+ # expanded
677+ # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
678+ # and expansion of cos/sin tensors via concatenation
679+ # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
680+ # and expansion of cos/sin tensors via repeat_interleave
681+ rope_mode : RotaryPosEmbeddingMode
682+ rope_mode = RotaryPosEmbeddingMode .BLOCKWISE if self .is_neox_style else RotaryPosEmbeddingMode .PAIRWISE
683+ query_shape = query .shape
684+ key_shape = key .shape
685+ query = query .view (num_tokens , - 1 , self .head_size )
686+ key = key .view (num_tokens , - 1 , self .head_size )
687+
688+ if self .head_size == self .rotary_dim :
689+ # Avoid unnecessary slicing and concatenation
690+ query = apply_rotary_pos_emb (query , cos , sin , None , 0 , rope_mode )
691+ key = apply_rotary_pos_emb (key , cos , sin , None , 0 , rope_mode )
692+ return query .reshape (query_shape ), key .reshape (key_shape )
693+
694+ query_rot = query [..., :self .rotary_dim ]
695+ query_pass = query [..., self .rotary_dim :]
696+ query_rot = apply_rotary_pos_emb (query_rot , cos , sin , None , 0 , rope_mode )
697+ query = torch .cat ((query_rot , query_pass ), dim = - 1 ).reshape (query_shape )
698+
699+ key_rot = key [..., :self .rotary_dim ]
700+ key_pass = key [..., self .rotary_dim :]
701+ key_rot = apply_rotary_pos_emb (key_rot , cos , sin , None , 0 , rope_mode )
702+ key = torch .cat ((key_rot , key_pass ), dim = - 1 ).reshape (key_shape )
703+ return query , key
0 commit comments