77#include < cassert>
88#include < vector>
99#include < set>
10+ #include < map>
1011
1112// meta information about KV cells that can be part of multiple sequences at the same time
1213// TODO: add unit tests
@@ -216,7 +217,7 @@ class llama_kv_cells_unified {
216217 assert (seq_id >= 0 );
217218
218219 seq[i].reset (seq_id);
219- seq_pos[seq_id]. erase ( pos[i]);
220+ seq_pos_dec (seq_id, pos[i]);
220221
221222 if (seq[i].none ()) {
222223 pos[i] = -1 ;
@@ -239,7 +240,7 @@ class llama_kv_cells_unified {
239240 seq[i].reset ();
240241
241242 seq[i].set (seq_id);
242- seq_pos[seq_id]. insert ( pos[i]);
243+ seq_pos_inc (seq_id, pos[i]);
243244
244245 return false ;
245246 }
@@ -284,7 +285,7 @@ class llama_kv_cells_unified {
284285 assert (!seq[i].test (seq_id));
285286
286287 seq[i].set (seq_id);
287- seq_pos[seq_id]. insert ( pos[i]);
288+ seq_pos_inc (seq_id, pos[i]);
288289 }
289290
290291 // return the sequence id of this cell
@@ -311,7 +312,9 @@ class llama_kv_cells_unified {
311312 return -1 ;
312313 }
313314
314- return *seq_pos[seq_id].begin ();
315+ assert (seq_pos[seq_id].begin ()->second > 0 );
316+
317+ return seq_pos[seq_id].begin ()->first ;
315318 }
316319
317320 // the maximum position of sequence seq_id currently present in any of the cells
@@ -324,7 +327,9 @@ class llama_kv_cells_unified {
324327 return -1 ;
325328 }
326329
327- return *seq_pos[seq_id].rbegin ();
330+ assert (seq_pos[seq_id].rbegin ()->second > 0 );
331+
332+ return seq_pos[seq_id].rbegin ()->first ;
328333 }
329334
330335 // note: call only if the cell is not empty
@@ -441,17 +446,36 @@ class llama_kv_cells_unified {
441446 // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
442447 std::vector<seq_set_t > seq;
443448
444- // the set seq_pos[s] tells us which positions are currently present for sequence s
449+ // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
450+ // if the position p is not present, seq_pos[s][p] is not set
445451 // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
446- std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
452+ //
453+ // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
454+ // - during performing a cache reuse via (rm + add)
455+ // - some vision models have input embeddings with repeating positions
456+ //
457+ std::map<llama_pos, int > seq_pos[LLAMA_MAX_SEQ];
447458
448459 // helper functions for updating `seq_pos`, once cell at a time:
449460
461+ void seq_pos_dec (llama_seq_id s, llama_pos p) {
462+ auto it = seq_pos[s].find (p);
463+ assert (it != seq_pos[s].end ());
464+
465+ if (--it->second == 0 ) {
466+ seq_pos[s].erase (it);
467+ }
468+ }
469+
470+ void seq_pos_inc (llama_seq_id s, llama_pos p) {
471+ seq_pos[s][p]++;
472+ }
473+
450474 // remove cell i
451475 void seq_pos_rm (uint32_t i) {
452476 for (int s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
453477 if (seq[i].test (s)) {
454- seq_pos[s]. erase ( pos[i]);
478+ seq_pos_dec (s, pos[i]);
455479 }
456480 }
457481 }
@@ -460,7 +484,7 @@ class llama_kv_cells_unified {
460484 void seq_pos_add (uint32_t i) {
461485 for (int s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
462486 if (seq[i].test (s)) {
463- seq_pos[s]. insert ( pos[i]);
487+ seq_pos_inc (s, pos[i]);
464488 }
465489 }
466490 }
0 commit comments