Skip to content

Commit 573694e

Browse files
compiladeMinh141120
authored andcommitted
kv-cache : avoid modifying recurrent cells when setting inputs (ggml-org#13834)
* kv-cache : avoid modifying recurrent cells when setting inputs * kv-cache : remove inp_s_mask It was replaced with equivalent and simpler functionality with rs_z (the first zeroed state) and the already-existing inp_s_copy. * kv-cache : fix non-consecutive token pos warning for recurrent models The problem was apparently caused by how the tail cells were swapped. * graph : simplify logic for recurrent state copies * kv-cache : use cell without src refs for rs_z in recurrent cache * llama-graph : fix recurrent state copy The `state_copy` shuffle assumes everything is moved at once, which is not true when `states_extra` is copied back to the cache before copying the range of states between `head` and `head + n_seqs`. This is only a problem if any of the cells in [`head`, `head + n_seqs`) have an `src` in [`head + n_seqs`, `head + n_kv`), which does happen when `n_ubatch > 1` in the `llama-parallel` example. Changing the order of the operations avoids the potential overwrite before use, although when copies are avoided (like with Mamba2), this will require further changes. * llama-graph : rename n_state to state_size in build_recurrent_state This naming should reduce confusion between the state size and the number of states.
1 parent 97d02ff commit 573694e

File tree

4 files changed

+133
-156
lines changed

4 files changed

+133
-156
lines changed

src/llama-graph.cpp

Lines changed: 54 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -191,21 +191,23 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
191191
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
192192

193193
uint32_t * data = (uint32_t *) cls->data;
194-
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
194+
memset(cls->data, 0, n_tokens * ggml_element_size(cls));
195195

196-
std::vector<int> last_pos(n_seqs_unq, -1);
197-
std::vector<int> last_row(n_seqs_unq, -1);
196+
std::vector<int> last_pos(n_tokens, -1);
197+
std::vector<int> last_row(n_tokens, -1);
198198

199-
for (int i = 0; i < n_tokens; ++i) {
200-
const llama_pos pos = ubatch->pos[i];
199+
for (int s = 0; s < n_seqs; ++s) {
200+
const llama_seq_id seq_id = ubatch->seq_id[s][0];
201201

202-
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
203-
const llama_seq_id seq_id = ubatch->seq_id[i][s];
204-
const int32_t seq_idx = ubatch->seq_idx[seq_id];
202+
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
203+
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
205204

206-
if (pos >= last_pos[seq_idx]) {
207-
last_pos[seq_idx] = pos;
208-
last_row[seq_idx] = i;
205+
for (int i = 0; i < n_seq_tokens; ++i) {
206+
const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
207+
208+
if (pos >= last_pos[seq_id]) {
209+
last_pos[seq_id] = pos;
210+
last_row[seq_id] = s*n_seq_tokens + i;
209211
}
210212
}
211213
}
@@ -228,8 +230,8 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
228230
int32_t * data = (int32_t *) s_copy->data;
229231

230232
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
231-
for (uint32_t i = 0; i < n_rs; ++i) {
232-
data[i] = mctx->s_copy(i);
233+
for (uint32_t i = 0; i < n_kv; ++i) {
234+
data[i] = kv_state->s_copy(i);
233235
}
234236
}
235237
}
@@ -962,7 +964,24 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
962964

963965
auto & cur = inp->cls;
964966

965-
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
967+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
968+
ggml_set_input(cur);
969+
970+
res->add_input(std::move(inp));
971+
972+
return cur;
973+
}
974+
975+
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
976+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
977+
978+
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
979+
980+
const auto n_kv = kv_state->get_n_kv();
981+
982+
auto & cur = inp->s_copy;
983+
984+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
966985
ggml_set_input(cur);
967986

968987
res->add_input(std::move(inp));
@@ -1425,19 +1444,27 @@ ggml_tensor * llm_graph_context::build_attn(
14251444
return cur;
14261445
}
14271446

1428-
ggml_tensor * llm_graph_context::build_copy_mask_state(
1447+
ggml_tensor * llm_graph_context::build_recurrent_state(
14291448
ggml_cgraph * gf,
14301449
ggml_tensor * s,
14311450
ggml_tensor * state_copy,
1432-
ggml_tensor * state_mask,
1433-
int32_t n_state,
1434-
int32_t n_seqs) const {
1451+
int32_t state_size,
1452+
int32_t n_seqs,
1453+
bool avoid_copies) const {
14351454
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
14361455

14371456
const auto n_kv = kv_state->get_n_kv();
14381457
const auto kv_head = kv_state->get_head();
1458+
const auto rs_zero = kv_state->get_rs_z();
1459+
1460+
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
14391461

1440-
ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
1462+
// Clear a single state which will then be copied to the other cleared states.
1463+
// Note that this is a no-op when the view is zero-sized.
1464+
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1465+
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
1466+
1467+
ggml_tensor * output_states;
14411468

14421469
// copy states
14431470
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
@@ -1448,7 +1475,8 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
14481475
// FIXME: zero-out NANs?
14491476
states = ggml_mul(ctx0, states, state_mask);
14501477

1451-
// copy states which won't be changed further (between n_seqs and n_kv)
1478+
// copy extra states which won't be changed further (between n_seqs and n_kv)
1479+
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
14521480
ggml_build_forward_expand(gf,
14531481
ggml_cpy(ctx0,
14541482
states_extra,
@@ -1457,47 +1485,10 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
14571485
return output_states;
14581486
}
14591487

1460-
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1461-
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1462-
1463-
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
1464-
1465-
const auto n_rs = mctx_cur->get_n_rs();
1466-
1467-
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1468-
ggml_set_input(inp->s_copy);
1469-
1470-
return (llm_graph_input_rs *) res->add_input(std::move(inp));
1471-
}
1472-
1473-
ggml_tensor * llm_graph_context::build_rs(
1474-
llm_graph_input_rs * inp,
1475-
ggml_cgraph * gf,
1476-
ggml_tensor * s,
1477-
int32_t state_size,
1478-
int32_t n_seqs,
1479-
bool avoid_copies) const {
1480-
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1481-
1482-
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1483-
}
1484-
1485-
ggml_tensor * llm_graph_context::build_rs(
1486-
llm_graph_input_mem_hybrid * inp,
1487-
ggml_cgraph * gf,
1488-
ggml_tensor * s,
1489-
int32_t state_size,
1490-
int32_t n_seqs,
1491-
bool avoid_copies) const {
1492-
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
1493-
1494-
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1495-
}
1496-
14971488
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1498-
llm_graph_input_rs * inp,
1499-
ggml_cgraph * gf,
1500-
const llama_ubatch & ubatch,
1489+
ggml_cgraph * gf,
1490+
ggml_tensor * state_copy,
1491+
const llama_ubatch & ubatch,
15011492
int il) const {
15021493
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
15031494

@@ -1507,9 +1498,9 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
15071498

15081499
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
15091500

1510-
ggml_tensor * token_shift = build_rs(
1511-
inp, gf, token_shift_all,
1512-
hparams.n_embd_r(), n_seqs);
1501+
ggml_tensor * token_shift = build_recurrent_state(
1502+
gf, token_shift_all, state_copy,
1503+
hparams.n_embd_k_s(), n_seqs);
15131504

15141505
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
15151506

src/llama-graph.h

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ class llm_graph_input_rs : public llm_graph_input_i {
199199

200200
ggml_tensor * s_copy; // I32 [kv_size]
201201

202-
const llama_memory_recurrent_context * mctx;
202+
const llama_kv_cache_recurrent_state * kv_state;
203203
};
204204

205205
class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -547,6 +547,7 @@ struct llm_graph_context {
547547
ggml_tensor * build_inp_out_ids() const;
548548
ggml_tensor * build_inp_mean() const;
549549
ggml_tensor * build_inp_cls() const;
550+
ggml_tensor * build_inp_s_copy() const;
550551

551552
ggml_tensor * build_inp_cross_embd() const;
552553
ggml_tensor * build_inp_pos_bucket_enc() const;
@@ -646,46 +647,18 @@ struct llm_graph_context {
646647
// recurrent
647648
//
648649

649-
// TODO: avoid notion of "kv"
650-
// TODO: move this implementation to llama_memory_recurrent.
651-
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
652-
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
653-
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
654-
// `llama_memory_recurrent`
655-
ggml_tensor * build_rs(
656-
ggml_cgraph * gf,
657-
ggml_tensor * s,
658-
ggml_tensor * state_copy,
659-
int32_t state_size,
660-
int32_t n_seqs,
661-
uint32_t n_kv,
662-
uint32_t kv_head,
663-
uint32_t kv_size,
664-
int32_t rs_zero,
665-
bool avoid_copies = false) const;
666-
667-
llm_graph_input_rs * build_rs_inp() const;
668-
669-
ggml_tensor * build_rs(
670-
llm_graph_input_rs * inp,
671-
ggml_cgraph * gf,
672-
ggml_tensor * s,
673-
int32_t state_size,
674-
int32_t n_seqs,
675-
bool avoid_copies = false) const;
676-
677-
ggml_tensor * build_rs(
678-
llm_graph_input_mem_hybrid * inp,
679-
ggml_cgraph * gf,
680-
ggml_tensor * s,
681-
int32_t state_size,
682-
int32_t n_seqs,
683-
bool avoid_copies = false) const;
650+
ggml_tensor * build_recurrent_state(
651+
ggml_cgraph * gf,
652+
ggml_tensor * s,
653+
ggml_tensor * state_copy,
654+
int32_t state_size,
655+
int32_t n_seqs,
656+
bool avoid_copies = false) const;
684657

685658
ggml_tensor * build_rwkv_token_shift_load(
686-
llm_graph_input_rs * inp,
687-
ggml_cgraph * gf,
688-
const llama_ubatch & ubatch,
659+
ggml_cgraph * gf,
660+
ggml_tensor * state_copy,
661+
const llama_ubatch & ubatch,
689662
int il) const;
690663

691664
ggml_tensor * build_rwkv_token_shift_store(

src/llama-memory-recurrent.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,9 @@ bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches)
429429
return success;
430430
}
431431

432-
bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
432+
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
433+
const uint32_t n_seqs = ubatch.n_seqs;
434+
433435
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
434436
const uint32_t n_seqs = ubatch.n_seqs;
435437

@@ -539,7 +541,7 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
539541
seq_meta.tail = next_empty_cell;
540542
// find next empty cell
541543
if (s + 1 < n_seqs) {
542-
for (uint32_t j = 0; j < size; ++j) {
544+
for (uint32_t i = 0; i < size; ++i) {
543545
next_empty_cell += 1;
544546
if (next_empty_cell >= size) { next_empty_cell -= size; }
545547
auto & cell = cells[next_empty_cell];
@@ -553,9 +555,8 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
553555

554556
// gather and re-order
555557
for (uint32_t s = 0; s < n_seqs; ++s) {
556-
const uint32_t i = s*n_seq_tokens;
557558
const int32_t dst_id = s + min;
558-
const int32_t src_id = cells[ubatch.seq_id[i][0]].tail;
559+
const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
559560
if (dst_id != src_id) {
560561
auto & dst_cell = cells[dst_id];
561562
auto & src_cell = cells[src_id];
@@ -565,8 +566,8 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
565566
std::swap(dst_cell.seq_id, src_cell.seq_id);
566567

567568
// swap tails
568-
for (uint32_t j = 0; j < size; ++j) {
569-
int32_t & tail = cells[j].tail;
569+
for (uint32_t i = 0; i < size; ++i) {
570+
int32_t & tail = cells[i].tail;
570571
if (tail == src_id) {
571572
tail = dst_id;
572573
} else if (tail == dst_id) {
@@ -578,10 +579,9 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
578579

579580
// update the pos of the used seqs
580581
for (uint32_t s = 0; s < n_seqs; ++s) {
581-
const uint32_t i = s*n_seq_tokens;
582-
const llama_pos last_pos = ubatch.pos[i + n_seq_tokens - 1];
582+
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
583583
const int32_t cell_id = s + min;
584-
auto & cell = cells[cell_id];
584+
kv_cell & cell = cells[cell_id];
585585

586586
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
587587
// What should happen when the pos backtracks or skips a value?
@@ -634,13 +634,13 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
634634
head = min;
635635
n = max - min + 1;
636636
used = std::count_if(cells.begin(), cells.end(),
637-
[](const mem_cell & cell){ return !cell.is_empty(); });
637+
[](const kv_cell & cell){ return !cell.is_empty(); });
638638

639639
// sanity check
640640
return n >= n_seqs;
641641
}
642642

643-
bool llama_memory_recurrent::get_can_shift() const {
643+
bool llama_kv_cache_recurrent::get_can_shift() const {
644644
// shifting the pos is trivial for recurrent models
645645
return true;
646646
}
@@ -1104,8 +1104,12 @@ uint32_t llama_memory_recurrent_context::get_head() const {
11041104
return is_full ? 0 : mem->head;
11051105
}
11061106

1107-
int32_t llama_memory_recurrent_context::get_rs_z() const {
1108-
return is_full ? 0 : mem->rs_z;
1107+
int32_t llama_kv_cache_recurrent_state::get_rs_z() const {
1108+
return is_full ? 0 : kv->rs_z;
1109+
}
1110+
1111+
uint32_t llama_kv_cache_recurrent_state::get_size() const {
1112+
return kv->size;
11091113
}
11101114

11111115
uint32_t llama_memory_recurrent_context::get_size() const {
@@ -1116,10 +1120,6 @@ ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
11161120
return mem->r_l[il];
11171121
}
11181122

1119-
ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
1120-
return mem->s_l[il];
1121-
}
1122-
1123-
int32_t llama_memory_recurrent_context::s_copy(int i) const {
1124-
return mem->cells[i + mem->head].src0;
1123+
int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
1124+
return kv->cells[i + kv->head].src0;
11251125
}

0 commit comments

Comments
 (0)