Skip to content

Commit

Permalink
llama : fix session saving/loading (ggerganov#3400)
Browse files Browse the repository at this point in the history
* llama : fix session saving/loading

* llama : temp fix for clearing "future" tokens from the KV cache

* llama : fix handling of "future" tokens when loading sessions

* llama : fix comments for llama_kv_cache API
  • Loading branch information
ggerganov authored and yusiwen committed Oct 7, 2023
1 parent 6487ea7 commit dfde801
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 59 deletions.
8 changes: 4 additions & 4 deletions examples/chat-persistent.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ if [[ -z "${PROMPT_CACHE_FILE+x}" || -z "${CHAT_SAVE_DIR+x}" ]]; then
exit 1
fi

MODEL="${MODEL:-./models/13B/ggml-model-q4_0.bin}"
MODEL="${MODEL:-./models/llama-13b/ggml-model-q4_0.gguf}"
PROMPT_TEMPLATE="${PROMPT_TEMPLATE:-./prompts/chat.txt}"
USER_NAME="${USER_NAME:-User}"
AI_NAME="${AI_NAME:-ChatLLaMa}"
Expand Down Expand Up @@ -61,9 +61,9 @@ fi

if [[ ! -e "$PROMPT_CACHE_FILE" ]]; then
echo 'Prompt cache does not exist, building...'
# Default batch_size to 8 here for better user feedback during initial prompt processing
# Default batch_size to 64 here for better user feedback during initial prompt processing
./main 2>>"$LOG" \
--batch_size 8 \
--batch_size 64 \
"${OPTS[@]}" \
--prompt-cache "$PROMPT_CACHE_FILE" \
--file "$CUR_PROMPT_FILE" \
Expand Down Expand Up @@ -132,7 +132,7 @@ while read -e line; do
# HACK get num tokens from debug message
# TODO get both messages in one go
if ! session_size_msg="$(tail -n30 "$LOG" | grep -oE "$SESSION_SIZE_MSG_PATTERN")" ||
! sample_time_msg="$( tail -n10 "$LOG" | grep -oE "$SAMPLE_TIME_MSG_PATTERN")"; then
! sample_time_msg="$(tail -n10 "$LOG" | grep -oE "$SAMPLE_TIME_MSG_PATTERN")"; then
echo >&2 "Couldn't get number of tokens from ./main output!"
exit 1
fi
Expand Down
3 changes: 3 additions & 0 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,9 @@ int main(int argc, char ** argv) {
if (i > 0) {
embd.erase(embd.begin(), embd.begin() + i);
}

// remove any "future" tokens that we might have inherited from the session from the KV cache
llama_kv_cache_tokens_rm(ctx, n_past, -1);
}

// evaluate tokens in batches
Expand Down
2 changes: 1 addition & 1 deletion examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ int main(int argc, char ** argv) {
}

// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, n_ctx);
llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1);

const auto t_main_end = ggml_time_us();

Expand Down
2 changes: 1 addition & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ struct llama_server_context
n_past = common_part(embd, prompt_tokens);

// since #3228 we now have to manually manage the KV cache
llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx);
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);

embd = prompt_tokens;
if (n_past == num_prompt_tokens)
Expand Down
6 changes: 3 additions & 3 deletions examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ int main(int argc, char ** argv) {
LOG("out of drafted tokens\n");
}

llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0));
++n_past_dft;

Expand Down Expand Up @@ -257,7 +257,7 @@ int main(int argc, char ** argv) {
}

// evaluate the drafted token on the draft model
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, -1);
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0));
++n_past_cur;

Expand All @@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
}

// evaluate the target model on the drafted tokens
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx);
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, -1);
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0));
++n_past_tgt;

Expand Down
134 changes: 85 additions & 49 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1283,8 +1283,8 @@ static bool llama_kv_cache_init(
// find an empty slot of size "n_tokens" in the cache
// updates the cache head
static bool llama_kv_cache_find_slot(
struct llama_kv_cache & cache,
const struct llama_batch & batch) {
struct llama_kv_cache & cache,
const struct llama_batch & batch) {
const uint32_t n_ctx = cache.size;
const uint32_t n_tokens = batch.n_tokens;

Expand Down Expand Up @@ -1352,10 +1352,13 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0,
}

static void llama_kv_cache_seq_rm(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();

for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.cells[i].seq_id.erase(seq_id);
Expand All @@ -1367,11 +1370,14 @@ static void llama_kv_cache_seq_rm(
}

static void llama_kv_cache_seq_cp(
struct llama_kv_cache & cache,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
struct llama_kv_cache & cache,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();

for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.cells[i].seq_id.insert(seq_id_dst);
Expand All @@ -1389,11 +1395,14 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
}

static void llama_kv_cache_seq_shift(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();

for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.cells[i].pos += delta;
Expand Down Expand Up @@ -7209,16 +7218,6 @@ struct llama_data_file_context : llama_data_context {
*
*/
static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
// TODO: does not support multi-sequence states
{
const auto & kv_self = ctx->kv_self;
for (uint32_t i = 0; i < kv_self.head; ++i) {
GGML_ASSERT(kv_self.cells[i].pos == (int32_t) i);
GGML_ASSERT(kv_self.cells[i].seq_id.size() == 1);
GGML_ASSERT(kv_self.cells[i].has_seq_id(0));
}
}

// copy rng
{
std::stringstream rng_ss;
Expand Down Expand Up @@ -7271,36 +7270,38 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
const auto & hparams = ctx->model.hparams;
const auto & cparams = ctx->cparams;

const int n_layer = hparams.n_layer;
const int n_embd = hparams.n_embd_gqa();
const int n_ctx = cparams.n_ctx;
const auto n_layer = hparams.n_layer;
const auto n_embd = hparams.n_embd_gqa();
const auto n_ctx = cparams.n_ctx;

const size_t kv_size = kv_self.buf.size;
const int kv_ntok = kv_self.head;
const size_t kv_buf_size = kv_self.buf.size;
const uint32_t kv_head = kv_self.head;
const uint32_t kv_size = kv_self.size;

data_ctx->write(&kv_size, sizeof(kv_size));
data_ctx->write(&kv_ntok, sizeof(kv_ntok));
data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
data_ctx->write(&kv_head, sizeof(kv_head));
data_ctx->write(&kv_size, sizeof(kv_size));

if (kv_size) {
if (kv_buf_size) {
const size_t elt_size = ggml_element_size(kv_self.k);

ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
ggml_cgraph gf{};

ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
std::vector<uint8_t> kout3d_data(ggml_nbytes(kout3d), 0);
kout3d->data = kout3d_data.data();

ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer);
std::vector<uint8_t> vout3d_data(ggml_nbytes(vout3d), 0);
vout3d->data = vout3d_data.data();

ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
n_embd, kv_ntok, n_layer,
n_embd, kv_head, n_layer,
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);

ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
kv_ntok, n_embd, n_layer,
kv_head, n_embd, n_layer,
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);

ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
Expand All @@ -7314,6 +7315,20 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
data_ctx->write(kout3d_data.data(), kout3d_data.size());
data_ctx->write(vout3d_data.data(), vout3d_data.size());
}

for (uint32_t i = 0; i < kv_size; ++i) {
const auto & cell = kv_self.cells[i];

const llama_pos pos = cell.pos;
const size_t seq_id_size = cell.seq_id.size();

data_ctx->write(&pos, sizeof(pos));
data_ctx->write(&seq_id_size, sizeof(seq_id_size));

for (auto seq_id : cell.seq_id) {
data_ctx->write(&seq_id, sizeof(seq_id));
}
}
}
}

Expand Down Expand Up @@ -7385,34 +7400,36 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
const int n_embd = hparams.n_embd_gqa();
const int n_ctx = cparams.n_ctx;

size_t kv_size;
int kv_ntok;
size_t kv_buf_size;
uint32_t kv_head;
uint32_t kv_size;

memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
memcpy(&kv_ntok, inp, sizeof(kv_ntok)); inp += sizeof(kv_ntok);
memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head);
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);

if (kv_size) {
GGML_ASSERT(kv_self.buf.size == kv_size);
if (kv_buf_size) {
GGML_ASSERT(kv_self.buf.size == kv_buf_size);

const size_t elt_size = ggml_element_size(kv_self.k);

ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
ggml_cgraph gf{};

ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
kin3d->data = (void *) inp;
inp += ggml_nbytes(kin3d);

ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer);
vin3d->data = (void *) inp;
inp += ggml_nbytes(vin3d);

ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
n_embd, kv_ntok, n_layer,
n_embd, kv_head, n_layer,
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);

ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
kv_ntok, n_embd, n_layer,
kv_head, n_embd, n_layer,
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);

ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
Expand All @@ -7422,8 +7439,27 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
ggml_free(cpy_ctx);
}

ctx->kv_self.head = kv_ntok;
ctx->kv_self.head = kv_head;
ctx->kv_self.size = kv_size;

ctx->kv_self.cells.resize(kv_size);

for (uint32_t i = 0; i < kv_size; ++i) {
llama_pos pos;
size_t seq_id_size;

memcpy(&pos, inp, sizeof(pos)); inp += sizeof(pos);
memcpy(&seq_id_size, inp, sizeof(seq_id_size)); inp += sizeof(seq_id_size);

ctx->kv_self.cells[i].pos = pos;

llama_seq_id seq_id;

for (size_t j = 0; j < seq_id_size; ++j) {
memcpy(&seq_id, inp, sizeof(seq_id)); inp += sizeof(seq_id);
ctx->kv_self.cells[i].seq_id.insert(seq_id);
}
}
}

const size_t nread = inp - src;
Expand Down
10 changes: 9 additions & 1 deletion llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'

#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 1
#define LLAMA_SESSION_VERSION 2

#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
Expand Down Expand Up @@ -333,12 +333,16 @@ extern "C" {
"avoid using this, it will be removed in the future, instead - count the tokens in user code");

// Remove all tokens data of cells in [c0, c1)
// c0 < 0 : [0, c1]
// c1 < 0 : [c0, inf)
LLAMA_API void llama_kv_cache_tokens_rm(
struct llama_context * ctx,
int32_t c0,
int32_t c1);

// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
Expand All @@ -347,6 +351,8 @@ extern "C" {

// Copy all tokens that belong to the specified sequence to another sequence
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_cp(
struct llama_context * ctx,
llama_seq_id seq_id_src,
Expand All @@ -361,6 +367,8 @@ extern "C" {

// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// If the KV cache is RoPEd, the KV data is updated accordingly
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_shift(
struct llama_context * ctx,
llama_seq_id seq_id,
Expand Down

0 comments on commit dfde801

Please sign in to comment.