Skip to content

Commit

Permalink
[CP-SAT] add automatic cast Literal -> LiteralIndex; implement genera…
Browse files Browse the repository at this point in the history
…l division where the denominator can have any domain that does not contain 0
  • Loading branch information
lperron committed Oct 12, 2023
1 parent bc39e45 commit f1bbd65
Show file tree
Hide file tree
Showing 14 changed files with 278 additions and 252 deletions.
172 changes: 84 additions & 88 deletions ortools/sat/clause.cc

Large diffs are not rendered by default.

17 changes: 8 additions & 9 deletions ortools/sat/clause.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ class LiteralWatchers : public SatPropagator {
// This is exposed since some inprocessing code can heuristically exploit the
// currently watched literal and blocking literal to do some simplification.
const std::vector<Watcher>& WatcherListOnFalse(Literal false_literal) const {
return watchers_on_false_[false_literal.Index()];
return watchers_on_false_[false_literal];
}

private:
Expand Down Expand Up @@ -564,16 +564,16 @@ class BinaryImplicationGraph : public SatPropagator {
// Returns the list of literal "directly" implied by l. Beware that this can
// easily change behind your back if you modify the solver state.
const absl::InlinedVector<Literal, 6>& Implications(Literal l) const {
return implications_[l.Index()];
return implications_[l];
}

// Returns the representative of the equivalence class of l (or l itself if it
// is on its own). Note that DetectEquivalences() should have been called to
// get any non-trival results.
Literal RepresentativeOf(Literal l) const {
if (l.Index() >= representative_of_.size()) return l;
if (representative_of_[l.Index()] == kNoLiteralIndex) return l;
return Literal(representative_of_[l.Index()]);
if (representative_of_[l] == kNoLiteralIndex) return l;
return Literal(representative_of_[l]);
}

// Prunes the implication graph by calling first DetectEquivalences() to
Expand Down Expand Up @@ -637,7 +637,7 @@ class BinaryImplicationGraph : public SatPropagator {
// Note that the set (and thus number) of redundant literal can only grow over
// time. This is because we always use the lowest index as representative of
// an equivalent class, so a redundant literal will stay that way.
bool IsRedundant(Literal l) const { return is_redundant_[l.Index()]; }
bool IsRedundant(Literal l) const { return is_redundant_[l]; }
int64_t num_redundant_literals() const {
CHECK_EQ(num_redundant_literals_ % 2, 0);
return num_redundant_literals_;
Expand Down Expand Up @@ -676,8 +676,7 @@ class BinaryImplicationGraph : public SatPropagator {
// our implications_ database. Except if ComputeTransitiveReduction()
// was aborted early, but in this case, if only one is present, the
// other could be removed, so we shouldn't need to output it.
if (a < b &&
duplicate_detection.insert({a.Index(), b.Index()}).second) {
if (a < b && duplicate_detection.insert({a, b}).second) {
out->AddBinaryClause(a, b);
}
}
Expand Down Expand Up @@ -711,7 +710,7 @@ class BinaryImplicationGraph : public SatPropagator {
// called, and we update it in some situation but we don't deal with fixed
// variables, at_most ones and duplicates implications for now.
int DirectImplicationsEstimatedSize(Literal literal) const {
return estimated_sizes_[literal.Index()];
return estimated_sizes_[literal];
}

// Variable elimination by replacing everything of the form a => var => b by a
Expand All @@ -725,7 +724,7 @@ class BinaryImplicationGraph : public SatPropagator {
int64_t NumImplicationOnVariableRemoval(BooleanVariable var);
void RemoveBooleanVariable(
BooleanVariable var, std::deque<std::vector<Literal>>* postsolve_clauses);
bool IsRemoved(Literal l) const { return is_removed_[l.Index()]; }
bool IsRemoved(Literal l) const { return is_removed_[l]; }

// TODO(user): consider at most ones.
void CleanupAllRemovedVariables();
Expand Down
21 changes: 15 additions & 6 deletions ortools/sat/cp_model_checker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,13 +366,22 @@ std::string ValidateIntDivConstraint(const CpModelProto& model,
RETURN_IF_NOT_EMPTY(ValidateAffineExpression(model, ct.int_div().exprs(1)));
RETURN_IF_NOT_EMPTY(ValidateAffineExpression(model, ct.int_div().target()));

const LinearExpressionProto& divisor_proto = ct.int_div().exprs(1);
if (MinOfExpression(model, divisor_proto) <= 0 &&
MaxOfExpression(model, divisor_proto) >= 0) {
return absl::StrCat("The divisor cannot span across zero in constraint: ",
ProtobufShortDebugString(ct));
const LinearExpressionProto& denom = ct.int_div().exprs(1);
const int64_t offset = denom.offset();
if (denom.vars().empty()) {
if (offset == 0) {
return absl::StrCat("Division by 0: ", ProtobufShortDebugString(ct));
}
} else {
const int64_t coeff = denom.coeffs(0);
CHECK_NE(coeff, 0);
const int64_t inverse_of_zero = -offset / coeff;
if (inverse_of_zero * coeff + offset == 0 &&
DomainOfRef(model, denom.vars(0)).Contains(inverse_of_zero)) {
return absl::StrCat("The domain of the divisor cannot contain 0: ",
ProtobufShortDebugString(ct));
}
}

return "";
}

Expand Down
2 changes: 1 addition & 1 deletion ortools/sat/docs/scheduling.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ public class OptionalIntervalSampleSat

## NoOverlap constraint

A NoOverlap constraint simply states that all intervals are disjoint. It is
A no_overlap constraint simply states that all intervals are disjoint. It is
built with a list of interval variables. Fixed intervals are useful for
excluding part of the timeline.

Expand Down
18 changes: 9 additions & 9 deletions ortools/sat/integer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ void IntegerEncoder::AssociateToIntegerLiteral(Literal literal,
if (new_size > reverse_encoding_.size()) {
reverse_encoding_.resize(new_size);
}
reverse_encoding_[literal.Index()].push_back(canonical_pair.first);
reverse_encoding_[literal].push_back(canonical_pair.first);
reverse_encoding_[literal.NegatedIndex()].push_back(canonical_pair.second);

// Detect the case >= max or <= min and properly register them. Note that
Expand Down Expand Up @@ -440,17 +440,17 @@ void IntegerEncoder::AssociateToIntegerEqualValue(Literal literal,
if (value == 1 && domain.Min() >= 0 && domain.Max() <= 1) {
if (literal.Index() >= literal_view_.size()) {
literal_view_.resize(literal.Index().value() + 1, kNoIntegerVariable);
literal_view_[literal.Index()] = var;
} else if (literal_view_[literal.Index()] == kNoIntegerVariable) {
literal_view_[literal.Index()] = var;
literal_view_[literal] = var;
} else if (literal_view_[literal] == kNoIntegerVariable) {
literal_view_[literal] = var;
}
}
if (value == -1 && domain.Min() >= -1 && domain.Max() <= 0) {
if (literal.Index() >= literal_view_.size()) {
literal_view_.resize(literal.Index().value() + 1, kNoIntegerVariable);
literal_view_[literal.Index()] = NegationOf(var);
} else if (literal_view_[literal.Index()] == kNoIntegerVariable) {
literal_view_[literal.Index()] = NegationOf(var);
literal_view_[literal] = NegationOf(var);
} else if (literal_view_[literal] == kNoIntegerVariable) {
literal_view_[literal] = NegationOf(var);
}
}

Expand Down Expand Up @@ -519,7 +519,7 @@ void IntegerEncoder::AssociateToIntegerEqualValue(Literal literal,
if (new_size > reverse_equality_encoding_.size()) {
reverse_equality_encoding_.resize(new_size);
}
reverse_equality_encoding_[literal.Index()].push_back({var, value});
reverse_equality_encoding_[literal].push_back({var, value});
}

bool IntegerEncoder::IsFixedOrHasAssociatedLiteral(IntegerLiteral i_lit) const {
Expand Down Expand Up @@ -2118,7 +2118,7 @@ void GenericLiteralWatcher::UpdateCallingNeeds(Trail* trail) {
while (propagation_trail_index_ < trail->Index()) {
const Literal literal = (*trail)[propagation_trail_index_++];
if (literal.Index() >= literal_to_watcher_.size()) continue;
for (const auto entry : literal_to_watcher_[literal.Index()]) {
for (const auto entry : literal_to_watcher_[literal]) {
if (!in_queue_[entry.id]) {
in_queue_[entry.id] = true;
queue_by_priority_[id_to_priority_[entry.id]].push_back(entry.id);
Expand Down
8 changes: 4 additions & 4 deletions ortools/sat/integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ class IntegerEncoder {
if (lit.Index() >= reverse_encoding_.size()) {
return empty_integer_literal_vector_;
}
return reverse_encoding_[lit.Index()];
return reverse_encoding_[lit];
}

// Returns the variable == value pairs that were associated with the given
Expand All @@ -575,7 +575,7 @@ class IntegerEncoder {
if (lit.Index() >= reverse_equality_encoding_.size()) {
return empty_integer_value_vector_;
}
return reverse_equality_encoding_[lit.Index()];
return reverse_equality_encoding_[lit];
}

// Returns all the variables for which this literal is associated to either
Expand All @@ -598,7 +598,7 @@ class IntegerEncoder {
// calling AssociateToIntegerEqualValue().
IntegerVariable GetLiteralView(Literal lit) const {
if (lit.Index() >= literal_view_.size()) return kNoIntegerVariable;
return literal_view_[lit.Index()];
return literal_view_[lit];
}

// If this is true, then a literal can be linearized with an affine expression
Expand Down Expand Up @@ -1729,7 +1729,7 @@ inline void GenericLiteralWatcher::WatchLiteral(Literal l, int id,
if (l.Index() >= literal_to_watcher_.size()) {
literal_to_watcher_.resize(l.Index().value() + 1);
}
literal_to_watcher_[l.Index()].push_back({id, watch_index});
literal_to_watcher_[l].push_back({id, watch_index});
}

inline void GenericLiteralWatcher::WatchLowerBound(IntegerVariable var, int id,
Expand Down
86 changes: 53 additions & 33 deletions ortools/sat/integer_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,12 @@
#include "ortools/base/logging.h"
#include "ortools/base/mathutil.h"
#include "ortools/base/stl_util.h"
#include "ortools/base/types.h"
#include "ortools/sat/integer.h"
#include "ortools/sat/linear_constraint.h"
#include "ortools/sat/model.h"
#include "ortools/sat/sat_base.h"
#include "ortools/sat/sat_solver.h"
#include "ortools/sat/util.h"
#include "ortools/util/saturated_arithmetic.h"
#include "ortools/util/sorted_interval_list.h"
#include "ortools/util/strong_integers.h"
#include "ortools/util/time_limit.h"
Expand Down Expand Up @@ -114,7 +112,7 @@ std::pair<IntegerValue, IntegerValue>
LinearConstraintPropagator<use_int128>::ConditionalLb(
IntegerLiteral integer_literal, IntegerVariable target_var) const {
// The code below is wrong if integer_literal and target_var are the same.
// In this case we return the trival bounds.
// In this case we return the trivial bounds.
if (PositiveVariable(integer_literal.var) == PositiveVariable(target_var)) {
if (integer_literal.var == target_var) {
return {kMinIntegerValue, integer_literal.bound};
Expand Down Expand Up @@ -704,7 +702,7 @@ bool LinMinPropagator::PropagateLinearUpperBound(
integer_trail_->RelaxLinearReason(
propagation_slack, reason_coeffs, trail_indices_reason);
}
// Now add the old integer_reason that triggered this propatation.
// Now add the old integer_reason that triggered this propagation.
for (IntegerLiteral reason_lit :
integer_reason_for_unique_candidate_) {
const int index = integer_trail_->FindTrailIndexOfVarBefore(
Expand Down Expand Up @@ -839,7 +837,7 @@ bool ProductPropagator::CanonicalizeCases() {
p_.GreaterOrEqual(0), {a_.GreaterOrEqual(0), b_.GreaterOrEqual(0)});
}

// Otherwise, make sure p is non-negative or accros zero.
// Otherwise, make sure p is non-negative or across zero.
if (integer_trail_->UpperBound(p_) <= 0) {
if (integer_trail_->LowerBound(a_) < 0) {
DCHECK_GT(integer_trail_->UpperBound(a_), 0);
Expand Down Expand Up @@ -1188,75 +1186,94 @@ DivisionPropagator::DivisionPropagator(AffineExpression num,
: num_(num),
denom_(denom),
div_(div),
negated_denom_(denom.Negated()),
negated_num_(num.Negated()),
negated_div_(div.Negated()),
integer_trail_(integer_trail) {
// The denominator can never be zero.
CHECK_GT(integer_trail->LevelZeroLowerBound(denom), 0);
}
integer_trail_(integer_trail) {}

bool DivisionPropagator::Propagate() {
if (!PropagateSigns()) return false;
if (integer_trail_->LowerBound(denom_) < 0 &&
integer_trail_->UpperBound(denom_) > 0) {
return true;
}

AffineExpression num = num_;
AffineExpression negated_num = negated_num_;
AffineExpression denom = denom_;
AffineExpression negated_denom = negated_denom_;

if (integer_trail_->UpperBound(num_) >= 0 &&
if (integer_trail_->UpperBound(denom) < 0) {
std::swap(num, negated_num);
std::swap(denom, negated_denom);
}

if (!PropagateSigns(num, denom, div_)) return false;

if (integer_trail_->UpperBound(num) >= 0 &&
integer_trail_->UpperBound(div_) >= 0 &&
!PropagateUpperBounds(num_, denom_, div_)) {
!PropagateUpperBounds(num, denom, div_)) {
return false;
}

if (integer_trail_->UpperBound(negated_num_) >= 0 &&
if (integer_trail_->UpperBound(negated_num) >= 0 &&
integer_trail_->UpperBound(negated_div_) >= 0 &&
!PropagateUpperBounds(negated_num_, denom_, negated_div_)) {
!PropagateUpperBounds(negated_num, denom, negated_div_)) {
return false;
}

if (integer_trail_->LowerBound(num_) >= 0 &&
if (integer_trail_->LowerBound(num) >= 0 &&
integer_trail_->LowerBound(div_) >= 0) {
return PropagatePositiveDomains(num_, denom_, div_);
return PropagatePositiveDomains(num, denom, div_);
}

if (integer_trail_->UpperBound(num_) <= 0 &&
if (integer_trail_->UpperBound(num) <= 0 &&
integer_trail_->UpperBound(div_) <= 0) {
return PropagatePositiveDomains(negated_num_, denom_, negated_div_);
return PropagatePositiveDomains(negated_num, denom, negated_div_);
}

return true;
}

bool DivisionPropagator::PropagateSigns() {
const IntegerValue min_num = integer_trail_->LowerBound(num_);
const IntegerValue max_num = integer_trail_->UpperBound(num_);
const IntegerValue min_div = integer_trail_->LowerBound(div_);
const IntegerValue max_div = integer_trail_->UpperBound(div_);
bool DivisionPropagator::PropagateSigns(AffineExpression num,
AffineExpression denom,
AffineExpression div) {
const IntegerValue min_num = integer_trail_->LowerBound(num);
const IntegerValue max_num = integer_trail_->UpperBound(num);
const IntegerValue min_div = integer_trail_->LowerBound(div);
const IntegerValue max_div = integer_trail_->UpperBound(div);

// If num >= 0, as denom > 0, then div must be >= 0.
if (min_num >= 0 && min_div < 0) {
if (!integer_trail_->SafeEnqueue(div_.GreaterOrEqual(0),
{num_.GreaterOrEqual(0)})) {
if (!integer_trail_->SafeEnqueue(
div.GreaterOrEqual(0),
{num.GreaterOrEqual(0), denom.GreaterOrEqual(1)})) {
return false;
}
}

// If div > 0, as denom > 0, then num must be > 0.
if (min_num <= 0 && min_div > 0) {
if (!integer_trail_->SafeEnqueue(num_.GreaterOrEqual(1),
{div_.GreaterOrEqual(1)})) {
if (!integer_trail_->SafeEnqueue(
num.GreaterOrEqual(1),
{div.GreaterOrEqual(1), denom.GreaterOrEqual(1)})) {
return false;
}
}

// If num <= 0, as denom > 0, then div must be <= 0.
if (max_num <= 0 && max_div > 0) {
if (!integer_trail_->SafeEnqueue(div_.LowerOrEqual(0),
{num_.LowerOrEqual(0)})) {
if (!integer_trail_->SafeEnqueue(
div.LowerOrEqual(0),
{num.LowerOrEqual(0), denom.GreaterOrEqual(1)})) {
return false;
}
}

// If div < 0, as denom > 0, then num must be < 0.
if (max_num >= 0 && max_div < 0) {
if (!integer_trail_->SafeEnqueue(num_.LowerOrEqual(-1),
{div_.LowerOrEqual(-1)})) {
if (!integer_trail_->SafeEnqueue(
num.LowerOrEqual(-1),
{div.LowerOrEqual(-1), denom.GreaterOrEqual(1)})) {
return false;
}
}
Expand Down Expand Up @@ -1291,6 +1308,7 @@ bool DivisionPropagator::PropagateUpperBounds(AffineExpression num,
if (!integer_trail_->SafeEnqueue(
num.LowerOrEqual(new_max_num),
{integer_trail_->UpperBoundAsLiteral(denom),
denom.GreaterOrEqual(1),
integer_trail_->UpperBoundAsLiteral(div)})) {
return false;
}
Expand All @@ -1314,7 +1332,8 @@ bool DivisionPropagator::PropagatePositiveDomains(AffineExpression num,
if (!integer_trail_->SafeEnqueue(
div.GreaterOrEqual(new_min_div),
{integer_trail_->LowerBoundAsLiteral(num),
integer_trail_->UpperBoundAsLiteral(denom)})) {
integer_trail_->UpperBoundAsLiteral(denom),
denom.GreaterOrEqual(1)})) {
return false;
}
}
Expand Down Expand Up @@ -1342,7 +1361,8 @@ bool DivisionPropagator::PropagatePositiveDomains(AffineExpression num,
if (!integer_trail_->SafeEnqueue(
denom.LowerOrEqual(new_max_denom),
{integer_trail_->UpperBoundAsLiteral(num), num.GreaterOrEqual(0),
integer_trail_->LowerBoundAsLiteral(div)})) {
integer_trail_->LowerBoundAsLiteral(div),
denom.GreaterOrEqual(1)})) {
return false;
}
}
Expand Down
Loading

0 comments on commit f1bbd65

Please sign in to comment.