@@ -771,8 +771,8 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
771771            GGML_ASSERT (ubatch.seq_id   [s*n_tokens][0 ] == seq_id);
772772        }
773773
774-         res.s0  = std::min<llama_seq_id >(res.s0 , seq_to_stream[seq_id]);
775-         res.s1  = std::max<llama_seq_id >(res.s1 , seq_to_stream[seq_id]);
774+         res.s0  = std::min<uint32_t >(res.s0 , seq_to_stream[seq_id]);
775+         res.s1  = std::max<uint32_t >(res.s1 , seq_to_stream[seq_id]);
776776
777777        res.strm [s] = seq_to_stream[seq_id];
778778        res.idxs [s].reserve (n_tokens);
@@ -964,11 +964,11 @@ bool llama_kv_cache::get_has_shift() const {
964964    return  result;
965965}
966966
967- uint32_t  llama_kv_cache::get_n_kv () const  {
967+ uint32_t  llama_kv_cache::get_n_kv (const  slot_info & sinfo ) const  {
968968    uint32_t  result = 0 ;
969969
970-     for  (uint32_t  s = 0 ; s < n_stream; ++s) {
971-         const  auto  & cells = v_cells[s ];
970+     for  (uint32_t  s = 0 ; s < sinfo. n_stream () ; ++s) {
971+         const  auto  & cells = v_cells[sinfo. strm [s] ];
972972
973973        result = std::max (std::min (cells.size (), std::max (n_pad, GGML_PAD (cells.used_max_p1 (), n_pad))), result);
974974    }
@@ -1017,18 +1017,18 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
10171017        //  note: v->nb[1] <= v->nb[2]
10181018        return  ggml_view_4d (ctx, v,
10191019                hparams.n_embd_head_v , hparams.n_head_kv (il), n_kv, ns,
1020-                 ggml_row_size (v->type , hparams.n_embd_head_v ),             //  v->nb[1]
1021-                 ggml_row_size (v->type , n_embd_v_gqa),         //  v->nb[2]
1022-                 ggml_row_size (v->type , n_embd_v_gqa*kv_size), //  v->nb[3]
1020+                 ggml_row_size (v->type , hparams.n_embd_head_v ),          //  v->nb[1]
1021+                 ggml_row_size (v->type , n_embd_v_gqa),                    //  v->nb[2]
1022+                 ggml_row_size (v->type , n_embd_v_gqa*kv_size),            //  v->nb[3]
10231023                ggml_row_size (v->type , n_embd_v_gqa*kv_size)*sinfo.s0 );
10241024    }
10251025
10261026    //  note: v->nb[1] > v->nb[2]
10271027    return  ggml_view_4d (ctx, v,
10281028            n_kv, hparams.n_head_kv (il), hparams.n_embd_head_v , ns,
1029-             ggml_row_size (v->type , kv_size*hparams.n_embd_head_v ),     //  v->nb[1]
1030-             ggml_row_size (v->type , kv_size),                           //  v->nb[2]
1031-             ggml_row_size (v->type , kv_size*n_embd_v_gqa), //  v->nb[3]
1029+             ggml_row_size (v->type , kv_size*hparams.n_embd_head_v ),  //  v->nb[1]
1030+             ggml_row_size (v->type , kv_size),                        //  v->nb[2]
1031+             ggml_row_size (v->type , kv_size*n_embd_v_gqa),            //  v->nb[3]
10321032            ggml_row_size (v->type , kv_size*n_embd_v_gqa)*sinfo.s0 );
10331033}
10341034
@@ -1985,8 +1985,7 @@ bool llama_kv_cache_context::apply() {
19851985    }
19861986
19871987    kv->apply_ubatch (sinfos[i_cur], ubatches[i_cur]);
1988- 
1989-     n_kv = kv->get_n_kv ();
1988+     n_kv = kv->get_n_kv (sinfos[i_cur]);
19901989
19911990    return  true ;
19921991}
0 commit comments