Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Sep 10, 2024
1 parent 66cbc7a commit 14766d0
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def forward(self, hidden_states):


class CohereRotaryEmbedding(nn.Module):
# Note: the forward pass of this RoPE is slightly different from Llama's, resulting in different `sin`/`cos` for
# the same parameterization. The differences are highlighted with a comment.

def __init__(
self,
dim=None,
Expand Down Expand Up @@ -211,7 +214,7 @@ def forward(self, x, position_ids):
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.repeat_interleave(freqs, 2, dim=-1) # Note that this line is different from e.g. Llama.
emb = torch.repeat_interleave(freqs, 2, dim=-1) # This line differs from Llama's implementation
cos = emb.cos()
sin = emb.sin()

Expand Down

0 comments on commit 14766d0

Please sign in to comment.