@@ -1007,7 +1007,8 @@ struct llama_layer {
1007
1007
};
1008
1008
1009
1009
struct llama_kv_cell {
1010
- llama_pos pos = -1 ;
1010
+ llama_pos pos = -1 ;
1011
+ llama_pos delta = 0 ;
1011
1012
1012
1013
std::set<llama_seq_id> seq_id;
1013
1014
@@ -1018,7 +1019,7 @@ struct llama_kv_cell {
1018
1019
1019
1020
// ring-buffer of cached KV data
1020
1021
struct llama_kv_cache {
1021
- bool is_roped = false ;
1022
+ bool has_shift = false ;
1022
1023
1023
1024
uint32_t head = 0 ;
1024
1025
uint32_t size = 0 ;
@@ -1223,6 +1224,8 @@ static bool llama_kv_cache_init(
1223
1224
const int64_t n_mem = n_layer*n_ctx;
1224
1225
const int64_t n_elements = n_embd*n_mem;
1225
1226
1227
+ cache.has_shift = false ;
1228
+
1226
1229
cache.head = 0 ;
1227
1230
cache.size = n_ctx;
1228
1231
@@ -1333,9 +1336,13 @@ void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t
1333
1336
}
1334
1337
}
1335
1338
1336
- void llama_kv_cache_rm_seq (struct llama_kv_cache & cache, llama_seq_id seq_id) {
1339
+ void llama_kv_cache_rm_seq (
1340
+ struct llama_kv_cache & cache,
1341
+ llama_seq_id seq_id,
1342
+ llama_pos p0,
1343
+ llama_pos p1) {
1337
1344
for (uint32_t i = 0 ; i < cache.size ; ++i) {
1338
- if (cache.cells [i].has_seq_id (seq_id)) {
1345
+ if (cache.cells [i].has_seq_id (seq_id) && cache. cells [i]. pos >= p0 && cache. cells [i]. pos < p1 ) {
1339
1346
cache.cells [i].seq_id .erase (seq_id);
1340
1347
if (cache.cells [i].seq_id .empty ()) {
1341
1348
cache.cells [i].pos = -1 ;
@@ -1353,18 +1360,22 @@ void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id)
1353
1360
}
1354
1361
}
1355
1362
1356
- void llama_kv_cache_shift (
1357
- struct llama_context & ctx ,
1363
+ void llama_kv_cache_shift_seq (
1364
+ struct llama_kv_cache & cache ,
1358
1365
llama_seq_id seq_id,
1359
1366
llama_pos p0,
1360
1367
llama_pos p1,
1361
1368
llama_pos delta) {
1362
- auto & hparams = ctx.model .hparams ;
1363
- auto & cache = ctx.kv_self ;
1364
-
1365
1369
for (uint32_t i = 0 ; i < cache.size ; ++i) {
1366
1370
if (cache.cells [i].has_seq_id (seq_id) && cache.cells [i].pos >= p0 && cache.cells [i].pos < p1) {
1367
1371
cache.cells [i].pos += delta;
1372
+ if (cache.cells [i].pos < 0 ) {
1373
+ cache.cells [i].pos = -1 ;
1374
+ cache.cells [i].seq_id .clear ();
1375
+ } else {
1376
+ cache.has_shift = true ;
1377
+ cache.cells [i].delta = delta;
1378
+ }
1368
1379
}
1369
1380
}
1370
1381
}
@@ -2595,6 +2606,8 @@ static struct ggml_cgraph * llm_build_llama(
2595
2606
const int32_t n_tokens = batch.n_tokens ;
2596
2607
const int32_t n_kv = llama_kv_cache_cell_max (kv_self);
2597
2608
2609
+ const bool do_rope_shift = kv_self.has_shift || ggml_allocr_is_measure (lctx.alloc );
2610
+
2598
2611
auto & buf_compute = lctx.buf_compute ;
2599
2612
2600
2613
struct ggml_init_params params = {
@@ -2698,6 +2711,16 @@ static struct ggml_cgraph * llm_build_llama(
2698
2711
}
2699
2712
}
2700
2713
2714
+ // K_shift
2715
+ struct ggml_tensor * K_shift = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_ctx);
2716
+ ggml_allocr_alloc (lctx.alloc , K_shift);
2717
+ if (!ggml_allocr_is_measure (lctx.alloc )) {
2718
+ int * data = (int *) K_shift->data ;
2719
+ for (int i = 0 ; i < n_ctx; ++i) {
2720
+ data[i] = kv_self.cells [i].delta ;
2721
+ }
2722
+ }
2723
+
2701
2724
for (int il = 0 ; il < n_layer; ++il) {
2702
2725
ggml_format_name (inpL, " layer_inp_%d" , il);
2703
2726
@@ -2723,6 +2746,17 @@ static struct ggml_cgraph * llm_build_llama(
2723
2746
ggml_set_name (cur, " attention_norm_0" );
2724
2747
}
2725
2748
2749
+ if (do_rope_shift) {
2750
+ ggml_build_forward_expand (gf,
2751
+ ggml_rope_custom_inplace (ctx0,
2752
+ ggml_view_3d (ctx0, kv_self.k ,
2753
+ n_embd_head, n_head_kv, n_ctx,
2754
+ ggml_element_size (kv_self.k )*n_embd_head,
2755
+ ggml_element_size (kv_self.k )*n_embd_gqa,
2756
+ ggml_element_size (kv_self.k )*n_embd_gqa*n_ctx*il),
2757
+ K_shift, n_embd_head, 0 , 0 , freq_base, freq_scale));
2758
+ }
2759
+
2726
2760
// self-attention
2727
2761
{
2728
2762
// compute Q and K and RoPE them
@@ -4033,7 +4067,8 @@ static bool llama_eval_internal(
4033
4067
#endif
4034
4068
4035
4069
// update the kv ring buffer
4036
- lctx.kv_self .head += n_tokens;
4070
+ lctx.kv_self .head += n_tokens;
4071
+ lctx.kv_self .has_shift = false ;
4037
4072
4038
4073
#ifdef GGML_PERF
4039
4074
// print timing information per ggml operation (for debugging purposes)
@@ -6562,10 +6597,6 @@ struct llama_context * llama_new_context_with_model(
6562
6597
return nullptr ;
6563
6598
}
6564
6599
6565
- if (model->arch == LLM_ARCH_LLAMA) {
6566
- ctx->kv_self .is_roped = true ;
6567
- }
6568
-
6569
6600
{
6570
6601
const size_t memory_size = ggml_nbytes (ctx->kv_self .k ) + ggml_nbytes (ctx->kv_self .v );
6571
6602
LLAMA_LOG_INFO (" %s: kv self size = %7.2f MB\n " , __func__, memory_size / 1024.0 / 1024.0 );
@@ -6803,16 +6834,16 @@ void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1
6803
6834
llama_kv_cache_rm_tokens (ctx->kv_self , c0, c1);
6804
6835
}
6805
6836
6806
- void llama_kv_cache_rm_seq (struct llama_context * ctx, llama_seq_id seq_id) {
6807
- llama_kv_cache_rm_seq (ctx->kv_self , seq_id);
6837
+ void llama_kv_cache_rm_seq (struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1 ) {
6838
+ llama_kv_cache_rm_seq (ctx->kv_self , seq_id, p0, p1 );
6808
6839
}
6809
6840
6810
6841
void llama_kv_cache_keep_seq (struct llama_context * ctx, llama_seq_id seq_id) {
6811
6842
llama_kv_cache_keep_seq (ctx->kv_self , seq_id);
6812
6843
}
6813
6844
6814
- void llama_kv_cache_shift (struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
6815
- llama_kv_cache_shift (* ctx, seq_id, p0, p1, delta);
6845
+ void llama_kv_cache_shift_seq (struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
6846
+ llama_kv_cache_shift_seq ( ctx-> kv_self , seq_id, p0, p1, delta);
6816
6847
}
6817
6848
6818
6849
// Returns the *maximum* size of the state
0 commit comments