Skip to content

Commit 439d562

Browse files
committed
kv-cells : fix tracking of seq_pos during cache reuse
ggml-ci
1 parent 238005c commit 439d562

File tree

2 files changed

+33
-10
lines changed

2 files changed

+33
-10
lines changed

src/llama-batch.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,8 @@ bool llama_batch_allocr::init(
247247
if (memory) {
248248
if (batch.token) {
249249
if (seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
250-
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
250+
LLAMA_LOG_ERROR("%s: sequence %d (min = %d) does not start from the last position (%d) stored in the memory\n",
251+
__func__, s, seq_pos_min(s), memory->seq_pos_max(s));
251252
return false;
252253
}
253254
} else {
@@ -256,7 +257,8 @@ bool llama_batch_allocr::init(
256257
// for embeddings (typically used as vision input), we allow them to have repeating positions
257258
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
258259
if (seq_pos_min(s) != memory->seq_pos_max(s) && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
259-
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
260+
LLAMA_LOG_ERROR("%s: sequence %d (min = %d) does not start from the last position (%d) stored in the memory\n",
261+
__func__, s, seq_pos_min(s), memory->seq_pos_max(s));
260262
return false;
261263
}
262264
}

src/llama-kv-cells.h

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <cassert>
88
#include <vector>
99
#include <set>
10+
#include <map>
1011

1112
// meta information about KV cells that can be part of multiple sequences at the same time
1213
// TODO: add unit tests
@@ -164,7 +165,7 @@ class llama_kv_cells_unified {
164165
assert(seq_id >= 0);
165166

166167
seq[i].reset(seq_id);
167-
seq_pos[seq_id].erase(pos[i]);
168+
seq_pos_dec(seq_id, pos[i]);
168169

169170
if (seq[i].none()) {
170171
pos[i] = -1;
@@ -187,7 +188,7 @@ class llama_kv_cells_unified {
187188
seq[i].reset();
188189

189190
seq[i].set(seq_id);
190-
seq_pos[seq_id].insert(pos[i]);
191+
seq_pos_inc(seq_id, pos[i]);
191192

192193
return false;
193194
}
@@ -232,7 +233,7 @@ class llama_kv_cells_unified {
232233
assert(!seq[i].test(seq_id));
233234

234235
seq[i].set(seq_id);
235-
seq_pos[seq_id].insert(pos[i]);
236+
seq_pos_inc(seq_id, pos[i]);
236237
}
237238

238239
// return the sequence id of this cell
@@ -259,7 +260,9 @@ class llama_kv_cells_unified {
259260
return -1;
260261
}
261262

262-
return *seq_pos[seq_id].begin();
263+
assert(seq_pos[seq_id].begin()->second > 0);
264+
265+
return seq_pos[seq_id].begin()->first;
263266
}
264267

265268
// the maximum position of sequence seq_id currently present in any of the cells
@@ -272,7 +275,9 @@ class llama_kv_cells_unified {
272275
return -1;
273276
}
274277

275-
return *seq_pos[seq_id].rbegin();
278+
assert(seq_pos[seq_id].rbegin()->second > 0);
279+
280+
return seq_pos[seq_id].rbegin()->first;
276281
}
277282

278283
// note: call only if the cell is not empty
@@ -391,15 +396,31 @@ class llama_kv_cells_unified {
391396

392397
// the set seq_pos[s] tells us which positions are currently present for sequence s
393398
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
394-
std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
399+
//
400+
// note that we cannot a use an std::set because in some cases a position can be more than once for the same seq:
401+
// - during performing a cache reuse via (add + rm)
402+
// - some vision models have input embeddings with repeating positions
403+
std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
395404

396405
// helper functions for updating `seq_pos`, once cell at a time:
397406

407+
void seq_pos_dec(llama_seq_id s, llama_pos p) {
408+
auto it = seq_pos[s].find(p);
409+
assert(it != seq_pos[s].end());
410+
if (--it->second == 0) {
411+
seq_pos[s].erase(it);
412+
}
413+
}
414+
415+
void seq_pos_inc(llama_seq_id s, llama_pos p) {
416+
seq_pos[s][p]++;
417+
}
418+
398419
// remove cell i
399420
void seq_pos_rm(uint32_t i) {
400421
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
401422
if (seq[i].test(s)) {
402-
seq_pos[s].erase(pos[i]);
423+
seq_pos_dec(s, pos[i]);
403424
}
404425
}
405426
}
@@ -408,7 +429,7 @@ class llama_kv_cells_unified {
408429
void seq_pos_add(uint32_t i) {
409430
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
410431
if (seq[i].test(s)) {
411-
seq_pos[s].insert(pos[i]);
432+
seq_pos_inc(s, pos[i]);
412433
}
413434
}
414435
}

0 commit comments

Comments
 (0)