Skip to content

Commit

Permalink
fix modular conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
VladOS95-cyber committed Dec 2, 2024
1 parent 2a11213 commit 9e6536d
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/transformers/models/starcoder2/modular_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
Expand Down

0 comments on commit 9e6536d

Please sign in to comment.