Skip to content

Commit db193bf

Browse files
committed
Revert "kv-cache : remove LLAMA_SET_ROWS checks (ggml-org#15505) ggml-ci"
1 parent ea6846b commit db193bf

File tree

5 files changed

+124
-20
lines changed

5 files changed

+124
-20
lines changed

src/llama-context.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,16 @@ llama_context::llama_context(
102102
cparams.op_offload = params.op_offload;
103103
cparams.kv_unified = params.kv_unified;
104104

105+
{
106+
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
107+
supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : supports_set_rows;
108+
109+
if (!supports_set_rows && !cparams.kv_unified) {
110+
LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
111+
cparams.kv_unified = true;
112+
}
113+
}
114+
105115
{
106116
const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
107117
graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable;
@@ -880,6 +890,12 @@ int llama_context::encode(const llama_batch & batch_inp) {
880890
}
881891
}
882892

893+
if (!supports_set_rows) {
894+
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
895+
// overlap with device computation.
896+
ggml_backend_sched_reset(sched.get());
897+
}
898+
883899
// TODO: hacky solution
884900
if (model.arch == LLM_ARCH_T5 && t_embd) {
885901
//cross.t_embd = t_embd;
@@ -1210,6 +1226,12 @@ int llama_context::decode(const llama_batch & batch_inp) {
12101226
// wait for the computation to finish (automatically done when obtaining the model output)
12111227
//synchronize();
12121228

1229+
if (!supports_set_rows) {
1230+
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1231+
// overlap with device computation.
1232+
ggml_backend_sched_reset(sched.get());
1233+
}
1234+
12131235
return 0;
12141236
}
12151237

src/llama-context.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,10 @@ struct llama_context {
283283

284284
bool has_evaluated_once = false;
285285

286+
// env: LLAMA_SET_ROWS (temporary)
287+
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
288+
bool supports_set_rows = true;
289+
286290
// env: LLAMA_GRAPH_REUSE_DISABLE
287291
bool graph_reuse_disable = false;
288292

src/llama-graph.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,8 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
314314
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
315315
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
316316

317+
res &= mctx->get_supports_set_rows(); // TODO: tmp
318+
317319
return res;
318320
}
319321

@@ -348,6 +350,8 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
348350
res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
349351
res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
350352

353+
res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp
354+
351355
return res;
352356
}
353357

src/llama-kv-cache.cpp

Lines changed: 84 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,18 @@ llama_kv_cache::llama_kv_cache(
197197

198198
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
199199
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
200+
201+
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
202+
supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : supports_set_rows;
203+
204+
if (!supports_set_rows) {
205+
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
206+
GGML_ASSERT(unified && "cannot use non-unified KV cache without ggml_set_rows() support");
207+
}
208+
209+
if (!supports_set_rows) {
210+
LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
211+
}
200212
}
201213

202214
void llama_kv_cache::clear(bool data) {
@@ -539,8 +551,11 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_
539551
bool success = true;
540552

541553
for (const auto & ubatch : ubatches) {
554+
// non-continuous slots require support for ggml_set_rows()
555+
const bool cont = supports_set_rows ? false : true;
556+
542557
// only find a suitable slot for the ubatch. don't modify the cells yet
543-
const auto sinfo_new = find_slot(ubatch, true);
558+
const auto sinfo_new = find_slot(ubatch, cont);
544559
if (sinfo_new.empty()) {
545560
success = false;
546561
break;
@@ -961,6 +976,10 @@ uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
961976
return result;
962977
}
963978

979+
bool llama_kv_cache::get_supports_set_rows() const {
980+
return supports_set_rows;
981+
}
982+
964983
ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
965984
const int32_t ikv = map_layer_ids.at(il);
966985

@@ -1014,26 +1033,36 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
10141033
}
10151034

10161035
ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
1017-
GGML_UNUSED(sinfo);
1018-
10191036
const int32_t ikv = map_layer_ids.at(il);
10201037

10211038
auto * k = layers[ikv].k;
10221039

1040+
const int64_t n_embd_k_gqa = k->ne[0];
10231041
const int64_t n_tokens = k_cur->ne[2];
10241042

10251043
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
10261044

1027-
if (k->ne[2] > 1) {
1028-
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
1045+
if (k_idxs && supports_set_rows) {
1046+
if (k->ne[2] > 1) {
1047+
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
1048+
}
1049+
1050+
return ggml_set_rows(ctx, k, k_cur, k_idxs);
10291051
}
10301052

1031-
return ggml_set_rows(ctx, k, k_cur, k_idxs);
1053+
// TODO: fallback to old ggml_cpy() method for backwards compatibility
1054+
// will be removed when ggml_set_rows() is adopted by all backends
1055+
1056+
GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
1057+
1058+
ggml_tensor * k_view = ggml_view_1d(ctx, k,
1059+
n_tokens*n_embd_k_gqa,
1060+
ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
1061+
1062+
return ggml_cpy(ctx, k_cur, k_view);
10321063
}
10331064

10341065
ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
1035-
GGML_UNUSED(sinfo);
1036-
10371066
const int32_t ikv = map_layer_ids.at(il);
10381067

10391068
auto * v = layers[ikv].v;
@@ -1043,25 +1072,48 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
10431072

10441073
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
10451074

1046-
if (!v_trans) {
1047-
if (v->ne[2] > 1) {
1048-
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
1075+
if (v_idxs && supports_set_rows) {
1076+
if (!v_trans) {
1077+
if (v->ne[2] > 1) {
1078+
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
1079+
}
1080+
1081+
return ggml_set_rows(ctx, v, v_cur, v_idxs);
10491082
}
10501083

1051-
return ggml_set_rows(ctx, v, v_cur, v_idxs);
1052-
}
1084+
// [TAG_V_CACHE_VARIABLE]
1085+
if (n_embd_v_gqa < v->ne[0]) {
1086+
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
1087+
}
10531088

1054-
// [TAG_V_CACHE_VARIABLE]
1055-
if (n_embd_v_gqa < v->ne[0]) {
1056-
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
1089+
// the row becomes a single element
1090+
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
1091+
1092+
v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
1093+
1094+
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
10571095
}
10581096

1059-
// the row becomes a single element
1060-
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
1097+
// TODO: fallback to old ggml_cpy() method for backwards compatibility
1098+
// will be removed when ggml_set_rows() is adopted by all backends
10611099

1062-
v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
1100+
GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
10631101

1064-
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
1102+
ggml_tensor * v_view = nullptr;
1103+
1104+
if (!v_trans) {
1105+
v_view = ggml_view_1d(ctx, v,
1106+
n_tokens*n_embd_v_gqa,
1107+
ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
1108+
} else {
1109+
v_cur = ggml_transpose(ctx, v_cur);
1110+
1111+
v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
1112+
(v->ne[1] )*ggml_element_size(v),
1113+
(sinfo.head())*ggml_element_size(v));
1114+
}
1115+
1116+
return ggml_cpy(ctx, v_cur, v_view);
10651117
}
10661118

10671119
ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
@@ -1091,6 +1143,10 @@ ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama
10911143
}
10921144

10931145
void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
1146+
if (!supports_set_rows) {
1147+
return;
1148+
}
1149+
10941150
const uint32_t n_tokens = ubatch->n_tokens;
10951151
GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
10961152

@@ -1107,6 +1163,10 @@ void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ub
11071163
}
11081164

11091165
void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
1166+
if (!supports_set_rows) {
1167+
return;
1168+
}
1169+
11101170
const uint32_t n_tokens = ubatch->n_tokens;
11111171
GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
11121172

@@ -1944,6 +2004,10 @@ uint32_t llama_kv_cache_context::get_n_kv() const {
19442004
return n_kv;
19452005
}
19462006

2007+
bool llama_kv_cache_context::get_supports_set_rows() const {
2008+
return kv->get_supports_set_rows();
2009+
}
2010+
19472011
ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
19482012
return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
19492013
}

src/llama-kv-cache.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ class llama_kv_cache : public llama_memory_i {
141141

142142
uint32_t get_n_kv(const slot_info & sinfo) const;
143143

144+
// TODO: temporary
145+
bool get_supports_set_rows() const;
146+
144147
// get views of the current state of the cache
145148
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
146149
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
@@ -212,6 +215,10 @@ class llama_kv_cache : public llama_memory_i {
212215
// env: LLAMA_KV_CACHE_DEBUG
213216
int debug = 0;
214217

218+
// env: LLAMA_SET_ROWS (temporary)
219+
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
220+
bool supports_set_rows = true;
221+
215222
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
216223

217224
std::vector<ggml_context_ptr> ctxs;
@@ -311,6 +318,9 @@ class llama_kv_cache_context : public llama_memory_context_i {
311318

312319
uint32_t get_n_kv() const;
313320

321+
// TODO: temporary
322+
bool get_supports_set_rows() const;
323+
314324
// get views of the current state of the cache
315325
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
316326
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;

0 commit comments

Comments
 (0)