Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added fungrad to Nonlinear #186

Merged
merged 1 commit into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pyproximal/proximal/Nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class Nonlinear(ProxOperator):
- ``fun``: a method evaluating the generic function :math:`f`
- ``grad``: a method evaluating the gradient of the generic function
:math:`f`
- ``fungrad``: a method evaluating both the generic function :math:`f`
and its gradient
- ``optimize``: a method that solves the optimization problem associated
with the proximal operator of :math:`f`. Note that the
``gradprox`` method must be used (instead of ``grad``) as this will
Expand Down Expand Up @@ -58,6 +60,12 @@ def _funprox(self, x, tau):
def _gradprox(self, x, tau):
return self.grad(x) + 1. / tau * (x - self.y)

def _fungradprox(self, x, tau):
f, g = self.fungrad(x)
f = f + 1. / (2 * tau) * ((x - self.y) ** 2).sum()
g = g + 1. / tau * (x - self.y)
return f, g

def fun(self, x):
raise NotImplementedError('The method fun has not been implemented.'
'Refer to the documentation for details on '
Expand All @@ -66,6 +74,10 @@ def grad(self, x):
raise NotImplementedError('The method grad has not been implemented.'
'Refer to the documentation for details on '
'how to subclass this operator.')
def fungrad(self, x):
raise NotImplementedError('The method grad has not been implemented.'
'Refer to the documentation for details on '
'how to subclass this operator.')
def optimize(self):
raise NotImplementedError('The method optimize has not been implemented.'
'Refer to the documentation for details on '
Expand Down
1 change: 1 addition & 0 deletions pytests/test_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def test_l2(par):
raiseerror=True, atol=1e-3,
verb=False)


@pytest.mark.parametrize("par", [(par1), (par2), (par1j), (par2j)])
def test_lowrank(par):
"""LowRankFactorizedMatrix gradient
Expand Down
Loading