Skip to content

Commit

Permalink
Solver and functional extensions (#264)
Browse files Browse the repository at this point in the history
* Add Huber function norm

* Add bisection vector root finder

* Improve error handling and tests

* Improvements

* Add golden section minimizer

* Docs correction

* Minor edits

* Add set distance functional

* Add squared set distance functional

* Clean up set distance tests

* Clean up functional tests

* Temporary fix for jinja2/nbconvert bug jupyter/nbconvert#1736

* Update change summary

* Address codefactor complaints

* Rename test file
  • Loading branch information
bwohlberg authored Mar 25, 2022
1 parent 63ea350 commit 31797e2
Show file tree
Hide file tree
Showing 16 changed files with 609 additions and 166 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Version 0.0.3 (unreleased)
• Add function ``linop.linop_from_function`` for constructing linear
operators from functions.
• Add support for addition of functionals.
• Additional solvers in ``scico.solver``.
• New Huber norm and set distance functionals.



Expand Down
1 change: 1 addition & 0 deletions docs/docs_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ sphinxcontrib-napoleon
sphinxcontrib-bibtex
sphinx-autodoc-typehints
faculty-sphinx-theme
jinja2<3.1.0 # temporary fix for jina2/nbconvert bug
nbsphinx
ipython_genutils
py2jn
Expand Down
12 changes: 12 additions & 0 deletions docs/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,18 @@ @Book {goodman-2005-fourier
edition = 3
}

@Article {huber-1964-robust,
doi = {10.1214/aoms/1177703732},
year = 1964,
month = Mar,
volume = 35,
number = 1,
pages = {73--101},
author = {Peter J. Huber},
title = {Robust Estimation of a Location Parameter},
journal = {The Annals of Mathematical Statistics}
}

@Article {kamilov-2017-plugandplay,
author = {Ulugbek Kamilov and Hassan Mansour and Brendt
Wohlberg},
Expand Down
6 changes: 5 additions & 1 deletion scico/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@

# isort: off
from ._functional import Functional, ScaledFunctional, SeparableFunctional, ZeroFunctional
from ._norm import L0Norm, L1Norm, SquaredL2Norm, L2Norm, L21Norm, NuclearNorm
from ._norm import HuberNorm, L0Norm, L1Norm, SquaredL2Norm, L2Norm, L21Norm, NuclearNorm
from ._indicator import NonNegativeIndicator, L2BallIndicator
from ._dist import SetDistance, SquaredSetDistance
from ._denoiser import BM3D, DnCNN


Expand All @@ -21,6 +22,7 @@
"ScaledFunctional",
"SeparableFunctional",
"ZeroFunctional",
"HuberNorm",
"L0Norm",
"L1Norm",
"SquaredL2Norm",
Expand All @@ -29,6 +31,8 @@
"NonNegativeIndicator",
"NuclearNorm",
"L2BallIndicator",
"SetDistance",
"SquaredSetDistance",
"BM3D",
"DnCNN",
]
Expand Down
156 changes: 156 additions & 0 deletions scico/functional/_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2022 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Distance functions."""

from typing import Callable, Union

from scico import numpy as snp
from scico.blockarray import BlockArray
from scico.typing import JaxArray

from ._functional import Functional


class SetDistance(Functional):
r"""Distance to a closed convex set.
This functional computes the :math:`\ell_2` distance from a vector to
a closed convex set :math:`C`
.. math::
d(\mb{x}) = \min_{\mb{y} \in C} \, \| \mb{x} - \mb{y} \|_2 \;.
The set is not specified directly, but in terms of a function
computing the projection into that set, i.e.
.. math::
d(\mb{x}) = \| \mb{x} - P_C(\mb{x}) \|_2 \;,
where :math:`P_C(\mb{x})` is the projection of :math:`\mb{x}` into
set :math:`C`.
"""

has_eval = True
has_prox = True

def __init__(self, proj: Callable, args=()):
r"""
Args:
proj: Function computing the projection into the convex set.
args: Additional arguments for function `proj`.
"""
self.proj = proj
self.args = args

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
r"""Compute the :math:`\ell_2` distance to the set.
Compute the distance :math:`d(\mb{x})` between :math:`\mb{x}` and
the set :math:`C`.
Args:
x: Input array :math:`\mb{x}`.
Returns:
Euclidean distance from `x` to the projection of `x`.
"""
y = self.proj(*((x,) + self.args))
return snp.linalg.norm(x - y)

def prox(
self, v: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs
) -> Union[JaxArray, BlockArray]:
r"""Proximal operator of the :math:`\ell_2` distance function.
Compute the proximal operator of the :math:`\ell_2` distance
function :math:`d(\mb{x})` :cite:`beck-2017-first` (Lemma 6.43).
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
classes.
Returns:
Scaled proximal operator evaluated at `v`.
"""
y = self.proj(*((v,) + self.args))
d = snp.linalg.norm(v - y)
𝜃 = lam / d if d >= lam else 1.0
return 𝜃 * y + (1.0 - 𝜃) * v


class SquaredSetDistance(Functional):
r"""Squared :math:`\ell_2` distance to a closed convex set.
This functional computes the :math:`\ell_2` distance from a vector to
a closed convex set :math:`C`
.. math::
d(\mb{x}) = \min_{\mb{y} \in C} \, (1/2) \| \mb{x} - \mb{y} \|_2^2
\;.
The set is not specified directly, but in terms of a function
computing the projection into that set, i.e.
.. math::
d(\mb{x}) = (1/2) \| \mb{x} - P_C(\mb{x}) \|_2^2 \;,
where :math:`P_C(\mb{x})` is the projection of :math:`\mb{x}` into
set :math:`C`.
"""

has_eval = True
has_prox = True

def __init__(self, proj: Callable, args=()):
r"""
Args:
proj: Function computing the projection into the convex set.
args: Additional arguments for function `proj`.
"""
self.proj = proj
self.args = args

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
r"""Compute the squared :math:`\ell_2` distance to the set.
Compute the distance :math:`d(\mb{x})` between :math:`\mb{x}` and
the set :math:`C`.
Args:
x: Input array :math:`\mb{x}`.
Returns:
Squared :math:`\ell_2` distance from `x` to the projection of `x`.
"""
y = self.proj(*((x,) + self.args))
return 0.5 * snp.linalg.norm(x - y) ** 2

def prox(
self, v: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs
) -> Union[JaxArray, BlockArray]:
r"""Proximal operator of the squared :math:`\ell_2` distance function.
Compute the proximal operator of the squared :math:`\ell_2` distance
function :math:`d(\mb{x})` :cite:`beck-2017-first` (Example 6.65).
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
classes.
Returns:
Scaled proximal operator evaluated at `v`.
"""
y = self.proj(*((v,) + self.args))
𝛼 = 1.0 / (1.0 + lam)
return 𝛼 * v + lam * 𝛼 * y
7 changes: 2 additions & 5 deletions scico/functional/_indicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class NonNegativeIndicator(Functional):
0 & \text{if } x_i \geq 0 \text{ for each } i \\
\infty & \text{else} \;.
\end{cases}
"""

has_eval = True
Expand All @@ -57,11 +56,10 @@ def prox(
\end{cases}
Args:
v : Input array :math:`\mb{v}`.
lam : Proximal parameter :math:`\lambda` (has no effect).
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda` (has no effect).
kwargs: Additional arguments that may be used by derived
classes.
"""
return snp.maximum(v, 0)

Expand Down Expand Up @@ -107,6 +105,5 @@ def prox(
.. math::
\mathrm{prox}_{\lambda I_r}(\mb{v}) = r \frac{\mb{v}}{\norm{\mb{v}}_2}\;.
"""
return self.radius * v / norm(v)
57 changes: 56 additions & 1 deletion scico/functional/_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from typing import Union

from jax import jit
from jax import jit, lax

from scico import numpy as snp
from scico.array import no_nan_divide
Expand Down Expand Up @@ -258,6 +258,61 @@ def prox(
return new_length * direction


class HuberNorm(Functional):
r"""Huber norm.
Compute a norm based on the Huber function :cite:`huber-1964-robust`
:cite:`beck-2017-first` (Sec. 6.7.1)
.. math::
H_{\delta}(\mb{x}) = \begin{cases}
(1/2) \| \mb{x} \|_2^2 & \text{ when } \| \mb{x} \|_2 \leq
\delta \\
\delta \| \mb{x} \|_2 - (1/2) & \text{ when } \| \mb{x} \|_2
> \delta \;,
\end{cases}
where :math:`\delta` is a parameter controlling the transitions
between :math:`\ell_1`-norm like and :math:`\ell_2`-norm like
behavior.
"""

has_eval = True
has_prox = True

def __init__(self, delta: float = 1.0):
r"""
Args:
delta: Huber function parameter :math:`\delta`.
"""
self.delta = delta
self._call_lt_branch = lambda xl2: 0.5 * xl2 ** 2
self._call_gt_branch = lambda xl2: self.delta * xl2 - 0.5
super().__init__()

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
xl2 = snp.linalg.norm(x)
return lax.cond(xl2 <= self.delta, self._call_lt_branch, self._call_gt_branch, xl2)

def prox(
self, v: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs
) -> Union[JaxArray, BlockArray]:
r"""Evaluate proximal operator of the Huber function.
Evaluate proximal operator of the Huber function
:cite:`beck-2017-first` (Sec. 6.7.3).
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
classes.
"""
vl2 = snp.linalg.norm(v)
den = snp.maximum(vl2, self.delta * (1.0 + lam))
return (1 - ((self.delta * lam) / den)) * v


class NuclearNorm(Functional):
r"""Nuclear norm.
Expand Down
Loading

0 comments on commit 31797e2

Please sign in to comment.