Skip to content

Commit

Permalink
Removes enforce precision in all codebase (squashed commit)
Browse files Browse the repository at this point in the history
  • Loading branch information
joanrue authored and SepandKashani committed Aug 17, 2024
1 parent 28751cc commit 11d4b4b
Show file tree
Hide file tree
Showing 74 changed files with 651 additions and 1,336 deletions.
33 changes: 0 additions & 33 deletions src/pyxu/abc/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pyxu.abc.operator as pxo
import pyxu.info.deps as pxd
import pyxu.info.ptype as pxt
import pyxu.runtime as pxrt
import pyxu.util as pxu


Expand Down Expand Up @@ -52,12 +51,10 @@ def svdvals(self, **kwargs) -> pxt.NDArray:
D = self.__class__.svdvals(self, **kwargs)
return D

@pxrt.enforce_precision(i=("arr", "damp"))
def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray:
out = self.__class__.pinv(self, arr=arr, damp=damp, **kwargs)
return out

@pxrt.enforce_precision()
def trace(self, **kwargs) -> pxt.Real:
tr = self.__class__.trace(self, **kwargs)
return tr
Expand Down Expand Up @@ -172,7 +169,6 @@ def _infer_op_klass(self) -> pxt.OpC:
klass = pxo.Operator._infer_operator_type(properties)
return klass

@pxrt.enforce_precision(i="arr")
def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
out = pxu.copy_if_unsafe(self._op.apply(arr))
out *= self._cst
Expand All @@ -187,7 +183,6 @@ def estimate_lipschitz(self, **kwargs) -> pxt.Real:
L *= abs(self._cst)
return L

@pxrt.enforce_precision(i=("arr", "tau"))
def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
return self._op.prox(arr, tau * self._cst)

Expand All @@ -214,13 +209,11 @@ def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real:
dL *= abs(self._cst)
return dL

@pxrt.enforce_precision(i="arr")
def grad(self, arr: pxt.NDArray) -> pxt.NDArray:
out = pxu.copy_if_unsafe(self._op.grad(arr))
out *= self._cst
return out

@pxrt.enforce_precision(i="arr")
def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
out = pxu.copy_if_unsafe(self._op.adjoint(arr))
out *= self._cst
Expand All @@ -236,7 +229,6 @@ def svdvals(self, **kwargs) -> pxt.NDArray:
D *= abs(self._cst)
return D

@pxrt.enforce_precision(i=("arr", "damp"))
def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray:
scale = damp / (self._cst**2)
out = pxu.copy_if_unsafe(self._op.pinv(arr, damp=scale, **kwargs))
Expand All @@ -251,7 +243,6 @@ def cogram(self) -> pxt.OpT:
op = self._op.cogram() * (self._cst**2)
return op

@pxrt.enforce_precision()
def trace(self, **kwargs) -> pxt.Real:
tr = self._op.trace(**kwargs) * self._cst
return tr
Expand Down Expand Up @@ -317,7 +308,6 @@ def op(self) -> pxt.OpT:
# ConstantVECTOR output: modify ConstantValued to work.
from pyxu.operator import ConstantValued

@pxrt.enforce_precision(i="arr")
def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray:
xp = pxu.get_array_module(arr)
arr = xp.zeros_like(arr)
Expand Down Expand Up @@ -373,7 +363,6 @@ def _infer_op_klass(self) -> pxt.OpC:
klass = pxo.Operator._infer_operator_type(properties)
return klass

@pxrt.enforce_precision(i="arr")
def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
x = arr.copy()
x *= self._cst
Expand All @@ -389,7 +378,6 @@ def estimate_lipschitz(self, **kwargs) -> pxt.Real:
L *= abs(self._cst)
return L

@pxrt.enforce_precision(i=("arr", "tau"))
def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
x = arr.copy()
x *= self._cst
Expand Down Expand Up @@ -423,15 +411,13 @@ def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real:
dL *= self._cst**2
return dL

@pxrt.enforce_precision(i="arr")
def grad(self, arr: pxt.NDArray) -> pxt.NDArray:
x = arr.copy()
x *= self._cst
out = pxu.copy_if_unsafe(self._op.grad(x))
out *= self._cst
return out

@pxrt.enforce_precision(i="arr")
def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
out = pxu.copy_if_unsafe(self._op.adjoint(arr))
out *= self._cst
Expand All @@ -447,7 +433,6 @@ def svdvals(self, **kwargs) -> pxt.NDArray:
D *= abs(self._cst)
return D

@pxrt.enforce_precision(i=("arr", "damp"))
def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray:
scale = damp / (self._cst**2)
out = pxu.copy_if_unsafe(self._op.pinv(arr, damp=scale, **kwargs))
Expand All @@ -462,7 +447,6 @@ def cogram(self) -> pxt.OpT:
op = self._op.cogram() * (self._cst**2)
return op

@pxrt.enforce_precision()
def trace(self, **kwargs) -> pxt.Real:
tr = self._op.trace(**kwargs) * self._cst
return tr
Expand Down Expand Up @@ -571,7 +555,6 @@ def _infer_op_klass(self) -> pxt.OpC:
klass = pxo.Operator._infer_operator_type(properties)
return klass

@pxrt.enforce_precision(i="arr")
def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
x = arr.copy()
x += self._cst
Expand All @@ -586,7 +569,6 @@ def estimate_lipschitz(self, **kwargs) -> pxt.Real:
L = self._op.estimate_lipschitz(**kwargs)
return L

@pxrt.enforce_precision(i=("arr", "tau"))
def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
x = arr.copy()
x += self._cst
Expand Down Expand Up @@ -626,7 +608,6 @@ def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real:
dL = self._op.estimate_diff_lipschitz(**kwargs)
return dL

@pxrt.enforce_precision(i="arr")
def grad(self, arr: pxt.NDArray) -> pxt.NDArray:
x = arr.copy()
x += self._cst
Expand Down Expand Up @@ -845,7 +826,6 @@ def _infer_op_klass(
klass = pxo.Operator._infer_operator_type(base)
return klass

@pxrt.enforce_precision(i="arr")
def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
out = pxu.copy_if_unsafe(self._lhs.apply(arr))
out += self._rhs.apply(arr)
Expand All @@ -866,7 +846,6 @@ def estimate_lipschitz(self, **kwargs) -> pxt.Real:
L = L_lhs + L_rhs
return L

@pxrt.enforce_precision(i=("arr", "tau"))
def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
P_LHS = self._lhs.properties()
P_RHS = self._rhs.properties()
Expand Down Expand Up @@ -908,14 +887,12 @@ def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real:
dL = dL_lhs + dL_rhs
return dL

@pxrt.enforce_precision(i="arr")
def grad(self, arr: pxt.NDArray) -> pxt.NDArray:
out = self._lhs.grad(arr)
out = pxu.copy_if_unsafe(out)
out += self._rhs.grad(arr)
return out

@pxrt.enforce_precision(i="arr")
def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
out = self._lhs.adjoint(arr)
out = pxu.copy_if_unsafe(out)
Expand Down Expand Up @@ -944,7 +921,6 @@ def cogram(self) -> pxt.OpT:
op = op1 + op2 + (op3 + op4).asop(pxo.SelfAdjointOp)
return op

@pxrt.enforce_precision()
def trace(self, **kwargs) -> pxt.Real:
tr = 0
for side in (self._lhs, self._rhs):
Expand Down Expand Up @@ -1106,7 +1082,6 @@ def _infer_op_klass(self) -> pxt.OpC:
klass = pxo.Operator._infer_operator_type(properties)
return klass

@pxrt.enforce_precision(i="arr")
def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
x = self._rhs.apply(arr)
out = self._lhs.apply(x)
Expand All @@ -1131,7 +1106,6 @@ def estimate_lipschitz(self, **kwargs) -> pxt.Real:
L = L_lhs * L_rhs
return L

@pxrt.enforce_precision(i=("arr", "tau"))
def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
if self.has(pxo.Property.PROXIMABLE):
out = None
Expand Down Expand Up @@ -1226,7 +1200,6 @@ def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real:
dL = np.inf
return dL

@pxrt.enforce_precision(i="arr")
def grad(self, arr: pxt.NDArray) -> pxt.NDArray:
sh = arr.shape[: -self.dim_rank]
if (len(sh) == 0) or self._rhs.has(pxo.Property.LINEAR):
Expand Down Expand Up @@ -1262,7 +1235,6 @@ def f(arr: pxt.NDArray) -> pxt.NDArray:
out = f(arr)
return out

@pxrt.enforce_precision(i="arr")
def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
x = self._lhs.adjoint(arr)
out = self._rhs.adjoint(x)
Expand Down Expand Up @@ -1363,7 +1335,6 @@ def _infer_op_klass(self) -> pxt.OpC:
klass = pxo.Operator._infer_operator_type(prop)
return klass

@pxrt.enforce_precision(i="arr")
def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
out = self._op.adjoint(arr)
return out
Expand All @@ -1376,7 +1347,6 @@ def estimate_lipschitz(self, **kwargs) -> pxt.Real:
L = self._op.estimate_lipschitz(**kwargs)
return L

@pxrt.enforce_precision(i=("arr", "tau"))
def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
out = pxo.LinFunc.prox(self, arr, tau)
return out
Expand All @@ -1387,12 +1357,10 @@ def jacobian(self, arr: pxt.NDArray) -> pxt.OpT:
def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real:
return 0

@pxrt.enforce_precision(i="arr")
def grad(self, arr: pxt.NDArray) -> pxt.NDArray:
out = pxo.LinFunc.grad(self, arr)
return out

@pxrt.enforce_precision(i="arr")
def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
out = self._op.apply(arr)
return out
Expand All @@ -1417,7 +1385,6 @@ def svdvals(self, **kwargs) -> pxt.NDArray:
D = self._op.svdvals(**kwargs)
return D

@pxrt.enforce_precision()
def trace(self, **kwargs) -> pxt.Real:
tr = self._op.trace(**kwargs)
return tr
Loading

0 comments on commit 11d4b4b

Please sign in to comment.