Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rMS proximal operator #142

Merged
merged 6 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ Non-Convex
Log1
QuadraticEnvelopeCard
QuadraticEnvelopeCardIndicator
RelaxedMumfordShah
SCAD


Expand Down
99 changes: 99 additions & 0 deletions pyproximal/proximal/RelaxedMS.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import numpy as np

from pyproximal.ProxOperator import _check_tau
from pyproximal import ProxOperator
from pyproximal.proximal.L1 import _current_sigma


def _l2(x, alpha):
r"""Scaling operation.

Applies the proximal of ``alpha||y - x||_2^2`` which is essentially a scaling operation.

Parameters
----------
x : :obj:`numpy.ndarray`
Vector
alpha : :obj:`float`
Scaling parameter

Returns
-------
y : :obj:`numpy.ndarray`
Scaled vector

"""
y = 1 / (1 + 2 * alpha) * x
return y


def _current_kappa(kappa, count):
if not callable(kappa):
return kappa
else:
return kappa(count)


class RelaxedMumfordShah(ProxOperator):
r"""Relaxed Mumford-Shah norm proximal operator.

Proximal operator of the relaxed Mumford-Shah norm:
:math:`\text{rMS}(x) = \min (\alpha\Vert x\Vert_2^2, \kappa)`.

Parameters
----------
sigma : :obj:`float` or :obj:`list` or :obj:`np.ndarray` or :obj:`func`, optional
Multiplicative coefficient of L2 norm that controls the smoothness of the solutuon.
This can be a constant number, a list of values (for multidimensional inputs, acting
on the second dimension) or a function that is called passing a counter which keeps
track of how many times the ``prox`` method has been invoked before and returns a
scalar (or a list of) ``sigma`` to be used.
kappa : :obj:`float` or :obj:`list` or :obj:`np.ndarray` or :obj:`func`, optional
Constant value in the rMS norm which essentially controls when the norm allows a jump. This can be a
constant number, a list of values (for multidimensional inputs, acting on the second dimension) or
a function that is called passing a counter which keeps track of how many
times the ``prox`` method has been invoked before and returns a scalar (or a list of)
``kappa`` to be used.

Notes
-----
The :math:`rMS` proximal operator is defined as [1]_:

.. math::
\text{prox}_{\tau \text{rMS}}(x) =
\begin{cases}
\frac{1}{1+2\tau\alpha}x & \text{ if } & \vert x\vert \leq \sqrt{\frac{\kappa}{\alpha}(1 + 2\tau\alpha)} \\
\kappa & \text{ else }
\end{cases}.

.. [1] Strekalovskiy, E., and D. Cremers, 2014, Real-time minimization of the piecewise smooth
Mumford-Shah functional: European Conference on Computer Vision, 127–141.

"""
def __init__(self, sigma=1., kappa=1.):
super().__init__(None, False)
self.sigma = sigma
self.kappa = kappa
self.count = 0

def __call__(self, x):
sigma = _current_sigma(self.sigma, self.count)
kappa = _current_sigma(self.kappa, self.count)
return np.minimum(sigma * np.linalg.norm(x) ** 2, kappa)

def _increment_count(func):
"""Increment counter
"""
def wrapped(self, *args, **kwargs):
self.count += 1
return func(self, *args, **kwargs)
return wrapped

@_increment_count
@_check_tau
def prox(self, x, tau):
sigma = _current_sigma(self.sigma, self.count)
kappa = _current_sigma(self.kappa, self.count)

x = np.where(np.abs(x) <= np.sqrt(kappa / sigma * (1 + 2 * tau * sigma)), _l2(x, tau * sigma), x)
return x
18 changes: 10 additions & 8 deletions pyproximal/proximal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,22 @@
Box Box indicator
Simplex Simplex indicator
Intersection Intersection indicator
AffineSet Affines set indicator
AffineSet Affines set indicator
Quadratic Quadratic function
Nonlinear Nonlinear function
Nonlinear Nonlinear function
L0 L0 Norm
L0Ball L0 Ball
L1 L1 Norm
L1Ball L1 Ball
Euclidean Euclidean Norm
EuclideanBall Euclidean Ball
Euclidean Euclidean Norm
EuclideanBall Euclidean Ball
L2 L2 Norm
L2Convolve L2 Norm of convolution operator
L21 L2,1 Norm
L21_plus_L1 L2,1 + L1 mixed-norm
Huber Huber Norm
TV Total Variation Norm
Huber Huber Norm
TV Total Variation Norm
RelaxedMumfordShah Relaxed Mumford Shah Norm
Nuclear Nuclear Norm
NuclearBall Nuclear Ball
Orthogonal Product between orthogonal operator and vector
Expand Down Expand Up @@ -53,6 +54,7 @@
from .L21_plus_L1 import *
from .Huber import *
from .TV import *
from .RelaxedMS import *
from .Nuclear import *
from .Orthogonal import *
from .VStack import *
Expand All @@ -66,8 +68,8 @@

__all__ = ['Box', 'Simplex', 'Intersection', 'AffineSet', 'Quadratic',
'Euclidean', 'EuclideanBall', 'L0', 'L0Ball', 'L1', 'L1Ball', 'L2',
'L2Convolve', 'L21', 'L21_plus_L1', 'Huber', 'TV', 'Nuclear',
'NuclearBall', 'Orthogonal', 'VStack', 'Nonlinear', 'SCAD',
'L2Convolve', 'L21', 'L21_plus_L1', 'Huber', 'TV', 'RelaxedMumfordShah',
'Nuclear', 'NuclearBall', 'Orthogonal', 'VStack', 'Nonlinear', 'SCAD',
'Log', 'Log1', 'ETP', 'Geman', 'QuadraticEnvelopeCard', 'SingularValuePenalty',
'QuadraticEnvelopeCardIndicator', 'QuadraticEnvelopeRankL2',
'Hankel']
19 changes: 18 additions & 1 deletion pytests/test_norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from pylops.basicoperators import Identity, Diagonal, MatrixMult, FirstDerivative
from pyproximal.utils import moreau
from pyproximal.proximal import Box, Euclidean, L2, L1, L21, L21_plus_L1, Huber, Nuclear, TV
from pyproximal.proximal import Box, Euclidean, L2, L1, L21, L21_plus_L1, \
Huber, Nuclear, RelaxedMumfordShah, TV

par1 = {'nx': 10, 'sigma': 1., 'dtype': 'float32'} # even float32
par2 = {'nx': 11, 'sigma': 2., 'dtype': 'float64'} # odd float64
Expand Down Expand Up @@ -202,6 +203,22 @@ def test_TV(par):
assert_array_almost_equal(tv(x), par['sigma'] * np.sum(np.abs(dx), axis=0))


@pytest.mark.parametrize("par", [(par1), (par2)])
def test_rMS(par):
"""rMS norm and proximal/dual proximal
"""
kappa = 1.
rMS = RelaxedMumfordShah(sigma=par['sigma'], kappa=kappa)

# norm
x = np.random.normal(0., 1., par['nx']).astype(par['dtype'])
assert rMS(x) == np.minimum(par['sigma'] * np.linalg.norm(x) ** 2, kappa)

# prox / dualprox
tau = 2.
assert moreau(rMS, x, tau)


def test_Nuclear_FOM():
"""Nuclear norm benchmark with FOM solver
"""
Expand Down
Binary file added testdata/marmousi_trace.npy
Binary file not shown.
Loading