From db60623e7926fb151b3cc63f029929122cac342a Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Sun, 10 Aug 2025 23:52:54 -0400 Subject: [PATCH 01/35] added getter for nextn layer count and server slot has_mtp property --- include/llama.h | 2 ++ src/llama-model.cpp | 4 ++++ tools/server/server.cpp | 12 +++++++++++- 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/include/llama.h b/include/llama.h index 545e957e5f5..3bade3ae71c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -495,6 +495,8 @@ extern "C" { LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab); + LLAMA_API int32_t llama_model_n_nextn_layer(const struct llama_model * model); + // Functions to access the model's GGUF metadata scalar values // - The functions return the length of the string on success, or -1 on failure // - The output string is always null-terminated and cleared on failure diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 58ca7df707e..2351478c2f0 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -18587,6 +18587,10 @@ const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) return nullptr; } +int32_t llama_model_n_nextn_layer(const llama_model * model) { + return model->hparams.nextn_predict_layers; +} + // deprecated int32_t llama_n_ctx_train(const llama_model * model) { return llama_model_n_ctx_train(model); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index a255d481a4d..7a931cc6b07 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1294,7 +1294,8 @@ struct server_slot { mtmd_context * mctx = nullptr; common_speculative * spec = nullptr; - + bool has_mtp = false; + std::vector lora; // the index relative to completion multi-task request @@ -2121,6 +2122,15 @@ struct server_context { common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); } } + else if (llama_model_n_nextn_layer(model) > 0) { + SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model)); + slot.has_mtp = true; + + // assume one speculative token (true of all well-known MTP models so far) + slot.batch_spec = llama_batch_init(2, 0, 1); + params_base.speculative.n_min = 0; + params_base.speculative.n_max = 1; + } SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); From e434f87cc739a1901931d88e33f777170a4e18e7 Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Mon, 11 Aug 2025 01:21:47 -0400 Subject: [PATCH 02/35] some work towards building mtp layer graph --- src/llama-model.cpp | 139 ++++++++++++++++++++++++++++++++++++++++ tools/server/server.cpp | 18 +++--- 2 files changed, 149 insertions(+), 8 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 2351478c2f0..9e09e7e0a4f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4507,6 +4507,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // but only PROCESS up to last layer (skipping final NextN layer) in forward pass for (int i = 0; i < n_layer; ++i) { int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { // skip all tensors in the NextN layers flags |= TENSOR_SKIP; @@ -13919,6 +13920,144 @@ struct llm_build_glm4_moe : public llm_graph_context { } }; +struct llm_build_glm4_moe_mtp : public llm_graph_context { + llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params, + // For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization + ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past + ) : llm_graph_context(params) { + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + // Assuming a single MTP layer at the end + const int il = hparams.n_layer - 1; + const auto & mtp_layer = model.layers[il]; + + ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); + ggml_set_i32(inp_pos, n_past); + llm_graph_input_attn_no_cache * inp_attn = nullptr; + + ggml_tensor * cur; + + // get MTP embedding for last (conventionally sampled) token + ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); + ggml_set_i32(inp_token_id, last_token_id); + ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id); + ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); + + // vLLM l99 previous_hidden_states = self.hnorm(previous_hidden_states) + ggml_tensor * hidden_state_norm = build_norm(hidden_state_inp, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); + + ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); // torch.cat + cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj + + + // now proceed through last layer (skipped in main model) + ggml_tensor * inpSA = cur; + + // Pre-attention norm for the MTP block + ggml_tensor* attn_inp = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(mtp_layer.wq, cur); + if (mtp_layer.bq) { + Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq); + } + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(mtp_layer.wk, cur); + if (mtp_layer.bk) { + Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk); + } + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(mtp_layer.wv, cur); + if (mtp_layer.bv) { + Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv); + } + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply Q/K norm if available (GLM-4.5 355B variant) + if (mtp_layer.attn_q_norm) { + Qcur = build_norm(Qcur, mtp_layer.attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (mtp_layer.attn_k_norm) { + Kcur = build_norm(Kcur, mtp_layer.attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + } + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + mtp_layer.wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + + cur = build_norm(ffn_inp, mtp_layer.attn_post_norm, NULL, LLM_NORM_RMS, il); + + // moe ffn for nextn block + { + // Process routed experts using existing MoE infrastructure + ggml_tensor * routed_out = build_moe_ffn(cur, + mtp_layer.ffn_gate_inp, + mtp_layer.ffn_up_exps, + mtp_layer.ffn_gate_exps, + mtp_layer.ffn_down_exps, + mtp_layer.ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(routed_out, "ffn_moe_out", il); + + // Process shared expert on original input + ggml_tensor * shared_out = build_ffn(cur, + mtp_layer.ffn_up_shexp, NULL, NULL, + mtp_layer.ffn_gate_shexp, NULL, NULL, + mtp_layer.ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shared_out, "ffn_shexp_out", il); + + // Final output: routed_output + shared_output + cur = ggml_add(ctx0, routed_out, shared_out); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il); + cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur); + + res->t_logits = cur; + + ggml_build_forward_expand(gf, res->t_logits); + } +}; + struct llm_build_nemotron : public llm_graph_context { llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 7a931cc6b07..08ffb25d241 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1432,7 +1432,7 @@ struct server_slot { } bool can_speculate() const { - return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt; + return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt; } void add_token(const completion_token_output & token) { @@ -2122,14 +2122,16 @@ struct server_context { common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); } } + + // if model has MTP and no draft model is specified... else if (llama_model_n_nextn_layer(model) > 0) { - SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model)); - slot.has_mtp = true; - - // assume one speculative token (true of all well-known MTP models so far) - slot.batch_spec = llama_batch_init(2, 0, 1); - params_base.speculative.n_min = 0; - params_base.speculative.n_max = 1; + SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model)); + slot.has_mtp = true; + + // assume one speculative token (true of all well-known MTP models so far) + slot.batch_spec = llama_batch_init(2, 0, 1); + params_base.speculative.n_min = 0; + params_base.speculative.n_max = 1; } SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); From 1f477b375504aa557ed21066aa6783b11781a179 Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Mon, 11 Aug 2025 20:54:45 -0400 Subject: [PATCH 03/35] make nextn weights loadable without a crash --- src/llama-arch.cpp | 13 +++++++------ src/llama-model.cpp | 27 ++++++++++++++++++++++++++- tools/server/server.cpp | 3 ++- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 18dcc6ddfe5..4b6fa3e6059 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -2240,12 +2240,13 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, // NextN/MTP tensors are currently ignored (reserved for future MTP support) // These tensors only exist in the last layer(s) and are treated as output tensors - {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // Changed to LLM_TENSOR_LAYER_REPEATING because we saved these under a blk with a non-negative id + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 9e09e7e0a4f..a9310a60905 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4510,7 +4510,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { // skip all tensors in the NextN layers - flags |= TENSOR_SKIP; + // flags |= TENSOR_SKIP; } auto & layer = layers[i]; @@ -4574,12 +4574,37 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + + // our input/output layer sanity check prevents us from loading the eh_proj layer! + // this is because eh_proj is labelled with a layer number in existing GGUFs, + // so we need to set bid == to successfully load the tensors, but our io layer sanity check requires bid == -1. + // this function is a hack that creates the nextn layers as LLM_TENSOR_LAYER_REPEATING instead. + /* auto create_tensor_override_io_sanity_check = + [&](llm_tensor type_enum, const char * suffix, int bid, const std::initializer_list& ne, int flags) -> ggml_tensor * { + + auto tn_orig = tn(type_enum, suffix, bid); + llm_tensor_info info_override = *tn_orig.info; + info_override.layer = LLM_TENSOR_LAYER_REPEATING; + + auto tn_override = tn_orig; + tn_override.info = &info_override; + + return create_tensor(tn_override, ne, flags); + };*/ + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags); layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags); + + // layer.nextn.eh_proj = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i, { 2 * n_embd, n_embd }, flags); + // layer.nextn.embed_tokens = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i, { n_embd, n_vocab }, flags); + // layer.nextn.enorm = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_ENORM, "weight", i, { n_embd }, flags); + // layer.nextn.hnorm = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_HNORM, "weight", i, { n_embd }, flags); + // layer.nextn.shared_head_head = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i, { n_embd, n_vocab }, flags); + // layer.nextn.shared_head_norm = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i, { n_embd }, flags); } } } diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 08ffb25d241..a9ad900ce39 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1432,7 +1432,8 @@ struct server_slot { } bool can_speculate() const { - return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt; + // return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt; + return (ctx_dft) && params.speculative.n_max > 0 && params.cache_prompt; } void add_token(const completion_token_output & token) { From 03231da69eec20677e25e2307d4fe31ac2ede034 Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Tue, 12 Aug 2025 01:03:59 -0400 Subject: [PATCH 04/35] add model member function to build mtp graph, to be called from speculative.cpp --- src/llama-model.cpp | 16 ++++++++++++++++ src/llama-model.h | 2 ++ 2 files changed, 18 insertions(+) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a9310a60905..667d9e442b3 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -18673,6 +18673,22 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { return llm->res->get_gf(); } +ggml_cgraph* llama_model::build_mtp_graph(const llm_graph_params& params, + ggml_tensor* hidden_state_inp, llama_token last_token_id, int n_past) const { + std::unique_ptr llm; + + switch (arch) { + case LLM_ARCH_GLM4_MOE: + { + llm = std::make_unique(*this, params, hidden_state_inp, last_token_id, n_past); + } break; + default: + GGML_ABORT("fatal error"); + } + + return llm->res->get_gf(); +} + // // interface implementation // diff --git a/src/llama-model.h b/src/llama-model.h index 6fcd74d57fd..77a18aca716 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -475,6 +475,8 @@ struct llama_model { // TODO: move this to new llm_arch_model_i interface ggml_cgraph * build_graph(const llm_graph_params & params) const; + ggml_cgraph * build_mtp_graph(const llm_graph_params & params, + ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const; private: struct impl; From cf0f7c0448c2c1736588673114558e5829db7879 Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Wed, 13 Aug 2025 02:21:17 -0400 Subject: [PATCH 05/35] broad thrust of the mtp implementation --- common/speculative.cpp | 126 ++++++++++++++++++++++++++++++++++++++++ common/speculative.h | 9 +++ include/llama.h | 17 ++++++ src/llama-context.cpp | 59 +++++++++++++++++++ src/llama-context.h | 7 +++ src/llama-graph.cpp | 4 ++ src/llama-graph.h | 1 + src/llama-model.cpp | 12 +++- tools/server/server.cpp | 36 ++++++++---- 9 files changed, 260 insertions(+), 11 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 262b2c23e72..e46a0968bde 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -5,6 +5,7 @@ #include "log.h" #include "common.h" #include "sampling.h" +#include "../src/llama-graph.h" #include #include @@ -359,3 +360,128 @@ llama_tokens common_speculative_gen_draft( } return result; } + + +llama_tokens mtp_speculative_gen_draft( + struct common_sampler * smpl, + struct llama_context * ctx, + llama_token id_last, + int32_t n_past, + int32_t last_tok_idx) { + + llama_tokens result; + + LOG_INF("step: '%d'\n", 1); + + // sample one token from the draft model -- this does NOT generalize to >1 MTP head + result.reserve(1); + + // need to determine which architecture we're using so we call the correct MTP model + const auto * model = llama_get_model(ctx); + + LOG_INF("step: '%d'\n", 2); + + //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); + //auto * gf = model.build_graph(gparams); + + LOG_INF("step: '%d'\n", 3); + + /*if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); + ret = GGML_STATUS_ALLOC_FAILED; + return nullptr; + }*/ + + //llm_graph_result res_mtp(ctx->graph_max_nodes()); + llm_graph_result * res_mtp; + llama_ubatch ubatch_mtp; + ubatch_mtp.n_tokens = 1; + ubatch_mtp.pos = &n_past; // Critical for positional encoding + + // We also need a minimal ubatch to provide positional context (RoPE) + // ubatch_mtp.tokens = &last_token_id; + // ubatch_mtp.seq_id = llama_get_main_seq_id(ctx); // Assuming a helper + // ubatch_mtp.logits = nullptr; + // ubatch_mtp.all_pos_0 = -1; + // ubatch_mtp.all_pos_1 = -1; + // ubatch_mtp.all_seq_id = -1; + + // Manually construct the graph parameters + //const llm_graph_params params_mtp = { + // /*.arch =*/ model->arch, + // /*.hparams =*/ model->hparams, + // /*.cparams =*/ ctx->cparams, + // /*.ubatch =*/ ubatch_mtp, + // /*.gtype =*/ LLM_GRAPH_TYPE_DECODER, + // /*.sched =*/ ctx->sched.get(), + // /*.backend_cpu =*/ ctx->backend_cpu, + // /*.cvec =*/ &ctx->cvec, + // /*.loras =*/ &ctx->loras, + // /*.mctx =*/ llama_get_memory(ctx), // Use the KV cache's memory context + // /*.cross =*/ &ctx->cross, + // /*.n_outputs =*/ 1, + // /*.cb =*/ ctx->graph_get_cb(), + // /*.res =*/ &res_mtp, // Point to our temporary result object + //}; + llm_graph_params params_mtp = llama_mtp_graph_params(ctx, res_mtp, ubatch_mtp); + + LOG_INF("step: '%d'\n", 4); + + // ggml_cgraph* build_mtp_graph(const llm_graph_params & params, + // ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const; + auto * last_embd = llama_get_embeddings_tensor(ctx); + + LOG_INF("step: '%d'\n", 5); + + GGML_ASSERT(model != nullptr); + GGML_ASSERT(last_embd != nullptr); + + auto * gf = llama_build_mtp_graph(model, params_mtp, last_embd, id_last, n_past); + + if (!gf) { + LOG_INF("%s: failed to initialize graph\n", __func__); + //ret = GGML_STATUS_FAILED; + return result; + } + + LOG_INF("step: '%d'\n", 6); + + const auto status = llama_graph_compute(ctx, gf, false); + + LOG_INF("step: '%d'\n", 7); + + struct ggml_tensor * logits_mtp = llama_graph_result_get_logits(res_mtp); + float * ctx_logit_pointer = llama_get_logits(ctx); + + LOG_INF("step: '%d'\n", 8); + + if (logits_mtp) { + llama_set_logits(ctx, logits_mtp); + } + + LOG_INF("step: '%d'\n", 9); + + { + common_sampler_sample(smpl, ctx, last_tok_idx, true); + + LOG_INF("step: '%d'\n", 10); + + const auto * cur_p = common_sampler_get_candidates(smpl); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + // skip accepting draft token -- since we're only drafting one token this can't affect future outputs + // smpl will accept the token if it doesn't get rejected by main model later + // common_sampler_accept(smpl, id, true); + + result.push_back(id); + } + + return result; +} diff --git a/common/speculative.h b/common/speculative.h index e69d7aaa1eb..3b048900738 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -27,6 +27,15 @@ void common_speculative_add_replacement_tgt_dft( struct common_speculative * spec, const char *source, const char *dest); + +// sample up to n_draft tokens and add them to the batch using the draft model +llama_tokens mtp_speculative_gen_draft( + struct common_sampler* smpl, + struct llama_context* ctx, + llama_token id_last, + int32_t n_past, + int32_t last_tok_idx); + // sample up to n_draft tokens and add them to the batch using the draft model llama_tokens common_speculative_gen_draft( struct common_speculative * spec, diff --git a/include/llama.h b/include/llama.h index 3bade3ae71c..2134f62d527 100644 --- a/include/llama.h +++ b/include/llama.h @@ -544,12 +544,17 @@ extern "C" { // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.) LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model); + LLAMA_API ggml_cgraph * llama_build_mtp_graph(const struct llama_model * model, const struct llm_graph_params & params, + struct ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, const char * fname_out, const llama_model_quantize_params * params); + + // // Adapters // @@ -972,6 +977,8 @@ extern "C" { // returns NULL for invalid ids. LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); + LLAMA_API void llama_set_logits(struct llama_context* ctx, struct ggml_tensor* logit_override); + // Get all output token embeddings. // when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model, // the embeddings for which llama_batch.logits[i] != 0 are stored contiguously @@ -994,6 +1001,8 @@ extern "C" { // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); + LLAMA_API ggml_tensor * llama_get_embeddings_tensor(struct llama_context * ctx); + // // Vocab // @@ -1452,6 +1461,14 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); + LLAMA_API llm_graph_params llama_mtp_graph_params(struct llama_context* ctx, class llm_graph_result * res, const struct llama_ubatch& ubatch); + + LLAMA_API ggml_status llama_graph_compute(struct llama_context * ctx, struct ggml_cgraph * gf, bool batched); + + LLAMA_API ggml_tensor * llama_graph_result_get_logits(class llm_graph_result * res); + + + #ifdef __cplusplus } #endif diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 26a5cf9c3f8..26c3e639d8a 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -6,6 +6,7 @@ #include "llama-memory.h" #include "llama-mmap.h" #include "llama-model.h" +#include "llama-graph.h" #include #include @@ -522,6 +523,14 @@ float * llama_context::get_logits() { return logits; } +void llama_context::set_logits(struct ggml_tensor * logit_override) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), logit_override); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + ggml_backend_tensor_get_async(backend_res, logit_override, logits, 0, model.vocab.n_tokens() * sizeof(float)); +} + float * llama_context::get_logits_ith(int32_t i) { int64_t j = -1; @@ -617,6 +626,10 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +ggml_tensor * llama_context::get_embeddings_tensor() { + return embd_tensor; +} + void llama_context::attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch) { @@ -1113,6 +1126,7 @@ int llama_context::decode(const llama_batch & batch_inp) { auto * t_logits = res->get_logits(); auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + embd_tensor = res->get_embd(); if (t_embd && res->get_embd_pooled()) { t_embd = res->get_embd_pooled(); @@ -1429,6 +1443,27 @@ llm_graph_params llama_context::graph_params( }; } +llm_graph_params llama_context::mtp_graph_params( + llm_graph_result* res, + const llama_ubatch& ubatch) const { + return { + /*.arch =*/ model.arch, + /*.hparams =*/ model.hparams, + /*.cparams =*/ cparams, + /*.ubatch =*/ ubatch, + /*.gtype =*/ LLM_GRAPH_TYPE_DECODER, + /*.sched =*/ sched.get(), + /*.backend_cpu =*/ backend_cpu, + /*.cvec =*/ &cvec, + /*.loras =*/ &loras, + /*.mctx =*/ memory->init_batch(*balloc, 1, false).get(), + /*.cross =*/ &cross, + /*.n_outputs =*/ 1, + /*.cb =*/ graph_get_cb(), + /*.res =*/ res, + }; +} + ggml_status llama_context::graph_compute( ggml_cgraph * gf, bool batched) { @@ -2233,6 +2268,7 @@ void llama_context::opt_epoch( llama_batch_free(batch); } + // // interface implementation // @@ -2274,6 +2310,8 @@ llama_context_params llama_context_default_params() { return result; } + + llama_context * llama_init_from_model( llama_model * model, llama_context_params params) { @@ -2412,6 +2450,11 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) { return ctx->get_logits_ith(i); } +void llama_set_logits(llama_context* ctx, struct ggml_tensor* logit_override) { + ctx->set_logits(logit_override); +} + + float * llama_get_embeddings(llama_context * ctx) { ctx->synchronize(); @@ -2430,6 +2473,13 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +ggml_tensor * llama_get_embeddings_tensor(llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_embeddings_tensor(); +} + + // llama adapter API int32_t llama_set_adapter_lora( @@ -2926,3 +2976,12 @@ void llama_opt_epoch( callback_train, callback_eval); } + +llm_graph_params llama_mtp_graph_params(llama_context* ctx, llm_graph_result* res, const llama_ubatch& ubatch) { + return ctx->mtp_graph_params(res, ubatch); +} + + +ggml_status llama_graph_compute(llama_context* ctx, ggml_cgraph* gf, bool batched) { + return ctx->graph_compute(gf, batched); +} diff --git a/src/llama-context.h b/src/llama-context.h index 25c143d56df..44bcdf6d952 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -59,6 +59,7 @@ struct llama_context { float * get_embeddings(); float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + ggml_tensor * get_embeddings_tensor(); void attach_threadpool( ggml_threadpool_t threadpool, @@ -199,6 +200,10 @@ struct llama_context { // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); + llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch) const; + + void set_logits(struct ggml_tensor* logit_override); + private: llm_graph_params graph_params( llm_graph_result * res, @@ -240,6 +245,7 @@ struct llama_context { // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE size_t embd_size = 0; // capacity (of floats) for embeddings float * embd = nullptr; + ggml_tensor * embd_tensor = nullptr; // sequence embeddings output (map of [n_embd] vectors) // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE @@ -308,3 +314,4 @@ struct llama_context { mutable int32_t n_reused = 0; // number of times the previous graph was reused }; + diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 053c72d6dc8..b5184e4559d 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1911,3 +1911,7 @@ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buck return relative_bucket; } + +ggml_tensor * llama_graph_result_get_logits(llm_graph_result * res) { + return res->get_logits(); +} diff --git a/src/llama-graph.h b/src/llama-graph.h index 6ff49de3a1c..10702ed219c 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -818,3 +818,4 @@ struct llm_graph_context { // TODO: better name int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional); + diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 667d9e442b3..8a9ba848032 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -18673,19 +18673,21 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { return llm->res->get_gf(); } -ggml_cgraph* llama_model::build_mtp_graph(const llm_graph_params& params, +ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params, ggml_tensor* hidden_state_inp, llama_token last_token_id, int n_past) const { std::unique_ptr llm; switch (arch) { case LLM_ARCH_GLM4_MOE: { + printf("step: '%d'\n", 56); llm = std::make_unique(*this, params, hidden_state_inp, last_token_id, n_past); } break; default: GGML_ABORT("fatal error"); } + printf("step: '%d'\n", 57); return llm->res->get_gf(); } @@ -19004,3 +19006,11 @@ bool llama_model_is_diffusion(const llama_model * model) { const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { return model->tensors_by_name; } + +ggml_cgraph * llama_build_mtp_graph(const llama_model * model, const llm_graph_params & params, + ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) { + printf("step: '%d'\n", 55); + + return model->build_mtp_graph(params, hidden_state_inp, last_token_id, n_past); +} + diff --git a/tools/server/server.cpp b/tools/server/server.cpp index a9ad900ce39..29d551ea513 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1294,7 +1294,8 @@ struct server_slot { mtmd_context * mctx = nullptr; common_speculative * spec = nullptr; - bool has_mtp = false; + bool has_mtp = false; + int32_t last_tok_idx = -1; std::vector lora; @@ -1432,8 +1433,8 @@ struct server_slot { } bool can_speculate() const { - // return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt; - return (ctx_dft) && params.speculative.n_max > 0 && params.cache_prompt; + return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt; + // return (ctx_dft) && params.speculative.n_max > 0 && params.cache_prompt; } void add_token(const completion_token_output & token) { @@ -1993,7 +1994,7 @@ struct server_context { SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); return false; } - + vocab = llama_model_get_vocab(model); n_ctx = llama_n_ctx(ctx); @@ -3531,6 +3532,7 @@ struct server_context { const int tok_idx = slot.i_batch - i; llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); + slot.last_tok_idx = tok_idx; slot.i_batch = -1; @@ -3567,6 +3569,8 @@ struct server_context { } } + SRV_DBG("starting speculative decoding: %d\n", 1); + // do speculative decoding for (auto & slot : slots) { if (!slot.is_processing() || !slot.can_speculate()) { @@ -3583,7 +3587,9 @@ struct server_context { } // determine the max draft that fits the current slot state + SLT_DBG(slot, "starting mtp draft: %d\n", 2); int n_draft_max = slot.params.speculative.n_max; + SLT_DBG(slot, "starting mtp draft: %d\n", 3); // note: n_past is not yet increased for the `id` token sampled above // also, need to leave space for 1 extra token to allow context shifts @@ -3601,15 +3607,25 @@ struct server_context { continue; } + SLT_DBG(slot, "slot has mtp: %d\n", slot.has_mtp); + llama_token id = slot.sampled; - struct common_speculative_params params_spec; - params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; - params_spec.p_min = slot.params.speculative.p_min; + llama_tokens draft; + if (slot.has_mtp) { + SLT_DBG(slot, "starting mtp draft: %d\n", 1); + llama_tokens draft = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); + } + else { + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; + params_spec.p_min = slot.params.speculative.p_min; + + const llama_tokens& cached_text_tokens = slot.cache_tokens.get_text_tokens(); - const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens(); - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + } // ignore small drafts if (slot.params.speculative.n_min > (int) draft.size()) { From 6e9bafc7a738b4c99f9440c0ec461e08cf6ce702 Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Fri, 15 Aug 2025 23:13:56 -0400 Subject: [PATCH 06/35] failed attempt to implement MTP; outputs tokens but KV cache management is unreasonable --- common/sampling.cpp | 5 ++ common/speculative.cpp | 135 ++++++++-------------------------------- common/speculative.h | 2 +- include/llama.h | 5 +- src/llama-context.cpp | 70 ++++++++++++++++----- src/llama-context.h | 8 ++- src/llama-model.cpp | 37 ++++++++--- tools/server/server.cpp | 26 +++++--- 8 files changed, 141 insertions(+), 147 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 9c04d35fd00..a5824ebeedb 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -348,6 +348,11 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co llama_sampler_apply(chain, &cur_p); + /*for (int k = 0; k < (int)cur_p.size; ++k) { + LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f)\n", + k, 0, cur_p.data[k].id, cur_p.data[k].p); + }*/ + GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); const llama_token id = cur_p.data[cur_p.selected].id; diff --git a/common/speculative.cpp b/common/speculative.cpp index e46a0968bde..fa784f62f69 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -6,6 +6,7 @@ #include "common.h" #include "sampling.h" #include "../src/llama-graph.h" +#include "../src/llama-context.h" #include #include @@ -362,126 +363,40 @@ llama_tokens common_speculative_gen_draft( } -llama_tokens mtp_speculative_gen_draft( - struct common_sampler * smpl, - struct llama_context * ctx, - llama_token id_last, - int32_t n_past, - int32_t last_tok_idx) { +llama_token mtp_speculative_gen_draft( + struct common_sampler* smpl, + struct llama_context* ctx, + llama_token id_last, + int32_t n_past, + int32_t last_tok_idx) { - llama_tokens result; - - LOG_INF("step: '%d'\n", 1); - - // sample one token from the draft model -- this does NOT generalize to >1 MTP head - result.reserve(1); - - // need to determine which architecture we're using so we call the correct MTP model const auto * model = llama_get_model(ctx); - - LOG_INF("step: '%d'\n", 2); - - //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); - //auto * gf = model.build_graph(gparams); - - LOG_INF("step: '%d'\n", 3); - - /*if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); - ret = GGML_STATUS_ALLOC_FAILED; - return nullptr; - }*/ - - //llm_graph_result res_mtp(ctx->graph_max_nodes()); - llm_graph_result * res_mtp; - llama_ubatch ubatch_mtp; - ubatch_mtp.n_tokens = 1; - ubatch_mtp.pos = &n_past; // Critical for positional encoding - - // We also need a minimal ubatch to provide positional context (RoPE) - // ubatch_mtp.tokens = &last_token_id; - // ubatch_mtp.seq_id = llama_get_main_seq_id(ctx); // Assuming a helper - // ubatch_mtp.logits = nullptr; - // ubatch_mtp.all_pos_0 = -1; - // ubatch_mtp.all_pos_1 = -1; - // ubatch_mtp.all_seq_id = -1; - - // Manually construct the graph parameters - //const llm_graph_params params_mtp = { - // /*.arch =*/ model->arch, - // /*.hparams =*/ model->hparams, - // /*.cparams =*/ ctx->cparams, - // /*.ubatch =*/ ubatch_mtp, - // /*.gtype =*/ LLM_GRAPH_TYPE_DECODER, - // /*.sched =*/ ctx->sched.get(), - // /*.backend_cpu =*/ ctx->backend_cpu, - // /*.cvec =*/ &ctx->cvec, - // /*.loras =*/ &ctx->loras, - // /*.mctx =*/ llama_get_memory(ctx), // Use the KV cache's memory context - // /*.cross =*/ &ctx->cross, - // /*.n_outputs =*/ 1, - // /*.cb =*/ ctx->graph_get_cb(), - // /*.res =*/ &res_mtp, // Point to our temporary result object - //}; - llm_graph_params params_mtp = llama_mtp_graph_params(ctx, res_mtp, ubatch_mtp); - - LOG_INF("step: '%d'\n", 4); - - // ggml_cgraph* build_mtp_graph(const llm_graph_params & params, - // ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const; auto * last_embd = llama_get_embeddings_tensor(ctx); - LOG_INF("step: '%d'\n", 5); - GGML_ASSERT(model != nullptr); GGML_ASSERT(last_embd != nullptr); + llama_build_and_execute_mtp_graph(ctx, last_embd, id_last, n_past, last_tok_idx); - auto * gf = llama_build_mtp_graph(model, params_mtp, last_embd, id_last, n_past); - - if (!gf) { - LOG_INF("%s: failed to initialize graph\n", __func__); - //ret = GGML_STATUS_FAILED; - return result; - } - - LOG_INF("step: '%d'\n", 6); - - const auto status = llama_graph_compute(ctx, gf, false); - - LOG_INF("step: '%d'\n", 7); - - struct ggml_tensor * logits_mtp = llama_graph_result_get_logits(res_mtp); - float * ctx_logit_pointer = llama_get_logits(ctx); + common_sampler_sample(smpl, ctx, last_tok_idx, true); - LOG_INF("step: '%d'\n", 8); + const auto* cur_p = common_sampler_get_candidates(smpl); + /*LOG_INF("cur_p->size: %d\n", cur_p->size); - if (logits_mtp) { - llama_set_logits(ctx, logits_mtp); - } - - LOG_INF("step: '%d'\n", 9); - - { - common_sampler_sample(smpl, ctx, last_tok_idx, true); - - LOG_INF("step: '%d'\n", 10); - - const auto * cur_p = common_sampler_get_candidates(smpl); - - for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { - LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); - } - - // add drafted token for each sequence - const llama_token id = cur_p->data[0].id; + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + }*/ - // skip accepting draft token -- since we're only drafting one token this can't affect future outputs - // smpl will accept the token if it doesn't get rejected by main model later - // common_sampler_accept(smpl, id, true); + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; - result.push_back(id); - } + // skip accepting draft token -- since we're only drafting one token this can't affect future outputs + // smpl will accept the token if it doesn't get rejected by main model later + // common_sampler_accept(smpl, id, true); - return result; + //llama_tokens result; + //result.reserve(1); + //result.push_back(id); + //return result; + return id; } diff --git a/common/speculative.h b/common/speculative.h index 3b048900738..6ff9e822f8d 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -29,7 +29,7 @@ void common_speculative_add_replacement_tgt_dft( // sample up to n_draft tokens and add them to the batch using the draft model -llama_tokens mtp_speculative_gen_draft( +llama_token mtp_speculative_gen_draft( struct common_sampler* smpl, struct llama_context* ctx, llama_token id_last, diff --git a/include/llama.h b/include/llama.h index 2134f62d527..16dc10d4032 100644 --- a/include/llama.h +++ b/include/llama.h @@ -977,8 +977,6 @@ extern "C" { // returns NULL for invalid ids. LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); - LLAMA_API void llama_set_logits(struct llama_context* ctx, struct ggml_tensor* logit_override); - // Get all output token embeddings. // when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model, // the embeddings for which llama_batch.logits[i] != 0 are stored contiguously @@ -1465,6 +1463,9 @@ extern "C" { LLAMA_API ggml_status llama_graph_compute(struct llama_context * ctx, struct ggml_cgraph * gf, bool batched); + LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx, + ggml_tensor* hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); + LLAMA_API ggml_tensor * llama_graph_result_get_logits(class llm_graph_result * res); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 26c3e639d8a..ca713fa3890 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -523,12 +523,16 @@ float * llama_context::get_logits() { return logits; } -void llama_context::set_logits(struct ggml_tensor * logit_override) { - ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), logit_override); +void llama_context::set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i) { + output_reorder(); + + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched_override, logit_override); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits != nullptr); - ggml_backend_tensor_get_async(backend_res, logit_override, logits, 0, model.vocab.n_tokens() * sizeof(float)); + int64_t j = output_ids[i]; + + ggml_backend_tensor_get_async(backend_res, logit_override, logits + j*model.vocab.n_tokens(), 0, model.vocab.n_tokens() * sizeof(float)); } float * llama_context::get_logits_ith(int32_t i) { @@ -1445,21 +1449,23 @@ llm_graph_params llama_context::graph_params( llm_graph_params llama_context::mtp_graph_params( llm_graph_result* res, - const llama_ubatch& ubatch) const { + const llama_ubatch& ubatch) { + size_t n_nodes = std::max(1024u, 8u * 8u * (((model.hparams.nextn_predict_layers + 1) * model.n_tensors()) / model.hparams.n_layer)); + ggml_backend_sched_t temp_sched = create_temp_scheduler(n_nodes); return { /*.arch =*/ model.arch, /*.hparams =*/ model.hparams, /*.cparams =*/ cparams, /*.ubatch =*/ ubatch, /*.gtype =*/ LLM_GRAPH_TYPE_DECODER, - /*.sched =*/ sched.get(), + /*.sched =*/ temp_sched, /*.backend_cpu =*/ backend_cpu, /*.cvec =*/ &cvec, /*.loras =*/ &loras, /*.mctx =*/ memory->init_batch(*balloc, 1, false).get(), /*.cross =*/ &cross, /*.n_outputs =*/ 1, - /*.cb =*/ graph_get_cb(), + /*.cb =*/ graph_get_cb(temp_sched), /*.res =*/ res, }; } @@ -1491,8 +1497,10 @@ ggml_status llama_context::graph_compute( return status; } -llm_graph_cb llama_context::graph_get_cb() const { - return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) { +llm_graph_cb llama_context::graph_get_cb(ggml_backend_sched * sched_override) const { + ggml_backend_sched * cb_sched = sched_override ? sched_override : sched.get(); + + return [=](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) { if (il >= 0) { ggml_format_name(cur, "%s-%d", name, il); } else { @@ -1502,7 +1510,7 @@ llm_graph_cb llama_context::graph_get_cb() const { if (!cparams.offload_kqv) { if (strcmp(name, "kqv_merged_cont") == 0) { // all nodes between the KV store and the attention output are run on the CPU - ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu); + ggml_backend_sched_set_tensor_backend(cb_sched, cur, backend_cpu); } } @@ -1515,7 +1523,7 @@ llm_graph_cb llama_context::graph_get_cb() const { for (const auto & backend : backends) { if (ggml_backend_get_device(backend.get()) == dev_layer) { if (ggml_backend_supports_op(backend.get(), cur)) { - ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get()); + ggml_backend_sched_set_tensor_backend(cb_sched, cur, backend.get()); } } } @@ -1524,6 +1532,10 @@ llm_graph_cb llama_context::graph_get_cb() const { }; } +ggml_backend_sched_t llama_context::create_temp_scheduler(size_t n_nodes) { + return ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), n_nodes, false, cparams.op_offload); +} + // // state save/load // @@ -2450,10 +2462,6 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) { return ctx->get_logits_ith(i); } -void llama_set_logits(llama_context* ctx, struct ggml_tensor* logit_override) { - ctx->set_logits(logit_override); -} - float * llama_get_embeddings(llama_context * ctx) { ctx->synchronize(); @@ -2985,3 +2993,37 @@ llm_graph_params llama_mtp_graph_params(llama_context* ctx, llm_graph_result* re ggml_status llama_graph_compute(llama_context* ctx, ggml_cgraph* gf, bool batched) { return ctx->graph_compute(gf, batched); } + +void llama_build_and_execute_mtp_graph(struct llama_context * ctx, + ggml_tensor * hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { + + const auto * model = llama_get_model(ctx); + + auto res_mtp = std::make_unique(ctx->graph_max_nodes()); + + llama_ubatch ubatch_mtp; + ubatch_mtp.n_tokens = 1; + ubatch_mtp.pos = &n_past; + + auto params_mtp = std::make_unique(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp)); + + auto* gf = model->build_mtp_graph(*params_mtp, hidden_state_inp, last_token_id, n_past); + + ggml_backend_sched_t sched = params_mtp->sched; + + ggml_backend_sched_reset(sched); // clear the allocation of the previous graph + ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it + + ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input"); + + ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors + ggml_backend_sched_graph_compute(sched, gf); // execute the graph + + struct ggml_tensor * logits_mtp = res_mtp->get_logits();; + LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp); + + if (logits_mtp) { + ctx->set_logits_ith(logits_mtp, sched, last_tok_idx); + } +} + diff --git a/src/llama-context.h b/src/llama-context.h index 44bcdf6d952..20314304c07 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -200,9 +200,11 @@ struct llama_context { // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); - llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch) const; + llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch); - void set_logits(struct ggml_tensor* logit_override); + void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i); + + ggml_backend_sched_t create_temp_scheduler(size_t n_nodes); private: llm_graph_params graph_params( @@ -211,7 +213,7 @@ struct llama_context { const llama_memory_context_i * mctx, llm_graph_type gtype) const; - llm_graph_cb graph_get_cb() const; + llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const; // TODO: read/write lora adapters and cvec size_t state_write_data(llama_io_write_i & io); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8a9ba848032..b0c096dec65 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13950,7 +13950,6 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { // For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past ) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -13958,22 +13957,43 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { const int il = hparams.n_layer - 1; const auto & mtp_layer = model.layers[il]; - ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); - ggml_set_i32(inp_pos, n_past); - llm_graph_input_attn_no_cache * inp_attn = nullptr; + // ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); + // ggml_set_i32(inp_pos, n_past); + ggml_tensor * inp_pos = build_inp_pos(); + + llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr; ggml_tensor * cur; // get MTP embedding for last (conventionally sampled) token + // ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); + // LLAMA_LOG_INFO("step: '%d'\n", 5641); + // ggml_set_i32(inp_token_id, last_token_id); + //ggml_set_no_alloc(ctx0, false); + //LLAMA_LOG_INFO("last token id: '%d'\n", last_token_id); + ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); - ggml_set_i32(inp_token_id, last_token_id); + ggml_set_name(inp_token_id, "mtp_token_id_input"); + ggml_set_input(inp_token_id); + + //ggml_tensor * inp_token_id = ggml_new_i32(ctx0, last_token_id); + //ggml_set_no_alloc(ctx0, true); + ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id); ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); + ggml_tensor * prev_embedding_leaf = ggml_dup_tensor(ctx0, hidden_state_inp); + ggml_set_name(prev_embedding_leaf, "mtp_prev_embedding_leaf"); + ggml_cpy(ctx0, hidden_state_inp, prev_embedding_leaf); + // vLLM l99 previous_hidden_states = self.hnorm(previous_hidden_states) - ggml_tensor * hidden_state_norm = build_norm(hidden_state_inp, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); + ggml_tensor * hidden_state_norm = build_norm(prev_embedding_leaf, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); + //token_emb_norm = ggml_cont(ctx0, token_emb_norm); + //hidden_state_norm = ggml_cont(ctx0, hidden_state_norm); ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); // torch.cat + + cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj @@ -14071,7 +14091,6 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { cur = ggml_add(ctx0, routed_out, shared_out); cb(cur, "ffn_out", il); } - cur = ggml_add(ctx0, cur, ffn_inp); cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il); @@ -18680,14 +18699,12 @@ ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params, switch (arch) { case LLM_ARCH_GLM4_MOE: { - printf("step: '%d'\n", 56); llm = std::make_unique(*this, params, hidden_state_inp, last_token_id, n_past); } break; default: GGML_ABORT("fatal error"); } - printf("step: '%d'\n", 57); return llm->res->get_gf(); } @@ -19009,8 +19026,8 @@ const std::vector> & llama_internal_get_te ggml_cgraph * llama_build_mtp_graph(const llama_model * model, const llm_graph_params & params, ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) { - printf("step: '%d'\n", 55); return model->build_mtp_graph(params, hidden_state_inp, last_token_id, n_past); } + diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 29d551ea513..e5039fe86ae 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2132,6 +2132,8 @@ struct server_context { // assume one speculative token (true of all well-known MTP models so far) slot.batch_spec = llama_batch_init(2, 0, 1); + SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens); + params_base.speculative.n_min = 0; params_base.speculative.n_max = 1; } @@ -3587,9 +3589,7 @@ struct server_context { } // determine the max draft that fits the current slot state - SLT_DBG(slot, "starting mtp draft: %d\n", 2); int n_draft_max = slot.params.speculative.n_max; - SLT_DBG(slot, "starting mtp draft: %d\n", 3); // note: n_past is not yet increased for the `id` token sampled above // also, need to leave space for 1 extra token to allow context shifts @@ -3607,14 +3607,13 @@ struct server_context { continue; } - SLT_DBG(slot, "slot has mtp: %d\n", slot.has_mtp); - llama_token id = slot.sampled; llama_tokens draft; if (slot.has_mtp) { - SLT_DBG(slot, "starting mtp draft: %d\n", 1); - llama_tokens draft = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); + llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); + draft.reserve(1); + draft.push_back(draft_id); } else { struct common_speculative_params params_spec; @@ -3624,7 +3623,16 @@ struct server_context { const llama_tokens& cached_text_tokens = slot.cache_tokens.get_text_tokens(); - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + } + + //llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); + //llama_tokens draft; + //draft.reserve(1); + //draft.push_back(draft_id); + + for (const auto& str : draft) { + SLT_DBG(slot, "%s\n", str); } // ignore small drafts @@ -3636,6 +3644,7 @@ struct server_context { // keep track of total number of drafted tokens tested slot.n_draft_total += draft.size(); + SLT_DBG(slot, "draft size = %d\n", draft.size()); // construct the speculation batch common_batch_clear(slot.batch_spec); @@ -3652,6 +3661,9 @@ struct server_context { // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + // if slot has mtp + // call + slot.n_past += ids.size(); slot.n_decoded += ids.size(); From 6870f9790c1bb1d0254241267b1a6c8a7fc82830 Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Sun, 17 Aug 2025 04:59:36 -0400 Subject: [PATCH 07/35] added proper KV cache management for MTP layers and slightly refactored --- common/speculative.cpp | 58 ++++++++++++++++++++++-------- common/speculative.h | 8 +++++ include/llama.h | 15 +------- src/llama-batch.cpp | 6 ++-- src/llama-context.cpp | 66 +++++++++++++++++++--------------- src/llama-context.h | 4 ++- src/llama-graph.cpp | 4 --- src/llama-kv-cache-unified.cpp | 2 +- src/llama-model.cpp | 22 +++++------- src/llama-model.h | 2 +- tools/server/server.cpp | 43 ++++++++++++++-------- 11 files changed, 136 insertions(+), 94 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index fa784f62f69..9f8384abb13 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -370,25 +370,45 @@ llama_token mtp_speculative_gen_draft( int32_t n_past, int32_t last_tok_idx) { - const auto * model = llama_get_model(ctx); - auto * last_embd = llama_get_embeddings_tensor(ctx); + llama_token token_data[] = { id_last }; + llama_pos pos_data[] = { n_past }; + int32_t n_seq_id_data[] = { 1 }; + llama_seq_id seq_id_data_internal[] = { 0 }; + llama_seq_id* seq_id_data[] = {seq_id_data_internal}; + int8_t logits_data[] = { (int8_t) (smpl != nullptr) }; + + llama_batch batch = { + /*.n_tokens = */ 1, + /*.token = */ token_data, + /*.embd = */ nullptr, + /*.pos = */ pos_data, + /*.n_seq_id = */ n_seq_id_data, + /*.seq_id = */ seq_id_data, + /*.logits = */ logits_data + }; + + llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx); + //LOG_INF("updating kv cache for n_past: %d\n", n_past); - GGML_ASSERT(model != nullptr); - GGML_ASSERT(last_embd != nullptr); - llama_build_and_execute_mtp_graph(ctx, last_embd, id_last, n_past, last_tok_idx); + if (!smpl) { + return -1; + } + else { + common_sampler_sample(smpl, ctx, last_tok_idx, true); + const auto* cur_p = common_sampler_get_candidates(smpl); - common_sampler_sample(smpl, ctx, last_tok_idx, true); + //for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) { + // LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + // k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + //} - const auto* cur_p = common_sampler_get_candidates(smpl); - /*LOG_INF("cur_p->size: %d\n", cur_p->size); + const llama_token id = cur_p->data[0].id; + return id; + } + // LOG_INF("cur_p->size: %d\n", cur_p->size); - for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { - LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); - }*/ // add drafted token for each sequence - const llama_token id = cur_p->data[0].id; // skip accepting draft token -- since we're only drafting one token this can't affect future outputs // smpl will accept the token if it doesn't get rejected by main model later @@ -398,5 +418,15 @@ llama_token mtp_speculative_gen_draft( //result.reserve(1); //result.push_back(id); //return result; - return id; +} + + +void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens) { + mtp_kv_update_data token; + for (int i = 0; i < tokens.size(); ++i) { + token = tokens[i]; + mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx); + } + + tokens.clear(); } diff --git a/common/speculative.h b/common/speculative.h index 6ff9e822f8d..786f3ad1e8d 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -12,6 +12,12 @@ struct common_speculative_params { float p_min = 0.75f; // min probability required to accept a token in the draft }; +struct mtp_kv_update_data { + llama_token id; + int32_t n_past; + int32_t tok_idx; +}; + struct common_speculative * common_speculative_init( struct llama_context * ctx_tgt, struct llama_context * ctx_dft @@ -42,3 +48,5 @@ llama_tokens common_speculative_gen_draft( struct common_speculative_params params, const llama_tokens & prompt, llama_token id_last); + +void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens); diff --git a/include/llama.h b/include/llama.h index 16dc10d4032..1de8a963cc0 100644 --- a/include/llama.h +++ b/include/llama.h @@ -544,9 +544,6 @@ extern "C" { // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.) LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model); - LLAMA_API ggml_cgraph * llama_build_mtp_graph(const struct llama_model * model, const struct llm_graph_params & params, - struct ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past); - // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, @@ -999,8 +996,6 @@ extern "C" { // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); - LLAMA_API ggml_tensor * llama_get_embeddings_tensor(struct llama_context * ctx); - // // Vocab // @@ -1459,16 +1454,8 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); - LLAMA_API llm_graph_params llama_mtp_graph_params(struct llama_context* ctx, class llm_graph_result * res, const struct llama_ubatch& ubatch); - - LLAMA_API ggml_status llama_graph_compute(struct llama_context * ctx, struct ggml_cgraph * gf, bool batched); - LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx, - ggml_tensor* hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); - - LLAMA_API ggml_tensor * llama_graph_result_get_logits(class llm_graph_result * res); - - + const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); #ifdef __cplusplus } diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 8698d89acec..ff73429301d 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -275,7 +275,9 @@ bool llama_batch_allocr::init( } } - if (!ok) { + // TEMPORARILY DISABLING THIS SANITY CHECK + // TODO: UNDO THIS IF IT WORKS + /*if (!ok) { LLAMA_LOG_ERROR( "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n" " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n" @@ -284,7 +286,7 @@ bool llama_batch_allocr::init( __func__, s, s, p0, s, seq_pos_min(s)); return false; - } + }*/ } if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ca713fa3890..34d514387b9 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1448,8 +1448,9 @@ llm_graph_params llama_context::graph_params( } llm_graph_params llama_context::mtp_graph_params( - llm_graph_result* res, - const llama_ubatch& ubatch) { + llm_graph_result * res, + const llama_ubatch& ubatch, + const llama_memory_context_i * mctx) { size_t n_nodes = std::max(1024u, 8u * 8u * (((model.hparams.nextn_predict_layers + 1) * model.n_tensors()) / model.hparams.n_layer)); ggml_backend_sched_t temp_sched = create_temp_scheduler(n_nodes); return { @@ -1462,7 +1463,7 @@ llm_graph_params llama_context::mtp_graph_params( /*.backend_cpu =*/ backend_cpu, /*.cvec =*/ &cvec, /*.loras =*/ &loras, - /*.mctx =*/ memory->init_batch(*balloc, 1, false).get(), + /*.mctx =*/ mctx, /*.cross =*/ &cross, /*.n_outputs =*/ 1, /*.cb =*/ graph_get_cb(temp_sched), @@ -1470,6 +1471,21 @@ llm_graph_params llama_context::mtp_graph_params( }; } +std::unique_ptr llama_context::mtp_memory_batch(const llama_batch& batch_inp) { + const auto& vocab = model.vocab; + const auto& hparams = model.hparams; + + const int64_t n_vocab = vocab.n_tokens(); + const int64_t n_embd = hparams.n_embd; + + if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, false)) { + LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); + return nullptr; + } + + return memory->init_batch(*balloc, 1, false); +} + ggml_status llama_context::graph_compute( ggml_cgraph * gf, bool batched) { @@ -2481,13 +2497,6 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } -ggml_tensor * llama_get_embeddings_tensor(llama_context * ctx) { - ctx->synchronize(); - - return ctx->get_embeddings_tensor(); -} - - // llama adapter API int32_t llama_set_adapter_lora( @@ -2985,42 +2994,43 @@ void llama_opt_epoch( callback_eval); } -llm_graph_params llama_mtp_graph_params(llama_context* ctx, llm_graph_result* res, const llama_ubatch& ubatch) { - return ctx->mtp_graph_params(res, ubatch); -} - - -ggml_status llama_graph_compute(llama_context* ctx, ggml_cgraph* gf, bool batched) { - return ctx->graph_compute(gf, batched); -} - void llama_build_and_execute_mtp_graph(struct llama_context * ctx, - ggml_tensor * hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { + const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { const auto * model = llama_get_model(ctx); auto res_mtp = std::make_unique(ctx->graph_max_nodes()); + llama_memory_context_ptr mctx = ctx->mtp_memory_batch(batch_inp); + const auto& ubatch_mtp = mctx->get_ubatch(); - llama_ubatch ubatch_mtp; - ubatch_mtp.n_tokens = 1; - ubatch_mtp.pos = &n_past; + //llama_ubatch ubatch_mtp; + //ubatch_mtp.n_tokens = 1; + //ubatch_mtp.pos = &n_past; - auto params_mtp = std::make_unique(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp)); + auto params_mtp = std::make_unique(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get())); + ggml_backend_sched_t sched = params_mtp->sched; - auto* gf = model->build_mtp_graph(*params_mtp, hidden_state_inp, last_token_id, n_past); + auto * last_embd = ctx->get_embeddings_ith(last_tok_idx); - ggml_backend_sched_t sched = params_mtp->sched; + if (mctx && !mctx->apply()) { + LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); + } + + auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past); ggml_backend_sched_reset(sched); // clear the allocation of the previous graph ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input"); - ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors + + ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input"); + ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors + ggml_backend_sched_graph_compute(sched, gf); // execute the graph struct ggml_tensor * logits_mtp = res_mtp->get_logits();; - LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp); + //LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp); if (logits_mtp) { ctx->set_logits_ith(logits_mtp, sched, last_tok_idx); diff --git a/src/llama-context.h b/src/llama-context.h index 20314304c07..e8ea3a4c9be 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -200,12 +200,14 @@ struct llama_context { // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); - llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch); + llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx); void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i); ggml_backend_sched_t create_temp_scheduler(size_t n_nodes); + std::unique_ptr mtp_memory_batch(const llama_batch& batch_inp); + private: llm_graph_params graph_params( llm_graph_result * res, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b5184e4559d..053c72d6dc8 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1911,7 +1911,3 @@ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buck return relative_bucket; } - -ggml_tensor * llama_graph_result_get_logits(llm_graph_result * res) { - return res->get_logits(); -} diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index e539142e6b8..ed6cf969d4a 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -41,7 +41,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( } if (model.arch == LLM_ARCH_GLM4_MOE) { // GLM-4.5: Only process up to last layer, skip final NextN layer - n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers; + n_layer_cache = hparams.n_layer;// - hparams.nextn_predict_layers; } // create a context for each buffer type diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b0c096dec65..04743e01f37 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13948,7 +13948,7 @@ struct llm_build_glm4_moe : public llm_graph_context { struct llm_build_glm4_moe_mtp : public llm_graph_context { llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params, // For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization - ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past + llama_token last_token_id, int n_past ) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -13961,7 +13961,8 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { // ggml_set_i32(inp_pos, n_past); ggml_tensor * inp_pos = build_inp_pos(); - llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr; + //llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr; + auto * inp_attn = build_attn_inp_kv_unified(); ggml_tensor * cur; @@ -13982,9 +13983,9 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id); ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); - ggml_tensor * prev_embedding_leaf = ggml_dup_tensor(ctx0, hidden_state_inp); - ggml_set_name(prev_embedding_leaf, "mtp_prev_embedding_leaf"); - ggml_cpy(ctx0, hidden_state_inp, prev_embedding_leaf); + ggml_tensor* prev_embedding_leaf = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, model.hparams.n_embd); + ggml_set_name(prev_embedding_leaf, "mtp_prev_embedding_input"); + ggml_set_input(prev_embedding_leaf); // vLLM l99 previous_hidden_states = self.hnorm(previous_hidden_states) ggml_tensor * hidden_state_norm = build_norm(prev_embedding_leaf, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); @@ -18693,13 +18694,13 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params, - ggml_tensor* hidden_state_inp, llama_token last_token_id, int n_past) const { + llama_token last_token_id, int n_past) const { std::unique_ptr llm; switch (arch) { case LLM_ARCH_GLM4_MOE: { - llm = std::make_unique(*this, params, hidden_state_inp, last_token_id, n_past); + llm = std::make_unique(*this, params, last_token_id, n_past); } break; default: GGML_ABORT("fatal error"); @@ -19024,10 +19025,3 @@ const std::vector> & llama_internal_get_te return model->tensors_by_name; } -ggml_cgraph * llama_build_mtp_graph(const llama_model * model, const llm_graph_params & params, - ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) { - - return model->build_mtp_graph(params, hidden_state_inp, last_token_id, n_past); -} - - diff --git a/src/llama-model.h b/src/llama-model.h index 77a18aca716..b28a37488f7 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -476,7 +476,7 @@ struct llama_model { // TODO: move this to new llm_arch_model_i interface ggml_cgraph * build_graph(const llm_graph_params & params) const; ggml_cgraph * build_mtp_graph(const llm_graph_params & params, - ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const; + llama_token last_token_id, int n_past) const; private: struct impl; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index e5039fe86ae..b85fa4e7691 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1278,6 +1278,7 @@ struct server_task_result_apply_lora : server_task_result { } }; + struct server_slot { int id; int id_task = -1; @@ -1295,8 +1296,9 @@ struct server_slot { common_speculative * spec = nullptr; bool has_mtp = false; + std::vector mtp_kv_update_batch; int32_t last_tok_idx = -1; - + std::vector lora; // the index relative to completion multi-task request @@ -1393,7 +1395,7 @@ struct server_slot { } bool need_embd() const { - return server_task_type_need_embd(task_type); + return server_task_type_need_embd(task_type) || has_mtp; } bool need_logits() const { @@ -1569,6 +1571,7 @@ struct server_slot { } }; + struct server_metrics { int64_t t_start = 0; @@ -1994,7 +1997,7 @@ struct server_context { SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); return false; } - + vocab = llama_model_get_vocab(model); n_ctx = llama_n_ctx(ctx); @@ -2124,18 +2127,21 @@ struct server_context { common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); } } - + // if model has MTP and no draft model is specified... else if (llama_model_n_nextn_layer(model) > 0) { SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model)); slot.has_mtp = true; - + // assume one speculative token (true of all well-known MTP models so far) slot.batch_spec = llama_batch_init(2, 0, 1); SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens); params_base.speculative.n_min = 0; params_base.speculative.n_max = 1; + + SRV_INF("%s\n", "MTP needs embeddings on decode, enabling"); + llama_set_embeddings(ctx, true); } SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); @@ -3383,7 +3389,11 @@ struct server_context { // embedding requires all tokens in the batch to be output const bool need_embd = server_task_type_need_embd(slot.task_type); + if (slot.has_mtp) { + slot.mtp_kv_update_batch.push_back({ cur_tok, slot.n_past, batch.n_tokens }); + } common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); + slot.cache_tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; @@ -3533,6 +3543,11 @@ struct server_context { const int tok_idx = slot.i_batch - i; + // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation + if (slot.has_mtp) { + mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); + } + llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); slot.last_tok_idx = tok_idx; @@ -3571,8 +3586,6 @@ struct server_context { } } - SRV_DBG("starting speculative decoding: %d\n", 1); - // do speculative decoding for (auto & slot : slots) { if (!slot.is_processing() || !slot.can_speculate()) { @@ -3631,13 +3644,9 @@ struct server_context { //draft.reserve(1); //draft.push_back(draft_id); - for (const auto& str : draft) { - SLT_DBG(slot, "%s\n", str); - } - // ignore small drafts - if (slot.params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min); + if (slot.params.speculative.n_min > (int)draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min); continue; } @@ -3661,8 +3670,12 @@ struct server_context { // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); - // if slot has mtp - // call + if (slot.has_mtp) { + for (int32_t i = 0; i < ids.size(); ++i) { + slot.mtp_kv_update_batch.push_back({ ids[i], slot.n_past + 1 + i, i }); + } + mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); + } slot.n_past += ids.size(); slot.n_decoded += ids.size(); From 382135aa3619294ab8bf87b0de4b1255ab7942f0 Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Sun, 17 Aug 2025 21:54:45 -0400 Subject: [PATCH 08/35] fixed mtp kv cache update sequencing after prompt processing --- tools/server/server.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index b85fa4e7691..e323f7b5210 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3543,18 +3543,19 @@ struct server_context { const int tok_idx = slot.i_batch - i; - // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation - if (slot.has_mtp) { - mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); - } - llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); slot.last_tok_idx = tok_idx; + SRV_INF("main loop sampled token: '%s'\n", common_token_to_piece(ctx, id, true).c_str()); slot.i_batch = -1; common_sampler_accept(slot.smpl, id, true); + // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation + if (slot.has_mtp) { + mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); + } + slot.n_decoded += 1; const int64_t t_current = ggml_time_us(); From d72f9d5691054958cd1b139f228e5e588d3974cf Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Tue, 19 Aug 2025 01:50:34 -0400 Subject: [PATCH 09/35] kludge-y kv cache management of mtp layer --- src/llama-context.cpp | 23 +++++++++++++++++++---- src/llama-kv-cache-unified.cpp | 9 +++++++++ src/llama-kv-cache-unified.h | 3 +++ tools/server/server.cpp | 2 +- 4 files changed, 32 insertions(+), 5 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 34d514387b9..a09ac6d447f 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -7,6 +7,7 @@ #include "llama-mmap.h" #include "llama-model.h" #include "llama-graph.h" +#include "llama-kv-cache-unified.h" #include #include @@ -3000,7 +3001,20 @@ void llama_build_and_execute_mtp_graph(struct llama_context * ctx, const auto * model = llama_get_model(ctx); auto res_mtp = std::make_unique(ctx->graph_max_nodes()); - llama_memory_context_ptr mctx = ctx->mtp_memory_batch(batch_inp); + std::unique_ptr mctx = ctx->mtp_memory_batch(batch_inp); + + std::vector idxs; + idxs.push_back(n_past); + llama_kv_cache_unified::slot_info sinfo = { + /*.s0 =*/ 0, + /*.s1 =*/ 0, + /*.strm =*/ { 0 }, + /*.idxs =*/ { idxs }, + }; + llama_kv_cache_unified::slot_info_vec_t sinfos; + sinfos.push_back(sinfo); + + static_cast(mctx.get())->set_sinfos(sinfos); const auto& ubatch_mtp = mctx->get_ubatch(); //llama_ubatch ubatch_mtp; @@ -3012,9 +3026,10 @@ void llama_build_and_execute_mtp_graph(struct llama_context * ctx, auto * last_embd = ctx->get_embeddings_ith(last_tok_idx); - if (mctx && !mctx->apply()) { - LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); - } + //if (mctx && !mctx->set_n_kv()) { + // LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); + //} + static_cast(mctx.get())->set_n_kv(); auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past); diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index ed6cf969d4a..53466264cd9 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -2322,6 +2322,11 @@ bool llama_kv_cache_unified_context::apply() { return true; } +void llama_kv_cache_unified_context::set_n_kv() { + n_kv = kv->get_n_kv(); +} + + llama_memory_status llama_kv_cache_unified_context::get_status() const { return status; } @@ -2384,6 +2389,10 @@ void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, con kv->set_input_pos_bucket(dst, ubatch); } +void llama_kv_cache_unified_context::set_sinfos(llama_kv_cache_unified::slot_info_vec_t new_sinfos) { + sinfos = new_sinfos; +} + uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { // the FA kernels require padding to avoid extra runtime boundary checks return cparams.flash_attn ? 256u : 32u; diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 342a675962e..c02607c2d0f 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -340,6 +340,7 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { // uint32_t get_n_kv() const; + void set_n_kv(); // TODO: temporary bool get_supports_set_rows() const; @@ -362,6 +363,8 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_sinfos(slot_info_vec_t new_sinfos); + private: llama_memory_status status; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index e323f7b5210..1191564dd2b 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3545,7 +3545,7 @@ struct server_context { llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); slot.last_tok_idx = tok_idx; - SRV_INF("main loop sampled token: '%s'\n", common_token_to_piece(ctx, id, true).c_str()); + //SRV_INF("main loop sampled token: '%s'\n", common_token_to_piece(ctx, id, true).c_str()); slot.i_batch = -1; From 471e026327cca9f6f58aeefe32129a6cb9390f4f Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Tue, 19 Aug 2025 23:10:56 -0400 Subject: [PATCH 10/35] fixed vram leak --- src/llama-context.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a09ac6d447f..62d7898b5fe 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -3050,5 +3050,7 @@ void llama_build_and_execute_mtp_graph(struct llama_context * ctx, if (logits_mtp) { ctx->set_logits_ith(logits_mtp, sched, last_tok_idx); } + + ggml_backend_sched_free(sched); } From 98bc0c6bf223f425f4ecea14f13fc46101f1b44a Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Tue, 26 Aug 2025 01:26:51 -0400 Subject: [PATCH 11/35] replace standard sampler with greedy sampler for mtp draft --- common/speculative.cpp | 4 +++- include/llama.h | 2 +- src/llama-context.cpp | 26 +++++++++++++++++++++----- src/llama-model.cpp | 4 ++++ 4 files changed, 29 insertions(+), 7 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 9f8384abb13..edeffe2d8ee 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -387,9 +387,10 @@ llama_token mtp_speculative_gen_draft( /*.logits = */ logits_data }; - llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx); + return llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx); //LOG_INF("updating kv cache for n_past: %d\n", n_past); + /* if (!smpl) { return -1; } @@ -405,6 +406,7 @@ llama_token mtp_speculative_gen_draft( const llama_token id = cur_p->data[0].id; return id; } + */ // LOG_INF("cur_p->size: %d\n", cur_p->size); diff --git a/include/llama.h b/include/llama.h index 1de8a963cc0..015c777763b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1454,7 +1454,7 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); - LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx, + LLAMA_API llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx, const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); #ifdef __cplusplus diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 62d7898b5fe..1f04b72145b 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2995,7 +2995,7 @@ void llama_opt_epoch( callback_eval); } -void llama_build_and_execute_mtp_graph(struct llama_context * ctx, +llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx, const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { const auto * model = llama_get_model(ctx); @@ -3044,13 +3044,29 @@ void llama_build_and_execute_mtp_graph(struct llama_context * ctx, ggml_backend_sched_graph_compute(sched, gf); // execute the graph - struct ggml_tensor * logits_mtp = res_mtp->get_logits();; + //struct ggml_tensor * logits_mtp = res_mtp->get_logits(); + //LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp); - if (logits_mtp) { - ctx->set_logits_ith(logits_mtp, sched, last_tok_idx); - } + //if (logits_mtp) { + // ctx->set_logits_ith(logits_mtp, sched, last_tok_idx); + //} + struct ggml_tensor * token_id_tensor = ggml_get_tensor(res_mtp->get_ctx(), "mtp_argmax_result"); + + + llama_token token_id = 0; // The C++ variable to hold the result. + + // ggml_backend_tensor_get is the function for GPU->CPU copies. + // We are copying a single 32-bit integer. + ggml_backend_tensor_get( + token_id_tensor, + &token_id, // Pointer to our C++ variable + 0, // Starting offset in bytes + sizeof(llama_token) // Number of bytes to copy + ); ggml_backend_sched_free(sched); + + return token_id; } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 04743e01f37..f9921e4b6d4 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -14100,6 +14100,10 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { res->t_logits = cur; ggml_build_forward_expand(gf, res->t_logits); + + struct ggml_tensor * token_id_tensor = ggml_argmax(ctx0, cur); + ggml_set_name(token_id_tensor, "mtp_argmax_result"); + ggml_build_forward_expand(gf, token_id_tensor); } }; From 9fab53e4388c20aef497efd82e86dcb99ca58064 Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Tue, 2 Sep 2025 17:14:09 -0400 Subject: [PATCH 12/35] fixed mtp kv cache update step in cases where prompt size > n_batch and n_ubatch --- common/speculative.cpp | 13 ++++++++++--- common/speculative.h | 2 +- tools/server/server.cpp | 20 ++++++++++++++++---- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index edeffe2d8ee..c1d9149ea13 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -423,11 +423,18 @@ llama_token mtp_speculative_gen_draft( } -void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens) { +void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, size_t batch_start, size_t n_tokens) { mtp_kv_update_data token; - for (int i = 0; i < tokens.size(); ++i) { + + if (n_tokens < 0) { + n_tokens = tokens.size(); + } + + for (int i = 0; i < std::min(tokens.size(), n_tokens); ++i) { token = tokens[i]; - mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx); + //fprintf(stderr, "updating mtp kv cache with token (%d, %d, %d)\n", token.id, token.n_past, (int) (token.tok_idx - batch_start)); + + mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx - batch_start); } tokens.clear(); diff --git a/common/speculative.h b/common/speculative.h index 786f3ad1e8d..bb29c07bb6a 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -49,4 +49,4 @@ llama_tokens common_speculative_gen_draft( const llama_tokens & prompt, llama_token id_last); -void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens); +void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, size_t batch_start = 0, size_t n_tokens = -1); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 1191564dd2b..34053cd0403 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1405,9 +1405,14 @@ struct server_slot { // if the context does not have a memory module then all embeddings have to be computed within a single ubatch // also we cannot split if the pooling would require any past tokens bool can_split() const { + //fprintf(stderr, "need_embd() %d\n", need_embd()); + //fprintf(stderr, "llama_get_memory(ctx) %d\n", llama_get_memory(ctx) != nullptr); + //fprintf(stderr, "POOLING_TYPE check %d\n", llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + return !need_embd() || - (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST) || + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_NONE); // this seems to save embeddings for whole batch? } bool can_batch_with(server_slot & other_slot) const { @@ -3508,6 +3513,13 @@ struct server_context { continue; // continue loop of n_batch } + for (auto & slot : slots) { + // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation + if (slot.has_mtp) { + mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, i, n_tokens); + } + } + // move the head of the batch forward with the number of tokens we just processed i_next = i + n_tokens; @@ -3552,9 +3564,9 @@ struct server_context { common_sampler_accept(slot.smpl, id, true); // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation - if (slot.has_mtp) { - mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); - } + //if (slot.has_mtp) { + // mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); + //} slot.n_decoded += 1; From 07670a22c63b1fa335d6ec1c4a1e4255a920848c Mon Sep 17 00:00:00 2001 From: samuel Date: Wed, 3 Sep 2025 13:25:21 -0300 Subject: [PATCH 13/35] feat: implemented sampling for MTP --- common/speculative.cpp | 50 ++++++----------------------------------- common/speculative.h | 8 +++---- include/llama.h | 4 ++-- src/llama-context.cpp | 51 +++++++++++++++++++++--------------------- src/llama-model.cpp | 16 ++++--------- 5 files changed, 43 insertions(+), 86 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index c1d9149ea13..8d849df94b8 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -370,56 +370,20 @@ llama_token mtp_speculative_gen_draft( int32_t n_past, int32_t last_tok_idx) { - llama_token token_data[] = { id_last }; - llama_pos pos_data[] = { n_past }; - int32_t n_seq_id_data[] = { 1 }; - llama_seq_id seq_id_data_internal[] = { 0 }; - llama_seq_id* seq_id_data[] = {seq_id_data_internal}; - int8_t logits_data[] = { (int8_t) (smpl != nullptr) }; - - llama_batch batch = { - /*.n_tokens = */ 1, - /*.token = */ token_data, - /*.embd = */ nullptr, - /*.pos = */ pos_data, - /*.n_seq_id = */ n_seq_id_data, - /*.seq_id = */ seq_id_data, - /*.logits = */ logits_data - }; - - return llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx); - //LOG_INF("updating kv cache for n_past: %d\n", n_past); - - /* if (!smpl) { return -1; } - else { - common_sampler_sample(smpl, ctx, last_tok_idx, true); - const auto* cur_p = common_sampler_get_candidates(smpl); - //for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) { - // LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - // k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); - //} - - const llama_token id = cur_p->data[0].id; - return id; - } - */ - // LOG_INF("cur_p->size: %d\n", cur_p->size); + llama_batch batch = llama_batch_init(1, 0, 1); + common_batch_add(batch, id_last, n_past, {0}, true); + llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx); - // add drafted token for each sequence + llama_token id = common_sampler_sample(smpl, ctx, last_tok_idx, true); - // skip accepting draft token -- since we're only drafting one token this can't affect future outputs - // smpl will accept the token if it doesn't get rejected by main model later - // common_sampler_accept(smpl, id, true); + llama_batch_free(batch); - //llama_tokens result; - //result.reserve(1); - //result.push_back(id); - //return result; + return id; } @@ -438,4 +402,4 @@ void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, size_t batch_start = 0, size_t n_tokens = -1); diff --git a/include/llama.h b/include/llama.h index 015c777763b..e43cd83468d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1454,8 +1454,8 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); - LLAMA_API llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx, - const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); + LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx, + const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); #ifdef __cplusplus } diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 1f04b72145b..fb285a8d297 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2995,7 +2995,7 @@ void llama_opt_epoch( callback_eval); } -llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx, +void llama_build_and_execute_mtp_graph(struct llama_context * ctx, const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { const auto * model = llama_get_model(ctx); @@ -3033,6 +3033,12 @@ llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx, auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past); + if (!gf) { + LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__); + if (sched) ggml_backend_sched_free(sched); + return; + } + ggml_backend_sched_reset(sched); // clear the allocation of the previous graph ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it @@ -3044,29 +3050,24 @@ llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx, ggml_backend_sched_graph_compute(sched, gf); // execute the graph - //struct ggml_tensor * logits_mtp = res_mtp->get_logits(); - - //LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp); - - //if (logits_mtp) { - // ctx->set_logits_ith(logits_mtp, sched, last_tok_idx); - //} - struct ggml_tensor * token_id_tensor = ggml_get_tensor(res_mtp->get_ctx(), "mtp_argmax_result"); - - - llama_token token_id = 0; // The C++ variable to hold the result. - - // ggml_backend_tensor_get is the function for GPU->CPU copies. - // We are copying a single 32-bit integer. - ggml_backend_tensor_get( - token_id_tensor, - &token_id, // Pointer to our C++ variable - 0, // Starting offset in bytes - sizeof(llama_token) // Number of bytes to copy - ); + struct ggml_tensor * logits_mtp = res_mtp->get_logits(); + + if (logits_mtp) { + float * logits_dest = ctx->get_logits_ith(last_tok_idx); + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched, logits_mtp); + if (backend_res) { + // ggml_backend_tensor_get is the function for GPU->CPU copies. + // We are copying a single 32-bit integer. + ggml_backend_tensor_get(logits_mtp, + logits_dest, // Pointer to our C++ variable + 0, // Starting offset in bytes + ggml_nbytes(logits_mtp)); // Number of bytes to copy + } else { + LLAMA_LOG_ERROR("%s: ERROR - Could not obtain the backend for the logits tensor.", __func__); + } + } else { + LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__); + } ggml_backend_sched_free(sched); - - return token_id; -} - +} \ No newline at end of file diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f9921e4b6d4..dd4bf211b7e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13950,6 +13950,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { // For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization llama_token last_token_id, int n_past ) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -13964,8 +13965,6 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { //llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr; auto * inp_attn = build_attn_inp_kv_unified(); - ggml_tensor * cur; - // get MTP embedding for last (conventionally sampled) token // ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); // LLAMA_LOG_INFO("step: '%d'\n", 5641); @@ -13979,7 +13978,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { //ggml_tensor * inp_token_id = ggml_new_i32(ctx0, last_token_id); //ggml_set_no_alloc(ctx0, true); - + ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id); ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); @@ -13994,9 +13993,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); // torch.cat - - cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj - + ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj // now proceed through last layer (skipped in main model) ggml_tensor * inpSA = cur; @@ -14096,14 +14093,9 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il); cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur); - + res->t_logits = cur; - ggml_build_forward_expand(gf, res->t_logits); - - struct ggml_tensor * token_id_tensor = ggml_argmax(ctx0, cur); - ggml_set_name(token_id_tensor, "mtp_argmax_result"); - ggml_build_forward_expand(gf, token_id_tensor); } }; From 5a5bce85777041d841393b4396e28f8e3065bb10 Mon Sep 17 00:00:00 2001 From: samuel Date: Wed, 3 Sep 2025 17:56:14 -0300 Subject: [PATCH 14/35] fix: add sample acceptance --- common/speculative.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/common/speculative.cpp b/common/speculative.cpp index 8d849df94b8..5edd4aa815b 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -381,6 +381,14 @@ llama_token mtp_speculative_gen_draft( llama_token id = common_sampler_sample(smpl, ctx, last_tok_idx, true); + const auto * cur_p = common_sampler_get_candidates(smpl); + for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) { + LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + } + + common_sampler_accept(smpl, id, true); + llama_batch_free(batch); return id; From 8742ce0e39823eeb101bb5b6099ff4ca7be10c6e Mon Sep 17 00:00:00 2001 From: samuel Date: Sat, 6 Sep 2025 00:21:18 -0300 Subject: [PATCH 15/35] feat: apply logits + greedy sampler --- common/sampling.cpp | 4 ++++ common/sampling.h | 2 ++ common/speculative.cpp | 19 +++++++++++++------ 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index a5824ebeedb..452cefee3b9 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -582,3 +582,7 @@ std::vector common_sampler_types_from_chars(const std::stri return samplers; } + +void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p) { + llama_sampler_apply(gsmpl->chain, cur_p); +} \ No newline at end of file diff --git a/common/sampling.h b/common/sampling.h index 2064421db4e..b424d7d6d70 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -105,3 +105,5 @@ std::vector common_sampler_types_from_chars(const std: llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind, const char * grammar_data); + +void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p); \ No newline at end of file diff --git a/common/speculative.cpp b/common/speculative.cpp index 5edd4aa815b..77ed75913d5 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -379,15 +379,22 @@ llama_token mtp_speculative_gen_draft( llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx); - llama_token id = common_sampler_sample(smpl, ctx, last_tok_idx, true); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + const int n_vocab = llama_n_vocab(vocab); - const auto * cur_p = common_sampler_get_candidates(smpl); - for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) { - LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + llama_token_data_array * cur_p = common_sampler_get_candidates(smpl); + + cur_p->size = n_vocab; + for (int i = 0; i < n_vocab; ++i) { + cur_p->data[i].id = i; + cur_p->data[i].logit = llama_get_logits_ith(ctx, last_tok_idx)[i]; } + cur_p->sorted = false; + + common_sampler_apply_chain(smpl, cur_p); - common_sampler_accept(smpl, id, true); + const llama_token id = cur_p->data[0].id; llama_batch_free(batch); From 1318b2de82716710b9853e07bd640443a5a025bb Mon Sep 17 00:00:00 2001 From: samuel Date: Sun, 14 Sep 2025 10:22:59 -0300 Subject: [PATCH 16/35] mtp-batch (wip): move mtp execution to batch format --- common/speculative.cpp | 47 +++++++------ include/llama.h | 5 +- src/llama-batch.cpp | 15 ++-- src/llama-context.cpp | 152 +++++++++++++++++++++++++---------------- src/llama-graph.cpp | 20 ++++++ src/llama-graph.h | 1 + src/llama-model.cpp | 52 ++++---------- src/llama-model.h | 3 +- 8 files changed, 166 insertions(+), 129 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 77ed75913d5..d13666c9f95 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -374,47 +374,54 @@ llama_token mtp_speculative_gen_draft( return -1; } - llama_batch batch = llama_batch_init(1, 0, 1); - common_batch_add(batch, id_last, n_past, {0}, true); + llama_batch mtp_batch = llama_batch_init(1, 0, 1); + common_batch_add(mtp_batch, id_last, n_past, {0}, true); + mtp_batch.update_mtp_kv = true; - llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx); + llama_decode(ctx, mtp_batch); + llama_batch_free(mtp_batch); const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); const int n_vocab = llama_n_vocab(vocab); - llama_token_data_array * cur_p = common_sampler_get_candidates(smpl); - cur_p->size = n_vocab; for (int i = 0; i < n_vocab; ++i) { cur_p->data[i].id = i; - cur_p->data[i].logit = llama_get_logits_ith(ctx, last_tok_idx)[i]; + cur_p->data[i].logit = llama_get_logits_ith(ctx, 0)[i]; // TODO: check if position 0 is the right } cur_p->sorted = false; - common_sampler_apply_chain(smpl, cur_p); - - const llama_token id = cur_p->data[0].id; - - llama_batch_free(batch); - - return id; + + return cur_p->data[0].id; } void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, size_t batch_start, size_t n_tokens) { - mtp_kv_update_data token; - + if (tokens.empty()) { + tokens.clear(); + return; + } if (n_tokens < 0) { n_tokens = tokens.size(); } + const size_t n_to_process = std::min((size_t)tokens.size(), n_tokens); + + LOG_DBG( + "[MTP BATCHING] mtp_update_kv_cache call for %zu tokens.\n", + n_to_process + ); + llama_batch mtp_batch = llama_batch_init(n_to_process, 0, 1); + + for (size_t i = 0; i < n_to_process; ++i) { + const mtp_kv_update_data& token_data = tokens[i]; + common_batch_add(mtp_batch, token_data.id, token_data.n_past, {0}, false); + } - for (int i = 0; i < std::min(tokens.size(), n_tokens); ++i) { - token = tokens[i]; - //fprintf(stderr, "updating mtp kv cache with token (%d, %d, %d)\n", token.id, token.n_past, (int) (token.tok_idx - batch_start)); + mtp_batch.update_mtp_kv = true; - mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx - batch_start); - } + llama_decode(ctx, mtp_batch); + llama_batch_free(mtp_batch); tokens.clear(); } \ No newline at end of file diff --git a/include/llama.h b/include/llama.h index e43cd83468d..0916bb9c5f2 100644 --- a/include/llama.h +++ b/include/llama.h @@ -230,6 +230,7 @@ extern "C" { int32_t * n_seq_id; llama_seq_id ** seq_id; int8_t * logits; // TODO: rename this to "output" + bool update_mtp_kv; } llama_batch; enum llama_model_kv_override_type { @@ -1454,8 +1455,8 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); - LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx, - const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); + // LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx, + // const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); #ifdef __cplusplus } diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index ff73429301d..589b138531b 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -834,13 +834,14 @@ struct llama_batch llama_batch_get_one( struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { llama_batch batch = { - /*n_tokens =*/ 0, - /*tokens =*/ nullptr, - /*embd =*/ nullptr, - /*pos =*/ nullptr, - /*n_seq_id =*/ nullptr, - /*seq_id =*/ nullptr, - /*logits =*/ nullptr, + /*n_tokens =*/ 0, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*update_mtp_kv =*/ false, }; if (embd) { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index fb285a8d297..69549edb1c5 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1070,6 +1070,7 @@ int llama_context::decode(const llama_batch & batch_inp) { }; int64_t n_outputs_prev = 0; + const bool do_mtp_kv_update = batch_inp.update_mtp_kv; do { const auto & ubatch = mctx->get_ubatch(); @@ -1129,6 +1130,39 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} + if (do_mtp_kv_update) { + LLAMA_LOG_INFO( + "[MTP BATCHING] Processando MTP KV update para um ubatch de %u tokens.\n", + ubatch.n_tokens + ); + auto res_mtp = std::make_unique(graph_max_nodes()); + + auto params_mtp = mtp_graph_params(res_mtp.get(), ubatch, mctx.get()); + ggml_backend_sched_t sched_mtp = params_mtp.sched; + + auto * gf_mtp = model.build_mtp_graph(params_mtp); + if (gf_mtp) { + ggml_backend_sched_alloc_graph(sched_mtp, gf_mtp); + + ggml_tensor* prev_embedding_tensor = res->get_embd(); + ggml_tensor* embd_input_mtp = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embeddings_batch_input"); + + // ggml_backend_tensor_set(embd_input_mtp, prev_embedding_tensor->data, 0, ggml_nbytes(prev_embedding_tensor)); + ggml_backend_tensor_copy(prev_embedding_tensor, embd_input_mtp); + + ggml_backend_sched_graph_compute(sched_mtp, gf_mtp); + + if (ubatch.output[0]) { + struct ggml_tensor * logits_mtp = res_mtp->get_logits(); + if (logits_mtp) { + float * logits_dest = logits + n_outputs_prev * n_vocab; + ggml_backend_tensor_get(logits_mtp, logits_dest, 0, ggml_nbytes(logits_mtp)); + } + } + } + ggml_backend_sched_free(sched_mtp); + } + auto * t_logits = res->get_logits(); auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; embd_tensor = res->get_embd(); @@ -2995,79 +3029,79 @@ void llama_opt_epoch( callback_eval); } -void llama_build_and_execute_mtp_graph(struct llama_context * ctx, - const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { +// void llama_build_and_execute_mtp_graph(struct llama_context * ctx, +// const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { - const auto * model = llama_get_model(ctx); +// const auto * model = llama_get_model(ctx); - auto res_mtp = std::make_unique(ctx->graph_max_nodes()); - std::unique_ptr mctx = ctx->mtp_memory_batch(batch_inp); +// auto res_mtp = std::make_unique(ctx->graph_max_nodes()); +// std::unique_ptr mctx = ctx->mtp_memory_batch(batch_inp); - std::vector idxs; - idxs.push_back(n_past); - llama_kv_cache_unified::slot_info sinfo = { - /*.s0 =*/ 0, - /*.s1 =*/ 0, - /*.strm =*/ { 0 }, - /*.idxs =*/ { idxs }, - }; - llama_kv_cache_unified::slot_info_vec_t sinfos; - sinfos.push_back(sinfo); +// std::vector idxs; +// idxs.push_back(n_past); +// llama_kv_cache_unified::slot_info sinfo = { +// /*.s0 =*/ 0, +// /*.s1 =*/ 0, +// /*.strm =*/ { 0 }, +// /*.idxs =*/ { idxs }, +// }; +// llama_kv_cache_unified::slot_info_vec_t sinfos; +// sinfos.push_back(sinfo); - static_cast(mctx.get())->set_sinfos(sinfos); - const auto& ubatch_mtp = mctx->get_ubatch(); +// static_cast(mctx.get())->set_sinfos(sinfos); +// const auto& ubatch_mtp = mctx->get_ubatch(); - //llama_ubatch ubatch_mtp; - //ubatch_mtp.n_tokens = 1; - //ubatch_mtp.pos = &n_past; +// //llama_ubatch ubatch_mtp; +// //ubatch_mtp.n_tokens = 1; +// //ubatch_mtp.pos = &n_past; - auto params_mtp = std::make_unique(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get())); - ggml_backend_sched_t sched = params_mtp->sched; +// auto params_mtp = std::make_unique(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get())); +// ggml_backend_sched_t sched = params_mtp->sched; - auto * last_embd = ctx->get_embeddings_ith(last_tok_idx); +// auto * last_embd = ctx->get_embeddings_ith(last_tok_idx); - //if (mctx && !mctx->set_n_kv()) { - // LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); - //} - static_cast(mctx.get())->set_n_kv(); +// //if (mctx && !mctx->set_n_kv()) { +// // LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); +// //} +// static_cast(mctx.get())->set_n_kv(); - auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past); +// auto * gf = model->build_mtp_graph(*params_mtp); - if (!gf) { - LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__); - if (sched) ggml_backend_sched_free(sched); - return; - } +// if (!gf) { +// LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__); +// if (sched) ggml_backend_sched_free(sched); +// return; +// } - ggml_backend_sched_reset(sched); // clear the allocation of the previous graph - ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it +// ggml_backend_sched_reset(sched); // clear the allocation of the previous graph +// ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it - ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input"); - ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors +// ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input"); +// ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors - ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input"); - ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors +// ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input"); +// ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors - ggml_backend_sched_graph_compute(sched, gf); // execute the graph +// ggml_backend_sched_graph_compute(sched, gf); // execute the graph - struct ggml_tensor * logits_mtp = res_mtp->get_logits(); +// struct ggml_tensor * logits_mtp = res_mtp->get_logits(); - if (logits_mtp) { - float * logits_dest = ctx->get_logits_ith(last_tok_idx); - ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched, logits_mtp); - if (backend_res) { - // ggml_backend_tensor_get is the function for GPU->CPU copies. - // We are copying a single 32-bit integer. - ggml_backend_tensor_get(logits_mtp, - logits_dest, // Pointer to our C++ variable - 0, // Starting offset in bytes - ggml_nbytes(logits_mtp)); // Number of bytes to copy - } else { - LLAMA_LOG_ERROR("%s: ERROR - Could not obtain the backend for the logits tensor.", __func__); - } - } else { - LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__); - } +// if (logits_mtp) { +// float * logits_dest = ctx->get_logits_ith(last_tok_idx); +// ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched, logits_mtp); +// if (backend_res) { +// // ggml_backend_tensor_get is the function for GPU->CPU copies. +// // We are copying a single 32-bit integer. +// ggml_backend_tensor_get(logits_mtp, +// logits_dest, // Pointer to our C++ variable +// 0, // Starting offset in bytes +// ggml_nbytes(logits_mtp)); // Number of bytes to copy +// } else { +// LLAMA_LOG_ERROR("%s: ERROR - Could not obtain the backend for the logits tensor.", __func__); +// } +// } else { +// LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__); +// } - ggml_backend_sched_free(sched); -} \ No newline at end of file +// ggml_backend_sched_free(sched); +// } \ No newline at end of file diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 053c72d6dc8..be7de40454e 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1074,6 +1074,26 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { return cur; } + +ggml_tensor * llm_graph_context::build_inp_embd_mtp(ggml_tensor * mtp_tok_embd) const { + auto inp = std::make_unique(); + ggml_tensor * cur = nullptr; + + if (ubatch.token) { + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + ggml_set_name(inp->tokens, "mtp_inp_tokens"); + ggml_set_input(inp->tokens); + + cur = ggml_get_rows(ctx0, mtp_tok_embd, inp->tokens); + } else { + GGML_ABORT("fatal error: MTP update expects token IDs, not embeddings"); + } + + cb(cur, "mtp_inp_embd", -1); + res->add_input(std::move(inp)); + return cur; +} + ggml_tensor * llm_graph_context::build_inp_pos() const { auto inp = std::make_unique(hparams.n_pos_per_embd()); diff --git a/src/llama-graph.h b/src/llama-graph.h index 10702ed219c..57772d9c158 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -664,6 +664,7 @@ struct llm_graph_context { // ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const; + ggml_tensor * build_inp_embd_mtp(ggml_tensor * mtp_tok_embd) const; ggml_tensor * build_inp_pos() const; ggml_tensor * build_inp_attn_scale() const; ggml_tensor * build_inp_out_ids() const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index dd4bf211b7e..cce99ef3b18 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13946,54 +13946,29 @@ struct llm_build_glm4_moe : public llm_graph_context { }; struct llm_build_glm4_moe_mtp : public llm_graph_context { - llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params, - // For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization - llama_token last_token_id, int n_past - ) : llm_graph_context(params) { + llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - // Assuming a single MTP layer at the end const int il = hparams.n_layer - 1; const auto & mtp_layer = model.layers[il]; - // ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); - // ggml_set_i32(inp_pos, n_past); ggml_tensor * inp_pos = build_inp_pos(); - - //llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr; auto * inp_attn = build_attn_inp_kv_unified(); - // get MTP embedding for last (conventionally sampled) token - // ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); - // LLAMA_LOG_INFO("step: '%d'\n", 5641); - // ggml_set_i32(inp_token_id, last_token_id); - //ggml_set_no_alloc(ctx0, false); - //LLAMA_LOG_INFO("last token id: '%d'\n", last_token_id); - - ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); - ggml_set_name(inp_token_id, "mtp_token_id_input"); - ggml_set_input(inp_token_id); - - //ggml_tensor * inp_token_id = ggml_new_i32(ctx0, last_token_id); - //ggml_set_no_alloc(ctx0, true); + ggml_tensor* prev_embeddings_batch = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_embd, n_tokens); + ggml_set_name(prev_embeddings_batch, "mtp_prev_embeddings_batch_input"); + ggml_set_input(prev_embeddings_batch); - ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id); - ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); - - ggml_tensor* prev_embedding_leaf = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, model.hparams.n_embd); - ggml_set_name(prev_embedding_leaf, "mtp_prev_embedding_input"); - ggml_set_input(prev_embedding_leaf); + ggml_tensor * token_emb = build_inp_embd_mtp(mtp_layer.nextn.embed_tokens); - // vLLM l99 previous_hidden_states = self.hnorm(previous_hidden_states) - ggml_tensor * hidden_state_norm = build_norm(prev_embedding_leaf, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); - //token_emb_norm = ggml_cont(ctx0, token_emb_norm); - //hidden_state_norm = ggml_cont(ctx0, hidden_state_norm); + ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); + ggml_tensor * hidden_state_norm = build_norm(prev_embeddings_batch, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); - ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); // torch.cat + ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); - ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj + ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // now proceed through last layer (skipped in main model) ggml_tensor * inpSA = cur; @@ -14090,11 +14065,11 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { cb(cur, "ffn_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); - cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il); cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur); - + res->t_logits = cur; + ggml_build_forward_expand(gf, res->t_logits); } }; @@ -18689,14 +18664,13 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { return llm->res->get_gf(); } -ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params, - llama_token last_token_id, int n_past) const { +ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params) const { std::unique_ptr llm; switch (arch) { case LLM_ARCH_GLM4_MOE: { - llm = std::make_unique(*this, params, last_token_id, n_past); + llm = std::make_unique(*this, params); } break; default: GGML_ABORT("fatal error"); diff --git a/src/llama-model.h b/src/llama-model.h index b28a37488f7..f5f9452a5b1 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -475,8 +475,7 @@ struct llama_model { // TODO: move this to new llm_arch_model_i interface ggml_cgraph * build_graph(const llm_graph_params & params) const; - ggml_cgraph * build_mtp_graph(const llm_graph_params & params, - llama_token last_token_id, int n_past) const; + ggml_cgraph * build_mtp_graph(const llm_graph_params& params) const; private: struct impl; From 042eb8a829876ed175320df9c8133bcea0c40460 Mon Sep 17 00:00:00 2001 From: samuel Date: Sun, 21 Sep 2025 21:29:00 -0300 Subject: [PATCH 17/35] mtp-batch (wip): merge mtp and model graph --- src/llama-context.cpp | 84 ++++++++--------------------------- src/llama-context.h | 8 ++-- src/llama-graph.h | 1 + src/llama-model.cpp | 98 +++++++++++++++++------------------------ src/llama-model.h | 1 - tools/server/server.cpp | 8 +++- 6 files changed, 70 insertions(+), 130 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 69549edb1c5..754ad6a041c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -729,7 +729,8 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } -llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { +llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret, + bool do_mtp_kv_update) { if (mctx && !mctx->apply()) { LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); ret = GGML_STATUS_FAILED; @@ -741,7 +742,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll // the new graph parameters // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters - const auto gparams = graph_params(res, ubatch, mctx, gtype); + const auto gparams = graph_params(res, ubatch, mctx, gtype, do_mtp_kv_update); if (!graph_reuse_disable && res->can_reuse(gparams)) { //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); @@ -781,7 +782,15 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); } + const int64_t t_exec_start_us = ggml_time_us(); const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1); + const int64_t t_exec_end_us = ggml_time_us(); + LLAMA_LOG_INFO( + "[PERF] Graph compute time: %.2f ms (ubatch_size: %u, MTP path: %s)\n", + (t_exec_end_us - t_exec_start_us) / 1000.0, + ubatch.n_tokens, + do_mtp_kv_update ? "yes" : "no" + ); if (status != GGML_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); ret = status; @@ -850,7 +859,7 @@ int llama_context::encode(const llama_batch & batch_inp) { cparams.causal_attn = false; ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, false); cparams.causal_attn = causal_attn_org; @@ -1092,7 +1101,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache @@ -1130,39 +1139,6 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - if (do_mtp_kv_update) { - LLAMA_LOG_INFO( - "[MTP BATCHING] Processando MTP KV update para um ubatch de %u tokens.\n", - ubatch.n_tokens - ); - auto res_mtp = std::make_unique(graph_max_nodes()); - - auto params_mtp = mtp_graph_params(res_mtp.get(), ubatch, mctx.get()); - ggml_backend_sched_t sched_mtp = params_mtp.sched; - - auto * gf_mtp = model.build_mtp_graph(params_mtp); - if (gf_mtp) { - ggml_backend_sched_alloc_graph(sched_mtp, gf_mtp); - - ggml_tensor* prev_embedding_tensor = res->get_embd(); - ggml_tensor* embd_input_mtp = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embeddings_batch_input"); - - // ggml_backend_tensor_set(embd_input_mtp, prev_embedding_tensor->data, 0, ggml_nbytes(prev_embedding_tensor)); - ggml_backend_tensor_copy(prev_embedding_tensor, embd_input_mtp); - - ggml_backend_sched_graph_compute(sched_mtp, gf_mtp); - - if (ubatch.output[0]) { - struct ggml_tensor * logits_mtp = res_mtp->get_logits(); - if (logits_mtp) { - float * logits_dest = logits + n_outputs_prev * n_vocab; - ggml_backend_tensor_get(logits_mtp, logits_dest, 0, ggml_nbytes(logits_mtp)); - } - } - } - ggml_backend_sched_free(sched_mtp); - } - auto * t_logits = res->get_logits(); auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; embd_tensor = res->get_embd(); @@ -1442,7 +1418,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u auto * res = gf_res_reserve.get(); - const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, false); res->reset(); @@ -1462,8 +1438,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u llm_graph_params llama_context::graph_params( llm_graph_result * res, const llama_ubatch & ubatch, - const llama_memory_context_i * mctx, - llm_graph_type gtype) const { + const llama_memory_context_i * mctx, + llm_graph_type gtype, + bool update_mtp_kv) const { return { /*.arch =*/ model.arch, /*.hparams =*/ model.hparams, @@ -1476,36 +1453,13 @@ llm_graph_params llama_context::graph_params( /*.loras =*/ &loras, /*.mctx =*/ mctx, /*.cross =*/ &cross, + /*.update_mtp_kv =*/ update_mtp_kv, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), /*.res =*/ res, }; } -llm_graph_params llama_context::mtp_graph_params( - llm_graph_result * res, - const llama_ubatch& ubatch, - const llama_memory_context_i * mctx) { - size_t n_nodes = std::max(1024u, 8u * 8u * (((model.hparams.nextn_predict_layers + 1) * model.n_tensors()) / model.hparams.n_layer)); - ggml_backend_sched_t temp_sched = create_temp_scheduler(n_nodes); - return { - /*.arch =*/ model.arch, - /*.hparams =*/ model.hparams, - /*.cparams =*/ cparams, - /*.ubatch =*/ ubatch, - /*.gtype =*/ LLM_GRAPH_TYPE_DECODER, - /*.sched =*/ temp_sched, - /*.backend_cpu =*/ backend_cpu, - /*.cvec =*/ &cvec, - /*.loras =*/ &loras, - /*.mctx =*/ mctx, - /*.cross =*/ &cross, - /*.n_outputs =*/ 1, - /*.cb =*/ graph_get_cb(temp_sched), - /*.res =*/ res, - }; -} - std::unique_ptr llama_context::mtp_memory_batch(const llama_batch& batch_inp) { const auto& vocab = model.vocab; const auto& hparams = model.hparams; @@ -2240,7 +2194,7 @@ void llama_context::opt_epoch_iter( auto * res = gf_res_prev.get(); - const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT, false); res->reset(); diff --git a/src/llama-context.h b/src/llama-context.h index e8ea3a4c9be..88f63e88d73 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -99,7 +99,8 @@ struct llama_context { const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, - ggml_status & ret); + ggml_status & ret, + const bool do_mtp_kv_update); int encode(const llama_batch & batch_inp); int decode(const llama_batch & batch_inp); @@ -200,8 +201,6 @@ struct llama_context { // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); - llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx); - void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i); ggml_backend_sched_t create_temp_scheduler(size_t n_nodes); @@ -213,7 +212,8 @@ struct llama_context { llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx, - llm_graph_type gtype) const; + llm_graph_type gtype, + bool update_mtp_kv) const; llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const; diff --git a/src/llama-graph.h b/src/llama-graph.h index 57772d9c158..3f8fe8e979b 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -402,6 +402,7 @@ struct llm_graph_params { const llama_adapter_loras * loras; const llama_memory_context_i * mctx; const llama_cross * cross; + bool update_mtp_kv; uint32_t n_outputs; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index cce99ef3b18..c4998707107 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13787,7 +13787,8 @@ struct llm_build_glm4 : public llm_graph_context { }; struct llm_build_glm4_moe : public llm_graph_context { - llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params, bool build_mtp_path) + : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -13932,68 +13933,57 @@ struct llm_build_glm4_moe : public llm_graph_context { cur = inpL; cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); - cb(cur, "result_norm", -1); + // cb(cur, "result_norm", -1); res->t_embd = cur; - // lm_head - cur = build_lora_mm(model.output, cur); - - cb(cur, "result_output", -1); - res->t_logits = cur; + if (build_mtp_path) { + const int il_mtp = hparams.n_layer - 1; + const auto & mtp_layer = model.layers[il_mtp]; + + ggml_tensor * mtp_logits = build_mtp_tail(mtp_layer, cur, n_embd_head); + res->t_logits = mtp_logits; + } else { + // lm_head + cur = build_lora_mm(model.output, cur); + res->t_logits = cur; + } - ggml_build_forward_expand(gf, cur); + ggml_build_forward_expand(gf, res->t_logits); } -}; - -struct llm_build_glm4_moe_mtp : public llm_graph_context { - llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - - const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); +private: + ggml_tensor * build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings, + int64_t n_embd_head + ) { const int il = hparams.n_layer - 1; - const auto & mtp_layer = model.layers[il]; ggml_tensor * inp_pos = build_inp_pos(); auto * inp_attn = build_attn_inp_kv_unified(); - - ggml_tensor* prev_embeddings_batch = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_embd, n_tokens); - ggml_set_name(prev_embeddings_batch, "mtp_prev_embeddings_batch_input"); - ggml_set_input(prev_embeddings_batch); - ggml_tensor * token_emb = build_inp_embd_mtp(mtp_layer.nextn.embed_tokens); ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); - ggml_tensor * hidden_state_norm = build_norm(prev_embeddings_batch, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); - + ggml_tensor * hidden_state_norm = build_norm(prev_embeddings, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); + ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); - ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // now proceed through last layer (skipped in main model) ggml_tensor * inpSA = cur; - // Pre-attention norm for the MTP block - ggml_tensor* attn_inp = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il); + cur = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il); // self-attention { ggml_tensor * Qcur = build_lora_mm(mtp_layer.wq, cur); - if (mtp_layer.bq) { - Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq); - } + if (mtp_layer.bq) Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq); cb(Qcur, "Qcur", il); ggml_tensor * Kcur = build_lora_mm(mtp_layer.wk, cur); - if (mtp_layer.bk) { - Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk); - } + if (mtp_layer.bk) Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk); cb(Kcur, "Kcur", il); ggml_tensor * Vcur = build_lora_mm(mtp_layer.wv, cur); - if (mtp_layer.bv) { - Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv); - } + if (mtp_layer.bv) Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv); cb(Vcur, "Vcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); @@ -14025,10 +14015,10 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - + cur = build_attn(inp_attn, - mtp_layer.wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + mtp_layer.wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); @@ -14068,9 +14058,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il); cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur); - res->t_logits = cur; - - ggml_build_forward_expand(gf, res->t_logits); + return cur; } }; @@ -18299,8 +18287,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { + const int64_t t_start_us = ggml_time_us(); + std::unique_ptr llm; + const bool build_mtp = params.update_mtp_kv; + switch (arch) { case LLM_ARCH_LLAMA: { @@ -18519,7 +18511,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_GLM4_MOE: { - llm = std::make_unique(*this, params); + llm = std::make_unique(*this, params, build_mtp); } break; case LLM_ARCH_BITNET: { @@ -18660,22 +18652,12 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); - - return llm->res->get_gf(); -} - -ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params) const { - std::unique_ptr llm; - - switch (arch) { - case LLM_ARCH_GLM4_MOE: - { - llm = std::make_unique(*this, params); - } break; - default: - GGML_ABORT("fatal error"); - } - + const int64_t t_end_us = ggml_time_us(); // Fim do cronômetro + LLAMA_LOG_INFO( + "[PERF] Graph build time: %.2f ms (MTP path: %s)\n", + (t_end_us - t_start_us) / 1000.0, + build_mtp ? "yes" : "no" + ); return llm->res->get_gf(); } diff --git a/src/llama-model.h b/src/llama-model.h index f5f9452a5b1..6fcd74d57fd 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -475,7 +475,6 @@ struct llama_model { // TODO: move this to new llm_arch_model_i interface ggml_cgraph * build_graph(const llm_graph_params & params) const; - ggml_cgraph * build_mtp_graph(const llm_graph_params& params) const; private: struct impl; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 34053cd0403..84a0e6fc158 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1739,7 +1739,7 @@ struct server_queue { while (true) { QUE_DBG("%s", "processing new tasks\n"); - + const int64_t t_turn_start_us = ggml_time_us(); while (true) { std::unique_lock lock(mutex_tasks); if (!running) { @@ -1762,7 +1762,11 @@ struct server_queue { QUE_DBG("%s", "update slots\n"); callback_update_slots(); - + const int64_t t_turn_end_us = ggml_time_us(); + SRV_DBG( + "[PERF] Server turn time: %.2f ms\n", + (t_turn_end_us - t_turn_start_us) / 1000.0 + ); QUE_DBG("%s", "waiting for new tasks\n"); { std::unique_lock lock(mutex_tasks); From df64508b937784112168aa099644b60fef015f05 Mon Sep 17 00:00:00 2001 From: samuel Date: Sun, 21 Sep 2025 21:55:41 -0300 Subject: [PATCH 18/35] mtp-batch (wip): merge glm graphs --- common/speculative.cpp | 19 ++- include/llama.h | 6 +- src/llama-batch.cpp | 1 + src/llama-context.cpp | 161 ++++++++++----------- src/llama-context.h | 8 +- src/llama-graph.h | 18 +++ src/llama-model.cpp | 306 +++++++++++++++++++++++----------------- tools/server/server.cpp | 21 ++- 8 files changed, 310 insertions(+), 230 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index d13666c9f95..1604dbd48ad 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -373,10 +373,24 @@ llama_token mtp_speculative_gen_draft( if (!smpl) { return -1; } - + const float * draft_input_hidden_state = llama_get_embeddings_ith(ctx, last_tok_idx); + llama_set_draft_input_hidden_state(ctx, draft_input_hidden_state); + llama_batch mtp_batch = llama_batch_init(1, 0, 1); common_batch_add(mtp_batch, id_last, n_past, {0}, true); - mtp_batch.update_mtp_kv = true; + + LOG_INF( + "[DEBUG-DRAFT-IN] Generating draft. id_last=%d, n_past=%d, last_tok_idx=%d\n", + id_last, n_past, last_tok_idx + ); + + mtp_batch.update_mtp_kv = false; + mtp_batch.use_mtp_head = true; + + LOG_INF("[DEBUG-DRAFT-CALL] Calling llama_decode for draft. update_mtp_kv=%s, use_mtp_head=%s\n", + mtp_batch.update_mtp_kv ? "true" : "false", + mtp_batch.use_mtp_head ? "true" : "false" + ); llama_decode(ctx, mtp_batch); llama_batch_free(mtp_batch); @@ -419,6 +433,7 @@ void mtp_update_kv_cache(struct llama_context * ctx, std::vectorapply()) { LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); ret = GGML_STATUS_FAILED; @@ -742,7 +742,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll // the new graph parameters // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters - const auto gparams = graph_params(res, ubatch, mctx, gtype, do_mtp_kv_update); + const auto gparams = graph_params(res, ubatch, mctx, gtype, do_mtp_kv_update, use_mtp_head); if (!graph_reuse_disable && res->can_reuse(gparams)) { //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); @@ -773,6 +773,29 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } } + if (do_mtp_kv_update || (use_mtp_head && !do_mtp_kv_update)) { // If it is any MTP operation + const char * target_tensor_name = "result_embd_pooled"; + ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name); + + const float * source_hidden_state = nullptr; + if (do_mtp_kv_update) { + // Cache warming uses the entire embeddings buffer + source_hidden_state = this->embd; + } else { + // Draft generation uses the specific state + source_hidden_state = this->draft_input_hidden_state; + } + + if (source_hidden_state != nullptr && hidden_states_input != nullptr) { + ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input)); + } else { + LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n", + __func__, target_tensor_name); + ret = GGML_STATUS_FAILED; + return nullptr; + } + } + // set the input data for the input tensors { //const auto t_start_us = ggml_time_us(); @@ -798,7 +821,12 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } ret = GGML_STATUS_SUCCESS; - + if (do_mtp_kv_update || use_mtp_head) { + ggml_tensor * sum_tensor = ggml_get_tensor(res->get_ctx(), "mtp_input_sum"); + if (sum_tensor) { + LLAMA_LOG_WARN("[DEBUG-SUM] MTP input sum node successfully created.\n"); + } + } return res; } @@ -859,7 +887,7 @@ int llama_context::encode(const llama_batch & batch_inp) { cparams.causal_attn = false; ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, false); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, false, false); cparams.causal_attn = causal_attn_org; @@ -972,6 +1000,10 @@ int llama_context::encode(const llama_batch & batch_inp) { int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + LLAMA_LOG_WARN("[DEBUG-DECODE-ENTRY] Entering llama_decode. update_mtp_kv=%s, use_mtp_head=%s\n", + batch_inp.update_mtp_kv ? "true" : "false", + batch_inp.use_mtp_head ? "true" : "false" + ); if (!memory) { LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); @@ -1080,9 +1112,24 @@ int llama_context::decode(const llama_batch & batch_inp) { int64_t n_outputs_prev = 0; const bool do_mtp_kv_update = batch_inp.update_mtp_kv; - + const bool use_mtp_head = batch_inp.use_mtp_head; + const bool is_prompt_warmup = batch_inp.n_tokens > 1 && (this->model.hparams.nextn_predict_layers > 0); + do { const auto & ubatch = mctx->get_ubatch(); + if (ubatch.n_tokens > 0) { + std::string pos_str; + for (uint32_t i = 0; i < std::min((uint32_t)5, ubatch.n_tokens); ++i) { + pos_str += std::to_string(ubatch.pos[i]) + " "; + } + LLAMA_LOG_WARN( + "[DEBUG-POS] ubatch_size=%u, update_mtp_kv=%s, use_mtp_head=%s. Posições: %s...\n", + ubatch.n_tokens, + batch_inp.update_mtp_kv ? "true" : "false", + batch_inp.use_mtp_head ? "true" : "false", + pos_str.c_str() + ); + } // count the outputs in this ubatch { @@ -1101,7 +1148,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update, use_mtp_head); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache @@ -1139,6 +1186,17 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} + // if (is_prompt_warmup) { + // auto res_mtp = std::make_unique(graph_max_nodes()); + // ggml_status status_mtp; + + // process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status_mtp, do_mtp_kv_update, use_mtp_head); + + // if (status_mtp != GGML_STATUS_SUCCESS) { + // LLAMA_LOG_WARN("%s: Failure in MTP heating ubatch\n", __func__); + // } + // } + auto * t_logits = res->get_logits(); auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; embd_tensor = res->get_embd(); @@ -1278,7 +1336,9 @@ int llama_context::decode(const llama_batch & batch_inp) { // overlap with device computation. ggml_backend_sched_reset(sched.get()); } - + if (!do_mtp_kv_update && !use_mtp_head) { + LLAMA_LOG_WARN("[DEBUG-EMBD-WRITE] Main decode completed. ctx->embd (%p) now contains the hidden state for the next draft.\n", (void*)this->embd); + } return 0; } @@ -1418,7 +1478,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u auto * res = gf_res_reserve.get(); - const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, false); + const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, false, false); res->reset(); @@ -1440,7 +1500,8 @@ llm_graph_params llama_context::graph_params( const llama_ubatch & ubatch, const llama_memory_context_i * mctx, llm_graph_type gtype, - bool update_mtp_kv) const { + bool update_mtp_kv, + bool use_mtp_head) const { return { /*.arch =*/ model.arch, /*.hparams =*/ model.hparams, @@ -1454,6 +1515,7 @@ llm_graph_params llama_context::graph_params( /*.mctx =*/ mctx, /*.cross =*/ &cross, /*.update_mtp_kv =*/ update_mtp_kv, + /*.use_mtp_head =*/ use_mtp_head, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), /*.res =*/ res, @@ -2194,7 +2256,7 @@ void llama_context::opt_epoch_iter( auto * res = gf_res_prev.get(); - const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT, false); + const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT, false, false); res->reset(); @@ -2983,79 +3045,6 @@ void llama_opt_epoch( callback_eval); } -// void llama_build_and_execute_mtp_graph(struct llama_context * ctx, -// const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { - -// const auto * model = llama_get_model(ctx); - -// auto res_mtp = std::make_unique(ctx->graph_max_nodes()); -// std::unique_ptr mctx = ctx->mtp_memory_batch(batch_inp); - -// std::vector idxs; -// idxs.push_back(n_past); -// llama_kv_cache_unified::slot_info sinfo = { -// /*.s0 =*/ 0, -// /*.s1 =*/ 0, -// /*.strm =*/ { 0 }, -// /*.idxs =*/ { idxs }, -// }; -// llama_kv_cache_unified::slot_info_vec_t sinfos; -// sinfos.push_back(sinfo); - -// static_cast(mctx.get())->set_sinfos(sinfos); -// const auto& ubatch_mtp = mctx->get_ubatch(); - -// //llama_ubatch ubatch_mtp; -// //ubatch_mtp.n_tokens = 1; -// //ubatch_mtp.pos = &n_past; - -// auto params_mtp = std::make_unique(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get())); -// ggml_backend_sched_t sched = params_mtp->sched; - -// auto * last_embd = ctx->get_embeddings_ith(last_tok_idx); - -// //if (mctx && !mctx->set_n_kv()) { -// // LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); -// //} -// static_cast(mctx.get())->set_n_kv(); - -// auto * gf = model->build_mtp_graph(*params_mtp); - -// if (!gf) { -// LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__); -// if (sched) ggml_backend_sched_free(sched); -// return; -// } - -// ggml_backend_sched_reset(sched); // clear the allocation of the previous graph -// ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it - -// ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input"); -// ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors - -// ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input"); -// ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors - -// ggml_backend_sched_graph_compute(sched, gf); // execute the graph - -// struct ggml_tensor * logits_mtp = res_mtp->get_logits(); - -// if (logits_mtp) { -// float * logits_dest = ctx->get_logits_ith(last_tok_idx); -// ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched, logits_mtp); -// if (backend_res) { -// // ggml_backend_tensor_get is the function for GPU->CPU copies. -// // We are copying a single 32-bit integer. -// ggml_backend_tensor_get(logits_mtp, -// logits_dest, // Pointer to our C++ variable -// 0, // Starting offset in bytes -// ggml_nbytes(logits_mtp)); // Number of bytes to copy -// } else { -// LLAMA_LOG_ERROR("%s: ERROR - Could not obtain the backend for the logits tensor.", __func__); -// } -// } else { -// LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__); -// } - -// ggml_backend_sched_free(sched); -// } \ No newline at end of file +void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state) { + ctx->draft_input_hidden_state = hidden_state; +} \ No newline at end of file diff --git a/src/llama-context.h b/src/llama-context.h index 88f63e88d73..1df3574c27c 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -61,6 +61,8 @@ struct llama_context { float * get_embeddings_seq(llama_seq_id seq_id); ggml_tensor * get_embeddings_tensor(); + const float * draft_input_hidden_state = nullptr; + void attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch); @@ -100,7 +102,8 @@ struct llama_context { llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret, - const bool do_mtp_kv_update); + const bool do_mtp_kv_update, + const bool use_mtp_head); int encode(const llama_batch & batch_inp); int decode(const llama_batch & batch_inp); @@ -213,7 +216,8 @@ struct llama_context { const llama_ubatch & ubatch, const llama_memory_context_i * mctx, llm_graph_type gtype, - bool update_mtp_kv) const; + bool update_mtp_kv, + bool use_mtp_head) const; llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const; diff --git a/src/llama-graph.h b/src/llama-graph.h index 3f8fe8e979b..40dd83f0bc1 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -29,6 +29,7 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DEFAULT, LLM_GRAPH_TYPE_ENCODER, LLM_GRAPH_TYPE_DECODER, + LLM_GRAPH_TYPE_DRAFT, }; enum llm_ffn_op_type { @@ -94,6 +95,20 @@ class llm_graph_input_i { using llm_graph_input_ptr = std::unique_ptr; +class llm_graph_input_mtp_states : public llm_graph_input_i { +public: + llm_graph_input_mtp_states() = default; + virtual ~llm_graph_input_mtp_states() = default; + + void set_input(const llama_ubatch * /*ubatch*/) override {} + + bool can_reuse(const llm_graph_params & /*params*/) override { + return true; + } + + ggml_tensor * states = nullptr; +}; + class llm_graph_input_embd : public llm_graph_input_i { public: llm_graph_input_embd() = default; @@ -403,6 +418,7 @@ struct llm_graph_params { const llama_memory_context_i * mctx; const llama_cross * cross; bool update_mtp_kv; + bool use_mtp_head; uint32_t n_outputs; @@ -451,6 +467,8 @@ struct llm_graph_params { cvec == other.cvec && loras == other.loras && cross == other.cross && + update_mtp_kv == other.update_mtp_kv && + use_mtp_head == other.use_mtp_head && n_outputs == other.n_outputs; } }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c4998707107..82c7be49cbb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13787,168 +13787,204 @@ struct llm_build_glm4 : public llm_graph_context { }; struct llm_build_glm4_moe : public llm_graph_context { - llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params, bool build_mtp_path) - : llm_graph_context(params) { + llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); ggml_tensor * cur; - ggml_tensor * inpL; - inpL = build_inp_embd(model.tok_embd); + LLAMA_LOG_WARN( + "[DEBUG-GRAPH-STATE] Building graph. MTP Head=%s, MTP KV Update=%s, n_tokens=%d\n", + params.use_mtp_head ? "true" : "false", + params.update_mtp_kv ? "true" : "false", + n_tokens + ); + // for (int i = 0; i < n_tokens; ++i) { + // LLAMA_LOG_WARN(" - ubatch token[%d]: ID=%d, Pos=%d\n", i, ubatch.token[i], ubatch.pos[i]); + // } + if (n_tokens > 0) { + LLAMA_LOG_WARN( + " - ubatch tokens: [ID=%d, Pos=%d] ... [ID=%d, Pos=%d]\n", + ubatch.token[0], ubatch.pos[0], + ubatch.token[n_tokens-1], ubatch.pos[n_tokens-1] + ); + } - // inp_pos - contains the positions - ggml_tensor * inp_pos = build_inp_pos(); + if (params.use_mtp_head) { + ggml_tensor* hidden_states_from_main_model; - auto * inp_attn = build_attn_inp_kv_unified(); + if (params.update_mtp_kv) { + hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_name(hidden_states_from_main_model, "result_embd_pooled"); + ggml_set_input(hidden_states_from_main_model); - ggml_tensor * inp_out_ids = build_inp_out_ids(); + auto inp_mtp = std::make_unique(); + inp_mtp->states = hidden_states_from_main_model; + res->add_input(std::move(inp_mtp)); + } else { + hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd); + ggml_set_name(hidden_states_from_main_model, "result_embd_pooled"); + ggml_set_input(hidden_states_from_main_model); - // Only process up to last layer (skip final NextN layer) - // Final layer tensors are loaded but not processed in forward pass - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { - ggml_tensor * inpSA = inpL; + auto inp_mtp = std::make_unique(); + inp_mtp->states = hidden_states_from_main_model; + res->add_input(std::move(inp_mtp)); + } + res->t_embd = hidden_states_from_main_model; - // Pre-attention norm - cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); + const int il_mtp = hparams.n_layer - 1; + const auto & mtp_layer = model.layers[il_mtp]; + res->t_logits = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head); - // self-attention - { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - cb(Qcur, "Qcur", il); + } else { + ggml_tensor * inpL = build_inp_embd(model.tok_embd); + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_unified(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + // Only process up to last layer (skip final NextN layer) + // Final layer tensors are loaded but not processed in forward pass + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { + // if (params.use_mtp_head) { + // LLAMA_LOG_ERROR("[DEBUG-KV-ERROR] MTP path is running the main layer %d!\n", il); + // } else { + // LLAMA_LOG_WARN("[DEBUG-KV] Main Head Path: Accessing layer %d\n", il); + // } + ggml_tensor * inpSA = inpL; - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - cb(Kcur, "Kcur", il); + // Pre-attention norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - cb(Vcur, "Vcur", il); + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + cb(Qcur, "Qcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + cb(Kcur, "Kcur", il); - // Apply Q/K norm if available (GLM-4.5 355B variant) - if (model.layers[il].attn_q_norm) { - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); - cb(Qcur, "Qcur_normed", il); - } - if (model.layers[il].attn_k_norm) { - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); - cb(Kcur, "Kcur_normed", il); - } + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + cb(Vcur, "Vcur", il); - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + // Apply Q/K norm if available (GLM-4.5 355B variant) + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + } - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); - cur = build_attn(inp_attn, - model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); - } + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); - if (il == n_transformer_layers - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } - // Post-attention norm - cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "post_attn_norm", il); + if (il == n_transformer_layers - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } - // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) - if (static_cast(il) < hparams.n_layer_dense_lead) { - // Dense FFN layer - cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); - } else { - // Process routed experts using existing MoE infrastructure - ggml_tensor * routed_out = build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - model.layers[il].ffn_exp_probs_b, - n_expert, n_expert_used, - LLM_FFN_SILU, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, - (llama_expert_gating_func_type) hparams.expert_gating_func, - il); - cb(routed_out, "ffn_moe_out", il); + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); - // Process shared expert on original input - ggml_tensor * shared_out = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, - model.layers[il].ffn_gate_shexp, NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(shared_out, "ffn_shexp_out", il); + // Post-attention norm + cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) + if (static_cast(il) < hparams.n_layer_dense_lead) { + // Dense FFN layer + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // Process routed experts using existing MoE infrastructure + ggml_tensor * routed_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(routed_out, "ffn_moe_out", il); - // Final output: routed_output + shared_output - cur = ggml_add(ctx0, routed_out, shared_out); - cb(cur, "ffn_out", il); - } + // Process shared expert on original input + ggml_tensor * shared_out = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shared_out, "ffn_shexp_out", il); - cur = ggml_add(ctx0, cur, ffn_inp); + // Final output: routed_output + shared_output + cur = ggml_add(ctx0, routed_out, shared_out); + cb(cur, "ffn_out", il); + } - cur = build_cvec(cur, il); - cb(cur, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); - // input for next layer - inpL = cur; - } + cur = build_cvec(cur, il); + cb(cur, "l_out", il); - cur = inpL; - cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + // input for next layer + inpL = cur; + } - // cb(cur, "result_norm", -1); - res->t_embd = cur; + cur = inpL; + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); - if (build_mtp_path) { - const int il_mtp = hparams.n_layer - 1; - const auto & mtp_layer = model.layers[il_mtp]; - - ggml_tensor * mtp_logits = build_mtp_tail(mtp_layer, cur, n_embd_head); - res->t_logits = mtp_logits; - } else { - // lm_head - cur = build_lora_mm(model.output, cur); - res->t_logits = cur; + // cb(cur, "result_norm", -1); + res->t_embd = cur; + + // Use the main model header + res->t_logits = build_lora_mm(model.output, cur); } - ggml_build_forward_expand(gf, res->t_logits); + ggml_build_forward_expand(gf, res->t_logits); + } private: @@ -13956,6 +13992,10 @@ struct llm_build_glm4_moe : public llm_graph_context { int64_t n_embd_head ) { const int il = hparams.n_layer - 1; + // LLAMA_LOG_WARN("[DEBUG-KV] MTP Head Path: Accessing layer %d\n", il); + ggml_tensor * sum_node = ggml_sum(ctx0, prev_embeddings); + + ggml_set_name(sum_node, "mtp_input_sum"); ggml_tensor * inp_pos = build_inp_pos(); auto * inp_attn = build_attn_inp_kv_unified(); @@ -14015,7 +14055,11 @@ struct llm_build_glm4_moe : public llm_graph_context { cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - + // LLAMA_LOG_WARN("[DEBUG-MTP-ATTN] Inputs for build_attn in the layer %d:\n", il); + // LLAMA_LOG_WARN(" - Qcur shape: [%d, %d, %d]\n", Qcur->ne[0], Qcur->ne[1], Qcur->ne[2]); + // LLAMA_LOG_WARN(" - Kcur shape: [%d, %d, %d]\n", Kcur->ne[0], Kcur->ne[1], Kcur->ne[2]); + // LLAMA_LOG_WARN(" - Vcur shape: [%d, %d, %d]\n", Vcur->ne[0], Vcur->ne[1], Vcur->ne[2]); + cur = build_attn(inp_attn, mtp_layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); @@ -18511,7 +18555,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_GLM4_MOE: { - llm = std::make_unique(*this, params, build_mtp); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_BITNET: { diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 84a0e6fc158..7070a56159e 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3387,6 +3387,15 @@ struct server_context { slot.n_prompt_tokens_processed += n_pos; } + const size_t n_to_log = slot.mtp_kv_update_batch.size(); + if (n_to_log > 0) { + SLT_INF(slot, + "DEBUG-KV-REQ Cache Warm-up: Requesting KV update for %zu tokens. Positions: %d ... %d\n", + n_to_log, + slot.mtp_kv_update_batch.front().n_past, + slot.mtp_kv_update_batch.back().n_past + ); + } // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { // get next token to process @@ -3517,12 +3526,12 @@ struct server_context { continue; // continue loop of n_batch } - for (auto & slot : slots) { - // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation - if (slot.has_mtp) { - mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, i, n_tokens); - } - } + // for (auto & slot : slots) { + // // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation + // if (slot.has_mtp) { + // mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, i, n_tokens); + // } + // } // move the head of the batch forward with the number of tokens we just processed i_next = i + n_tokens; From 3da7e7f3309dbb576538850c92c1cbf8fdc6d6ee Mon Sep 17 00:00:00 2001 From: samuel Date: Tue, 23 Sep 2025 22:45:11 -0300 Subject: [PATCH 19/35] mtp-batch (fix): warm mtp cache for small batch size --- common/speculative.cpp | 17 ++++++++--------- common/speculative.h | 2 +- src/llama-context.cpp | 9 ++++++++- tools/server/server.cpp | 20 +++++++++++++------- 4 files changed, 30 insertions(+), 18 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 1604dbd48ad..950a9a54bc5 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -373,15 +373,17 @@ llama_token mtp_speculative_gen_draft( if (!smpl) { return -1; } - const float * draft_input_hidden_state = llama_get_embeddings_ith(ctx, last_tok_idx); + const float * draft_input_hidden_state = llama_get_embeddings_ith(ctx, -1); llama_set_draft_input_hidden_state(ctx, draft_input_hidden_state); + LOG_INF("[DEBUG-DRAFT-STATE] Main model final embd pointer: %p, State being used for draft: %p\n", + (void*)llama_get_embeddings(ctx), (void*)draft_input_hidden_state); llama_batch mtp_batch = llama_batch_init(1, 0, 1); common_batch_add(mtp_batch, id_last, n_past, {0}, true); LOG_INF( "[DEBUG-DRAFT-IN] Generating draft. id_last=%d, n_past=%d, last_tok_idx=%d\n", - id_last, n_past, last_tok_idx + id_last, n_past, draft_input_hidden_state ); mtp_batch.update_mtp_kv = false; @@ -411,15 +413,12 @@ llama_token mtp_speculative_gen_draft( } -void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, size_t batch_start, size_t n_tokens) { +void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens) { if (tokens.empty()) { - tokens.clear(); return; } - if (n_tokens < 0) { - n_tokens = tokens.size(); - } - const size_t n_to_process = std::min((size_t)tokens.size(), n_tokens); + + const size_t n_to_process = tokens.size(); LOG_DBG( "[MTP BATCHING] mtp_update_kv_cache call for %zu tokens.\n", @@ -438,5 +437,5 @@ void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, size_t batch_start = 0, size_t n_tokens = -1); +void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 9d77ce30790..5427f29eb7d 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1146,7 +1146,14 @@ int llama_context::decode(const llama_batch & batch_inp) { // needs to happen before the graph is built n_outputs = n_outputs_new; } - + if (do_mtp_kv_update) { + LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] MTP KV Update ubatch: n_tokens=%d\n", ubatch.n_tokens); + std::string positions_str; + for (int i = 0; i < ubatch.n_tokens; ++i) { + positions_str += std::to_string(ubatch.pos[i]) + " "; + } + LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] Positions: %s\n", positions_str.c_str()); + } ggml_status status; const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update, use_mtp_head); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 7070a56159e..ddd7b6afa82 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3525,13 +3525,18 @@ struct server_context { continue; // continue loop of n_batch } + for (auto & slot : slots) { + if (slot.has_mtp && slot.n_past == slot.n_prompt_tokens) { + SLT_INF(slot, "Prompt processing finished. Warming up MTP KV cache for %d tokens.\n", slot.n_prompt_tokens); + slot.mtp_kv_update_batch.clear(); + + for (int j = 0; j < slot.n_prompt_tokens; ++j) { + slot.mtp_kv_update_batch.push_back({ slot.prompt_tokens[j], (llama_pos)j, j }); + } - // for (auto & slot : slots) { - // // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation - // if (slot.has_mtp) { - // mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, i, n_tokens); - // } - // } + mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); + } + } // move the head of the batch forward with the number of tokens we just processed i_next = i + n_tokens; @@ -3697,8 +3702,9 @@ struct server_context { const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); if (slot.has_mtp) { + slot.mtp_kv_update_batch.clear(); for (int32_t i = 0; i < ids.size(); ++i) { - slot.mtp_kv_update_batch.push_back({ ids[i], slot.n_past + 1 + i, i }); + slot.mtp_kv_update_batch.push_back({ ids[i], slot.n_past + i, i }); } mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); } From 75dc25e6fe781c1b65038d69390fb778d760e3a1 Mon Sep 17 00:00:00 2001 From: samuel Date: Sat, 27 Sep 2025 17:17:00 -0300 Subject: [PATCH 20/35] mtp-batch (wip): organize batch for mtp cache --- common/speculative.cpp | 21 +++++++++++++------- common/speculative.h | 2 +- src/llama-context.cpp | 13 +++---------- tools/server/server.cpp | 43 ++++++++++++++--------------------------- 4 files changed, 33 insertions(+), 46 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 950a9a54bc5..503da981944 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -373,7 +373,7 @@ llama_token mtp_speculative_gen_draft( if (!smpl) { return -1; } - const float * draft_input_hidden_state = llama_get_embeddings_ith(ctx, -1); + const float * draft_input_hidden_state = llama_get_embeddings(ctx); llama_set_draft_input_hidden_state(ctx, draft_input_hidden_state); LOG_INF("[DEBUG-DRAFT-STATE] Main model final embd pointer: %p, State being used for draft: %p\n", (void*)llama_get_embeddings(ctx), (void*)draft_input_hidden_state); @@ -413,17 +413,24 @@ llama_token mtp_speculative_gen_draft( } -void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens) { +void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, const char* tag) { if (tokens.empty()) { return; } const size_t n_to_process = tokens.size(); - - LOG_DBG( - "[MTP BATCHING] mtp_update_kv_cache call for %zu tokens.\n", - n_to_process - ); + std::string details_str; + for (size_t i = 0; i < std::min((size_t)5, n_to_process); ++i) { + details_str += " {id: " + std::to_string(tokens[i].id) + ", pos: " + std::to_string(tokens[i].n_past) + "}"; + } + LOG_INF("[MTP-UPDATE|%s] Updating %zu tokens. Details:%s ...\n", tag, n_to_process, details_str.c_str()); + + // LOG_INF("[DEBUG-CHUNK] Warming up MTP model chunk. Batch size: %zu\n", n_to_process); + // std::string positions_str; + // for (size_t i = 0; i < std::min((size_t)5, n_to_process); ++i) { + // positions_str += std::to_string(tokens[i].n_past) + " "; + // } + // LOG_INF("[DEBUG-CHUNK] MTP warm-up positions: %s...\n", positions_str.c_str()); llama_batch mtp_batch = llama_batch_init(n_to_process, 0, 1); for (size_t i = 0; i < n_to_process; ++i) { diff --git a/common/speculative.h b/common/speculative.h index 827600c33d7..c60bd97ac3f 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -49,4 +49,4 @@ llama_tokens common_speculative_gen_draft( const llama_tokens & prompt, llama_token id_last); -void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens); +void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, const char* tag); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 5427f29eb7d..070c1b738f6 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -778,13 +778,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name); const float * source_hidden_state = nullptr; - if (do_mtp_kv_update) { - // Cache warming uses the entire embeddings buffer - source_hidden_state = this->embd; - } else { - // Draft generation uses the specific state - source_hidden_state = this->draft_input_hidden_state; - } + source_hidden_state = this->draft_input_hidden_state; if (source_hidden_state != nullptr && hidden_states_input != nullptr) { ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input)); @@ -1149,14 +1143,13 @@ int llama_context::decode(const llama_batch & batch_inp) { if (do_mtp_kv_update) { LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] MTP KV Update ubatch: n_tokens=%d\n", ubatch.n_tokens); std::string positions_str; - for (int i = 0; i < ubatch.n_tokens; ++i) { + for (int i = 0; i < std::min((uint32_t)5, ubatch.n_tokens); ++i) { positions_str += std::to_string(ubatch.pos[i]) + " "; } - LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] Positions: %s\n", positions_str.c_str()); + LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] Positions: %s...\n", positions_str.c_str()); } ggml_status status; const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update, use_mtp_head); - if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache llama_pos pos_min[LLAMA_MAX_SEQ]; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index ddd7b6afa82..aba5859acfa 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3387,14 +3387,8 @@ struct server_context { slot.n_prompt_tokens_processed += n_pos; } - const size_t n_to_log = slot.mtp_kv_update_batch.size(); - if (n_to_log > 0) { - SLT_INF(slot, - "DEBUG-KV-REQ Cache Warm-up: Requesting KV update for %zu tokens. Positions: %d ... %d\n", - n_to_log, - slot.mtp_kv_update_batch.front().n_past, - slot.mtp_kv_update_batch.back().n_past - ); + if (slot.has_mtp) { + slot.mtp_kv_update_batch.clear(); } // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { @@ -3484,6 +3478,7 @@ struct server_context { batch.seq_id + i, batch.logits + i, }; + LOG_INF("\n[DEBUG-CHUNK] Processing main model chunk. Batch size: %d\n", n_tokens); const int ret = llama_decode(ctx, batch_view); @@ -3525,16 +3520,18 @@ struct server_context { continue; // continue loop of n_batch } + + // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation + // Aquece o cache MTP para os pedaços do prompt que acabaram de ser processados. + // Esta lógica SÓ deve ser executada durante o processamento do prompt. for (auto & slot : slots) { - if (slot.has_mtp && slot.n_past == slot.n_prompt_tokens) { - SLT_INF(slot, "Prompt processing finished. Warming up MTP KV cache for %d tokens.\n", slot.n_prompt_tokens); - slot.mtp_kv_update_batch.clear(); - - for (int j = 0; j < slot.n_prompt_tokens; ++j) { - slot.mtp_kv_update_batch.push_back({ slot.prompt_tokens[j], (llama_pos)j, j }); - } - - mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); + if (slot.state == SLOT_STATE_PROCESSING_PROMPT && slot.has_mtp && !slot.mtp_kv_update_batch.empty()) { + SLT_INF(slot, "DEBUG-KV-REQ: Warming up MTP cache for prompt chunk of size %zu. Positions: %d ... %d\n", + slot.mtp_kv_update_batch.size(), + slot.mtp_kv_update_batch.front().n_past, + slot.mtp_kv_update_batch.back().n_past + ); + mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, "PROMPT_WARMUP"); } } @@ -3581,11 +3578,6 @@ struct server_context { common_sampler_accept(slot.smpl, id, true); - // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation - //if (slot.has_mtp) { - // mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); - //} - slot.n_decoded += 1; const int64_t t_current = ggml_time_us(); @@ -3670,11 +3662,6 @@ struct server_context { draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); } - //llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); - //llama_tokens draft; - //draft.reserve(1); - //draft.push_back(draft_id); - // ignore small drafts if (slot.params.speculative.n_min > (int)draft.size()) { SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min); @@ -3706,7 +3693,7 @@ struct server_context { for (int32_t i = 0; i < ids.size(); ++i) { slot.mtp_kv_update_batch.push_back({ ids[i], slot.n_past + i, i }); } - mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); + mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, "GEN_ACCEPTED"); } slot.n_past += ids.size(); From 67c6c069e0a5496adfd7d8aa6ca7514db5a6f437 Mon Sep 17 00:00:00 2001 From: samuel Date: Sat, 27 Sep 2025 19:42:32 -0300 Subject: [PATCH 21/35] mtp-batch (wip): Isolate MTP graph to prevent host embedding buffer corruption --- common/speculative.cpp | 38 +++++----- common/speculative.h | 5 +- include/llama.h | 1 + src/llama-batch.cpp | 1 + src/llama-context.cpp | 164 +++++++++++++++++++++++----------------- src/llama-context.h | 3 +- src/llama-model.cpp | 39 +++++----- tools/server/server.cpp | 36 +++++++-- 8 files changed, 171 insertions(+), 116 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 503da981944..7a17b8d965e 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -373,19 +373,9 @@ llama_token mtp_speculative_gen_draft( if (!smpl) { return -1; } - const float * draft_input_hidden_state = llama_get_embeddings(ctx); - llama_set_draft_input_hidden_state(ctx, draft_input_hidden_state); - LOG_INF("[DEBUG-DRAFT-STATE] Main model final embd pointer: %p, State being used for draft: %p\n", - (void*)llama_get_embeddings(ctx), (void*)draft_input_hidden_state); - llama_batch mtp_batch = llama_batch_init(1, 0, 1); common_batch_add(mtp_batch, id_last, n_past, {0}, true); - LOG_INF( - "[DEBUG-DRAFT-IN] Generating draft. id_last=%d, n_past=%d, last_tok_idx=%d\n", - id_last, n_past, draft_input_hidden_state - ); - mtp_batch.update_mtp_kv = false; mtp_batch.use_mtp_head = true; @@ -413,7 +403,9 @@ llama_token mtp_speculative_gen_draft( } -void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, const char* tag) { +void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, + bool is_prompt_warmup) { + if (tokens.empty()) { return; } @@ -423,26 +415,34 @@ void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, const char* tag); +void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, + bool is_prompt_warmup); + +double calculate_vector_sum_double(const float* vec, size_t size); \ No newline at end of file diff --git a/include/llama.h b/include/llama.h index 6a71c29bffd..e6d6aadf7e5 100644 --- a/include/llama.h +++ b/include/llama.h @@ -232,6 +232,7 @@ extern "C" { int8_t * logits; // TODO: rename this to "output" bool update_mtp_kv; bool use_mtp_head; + bool is_mtp_prompt_warmup; } llama_batch; enum llama_model_kv_override_type { diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 273b2b7010d..3268467f274 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -843,6 +843,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ /*logits =*/ nullptr, /*.use_mtp_head =*/ false, /*update_mtp_kv =*/ false, + /*.is_mtp_prompt_warmup =*/ false, }; if (embd) { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 070c1b738f6..81c7a48d0e3 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -13,6 +13,7 @@ #include #include #include +#include // // llama_context @@ -729,8 +730,19 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } +static double calculate_vector_sum(const float* vec, size_t size) { + if (!vec) { + return 0.0; + } + double sum = 0.0; + for (size_t i = 0; i < size; ++i) { + sum += vec[i]; + } + return sum; +} + llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret, - bool do_mtp_kv_update, bool use_mtp_head) { + bool do_mtp_kv_update, bool use_mtp_head, bool is_mtp_prompt_warmup) { if (mctx && !mctx->apply()) { LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); ret = GGML_STATUS_FAILED; @@ -778,9 +790,20 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name); const float * source_hidden_state = nullptr; - source_hidden_state = this->draft_input_hidden_state; + if (is_mtp_prompt_warmup || (do_mtp_kv_update && !is_mtp_prompt_warmup)) { + source_hidden_state = this->embd; + } else { + source_hidden_state = this->draft_input_hidden_state; + } if (source_hidden_state != nullptr && hidden_states_input != nullptr) { + const size_t n_embd = this->model.hparams.n_embd; + const size_t n_tokens_for_sum = (do_mtp_kv_update && ubatch.n_tokens > 2) ? ubatch.n_tokens : 1; + double input_sum = calculate_vector_sum(source_hidden_state, n_tokens_for_sum * n_embd); + const char * op_type = (do_mtp_kv_update) ? "MTP_UPDATE" : "DRAFT_GEN"; + + LLAMA_LOG_WARN("[MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n", op_type, input_sum); + ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input)); } else { LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n", @@ -881,7 +904,7 @@ int llama_context::encode(const llama_batch & batch_inp) { cparams.causal_attn = false; ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, false, false); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, false, false, false); cparams.causal_attn = causal_attn_org; @@ -1107,7 +1130,7 @@ int llama_context::decode(const llama_batch & batch_inp) { int64_t n_outputs_prev = 0; const bool do_mtp_kv_update = batch_inp.update_mtp_kv; const bool use_mtp_head = batch_inp.use_mtp_head; - const bool is_prompt_warmup = batch_inp.n_tokens > 1 && (this->model.hparams.nextn_predict_layers > 0); + const bool is_prompt_warmup = batch_inp.is_mtp_prompt_warmup; do { const auto & ubatch = mctx->get_ubatch(); @@ -1117,7 +1140,7 @@ int llama_context::decode(const llama_batch & batch_inp) { pos_str += std::to_string(ubatch.pos[i]) + " "; } LLAMA_LOG_WARN( - "[DEBUG-POS] ubatch_size=%u, update_mtp_kv=%s, use_mtp_head=%s. Posições: %s...\n", + "[DEBUG-POS] ubatch_size=%u, update_mtp_kv=%s, use_mtp_head=%s. Positions: %s...\n", ubatch.n_tokens, batch_inp.update_mtp_kv ? "true" : "false", batch_inp.use_mtp_head ? "true" : "false", @@ -1149,7 +1172,7 @@ int llama_context::decode(const llama_batch & batch_inp) { LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] Positions: %s...\n", positions_str.c_str()); } ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update, use_mtp_head); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update, use_mtp_head, is_prompt_warmup); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache llama_pos pos_min[LLAMA_MAX_SEQ]; @@ -1186,20 +1209,8 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - // if (is_prompt_warmup) { - // auto res_mtp = std::make_unique(graph_max_nodes()); - // ggml_status status_mtp; - - // process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status_mtp, do_mtp_kv_update, use_mtp_head); - - // if (status_mtp != GGML_STATUS_SUCCESS) { - // LLAMA_LOG_WARN("%s: Failure in MTP heating ubatch\n", __func__); - // } - // } - auto * t_logits = res->get_logits(); auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; - embd_tensor = res->get_embd(); if (t_embd && res->get_embd_pooled()) { t_embd = res->get_embd_pooled(); @@ -1220,58 +1231,69 @@ int llama_context::decode(const llama_batch & batch_inp) { } } + if (use_mtp_head) { + if (t_embd != nullptr) { + LLAMA_LOG_ERROR("[MTP-GRAPH-BUG] The MTP graph returned an embedding tensor when it shouldn't have! This will cause corruption.\n"); + } else { + LLAMA_LOG_WARN("[MTP-GRAPH-OK] The MTP graph correctly did not return an embedding tensor.\n"); + } + } + // extract embeddings if (t_embd && n_outputs > 0) { - ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); - GGML_ASSERT(backend_embd != nullptr); - - switch (cparams.pooling_type) { - case LLAMA_POOLING_TYPE_NONE: - { - // extract token embeddings - GGML_ASSERT(embd != nullptr); - float * embd_out = embd + n_outputs_prev*n_embd; - - if (n_outputs) { - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_MEAN: - case LLAMA_POOLING_TYPE_CLS: - case LLAMA_POOLING_TYPE_LAST: - { - // extract sequence embeddings (cleared before processing each batch) - auto & embd_seq_out = embd_seq; - - for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { - const llama_seq_id seq_id = ubatch.seq_id_unq[s]; - const int32_t seq_idx = ubatch.seq_idx[seq_id]; - - embd_seq_out[seq_id].resize(n_embd); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_RANK: - { - // extract the rerank score - n_cls_out floats per sequence - auto & embd_seq_out = embd_seq; - - const uint32_t n_cls_out = hparams.n_cls_out; - - for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { - const llama_seq_id seq_id = ubatch.seq_id_unq[s]; - const int32_t seq_idx = ubatch.seq_idx[seq_id]; - - embd_seq_out[seq_id].resize(n_cls_out); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); + if (!use_mtp_head) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); + GGML_ASSERT(backend_embd != nullptr); + + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + GGML_ASSERT(embd != nullptr); + float * embd_out = embd + n_outputs_prev*n_embd; + + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings (cleared before processing each batch) + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + const int32_t seq_idx = ubatch.seq_idx[seq_id]; + + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_RANK: + { + // extract the rerank score - n_cls_out floats per sequence + auto & embd_seq_out = embd_seq; + const uint32_t n_cls_out = hparams.n_cls_out; + + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + const int32_t seq_idx = ubatch.seq_idx[seq_id]; + + embd_seq_out[seq_id].resize(n_cls_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); } - } break; - case LLAMA_POOLING_TYPE_UNSPECIFIED: - { - GGML_ABORT("unknown pooling type"); - } + } + } else { + LLAMA_LOG_WARN("[DEBUG-EMBD-COPY] Skipping embedding buffer copy for MTP operation (use_mtp_head=true).\n"); } } @@ -1336,8 +1358,12 @@ int llama_context::decode(const llama_batch & batch_inp) { // overlap with device computation. ggml_backend_sched_reset(sched.get()); } - if (!do_mtp_kv_update && !use_mtp_head) { - LLAMA_LOG_WARN("[DEBUG-EMBD-WRITE] Main decode completed. ctx->embd (%p) now contains the hidden state for the next draft.\n", (void*)this->embd); + + if (!use_mtp_head) { + synchronize(); + const size_t n_embd = this->model.hparams.n_embd; + double full_buffer_sum = calculate_vector_sum(this->embd, n_outputs_all * n_embd); + LLAMA_LOG_WARN("[INTEGRITY-CHECK|A] After main decode. ubatch_size=%d. Checksum: %e\n", n_outputs_all, full_buffer_sum); } return 0; } diff --git a/src/llama-context.h b/src/llama-context.h index 1df3574c27c..aa6ced79471 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -103,7 +103,8 @@ struct llama_context { llama_memory_context_i * mctx, ggml_status & ret, const bool do_mtp_kv_update, - const bool use_mtp_head); + const bool use_mtp_head, + bool is_mtp_prompt_warmup); int encode(const llama_batch & batch_inp); int decode(const llama_batch & batch_inp); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 82c7be49cbb..1f00bb7dd76 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13821,20 +13821,19 @@ struct llm_build_glm4_moe : public llm_graph_context { auto inp_mtp = std::make_unique(); inp_mtp->states = hidden_states_from_main_model; res->add_input(std::move(inp_mtp)); - } else { - hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd); - ggml_set_name(hidden_states_from_main_model, "result_embd_pooled"); - ggml_set_input(hidden_states_from_main_model); + } else { + hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd); + ggml_set_name(hidden_states_from_main_model, "result_embd_pooled"); + ggml_set_input(hidden_states_from_main_model); - auto inp_mtp = std::make_unique(); - inp_mtp->states = hidden_states_from_main_model; - res->add_input(std::move(inp_mtp)); - } - res->t_embd = hidden_states_from_main_model; + auto inp_mtp = std::make_unique(); + inp_mtp->states = hidden_states_from_main_model; + res->add_input(std::move(inp_mtp)); + } - const int il_mtp = hparams.n_layer - 1; - const auto & mtp_layer = model.layers[il_mtp]; - res->t_logits = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head); + const int il_mtp = hparams.n_layer - 1; + const auto & mtp_layer = model.layers[il_mtp]; + res->t_logits = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head); } else { ggml_tensor * inpL = build_inp_embd(model.tok_embd); @@ -13991,9 +13990,11 @@ struct llm_build_glm4_moe : public llm_graph_context { ggml_tensor * build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings, int64_t n_embd_head ) { + ggml_tensor * embd_copy = ggml_dup(ctx0, prev_embeddings); + const int il = hparams.n_layer - 1; // LLAMA_LOG_WARN("[DEBUG-KV] MTP Head Path: Accessing layer %d\n", il); - ggml_tensor * sum_node = ggml_sum(ctx0, prev_embeddings); + ggml_tensor * sum_node = ggml_sum(ctx0, embd_copy); ggml_set_name(sum_node, "mtp_input_sum"); @@ -14002,7 +14003,7 @@ struct llm_build_glm4_moe : public llm_graph_context { ggml_tensor * token_emb = build_inp_embd_mtp(mtp_layer.nextn.embed_tokens); ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); - ggml_tensor * hidden_state_norm = build_norm(prev_embeddings, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); + ggml_tensor * hidden_state_norm = build_norm(embd_copy, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); @@ -18694,13 +18695,15 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { GGML_ABORT("fatal error"); } - // add on pooling layer - llm->build_pooling(cls, cls_b, cls_out, cls_out_b); - const int64_t t_end_us = ggml_time_us(); // Fim do cronômetro + if (!params.use_mtp_head) { + // add on pooling layer + llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + } + const int64_t t_end_us = ggml_time_us(); LLAMA_LOG_INFO( "[PERF] Graph build time: %.2f ms (MTP path: %s)\n", (t_end_us - t_start_us) / 1000.0, - build_mtp ? "yes" : "no" + params.use_mtp_head ? "yes" : "no" ); return llm->res->get_gf(); } diff --git a/tools/server/server.cpp b/tools/server/server.cpp index aba5859acfa..3d025b21202 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3522,8 +3522,6 @@ struct server_context { } // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation - // Aquece o cache MTP para os pedaços do prompt que acabaram de ser processados. - // Esta lógica SÓ deve ser executada durante o processamento do prompt. for (auto & slot : slots) { if (slot.state == SLOT_STATE_PROCESSING_PROMPT && slot.has_mtp && !slot.mtp_kv_update_batch.empty()) { SLT_INF(slot, "DEBUG-KV-REQ: Warming up MTP cache for prompt chunk of size %zu. Positions: %d ... %d\n", @@ -3531,7 +3529,7 @@ struct server_context { slot.mtp_kv_update_batch.front().n_past, slot.mtp_kv_update_batch.back().n_past ); - mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, "PROMPT_WARMUP"); + mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, true); } } @@ -3569,13 +3567,16 @@ struct server_context { } const int tok_idx = slot.i_batch - i; - + // Sets the initial state for the first draft generation. + if (slot.has_mtp) { + llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, -1)); + } llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); slot.last_tok_idx = tok_idx; //SRV_INF("main loop sampled token: '%s'\n", common_token_to_piece(ctx, id, true).c_str()); slot.i_batch = -1; - + SLT_INF(slot, "[SAMPLER-ACCEPT] Accepting token ID %d at index %zu\n", id, i); common_sampler_accept(slot.smpl, id, true); slot.n_decoded += 1; @@ -3647,6 +3648,7 @@ struct server_context { llama_tokens draft; if (slot.has_mtp) { + SLT_INF(slot, "[POS-SYNC] Before draft gen. n_past = %d\n", slot.n_past); llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); draft.reserve(1); draft.push_back(draft_id); @@ -3682,21 +3684,39 @@ struct server_context { } SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); - + SLT_INF(slot, "[POS-SYNC] Before validation decode. n_past = %d, spec_batch_size = %d\n", slot.n_past, slot.batch_spec.n_tokens); llama_decode(ctx, slot.batch_spec); + const size_t n_embd = llama_n_embd(llama_get_model(ctx)); + const size_t golden_buffer_size_in_floats = slot.batch_spec.n_tokens * n_embd; + const float* golden_embd_ptr = llama_get_embeddings(ctx); + double golden_checksum = calculate_vector_sum_double(golden_embd_ptr, golden_buffer_size_in_floats); + SLT_INF(slot, "[VERIFY] Golden checksum after validation: %e (size: %zu tokens)\n", golden_checksum, slot.batch_spec.n_tokens); + // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); - + SLT_INF(slot, "[POS-SYNC] Tokens accepted: %zu\n", ids.size()); + if (slot.has_mtp) { + llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, ids.size() - 1)); + + const float* embd_after_draft_ptr = llama_get_embeddings(ctx); + double checksum_after_draft = calculate_vector_sum_double(embd_after_draft_ptr, golden_buffer_size_in_floats); + SLT_INF(slot, "[VERIFY] Checksum after draft gen (should be unchanged): %e\n", checksum_after_draft); + slot.mtp_kv_update_batch.clear(); for (int32_t i = 0; i < ids.size(); ++i) { slot.mtp_kv_update_batch.push_back({ ids[i], slot.n_past + i, i }); } - mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, "GEN_ACCEPTED"); + mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, false); + + const float* embd_after_update_ptr = llama_get_embeddings(ctx); + double checksum_after_update = calculate_vector_sum_double(embd_after_update_ptr, golden_buffer_size_in_floats); + SLT_INF(slot, "[VERIFY] Checksum after MTP update (should be unchanged): %e\n", checksum_after_update); } slot.n_past += ids.size(); + SLT_INF(slot, "[POS-SYNC] After n_past update. New n_past = %d\n", slot.n_past); slot.n_decoded += ids.size(); // update how many tokens out of those tested were accepted From febd8235d27fe9174ee4b54ea7a10e630939fee0 Mon Sep 17 00:00:00 2001 From: samuel Date: Sun, 5 Oct 2025 14:43:40 -0300 Subject: [PATCH 22/35] mtp-batch (wip): fix how to warmup kv cache for MTP --- common/speculative.cpp | 29 ++++++++--------------------- common/speculative.h | 3 +-- src/llama-model.cpp | 17 +++++++++++++++-- tools/server/server.cpp | 40 ++++++++++++++++++---------------------- 4 files changed, 42 insertions(+), 47 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 7a17b8d965e..2e0b91a4e26 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -403,36 +403,23 @@ llama_token mtp_speculative_gen_draft( } -void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, - bool is_prompt_warmup) { - - if (tokens.empty()) { +void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) { + if (batch.n_tokens == 0) { return; } - const size_t n_to_process = tokens.size(); - std::string details_str; - for (size_t i = 0; i < std::min((size_t)5, n_to_process); ++i) { - details_str += " {id: " + std::to_string(tokens[i].id) + ", pos: " + std::to_string(tokens[i].n_past) + "}"; - } - LOG_INF("[MTP-UPDATE|%s] Updating %zu tokens. Details:%s ...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", n_to_process, details_str.c_str()); - - llama_batch mtp_batch = llama_batch_init(n_to_process, 0, 1); - - for (size_t i = 0; i < n_to_process; ++i) { - const mtp_kv_update_data& token_data = tokens[i]; - // Check seq_id {0}, it may be a problem with multiple sequences. - common_batch_add(mtp_batch, token_data.id, token_data.n_past, {0}, false); - } + LOG_INF("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens); + llama_batch mtp_batch = batch; mtp_batch.update_mtp_kv = true; mtp_batch.use_mtp_head = true; mtp_batch.is_mtp_prompt_warmup = is_prompt_warmup; - llama_decode(ctx, mtp_batch); + for (int i = 0; i < mtp_batch.n_tokens; ++i) { + mtp_batch.logits[i] = false; + } - llama_batch_free(mtp_batch); - tokens.clear(); + llama_decode(ctx, mtp_batch); } // Debug function - It will be removed later diff --git a/common/speculative.h b/common/speculative.h index 11c0d4553e6..e121e8ed146 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -49,7 +49,6 @@ llama_tokens common_speculative_gen_draft( const llama_tokens & prompt, llama_token id_last); -void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, - bool is_prompt_warmup); +void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup); double calculate_vector_sum_double(const float* vec, size_t size); \ No newline at end of file diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 1f00bb7dd76..6ca53a80cd1 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13788,6 +13788,13 @@ struct llm_build_glm4 : public llm_graph_context { struct llm_build_glm4_moe : public llm_graph_context { llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + LLAMA_LOG_WARN( + "[GRAPH_BUILD] Building graph. Path: %s, MTP_Update: %s, UBatch_Tokens: %d, First_Pos: %d\n", + params.use_mtp_head ? "MTP" : "MAIN", + params.update_mtp_kv ? "true" : "false", + n_tokens, + n_tokens > 0 ? ubatch.pos[0] : -1 + ); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -13906,7 +13913,10 @@ struct llm_build_glm4_moe : public llm_graph_context { cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - + if (ubatch.n_tokens > 0) { + LLAMA_LOG_WARN("[KV_WRITE] path=MAIN, layer=%d, n_tokens=%d, pos_start=%d, pos_end=%d\n", + il, ubatch.n_tokens, ubatch.pos[0], ubatch.pos[ubatch.n_tokens-1]); + } cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); @@ -14060,7 +14070,10 @@ struct llm_build_glm4_moe : public llm_graph_context { // LLAMA_LOG_WARN(" - Qcur shape: [%d, %d, %d]\n", Qcur->ne[0], Qcur->ne[1], Qcur->ne[2]); // LLAMA_LOG_WARN(" - Kcur shape: [%d, %d, %d]\n", Kcur->ne[0], Kcur->ne[1], Kcur->ne[2]); // LLAMA_LOG_WARN(" - Vcur shape: [%d, %d, %d]\n", Vcur->ne[0], Vcur->ne[1], Vcur->ne[2]); - + if (ubatch.n_tokens > 0) { + LLAMA_LOG_WARN("[KV_WRITE] path=MTP, layer=%d, n_tokens=%d, pos_start=%d, pos_end=%d\n", + il, ubatch.n_tokens, ubatch.pos[0], ubatch.pos[ubatch.n_tokens-1]); + } cur = build_attn(inp_attn, mtp_layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 3d025b21202..3399e16823a 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1296,7 +1296,6 @@ struct server_slot { common_speculative * spec = nullptr; bool has_mtp = false; - std::vector mtp_kv_update_batch; int32_t last_tok_idx = -1; std::vector lora; @@ -3387,9 +3386,6 @@ struct server_context { slot.n_prompt_tokens_processed += n_pos; } - if (slot.has_mtp) { - slot.mtp_kv_update_batch.clear(); - } // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { // get next token to process @@ -3401,9 +3397,6 @@ struct server_context { // embedding requires all tokens in the batch to be output const bool need_embd = server_task_type_need_embd(slot.task_type); - if (slot.has_mtp) { - slot.mtp_kv_update_batch.push_back({ cur_tok, slot.n_past, batch.n_tokens }); - } common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); slot.cache_tokens.push_back(cur_tok); @@ -3520,19 +3513,17 @@ struct server_context { continue; // continue loop of n_batch } - - // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation - for (auto & slot : slots) { - if (slot.state == SLOT_STATE_PROCESSING_PROMPT && slot.has_mtp && !slot.mtp_kv_update_batch.empty()) { - SLT_INF(slot, "DEBUG-KV-REQ: Warming up MTP cache for prompt chunk of size %zu. Positions: %d ... %d\n", - slot.mtp_kv_update_batch.size(), - slot.mtp_kv_update_batch.front().n_past, - slot.mtp_kv_update_batch.back().n_past - ); - mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, true); + + bool needs_mtp_warmup = false; + if (slot_batched && slot_batched->has_mtp) { + if (slot_batched->state == SLOT_STATE_PROCESSING_PROMPT || slot_batched->state == SLOT_STATE_DONE_PROMPT) { + needs_mtp_warmup = true; } } + if (needs_mtp_warmup) { + mtp_update_kv_cache(ctx, batch_view, true); + } // move the head of the batch forward with the number of tokens we just processed i_next = i + n_tokens; @@ -3704,12 +3695,17 @@ struct server_context { double checksum_after_draft = calculate_vector_sum_double(embd_after_draft_ptr, golden_buffer_size_in_floats); SLT_INF(slot, "[VERIFY] Checksum after draft gen (should be unchanged): %e\n", checksum_after_draft); - slot.mtp_kv_update_batch.clear(); - for (int32_t i = 0; i < ids.size(); ++i) { - slot.mtp_kv_update_batch.push_back({ ids[i], slot.n_past + i, i }); - } - mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, false); + if (!ids.empty()) { + llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1); + + for (size_t i = 0; i < ids.size(); ++i) { + common_batch_add(accepted_batch, ids[i], slot.n_past + i, { slot.id }, false); + } + mtp_update_kv_cache(ctx, accepted_batch, false); + + llama_batch_free(accepted_batch); + } const float* embd_after_update_ptr = llama_get_embeddings(ctx); double checksum_after_update = calculate_vector_sum_double(embd_after_update_ptr, golden_buffer_size_in_floats); SLT_INF(slot, "[VERIFY] Checksum after MTP update (should be unchanged): %e\n", checksum_after_update); From 5e1d719beffccf8c22784c24b52ff6f5ab56b9ff Mon Sep 17 00:00:00 2001 From: samuel Date: Thu, 9 Oct 2025 15:21:23 -0300 Subject: [PATCH 23/35] mtp-batch (feat): Create and manage sinfo for MTP --- common/speculative.cpp | 27 ++++++- common/speculative.h | 7 ++ include/llama.h | 6 +- src/llama-context.cpp | 69 ++++++++++++++-- src/llama-context.h | 10 +++ src/llama-kv-cache-unified.cpp | 141 +++++++++++++++++++++++---------- src/llama-kv-cache-unified.h | 15 +++- tools/server/server.cpp | 22 ++--- 8 files changed, 232 insertions(+), 65 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 2e0b91a4e26..f71982f9e4a 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -418,10 +418,35 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b for (int i = 0; i < mtp_batch.n_tokens; ++i) { mtp_batch.logits[i] = false; } - llama_decode(ctx, mtp_batch); } +void mtp_accept_tokens( + struct llama_context * ctx, + const std::vector & ids, + int32_t n_past_base, + llama_seq_id seq_id +) { + if (ids.empty()) { + return; + } + + if (!llama_mtp_prepare_sinfo_for_update(ctx, ids.size())) { + return; + } + + llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1); + for (size_t i = 0; i < ids.size(); ++i) { + common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, false); + } + + mtp_update_kv_cache(ctx, accepted_batch, false); + + llama_mtp_cancel_sinfo_update(ctx); + + llama_batch_free(accepted_batch); +} + // Debug function - It will be removed later double calculate_vector_sum_double(const float* vec, size_t size) { if (!vec) { diff --git a/common/speculative.h b/common/speculative.h index e121e8ed146..d361e69d077 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -51,4 +51,11 @@ llama_tokens common_speculative_gen_draft( void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup); +void mtp_accept_tokens( + struct llama_context * ctx, + const std::vector & ids, + int32_t n_past_base, + llama_seq_id seq_id +); + double calculate_vector_sum_double(const float* vec, size_t size); \ No newline at end of file diff --git a/include/llama.h b/include/llama.h index e6d6aadf7e5..024d53f21cc 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1457,7 +1457,11 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); - LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state); + LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state); + + LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted); + + LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx); #ifdef __cplusplus } diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 81c7a48d0e3..edf5d747f10 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -18,6 +18,11 @@ // // llama_context // +struct llama_context_kv_cache_data { + llama_kv_cache_unified::slot_info_vec_t last_main_model_sinfos; + llama_kv_cache_unified::slot_info_vec_t resized_sinfo_for_force; + const llama_kv_cache_unified::slot_info_vec_t * forced_sinfos = nullptr; +}; llama_context::llama_context( const llama_model & model, @@ -106,6 +111,8 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; + kv_cache_data = new llama_context_kv_cache_data(); + { const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : supports_set_rows; @@ -371,6 +378,7 @@ llama_context::llama_context( llama_context::~llama_context() { ggml_opt_free(opt_ctx); + delete static_cast(kv_cache_data); } void llama_context::synchronize() { @@ -1017,6 +1025,8 @@ int llama_context::encode(const llama_batch & batch_inp) { int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + + auto * kvd = static_cast(kv_cache_data); LLAMA_LOG_WARN("[DEBUG-DECODE-ENTRY] Entering llama_decode. update_mtp_kv=%s, use_mtp_head=%s\n", batch_inp.update_mtp_kv ? "true" : "false", batch_inp.use_mtp_head ? "true" : "false" @@ -1076,10 +1086,31 @@ int llama_context::decode(const llama_batch & batch_inp) { // handle any pending defrags/shifts kv_self_update(false); - llama_memory_context_ptr mctx; + std::unique_ptr mctx; while (true) { - mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + if (cparams.warmup) { + mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + } else { + if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) { + LLAMA_LOG_WARN("[DEBUG-CACHE-REUSE] Forcing sinfos, bypassing find_slot.\n"); + + mctx = static_cast(memory.get())->init_batch_with_sinfos( + *balloc, cparams.n_ubatch, *kvd->forced_sinfos, true + ); + } else { + mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + + if (!batch_inp.use_mtp_head && !batch_inp.update_mtp_kv) { + if (mctx && mctx->get_status() == LLAMA_MEMORY_STATUS_SUCCESS) { + kvd->last_main_model_sinfos = static_cast(mctx.get())->get_sinfos(); + } else { + kvd->last_main_model_sinfos.clear(); + } + } + } + } + if (!mctx) { return -2; } @@ -1091,29 +1122,28 @@ int llama_context::decode(const llama_batch & batch_inp) { case LLAMA_MEMORY_STATUS_NO_UPDATE: { LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status()); - return -2; } case LLAMA_MEMORY_STATUS_FAILED_PREPARE: { + // if (use_last_main_model_sinfos) { + // LLAMA_LOG_ERROR("%s: Mismatch between ubatches and sinfos during reuse.\n", __func__); + // return -1; + // } + if (!did_optimize) { did_optimize = true; - if (kv_self_update(true)) { LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens()); - continue; } } - LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens()); - return 1; } case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: { LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens()); - return -2; } } @@ -3073,4 +3103,27 @@ void llama_opt_epoch( void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state) { ctx->draft_input_hidden_state = hidden_state; +} + +bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted) { + auto * kvd = static_cast(ctx->kv_cache_data); + const auto & last_sinfo = kvd->last_main_model_sinfos; + + if (last_sinfo.empty() || last_sinfo[0].idxs.empty()) { + LLAMA_LOG_ERROR("%s: The sinfo for the last main call is not available.", __func__); + return false; + } + + kvd->resized_sinfo_for_force = last_sinfo; + + kvd->resized_sinfo_for_force[0].idxs[0].resize(n_accepted); + + kvd->forced_sinfos = &kvd->resized_sinfo_for_force; + + return true; +} + +void llama_mtp_cancel_sinfo_update(struct llama_context * ctx) { + auto * kvd = static_cast(ctx->kv_cache_data); + kvd->forced_sinfos = nullptr; } \ No newline at end of file diff --git a/src/llama-context.h b/src/llama-context.h index aa6ced79471..654409cb6cc 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -20,6 +20,8 @@ class llama_io_write_i; struct llama_memory_i; struct llama_memory_context_i; +struct llama_context_kv_cache_data; + struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs llama_context( @@ -27,6 +29,11 @@ struct llama_context { llama_context_params params); ~llama_context(); + + llama_context(const llama_context &) = delete; + llama_context & operator=(const llama_context &) = delete; + llama_context(llama_context &&) = delete; + llama_context & operator=(llama_context &&) = delete; void synchronize(); @@ -211,6 +218,9 @@ struct llama_context { std::unique_ptr mtp_memory_batch(const llama_batch& batch_inp); + // For MTP KV cache cell reuse + void * kv_cache_data; + private: llm_graph_params graph_params( llm_graph_result * res, diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 53466264cd9..787fb8d9a55 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -508,6 +508,34 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch( return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } +llama_memory_context_ptr llama_kv_cache_unified::init_batch_with_sinfos( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + const slot_info_vec_t & sinfos, + bool is_inplace_update) { + + if (sinfos.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + balloc.split_reset(); + std::vector ubatches; + while (true) { + auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true); + if (ubatch.n_tokens == 0) { + break; + } + ubatches.push_back(std::move(ubatch)); + } + + if (ubatches.size() != sinfos.size()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique( + this, sinfos, std::move(ubatches), is_inplace_update); +} + llama_memory_context_ptr llama_kv_cache_unified::init_full() { return std::make_unique(this); } @@ -738,6 +766,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d } llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const { + LLAMA_LOG_WARN("%s: Entering find_slot for ubatch of %d tokens.\n", __func__, ubatch.n_tokens); if (debug > 0) { const auto & cells = v_cells[seq_to_stream[1]]; @@ -928,72 +957,95 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ } assert(res.s1 >= res.s0); + if (!res.empty()) { + std::string idxs_str; + for (const auto& vec : res.idxs) { + if (!vec.empty()) { + if (vec.size() > 8) { + idxs_str += " [" + std::to_string(vec.front()) + "..." + std::to_string(vec.back()) + " (" + std::to_string(vec.size()) + " cells)]"; + } else { + idxs_str += " ["; + for(size_t i = 0; i < vec.size(); ++i) { + idxs_str += std::to_string(vec[i]) + (i == vec.size() - 1 ? "" : ", "); + } + idxs_str += "]"; + } + } + } + LLAMA_LOG_WARN("%s: find_slot SUCCEEDED for ubatch of %d tokens. Idxs:%s\n", __func__, ubatch.n_tokens, idxs_str.c_str()); + } else { + LLAMA_LOG_ERROR("%s: find_slot FAILED to allocate cells for ubatch of %d tokens.\n", __func__, ubatch.n_tokens); + } return res; } -void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) { - // keep track of the max sequence position that we would overwrite with this ubatch - // for non-SWA cache, this would be always empty - llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; - for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { - seq_pos_max_rm[s] = -1; - } +void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update) { + if (!is_inplace_update) { + // keep track of the max sequence position that we would overwrite with this ubatch + // for non-SWA cache, this would be always empty + llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + seq_pos_max_rm[s] = -1; + } - assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size()); + assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size()); - for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { - for (uint32_t ii = 0; ii < sinfo.size(); ++ii) { - const uint32_t i = s*sinfo.size() + ii; + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + for (uint32_t ii = 0; ii < sinfo.size(); ++ii) { + const uint32_t i = s*sinfo.size() + ii; - auto & cells = v_cells[sinfo.strm[s]]; + auto & cells = v_cells[sinfo.strm[s]]; - const auto idx = sinfo.idxs[s][ii]; + const auto idx = sinfo.idxs[s][ii]; - if (!cells.is_empty(idx)) { - assert(cells.seq_count(idx) == 1); + if (!is_inplace_update) { + if (!cells.is_empty(idx)) { + assert(cells.seq_count(idx) == 1); - const llama_seq_id seq_id = cells.seq_get(idx); - const llama_pos pos = cells.pos_get(idx); + const llama_seq_id seq_id = cells.seq_get(idx); + const llama_pos pos = cells.pos_get(idx); - seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); - cells.rm(idx); - } + cells.rm(idx); + } + } - cells.pos_set(idx, ubatch.pos[i]); + cells.pos_set(idx, ubatch.pos[i]); - for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { - cells.seq_add(idx, ubatch.seq_id[i][s]); + for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { + cells.seq_add(idx, ubatch.seq_id[i][s]); + } } } - } - // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence - // will be present in the cache. so we have to purge any position which is less than those we would overwrite - // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 - for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { - if (seq_pos_max_rm[s] == -1) { - continue; - } + // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence + // will be present in the cache. so we have to purge any position which is less than those we would overwrite + // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (seq_pos_max_rm[s] == -1) { + continue; + } - GGML_ASSERT(s < seq_to_stream.size()); + GGML_ASSERT(s < seq_to_stream.size()); - auto & cells = v_cells[seq_to_stream[s]]; + auto & cells = v_cells[seq_to_stream[s]]; - if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { - LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", - __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); + if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { + LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", + __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); - seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); + seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); + } } - } - // move the head at the end of the slot - for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { - auto & head = v_heads[sinfo.strm[s]]; + // move the head at the end of the slot + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + auto & head = v_heads[sinfo.strm[s]]; - head = sinfo.idxs[s].back() + 1; + head = sinfo.idxs[s].back() + 1; + } } } @@ -2290,7 +2342,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv, llama_kv_cache_unified::slot_info_vec_t sinfos, - std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) { + std::vector ubatches, + bool is_inplace_update) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)), is_inplace_update(is_inplace_update) { } llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default; @@ -2315,7 +2368,7 @@ bool llama_kv_cache_unified_context::apply() { return true; } - kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]); + kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur], is_inplace_update); n_kv = kv->get_n_kv(); diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index c02607c2d0f..f64f7faa5c0 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -116,6 +116,12 @@ class llama_kv_cache_unified : public llama_memory_i { llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) override; + + llama_memory_context_ptr init_batch_with_sinfos( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + const slot_info_vec_t & sinfos, + bool is_inplace_update); llama_memory_context_ptr init_full() override; @@ -181,7 +187,7 @@ class llama_kv_cache_unified : public llama_memory_i { slot_info find_slot(const llama_ubatch & ubatch, bool cont) const; // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]] - void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch); + void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update = false); // // input API @@ -321,7 +327,8 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { llama_kv_cache_unified_context( llama_kv_cache_unified * kv, slot_info_vec_t sinfos, - std::vector ubatches); + std::vector ubatches, + bool is_inplace_update = false); virtual ~llama_kv_cache_unified_context(); @@ -365,6 +372,8 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { void set_sinfos(slot_info_vec_t new_sinfos); + const slot_info_vec_t & get_sinfos() const { return sinfos; } + private: llama_memory_status status; @@ -399,4 +408,6 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { // a heuristic, to avoid attending the full cache if it is not yet utilized // as the cache gets filled, the benefit from this heuristic disappears int32_t n_kv; + + bool is_inplace_update = false; }; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 3399e16823a..844805d0ced 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3522,8 +3522,15 @@ struct server_context { } if (needs_mtp_warmup) { - mtp_update_kv_cache(ctx, batch_view, true); + if (llama_mtp_prepare_sinfo_for_update(ctx, batch_view.n_tokens)) { + mtp_update_kv_cache(ctx, batch_view, true); + + llama_mtp_cancel_sinfo_update(ctx); + } else { + LOG_ERR("%s: Failed to prepare the MTP symphony for warmup.", __func__); + } } + // move the head of the batch forward with the number of tokens we just processed i_next = i + n_tokens; @@ -3696,16 +3703,13 @@ struct server_context { SLT_INF(slot, "[VERIFY] Checksum after draft gen (should be unchanged): %e\n", checksum_after_draft); if (!ids.empty()) { - llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1); - - for (size_t i = 0; i < ids.size(); ++i) { - common_batch_add(accepted_batch, ids[i], slot.n_past + i, { slot.id }, false); - } + llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, ids.size() - 1)); + } else { + llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, 0)); + } - mtp_update_kv_cache(ctx, accepted_batch, false); + mtp_accept_tokens(ctx, ids, slot.n_past, slot.id); - llama_batch_free(accepted_batch); - } const float* embd_after_update_ptr = llama_get_embeddings(ctx); double checksum_after_update = calculate_vector_sum_double(embd_after_update_ptr, golden_buffer_size_in_floats); SLT_INF(slot, "[VERIFY] Checksum after MTP update (should be unchanged): %e\n", checksum_after_update); From 6f74ba38070d62d37bc0fb71ce9871e1a4ffabcc Mon Sep 17 00:00:00 2001 From: samuel Date: Thu, 9 Oct 2025 22:27:18 -0300 Subject: [PATCH 24/35] mtp-batch (fix): prevent mtp draft from polluting the cache --- common/speculative.cpp | 4 ++++ include/llama.h | 6 +++++- src/llama-context.cpp | 26 +++++++++++++++++++++++++- src/llama-context.h | 2 ++ tools/server/server.cpp | 5 ++--- 5 files changed, 38 insertions(+), 5 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index f71982f9e4a..8249a3a52c7 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -374,6 +374,8 @@ llama_token mtp_speculative_gen_draft( return -1; } llama_batch mtp_batch = llama_batch_init(1, 0, 1); + const llama_pos draft_pos = n_past; + const llama_seq_id draft_seq_id = 0; common_batch_add(mtp_batch, id_last, n_past, {0}, true); mtp_batch.update_mtp_kv = false; @@ -387,6 +389,8 @@ llama_token mtp_speculative_gen_draft( llama_decode(ctx, mtp_batch); llama_batch_free(mtp_batch); + llama_kv_cache_seq_rm(ctx, draft_seq_id, draft_pos, draft_pos + 1); + const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); const int n_vocab = llama_n_vocab(vocab); diff --git a/include/llama.h b/include/llama.h index 024d53f21cc..01e75cea62c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1460,9 +1460,13 @@ extern "C" { LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state); LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted); - + + LLAMA_API bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx); + LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx); + LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); + #ifdef __cplusplus } #endif diff --git a/src/llama-context.cpp b/src/llama-context.cpp index edf5d747f10..8939edabaa7 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -3105,6 +3105,20 @@ void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float ctx->draft_input_hidden_state = hidden_state; } +bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx) { + auto * kvd = static_cast(ctx->kv_cache_data); + const auto & last_sinfo = kvd->last_main_model_sinfos; + + if (last_sinfo.empty()) { + LLAMA_LOG_ERROR("%s: The main call sinfo is not available for warmup.\n", __func__); + return false; + } + + kvd->forced_sinfos = &last_sinfo; + return true; +} + + bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted) { auto * kvd = static_cast(ctx->kv_cache_data); const auto & last_sinfo = kvd->last_main_model_sinfos; @@ -3126,4 +3140,14 @@ bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_acc void llama_mtp_cancel_sinfo_update(struct llama_context * ctx) { auto * kvd = static_cast(ctx->kv_cache_data); kvd->forced_sinfos = nullptr; -} \ No newline at end of file +} + +void llama_context::kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + if (memory) { + static_cast(memory.get())->seq_rm(seq_id, p0, p1); + } +} + +void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + ctx->kv_cache_seq_rm(seq_id, p0, p1); +} diff --git a/src/llama-context.h b/src/llama-context.h index 654409cb6cc..e15a336938a 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -100,6 +100,8 @@ struct llama_context { int32_t il_start, int32_t il_end); + void kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1); + // process a single ubatch with a specific graph type // if memory_context is provided, it will be applied first to the context's memory // ret contains the status of the graph computation diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 844805d0ced..91cc438dccc 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3520,11 +3520,10 @@ struct server_context { needs_mtp_warmup = true; } } - + if (needs_mtp_warmup) { - if (llama_mtp_prepare_sinfo_for_update(ctx, batch_view.n_tokens)) { + if (llama_mtp_prepare_sinfo_for_warmup(ctx)) { mtp_update_kv_cache(ctx, batch_view, true); - llama_mtp_cancel_sinfo_update(ctx); } else { LOG_ERR("%s: Failed to prepare the MTP symphony for warmup.", __func__); From 913af8f48d2dab1d9e907cf6c48c921a229a295c Mon Sep 17 00:00:00 2001 From: samuel Date: Fri, 10 Oct 2025 16:44:28 -0300 Subject: [PATCH 25/35] mtp-batch(refactor): Replace MTP boolean flags with an explicit operation enum --- common/speculative.cpp | 26 ++++++---- include/llama.h | 15 ++++-- src/llama-batch.cpp | 4 +- src/llama-context.cpp | 104 ++++++++++++++++++++-------------------- src/llama-context.h | 7 +-- src/llama-graph.h | 6 +-- src/llama-model.cpp | 44 ++++++++--------- tools/server/server.cpp | 15 +++--- 8 files changed, 113 insertions(+), 108 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 8249a3a52c7..2b63e16e02d 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -378,17 +378,21 @@ llama_token mtp_speculative_gen_draft( const llama_seq_id draft_seq_id = 0; common_batch_add(mtp_batch, id_last, n_past, {0}, true); - mtp_batch.update_mtp_kv = false; - mtp_batch.use_mtp_head = true; + mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_GEN; - LOG_INF("[DEBUG-DRAFT-CALL] Calling llama_decode for draft. update_mtp_kv=%s, use_mtp_head=%s\n", - mtp_batch.update_mtp_kv ? "true" : "false", - mtp_batch.use_mtp_head ? "true" : "false" - ); + // LOG_INF("[DEBUG-DRAFT-CALL] Calling llama_decode for draft. update_mtp_kv=%s, use_mtp_head=%s\n", + // mtp_batch.update_mtp_kv ? "true" : "false", + // mtp_batch.use_mtp_head ? "true" : "false" + // ); + // Perform the MTP draft generation decode. This writes the MTP layer's + // KV state for the draft token into the cache. llama_decode(ctx, mtp_batch); llama_batch_free(mtp_batch); + // CRITICAL: Purge the metadata for the draft token we just wrote. + // This makes the physical cell available again for the main model's validation pass, + // preventing a cache state corruption where two cells map to the same logical position. llama_kv_cache_seq_rm(ctx, draft_seq_id, draft_pos, draft_pos + 1); const llama_model * model = llama_get_model(ctx); @@ -398,7 +402,7 @@ llama_token mtp_speculative_gen_draft( cur_p->size = n_vocab; for (int i = 0; i < n_vocab; ++i) { cur_p->data[i].id = i; - cur_p->data[i].logit = llama_get_logits_ith(ctx, 0)[i]; // TODO: check if position 0 is the right + cur_p->data[i].logit = llama_get_logits_ith(ctx, 0)[i]; // For a single-token batch, logits are always at index 0. } cur_p->sorted = false; common_sampler_apply_chain(smpl, cur_p); @@ -415,9 +419,11 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b LOG_INF("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens); llama_batch mtp_batch = batch; - mtp_batch.update_mtp_kv = true; - mtp_batch.use_mtp_head = true; - mtp_batch.is_mtp_prompt_warmup = is_prompt_warmup; + if (is_prompt_warmup) { + mtp_batch.mtp_params.op_type = MTP_OP_WARMUP; + } else { + mtp_batch.mtp_params.op_type = MTP_OP_UPDATE_ACCEPTED; + } for (int i = 0; i < mtp_batch.n_tokens; ++i) { mtp_batch.logits[i] = false; diff --git a/include/llama.h b/include/llama.h index 01e75cea62c..89c85103109 100644 --- a/include/llama.h +++ b/include/llama.h @@ -221,6 +221,17 @@ extern "C" { // - if not: only the last token is output // ) // + typedef enum { + MTP_OP_NONE, + MTP_OP_WARMUP, + MTP_OP_UPDATE_ACCEPTED, + MTP_OP_DRAFT_GEN, + } llama_mtp_op_type; + + typedef struct llama_mtp_params { + llama_mtp_op_type op_type; + } llama_mtp_params; + typedef struct llama_batch { int32_t n_tokens; @@ -230,9 +241,7 @@ extern "C" { int32_t * n_seq_id; llama_seq_id ** seq_id; int8_t * logits; // TODO: rename this to "output" - bool update_mtp_kv; - bool use_mtp_head; - bool is_mtp_prompt_warmup; + llama_mtp_params mtp_params; } llama_batch; enum llama_model_kv_override_type { diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 3268467f274..c01960c55ea 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -841,9 +841,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, /*logits =*/ nullptr, - /*.use_mtp_head =*/ false, - /*update_mtp_kv =*/ false, - /*.is_mtp_prompt_warmup =*/ false, + /*.mtp_params =*/ { MTP_OP_NONE }, }; if (embd) { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 8939edabaa7..f22a3980489 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -750,7 +750,7 @@ static double calculate_vector_sum(const float* vec, size_t size) { } llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret, - bool do_mtp_kv_update, bool use_mtp_head, bool is_mtp_prompt_warmup) { + const llama_mtp_params & mtp_params) { if (mctx && !mctx->apply()) { LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); ret = GGML_STATUS_FAILED; @@ -762,7 +762,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll // the new graph parameters // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters - const auto gparams = graph_params(res, ubatch, mctx, gtype, do_mtp_kv_update, use_mtp_head); + const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params); if (!graph_reuse_disable && res->can_reuse(gparams)) { //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); @@ -793,12 +793,12 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } } - if (do_mtp_kv_update || (use_mtp_head && !do_mtp_kv_update)) { // If it is any MTP operation + if (mtp_params.op_type != MTP_OP_NONE) { // If it is any MTP operation const char * target_tensor_name = "result_embd_pooled"; ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name); const float * source_hidden_state = nullptr; - if (is_mtp_prompt_warmup || (do_mtp_kv_update && !is_mtp_prompt_warmup)) { + if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { source_hidden_state = this->embd; } else { source_hidden_state = this->draft_input_hidden_state; @@ -806,9 +806,9 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll if (source_hidden_state != nullptr && hidden_states_input != nullptr) { const size_t n_embd = this->model.hparams.n_embd; - const size_t n_tokens_for_sum = (do_mtp_kv_update && ubatch.n_tokens > 2) ? ubatch.n_tokens : 1; + const size_t n_tokens_for_sum = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED && ubatch.n_tokens > 2) ? ubatch.n_tokens : 1; double input_sum = calculate_vector_sum(source_hidden_state, n_tokens_for_sum * n_embd); - const char * op_type = (do_mtp_kv_update) ? "MTP_UPDATE" : "DRAFT_GEN"; + const char * op_type = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) ? "MTP_UPDATE" : "DRAFT_GEN"; LLAMA_LOG_WARN("[MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n", op_type, input_sum); @@ -833,12 +833,12 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll const int64_t t_exec_start_us = ggml_time_us(); const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1); const int64_t t_exec_end_us = ggml_time_us(); - LLAMA_LOG_INFO( - "[PERF] Graph compute time: %.2f ms (ubatch_size: %u, MTP path: %s)\n", - (t_exec_end_us - t_exec_start_us) / 1000.0, - ubatch.n_tokens, - do_mtp_kv_update ? "yes" : "no" - ); + // LLAMA_LOG_INFO( + // "[PERF] Graph compute time: %.2f ms (ubatch_size: %u, MTP path: %s)\n", + // (t_exec_end_us - t_exec_start_us) / 1000.0, + // ubatch.n_tokens, + // do_mtp_kv_update ? "yes" : "no" + // ); if (status != GGML_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); ret = status; @@ -846,7 +846,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } ret = GGML_STATUS_SUCCESS; - if (do_mtp_kv_update || use_mtp_head) { + if (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { ggml_tensor * sum_tensor = ggml_get_tensor(res->get_ctx(), "mtp_input_sum"); if (sum_tensor) { LLAMA_LOG_WARN("[DEBUG-SUM] MTP input sum node successfully created.\n"); @@ -912,7 +912,7 @@ int llama_context::encode(const llama_batch & batch_inp) { cparams.causal_attn = false; ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, false, false, false); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, { MTP_OP_NONE }); cparams.causal_attn = causal_attn_org; @@ -1027,10 +1027,10 @@ int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT auto * kvd = static_cast(kv_cache_data); - LLAMA_LOG_WARN("[DEBUG-DECODE-ENTRY] Entering llama_decode. update_mtp_kv=%s, use_mtp_head=%s\n", - batch_inp.update_mtp_kv ? "true" : "false", - batch_inp.use_mtp_head ? "true" : "false" - ); + // LLAMA_LOG_WARN("[DEBUG-DECODE-ENTRY] Entering llama_decode. update_mtp_kv=%s, use_mtp_head=%s\n", + // batch_inp.update_mtp_kv ? "true" : "false", + // batch_inp.use_mtp_head ? "true" : "false" + // ); if (!memory) { LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); @@ -1101,7 +1101,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } else { mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); - if (!batch_inp.use_mtp_head && !batch_inp.update_mtp_kv) { + if (batch_inp.mtp_params.op_type == MTP_OP_NONE) { if (mctx && mctx->get_status() == LLAMA_MEMORY_STATUS_SUCCESS) { kvd->last_main_model_sinfos = static_cast(mctx.get())->get_sinfos(); } else { @@ -1158,9 +1158,9 @@ int llama_context::decode(const llama_batch & batch_inp) { }; int64_t n_outputs_prev = 0; - const bool do_mtp_kv_update = batch_inp.update_mtp_kv; - const bool use_mtp_head = batch_inp.use_mtp_head; - const bool is_prompt_warmup = batch_inp.is_mtp_prompt_warmup; + // const bool do_mtp_kv_update = batch_inp.update_mtp_kv; + // const bool use_mtp_head = batch_inp.use_mtp_head; + // const bool is_prompt_warmup = batch_inp.is_mtp_prompt_warmup; do { const auto & ubatch = mctx->get_ubatch(); @@ -1169,13 +1169,13 @@ int llama_context::decode(const llama_batch & batch_inp) { for (uint32_t i = 0; i < std::min((uint32_t)5, ubatch.n_tokens); ++i) { pos_str += std::to_string(ubatch.pos[i]) + " "; } - LLAMA_LOG_WARN( - "[DEBUG-POS] ubatch_size=%u, update_mtp_kv=%s, use_mtp_head=%s. Positions: %s...\n", - ubatch.n_tokens, - batch_inp.update_mtp_kv ? "true" : "false", - batch_inp.use_mtp_head ? "true" : "false", - pos_str.c_str() - ); + // LLAMA_LOG_WARN( + // "[DEBUG-POS] ubatch_size=%u, update_mtp_kv=%s, use_mtp_head=%s. Positions: %s...\n", + // ubatch.n_tokens, + // batch_inp.update_mtp_kv ? "true" : "false", + // batch_inp.use_mtp_head ? "true" : "false", + // pos_str.c_str() + // ); } // count the outputs in this ubatch @@ -1193,16 +1193,16 @@ int llama_context::decode(const llama_batch & batch_inp) { // needs to happen before the graph is built n_outputs = n_outputs_new; } - if (do_mtp_kv_update) { - LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] MTP KV Update ubatch: n_tokens=%d\n", ubatch.n_tokens); - std::string positions_str; - for (int i = 0; i < std::min((uint32_t)5, ubatch.n_tokens); ++i) { - positions_str += std::to_string(ubatch.pos[i]) + " "; - } - LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] Positions: %s...\n", positions_str.c_str()); - } + // if (do_mtp_kv_update) { + // LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] MTP KV Update ubatch: n_tokens=%d\n", ubatch.n_tokens); + // std::string positions_str; + // for (int i = 0; i < std::min((uint32_t)5, ubatch.n_tokens); ++i) { + // positions_str += std::to_string(ubatch.pos[i]) + " "; + // } + // LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] Positions: %s...\n", positions_str.c_str()); + // } ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update, use_mtp_head, is_prompt_warmup); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, batch_inp.mtp_params); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache llama_pos pos_min[LLAMA_MAX_SEQ]; @@ -1261,17 +1261,17 @@ int llama_context::decode(const llama_batch & batch_inp) { } } - if (use_mtp_head) { - if (t_embd != nullptr) { - LLAMA_LOG_ERROR("[MTP-GRAPH-BUG] The MTP graph returned an embedding tensor when it shouldn't have! This will cause corruption.\n"); - } else { - LLAMA_LOG_WARN("[MTP-GRAPH-OK] The MTP graph correctly did not return an embedding tensor.\n"); - } - } + // if (use_mtp_head) { + // if (t_embd != nullptr) { + // LLAMA_LOG_ERROR("[MTP-GRAPH-BUG] The MTP graph returned an embedding tensor when it shouldn't have! This will cause corruption.\n"); + // } else { + // LLAMA_LOG_WARN("[MTP-GRAPH-OK] The MTP graph correctly did not return an embedding tensor.\n"); + // } + // } // extract embeddings if (t_embd && n_outputs > 0) { - if (!use_mtp_head) { + if (batch_inp.mtp_params.op_type == MTP_OP_NONE) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); @@ -1389,7 +1389,7 @@ int llama_context::decode(const llama_batch & batch_inp) { ggml_backend_sched_reset(sched.get()); } - if (!use_mtp_head) { + if (batch_inp.mtp_params.op_type == MTP_OP_NONE) { synchronize(); const size_t n_embd = this->model.hparams.n_embd; double full_buffer_sum = calculate_vector_sum(this->embd, n_outputs_all * n_embd); @@ -1534,7 +1534,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u auto * res = gf_res_reserve.get(); - const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, false, false); + const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, { MTP_OP_NONE }); res->reset(); @@ -1556,8 +1556,7 @@ llm_graph_params llama_context::graph_params( const llama_ubatch & ubatch, const llama_memory_context_i * mctx, llm_graph_type gtype, - bool update_mtp_kv, - bool use_mtp_head) const { + const llama_mtp_params & mtp_params) const { return { /*.arch =*/ model.arch, /*.hparams =*/ model.hparams, @@ -1570,8 +1569,7 @@ llm_graph_params llama_context::graph_params( /*.loras =*/ &loras, /*.mctx =*/ mctx, /*.cross =*/ &cross, - /*.update_mtp_kv =*/ update_mtp_kv, - /*.use_mtp_head =*/ use_mtp_head, + /*.mtp_params =*/ mtp_params, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), /*.res =*/ res, @@ -2312,7 +2310,7 @@ void llama_context::opt_epoch_iter( auto * res = gf_res_prev.get(); - const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT, false, false); + const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT, { MTP_OP_NONE }); res->reset(); diff --git a/src/llama-context.h b/src/llama-context.h index e15a336938a..70ca4e0832d 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -111,9 +111,7 @@ struct llama_context { llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret, - const bool do_mtp_kv_update, - const bool use_mtp_head, - bool is_mtp_prompt_warmup); + const llama_mtp_params & mtp_params); int encode(const llama_batch & batch_inp); int decode(const llama_batch & batch_inp); @@ -229,8 +227,7 @@ struct llama_context { const llama_ubatch & ubatch, const llama_memory_context_i * mctx, llm_graph_type gtype, - bool update_mtp_kv, - bool use_mtp_head) const; + const llama_mtp_params & mtp_params) const; llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const; diff --git a/src/llama-graph.h b/src/llama-graph.h index 40dd83f0bc1..3c5feadfdc7 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -417,8 +417,7 @@ struct llm_graph_params { const llama_adapter_loras * loras; const llama_memory_context_i * mctx; const llama_cross * cross; - bool update_mtp_kv; - bool use_mtp_head; + llama_mtp_params mtp_params; uint32_t n_outputs; @@ -467,8 +466,7 @@ struct llm_graph_params { cvec == other.cvec && loras == other.loras && cross == other.cross && - update_mtp_kv == other.update_mtp_kv && - use_mtp_head == other.use_mtp_head && + mtp_params.op_type == other.mtp_params.op_type && n_outputs == other.n_outputs; } }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 6ca53a80cd1..77bdf83edae 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13788,24 +13788,24 @@ struct llm_build_glm4 : public llm_graph_context { struct llm_build_glm4_moe : public llm_graph_context { llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - LLAMA_LOG_WARN( - "[GRAPH_BUILD] Building graph. Path: %s, MTP_Update: %s, UBatch_Tokens: %d, First_Pos: %d\n", - params.use_mtp_head ? "MTP" : "MAIN", - params.update_mtp_kv ? "true" : "false", - n_tokens, - n_tokens > 0 ? ubatch.pos[0] : -1 - ); + // LLAMA_LOG_WARN( + // "[GRAPH_BUILD] Building graph. Path: %s, MTP_Update: %s, UBatch_Tokens: %d, First_Pos: %d\n", + // params.use_mtp_head ? "MTP" : "MAIN", + // params.update_mtp_kv ? "true" : "false", + // n_tokens, + // n_tokens > 0 ? ubatch.pos[0] : -1 + // ); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); ggml_tensor * cur; - LLAMA_LOG_WARN( - "[DEBUG-GRAPH-STATE] Building graph. MTP Head=%s, MTP KV Update=%s, n_tokens=%d\n", - params.use_mtp_head ? "true" : "false", - params.update_mtp_kv ? "true" : "false", - n_tokens - ); + // LLAMA_LOG_WARN( + // "[DEBUG-GRAPH-STATE] Building graph. MTP Head=%s, MTP KV Update=%s, n_tokens=%d\n", + // params.use_mtp_head ? "true" : "false", + // params.update_mtp_kv ? "true" : "false", + // n_tokens + // ); // for (int i = 0; i < n_tokens; ++i) { // LLAMA_LOG_WARN(" - ubatch token[%d]: ID=%d, Pos=%d\n", i, ubatch.token[i], ubatch.pos[i]); // } @@ -13817,10 +13817,10 @@ struct llm_build_glm4_moe : public llm_graph_context { ); } - if (params.use_mtp_head) { + if (params.mtp_params.op_type != MTP_OP_NONE) { ggml_tensor* hidden_states_from_main_model; - if (params.update_mtp_kv) { + if (params.mtp_params.op_type == MTP_OP_WARMUP || params.mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); ggml_set_name(hidden_states_from_main_model, "result_embd_pooled"); ggml_set_input(hidden_states_from_main_model); @@ -18349,7 +18349,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { std::unique_ptr llm; - const bool build_mtp = params.update_mtp_kv; + const bool build_mtp = params.mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED; switch (arch) { case LLM_ARCH_LLAMA: @@ -18708,16 +18708,16 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { GGML_ABORT("fatal error"); } - if (!params.use_mtp_head) { + if (params.mtp_params.op_type == MTP_OP_NONE) { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); } const int64_t t_end_us = ggml_time_us(); - LLAMA_LOG_INFO( - "[PERF] Graph build time: %.2f ms (MTP path: %s)\n", - (t_end_us - t_start_us) / 1000.0, - params.use_mtp_head ? "yes" : "no" - ); + // LLAMA_LOG_INFO( + // "[PERF] Graph build time: %.2f ms (MTP path: %s)\n", + // (t_end_us - t_start_us) / 1000.0, + // params.use_mtp_head ? "yes" : "no" + // ); return llm->res->get_gf(); } diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 91cc438dccc..527d84cd081 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3514,16 +3514,15 @@ struct server_context { continue; // continue loop of n_batch } - bool needs_mtp_warmup = false; - if (slot_batched && slot_batched->has_mtp) { - if (slot_batched->state == SLOT_STATE_PROCESSING_PROMPT || slot_batched->state == SLOT_STATE_DONE_PROMPT) { - needs_mtp_warmup = true; - } - } - - if (needs_mtp_warmup) { + if (slot_batched && slot_batched->has_mtp && + (slot_batched->state == SLOT_STATE_PROCESSING_PROMPT || slot_batched->state == SLOT_STATE_DONE_PROMPT)) { + + // Prepare the context to reuse the exact sinfo layout (including multiple u-batches) + // from the main model's prompt processing pass. This ensures the MTP layer's + // KV cache is perfectly aligned. if (llama_mtp_prepare_sinfo_for_warmup(ctx)) { mtp_update_kv_cache(ctx, batch_view, true); + // Clean up the forced state to not affect subsequent decodes. llama_mtp_cancel_sinfo_update(ctx); } else { LOG_ERR("%s: Failed to prepare the MTP symphony for warmup.", __func__); From a99709d0c1401d0b447dce1bd0101fb56390f50e Mon Sep 17 00:00:00 2001 From: samuel Date: Fri, 10 Oct 2025 17:24:34 -0300 Subject: [PATCH 26/35] mtp-batch(refactor): Extract decode context and MTP input logic into helper methods --- src/llama-context.cpp | 119 +++++++++++++++++++++++++++--------------- src/llama-context.h | 8 +++ 2 files changed, 84 insertions(+), 43 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index f22a3980489..4bdbee951d8 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -794,28 +794,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } if (mtp_params.op_type != MTP_OP_NONE) { // If it is any MTP operation - const char * target_tensor_name = "result_embd_pooled"; - ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name); - - const float * source_hidden_state = nullptr; - if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { - source_hidden_state = this->embd; - } else { - source_hidden_state = this->draft_input_hidden_state; - } - - if (source_hidden_state != nullptr && hidden_states_input != nullptr) { - const size_t n_embd = this->model.hparams.n_embd; - const size_t n_tokens_for_sum = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED && ubatch.n_tokens > 2) ? ubatch.n_tokens : 1; - double input_sum = calculate_vector_sum(source_hidden_state, n_tokens_for_sum * n_embd); - const char * op_type = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) ? "MTP_UPDATE" : "DRAFT_GEN"; - - LLAMA_LOG_WARN("[MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n", op_type, input_sum); - - ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input)); - } else { - LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n", - __func__, target_tensor_name); + if (!prepare_mtp_graph_inputs(res, ubatch, mtp_params)) { ret = GGML_STATUS_FAILED; return nullptr; } @@ -1089,27 +1068,7 @@ int llama_context::decode(const llama_batch & batch_inp) { std::unique_ptr mctx; while (true) { - if (cparams.warmup) { - mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); - } else { - if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) { - LLAMA_LOG_WARN("[DEBUG-CACHE-REUSE] Forcing sinfos, bypassing find_slot.\n"); - - mctx = static_cast(memory.get())->init_batch_with_sinfos( - *balloc, cparams.n_ubatch, *kvd->forced_sinfos, true - ); - } else { - mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); - - if (batch_inp.mtp_params.op_type == MTP_OP_NONE) { - if (mctx && mctx->get_status() == LLAMA_MEMORY_STATUS_SUCCESS) { - kvd->last_main_model_sinfos = static_cast(mctx.get())->get_sinfos(); - } else { - kvd->last_main_model_sinfos.clear(); - } - } - } - } + mctx = this->initialize_decode_context(batch_inp, output_all); if (!mctx) { return -2; @@ -3149,3 +3108,77 @@ void llama_context::kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { ctx->kv_cache_seq_rm(seq_id, p0, p1); } + +/* + Initializes the memory context for a decode operation. + The logic follows a specific priority: + 1. Warmup: Always use a standard batch initialization. + 2. Forced S-Info (MTP Updates): If a specific KV cache layout is forced, use it. + 3. Default: Use a standard batch initialization, and if it's a main model pass, + save the resulting s-info for potential future reuse by MTP. +*/ +std::unique_ptr llama_context::initialize_decode_context(const llama_batch & batch_inp, const bool output_all) { + auto * kvd = static_cast(kv_cache_data); + std::unique_ptr mctx; + + if (cparams.warmup) { + mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + } else if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) { + LLAMA_LOG_WARN("[DEBUG-CACHE-REUSE] Forcing sinfos, bypassing find_slot.\n"); + mctx = static_cast(memory.get())->init_batch_with_sinfos( + *balloc, cparams.n_ubatch, *kvd->forced_sinfos, true + ); + } else { + mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + + if (batch_inp.mtp_params.op_type == MTP_OP_NONE) { + if (mctx && mctx->get_status() == LLAMA_MEMORY_STATUS_SUCCESS) { + kvd->last_main_model_sinfos = static_cast(mctx.get())->get_sinfos(); + } else { + kvd->last_main_model_sinfos.clear(); + } + } + } + + return mctx; +} + + +bool llama_context::prepare_mtp_graph_inputs( + llm_graph_result * res, + const llama_ubatch & ubatch, + const llama_mtp_params & mtp_params) { + + const char * target_tensor_name = "result_embd_pooled"; + ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name); + + const float * source_hidden_state = nullptr; + if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { + source_hidden_state = this->embd; + } else { // MTP_OP_DRAFT_GEN + source_hidden_state = this->draft_input_hidden_state; + } + + if (source_hidden_state != nullptr && hidden_states_input != nullptr) { + const size_t n_embd = this->model.hparams.n_embd; + const size_t n_tokens_for_sum = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED && ubatch.n_tokens > 2) ? ubatch.n_tokens : 1; + double input_sum = calculate_vector_sum(source_hidden_state, n_tokens_for_sum * n_embd); + + const char * op_type; + if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { + op_type = "MTP_UPDATE"; + } else { // MTP_OP_DRAFT_GEN + op_type = "DRAFT_GEN"; + } + + LLAMA_LOG_WARN("[MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n", op_type, input_sum); + + ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input)); + } else { + LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n", + __func__, target_tensor_name); + return false; + } + + return true; +} diff --git a/src/llama-context.h b/src/llama-context.h index 70ca4e0832d..ab854c1af1a 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -231,6 +231,14 @@ struct llama_context { llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const; + // Methods for MTP decode + std::unique_ptr initialize_decode_context(const llama_batch & batch_inp, const bool output_all); + + bool prepare_mtp_graph_inputs( + llm_graph_result * res, + const llama_ubatch & ubatch, + const llama_mtp_params & mtp_params); + // TODO: read/write lora adapters and cvec size_t state_write_data(llama_io_write_i & io); size_t state_read_data (llama_io_read_i & io); From b4cbe030ac25056717763b812d1dd89681c08522 Mon Sep 17 00:00:00 2001 From: samuel Date: Sat, 11 Oct 2025 18:37:40 -0300 Subject: [PATCH 27/35] mtp-batch(chore): Fix logit flags for speculative sampling and remove debug logs --- common/speculative.cpp | 23 +++---------------- common/speculative.h | 2 -- src/llama-context.cpp | 42 +--------------------------------- src/llama-kv-cache-unified.cpp | 4 ---- src/llama-model.cpp | 38 ++---------------------------- tools/server/server.cpp | 29 ++--------------------- 6 files changed, 8 insertions(+), 130 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 2b63e16e02d..02eca967caf 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -380,11 +380,6 @@ llama_token mtp_speculative_gen_draft( mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_GEN; - // LOG_INF("[DEBUG-DRAFT-CALL] Calling llama_decode for draft. update_mtp_kv=%s, use_mtp_head=%s\n", - // mtp_batch.update_mtp_kv ? "true" : "false", - // mtp_batch.use_mtp_head ? "true" : "false" - // ); - // Perform the MTP draft generation decode. This writes the MTP layer's // KV state for the draft token into the cache. llama_decode(ctx, mtp_batch); @@ -416,7 +411,7 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b return; } - LOG_INF("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens); + LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens); llama_batch mtp_batch = batch; if (is_prompt_warmup) { @@ -426,7 +421,7 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b } for (int i = 0; i < mtp_batch.n_tokens; ++i) { - mtp_batch.logits[i] = false; + mtp_batch.logits[i] = true; } llama_decode(ctx, mtp_batch); } @@ -447,7 +442,7 @@ void mtp_accept_tokens( llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1); for (size_t i = 0; i < ids.size(); ++i) { - common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, false); + common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true); } mtp_update_kv_cache(ctx, accepted_batch, false); @@ -456,15 +451,3 @@ void mtp_accept_tokens( llama_batch_free(accepted_batch); } - -// Debug function - It will be removed later -double calculate_vector_sum_double(const float* vec, size_t size) { - if (!vec) { - return 0.0; - } - double sum = 0.0; - for (size_t i = 0; i < size; ++i) { - sum += vec[i]; - } - return sum; -} \ No newline at end of file diff --git a/common/speculative.h b/common/speculative.h index d361e69d077..8b81f4ac77d 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -57,5 +57,3 @@ void mtp_accept_tokens( int32_t n_past_base, llama_seq_id seq_id ); - -double calculate_vector_sum_double(const float* vec, size_t size); \ No newline at end of file diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 4bdbee951d8..7c9aff2826a 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -809,15 +809,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); } - const int64_t t_exec_start_us = ggml_time_us(); const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1); - const int64_t t_exec_end_us = ggml_time_us(); - // LLAMA_LOG_INFO( - // "[PERF] Graph compute time: %.2f ms (ubatch_size: %u, MTP path: %s)\n", - // (t_exec_end_us - t_exec_start_us) / 1000.0, - // ubatch.n_tokens, - // do_mtp_kv_update ? "yes" : "no" - // ); if (status != GGML_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); ret = status; @@ -827,9 +819,6 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll ret = GGML_STATUS_SUCCESS; if (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { ggml_tensor * sum_tensor = ggml_get_tensor(res->get_ctx(), "mtp_input_sum"); - if (sum_tensor) { - LLAMA_LOG_WARN("[DEBUG-SUM] MTP input sum node successfully created.\n"); - } } return res; } @@ -1123,20 +1112,6 @@ int llama_context::decode(const llama_batch & batch_inp) { do { const auto & ubatch = mctx->get_ubatch(); - if (ubatch.n_tokens > 0) { - std::string pos_str; - for (uint32_t i = 0; i < std::min((uint32_t)5, ubatch.n_tokens); ++i) { - pos_str += std::to_string(ubatch.pos[i]) + " "; - } - // LLAMA_LOG_WARN( - // "[DEBUG-POS] ubatch_size=%u, update_mtp_kv=%s, use_mtp_head=%s. Positions: %s...\n", - // ubatch.n_tokens, - // batch_inp.update_mtp_kv ? "true" : "false", - // batch_inp.use_mtp_head ? "true" : "false", - // pos_str.c_str() - // ); - } - // count the outputs in this ubatch { int32_t n_outputs_new = 0; @@ -1281,8 +1256,6 @@ int llama_context::decode(const llama_batch & batch_inp) { GGML_ABORT("unknown pooling type"); } } - } else { - LLAMA_LOG_WARN("[DEBUG-EMBD-COPY] Skipping embedding buffer copy for MTP operation (use_mtp_head=true).\n"); } } @@ -1347,13 +1320,6 @@ int llama_context::decode(const llama_batch & batch_inp) { // overlap with device computation. ggml_backend_sched_reset(sched.get()); } - - if (batch_inp.mtp_params.op_type == MTP_OP_NONE) { - synchronize(); - const size_t n_embd = this->model.hparams.n_embd; - double full_buffer_sum = calculate_vector_sum(this->embd, n_outputs_all * n_embd); - LLAMA_LOG_WARN("[INTEGRITY-CHECK|A] After main decode. ubatch_size=%d. Checksum: %e\n", n_outputs_all, full_buffer_sum); - } return 0; } @@ -3124,7 +3090,7 @@ std::unique_ptr llama_context::initialize_decode_context if (cparams.warmup) { mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); } else if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) { - LLAMA_LOG_WARN("[DEBUG-CACHE-REUSE] Forcing sinfos, bypassing find_slot.\n"); + LLAMA_LOG_DEBUG("%s: Forcing sinfos, bypassing find_slot.\n", __func__); mctx = static_cast(memory.get())->init_batch_with_sinfos( *balloc, cparams.n_ubatch, *kvd->forced_sinfos, true ); @@ -3160,10 +3126,6 @@ bool llama_context::prepare_mtp_graph_inputs( } if (source_hidden_state != nullptr && hidden_states_input != nullptr) { - const size_t n_embd = this->model.hparams.n_embd; - const size_t n_tokens_for_sum = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED && ubatch.n_tokens > 2) ? ubatch.n_tokens : 1; - double input_sum = calculate_vector_sum(source_hidden_state, n_tokens_for_sum * n_embd); - const char * op_type; if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { op_type = "MTP_UPDATE"; @@ -3171,8 +3133,6 @@ bool llama_context::prepare_mtp_graph_inputs( op_type = "DRAFT_GEN"; } - LLAMA_LOG_WARN("[MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n", op_type, input_sum); - ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input)); } else { LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n", diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 787fb8d9a55..90ee8f726ef 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -766,7 +766,6 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d } llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const { - LLAMA_LOG_WARN("%s: Entering find_slot for ubatch of %d tokens.\n", __func__, ubatch.n_tokens); if (debug > 0) { const auto & cells = v_cells[seq_to_stream[1]]; @@ -972,9 +971,6 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ } } } - LLAMA_LOG_WARN("%s: find_slot SUCCEEDED for ubatch of %d tokens. Idxs:%s\n", __func__, ubatch.n_tokens, idxs_str.c_str()); - } else { - LLAMA_LOG_ERROR("%s: find_slot FAILED to allocate cells for ubatch of %d tokens.\n", __func__, ubatch.n_tokens); } return res; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 77bdf83edae..56f2bae06cd 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13788,35 +13788,11 @@ struct llm_build_glm4 : public llm_graph_context { struct llm_build_glm4_moe : public llm_graph_context { llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - // LLAMA_LOG_WARN( - // "[GRAPH_BUILD] Building graph. Path: %s, MTP_Update: %s, UBatch_Tokens: %d, First_Pos: %d\n", - // params.use_mtp_head ? "MTP" : "MAIN", - // params.update_mtp_kv ? "true" : "false", - // n_tokens, - // n_tokens > 0 ? ubatch.pos[0] : -1 - // ); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); ggml_tensor * cur; - // LLAMA_LOG_WARN( - // "[DEBUG-GRAPH-STATE] Building graph. MTP Head=%s, MTP KV Update=%s, n_tokens=%d\n", - // params.use_mtp_head ? "true" : "false", - // params.update_mtp_kv ? "true" : "false", - // n_tokens - // ); - // for (int i = 0; i < n_tokens; ++i) { - // LLAMA_LOG_WARN(" - ubatch token[%d]: ID=%d, Pos=%d\n", i, ubatch.token[i], ubatch.pos[i]); - // } - if (n_tokens > 0) { - LLAMA_LOG_WARN( - " - ubatch tokens: [ID=%d, Pos=%d] ... [ID=%d, Pos=%d]\n", - ubatch.token[0], ubatch.pos[0], - ubatch.token[n_tokens-1], ubatch.pos[n_tokens-1] - ); - } - if (params.mtp_params.op_type != MTP_OP_NONE) { ggml_tensor* hidden_states_from_main_model; @@ -13913,10 +13889,7 @@ struct llm_build_glm4_moe : public llm_graph_context { cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - if (ubatch.n_tokens > 0) { - LLAMA_LOG_WARN("[KV_WRITE] path=MAIN, layer=%d, n_tokens=%d, pos_start=%d, pos_end=%d\n", - il, ubatch.n_tokens, ubatch.pos[0], ubatch.pos[ubatch.n_tokens-1]); - } + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); @@ -14066,14 +14039,7 @@ struct llm_build_glm4_moe : public llm_graph_context { cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - // LLAMA_LOG_WARN("[DEBUG-MTP-ATTN] Inputs for build_attn in the layer %d:\n", il); - // LLAMA_LOG_WARN(" - Qcur shape: [%d, %d, %d]\n", Qcur->ne[0], Qcur->ne[1], Qcur->ne[2]); - // LLAMA_LOG_WARN(" - Kcur shape: [%d, %d, %d]\n", Kcur->ne[0], Kcur->ne[1], Kcur->ne[2]); - // LLAMA_LOG_WARN(" - Vcur shape: [%d, %d, %d]\n", Vcur->ne[0], Vcur->ne[1], Vcur->ne[2]); - if (ubatch.n_tokens > 0) { - LLAMA_LOG_WARN("[KV_WRITE] path=MTP, layer=%d, n_tokens=%d, pos_start=%d, pos_end=%d\n", - il, ubatch.n_tokens, ubatch.pos[0], ubatch.pos[ubatch.n_tokens-1]); - } + cur = build_attn(inp_attn, mtp_layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 527d84cd081..4ff69f005f5 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1738,7 +1738,7 @@ struct server_queue { while (true) { QUE_DBG("%s", "processing new tasks\n"); - const int64_t t_turn_start_us = ggml_time_us(); + while (true) { std::unique_lock lock(mutex_tasks); if (!running) { @@ -1761,11 +1761,7 @@ struct server_queue { QUE_DBG("%s", "update slots\n"); callback_update_slots(); - const int64_t t_turn_end_us = ggml_time_us(); - SRV_DBG( - "[PERF] Server turn time: %.2f ms\n", - (t_turn_end_us - t_turn_start_us) / 1000.0 - ); + QUE_DBG("%s", "waiting for new tasks\n"); { std::unique_lock lock(mutex_tasks); @@ -3471,7 +3467,6 @@ struct server_context { batch.seq_id + i, batch.logits + i, }; - LOG_INF("\n[DEBUG-CHUNK] Processing main model chunk. Batch size: %d\n", n_tokens); const int ret = llama_decode(ctx, batch_view); @@ -3569,10 +3564,8 @@ struct server_context { } llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); slot.last_tok_idx = tok_idx; - //SRV_INF("main loop sampled token: '%s'\n", common_token_to_piece(ctx, id, true).c_str()); slot.i_batch = -1; - SLT_INF(slot, "[SAMPLER-ACCEPT] Accepting token ID %d at index %zu\n", id, i); common_sampler_accept(slot.smpl, id, true); slot.n_decoded += 1; @@ -3644,7 +3637,6 @@ struct server_context { llama_tokens draft; if (slot.has_mtp) { - SLT_INF(slot, "[POS-SYNC] Before draft gen. n_past = %d\n", slot.n_past); llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); draft.reserve(1); draft.push_back(draft_id); @@ -3680,26 +3672,14 @@ struct server_context { } SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); - SLT_INF(slot, "[POS-SYNC] Before validation decode. n_past = %d, spec_batch_size = %d\n", slot.n_past, slot.batch_spec.n_tokens); llama_decode(ctx, slot.batch_spec); - const size_t n_embd = llama_n_embd(llama_get_model(ctx)); - const size_t golden_buffer_size_in_floats = slot.batch_spec.n_tokens * n_embd; - const float* golden_embd_ptr = llama_get_embeddings(ctx); - double golden_checksum = calculate_vector_sum_double(golden_embd_ptr, golden_buffer_size_in_floats); - SLT_INF(slot, "[VERIFY] Golden checksum after validation: %e (size: %zu tokens)\n", golden_checksum, slot.batch_spec.n_tokens); - // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); - SLT_INF(slot, "[POS-SYNC] Tokens accepted: %zu\n", ids.size()); if (slot.has_mtp) { llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, ids.size() - 1)); - const float* embd_after_draft_ptr = llama_get_embeddings(ctx); - double checksum_after_draft = calculate_vector_sum_double(embd_after_draft_ptr, golden_buffer_size_in_floats); - SLT_INF(slot, "[VERIFY] Checksum after draft gen (should be unchanged): %e\n", checksum_after_draft); - if (!ids.empty()) { llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, ids.size() - 1)); } else { @@ -3707,14 +3687,9 @@ struct server_context { } mtp_accept_tokens(ctx, ids, slot.n_past, slot.id); - - const float* embd_after_update_ptr = llama_get_embeddings(ctx); - double checksum_after_update = calculate_vector_sum_double(embd_after_update_ptr, golden_buffer_size_in_floats); - SLT_INF(slot, "[VERIFY] Checksum after MTP update (should be unchanged): %e\n", checksum_after_update); } slot.n_past += ids.size(); - SLT_INF(slot, "[POS-SYNC] After n_past update. New n_past = %d\n", slot.n_past); slot.n_decoded += ids.size(); // update how many tokens out of those tested were accepted From 4bcc9e261ef57ee4cfaa65d06bcd0fcdeacf7797 Mon Sep 17 00:00:00 2001 From: samuel Date: Sat, 11 Oct 2025 18:51:22 -0300 Subject: [PATCH 28/35] mtp-batch(fix): Correctly advance cache head and add MTP documentation --- common/speculative.cpp | 4 ++++ include/llama.h | 22 ++++++++++++++++++++++ src/llama-context.h | 4 ++++ src/llama-kv-cache-unified.cpp | 33 +++++++++++++++------------------ 4 files changed, 45 insertions(+), 18 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 02eca967caf..a7a40426821 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -436,10 +436,13 @@ void mtp_accept_tokens( return; } + // Prepare a resized copy of the validation sinfo to match the number of accepted tokens. + // This sets up the context for a "forced sinfo" decode. if (!llama_mtp_prepare_sinfo_for_update(ctx, ids.size())) { return; } + // Build a new batch containing only the accepted tokens. llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1); for (size_t i = 0; i < ids.size(); ++i) { common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true); @@ -447,6 +450,7 @@ void mtp_accept_tokens( mtp_update_kv_cache(ctx, accepted_batch, false); + // Clean up the forced state to not affect subsequent, normal decode calls. llama_mtp_cancel_sinfo_update(ctx); llama_batch_free(accepted_batch); diff --git a/include/llama.h b/include/llama.h index 89c85103109..0b15d4bf1cd 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1466,14 +1466,36 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); + // + // MTP + // + LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state); + /** + * @brief Prepares the context for an MTP KV cache update by creating a resized copy of the last sinfo. + * This is used after speculative validation when only a subset of draft tokens are accepted. + * @param n_accepted The number of tokens that were accepted and for which the sinfo should be resized. + * @return true on success. + */ LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted); + /** + * @brief Prepares the context for an MTP KV cache update by reusing the sinfo from the last main model decode. + * This is used for the prompt warmup to ensure the MTP and main model KV caches are perfectly aligned. + * @return true on success. + */ LLAMA_API bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx); + /** + * @brief Clears the forced sinfo state from the context. Must be called after a decode that used a prepared sinfo. + */ LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx); + /** + * @brief Removes KV cache metadata for a specified sequence and token range. + * This makes the physical cells logically available again without deleting the tensor data. + */ LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); #ifdef __cplusplus diff --git a/src/llama-context.h b/src/llama-context.h index ab854c1af1a..4d77d5d81ae 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -30,6 +30,10 @@ struct llama_context { ~llama_context(); + // The llama_context manages significant resources (GPU memory, file handles, PImpl data) + // and is fundamentally a non-copyable, non-movable object. Deleting these special + // member functions enforces this rule and is also technically required to allow the + // PImpl pattern (via unique_ptr or void*) with an incomplete type in the header. llama_context(const llama_context &) = delete; llama_context & operator=(const llama_context &) = delete; llama_context(llama_context &&) = delete; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 90ee8f726ef..8d9b1f631f7 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -977,6 +977,10 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ } void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update) { + // For "in-place" updates (MTP warmup/accept), we only update the tensor data. + // The cell metadata (logical position, sequence ID) has already been set + // by the main model's pass. We must skip all metadata modifications + // to prevent `pos_set` from asserting on an already-set cell. if (!is_inplace_update) { // keep track of the max sequence position that we would overwrite with this ubatch // for non-SWA cache, this would be always empty @@ -995,17 +999,12 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u const auto idx = sinfo.idxs[s][ii]; - if (!is_inplace_update) { - if (!cells.is_empty(idx)) { - assert(cells.seq_count(idx) == 1); - - const llama_seq_id seq_id = cells.seq_get(idx); - const llama_pos pos = cells.pos_get(idx); - - seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); - - cells.rm(idx); - } + if (!cells.is_empty(idx)) { + assert(cells.seq_count(idx) == 1); + const llama_seq_id seq_id = cells.seq_get(idx); + const llama_pos pos = cells.pos_get(idx); + seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + cells.rm(idx); } cells.pos_set(idx, ubatch.pos[i]); @@ -1029,19 +1028,17 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u auto & cells = v_cells[seq_to_stream[s]]; if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { - LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", - __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); } } + } - // move the head at the end of the slot - for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { - auto & head = v_heads[sinfo.strm[s]]; + // move the head at the end of the slot + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + auto & head = v_heads[sinfo.strm[s]]; - head = sinfo.idxs[s].back() + 1; - } + head = sinfo.idxs[s].back() + 1; } } From 0127c6beeb384ec3abbc18b22dbe830f22fcf4b4 Mon Sep 17 00:00:00 2001 From: samuel Date: Sat, 11 Oct 2025 22:20:54 -0300 Subject: [PATCH 29/35] mtp-batch(chore): Remove final MTP debug logs and dead code --- src/llama-context.cpp | 43 ++++------------------------------------- src/llama-model.cpp | 16 --------------- tools/server/server.cpp | 2 +- 3 files changed, 5 insertions(+), 56 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 7c9aff2826a..a5345ee2a40 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -13,7 +13,6 @@ #include #include #include -#include // // llama_context @@ -738,17 +737,6 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } -static double calculate_vector_sum(const float* vec, size_t size) { - if (!vec) { - return 0.0; - } - double sum = 0.0; - for (size_t i = 0; i < size; ++i) { - sum += vec[i]; - } - return sum; -} - llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret, const llama_mtp_params & mtp_params) { if (mctx && !mctx->apply()) { @@ -995,10 +983,6 @@ int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT auto * kvd = static_cast(kv_cache_data); - // LLAMA_LOG_WARN("[DEBUG-DECODE-ENTRY] Entering llama_decode. update_mtp_kv=%s, use_mtp_head=%s\n", - // batch_inp.update_mtp_kv ? "true" : "false", - // batch_inp.use_mtp_head ? "true" : "false" - // ); if (!memory) { LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); @@ -1074,10 +1058,10 @@ int llama_context::decode(const llama_batch & batch_inp) { } case LLAMA_MEMORY_STATUS_FAILED_PREPARE: { - // if (use_last_main_model_sinfos) { - // LLAMA_LOG_ERROR("%s: Mismatch between ubatches and sinfos during reuse.\n", __func__); - // return -1; - // } + if (kvd->forced_sinfos) { + LLAMA_LOG_ERROR("%s: Mismatch between ubatches and sinfos during reuse.\n", __func__); + return -1; + } if (!did_optimize) { did_optimize = true; @@ -1106,9 +1090,6 @@ int llama_context::decode(const llama_batch & batch_inp) { }; int64_t n_outputs_prev = 0; - // const bool do_mtp_kv_update = batch_inp.update_mtp_kv; - // const bool use_mtp_head = batch_inp.use_mtp_head; - // const bool is_prompt_warmup = batch_inp.is_mtp_prompt_warmup; do { const auto & ubatch = mctx->get_ubatch(); @@ -1127,14 +1108,6 @@ int llama_context::decode(const llama_batch & batch_inp) { // needs to happen before the graph is built n_outputs = n_outputs_new; } - // if (do_mtp_kv_update) { - // LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] MTP KV Update ubatch: n_tokens=%d\n", ubatch.n_tokens); - // std::string positions_str; - // for (int i = 0; i < std::min((uint32_t)5, ubatch.n_tokens); ++i) { - // positions_str += std::to_string(ubatch.pos[i]) + " "; - // } - // LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] Positions: %s...\n", positions_str.c_str()); - // } ggml_status status; const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, batch_inp.mtp_params); if (!res) { @@ -1195,14 +1168,6 @@ int llama_context::decode(const llama_batch & batch_inp) { } } - // if (use_mtp_head) { - // if (t_embd != nullptr) { - // LLAMA_LOG_ERROR("[MTP-GRAPH-BUG] The MTP graph returned an embedding tensor when it shouldn't have! This will cause corruption.\n"); - // } else { - // LLAMA_LOG_WARN("[MTP-GRAPH-OK] The MTP graph correctly did not return an embedding tensor.\n"); - // } - // } - // extract embeddings if (t_embd && n_outputs > 0) { if (batch_inp.mtp_params.op_type == MTP_OP_NONE) { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 56f2bae06cd..ab7daee356a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13829,11 +13829,6 @@ struct llm_build_glm4_moe : public llm_graph_context { // Final layer tensors are loaded but not processed in forward pass const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; for (int il = 0; il < n_transformer_layers; ++il) { - // if (params.use_mtp_head) { - // LLAMA_LOG_ERROR("[DEBUG-KV-ERROR] MTP path is running the main layer %d!\n", il); - // } else { - // LLAMA_LOG_WARN("[DEBUG-KV] Main Head Path: Accessing layer %d\n", il); - // } ggml_tensor * inpSA = inpL; // Pre-attention norm @@ -13976,7 +13971,6 @@ struct llm_build_glm4_moe : public llm_graph_context { ggml_tensor * embd_copy = ggml_dup(ctx0, prev_embeddings); const int il = hparams.n_layer - 1; - // LLAMA_LOG_WARN("[DEBUG-KV] MTP Head Path: Accessing layer %d\n", il); ggml_tensor * sum_node = ggml_sum(ctx0, embd_copy); ggml_set_name(sum_node, "mtp_input_sum"); @@ -18311,12 +18305,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { - const int64_t t_start_us = ggml_time_us(); std::unique_ptr llm; - - const bool build_mtp = params.mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED; - switch (arch) { case LLM_ARCH_LLAMA: { @@ -18678,12 +18668,6 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); } - const int64_t t_end_us = ggml_time_us(); - // LLAMA_LOG_INFO( - // "[PERF] Graph build time: %.2f ms (MTP path: %s)\n", - // (t_end_us - t_start_us) / 1000.0, - // params.use_mtp_head ? "yes" : "no" - // ); return llm->res->get_gf(); } diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 4ff69f005f5..a24532c6939 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3520,7 +3520,7 @@ struct server_context { // Clean up the forced state to not affect subsequent decodes. llama_mtp_cancel_sinfo_update(ctx); } else { - LOG_ERR("%s: Failed to prepare the MTP symphony for warmup.", __func__); + LOG_ERR("%s: Failed to prepare the MTP for warmup.", __func__); } } From 171346c742c310bbcfbd786b61250638ccf8b44d Mon Sep 17 00:00:00 2001 From: samuel Date: Sun, 12 Oct 2025 16:33:01 -0300 Subject: [PATCH 30/35] mtp-graph(feat): Reactivate graph reuse only for main model path --- src/llama-context.cpp | 102 +++++++++++++++++++++++++++++++++--------- 1 file changed, 82 insertions(+), 20 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a5345ee2a40..22d3f043471 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -17,10 +17,24 @@ // // llama_context // +// Key for the graph cache. It contains all parameters that define the graph topology. +struct llama_graph_cache_key { + uint32_t n_tokens; + uint32_t n_outputs; + llama_mtp_op_type op_type; + bool causal_attn; + + bool operator<(const llama_graph_cache_key& other) const { + return std::tie(n_tokens, n_outputs, op_type, causal_attn) < + std::tie(other.n_tokens, other.n_outputs, other.op_type, other.causal_attn); + } +}; + struct llama_context_kv_cache_data { llama_kv_cache_unified::slot_info_vec_t last_main_model_sinfos; llama_kv_cache_unified::slot_info_vec_t resized_sinfo_for_force; const llama_kv_cache_unified::slot_info_vec_t * forced_sinfos = nullptr; + std::map graph_cache; }; llama_context::llama_context( @@ -745,40 +759,88 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll return nullptr; } - auto * res = gf_res_prev.get(); - auto * gf = res->get_gf(); - - // the new graph parameters - // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters - const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params); - - if (!graph_reuse_disable && res->can_reuse(gparams)) { + auto * kvd = static_cast(kv_cache_data); + llm_graph_result * res; + + if (mtp_params.op_type != MTP_OP_NONE) { + int32_t n_outputs = 0; + for (int i = 0; i < ubatch.n_tokens; ++i) { if (ubatch.output[i]) n_outputs++; } + const llama_graph_cache_key key = { ubatch.n_tokens, (uint32_t)n_outputs, mtp_params.op_type, cparams.causal_attn }; + + auto & res_ptr = kvd->graph_cache[key]; + if (!res_ptr) { + LLAMA_LOG_INFO("[GRAPH-CACHE] Creating a new graph container for key (op=%d, tok=%d, out=%d)\n", + (int)key.op_type, key.n_tokens, key.n_outputs); + res_ptr = std::make_unique(graph_max_nodes()); + } + res = res_ptr.get(); + + // the new graph parameters + // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters + const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params); + + // if (!graph_reuse_disable && res->can_reuse(gparams)) { //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); - - n_reused++; - } else { - res->reset(); - + // LLAMA_LOG_INFO("[GRAPH-CACHE] HIT, reusing graph STRUCTURE for key (op=%d, tok=%d, out=%d)\n", + // (int)key.op_type, key.n_tokens, key.n_outputs); + // n_reused++; + // } else { + LLAMA_LOG_INFO("[GRAPH-CACHE] MISS, RECONSTRUCTING THE STRUCTURE of the graph for key (op=%d, tok=%d, out=%d)\n", + (int)key.op_type, key.n_tokens, key.n_outputs); + ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); - //const auto t_start_us = ggml_time_us(); - - gf = model.build_graph(gparams); - - //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); + res->reset(); + res->set_params(gparams); + res->gf = model.build_graph(gparams); - if (!gf) { + if (!res->gf) { LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); ret = GGML_STATUS_FAILED; return nullptr; } - if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { + if (!ggml_backend_sched_alloc_graph(sched.get(), res->gf)) { LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); ret = GGML_STATUS_ALLOC_FAILED; return nullptr; } + // } + + } else { + res = gf_res_prev.get(); + const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params); + + if (!graph_reuse_disable && res->can_reuse(gparams)) { + LLAMA_LOG_INFO("%s: reusing previous graph\n", __func__); + n_reused++; + } else { + LLAMA_LOG_INFO("%s: RECONSTRUCTED graph...\n", __func__); + + ggml_backend_sched_reset(sched.get()); + ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + + res->reset(); + res->set_params(gparams); + //const auto t_start_us = ggml_time_us(); + + res->gf = model.build_graph(gparams); + + //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); + + if (!res->gf) { + LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + if (!ggml_backend_sched_alloc_graph(sched.get(), res->gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); + ret = GGML_STATUS_ALLOC_FAILED; + return nullptr; + } + } } if (mtp_params.op_type != MTP_OP_NONE) { // If it is any MTP operation From cae85fe531876762ee02524fc4c3f6c5e7824c63 Mon Sep 17 00:00:00 2001 From: samuel Date: Thu, 16 Oct 2025 13:42:31 -0300 Subject: [PATCH 31/35] mtp-batch(fix): avoid logits for mtp kv cache operations --- src/llama-context.cpp | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a5345ee2a40..fb35d6c79de 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1155,16 +1155,25 @@ int llama_context::decode(const llama_batch & batch_inp) { // extract logits if (t_logits && n_outputs > 0) { - ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); - GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); - - float * logits_out = logits + n_outputs_prev*n_vocab; - - if (n_outputs) { - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); - ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + // MTP operations that are purely for updating the KV cache + // (MTP_OP_WARMUP and MTP_OP_UPDATE_ACCEPTED) also produce a logit tensor + // as a side effect of running the graph. If these logits are copied + // back to the main context buffer, they will overwrite the valid logits + // produced by the main model's pass, leading to incorrect sampling. + // This condition explicitly prevents that copy for cache-only operations. + if (batch_inp.mtp_params.op_type != MTP_OP_WARMUP && + batch_inp.mtp_params.op_type != MTP_OP_UPDATE_ACCEPTED) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + float * logits_out = logits + n_outputs_prev*n_vocab; + + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); + ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + } } } From ea77394183b8e6c368af969b8274039a54b11486 Mon Sep 17 00:00:00 2001 From: samuel Date: Sat, 6 Dec 2025 13:47:54 -0300 Subject: [PATCH 32/35] mtp-graph (fix): move llama_get_logits_ith outside the loop --- common/speculative.cpp | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index a7a40426821..97187a70e39 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -374,7 +374,6 @@ llama_token mtp_speculative_gen_draft( return -1; } llama_batch mtp_batch = llama_batch_init(1, 0, 1); - const llama_pos draft_pos = n_past; const llama_seq_id draft_seq_id = 0; common_batch_add(mtp_batch, id_last, n_past, {0}, true); @@ -382,23 +381,30 @@ llama_token mtp_speculative_gen_draft( // Perform the MTP draft generation decode. This writes the MTP layer's // KV state for the draft token into the cache. - llama_decode(ctx, mtp_batch); + if (llama_decode(ctx, mtp_batch) != 0) { + llama_batch_free(mtp_batch); + return -1; + } llama_batch_free(mtp_batch); // CRITICAL: Purge the metadata for the draft token we just wrote. // This makes the physical cell available again for the main model's validation pass, // preventing a cache state corruption where two cells map to the same logical position. - llama_kv_cache_seq_rm(ctx, draft_seq_id, draft_pos, draft_pos + 1); + llama_kv_cache_seq_rm(ctx, draft_seq_id, n_past, n_past + 1); const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); const int n_vocab = llama_n_vocab(vocab); + llama_token_data_array * cur_p = common_sampler_get_candidates(smpl); + float * logits = llama_get_logits_ith(ctx, 0); cur_p->size = n_vocab; + for (int i = 0; i < n_vocab; ++i) { cur_p->data[i].id = i; - cur_p->data[i].logit = llama_get_logits_ith(ctx, 0)[i]; // For a single-token batch, logits are always at index 0. + cur_p->data[i].logit = logits[i]; } + cur_p->sorted = false; common_sampler_apply_chain(smpl, cur_p); From 6de0ecf55db8567db4faa99b0152b72c9e854548 Mon Sep 17 00:00:00 2001 From: samuel Date: Sat, 6 Dec 2025 14:40:13 -0300 Subject: [PATCH 33/35] mtp (feat): add mtp arg --- common/arg.cpp | 7 +++++++ common/common.h | 1 + tools/server/server.cpp | 2 +- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/common/arg.cpp b/common/arg.cpp index 0f01bb31454..61c1abc3bcc 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3315,6 +3315,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.cache_type_k = kv_cache_type_from_str(value); } ).set_env("LLAMA_ARG_CACHE_TYPE_K_DRAFT")); + add_opt(common_arg( + {"-mtp", "--multi-token-prediction"}, + string_format("Activate multi-token-prediction (if supported) (default: %s)", params.mtp ? "true" : "false"), + [](common_params & params) { + params.mtp = true; + } + )); add_opt(common_arg( {"-ctvd", "--cache-type-v-draft"}, "TYPE", string_format( diff --git a/common/common.h b/common/common.h index 5eab199af55..eb43003ed83 100644 --- a/common/common.h +++ b/common/common.h @@ -362,6 +362,7 @@ struct common_params { bool check_tensors = false; // validate tensor data bool no_op_offload = false; // globally disable offload host tensor operations to device bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking) + bool mtp = false; // use mtp is supported bool single_turn = false; // single turn chat conversation diff --git a/tools/server/server.cpp b/tools/server/server.cpp index a24532c6939..541ff07882f 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2133,7 +2133,7 @@ struct server_context { } // if model has MTP and no draft model is specified... - else if (llama_model_n_nextn_layer(model) > 0) { + else if (llama_model_n_nextn_layer(model) > 0 && params_base.mtp) { SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model)); slot.has_mtp = true; From a91980a8f3475a6bbac0a64d8be06dd4b613020e Mon Sep 17 00:00:00 2001 From: samuel Date: Sat, 6 Dec 2025 15:18:19 -0300 Subject: [PATCH 34/35] mtp (chore): clean old code --- common/sampling.cpp | 5 ----- src/llama-batch.cpp | 6 ++---- src/llama-context.cpp | 33 ++++----------------------------- src/llama-context.h | 2 -- tools/server/server.cpp | 7 +------ 5 files changed, 7 insertions(+), 46 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 452cefee3b9..8668c4d7157 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -348,11 +348,6 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co llama_sampler_apply(chain, &cur_p); - /*for (int k = 0; k < (int)cur_p.size; ++k) { - LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f)\n", - k, 0, cur_p.data[k].id, cur_p.data[k].p); - }*/ - GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); const llama_token id = cur_p.data[cur_p.selected].id; diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index c01960c55ea..3687058c82e 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -275,9 +275,7 @@ bool llama_batch_allocr::init( } } - // TEMPORARILY DISABLING THIS SANITY CHECK - // TODO: UNDO THIS IF IT WORKS - /*if (!ok) { + if (!ok) { LLAMA_LOG_ERROR( "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n" " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n" @@ -286,7 +284,7 @@ bool llama_batch_allocr::init( __func__, s, s, p0, s, seq_pos_min(s)); return false; - }*/ + } } if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 7e1e699d790..1a7f148a2c3 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -546,18 +546,6 @@ float * llama_context::get_logits() { return logits; } -void llama_context::set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i) { - output_reorder(); - - ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched_override, logit_override); - GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); - - int64_t j = output_ids[i]; - - ggml_backend_tensor_get_async(backend_res, logit_override, logits + j*model.vocab.n_tokens(), 0, model.vocab.n_tokens() * sizeof(float)); -} - float * llama_context::get_logits_ith(int32_t i) { int64_t j = -1; @@ -769,8 +757,8 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll auto & res_ptr = kvd->graph_cache[key]; if (!res_ptr) { - LLAMA_LOG_INFO("[GRAPH-CACHE] Creating a new graph container for key (op=%d, tok=%d, out=%d)\n", - (int)key.op_type, key.n_tokens, key.n_outputs); + LLAMA_LOG_DEBUG("%s: Creating a new graph container for key (op=%d, tok=%d, out=%d)\n", + __func__, (int)key.op_type, key.n_tokens, key.n_outputs); res_ptr = std::make_unique(graph_max_nodes()); } res = res_ptr.get(); @@ -779,15 +767,6 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params); - // if (!graph_reuse_disable && res->can_reuse(gparams)) { - //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); - // LLAMA_LOG_INFO("[GRAPH-CACHE] HIT, reusing graph STRUCTURE for key (op=%d, tok=%d, out=%d)\n", - // (int)key.op_type, key.n_tokens, key.n_outputs); - // n_reused++; - // } else { - LLAMA_LOG_INFO("[GRAPH-CACHE] MISS, RECONSTRUCTING THE STRUCTURE of the graph for key (op=%d, tok=%d, out=%d)\n", - (int)key.op_type, key.n_tokens, key.n_outputs); - ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); @@ -806,17 +785,16 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll ret = GGML_STATUS_ALLOC_FAILED; return nullptr; } - // } } else { res = gf_res_prev.get(); const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params); if (!graph_reuse_disable && res->can_reuse(gparams)) { - LLAMA_LOG_INFO("%s: reusing previous graph\n", __func__); + LLAMA_LOG_DEBUG("%s: Reusing previous graph\n", __func__); n_reused++; } else { - LLAMA_LOG_INFO("%s: RECONSTRUCTED graph...\n", __func__); + LLAMA_LOG_DEBUG("%s: Reconstructed graph...\n", __func__); ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); @@ -867,9 +845,6 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } ret = GGML_STATUS_SUCCESS; - if (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { - ggml_tensor * sum_tensor = ggml_get_tensor(res->get_ctx(), "mtp_input_sum"); - } return res; } diff --git a/src/llama-context.h b/src/llama-context.h index 4d77d5d81ae..7297c4c5d16 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -216,8 +216,6 @@ struct llama_context { // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); - void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i); - ggml_backend_sched_t create_temp_scheduler(size_t n_nodes); std::unique_ptr mtp_memory_batch(const llama_batch& batch_inp); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 541ff07882f..8198258b254 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1404,14 +1404,10 @@ struct server_slot { // if the context does not have a memory module then all embeddings have to be computed within a single ubatch // also we cannot split if the pooling would require any past tokens bool can_split() const { - //fprintf(stderr, "need_embd() %d\n", need_embd()); - //fprintf(stderr, "llama_get_memory(ctx) %d\n", llama_get_memory(ctx) != nullptr); - //fprintf(stderr, "POOLING_TYPE check %d\n", llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); - return !need_embd() || (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST) || - (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_NONE); // this seems to save embeddings for whole batch? + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_NONE); } bool can_batch_with(server_slot & other_slot) const { @@ -1440,7 +1436,6 @@ struct server_slot { bool can_speculate() const { return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt; - // return (ctx_dft) && params.speculative.n_max > 0 && params.cache_prompt; } void add_token(const completion_token_output & token) { From bdf72d9552e3da64ffc85f175664713388752914 Mon Sep 17 00:00:00 2001 From: samuel Date: Sat, 6 Dec 2025 16:10:16 -0300 Subject: [PATCH 35/35] sampling (feat): optimize speculative drafting with fast-path selection --- common/sampling.cpp | 39 +++++++++++++++++++++++++++++++++++++-- common/sampling.h | 2 +- common/speculative.cpp | 27 +++++---------------------- common/speculative.h | 3 +-- tools/server/server.cpp | 2 +- 5 files changed, 45 insertions(+), 28 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 8668c4d7157..8422023a540 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -578,6 +578,41 @@ std::vector common_sampler_types_from_chars(const std::stri return samplers; } -void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p) { - llama_sampler_apply(gsmpl->chain, cur_p); +/** + * Specialized sampling for speculative drafting. + * + * Prioritizes performance by using a direct ArgMax loop (Greedy) when no + * penalties (repetition, frequency, presence, DRY) are configured. + * Falls back to the full sampler chain if penalties are active to prevent + * generative loops or adhere to constraints. + */ +llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) { + const auto & params = gsmpl->params; + + bool use_heavy_sampler = + (params.penalty_last_n > 0 && ( + params.penalty_repeat != 1.0f || + params.penalty_freq != 0.0f || + params.penalty_present != 0.0f + )) || + (params.dry_allowed_length > 0 && params.dry_multiplier != 0.0f); + + if (use_heavy_sampler) { + return common_sampler_sample(gsmpl, ctx, idx, false); + } + + float * logits = llama_get_logits_ith(ctx, idx); + const int n_vocab = llama_n_vocab(llama_model_get_vocab(llama_get_model(ctx))); + + int best_id = 0; + float max_val = logits[0]; + + for (int i = 1; i < n_vocab; ++i) { + if (logits[i] > max_val) { + max_val = logits[i]; + best_id = i; + } + } + + return best_id; } \ No newline at end of file diff --git a/common/sampling.h b/common/sampling.h index b424d7d6d70..81a89727384 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -106,4 +106,4 @@ std::vector common_sampler_types_from_chars(const std: llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind, const char * grammar_data); -void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p); \ No newline at end of file +llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, struct llama_context * ctx, int idx); \ No newline at end of file diff --git a/common/speculative.cpp b/common/speculative.cpp index 97187a70e39..9d6b7ec1cbb 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -367,14 +367,13 @@ llama_token mtp_speculative_gen_draft( struct common_sampler* smpl, struct llama_context* ctx, llama_token id_last, - int32_t n_past, - int32_t last_tok_idx) { + int32_t n_past) { + + if (!smpl) return -1; - if (!smpl) { - return -1; - } llama_batch mtp_batch = llama_batch_init(1, 0, 1); const llama_seq_id draft_seq_id = 0; + common_batch_add(mtp_batch, id_last, n_past, {0}, true); mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_GEN; @@ -392,23 +391,7 @@ llama_token mtp_speculative_gen_draft( // preventing a cache state corruption where two cells map to the same logical position. llama_kv_cache_seq_rm(ctx, draft_seq_id, n_past, n_past + 1); - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - const int n_vocab = llama_n_vocab(vocab); - - llama_token_data_array * cur_p = common_sampler_get_candidates(smpl); - float * logits = llama_get_logits_ith(ctx, 0); - cur_p->size = n_vocab; - - for (int i = 0; i < n_vocab; ++i) { - cur_p->data[i].id = i; - cur_p->data[i].logit = logits[i]; - } - - cur_p->sorted = false; - common_sampler_apply_chain(smpl, cur_p); - - return cur_p->data[0].id; + return common_sampler_sample_speculative(smpl, ctx, 0); } diff --git a/common/speculative.h b/common/speculative.h index 8b81f4ac77d..4720c50cfde 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -39,8 +39,7 @@ llama_token mtp_speculative_gen_draft( struct common_sampler* smpl, struct llama_context* ctx, llama_token id_last, - int32_t n_past, - int32_t last_tok_idx); + int32_t n_past); // sample up to n_draft tokens and add them to the batch using the draft model llama_tokens common_speculative_gen_draft( diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 8198258b254..e26e1d6fecc 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3632,7 +3632,7 @@ struct server_context { llama_tokens draft; if (slot.has_mtp) { - llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); + llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past); draft.reserve(1); draft.push_back(draft_id); }