Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PLM GGUF Conversion & Inference Support #12457

Merged
merged 41 commits into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
563ec88
add edgellm model arch[conversation feature doesn't work]
Si1w Feb 1, 2025
f006d42
Merge branch 'ggerganov:master' into master
Si1w Feb 2, 2025
c14cad9
remove output.weight layer for edgellm arch
Si1w Feb 3, 2025
1a47cee
Merge branch 'master' of github.com:Si1w/llama.cpp
Si1w Feb 3, 2025
21ed73d
Merge branch 'ggerganov:master' into master
Si1w Feb 3, 2025
9a54239
Merge branch 'ggerganov:master' into master
Si1w Feb 3, 2025
08b5a57
Merge branch 'ggerganov:master' into master
Si1w Feb 13, 2025
7813da4
Merge branch 'ggml-org:master' into master
Si1w Feb 15, 2025
731ed0a
Merge branch 'ggml-org:master' into master
Si1w Feb 24, 2025
b808f00
Merge branch 'ggml-org:master' into master
Si1w Feb 25, 2025
f687e8e
Merge branch 'ggml-org:master' into master
Si1w Mar 3, 2025
5646eb9
[Model] update the name of the model
Si1w Mar 3, 2025
2518841
update the name of model arch in convert gguf
Si1w Mar 3, 2025
ff3d94f
Merge branch 'ggml-org:master' into master
Si1w Mar 11, 2025
444dfe5
Merge remote-tracking branch 'upstream/master'
Si1w Mar 13, 2025
22d35ac
[Model] Refarctor the model arch into llama-model
Si1w Mar 13, 2025
93cf1e4
Merge branch 'ggml-org:master' into master
Si1w Mar 13, 2025
850d301
Merge branch 'ggml-org:master' into master
Si1w Mar 14, 2025
4235644
Merge branch 'ggml-org:master' into master
Si1w Mar 15, 2025
0fcce31
Merge branch 'ggml-org:master' into master
Si1w Mar 18, 2025
55b8674
Merge branch 'ggml-org:master' into master
Si1w Mar 18, 2025
69d61ee
[Bug] Fix the bug in create attn kv
Si1w Mar 18, 2025
a7f4a68
Merge branch 'ggml-org:master' into master
Si1w Mar 18, 2025
066901e
Merge branch 'ggml-org:master' into master
Si1w Mar 19, 2025
9d47a39
Merge branch 'ggml-org:master' into master
Si1w Mar 19, 2025
95de3c6
[Code] Fix editorconfig erros
Si1w Mar 19, 2025
d7a2fc0
[Code] Remove Trailing whitespace
Si1w Mar 19, 2025
91f06a7
Merge branch 'ggml-org:master' into master
Si1w Mar 19, 2025
4bd85c6
[Code] Remove Trailing whitespace
Si1w Mar 19, 2025
5f75445
Merge branch 'master' of github.com:Si1w/llama.cpp
Si1w Mar 19, 2025
cd460ab
[Code] Change the order of model arch in list
Si1w Mar 19, 2025
6d3ac9a
[Code] Fix flake8 Lint errors
Si1w Mar 20, 2025
0b8de3f
Merge branch 'ggml-org:master' into master
Si1w Mar 20, 2025
646521e
Remove trailing white space
Si1w Mar 20, 2025
f5b5271
Merge branch 'master' of github.com:Si1w/llama.cpp
Si1w Mar 20, 2025
2339115
Merge branch 'ggml-org:master' into master
Si1w Mar 21, 2025
7772d4f
[Code] Remove call in model arch
Si1w Mar 21, 2025
e9c7ff4
Merge branch 'ggml-org:master' into master
Si1w Mar 21, 2025
1ec1c1e
Merge branch 'master' of github.com:Si1w/llama.cpp
Si1w Mar 21, 2025
82889bb
Merge branch 'ggml-org:master' into master
Si1w Mar 24, 2025
3a07979
Merge branch 'ggml-org:master' into master
Si1w Mar 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4419,6 +4419,29 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@Model.register("PLMForCausalLM")
class PLMModel(Model):
model_arch = gguf.MODEL_ARCH.PLM

def set_vocab(self):
self._set_vocab_gpt2()

def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length(hparams["v_head_dim"])
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
super().prepare_tensors()


@Model.register("T5WithLMHeadModel")
@Model.register("T5ForConditionalGeneration")
@Model.register("MT5ForConditionalGeneration")
Expand Down
16 changes: 16 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ class MODEL_ARCH(IntEnum):
GRANITE_MOE = auto()
CHAMELEON = auto()
WAVTOKENIZER_DEC = auto()
PLM = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -488,6 +489,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.GRANITE_MOE: "granitemoe",
MODEL_ARCH.CHAMELEON: "chameleon",
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
MODEL_ARCH.PLM: "plm",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -1464,6 +1466,20 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
],
MODEL_ARCH.PLM: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_DOWN,
],
MODEL_ARCH.CHATGLM : [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.ROPE_FREQS,
Expand Down
17 changes: 17 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
{ LLM_ARCH_CHAMELEON, "chameleon" },
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
{ LLM_ARCH_PLM, "plm" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -1043,6 +1044,22 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
LLM_ARCH_PLM,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_CHATGLM,
{
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ enum llm_arch {
LLM_ARCH_GRANITE_MOE,
LLM_ARCH_CHAMELEON,
LLM_ARCH_WAVTOKENIZER_DEC,
LLM_ARCH_PLM,
LLM_ARCH_UNKNOWN,
};

Expand Down
216 changes: 216 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_1_4B: return "1.4B";
case LLM_TYPE_1_5B: return "1.5B";
case LLM_TYPE_1_6B: return "1.6B";
case LLM_TYPE_1_8B: return "1.8B";
case LLM_TYPE_2B: return "2B";
case LLM_TYPE_2_8B: return "2.8B";
case LLM_TYPE_2_9B: return "2.9B";
Expand Down Expand Up @@ -1144,6 +1145,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_PLM:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
switch (hparams.n_layer) {
case 32: type = LLM_TYPE_1_8B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_CHATGLM:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
Expand Down Expand Up @@ -3068,6 +3078,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
}
} break;
case LLM_ARCH_PLM:
{
const int64_t n_embd_head_qk_rope = hparams.n_rot;
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
const int64_t kv_lora_rank = hparams.n_lora_kv;

tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);

// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
// output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);

for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];

layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);

layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0);

layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
case LLM_ARCH_BITNET:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
Expand Down Expand Up @@ -11615,6 +11654,178 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context {
}
};

struct llm_build_plm : public llm_graph_context {
llm_build_plm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k));

const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;

ggml_tensor * cur;
ggml_tensor * inpL;

// {n_embd, n_tokens}
inpL = build_inp_embd(model.tok_embd);

// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();

auto * inp_attn = build_attn_inp_kv_unified();

for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;

// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);

// self_attention
{
ggml_tensor * q = NULL;
q = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
cb(q, "q", il);

// split into {n_head * n_embd_head_qk_nope, n_tokens}
ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
ggml_row_size(q->type, hparams.n_embd_head_k),
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
0);
cb(q_nope, "q_nope", il);

// and {n_head * n_embd_head_qk_rope, n_tokens}
ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
ggml_row_size(q->type, hparams.n_embd_head_k),
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
ggml_row_size(q->type, n_embd_head_qk_nope));
cb(q_pe, "q_pe", il);

// {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
cb(kv_pe_compresseed, "kv_pe_compresseed", il);

// split into {kv_lora_rank, n_tokens}
ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
kv_pe_compresseed->nb[1],
0);
cb(kv_compressed, "kv_compressed", il);

// and {n_embd_head_qk_rope, n_tokens}
ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
kv_pe_compresseed->nb[1],
kv_pe_compresseed->nb[1],
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
cb(k_pe, "k_pe", il);

kv_compressed = build_norm(kv_compressed,
model.layers[il].attn_kv_a_norm, NULL,
LLM_NORM_RMS, il);
cb(kv_compressed, "kv_compressed", il);

// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
cb(kv, "kv", il);

// split into {n_head * n_embd_head_qk_nope, n_tokens}
ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
0);
cb(k_nope, "k_nope", il);

// and {n_head * n_embd_head_v, n_tokens}
ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
cb(v_states, "v_states", il);

v_states = ggml_cont(ctx0, v_states);
cb(v_states, "v_states", il);

v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
0);
cb(v_states, "v_states", il);

q_pe = ggml_rope_ext(
ctx0, q_pe, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(q_pe, "q_pe", il);

// shared RoPE key
k_pe = ggml_rope_ext(
ctx0, k_pe, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(k_pe, "k_pe", il);

ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
cb(q_states, "q_states", il);

ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
cb(k_states, "k_states", il);

cur = build_attn(inp_attn, gf,
model.layers[il].wo, NULL,
q_states, k_states, v_states, nullptr, kq_scale, il);
}

if (il == n_layer - 1) {
// skip computing output for unused tokens
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}

ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);

cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);

cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
NULL, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il);
cb(cur, "ffn_out", il);

cur = ggml_add(ctx0, cur, ffn_inp);

cur = build_cvec(cur, il);
cb(cur, "l_out", il);

// input for next layer
inpL = cur;
}

cur = inpL;

cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);

cb(cur, "result_norm", -1);
res->t_embd = cur;

cur = build_lora_mm(model.output, cur);

cb(cur, "result_output", -1);
res->t_logits = cur;

ggml_build_forward_expand(gf, cur);
}
};

llama_memory_i * llama_model::create_memory() const {
llama_memory_i * res;

Expand Down Expand Up @@ -11886,6 +12097,10 @@ llm_graph_result_ptr llama_model::build_graph(
{
llm = std::make_unique<llm_build_wavtokenizer_dec>(*this, params, gf);
} break;
case LLM_ARCH_PLM:
{
llm = std::make_unique<llm_build_plm>(*this, params, gf);
} break;
default:
GGML_ABORT("fatal error");
}
Expand Down Expand Up @@ -12012,6 +12227,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_ARCTIC:
case LLM_ARCH_DEEPSEEK:
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_PLM:
case LLM_ARCH_CHATGLM:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
Expand Down
1 change: 1 addition & 0 deletions src/llama-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ enum llm_type {
LLM_TYPE_1_4B,
LLM_TYPE_1_5B,
LLM_TYPE_1_6B,
LLM_TYPE_1_8B,
LLM_TYPE_2B,
LLM_TYPE_2_8B,
LLM_TYPE_2_9B,
Expand Down
Loading