Skip to content

Commit ac2219f

Browse files
authored
llama : fix session saving/loading (#3400)
* llama : fix session saving/loading * llama : temp fix for clearing "future" tokens from the KV cache * llama : fix handling of "future" tokens when loading sessions * llama : fix comments for llama_kv_cache API
1 parent 48be797 commit ac2219f

File tree

7 files changed

+106
-59
lines changed

7 files changed

+106
-59
lines changed

examples/chat-persistent.sh

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ if [[ -z "${PROMPT_CACHE_FILE+x}" || -z "${CHAT_SAVE_DIR+x}" ]]; then
99
exit 1
1010
fi
1111

12-
MODEL="${MODEL:-./models/13B/ggml-model-q4_0.bin}"
12+
MODEL="${MODEL:-./models/llama-13b/ggml-model-q4_0.gguf}"
1313
PROMPT_TEMPLATE="${PROMPT_TEMPLATE:-./prompts/chat.txt}"
1414
USER_NAME="${USER_NAME:-User}"
1515
AI_NAME="${AI_NAME:-ChatLLaMa}"
@@ -61,9 +61,9 @@ fi
6161

6262
if [[ ! -e "$PROMPT_CACHE_FILE" ]]; then
6363
echo 'Prompt cache does not exist, building...'
64-
# Default batch_size to 8 here for better user feedback during initial prompt processing
64+
# Default batch_size to 64 here for better user feedback during initial prompt processing
6565
./main 2>>"$LOG" \
66-
--batch_size 8 \
66+
--batch_size 64 \
6767
"${OPTS[@]}" \
6868
--prompt-cache "$PROMPT_CACHE_FILE" \
6969
--file "$CUR_PROMPT_FILE" \
@@ -132,7 +132,7 @@ while read -e line; do
132132
# HACK get num tokens from debug message
133133
# TODO get both messages in one go
134134
if ! session_size_msg="$(tail -n30 "$LOG" | grep -oE "$SESSION_SIZE_MSG_PATTERN")" ||
135-
! sample_time_msg="$( tail -n10 "$LOG" | grep -oE "$SAMPLE_TIME_MSG_PATTERN")"; then
135+
! sample_time_msg="$(tail -n10 "$LOG" | grep -oE "$SAMPLE_TIME_MSG_PATTERN")"; then
136136
echo >&2 "Couldn't get number of tokens from ./main output!"
137137
exit 1
138138
fi

examples/main/main.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,9 @@ int main(int argc, char ** argv) {
543543
if (i > 0) {
544544
embd.erase(embd.begin(), embd.begin() + i);
545545
}
546+
547+
// remove any "future" tokens that we might have inherited from the session from the KV cache
548+
llama_kv_cache_tokens_rm(ctx, n_past, -1);
546549
}
547550

548551
// evaluate tokens in batches

examples/parallel/parallel.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ int main(int argc, char ** argv) {
332332
}
333333

334334
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
335-
llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, n_ctx);
335+
llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1);
336336

337337
const auto t_main_end = ggml_time_us();
338338

examples/server/server.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ struct llama_server_context
448448
n_past = common_part(embd, prompt_tokens);
449449

450450
// since #3228 we now have to manually manage the KV cache
451-
llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx);
451+
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
452452

453453
embd = prompt_tokens;
454454
if (n_past == num_prompt_tokens)

examples/speculative/speculative.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ int main(int argc, char ** argv) {
172172
LOG("out of drafted tokens\n");
173173
}
174174

175-
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx);
175+
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
176176
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0));
177177
++n_past_dft;
178178

@@ -257,7 +257,7 @@ int main(int argc, char ** argv) {
257257
}
258258

259259
// evaluate the drafted token on the draft model
260-
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx);
260+
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, -1);
261261
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0));
262262
++n_past_cur;
263263

@@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
267267
}
268268

269269
// evaluate the target model on the drafted tokens
270-
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx);
270+
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, -1);
271271
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0));
272272
++n_past_tgt;
273273

llama.cpp

+85-49
Original file line numberDiff line numberDiff line change
@@ -1283,8 +1283,8 @@ static bool llama_kv_cache_init(
12831283
// find an empty slot of size "n_tokens" in the cache
12841284
// updates the cache head
12851285
static bool llama_kv_cache_find_slot(
1286-
struct llama_kv_cache & cache,
1287-
const struct llama_batch & batch) {
1286+
struct llama_kv_cache & cache,
1287+
const struct llama_batch & batch) {
12881288
const uint32_t n_ctx = cache.size;
12891289
const uint32_t n_tokens = batch.n_tokens;
12901290

@@ -1352,10 +1352,13 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0,
13521352
}
13531353

13541354
static void llama_kv_cache_seq_rm(
1355-
struct llama_kv_cache & cache,
1356-
llama_seq_id seq_id,
1357-
llama_pos p0,
1358-
llama_pos p1) {
1355+
struct llama_kv_cache & cache,
1356+
llama_seq_id seq_id,
1357+
llama_pos p0,
1358+
llama_pos p1) {
1359+
if (p0 < 0) p0 = 0;
1360+
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
1361+
13591362
for (uint32_t i = 0; i < cache.size; ++i) {
13601363
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
13611364
cache.cells[i].seq_id.erase(seq_id);
@@ -1367,11 +1370,14 @@ static void llama_kv_cache_seq_rm(
13671370
}
13681371

13691372
static void llama_kv_cache_seq_cp(
1370-
struct llama_kv_cache & cache,
1371-
llama_seq_id seq_id_src,
1372-
llama_seq_id seq_id_dst,
1373-
llama_pos p0,
1374-
llama_pos p1) {
1373+
struct llama_kv_cache & cache,
1374+
llama_seq_id seq_id_src,
1375+
llama_seq_id seq_id_dst,
1376+
llama_pos p0,
1377+
llama_pos p1) {
1378+
if (p0 < 0) p0 = 0;
1379+
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
1380+
13751381
for (uint32_t i = 0; i < cache.size; ++i) {
13761382
if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
13771383
cache.cells[i].seq_id.insert(seq_id_dst);
@@ -1389,11 +1395,14 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
13891395
}
13901396

13911397
static void llama_kv_cache_seq_shift(
1392-
struct llama_kv_cache & cache,
1393-
llama_seq_id seq_id,
1394-
llama_pos p0,
1395-
llama_pos p1,
1396-
llama_pos delta) {
1398+
struct llama_kv_cache & cache,
1399+
llama_seq_id seq_id,
1400+
llama_pos p0,
1401+
llama_pos p1,
1402+
llama_pos delta) {
1403+
if (p0 < 0) p0 = 0;
1404+
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
1405+
13971406
for (uint32_t i = 0; i < cache.size; ++i) {
13981407
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
13991408
cache.cells[i].pos += delta;
@@ -7209,16 +7218,6 @@ struct llama_data_file_context : llama_data_context {
72097218
*
72107219
*/
72117220
static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
7212-
// TODO: does not support multi-sequence states
7213-
{
7214-
const auto & kv_self = ctx->kv_self;
7215-
for (uint32_t i = 0; i < kv_self.head; ++i) {
7216-
GGML_ASSERT(kv_self.cells[i].pos == (int32_t) i);
7217-
GGML_ASSERT(kv_self.cells[i].seq_id.size() == 1);
7218-
GGML_ASSERT(kv_self.cells[i].has_seq_id(0));
7219-
}
7220-
}
7221-
72227221
// copy rng
72237222
{
72247223
std::stringstream rng_ss;
@@ -7271,36 +7270,38 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
72717270
const auto & hparams = ctx->model.hparams;
72727271
const auto & cparams = ctx->cparams;
72737272

7274-
const int n_layer = hparams.n_layer;
7275-
const int n_embd = hparams.n_embd_gqa();
7276-
const int n_ctx = cparams.n_ctx;
7273+
const auto n_layer = hparams.n_layer;
7274+
const auto n_embd = hparams.n_embd_gqa();
7275+
const auto n_ctx = cparams.n_ctx;
72777276

7278-
const size_t kv_size = kv_self.buf.size;
7279-
const int kv_ntok = kv_self.head;
7277+
const size_t kv_buf_size = kv_self.buf.size;
7278+
const uint32_t kv_head = kv_self.head;
7279+
const uint32_t kv_size = kv_self.size;
72807280

7281-
data_ctx->write(&kv_size, sizeof(kv_size));
7282-
data_ctx->write(&kv_ntok, sizeof(kv_ntok));
7281+
data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
7282+
data_ctx->write(&kv_head, sizeof(kv_head));
7283+
data_ctx->write(&kv_size, sizeof(kv_size));
72837284

7284-
if (kv_size) {
7285+
if (kv_buf_size) {
72857286
const size_t elt_size = ggml_element_size(kv_self.k);
72867287

72877288
ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
72887289
ggml_cgraph gf{};
72897290

7290-
ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
7291+
ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
72917292
std::vector<uint8_t> kout3d_data(ggml_nbytes(kout3d), 0);
72927293
kout3d->data = kout3d_data.data();
72937294

7294-
ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
7295+
ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer);
72957296
std::vector<uint8_t> vout3d_data(ggml_nbytes(vout3d), 0);
72967297
vout3d->data = vout3d_data.data();
72977298

72987299
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
7299-
n_embd, kv_ntok, n_layer,
7300+
n_embd, kv_head, n_layer,
73007301
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
73017302

73027303
ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
7303-
kv_ntok, n_embd, n_layer,
7304+
kv_head, n_embd, n_layer,
73047305
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
73057306

73067307
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
@@ -7314,6 +7315,20 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
73147315
data_ctx->write(kout3d_data.data(), kout3d_data.size());
73157316
data_ctx->write(vout3d_data.data(), vout3d_data.size());
73167317
}
7318+
7319+
for (uint32_t i = 0; i < kv_size; ++i) {
7320+
const auto & cell = kv_self.cells[i];
7321+
7322+
const llama_pos pos = cell.pos;
7323+
const size_t seq_id_size = cell.seq_id.size();
7324+
7325+
data_ctx->write(&pos, sizeof(pos));
7326+
data_ctx->write(&seq_id_size, sizeof(seq_id_size));
7327+
7328+
for (auto seq_id : cell.seq_id) {
7329+
data_ctx->write(&seq_id, sizeof(seq_id));
7330+
}
7331+
}
73177332
}
73187333
}
73197334

@@ -7385,34 +7400,36 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
73857400
const int n_embd = hparams.n_embd_gqa();
73867401
const int n_ctx = cparams.n_ctx;
73877402

7388-
size_t kv_size;
7389-
int kv_ntok;
7403+
size_t kv_buf_size;
7404+
uint32_t kv_head;
7405+
uint32_t kv_size;
73907406

7391-
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
7392-
memcpy(&kv_ntok, inp, sizeof(kv_ntok)); inp += sizeof(kv_ntok);
7407+
memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
7408+
memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head);
7409+
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
73937410

7394-
if (kv_size) {
7395-
GGML_ASSERT(kv_self.buf.size == kv_size);
7411+
if (kv_buf_size) {
7412+
GGML_ASSERT(kv_self.buf.size == kv_buf_size);
73967413

73977414
const size_t elt_size = ggml_element_size(kv_self.k);
73987415

73997416
ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
74007417
ggml_cgraph gf{};
74017418

7402-
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
7419+
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
74037420
kin3d->data = (void *) inp;
74047421
inp += ggml_nbytes(kin3d);
74057422

7406-
ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
7423+
ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer);
74077424
vin3d->data = (void *) inp;
74087425
inp += ggml_nbytes(vin3d);
74097426

74107427
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
7411-
n_embd, kv_ntok, n_layer,
7428+
n_embd, kv_head, n_layer,
74127429
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
74137430

74147431
ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
7415-
kv_ntok, n_embd, n_layer,
7432+
kv_head, n_embd, n_layer,
74167433
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
74177434

74187435
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
@@ -7422,8 +7439,27 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
74227439
ggml_free(cpy_ctx);
74237440
}
74247441

7425-
ctx->kv_self.head = kv_ntok;
7442+
ctx->kv_self.head = kv_head;
74267443
ctx->kv_self.size = kv_size;
7444+
7445+
ctx->kv_self.cells.resize(kv_size);
7446+
7447+
for (uint32_t i = 0; i < kv_size; ++i) {
7448+
llama_pos pos;
7449+
size_t seq_id_size;
7450+
7451+
memcpy(&pos, inp, sizeof(pos)); inp += sizeof(pos);
7452+
memcpy(&seq_id_size, inp, sizeof(seq_id_size)); inp += sizeof(seq_id_size);
7453+
7454+
ctx->kv_self.cells[i].pos = pos;
7455+
7456+
llama_seq_id seq_id;
7457+
7458+
for (size_t j = 0; j < seq_id_size; ++j) {
7459+
memcpy(&seq_id, inp, sizeof(seq_id)); inp += sizeof(seq_id);
7460+
ctx->kv_self.cells[i].seq_id.insert(seq_id);
7461+
}
7462+
}
74277463
}
74287464

74297465
const size_t nread = inp - src;

llama.h

+9-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
4343

4444
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
45-
#define LLAMA_SESSION_VERSION 1
45+
#define LLAMA_SESSION_VERSION 2
4646

4747
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
4848
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
@@ -333,12 +333,16 @@ extern "C" {
333333
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
334334

335335
// Remove all tokens data of cells in [c0, c1)
336+
// c0 < 0 : [0, c1]
337+
// c1 < 0 : [c0, inf)
336338
LLAMA_API void llama_kv_cache_tokens_rm(
337339
struct llama_context * ctx,
338340
int32_t c0,
339341
int32_t c1);
340342

341343
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
344+
// p0 < 0 : [0, p1]
345+
// p1 < 0 : [p0, inf)
342346
LLAMA_API void llama_kv_cache_seq_rm(
343347
struct llama_context * ctx,
344348
llama_seq_id seq_id,
@@ -347,6 +351,8 @@ extern "C" {
347351

348352
// Copy all tokens that belong to the specified sequence to another sequence
349353
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
354+
// p0 < 0 : [0, p1]
355+
// p1 < 0 : [p0, inf)
350356
LLAMA_API void llama_kv_cache_seq_cp(
351357
struct llama_context * ctx,
352358
llama_seq_id seq_id_src,
@@ -361,6 +367,8 @@ extern "C" {
361367

362368
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
363369
// If the KV cache is RoPEd, the KV data is updated accordingly
370+
// p0 < 0 : [0, p1]
371+
// p1 < 0 : [p0, inf)
364372
LLAMA_API void llama_kv_cache_seq_shift(
365373
struct llama_context * ctx,
366374
llama_seq_id seq_id,

0 commit comments

Comments
 (0)