Skip to content

Commit

Permalink
Merge pull request #132 from mrava87/feature-pnpgeneral
Browse files Browse the repository at this point in the history
feature: modify PnP signature to allow any proximal solver
  • Loading branch information
mrava87 authored Jul 20, 2023
2 parents b98160e + 5d378e2 commit 2eeb785
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 26 deletions.
21 changes: 11 additions & 10 deletions pyproximal/optimization/pnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,19 @@ def prox(self, x, tau):
return xden.ravel()


def PlugAndPlay(proxf, denoiser, dims, x0, tau, niter=10,
callback=None, show=False):
r"""Plug-and-Play Priors with ADMM optimization
def PlugAndPlay(proxf, denoiser, dims, x0, solver=ADMM, **kwargs_solver):
r"""Plug-and-Play Priors with any proximal algorithm of choice
Solves the following minimization problem using the ADMM algorithm:
Solves the following minimization problem using any proximal a
lgorithm of choice:
.. math::
\mathbf{x},\mathbf{z} = \argmin_{\mathbf{x}}
f(\mathbf{x}) + \lambda g(\mathbf{x})
where :math:`f(\mathbf{x})` is a function that has a known proximal
operator where :math:`g(\mathbf{x})` is a function acting as implicit
where :math:`f(\mathbf{x})` is a function that has a known gradient or
proximal operator and :math:`g(\mathbf{x})` is a function acting as implicit
prior. Implicit means that no explicit function should be defined: instead,
a denoising algorithm of choice is used. See Notes for details.
Expand All @@ -62,6 +62,8 @@ def PlugAndPlay(proxf, denoiser, dims, x0, tau, niter=10,
prior to calling the ``denoiser``
x0 : :obj:`numpy.ndarray`
Initial vector
solver : :func:`pyproximal.optimization.primal` or :func:`pyproximal.optimization.primaldual`
Solver of choice
tau : :obj:`float`, optional
Positive scalar weight, which should satisfy the following condition
to guarantees convergence: :math:`\tau \in (0, 1/L]` where ``L`` is
Expand All @@ -83,7 +85,8 @@ def PlugAndPlay(proxf, denoiser, dims, x0, tau, niter=10,
Notes
-----
Plug-and-Play Priors [1]_ can be expressed by the following recursion:
Plug-and-Play Priors [1]_ can be used with any proximal algorithm of choice. For example, when
ADMM is selected, the resulting scheme can be expressed by the following recursion:
.. math::
Expand Down Expand Up @@ -119,6 +122,4 @@ def PlugAndPlay(proxf, denoiser, dims, x0, tau, niter=10,
# Denoiser
proxpnp = _Denoise(denoiser, dims=dims)

return ADMM(proxf, proxpnp, tau=tau, x0=x0,
niter=niter, callback=callback,
show=show)
return solver(proxf, proxpnp, x0=x0, **kwargs_solver)
7 changes: 4 additions & 3 deletions pyproximal/optimization/primal.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,11 @@ def ProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
'Proximal operator (f): %s\n'
'Proximal operator (g): %s\n'
'tau = %s\tbeta=%10e\n'
'epsg = %s\tniter = %d\t'
'niterback = %d\n' % (type(proxf), type(proxg),
'epsg = %s\tniter = %d\n'
''
'niterback = %d\tacceleration = %s\n' % (type(proxf), type(proxg),
'Adaptive' if tau is None else str(tau), beta,
epsg_print, niter, niterback))
epsg_print, niter, niterback, acceleration))
head = ' Itn x[0] f g J=f+eps*g'
print(head)

Expand Down
53 changes: 40 additions & 13 deletions tutorials/plugandplay.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@
As an example, we will consider a simplified MRI experiment, where the
data is created by appling a 2D Fourier Transform to the input model and
by randomly sampling 60% of its values. We will also use the famous
by randomly sampling 60% of its values. We will use the famous
`BM3D <https://pypi.org/project/bm3d>`_ as the denoiser, but any other denoiser
of choice can be used instead!
Finally, whilst in the original paper, PnP is associated to the ADMM solver, subsequent
research showed that the same principle can be applied to pretty much any proximal
solver. We will show how to pass a solver of choice to our
:func:`pyproximal.optimization.pnp.PlugAndPlay` solver.
"""
import numpy as np
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -67,7 +72,7 @@

###############################################################################
# At this point we create a denoiser instance using the BM3D algorithm and use
# as Plug-and-Play Prior to the ADMM algorithm
# as Plug-and-Play Prior to the PG and ADMM algorithms

def callback(x, xtrue, errhist):
errhist.append(np.linalg.norm(x - xtrue))
Expand All @@ -83,24 +88,46 @@ def callback(x, xtrue, errhist):
denoiser = lambda x, tau: bm3d.bm3d(np.real(x), sigma_psd=sigma * tau,
stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING)

errhist = []
xpnp = pyproximal.optimization.pnp.PlugAndPlay(l2, denoiser, x.shape,
tau=tau, x0=np.zeros(x.size),
niter=40, show=True,
callback=lambda xx: callback(xx, x.ravel(),
errhist))[0]
xpnp = np.real(xpnp.reshape(x.shape))
# PG-Pnp
errhistpg = []
xpnppg = pyproximal.optimization.pnp.PlugAndPlay(l2, denoiser, x.shape,
solver=pyproximal.optimization.primal.ProximalGradient,
tau=tau, x0=np.zeros(x.size),
niter=40,
acceleration='fista',
show=True,
callback=lambda xx: callback(xx, x.ravel(),
errhistpg))
xpnppg = np.real(xpnppg.reshape(x.shape))

# ADMM-PnP
errhistadmm = []
xpnpadmm = pyproximal.optimization.pnp.PlugAndPlay(l2, denoiser, x.shape,
solver=pyproximal.optimization.primal.ADMM,
tau=tau, x0=np.zeros(x.size),
niter=40, show=True,
callback=lambda xx: callback(xx, x.ravel(),
errhistadmm))[0]
xpnpadmm = np.real(xpnpadmm.reshape(x.shape))

fig, axs = plt.subplots(1, 2, figsize=(9, 5))
fig, axs = plt.subplots(1, 3, figsize=(14, 5))
axs[0].imshow(x, vmin=0, vmax=1, cmap="gray")
axs[0].set_title("Model")
axs[0].axis("tight")
axs[1].imshow(xpnp, vmin=0, vmax=1, cmap="gray")
axs[1].set_title("PnP Inversion")
axs[1].imshow(xpnppg, vmin=0, vmax=1, cmap="gray")
axs[1].set_title("PG-PnP Inversion")
axs[1].axis("tight")
axs[2].imshow(xpnpadmm, vmin=0, vmax=1, cmap="gray")
axs[2].set_title("ADMM-PnP Inversion")
axs[2].axis("tight")
plt.tight_layout()

###############################################################################
# Finally, let's compare the error convergence of the two variations of PnP

plt.figure(figsize=(12, 3))
plt.plot(errhist, 'k', lw=2)
plt.plot(errhistpg, 'k', lw=2, label='PG')
plt.plot(errhistadmm, 'r', lw=2, label='ADMM')
plt.title("Error norm")
plt.legend()
plt.tight_layout()

0 comments on commit 2eeb785

Please sign in to comment.