diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index b184735566a0a..1a9f4e3159f94 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -582,21 +582,15 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { continue; } - // keep track of what the minimum sequence positions would be if we accept the ubatch - llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES]; - for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { - seq_pos_min[s] = cells.seq_pos_min(s); - } - bool found = true; for (uint32_t i = 0; i < n_tokens; i++) { - const llama_pos pos = ubatch.pos[i]; - const llama_seq_id seq_id = ubatch.seq_id[i][0]; + //const llama_pos pos = ubatch.pos[i]; + //const llama_seq_id seq_id = ubatch.seq_id[i][0]; // can we use this cell? either: // - the cell is empty // - the cell is occupied only by one sequence: - // - mask causally, if the sequence is the same as the one we are inserting + // - (disabled) mask causally, if the sequence is the same as the one we are inserting // - mask SWA, using current max pos for that sequence in the cache // always insert in the cell with minimum pos bool can_use = cells.is_empty(head_cur + i); @@ -604,21 +598,17 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { if (!can_use && cells.seq_count(head_cur + i) == 1) { const llama_pos pos_cell = cells.pos_get(head_cur + i); - // causal mask - if (cells.seq_has(head_cur + i, seq_id)) { - can_use = pos_cell >= pos; - } + // (disabled) causal mask + // note: it's better to purge any "future" tokens beforehand + //if (cells.seq_has(head_cur + i, seq_id)) { + // can_use = pos_cell >= pos; + //} if (!can_use) { const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i); // SWA mask - // note: we insert only in the cell with minimum pos in order to preserve the invariant that - // all positions between [pos_min, pos_max] for each sequence will be present in the cache - // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 - if (pos_cell == seq_pos_min[seq_id_cell] && - is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { - seq_pos_min[seq_id_cell]++; + if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { can_use = true; } } @@ -646,8 +636,22 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { } void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) { + // keep track of the max sequence position that we would overwrite with this ubatch + // for non-SWA cache, this would be always empty + llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES]; + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + seq_pos_max_rm[s] = -1; + } + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { if (!cells.is_empty(head_cur + i)) { + assert(cells.seq_count(head_cur + i) == 1); + + const llama_seq_id seq_id = cells.seq_get(head_cur + i); + const llama_pos pos = cells.pos_get(head_cur + i); + + seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + cells.rm(head_cur + i); } @@ -658,6 +662,22 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch } } + // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence + // will be present in the cache. so we have to purge any position which is less than those we would overwrite + // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (seq_pos_max_rm[s] == -1) { + continue; + } + + if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { + LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", + __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); + + seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); + } + } + // move the head at the end of the slot head = head_cur + ubatch.n_tokens; }