Skip to content

Commit

Permalink
Merge pull request #142 from NickLuiken/dev
Browse files Browse the repository at this point in the history
rMS proximal operator
  • Loading branch information
mrava87 authored Oct 10, 2023
2 parents a9f3249 + fd56deb commit a9cca82
Show file tree
Hide file tree
Showing 6 changed files with 348 additions and 9 deletions.
1 change: 1 addition & 0 deletions docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ Non-Convex
Log1
QuadraticEnvelopeCard
QuadraticEnvelopeCardIndicator
RelaxedMumfordShah
SCAD


Expand Down
99 changes: 99 additions & 0 deletions pyproximal/proximal/RelaxedMS.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import numpy as np

from pyproximal.ProxOperator import _check_tau
from pyproximal import ProxOperator
from pyproximal.proximal.L1 import _current_sigma


def _l2(x, alpha):
r"""Scaling operation.
Applies the proximal of ``alpha||y - x||_2^2`` which is essentially a scaling operation.
Parameters
----------
x : :obj:`numpy.ndarray`
Vector
alpha : :obj:`float`
Scaling parameter
Returns
-------
y : :obj:`numpy.ndarray`
Scaled vector
"""
y = 1 / (1 + 2 * alpha) * x
return y


def _current_kappa(kappa, count):
if not callable(kappa):
return kappa
else:
return kappa(count)


class RelaxedMumfordShah(ProxOperator):
r"""Relaxed Mumford-Shah norm proximal operator.
Proximal operator of the relaxed Mumford-Shah norm:
:math:`\text{rMS}(x) = \min (\alpha\Vert x\Vert_2^2, \kappa)`.
Parameters
----------
sigma : :obj:`float` or :obj:`list` or :obj:`np.ndarray` or :obj:`func`, optional
Multiplicative coefficient of L2 norm that controls the smoothness of the solutuon.
This can be a constant number, a list of values (for multidimensional inputs, acting
on the second dimension) or a function that is called passing a counter which keeps
track of how many times the ``prox`` method has been invoked before and returns a
scalar (or a list of) ``sigma`` to be used.
kappa : :obj:`float` or :obj:`list` or :obj:`np.ndarray` or :obj:`func`, optional
Constant value in the rMS norm which essentially controls when the norm allows a jump. This can be a
constant number, a list of values (for multidimensional inputs, acting on the second dimension) or
a function that is called passing a counter which keeps track of how many
times the ``prox`` method has been invoked before and returns a scalar (or a list of)
``kappa`` to be used.
Notes
-----
The :math:`rMS` proximal operator is defined as [1]_:
.. math::
\text{prox}_{\tau \text{rMS}}(x) =
\begin{cases}
\frac{1}{1+2\tau\alpha}x & \text{ if } & \vert x\vert \leq \sqrt{\frac{\kappa}{\alpha}(1 + 2\tau\alpha)} \\
\kappa & \text{ else }
\end{cases}.
.. [1] Strekalovskiy, E., and D. Cremers, 2014, Real-time minimization of the piecewise smooth
Mumford-Shah functional: European Conference on Computer Vision, 127–141.
"""
def __init__(self, sigma=1., kappa=1.):
super().__init__(None, False)
self.sigma = sigma
self.kappa = kappa
self.count = 0

def __call__(self, x):
sigma = _current_sigma(self.sigma, self.count)
kappa = _current_sigma(self.kappa, self.count)
return np.minimum(sigma * np.linalg.norm(x) ** 2, kappa)

def _increment_count(func):
"""Increment counter
"""
def wrapped(self, *args, **kwargs):
self.count += 1
return func(self, *args, **kwargs)
return wrapped

@_increment_count
@_check_tau
def prox(self, x, tau):
sigma = _current_sigma(self.sigma, self.count)
kappa = _current_sigma(self.kappa, self.count)

x = np.where(np.abs(x) <= np.sqrt(kappa / sigma * (1 + 2 * tau * sigma)), _l2(x, tau * sigma), x)
return x
18 changes: 10 additions & 8 deletions pyproximal/proximal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,22 @@
Box Box indicator
Simplex Simplex indicator
Intersection Intersection indicator
AffineSet Affines set indicator
AffineSet Affines set indicator
Quadratic Quadratic function
Nonlinear Nonlinear function
Nonlinear Nonlinear function
L0 L0 Norm
L0Ball L0 Ball
L1 L1 Norm
L1Ball L1 Ball
Euclidean Euclidean Norm
EuclideanBall Euclidean Ball
Euclidean Euclidean Norm
EuclideanBall Euclidean Ball
L2 L2 Norm
L2Convolve L2 Norm of convolution operator
L21 L2,1 Norm
L21_plus_L1 L2,1 + L1 mixed-norm
Huber Huber Norm
TV Total Variation Norm
Huber Huber Norm
TV Total Variation Norm
RelaxedMumfordShah Relaxed Mumford Shah Norm
Nuclear Nuclear Norm
NuclearBall Nuclear Ball
Orthogonal Product between orthogonal operator and vector
Expand Down Expand Up @@ -53,6 +54,7 @@
from .L21_plus_L1 import *
from .Huber import *
from .TV import *
from .RelaxedMS import *
from .Nuclear import *
from .Orthogonal import *
from .VStack import *
Expand All @@ -66,8 +68,8 @@

__all__ = ['Box', 'Simplex', 'Intersection', 'AffineSet', 'Quadratic',
'Euclidean', 'EuclideanBall', 'L0', 'L0Ball', 'L1', 'L1Ball', 'L2',
'L2Convolve', 'L21', 'L21_plus_L1', 'Huber', 'TV', 'Nuclear',
'NuclearBall', 'Orthogonal', 'VStack', 'Nonlinear', 'SCAD',
'L2Convolve', 'L21', 'L21_plus_L1', 'Huber', 'TV', 'RelaxedMumfordShah',
'Nuclear', 'NuclearBall', 'Orthogonal', 'VStack', 'Nonlinear', 'SCAD',
'Log', 'Log1', 'ETP', 'Geman', 'QuadraticEnvelopeCard', 'SingularValuePenalty',
'QuadraticEnvelopeCardIndicator', 'QuadraticEnvelopeRankL2',
'Hankel']
19 changes: 18 additions & 1 deletion pytests/test_norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from pylops.basicoperators import Identity, Diagonal, MatrixMult, FirstDerivative
from pyproximal.utils import moreau
from pyproximal.proximal import Box, Euclidean, L2, L1, L21, L21_plus_L1, Huber, Nuclear, TV
from pyproximal.proximal import Box, Euclidean, L2, L1, L21, L21_plus_L1, \
Huber, Nuclear, RelaxedMumfordShah, TV

par1 = {'nx': 10, 'sigma': 1., 'dtype': 'float32'} # even float32
par2 = {'nx': 11, 'sigma': 2., 'dtype': 'float64'} # odd float64
Expand Down Expand Up @@ -202,6 +203,22 @@ def test_TV(par):
assert_array_almost_equal(tv(x), par['sigma'] * np.sum(np.abs(dx), axis=0))


@pytest.mark.parametrize("par", [(par1), (par2)])
def test_rMS(par):
"""rMS norm and proximal/dual proximal
"""
kappa = 1.
rMS = RelaxedMumfordShah(sigma=par['sigma'], kappa=kappa)

# norm
x = np.random.normal(0., 1., par['nx']).astype(par['dtype'])
assert rMS(x) == np.minimum(par['sigma'] * np.linalg.norm(x) ** 2, kappa)

# prox / dualprox
tau = 2.
assert moreau(rMS, x, tau)


def test_Nuclear_FOM():
"""Nuclear norm benchmark with FOM solver
"""
Expand Down
Binary file added testdata/marmousi_trace.npy
Binary file not shown.
Loading

0 comments on commit a9cca82

Please sign in to comment.