From f5f0e8eb0353e6833aa79ce5fc50e6cb59748519 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 6 Oct 2025 10:33:48 +0300 Subject: [PATCH] server : improve context checkpoint logic --- src/llama-memory-recurrent.cpp | 5 +- tools/server/server.cpp | 86 ++++++++++++++++++++-------------- 2 files changed, 56 insertions(+), 35 deletions(-) diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 9402d9cb8df9f..d67f5a5f47b87 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -861,9 +861,12 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { if (dest_seq_id != -1) { // single sequence - seq_rm(dest_seq_id, -1, -1); + if (cell_count == 0) { + return true; + } + llama_batch_allocr balloc(hparams.n_pos_per_embd()); llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index a21147613db00..5b288e65c766b 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3676,6 +3676,20 @@ struct server_context { alora_disabled_id = enabled_loras[0]; } + bool do_checkpoint = params_base.n_ctx_checkpoints > 0; + + // make a checkpoint of the parts of the memory that cannot be rolled back. + // checkpoints are created only if: + // - the model uses SWA and we are not using `swa_full` + // - the model architecture is marked as recurrent or hybrid + // + // TODO: try to make this conditional on the context or the memory module, instead of the model type + do_checkpoint = do_checkpoint && ( + llama_model_is_recurrent(model) || + llama_model_is_hybrid(model) || + (llama_model_n_swa(model) > 0 && !params_base.swa_full) + ); + // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { // get next token to process @@ -3700,6 +3714,11 @@ struct server_context { slot.n_prompt_tokens_processed++; slot.n_past++; + + // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. + if (do_checkpoint && slot.n_prompt_tokens - slot.n_past == 64) { + break; + } } // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); @@ -3730,6 +3749,39 @@ struct server_context { slot.i_batch = batch.n_tokens - 1; SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); + + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); + + // no need for empty or small checkpoints + do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); + + // no need to create checkpoints that are too close together + do_checkpoint = do_checkpoint && (slot.ctx_checkpoints.empty() || pos_max > slot.ctx_checkpoints.back().pos_max + 64); + + if (do_checkpoint) { + while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { + // make room for the new checkpoint, if needed + const auto & cur = slot.ctx_checkpoints.front(); + SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + + slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin()); + } + + const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{ + /*.pos_min = */ pos_min, + /*.pos_max = */ pos_max, + /*.data = */ std::vector(checkpoint_size), + }); + + llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + } } } @@ -3853,40 +3905,6 @@ struct server_context { // prompt evaluated for next-token prediction slot.state = SLOT_STATE_GENERATING; - - // make a checkpoint of the parts of the memory that cannot be rolled back. - // checkpoints are created only if: - // - the model uses SWA and we are not using `swa_full` - // - the model architecture is marked as recurrent or hybrid - // - // TODO: try to make this conditional on the context or the memory module, instead of the model type - const bool do_checkpoint = - (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) || - (llama_model_n_swa(model) > 0 && !params_base.swa_full); - - if (do_checkpoint && params_base.n_ctx_checkpoints > 0) { - while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { - // make room for the new checkpoint, if needed - const auto & cur = slot.ctx_checkpoints.front(); - SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - - slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin()); - } - - const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{ - /*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id), - /*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id), - /*.data = */ std::vector(checkpoint_size), - }); - - llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots }