Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow add_constraint and add_objective to accept expressions #23

Merged
merged 2 commits into from
May 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions ilpy/wrapper.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from enum import IntEnum, auto

from ilpy.expressions import Expression

class Preference(IntEnum):
Any = auto()
Scip = auto()
Expand Down Expand Up @@ -87,9 +89,9 @@ class Solver:
variable_types: dict[int, VariableType] | None = None,
preference: Preference = Preference.Any,
) -> None: ...
def set_objective(self, objective: Objective) -> None: ...
def set_objective(self, objective: Objective | Expression) -> None: ...
def set_constraints(self, constraints: Constraints) -> None: ...
def add_constraint(self, constraint: Constraint) -> None: ...
def add_constraint(self, constraint: Constraint | Expression) -> None: ...
def set_timeout(self, timeout: float) -> None: ...
def set_optimality_gap(self, gap: float, absolute: bool = False) -> None: ...
def set_num_threads(self, num_threads: int) -> None: ...
Expand Down
22 changes: 18 additions & 4 deletions ilpy/wrapper.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# distutils: language = c++
from typing import TYPE_CHECKING

from libc.stdint cimport uint32_t
from libcpp.memory cimport shared_ptr
Expand All @@ -7,6 +8,9 @@ from libcpp.string cimport string
from cython.operator cimport dereference as deref
cimport decl

if TYPE_CHECKING:
from .expression import Expression

####################################
# Enums #
####################################
Expand Down Expand Up @@ -191,14 +195,24 @@ cdef class Solver:
self.num_variables = num_variables
deref(self.p).initialize(num_variables, default_variable_type, vtypes)

def set_objective(self, Objective objective):
deref(self.p).setObjective(objective.p[0])
def set_objective(self, objective: Objective | Expression):
cdef Objective obj
if hasattr(objective, "as_objective"):
obj = objective.as_objective()
else:
obj = objective
deref(self.p).setObjective(obj.p[0])

def set_constraints(self, Constraints constraints):
deref(self.p).setConstraints(constraints.p[0])

def add_constraint(self, Constraint constraint):
deref(self.p).addConstraint(constraint.p[0])
def add_constraint(self, constraint: Constraint | Expression):
cdef Constraint const
if hasattr(constraint, "as_constraint"):
const = constraint.as_constraint()
else:
const = constraint
deref(self.p).addConstraint(const.p[0])

def set_timeout(self, timeout):
deref(self.p).setTimeout(timeout)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_simple_solver(preference: ilpy.Preference, as_expression: bool) -> None
# note: the Constant(0) here is only to satisfy mypy... it would work without
_e = sum((Variable(str(i), index=i) for i in range(num_vars)), Constant(0))
_e += 0.5 * Variable(str(special_var), index=special_var)
objective = _e.as_objective()
objective = _e
else:
objective = ilpy.Objective()
for i in range(num_vars):
Expand All @@ -45,7 +45,7 @@ def test_simple_solver(preference: ilpy.Preference, as_expression: bool) -> None
# constraints
if as_expression:
_e = sum((Variable(str(i), index=i) for i in range(num_vars)), Constant(0))
constraint = (_e == 1).as_constraint()
constraint = _e == 1
else:
constraint = ilpy.Constraint()
for i in range(num_vars):
Expand Down