Skip to content

Commit 85c8f78

Browse files
ggerganovMinh141120
authored andcommitted
kv-cache : relax SWA masking condition (ggml-org#14119)
ggml-ci
1 parent 0ca0399 commit 85c8f78

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/llama-kv-cache-unified.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -687,8 +687,8 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
687687
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
688688
// keep track of the max sequence position that we would overwrite with this ubatch
689689
// for non-SWA cache, this would be always empty
690-
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
691-
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
690+
llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
691+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
692692
seq_pos_max_rm[s] = -1;
693693
}
694694

@@ -706,15 +706,15 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
706706

707707
cells.pos_set(head_cur + i, ubatch.pos[i]);
708708

709-
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
710-
cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
709+
for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
710+
cells.seq_add(head_cur + i, ubatch.seq_id[i][j]);
711711
}
712712
}
713713

714714
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
715715
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
716716
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
717-
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
717+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
718718
if (seq_pos_max_rm[s] == -1) {
719719
continue;
720720
}

0 commit comments

Comments
 (0)