@@ -197,6 +197,18 @@ llama_kv_cache::llama_kv_cache(
197197
198198 const char * LLAMA_KV_CACHE_DEBUG = getenv (" LLAMA_KV_CACHE_DEBUG" );
199199 debug = LLAMA_KV_CACHE_DEBUG ? atoi (LLAMA_KV_CACHE_DEBUG) : 0 ;
200+
201+ const char * LLAMA_SET_ROWS = getenv (" LLAMA_SET_ROWS" );
202+ supports_set_rows = LLAMA_SET_ROWS ? atoi (LLAMA_SET_ROWS) != 0 : supports_set_rows;
203+
204+ if (!supports_set_rows) {
205+ // ref: https://github.com/ggml-org/llama.cpp/pull/14363
206+ GGML_ASSERT (unified && " cannot use non-unified KV cache without ggml_set_rows() support" );
207+ }
208+
209+ if (!supports_set_rows) {
210+ LLAMA_LOG_WARN (" %s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n " , __func__);
211+ }
200212}
201213
202214void llama_kv_cache::clear (bool data) {
@@ -539,8 +551,11 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_
539551 bool success = true ;
540552
541553 for (const auto & ubatch : ubatches) {
554+ // non-continuous slots require support for ggml_set_rows()
555+ const bool cont = supports_set_rows ? false : true ;
556+
542557 // only find a suitable slot for the ubatch. don't modify the cells yet
543- const auto sinfo_new = find_slot (ubatch, true );
558+ const auto sinfo_new = find_slot (ubatch, cont );
544559 if (sinfo_new.empty ()) {
545560 success = false ;
546561 break ;
@@ -961,6 +976,10 @@ uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
961976 return result;
962977}
963978
979+ bool llama_kv_cache::get_supports_set_rows () const {
980+ return supports_set_rows;
981+ }
982+
964983ggml_tensor * llama_kv_cache::get_k (ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
965984 const int32_t ikv = map_layer_ids.at (il);
966985
@@ -1014,26 +1033,36 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
10141033}
10151034
10161035ggml_tensor * llama_kv_cache::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
1017- GGML_UNUSED (sinfo);
1018-
10191036 const int32_t ikv = map_layer_ids.at (il);
10201037
10211038 auto * k = layers[ikv].k ;
10221039
1040+ const int64_t n_embd_k_gqa = k->ne [0 ];
10231041 const int64_t n_tokens = k_cur->ne [2 ];
10241042
10251043 k_cur = ggml_reshape_2d (ctx, k_cur, k->ne [0 ], n_tokens);
10261044
1027- if (k->ne [2 ] > 1 ) {
1028- k = ggml_reshape_2d (ctx, k, k->ne [0 ], k->ne [1 ]*k->ne [2 ]);
1045+ if (k_idxs && supports_set_rows) {
1046+ if (k->ne [2 ] > 1 ) {
1047+ k = ggml_reshape_2d (ctx, k, k->ne [0 ], k->ne [1 ]*k->ne [2 ]);
1048+ }
1049+
1050+ return ggml_set_rows (ctx, k, k_cur, k_idxs);
10291051 }
10301052
1031- return ggml_set_rows (ctx, k, k_cur, k_idxs);
1053+ // TODO: fallback to old ggml_cpy() method for backwards compatibility
1054+ // will be removed when ggml_set_rows() is adopted by all backends
1055+
1056+ GGML_ASSERT (n_stream == 1 && " n_stream > 1 not supported without LLAMA_SET_ROWS" );
1057+
1058+ ggml_tensor * k_view = ggml_view_1d (ctx, k,
1059+ n_tokens*n_embd_k_gqa,
1060+ ggml_row_size (k->type , n_embd_k_gqa)*sinfo.head ());
1061+
1062+ return ggml_cpy (ctx, k_cur, k_view);
10321063}
10331064
10341065ggml_tensor * llama_kv_cache::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
1035- GGML_UNUSED (sinfo);
1036-
10371066 const int32_t ikv = map_layer_ids.at (il);
10381067
10391068 auto * v = layers[ikv].v ;
@@ -1043,25 +1072,48 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
10431072
10441073 v_cur = ggml_reshape_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
10451074
1046- if (!v_trans) {
1047- if (v->ne [2 ] > 1 ) {
1048- v = ggml_reshape_2d (ctx, v, v->ne [0 ], v->ne [1 ]*v->ne [2 ]);
1075+ if (v_idxs && supports_set_rows) {
1076+ if (!v_trans) {
1077+ if (v->ne [2 ] > 1 ) {
1078+ v = ggml_reshape_2d (ctx, v, v->ne [0 ], v->ne [1 ]*v->ne [2 ]);
1079+ }
1080+
1081+ return ggml_set_rows (ctx, v, v_cur, v_idxs);
10491082 }
10501083
1051- return ggml_set_rows (ctx, v, v_cur, v_idxs);
1052- }
1084+ // [TAG_V_CACHE_VARIABLE]
1085+ if (n_embd_v_gqa < v->ne [0 ]) {
1086+ v_cur = ggml_pad (ctx, v_cur, v->ne [0 ] - n_embd_v_gqa, 0 , 0 , 0 );
1087+ }
10531088
1054- // [TAG_V_CACHE_VARIABLE]
1055- if (n_embd_v_gqa < v->ne [0 ]) {
1056- v_cur = ggml_pad (ctx, v_cur, v->ne [0 ] - n_embd_v_gqa, 0 , 0 , 0 );
1089+ // the row becomes a single element
1090+ ggml_tensor * v_view = ggml_reshape_2d (ctx, v, 1 , v->ne [0 ]*v->ne [1 ]*v->ne [2 ]);
1091+
1092+ v_cur = ggml_reshape_2d (ctx, v_cur, 1 , v_cur->ne [0 ]*v_cur->ne [1 ]);
1093+
1094+ return ggml_set_rows (ctx, v_view, v_cur, v_idxs);
10571095 }
10581096
1059- // the row becomes a single element
1060- ggml_tensor * v_view = ggml_reshape_2d (ctx, v, 1 , v-> ne [ 0 ]*v-> ne [ 1 ]*v-> ne [ 2 ]);
1097+ // TODO: fallback to old ggml_cpy() method for backwards compatibility
1098+ // will be removed when ggml_set_rows() is adopted by all backends
10611099
1062- v_cur = ggml_reshape_2d (ctx, v_cur, 1 , v_cur-> ne [ 0 ]*v_cur-> ne [ 1 ] );
1100+ GGML_ASSERT (n_stream == 1 && " n_stream > 1 not supported without LLAMA_SET_ROWS " );
10631101
1064- return ggml_set_rows (ctx, v_view, v_cur, v_idxs);
1102+ ggml_tensor * v_view = nullptr ;
1103+
1104+ if (!v_trans) {
1105+ v_view = ggml_view_1d (ctx, v,
1106+ n_tokens*n_embd_v_gqa,
1107+ ggml_row_size (v->type , n_embd_v_gqa)*sinfo.head ());
1108+ } else {
1109+ v_cur = ggml_transpose (ctx, v_cur);
1110+
1111+ v_view = ggml_view_2d (ctx, v, n_tokens, n_embd_v_gqa,
1112+ (v->ne [1 ] )*ggml_element_size (v),
1113+ (sinfo.head ())*ggml_element_size (v));
1114+ }
1115+
1116+ return ggml_cpy (ctx, v_cur, v_view);
10651117}
10661118
10671119ggml_tensor * llama_kv_cache::build_input_k_idxs (ggml_context * ctx, const llama_ubatch & ubatch) const {
@@ -1091,6 +1143,10 @@ ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama
10911143}
10921144
10931145void llama_kv_cache::set_input_k_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
1146+ if (!supports_set_rows) {
1147+ return ;
1148+ }
1149+
10941150 const uint32_t n_tokens = ubatch->n_tokens ;
10951151 GGML_ASSERT (n_tokens == (int64_t ) sinfo.size ()*sinfo.n_stream ());
10961152
@@ -1107,6 +1163,10 @@ void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ub
11071163}
11081164
11091165void llama_kv_cache::set_input_v_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
1166+ if (!supports_set_rows) {
1167+ return ;
1168+ }
1169+
11101170 const uint32_t n_tokens = ubatch->n_tokens ;
11111171 GGML_ASSERT (n_tokens == (int64_t ) sinfo.size ()*sinfo.n_stream ());
11121172
@@ -1944,6 +2004,10 @@ uint32_t llama_kv_cache_context::get_n_kv() const {
19442004 return n_kv;
19452005}
19462006
2007+ bool llama_kv_cache_context::get_supports_set_rows () const {
2008+ return kv->get_supports_set_rows ();
2009+ }
2010+
19472011ggml_tensor * llama_kv_cache_context::get_k (ggml_context * ctx, int32_t il) const {
19482012 return kv->get_k (ctx, il, n_kv, sinfos[i_cur]);
19492013}
0 commit comments