Skip to content

Commit

Permalink
Clean up duplicated code
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Nov 30, 2023
1 parent c8b19c7 commit 368bd1a
Showing 1 changed file with 71 additions and 82 deletions.
153 changes: 71 additions & 82 deletions scico/functional/_tvnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,77 @@
VerticalStack,
)
from scico.numpy import Array
from scico.typing import DType

from ._functional import Functional
from ._norm import L1Norm, L21Norm


class AnisotropicTVNorm(Functional):
class AbstractTVNorm(Functional):
"""Abstract base class for total variation (TV) norms.
Abstract base class for total variation (TV) norms with
proximal operators approximations.
"""

has_eval = True
has_prox = True

def __init__(self, ndims: Optional[int] = None):
"""
Args:
ndims: Number of (trailing) dimensions of the input over
which to apply the finite difference operator. If
``None``, differences are evaluated along all axes.
"""
self.ndims = ndims
self.h0 = snp.array([1.0, 1.0]) / snp.sqrt(2.0) # lowpass filter
self.h1 = snp.array([1.0, -1.0]) / snp.sqrt(2.0) # highpass filter
self.G: Optional[LinearOperator] = None
self.W: Optional[LinearOperator] = None

@staticmethod
def _shape(idx: int, ndims: int) -> Tuple:
"""Construct a shape tuple.
Construct a tuple of size `ndims` with all unit entries except
for index `idx`, which has a -1 entry.
"""
return (1,) * idx + (-1,) + (1,) * (ndims - idx - 1)

def _construct_W(self, shape: Tuple, dtype: DType, ndims: int) -> VerticalStack:
"""Construct a partial shift-invariant Haar transform operator.
Construct a single-level shift-invariant Haar transform operator.
"""
h0 = self.h0.astype(dtype)
h1 = self.h1.astype(dtype)
C0 = VerticalStack( # Stack of lowpass filter operators for each axis
[
CircularConvolve(
h0.reshape(AbstractTVNorm._shape(k, ndims)),
shape,
ndims=self.ndims,
)
for k in range(ndims)
]
)
C1 = VerticalStack( # Stack of highpass filter operators for each axis
[
CircularConvolve(
h1.reshape(AbstractTVNorm._shape(k, ndims)),
shape,
ndims=self.ndims,
)
for k in range(ndims)
]
)
# single-level shift-invariant Haar transform
W = VerticalStack([C0, C1], jit=True)
return W


class AnisotropicTVNorm(AbstractTVNorm):
r"""The anisotropic total variation (TV) norm.
The anisotropic total variation (TV) norm computed by
Expand Down Expand Up @@ -53,22 +118,15 @@ class AnisotropicTVNorm(Functional):
in the `rho_list` algorithm parameter.
"""

has_eval = True
has_prox = True

def __init__(self, ndims: Optional[int] = None):
"""
Args:
ndims: Number of (trailing) dimensions of the input over
which to apply the finite difference operator. If
``None``, differences are evaluated along all axes.
"""
self.ndims = ndims
self.h0 = snp.array([1.0, 1.0]) / snp.sqrt(2.0) # lowpass filter
self.h1 = snp.array([1.0, -1.0]) / snp.sqrt(2.0) # highpass filter
super().__init__(ndims=ndims)
self.l1norm = L1Norm()
self.G: Optional[LinearOperator] = None
self.W: Optional[LinearOperator] = None

def __call__(self, x: Array) -> float:
"""Compute the anisotropic TV norm of an array."""
Expand All @@ -83,15 +141,6 @@ def __call__(self, x: Array) -> float:
)
return self.l1norm(self.G @ x)

@staticmethod
def _shape(idx: int, ndims: int) -> Tuple:
"""Construct a shape tuple.
Construct a tuple of size `ndims` with all unit entries except
for index `idx`, which has a -1 entry.
"""
return (1,) * idx + (-1,) + (1,) * (ndims - idx - 1)

def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array:
r"""Approximate proximal operator of the isotropic TV norm.
Expand All @@ -111,38 +160,15 @@ def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array:
K = 2 * ndims

if self.W is None or self.W.shape[1] != v.shape:
h0 = self.h0.astype(v.dtype)
h1 = self.h1.astype(v.dtype)
C0 = VerticalStack( # Stack of lowpass filter operators for each axis
[
CircularConvolve(
h0.reshape(AnisotropicTVNorm._shape(k, ndims)),
v.shape,
ndims=self.ndims,
)
for k in range(ndims)
]
)
C1 = VerticalStack( # Stack of highpass filter operators for each axis
[
CircularConvolve(
h1.reshape(AnisotropicTVNorm._shape(k, ndims)),
v.shape,
ndims=self.ndims,
)
for k in range(ndims)
]
)
# single-level shift-invariant Haar transform
self.W = VerticalStack([C0, C1], jit=True)
self.W = self._construct_W(v.shape, v.dtype, ndims)

Wv = self.W @ v
# Apply 𝑙1 shrinkage to highpass component of shift-invariant Haar transform
Wv = Wv.at[1].set(self.l1norm.prox(Wv[1], snp.sqrt(2) * K * lam))
return (1.0 / K) * self.W.T @ Wv


class IsotropicTVNorm(Functional):
class IsotropicTVNorm(AbstractTVNorm):
r"""The isotropic total variation (TV) norm.
The isotropic total variation (TV) norm computed by
Expand Down Expand Up @@ -173,22 +199,15 @@ class IsotropicTVNorm(Functional):
in the `rho_list` algorithm parameter.
"""

has_eval = True
has_prox = True

def __init__(self, ndims: Optional[int] = None):
r"""
Args:
ndims: Number of (trailing) dimensions of the input over
which to apply the finite difference operator. If
``None``, differences are evaluated along all axes.
"""
self.ndims = ndims
self.h0 = snp.array([1.0, 1.0]) / snp.sqrt(2.0) # lowpass filter
self.h1 = snp.array([1.0, -1.0]) / snp.sqrt(2.0) # highpass filter
super().__init__(ndims=ndims)
self.l21norm = L21Norm()
self.G = None
self.W = None

def __call__(self, x: Array) -> float:
r"""Compute the isotropic TV norm of an array."""
Expand All @@ -203,15 +222,6 @@ def __call__(self, x: Array) -> float:
)
return self.l21norm(self.G @ x)

@staticmethod
def _shape(idx: int, ndims: int) -> Tuple:
"""Construct a shape tuple.
Construct a tuple of size `ndims` with all unit entries except
for index `idx`, which has a -1 entry.
"""
return (1,) * idx + (-1,) + (1,) * (ndims - idx - 1)

def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array:
r"""Approximate proximal operator of the isotropic TV norm.
Expand All @@ -231,28 +241,7 @@ def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array:
K = 2 * ndims

if self.W is None or self.W.shape[1] != v.shape:
C0 = VerticalStack( # Stack of lowpass filter operators for each axis
[
CircularConvolve(
self.h0.reshape(AnisotropicTVNorm._shape(k, ndims)),
v.shape,
ndims=self.ndims,
)
for k in range(ndims)
]
)
C1 = VerticalStack( # Stack of highpass filter operators for each axis
[
CircularConvolve(
self.h1.reshape(AnisotropicTVNorm._shape(k, ndims)),
v.shape,
ndims=self.ndims,
)
for k in range(ndims)
]
)
# single-level shift-invariant Haar transform
self.W = VerticalStack((C0, C1), jit=True)
self.W = self._construct_W(v.shape, v.dtype, ndims)

Wv = self.W @ v
# Apply 𝑙21 shrinkage to highpass component of shift-invariant Haar transform
Expand Down

0 comments on commit 368bd1a

Please sign in to comment.