Skip to content

Commit

Permalink
Refactor KV cache, Rope , reduce common code (#1148)
Browse files Browse the repository at this point in the history
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
  • Loading branch information
abhilash1910 and regisss committed Dec 5, 2024
1 parent c1bb5a5 commit d49ca3b
Show file tree
Hide file tree
Showing 11 changed files with 158 additions and 485 deletions.
3 changes: 3 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@
GaudiMllamaVisionSdpaAttention,
)
from .modeling_all_models import (
KVCache,
Matmul,
apply_customized_rope_module,
gaudi_check_and_enable_sdpa,
gaudi_conv1d_forward,
gaudi_get_extended_attention_mask,
Expand Down
10 changes: 2 additions & 8 deletions optimum/habana/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
CLIPVisionTransformer,
)

from ..modeling_all_models import Matmul


try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
Expand Down Expand Up @@ -47,14 +49,6 @@ def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, sof
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode)


class Matmul(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.matmul(x, y)


class Softmax(nn.Module):
def __init__(self):
super().__init__()
Expand Down
85 changes: 12 additions & 73 deletions optimum/habana/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,18 @@
print("Not using HPU fused kernel for scaled_dot_product_attention")
FusedSDPA = None

try:
from habana_frameworks.torch.hpu import sdp_kernel

SDPContext = True
except ImportError:
SDPContext = False

try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
except ImportError:
print("Not using HPU fused kernel for apply_rotary_pos_emb")
FusedRoPE = None

try:
from habana_frameworks.torch.hpu import sdp_kernel

SDPContext = True
except ImportError:
SDPContext = False

import habana_frameworks.torch.core as htcore
from torch import nn
Expand Down Expand Up @@ -53,6 +52,7 @@
_gaudi_prepare_4d_causal_attention_mask,
)
from ...modeling_rope_utils import GaudiRotaryEmbedding
from ..modeling_all_models import KVCache, Matmul, apply_customized_rope_module


logger = logging.get_logger(__name__)
Expand All @@ -72,14 +72,9 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
return residual


def apply_customized_rope(q, k, cos, sin, position_ids):
def apply_customized_rope(q, k, cos, sin, position_ids, training=True):
if q.device.type == "hpu" and FusedRoPE:
# TODO: remove `.clone()` when it is fixed in SynapseAI
return FusedRoPE.apply(
q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
), FusedRoPE.apply(
k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
)
return apply_customized_rope_module(q, k, cos, sin, position_ids, training)
else:
return apply_rotary_pos_emb(q, k, cos[position_ids], sin[position_ids])

Expand Down Expand Up @@ -141,14 +136,6 @@ def forward(self, x, dim=None, invAttnHead=None):
return torch.ops.hpu.softmax_fp8(x, dim, None, None, invAttnHead)


class Matmul(nn.Module):
def __init__(self):
super().__init__()

def forward(self, *args, **kwargs):
return torch.matmul(*args, **kwargs)


# ScaledDotProductAttention is based on torch.nn.functional.scaled_dot_product_attention
class ScaledDotProductAttention(nn.Module):
def __init__(self, config: FalconConfig):
Expand Down Expand Up @@ -185,56 +172,6 @@ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=Fa
return attn_output


def update(prev, cur, dim, idx, inp_seq_len):
orig_cur = cur
cur = cur.to(dtype=prev.dtype)

if prev.shape == cur.shape:
prev.copy_(cur)
return orig_cur

if cur.shape[-2] > 1 and cur.shape[-2] <= prev.shape[-2]:
# Initialize
prev[:, :, :inp_seq_len, :].copy_(cur)
return orig_cur
assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}"
if idx is not None:
prev.index_copy_(dim, idx - 1, cur)
prev_cast = prev.to(orig_cur.dtype)
return prev_cast
else:
return torch.cat((prev, cur), dim=dim)


class KVCache(torch.nn.Module):
def __init__(self):
super(KVCache, self).__init__()
self.cache = None
self.inp_seq_len = -1

def allocate(self, inp_seq_len, dtype, device, shape):
if self.cache is None or self.cache.shape != shape:
self.inp_seq_len = inp_seq_len
self.cache = torch.zeros(shape, dtype=dtype, device=device)
else:
assert (
self.inp_seq_len == inp_seq_len
), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
self.cache.fill_(0)

def get_shape(self):
if self.cache is None:
return None
return self.cache.shape

def forward(self, cur, dim, idx):
return self.update(self.cache, cur, dim, idx, self.inp_seq_len)

@staticmethod
def update(prev, cur, dim, idx, inp_seq_len):
return update(prev, cur, dim, idx, inp_seq_len)


class GaudiFalconAttention(FalconAttention):
"""
Inherits from FalconAttention: https://github.com/huggingface/transformers/blob/838b87abe231fd70be5132088d0dee72a7bb8d62/src/transformers/models/falcon/modeling_falcon.py#L267
Expand Down Expand Up @@ -383,7 +320,9 @@ def pre_attn_forward(

if alibi is None:
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
query_layer, key_layer = apply_customized_rope(query_layer, key_layer, cos, sin, position_ids)
query_layer, key_layer = apply_customized_rope(
query_layer, key_layer, cos, sin, position_ids, self.training
)

if use_cache:
if self.training:
Expand Down
41 changes: 14 additions & 27 deletions optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
print("Not using HPU fused kernel for apply_rotary_pos_emb")
FusedRoPE = None

from ..modeling_all_models import apply_customized_rope_module


class GaudiGPTNeoXAttention(GPTNeoXAttention):
def __init__(self, config: GPTNeoXConfig, layer_idx=None):
Expand Down Expand Up @@ -455,34 +457,19 @@ def prepare_inputs_for_generation(
return model_inputs


def gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()
self.sin_cached = emb.sin()


def apply_customized_rope(q, k, cos, sin, position_ids, training=True):
if q.device.type == "hpu" and FusedRoPE:
if training:
rope_q = FusedRoPE.apply(
q.to(torch.float), cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids
)
rope_k = FusedRoPE.apply(
k.to(torch.float), cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids
)
else:
if q.dtype == torch.bfloat16:
rope_q = FusedRoPE.apply(
q,
cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
position_ids,
)
else:
rope_q = FusedRoPE.apply(q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids)
if k.dtype == torch.bfloat16:
rope_k = FusedRoPE.apply(
k,
cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
position_ids,
)
else:
rope_k = FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids)
return rope_q, rope_k
return apply_customized_rope_module(q, k, cos, sin, position_ids, training)
else:
return apply_rotary_pos_emb(q.to(torch.float), k.to(torch.float), cos[position_ids], sin[position_ids])
73 changes: 5 additions & 68 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@
from ...modeling_attn_mask_utils import (
_gaudi_prepare_4d_causal_attention_mask,
)
from ..modeling_all_models import KVCache, Matmul, apply_customized_rope_module
from .configuration_llama import LlamaConfig


try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE # noqa

has_fused_rope = True
except ImportError:
Expand Down Expand Up @@ -348,7 +349,7 @@ def gaudi_llama_repeat_kv(
return query_states, key_states, value_states, attention_mask


# FusedScaledDotProductAttention
# FusedScaledDotProductAttention
class ModuleFusedSDPA(torch.nn.Module):
def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8):
super().__init__()
Expand Down Expand Up @@ -387,55 +388,6 @@ def forward(
)


class Matmul(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.matmul(x, y)


class KVCache(torch.nn.Module):
def __init__(self):
super(KVCache, self).__init__()
self.cache = None
self.inp_seq_len = -1

def allocate(self, inp_seq_len, dtype, device, shape):
if self.cache is None or self.cache.shape != shape:
self.inp_seq_len = inp_seq_len
self.cache = torch.zeros(shape, dtype=dtype, device=device)
else:
assert (
self.inp_seq_len == inp_seq_len
), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
self.cache.fill_(0)

@staticmethod
def update(prev, cur, dim, idx, inp_seq_len):
orig_cur = cur
if prev.shape == cur.shape:
prev.copy_(cur)
return orig_cur
if idx is not None and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
# Initialize
prev[:, :, :inp_seq_len, :].copy_(cur)
return orig_cur
if idx is not None:
prev.index_copy_(dim, idx - 1, cur)
return prev
else:
return torch.cat((prev, cur), dim=dim)

def get_shape(self):
if self.cache is None:
return None
return self.cache.shape

def forward(self, cur, dim, idx):
return self.update(self.cache, cur, dim, idx, self.inp_seq_len)


def GaudiDistributedAttention(fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed):
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
return fused_scaled_dot_product_attention_distributed
Expand Down Expand Up @@ -1568,23 +1520,8 @@ def _reorder_cache(past_key_values, beam_idx):
return reordered_past


def apply_customized_rope(q, k, cos, sin, position_ids):
def apply_customized_rope(q, k, cos, sin, position_ids, training=True):
if q.device.type == "hpu" and has_fused_rope:
# TODO: remove `.clone()` when it is fixed in SynapseAI
if k.dtype == torch.bfloat16:
return FusedRoPE.apply(
q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
), FusedRoPE.apply(
k,
cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16),
sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16),
position_ids,
)
return FusedRoPE.apply(
q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
), FusedRoPE.apply(
k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
)
return apply_customized_rope_module(q, k, cos, sin, position_ids, training)
else:
# keep the same implementation as Transformers v4.37.2
return apply_rotary_pos_emb(q, k, cos[position_ids], sin[position_ids])
Loading

0 comments on commit d49ca3b

Please sign in to comment.