From 08a5896b8c4e247938e3a3d3632d978bd7bf7aa1 Mon Sep 17 00:00:00 2001 From: Saltan Date: Sun, 5 Nov 2023 10:50:03 -0800 Subject: [PATCH] Added a new Functional for TV Norm (#456) * Added a new Functional for TV Norm implementing its proximal operator using the fast subiteration free algorithm proposed by Kamilov, 2016 * added checks for input shape in TV2DNorm * fixed lint errors, changed required argument to default in TV2DNorm, fixed inconsistent signature for prox function, added more comments to the helper functions * some unsaved changes from last commit * newline at end of file error * sort imports lint error * removed the default shape parameter from TV2DNorm * Some docs edits * Disable BlockArray tests on TV2DNorm * Fix black formatting * updated the TV norm logic to apply shrinkage to only the difference operator of the haar transform as in Kamilov, 2016 * Implementation supporting arbitrary dimensional inputs * Add a test * Minor changes * New implementation of TV norm and approximage prox * Clean up * Typo fix * Minor change * Add change log entry * Resolve typing errors * Resolve some oversights and issues arising when 64 bit floats enabled * Apply skipped pre-commit --------- Co-authored-by: Salman Naqvi Co-authored-by: Brendt Wohlberg --- CHANGES.rst | 2 + docs/source/references.bib | 12 +++ scico/functional/__init__.py | 2 + scico/functional/_tvnorm.py | 142 +++++++++++++++++++++++++++ scico/test/functional/test_core.py | 7 +- scico/test/functional/test_tvnorm.py | 40 ++++++++ 6 files changed, 204 insertions(+), 1 deletion(-) create mode 100644 scico/functional/_tvnorm.py create mode 100644 scico/test/functional/test_tvnorm.py diff --git a/CHANGES.rst b/CHANGES.rst index 7dc048a0a..265b9dd38 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,6 +6,8 @@ SCICO Release Notes Version 0.0.5 (unreleased) ---------------------------- +• New functional ``functional.AnisotropicTVNorm`` with proximal operator + approximation. • New integrated Radon/X-ray transform ``linop.XRayTransform``. • Rename modules ``radon_astra`` and ``radon_svmbir`` to ``xray.astra`` and ``xray.svmbir`` respectively, and rename ``TomographicProjector`` classes diff --git a/docs/source/references.bib b/docs/source/references.bib index e611ff201..b611601bb 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -387,6 +387,18 @@ @Article {jin-2017-unet doi = {10.1109/TIP.2017.2713099} } +@Article {kamilov-2016-parallel, + title = {A parallel proximal algorithm for anisotropic total + variation minimization}, + author = {Ulugbek S. Kamilov}, + journal = {IEEE Transactions on Image Processing}, + volume = 26, + number = 2, + pages = {539--548}, + year = 2016, + doi = {10.1109/tip.2016.2629449 } +} + @Article {kamilov-2017-plugandplay, author = {Ulugbek Kamilov and Hassan Mansour and Brendt Wohlberg}, diff --git a/scico/functional/__init__.py b/scico/functional/__init__.py index 53d426067..48509cd40 100644 --- a/scico/functional/__init__.py +++ b/scico/functional/__init__.py @@ -21,12 +21,14 @@ NuclearNorm, L1MinusL2Norm, ) +from ._tvnorm import AnisotropicTVNorm from ._indicator import NonNegativeIndicator, L2BallIndicator from ._denoiser import BM3D, BM4D, DnCNN from ._dist import SetDistance, SquaredSetDistance __all__ = [ + "AnisotropicTVNorm", "Functional", "ScaledFunctional", "SeparableFunctional", diff --git a/scico/functional/_tvnorm.py b/scico/functional/_tvnorm.py new file mode 100644 index 000000000..b8d621c93 --- /dev/null +++ b/scico/functional/_tvnorm.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2023 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SCICO package. Details of the copyright and +# user license can be found in the 'LICENSE' file distributed with the +# package. + +"""Anisotropic total variation norm.""" + +from typing import Optional, Tuple + +from scico import numpy as snp +from scico.linop import ( + CircularConvolve, + FiniteDifference, + LinearOperator, + VerticalStack, +) +from scico.numpy import Array + +from ._functional import Functional +from ._norm import L1Norm + + +class AnisotropicTVNorm(Functional): + r"""The anisotropic total variation (TV) norm. + + The anisotropic total variation (TV) norm computed by + + .. code-block:: python + + ATV = scico.functional.AnisotropicTVNorm() + x_norm = ATV(x) + + is equivalent to + + .. code-block:: python + + C = linop.FiniteDifference(input_shape=x.shape, circular=True) + L1 = functional.L1Norm() + x_norm = L1(C @ x) + + The scaled proximal operator is computed using an approximation that + holds for small scaling parameters :cite:`kamilov-2016-parallel`. + This does not imply that it can only be applied to problems requiring + a small regularization parameter since most proximal algorithms + include an additional algorithm parameter that also plays a role in + the parameter of the proximal operator. For example, in :class:`.PGM` + and :class:`.AcceleratedPGM`, the scaled proximal operator parameter + is the regularization parameter divided by the `L0` algorithm + parameter, and for :class:`.ADMM`, the scaled proximal operator + parameters are the regularization parameters divided by the entries + in the `rho_list` algorithm parameter. + """ + + has_eval = True + has_prox = True + + def __init__(self, ndims: Optional[int] = None): + r""" + Args: + ndims: Number of (trailing) dimensions of the input over + which to apply the finite difference operator. If + ``None``, differences are evaluated along all axes. + """ + self.ndims = ndims + self.h0 = snp.array([1.0, 1.0]) / snp.sqrt(2.0) # lowpass filter + self.h1 = snp.array([1.0, -1.0]) / snp.sqrt(2.0) # highpass filter + self.l1norm = L1Norm() + self.G: Optional[LinearOperator] = None + self.W: Optional[LinearOperator] = None + + def __call__(self, x: Array) -> float: + r"""Compute the anisotropic TV norm of an array.""" + if self.G is None or self.G.shape[1] != x.shape: + if self.ndims is None: + ndims = x.ndim + else: + ndims = self.ndims + axes = tuple(range(ndims)) + self.G = FiniteDifference( + x.shape, input_dtype=x.dtype, axes=axes, circular=True, jit=True + ) + return snp.sum(snp.abs(self.G @ x)) + + @staticmethod + def _shape(idx: int, ndims: int) -> Tuple: + """Construct a shape tuple. + + Construct a tuple of size `ndims` with all unit entries except + for index `idx`, which has a -1 entry. + """ + return (1,) * idx + (-1,) + (1,) * (ndims - idx - 1) + + def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array: + r"""Approximate proximal operator of the isotropic TV norm. + + Approximation of the proximal operator of the anisotropic TV norm, + computed via the method described in :cite:`kamilov-2016-parallel`. + + Args: + v: Input array :math:`\mb{v}`. + lam: Proximal parameter :math:`\lam`. + kwargs: Additional arguments that may be used by derived + classes. + """ + if self.ndims is None: + ndims = v.ndim + else: + ndims = self.ndims + K = 2 * ndims + + if self.W is None or self.W.shape[1] != v.shape: + h0 = self.h0.astype(v.dtype) + h1 = self.h1.astype(v.dtype) + C0 = VerticalStack( # Stack of lowpass filter operators for each axis + [ + CircularConvolve( + h0.reshape(AnisotropicTVNorm._shape(k, ndims)), + v.shape, + ndims=self.ndims, + ) + for k in range(ndims) + ] + ) + C1 = VerticalStack( # Stack of highpass filter operators for each axis + [ + CircularConvolve( + h1.reshape(AnisotropicTVNorm._shape(k, ndims)), + v.shape, + ndims=self.ndims, + ) + for k in range(ndims) + ] + ) + # single-level shift-invariant Haar transform + self.W = VerticalStack([C0, C1], jit=True) + + Wv = self.W @ v + # Apply 𝑙1 shrinkage to highpass component of shift-invariant Haar transform + Wv = Wv.at[1].set(self.l1norm.prox(Wv[1], snp.sqrt(2) * K * lam)) + return (1.0 / K) * self.W.T @ Wv diff --git a/scico/test/functional/test_core.py b/scico/test/functional/test_core.py index a48fa632b..6cf0c0393 100644 --- a/scico/test/functional/test_core.py +++ b/scico/test/functional/test_core.py @@ -16,7 +16,12 @@ from scico import functional from scico.random import randn -NO_BLOCK_ARRAY = [functional.L21Norm, functional.L1MinusL2Norm, functional.NuclearNorm] +NO_BLOCK_ARRAY = [ + functional.L21Norm, + functional.L1MinusL2Norm, + functional.NuclearNorm, + functional.AnisotropicTVNorm, +] NO_COMPLEX = [functional.NonNegativeIndicator] diff --git a/scico/test/functional/test_tvnorm.py b/scico/test/functional/test_tvnorm.py new file mode 100644 index 000000000..1cef21535 --- /dev/null +++ b/scico/test/functional/test_tvnorm.py @@ -0,0 +1,40 @@ +import numpy as np + +import scico.random +from scico import functional, linop, loss, metric +from scico.optimize.admm import ADMM, LinearSubproblemSolver +from scico.optimize.pgm import AcceleratedPGM + + +def test_tvnorm(): + + N = 128 + g = np.linspace(0, 2 * np.pi, N, dtype=np.float32) + x_gt = np.sin(2 * g) + x_gt[x_gt > 0.5] = 0.5 + x_gt[x_gt < -0.5] = -0.5 + σ = 0.02 + noise, key = scico.random.randn(x_gt.shape, seed=0) + y = x_gt + σ * noise + + λ = 5e-2 + f = loss.SquaredL2Loss(y=y) + + C = linop.FiniteDifference(input_shape=x_gt.shape, circular=True) + g = λ * functional.L1Norm() + solver = ADMM( + f=f, + g_list=[g], + C_list=[C], + rho_list=[1e1], + x0=y, + maxiter=50, + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}), + ) + x_tvdn = solver.solve() + + h = λ * functional.AnisotropicTVNorm() + solver = AcceleratedPGM(f=f, g=h, L0=2e2, x0=y, maxiter=50) + x_approx = solver.solve() + + assert metric.snr(x_tvdn, x_approx) > 45