@@ -680,14 +680,14 @@ struct llm_graph_context {
680680 //
681681
682682 ggml_tensor * build_attn_mha (
683- ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
684- ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
685- ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
686- ggml_tensor * kq_b,
687- ggml_tensor * kq_mask,
688- ggml_tensor * sinks,
689- ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
690- float kq_scale) const ;
683+ ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
684+ ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
685+ ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
686+ ggml_tensor * kq_b,
687+ ggml_tensor * kq_mask,
688+ ggml_tensor * sinks, // [n_head_q]
689+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
690+ float kq_scale) const ;
691691
692692 llm_graph_input_attn_no_cache * build_attn_inp_no_cache () const ;
693693
@@ -699,6 +699,7 @@ struct llm_graph_context {
699699 ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
700700 ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
701701 ggml_tensor * kq_b,
702+ ggml_tensor * sinks, // [n_head_q]
702703 ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
703704 float kq_scale,
704705 int il) const ;
@@ -713,6 +714,7 @@ struct llm_graph_context {
713714 ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
714715 ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
715716 ggml_tensor * kq_b,
717+ ggml_tensor * sinks, // [n_head_q]
716718 ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
717719 float kq_scale,
718720 int il) const ;
@@ -728,21 +730,8 @@ struct llm_graph_context {
728730 ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
729731 ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
730732 ggml_tensor * kq_b,
731- ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
732- float kq_scale,
733- int il) const ;
734-
735- // TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
736- ggml_tensor * build_attn_with_sinks (
737- llm_graph_input_attn_kv_iswa * inp,
738- ggml_tensor * wo,
739- ggml_tensor * wo_b,
740- ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
741- ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
742- ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
743- ggml_tensor * kq_b,
744- ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
745733 ggml_tensor * sinks, // [n_head_q]
734+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
746735 float kq_scale,
747736 int il) const ;
748737
@@ -756,6 +745,7 @@ struct llm_graph_context {
756745 ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
757746 ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
758747 ggml_tensor * kq_b,
748+ ggml_tensor * sinks, // [n_head_q]
759749 ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
760750 float kq_scale,
761751 int il) const ;
0 commit comments