From 9e6536d73faaf0bdc1044d5cd34e3565140bf6c0 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Fri, 29 Nov 2024 18:33:44 +0100 Subject: [PATCH] fix modular conversion --- src/transformers/models/starcoder2/modular_starcoder2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index b323a3ce9e4d5b..95703257e1f3dc 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -136,9 +136,9 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning_once(