Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support Multi-GPU for Qwen-MoE model #2573

Merged
merged 1 commit into from
Jun 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions python/mlc_llm/model/qwen2_moe/qwen2_moe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.nn.expert import MixtralExperts
from mlc_llm.support import logging
from mlc_llm.support import tensor_parallel as tp

logger = logging.getLogger(__name__)

# TODO(mlc-team): Support Tensor Parallel.


@dataclasses.dataclass
class Qwen2MoeConfig(QWen2Config): # pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -68,10 +67,7 @@ def __init__(self, config: Qwen2MoeConfig):
)
self.moe_intermediate_size = config.moe_intermediate_size // config.tensor_parallel_shards
self.norm_topk_prob = config.norm_topk_prob
self.share_expert_intermediate_size = (
config.shared_expert_intermediate_size // config.tensor_parallel_shards
)
self.shared_expert = Qwen2MoeMLP(config, self.share_expert_intermediate_size)
self.shared_expert = Qwen2MoeMLP(config, config.shared_expert_intermediate_size)
self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias=False)

self.gate = nn.Linear(
Expand Down Expand Up @@ -154,7 +150,42 @@ def __init__(self, config: Qwen2MoeConfig):
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, -1, config.rms_norm_eps, bias=False
)

def _set_tp():
def _set(layer, hint):
layer.attrs["shard_strategy"] = hint

hd = config.head_dim
q = self.self_attn.num_attention_heads * hd
k = self.self_attn.num_key_value_heads * hd
v = self.self_attn.num_key_value_heads * hd
si = self.mlp.shared_expert.intermediate_size
mi = self.mlp.moe_intermediate_size
_set(
self.self_attn.c_attn.weight,
tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]),
)
_set(
self.self_attn.c_attn.bias,
tp.ShardSingleDim("_shard_qkv_bias", dim=0, segs=[q, k, v]),
)
_set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1))
_set(
self.mlp.shared_expert.gate_up_proj.weight,
tp.ShardSingleDim("_shard_shared_mlp_up", segs=[si, si], dim=0),
)
_set(
self.mlp.shared_expert.down_proj.weight,
tp.ShardSingleDim("_shard_shared_mlp_down", dim=1),
)
_set(
self.mlp.moe_gate_up_proj.weight,
tp.ShardSingleDim("_shard_moe_mlp_up", segs=[mi, mi], dim=1),
)
_set(self.mlp.moe_down_proj.weight, tp.ShardSingleDim("_shard_moe_mlp_down", dim=2))

self.tensor_parallel_shards = config.tensor_parallel_shards
_set_tp()

def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
out = self.input_layernorm(hidden_states)
Expand Down Expand Up @@ -202,8 +233,6 @@ def __init__(self, config: Qwen2MoeConfig):
self.vocab_size = config.vocab_size
self.tensor_parallel_shards = config.tensor_parallel_shards
self.head_dim = config.head_dim
if self.tensor_parallel_shards != 1:
raise ValueError("Currently only support tensor_parallel_shards=1.")

def to(self, dtype: Optional[str] = None):
super().to(dtype=dtype)
Expand Down
Loading