Skip to content

Commit 6289ed6

Browse files
committed
llama : add llama_kv_cache_shift_seq + no more context swaps
1 parent 86c90e3 commit 6289ed6

File tree

4 files changed

+64
-29
lines changed

4 files changed

+64
-29
lines changed

Diff for: common/common.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -781,6 +781,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
781781

782782
std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
783783
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads);
784+
llama_kv_cache_keep_seq(lctx, -1);
784785
llama_reset_timings(lctx);
785786
}
786787

Diff for: examples/main/main.cpp

+13-8
Original file line numberDiff line numberDiff line change
@@ -499,17 +499,22 @@ int main(int argc, char ** argv) {
499499
break;
500500
}
501501

502-
const int n_left = n_past - params.n_keep;
503-
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d\n", n_past, n_left, n_ctx, params.n_keep);
502+
const int n_left = n_past - params.n_keep - 1;
503+
const int n_discard = n_left/2;
504504

505-
// always keep the first token - BOS
506-
n_past = std::max(1, params.n_keep);
507-
n_past_guidance = std::max(1, params.n_keep + guidance_offset);
505+
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
506+
n_past, n_left, n_ctx, params.n_keep, n_discard);
508507

509-
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
508+
llama_kv_cache_rm_seq (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
509+
llama_kv_cache_shift_seq(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
510+
511+
n_past -= n_discard;
510512

511-
// insert n_left/2 tokens at the start of embd from last_tokens
512-
embd.insert(embd.begin(), last_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_tokens.end() - embd.size());
513+
if (ctx_guidance) {
514+
n_past_guidance -= n_discard;
515+
}
516+
517+
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
513518

514519
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
515520

Diff for: llama.cpp

+47-18
Original file line numberDiff line numberDiff line change
@@ -1007,7 +1007,8 @@ struct llama_layer {
10071007
};
10081008

10091009
struct llama_kv_cell {
1010-
llama_pos pos = -1;
1010+
llama_pos pos = -1;
1011+
llama_pos delta = 0;
10111012

10121013
std::set<llama_seq_id> seq_id;
10131014

@@ -1018,7 +1019,7 @@ struct llama_kv_cell {
10181019

10191020
// ring-buffer of cached KV data
10201021
struct llama_kv_cache {
1021-
bool is_roped = false;
1022+
bool has_shift = false;
10221023

10231024
uint32_t head = 0;
10241025
uint32_t size = 0;
@@ -1333,9 +1334,13 @@ void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t
13331334
}
13341335
}
13351336

1336-
void llama_kv_cache_rm_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) {
1337+
void llama_kv_cache_rm_seq(
1338+
struct llama_kv_cache & cache,
1339+
llama_seq_id seq_id,
1340+
llama_pos p0,
1341+
llama_pos p1) {
13371342
for (uint32_t i = 0; i < cache.size; ++i) {
1338-
if (cache.cells[i].has_seq_id(seq_id)) {
1343+
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
13391344
cache.cells[i].seq_id.erase(seq_id);
13401345
if (cache.cells[i].seq_id.empty()) {
13411346
cache.cells[i].pos = -1;
@@ -1353,18 +1358,22 @@ void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id)
13531358
}
13541359
}
13551360

1356-
void llama_kv_cache_shift(
1357-
struct llama_context & ctx,
1361+
void llama_kv_cache_shift_seq(
1362+
struct llama_kv_cache & cache,
13581363
llama_seq_id seq_id,
13591364
llama_pos p0,
13601365
llama_pos p1,
13611366
llama_pos delta) {
1362-
auto & hparams = ctx.model.hparams;
1363-
auto & cache = ctx.kv_self;
1364-
13651367
for (uint32_t i = 0; i < cache.size; ++i) {
13661368
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
13671369
cache.cells[i].pos += delta;
1370+
if (cache.cells[i].pos < 0) {
1371+
cache.cells[i].pos = -1;
1372+
cache.cells[i].seq_id.clear();
1373+
} else {
1374+
cache.has_shift = true;
1375+
cache.cells[i].delta = delta;
1376+
}
13681377
}
13691378
}
13701379
}
@@ -2595,6 +2604,8 @@ static struct ggml_cgraph * llm_build_llama(
25952604
const int32_t n_tokens = batch.n_tokens;
25962605
const int32_t n_kv = llama_kv_cache_cell_max(kv_self);
25972606

2607+
const bool do_rope_shift = kv_self.has_shift;
2608+
25982609
auto & buf_compute = lctx.buf_compute;
25992610

26002611
struct ggml_init_params params = {
@@ -2698,6 +2709,16 @@ static struct ggml_cgraph * llm_build_llama(
26982709
}
26992710
}
27002711

2712+
// K_shift
2713+
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
2714+
ggml_allocr_alloc(lctx.alloc, K_shift);
2715+
if (!ggml_allocr_is_measure(lctx.alloc) && do_rope_shift) {
2716+
int * data = (int *) K_shift->data;
2717+
for (int i = 0; i < n_ctx; ++i) {
2718+
data[i] = kv_self.cells[i].delta;
2719+
}
2720+
}
2721+
27012722
for (int il = 0; il < n_layer; ++il) {
27022723
ggml_format_name(inpL, "layer_inp_%d", il);
27032724

@@ -2723,6 +2744,17 @@ static struct ggml_cgraph * llm_build_llama(
27232744
ggml_set_name(cur, "attention_norm_0");
27242745
}
27252746

2747+
if (do_rope_shift) {
2748+
ggml_build_forward_expand(gf,
2749+
ggml_rope_custom_inplace(ctx0,
2750+
ggml_view_3d(ctx0, kv_self.k,
2751+
n_embd_head, n_head_kv, n_ctx,
2752+
ggml_element_size(kv_self.k)*n_embd_head,
2753+
ggml_element_size(kv_self.k)*n_embd_gqa,
2754+
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
2755+
K_shift, n_embd_head, 0, 0, freq_base, freq_scale));
2756+
}
2757+
27262758
// self-attention
27272759
{
27282760
// compute Q and K and RoPE them
@@ -4033,7 +4065,8 @@ static bool llama_eval_internal(
40334065
#endif
40344066

40354067
// update the kv ring buffer
4036-
lctx.kv_self.head += n_tokens;
4068+
lctx.kv_self.head += n_tokens;
4069+
lctx.kv_self.has_shift = false;
40374070

40384071
#ifdef GGML_PERF
40394072
// print timing information per ggml operation (for debugging purposes)
@@ -6562,10 +6595,6 @@ struct llama_context * llama_new_context_with_model(
65626595
return nullptr;
65636596
}
65646597

6565-
if (model->arch == LLM_ARCH_LLAMA) {
6566-
ctx->kv_self.is_roped = true;
6567-
}
6568-
65696598
{
65706599
const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v);
65716600
LLAMA_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
@@ -6803,16 +6832,16 @@ void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1
68036832
llama_kv_cache_rm_tokens(ctx->kv_self, c0, c1);
68046833
}
68056834

6806-
void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id) {
6807-
llama_kv_cache_rm_seq(ctx->kv_self, seq_id);
6835+
void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
6836+
llama_kv_cache_rm_seq(ctx->kv_self, seq_id, p0, p1);
68086837
}
68096838

68106839
void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id) {
68116840
llama_kv_cache_keep_seq(ctx->kv_self, seq_id);
68126841
}
68136842

6814-
void llama_kv_cache_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
6815-
llama_kv_cache_shift(*ctx, seq_id, p0, p1, delta);
6843+
void llama_kv_cache_shift_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
6844+
llama_kv_cache_shift_seq(ctx->kv_self, seq_id, p0, p1, delta);
68166845
}
68176846

68186847
// Returns the *maximum* size of the state

Diff for: llama.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -324,15 +324,15 @@ extern "C" {
324324
// Remove all tokens data of cells in [c0, c1)
325325
LLAMA_API void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1);
326326

327-
// Removes all tokens that belong to the specified sequence
328-
LLAMA_API void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id);
327+
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
328+
LLAMA_API void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1);
329329

330330
// Removes all tokens that do not belong to the specified sequence
331331
LLAMA_API void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id);
332332

333333
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
334334
// If the KV cache is RoPEd, the KV data is updated accordingly
335-
LLAMA_API void llama_kv_cache_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta);
335+
LLAMA_API void llama_kv_cache_shift_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta);
336336

337337
//
338338
// State / sessions

0 commit comments

Comments
 (0)