Skip to content

Commit

Permalink
bug fixes to sls
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolajBjorner committed Nov 17, 2024
1 parent e380903 commit c7ea496
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 139 deletions.
121 changes: 23 additions & 98 deletions src/ast/sls/sls_arith_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ namespace sls {
}
}



template<typename num_t>
std::ostream& arith_base<num_t>::ineq::display(std::ostream& out) const {
bool first = true;
Expand Down Expand Up @@ -118,7 +116,7 @@ namespace sls {
template<typename num_t>
void arith_base<num_t>::save_best_values() {
for (auto& v : m_vars)
v.m_best_value = v.m_value;
v.set_best_value(v.value());
check_ineqs();
}

Expand Down Expand Up @@ -168,8 +166,8 @@ namespace sls {
template<typename num_t>
num_t arith_base<num_t>::dtt(bool sign, ineq const& ineq, var_t v, num_t const& new_value) const {
for (auto const& [coeff, w] : ineq.m_args)
if (w == v)
return dtt(sign, ineq.m_args_value + coeff * (new_value - m_vars[v].m_value), ineq);
if (w == v)
return dtt(sign, ineq.m_args_value + coeff * (new_value - m_vars[v].value()), ineq);
return num_t(1);
}

Expand Down Expand Up @@ -444,17 +442,19 @@ namespace sls {

delta_out = delta;

if (m_last_var == v && m_last_delta == -delta)
return false;
if (m_last_var == v && m_last_delta == -delta)
return false;

if (m_use_tabu && vi.is_tabu(m_stats.m_num_steps, delta))
if (m_use_tabu && vi.is_tabu(m_stats.m_num_steps, delta))
return false;


auto old_value = value(v);
auto new_value = old_value + delta;
if (!vi.in_range(new_value))
return false;


if (m_use_tabu && !in_bounds(v, new_value) && in_bounds(v, old_value)) {
auto const& lo = m_vars[v].m_lo;
auto const& hi = m_vars[v].m_hi;
Expand Down Expand Up @@ -490,9 +490,7 @@ namespace sls {
void arith_base<num_t>::add_update(var_t v, num_t delta) {
num_t delta_out;
if (!is_permitted_update(v, delta, delta_out))
return;


return;
m_updates.push_back({ v, delta_out, 0 });
}

Expand Down Expand Up @@ -647,7 +645,7 @@ namespace sls {
bool arith_base<num_t>::update(var_t v, num_t const& new_value) {
auto& vi = m_vars[v];
expr* e = vi.m_expr;
auto old_value = vi.m_value;
auto old_value = vi.value();
if (old_value == new_value)
return true;
if (!vi.in_range(new_value))
Expand All @@ -665,15 +663,10 @@ namespace sls {
}
}
catch (overflow_exception const&) {
verbose_stream() << "overflow1\n";
return false;
}

#if 0
if (!check_update(v, new_value))
return false;
apply_checked_update();
#else

buffer<sat::bool_var> to_flip;
for (auto const& [coeff, bv] : vi.m_bool_vars) {
auto& ineq = *atom(bv);
Expand All @@ -687,12 +680,13 @@ namespace sls {

}
IF_VERBOSE(5, verbose_stream() << "repair: v" << v << " := " << old_value << " -> " << new_value << "\n");
vi.m_value = new_value;
vi.set_value(new_value);
ctx.new_value_eh(e);
m_last_var = v;

for (auto bv : to_flip) {
ctx.flip(bv);
if (dtt(sign(bv), *atom(bv)) != 0)
ctx.flip(bv);
SASSERT(dtt(sign(bv), *atom(bv)) == 0);
}

Expand All @@ -711,6 +705,7 @@ namespace sls {
prod *= power_of(value(w), p);
}
catch (overflow_exception const&) {
verbose_stream() << "overflow\n";
return false;
}
if (value(w) != prod && !update(w, prod))
Expand All @@ -727,82 +722,10 @@ namespace sls {
if (!update(ad.m_var, sum))
return false;
}
#endif

return true;
}

template<typename num_t>
bool arith_base<num_t>::check_update(var_t v, num_t new_value) {

++m_update_timestamp;
if (m_update_timestamp == 0) {
for (auto& vi : m_vars)
vi.set_update_value(num_t(0), 0);
++m_update_timestamp;
}
auto& vi = m_vars[v];
m_update_trail.reset();
m_update_trail.push_back(v);
vi.set_update_value(new_value, m_update_timestamp);

num_t delta;
for (unsigned i = 0; i < m_update_trail.size(); ++i) {
auto v = m_update_trail[i];
auto& vi = m_vars[v];
for (auto idx : vi.m_muls) {
auto const& [w, monomial] = m_muls[idx];
num_t prod(1);
try {
for (auto [w, p] : monomial)
prod *= power_of(get_update_value(w), p);
}
catch (overflow_exception const&) {
return false;
}
if (get_update_value(w) != prod && (!is_permitted_update(w, prod - value(w), delta) || prod - value(w) != delta))
return false;
m_update_trail.push_back(w);
m_vars[w].set_update_value(prod, m_update_timestamp);
}

for (auto idx : vi.m_adds) {
auto const& ad = m_adds[idx];
auto w = ad.m_var;
num_t sum(ad.m_coeff);
for (auto const& [coeff, w] : ad.m_args)
sum += coeff * get_update_value(w);
if (get_update_value(v) != sum && !(is_permitted_update(w, sum - value(w), delta) || sum - value(w) != delta))
return false;
m_update_trail.push_back(w);
m_vars[w].set_update_value(sum, m_update_timestamp);
}
}
return true;
}

template<typename num_t>
void arith_base<num_t>::apply_checked_update() {
for (auto v : m_update_trail) {
auto & vi = m_vars[v];
auto old_value = vi.m_value;
vi.m_value = vi.get_update_value(m_update_timestamp);
auto new_value = vi.m_value;
ctx.new_value_eh(vi.m_expr);
for (auto const& [coeff, bv] : vi.m_bool_vars) {
auto& ineq = *atom(bv);
bool old_sign = sign(bv);
sat::literal lit(bv, old_sign);
SASSERT(ctx.is_true(lit));
ineq.m_args_value += coeff * (new_value - old_value);
num_t dtt_new = dtt(old_sign, ineq);
if (dtt_new != 0)
ctx.flip(bv);
SASSERT(dtt(sign(bv), ineq) == 0);
}
}
}


template<typename num_t>
typename arith_base<num_t>::ineq& arith_base<num_t>::new_ineq(ineq_kind op, num_t const& coeff) {
auto* i = alloc(ineq);
Expand Down Expand Up @@ -906,7 +829,7 @@ namespace sls {
m_vars[w].m_muls.push_back(idx), prod *= power_of(value(w), p);
m_vars[v].m_def_idx = idx;
m_vars[v].m_op = arith_op_kind::OP_MUL;
m_vars[v].m_value = prod;
m_vars[v].set_value(prod);
add_arg(term, coeff, v);
break;
}
Expand Down Expand Up @@ -972,7 +895,7 @@ namespace sls {
m_ops.push_back({v, k, v, w});
m_vars[v].m_def_idx = idx;
m_vars[v].m_op = k;
m_vars[v].m_value = val;
m_vars[v].set_value(val);
return v;
}

Expand All @@ -993,7 +916,7 @@ namespace sls {
m_vars[w].m_adds.push_back(idx), sum += c * value(w);
m_vars[v].m_def_idx = idx;
m_vars[v].m_op = arith_op_kind::OP_ADD;
m_vars[v].m_value = sum;
m_vars[v].set_value(sum);
return v;
}

Expand Down Expand Up @@ -1055,6 +978,7 @@ namespace sls {
else {
SASSERT(!a.is_arith_expr(e));
}

}

template<typename num_t>
Expand Down Expand Up @@ -1345,6 +1269,7 @@ namespace sls {
hi_valid = false;
}
catch (overflow_exception&) {
verbose_stream() << "overflow3\n";
hi_valid = false;
}
}
Expand Down Expand Up @@ -2021,7 +1946,7 @@ namespace sls {
if (is_num(e, n))
return expr_ref(a.mk_numeral(n.to_rational(), a.is_int(e)), m);
auto v = mk_term(e);
return expr_ref(a.mk_numeral(m_vars[v].m_value.to_rational(), a.is_int(e)), m);
return expr_ref(a.mk_numeral(m_vars[v].value().to_rational(), a.is_int(e)), m);
}

template<typename num_t>
Expand Down Expand Up @@ -2112,7 +2037,7 @@ namespace sls {
auto const& vi = m_vars[v];
auto const& lo = vi.m_lo;
auto const& hi = vi.m_hi;
out << "v" << v << " := " << vi.m_value << " ";
out << "v" << v << " := " << vi.value() << " ";
if (lo || hi) {
if (lo)
out << (lo->is_strict ? "(": "[") << lo->value;
Expand Down
56 changes: 33 additions & 23 deletions src/ast/sls/sls_arith_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,14 @@ namespace sls {

class var_info {
num_t m_range{ 100000000 };
num_t m_update_value{ 0 };
unsigned m_update_timestamp = 0;
unsigned m_num_out_of_range = 0;
unsigned m_num_in_range = 0;
num_t m_value{ 0 };
num_t m_best_value{ 0 };
public:
var_info(expr* e, var_sort k): m_expr(e), m_sort(k) {}
expr* m_expr;
num_t m_value{ 0 };
num_t m_best_value{ 0 };

var_sort m_sort;
arith_op_kind m_op = arith_op_kind::LAST_ARITH_OP;
unsigned m_def_idx = UINT_MAX;
Expand All @@ -91,23 +92,27 @@ namespace sls {
unsigned_vector m_adds;
optional<bound> m_lo, m_hi;

// retrieve temporary value during an update.
void set_update_value(num_t const& v, unsigned timestamp) {
m_update_value = v;
m_update_timestamp = timestamp;
}
num_t const& get_update_value(unsigned ts) const {
return ts == m_update_timestamp ? m_update_value : m_value;
}
num_t const& value() const { return m_value; }
void set_value(num_t const& v) { m_value = v; }

num_t const& best_value() const { return m_best_value; }
void set_best_value(num_t const& v) { m_best_value = v; }

bool in_range(num_t const& n) const {
bool in_range(num_t const& n) {
if (-m_range < n && n < m_range)
return true;
bool result = false;
if (m_lo && !m_hi)
return n < m_lo->value + m_range;
if (!m_lo && m_hi)
return n > m_hi->value - m_range;
return false;
result = n < m_lo->value + m_range;
else if (!m_lo && m_hi)
result = n > m_hi->value - m_range;
#if 0
if (!result)
out_of_range();
else
++m_num_in_range;
#endif
return result;
}
unsigned m_tabu_pos = 0, m_tabu_neg = 0;
unsigned m_last_pos = 0, m_last_neg = 0;
Expand All @@ -120,6 +125,15 @@ namespace sls {
else
m_tabu_neg = tabu_step, m_last_neg = step;
}
void out_of_range() {
++m_num_out_of_range;
if (m_num_out_of_range < 1000 * (1 + m_num_in_range))
return;
IF_VERBOSE(2, verbose_stream() << "increase range " << m_range << "\n");
m_range *= 2;
m_num_out_of_range = 0;
m_num_in_range = 0;
}
};

struct mul_def {
Expand Down Expand Up @@ -187,10 +201,7 @@ namespace sls {

void add_update(var_t v, num_t delta);
bool is_permitted_update(var_t v, num_t const& delta, num_t& delta_out);
unsigned m_update_timestamp = 0;
svector<var_t> m_update_trail;
bool check_update(var_t v, num_t new_value);
void apply_checked_update();


num_t value1(var_t v);

Expand Down Expand Up @@ -247,8 +258,7 @@ namespace sls {

bool is_int(var_t v) const { return m_vars[v].m_sort == var_sort::INT; }

num_t value(var_t v) const { return m_vars[v].m_value; }
num_t const& get_update_value(var_t v) const { return m_vars[v].get_update_value(m_update_timestamp); }
num_t value(var_t v) const { return m_vars[v].value(); }
bool is_num(expr* e, num_t& i);
expr_ref from_num(sort* s, num_t const& n);
void check_ineqs();
Expand Down
10 changes: 3 additions & 7 deletions src/ast/sls/sls_arith_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace sls {
return m_arith64->_fn_;\
}\
catch (overflow_exception&) {\
throw;\
IF_VERBOSE(1, verbose_stream() << "revert to bignum solver " << #_fn_ << "\n");\
init_backup();\
}\
}\
Expand All @@ -39,7 +39,7 @@ namespace sls {
m_arith64->_fn_;\
}\
catch (overflow_exception&) {\
throw;\
IF_VERBOSE(1, verbose_stream() << "revert to bignum solver " << #_fn_ << "\n");\
init_backup();\
}\
}\
Expand All @@ -49,11 +49,7 @@ namespace sls {
plugin(ctx), m_shared(ctx.get_manager()) {
m_arith64 = alloc(arith_base<checked_int64<true>>, ctx);
m_arith = alloc(arith_base<rational>, ctx);
m_arith64 = nullptr;
if (m_arith)
m_fid = m_arith->fid();
else
m_fid = m_arith64->fid();
m_fid = m_arith->fid();
}

void arith_plugin::init_backup() {
Expand Down
Loading

0 comments on commit c7ea496

Please sign in to comment.