Skip to content

Commit c6237c7

Browse files
authored
Merge pull request #1 from SamuelOliveirads/glm4-moe-mtp
feat: implemented sampling for MTP
2 parents 9fab53e + 8742ce0 commit c6237c7

File tree

7 files changed

+62
-84
lines changed

7 files changed

+62
-84
lines changed

common/sampling.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,3 +582,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
582582

583583
return samplers;
584584
}
585+
586+
void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p) {
587+
llama_sampler_apply(gsmpl->chain, cur_p);
588+
}

common/sampling.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,5 @@ std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std:
105105

106106
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
107107
const char * grammar_kind, const char * grammar_data);
108+
109+
void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p);

common/speculative.cpp

Lines changed: 20 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -370,56 +370,35 @@ llama_token mtp_speculative_gen_draft(
370370
int32_t n_past,
371371
int32_t last_tok_idx) {
372372

373-
llama_token token_data[] = { id_last };
374-
llama_pos pos_data[] = { n_past };
375-
int32_t n_seq_id_data[] = { 1 };
376-
llama_seq_id seq_id_data_internal[] = { 0 };
377-
llama_seq_id* seq_id_data[] = {seq_id_data_internal};
378-
int8_t logits_data[] = { (int8_t) (smpl != nullptr) };
379-
380-
llama_batch batch = {
381-
/*.n_tokens = */ 1,
382-
/*.token = */ token_data,
383-
/*.embd = */ nullptr,
384-
/*.pos = */ pos_data,
385-
/*.n_seq_id = */ n_seq_id_data,
386-
/*.seq_id = */ seq_id_data,
387-
/*.logits = */ logits_data
388-
};
389-
390-
return llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
391-
//LOG_INF("updating kv cache for n_past: %d\n", n_past);
392-
393-
/*
394373
if (!smpl) {
395374
return -1;
396375
}
397-
else {
398-
common_sampler_sample(smpl, ctx, last_tok_idx, true);
399-
const auto* cur_p = common_sampler_get_candidates(smpl);
400376

401-
//for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) {
402-
// LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
403-
// k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
404-
//}
377+
llama_batch batch = llama_batch_init(1, 0, 1);
378+
common_batch_add(batch, id_last, n_past, {0}, true);
405379

406-
const llama_token id = cur_p->data[0].id;
407-
return id;
380+
llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
381+
382+
const llama_model * model = llama_get_model(ctx);
383+
const llama_vocab * vocab = llama_model_get_vocab(model);
384+
const int n_vocab = llama_n_vocab(vocab);
385+
386+
llama_token_data_array * cur_p = common_sampler_get_candidates(smpl);
387+
388+
cur_p->size = n_vocab;
389+
for (int i = 0; i < n_vocab; ++i) {
390+
cur_p->data[i].id = i;
391+
cur_p->data[i].logit = llama_get_logits_ith(ctx, last_tok_idx)[i];
408392
}
409-
*/
410-
// LOG_INF("cur_p->size: %d\n", cur_p->size);
393+
cur_p->sorted = false;
411394

395+
common_sampler_apply_chain(smpl, cur_p);
412396

413-
// add drafted token for each sequence
397+
const llama_token id = cur_p->data[0].id;
414398

415-
// skip accepting draft token -- since we're only drafting one token this can't affect future outputs
416-
// smpl will accept the token if it doesn't get rejected by main model later
417-
// common_sampler_accept(smpl, id, true);
399+
llama_batch_free(batch);
418400

419-
//llama_tokens result;
420-
//result.reserve(1);
421-
//result.push_back(id);
422-
//return result;
401+
return id;
423402
}
424403

425404

@@ -438,4 +417,4 @@ void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_d
438417
}
439418

440419
tokens.clear();
441-
}
420+
}

common/speculative.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ llama_token mtp_speculative_gen_draft(
4444

4545
// sample up to n_draft tokens and add them to the batch using the draft model
4646
llama_tokens common_speculative_gen_draft(
47-
struct common_speculative * spec,
48-
struct common_speculative_params params,
49-
const llama_tokens & prompt,
50-
llama_token id_last);
47+
struct common_speculative * spec,
48+
struct common_speculative_params params,
49+
const llama_tokens & prompt,
50+
llama_token id_last);
5151

5252
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start = 0, size_t n_tokens = -1);

include/llama.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,8 +1454,8 @@ extern "C" {
14541454
ggml_opt_epoch_callback callback_train,
14551455
ggml_opt_epoch_callback callback_eval);
14561456

1457-
LLAMA_API llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,
1458-
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
1457+
LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
1458+
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
14591459

14601460
#ifdef __cplusplus
14611461
}

src/llama-context.cpp

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2995,7 +2995,7 @@ void llama_opt_epoch(
29952995
callback_eval);
29962996
}
29972997

2998-
llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,
2998+
void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
29992999
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
30003000

30013001
const auto * model = llama_get_model(ctx);
@@ -3033,6 +3033,12 @@ llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,
30333033

30343034
auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past);
30353035

3036+
if (!gf) {
3037+
LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__);
3038+
if (sched) ggml_backend_sched_free(sched);
3039+
return;
3040+
}
3041+
30363042
ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
30373043
ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it
30383044

@@ -3044,29 +3050,24 @@ llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,
30443050

30453051
ggml_backend_sched_graph_compute(sched, gf); // execute the graph
30463052

3047-
//struct ggml_tensor * logits_mtp = res_mtp->get_logits();
3048-
3049-
//LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp);
3050-
3051-
//if (logits_mtp) {
3052-
// ctx->set_logits_ith(logits_mtp, sched, last_tok_idx);
3053-
//}
3054-
struct ggml_tensor * token_id_tensor = ggml_get_tensor(res_mtp->get_ctx(), "mtp_argmax_result");
3055-
3056-
3057-
llama_token token_id = 0; // The C++ variable to hold the result.
3058-
3059-
// ggml_backend_tensor_get is the function for GPU->CPU copies.
3060-
// We are copying a single 32-bit integer.
3061-
ggml_backend_tensor_get(
3062-
token_id_tensor,
3063-
&token_id, // Pointer to our C++ variable
3064-
0, // Starting offset in bytes
3065-
sizeof(llama_token) // Number of bytes to copy
3066-
);
3053+
struct ggml_tensor * logits_mtp = res_mtp->get_logits();
3054+
3055+
if (logits_mtp) {
3056+
float * logits_dest = ctx->get_logits_ith(last_tok_idx);
3057+
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched, logits_mtp);
3058+
if (backend_res) {
3059+
// ggml_backend_tensor_get is the function for GPU->CPU copies.
3060+
// We are copying a single 32-bit integer.
3061+
ggml_backend_tensor_get(logits_mtp,
3062+
logits_dest, // Pointer to our C++ variable
3063+
0, // Starting offset in bytes
3064+
ggml_nbytes(logits_mtp)); // Number of bytes to copy
3065+
} else {
3066+
LLAMA_LOG_ERROR("%s: ERROR - Could not obtain the backend for the logits tensor.", __func__);
3067+
}
3068+
} else {
3069+
LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__);
3070+
}
30673071

30683072
ggml_backend_sched_free(sched);
3069-
3070-
return token_id;
3071-
}
3072-
3073+
}

src/llama-model.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13950,6 +13950,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
1395013950
// For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization
1395113951
llama_token last_token_id, int n_past
1395213952
) : llm_graph_context(params) {
13953+
1395313954
const int64_t n_embd_head = hparams.n_embd_head_v;
1395413955
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
1395513956

@@ -13964,8 +13965,6 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
1396413965
//llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr;
1396513966
auto * inp_attn = build_attn_inp_kv_unified();
1396613967

13967-
ggml_tensor * cur;
13968-
1396913968
// get MTP embedding for last (conventionally sampled) token
1397013969
// ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1);
1397113970
// LLAMA_LOG_INFO("step: '%d'\n", 5641);
@@ -13979,7 +13978,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
1397913978

1398013979
//ggml_tensor * inp_token_id = ggml_new_i32(ctx0, last_token_id);
1398113980
//ggml_set_no_alloc(ctx0, true);
13982-
13981+
1398313982
ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id);
1398413983
ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
1398513984

@@ -13994,9 +13993,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
1399413993

1399513994
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); // torch.cat
1399613995

13997-
13998-
cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj
13999-
13996+
ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj
1400013997

1400113998
// now proceed through last layer (skipped in main model)
1400213999
ggml_tensor * inpSA = cur;
@@ -14096,14 +14093,9 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
1409614093

1409714094
cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il);
1409814095
cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur);
14099-
14096+
1410014097
res->t_logits = cur;
14101-
1410214098
ggml_build_forward_expand(gf, res->t_logits);
14103-
14104-
struct ggml_tensor * token_id_tensor = ggml_argmax(ctx0, cur);
14105-
ggml_set_name(token_id_tensor, "mtp_argmax_result");
14106-
ggml_build_forward_expand(gf, token_id_tensor);
1410714099
}
1410814100
};
1410914101

0 commit comments

Comments
 (0)