diff --git a/scico/functional/_tvnorm.py b/scico/functional/_tvnorm.py index 1276f3ad7..dc3995648 100644 --- a/scico/functional/_tvnorm.py +++ b/scico/functional/_tvnorm.py @@ -17,12 +17,77 @@ VerticalStack, ) from scico.numpy import Array +from scico.typing import DType from ._functional import Functional from ._norm import L1Norm, L21Norm -class AnisotropicTVNorm(Functional): +class AbstractTVNorm(Functional): + """Abstract base class for total variation (TV) norms. + + Abstract base class for total variation (TV) norms with + proximal operators approximations. + """ + + has_eval = True + has_prox = True + + def __init__(self, ndims: Optional[int] = None): + """ + Args: + ndims: Number of (trailing) dimensions of the input over + which to apply the finite difference operator. If + ``None``, differences are evaluated along all axes. + """ + self.ndims = ndims + self.h0 = snp.array([1.0, 1.0]) / snp.sqrt(2.0) # lowpass filter + self.h1 = snp.array([1.0, -1.0]) / snp.sqrt(2.0) # highpass filter + self.G: Optional[LinearOperator] = None + self.W: Optional[LinearOperator] = None + + @staticmethod + def _shape(idx: int, ndims: int) -> Tuple: + """Construct a shape tuple. + + Construct a tuple of size `ndims` with all unit entries except + for index `idx`, which has a -1 entry. + """ + return (1,) * idx + (-1,) + (1,) * (ndims - idx - 1) + + def _construct_W(self, shape: Tuple, dtype: DType, ndims: int) -> VerticalStack: + """Construct a partial shift-invariant Haar transform operator. + + Construct a single-level shift-invariant Haar transform operator. + """ + h0 = self.h0.astype(dtype) + h1 = self.h1.astype(dtype) + C0 = VerticalStack( # Stack of lowpass filter operators for each axis + [ + CircularConvolve( + h0.reshape(AbstractTVNorm._shape(k, ndims)), + shape, + ndims=self.ndims, + ) + for k in range(ndims) + ] + ) + C1 = VerticalStack( # Stack of highpass filter operators for each axis + [ + CircularConvolve( + h1.reshape(AbstractTVNorm._shape(k, ndims)), + shape, + ndims=self.ndims, + ) + for k in range(ndims) + ] + ) + # single-level shift-invariant Haar transform + W = VerticalStack([C0, C1], jit=True) + return W + + +class AnisotropicTVNorm(AbstractTVNorm): r"""The anisotropic total variation (TV) norm. The anisotropic total variation (TV) norm computed by @@ -53,9 +118,6 @@ class AnisotropicTVNorm(Functional): in the `rho_list` algorithm parameter. """ - has_eval = True - has_prox = True - def __init__(self, ndims: Optional[int] = None): """ Args: @@ -63,12 +125,8 @@ def __init__(self, ndims: Optional[int] = None): which to apply the finite difference operator. If ``None``, differences are evaluated along all axes. """ - self.ndims = ndims - self.h0 = snp.array([1.0, 1.0]) / snp.sqrt(2.0) # lowpass filter - self.h1 = snp.array([1.0, -1.0]) / snp.sqrt(2.0) # highpass filter + super().__init__(ndims=ndims) self.l1norm = L1Norm() - self.G: Optional[LinearOperator] = None - self.W: Optional[LinearOperator] = None def __call__(self, x: Array) -> float: """Compute the anisotropic TV norm of an array.""" @@ -83,15 +141,6 @@ def __call__(self, x: Array) -> float: ) return self.l1norm(self.G @ x) - @staticmethod - def _shape(idx: int, ndims: int) -> Tuple: - """Construct a shape tuple. - - Construct a tuple of size `ndims` with all unit entries except - for index `idx`, which has a -1 entry. - """ - return (1,) * idx + (-1,) + (1,) * (ndims - idx - 1) - def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array: r"""Approximate proximal operator of the isotropic TV norm. @@ -111,30 +160,7 @@ def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array: K = 2 * ndims if self.W is None or self.W.shape[1] != v.shape: - h0 = self.h0.astype(v.dtype) - h1 = self.h1.astype(v.dtype) - C0 = VerticalStack( # Stack of lowpass filter operators for each axis - [ - CircularConvolve( - h0.reshape(AnisotropicTVNorm._shape(k, ndims)), - v.shape, - ndims=self.ndims, - ) - for k in range(ndims) - ] - ) - C1 = VerticalStack( # Stack of highpass filter operators for each axis - [ - CircularConvolve( - h1.reshape(AnisotropicTVNorm._shape(k, ndims)), - v.shape, - ndims=self.ndims, - ) - for k in range(ndims) - ] - ) - # single-level shift-invariant Haar transform - self.W = VerticalStack([C0, C1], jit=True) + self.W = self._construct_W(v.shape, v.dtype, ndims) Wv = self.W @ v # Apply 𝑙1 shrinkage to highpass component of shift-invariant Haar transform @@ -142,7 +168,7 @@ def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array: return (1.0 / K) * self.W.T @ Wv -class IsotropicTVNorm(Functional): +class IsotropicTVNorm(AbstractTVNorm): r"""The isotropic total variation (TV) norm. The isotropic total variation (TV) norm computed by @@ -173,9 +199,6 @@ class IsotropicTVNorm(Functional): in the `rho_list` algorithm parameter. """ - has_eval = True - has_prox = True - def __init__(self, ndims: Optional[int] = None): r""" Args: @@ -183,12 +206,8 @@ def __init__(self, ndims: Optional[int] = None): which to apply the finite difference operator. If ``None``, differences are evaluated along all axes. """ - self.ndims = ndims - self.h0 = snp.array([1.0, 1.0]) / snp.sqrt(2.0) # lowpass filter - self.h1 = snp.array([1.0, -1.0]) / snp.sqrt(2.0) # highpass filter + super().__init__(ndims=ndims) self.l21norm = L21Norm() - self.G = None - self.W = None def __call__(self, x: Array) -> float: r"""Compute the isotropic TV norm of an array.""" @@ -203,15 +222,6 @@ def __call__(self, x: Array) -> float: ) return self.l21norm(self.G @ x) - @staticmethod - def _shape(idx: int, ndims: int) -> Tuple: - """Construct a shape tuple. - - Construct a tuple of size `ndims` with all unit entries except - for index `idx`, which has a -1 entry. - """ - return (1,) * idx + (-1,) + (1,) * (ndims - idx - 1) - def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array: r"""Approximate proximal operator of the isotropic TV norm. @@ -231,28 +241,7 @@ def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array: K = 2 * ndims if self.W is None or self.W.shape[1] != v.shape: - C0 = VerticalStack( # Stack of lowpass filter operators for each axis - [ - CircularConvolve( - self.h0.reshape(AnisotropicTVNorm._shape(k, ndims)), - v.shape, - ndims=self.ndims, - ) - for k in range(ndims) - ] - ) - C1 = VerticalStack( # Stack of highpass filter operators for each axis - [ - CircularConvolve( - self.h1.reshape(AnisotropicTVNorm._shape(k, ndims)), - v.shape, - ndims=self.ndims, - ) - for k in range(ndims) - ] - ) - # single-level shift-invariant Haar transform - self.W = VerticalStack((C0, C1), jit=True) + self.W = self._construct_W(v.shape, v.dtype, ndims) Wv = self.W @ v # Apply 𝑙21 shrinkage to highpass component of shift-invariant Haar transform