Skip to content

Commit

Permalink
Implementing JuMP format scalar and indexed
Browse files Browse the repository at this point in the history
variables.
  • Loading branch information
jezsadler committed Sep 28, 2024
1 parent 3c20611 commit 29b89bc
Show file tree
Hide file tree
Showing 4 changed files with 291 additions and 19 deletions.
279 changes: 270 additions & 9 deletions src/omlt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@
from abc import ABC, abstractmethod
import pyomo.environ as pyo

from omlt.dependencies import julia_available, moi_available

if julia_available and moi_available:
from juliacall import Main as jl
from juliacall import Base

jl.seval("import MathOptInterface")
moi = jl.MathOptInterface
jl.seval("import JuMP")
jump = jl.JuMP


class OmltVar(ABC):
def __new__(cls, *indexes, **kwargs):
Expand All @@ -26,7 +37,7 @@ def __new__(cls, *args, format="pyomo", **kwargs):
if format not in subclass_map:
raise ValueError(
f"Variable format %s not recognized. Supported formats "
"are 'pyomo' or 'moi'.",
"are 'pyomo' or 'jump'.",
format,
)
subclass = subclass_map[format]
Expand Down Expand Up @@ -76,6 +87,12 @@ def ub(self):
def ub(self, val):
pass

def is_component_type(self):
return True

def is_indexed(self):
return False

# @abstractmethod
# def __mul__(self, other):
# pass
Expand Down Expand Up @@ -204,15 +221,128 @@ def __abs__(self):
return pyo.NumericValue.__abs__(self)


class OmltScalarJuMP(OmltScalar):
format = "jump"

# Claim to be a Pyomo Var so blocks will register
# properly.
@property
def __class__(self):
return pyo.ScalarVar

def __init__(self, *args, **kwargs):

self._block = kwargs.pop("block", None)

self._bounds = kwargs.pop("bounds", None)

if isinstance(self._bounds, tuple) and len(self._bounds) == 2:
_lb = self._bounds[0]
_has_lb = _lb is not None
_ub = self._bounds[1]
_has_ub = _ub is not None
elif self._bounds is None:
_has_lb = False
_lb = None
_has_ub = False
_ub = None
else:
raise ValueError("Bounds must be given as a tuple")

_domain = kwargs.pop("domain", None)
_within = kwargs.pop("within", None)

if _domain and _within and _domain != _within:
raise ValueError(
"'domain' and 'within' keywords have both "
"been supplied and do not agree. Please try "
"with a single keyword for the domain of this "
"variable."
)
elif _domain:
self.domain = _domain
elif _within:
self.domain = _within
else:
self.domain = None

if self.domain == pyo.Binary:
self.binary = True
else:
self.binary = False
if self.domain == pyo.Integers:
self.integer = True
else:
self.integer = False

_initialize = kwargs.pop("initialize", None)

if _initialize:
self._value = _initialize
else:
self._value = None

self._jumpvarinfo = jump.VariableInfo(
_has_lb,
_lb,
_has_ub,
_ub,
False, # is fixed
None, # fixed value
_initialize is not None,
self._value,
self.binary,
self.integer,
)
self._constructed = False
self._parent = None
self._ctype = pyo.ScalarVar

def construct(self, data):
if self._block:
self._jumpvar = jump.add_variable(self._block, self._jumpvarinfo)
else:
self._jumpvar = jump.build_variable(Base.error, self._jumpvarinfo)
self._constructed = True

def fix(self, value, skip_validation):
self.fixed = True
self._value = value

@property
def bounds(self):
pass

@bounds.setter
def bounds(self, val):
pass

@property
def lb(self):
return self._jumpvar.info.lower_bound

@lb.setter
def lb(self, val):
jump.set_upper_bound(self._jumpvar, val)

@property
def ub(self):
return self._jumpvar.info.upper_bound

@ub.setter
def ub(self, val):
jump.set_upper_bound(self._jumpvar, val)

def to_jump(self):
if self._constructed:
return self._jumpvar


"""
Future formats to implement.
"""


class OmltScalarMOI(OmltScalar):
format = "moi"


class OmltScalarSmoke(OmltScalar):
format = "smoke"

Expand Down Expand Up @@ -257,11 +387,16 @@ def setub(self, value):
def setlb(self, value):
pass

def valid_model_component(self):
"""Return True if this can be used as a model component."""
return True


class OmltIndexedPyomo(pyo.Var, OmltIndexed):
format = "pyomo"

def __init__(self, *indexes, **kwargs):
kwargs.pop("format", None)
super().__init__(*indexes, **kwargs)

def fix(self, value=None, skip_validation=False):
Expand All @@ -282,15 +417,141 @@ def setlb(self, value):
vardata.lb = value


class OmltIndexedJuMP(OmltIndexed):
format = "jump"

# Claim to be a Pyomo Var so blocks will register
# properly.
@property
def __class__(self):
return pyo.Var

def __init__(self, *indexes, **kwargs):
if len(indexes) == 1:
index_set = indexes[0]
i_dict = {}
for i, val in enumerate(index_set):
i_dict[i] = val
self._index_set = tuple(i_dict[i] for i in range(len(index_set)))
else:
raise ValueError("Currently index cross-products are unsupported.")
self._varinfo = {}
for idx in self._index_set:
self._varinfo[idx] = jump.VariableInfo(
False, # _has_lb,
None, # _lb,
False, # _has_ub,
None, # _ub,
False, # is fixed
None, # fix value
False, # _initialize is not None,
None, # self._value,
False, # self.binary,
False, # self.integer
)
self._vars = {}
self._constructed = False
self._ctype = pyo.Var
self._parent = None

def __getitem__(self, item):
if isinstance(item, tuple) and len(item) == 1:
return self._vars[item[0]]
else:
return self._vars[item]

def __setitem__(self, item, value):
self._vars[item] = value

def keys(self):
return self._vars.keys()

def values(self):
return self._vars.values()

def items(self):
return self._vars.items()

def fix(self, value=None, skip_validation=False):
self.fixed = True
if value is None:
for vardata in self.values():
vardata.fix(skip_validation)
else:
for vardata in self.values():
vardata.fix(value, skip_validation)

def __len__(self):
"""
Return the number of component data objects stored by this
component.
"""
return len(self._vars)

def __contains__(self, idx):
"""Return true if the index is in the dictionary"""
return idx in self._vars

# The default implementation is for keys() and __iter__ to be
# synonyms. The logic is implemented in keys() so that
# keys/values/items continue to work for components that implement
# other definitions for __iter__ (e.g., Set)
def __iter__(self):
"""Return an iterator of the component data keys"""
return self._vars.__iter__()

def construct(self, data=None):
for idx in self._index_set:
self._vars[idx] = jump.build_variable(Base.error, self._varinfo[idx])
self._constructed = True

def setub(self, value):
if self._constructed:
for idx in self.index_set():
self._varinfo[idx].has_ub = True
self._varinfo[idx].upper_bound = value
self._vars[idx].info.has_ub = True
self._vars[idx].info.upper_bound = value
else:
for idx in self.index_set():
self._varinfo[idx].has_ub = True
self._varinfo[idx].upper_bound = value

def setlb(self, value):
if self._constructed:
for idx in self.index_set():
self._varinfo[idx].has_lb = True
self._varinfo[idx].lower_bound = value
self._vars[idx].info.has_lb = True
self._vars[idx].info.lower_bound = value
else:
for idx in self.index_set():
self._varinfo[idx].has_lb = True
self._varinfo[idx].lower_bound = value

@property
def ctype(self):
return self._ctype

def index_set(self):
return self._index_set

@property
def name(self):
return self._name

def to_jump(self):
if self._constructed:
return jump.Containers.DenseAxisArray(
list(self._vars.values()), self.index_set()
)


"""
Future formats to implement.
"""


class OmltIndexedMOI(OmltIndexed):
format = "moi"


class OmltIndexedSmoke(OmltIndexed):
format = "smoke"

Expand Down
9 changes: 7 additions & 2 deletions src/omlt/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, component):
self.__formulation = None
self.__input_indexes = None
self.__output_indexes = None
self.__format = "pyomo"

def _setup_inputs_outputs(self, *, input_indexes, output_indexes):
"""Setup inputs and outputs.
Expand All @@ -58,9 +59,9 @@ def _setup_inputs_outputs(self, *, input_indexes, output_indexes):
self.__output_indexes = output_indexes

self.inputs_set = pyo.Set(initialize=input_indexes)
self.inputs = OmltVar(self.inputs_set, initialize=0)
self.inputs = OmltVar(self.inputs_set, initialize=0, format=self.__format)
self.outputs_set = pyo.Set(initialize=output_indexes)
self.outputs = OmltVar(self.outputs_set, initialize=0)
self.outputs = OmltVar(self.outputs_set, initialize=0, format=self.__format)

def build_formulation(self, formulation):
"""Build formulation.
Expand All @@ -74,6 +75,10 @@ def build_formulation(self, formulation):
----------
formulation : instance of _PyomoFormulation
see, for example, FullSpaceNNFormulation
format : str
Which modelling language to build the formulation in.
Currently supported are "pyomo" (default) and "jump".
"""
if not formulation.input_indexes:
msg = (
Expand Down
10 changes: 10 additions & 0 deletions src/omlt/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,13 @@

torch_geometric, torch_geometric_available = attempt_import("torch_geometric")
lineartree, lineartree_available = attempt_import("lineartree")

julia, julia_available = attempt_import("juliacall")

if julia_available:
from juliacall import Main as jl
try:
jl.seval("import MathOptInterface")
moi_available = True
except jl.ArgumentError:
moi_available = False
Loading

0 comments on commit 29b89bc

Please sign in to comment.