Skip to content

Commit

Permalink
[CP-SAT] more work on python layer
Browse files Browse the repository at this point in the history
  • Loading branch information
lperron committed Dec 29, 2024
1 parent 39248b0 commit 2b22356
Show file tree
Hide file tree
Showing 3 changed files with 316 additions and 212 deletions.
310 changes: 150 additions & 160 deletions ortools/sat/python/linear_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ LinearExpr* LinearExpr::Constant(int64_t value) {
return new IntConstant(value);
}

LinearExpr* LinearExpr::Add(LinearExpr* other) {
return new BinaryAdd(this, other);
LinearExpr* LinearExpr::Add(LinearExpr* expr) {
return new BinaryAdd(this, expr);
}

LinearExpr* LinearExpr::AddInt(int64_t cst) {
Expand All @@ -203,40 +203,40 @@ LinearExpr* LinearExpr::AddDouble(double cst) {
return new FloatAffine(this, 1.0, cst);
}

LinearExpr* LinearExpr::Sub(ExprOrValue other) {
if (other.expr != nullptr) {
return new IntWeightedSum({this, other.expr}, {1, -1}, 0);
} else if (other.double_value != 0.0) {
return new FloatAffine(this, 1.0, -other.double_value);
} else if (other.int_value != 0) {
return new IntAffine(this, 1, -other.int_value);
} else {
return this;
}
LinearExpr* LinearExpr::Sub(LinearExpr* expr) {
return new IntWeightedSum({this, expr}, {1, -1}, 0);
}

LinearExpr* LinearExpr::RSub(ExprOrValue other) {
if (other.expr != nullptr) {
return new IntWeightedSum({this, other.expr}, {-1, 1}, 0);
} else if (other.double_value != 0.0) {
return new FloatAffine(this, -1.0, other.double_value);
} else {
return new IntAffine(this, -1, other.int_value);
}
LinearExpr* LinearExpr::SubInt(int64_t cst) {
if (cst == 0) return this;
return new IntAffine(this, 1, -cst);
}

LinearExpr* LinearExpr::Mul(double cst) {
if (cst == 0.0) return new IntConstant(0);
if (cst == 1.0) return this;
return new FloatAffine(this, cst, 0.0);
LinearExpr* LinearExpr::SubDouble(double cst) {
if (cst == 0.0) return this;
return new FloatAffine(this, 1.0, -cst);
}

LinearExpr* LinearExpr::Mul(int64_t cst) {
LinearExpr* LinearExpr::RSubInt(int64_t cst) {
return new IntAffine(this, -1, cst);
}

LinearExpr* LinearExpr::RSubDouble(double cst) {
return new FloatAffine(this, -1.0, cst);
}

LinearExpr* LinearExpr::MulInt(int64_t cst) {
if (cst == 0) return new IntConstant(0);
if (cst == 1) return this;
return new IntAffine(this, cst, 0);
}

LinearExpr* LinearExpr::MulDouble(double cst) {
if (cst == 0.0) return new IntConstant(0);
if (cst == 1.0) return this;
return new FloatAffine(this, cst, 0.0);
}

LinearExpr* LinearExpr::Neg() { return new IntAffine(this, -1, 0); }

void FloatExprVisitor::AddToProcess(LinearExpr* expr, double coeff) {
Expand Down Expand Up @@ -449,152 +449,142 @@ std::string FloatAffine::DebugString() const {
return absl::StrCat("FloatAffine(expr=", expr_->DebugString(),
", coeff=", coeff_, ", offset=", offset_, ")");
}
BoundedLinearExpression* LinearExpr::Eq(LinearExpr* rhs) {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
lin.AddToProcess(rhs, -1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(vars, coeffs, offset, Domain(0));
}

BoundedLinearExpression* LinearExpr::Eq(ExprOrValue other) {
if (other.double_value != 0.0) return nullptr;
if (other.expr != nullptr) {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
lin.AddToProcess(other.expr, -1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(vars, coeffs, offset, Domain(0));
} else {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(vars, coeffs, offset,
Domain(other.int_value));
}
BoundedLinearExpression* LinearExpr::EqCst(int64_t rhs) {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(vars, coeffs, offset, Domain(rhs));
}

BoundedLinearExpression* LinearExpr::Ne(ExprOrValue other) {
if (other.double_value != 0.0) return nullptr;
if (other.expr != nullptr) {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
lin.AddToProcess(other.expr, -1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(vars, coeffs, offset,
Domain(0).Complement());
} else {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(vars, coeffs, offset,
Domain(other.int_value).Complement());
}
BoundedLinearExpression* LinearExpr::Ne(LinearExpr* rhs) {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
lin.AddToProcess(rhs, -1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(vars, coeffs, offset,
Domain(0).Complement());
}

BoundedLinearExpression* LinearExpr::Le(ExprOrValue other) {
if (other.double_value != 0.0) return nullptr;
if (other.expr != nullptr) {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
lin.AddToProcess(other.expr, -1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(
vars, coeffs, offset, Domain(std::numeric_limits<int64_t>::min(), 0));
} else {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(
vars, coeffs, offset,
Domain(std::numeric_limits<int64_t>::min(), other.int_value));
}
BoundedLinearExpression* LinearExpr::NeCst(int64_t rhs) {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(vars, coeffs, offset,
Domain(rhs).Complement());
}

BoundedLinearExpression* LinearExpr::Lt(ExprOrValue other) {
if (other.double_value != 0.0) return nullptr;
if (other.expr != nullptr) {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
lin.AddToProcess(other.expr, -1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(
vars, coeffs, offset, Domain(std::numeric_limits<int64_t>::min(), -1));
} else {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(
vars, coeffs, offset,
Domain(std::numeric_limits<int64_t>::min(), other.int_value - 1));
}
BoundedLinearExpression* LinearExpr::Le(LinearExpr* rhs) {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
lin.AddToProcess(rhs, -1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(
vars, coeffs, offset, Domain(std::numeric_limits<int64_t>::min(), 0));
}

BoundedLinearExpression* LinearExpr::Ge(ExprOrValue other) {
if (other.double_value != 0.0) return nullptr;
if (other.expr != nullptr) {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
lin.AddToProcess(other.expr, -1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(
vars, coeffs, offset, Domain(0, std::numeric_limits<int64_t>::max()));
} else {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(
vars, coeffs, offset,
Domain(other.int_value, std::numeric_limits<int64_t>::max()));
}
BoundedLinearExpression* LinearExpr::LeCst(int64_t rhs) {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(
vars, coeffs, offset, Domain(std::numeric_limits<int64_t>::min(), rhs));
}

BoundedLinearExpression* LinearExpr::Gt(ExprOrValue other) {
if (other.double_value != 0.0) return nullptr;
if (other.expr != nullptr) {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
lin.AddToProcess(other.expr, -1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(
vars, coeffs, offset, Domain(1, std::numeric_limits<int64_t>::max()));
} else {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(
vars, coeffs, offset,
Domain(other.int_value + 1, std::numeric_limits<int64_t>::max()));
}
BoundedLinearExpression* LinearExpr::Lt(LinearExpr* rhs) {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
lin.AddToProcess(rhs, -1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(
vars, coeffs, offset, Domain(std::numeric_limits<int64_t>::min(), -1));
}

BoundedLinearExpression* LinearExpr::LtCst(int64_t rhs) {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(
vars, coeffs, offset,
Domain(std::numeric_limits<int64_t>::min(), rhs - 1));
}

BoundedLinearExpression* LinearExpr::Ge(LinearExpr* rhs) {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
lin.AddToProcess(rhs, -1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(
vars, coeffs, offset, Domain(0, std::numeric_limits<int64_t>::max()));
}

BoundedLinearExpression* LinearExpr::GeCst(int64_t rhs) {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(
vars, coeffs, offset, Domain(rhs, std::numeric_limits<int64_t>::max()));
}

BoundedLinearExpression* LinearExpr::Gt(LinearExpr* rhs) {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
lin.AddToProcess(rhs, -1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(
vars, coeffs, offset, Domain(1, std::numeric_limits<int64_t>::max()));
}

BoundedLinearExpression* LinearExpr::GtCst(int64_t rhs) {
IntExprVisitor lin;
lin.AddToProcess(this, 1);
std::vector<BaseIntVar*> vars;
std::vector<int64_t> coeffs;
int64_t offset;
if (!lin.Process(&vars, &coeffs, &offset)) return nullptr;
return new BoundedLinearExpression(
vars, coeffs, offset,
Domain(rhs + 1, std::numeric_limits<int64_t>::max()));
}

void IntExprVisitor::AddToProcess(LinearExpr* expr, int64_t coeff) {
Expand Down
31 changes: 20 additions & 11 deletions ortools/sat/python/linear_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,30 @@ class LinearExpr {
static LinearExpr* Constant(int64_t value);
static LinearExpr* Constant(double value);

LinearExpr* Add(LinearExpr* other);
LinearExpr* Add(LinearExpr* expr);
LinearExpr* AddInt(int64_t cst);
LinearExpr* AddDouble(double cst);
LinearExpr* Sub(ExprOrValue other);
LinearExpr* RSub(ExprOrValue other);
LinearExpr* Mul(double cst);
LinearExpr* Mul(int64_t cst);
LinearExpr* Sub(LinearExpr* expr);
LinearExpr* SubInt(int64_t cst);
LinearExpr* SubDouble(double cst);
LinearExpr* RSubInt(int64_t cst);
LinearExpr* RSubDouble(double cst);
LinearExpr* MulInt(int64_t cst);
LinearExpr* MulDouble(double cst);
LinearExpr* Neg();

BoundedLinearExpression* Eq(ExprOrValue other);
BoundedLinearExpression* Ne(ExprOrValue other);
BoundedLinearExpression* Ge(ExprOrValue other);
BoundedLinearExpression* Le(ExprOrValue other);
BoundedLinearExpression* Lt(ExprOrValue other);
BoundedLinearExpression* Gt(ExprOrValue other);
BoundedLinearExpression* Eq(LinearExpr* rhs);
BoundedLinearExpression* EqCst(int64_t rhs);
BoundedLinearExpression* Ne(LinearExpr* rhs);
BoundedLinearExpression* NeCst(int64_t rhs);
BoundedLinearExpression* Ge(LinearExpr* rhs);
BoundedLinearExpression* GeCst(int64_t rhs);
BoundedLinearExpression* Le(LinearExpr* rhs);
BoundedLinearExpression* LeCst(int64_t rhs);
BoundedLinearExpression* Lt(LinearExpr* rhs);
BoundedLinearExpression* LtCst(int64_t rhs);
BoundedLinearExpression* Gt(LinearExpr* rhs);
BoundedLinearExpression* GtCst(int64_t rhs);
};

// Compare the indices of variables.
Expand Down
Loading

0 comments on commit 2b22356

Please sign in to comment.