@@ -370,56 +370,35 @@ llama_token mtp_speculative_gen_draft(
370370    int32_t  n_past,
371371    int32_t  last_tok_idx) {
372372
373-     llama_token token_data[] = { id_last };
374-     llama_pos pos_data[] = { n_past };
375-     int32_t  n_seq_id_data[] = { 1  };
376-     llama_seq_id seq_id_data_internal[] = { 0  };
377-     llama_seq_id* seq_id_data[] = {seq_id_data_internal};
378-     int8_t  logits_data[] = { (int8_t ) (smpl != nullptr ) };
379- 
380-     llama_batch batch = {
381-         /* .n_tokens = */ 1 ,
382-         /* .token = */ 
383-         /* .embd = */ nullptr ,
384-         /* .pos = */ 
385-         /* .n_seq_id = */ 
386-         /* .seq_id = */ 
387-         /* .logits = */ 
388-     };
389- 
390-     return  llama_build_and_execute_mtp_graph (ctx, batch, id_last, n_past, last_tok_idx);
391-     // LOG_INF("updating kv cache for n_past: %d\n", n_past);
392- 
393-     /* 
394373    if  (!smpl) {
395374        return  -1 ;
396375    }
397-     else { 
398-         common_sampler_sample(smpl, ctx, last_tok_idx, true); 
399-         const auto* cur_p = common_sampler_get_candidates(smpl); 
400376
401-         //for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) { 
402-         //    LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", 
403-         //        k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); 
404-         //} 
377+     llama_batch batch = llama_batch_init (1 , 0 , 1 );
378+     common_batch_add (batch, id_last, n_past, {0 }, true );
405379
406-         const llama_token id = cur_p->data[0].id; 
407-         return id; 
380+     llama_build_and_execute_mtp_graph (ctx, batch, id_last, n_past, last_tok_idx);
381+ 
382+     const  llama_model * model = llama_get_model (ctx);
383+     const  llama_vocab * vocab = llama_model_get_vocab (model);
384+     const  int  n_vocab = llama_n_vocab (vocab);
385+ 
386+     llama_token_data_array * cur_p = common_sampler_get_candidates (smpl);
387+ 
388+     cur_p->size  = n_vocab;
389+     for  (int  i = 0 ; i < n_vocab; ++i) {
390+         cur_p->data [i].id  = i;
391+         cur_p->data [i].logit  = llama_get_logits_ith (ctx, last_tok_idx)[i];
408392    }
409-     */  
410-     //  LOG_INF("cur_p->size: %d\n", cur_p->size);
393+     cur_p->sorted  = false ;
411394
395+     common_sampler_apply_chain (smpl, cur_p);
412396
413-     //  add drafted token for each sequence 
397+     const  llama_token id = cur_p-> data [ 0 ]. id ; 
414398
415-     //  skip accepting draft token -- since we're only drafting one token this can't affect future outputs
416-     //  smpl will accept the token if it doesn't get rejected by main model later
417-     //  common_sampler_accept(smpl, id, true);
399+     llama_batch_free (batch);
418400
419-     // llama_tokens result;
420-     // result.reserve(1);
421-     // result.push_back(id);
422-     // return result;
401+     return  id;
423402}
424403
425404
@@ -438,4 +417,4 @@ void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_d
438417    }
439418
440419    tokens.clear ();
441- }
420+ }
0 commit comments