Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1932,13 +1932,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_env("LLAMA_ARG_SWA_FULL"));
add_opt(common_arg(
{"--swa-checkpoints"}, "N",
string_format("max number of SWA checkpoints per slot to create (default: %d)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints),
{"--ctx-checkpoints", "--swa-checkpoints"}, "N",
string_format("max number of context checkpoints to create per slot (default: %d)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_ctx_checkpoints),
[](common_params & params, int value) {
params.n_swa_checkpoints = value;
params.n_ctx_checkpoints = value;
}
).set_env("LLAMA_ARG_SWA_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--kv-unified", "-kvu"},
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
Expand Down
2 changes: 1 addition & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ struct common_params {
int32_t timeout_write = timeout_read; // http write timeout in seconds
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot
int32_t n_ctx_checkpoints = 3; // max number of context checkpoints per slot

std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT
Expand Down
7 changes: 7 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,9 @@ extern "C" {
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);

// Returns true if the model is hybrid (like Jamba, Granite, etc.)
LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);

// Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);

Expand Down Expand Up @@ -791,8 +794,12 @@ extern "C" {
size_t n_token_capacity,
size_t * n_token_count_out);

// for backwards-compat
#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1

// work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba)
#define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1

typedef uint32_t llama_state_seq_flags;

LLAMA_API size_t llama_state_seq_get_size_ext(
Expand Down
4 changes: 2 additions & 2 deletions src/llama-kv-cache-iswa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,15 @@ bool llama_kv_cache_iswa::get_can_shift() const {
}

void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
kv_base->state_write(io, seq_id, flags);
}

kv_swa->state_write(io, seq_id, flags);
}

void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
kv_base->state_read(io, seq_id, flags);
}

Expand Down
16 changes: 8 additions & 8 deletions src/llama-memory-hybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,17 @@ std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid::memory_breakdo
}

void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);

mem_attn->state_write(io, seq_id);
mem_recr->state_write(io, seq_id);
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
mem_attn->state_write(io, seq_id, flags);
}
mem_recr->state_write(io, seq_id, flags);
}

void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);

mem_attn->state_read(io, seq_id);
mem_recr->state_read(io, seq_id);
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
mem_attn->state_read(io, seq_id, flags);
}
mem_recr->state_read(io, seq_id, flags);
}

llama_kv_cache * llama_memory_hybrid::get_mem_attn() const {
Expand Down
5 changes: 4 additions & 1 deletion src/llama-memory-recurrent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ void llama_memory_recurrent::clear(bool data) {
}

bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
//printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1);
uint32_t new_head = size;

if (p0 < 0) {
Expand All @@ -156,7 +157,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
if (tail_id >= 0) {
const auto & cell = cells[tail_id];
// partial intersection is invalid
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) {
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
return false;
}
// invalidate tails which will be cleared
Expand All @@ -167,6 +169,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
} else {
// seq_id is negative, then the range should include everything or nothing
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: `seq_id` is negative, so returning false\n");
return false;
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20151,6 +20151,10 @@ bool llama_model_is_recurrent(const llama_model * model) {
return llm_arch_is_recurrent(model->arch);
}

bool llama_model_is_hybrid(const llama_model * model) {
return llm_arch_is_hybrid(model->arch);
}

bool llama_model_is_diffusion(const llama_model * model) {
return llm_arch_is_diffusion(model->arch);
}
Expand Down
115 changes: 58 additions & 57 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ struct completion_token_output {
}
};

struct swa_checkpoint {
struct ctx_checkpoint {
llama_pos pos_min;
llama_pos pos_max;

Expand Down Expand Up @@ -1460,7 +1460,7 @@ struct server_slot {

std::vector<completion_token_output> generated_token_probs;

std::vector<swa_checkpoint> swa_checkpoints;
std::vector<ctx_checkpoint> ctx_checkpoints;

bool has_next_token = true;
bool has_new_line = false;
Expand Down Expand Up @@ -3541,7 +3541,11 @@ struct server_context {
slot.n_past = 0;
}

const auto n_swa = llama_model_n_swa(model);
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
const auto n_swa = std::max(1, llama_model_n_swa(model));

// the largest pos_min required for a checkpoint to be useful
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);

if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
Expand All @@ -3550,66 +3554,62 @@ struct server_context {
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
}

const auto pos_min_thold = std::max(0, slot.n_past - n_swa);

if (pos_min > pos_min_thold) {
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);

// search for a SWA checkpoint
// search for a context checkpoint
const auto it = std::find_if(
slot.swa_checkpoints.rbegin(),
slot.swa_checkpoints.rend(),
slot.ctx_checkpoints.rbegin(),
slot.ctx_checkpoints.rend(),
[&](const auto & cur) {
return cur.pos_min <= pos_min_thold;
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
return cur.pos_min < pos_min_thold;
}
);

bool do_reset = it == slot.swa_checkpoints.rend();
bool do_reset = it == slot.ctx_checkpoints.rend();
//printf("[DEBUG] `do_reset` was set to `%s`\n", do_reset ? "true" : "false");

if (!do_reset) {
// restore the checkpoint
const size_t swa_size = it->data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
// restore the context checkpoint
const size_t ctx_checkpoint_size = it->data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), ctx_checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);

if (n != swa_size) {
SLT_ERR(slot, "failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
if (n != ctx_checkpoint_size) {
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
do_reset = true;
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
} else {
slot.n_past = std::min(slot.n_past, it->pos_max);

SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
}
}

if (do_reset) {
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");

slot.n_past = 0;
slot.swa_checkpoints.clear();
}
}
}

if (n_swa > 0) {
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);

{
// erase any checkpoints with pos_min > pos_min_thold
for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) {
const auto & cur = slot.swa_checkpoints[i];
for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) {
const auto & cur = slot.ctx_checkpoints[i];
if (cur.pos_min > pos_min_thold) {
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i);

SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i);
}
}
}
}

// [TAG_PROMPT_LOGITS]
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);

SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens);
slot.n_past--;
SLT_WRN(slot, "n_past was set to %d\n", slot.n_past);
}

slot.n_prompt_tokens_cache = slot.n_past;
Expand All @@ -3623,17 +3623,17 @@ struct server_context {
}
}

// keep only the common part
// truncate any tokens that are beyond n_past for this slot
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) {
// could not partially delete (likely using a non-Transformer model)
SLT_WRN(slot, "failed to truncate tokens beyond n_past = %d\n", slot.n_past);
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);

// there is no common part left
slot.n_past = 0;
slot.n_prompt_tokens_cache = 0;
}

SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
SLT_INF(slot, "n_past = %d, memory_seq_rm [%d, end)\n", slot.n_past, slot.n_past);

// remove the non-common part from the cache
slot.cache_tokens.keep_first(slot.n_past);
Expand Down Expand Up @@ -3854,37 +3854,38 @@ struct server_context {
// prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING;

// make a checkpoint with the SWA memory
// checkpoints are needed only if we are not using "--swa-full"
if (llama_model_n_swa(model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0) {
if (slot.swa_checkpoints.size() >= (size_t) params_base.n_swa_checkpoints) {
{
const auto & cur = slot.swa_checkpoints.back();

SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
}

slot.swa_checkpoints.erase(slot.swa_checkpoints.begin());
// make a checkpoint of the parts of the memory that cannot be rolled back.
// checkpoints are created only if:
// - the model uses SWA and we are not using `swa_full`
// - the model architecture is marked as recurrent or hybrid
//
// TODO: try to make this conditional on the context or the memory module, instead of the model type
const bool do_checkpoint =
(llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) ||
(llama_model_n_swa(model) > 0 && !params_base.swa_full);

if (do_checkpoint && params_base.n_ctx_checkpoints > 0) {
while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.ctx_checkpoints.front();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);

slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
}

const size_t swa_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);

auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{
auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
/*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
/*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
/*.data = */ std::vector<uint8_t>(swa_size),
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});

llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);

float size_total = 0.0f;
for (const auto & checkpoint : slot.swa_checkpoints) {
size_total += (float) checkpoint.data.size() / 1024 / 1024;
}
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);

SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d (%.3f MiB)\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024, (int) slot.swa_checkpoints.size(), params_base.n_swa_checkpoints, size_total);
SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
(int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
}
} else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots
Expand Down
Loading