Skip to content

Commit 2ddadc4

Browse files
committed
n_ff_shexp calculation and further minor changes
1 parent bb23dd0 commit 2ddadc4

File tree

3 files changed

+25
-5
lines changed

3 files changed

+25
-5
lines changed

.gitignore

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,17 @@ poetry.toml
146146
# Local scripts
147147
/run-vim.sh
148148
/run-chat.sh
149+
ernieconv.txt
150+
ErnieMoE/.gitattributes
151+
ErnieMoE/added_tokens.json
152+
ErnieMoE/config.json
153+
ErnieMoE/configuration_ernie4_5_moe.py
154+
ErnieMoE/generation_config.json
155+
ErnieMoE/LICENSE
156+
ErnieMoE/model.safetensors.index.json
157+
ErnieMoE/modeling_ernie4_5_moe.py
158+
ErnieMoE/README.md
159+
ErnieMoE/special_tokens_map.json
160+
ErnieMoE/tokenization_ernie4_5.py
161+
ErnieMoE/tokenizer_config.json
162+
ErnieMoE/tokenizer.model

convert_hf_to_gguf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2830,6 +2830,8 @@ def set_gguf_parameters(self):
28302830
self.gguf_writer.add_interleave_moe_layer_step(self.hparams["moe_layer_interval"])
28312831
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
28322832
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
2833+
if (shared_expert_intermediate_size := self.hparams.get('intermediate_size')) is not None and (num_key_value_heads := self.hparams.get('num_key_value_heads')) is not None:
2834+
self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_intermediate_size // num_key_value_heads)
28332835

28342836
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
28352837
if "exps" in new_name:

src/llama-model.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1611,8 +1611,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
16111611
{
16121612
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
16131613
if (arch == LLM_ARCH_ERNIE4_5_MOE) {
1614-
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
1615-
ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
1614+
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
1615+
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
1616+
ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
16161617
}
16171618
switch (hparams.n_layer) {
16181619
case 18: type = LLM_TYPE_0_3B; break;
@@ -4787,7 +4788,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
47874788
int n_ff_exp = hparams.n_ff_exp;
47884789

47894790
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
4790-
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
4791+
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
4792+
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED);
47914793
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0);
47924794
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
47934795

@@ -8433,7 +8435,9 @@ struct llm_build_ernie4_5_moe : public llm_graph_context {
84338435
cb(ffn_inp, "ffn_inp", il);
84348436

84358437
// feed-forward network
8436-
if (model.layers[il].ffn_gate_inp == nullptr) {
8438+
bool is_moe_layer = arch == LLM_ARCH_ERNIE4_5_MOE && hparams.n_moe_layer_step > 0 && (il + 1) % hparams.n_moe_layer_step == 0;
8439+
8440+
if (!is_moe_layer) {
84378441
cur = build_norm(ffn_inp,
84388442
model.layers[il].ffn_norm, NULL,
84398443
LLM_NORM_RMS, il);
@@ -8458,7 +8462,7 @@ struct llm_build_ernie4_5_moe : public llm_graph_context {
84588462
model.layers[il].ffn_up_exps,
84598463
model.layers[il].ffn_gate_exps,
84608464
model.layers[il].ffn_down_exps,
8461-
nullptr,
8465+
model.layers[il].ffn_exp_probs_b,
84628466
n_expert, n_expert_used,
84638467
LLM_FFN_SILU, true,
84648468
false, 0.0,

0 commit comments

Comments
 (0)