diff --git a/python/ommx/ommx/v1/__init__.py b/python/ommx/ommx/v1/__init__.py index 20819eb2..d32c0894 100644 --- a/python/ommx/ommx/v1/__init__.py +++ b/python/ommx/ommx/v1/__init__.py @@ -994,8 +994,29 @@ def _as_pandas_entry(self) -> dict: } | {f"parameters.{key}": value for key, value in v.parameters.items()} +class AsConstraint(ABC): + def __le__(self, other) -> Constraint: + return Constraint( + function=self - other, equality=Equality.EQUALITY_LESS_THAN_OR_EQUAL_TO_ZERO + ) + + def __ge__(self, other) -> Constraint: + return Constraint( + function=other - self, equality=Equality.EQUALITY_LESS_THAN_OR_EQUAL_TO_ZERO + ) + + def __req__(self, other) -> Constraint: + return self == other + + def __rle__(self, other) -> Constraint: + return self.__ge__(other) + + def __rge__(self, other) -> Constraint: + return self.__le__(other) + + @dataclass -class Linear: +class Linear(AsConstraint): """ Modeler API for linear function @@ -1232,28 +1253,9 @@ def __eq__(self, other) -> Constraint: # type: ignore[reportIncompatibleMethodO function=self - other, equality=Equality.EQUALITY_EQUAL_TO_ZERO ) - def __le__(self, other) -> Constraint: - return Constraint( - function=self - other, equality=Equality.EQUALITY_LESS_THAN_OR_EQUAL_TO_ZERO - ) - - def __ge__(self, other) -> Constraint: - return Constraint( - function=other - self, equality=Equality.EQUALITY_LESS_THAN_OR_EQUAL_TO_ZERO - ) - - def __req__(self, other) -> Constraint: - return self == other - - def __rle__(self, other) -> Constraint: - return self.__ge__(other) - - def __rge__(self, other) -> Constraint: - return self.__le__(other) - @dataclass -class Quadratic: +class Quadratic(AsConstraint): raw: _Quadratic def __init__( @@ -1448,7 +1450,7 @@ def __neg__(self) -> Linear: @dataclass -class Polynomial: +class Polynomial(AsConstraint): raw: _Polynomial def __init__(self, *, terms: dict[Iterable[int], float | int] = {}): @@ -1633,7 +1635,14 @@ def __neg__(self) -> Linear: def as_function( - f: int | float | DecisionVariable | Linear | Quadratic | Polynomial | _Function, + f: int + | float + | DecisionVariable + | Linear + | Quadratic + | Polynomial + | _Function + | Function, ) -> _Function: if isinstance(f, (int, float)): return _Function(constant=f) @@ -1647,12 +1656,14 @@ def as_function( return _Function(polynomial=f.raw) elif isinstance(f, _Function): return f + elif isinstance(f, Function): + return f.raw else: raise ValueError(f"Unknown function type: {type(f)}") @dataclass -class Function: +class Function(AsConstraint): raw: _Function def __init__( @@ -1888,7 +1899,13 @@ class Constraint: def __init__( self, *, - function: int | float | DecisionVariable | Linear | Quadratic | Polynomial, + function: int + | float + | DecisionVariable + | Linear + | Quadratic + | Polynomial + | Function, equality: Equality.ValueType, id: Optional[int] = None, name: Optional[str] = None,