Skip to content

Commit 6dd223b

Browse files
authored
make expert group selection generally available
The new LLaDA2Moe model uses this method too, make it generally available regardless of architecture.
1 parent bc6c48a commit 6dd223b

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

src/llama-graph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
929929
}
930930

931931
// select top n_group_used expert groups
932-
if (arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) {
932+
if (hparams.n_expert_groups > 0 && n_tokens > 0) {
933933
const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
934934

935935
// organize experts into n_expert_groups

src/llama-model.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -484,11 +484,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
484484
return;
485485
}
486486

487-
ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
488-
ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
489-
ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer);
490-
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
491-
ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
487+
ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
488+
ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
489+
ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer);
490+
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
491+
ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
492+
ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false);
493+
ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false);
492494

493495
if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
494496
ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features);
@@ -504,8 +506,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
504506
GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
505507
if (hparams.n_expert > 0) {
506508
GGML_ASSERT(hparams.n_expert_used > 0);
509+
GGML_ASSERT(hparams.n_expert_groups < hparams.n_expert);
510+
if (hparams.n_expert_groups > 0) {
511+
GGML_ASSERT(hparams.n_expert % hparams.n_expert_groups == 0);
512+
GGML_ASSERT(hparams.n_group_used > 0);
513+
GGML_ASSERT(hparams.n_group_used < hparams.n_expert_groups);
514+
}
507515
} else {
508516
GGML_ASSERT(hparams.n_expert_used == 0);
517+
GGML_ASSERT(hparams.n_expert_groups == 0);
509518
}
510519

511520
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
@@ -1896,8 +1905,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
18961905
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
18971906
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
18981907
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
1899-
ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups);
1900-
ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used);
19011908
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
19021909
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
19031910
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func);

0 commit comments

Comments
 (0)