@@ -210,7 +210,7 @@ bool llama_batch_allocr::init(
210210 LLAMA_LOG_DEBUG (" %s: input batch info:\n " , __func__);
211211
212212 llama_ubatch ubatch {
213- /* .equal_seqs =*/ false ,
213+ /* .b_equal_seqs =*/ false ,
214214 /* .n_tokens =*/ (uint32_t ) batch.n_tokens ,
215215 /* .n_seq_tokens =*/ (uint32_t ) 1 ,
216216 /* .n_seqs =*/ (uint32_t ) batch.n_tokens ,
@@ -223,6 +223,7 @@ bool llama_batch_allocr::init(
223223 /* .seq_id_unq =*/ this ->seq_id_unq .data (),
224224 /* .seq_idx =*/ this ->seq_idx .data (),
225225 /* .output =*/ batch.logits ,
226+ /* .data =*/ {},
226227 };
227228
228229 ubatch_print (ubatch, debug);
@@ -366,39 +367,38 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
366367 clear ();
367368 split_reset ();
368369
369- ubatches. emplace_back ();
370+ auto udata = std::make_shared<llama_ubatch:: data_t > ();
370371
371- auto & ubatch = ubatches.back ();
372-
373- ubatch.token .resize (n_tokens);
374- ubatch.embd .clear ();
375- ubatch.pos .resize (n_tokens);
376- ubatch.n_seq_id .resize (n_tokens);
377- ubatch.seq_id .resize (n_tokens);
378- ubatch.seq_id_unq .resize (0 );
379- ubatch.seq_idx .resize (LLAMA_MAX_SEQ, -1 );
380- ubatch.output .resize (n_tokens);
372+ udata->token .resize (n_tokens);
373+ udata->embd .clear ();
374+ udata->pos .resize (n_tokens);
375+ udata->n_seq_id .resize (n_tokens);
376+ udata->seq_id .resize (n_tokens);
377+ udata->seq_id_unq .resize (0 );
378+ udata->seq_idx .resize (LLAMA_MAX_SEQ, -1 );
379+ udata->output .resize (n_tokens);
381380
382381 for (uint32_t s = 0 ; s < n_seqs; ++s) {
383- ubatch. seq_idx [s] = s;
384- ubatch. seq_id_unq .push_back (s);
382+ udata-> seq_idx [s] = s;
383+ udata-> seq_id_unq .push_back (s);
385384 }
386385
387386 llama_ubatch res {
388- /* .equal_seqs =*/ true ,
387+ /* .b_equal_seqs =*/ true ,
389388 /* .n_tokens =*/ n_tokens,
390389 /* .n_seq_tokens =*/ n_seq_tokens,
391390 /* .n_seqs =*/ n_seqs,
392391 /* .n_seqs_unq =*/ n_seqs,
393392
394- /* .token =*/ ubatch. token .data (),
393+ /* .token =*/ udata-> token .data (),
395394 /* .embd =*/ nullptr ,
396- /* .pos =*/ ubatch.pos .data (),
397- /* .n_seq_id =*/ ubatch.n_seq_id .data (),
398- /* .seq_id =*/ ubatch.seq_id .data (),
399- /* .seq_id_unq =*/ ubatch.seq_id_unq .data (),
400- /* .seq_idx =*/ ubatch.seq_idx .data (),
401- /* .output =*/ ubatch.output .data (),
395+ /* .pos =*/ udata->pos .data (),
396+ /* .n_seq_id =*/ udata->n_seq_id .data (),
397+ /* .seq_id =*/ udata->seq_id .data (),
398+ /* .seq_id_unq =*/ udata->seq_id_unq .data (),
399+ /* .seq_idx =*/ udata->seq_idx .data (),
400+ /* .output =*/ udata->output .data (),
401+ /* .data =*/ std::move (udata),
402402 };
403403
404404 return res;
@@ -439,8 +439,6 @@ void llama_batch_allocr::split_reset() {
439439
440440 used.clear ();
441441 used.resize (get_n_tokens (), false );
442-
443- ubatches.clear ();
444442}
445443
446444llama_ubatch llama_batch_allocr::split_simple (uint32_t n_ubatch) {
@@ -655,78 +653,77 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
655653
656654 assert (n_tokens%n_seqs == 0 );
657655
658- ubatches.emplace_back ();
659-
660- auto & ubatch = ubatches.back ();
656+ auto udata = std::make_shared<llama_ubatch::data_t >();
661657
662658 const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1 ;
663659
664660 const int64_t n_embd_all = batch.embd ? (int64_t ) n_tokens*n_embd : 0 ;
665661 const int64_t n_pos_all = (int64_t ) n_tokens*n_pos_cur;
666662
667- ubatch. token .resize (n_tokens);
668- ubatch. embd .resize (n_embd_all);
669- ubatch. pos .resize (n_pos_all);
670- ubatch. n_seq_id .resize (n_tokens);
671- ubatch. seq_id .resize (n_tokens);
672- ubatch. seq_id_unq .resize (0 );
673- ubatch. seq_idx .resize (LLAMA_MAX_SEQ, -1 );
674- ubatch. output .resize (n_tokens);
663+ udata-> token .resize (n_tokens);
664+ udata-> embd .resize (n_embd_all);
665+ udata-> pos .resize (n_pos_all);
666+ udata-> n_seq_id .resize (n_tokens);
667+ udata-> seq_id .resize (n_tokens);
668+ udata-> seq_id_unq .resize (0 );
669+ udata-> seq_idx .resize (LLAMA_MAX_SEQ, -1 );
670+ udata-> output .resize (n_tokens);
675671
676672 seq_set_t seq_set_unq;
677673
678674 for (size_t i = 0 ; i < idxs.size (); ++i) {
679675 if (batch.token ) {
680- ubatch. token [i] = batch.token [idxs[i]];
676+ udata-> token [i] = batch.token [idxs[i]];
681677 }
682678
683679 if (batch.embd ) {
684- memcpy (ubatch. embd .data () + i*n_embd, batch.embd + (int64_t ) idxs[i]*n_embd, n_embd*sizeof (float ));
680+ memcpy (udata-> embd .data () + i*n_embd, batch.embd + (int64_t ) idxs[i]*n_embd, n_embd*sizeof (float ));
685681 }
686682
687683 for (int j = 0 ; j < n_pos_cur; ++j) {
688- ubatch. pos [j*n_tokens + i] = batch.pos [j*batch.n_tokens + idxs[i]];
684+ udata-> pos [j*n_tokens + i] = batch.pos [j*batch.n_tokens + idxs[i]];
689685 }
690686
691- ubatch. n_seq_id [i] = batch.n_seq_id [idxs[i]];
692- ubatch. seq_id [i] = batch.seq_id [idxs[i]];
693- ubatch. output [i] = batch.logits [idxs[i]];
687+ udata-> n_seq_id [i] = batch.n_seq_id [idxs[i]];
688+ udata-> seq_id [i] = batch.seq_id [idxs[i]];
689+ udata-> output [i] = batch.logits [idxs[i]];
694690
695- for (int s = 0 ; s < ubatch. n_seq_id [i]; ++s) {
696- seq_set_unq.set (ubatch. seq_id [i][s]);
691+ for (int s = 0 ; s < udata-> n_seq_id [i]; ++s) {
692+ seq_set_unq.set (udata-> seq_id [i][s]);
697693 }
698694
699- if (ubatch. output [i]) {
695+ if (udata-> output [i]) {
700696 out_ids.push_back (idxs[i]);
701697 }
702698 }
703699
704700 for (uint32_t s = 0 ; s < n_seq_max; ++s) {
705701 if (seq_set_unq.test (s)) {
706- ubatch. seq_idx [s] = ubatch. seq_id_unq .size ();
707- ubatch. seq_id_unq .push_back (s);
702+ udata-> seq_idx [s] = udata-> seq_id_unq .size ();
703+ udata-> seq_id_unq .push_back (s);
708704 }
709705 }
710706
711707 llama_ubatch res {
712- /* .equal_seqs =*/ equal_seqs,
708+ /* .b_equal_seqs =*/ equal_seqs,
713709 /* .n_tokens =*/ n_tokens,
714710 /* .n_seq_tokens =*/ n_tokens/n_seqs,
715711 /* .n_seqs =*/ n_seqs,
716- /* .n_seqs_unq =*/ (uint32_t ) ubatch.seq_id_unq .size (),
717-
718- /* .token =*/ batch.token ? ubatch.token .data () : nullptr ,
719- /* .embd =*/ batch.embd ? ubatch.embd .data () : nullptr ,
720- /* .pos =*/ ubatch.pos .data (),
721- /* .n_seq_id =*/ ubatch.n_seq_id .data (),
722- /* .seq_id =*/ ubatch.seq_id .data (),
723- /* .seq_id_unq =*/ ubatch.seq_id_unq .data (),
724- /* .seq_idx =*/ ubatch.seq_idx .data (),
725- /* .output =*/ ubatch.output .data (),
712+ /* .n_seqs_unq =*/ (uint32_t ) udata->seq_id_unq .size (),
713+
714+ /* .token =*/ batch.token ? udata->token .data () : nullptr ,
715+ /* .embd =*/ batch.embd ? udata->embd .data () : nullptr ,
716+ /* .pos =*/ udata->pos .data (),
717+ /* .n_seq_id =*/ udata->n_seq_id .data (),
718+ /* .seq_id =*/ udata->seq_id .data (),
719+ /* .seq_id_unq =*/ udata->seq_id_unq .data (),
720+ /* .seq_idx =*/ udata->seq_idx .data (),
721+ /* .output =*/ udata->output .data (),
722+ /* .data =*/ std::move (udata),
726723 };
727724
728725 if (debug > 0 ) {
729- LLAMA_LOG_DEBUG (" %s: added ubatch %d to split:\n " , __func__, ( int ) ubatches. size () - 1 );
726+ LLAMA_LOG_DEBUG (" %s: added ubatch to split:\n " , __func__);
730727
731728 ubatch_print (res, debug);
732729 }
@@ -736,7 +733,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
736733
737734void llama_batch_allocr::ubatch_print (const llama_ubatch & ubatch, int debug) {
738735 if (debug > 0 ) {
739- LLAMA_LOG_DEBUG (" %s: equal_seqs = %d\n " , __func__, ubatch.equal_seqs );
736+ LLAMA_LOG_DEBUG (" %s: equal_seqs = %d\n " , __func__, ubatch.equal_seqs () );
740737 LLAMA_LOG_DEBUG (" %s: n_tokens = %d\n " , __func__, ubatch.n_tokens );
741738 LLAMA_LOG_DEBUG (" %s: n_seq_tokens = %d\n " , __func__, ubatch.n_seq_tokens );
742739 LLAMA_LOG_DEBUG (" %s: n_seqs = %d\n " , __func__, ubatch.n_seqs );
0 commit comments