77#include < array>
88#include < vector>
99#include < set>
10+ #include < bitset>
11+ #include < unordered_map>
1012
13+ // keep this struct lightweight
14+ // it points to data in `llama_batch_allocr`
1115// keep this struct lightweight
1216// it points to data in `llama_batch_allocr`
1317struct llama_ubatch {
@@ -19,105 +23,127 @@ struct llama_ubatch {
1923 uint32_t n_seqs; // sequence sets in the ubatch
2024 uint32_t n_seqs_unq; // unique sequence ids in the ubatch
2125
22- llama_token * token; // [n_tokens]
23- float * embd; // [n_embd, n_tokens]
24- llama_pos * pos; // [n_tokens]
25- int32_t * n_seq_id; // [n_seqs]
26- llama_seq_id ** seq_id; // [n_seqs]
27- int8_t * output; // [n_tokens]
28- };
29-
30- struct llama_sbatch_seq {
31- int32_t n_seq_id;
32-
33- llama_seq_id * seq_id;
34-
35- size_t offset;
36- size_t length;
37- };
38-
39- // sequence-length-aware batch splitting
40- struct llama_sbatch {
41- // tokens left in this batch
42- size_t n_tokens;
43-
44- // only for debugging purposes
45- const llama_vocab * vocab;
46-
47- // sorted indices into the batch
48- std::vector<int64_t > ids;
49- // batch indices of the output
50- std::vector<int64_t > out_ids;
51- std::vector<llama_sbatch_seq> seq;
52-
53- std::array<llama_seq_id, 1 > seq_id_0 = { 0 }; // default sequence id
54-
55- // buffers for the ubatches
56- // TODO: very hacky, this needs a complete rework
57- struct ubatch_data {
58- std::vector<llama_token> token;
59- std::vector<float > embd;
60- std::vector<llama_pos> pos;
61- std::vector<int32_t > n_seq_id;
62- std::vector<llama_seq_id *> seq_id;
63- std::vector<int8_t > output;
64- };
65-
66- std::vector<ubatch_data> udatas;
67-
68- llama_ubatch reserve_ubatch (size_t n_ubatch, bool has_embd = false );
69-
70- void add_seq_to_ubatch (llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
71-
72- // simple split, unknown number of sequences of unequal lengths
73- llama_ubatch split_simple (size_t n_ubatch);
74-
75- // make batches of equal-length sequences
76- llama_ubatch split_equal (size_t n_ubatch);
77-
78- // sequence-wise split
79- llama_ubatch split_seq (size_t n_ubatch);
80-
81- llama_sbatch () = default ;
82- llama_sbatch (const llama_batch & batch, size_t n_embd, bool simple_split = false );
26+ // seq_id_unq: unique sequence ids in the ubatch
27+ // seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
28+ // used for extracting sequence pooled embeddings
29+
30+ // // size | idx | val
31+ llama_token * token; // [n_tokens] | i | id, token
32+ float * embd; // [n_embd, n_tokens] | i | embd
33+ llama_pos * pos; // [n_tokens] | i | pos
34+ int32_t * n_seq_id; // [n_tokens] | i | -
35+ llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
36+ llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
37+ int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
38+ int8_t * output; // [n_tokens] | i | -
8339};
8440
85- // a helper for sanitizing and fulfilling a batch
41+ // a helper for sanitizing, fulfilling and splitting a batch
8642class llama_batch_allocr {
8743public:
88- llama_batch_allocr ();
44+ llama_batch_allocr (uint32_t n_pos_per_embd );
8945
9046 // sanitize and auto-gen missing data in the input batch
9147 // memory is optional. if provided will be used to check for sequence continuity and to determine the positions
9248 bool init (
9349 const llama_batch & batch_inp,
9450 const llama_vocab & vocab,
9551 const llama_memory_i * memory,
96- bool embd_all);
52+ uint32_t n_embd,
53+ bool output_all);
9754
9855 const llama_batch & get_batch () const ;
9956
57+ uint32_t get_n_tokens () const ;
10058 uint32_t get_n_outputs () const ;
10159
60+ // the array of output indices in the order they were encountered during the ubatch splitting
61+ std::vector<int32_t > & get_out_ids ();
62+
63+ // min/max positions of each sequence in the current ubatch
10264 llama_pos seq_pos_min (llama_seq_id seq_id) const ;
10365 llama_pos seq_pos_max (llama_seq_id seq_id) const ;
10466
67+ // call once before splitting the batch to reset the internal state
68+ void split_reset ();
69+
70+ // simple split, unknown number of sequence sets of unequal lengths
71+ llama_ubatch split_simple (uint32_t n_ubatch);
72+
73+ // make ubatches of equal-length sequences sets
74+ llama_ubatch split_equal (uint32_t n_ubatch);
75+
76+ // sequence-set-wise split - each ubatch contains a single sequence-set
77+ llama_ubatch split_seq (uint32_t n_ubatch);
78+
79+ // a helper method for creating a well-defined ubatch of tokens
80+ // TODO: support embeddings if needed in the future
81+ llama_ubatch ubatch_reserve (uint32_t n_seq_tokens, uint32_t n_seqs);
82+
10583private:
10684 void clear ();
10785
86+ // create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
87+ // return llama_ubatch.n_tokens == 0 if the entire batch was consumed
88+ llama_ubatch ubatch_add (const std::vector<int32_t > & idxs, uint32_t n_seqs, bool equal_seqs);
89+
90+ // for debugging, start with LLAMA_BATCH_DEBUG=2
91+ void ubatch_print (const llama_ubatch & ubatch, int debug);
92+
10893 llama_batch batch;
10994
95+ // only for debugging purposes
96+ const llama_vocab * vocab;
97+
98+ // TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
99+ // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
100+ const uint32_t n_pos_per_embd;
101+
102+ uint32_t n_embd;
110103 uint32_t n_outputs;
111104
112105 std::array<llama_seq_id, 1 > seq_id_0 = { 0 }; // default sequence id
113106
114107 std::vector<llama_pos> pos;
115108 std::vector<int32_t > n_seq_id;
116109 std::vector<llama_seq_id *> seq_id;
110+ std::vector<llama_seq_id> seq_id_unq;
111+ std::vector<int32_t > seq_idx;
117112 std::vector<int8_t > output;
118113
119- std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
120- std::vector<std::vector<bool >> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
114+ using pos_set_t = std::set<llama_pos>;
115+ using seq_cpl_t = std::vector<bool >;
116+
117+ std::vector<pos_set_t > seq_pos; // seq_pos[s]: the set of positions in sequence s
118+ std::vector<seq_cpl_t > seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
119+
120+ using idx_vec_t = std::vector<int32_t >;
121+ using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
122+
123+ std::vector<seq_set_t > seq_set; // seq_set[i]: the sequence set of token i
124+
125+ std::unordered_map<seq_set_t , idx_vec_t > seq_set_map; // the indices at which the sequence set appears
126+
127+ // batch indices of the output
128+ std::vector<int32_t > out_ids;
129+
130+ // used[i] indicates if token i has already been used in a previous ubatch
131+ std::vector<bool > used;
132+
133+ // llama_ubatch points to this data:
134+ struct ubatch {
135+ std::vector<llama_token> token;
136+ std::vector<float > embd;
137+ std::vector<llama_pos> pos;
138+ std::vector<int32_t > n_seq_id;
139+ std::vector<llama_seq_id *> seq_id;
140+ std::vector<llama_seq_id> seq_id_unq;
141+ std::vector<int32_t > seq_idx;
142+ std::vector<int8_t > output;
143+ };
144+
145+ // current splitting state:
146+ std::vector<ubatch> ubatches;
121147
122148 int debug;
123149};
0 commit comments