diff --git a/src/llama.cpp b/src/llama.cpp index b39906fd53b6f..1a65faef9bedc 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4625,16 +4625,6 @@ static void llm_load_hparams( // non-transformer models do not have attention heads if (hparams.n_head() > 0) { - // sanity check for n_rot (optional) - hparams.n_rot = hparams.n_embd / hparams.n_head(); - - ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); - - if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) { - if (hparams.n_rot != hparams.n_embd / hparams.n_head()) { - throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head())); - } - } // gpt-neox n_rot = rotary_pct * (n_embd / n_head) // gpt-j n_rot = rotary_dim @@ -4643,6 +4633,17 @@ static void llm_load_hparams( hparams.n_embd_head_v = hparams.n_embd / hparams.n_head(); ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); + + // sanity check for n_rot (optional) + hparams.n_rot = hparams.n_embd_head_k; + + ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); + + if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) { + if (hparams.n_rot != hparams.n_embd_head_k) { + throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); + } + } } else { hparams.n_rot = 0; hparams.n_embd_head_k = 0; @@ -11490,7 +11491,7 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr, - n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); @@ -11499,7 +11500,7 @@ struct llm_build_context { Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr, - n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); @@ -11603,7 +11604,7 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr, - n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); @@ -11612,7 +11613,7 @@ struct llm_build_context { Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr, - n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il);