@@ -1283,8 +1283,8 @@ static bool llama_kv_cache_init(
1283
1283
// find an empty slot of size "n_tokens" in the cache
1284
1284
// updates the cache head
1285
1285
static bool llama_kv_cache_find_slot (
1286
- struct llama_kv_cache & cache,
1287
- const struct llama_batch & batch) {
1286
+ struct llama_kv_cache & cache,
1287
+ const struct llama_batch & batch) {
1288
1288
const uint32_t n_ctx = cache.size ;
1289
1289
const uint32_t n_tokens = batch.n_tokens ;
1290
1290
@@ -1352,10 +1352,13 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0,
1352
1352
}
1353
1353
1354
1354
static void llama_kv_cache_seq_rm (
1355
- struct llama_kv_cache & cache,
1356
- llama_seq_id seq_id,
1357
- llama_pos p0,
1358
- llama_pos p1) {
1355
+ struct llama_kv_cache & cache,
1356
+ llama_seq_id seq_id,
1357
+ llama_pos p0,
1358
+ llama_pos p1) {
1359
+ if (p0 < 0 ) p0 = 0 ;
1360
+ if (p1 < 0 ) p1 = std::numeric_limits<llama_pos>::max ();
1361
+
1359
1362
for (uint32_t i = 0 ; i < cache.size ; ++i) {
1360
1363
if (cache.cells [i].has_seq_id (seq_id) && cache.cells [i].pos >= p0 && cache.cells [i].pos < p1) {
1361
1364
cache.cells [i].seq_id .erase (seq_id);
@@ -1367,11 +1370,14 @@ static void llama_kv_cache_seq_rm(
1367
1370
}
1368
1371
1369
1372
static void llama_kv_cache_seq_cp (
1370
- struct llama_kv_cache & cache,
1371
- llama_seq_id seq_id_src,
1372
- llama_seq_id seq_id_dst,
1373
- llama_pos p0,
1374
- llama_pos p1) {
1373
+ struct llama_kv_cache & cache,
1374
+ llama_seq_id seq_id_src,
1375
+ llama_seq_id seq_id_dst,
1376
+ llama_pos p0,
1377
+ llama_pos p1) {
1378
+ if (p0 < 0 ) p0 = 0 ;
1379
+ if (p1 < 0 ) p1 = std::numeric_limits<llama_pos>::max ();
1380
+
1375
1381
for (uint32_t i = 0 ; i < cache.size ; ++i) {
1376
1382
if (cache.cells [i].has_seq_id (seq_id_src) && cache.cells [i].pos >= p0 && cache.cells [i].pos < p1) {
1377
1383
cache.cells [i].seq_id .insert (seq_id_dst);
@@ -1389,11 +1395,14 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
1389
1395
}
1390
1396
1391
1397
static void llama_kv_cache_seq_shift (
1392
- struct llama_kv_cache & cache,
1393
- llama_seq_id seq_id,
1394
- llama_pos p0,
1395
- llama_pos p1,
1396
- llama_pos delta) {
1398
+ struct llama_kv_cache & cache,
1399
+ llama_seq_id seq_id,
1400
+ llama_pos p0,
1401
+ llama_pos p1,
1402
+ llama_pos delta) {
1403
+ if (p0 < 0 ) p0 = 0 ;
1404
+ if (p1 < 0 ) p1 = std::numeric_limits<llama_pos>::max ();
1405
+
1397
1406
for (uint32_t i = 0 ; i < cache.size ; ++i) {
1398
1407
if (cache.cells [i].has_seq_id (seq_id) && cache.cells [i].pos >= p0 && cache.cells [i].pos < p1) {
1399
1408
cache.cells [i].pos += delta;
@@ -7209,16 +7218,6 @@ struct llama_data_file_context : llama_data_context {
7209
7218
*
7210
7219
*/
7211
7220
static void llama_copy_state_data_internal (struct llama_context * ctx, llama_data_context * data_ctx) {
7212
- // TODO: does not support multi-sequence states
7213
- {
7214
- const auto & kv_self = ctx->kv_self ;
7215
- for (uint32_t i = 0 ; i < kv_self.head ; ++i) {
7216
- GGML_ASSERT (kv_self.cells [i].pos == (int32_t ) i);
7217
- GGML_ASSERT (kv_self.cells [i].seq_id .size () == 1 );
7218
- GGML_ASSERT (kv_self.cells [i].has_seq_id (0 ));
7219
- }
7220
- }
7221
-
7222
7221
// copy rng
7223
7222
{
7224
7223
std::stringstream rng_ss;
@@ -7271,36 +7270,38 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
7271
7270
const auto & hparams = ctx->model .hparams ;
7272
7271
const auto & cparams = ctx->cparams ;
7273
7272
7274
- const int n_layer = hparams.n_layer ;
7275
- const int n_embd = hparams.n_embd_gqa ();
7276
- const int n_ctx = cparams.n_ctx ;
7273
+ const auto n_layer = hparams.n_layer ;
7274
+ const auto n_embd = hparams.n_embd_gqa ();
7275
+ const auto n_ctx = cparams.n_ctx ;
7277
7276
7278
- const size_t kv_size = kv_self.buf .size ;
7279
- const int kv_ntok = kv_self.head ;
7277
+ const size_t kv_buf_size = kv_self.buf .size ;
7278
+ const uint32_t kv_head = kv_self.head ;
7279
+ const uint32_t kv_size = kv_self.size ;
7280
7280
7281
- data_ctx->write (&kv_size, sizeof (kv_size));
7282
- data_ctx->write (&kv_ntok, sizeof (kv_ntok));
7281
+ data_ctx->write (&kv_buf_size, sizeof (kv_buf_size));
7282
+ data_ctx->write (&kv_head, sizeof (kv_head));
7283
+ data_ctx->write (&kv_size, sizeof (kv_size));
7283
7284
7284
- if (kv_size ) {
7285
+ if (kv_buf_size ) {
7285
7286
const size_t elt_size = ggml_element_size (kv_self.k );
7286
7287
7287
7288
ggml_context * cpy_ctx = ggml_init ({ 4096 , NULL , /* no_alloc */ true });
7288
7289
ggml_cgraph gf{};
7289
7290
7290
- ggml_tensor * kout3d = ggml_new_tensor_3d (cpy_ctx, kv_self.k ->type , n_embd, kv_ntok , n_layer);
7291
+ ggml_tensor * kout3d = ggml_new_tensor_3d (cpy_ctx, kv_self.k ->type , n_embd, kv_head , n_layer);
7291
7292
std::vector<uint8_t > kout3d_data (ggml_nbytes (kout3d), 0 );
7292
7293
kout3d->data = kout3d_data.data ();
7293
7294
7294
- ggml_tensor * vout3d = ggml_new_tensor_3d (cpy_ctx, kv_self.v ->type , kv_ntok , n_embd, n_layer);
7295
+ ggml_tensor * vout3d = ggml_new_tensor_3d (cpy_ctx, kv_self.v ->type , kv_head , n_embd, n_layer);
7295
7296
std::vector<uint8_t > vout3d_data (ggml_nbytes (vout3d), 0 );
7296
7297
vout3d->data = vout3d_data.data ();
7297
7298
7298
7299
ggml_tensor * k3d = ggml_view_3d (cpy_ctx, kv_self.k ,
7299
- n_embd, kv_ntok , n_layer,
7300
+ n_embd, kv_head , n_layer,
7300
7301
elt_size*n_embd, elt_size*n_embd*n_ctx, 0 );
7301
7302
7302
7303
ggml_tensor * v3d = ggml_view_3d (cpy_ctx, kv_self.v ,
7303
- kv_ntok , n_embd, n_layer,
7304
+ kv_head , n_embd, n_layer,
7304
7305
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0 );
7305
7306
7306
7307
ggml_build_forward_expand (&gf, ggml_cpy (cpy_ctx, k3d, kout3d));
@@ -7314,6 +7315,20 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
7314
7315
data_ctx->write (kout3d_data.data (), kout3d_data.size ());
7315
7316
data_ctx->write (vout3d_data.data (), vout3d_data.size ());
7316
7317
}
7318
+
7319
+ for (uint32_t i = 0 ; i < kv_size; ++i) {
7320
+ const auto & cell = kv_self.cells [i];
7321
+
7322
+ const llama_pos pos = cell.pos ;
7323
+ const size_t seq_id_size = cell.seq_id .size ();
7324
+
7325
+ data_ctx->write (&pos, sizeof (pos));
7326
+ data_ctx->write (&seq_id_size, sizeof (seq_id_size));
7327
+
7328
+ for (auto seq_id : cell.seq_id ) {
7329
+ data_ctx->write (&seq_id, sizeof (seq_id));
7330
+ }
7331
+ }
7317
7332
}
7318
7333
}
7319
7334
@@ -7385,34 +7400,36 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
7385
7400
const int n_embd = hparams.n_embd_gqa ();
7386
7401
const int n_ctx = cparams.n_ctx ;
7387
7402
7388
- size_t kv_size;
7389
- int kv_ntok;
7403
+ size_t kv_buf_size;
7404
+ uint32_t kv_head;
7405
+ uint32_t kv_size;
7390
7406
7391
- memcpy (&kv_size, inp, sizeof (kv_size)); inp += sizeof (kv_size);
7392
- memcpy (&kv_ntok, inp, sizeof (kv_ntok)); inp += sizeof (kv_ntok);
7407
+ memcpy (&kv_buf_size, inp, sizeof (kv_buf_size)); inp += sizeof (kv_buf_size);
7408
+ memcpy (&kv_head, inp, sizeof (kv_head)); inp += sizeof (kv_head);
7409
+ memcpy (&kv_size, inp, sizeof (kv_size)); inp += sizeof (kv_size);
7393
7410
7394
- if (kv_size ) {
7395
- GGML_ASSERT (kv_self.buf .size == kv_size );
7411
+ if (kv_buf_size ) {
7412
+ GGML_ASSERT (kv_self.buf .size == kv_buf_size );
7396
7413
7397
7414
const size_t elt_size = ggml_element_size (kv_self.k );
7398
7415
7399
7416
ggml_context * cpy_ctx = ggml_init ({ 4096 , NULL , /* no_alloc */ true });
7400
7417
ggml_cgraph gf{};
7401
7418
7402
- ggml_tensor * kin3d = ggml_new_tensor_3d (cpy_ctx, kv_self.k ->type , n_embd, kv_ntok , n_layer);
7419
+ ggml_tensor * kin3d = ggml_new_tensor_3d (cpy_ctx, kv_self.k ->type , n_embd, kv_head , n_layer);
7403
7420
kin3d->data = (void *) inp;
7404
7421
inp += ggml_nbytes (kin3d);
7405
7422
7406
- ggml_tensor * vin3d = ggml_new_tensor_3d (cpy_ctx, kv_self.v ->type , kv_ntok , n_embd, n_layer);
7423
+ ggml_tensor * vin3d = ggml_new_tensor_3d (cpy_ctx, kv_self.v ->type , kv_head , n_embd, n_layer);
7407
7424
vin3d->data = (void *) inp;
7408
7425
inp += ggml_nbytes (vin3d);
7409
7426
7410
7427
ggml_tensor * k3d = ggml_view_3d (cpy_ctx, kv_self.k ,
7411
- n_embd, kv_ntok , n_layer,
7428
+ n_embd, kv_head , n_layer,
7412
7429
elt_size*n_embd, elt_size*n_embd*n_ctx, 0 );
7413
7430
7414
7431
ggml_tensor * v3d = ggml_view_3d (cpy_ctx, kv_self.v ,
7415
- kv_ntok , n_embd, n_layer,
7432
+ kv_head , n_embd, n_layer,
7416
7433
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0 );
7417
7434
7418
7435
ggml_build_forward_expand (&gf, ggml_cpy (cpy_ctx, kin3d, k3d));
@@ -7422,8 +7439,27 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
7422
7439
ggml_free (cpy_ctx);
7423
7440
}
7424
7441
7425
- ctx->kv_self .head = kv_ntok ;
7442
+ ctx->kv_self .head = kv_head ;
7426
7443
ctx->kv_self .size = kv_size;
7444
+
7445
+ ctx->kv_self .cells .resize (kv_size);
7446
+
7447
+ for (uint32_t i = 0 ; i < kv_size; ++i) {
7448
+ llama_pos pos;
7449
+ size_t seq_id_size;
7450
+
7451
+ memcpy (&pos, inp, sizeof (pos)); inp += sizeof (pos);
7452
+ memcpy (&seq_id_size, inp, sizeof (seq_id_size)); inp += sizeof (seq_id_size);
7453
+
7454
+ ctx->kv_self .cells [i].pos = pos;
7455
+
7456
+ llama_seq_id seq_id;
7457
+
7458
+ for (size_t j = 0 ; j < seq_id_size; ++j) {
7459
+ memcpy (&seq_id, inp, sizeof (seq_id)); inp += sizeof (seq_id);
7460
+ ctx->kv_self .cells [i].seq_id .insert (seq_id);
7461
+ }
7462
+ }
7427
7463
}
7428
7464
7429
7465
const size_t nread = inp - src;
0 commit comments