77#include < cassert>
88#include < vector>
99#include < set>
10+ #include < map>
1011
1112// meta information about KV cells that can be part of multiple sequences at the same time
1213// TODO: add unit tests
@@ -164,7 +165,7 @@ class llama_kv_cells_unified {
164165 assert (seq_id >= 0 );
165166
166167 seq[i].reset (seq_id);
167- seq_pos[seq_id]. erase ( pos[i]);
168+ seq_pos_dec (seq_id, pos[i]);
168169
169170 if (seq[i].none ()) {
170171 pos[i] = -1 ;
@@ -187,7 +188,7 @@ class llama_kv_cells_unified {
187188 seq[i].reset ();
188189
189190 seq[i].set (seq_id);
190- seq_pos[seq_id]. insert ( pos[i]);
191+ seq_pos_inc (seq_id, pos[i]);
191192
192193 return false ;
193194 }
@@ -232,7 +233,7 @@ class llama_kv_cells_unified {
232233 assert (!seq[i].test (seq_id));
233234
234235 seq[i].set (seq_id);
235- seq_pos[seq_id]. insert ( pos[i]);
236+ seq_pos_inc (seq_id, pos[i]);
236237 }
237238
238239 // return the sequence id of this cell
@@ -259,7 +260,9 @@ class llama_kv_cells_unified {
259260 return -1 ;
260261 }
261262
262- return *seq_pos[seq_id].begin ();
263+ assert (seq_pos[seq_id].begin ()->second > 0 );
264+
265+ return seq_pos[seq_id].begin ()->first ;
263266 }
264267
265268 // the maximum position of sequence seq_id currently present in any of the cells
@@ -272,7 +275,9 @@ class llama_kv_cells_unified {
272275 return -1 ;
273276 }
274277
275- return *seq_pos[seq_id].rbegin ();
278+ assert (seq_pos[seq_id].rbegin ()->second > 0 );
279+
280+ return seq_pos[seq_id].rbegin ()->first ;
276281 }
277282
278283 // note: call only if the cell is not empty
@@ -391,15 +396,31 @@ class llama_kv_cells_unified {
391396
392397 // the set seq_pos[s] tells us which positions are currently present for sequence s
393398 // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
394- std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
399+ //
400+ // note that we cannot a use an std::set because in some cases a position can be more than once for the same seq:
401+ // - during performing a cache reuse via (add + rm)
402+ // - some vision models have input embeddings with repeating positions
403+ std::map<llama_pos, int > seq_pos[LLAMA_MAX_SEQ];
395404
396405 // helper functions for updating `seq_pos`, once cell at a time:
397406
407+ void seq_pos_dec (llama_seq_id s, llama_pos p) {
408+ auto it = seq_pos[s].find (p);
409+ assert (it != seq_pos[s].end ());
410+ if (--it->second == 0 ) {
411+ seq_pos[s].erase (it);
412+ }
413+ }
414+
415+ void seq_pos_inc (llama_seq_id s, llama_pos p) {
416+ seq_pos[s][p]++;
417+ }
418+
398419 // remove cell i
399420 void seq_pos_rm (uint32_t i) {
400421 for (int s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
401422 if (seq[i].test (s)) {
402- seq_pos[s]. erase ( pos[i]);
423+ seq_pos_dec (s, pos[i]);
403424 }
404425 }
405426 }
@@ -408,7 +429,7 @@ class llama_kv_cells_unified {
408429 void seq_pos_add (uint32_t i) {
409430 for (int s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
410431 if (seq[i].test (s)) {
411- seq_pos[s]. insert ( pos[i]);
432+ seq_pos_inc (s, pos[i]);
412433 }
413434 }
414435 }
0 commit comments