From 5025c3c8a95fd12c23d485943e8e42607ef43c50 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 22 Nov 2024 13:59:11 -0800 Subject: [PATCH 1/3] add rewrites for hoisting constants from ite expressions Signed-off-by: Nikolaj Bjorner --- src/ast/rewriter/arith_rewriter.cpp | 66 ++++++++++++++++++++++++++++ src/ast/rewriter/arith_rewriter.h | 1 + src/ast/rewriter/poly_rewriter_def.h | 2 + src/ast/rewriter/th_rewriter.cpp | 10 ++++- 4 files changed, 77 insertions(+), 2 deletions(-) diff --git a/src/ast/rewriter/arith_rewriter.cpp b/src/ast/rewriter/arith_rewriter.cpp index b67e873c002..93949f4db49 100644 --- a/src/ast/rewriter/arith_rewriter.cpp +++ b/src/ast/rewriter/arith_rewriter.cpp @@ -804,6 +804,72 @@ bool arith_rewriter::is_arith_term(expr * n) const { return n->get_kind() == AST_APP && to_app(n)->get_family_id() == get_fid(); } +br_status arith_rewriter::mk_ite_core(expr* c, expr* t, expr* e, expr_ref & result) { + numeral v1, v2; + bool is_int; + bool is_num1 = m_util.is_numeral(t, v1, is_int); + bool is_num2 = m_util.is_numeral(e, v2, is_int); + if (is_num1 && is_num2 && v1 == 0 && v2 != 1) { + result = m_util.mk_mul(e, m.mk_ite(c, t, m_util.mk_numeral(rational(1), is_int))); + return BR_DONE; + } + if (is_num1 && is_num2 && v2 == 0 && v1 != 1) { + result = m_util.mk_mul(t, m.mk_ite(c, m_util.mk_numeral(rational(1), is_int), e)); + return BR_DONE; + } + if (is_num1 && is_num2 && is_int && gcd(v1, v2) != 1) { + auto g = gcd(v1, v2); + if (g > 0 && v1 < 0 && v2 < 0) + g = -g; + + result = m_util.mk_numeral(g, is_int); + result = m_util.mk_mul(result, m.mk_ite(c, m_util.mk_numeral(v1/g, true), m_util.mk_numeral(v2/g, true))); + return BR_REWRITE2; + } + if (is_num1 && is_num2 && v1 != 0 && v2 != 0 && v1 != v2) { + if (v1 > v2) + result = m_util.mk_add(e, m.mk_ite(c, m_util.mk_numeral(v1 - v2, is_int), m_util.mk_numeral(rational::zero(), is_int))); + else + result = m_util.mk_add(e, m.mk_ite(c, m_util.mk_numeral(rational::zero(), is_int), m_util.mk_numeral(v2 - v1, is_int))); + return BR_DONE; + } + expr* x, *y; + if (is_num1 && m_util.is_mul(e, x, y) && m_util.is_numeral(x, v2, is_int) && v2 != 0) { + if (v1 == 0) { + result = m_util.mk_mul(x, m.mk_ite(c, t, y)); + return BR_DONE; + } + if (is_int && divides(v2, v1)) { + result = m_util.mk_mul(x, m.mk_ite(c, m_util.mk_numeral(v1/v2, true), y)); + return BR_DONE; + } + + } + if (is_num2 && m_util.is_mul(t, x, y) && m_util.is_numeral(x, v1, is_int) && v1 != 0) { + if (v2 == 0) { + result = m_util.mk_mul(x, m.mk_ite(c, y, e)); + return BR_DONE; + } + if (is_int && divides(v1, v2)) { + result = m_util.mk_mul(x, m.mk_ite(c, y, m_util.mk_numeral(v2/v1, true))); + return BR_DONE; + } + + } + if (is_num1 && m_util.is_add(e, x, y) && m_util.is_numeral(x, v2, is_int)) { + result = m_util.mk_add(x, m.mk_ite(c, m_util.mk_numeral(v1 - v2, is_int), y)); + return BR_REWRITE2; + } + if (is_num2 && m_util.is_add(t, x, y) && m_util.is_numeral(x, v1, is_int)) { + result = m_util.mk_add(x, m.mk_ite(c, y, m_util.mk_numeral(v2 - v1, is_int))); + return BR_REWRITE2; + } + + + + return BR_FAILED; +} + br_status arith_rewriter::mk_eq_core(expr * arg1, expr * arg2, expr_ref & result) { br_status st = BR_FAILED; if (m_eq2ineq) { diff --git a/src/ast/rewriter/arith_rewriter.h b/src/ast/rewriter/arith_rewriter.h index a1aadfa7f10..cfdd1e58f87 100644 --- a/src/ast/rewriter/arith_rewriter.h +++ b/src/ast/rewriter/arith_rewriter.h @@ -137,6 +137,7 @@ class arith_rewriter : public poly_rewriter { br_status mk_lt_core(expr * arg1, expr * arg2, expr_ref & result); br_status mk_ge_core(expr * arg1, expr * arg2, expr_ref & result); br_status mk_gt_core(expr * arg1, expr * arg2, expr_ref & result); + br_status mk_ite_core(expr* c, expr* t, expr* e, expr_ref & result); br_status mk_add_core(unsigned num_args, expr * const * args, expr_ref & result); br_status mk_mul_core(unsigned num_args, expr * const * args, expr_ref & result); diff --git a/src/ast/rewriter/poly_rewriter_def.h b/src/ast/rewriter/poly_rewriter_def.h index f739579e6f9..a2c6b2a2f6b 100644 --- a/src/ast/rewriter/poly_rewriter_def.h +++ b/src/ast/rewriter/poly_rewriter_def.h @@ -1017,7 +1017,9 @@ bool poly_rewriter::hoist_ite(expr_ref& e) { ++i; } if (!pinned.empty()) { + TRACE("poly_rewriter", tout << e << "\n"); e = mk_add_app(adds.size(), adds.data()); + TRACE("poly_rewriter", tout << e << "\n"); return true; } return false; diff --git a/src/ast/rewriter/th_rewriter.cpp b/src/ast/rewriter/th_rewriter.cpp index 3af887008a8..50f0d3c6546 100644 --- a/src/ast/rewriter/th_rewriter.cpp +++ b/src/ast/rewriter/th_rewriter.cpp @@ -172,6 +172,9 @@ struct th_rewriter_cfg : public default_rewriter_cfg { family_id s_fid = args[1]->get_sort()->get_family_id(); if (s_fid == m_bv_rw.get_fid()) st = m_bv_rw.mk_ite_core(args[0], args[1], args[2], result); + if (st == BR_FAILED && s_fid == m_a_rw.get_fid()) + st = m_a_rw.mk_ite_core(args[0], args[1], args[2], result); + CTRACE("th_rewriter_step", st != BR_FAILED, tout << result << "\n"); if (st != BR_FAILED) return st; } @@ -197,7 +200,9 @@ struct th_rewriter_cfg : public default_rewriter_cfg { return st; } - return m_b_rw.mk_app_core(f, num, args, result); + st = m_b_rw.mk_app_core(f, num, args, result); + CTRACE("th_rewriter_step", st != BR_FAILED, tout << result << "\n"); + return st; } if (fid == m_a_rw.get_fid() && OP_LE == f->get_decl_kind() && m_seq_rw.u().has_seq()) { st = m_seq_rw.mk_le_core(args[0], args[1], result); @@ -315,7 +320,7 @@ struct th_rewriter_cfg : public default_rewriter_cfg { return pull_ite_core(f, to_app(args[1]), to_app(args[0]), result); } family_id fid = f->get_family_id(); - if (num == 2 && (fid == m().get_basic_family_id() || fid == m_a_rw.get_fid() || fid == m_bv_rw.get_fid())) { + if (num == 2 && (fid == m().get_basic_family_id() || fid == m_bv_rw.get_fid())) { // (f v3 (ite c v1 v2)) --> (ite v (f v3 v1) (f v3 v2)) if (m().is_value(args[0]) && is_ite_value_tree(args[1])) return pull_ite_core(f, to_app(args[1]), to_app(args[0]), result); @@ -554,6 +559,7 @@ struct th_rewriter_cfg : public default_rewriter_cfg { result = m().mk_app(f_prime, common, m().mk_ite(c, new_t, new_e)); else result = m().mk_app(f_prime, m().mk_ite(c, new_t, new_e), common); + TRACE("push_ite", tout << result << "\n";); return BR_DONE; } TRACE("push_ite", tout << "failed\n";); From 0bf9369414ec577a8eb2e4345ae99e1c7e618178 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 22 Nov 2024 14:56:41 -0800 Subject: [PATCH 2/3] fix bug in rewriter Signed-off-by: Nikolaj Bjorner --- src/ast/rewriter/arith_rewriter.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ast/rewriter/arith_rewriter.cpp b/src/ast/rewriter/arith_rewriter.cpp index 93949f4db49..e868a6eb204 100644 --- a/src/ast/rewriter/arith_rewriter.cpp +++ b/src/ast/rewriter/arith_rewriter.cpp @@ -830,7 +830,7 @@ br_status arith_rewriter::mk_ite_core(expr* c, expr* t, expr* e, expr_ref & resu if (v1 > v2) result = m_util.mk_add(e, m.mk_ite(c, m_util.mk_numeral(v1 - v2, is_int), m_util.mk_numeral(rational::zero(), is_int))); else - result = m_util.mk_add(e, m.mk_ite(c, m_util.mk_numeral(rational::zero(), is_int), m_util.mk_numeral(v2 - v1, is_int))); + result = m_util.mk_add(t, m.mk_ite(c, m_util.mk_numeral(rational::zero(), is_int), m_util.mk_numeral(v2 - v1, is_int))); return BR_DONE; } expr* x, *y; From 65bfcec146132193669b3902dceeabb35510e18f Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 22 Nov 2024 16:56:08 -0800 Subject: [PATCH 3/3] resurrect rewriting of equality over ite Signed-off-by: Nikolaj Bjorner --- src/ast/rewriter/arith_rewriter.cpp | 14 ++++++++++++++ src/ast/rewriter/th_rewriter.cpp | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/ast/rewriter/arith_rewriter.cpp b/src/ast/rewriter/arith_rewriter.cpp index e868a6eb204..b1782484359 100644 --- a/src/ast/rewriter/arith_rewriter.cpp +++ b/src/ast/rewriter/arith_rewriter.cpp @@ -740,6 +740,20 @@ br_status arith_rewriter::mk_le_ge_eq_core(expr * arg1, expr * arg2, op_kind kin case EQ: result = m.mk_ite(c, m.mk_eq(t, arg2), m.mk_eq(e, arg2)); return BR_REWRITE2; } } + if (m.is_ite(arg2, c, t, e) && is_numeral(t, a2) && is_numeral(arg1, a1)) { + switch (kind) { + case LE: result = a1 <= a2 ? m.mk_or(c, m_util.mk_le(arg1, e)) : m.mk_and(m.mk_not(c), m_util.mk_le(arg1, e)); return BR_REWRITE2; + case GE: result = a1 >= a2 ? m.mk_or(c, m_util.mk_ge(arg1, e)) : m.mk_and(m.mk_not(c), m_util.mk_ge(arg1, e)); return BR_REWRITE2; + case EQ: result = a1 == a2 ? m.mk_or(c, m.mk_eq(e, arg1)) : m.mk_and(m.mk_not(c), m_util.mk_eq(e, arg1)); return BR_REWRITE2; + } + } + if (m.is_ite(arg2, c, t, e) && is_numeral(e, a2) && is_numeral(arg1, a1)) { + switch (kind) { + case LE: result = a1 <= a2 ? m.mk_or(m.mk_not(c), m_util.mk_le(arg1, t)) : m.mk_and(c, m_util.mk_le(arg1, t)); return BR_REWRITE2; + case GE: result = a1 >= a2 ? m.mk_or(m.mk_not(c), m_util.mk_ge(arg1, e)) : m.mk_and(c, m_util.mk_ge(arg1, t)); return BR_REWRITE2; + case EQ: result = a1 == a2 ? m.mk_or(m.mk_not(c), m.mk_eq(t, arg1)) : m.mk_and(c, m_util.mk_eq(t, arg1)); return BR_REWRITE2; + } + } if (m_util.is_to_int(arg2) && is_numeral(arg1)) { kind = inv(kind); std::swap(arg1, arg2); diff --git a/src/ast/rewriter/th_rewriter.cpp b/src/ast/rewriter/th_rewriter.cpp index 50f0d3c6546..87ab739fcde 100644 --- a/src/ast/rewriter/th_rewriter.cpp +++ b/src/ast/rewriter/th_rewriter.cpp @@ -320,7 +320,7 @@ struct th_rewriter_cfg : public default_rewriter_cfg { return pull_ite_core(f, to_app(args[1]), to_app(args[0]), result); } family_id fid = f->get_family_id(); - if (num == 2 && (fid == m().get_basic_family_id() || fid == m_bv_rw.get_fid())) { + if (num == 2 && (fid == m().get_basic_family_id())) { // (f v3 (ite c v1 v2)) --> (ite v (f v3 v1) (f v3 v2)) if (m().is_value(args[0]) && is_ite_value_tree(args[1])) return pull_ite_core(f, to_app(args[1]), to_app(args[0]), result);