Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Cohere Model #30027

Merged
merged 3 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/models/cohere/configuration_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class CohereConfig(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
use_qk_norm (`bool`, *optional*, defaults to `False`):
Whether to use query-key normalization in the attention

```python
>>> from transformers import CohereModel, CohereConfig
Expand Down Expand Up @@ -123,6 +125,7 @@ def __init__(
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
use_qk_norm=False,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -145,6 +148,7 @@ def __init__(
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.use_qk_norm = use_qk_norm

super().__init__(
pad_token_id=pad_token_id,
Expand Down
62 changes: 42 additions & 20 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def _get_unpad_data(attention_mask):


class CohereLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5, bias=False):
def __init__(self, hidden_size=None, eps=1e-5, bias=False):
"""The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size)) if bias else None
self.variance_epsilon = eps

def forward(self, hidden_states):
Expand All @@ -89,8 +89,6 @@ def forward(self, hidden_states):
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight.to(torch.float32) * hidden_states
if self.bias is not None:
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
hidden_states = hidden_states + self.bias.to(torch.float32)
return hidden_states.to(input_dtype)


Expand Down Expand Up @@ -122,7 +120,7 @@ def forward(self, x, position_ids):
emb = torch.repeat_interleave(freqs, 2, dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
return cos, sin


def rotate_half(x):
Expand All @@ -133,7 +131,6 @@ def rotate_half(x):
return rot_x


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand All @@ -154,11 +151,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
dtype = q.dtype
q = q.float()
k = k.float()
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
Comment on lines +154 to +161
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is done outside apply_rotary_pos_emb we can keep the copied from 😉 but it's not an issue



# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
Expand Down Expand Up @@ -192,7 +192,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


# Copied from transformers.models.llama.modeling_llama.LlamaAttention Llama->Cohere
class CohereAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

Expand All @@ -216,13 +215,21 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None):
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.use_qk_norm = config.use_qk_norm

if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)

if self.use_qk_norm:
# When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this comment is relevant as the model is not sharded by default

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is to warn others who port this to other frameworks from HF.

self.q_norm = CohereLayerNorm(hidden_size=(self.num_heads, self.head_dim), eps=config.layer_norm_eps)
self.k_norm = CohereLayerNorm(
hidden_size=(self.num_key_value_heads, self.head_dim), eps=config.layer_norm_eps
)

self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
Expand Down Expand Up @@ -255,8 +262,14 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
if self.use_qk_norm:
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)

query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

past_key_value = getattr(self, "past_key_value", past_key_value)
Expand Down Expand Up @@ -335,11 +348,14 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
if self.use_qk_norm:
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)

query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

cos, sin = self.rotary_emb(value_states, position_ids)
Expand Down Expand Up @@ -505,7 +521,7 @@ class CohereSdpaAttention(CohereAttention):
SDPA API.
"""

# Adapted from CohereAttention.forward
# Ignore copy
def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -538,8 +554,14 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
if self.use_qk_norm:
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)

query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

cos, sin = self.rotary_emb(value_states, position_ids)
Expand Down Expand Up @@ -599,7 +621,7 @@ def __init__(self, config: CohereConfig, layer_idx: int):
self.self_attn = COHERE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)

self.mlp = CohereMLP(config)
self.input_layernorm = CohereLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.input_layernorm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)

def forward(
self,
Expand Down Expand Up @@ -822,7 +844,7 @@ def __init__(self, config: CohereConfig):
self.layers = nn.ModuleList(
[CohereDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = CohereLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.norm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
self.gradient_checkpointing = False

# Initialize weights and apply final processing
Expand Down