Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

talk-llama: only copy used KV cache in get / set state #890

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 78 additions & 21 deletions examples/talk-llama/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,9 @@ static bool llama_eval_internal(
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);

// update kv token count
lctx.model.kv_self.n = n_past + N;

// extract logits
{
auto & logits_out = lctx.logits;
Expand Down Expand Up @@ -2386,7 +2389,7 @@ void llama_set_rng_seed(struct llama_context * ctx, int seed) {
ctx->rng.seed(seed);
}

// Returns the size of the state
// Returns the *maximum* size of the state
size_t llama_get_state_size(struct llama_context * ctx) {
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
// for reference, std::mt19937(1337) serializes to 6701 bytes.
Expand Down Expand Up @@ -2465,21 +2468,51 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {

// copy kv cache
{
const size_t kv_size = ctx->model.kv_self.buf.size;
const auto & kv_self = ctx->model.kv_self;
const auto & hparams = ctx->model.hparams;
const int n_layer = hparams.n_layer;
const int n_embd = hparams.n_embd;
const int n_ctx = hparams.n_ctx;

const size_t kv_size = kv_self.buf.size;
const int kv_ntok = llama_get_kv_cache_token_count(ctx);

memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);

if (kv_size) {
memcpy(out, ctx->model.kv_self.buf.addr, kv_size); out += kv_size;
const size_t elt_size = ggml_element_size(kv_self.k);
char buffer[4096];
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
ggml_cgraph gf{};
gf.n_threads = 1;

ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
kout3d->data = out;
out += ggml_nbytes(kout3d);

ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
vout3d->data = out;
out += ggml_nbytes(vout3d);

ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
n_embd, kv_ntok, n_layer,
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);

ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
kv_ntok, n_embd, n_layer,
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);

ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
ggml_graph_compute(cpy_ctx, &gf);
}
}

const size_t written = out - dest;
const size_t expected = llama_get_state_size(ctx);
const size_t max_size = llama_get_state_size(ctx);

LLAMA_ASSERT(written == expected);
LLAMA_ASSERT(written <= max_size);

return written;
}
Expand Down Expand Up @@ -2537,32 +2570,56 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {

// set kv cache
{
const auto & kv_self = ctx->model.kv_self;
const auto & hparams = ctx->model.hparams;
const int n_layer = hparams.n_layer;
const int n_embd = hparams.n_embd;
const int n_ctx = hparams.n_ctx;

size_t kv_size;
int kv_ntok;

memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size);
memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);

if (kv_size) {
LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);
LLAMA_ASSERT(kv_self.buf.size == kv_size);

const size_t elt_size = ggml_element_size(kv_self.k);
char buffer[4096];
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
ggml_cgraph gf{};
gf.n_threads = 1;

ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
kin3d->data = (void *) in;
in += ggml_nbytes(kin3d);

ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
vin3d->data = (void *) in;
in += ggml_nbytes(vin3d);

void * k_data = ctx->model.kv_self.k->data; // remember data pointers
void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
n_embd, kv_ntok, n_layer,
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);

memcpy(ctx->model.kv_self.buf.addr, in, kv_size); in += kv_size;
ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
kv_ntok, n_embd, n_layer,
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);

ctx->model.kv_self.k->data = k_data; // restore correct data pointers
ctx->model.kv_self.v->data = v_data;
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d));
ggml_graph_compute(cpy_ctx, &gf);

}

ctx->model.kv_self.n = kv_ntok;
}

const size_t nread = in - src;
const size_t expected = llama_get_state_size(ctx);
const size_t max_size = llama_get_state_size(ctx);

LLAMA_ASSERT(nread == expected);
LLAMA_ASSERT(nread <= max_size);

return nread;
}
Expand Down Expand Up @@ -2733,14 +2790,14 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi
// restore the context state
{
const size_t n_state_size_cur = file.size - file.tell();
const size_t n_state_size_exp = llama_get_state_size(ctx);
const size_t n_state_size_max = llama_get_state_size(ctx);

if (n_state_size_cur != n_state_size_exp) {
fprintf(stderr, "%s : the state size in session file didn't match! expected %zu, got %zu\n", __func__, n_state_size_exp, n_state_size_cur);
if (n_state_size_cur > n_state_size_max) {
fprintf(stderr, "%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur);
return false;
}

std::vector<uint8_t> state_data(n_state_size_cur);
std::vector<uint8_t> state_data(n_state_size_max);
file.read_raw(state_data.data(), n_state_size_cur);

llama_set_state_data(ctx, state_data.data());
Expand All @@ -2763,12 +2820,12 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi

// save the context state
{
const size_t n_state_size = llama_get_state_size(ctx);
const size_t n_state_size_max = llama_get_state_size(ctx);

std::vector<uint8_t> state_data(n_state_size);
llama_copy_state_data(ctx, state_data.data());
std::vector<uint8_t> state_data(n_state_size_max);
const size_t n_state_size_cur = llama_copy_state_data(ctx, state_data.data());

file.write_raw(state_data.data(), n_state_size);
file.write_raw(state_data.data(), n_state_size_cur);
}

return true;
Expand Down
5 changes: 3 additions & 2 deletions examples/talk-llama/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#define LLAMA_FILE_MAGIC 'ggjt'
#define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml'
#define LLAMA_SESSION_MAGIC 'ggsn'
#define LLAMA_SESSION_VERSION 0
#define LLAMA_SESSION_VERSION 1

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -127,7 +127,8 @@ extern "C" {
// Sets the current rng seed.
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);

// Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
// Returns the maximum size in bytes of the state (rng, logits, embedding
// and kv_cache) - will often be smaller after compacting tokens
LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);

// Copies the state to the specified destination address.
Expand Down