@@ -191,21 +191,23 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
191191 GGML_ASSERT (ggml_backend_buffer_is_host (cls->buffer ));
192192
193193 uint32_t * data = (uint32_t *) cls->data ;
194- memset (cls->data , 0 , n_seqs_unq* ggml_element_size (cls));
194+ memset (cls->data , 0 , n_tokens * ggml_element_size (cls));
195195
196- std::vector<int > last_pos (n_seqs_unq , -1 );
197- std::vector<int > last_row (n_seqs_unq , -1 );
196+ std::vector<int > last_pos (n_tokens , -1 );
197+ std::vector<int > last_row (n_tokens , -1 );
198198
199- for (int i = 0 ; i < n_tokens ; ++i ) {
200- const llama_pos pos = ubatch->pos [i ];
199+ for (int s = 0 ; s < n_seqs ; ++s ) {
200+ const llama_seq_id seq_id = ubatch->seq_id [s][ 0 ];
201201
202- for (int s = 0 ; s < ubatch->n_seq_id [i]; ++s) {
203- const llama_seq_id seq_id = ubatch->seq_id [i][s];
204- const int32_t seq_idx = ubatch->seq_idx [seq_id];
202+ // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
203+ GGML_ASSERT (seq_id < n_tokens && " seq_id cannot be larger than n_tokens with pooling_type == LAST" );
205204
206- if (pos >= last_pos[seq_idx]) {
207- last_pos[seq_idx] = pos;
208- last_row[seq_idx] = i;
205+ for (int i = 0 ; i < n_seq_tokens; ++i) {
206+ const llama_pos pos = ubatch->pos [s*n_seq_tokens + i];
207+
208+ if (pos >= last_pos[seq_id]) {
209+ last_pos[seq_id] = pos;
210+ last_row[seq_id] = s*n_seq_tokens + i;
209211 }
210212 }
211213 }
@@ -228,8 +230,8 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
228230 int32_t * data = (int32_t *) s_copy->data ;
229231
230232 // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
231- for (uint32_t i = 0 ; i < n_rs ; ++i) {
232- data[i] = mctx ->s_copy (i);
233+ for (uint32_t i = 0 ; i < n_kv ; ++i) {
234+ data[i] = kv_state ->s_copy (i);
233235 }
234236 }
235237}
@@ -962,7 +964,24 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
962964
963965 auto & cur = inp->cls ;
964966
965- cur = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq );
967+ cur = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_tokens);
968+ ggml_set_input (cur);
969+
970+ res->add_input (std::move (inp));
971+
972+ return cur;
973+ }
974+
975+ ggml_tensor * llm_graph_context::build_inp_s_copy () const {
976+ const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
977+
978+ auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
979+
980+ const auto n_kv = kv_state->get_n_kv ();
981+
982+ auto & cur = inp->s_copy ;
983+
984+ cur = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_kv);
966985 ggml_set_input (cur);
967986
968987 res->add_input (std::move (inp));
@@ -1425,19 +1444,27 @@ ggml_tensor * llm_graph_context::build_attn(
14251444 return cur;
14261445}
14271446
1428- ggml_tensor * llm_graph_context::build_copy_mask_state (
1447+ ggml_tensor * llm_graph_context::build_recurrent_state (
14291448 ggml_cgraph * gf,
14301449 ggml_tensor * s,
14311450 ggml_tensor * state_copy,
1432- ggml_tensor * state_mask ,
1433- int32_t n_state ,
1434- int32_t n_seqs ) const {
1451+ int32_t state_size ,
1452+ int32_t n_seqs ,
1453+ bool avoid_copies ) const {
14351454 const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
14361455
14371456 const auto n_kv = kv_state->get_n_kv ();
14381457 const auto kv_head = kv_state->get_head ();
1458+ const auto rs_zero = kv_state->get_rs_z ();
1459+
1460+ ggml_tensor * states = ggml_reshape_2d (ctx0, s, state_size, kv_state->get_size ());
14391461
1440- ggml_tensor * states = ggml_reshape_2d (ctx0, s, n_state, kv_state->get_size ());
1462+ // Clear a single state which will then be copied to the other cleared states.
1463+ // Note that this is a no-op when the view is zero-sized.
1464+ ggml_tensor * state_zero = ggml_view_1d (ctx0, states, state_size*(rs_zero >= 0 ), rs_zero*states->nb [1 ]*(rs_zero >= 0 ));
1465+ ggml_build_forward_expand (gf, ggml_scale_inplace (ctx0, state_zero, 0 ));
1466+
1467+ ggml_tensor * output_states;
14411468
14421469 // copy states
14431470 // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
@@ -1448,7 +1475,8 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
14481475 // FIXME: zero-out NANs?
14491476 states = ggml_mul (ctx0, states, state_mask);
14501477
1451- // copy states which won't be changed further (between n_seqs and n_kv)
1478+ // copy extra states which won't be changed further (between n_seqs and n_kv)
1479+ ggml_tensor * states_extra = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb [0 ]));
14521480 ggml_build_forward_expand (gf,
14531481 ggml_cpy (ctx0,
14541482 states_extra,
@@ -1457,47 +1485,10 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
14571485 return output_states;
14581486}
14591487
1460- llm_graph_input_rs * llm_graph_context::build_rs_inp () const {
1461- const auto * mctx_cur = static_cast <const llama_memory_recurrent_context *>(mctx);
1462-
1463- auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
1464-
1465- const auto n_rs = mctx_cur->get_n_rs ();
1466-
1467- inp->s_copy = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_rs);
1468- ggml_set_input (inp->s_copy );
1469-
1470- return (llm_graph_input_rs *) res->add_input (std::move (inp));
1471- }
1472-
1473- ggml_tensor * llm_graph_context::build_rs (
1474- llm_graph_input_rs * inp,
1475- ggml_cgraph * gf,
1476- ggml_tensor * s,
1477- int32_t state_size,
1478- int32_t n_seqs,
1479- bool avoid_copies) const {
1480- const auto * mctx_cur = static_cast <const llama_memory_recurrent_context *>(mctx);
1481-
1482- return build_rs (gf, s, inp->s_copy , state_size, n_seqs, mctx_cur->get_n_rs (), mctx_cur->get_head (), mctx_cur->get_size (), mctx_cur->get_rs_z (), avoid_copies);
1483- }
1484-
1485- ggml_tensor * llm_graph_context::build_rs (
1486- llm_graph_input_mem_hybrid * inp,
1487- ggml_cgraph * gf,
1488- ggml_tensor * s,
1489- int32_t state_size,
1490- int32_t n_seqs,
1491- bool avoid_copies) const {
1492- const auto * mctx_cur = static_cast <const llama_memory_hybrid_context *>(mctx)->get_recr ();
1493-
1494- return build_rs (gf, s, inp->s_copy , state_size, n_seqs, mctx_cur->get_n_rs (), mctx_cur->get_head (), mctx_cur->get_size (), mctx_cur->get_rs_z (), avoid_copies);
1495- }
1496-
14971488ggml_tensor * llm_graph_context::build_rwkv_token_shift_load (
1498- llm_graph_input_rs * inp ,
1499- ggml_cgraph * gf ,
1500- const llama_ubatch & ubatch,
1489+ ggml_cgraph * gf ,
1490+ ggml_tensor * state_copy ,
1491+ const llama_ubatch & ubatch,
15011492 int il) const {
15021493 const auto * mctx_cur = static_cast <const llama_memory_recurrent_context *>(mctx);
15031494
@@ -1507,9 +1498,9 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
15071498
15081499 ggml_tensor * token_shift_all = mctx_cur->get_r_l (il);
15091500
1510- ggml_tensor * token_shift = build_rs (
1511- inp, gf, token_shift_all,
1512- hparams.n_embd_r (), n_seqs);
1501+ ggml_tensor * token_shift = build_recurrent_state (
1502+ gf, token_shift_all, state_copy ,
1503+ hparams.n_embd_k_s (), n_seqs);
15131504
15141505 token_shift = ggml_reshape_3d (ctx0, token_shift, hparams.n_embd , token_shift_count, n_seqs);
15151506
0 commit comments