Skip to content

Commit 0f332a9

Browse files
committed
llama : temp fix for clearing "future" tokens from the KV cache
1 parent 6a9fe3d commit 0f332a9

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

llama.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -7478,6 +7478,25 @@ void llama_batch_free(struct llama_batch batch) {
74787478
int llama_decode(
74797479
struct llama_context * ctx,
74807480
struct llama_batch batch) {
7481+
// TODO: temporary solution to auto clear "future" tokens from the cache
7482+
// ref: https://github.com/ggerganov/llama.cpp/pull/3400
7483+
if (batch.pos) {
7484+
std::map<llama_seq_id, llama_pos> seq_min_pos;
7485+
for (int i = 0; i < batch.n_tokens; i++) {
7486+
if (seq_min_pos.count(batch.seq_id[i]) == 0) {
7487+
seq_min_pos[batch.seq_id[i]] = batch.pos[i];
7488+
} else {
7489+
seq_min_pos[batch.seq_id[i]] = std::min(seq_min_pos[batch.seq_id[i]], batch.pos[i]);
7490+
}
7491+
}
7492+
7493+
for (auto & kv : seq_min_pos) {
7494+
llama_kv_cache_seq_rm(ctx->kv_self, kv.first, kv.second, ctx->cparams.n_ctx);
7495+
}
7496+
} else {
7497+
llama_kv_cache_seq_rm(ctx->kv_self, batch.all_seq_id, batch.all_pos_0, ctx->cparams.n_ctx);
7498+
}
7499+
74817500
const int ret = llama_decode_internal(*ctx, batch);
74827501
if (ret < 0) {
74837502
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);

0 commit comments

Comments
 (0)