@@ -1007,7 +1007,8 @@ struct llama_layer {
1007
1007
};
1008
1008
1009
1009
struct llama_kv_cell {
1010
- llama_pos pos = -1 ;
1010
+ llama_pos pos = -1 ;
1011
+ llama_pos delta = 0 ;
1011
1012
1012
1013
std::set<llama_seq_id> seq_id;
1013
1014
@@ -1018,7 +1019,7 @@ struct llama_kv_cell {
1018
1019
1019
1020
// ring-buffer of cached KV data
1020
1021
struct llama_kv_cache {
1021
- bool is_roped = false ;
1022
+ bool has_shift = false ;
1022
1023
1023
1024
uint32_t head = 0 ;
1024
1025
uint32_t size = 0 ;
@@ -1333,9 +1334,13 @@ void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t
1333
1334
}
1334
1335
}
1335
1336
1336
- void llama_kv_cache_rm_seq (struct llama_kv_cache & cache, llama_seq_id seq_id) {
1337
+ void llama_kv_cache_rm_seq (
1338
+ struct llama_kv_cache & cache,
1339
+ llama_seq_id seq_id,
1340
+ llama_pos p0,
1341
+ llama_pos p1) {
1337
1342
for (uint32_t i = 0 ; i < cache.size ; ++i) {
1338
- if (cache.cells [i].has_seq_id (seq_id)) {
1343
+ if (cache.cells [i].has_seq_id (seq_id) && cache. cells [i]. pos >= p0 && cache. cells [i]. pos < p1 ) {
1339
1344
cache.cells [i].seq_id .erase (seq_id);
1340
1345
if (cache.cells [i].seq_id .empty ()) {
1341
1346
cache.cells [i].pos = -1 ;
@@ -1353,18 +1358,22 @@ void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id)
1353
1358
}
1354
1359
}
1355
1360
1356
- void llama_kv_cache_shift (
1357
- struct llama_context & ctx ,
1361
+ void llama_kv_cache_shift_seq (
1362
+ struct llama_kv_cache & cache ,
1358
1363
llama_seq_id seq_id,
1359
1364
llama_pos p0,
1360
1365
llama_pos p1,
1361
1366
llama_pos delta) {
1362
- auto & hparams = ctx.model .hparams ;
1363
- auto & cache = ctx.kv_self ;
1364
-
1365
1367
for (uint32_t i = 0 ; i < cache.size ; ++i) {
1366
1368
if (cache.cells [i].has_seq_id (seq_id) && cache.cells [i].pos >= p0 && cache.cells [i].pos < p1) {
1367
1369
cache.cells [i].pos += delta;
1370
+ if (cache.cells [i].pos < 0 ) {
1371
+ cache.cells [i].pos = -1 ;
1372
+ cache.cells [i].seq_id .clear ();
1373
+ } else {
1374
+ cache.has_shift = true ;
1375
+ cache.cells [i].delta = delta;
1376
+ }
1368
1377
}
1369
1378
}
1370
1379
}
@@ -2595,6 +2604,8 @@ static struct ggml_cgraph * llm_build_llama(
2595
2604
const int32_t n_tokens = batch.n_tokens ;
2596
2605
const int32_t n_kv = llama_kv_cache_cell_max (kv_self);
2597
2606
2607
+ const bool do_rope_shift = kv_self.has_shift ;
2608
+
2598
2609
auto & buf_compute = lctx.buf_compute ;
2599
2610
2600
2611
struct ggml_init_params params = {
@@ -2698,6 +2709,16 @@ static struct ggml_cgraph * llm_build_llama(
2698
2709
}
2699
2710
}
2700
2711
2712
+ // K_shift
2713
+ struct ggml_tensor * K_shift = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_ctx);
2714
+ ggml_allocr_alloc (lctx.alloc , K_shift);
2715
+ if (!ggml_allocr_is_measure (lctx.alloc ) && do_rope_shift) {
2716
+ int * data = (int *) K_shift->data ;
2717
+ for (int i = 0 ; i < n_ctx; ++i) {
2718
+ data[i] = kv_self.cells [i].delta ;
2719
+ }
2720
+ }
2721
+
2701
2722
for (int il = 0 ; il < n_layer; ++il) {
2702
2723
ggml_format_name (inpL, " layer_inp_%d" , il);
2703
2724
@@ -2723,6 +2744,17 @@ static struct ggml_cgraph * llm_build_llama(
2723
2744
ggml_set_name (cur, " attention_norm_0" );
2724
2745
}
2725
2746
2747
+ if (do_rope_shift) {
2748
+ ggml_build_forward_expand (gf,
2749
+ ggml_rope_custom_inplace (ctx0,
2750
+ ggml_view_3d (ctx0, kv_self.k ,
2751
+ n_embd_head, n_head_kv, n_ctx,
2752
+ ggml_element_size (kv_self.k )*n_embd_head,
2753
+ ggml_element_size (kv_self.k )*n_embd_gqa,
2754
+ ggml_element_size (kv_self.k )*n_embd_gqa*n_ctx*il),
2755
+ K_shift, n_embd_head, 0 , 0 , freq_base, freq_scale));
2756
+ }
2757
+
2726
2758
// self-attention
2727
2759
{
2728
2760
// compute Q and K and RoPE them
@@ -4033,7 +4065,8 @@ static bool llama_eval_internal(
4033
4065
#endif
4034
4066
4035
4067
// update the kv ring buffer
4036
- lctx.kv_self .head += n_tokens;
4068
+ lctx.kv_self .head += n_tokens;
4069
+ lctx.kv_self .has_shift = false ;
4037
4070
4038
4071
#ifdef GGML_PERF
4039
4072
// print timing information per ggml operation (for debugging purposes)
@@ -6562,10 +6595,6 @@ struct llama_context * llama_new_context_with_model(
6562
6595
return nullptr ;
6563
6596
}
6564
6597
6565
- if (model->arch == LLM_ARCH_LLAMA) {
6566
- ctx->kv_self .is_roped = true ;
6567
- }
6568
-
6569
6598
{
6570
6599
const size_t memory_size = ggml_nbytes (ctx->kv_self .k ) + ggml_nbytes (ctx->kv_self .v );
6571
6600
LLAMA_LOG_INFO (" %s: kv self size = %7.2f MB\n " , __func__, memory_size / 1024.0 / 1024.0 );
@@ -6803,16 +6832,16 @@ void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1
6803
6832
llama_kv_cache_rm_tokens (ctx->kv_self , c0, c1);
6804
6833
}
6805
6834
6806
- void llama_kv_cache_rm_seq (struct llama_context * ctx, llama_seq_id seq_id) {
6807
- llama_kv_cache_rm_seq (ctx->kv_self , seq_id);
6835
+ void llama_kv_cache_rm_seq (struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1 ) {
6836
+ llama_kv_cache_rm_seq (ctx->kv_self , seq_id, p0, p1 );
6808
6837
}
6809
6838
6810
6839
void llama_kv_cache_keep_seq (struct llama_context * ctx, llama_seq_id seq_id) {
6811
6840
llama_kv_cache_keep_seq (ctx->kv_self , seq_id);
6812
6841
}
6813
6842
6814
- void llama_kv_cache_shift (struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
6815
- llama_kv_cache_shift (* ctx, seq_id, p0, p1, delta);
6843
+ void llama_kv_cache_shift_seq (struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
6844
+ llama_kv_cache_shift_seq ( ctx-> kv_self , seq_id, p0, p1, delta);
6816
6845
}
6817
6846
6818
6847
// Returns the *maximum* size of the state
0 commit comments