Skip to content

Commit

Permalink
add parallel tp for starcoder2
Browse files Browse the repository at this point in the history
  • Loading branch information
VladOS95-cyber committed Dec 2, 2024
1 parent cfda3c1 commit 2a11213
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ class Starcoder2Config(PretrainedConfig):

model_type = "starcoder2"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Starcoder2`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.c_fc": "colwise",
"layers.*.mlp.c_proj": "colwise",
}

def __init__(
self,
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/models/starcoder2/modeling_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,9 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
Expand Down Expand Up @@ -1043,6 +1043,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}

def __init__(self, config):
super().__init__(config)
Expand Down

0 comments on commit 2a11213

Please sign in to comment.