diff --git a/ilpy/wrapper.pyi b/ilpy/wrapper.pyi index e95b460..1dbdcbc 100644 --- a/ilpy/wrapper.pyi +++ b/ilpy/wrapper.pyi @@ -1,5 +1,7 @@ from enum import IntEnum, auto +from ilpy.expressions import Expression + class Preference(IntEnum): Any = auto() Scip = auto() @@ -87,9 +89,9 @@ class Solver: variable_types: dict[int, VariableType] | None = None, preference: Preference = Preference.Any, ) -> None: ... - def set_objective(self, objective: Objective) -> None: ... + def set_objective(self, objective: Objective | Expression) -> None: ... def set_constraints(self, constraints: Constraints) -> None: ... - def add_constraint(self, constraint: Constraint) -> None: ... + def add_constraint(self, constraint: Constraint | Expression) -> None: ... def set_timeout(self, timeout: float) -> None: ... def set_optimality_gap(self, gap: float, absolute: bool = False) -> None: ... def set_num_threads(self, num_threads: int) -> None: ... diff --git a/ilpy/wrapper.pyx b/ilpy/wrapper.pyx index d580504..b386d26 100644 --- a/ilpy/wrapper.pyx +++ b/ilpy/wrapper.pyx @@ -1,4 +1,5 @@ # distutils: language = c++ +from typing import TYPE_CHECKING from libc.stdint cimport uint32_t from libcpp.memory cimport shared_ptr @@ -7,6 +8,9 @@ from libcpp.string cimport string from cython.operator cimport dereference as deref cimport decl +if TYPE_CHECKING: + from .expression import Expression + #################################### # Enums # #################################### @@ -191,14 +195,24 @@ cdef class Solver: self.num_variables = num_variables deref(self.p).initialize(num_variables, default_variable_type, vtypes) - def set_objective(self, Objective objective): - deref(self.p).setObjective(objective.p[0]) + def set_objective(self, objective: Objective | Expression): + cdef Objective obj + if hasattr(objective, "as_objective"): + obj = objective.as_objective() + else: + obj = objective + deref(self.p).setObjective(obj.p[0]) def set_constraints(self, Constraints constraints): deref(self.p).setConstraints(constraints.p[0]) - def add_constraint(self, Constraint constraint): - deref(self.p).addConstraint(constraint.p[0]) + def add_constraint(self, constraint: Constraint | Expression): + cdef Constraint const + if hasattr(constraint, "as_constraint"): + const = constraint.as_constraint() + else: + const = constraint + deref(self.p).addConstraint(const.p[0]) def set_timeout(self, timeout): deref(self.p).setTimeout(timeout) diff --git a/tests/test_solvers.py b/tests/test_solvers.py index b69e05e..5d9aa55 100644 --- a/tests/test_solvers.py +++ b/tests/test_solvers.py @@ -35,7 +35,7 @@ def test_simple_solver(preference: ilpy.Preference, as_expression: bool) -> None # note: the Constant(0) here is only to satisfy mypy... it would work without _e = sum((Variable(str(i), index=i) for i in range(num_vars)), Constant(0)) _e += 0.5 * Variable(str(special_var), index=special_var) - objective = _e.as_objective() + objective = _e else: objective = ilpy.Objective() for i in range(num_vars): @@ -45,7 +45,7 @@ def test_simple_solver(preference: ilpy.Preference, as_expression: bool) -> None # constraints if as_expression: _e = sum((Variable(str(i), index=i) for i in range(num_vars)), Constant(0)) - constraint = (_e == 1).as_constraint() + constraint = _e == 1 else: constraint = ilpy.Constraint() for i in range(num_vars):