From f17bb1301b9920644d9c1b68ad38ac4f4dfa77f4 Mon Sep 17 00:00:00 2001 From: mrava87 Date: Fri, 23 Aug 2024 10:43:03 -0500 Subject: [PATCH] feat: added fungrad to Nonlinear --- pyproximal/proximal/Nonlinear.py | 12 ++++++++++++ pytests/test_grads.py | 1 + 2 files changed, 13 insertions(+) diff --git a/pyproximal/proximal/Nonlinear.py b/pyproximal/proximal/Nonlinear.py index 4bf318c..4564732 100644 --- a/pyproximal/proximal/Nonlinear.py +++ b/pyproximal/proximal/Nonlinear.py @@ -14,6 +14,8 @@ class Nonlinear(ProxOperator): - ``fun``: a method evaluating the generic function :math:`f` - ``grad``: a method evaluating the gradient of the generic function :math:`f` + - ``fungrad``: a method evaluating both the generic function :math:`f` + and its gradient - ``optimize``: a method that solves the optimization problem associated with the proximal operator of :math:`f`. Note that the ``gradprox`` method must be used (instead of ``grad``) as this will @@ -58,6 +60,12 @@ def _funprox(self, x, tau): def _gradprox(self, x, tau): return self.grad(x) + 1. / tau * (x - self.y) + def _fungradprox(self, x, tau): + f, g = self.fungrad(x) + f = f + 1. / (2 * tau) * ((x - self.y) ** 2).sum() + g = g + 1. / tau * (x - self.y) + return f, g + def fun(self, x): raise NotImplementedError('The method fun has not been implemented.' 'Refer to the documentation for details on ' @@ -66,6 +74,10 @@ def grad(self, x): raise NotImplementedError('The method grad has not been implemented.' 'Refer to the documentation for details on ' 'how to subclass this operator.') + def fungrad(self, x): + raise NotImplementedError('The method grad has not been implemented.' + 'Refer to the documentation for details on ' + 'how to subclass this operator.') def optimize(self): raise NotImplementedError('The method optimize has not been implemented.' 'Refer to the documentation for details on ' diff --git a/pytests/test_grads.py b/pytests/test_grads.py index 88b8fd4..b121593 100644 --- a/pytests/test_grads.py +++ b/pytests/test_grads.py @@ -52,6 +52,7 @@ def test_l2(par): raiseerror=True, atol=1e-3, verb=False) + @pytest.mark.parametrize("par", [(par1), (par2), (par1j), (par2j)]) def test_lowrank(par): """LowRankFactorizedMatrix gradient