From 1b7f583e16249d5a3094c74103b60024ab7111d1 Mon Sep 17 00:00:00 2001 From: Salman Naqvi Date: Thu, 5 Oct 2023 15:02:48 -0700 Subject: [PATCH 01/29] Added a new Functional for TV Norm implementing its proximal operator using the fast subiteration free algorithm proposed by Kamilov, 2016 --- scico/functional/__init__.py | 2 + scico/functional/_norm.py | 104 +++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+) diff --git a/scico/functional/__init__.py b/scico/functional/__init__.py index 53d426067..377e29b1a 100644 --- a/scico/functional/__init__.py +++ b/scico/functional/__init__.py @@ -20,6 +20,7 @@ L21Norm, NuclearNorm, L1MinusL2Norm, + TV2DNorm, ) from ._indicator import NonNegativeIndicator, L2BallIndicator from ._denoiser import BM3D, BM4D, DnCNN @@ -46,6 +47,7 @@ "BM3D", "BM4D", "DnCNN", + "TV2DNorm", ] # Imported items in __all__ appear to originate in top-level functional module diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index 332d0500f..22e7e9825 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -15,6 +15,7 @@ from scico.numpy import Array, BlockArray, count_nonzero from scico.numpy.linalg import norm from scico.numpy.util import no_nan_divide +from scico.linop import FiniteDifference from ._functional import Functional @@ -477,3 +478,106 @@ def prox( svdU, svdS, svdV = snp.linalg.svd(v, full_matrices=False) svdS = snp.maximum(0, svdS - lam) return svdU @ snp.diag(svdS) @ svdV + + +class TV2DNorm(Functional): + r"""The :math:`\ell_{TV}` norm. + + For a :math:`M \times N` matrix, :math:`\mb{A}`, by default, + + .. math:: + \norm{\mb{A}}_{TV} = \sum_{n=1}^N \sum_{m=1}^M + \abs{\nabla{A}_{m,n}} \;. + + This norm currently only has proximal operator defined only for + 2 dimensional data. + + For `BlockArray` inputs, the :math:`\ell_{TV}` norm follows the + reduction rules described in :class:`BlockArray`. + + A typical use case is computing the anisotropic total variation norm. + """ + + has_eval = True + has_prox = True + + def __init__(self, dims, tau: float = 1.0): + r""" + Args: + tau: Parameter :math:`\tau` in the norm definition. + """ + self.dims = dims + self.tau = tau + + def __call__(self, x: Union[Array, BlockArray]) -> float: + r"""Return the :math:`\ell_{TV}` norm of an array.""" + y = 0 + gradOp = FiniteDifference(self.dims, input_dtype=x.dtype, circular=True) + grads = gradOp @ x + for g in grads: + y += snp.abs(g) + return self.tau * snp.sum(y) + + def prox( + self, x: Union[Array, BlockArray], lam: float = 1.0, **kwargs + ) -> Union[Array, BlockArray]: + r"""Proximal operator of the :math:`\ell_{TV}` norm. + + Evaluate proximal operator of the TV norm + :cite:`tip-2016-kamilov`. + + Args: + v: Input array :math:`\mb{v}`. + lam: Proximal parameter :math:`\lam`. + kwargs: Additional arguments that may be used by derived + classes. + """ + D = 2 + K = 2*D + thresh = snp.sqrt(2) * K * self.tau * lam + + y = snp.zeros_like(x) + for ax in range(2): + y = y.at[:].add(self.iht2(self.shrink(self.ht2(x, axis=ax, shift=False), thresh), axis=ax, shift=False)) + y = y.at[:].add(self.iht2(self.shrink(self.ht2(x, axis=ax, shift=True), thresh), axis=ax, shift=True)) + y = y.at[:].divide(K) + + return y + + def ht2(self, x, axis, shift): + s = x.shape + w = snp.zeros(s) + C = 1 / snp.sqrt(2) + if shift: + x = snp.roll(x, -1, axis=axis) + + m = s[axis] // 2 + if not axis: + w = w.at[:m, :].set(C * (x[1::2, :] + x[::2, :])) + w = w.at[m:, :].set(C * (x[1::2, :] - x[::2, :])) + else: + w = w.at[:, :m].set(C * (x[:, 1::2] + x[:, ::2])) + w = w.at[:, m:].set(C * (x[:, 1::2] - x[:, ::2])) + return w + + def iht2(self, w, axis, shift): + s = snp.shape(w) + y = snp.zeros(s) + C = 1 / snp.sqrt(2) + m = s[axis] // 2 + if not axis: + y = y.at[::2, :].set(C * (w[:m, :] - w[m:, :])) + y = y.at[1::2, :].set(C * (w[:m, :] + w[m:, :])) + else: + y = y.at[:, ::2].set(C * (w[:, :m] - w[:, m:])) + y = y.at[:, 1::2].set(C * (w[:, :m] + w[:, m:])) + + if shift: + y = snp.roll(y, 1, axis) + + return y + + def shrink(self, x, tau): + threshed = snp.maximum(snp.abs(x)-tau, 0) + threshed = threshed.at[:].multiply(snp.sign(x)) + return threshed \ No newline at end of file From 9d1d73a9184afe4515be483a61d120a384affabb Mon Sep 17 00:00:00 2001 From: Salman Naqvi Date: Thu, 5 Oct 2023 16:14:10 -0700 Subject: [PATCH 02/29] added checks for input shape in TV2DNorm --- scico/functional/_norm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index 22e7e9825..15b50cfad 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -501,7 +501,7 @@ class TV2DNorm(Functional): has_eval = True has_prox = True - def __init__(self, dims, tau: float = 1.0): + def __init__(self, dims: Tuple[int, int], tau: float = 1.0): r""" Args: tau: Parameter :math:`\tau` in the norm definition. @@ -511,6 +511,7 @@ def __init__(self, dims, tau: float = 1.0): def __call__(self, x: Union[Array, BlockArray]) -> float: r"""Return the :math:`\ell_{TV}` norm of an array.""" + assert x.shape == self.dims y = 0 gradOp = FiniteDifference(self.dims, input_dtype=x.dtype, circular=True) grads = gradOp @ x @@ -532,6 +533,7 @@ def prox( kwargs: Additional arguments that may be used by derived classes. """ + assert x.shape == self.dims D = 2 K = 2*D thresh = snp.sqrt(2) * K * self.tau * lam From 877df4c79790584960a3a69f3a5ac3fd1d50a6a0 Mon Sep 17 00:00:00 2001 From: Salman Naqvi Date: Thu, 5 Oct 2023 16:28:51 -0700 Subject: [PATCH 03/29] fixed lint errors, changed required argument to default in TV2DNorm, fixed inconsistent signature for prox function, added more comments to the helper functions --- scico/functional/_norm.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index 15b50cfad..0b07596fa 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -501,7 +501,7 @@ class TV2DNorm(Functional): has_eval = True has_prox = True - def __init__(self, dims: Tuple[int, int], tau: float = 1.0): + def __init__(self, dims: Tuple[int, int] = (1,1), tau: float = 1.0): r""" Args: tau: Parameter :math:`\tau` in the norm definition. @@ -520,7 +520,7 @@ def __call__(self, x: Union[Array, BlockArray]) -> float: return self.tau * snp.sum(y) def prox( - self, x: Union[Array, BlockArray], lam: float = 1.0, **kwargs + self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""Proximal operator of the :math:`\ell_{TV}` norm. @@ -533,20 +533,21 @@ def prox( kwargs: Additional arguments that may be used by derived classes. """ - assert x.shape == self.dims + assert v.shape == self.dims D = 2 K = 2*D thresh = snp.sqrt(2) * K * self.tau * lam - y = snp.zeros_like(x) + y = snp.zeros_like(v) for ax in range(2): - y = y.at[:].add(self.iht2(self.shrink(self.ht2(x, axis=ax, shift=False), thresh), axis=ax, shift=False)) - y = y.at[:].add(self.iht2(self.shrink(self.ht2(x, axis=ax, shift=True), thresh), axis=ax, shift=True)) + y = y.at[:].add(self.iht2(self.shrink(self.ht2(v, axis=ax, shift=False), thresh), axis=ax, shift=False)) + y = y.at[:].add(self.iht2(self.shrink(self.ht2(v, axis=ax, shift=True), thresh), axis=ax, shift=True)) y = y.at[:].divide(K) return y def ht2(self, x, axis, shift): + r"""Forward Discrete Haar Wavelet transform in 2D""" s = x.shape w = snp.zeros(s) C = 1 / snp.sqrt(2) @@ -563,6 +564,7 @@ def ht2(self, x, axis, shift): return w def iht2(self, w, axis, shift): + r"""Inverse Discrete Haar Wavelet transform in 2D""" s = snp.shape(w) y = snp.zeros(s) C = 1 / snp.sqrt(2) @@ -580,6 +582,7 @@ def iht2(self, w, axis, shift): return y def shrink(self, x, tau): + r"""Wavelet shrinkage operator""" threshed = snp.maximum(snp.abs(x)-tau, 0) threshed = threshed.at[:].multiply(snp.sign(x)) return threshed \ No newline at end of file From 4947382e21e323882d52f288a421d6b7a920716f Mon Sep 17 00:00:00 2001 From: Salman Naqvi Date: Thu, 5 Oct 2023 16:30:57 -0700 Subject: [PATCH 04/29] some unsaved changes from last commit --- scico/functional/_norm.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index 0b07596fa..61443f6b9 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -490,7 +490,7 @@ class TV2DNorm(Functional): \abs{\nabla{A}_{m,n}} \;. This norm currently only has proximal operator defined only for - 2 dimensional data. + 2 dimensional data. For `BlockArray` inputs, the :math:`\ell_{TV}` norm follows the reduction rules described in :class:`BlockArray`. @@ -523,10 +523,10 @@ def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""Proximal operator of the :math:`\ell_{TV}` norm. - + Evaluate proximal operator of the TV norm :cite:`tip-2016-kamilov`. - + Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lam`. @@ -535,15 +535,23 @@ def prox( """ assert v.shape == self.dims D = 2 - K = 2*D + K = 2 * D thresh = snp.sqrt(2) * K * self.tau * lam y = snp.zeros_like(v) for ax in range(2): - y = y.at[:].add(self.iht2(self.shrink(self.ht2(v, axis=ax, shift=False), thresh), axis=ax, shift=False)) - y = y.at[:].add(self.iht2(self.shrink(self.ht2(v, axis=ax, shift=True), thresh), axis=ax, shift=True)) + y = y.at[:].add( + self.iht2( + self.shrink(self.ht2(v, axis=ax, shift=False), thresh), axis=ax, shift=False + ) + ) + y = y.at[:].add( + self.iht2( + self.shrink(self.ht2(v, axis=ax, shift=True), thresh), axis=ax, shift=True + ) + ) y = y.at[:].divide(K) - + return y def ht2(self, x, axis, shift): @@ -554,13 +562,14 @@ def ht2(self, x, axis, shift): if shift: x = snp.roll(x, -1, axis=axis) - m = s[axis] // 2 + m = s[axis] // 2 if not axis: w = w.at[:m, :].set(C * (x[1::2, :] + x[::2, :])) w = w.at[m:, :].set(C * (x[1::2, :] - x[::2, :])) else: w = w.at[:, :m].set(C * (x[:, 1::2] + x[:, ::2])) w = w.at[:, m:].set(C * (x[:, 1::2] - x[:, ::2])) + return w def iht2(self, w, axis, shift): @@ -568,7 +577,7 @@ def iht2(self, w, axis, shift): s = snp.shape(w) y = snp.zeros(s) C = 1 / snp.sqrt(2) - m = s[axis] // 2 + m = s[axis] // 2 if not axis: y = y.at[::2, :].set(C * (w[:m, :] - w[m:, :])) y = y.at[1::2, :].set(C * (w[:m, :] + w[m:, :])) @@ -583,6 +592,7 @@ def iht2(self, w, axis, shift): def shrink(self, x, tau): r"""Wavelet shrinkage operator""" - threshed = snp.maximum(snp.abs(x)-tau, 0) + threshed = snp.maximum(snp.abs(x) - tau, 0) threshed = threshed.at[:].multiply(snp.sign(x)) - return threshed \ No newline at end of file + return threshed + \ No newline at end of file From 4375ddfe04b0c42619a2d4540385f776b46ab17b Mon Sep 17 00:00:00 2001 From: Salman Naqvi Date: Thu, 5 Oct 2023 16:34:36 -0700 Subject: [PATCH 05/29] some unsaved changes from last commit --- scico/functional/_norm.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index 61443f6b9..352377d26 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -501,7 +501,7 @@ class TV2DNorm(Functional): has_eval = True has_prox = True - def __init__(self, dims: Tuple[int, int] = (1,1), tau: float = 1.0): + def __init__(self, dims: Tuple[int, int] = (1, 1), tau: float = 1.0): r""" Args: tau: Parameter :math:`\tau` in the norm definition. @@ -543,15 +543,14 @@ def prox( y = y.at[:].add( self.iht2( self.shrink(self.ht2(v, axis=ax, shift=False), thresh), axis=ax, shift=False - ) ) + ) y = y.at[:].add( self.iht2( self.shrink(self.ht2(v, axis=ax, shift=True), thresh), axis=ax, shift=True - ) ) + ) y = y.at[:].divide(K) - return y def ht2(self, x, axis, shift): @@ -569,7 +568,6 @@ def ht2(self, x, axis, shift): else: w = w.at[:, :m].set(C * (x[:, 1::2] + x[:, ::2])) w = w.at[:, m:].set(C * (x[:, 1::2] - x[:, ::2])) - return w def iht2(self, w, axis, shift): @@ -587,12 +585,10 @@ def iht2(self, w, axis, shift): if shift: y = snp.roll(y, 1, axis) - return y def shrink(self, x, tau): r"""Wavelet shrinkage operator""" threshed = snp.maximum(snp.abs(x) - tau, 0) threshed = threshed.at[:].multiply(snp.sign(x)) - return threshed - \ No newline at end of file + return threshed \ No newline at end of file From 71fc6363b08b00b8a091faf6634885e3237de8fd Mon Sep 17 00:00:00 2001 From: Salman Naqvi Date: Thu, 5 Oct 2023 16:39:38 -0700 Subject: [PATCH 06/29] newline at end of file error --- scico/functional/_norm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index 352377d26..945ab8bbb 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -591,4 +591,4 @@ def shrink(self, x, tau): r"""Wavelet shrinkage operator""" threshed = snp.maximum(snp.abs(x) - tau, 0) threshed = threshed.at[:].multiply(snp.sign(x)) - return threshed \ No newline at end of file + return threshed From c62da28ebc4e1de42452e0384800359977f5bfb2 Mon Sep 17 00:00:00 2001 From: Salman Naqvi Date: Thu, 5 Oct 2023 16:43:33 -0700 Subject: [PATCH 07/29] sort imports lint error --- scico/functional/_norm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index 945ab8bbb..80fc15209 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -12,10 +12,10 @@ from jax import jit, lax from scico import numpy as snp +from scico.linop import FiniteDifference from scico.numpy import Array, BlockArray, count_nonzero from scico.numpy.linalg import norm from scico.numpy.util import no_nan_divide -from scico.linop import FiniteDifference from ._functional import Functional From 98ce9898394cf4469a687fe472810e570c4bb213 Mon Sep 17 00:00:00 2001 From: Salman Naqvi Date: Thu, 5 Oct 2023 17:51:10 -0700 Subject: [PATCH 08/29] removed the default shape parameter from TV2DNorm --- scico/functional/_norm.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index 80fc15209..11a931c5c 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -7,6 +7,7 @@ """Functionals that are norms.""" +import warnings from typing import Optional, Tuple, Union from jax import jit, lax @@ -501,19 +502,17 @@ class TV2DNorm(Functional): has_eval = True has_prox = True - def __init__(self, dims: Tuple[int, int] = (1, 1), tau: float = 1.0): + def __init__(self, tau: float = 1.0): r""" Args: tau: Parameter :math:`\tau` in the norm definition. """ - self.dims = dims self.tau = tau def __call__(self, x: Union[Array, BlockArray]) -> float: r"""Return the :math:`\ell_{TV}` norm of an array.""" - assert x.shape == self.dims y = 0 - gradOp = FiniteDifference(self.dims, input_dtype=x.dtype, circular=True) + gradOp = FiniteDifference(x.shape, input_dtype=x.dtype, circular=True) grads = gradOp @ x for g in grads: y += snp.abs(g) @@ -533,7 +532,6 @@ def prox( kwargs: Additional arguments that may be used by derived classes. """ - assert v.shape == self.dims D = 2 K = 2 * D thresh = snp.sqrt(2) * K * self.tau * lam From 096d1c92658dee82b76f7e67e43017272ebb04cb Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 9 Oct 2023 01:41:06 -0600 Subject: [PATCH 09/29] Some docs edits --- docs/source/references.bib | 12 ++++++++++++ scico/functional/_norm.py | 22 ++++++++++------------ 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/docs/source/references.bib b/docs/source/references.bib index e611ff201..b611601bb 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -387,6 +387,18 @@ @Article {jin-2017-unet doi = {10.1109/TIP.2017.2713099} } +@Article {kamilov-2016-parallel, + title = {A parallel proximal algorithm for anisotropic total + variation minimization}, + author = {Ulugbek S. Kamilov}, + journal = {IEEE Transactions on Image Processing}, + volume = 26, + number = 2, + pages = {539--548}, + year = 2016, + doi = {10.1109/tip.2016.2629449 } +} + @Article {kamilov-2017-plugandplay, author = {Ulugbek Kamilov and Hassan Mansour and Brendt Wohlberg}, diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index 11a931c5c..da9372102 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -482,21 +482,19 @@ def prox( class TV2DNorm(Functional): - r"""The :math:`\ell_{TV}` norm. + r"""The anisotropic total variation (TV) norm. For a :math:`M \times N` matrix, :math:`\mb{A}`, by default, .. math:: - \norm{\mb{A}}_{TV} = \sum_{n=1}^N \sum_{m=1}^M + \norm{\mb{A}}_{\text{TV}} = \sum_{n=1}^N \sum_{m=1}^M \abs{\nabla{A}_{m,n}} \;. - This norm currently only has proximal operator defined only for - 2 dimensional data. + The proximal operator of this norm is currently only defined for 2 + dimensional data. - For `BlockArray` inputs, the :math:`\ell_{TV}` norm follows the - reduction rules described in :class:`BlockArray`. - - A typical use case is computing the anisotropic total variation norm. + For `BlockArray` inputs, the TV norm follows the reduction rules + described in :class:`BlockArray`. """ has_eval = True @@ -510,7 +508,7 @@ def __init__(self, tau: float = 1.0): self.tau = tau def __call__(self, x: Union[Array, BlockArray]) -> float: - r"""Return the :math:`\ell_{TV}` norm of an array.""" + r"""Return the TV norm of an array.""" y = 0 gradOp = FiniteDifference(x.shape, input_dtype=x.dtype, circular=True) grads = gradOp @ x @@ -521,10 +519,10 @@ def __call__(self, x: Union[Array, BlockArray]) -> float: def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: - r"""Proximal operator of the :math:`\ell_{TV}` norm. + r"""Proximal operator of the TV norm. - Evaluate proximal operator of the TV norm - :cite:`tip-2016-kamilov`. + Approximate the proximal operator of the anisotropic TV norm via + the method described in :cite:`kamilov-2016-parallel`. Args: v: Input array :math:`\mb{v}`. From c2e1de5b112f59e4e0b94ff01a017f6e47b171b3 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 9 Oct 2023 02:06:12 -0600 Subject: [PATCH 10/29] Disable BlockArray tests on TV2DNorm --- scico/test/functional/test_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/test/functional/test_core.py b/scico/test/functional/test_core.py index 9b4f43f8a..8dd1f6706 100644 --- a/scico/test/functional/test_core.py +++ b/scico/test/functional/test_core.py @@ -16,7 +16,7 @@ from scico import functional from scico.random import randn -NO_BLOCK_ARRAY = [functional.L21Norm, functional.L1MinusL2Norm, functional.NuclearNorm] +NO_BLOCK_ARRAY = [functional.L21Norm, functional.L1MinusL2Norm, functional.NuclearNorm, functional.TV2DNorm] NO_COMPLEX = [functional.NonNegativeIndicator] From 3a0cdb0c07a700333d69f350b09987752f258f98 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 9 Oct 2023 02:06:49 -0600 Subject: [PATCH 11/29] Fix black formatting --- scico/test/functional/test_core.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/scico/test/functional/test_core.py b/scico/test/functional/test_core.py index 8dd1f6706..74c79bbe8 100644 --- a/scico/test/functional/test_core.py +++ b/scico/test/functional/test_core.py @@ -16,7 +16,12 @@ from scico import functional from scico.random import randn -NO_BLOCK_ARRAY = [functional.L21Norm, functional.L1MinusL2Norm, functional.NuclearNorm, functional.TV2DNorm] +NO_BLOCK_ARRAY = [ + functional.L21Norm, + functional.L1MinusL2Norm, + functional.NuclearNorm, + functional.TV2DNorm, +] NO_COMPLEX = [functional.NonNegativeIndicator] From 605d11b1094ee574b6aa3a60cd5dcd960ba8cdc2 Mon Sep 17 00:00:00 2001 From: Salman Naqvi Date: Tue, 10 Oct 2023 23:14:47 -0700 Subject: [PATCH 12/29] updated the TV norm logic to apply shrinkage to only the difference operator of the haar transform as in Kamilov, 2016 --- scico/functional/_norm.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index da9372102..933062bb0 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -538,40 +538,38 @@ def prox( for ax in range(2): y = y.at[:].add( self.iht2( - self.shrink(self.ht2(v, axis=ax, shift=False), thresh), axis=ax, shift=False + self.ht2_shrink(v, axis=ax, shift=False, thresh=thresh), axis=ax, shift=False ) ) y = y.at[:].add( self.iht2( - self.shrink(self.ht2(v, axis=ax, shift=True), thresh), axis=ax, shift=True + self.ht2_shrink(v, axis=ax, shift=True, thresh=thresh), axis=ax, shift=True ) ) y = y.at[:].divide(K) return y - def ht2(self, x, axis, shift): + def ht2_shrink(self, x, axis, shift, thresh): r"""Forward Discrete Haar Wavelet transform in 2D""" - s = x.shape - w = snp.zeros(s) + w = snp.zeros_like(x) C = 1 / snp.sqrt(2) if shift: x = snp.roll(x, -1, axis=axis) - m = s[axis] // 2 + m = x.shape[axis] // 2 if not axis: w = w.at[:m, :].set(C * (x[1::2, :] + x[::2, :])) - w = w.at[m:, :].set(C * (x[1::2, :] - x[::2, :])) + w = w.at[m:, :].set(self.shrink(C * (x[1::2, :] - x[::2, :]), thresh)) else: w = w.at[:, :m].set(C * (x[:, 1::2] + x[:, ::2])) - w = w.at[:, m:].set(C * (x[:, 1::2] - x[:, ::2])) + w = w.at[:, m:].set(self.shrink(C * (x[:, 1::2] - x[:, ::2]), thresh)) return w def iht2(self, w, axis, shift): r"""Inverse Discrete Haar Wavelet transform in 2D""" - s = snp.shape(w) - y = snp.zeros(s) + y = snp.zeros_like(w) C = 1 / snp.sqrt(2) - m = s[axis] // 2 + m = w.shape[axis] // 2 if not axis: y = y.at[::2, :].set(C * (w[:m, :] - w[m:, :])) y = y.at[1::2, :].set(C * (w[:m, :] + w[m:, :])) From c8efe90beeb7e73311190c27c8fc9eb746d8d506 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 1 Nov 2023 21:47:43 -0600 Subject: [PATCH 13/29] Implementation supporting arbitrary dimensional inputs --- scico/functional/__init__.py | 4 +- scico/functional/_norm.py | 107 ----------------------------------- 2 files changed, 2 insertions(+), 109 deletions(-) diff --git a/scico/functional/__init__.py b/scico/functional/__init__.py index 377e29b1a..48509cd40 100644 --- a/scico/functional/__init__.py +++ b/scico/functional/__init__.py @@ -20,14 +20,15 @@ L21Norm, NuclearNorm, L1MinusL2Norm, - TV2DNorm, ) +from ._tvnorm import AnisotropicTVNorm from ._indicator import NonNegativeIndicator, L2BallIndicator from ._denoiser import BM3D, BM4D, DnCNN from ._dist import SetDistance, SquaredSetDistance __all__ = [ + "AnisotropicTVNorm", "Functional", "ScaledFunctional", "SeparableFunctional", @@ -47,7 +48,6 @@ "BM3D", "BM4D", "DnCNN", - "TV2DNorm", ] # Imported items in __all__ appear to originate in top-level functional module diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index 933062bb0..5d0a6579e 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -479,110 +479,3 @@ def prox( svdU, svdS, svdV = snp.linalg.svd(v, full_matrices=False) svdS = snp.maximum(0, svdS - lam) return svdU @ snp.diag(svdS) @ svdV - - -class TV2DNorm(Functional): - r"""The anisotropic total variation (TV) norm. - - For a :math:`M \times N` matrix, :math:`\mb{A}`, by default, - - .. math:: - \norm{\mb{A}}_{\text{TV}} = \sum_{n=1}^N \sum_{m=1}^M - \abs{\nabla{A}_{m,n}} \;. - - The proximal operator of this norm is currently only defined for 2 - dimensional data. - - For `BlockArray` inputs, the TV norm follows the reduction rules - described in :class:`BlockArray`. - """ - - has_eval = True - has_prox = True - - def __init__(self, tau: float = 1.0): - r""" - Args: - tau: Parameter :math:`\tau` in the norm definition. - """ - self.tau = tau - - def __call__(self, x: Union[Array, BlockArray]) -> float: - r"""Return the TV norm of an array.""" - y = 0 - gradOp = FiniteDifference(x.shape, input_dtype=x.dtype, circular=True) - grads = gradOp @ x - for g in grads: - y += snp.abs(g) - return self.tau * snp.sum(y) - - def prox( - self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs - ) -> Union[Array, BlockArray]: - r"""Proximal operator of the TV norm. - - Approximate the proximal operator of the anisotropic TV norm via - the method described in :cite:`kamilov-2016-parallel`. - - Args: - v: Input array :math:`\mb{v}`. - lam: Proximal parameter :math:`\lam`. - kwargs: Additional arguments that may be used by derived - classes. - """ - D = 2 - K = 2 * D - thresh = snp.sqrt(2) * K * self.tau * lam - - y = snp.zeros_like(v) - for ax in range(2): - y = y.at[:].add( - self.iht2( - self.ht2_shrink(v, axis=ax, shift=False, thresh=thresh), axis=ax, shift=False - ) - ) - y = y.at[:].add( - self.iht2( - self.ht2_shrink(v, axis=ax, shift=True, thresh=thresh), axis=ax, shift=True - ) - ) - y = y.at[:].divide(K) - return y - - def ht2_shrink(self, x, axis, shift, thresh): - r"""Forward Discrete Haar Wavelet transform in 2D""" - w = snp.zeros_like(x) - C = 1 / snp.sqrt(2) - if shift: - x = snp.roll(x, -1, axis=axis) - - m = x.shape[axis] // 2 - if not axis: - w = w.at[:m, :].set(C * (x[1::2, :] + x[::2, :])) - w = w.at[m:, :].set(self.shrink(C * (x[1::2, :] - x[::2, :]), thresh)) - else: - w = w.at[:, :m].set(C * (x[:, 1::2] + x[:, ::2])) - w = w.at[:, m:].set(self.shrink(C * (x[:, 1::2] - x[:, ::2]), thresh)) - return w - - def iht2(self, w, axis, shift): - r"""Inverse Discrete Haar Wavelet transform in 2D""" - y = snp.zeros_like(w) - C = 1 / snp.sqrt(2) - m = w.shape[axis] // 2 - if not axis: - y = y.at[::2, :].set(C * (w[:m, :] - w[m:, :])) - y = y.at[1::2, :].set(C * (w[:m, :] + w[m:, :])) - else: - y = y.at[:, ::2].set(C * (w[:, :m] - w[:, m:])) - y = y.at[:, 1::2].set(C * (w[:, :m] + w[:, m:])) - - if shift: - y = snp.roll(y, 1, axis) - return y - - def shrink(self, x, tau): - r"""Wavelet shrinkage operator""" - threshed = snp.maximum(snp.abs(x) - tau, 0) - threshed = threshed.at[:].multiply(snp.sign(x)) - return threshed From ec8686ef7b8413127ca1e36efc8d48ef10f22fcd Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 2 Nov 2023 21:39:03 -0600 Subject: [PATCH 14/29] Add a test --- scico/test/functional/test_tvnorm.py | 43 ++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 scico/test/functional/test_tvnorm.py diff --git a/scico/test/functional/test_tvnorm.py b/scico/test/functional/test_tvnorm.py new file mode 100644 index 000000000..740601353 --- /dev/null +++ b/scico/test/functional/test_tvnorm.py @@ -0,0 +1,43 @@ +import numpy as np + +import scico.random +from scico import functional, linop, loss, metric +from scico.optimize.admm import ADMM, LinearSubproblemSolver +from scico.optimize.pgm import AcceleratedPGM + + +def test_tvnorm(): + + N = 128 + g = np.linspace(0, 2 * np.pi, N) + x_gt = np.sin(2 * g) + x_gt[x_gt > 0.5] = 0.5 + x_gt[x_gt < -0.5] = -0.5 + σ = 0.02 + noise, key = scico.random.randn(x_gt.shape, seed=0) + y = x_gt + σ * noise + + λ = 5e-2 + + f = loss.SquaredL2Loss(y=y) + g = λ * functional.L1Norm() + C = linop.FiniteDifference(input_shape=x_gt.shape, circular=True) + solver = ADMM( + f=f, + g_list=[g], + C_list=[C], + rho_list=[1e1], + x0=y, + maxiter=50, + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}), + itstat_options={"display": True, "period": 10}, + ) + x_tvdn = solver.solve() + + h = λ * functional.AnisotropicTVNorm() + solver = AcceleratedPGM( + f=f, g=h, L0=2e2, x0=y, maxiter=50, itstat_options={"display": True, "period": 10} + ) + x_approx = solver.solve() + + assert metric.snr(x_tvdn, x_approx) > 45 From 4f2f189bcedd664b7e4569ca447d119d006e6169 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 3 Nov 2023 10:54:50 -0600 Subject: [PATCH 15/29] Minor changes --- scico/test/functional/test_tvnorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scico/test/functional/test_tvnorm.py b/scico/test/functional/test_tvnorm.py index 740601353..232665aca 100644 --- a/scico/test/functional/test_tvnorm.py +++ b/scico/test/functional/test_tvnorm.py @@ -18,10 +18,10 @@ def test_tvnorm(): y = x_gt + σ * noise λ = 5e-2 - f = loss.SquaredL2Loss(y=y) - g = λ * functional.L1Norm() + C = linop.FiniteDifference(input_shape=x_gt.shape, circular=True) + g = λ * functional.L1Norm() solver = ADMM( f=f, g_list=[g], From b7427f7fa5cde3b68dc67b1f30bd657715295b7d Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 3 Nov 2023 11:47:39 -0600 Subject: [PATCH 16/29] New implementation of TV norm and approximage prox --- scico/functional/_tvnorm.py | 138 ++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 scico/functional/_tvnorm.py diff --git a/scico/functional/_tvnorm.py b/scico/functional/_tvnorm.py new file mode 100644 index 000000000..ed2a2992f --- /dev/null +++ b/scico/functional/_tvnorm.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2023 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SCICO package. Details of the copyright and +# user license can be found in the 'LICENSE' file distributed with the +# package. + +"""Anisotropic total variation norm.""" + +import warnings +from typing import Optional, Tuple, Union + +from jax import jit, lax + +from scico import numpy as snp +from scico.linop import FiniteDifference, VerticalStack, CircularConvolve +from scico.numpy import Array, BlockArray, count_nonzero +from scico.numpy.linalg import norm +from scico.numpy.util import no_nan_divide + +from ._functional import Functional +from ._norm import L1Norm + + +class AnisotropicTVNorm(Functional): + r"""The anisotropic total variation (TV) norm. + + The anisotropic total variation (TV) norm computed by + + .. code-block:: python + + ATV = scico.functional.AnisotropicTVNorm() + x_norm = ATV(x) + + is equivalent to + + .. code-block:: python + + C = linop.FiniteDifference(input_shape=x.shape, circular=True) + L1 = functional.L1Norm() + x_norm = L1(C @ x) + + The scaled proximal operator is computed using an approximation that + holds for small scaling parameters :cite:`kamilov-2016-parallel`. + This does not imply that it can only be applied to problems requiring + a small regularization parameter since most proximal algorithms + include an additional algorithm parameter that also plays a role in + the parameter of the proximal operator. For example, in :class:`.PGM` + and :class:`.AcceleratedPGM`, the scaled proximal operator parameter + is the regularization parameter divided by the `L0` algorithm + parameter, and for :class:`.ADMM`, the scaled proximal operator + parameters are the regularization parameters divided by the entries + in the `rho_list` algorithm parameter. + """ + + has_eval = True + has_prox = True + + def __init__(self, ndims: Optional[int] = None): + r""" + Args: + ndims: Number of (trailing) dimensions of the input over + which to apply the finite difference operator. If + ``None``, differences are evaluated along all axes. + """ + self.ndims = ndims + self.h0 = snp.array([1.0, 1.0]) / snp.sqrt(2.0) # lowpass filter + self.h1 = snp.array([1.0, -1.0]) / snp.sqrt(2.0) # highpass filter + self.l1norm = L1Norm() + self.G = None + self.W = None + + def __call__(self, x: Array) -> float: + r"""Compute the anisotropic TV norm of an array.""" + if self.G is None or self.G.shape[1] != x.shape: + if self.ndims is None: + ndims = x.ndim + else: + ndims = self.ndims + axes = tuple(range(ndims)) + self.G = FiniteDifference(x.shape, input_dtype=x.dtype, axes=axes, circular=True) + return snp.sum(snp.abs(self.G @ x)) + + @staticmethod + def _shape(idx: int, ndims: int) -> Tuple: + """Construct a shape tuple. + + Construct a tuple of size `ndims` with all unit entries except + for index `idx`, which has a -1 entry. + """ + return (1,) * idx + (-1,) + (1,) * (ndims - idx - 1) + + def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array: + r"""Approximate proximal operator of the anisotripic TV norm. + + Approximation of the proximal operator of the anisotropic TV norm, + computed via the method described in :cite:`kamilov-2016-parallel`. + + Args: + v: Input array :math:`\mb{v}`. + lam: Proximal parameter :math:`\lam`. + kwargs: Additional arguments that may be used by derived + classes. + """ + if self.ndims is None: + ndims = v.ndim + else: + ndims = self.ndims + K = 2 * ndims + + if self.W is None or self.W.shape[1] != v.shape: + C0 = VerticalStack( # Stack of lowpass filter operators for each axis + [ + CircularConvolve( + self.h0.reshape(AnisotropicTVNorm._shape(k, ndims)), + v.shape, + ndims=self.ndims, + ) + for k in range(ndims) + ] + ) + C1 = VerticalStack( # Stack of highpass filter operators for each axis + [ + CircularConvolve( + self.h1.reshape(AnisotropicTVNorm._shape(k, ndims)), + v.shape, + ndims=self.ndims, + ) + for k in range(ndims) + ] + ) + # single-level shift-invariant Haar transform + self.W = VerticalStack((C0, C1), jit=True) + + Wv = self.W @ v + # Apply 𝑙1 shrinkage to highpass component of shift-invariant Haar transform + Wv = Wv.at[1].set(self.l1norm.prox(Wv[1], snp.sqrt(2) * K * lam)) + return (1.0 / K) * self.W.T @ Wv From c0c96337fcc529a26aef20655587cb6f9a72bd5b Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 3 Nov 2023 11:51:20 -0600 Subject: [PATCH 17/29] Clean up --- scico/functional/_tvnorm.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/scico/functional/_tvnorm.py b/scico/functional/_tvnorm.py index ed2a2992f..2f1f1088d 100644 --- a/scico/functional/_tvnorm.py +++ b/scico/functional/_tvnorm.py @@ -7,16 +7,11 @@ """Anisotropic total variation norm.""" -import warnings -from typing import Optional, Tuple, Union - -from jax import jit, lax +from typing import Optional, Tuple from scico import numpy as snp -from scico.linop import FiniteDifference, VerticalStack, CircularConvolve -from scico.numpy import Array, BlockArray, count_nonzero -from scico.numpy.linalg import norm -from scico.numpy.util import no_nan_divide +from scico.linop import CircularConvolve, FiniteDifference, VerticalStack +from scico.numpy import Array from ._functional import Functional from ._norm import L1Norm @@ -64,7 +59,7 @@ def __init__(self, ndims: Optional[int] = None): ``None``, differences are evaluated along all axes. """ self.ndims = ndims - self.h0 = snp.array([1.0, 1.0]) / snp.sqrt(2.0) # lowpass filter + self.h0 = snp.array([1.0, 1.0]) / snp.sqrt(2.0) # lowpass filter self.h1 = snp.array([1.0, -1.0]) / snp.sqrt(2.0) # highpass filter self.l1norm = L1Norm() self.G = None From f251c60dfa3f0dc1f17f4d14c71e0c8075791072 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 3 Nov 2023 11:52:07 -0600 Subject: [PATCH 18/29] Typo fix --- scico/functional/_tvnorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/functional/_tvnorm.py b/scico/functional/_tvnorm.py index 2f1f1088d..a1a7a0ff7 100644 --- a/scico/functional/_tvnorm.py +++ b/scico/functional/_tvnorm.py @@ -86,7 +86,7 @@ def _shape(idx: int, ndims: int) -> Tuple: return (1,) * idx + (-1,) + (1,) * (ndims - idx - 1) def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array: - r"""Approximate proximal operator of the anisotripic TV norm. + r"""Approximate proximal operator of the isotropic TV norm. Approximation of the proximal operator of the anisotropic TV norm, computed via the method described in :cite:`kamilov-2016-parallel`. From feb4b7765426860f7ac3420830007bd0aed17b09 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sat, 4 Nov 2023 07:14:36 -0600 Subject: [PATCH 19/29] Minor change --- scico/functional/_tvnorm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scico/functional/_tvnorm.py b/scico/functional/_tvnorm.py index a1a7a0ff7..f0e726cb7 100644 --- a/scico/functional/_tvnorm.py +++ b/scico/functional/_tvnorm.py @@ -73,7 +73,9 @@ def __call__(self, x: Array) -> float: else: ndims = self.ndims axes = tuple(range(ndims)) - self.G = FiniteDifference(x.shape, input_dtype=x.dtype, axes=axes, circular=True) + self.G = FiniteDifference( + x.shape, input_dtype=x.dtype, axes=axes, circular=True, jit=True + ) return snp.sum(snp.abs(self.G @ x)) @staticmethod From 7fe98b9fe2c140febb6bcf539e9be11f9c87ebe0 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sat, 4 Nov 2023 08:02:58 -0600 Subject: [PATCH 20/29] Add change log entry --- CHANGES.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGES.rst b/CHANGES.rst index 7dc048a0a..265b9dd38 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,6 +6,8 @@ SCICO Release Notes Version 0.0.5 (unreleased) ---------------------------- +• New functional ``functional.AnisotropicTVNorm`` with proximal operator + approximation. • New integrated Radon/X-ray transform ``linop.XRayTransform``. • Rename modules ``radon_astra`` and ``radon_svmbir`` to ``xray.astra`` and ``xray.svmbir`` respectively, and rename ``TomographicProjector`` classes From ded11f8ac5de4d9cbb652495103ed1576d9ca8ec Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sun, 5 Nov 2023 10:26:16 -0700 Subject: [PATCH 21/29] Resolve typing errors --- scico/functional/_tvnorm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scico/functional/_tvnorm.py b/scico/functional/_tvnorm.py index f0e726cb7..8c3195a09 100644 --- a/scico/functional/_tvnorm.py +++ b/scico/functional/_tvnorm.py @@ -10,7 +10,7 @@ from typing import Optional, Tuple from scico import numpy as snp -from scico.linop import CircularConvolve, FiniteDifference, VerticalStack +from scico.linop import LinearOperator, CircularConvolve, FiniteDifference, VerticalStack from scico.numpy import Array from ._functional import Functional @@ -62,8 +62,8 @@ def __init__(self, ndims: Optional[int] = None): self.h0 = snp.array([1.0, 1.0]) / snp.sqrt(2.0) # lowpass filter self.h1 = snp.array([1.0, -1.0]) / snp.sqrt(2.0) # highpass filter self.l1norm = L1Norm() - self.G = None - self.W = None + self.G: Optional[LinearOperator] = None + self.W: Optional[LinearOperator] = None def __call__(self, x: Array) -> float: r"""Compute the anisotropic TV norm of an array.""" @@ -127,7 +127,7 @@ def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array: ] ) # single-level shift-invariant Haar transform - self.W = VerticalStack((C0, C1), jit=True) + self.W = VerticalStack([C0, C1], jit=True) Wv = self.W @ v # Apply 𝑙1 shrinkage to highpass component of shift-invariant Haar transform From c760e45eaeacbc9f6bcddb3b7a29095d18f4f2fc Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sun, 5 Nov 2023 10:41:56 -0700 Subject: [PATCH 22/29] Resolve some oversights and issues arising when 64 bit floats enabled --- scico/functional/_tvnorm.py | 6 ++++-- scico/test/functional/test_core.py | 2 +- scico/test/functional/test_tvnorm.py | 7 +++---- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/scico/functional/_tvnorm.py b/scico/functional/_tvnorm.py index 8c3195a09..a6b36a32d 100644 --- a/scico/functional/_tvnorm.py +++ b/scico/functional/_tvnorm.py @@ -106,10 +106,12 @@ def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array: K = 2 * ndims if self.W is None or self.W.shape[1] != v.shape: + h0 = self.h0.astype(v.dtype) + h1 = self.h1.astype(v.dtype) C0 = VerticalStack( # Stack of lowpass filter operators for each axis [ CircularConvolve( - self.h0.reshape(AnisotropicTVNorm._shape(k, ndims)), + h0.reshape(AnisotropicTVNorm._shape(k, ndims)), v.shape, ndims=self.ndims, ) @@ -119,7 +121,7 @@ def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array: C1 = VerticalStack( # Stack of highpass filter operators for each axis [ CircularConvolve( - self.h1.reshape(AnisotropicTVNorm._shape(k, ndims)), + h1.reshape(AnisotropicTVNorm._shape(k, ndims)), v.shape, ndims=self.ndims, ) diff --git a/scico/test/functional/test_core.py b/scico/test/functional/test_core.py index e77226d5f..9d1aa78fe 100644 --- a/scico/test/functional/test_core.py +++ b/scico/test/functional/test_core.py @@ -20,7 +20,7 @@ functional.L21Norm, functional.L1MinusL2Norm, functional.NuclearNorm, - functional.TV2DNorm, + functional.AnisotropicTVNorm ] NO_COMPLEX = [functional.NonNegativeIndicator] diff --git a/scico/test/functional/test_tvnorm.py b/scico/test/functional/test_tvnorm.py index 232665aca..3a2fb80c5 100644 --- a/scico/test/functional/test_tvnorm.py +++ b/scico/test/functional/test_tvnorm.py @@ -9,7 +9,7 @@ def test_tvnorm(): N = 128 - g = np.linspace(0, 2 * np.pi, N) + g = np.linspace(0, 2 * np.pi, N, dtype=np.float32) x_gt = np.sin(2 * g) x_gt[x_gt > 0.5] = 0.5 x_gt[x_gt < -0.5] = -0.5 @@ -29,14 +29,13 @@ def test_tvnorm(): rho_list=[1e1], x0=y, maxiter=50, - subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}), - itstat_options={"display": True, "period": 10}, + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}) ) x_tvdn = solver.solve() h = λ * functional.AnisotropicTVNorm() solver = AcceleratedPGM( - f=f, g=h, L0=2e2, x0=y, maxiter=50, itstat_options={"display": True, "period": 10} + f=f, g=h, L0=2e2, x0=y, maxiter=50 ) x_approx = solver.solve() From dafd6261322f8d299bd2b0e7f47eeaf074187bce Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sun, 5 Nov 2023 10:44:45 -0700 Subject: [PATCH 23/29] Standardise code formatting --- scico/test/functional/test_tvnorm.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/scico/test/functional/test_tvnorm.py b/scico/test/functional/test_tvnorm.py index 3a2fb80c5..1cef21535 100644 --- a/scico/test/functional/test_tvnorm.py +++ b/scico/test/functional/test_tvnorm.py @@ -29,14 +29,12 @@ def test_tvnorm(): rho_list=[1e1], x0=y, maxiter=50, - subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}) + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}), ) x_tvdn = solver.solve() h = λ * functional.AnisotropicTVNorm() - solver = AcceleratedPGM( - f=f, g=h, L0=2e2, x0=y, maxiter=50 - ) + solver = AcceleratedPGM(f=f, g=h, L0=2e2, x0=y, maxiter=50) x_approx = solver.solve() assert metric.snr(x_tvdn, x_approx) > 45 From 2949931bccc34d780df8565227fbb97b18e69ac7 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sun, 5 Nov 2023 10:46:11 -0700 Subject: [PATCH 24/29] Standardise code formatting --- scico/test/functional/test_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/test/functional/test_core.py b/scico/test/functional/test_core.py index 9d1aa78fe..6cf0c0393 100644 --- a/scico/test/functional/test_core.py +++ b/scico/test/functional/test_core.py @@ -20,7 +20,7 @@ functional.L21Norm, functional.L1MinusL2Norm, functional.NuclearNorm, - functional.AnisotropicTVNorm + functional.AnisotropicTVNorm, ] NO_COMPLEX = [functional.NonNegativeIndicator] From 6a654ecd9913e737db109893c2fa7c3433c12740 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sun, 5 Nov 2023 10:48:00 -0700 Subject: [PATCH 25/29] Standardise code formatting --- scico/functional/_tvnorm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/scico/functional/_tvnorm.py b/scico/functional/_tvnorm.py index a6b36a32d..b8d621c93 100644 --- a/scico/functional/_tvnorm.py +++ b/scico/functional/_tvnorm.py @@ -10,7 +10,12 @@ from typing import Optional, Tuple from scico import numpy as snp -from scico.linop import LinearOperator, CircularConvolve, FiniteDifference, VerticalStack +from scico.linop import ( + CircularConvolve, + FiniteDifference, + LinearOperator, + VerticalStack, +) from scico.numpy import Array from ._functional import Functional From 3b7f75b6412e28395779fd0b46d0aa67c20400af Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sun, 5 Nov 2023 11:38:48 -0700 Subject: [PATCH 26/29] Apply skipped pre-commit --- scico/functional/_norm.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index 5d0a6579e..332d0500f 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -7,13 +7,11 @@ """Functionals that are norms.""" -import warnings from typing import Optional, Tuple, Union from jax import jit, lax from scico import numpy as snp -from scico.linop import FiniteDifference from scico.numpy import Array, BlockArray, count_nonzero from scico.numpy.linalg import norm from scico.numpy.util import no_nan_divide From 64228e828e62a2928e245cc3e67ae62ef8e673c6 Mon Sep 17 00:00:00 2001 From: Salman Naqvi Date: Wed, 8 Nov 2023 11:50:34 -0800 Subject: [PATCH 27/29] added a new solver for solving composite prior minimization problem, called Proximal Averaged Projected Gradient Method --- scico/optimize/_papgm.py | 125 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 scico/optimize/_papgm.py diff --git a/scico/optimize/_papgm.py b/scico/optimize/_papgm.py new file mode 100644 index 000000000..31a7e9311 --- /dev/null +++ b/scico/optimize/_papgm.py @@ -0,0 +1,125 @@ +"""Proximal Averaged Accelerated Projected Gradient Method.""" + +from typing import List, Optional, Tuple, Union + +import scico.numpy as snp +from scico.numpy import Array, BlockArray +from scico.functional import Loss, Functional + +from ._common import Optimizer + +class AcceleratedPAPGM(Optimizer): + r"""Accelerated Proximal Averaged Projected Gradient Method (AcceleratedPAPGM) base class. + + Minimize a function of the form :math:`f(\mb{x}) + \sum_{i=1}^N \rho_i g_i(\mb{x})`, + + where :math:`f` and the :math:`g` are instances of :class:`.Functional`, + `rho_i` are positive and non-zero and sum upto 1. + This modifies FISTA to handle the case of composite prior minimization. + :cite:`yaoliang-2013-nips`. + + """ + + def __init__( + self, + f: Union[Loss, Functional], + g_list: List[Functional], + rho_list: List[float], + L0: float, + x0: Union[Array, BlockArray], + **kwargs, + ): + r""" + Args: + f: (:class:`.Functional`): Functional :math:`f` (usually a + :class:`.Loss`) + g_list: (list of :class:`.Functional`): List of :math:`g_i` + functionals. Must be same length as :code:`rho_list`. + rho_list: (list of scalars): List of :math:`\rho_i` penalty + parameters. Must be same length as :code:`g_list` and sum to 1. + L0: (float): Initial estimate of Lipschitz constant of f. + x0: (array-like): Starting point for :math:`\mb{x}`. + **kwargs: Additional optional parameters handled by + initializer of base class :class:`.Optimizer`. + """ + self.f: Union[Loss, Functional] = f + self.g_list: List[Functional] = g_list + self.rho_list: List[float] = rho_list + self.x: Union[Array, BlockArray] = x0 + self.fixed_point_residual: float = snp.inf + self.v: Union[Array, BlockArray] = x0 + self.t: float = 1.0 + self.L: float = L0 + + super().__init__(**kwargs) + + def step(self): + """Take a single AcceleratedPAPGM step.""" + assert snp.sum(snp.array(self.rho_list)) == 1 + assert snp.all(snp.array([rho>=0 for rho in self.rho_list])) + + x_old = self.x + z = self.v - 1.0 / self.L * self.f.grad(self.v) + + self.fixed_point_residual = 0 + self.x = snp.zeros_like(z) + for gi, rhoi in zip(self.g_list, self.rho_list): + self.x += rhoi * gi.prox(z, 1.0 / self.L) + self.fixed_point_residual += snp.linalg.norm(self.x - self.v) + + t_old = self.t + self.t = 0.5 * (1 + snp.sqrt(1 + 4 * t_old**2)) + self.v = self.x + ((t_old - 1) / self.t) * (self.x - x_old) + + def _working_vars_finite(self) -> bool: + """Determine where ``NaN`` of ``Inf`` encountered in solve. + + Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in + a solver working variable. + """ + return snp.all(snp.isfinite(self.x)) and snp.all(snp.isfinite(self.v)) + + def minimizer(self): + """Return current estimate of the functional mimimizer.""" + return self.x + + def objective(self, x: Optional[Union[Array, BlockArray]] = None) -> float: + r"""Evaluate the objective function + + .. math:: + f(\mb{x}) + \sum_{i=1}^N g_i(\mb{x}_i) \;. + + Args: + x: Point at which to evaluate objective function. If ``None``, + the objective is evaluated at the current iterate + :code:`self.x`. + + Returns: + Value of the objective function. + """ + if x is None: + x = self.x + out = 0.0 + if self.f: + out += self.f(x) + for gi, rhoi in zip(self.g_list, self.rho_list): + out += rhoi * gi(x) + return out + + def _objective_evaluatable(self): + """Determine whether the objective function can be evaluated.""" + return (not self.f or self.f.has_eval) and all([_.has_eval for _ in self.g_list]) + + def _itstat_extra_fields(self): + """Define AcceleratedPAPGM iteration statistics fields.""" + itstat_fields = {"L": "%9.3e", "Residual": "%9.3e"} + itstat_attrib = ["L", "norm_residual()"] + return itstat_fields, itstat_attrib + + def norm_residual(self) -> float: + r"""Return the fixed point residual. + + Return the fixed point residual (see Sec. 4.3 of + :cite:`liu-2018-first`). + """ + return self.fixed_point_residual From de76dec12c34044bcc4b8712cede8c1f0cae836e Mon Sep 17 00:00:00 2001 From: Salman Naqvi Date: Wed, 8 Nov 2023 12:15:17 -0800 Subject: [PATCH 28/29] Fixed typos that were causing the lint and mypy tests to fail. --- scico/optimize/_papgm.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/scico/optimize/_papgm.py b/scico/optimize/_papgm.py index 31a7e9311..e8092bf7a 100644 --- a/scico/optimize/_papgm.py +++ b/scico/optimize/_papgm.py @@ -4,17 +4,19 @@ import scico.numpy as snp from scico.numpy import Array, BlockArray -from scico.functional import Loss, Functional +from scico.functional import Functional +from scico.loss import Loss from ._common import Optimizer + class AcceleratedPAPGM(Optimizer): r"""Accelerated Proximal Averaged Projected Gradient Method (AcceleratedPAPGM) base class. - Minimize a function of the form :math:`f(\mb{x}) + \sum_{i=1}^N \rho_i g_i(\mb{x})`, + Minimize a function of the form :math:`f(\mb{x}) + \sum_{i=1}^N \rho_i g_i(\mb{x})`, - where :math:`f` and the :math:`g` are instances of :class:`.Functional`, - `rho_i` are positive and non-zero and sum upto 1. + where :math:`f` and the :math:`g` are instances of :class:`.Functional`, + `rho_i` are positive and non-zero and sum upto 1. This modifies FISTA to handle the case of composite prior minimization. :cite:`yaoliang-2013-nips`. @@ -56,7 +58,7 @@ def __init__( def step(self): """Take a single AcceleratedPAPGM step.""" assert snp.sum(snp.array(self.rho_list)) == 1 - assert snp.all(snp.array([rho>=0 for rho in self.rho_list])) + assert snp.all(snp.array([rho >= 0 for rho in self.rho_list])) x_old = self.x z = self.v - 1.0 / self.L * self.f.grad(self.v) @@ -82,7 +84,7 @@ def _working_vars_finite(self) -> bool: def minimizer(self): """Return current estimate of the functional mimimizer.""" return self.x - + def objective(self, x: Optional[Union[Array, BlockArray]] = None) -> float: r"""Evaluate the objective function From 7d7df69a8ba73fb5e244d1eb7e4b80d9d8e51272 Mon Sep 17 00:00:00 2001 From: Salman Naqvi Date: Wed, 8 Nov 2023 12:17:37 -0800 Subject: [PATCH 29/29] stylistic sugar changes for lint tests --- scico/optimize/_papgm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/optimize/_papgm.py b/scico/optimize/_papgm.py index e8092bf7a..f7102c6c2 100644 --- a/scico/optimize/_papgm.py +++ b/scico/optimize/_papgm.py @@ -3,9 +3,9 @@ from typing import List, Optional, Tuple, Union import scico.numpy as snp -from scico.numpy import Array, BlockArray from scico.functional import Functional from scico.loss import Loss +from scico.numpy import Array, BlockArray from ._common import Optimizer