File tree 1 file changed +19
-0
lines changed
1 file changed +19
-0
lines changed Original file line number Diff line number Diff line change @@ -7478,6 +7478,25 @@ void llama_batch_free(struct llama_batch batch) {
7478
7478
int llama_decode (
7479
7479
struct llama_context * ctx,
7480
7480
struct llama_batch batch) {
7481
+ // TODO: temporary solution to auto clear "future" tokens from the cache
7482
+ // ref: https://github.com/ggerganov/llama.cpp/pull/3400
7483
+ if (batch.pos ) {
7484
+ std::map<llama_seq_id, llama_pos> seq_min_pos;
7485
+ for (int i = 0 ; i < batch.n_tokens ; i++) {
7486
+ if (seq_min_pos.count (batch.seq_id [i]) == 0 ) {
7487
+ seq_min_pos[batch.seq_id [i]] = batch.pos [i];
7488
+ } else {
7489
+ seq_min_pos[batch.seq_id [i]] = std::min (seq_min_pos[batch.seq_id [i]], batch.pos [i]);
7490
+ }
7491
+ }
7492
+
7493
+ for (auto & kv : seq_min_pos) {
7494
+ llama_kv_cache_seq_rm (ctx->kv_self , kv.first , kv.second , ctx->cparams .n_ctx );
7495
+ }
7496
+ } else {
7497
+ llama_kv_cache_seq_rm (ctx->kv_self , batch.all_seq_id , batch.all_pos_0 , ctx->cparams .n_ctx );
7498
+ }
7499
+
7481
7500
const int ret = llama_decode_internal (*ctx, batch);
7482
7501
if (ret < 0 ) {
7483
7502
LLAMA_LOG_ERROR (" %s: failed to decode, ret = %d\n " , __func__, ret);
You can’t perform that action at this time.
0 commit comments