Skip to content

Commit

Permalink
feat: added fungrad to Nonlinear
Browse files Browse the repository at this point in the history
  • Loading branch information
mrava87 committed Aug 23, 2024
1 parent 2f7b7ad commit f17bb13
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
12 changes: 12 additions & 0 deletions pyproximal/proximal/Nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class Nonlinear(ProxOperator):
- ``fun``: a method evaluating the generic function :math:`f`
- ``grad``: a method evaluating the gradient of the generic function
:math:`f`
- ``fungrad``: a method evaluating both the generic function :math:`f`
and its gradient
- ``optimize``: a method that solves the optimization problem associated
with the proximal operator of :math:`f`. Note that the
``gradprox`` method must be used (instead of ``grad``) as this will
Expand Down Expand Up @@ -58,6 +60,12 @@ def _funprox(self, x, tau):
def _gradprox(self, x, tau):
return self.grad(x) + 1. / tau * (x - self.y)

def _fungradprox(self, x, tau):
f, g = self.fungrad(x)
f = f + 1. / (2 * tau) * ((x - self.y) ** 2).sum()
g = g + 1. / tau * (x - self.y)
return f, g

def fun(self, x):
raise NotImplementedError('The method fun has not been implemented.'
'Refer to the documentation for details on '
Expand All @@ -66,6 +74,10 @@ def grad(self, x):
raise NotImplementedError('The method grad has not been implemented.'
'Refer to the documentation for details on '
'how to subclass this operator.')
def fungrad(self, x):
raise NotImplementedError('The method grad has not been implemented.'
'Refer to the documentation for details on '
'how to subclass this operator.')
def optimize(self):
raise NotImplementedError('The method optimize has not been implemented.'
'Refer to the documentation for details on '
Expand Down
1 change: 1 addition & 0 deletions pytests/test_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def test_l2(par):
raiseerror=True, atol=1e-3,
verb=False)


@pytest.mark.parametrize("par", [(par1), (par2), (par1j), (par2j)])
def test_lowrank(par):
"""LowRankFactorizedMatrix gradient
Expand Down

0 comments on commit f17bb13

Please sign in to comment.