Skip to content

Commit

Permalink
add base tp plan for qwen2 and qwen2moe
Browse files Browse the repository at this point in the history
  • Loading branch information
VladOS95-cyber committed Nov 29, 2024
1 parent 89d7bf5 commit 73d5df4
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 6 deletions.
11 changes: 11 additions & 0 deletions src/transformers/models/qwen2/configuration_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,17 @@ class Qwen2Config(PretrainedConfig):
model_type = "qwen2"
keys_to_ignore_at_inference = ["past_key_values"]

# Default tensor parallel plan for base model `Qwen2`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
vocab_size=151936,
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,9 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
Expand Down Expand Up @@ -1077,6 +1077,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}

def __init__(self, config):
super().__init__(config)
Expand Down
11 changes: 11 additions & 0 deletions src/transformers/models/qwen2_moe/configuration_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,17 @@ class Qwen2MoeConfig(PretrainedConfig):
model_type = "qwen2_moe"
keys_to_ignore_at_inference = ["past_key_values"]

# Default tensor parallel plan for base model `Qwe2Moe`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
vocab_size=151936,
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,9 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
Expand Down Expand Up @@ -1257,6 +1257,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}

def __init__(self, config):
super().__init__(config)
Expand Down

0 comments on commit 73d5df4

Please sign in to comment.