@@ -484,11 +484,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
484484 return;
485485 }
486486
487- ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
488- ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
489- ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer);
490- ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
491- ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
487+ ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
488+ ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
489+ ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer);
490+ ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
491+ ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
492+ ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false);
493+ ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false);
492494
493495 if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
494496 ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features);
@@ -504,8 +506,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
504506 GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
505507 if (hparams.n_expert > 0) {
506508 GGML_ASSERT(hparams.n_expert_used > 0);
509+ GGML_ASSERT(hparams.n_expert_groups < hparams.n_expert);
510+ if (hparams.n_expert_groups > 0) {
511+ GGML_ASSERT(hparams.n_expert % hparams.n_expert_groups == 0);
512+ GGML_ASSERT(hparams.n_group_used > 0);
513+ GGML_ASSERT(hparams.n_group_used < hparams.n_expert_groups);
514+ }
507515 } else {
508516 GGML_ASSERT(hparams.n_expert_used == 0);
517+ GGML_ASSERT(hparams.n_expert_groups == 0);
509518 }
510519
511520 std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
@@ -1896,8 +1905,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
18961905 ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
18971906 ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
18981907 ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
1899- ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups);
1900- ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used);
19011908 ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
19021909 ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
19031910 ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func);
0 commit comments