Skip to content

Commit

Permalink
add LlamaRotary because generation is not good otherwise
Browse files Browse the repository at this point in the history
  • Loading branch information
3outeille committed Jun 14, 2024
1 parent 3e169c5 commit 31c12e8
Showing 1 changed file with 92 additions and 6 deletions.
98 changes: 92 additions & 6 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,73 @@ def forward(
return x_out.type(dtype)


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


### llama
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim: int, end: int, theta: float = 500000.0):
super().__init__()
self.dim = dim
self.end = end
self.theta = theta
self.init_rotary_embeddings()

def init_rotary_embeddings(self):
inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cuda") / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

@torch.no_grad()
def forward(
self,
x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk]
position_ids: Optional[torch.LongTensor], # [batch_size, seq_length]
):
# x: [bs, num_attention_heads, seq_len, head_size]
# print("rotary")
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=2):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


class GLUActivation(nn.Module):
def __init__(self, act_fn_name: str):
super().__init__()
Expand Down Expand Up @@ -303,10 +370,17 @@ def __init__(
contiguous_chunks=qkv_contiguous_chunks,
)
# TODO(kunhao): We want to have only one version per device and not one version per layer.
self.rotary_embedding = RotaryEmbedding(
dim=self.d_qk,
end=config.max_position_embeddings,
)
if config.rope_interleaved:
self.rotary_embedding = RotaryEmbedding(
dim=self.d_qk,
end=config.max_position_embeddings,
)
else:
self.rotary_embedding = LlamaRotaryEmbedding(
dim=self.d_qk,
end=config.max_position_embeddings,
)
self.rope_interleaved = config.rope_interleaved

# NOTE: Only supported for training (TODO(fmom): position_ids not supported yet)
self.flash_rotary_embedding = FlashRotaryEmbedding(
Expand Down Expand Up @@ -336,6 +410,7 @@ def forward(
self,
hidden_states, # [seq_length, batch_size, hidden_size]
sequence_mask, # [batch_size, seq_length]
position_ids: Optional[torch.LongTensor] = None,
):
from flash_attn import bert_padding
from flash_attn.flash_attn_interface import (
Expand Down Expand Up @@ -390,8 +465,19 @@ def forward(
# Compute rotary embeddings
# Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache
old_rotary_embed_end = self.rotary_embedding.end
query_states = self.rotary_embedding(query_states, position_ids=position_ids)
key_states = self.rotary_embedding(key_states, position_ids=position_ids)

# Rotate half rotary_embedding
# cos, sin = self.rotary_embedding(value_states, position_ids)
# query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

# interleaved
if self.rope_interleaved:
query_states = self.rotary_embedding(query_states, position_ids=position_ids)
key_states = self.rotary_embedding(key_states, position_ids=position_ids)
# llama rotary position embedding
else:
cos, sin = self.rotary_embedding(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if "key" not in store:
# First inference iteration (Prefill)
Expand Down

0 comments on commit 31c12e8

Please sign in to comment.