diff --git a/common/arg.cpp b/common/arg.cpp index cbca8b5ac5abb..577048c201b76 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1932,13 +1932,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_env("LLAMA_ARG_SWA_FULL")); add_opt(common_arg( - {"--swa-checkpoints"}, "N", - string_format("max number of SWA checkpoints per slot to create (default: %d)\n" - "[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints), + {"--ctx-checkpoints", "--swa-checkpoints"}, "N", + string_format("max number of context checkpoints to create per slot (default: %d)\n" + "[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_ctx_checkpoints), [](common_params & params, int value) { - params.n_swa_checkpoints = value; + params.n_ctx_checkpoints = value; } - ).set_env("LLAMA_ARG_SWA_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER})); + ).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--kv-unified", "-kvu"}, string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n" diff --git a/common/common.h b/common/common.h index 40c6847f32ddb..d33788bd100b2 100644 --- a/common/common.h +++ b/common/common.h @@ -424,7 +424,7 @@ struct common_params { int32_t timeout_write = timeout_read; // http write timeout in seconds int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting - int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot + int32_t n_ctx_checkpoints = 3; // max number of context checkpoints per slot std::string hostname = "127.0.0.1"; std::string public_path = ""; // NOLINT diff --git a/include/llama.h b/include/llama.h index 452d9ec5bf285..8fc3d7db5a917 100644 --- a/include/llama.h +++ b/include/llama.h @@ -543,6 +543,9 @@ extern "C" { // Returns true if the model is recurrent (like Mamba, RWKV, etc.) LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); + // Returns true if the model is hybrid (like Jamba, Granite, etc.) + LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); + // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.) LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model); @@ -791,8 +794,12 @@ extern "C" { size_t n_token_capacity, size_t * n_token_count_out); +// for backwards-compat #define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1 +// work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba) +#define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1 + typedef uint32_t llama_state_seq_flags; LLAMA_API size_t llama_state_seq_get_size_ext( diff --git a/src/llama-kv-cache-iswa.cpp b/src/llama-kv-cache-iswa.cpp index 827302e6d25bd..facba1d004012 100644 --- a/src/llama-kv-cache-iswa.cpp +++ b/src/llama-kv-cache-iswa.cpp @@ -220,7 +220,7 @@ bool llama_kv_cache_iswa::get_can_shift() const { } void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { - if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { kv_base->state_write(io, seq_id, flags); } @@ -228,7 +228,7 @@ void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id } void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { - if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { kv_base->state_read(io, seq_id, flags); } diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index abf652483c202..cb8832a353b11 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -175,17 +175,17 @@ std::map llama_memory_hybrid::memory_breakdo } void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { - GGML_UNUSED(flags); - - mem_attn->state_write(io, seq_id); - mem_recr->state_write(io, seq_id); + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { + mem_attn->state_write(io, seq_id, flags); + } + mem_recr->state_write(io, seq_id, flags); } void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { - GGML_UNUSED(flags); - - mem_attn->state_read(io, seq_id); - mem_recr->state_read(io, seq_id); + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { + mem_attn->state_read(io, seq_id, flags); + } + mem_recr->state_read(io, seq_id, flags); } llama_kv_cache * llama_memory_hybrid::get_mem_attn() const { diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 44645fcdd2d48..e23e74982b278 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -136,6 +136,7 @@ void llama_memory_recurrent::clear(bool data) { } bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + //printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1); uint32_t new_head = size; if (p0 < 0) { @@ -156,7 +157,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (tail_id >= 0) { const auto & cell = cells[tail_id]; // partial intersection is invalid - if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) { + //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n"); return false; } // invalidate tails which will be cleared @@ -167,6 +169,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos } else { // seq_id is negative, then the range should include everything or nothing if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: `seq_id` is negative, so returning false\n"); return false; } } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index cce77a854bb2f..4c2d481a41d42 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -20151,6 +20151,10 @@ bool llama_model_is_recurrent(const llama_model * model) { return llm_arch_is_recurrent(model->arch); } +bool llama_model_is_hybrid(const llama_model * model) { + return llm_arch_is_hybrid(model->arch); +} + bool llama_model_is_diffusion(const llama_model * model) { return llm_arch_is_diffusion(model->arch); } diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 6062904a8c7c0..a21147613db00 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -764,7 +764,7 @@ struct completion_token_output { } }; -struct swa_checkpoint { +struct ctx_checkpoint { llama_pos pos_min; llama_pos pos_max; @@ -1460,7 +1460,7 @@ struct server_slot { std::vector generated_token_probs; - std::vector swa_checkpoints; + std::vector ctx_checkpoints; bool has_next_token = true; bool has_new_line = false; @@ -3541,7 +3541,11 @@ struct server_context { slot.n_past = 0; } - const auto n_swa = llama_model_n_swa(model); + // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1 + const auto n_swa = std::max(1, llama_model_n_swa(model)); + + // the largest pos_min required for a checkpoint to be useful + const auto pos_min_thold = std::max(0, slot.n_past - n_swa); if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) { const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); @@ -3550,66 +3554,62 @@ struct server_context { GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); } - const auto pos_min_thold = std::max(0, slot.n_past - n_swa); - if (pos_min > pos_min_thold) { SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa); - // search for a SWA checkpoint + // search for a context checkpoint const auto it = std::find_if( - slot.swa_checkpoints.rbegin(), - slot.swa_checkpoints.rend(), + slot.ctx_checkpoints.rbegin(), + slot.ctx_checkpoints.rend(), [&](const auto & cur) { - return cur.pos_min <= pos_min_thold; + // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] + return cur.pos_min < pos_min_thold; } ); - bool do_reset = it == slot.swa_checkpoints.rend(); + bool do_reset = it == slot.ctx_checkpoints.rend(); + //printf("[DEBUG] `do_reset` was set to `%s`\n", do_reset ? "true" : "false"); if (!do_reset) { - // restore the checkpoint - const size_t swa_size = it->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY); + // restore the context checkpoint + const size_t ctx_checkpoint_size = it->data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), ctx_checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - if (n != swa_size) { - SLT_ERR(slot, "failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024); + if (n != ctx_checkpoint_size) { + SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024); do_reset = true; + //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); } else { - slot.n_past = std::min(slot.n_past, it->pos_max); - - SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024); + slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max)); + SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024); } } if (do_reset) { - SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n", + SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n", "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); - slot.n_past = 0; - slot.swa_checkpoints.clear(); } } } - if (n_swa > 0) { - const auto pos_min_thold = std::max(0, slot.n_past - n_swa); - + { // erase any checkpoints with pos_min > pos_min_thold - for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) { - const auto & cur = slot.swa_checkpoints[i]; + for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) { + const auto & cur = slot.ctx_checkpoints[i]; if (cur.pos_min > pos_min_thold) { - slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i); - - SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024); + slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i); } } } } + // [TAG_PROMPT_LOGITS] if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { - SLT_WRN(slot, "need to evaluate at least 1 token for each active slot, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens); - + SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens); slot.n_past--; + SLT_WRN(slot, "n_past was set to %d\n", slot.n_past); } slot.n_prompt_tokens_cache = slot.n_past; @@ -3623,9 +3623,9 @@ struct server_context { } } - // keep only the common part + // truncate any tokens that are beyond n_past for this slot if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) { - // could not partially delete (likely using a non-Transformer model) + SLT_WRN(slot, "failed to truncate tokens beyond n_past = %d\n", slot.n_past); llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); // there is no common part left @@ -3633,7 +3633,7 @@ struct server_context { slot.n_prompt_tokens_cache = 0; } - SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); + SLT_INF(slot, "n_past = %d, memory_seq_rm [%d, end)\n", slot.n_past, slot.n_past); // remove the non-common part from the cache slot.cache_tokens.keep_first(slot.n_past); @@ -3854,37 +3854,38 @@ struct server_context { // prompt evaluated for next-token prediction slot.state = SLOT_STATE_GENERATING; - // make a checkpoint with the SWA memory - // checkpoints are needed only if we are not using "--swa-full" - if (llama_model_n_swa(model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0) { - if (slot.swa_checkpoints.size() >= (size_t) params_base.n_swa_checkpoints) { - { - const auto & cur = slot.swa_checkpoints.back(); - - SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", - cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - } - - slot.swa_checkpoints.erase(slot.swa_checkpoints.begin()); + // make a checkpoint of the parts of the memory that cannot be rolled back. + // checkpoints are created only if: + // - the model uses SWA and we are not using `swa_full` + // - the model architecture is marked as recurrent or hybrid + // + // TODO: try to make this conditional on the context or the memory module, instead of the model type + const bool do_checkpoint = + (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) || + (llama_model_n_swa(model) > 0 && !params_base.swa_full); + + if (do_checkpoint && params_base.n_ctx_checkpoints > 0) { + while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { + // make room for the new checkpoint, if needed + const auto & cur = slot.ctx_checkpoints.front(); + SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + + slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin()); } - const size_t swa_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY); + const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{ + auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{ /*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id), /*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id), - /*.data = */ std::vector(swa_size), + /*.data = */ std::vector(checkpoint_size), }); - llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY); - - float size_total = 0.0f; - for (const auto & checkpoint : slot.swa_checkpoints) { - size_total += (float) checkpoint.data.size() / 1024 / 1024; - } + llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d (%.3f MiB)\n", - cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024, (int) slot.swa_checkpoints.size(), params_base.n_swa_checkpoints, size_total); + SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots