@@ -421,11 +421,38 @@ struct llm_graph_params {
421421 // TODO: temporary
422422 llm_graph_result_i * res;
423423
424- bool is_same (const llm_graph_params & other) const {
424+ // return true if the "other" params would result in a graph with the same topology as with the current params
425+ // having the same topology allows us to reuse the graph in some cases
426+ bool allow_reuse (const llm_graph_params & other) const {
427+ // first check the ubatch
428+ bool can_reuse_ubatch =
429+ ubatch.equal_seqs == other.ubatch .equal_seqs &&
430+ ubatch.n_tokens == other.ubatch .n_tokens &&
431+ ubatch.n_seq_tokens == other.ubatch .n_seq_tokens &&
432+ ubatch.n_seqs == other.ubatch .n_seqs &&
433+ ubatch.n_seqs_unq == other.ubatch .n_seqs_unq &&
434+ (
435+ (!ubatch.token && !other.ubatch .token ) ||
436+ (!ubatch.embd && !other.ubatch .embd )
437+ );
438+
439+ // TODO: this won't work because seq_id_unq ptr can point to an old balloc that has
440+ // been freed by this point. find a way to fix this
441+ // for (uint32_t s = 0; s < n_seqs_unq; ++s) {
442+ // can_reuse_ubatch &= seq_id_unq[s] == other.seq_id_unq[s];
443+ // }
444+
445+ // for now conservatively disallow, until the issue above is resolved
446+ // ref: https://github.com/ggml-org/llama.cpp/pull/14363
447+ can_reuse_ubatch = can_reuse_ubatch && !ubatch.equal_seqs ;
448+
449+ if (!can_reuse_ubatch) {
450+ return false ;
451+ }
452+
425453 return
426- hparams.is_same (other.hparams ) &&
427- cparams.is_same (other.cparams ) &&
428- ubatch .is_same (other.ubatch ) &&
454+ cparams.embeddings == other.cparams .embeddings &&
455+ cparams.causal_attn == other.cparams .causal_attn &&
429456 arch == other.arch &&
430457 gtype == other.gtype &&
431458 cvec == other.cvec &&
@@ -488,7 +515,7 @@ class llm_graph_result : public llm_graph_result_i {
488515 // contexts of the input tensors of the graph and we can reuse it for another computation
489516 // return true if the graph was updated and can be reused
490517 bool can_reuse (const llm_graph_params & params) override {
491- if (!this ->params .is_same (params)) {
518+ if (!this ->params .allow_reuse (params)) {
492519 return false ;
493520 }
494521
0 commit comments