From 81084b8232cd9e2ec85d0d96e8d34ce3b16e86d3 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 1 Apr 2022 13:07:17 -0700 Subject: [PATCH] #5778 #5937 --- src/model/datatype_factory.cpp | 2 +- src/sat/smt/array_model.cpp | 162 ++++++++++++++++++++++----------- src/sat/smt/array_solver.h | 29 ++++++ src/sat/smt/sat_th.h | 1 + src/smt/theory_arith.h | 4 +- src/smt/theory_arith_core.h | 53 ++++++++--- 6 files changed, 184 insertions(+), 67 deletions(-) diff --git a/src/model/datatype_factory.cpp b/src/model/datatype_factory.cpp index 7ca3b5699bb..e58812a1f45 100644 --- a/src/model/datatype_factory.cpp +++ b/src/model/datatype_factory.cpp @@ -217,7 +217,7 @@ expr * datatype_factory::get_fresh_value(sort * s) { expr * maybe_new_arg = nullptr; if (!m_util.is_datatype(s_arg)) maybe_new_arg = m_model.get_fresh_value(s_arg); - else if (num_iterations <= 1 || s == s_arg) + else if (num_iterations <= 1 || m_util.is_recursive(s_arg)) maybe_new_arg = get_almost_fresh_value(s_arg); else maybe_new_arg = get_fresh_value(s_arg); diff --git a/src/sat/smt/array_model.cpp b/src/sat/smt/array_model.cpp index 1f1066121bf..cb9ffe229d2 100644 --- a/src/sat/smt/array_model.cpp +++ b/src/sat/smt/array_model.cpp @@ -24,6 +24,11 @@ namespace array { void solver::init_model() { collect_defaults(); + collect_selects(); + } + + void solver::finalize_model(model& mdl) { + std::for_each(m_selects_range.begin(), m_selects_range.end(), delete_proc()); } bool solver::add_dep(euf::enode* n, top_sort& dep) { @@ -103,17 +108,15 @@ namespace array { if (!get_else(v) && fi->get_else()) set_else(v, fi->get_else()); - - for (euf::enode* p : euf::enode_parents(n)) { - if (a.is_select(p->get_expr()) && p->get_arg(0)->get_root() == n) { - expr* value = values.get(p->get_root_id(), nullptr); - if (!value || value == fi->get_else()) - continue; - args.reset(); - for (unsigned i = 1; i < p->num_args(); ++i) - args.push_back(values.get(p->get_arg(i)->get_root_id())); - fi->insert_entry(args.data(), value); - } + + for (euf::enode* p : *get_select_set(n)) { + expr* value = values.get(p->get_root_id(), nullptr); + if (!value || value == fi->get_else()) + continue; + args.reset(); + for (unsigned i = 1; i < p->num_args(); ++i) + args.push_back(values.get(p->get_arg(i)->get_root_id())); + fi->insert_entry(args.data(), value); } TRACE("array", tout << "array-as-function " << ctx.bpp(n) << " := " << mk_pp(f, m) << "\n" << "default " << mk_pp(fi->get_else(), m) << "\n";); @@ -135,52 +138,103 @@ namespace array { return true; return false; -#if 0 - struct eq { - solver& s; - eq(solver& s) :s(s) {} - bool operator()(euf::enode* n1, euf::enode* n2) const { - SASSERT(s.a.is_select(n1->get_expr())); - SASSERT(s.a.is_select(n2->get_expr())); - for (unsigned i = n1->num_args(); i-- > 1; ) - if (n1->get_arg(i)->get_root() != n2->get_arg(i)->get_root()) - return false; - return true; - } - }; - struct hash { - solver& s; - hash(solver& s) :s(s) {} - unsigned operator()(euf::enode* n) const { - SASSERT(s.a.is_select(n->get_expr())); - unsigned h = 33; - for (unsigned i = n->num_args(); i-- > 1; ) - h = hash_u_u(h, n->get_arg(i)->get_root_id()); - return h; + } + + unsigned solver::sel_hash::operator()(euf::enode * n) const { + return get_composite_hash(n, n->num_args() - 1, sel_khasher(), sel_chasher()); + } + + bool solver::sel_eq::operator()(euf::enode * n1, euf::enode * n2) const { + SASSERT(n1->num_args() == n2->num_args()); + unsigned num_args = n1->num_args(); + for (unsigned i = 1; i < num_args; i++) + if (n1->get_arg(i)->get_root() != n2->get_arg(i)->get_root()) + return false; + return true; + } + + + void solver::collect_selects() { + int num_vars = get_num_vars(); + + m_selects.reset(); + m_selects_domain.reset(); + m_selects_range.reset(); + + for (theory_var v = 0; v < num_vars; ++v) { + euf::enode * r = var2enode(v)->get_root(); + if (is_representative(v) && ctx.is_relevant(r)) { + for (euf::enode * parent : euf::enode_parents(r)) { + if (parent->get_cg() == parent && + ctx.is_relevant(parent) && + a.is_select(parent->get_expr()) && + parent->get_arg(0)->get_root() == r) { + select_set * s = get_select_set(r); + SASSERT(!s->contains(parent) || (*(s->find(parent)))->get_root() == parent->get_root()); + s->insert(parent); + } + } } - }; - eq eq_proc(*this); - hash hash_proc(*this); - hashtable table(DEFAULT_HASHTABLE_INITIAL_CAPACITY, hash_proc, eq_proc); - euf::enode* p2 = nullptr; - auto maps_diff = [&](euf::enode* p, euf::enode* else_, euf::enode* r) { - return table.find(p, p2) ? p2->get_root() != r : (else_ && else_ != r); - }; - auto table_diff = [&](euf::enode* r1, euf::enode* r2, euf::enode* else1) { - table.reset(); - for (euf::enode* p : euf::enode_parents(r1)) - if (a.is_select(p->get_expr()) && r1 == p->get_arg(0)->get_root()) - table.insert(p); - for (euf::enode* p : euf::enode_parents(r2)) - if (a.is_select(p->get_expr()) && r2 == p->get_arg(0)->get_root()) - if (maps_diff(p, else1, p->get_root())) - return true; - return false; - }; + } + euf::enode_pair_vector todo; + for (euf::enode * r : m_selects_domain) + for (euf::enode* sel : *get_select_set(r)) + propagate_select_to_store_parents(r, sel, todo); + for (unsigned qhead = 0; qhead < todo.size(); qhead++) { + euf::enode_pair & pair = todo[qhead]; + euf::enode * r = pair.first; + euf::enode * sel = pair.second; + propagate_select_to_store_parents(r, sel, todo); + } + } + + void solver::propagate_select_to_store_parents(euf::enode* r, euf::enode* sel, euf::enode_pair_vector& todo) { + SASSERT(r->get_root() == r); + SASSERT(a.is_select(sel->get_expr())); + if (!ctx.is_relevant(r)) + return; - return table_diff(r1, r2, else1) || table_diff(r2, r1, else2); + for (euf::enode * parent : euf::enode_parents(r)) { + if (ctx.is_relevant(parent) && + a.is_store(parent->get_expr()) && + parent->get_arg(0)->get_root() == r) { + // propagate upward + select_set * parent_sel_set = get_select_set(parent); + euf::enode * parent_root = parent->get_root(); + + if (parent_sel_set->contains(sel)) + continue; -#endif + SASSERT(sel->num_args() + 1 == parent->num_args()); + + // check whether the sel idx was overwritten by the store + unsigned num_args = sel->num_args(); + unsigned i = 1; + for (; i < num_args; i++) { + if (sel->get_arg(i)->get_root() != parent->get_arg(i)->get_root()) + break; + } + + if (i < num_args) { + SASSERT(!parent_sel_set->contains(sel) || (*(parent_sel_set->find(sel)))->get_root() == sel->get_root()); + parent_sel_set->insert(sel); + todo.push_back(std::make_pair(parent_root, sel)); + } + } + } + } + + solver::select_set* solver::get_select_set(euf::enode* n) { + euf::enode * r = n->get_root(); + select_set * set = nullptr; + m_selects.find(r, set); + if (set == nullptr) { + set = alloc(select_set); + m_selects.insert(r, set); + m_selects_domain.push_back(r); + m_selects_range.push_back(set); + } + return set; } void solver::collect_defaults() { diff --git a/src/sat/smt/array_solver.h b/src/sat/smt/array_solver.h index 31bdba4a19c..511f971a362 100644 --- a/src/sat/smt/array_solver.h +++ b/src/sat/smt/array_solver.h @@ -218,11 +218,39 @@ namespace array { void pop_core(unsigned n) override; // models + // I need a set of select enodes where select(A,i) = select(B,j) if i->get_root() == j->get_root() + struct sel_khasher { + unsigned operator()(euf::enode const * n) const { return 0; } + }; + + struct sel_chasher { + unsigned operator()(euf::enode const * n, unsigned idx) const { + return n->get_arg(idx+1)->get_root()->hash(); + } + }; + + struct sel_hash { + unsigned operator()(euf::enode * n) const; + }; + + struct sel_eq { + bool operator()(euf::enode * n1, euf::enode * n2) const; + }; + + typedef ptr_hashtable select_set; euf::enode_vector m_defaults; // temporary field for model construction ptr_vector m_else_values; // svector m_parents; // temporary field for model construction + obj_map m_selects; // mapping from array -> relevant selects + ptr_vector m_selects_domain; + ptr_vector m_selects_range; + bool must_have_different_model_values(theory_var v1, theory_var v2); + select_set* get_select_set(euf::enode* n); void collect_defaults(); + void collect_selects(); // mapping from array -> relevant selects + void propagate_select_to_store_parents(euf::enode* r, euf::enode* sel, euf::enode_pair_vector& todo); + void mg_merge(theory_var u, theory_var v); theory_var mg_find(theory_var n); void set_default(theory_var v, euf::enode* n); @@ -254,6 +282,7 @@ namespace array { void new_diseq_eh(euf::th_eq const& eq) override; bool unit_propagate() override; void init_model() override; + void finalize_model(model& mdl) override; bool include_func_interp(func_decl* f) const override { return a.is_ext(f); } void add_value(euf::enode* n, model& mdl, expr_ref_vector& values) override; bool add_dep(euf::enode* n, top_sort& dep) override; diff --git a/src/sat/smt/sat_th.h b/src/sat/smt/sat_th.h index b2d8d85b783..dbd042e9896 100644 --- a/src/sat/smt/sat_th.h +++ b/src/sat/smt/sat_th.h @@ -188,6 +188,7 @@ namespace euf { enode* expr2enode(expr* e) const; enode* var2enode(theory_var v) const { return m_var2enode[v]; } expr* var2expr(theory_var v) const { return var2enode(v)->get_expr(); } + bool is_representative(theory_var v) const { return v == get_representative(v); } expr* bool_var2expr(sat::bool_var v) const; expr_ref literal2expr(sat::literal lit) const; enode* bool_var2enode(sat::bool_var v) const { expr* e = bool_var2expr(v); return e ? expr2enode(e) : nullptr; } diff --git a/src/smt/theory_arith.h b/src/smt/theory_arith.h index 31acc2be0d7..6e1d77dd427 100644 --- a/src/smt/theory_arith.h +++ b/src/smt/theory_arith.h @@ -602,9 +602,11 @@ namespace smt { void add_row_entry(unsigned r_id, numeral const & coeff, theory_var v); uint_set& row_vars(); class scoped_row_vars; - + + void check_app(expr* e, expr* n); void internalize_internal_monomial(app * m, unsigned r_id); theory_var internalize_add(app * n); + theory_var internalize_sub(app * n); theory_var internalize_mul_core(app * m); theory_var internalize_mul(app * m); theory_var internalize_div(app * n); diff --git a/src/smt/theory_arith_core.h b/src/smt/theory_arith_core.h index 4a2963656c9..0168652cb16 100644 --- a/src/smt/theory_arith_core.h +++ b/src/smt/theory_arith_core.h @@ -302,6 +302,44 @@ namespace smt { } } + template + void theory_arith::check_app(expr* e, expr* n) { + if (is_app(e)) + return; + std::ostringstream strm; + strm << mk_pp(n, m) << " contains a " << (is_var(e) ? "free variable":"quantifier"); + throw default_exception(strm.str()); + } + + + template + theory_var theory_arith::internalize_sub(app * n) { + VERIFY(m_util.is_sub(n)); + bool first = true; + unsigned r_id = mk_row(); + scoped_row_vars _sc(m_row_vars, m_row_vars_top); + theory_var v; + for (expr* arg : *n) { + check_app(arg, n); + v = internalize_term_core(to_app(arg)); + if (first) + add_row_entry(r_id, numeral::one(), v); + else + add_row_entry(r_id, numeral::one(), v); + first = false; + } + enode * e = mk_enode(n); + v = e->get_th_var(get_id()); + if (v == null_theory_var) { + v = mk_var(e); + add_row_entry(r_id, numeral::one(), v); + init_row(r_id); + } + else + del_row(r_id); + return v; + } + /** \brief Internalize a polynomial (+ h t). Return an alias for the monomial, that is, a variable v such that v = (+ h t) is a new row in the tableau. @@ -314,11 +352,7 @@ namespace smt { unsigned r_id = mk_row(); scoped_row_vars _sc(m_row_vars, m_row_vars_top); for (expr* arg : *n) { - if (is_var(arg)) { - std::ostringstream strm; - strm << mk_pp(n, m) << " contains a free variable"; - throw default_exception(strm.str()); - } + check_app(arg, n); internalize_internal_monomial(to_app(arg), r_id); } enode * e = mk_enode(n); @@ -383,11 +417,7 @@ namespace smt { } unsigned r_id = mk_row(); scoped_row_vars _sc(m_row_vars, m_row_vars_top); - if (is_var(arg1)) { - std::ostringstream strm; - strm << mk_pp(m, get_manager()) << " contains a free variable"; - throw default_exception(strm.str()); - } + check_app(arg1, m); if (reflection_enabled()) internalize_term_core(to_app(arg0)); theory_var v = internalize_mul_core(to_app(arg1)); @@ -749,7 +779,6 @@ namespace smt { return e->get_th_var(get_id()); } - SASSERT(!m_util.is_sub(n)); SASSERT(!m_util.is_uminus(n)); if (m_util.is_add(n)) @@ -770,6 +799,8 @@ namespace smt { return internalize_to_int(n); else if (m_util.is_numeral(n)) return internalize_numeral(n); + else if (m_util.is_sub(n)) + return internalize_sub(n); if (m_util.is_power(n)) { // unsupported found_unsupported_op(n);