@@ -282,7 +282,7 @@ bool llama_batch_allocr::init(
282282 }
283283 }
284284
285- // disallow disjoint sequence sets:
285+ // disallow partial sequence sub- sets:
286286 //
287287 // invalid: x
288288 // i: 0 1 2 ...
@@ -291,28 +291,46 @@ bool llama_batch_allocr::init(
291291 // seq_id[i][1]: 1 1 2
292292 // seq_id[i][2]: 2
293293 //
294+ // disallow decreasing sequence positions:
295+ //
296+ // invalid: x
297+ // i: 0 1 2 3 4 5 6 ...
298+ // ---------------------------------------
299+ // pos[i]: 4 5 0 1 6 2 3
300+ // seq_id[i][0]: 0 0 1 1 0 1 0
301+ //
294302 {
295303 seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
296304 for (int32_t s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
297305 cur_seq_set[s].set ();
298306 }
299307
308+ llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
309+ for (int32_t s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
310+ cur_seq_pos[s] = -1 ;
311+ }
312+
300313 for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
314+ const llama_pos pos = batch.pos [i];
315+
301316 for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
302317 const llama_seq_id seq_id = batch.seq_id [i][s];
303318
304319 cur_seq_set[seq_id] &= seq_set[i];
305320
306321 if (cur_seq_set[seq_id].none ()) {
307- LLAMA_LOG_ERROR (" %s: sequence %d belongs to incompatible sequence sets\n " , __func__, seq_id);
322+ LLAMA_LOG_ERROR (" %s: sequence %d belongs to incompatible sequence sets (not allowed)\n " , __func__, seq_id);
323+ return false ;
324+ }
325+
326+ if (pos < cur_seq_pos[seq_id]) {
327+ LLAMA_LOG_ERROR (" %s: sequence %d positions are decreasing (not allowed)\n " , __func__, seq_id);
308328 return false ;
309329 }
310330 }
311331 }
312332 }
313333
314- // TODO: check that positions are increasing
315-
316334 split_reset ();
317335
318336 return true ;
0 commit comments