From ce8adf881b4dcb49078ad2c7c08dc74891418baa Mon Sep 17 00:00:00 2001 From: guanhuaw Date: Mon, 18 Mar 2024 22:09:22 -0700 Subject: [PATCH 1/6] start torchscript --- .pre-commit-config.yaml | 17 +++++++++-------- mirtorch/alg/fista.py | 2 +- mirtorch/linear/__init__.py | 7 ++++++- mirtorch/linear/linearmaps.py | 4 ++-- mirtorch/linear/mri.py | 12 ++++++++---- mirtorch/linear/spect.py | 1 + mirtorch/linear/util.py | 16 +++++++++------- mirtorch/prox/prox.py | 3 +-- pyproject.toml | 1 + tests/basics_tests.py | 3 ++- tests/prox_tests.py | 3 ++- tests/spect_tests.py | 3 ++- 12 files changed, 44 insertions(+), 28 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 30186c8..3ac53b3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,14 +14,15 @@ repos: language_version: python3 exclude: ^(?:tests|docs|examples)/ - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: 'v0.1.5' - hooks: - - id: ruff - types_or: [python, pyi, jupyter] - args: [ --fix, --exit-non-zero-on-fix ] - exclude: ^(?:tests|docs|examples)/ - + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.3.3 + hooks: + # Run the linter. + - id: ruff + types_or: [python, pyi, jupyter] + args: [--fix] + exclude: ^(?:tests|docs|examples)/ - repo: https://github.com/codespell-project/codespell rev: v2.1.0 diff --git a/mirtorch/alg/fista.py b/mirtorch/alg/fista.py index 4400d00..dff57a6 100644 --- a/mirtorch/alg/fista.py +++ b/mirtorch/alg/fista.py @@ -63,7 +63,7 @@ def _update_momentum(): beta = (told - 1) / tnew told = tnew - # initalize parameters + # initialize parameters xold = x0 yold = x0 told = 1.0 diff --git a/mirtorch/linear/__init__.py b/mirtorch/linear/__init__.py index 417ec7f..bf48ce4 100644 --- a/mirtorch/linear/__init__.py +++ b/mirtorch/linear/__init__.py @@ -20,7 +20,7 @@ Patch3D, Diffnd, ) -from .mri import FFTCn +from .mri import FFTCn, NuSense, NuSenseGram, Gmri, GmriGram, Sense from .wavelets import Wavelet2D from .spect import SPECT @@ -46,4 +46,9 @@ "Patch3D", "FFTCn", "SPECT", + "NuSense", + "NuSenseGram", + "Gmri", + "GmriGram", + "Sense", ] diff --git a/mirtorch/linear/linearmaps.py b/mirtorch/linear/linearmaps.py index f0653a5..289bf3a 100644 --- a/mirtorch/linear/linearmaps.py +++ b/mirtorch/linear/linearmaps.py @@ -60,8 +60,8 @@ def __init__(self, size_in: Sequence[int], size_out: Sequence[int]): self.size_out = list(size_out) def __repr__(self): - return "<{oshape}x{ishape} {repr_str} Linop>".format( - oshape=self.size_out, ishape=self.size_in, repr_str=self.__class__.__name__ + return "".format( + repr_str=self.__class__.__name__, oshape=self.size_out, ishape=self.size_in ) def __call__(self, x) -> Tensor: diff --git a/mirtorch/linear/mri.py b/mirtorch/linear/mri.py index 2342bcd..2793e4c 100644 --- a/mirtorch/linear/mri.py +++ b/mirtorch/linear/mri.py @@ -4,7 +4,7 @@ """ import math -from typing import Sequence, Union +from typing import Sequence, Union, List import numpy as np import torch @@ -31,20 +31,22 @@ def __init__( self, size_in: Sequence[int], size_out: Sequence[int], - dims: Union[int, Sequence[int]], + dims: Union[int, List[int]] | None = None, norm: str = "ortho", ): super(FFTCn, self).__init__(size_in, size_out) self.norm = norm self.dims = dims - def _apply(self, x: Tensor) -> Tensor: + @torch.jit.script + def _apply(self: LinearMap, x: Tensor) -> Tensor: x = ifftshift(x, self.dims) x = fftn(x, dim=self.dims, norm=self.norm) x = fftshift(x, self.dims) return x - def _apply_adjoint(self, x: Tensor) -> Tensor: + @torch.jit.script + def _apply_adjoint(self: LinearMap, x: Tensor) -> Tensor: x = ifftshift(x, self.dims) if self.norm == "ortho": x = ifftn(x, dim=self.dims, norm="ortho") @@ -96,6 +98,7 @@ def __init__( self.smaps = smaps self.batchmode = batchmode + @torch.jit.script def _apply(self, x: Tensor) -> Tensor: r""" Args: @@ -109,6 +112,7 @@ def _apply(self, x: Tensor) -> Tensor: k = fftshift(k, self.dims) * self.masks return k + @torch.jit.script def _apply_adjoint(self, k: Tensor) -> Tensor: r""" Args: diff --git a/mirtorch/linear/spect.py b/mirtorch/linear/spect.py index ec9b8d9..44cc281 100644 --- a/mirtorch/linear/spect.py +++ b/mirtorch/linear/spect.py @@ -3,6 +3,7 @@ SPECT forward-backward projector with parallel beam collimator. 2023-06, Zongyu Li, University of Michigan """ + from typing import Sequence import torch diff --git a/mirtorch/linear/util.py b/mirtorch/linear/util.py index 817516f..1b0df42 100644 --- a/mirtorch/linear/util.py +++ b/mirtorch/linear/util.py @@ -1,4 +1,4 @@ -from typing import Sequence, Union +from typing import Union, List import numpy as np import torch @@ -90,29 +90,31 @@ def backward(ctx, dx): return finitediff(dx, ctx.dim, ctx.mode), None, None -def fftshift(x: Tensor, dims: Union[int, Sequence[int]] = None): +def fftshift(x: Tensor, dims: Union[int, List[int]] | None = None): """ Similar to np.fft.fftshift but applies to PyTorch tensors. From fastMRI code. """ if dims is None: - dims = tuple(range(x.dim())) + dims = list(range(x.dim())) shifts = [dim // 2 for dim in x.shape] elif isinstance(dims, int): - shifts = x.shape[dims] // 2 + shifts = [x.shape[dims] // 2] + dims = [dims] else: shifts = [x.shape[i] // 2 for i in dims] return torch.roll(x, shifts, dims) -def ifftshift(x: Tensor, dims: Union[int, Sequence[int]] = None): +def ifftshift(x: Tensor, dims: Union[int, List[int]] | None = None): """ Similar to np.fft.ifftshift but applies to PyTorch tensors. From fastMRI code. """ if dims is None: - dims = tuple(range(x.dims())) + dims = list(range(x.dim())) shifts = [(dim + 1) // 2 for dim in x.shape] elif isinstance(dims, int): - shifts = (x.shape[dims] + 1) // 2 + shifts = [(x.shape[dims] + 1) // 2] + dims = [dims] else: shifts = [(x.shape[i] + 1) // 2 for i in dims] return torch.roll(x, shifts, dims) diff --git a/mirtorch/prox/prox.py b/mirtorch/prox/prox.py index 1a0ce3e..ba32c88 100644 --- a/mirtorch/prox/prox.py +++ b/mirtorch/prox/prox.py @@ -1,10 +1,9 @@ -""" Proximal operators, such as soft-thresholding, box-constraint and L2 norm. +"""Proximal operators, such as soft-thresholding, box-constraint and L2 norm. Prox() class includes the common proximal operators used in iterative optimization. 2021-02. Neel Shah and Guanhua Wang, University of Michigan """ - from mirtorch.linear import LinearMap import torch from typing import Union diff --git a/pyproject.toml b/pyproject.toml index d82a1ef..330a920 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "matplotlib", "pytorch_wavelets@git+https://github.com/fbcotter/pytorch_wavelets.git@8d2e3b4289beaea9aa89f7b1dbb290e448331197#egg=pytorch_wavelets", ] + dynamic = ["version"] [project.urls] # Optional diff --git a/tests/basics_tests.py b/tests/basics_tests.py index a41dd4e..10c2b45 100644 --- a/tests/basics_tests.py +++ b/tests/basics_tests.py @@ -1,5 +1,6 @@ import unittest -import sys, os +import sys +import os path = os.path.dirname(os.path.abspath(__file__)) path = path[: path.rfind("/")] diff --git a/tests/prox_tests.py b/tests/prox_tests.py index 8163eb3..50fd810 100644 --- a/tests/prox_tests.py +++ b/tests/prox_tests.py @@ -1,5 +1,6 @@ import unittest -import sys, os +import sys +import os path = os.path.dirname(os.path.abspath(__file__)) path = path[: path.rfind("/")] diff --git a/tests/spect_tests.py b/tests/spect_tests.py index 1bbd5d5..e2d20f8 100644 --- a/tests/spect_tests.py +++ b/tests/spect_tests.py @@ -5,7 +5,8 @@ """ import unittest -import sys, os +import sys +import os path = os.path.dirname(os.path.abspath(__file__)) path = path[: path.rfind("/")] From 4497f65d1f9876d613eed690f8d94b4ef5f8a2fd Mon Sep 17 00:00:00 2001 From: guanhuaw Date: Sun, 24 Mar 2024 19:51:22 -0700 Subject: [PATCH 2/6] jitting --- mirtorch/alg/fbpd.py | 8 ++--- mirtorch/linear/linearmaps.py | 68 ++++++++++++++++++----------------- mirtorch/linear/mri.py | 8 ++--- 3 files changed, 44 insertions(+), 40 deletions(-) diff --git a/mirtorch/alg/fbpd.py b/mirtorch/alg/fbpd.py index e3affa0..342d065 100644 --- a/mirtorch/alg/fbpd.py +++ b/mirtorch/alg/fbpd.py @@ -42,10 +42,10 @@ def __init__( h_prox: Prox, g_L: float, G_norm: float, - G: LinearMap = None, - tau: float = None, + G: LinearMap | None = None, + tau: float | None = None, max_iter: int = 10, - eval_func: Callable = None, + eval_func: Callable | None = None, p: int = 1, ): self.max_iter = max_iter @@ -87,7 +87,7 @@ def run(self, x0: torch.Tensor): if self.eval_func is not None: saved.append(self.eval_func(xold)) logger.info( - "The cost function at %dth iter in FBPD: %10.3e." % (i, saved[-1]) + "The cost function at %dth iter in FBPD: %10.3e.", i, saved[-1] ) if self.eval_func is not None: return xold, saved diff --git a/mirtorch/linear/linearmaps.py b/mirtorch/linear/linearmaps.py index 289bf3a..c042e0e 100644 --- a/mirtorch/linear/linearmaps.py +++ b/mirtorch/linear/linearmaps.py @@ -1,4 +1,6 @@ -from typing import Sequence, TypeVar, Union +from __future__ import annotations + +from typing import List, Union import numpy as np import torch @@ -14,9 +16,6 @@ def check_device(x, y): assert x.device == y.device, "Tensors should be on the same device" -T = TypeVar("T", bound="LinearMap") - - class LinearMap: r""" Abstraction of linear operators as matrices :math:`y = A*x`. @@ -52,7 +51,7 @@ def backward(ctx, grad_data_in): size_out: the size of the output of the linear map (a list) """ - def __init__(self, size_in: Sequence[int], size_out: Sequence[int]): + def __init__(self, size_in: List[int], size_out: List[int]): r""" Initiate the linear operator. """ @@ -60,8 +59,8 @@ def __init__(self, size_in: Sequence[int], size_out: Sequence[int]): self.size_out = list(size_out) def __repr__(self): - return "".format( - repr_str=self.__class__.__name__, oshape=self.size_out, ishape=self.size_in + return ( + f"" ) def __call__(self, x) -> Tensor: @@ -96,19 +95,21 @@ def adjoint(self, x) -> Tensor: return self._apply_adjoint(x) @property - def H(self) -> T: + def H(self) -> LinearMap: r""" Apply the (Hermitian) transpose """ return ConjTranspose(self) - def __add__(self: T, other: T) -> T: + def __add__(self: LinearMap, other: LinearMap) -> LinearMap: r""" Reload the + symbol. """ return Add(self, other) - def __mul__(self: T, other) -> T: + def __mul__( + self: LinearMap, other: Union[str, int, LinearMap, Tensor] + ) -> LinearMap: r""" Reload the * symbol. """ @@ -116,49 +117,52 @@ def __mul__(self: T, other) -> T: return Multiply(self, other) elif isinstance(other, LinearMap): return Matmul(self, other) - elif isinstance(other, torch.Tensor): + elif isinstance(other, Tensor): if not other.shape: - # raise ValueError( - # "Input tensor has empty shape. If want to scale the linear map, please use the standard scalar") return Multiply(self, other) return self.apply(other) else: raise NotImplementedError( - f"Only scalers, Linearmaps or Tensors, rather than '{type(other)}' are allowed as arguments for this function." + ( + f"Only scalers, Linearmaps or Tensors, rather than '{type(other)}' " + "fare allowed as arguments for this function." + ) ) - def __rmul__(self: T, other) -> T: + def __rmul__( + self: LinearMap, other: Union[str, int, LinearMap, Tensor] + ) -> LinearMap: r""" Reload the * symbol. """ if np.isscalar(other): return Multiply(self, other) - elif isinstance(other, torch.Tensor) and not other.shape: + elif isinstance(other, Tensor) and not other.shape: return Multiply(self, other) else: return NotImplemented - def __sub__(self: T, other: T) -> T: + def __sub__(self: LinearMap, other: LinearMap) -> LinearMap: r""" Reload the - symbol. """ return self.__add__(-other) - def __neg__(self: T) -> T: + def __neg__(self: LinearMap) -> LinearMap: r""" Reload the - symbol. """ return -1 * self - def to(self: T, *args, **kwargs): + def to(self: LinearMap, device: Union[torch.device, str]) -> LinearMap: r""" Copy to different devices """ for prop in self.__dict__.keys(): - if isinstance(self.__dict__[prop], torch.Tensor) or isinstance( + if isinstance(self.__dict__[prop], Tensor) or isinstance( self.__dict__[prop], torch.nn.Module ): - self.__dict__[prop] = self.__dict__[prop].to(*args, **kwargs) + self.__dict__[prop] = self.__dict__[prop].to(device) class Add(LinearMap): @@ -184,10 +188,10 @@ def __init__(self, A: LinearMap, B: LinearMap): self.B = B super().__init__(self.A.size_in, self.B.size_out) - def _apply(self: T, x: Tensor) -> Tensor: + def _apply(self: LinearMap, x: Tensor) -> Tensor: return self.A(x) + self.B(x) - def _apply_adjoint(self: T, x: Tensor) -> Tensor: + def _apply_adjoint(self: LinearMap, x: Tensor) -> Tensor: return self.A.H(x) + self.B.H(x) @@ -208,11 +212,11 @@ def __init__(self, A: LinearMap, a: FloatLike): self.A = A super().__init__(self.A.size_in, self.A.size_out) - def _apply(self: T, x: Tensor) -> Tensor: + def _apply(self: LinearMap, x: Tensor) -> Tensor: ax = x * self.a return self.A(ax) - def _apply_adjoint(self: T, x: Tensor) -> Tensor: + def _apply_adjoint(self: LinearMap, x: Tensor) -> Tensor: ax = x * self.a return self.A.H(ax) @@ -232,11 +236,11 @@ def __init__(self, A: LinearMap, B: LinearMap): assert list(self.B.size_out) == list(self.A.size_in), "Shapes do not match" super().__init__(self.B.size_in, self.A.size_out) - def _apply(self: T, x: Tensor) -> Tensor: + def _apply(self: LinearMap, x: Tensor) -> Tensor: # TODO: add gram operator return self.A(self.B(x)) - def _apply_adjoint(self: T, x: Tensor) -> Tensor: + def _apply_adjoint(self: LinearMap, x: Tensor) -> Tensor: return self.B.H(self.A.H(x)) @@ -249,10 +253,10 @@ def __init__(self, A: LinearMap): self.A = A super().__init__(A.size_out, A.size_in) - def _apply(self: T, x: Tensor) -> Tensor: + def _apply(self: LinearMap, x: Tensor) -> Tensor: return self.A.adjoint(x) - def _apply_adjoint(self: T, x: Tensor) -> Tensor: + def _apply_adjoint(self: LinearMap, x: Tensor) -> Tensor: return self.A.apply(x) @@ -265,7 +269,7 @@ class BlockDiagonal(LinearMap): A : List of 2D linear maps """ - def __init__(self, A: Sequence[LinearMap]): + def __init__(self, A: List[LinearMap]): self.A = A # dimension checks @@ -280,7 +284,7 @@ def __init__(self, A: Sequence[LinearMap]): size_out = list(A[0].size_out) + [nz] super().__init__(tuple(size_in), tuple(size_out)) - def _apply(self: T, x: Tensor) -> Tensor: + def _apply(self: LinearMap, x: Tensor) -> Tensor: out = torch.zeros( self.size_out, dtype=x.dtype, device=x.device, layout=x.layout ) @@ -291,7 +295,7 @@ def _apply(self: T, x: Tensor) -> Tensor: out[..., k] = self.A[k].apply(x[..., k]) return out - def _apply_adjoint(self: T, x: Tensor): + def _apply_adjoint(self: LinearMap, x: Tensor): out = torch.zeros(self.size_in, dtype=x.dtype, device=x.device, layout=x.layout) nz = self.size_in[-1] diff --git a/mirtorch/linear/mri.py b/mirtorch/linear/mri.py index 2793e4c..53ecfdb 100644 --- a/mirtorch/linear/mri.py +++ b/mirtorch/linear/mri.py @@ -165,7 +165,7 @@ def __init__( traj: Tensor, norm="ortho", batchmode=True, - numpoints: Union[int, Sequence[int]] = 6, + numpoints: Union[int, List[int]] = 6, grid_size: float = 2, sequential: bool = False, ): @@ -328,7 +328,7 @@ def __init__( traj: Tensor, norm="ortho", batchmode=True, - numpoints: Union[int, Sequence[int]] = 6, + numpoints: Union[int, List[int]] = 6, grid_size: float = 2, ): self.smaps = smaps @@ -433,7 +433,7 @@ def __init__( L: int = 6, nbins: int = 20, dt: int = 4e-3, - numpoints: Union[int, Sequence[int]] = 6, + numpoints: Union[int, List[int]] = 6, grid_size: float = 2, T: Tensor = None, ): @@ -560,7 +560,7 @@ def __init__( L: int = 6, nbins: int = 20, dt: int = 4e-3, - numpoints: Union[int, Sequence[int]] = 6, + numpoints: Union[int, List[int]] = 6, grid_size: float = 2, T: Tensor = None, ): From 85c7722af068974c9e41429e17bfaa62e19c96ec Mon Sep 17 00:00:00 2001 From: guanhuaw Date: Sun, 21 Jul 2024 20:22:25 -0700 Subject: [PATCH 3/6] Improve readme --- README.md | 9 +++++---- mirtorch/linear/mri.py | 40 +++++++++++----------------------------- 2 files changed, 16 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 1f73536..e315f0a 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ A Py***Torch***-based differentiable ***I***mage ***R***econstruction ***T***ool The work is inspired by [MIRT](https://github.com/JeffFessler/mirt), a well-acclaimed toolbox for medical imaging reconstruction. -The main objective is to facilitate rapid, data-driven image reconstruction using CPUs and GPUs through fast prototyping and iteration. Researchers can conveniently develop new model-based and learning-based methods (e.g., unrolled neural networks) with abstraction layers. The availability of auto-differentiation enables optimization of imaging protocols and reconstruction parameters using gradient methods. +The main objective is to facilitate rapid, data-driven medical image reconstruction using CPUs and GPUs, for fast prototyping. Researchers can conveniently develop new model-based and learning-based methods (e.g., unrolled neural networks) with abstraction layers. The availability of auto-differentiation enables optimization of imaging protocols and reconstruction parameters using gradient methods. Documentation: https://mirtorch.readthedocs.io/en/latest/ @@ -26,13 +26,13 @@ To install the `MIRTorch` package, after cloning the repo, please try `pip insta The `LinearMap` class overloads common matrix operations, such as `+, - , *`. -Instances include basic linear operations (like convolution), classical imaging processing, and MRI system matrix (Cartesian and Non-Cartesian, sensitivity- and B0-informed system models). ***NEW!*** MIRTorch recently adds the support for SPPET and CT. +Instances include basic linear operations (like convolution), classical imaging processing, and MRI system matrix (Cartesian and Non-Cartesian, sensitivity- and B0-informed system models). ***NEW!*** MIRTorch recently adds the support for SPECT and CT. Since the Jacobian matrix of a linear operator is itself, the toolbox can actively calculate such Jacobians during backpropagation, avoiding the large cache cost required by auto-differentiation. When defining linear operators, please make sure that all torch tensors are on the same device and compatible. For example, `torch.cfloat` are compatible with `torch.float` but not `torch.double`. Similarly, `torch.chalf` is compatible with `torch.half`. When the data is image, there are 2 empirical formats: `[num_batch, num_channel, nx, ny, (nz)]` and `[nx, ny, (nz)]`. -For some LinearMaps, there is a boolean `batchmode` to control it. +For some LinearMaps, there is a boolean `batchmode` to control the shape. #### Proximal operators @@ -81,7 +81,7 @@ This work is inspired by (but not limited to): * PyLops: https://github.com/PyLops/pylops -If the code is useful to your research, please cite: +If the code is useful to your research, please consider citing: ```bibtex @article{wang:22:bjork, @@ -102,6 +102,7 @@ If the code is useful to your research, please cite: year={2022} } ``` +If you use the SPECT code, please consider citing: ```bibtex @ARTICLE{li:23:tet, diff --git a/mirtorch/linear/mri.py b/mirtorch/linear/mri.py index 53ecfdb..b305111 100644 --- a/mirtorch/linear/mri.py +++ b/mirtorch/linear/mri.py @@ -4,7 +4,7 @@ """ import math -from typing import Sequence, Union, List +from typing import Union, List import numpy as np import torch @@ -29,8 +29,8 @@ class FFTCn(LinearMap): def __init__( self, - size_in: Sequence[int], - size_out: Sequence[int], + size_in: List[int], + size_out: List[int], dims: Union[int, List[int]] | None = None, norm: str = "ortho", ): @@ -39,14 +39,17 @@ def __init__( self.dims = dims @torch.jit.script - def _apply(self: LinearMap, x: Tensor) -> Tensor: - x = ifftshift(x, self.dims) - x = fftn(x, dim=self.dims, norm=self.norm) - x = fftshift(x, self.dims) + def fwd(x: Tensor, dims: Union[int, List[int]], norm: str) -> Tensor: + x = ifftshift(x, dims) + x = fftn(x, dim=dims, norm=norm) + x = fftshift(x, dims) + return x + + def _apply(self, x: Tensor) -> Tensor: return x @torch.jit.script - def _apply_adjoint(self: LinearMap, x: Tensor) -> Tensor: + def _apply_adjoint(self, x: Tensor) -> Tensor: x = ifftshift(x, self.dims) if self.norm == "ortho": x = ifftn(x, dim=self.dims, norm="ortho") @@ -691,24 +694,3 @@ def mri_exp_approx(b0, bins, lseg, t): ct = np.transpose(np.exp(-np.expand_dims(tl, axis=1) @ b0_v)) return b, ct, tl - - -# def tukey_filer(LinearMap): -# r""" -# A Tukey filter to counteract Gibbs ringing artifacts -# Parameters: -# size_in: the signal size [nbatch, nchannel, nx (ny, nz ...)] -# width: the window length [wdx (wdy, wdz) ...] -# alpha(s): control parameters of the tukey window -# Returns: -# -# """ -# -# def __init__(self, -# size_in: Sequence[int], -# width: Sequence[int], -# alpha: Sequence[int] -# ): -# self.width = width -# self.alpha = alpha -# super(tukey_filer, self).__init__(tuple(size_in), tuple(size_in)) From 675cdf75356cf42fa28a8fa7ad692e791b25b900 Mon Sep 17 00:00:00 2001 From: guanhuaw Date: Sat, 27 Jul 2024 21:22:19 -0700 Subject: [PATCH 4/6] add pytests --- .github/workflows/python-ci.yml | 50 +- .github/workflows/python-publish.yml | 29 +- .gitignore | 1 + .pre-commit-config.yaml | 2 +- LICENSE | 29 + README.md | 3 +- mirtorch/linear/__init__.py | 4 + mirtorch/linear/linearmaps.py | 90 +- mirtorch/linear/mri.py | 17 +- mirtorch/linear/util.py | 6 +- mirtorch/linear/wavelets.py | 2 +- mirtorch/vendors/pytorch_wavelets/__init__.py | 35 + mirtorch/vendors/pytorch_wavelets/_version.py | 2 + .../pytorch_wavelets/dtcwt/__init__.py | 6 + .../vendors/pytorch_wavelets/dtcwt/coeffs.py | 142 +++ .../pytorch_wavelets/dtcwt/lowlevel.py | 381 +++++++ .../pytorch_wavelets/dtcwt/lowlevel2.py | 663 ++++++++++++ .../pytorch_wavelets/dtcwt/transform2d.py | 299 ++++++ .../pytorch_wavelets/dtcwt/transform_funcs.py | 495 +++++++++ .../vendors/pytorch_wavelets/dwt/__init__.py | 0 .../vendors/pytorch_wavelets/dwt/lowlevel.py | 997 ++++++++++++++++++ .../pytorch_wavelets/dwt/swt_inverse.py | 213 ++++ .../pytorch_wavelets/dwt/transform1d.py | 117 ++ .../pytorch_wavelets/dwt/transform2d.py | 223 ++++ .../pytorch_wavelets/scatternet/__init__.py | 3 + .../pytorch_wavelets/scatternet/layers.py | 209 ++++ .../pytorch_wavelets/scatternet/lowlevel.py | 779 ++++++++++++++ mirtorch/vendors/pytorch_wavelets/utils.py | 243 +++++ pyproject.toml | 37 +- tests/basics_tests.py | 386 +++---- tests/linops_tests.py | 109 ++ tests/mri_tests.py | 76 ++ tests/prox_tests.py | 257 ++--- tests/spect_tests.py | 164 ++- tests/util_tests.py | 109 ++ 35 files changed, 5702 insertions(+), 476 deletions(-) create mode 100644 mirtorch/vendors/pytorch_wavelets/__init__.py create mode 100644 mirtorch/vendors/pytorch_wavelets/_version.py create mode 100644 mirtorch/vendors/pytorch_wavelets/dtcwt/__init__.py create mode 100644 mirtorch/vendors/pytorch_wavelets/dtcwt/coeffs.py create mode 100644 mirtorch/vendors/pytorch_wavelets/dtcwt/lowlevel.py create mode 100644 mirtorch/vendors/pytorch_wavelets/dtcwt/lowlevel2.py create mode 100644 mirtorch/vendors/pytorch_wavelets/dtcwt/transform2d.py create mode 100644 mirtorch/vendors/pytorch_wavelets/dtcwt/transform_funcs.py create mode 100644 mirtorch/vendors/pytorch_wavelets/dwt/__init__.py create mode 100644 mirtorch/vendors/pytorch_wavelets/dwt/lowlevel.py create mode 100644 mirtorch/vendors/pytorch_wavelets/dwt/swt_inverse.py create mode 100644 mirtorch/vendors/pytorch_wavelets/dwt/transform1d.py create mode 100644 mirtorch/vendors/pytorch_wavelets/dwt/transform2d.py create mode 100644 mirtorch/vendors/pytorch_wavelets/scatternet/__init__.py create mode 100644 mirtorch/vendors/pytorch_wavelets/scatternet/layers.py create mode 100644 mirtorch/vendors/pytorch_wavelets/scatternet/lowlevel.py create mode 100644 mirtorch/vendors/pytorch_wavelets/utils.py create mode 100644 tests/util_tests.py diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index ccb695a..e8a646f 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -3,35 +3,49 @@ name: Python-CI on: push: branches: - - main - master - develop - feature/* pull_request: branches: - - main - master - develop - feature/* jobs: build: - runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: '3.x' # Specify the Python version you need - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -e . - pip install ruff pytest - - - name: Lint with Ruff - run: | - ruff ./mirtorch + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" # Specify the Python version you need + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e . + pip install ruff pytest + + - name: Lint with Ruff + run: | + ruff ./mirtorch + + - name: Test with pytest + run: | + pytest ./tests + + - name: Automated Version Bump + if: github.ref == 'refs/heads/master' + uses: phips28/gh-action-bump-version@master + with: + tag-prefix: "v" # or an empty string if you don't want a prefix + filename: "pyproject.toml" + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Push changes + if: github.ref == 'refs/heads/master' + run: git push && git push --tags diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index bdaab28..d58d71f 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -26,14 +26,23 @@ jobs: uses: actions/setup-python@v3 with: python-version: '3.x' - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install build - - name: Build package - run: python -m build - - name: Publish package - uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 + - name: Install pypa/build + run: >- + python3 -m + pip install + build + --user + - name: Build a binary wheel and a source tarball + run: python3 -m build + - name: Store the distribution packages + uses: actions/upload-artifact@v3 with: - user: __token__ - password: ${{ secrets.PYPI_API_TOKEN }} + name: python-package-distributions + path: dist/ + - name: Download all the dists + uses: actions/download-artifact@v3 + with: + name: python-package-distributions + path: dist/ + - name: Publish distribution 📦 to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.gitignore b/.gitignore index 9234227..5bdf639 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,4 @@ mrt/ docs/_build docs/_static docs/_templates +.ruff_cache diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3ac53b3..d18a903 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,4 +28,4 @@ repos: rev: v2.1.0 hooks: - id: codespell - exclude: ^(?:tests|docs|examples)/ + exclude: ^(?:tests|docs|examples|mirtorch/vendors)/ diff --git a/LICENSE b/LICENSE index 4b68dcf..5b5e122 100644 --- a/LICENSE +++ b/LICENSE @@ -49,3 +49,32 @@ Version 3, 29 June 2007 Copyright (C) 2007 Free Software Foundation, Inc. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. + +----------------------------- LICENSE for pytorch_wavelets---------------------------- +This licence applies to any parts of this library which are novel in comparison +to the original DTCWT MATLAB toolbox written by Nick Kingsbury and Cian +Shaffrey. See the Provenance section of README.rst file for details on any further +restrictions of use. If you wish to use the DTCWT, you should read that license as well. +The DWT sections come under this license. + +MIT License + +Copyright (c) 2020 Fergal Cotter + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index e315f0a..8364902 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,8 @@ Documentation: https://mirtorch.readthedocs.io/en/latest/ ### Installation We recommend to [pre-install `PyTorch` first](https://pytorch.org/). -To install the `MIRTorch` package, after cloning the repo, please try `pip install -e .`(one may modify the package locally with this option.) +Use `pip install mirtorch` to install. +To install the `MIRTorch` locally, after cloning the repo, please try `pip install -e .`(one may modify the package locally with this option.) ------ diff --git a/mirtorch/linear/__init__.py b/mirtorch/linear/__init__.py index bf48ce4..c27a018 100644 --- a/mirtorch/linear/__init__.py +++ b/mirtorch/linear/__init__.py @@ -6,6 +6,8 @@ ConjTranspose, BlockDiagonal, Kron, + Vstack, + Hstack, ) from .basics import ( Diff1d, @@ -51,4 +53,6 @@ "Gmri", "GmriGram", "Sense", + "Vstack", + "Hstack", ] diff --git a/mirtorch/linear/linearmaps.py b/mirtorch/linear/linearmaps.py index c042e0e..81516a0 100644 --- a/mirtorch/linear/linearmaps.py +++ b/mirtorch/linear/linearmaps.py @@ -63,20 +63,20 @@ def __repr__(self): f"" ) - def __call__(self, x) -> Tensor: + def __call__(self, x: Tensor) -> Tensor: # for a instance A, we can apply it by calling A(x). Equal to A*x return self.apply(x) - def _apply(self, x) -> Tensor: + def _apply(self, x: Tensor) -> Tensor: # worth noting that the function here should be differentiable, # for example, composed of native torch functions, # or torch.autograd.Function, or nn.module raise NotImplementedError - def _apply_adjoint(self, x) -> Tensor: + def _apply_adjoint(self, x: Tensor) -> Tensor: raise NotImplementedError - def apply(self, x) -> Tensor: + def apply(self, x: Tensor) -> Tensor: r""" Apply the forward operator """ @@ -85,7 +85,7 @@ def apply(self, x) -> Tensor: ), f"Shape of input data {x.shape} and forward linear op {self.size_in} do not match!" return self._apply(x) - def adjoint(self, x) -> Tensor: + def adjoint(self, x: Tensor) -> Tensor: r""" Apply the adjoint operator """ @@ -109,7 +109,7 @@ def __add__(self: LinearMap, other: LinearMap) -> LinearMap: def __mul__( self: LinearMap, other: Union[str, int, LinearMap, Tensor] - ) -> LinearMap: + ) -> Union[LinearMap, Tensor]: r""" Reload the * symbol. """ @@ -347,10 +347,80 @@ def _apply_adjoint(self, x: Tensor): class Vstack(LinearMap): - # TODO - pass + r""" + Vertical stacking of linear operators. + + .. math:: + [A1; A2; ...; An] * x = [A1(x); A2(x); ...; An(x)] + + Attributes: + A: List of LinearMaps to be stacked vertically + dim: the dimension along which to stack the LinearMaps + """ + + def __init__(self, A: List[LinearMap], dim: int = 0): + self.A = A + + # Check that all input sizes are the same + assert all( + [A[i].size_in == A[0].size_in for i in range(len(A))] + ), "All input sizes must be the same" + + # Calculate the total output size + size_out = [sum(A[i].size_out[0] for i in range(len(A)))] + list( + A[0].size_out[1:] + ) + + self.dim = dim + + super().__init__(A[0].size_in, size_out) + + def _apply(self, x: Tensor) -> Tensor: + return torch.cat([A_i(x) for A_i in self.A], dim=self.dim) + + def _apply_adjoint(self, x: Tensor) -> Tensor: + outputs = [] + start = 0 + for A_i in self.A: + end = start + A_i.size_out[0] + outputs.append(A_i.H(x[start:end])) + start = end + return sum(outputs) class Hstack(LinearMap): - # TODO - pass + r""" + Horizontal stacking of linear operators. + + .. math:: + [A1, A2, ..., An] * [x1; x2; ...; xn] = A1(x1) + A2(x2) + ... + An(xn) + + Attributes: + A: List of LinearMaps to be stacked horizontally + """ + + def __init__(self, A: List[LinearMap], dim: int = 0): + self.A = A + + # Check that all output sizes are the same + assert all( + [A[i].size_out == A[0].size_out for i in range(len(A))] + ), "All output sizes must be the same" + + # Calculate the total input size + size_in = [sum(A[i].size_in[0] for i in range(len(A)))] + list(A[0].size_in[1:]) + self.dim = dim + + super().__init__(size_in, A[0].size_out) + + def _apply(self, x: Tensor) -> Tensor: + outputs = [] + start = 0 + for A_i in self.A: + end = start + A_i.size_in[0] + outputs.append(A_i(x[start:end])) + start = end + return sum(outputs) + + def _apply_adjoint(self, x: Tensor) -> Tensor: + return torch.cat([A_i.H(x) for A_i in self.A], dim=self.dim) diff --git a/mirtorch/linear/mri.py b/mirtorch/linear/mri.py index b305111..5f3d86f 100644 --- a/mirtorch/linear/mri.py +++ b/mirtorch/linear/mri.py @@ -4,7 +4,7 @@ """ import math -from typing import Union, List +from typing import Union, List, Tuple import numpy as np import torch @@ -31,24 +31,19 @@ def __init__( self, size_in: List[int], size_out: List[int], - dims: Union[int, List[int]] | None = None, + dims: Tuple[int] | None = None, norm: str = "ortho", ): super(FFTCn, self).__init__(size_in, size_out) self.norm = norm self.dims = dims - @torch.jit.script - def fwd(x: Tensor, dims: Union[int, List[int]], norm: str) -> Tensor: - x = ifftshift(x, dims) - x = fftn(x, dim=dims, norm=norm) - x = fftshift(x, dims) - return x - def _apply(self, x: Tensor) -> Tensor: + x = ifftshift(x, self.dims) + x = fftn(x, dim=self.dims, norm=self.norm) + x = fftshift(x, self.dims) return x - @torch.jit.script def _apply_adjoint(self, x: Tensor) -> Tensor: x = ifftshift(x, self.dims) if self.norm == "ortho": @@ -101,7 +96,6 @@ def __init__( self.smaps = smaps self.batchmode = batchmode - @torch.jit.script def _apply(self, x: Tensor) -> Tensor: r""" Args: @@ -115,7 +109,6 @@ def _apply(self, x: Tensor) -> Tensor: k = fftshift(k, self.dims) * self.masks return k - @torch.jit.script def _apply_adjoint(self, k: Tensor) -> Tensor: r""" Args: diff --git a/mirtorch/linear/util.py b/mirtorch/linear/util.py index 1b0df42..33f3a9a 100644 --- a/mirtorch/linear/util.py +++ b/mirtorch/linear/util.py @@ -1,4 +1,4 @@ -from typing import Union, List +from typing import Union, Tuple, List import numpy as np import torch @@ -39,7 +39,7 @@ def finitediff_adj(y: Tensor, dim: int = -1, mode="reflexive"): Returns: y: the first-order finite difference of x """ - if mode == "reflexibe": + if mode == "reflexive": len_dim = y.shape[dim] return torch.cat( ( @@ -105,7 +105,7 @@ def fftshift(x: Tensor, dims: Union[int, List[int]] | None = None): return torch.roll(x, shifts, dims) -def ifftshift(x: Tensor, dims: Union[int, List[int]] | None = None): +def ifftshift(x: Tensor, dims: Union[int, Tuple[int]] | None = None): """ Similar to np.fft.ifftshift but applies to PyTorch tensors. From fastMRI code. """ diff --git a/mirtorch/linear/wavelets.py b/mirtorch/linear/wavelets.py index cbdd4e6..609082b 100644 --- a/mirtorch/linear/wavelets.py +++ b/mirtorch/linear/wavelets.py @@ -2,7 +2,7 @@ from typing import Sequence, Tuple, List import torch -from pytorch_wavelets import DWTForward, DWTInverse +from mirtorch.vendors.pytorch_wavelets import DWTForward, DWTInverse from torch import Tensor from .linearmaps import LinearMap diff --git a/mirtorch/vendors/pytorch_wavelets/__init__.py b/mirtorch/vendors/pytorch_wavelets/__init__.py new file mode 100644 index 0000000..36dbc03 --- /dev/null +++ b/mirtorch/vendors/pytorch_wavelets/__init__.py @@ -0,0 +1,35 @@ +__all__ = [ + "__version__", + "DTCWTForward", + "DTCWTInverse", + "DWTForward", + "DWTInverse", + "DWT1DForward", + "DWT1DInverse", + "DTCWT", + "IDTCWT", + "DWT", + "IDWT", + "DWT1D", + "DWT2D", + "IDWT1D", + "IDWT2D", + "ScatLayer", + "ScatLayerj2", +] + +from .dtcwt.transform2d import DTCWTForward, DTCWTInverse +from .dwt.transform2d import DWTForward, DWTInverse +from .dwt.transform1d import DWT1DForward, DWT1DInverse +from .scatternet import ScatLayer, ScatLayerj2 + +# Some aliases +DTCWT = DTCWTForward +IDTCWT = DTCWTInverse +DWT = DWTForward +IDWT = DWTInverse +DWT2D = DWT +IDWT2D = IDWT + +DWT1D = DWT1DForward +IDWT1D = DWT1DInverse diff --git a/mirtorch/vendors/pytorch_wavelets/_version.py b/mirtorch/vendors/pytorch_wavelets/_version.py new file mode 100644 index 0000000..fd832a5 --- /dev/null +++ b/mirtorch/vendors/pytorch_wavelets/_version.py @@ -0,0 +1,2 @@ +# IMPORTANT: before release, remove the 'devN' tag from the release name +__version__ = "1.3.0" diff --git a/mirtorch/vendors/pytorch_wavelets/dtcwt/__init__.py b/mirtorch/vendors/pytorch_wavelets/dtcwt/__init__.py new file mode 100644 index 0000000..68296e7 --- /dev/null +++ b/mirtorch/vendors/pytorch_wavelets/dtcwt/__init__.py @@ -0,0 +1,6 @@ +""" +Provide low-level torch accelerated operations. This backend requires that +torch be installed. Works best with a GPU but still offers good +improvements with a CPU. + +""" diff --git a/mirtorch/vendors/pytorch_wavelets/dtcwt/coeffs.py b/mirtorch/vendors/pytorch_wavelets/dtcwt/coeffs.py new file mode 100644 index 0000000..6055b21 --- /dev/null +++ b/mirtorch/vendors/pytorch_wavelets/dtcwt/coeffs.py @@ -0,0 +1,142 @@ +"""Functions to load standard wavelet coefficients. + +""" +from __future__ import absolute_import + +from numpy import load + +try: + import pywt + + _HAVE_PYWT = True +except ImportError: + _HAVE_PYWT = False + +COEFF_CACHE = {} + + +def _load_from_file(basename, varnames): + + try: + mat = COEFF_CACHE[basename] + except KeyError: + COEFF_CACHE[basename] = mat + + try: + return tuple(mat[k] for k in varnames) + except KeyError: + raise ValueError( + "Wavelet does not define ({0}) coefficients".format(", ".join(varnames)) + ) + + +def biort(name): + """Deprecated. Use :py::func:`pytorch_wavelets.dtcwt.coeffs.level1` + Instead + """ + return level1(name, compact=True) + + +def level1(name, compact=False): + """Load level 1 wavelet by name. + + :param name: a string specifying the wavelet family name + :returns: a tuple of vectors giving filter coefficients + + ============= ============================================ + Name Wavelet + ============= ============================================ + antonini Antonini 9,7 tap filters. + farras Farras 8,8 tap filters + legall LeGall 5,3 tap filters. + near_sym_a Near-Symmetric 5,7 tap filters. + near_sym_b Near-Symmetric 13,19 tap filters. + near_sym_b_bp Near-Symmetric 13,19 tap filters + BP filter + ============= ============================================ + + Return a tuple whose elements are a vector specifying the h0o, g0o, h1o and + g1o coefficients. + + See :ref:`rot-symm-wavelets` for an explanation of the ``near_sym_b_bp`` + wavelet filters. + + :raises IOError: if name does not correspond to a set of wavelets known to + the library. + :raises ValueError: if name doesn't specify + :py:func:`pytorch_wavelets.dtcwt.coeffs.qshift` wavelet. + + """ + if compact: + if name == "near_sym_b_bp": + return _load_from_file(name, ("h0o", "g0o", "h1o", "g1o", "h2o", "g2o")) + else: + return _load_from_file(name, ("h0o", "g0o", "h1o", "g1o")) + else: + return _load_from_file( + name, ("h0a", "h0b", "g0a", "g0b", "h1a", "h1b", "g1a", "g1b") + ) + + +def qshift(name): + """Load level >=2 wavelet by name, + + :param name: a string specifying the wavelet family name + :returns: a tuple of vectors giving filter coefficients + + ============ ============================================ + Name Wavelet + ============ ============================================ + qshift_06 Quarter Sample Shift Orthogonal (Q-Shift) 10,10 tap filters, + (only 6,6 non-zero taps). + qshift_a Q-shift 10,10 tap filters, + (with 10,10 non-zero taps, unlike qshift_06). + qshift_b Q-Shift 14,14 tap filters. + qshift_c Q-Shift 16,16 tap filters. + qshift_d Q-Shift 18,18 tap filters. + qshift_b_bp Q-Shift 18,18 tap filters + BP + ============ ============================================ + + Return a tuple whose elements are a vector specifying the h0a, h0b, g0a, + g0b, h1a, h1b, g1a and g1b coefficients. + + See :ref:`rot-symm-wavelets` for an explanation of the ``qshift_b_bp`` + wavelet filters. + + :raises IOError: if name does not correspond to a set of wavelets known to + the library. + :raises ValueError: if name doesn't specify a + :py:func:`pytorch_wavelets.dtcwt.coeffs.biort` wavelet. + + """ + if name == "qshift_b_bp": + return _load_from_file( + name, + ( + "h0a", + "h0b", + "g0a", + "g0b", + "h1a", + "h1b", + "g1a", + "g1b", + "h2a", + "h2b", + "g2a", + "g2b", + ), + ) + else: + return _load_from_file( + name, ("h0a", "h0b", "g0a", "g0b", "h1a", "h1b", "g1a", "g1b") + ) + + +def pywt_coeffs(name): + """Wraps pywt Wavelet function.""" + if not _HAVE_PYWT: + raise ImportError("Could not find PyWavelets module") + return pywt.Wavelet(name) + + +# vim:sw=4:sts=4:et diff --git a/mirtorch/vendors/pytorch_wavelets/dtcwt/lowlevel.py b/mirtorch/vendors/pytorch_wavelets/dtcwt/lowlevel.py new file mode 100644 index 0000000..125c440 --- /dev/null +++ b/mirtorch/vendors/pytorch_wavelets/dtcwt/lowlevel.py @@ -0,0 +1,381 @@ +from __future__ import absolute_import + +import torch +import torch.nn.functional as F +import numpy as np +from ..utils import symm_pad_1d as symm_pad + + +def as_column_vector(v): + """Return *v* as a column vector with shape (N,1).""" + v = np.atleast_2d(v) + if v.shape[0] == 1: + return v.T + else: + return v + + +def _as_row_vector(v): + """Return *v* as a row vector with shape (1, N).""" + v = np.atleast_2d(v) + if v.shape[0] == 1: + return v + else: + return v.T + + +def _as_row_tensor(h): + if isinstance(h, torch.Tensor): + h = torch.reshape(h, [1, -1]) + else: + h = as_column_vector(h).T + h = torch.tensor(h, dtype=torch.get_default_dtype()) + return h + + +def _as_col_vector(v): + """Return *v* as a column vector with shape (N,1).""" + v = np.atleast_2d(v) + if v.shape[0] == 1: + return v.T + else: + return v + + +def _as_col_tensor(h): + if isinstance(h, torch.Tensor): + h = torch.reshape(h, [-1, 1]) + else: + h = as_column_vector(h) + h = torch.tensor(h, dtype=torch.get_default_dtype()) + return h + + +def prep_filt(h, c, transpose=False): + """Prepares an array to be of the correct format for pytorch. + Can also specify whether to make it a row filter (set tranpose=True)""" + h = _as_col_vector(h)[::-1] + h = h[None, None, :] + h = np.repeat(h, repeats=c, axis=0) + if transpose: + h = h.transpose((0, 1, 3, 2)) + h = np.copy(h) + return torch.tensor(h, dtype=torch.get_default_dtype()) + + +def colfilter(X, h, mode="symmetric"): + if X is None or X.shape == torch.Size([]): + return torch.zeros(1, 1, 1, 1, device=X.device) + b, ch, row, col = X.shape + m = h.shape[2] // 2 + if mode == "symmetric": + xe = symm_pad(row, m) + X = F.conv2d(X[:, :, xe], h.repeat(ch, 1, 1, 1), groups=ch) + else: + X = F.conv2d(X, h.repeat(ch, 1, 1, 1), groups=ch, padding=(m, 0)) + return X + + +def rowfilter(X, h, mode="symmetric"): + if X is None or X.shape == torch.Size([]): + return torch.zeros(1, 1, 1, 1, device=X.device) + b, ch, row, col = X.shape + m = h.shape[2] // 2 + h = h.transpose(2, 3).contiguous() + if mode == "symmetric": + xe = symm_pad(col, m) + X = F.conv2d(X[:, :, :, xe], h.repeat(ch, 1, 1, 1), groups=ch) + else: + X = F.conv2d(X, h.repeat(ch, 1, 1, 1), groups=ch, padding=(0, m)) + return X + + +def coldfilt(X, ha, hb, highpass=False, mode="symmetric"): + if X is None or X.shape == torch.Size([]): + return torch.zeros(1, 1, 1, 1, device=X.device) + batch, ch, r, c = X.shape + r2 = r // 2 + if r % 4 != 0: + raise ValueError( + "No. of rows in X must be a multiple of 4\n" + "X was {}".format(X.shape) + ) + + if mode == "symmetric": + m = ha.shape[2] + xe = symm_pad(r, m) + X = torch.cat((X[:, :, xe[2::2]], X[:, :, xe[3::2]]), dim=1) + h = torch.cat((ha.repeat(ch, 1, 1, 1), hb.repeat(ch, 1, 1, 1)), dim=0) + X = F.conv2d(X, h, stride=(2, 1), groups=ch * 2) + else: + raise NotImplementedError() + + # Reshape result to be shape [Batch, ch, r/2, c]. This reshaping + # interleaves the columns + if highpass: + X = torch.stack((X[:, ch:], X[:, :ch]), dim=-2).view(batch, ch, r2, c) + else: + X = torch.stack((X[:, :ch], X[:, ch:]), dim=-2).view(batch, ch, r2, c) + + return X + + +def rowdfilt(X, ha, hb, highpass=False, mode="symmetric"): + if X is None or X.shape == torch.Size([]): + return torch.zeros(1, 1, 1, 1, device=X.device) + batch, ch, r, c = X.shape + c2 = c // 2 + if c % 4 != 0: + raise ValueError( + "No. of cols in X must be a multiple of 4\n" + "X was {}".format(X.shape) + ) + + if mode == "symmetric": + m = ha.shape[2] + xe = symm_pad(c, m) + X = torch.cat((X[:, :, :, xe[2::2]], X[:, :, :, xe[3::2]]), dim=1) + h = torch.cat( + ( + ha.reshape(1, 1, 1, m).repeat(ch, 1, 1, 1), + hb.reshape(1, 1, 1, m).repeat(ch, 1, 1, 1), + ), + dim=0, + ) + X = F.conv2d(X, h, stride=(1, 2), groups=ch * 2) + else: + raise NotImplementedError() + + # Reshape result to be shape [Batch, ch, r/2, c]. This reshaping + # interleaves the columns + if highpass: + Y = torch.stack((X[:, ch:], X[:, :ch]), dim=-1).view(batch, ch, r, c2) + else: + Y = torch.stack((X[:, :ch], X[:, ch:]), dim=-1).view(batch, ch, r, c2) + + return Y + + +def colifilt(X, ha, hb, highpass=False, mode="symmetric"): + if X is None or X.shape == torch.Size([]): + return torch.zeros(1, 1, 1, 1, device=X.device) + m = ha.shape[2] + m2 = m // 2 + hao = ha[:, :, 1::2] + hae = ha[:, :, ::2] + hbo = hb[:, :, 1::2] + hbe = hb[:, :, ::2] + batch, ch, r, c = X.shape + if r % 2 != 0: + raise ValueError( + "No. of rows in X must be a multiple of 2.\n" + "X was {}".format(X.shape) + ) + xe = symm_pad(r, m2) + + if m2 % 2 == 0: + h1 = hae + h2 = hbe + h3 = hao + h4 = hbo + if highpass: + X = torch.cat( + ( + X[:, :, xe[1:-2:2]], + X[:, :, xe[:-2:2]], + X[:, :, xe[3::2]], + X[:, :, xe[2::2]], + ), + dim=1, + ) + else: + X = torch.cat( + ( + X[:, :, xe[:-2:2]], + X[:, :, xe[1:-2:2]], + X[:, :, xe[2::2]], + X[:, :, xe[3::2]], + ), + dim=1, + ) + else: + h1 = hao + h2 = hbo + h3 = hae + h4 = hbe + if highpass: + X = torch.cat( + ( + X[:, :, xe[2:-1:2]], + X[:, :, xe[1:-1:2]], + X[:, :, xe[2:-1:2]], + X[:, :, xe[1:-1:2]], + ), + dim=1, + ) + else: + X = torch.cat( + ( + X[:, :, xe[1:-1:2]], + X[:, :, xe[2:-1:2]], + X[:, :, xe[1:-1:2]], + X[:, :, xe[2:-1:2]], + ), + dim=1, + ) + h = torch.cat( + ( + h1.repeat(ch, 1, 1, 1), + h2.repeat(ch, 1, 1, 1), + h3.repeat(ch, 1, 1, 1), + h4.repeat(ch, 1, 1, 1), + ), + dim=0, + ) + + X = F.conv2d(X, h, groups=4 * ch) + # Stack 4 tensors of shape [batch, ch, r2, c] into one tensor + # [batch, ch, r2, 4, c] + X = torch.stack( + [X[:, :ch], X[:, ch : 2 * ch], X[:, 2 * ch : 3 * ch], X[:, 3 * ch :]], dim=3 + ).view(batch, ch, r * 2, c) + + return X + + +def rowifilt(X, ha, hb, highpass=False, mode="symmetric"): + if X is None or X.shape == torch.Size([]): + return torch.zeros(1, 1, 1, 1, device=X.device) + m = ha.shape[2] + m2 = m // 2 + hao = ha[:, :, 1::2] + hae = ha[:, :, ::2] + hbo = hb[:, :, 1::2] + hbe = hb[:, :, ::2] + batch, ch, r, c = X.shape + if c % 2 != 0: + raise ValueError( + "No. of cols in X must be a multiple of 2.\n" + "X was {}".format(X.shape) + ) + xe = symm_pad(c, m2) + + if m2 % 2 == 0: + h1 = hae + h2 = hbe + h3 = hao + h4 = hbo + if highpass: + X = torch.cat( + ( + X[:, :, :, xe[1:-2:2]], + X[:, :, :, xe[:-2:2]], + X[:, :, :, xe[3::2]], + X[:, :, :, xe[2::2]], + ), + dim=1, + ) + else: + X = torch.cat( + ( + X[:, :, :, xe[:-2:2]], + X[:, :, :, xe[1:-2:2]], + X[:, :, :, xe[2::2]], + X[:, :, :, xe[3::2]], + ), + dim=1, + ) + else: + h1 = hao + h2 = hbo + h3 = hae + h4 = hbe + if highpass: + X = torch.cat( + ( + X[:, :, :, xe[2:-1:2]], + X[:, :, :, xe[1:-1:2]], + X[:, :, :, xe[2:-1:2]], + X[:, :, :, xe[1:-1:2]], + ), + dim=1, + ) + else: + X = torch.cat( + ( + X[:, :, :, xe[1:-1:2]], + X[:, :, :, xe[2:-1:2]], + X[:, :, :, xe[1:-1:2]], + X[:, :, :, xe[2:-1:2]], + ), + dim=1, + ) + h = torch.cat( + ( + h1.repeat(ch, 1, 1, 1), + h2.repeat(ch, 1, 1, 1), + h3.repeat(ch, 1, 1, 1), + h4.repeat(ch, 1, 1, 1), + ), + dim=0, + ).reshape(4 * ch, 1, 1, m2) + + X = F.conv2d(X, h, groups=4 * ch) + # Stack 4 tensors of shape [batch, ch, r2, c] into one tensor + # [batch, ch, r2, 4, c] + X = torch.stack( + [X[:, :ch], X[:, ch : 2 * ch], X[:, 2 * ch : 3 * ch], X[:, 3 * ch :]], dim=4 + ).view(batch, ch, r, c * 2) + return X + + +# def q2c(y, dim=-1): +def q2c(y, dim=-1): + """ + Convert from quads in y to complex numbers in z. + """ + + # Arrange pixels from the corners of the quads into + # 2 subimages of alternate real and imag pixels. + # a----b + # | | + # | | + # c----d + # Combine (a,b) and (d,c) to form two complex subimages. + y = y / np.sqrt(2) + a, b = y[:, :, 0::2, 0::2], y[:, :, 0::2, 1::2] + c, d = y[:, :, 1::2, 0::2], y[:, :, 1::2, 1::2] + + # return torch.stack((a-d, b+c), dim=dim), torch.stack((a+d, b-c), dim=dim) + return ((a - d, b + c), (a + d, b - c)) + + +def c2q(w1, w2): + """ + Scale by gain and convert from complex w(:,:,1:2) to real quad-numbers + in z. + + Arrange pixels from the real and imag parts of the 2 highpasses + into 4 separate subimages . + A----B Re Im of w(:,:,1) + | | + | | + C----D Re Im of w(:,:,2) + + """ + w1r, w1i = w1 + w2r, w2i = w2 + + x1 = w1r + w2r + x2 = w1i + w2i + x3 = w1i - w2i + x4 = -w1r + w2r + + # Get the shape of the tensor excluding the real/imagniary part + b, ch, r, c = w1r.shape + + # Create new empty tensor and fill it + y = w1r.new_zeros((b, ch, r * 2, c * 2), requires_grad=w1r.requires_grad) + y[:, :, ::2, ::2] = x1 + y[:, :, ::2, 1::2] = x2 + y[:, :, 1::2, ::2] = x3 + y[:, :, 1::2, 1::2] = x4 + y /= np.sqrt(2) + + return y diff --git a/mirtorch/vendors/pytorch_wavelets/dtcwt/lowlevel2.py b/mirtorch/vendors/pytorch_wavelets/dtcwt/lowlevel2.py new file mode 100644 index 0000000..cdf6ec8 --- /dev/null +++ b/mirtorch/vendors/pytorch_wavelets/dtcwt/lowlevel2.py @@ -0,0 +1,663 @@ +""" This module was part of an attempt to speed up the DTCWT. The code was +ultimately slower than the original implementation, but it is a nice +reference point for doing a DTCWT directly as 4 separate DWTs. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..dwt.lowlevel import roll, mypad +import pywt +from ..dwt.transform2d import DWTForward, DWTInverse +from ..dwt.lowlevel import afb2d, sfb2d_nonsep as sfb2d +from ..dwt.lowlevel import prep_filt_afb2d, prep_filt_sfb2d_nonsep as prep_filt_sfb2d +from ..dtcwt.coeffs import level1 as _level1, qshift as _qshift, biort as _biort + + +class DTCWTForward2(nn.Module): + """DTCWT based on 4 DWTs. Still works, but the above implementation is + faster""" + + def __init__(self, biort="farras", qshift="qshift_a", J=3, mode="symmetric"): + super().__init__() + self.biort = biort + self.qshift = qshift + self.J = J + + if isinstance(biort, str): + biort = _level1(biort) + assert len(biort) == 8 + h0a1, h0b1, _, _, h1a1, h1b1, _, _ = biort + DWTaa1 = DWTForward(J=1, wave=(h0a1, h1a1, h0a1, h1a1), mode=mode) + DWTab1 = DWTForward(J=1, wave=(h0a1, h1a1, h0b1, h1b1), mode=mode) + DWTba1 = DWTForward(J=1, wave=(h0b1, h1b1, h0a1, h1a1), mode=mode) + DWTbb1 = DWTForward(J=1, wave=(h0b1, h1b1, h0b1, h1b1), mode=mode) + self.level1 = nn.ModuleList([DWTaa1, DWTab1, DWTba1, DWTbb1]) + + if J > 1: + if isinstance(qshift, str): + qshift = _qshift(qshift) + assert len(qshift) == 8 + h0a, h0b, _, _, h1a, h1b, _, _ = qshift + DWTaa = DWTForward(J - 1, (h0a, h1a, h0a, h1a), mode=mode) + DWTab = DWTForward(J - 1, (h0a, h1a, h0b, h1b), mode=mode) + DWTba = DWTForward(J - 1, (h0b, h1b, h0a, h1a), mode=mode) + DWTbb = DWTForward(J - 1, (h0b, h1b, h0b, h1b), mode=mode) + self.level2 = nn.ModuleList([DWTaa, DWTab, DWTba, DWTbb]) + + def forward(self, x): + x = x / 2 + J = self.J + w = [[[None for _ in range(2)] for _ in range(2)] for j in range(J)] + lows = [[None for _ in range(2)] for _ in range(2)] + for m in range(2): + for n in range(2): + # Do the first level transform + ll, (w[0][m][n],) = self.level1[m * 2 + n](x) + # w[0][m][n] = [bands[:,:,2], bands[:,:,1], bands[:,:,3]] + + # Do the second+ level transform with the second level filters + if J > 1: + ll, bands = self.level2[m * 2 + n](ll) + for j in range(1, J): + w[j][m][n] = bands[j - 1] + lows[m][n] = ll + + # Convert the quads into real and imaginary parts + yh = [ + None, + ] * J + for j in range(J): + deg75r, deg105i = pm(w[j][0][0][:, :, 1], w[j][1][1][:, :, 1]) + deg105r, deg75i = pm(w[j][0][1][:, :, 1], w[j][1][0][:, :, 1]) + deg15r, deg165i = pm(w[j][0][0][:, :, 0], w[j][1][1][:, :, 0]) + deg165r, deg15i = pm(w[j][0][1][:, :, 0], w[j][1][0][:, :, 0]) + deg135r, deg45i = pm(w[j][0][0][:, :, 2], w[j][1][1][:, :, 2]) + deg45r, deg135i = pm(w[j][0][1][:, :, 2], w[j][1][0][:, :, 2]) + w[j] = None + yhr = torch.stack( + (deg15r, deg45r, deg75r, deg105r, deg135r, deg165r), dim=1 + ) + yhi = torch.stack( + (deg15i, deg45i, deg75i, deg105i, deg135i, deg165i), dim=1 + ) + yh[j] = torch.stack((yhr, yhi), dim=-1) + + return lows, yh + + +class DTCWTInverse2(nn.Module): + def __init__(self, biort="farras", qshift="qshift_a", mode="symmetric"): + super().__init__() + self.biort = biort + self.qshift = qshift + + if isinstance(biort, str): + biort = _level1(biort) + assert len(biort) == 8 + _, _, g0a1, g0b1, _, _, g1a1, g1b1 = biort + IWTaa1 = DWTInverse(wave=(g0a1, g1a1, g0a1, g1a1), mode=mode) + IWTab1 = DWTInverse(wave=(g0a1, g1a1, g0b1, g1b1), mode=mode) + IWTba1 = DWTInverse(wave=(g0b1, g1b1, g0a1, g1a1), mode=mode) + IWTbb1 = DWTInverse(wave=(g0b1, g1b1, g0b1, g1b1), mode=mode) + self.level1 = nn.ModuleList([IWTaa1, IWTab1, IWTba1, IWTbb1]) + + if isinstance(qshift, str): + qshift = _qshift(qshift) + assert len(qshift) == 8 + _, _, g0a, g0b, _, _, g1a, g1b = qshift + IWTaa = DWTInverse(wave=(g0a, g1a, g0a, g1a), mode=mode) + IWTab = DWTInverse(wave=(g0a, g1a, g0b, g1b), mode=mode) + IWTba = DWTInverse(wave=(g0b, g1b, g0a, g1a), mode=mode) + IWTbb = DWTInverse(wave=(g0b, g1b, g0b, g1b), mode=mode) + self.level2 = nn.ModuleList([IWTaa, IWTab, IWTba, IWTbb]) + + def forward(self, x): + # Convert the highs back to subbands + yl, yh = x + J = len(yh) + # w = [[[[None for i in range(3)] for j in range(2)] + # for k in range(2)] for l in range(J)] + w = [ + [[[None for band in range(3)] for j in range(J)] for m in range(2)] + for n in range(2) + ] + for j in range(J): + w[0][0][j][0], w[1][1][j][0] = pm( + yh[j][:, 2, :, :, :, 0], yh[j][:, 3, :, :, :, 1] + ) + w[0][1][j][0], w[1][0][j][0] = pm( + yh[j][:, 3, :, :, :, 0], yh[j][:, 2, :, :, :, 1] + ) + w[0][0][j][1], w[1][1][j][1] = pm( + yh[j][:, 0, :, :, :, 0], yh[j][:, 5, :, :, :, 1] + ) + w[0][1][j][1], w[1][0][j][1] = pm( + yh[j][:, 5, :, :, :, 0], yh[j][:, 0, :, :, :, 1] + ) + w[0][0][j][2], w[1][1][j][2] = pm( + yh[j][:, 1, :, :, :, 0], yh[j][:, 4, :, :, :, 1] + ) + w[0][1][j][2], w[1][0][j][2] = pm( + yh[j][:, 4, :, :, :, 0], yh[j][:, 1, :, :, :, 1] + ) + w[0][0][j] = torch.stack(w[0][0][j], dim=2) + w[0][1][j] = torch.stack(w[0][1][j], dim=2) + w[1][0][j] = torch.stack(w[1][0][j], dim=2) + w[1][1][j] = torch.stack(w[1][1][j], dim=2) + + y = None + for m in range(2): + for n in range(2): + lo = yl[m][n] + if J > 1: + lo = self.level2[m * 2 + n]((lo, w[m][n][1:])) + lo = self.level1[m * 2 + n]((lo, (w[m][n][0],))) + + # Add to the output + if y is None: + y = lo + else: + y = y + lo + + # Normalize + y = y / 2 + return y + + +def prep_filt_quad_afb2d_nonsep( + h0a_col, + h1a_col, + h0a_row, + h1a_row, + h0b_col, + h1b_col, + h0b_row, + h1b_row, + h0c_col, + h1c_col, + h0c_row, + h1c_row, + h0d_col, + h1d_col, + h0d_row, + h1d_row, + device=None, +): + """ + Prepares the filters to be of the right form for the afb2d_nonsep function. + In particular, makes 2d point spread functions, and mirror images them in + preparation to do torch.conv2d. + + Inputs: + h0_col (array-like): low pass column filter bank + h1_col (array-like): high pass column filter bank + h0_row (array-like): low pass row filter bank. If none, will assume the + same as column filter + h1_row (array-like): high pass row filter bank. If none, will assume the + same as column filter + device: which device to put the tensors on to + + Returns: + filts: (4, 1, h, w) tensor ready to get the four subbands + """ + lla = np.outer(h0a_col, h0a_row) + lha = np.outer(h1a_col, h0a_row) + hla = np.outer(h0a_col, h1a_row) + hha = np.outer(h1a_col, h1a_row) + llb = np.outer(h0b_col, h0b_row) + lhb = np.outer(h1b_col, h0b_row) + hlb = np.outer(h0b_col, h1b_row) + hhb = np.outer(h1b_col, h1b_row) + llc = np.outer(h0c_col, h0c_row) + lhc = np.outer(h1c_col, h0c_row) + hlc = np.outer(h0c_col, h1c_row) + hhc = np.outer(h1c_col, h1c_row) + lld = np.outer(h0d_col, h0d_row) + lhd = np.outer(h1d_col, h0d_row) + hld = np.outer(h0d_col, h1d_row) + hhd = np.outer(h1d_col, h1d_row) + filts = np.stack( + [ + lla[None, ::-1, ::-1], + llb[None, ::-1, ::-1], + llc[None, ::-1, ::-1], + lld[None, ::-1, ::-1], + lha[None, ::-1, ::-1], + lhb[None, ::-1, ::-1], + lhc[None, ::-1, ::-1], + lhd[None, ::-1, ::-1], + hla[None, ::-1, ::-1], + hlb[None, ::-1, ::-1], + hlc[None, ::-1, ::-1], + hld[None, ::-1, ::-1], + hha[None, ::-1, ::-1], + hhb[None, ::-1, ::-1], + hhc[None, ::-1, ::-1], + hhd[None, ::-1, ::-1], + ], + axis=0, + ) + filts = torch.tensor(filts, dtype=torch.get_default_dtype(), device=device) + return filts + + +def prep_filt_quad_afb2d(h0a, h1a, h0b, h1b, device=None): + """ + Prepares the filters to be of the right form for the quad_afb2d function. + + Inputs: + h0_col (array-like): low pass column filter bank + h1_col (array-like): high pass column filter bank + h0_row (array-like): low pass row filter bank. If none, will assume the + same as column filter + h1_row (array-like): high pass row filter bank. If none, will assume the + same as column filter + device: which device to put the tensors on to + + Returns: + filts: (4, 1, h, w) tensor ready to get the four subbands + """ + h0a_col = np.array(h0a).ravel()[::-1][None, :, None] + h1a_col = np.array(h1a).ravel()[::-1][None, :, None] + h0b_col = np.array(h0a).ravel()[::-1][None, :, None] + h1b_col = np.array(h1a).ravel()[::-1][None, :, None] + h0c_col = np.array(h0b).ravel()[::-1][None, :, None] + h1c_col = np.array(h1b).ravel()[::-1][None, :, None] + h0d_col = np.array(h0b).ravel()[::-1][None, :, None] + h1d_col = np.array(h1b).ravel()[::-1][None, :, None] + h0a_row = np.array(h0a).ravel()[::-1][None, None, :] + h1a_row = np.array(h1a).ravel()[::-1][None, None, :] + h0b_row = np.array(h0b).ravel()[::-1][None, None, :] + h1b_row = np.array(h1b).ravel()[::-1][None, None, :] + h0c_row = np.array(h0a).ravel()[::-1][None, None, :] + h1c_row = np.array(h1a).ravel()[::-1][None, None, :] + h0d_row = np.array(h0b).ravel()[::-1][None, None, :] + h1d_row = np.array(h1b).ravel()[::-1][None, None, :] + cols = np.stack( + (h0a_col, h1a_col, h0b_col, h1b_col, h0c_col, h1c_col, h0d_col, h1d_col), axis=0 + ) + rows = np.stack( + ( + h0a_row, + h1a_row, + h0a_row, + h1a_row, + h0b_row, + h1b_row, + h0b_row, + h1b_row, + h0c_row, + h1c_row, + h0c_row, + h1c_row, + h0d_row, + h1d_row, + h0d_row, + h1d_row, + ), + axis=0, + ) + cols = torch.tensor(np.copy(cols), dtype=torch.get_default_dtype(), device=device) + rows = torch.tensor(np.copy(rows), dtype=torch.get_default_dtype(), device=device) + return cols, rows + + +def quad_afb2d(x, cols, rows, mode="zero", split=True, stride=2): + """Does a single level 2d wavelet decomposition of an input. Does separate + row and column filtering by two calls to + :py:func:`pytorch_wavelets.dwt.lowlevel.afb1d` + + Inputs: + x (torch.Tensor): Input to decompose + filts (list of ndarray or torch.Tensor): If a list of tensors has been + given, this function assumes they are in the right form (the form + returned by + :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d`). + Otherwise, this function will prepare the filters to be of the right + form by calling + :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d`. + mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which + padding to use. If periodization, the output size will be half the + input size. Otherwise, the output size will be slightly larger than + half. + """ + x = x / 2 + C = x.shape[1] + cols = torch.cat([cols] * C, dim=0) + rows = torch.cat([rows] * C, dim=0) + + if mode == "per" or mode == "periodization": + # Do column filtering + L = cols.shape[2] + L2 = L // 2 + if x.shape[2] % 2 == 1: + x = torch.cat((x, x[:, :, -1:]), dim=2) + N2 = x.shape[2] // 2 + x = roll(x, -L2, dim=2) + pad = (L - 1, 0) + lohi = F.conv2d(x, cols, padding=pad, stride=(stride, 1), groups=C) + lohi[:, :, :L2] = lohi[:, :, :L2] + lohi[:, :, N2 : N2 + L2] + lohi = lohi[:, :, :N2] + + # Do row filtering + L = rows.shape[3] + L2 = L // 2 + if lohi.shape[3] % 2 == 1: + lohi = torch.cat((lohi, lohi[:, :, :, -1:]), dim=3) + N2 = x.shape[3] // 2 + lohi = roll(lohi, -L2, dim=3) + pad = (0, L - 1) + w = F.conv2d(lohi, rows, padding=pad, stride=(1, stride), groups=8 * C) + w[:, :, :, :L2] = w[:, :, :, :L2] + w[:, :, :, N2 : N2 + L2] + w = w[:, :, :, :N2] + elif mode == "zero": + # Do column filtering + N = x.shape[2] + L = cols.shape[2] + outsize = pywt.dwt_coeff_len(N, L, mode="zero") + p = 2 * (outsize - 1) - N + L + + # Sadly, pytorch only allows for same padding before and after, if + # we need to do more padding after for odd length signals, have to + # prepad + if p % 2 == 1: + x = F.pad(x, (0, 0, 0, 1)) + pad = (p // 2, 0) + # Calculate the high and lowpass + lohi = F.conv2d(x, cols, padding=pad, stride=(stride, 1), groups=C) + + # Do row filtering + N = lohi.shape[3] + L = rows.shape[3] + outsize = pywt.dwt_coeff_len(N, L, mode="zero") + p = 2 * (outsize - 1) - N + L + if p % 2 == 1: + lohi = F.pad(lohi, (0, 1, 0, 0)) + pad = (0, p // 2) + w = F.conv2d(lohi, rows, padding=pad, stride=(1, stride), groups=8 * C) + elif mode == "symmetric" or mode == "reflect": + # Do column filtering + N = x.shape[2] + L = cols.shape[2] + outsize = pywt.dwt_coeff_len(N, L, mode=mode) + p = 2 * (outsize - 1) - N + L + x = mypad(x, pad=(0, 0, p // 2, (p + 1) // 2), mode=mode) + lohi = F.conv2d(x, cols, stride=(stride, 1), groups=C) + + # Do row filtering + N = lohi.shape[3] + L = rows.shape[3] + outsize = pywt.dwt_coeff_len(N, L, mode=mode) + p = 2 * (outsize - 1) - N + L + lohi = mypad(lohi, pad=(p // 2, (p + 1) // 2, 0, 0), mode=mode) + w = F.conv2d(lohi, rows, stride=(1, stride), groups=8 * C) + else: + raise ValueError("Unkown pad type: {}".format(mode)) + + y = w.view((w.shape[0], C, 4, 4, w.shape[-2], w.shape[-1])) + yl = y[:, :, :, 0] + yh = y[:, :, :, 1:] + deg75r, deg105i = pm(yh[:, :, 0, 0], yh[:, :, 3, 0]) + deg105r, deg75i = pm(yh[:, :, 1, 0], yh[:, :, 2, 0]) + deg15r, deg165i = pm(yh[:, :, 0, 1], yh[:, :, 3, 1]) + deg165r, deg15i = pm(yh[:, :, 1, 1], yh[:, :, 2, 1]) + deg135r, deg45i = pm(yh[:, :, 0, 2], yh[:, :, 3, 2]) + deg45r, deg135i = pm(yh[:, :, 1, 2], yh[:, :, 2, 2]) + yhr = torch.stack((deg15r, deg45r, deg75r, deg105r, deg135r, deg165r), dim=1) + yhi = torch.stack((deg15i, deg45i, deg75i, deg105i, deg135i, deg165i), dim=1) + yh = torch.stack((yhr, yhi), dim=-1) + + yl_rowa = torch.stack((yl[:, :, 1], yl[:, :, 0]), dim=-1) + yl_rowb = torch.stack((yl[:, :, 3], yl[:, :, 2]), dim=-1) + yl_rowa = yl_rowa.view(yl.shape[0], C, yl.shape[-2], yl.shape[-1] * 2) + yl_rowb = yl_rowb.view(yl.shape[0], C, yl.shape[-2], yl.shape[-1] * 2) + z = torch.stack((yl_rowb, yl_rowa), dim=-2) + yl = z.view(yl.shape[0], C, yl.shape[-2] * 2, yl.shape[-1] * 2) + + return yl.contiguous(), yh + + +def quad_afb2d_nonsep(x, filts, mode="zero"): + """Does a 1 level 2d wavelet decomposition of an input. Doesn't do separate + row and column filtering. + + Inputs: + x (torch.Tensor): Input to decompose + filts (list or torch.Tensor): If a list is given, should be the low and + highpass filter banks. If a tensor is given, it should be of the + form created by + :py:func:`pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d_nonsep` + mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which + padding to use. If periodization, the output size will be half the + input size. Otherwise, the output size will be slightly larger than + half. + """ + C = x.shape[1] + Ny = x.shape[2] + Nx = x.shape[3] + + # Check the filter inputs + f = torch.cat([filts] * C, dim=0) + Ly = f.shape[2] + Lx = f.shape[3] + + if mode == "periodization" or mode == "per": + if x.shape[2] % 2 == 1: + x = torch.cat((x, x[:, :, -1:]), dim=2) + Ny += 1 + if x.shape[3] % 2 == 1: + x = torch.cat((x, x[:, :, :, -1:]), dim=3) + Nx += 1 + pad = (Ly - 1, Lx - 1) + stride = (2, 2) + x = roll(roll(x, -Ly // 2, dim=2), -Lx // 2, dim=3) + y = F.conv2d(x, f, padding=pad, stride=stride, groups=C) + y[:, :, : Ly // 2] += y[:, :, Ny // 2 : Ny // 2 + Ly // 2] + y[:, :, :, : Lx // 2] += y[:, :, :, Nx // 2 : Nx // 2 + Lx // 2] + y = y[:, :, : Ny // 2, : Nx // 2] + elif mode == "zero" or mode == "symmetric" or mode == "reflect": + # Calculate the pad size + out1 = pywt.dwt_coeff_len(Ny, Ly, mode=mode) + out2 = pywt.dwt_coeff_len(Nx, Lx, mode=mode) + p1 = 2 * (out1 - 1) - Ny + Ly + p2 = 2 * (out2 - 1) - Nx + Lx + if mode == "zero": + # Sadly, pytorch only allows for same padding before and after, if + # we need to do more padding after for odd length signals, have to + # prepad + if p1 % 2 == 1 and p2 % 2 == 1: + x = F.pad(x, (0, 1, 0, 1)) + elif p1 % 2 == 1: + x = F.pad(x, (0, 0, 0, 1)) + elif p2 % 2 == 1: + x = F.pad(x, (0, 1, 0, 0)) + # Calculate the high and lowpass + y = F.conv2d(x, f, padding=(p1 // 2, p2 // 2), stride=2, groups=C) + elif mode == "symmetric" or mode == "reflect": + pad = (p2 // 2, (p2 + 1) // 2, p1 // 2, (p1 + 1) // 2) + x = mypad(x, pad=pad, mode=mode) + y = F.conv2d(x, f, stride=2, groups=C) + else: + raise ValueError("Unkown pad type: {}".format(mode)) + + y = y.reshape((y.shape[0], C, 4, y.shape[-2], y.shape[-1])) + yl = y[:, :, 0].contiguous() + yh = y[:, :, 1:].contiguous() + return yl, yh + + +def cplxdual2D( + x, J, level1="farras", qshift="qshift_a", mode="periodization", mag=False +): + """Do a complex dtcwt + + Returns: + lows: lowpass outputs from each of the 4 trees. Is a 2x2 list of lists + w: bandpass outputs from each of the 4 trees. Is a list of lists, with + shape [J][2][2][3]. Initially the 3 outputs are the lh, hl and hh from + each of the 4 trees. After doing sums and differences though, they + become the real and imaginary parts for the 6 orientations. In + particular: + first index - indexes over scales + second index - 0 = real, 1 = imaginary + third and fourth indices: + 0,1 = 15 degrees + 1,2 = 45 degrees + 0,0 = 75 degrees + 1,0 = 105 degrees + 0,2 = 135 degrees + 1,1 = 165 degrees + """ + x = x / 2 + # Get the filters + h0a1, h0b1, _, _, h1a1, h1b1, _, _ = _level1(level1) + h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift) + + Faf = ( + ( + prep_filt_afb2d(h0a1, h1a1, h0a1, h1a1, device=x.device), + prep_filt_afb2d(h0a1, h1a1, h0b1, h1b1, device=x.device), + ), + ( + prep_filt_afb2d(h0b1, h1b1, h0a1, h1a1, device=x.device), + prep_filt_afb2d(h0b1, h1b1, h0b1, h1b1, device=x.device), + ), + ) + af = ( + ( + prep_filt_afb2d(h0a, h1a, h0a, h1a, device=x.device), + prep_filt_afb2d(h0a, h1a, h0b, h1b, device=x.device), + ), + ( + prep_filt_afb2d(h0b, h1b, h0a, h1a, device=x.device), + prep_filt_afb2d(h0b, h1b, h0b, h1b, device=x.device), + ), + ) + + # Do 4 fully decimated dwts + w = [[[None for _ in range(2)] for _ in range(2)] for j in range(J)] + lows = [[None for _ in range(2)] for _ in range(2)] + for m in range(2): + for n in range(2): + # Do the first level transform with the first level filters + # ll, bands = afb2d(x, (Faf[m][0], Faf[m][1], Faf[n][0], Faf[n][1]), mode=mode) + bands = afb2d(x, Faf[m][n], mode=mode) + # Separate the low and bandpasses + s = bands.shape + bands = bands.reshape(s[0], -1, 4, s[-2], s[-1]) + ll = bands[:, :, 0].contiguous() + w[0][m][n] = [bands[:, :, 2], bands[:, :, 1], bands[:, :, 3]] + + # Do the second+ level transform with the second level filters + for j in range(1, J): + # ll, bands = afb2d(ll, (af[m][0], af[m][1], af[n][0], af[n][1]), mode=mode) + bands = afb2d(ll, af[m][n], mode=mode) + # Separate the low and bandpasses + s = bands.shape + bands = bands.reshape(s[0], -1, 4, s[-2], s[-1]) + ll = bands[:, :, 0].contiguous() + w[j][m][n] = [bands[:, :, 2], bands[:, :, 1], bands[:, :, 3]] + lows[m][n] = ll + + # Convert the quads into real and imaginary parts + yh = [ + None, + ] * J + for j in range(J): + deg75r, deg105i = pm(w[j][0][0][0], w[j][1][1][0]) + deg105r, deg75i = pm(w[j][0][1][0], w[j][1][0][0]) + deg15r, deg165i = pm(w[j][0][0][1], w[j][1][1][1]) + deg165r, deg15i = pm(w[j][0][1][1], w[j][1][0][1]) + deg135r, deg45i = pm(w[j][0][0][2], w[j][1][1][2]) + deg45r, deg135i = pm(w[j][0][1][2], w[j][1][0][2]) + yhr = torch.stack((deg15r, deg45r, deg75r, deg105r, deg135r, deg165r), dim=1) + yhi = torch.stack((deg15i, deg45i, deg75i, deg105i, deg135i, deg165i), dim=1) + if mag: + yh[j] = torch.sqrt(yhr**2 + yhi**2 + 0.01) - np.sqrt(0.01) + else: + yh[j] = torch.stack((yhr, yhi), dim=-1) + + return lows, yh + + +def icplxdual2D(yl, yh, level1="farras", qshift="qshift_a", mode="periodization"): + # Get the filters + _, _, g0a1, g0b1, _, _, g1a1, g1b1 = _level1(level1) + _, _, g0a, g0b, _, _, g1a, g1b = _qshift(qshift) + + dev = yl[0][0].device + Faf = ( + ( + prep_filt_sfb2d(g0a1, g1a1, g0a1, g1a1, device=dev), + prep_filt_sfb2d(g0a1, g1a1, g0b1, g1b1, device=dev), + ), + ( + prep_filt_sfb2d(g0b1, g1b1, g0a1, g1a1, device=dev), + prep_filt_sfb2d(g0b1, g1b1, g0b1, g1b1, device=dev), + ), + ) + af = ( + ( + prep_filt_sfb2d(g0a, g1a, g0a, g1a, device=dev), + prep_filt_sfb2d(g0a, g1a, g0b, g1b, device=dev), + ), + ( + prep_filt_sfb2d(g0b, g1b, g0a, g1a, device=dev), + prep_filt_sfb2d(g0b, g1b, g0b, g1b, device=dev), + ), + ) + + # Convert the highs back to subbands + J = len(yh) + w = [ + [[[None for i in range(3)] for j in range(2)] for k in range(2)] + for l in range(J) + ] + for j in range(J): + w[j][0][0][0], w[j][1][1][0] = pm( + yh[j][:, 2, :, :, :, 0], yh[j][:, 3, :, :, :, 1] + ) + w[j][0][1][0], w[j][1][0][0] = pm( + yh[j][:, 3, :, :, :, 0], yh[j][:, 2, :, :, :, 1] + ) + w[j][0][0][1], w[j][1][1][1] = pm( + yh[j][:, 0, :, :, :, 0], yh[j][:, 5, :, :, :, 1] + ) + w[j][0][1][1], w[j][1][0][1] = pm( + yh[j][:, 5, :, :, :, 0], yh[j][:, 0, :, :, :, 1] + ) + w[j][0][0][2], w[j][1][1][2] = pm( + yh[j][:, 1, :, :, :, 0], yh[j][:, 4, :, :, :, 1] + ) + w[j][0][1][2], w[j][1][0][2] = pm( + yh[j][:, 4, :, :, :, 0], yh[j][:, 1, :, :, :, 1] + ) + w[j][0][0] = torch.stack(w[j][0][0], dim=2) + w[j][0][1] = torch.stack(w[j][0][1], dim=2) + w[j][1][0] = torch.stack(w[j][1][0], dim=2) + w[j][1][1] = torch.stack(w[j][1][1], dim=2) + + y = None + for m in range(2): + for n in range(2): + lo = yl[m][n] + for j in range(J - 1, 0, -1): + lo = sfb2d(lo, w[j][m][n], af[m][n], mode=mode) + lo = sfb2d(lo, w[0][m][n], Faf[m][n], mode=mode) + + # Add to the output + if y is None: + y = lo + else: + y = y + lo + + # Normalize + y = y / 2 + return y + + +def pm(a, b): + u = (a + b) / np.sqrt(2) + v = (a - b) / np.sqrt(2) + return u, v diff --git a/mirtorch/vendors/pytorch_wavelets/dtcwt/transform2d.py b/mirtorch/vendors/pytorch_wavelets/dtcwt/transform2d.py new file mode 100644 index 0000000..4b1fc32 --- /dev/null +++ b/mirtorch/vendors/pytorch_wavelets/dtcwt/transform2d.py @@ -0,0 +1,299 @@ +import torch +import torch.nn as nn +from numpy import ndarray, sqrt + +from ..dtcwt.coeffs import qshift as _qshift, biort as _biort, level1 +from ..dtcwt.lowlevel import prep_filt +from ..dtcwt.transform_funcs import FWD_J1, FWD_J2PLUS +from ..dtcwt.transform_funcs import INV_J1, INV_J2PLUS +from ..dtcwt.transform_funcs import get_dimensions6 +from ..dwt.lowlevel import mode_to_int +from ..dwt.transform2d import DWTForward, DWTInverse + + +def pm(a, b): + u = (a + b) / sqrt(2) + v = (a - b) / sqrt(2) + return u, v + + +class DTCWTForward(nn.Module): + """Performs a 2d DTCWT Forward decomposition of an image + + Args: + biort (str): One of 'antonini', 'legall', 'near_sym_a', 'near_sym_b'. + Specifies the first level biorthogonal wavelet filters. Can also + give a two tuple for the low and highpass filters directly. + qshift (str): One of 'qshift_06', 'qshift_a', 'qshift_b', 'qshift_c', + 'qshift_d'. Specifies the second level quarter shift filters. Can + also give a 4-tuple for the low tree a, low tree b, high tree a and + high tree b filters directly. + J (int): Number of levels of decomposition + skip_hps (bools): List of bools of length J which specify whether or + not to calculate the bandpass outputs at the given scale. + skip_hps[0] is for the first scale. Can be a single bool in which + case that is applied to all scales. + include_scale (bool): If true, return the bandpass outputs. Can also be + a list of length J specifying which lowpasses to return. I.e. if + [False, True, True], the forward call will return the second and + third lowpass outputs, but discard the lowpass from the first level + transform. + o_dim (int): Which dimension to put the orientations in + ri_dim (int): which dimension to put the real and imaginary parts + """ + + def __init__( + self, + biort="near_sym_a", + qshift="qshift_a", + J=3, + skip_hps=False, + include_scale=False, + o_dim=2, + ri_dim=-1, + mode="symmetric", + ): + super().__init__() + if o_dim == ri_dim: + raise ValueError( + "Orientations and real/imaginary parts must be " + "in different dimensions." + ) + + self.biort = biort + self.qshift = qshift + self.J = J + self.o_dim = o_dim + self.ri_dim = ri_dim + self.mode = mode + if isinstance(biort, str): + h0o, _, h1o, _ = _biort(biort) + self.register_buffer("h0o", prep_filt(h0o, 1)) + self.register_buffer("h1o", prep_filt(h1o, 1)) + else: + self.register_buffer("h0o", prep_filt(biort[0], 1)) + self.register_buffer("h1o", prep_filt(biort[1], 1)) + if isinstance(qshift, str): + h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift) + self.register_buffer("h0a", prep_filt(h0a, 1)) + self.register_buffer("h0b", prep_filt(h0b, 1)) + self.register_buffer("h1a", prep_filt(h1a, 1)) + self.register_buffer("h1b", prep_filt(h1b, 1)) + else: + self.register_buffer("h0a", prep_filt(qshift[0], 1)) + self.register_buffer("h0b", prep_filt(qshift[1], 1)) + self.register_buffer("h1a", prep_filt(qshift[2], 1)) + self.register_buffer("h1b", prep_filt(qshift[3], 1)) + + # Get the function to do the DTCWT + if isinstance(skip_hps, (list, tuple, ndarray)): + self.skip_hps = skip_hps + else: + self.skip_hps = [ + skip_hps, + ] * self.J + if isinstance(include_scale, (list, tuple, ndarray)): + self.include_scale = include_scale + else: + self.include_scale = [ + include_scale, + ] * self.J + + def forward(self, x): + """Forward Dual Tree Complex Wavelet Transform + + Args: + x (tensor): Input to transform. Should be of shape + :math:`(N, C_{in}, H_{in}, W_{in})`. + + Returns: + (yl, yh) + tuple of lowpass (yl) and bandpass (yh) coefficients. + If include_scale was true, yl will be a list of lowpass + coefficients, otherwise will be just the final lowpass + coefficient of shape :math:`(N, C_{in}, H_{in}', W_{in}')`. Yh + will be a list of the complex bandpass coefficients of shape + :math:`list(N, C_{in}, 6, H_{in}'', W_{in}'', 2)`, or similar + shape depending on o_dim and ri_dim + + Note: + :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` are the shapes of a + DTCWT pyramid. + """ + scales = [ + x.new_zeros([]), + ] * self.J + highs = [ + x.new_zeros([]), + ] * self.J + mode = mode_to_int(self.mode) + if self.J == 0: + return x, None + + # If the row/col count of X is not divisible by 2 then we need to + # extend X + r, c = x.shape[2:] + if r % 2 != 0: + x = torch.cat((x, x[:, :, -1:]), dim=2) + if c % 2 != 0: + x = torch.cat((x, x[:, :, :, -1:]), dim=3) + + # Do the level 1 transform + low, h = FWD_J1.apply( + x, self.h0o, self.h1o, self.skip_hps[0], self.o_dim, self.ri_dim, mode + ) + highs[0] = h + if self.include_scale[0]: + scales[0] = low + + for j in range(1, self.J): + # Ensure the lowpass is divisible by 4 + r, c = low.shape[2:] + if r % 4 != 0: + low = torch.cat((low[:, :, 0:1], low, low[:, :, -1:]), dim=2) + if c % 4 != 0: + low = torch.cat((low[:, :, :, 0:1], low, low[:, :, :, -1:]), dim=3) + + low, h = FWD_J2PLUS.apply( + low, + self.h0a, + self.h1a, + self.h0b, + self.h1b, + self.skip_hps[j], + self.o_dim, + self.ri_dim, + mode, + ) + highs[j] = h + if self.include_scale[j]: + scales[j] = low + + if True in self.include_scale: + return scales, highs + else: + return low, highs + + +class DTCWTInverse(nn.Module): + """2d DTCWT Inverse + + Args: + biort (str): One of 'antonini', 'legall', 'near_sym_a', 'near_sym_b'. + Specifies the first level biorthogonal wavelet filters. Can also + give a two tuple for the low and highpass filters directly. + qshift (str): One of 'qshift_06', 'qshift_a', 'qshift_b', 'qshift_c', + 'qshift_d'. Specifies the second level quarter shift filters. Can + also give a 4-tuple for the low tree a, low tree b, high tree a and + high tree b filters directly. + J (int): Number of levels of decomposition. + o_dim (int):which dimension the orientations are in + ri_dim (int): which dimension to put th real and imaginary parts in + """ + + def __init__( + self, + biort="near_sym_a", + qshift="qshift_a", + o_dim=2, + ri_dim=-1, + mode="symmetric", + ): + super().__init__() + self.biort = biort + self.qshift = qshift + self.o_dim = o_dim + self.ri_dim = ri_dim + self.mode = mode + if isinstance(biort, str): + _, g0o, _, g1o = _biort(biort) + self.register_buffer("g0o", prep_filt(g0o, 1)) + self.register_buffer("g1o", prep_filt(g1o, 1)) + else: + self.register_buffer("g0o", prep_filt(biort[0], 1)) + self.register_buffer("g1o", prep_filt(biort[1], 1)) + if isinstance(qshift, str): + _, _, g0a, g0b, _, _, g1a, g1b = _qshift(qshift) + self.register_buffer("g0a", prep_filt(g0a, 1)) + self.register_buffer("g0b", prep_filt(g0b, 1)) + self.register_buffer("g1a", prep_filt(g1a, 1)) + self.register_buffer("g1b", prep_filt(g1b, 1)) + else: + self.register_buffer("g0a", prep_filt(qshift[0], 1)) + self.register_buffer("g0b", prep_filt(qshift[1], 1)) + self.register_buffer("g1a", prep_filt(qshift[2], 1)) + self.register_buffer("g1b", prep_filt(qshift[3], 1)) + + def forward(self, coeffs): + """ + Args: + coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where: + yl is a tensor of shape :math:`(N, C_{in}, H_{in}', W_{in}')` + and yh is a list of the complex bandpass coefficients of shape + :math:`list(N, C_{in}, 6, H_{in}'', W_{in}'', 2)`, or similar + depending on o_dim and ri_dim + + Returns: + Reconstructed output + + Note: + Can accept Nones or an empty tensor (torch.tensor([])) for the + lowpass or bandpass inputs. In this cases, an array of zeros + replaces that input. + + Note: + :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` are the shapes of a + DTCWT pyramid. + + Note: + If include_scale was true for the forward pass, you should provide + only the final lowpass output here, as normal for an inverse wavelet + transform. + """ + low, highs = coeffs + J = len(highs) + mode = mode_to_int(self.mode) + _, _, h_dim, w_dim = get_dimensions6(self.o_dim, self.ri_dim) + for j, s in zip(range(J - 1, 0, -1), highs[1:][::-1]): + if s is not None and s.shape != torch.Size([]): + assert s.shape[self.o_dim] == 6, ( + "Inverse transform must " "have input with 6 orientations" + ) + assert len(s.shape) == 6, "Bandpass inputs must have " "6 dimensions" + assert s.shape[self.ri_dim] == 2, ( + "Inputs must be complex " + "with real and imaginary parts in the ri dimension" + ) + # Ensure the low and highpass are the right size + r, c = low.shape[2:] + r1, c1 = s.shape[h_dim], s.shape[w_dim] + if r != r1 * 2: + low = low[:, :, 1:-1] + if c != c1 * 2: + low = low[:, :, :, 1:-1] + + low = INV_J2PLUS.apply( + low, + s, + self.g0a, + self.g1a, + self.g0b, + self.g1b, + self.o_dim, + self.ri_dim, + mode, + ) + + # Ensure the low and highpass are the right size + if highs[0] is not None and highs[0].shape != torch.Size([]): + r, c = low.shape[2:] + r1, c1 = highs[0].shape[h_dim], highs[0].shape[w_dim] + if r != r1 * 2: + low = low[:, :, 1:-1] + if c != c1 * 2: + low = low[:, :, :, 1:-1] + + low = INV_J1.apply( + low, highs[0], self.g0o, self.g1o, self.o_dim, self.ri_dim, mode + ) + return low diff --git a/mirtorch/vendors/pytorch_wavelets/dtcwt/transform_funcs.py b/mirtorch/vendors/pytorch_wavelets/dtcwt/transform_funcs.py new file mode 100644 index 0000000..1d49eb8 --- /dev/null +++ b/mirtorch/vendors/pytorch_wavelets/dtcwt/transform_funcs.py @@ -0,0 +1,495 @@ +import torch +from torch import tensor +from torch.autograd import Function +from ..dtcwt.lowlevel import colfilter, rowfilter +from ..dtcwt.lowlevel import coldfilt, rowdfilt +from ..dtcwt.lowlevel import colifilt, rowifilt, q2c, c2q +from ..dwt.lowlevel import int_to_mode + + +def get_dimensions5(o_dim, ri_dim): + """Get the orientation, height and width dimensions after the real and + imaginary parts have been popped off (5 dimensional tensor).""" + o_dim = o_dim % 6 + ri_dim = ri_dim % 6 + + if ri_dim < o_dim: + o_dim -= 1 + + if o_dim == 4: + h_dim = 2 + w_dim = 3 + elif o_dim == 3: + h_dim = 2 + w_dim = 4 + else: + h_dim = 3 + w_dim = 4 + + return o_dim, ri_dim, h_dim, w_dim + + +def get_dimensions6(o_dim, ri_dim): + """Get the orientation, real/imag, height and width dimensions + for the full tensor (6 dimensions).""" + # Calculate which dimension to put the real and imaginary parts and the + # orientations. Also work out where the rows and columns in the original + # image were + o_dim = o_dim % 6 + ri_dim = ri_dim % 6 + + if ri_dim < o_dim: + o_dim -= 1 + + if o_dim >= 3 and ri_dim >= 3: + h_dim = 2 + elif o_dim >= 4 or ri_dim >= 4: + h_dim = 3 + else: + h_dim = 4 + + if o_dim >= 4 and ri_dim >= 4: + w_dim = 3 + elif o_dim >= 4 or ri_dim >= 4: + w_dim = 4 + else: + w_dim = 5 + + return o_dim, ri_dim, h_dim, w_dim + + +def highs_to_orientations(lh, hl, hh, o_dim): + (deg15r, deg15i), (deg165r, deg165i) = q2c(lh) + (deg45r, deg45i), (deg135r, deg135i) = q2c(hh) + (deg75r, deg75i), (deg105r, deg105i) = q2c(hl) + + # Convert real and imaginary to magnitude + reals = torch.stack([deg15r, deg45r, deg75r, deg105r, deg135r, deg165r], dim=o_dim) + imags = torch.stack([deg15i, deg45i, deg75i, deg105i, deg135i, deg165i], dim=o_dim) + + return reals, imags + + +def orientations_to_highs(reals, imags, o_dim): + dev = reals.device + horiz = torch.index_select(reals, o_dim, tensor([0, 5], device=dev)) + diag = torch.index_select(reals, o_dim, tensor([1, 4], device=dev)) + vertic = torch.index_select(reals, o_dim, tensor([2, 3], device=dev)) + deg15r, deg165r = torch.unbind(horiz, dim=o_dim) + deg45r, deg135r = torch.unbind(diag, dim=o_dim) + deg75r, deg105r = torch.unbind(vertic, dim=o_dim) + dev = imags.device + horiz = torch.index_select(imags, o_dim, tensor([0, 5], device=dev)) + diag = torch.index_select(imags, o_dim, tensor([1, 4], device=dev)) + vertic = torch.index_select(imags, o_dim, tensor([2, 3], device=dev)) + deg15i, deg165i = torch.unbind(horiz, dim=o_dim) + deg45i, deg135i = torch.unbind(diag, dim=o_dim) + deg75i, deg105i = torch.unbind(vertic, dim=o_dim) + + lh = c2q((deg15r, deg15i), (deg165r, deg165i)) + hl = c2q((deg75r, deg75i), (deg105r, deg105i)) + hh = c2q((deg45r, deg45i), (deg135r, deg135i)) + + return lh, hl, hh + + +def fwd_j1(x, h0, h1, skip_hps, o_dim, mode): + """Level 1 forward dtcwt. + + Have it as a separate function as can be used by + the forward pass of the forward transform and the backward pass of the + inverse transform. + """ + # Level 1 forward (biorthogonal analysis filters) + if not skip_hps: + lo = rowfilter(x, h0, mode) + hi = rowfilter(x, h1, mode) + ll = colfilter(lo, h0, mode) + lh = colfilter(lo, h1, mode) + del lo + hl = colfilter(hi, h0, mode) + hh = colfilter(hi, h1, mode) + del hi + highr, highi = highs_to_orientations(lh, hl, hh, o_dim) + else: + ll = rowfilter(x, h0, mode) + ll = colfilter(ll, h0, mode) + highr = x.new_zeros([]) + highi = x.new_zeros([]) + return ll, highr, highi + + +def fwd_j1_rot(x, h0, h1, h2, skip_hps, o_dim, mode): + """Level 1 forward dtcwt. + + Have it as a separate function as can be used by + the forward pass of the forward transform and the backward pass of the + inverse transform. + """ + # Level 1 forward (biorthogonal analysis filters) + if not skip_hps: + lo = rowfilter(x, h0, mode) + hi = rowfilter(x, h1, mode) + ba = rowfilter(x, h2, mode) + + lh = colfilter(lo, h1, mode) + hl = colfilter(hi, h0, mode) + hh = colfilter(ba, h2, mode) + ll = colfilter(lo, h0, mode) + + del lo, hi, ba + highr, highi = highs_to_orientations(lh, hl, hh, o_dim) + else: + ll = rowfilter(x, h0, mode) + ll = colfilter(ll, h0, mode) + highr = x.new_zeros([]) + highi = x.new_zeros([]) + return ll, highr, highi + + +def inv_j1(ll, highr, highi, g0, g1, o_dim, h_dim, w_dim, mode): + """Level1 inverse dtcwt. + + Have it as a separate function as can be used by the forward pass of the + inverse transform and the backward pass of the forward transform. + """ + if highr is None or highr.shape == torch.Size([]): + y = rowfilter(colfilter(ll, g0), g0) + else: + # Get the double sampled bandpass coefficients + lh, hl, hh = orientations_to_highs(highr, highi, o_dim) + + if ll is None or ll.shape == torch.Size([]): + # Interpolate + hi = colfilter(hh, g1, mode) + colfilter(hl, g0, mode) + lo = colfilter(lh, g1, mode) + del lh, hh, hl + else: + # Possibly cut back some rows to make the ll match the highs + r, c = ll.shape[2:] + r1, c1 = highr.shape[h_dim], highr.shape[w_dim] + if r != r1 * 2: + ll = ll[:, :, 1:-1] + if c != c1 * 2: + ll = ll[:, :, :, 1:-1] + # Interpolate + hi = colfilter(hh, g1, mode) + colfilter(hl, g0, mode) + lo = colfilter(lh, g1, mode) + colfilter(ll, g0, mode) + del lh, hl, hh + + y = rowfilter(hi, g1, mode) + rowfilter(lo, g0, mode) + + return y + + +def inv_j1_rot(ll, highr, highi, g0, g1, g2, o_dim, h_dim, w_dim, mode): + """Level1 inverse dtcwt. + + Have it as a separate function as can be used by the forward pass of the + inverse transform and the backward pass of the forward transform. + """ + if highr is None or highr.shape == torch.Size([]): + y = rowfilter(colfilter(ll, g0), g0) + else: + # Get the double sampled bandpass coefficients + lh, hl, hh = orientations_to_highs(highr, highi, o_dim) + + if ll is None or ll.shape == torch.Size([]): + # Interpolate + lo = colfilter(lh, g1, mode) + hi = colfilter(hl, g0, mode) + ba = colfilter(hh, g2, mode) + del lh, hh, hl + else: + # Possibly cut back some rows to make the ll match the highs + r, c = ll.shape[2:] + r1, c1 = highr.shape[h_dim], highr.shape[w_dim] + if r != r1 * 2: + ll = ll[:, :, 1:-1] + if c != c1 * 2: + ll = ll[:, :, :, 1:-1] + + # Interpolate + lo = colfilter(lh, g1, mode) + colfilter(ll, g0, mode) + hi = colfilter(hl, g0, mode) + ba = colfilter(hh, g2, mode) + del lh, hl, hh + + y = rowfilter(hi, g1, mode) + rowfilter(lo, g0, mode) + rowfilter(ba, g2, mode) + + return y + + +def fwd_j2plus(x, h0a, h1a, h0b, h1b, skip_hps, o_dim, mode): + """Level 2 plus forward dtcwt. + + Have it as a separate function as can be used by + the forward pass of the forward transform and the backward pass of the + inverse transform. + """ + if not skip_hps: + lo = rowdfilt(x, h0b, h0a, False, mode) + hi = rowdfilt(x, h1b, h1a, True, mode) + + ll = coldfilt(lo, h0b, h0a, False, mode) + lh = coldfilt(lo, h1b, h1a, True, mode) + hl = coldfilt(hi, h0b, h0a, False, mode) + hh = coldfilt(hi, h1b, h1a, True, mode) + del lo, hi + highr, highi = highs_to_orientations(lh, hl, hh, o_dim) + else: + ll = rowdfilt(x, h0b, h0a, False, mode) + ll = coldfilt(ll, h0b, h0a, False, mode) + highr = None + highi = None + + return ll, highr, highi + + +def fwd_j2plus_rot(x, h0a, h1a, h0b, h1b, h2a, h2b, skip_hps, o_dim, mode): + """Level 2 plus forward dtcwt. + + Have it as a separate function as can be used by + the forward pass of the forward transform and the backward pass of the + inverse transform. + """ + if not skip_hps: + lo = rowdfilt(x, h0b, h0a, False, mode) + hi = rowdfilt(x, h1b, h1a, True, mode) + ba = rowdfilt(x, h2b, h2a, True, mode) + + lh = coldfilt(lo, h1b, h1a, True, mode) + hl = coldfilt(hi, h0b, h0a, False, mode) + hh = coldfilt(ba, h2b, h2a, True, mode) + ll = coldfilt(lo, h0b, h0a, False, mode) + del lo, hi, ba + highr, highi = highs_to_orientations(lh, hl, hh, o_dim) + else: + ll = rowdfilt(x, h0b, h0a, False, mode) + ll = coldfilt(ll, h0b, h0a, False, mode) + highr = None + highi = None + + return ll, highr, highi + + +def inv_j2plus(ll, highr, highi, g0a, g1a, g0b, g1b, o_dim, h_dim, w_dim, mode): + """Level2+ inverse dtcwt. + + Have it as a separate function as can be used by the forward pass of the + inverse transform and the backward pass of the forward transform. + """ + if highr is None or highr.shape == torch.Size([]): + y = rowifilt(colifilt(ll, g0b, g0a, False, mode), g0b, g0a, False, mode) + else: + # Get the double sampled bandpass coefficients + lh, hl, hh = orientations_to_highs(highr, highi, o_dim) + + if ll is None or ll.shape == torch.Size([]): + # Interpolate + hi = colifilt(hh, g1b, g1a, True, mode) + colifilt( + hl, g0b, g0a, False, mode + ) + lo = colifilt(lh, g1b, g1a, True, mode) + del lh, hh, hl + else: + # Interpolate + hi = colifilt(hh, g1b, g1a, True, mode) + colifilt( + hl, g0b, g0a, False, mode + ) + lo = colifilt(lh, g1b, g1a, True, mode) + colifilt( + ll, g0b, g0a, False, mode + ) + del lh, hl, hh + + y = rowifilt(hi, g1b, g1a, True, mode) + rowifilt(lo, g0b, g0a, False, mode) + return y + + +def inv_j2plus_rot( + ll, highr, highi, g0a, g1a, g0b, g1b, g2a, g2b, o_dim, h_dim, w_dim, mode +): + """Level2+ inverse dtcwt. + + Have it as a separate function as can be used by the forward pass of the + inverse transform and the backward pass of the forward transform. + """ + if highr is None or highr.shape == torch.Size([]): + y = rowifilt(colifilt(ll, g0b, g0a, False, mode), g0b, g0a, False, mode) + else: + # Get the double sampled bandpass coefficients + lh, hl, hh = orientations_to_highs(highr, highi, o_dim) + + if ll is None or ll.shape == torch.Size([]): + # Interpolate + lo = colifilt(lh, g1b, g1a, True, mode) + hi = colifilt(hl, g0b, g0a, False, mode) + ba = colifilt(hh, g2b, g2a, True, mode) + del lh, hh, hl + else: + # Interpolate + lo = colifilt(lh, g1b, g1a, True, mode) + colifilt( + ll, g0b, g0a, False, mode + ) + hi = colifilt(hl, g0b, g0a, False, mode) + ba = colifilt(hh, g2b, g2a, True, mode) + del lh, hl, hh + + y = ( + rowifilt(hi, g1b, g1a, True, mode) + + rowifilt(lo, g0b, g0a, False, mode) + + rowifilt(ba, g2b, g2a, True, mode) + ) + return y + + +class FWD_J1(Function): + """Differentiable function doing 1 level forward DTCWT""" + + @staticmethod + def forward(ctx, x, h0, h1, skip_hps, o_dim, ri_dim, mode): + mode = int_to_mode(mode) + ctx.mode = mode + ctx.save_for_backward(h0, h1) + ctx.dims = get_dimensions5(o_dim, ri_dim) + o_dim, ri_dim = ctx.dims[0], ctx.dims[1] + + ll, highr, highi = fwd_j1(x, h0, h1, skip_hps, o_dim, mode) + if not skip_hps: + highs = torch.stack((highr, highi), dim=ri_dim) + else: + highs = ll.new_zeros([]) + return ll, highs + + @staticmethod + def backward(ctx, dl, dh): + h0, h1 = ctx.saved_tensors + mode = ctx.mode + dx = None + if ctx.needs_input_grad[0]: + o_dim, ri_dim, h_dim, w_dim = ctx.dims + if dh is not None and dh.shape != torch.Size([]): + dhr, dhi = torch.unbind(dh, dim=ri_dim) + else: + dhr = dl.new_zeros([]) + dhi = dl.new_zeros([]) + dx = inv_j1(dl, dhr, dhi, h0, h1, o_dim, h_dim, w_dim, mode) + + return dx, None, None, None, None, None, None + + +class FWD_J2PLUS(Function): + """Differentiable function doing second level forward DTCWT""" + + @staticmethod + def forward(ctx, x, h0a, h1a, h0b, h1b, skip_hps, o_dim, ri_dim, mode): + mode = "symmetric" + ctx.mode = mode + ctx.save_for_backward(h0a, h1a, h0b, h1b) + ctx.dims = get_dimensions5(o_dim, ri_dim) + o_dim, ri_dim = ctx.dims[0], ctx.dims[1] + + ll, highr, highi = fwd_j2plus(x, h0a, h1a, h0b, h1b, skip_hps, o_dim, mode) + if not skip_hps: + highs = torch.stack((highr, highi), dim=ri_dim) + else: + highs = ll.new_zeros([]) + return ll, highs + + @staticmethod + def backward(ctx, dl, dh): + h0a, h1a, h0b, h1b = ctx.saved_tensors + mode = ctx.mode + # The colifilt and rowifilt functions use conv2d not conv2d_transpose, + # so need to reverse the filters + h0a, h0b = h0b, h0a + h1a, h1b = h1b, h1a + dx = None + if ctx.needs_input_grad[0]: + o_dim, ri_dim, h_dim, w_dim = ctx.dims + if dh is not None and dh.shape != torch.Size([]): + dhr, dhi = torch.unbind(dh, dim=ri_dim) + else: + dhr = dl.new_zeros([]) + dhi = dl.new_zeros([]) + dx = inv_j2plus(dl, dhr, dhi, h0a, h1a, h0b, h1b, o_dim, h_dim, w_dim, mode) + + return dx, None, None, None, None, None, None, None, None + + +class INV_J1(Function): + """Differentiable function doing 1 level inverse DTCWT""" + + @staticmethod + def forward(ctx, lows, highs, g0, g1, o_dim, ri_dim, mode): + mode = int_to_mode(mode) + ctx.mode = mode + ctx.save_for_backward(g0, g1) + ctx.dims = get_dimensions5(o_dim, ri_dim) + o_dim, ri_dim, h_dim, w_dim = ctx.dims + if highs is not None and highs.shape != torch.Size([]): + highr, highi = torch.unbind(highs, dim=ri_dim) + else: + highr = lows.new_zeros([]) + highi = lows.new_zeros([]) + y = inv_j1(lows, highr, highi, g0, g1, o_dim, h_dim, w_dim, mode) + return y + + @staticmethod + def backward(ctx, dy): + g0, g1 = ctx.saved_tensors + dl = None + dh = None + o_dim, ri_dim = ctx.dims[0], ctx.dims[1] + mode = ctx.mode + if ctx.needs_input_grad[0] and not ctx.needs_input_grad[1]: + dl, _, _ = fwd_j1(dy, g0, g1, True, o_dim, mode) + elif ctx.needs_input_grad[1] and not ctx.needs_input_grad[0]: + _, dhr, dhi = fwd_j1(dy, g0, g1, False, o_dim, mode) + dh = torch.stack((dhr, dhi), dim=ri_dim) + elif ctx.needs_input_grad[0] and ctx.needs_input_grad[1]: + dl, dhr, dhi = fwd_j1(dy, g0, g1, False, o_dim, mode) + dh = torch.stack((dhr, dhi), dim=ri_dim) + + return dl, dh, None, None, None, None, None + + +class INV_J2PLUS(Function): + """Differentiable function doing level 2 onwards inverse DTCWT""" + + @staticmethod + def forward(ctx, lows, highs, g0a, g1a, g0b, g1b, o_dim, ri_dim, mode): + mode = "symmetric" + ctx.mode = mode + ctx.save_for_backward(g0a, g1a, g0b, g1b) + ctx.dims = get_dimensions5(o_dim, ri_dim) + o_dim, ri_dim, h_dim, w_dim = ctx.dims + if highs is not None and highs.shape != torch.Size([]): + highr, highi = torch.unbind(highs, dim=ri_dim) + else: + highr = lows.new_zeros([]) + highi = lows.new_zeros([]) + y = inv_j2plus( + lows, highr, highi, g0a, g1a, g0b, g1b, o_dim, h_dim, w_dim, mode + ) + return y + + @staticmethod + def backward(ctx, dy): + g0a, g1a, g0b, g1b = ctx.saved_tensors + g0a, g0b = g0b, g0a + g1a, g1b = g1b, g1a + o_dim, ri_dim = ctx.dims[0], ctx.dims[1] + mode = ctx.mode + dl = None + dh = None + if ctx.needs_input_grad[0] and not ctx.needs_input_grad[1]: + dl, _, _ = fwd_j2plus(dy, g0a, g1a, g0b, g1b, True, o_dim, mode) + elif ctx.needs_input_grad[1] and not ctx.needs_input_grad[0]: + _, dhr, dhi = fwd_j2plus(dy, g0a, g1a, g0b, g1b, False, o_dim, mode) + dh = torch.stack((dhr, dhi), dim=ri_dim) + elif ctx.needs_input_grad[0] and ctx.needs_input_grad[1]: + dl, dhr, dhi = fwd_j2plus(dy, g0a, g1a, g0b, g1b, False, o_dim, mode) + dh = torch.stack((dhr, dhi), dim=ri_dim) + + return dl, dh, None, None, None, None, None, None, None diff --git a/mirtorch/vendors/pytorch_wavelets/dwt/__init__.py b/mirtorch/vendors/pytorch_wavelets/dwt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mirtorch/vendors/pytorch_wavelets/dwt/lowlevel.py b/mirtorch/vendors/pytorch_wavelets/dwt/lowlevel.py new file mode 100644 index 0000000..bfd2264 --- /dev/null +++ b/mirtorch/vendors/pytorch_wavelets/dwt/lowlevel.py @@ -0,0 +1,997 @@ +import torch +import torch.nn.functional as F +import numpy as np +from torch.autograd import Function +from ..utils import reflect +import pywt + + +def roll(x, n, dim, make_even=False): + if n < 0: + n = x.shape[dim] + n + + if make_even and x.shape[dim] % 2 == 1: + end = 1 + else: + end = 0 + + if dim == 0: + return torch.cat((x[-n:], x[: -n + end]), dim=0) + elif dim == 1: + return torch.cat((x[:, -n:], x[:, : -n + end]), dim=1) + elif dim == 2 or dim == -2: + return torch.cat((x[:, :, -n:], x[:, :, : -n + end]), dim=2) + elif dim == 3 or dim == -1: + return torch.cat((x[:, :, :, -n:], x[:, :, :, : -n + end]), dim=3) + + +def mypad(x, pad, mode="constant", value=0): + """Function to do numpy like padding on tensors. Only works for 2-D + padding. + + Inputs: + x (tensor): tensor to pad + pad (tuple): tuple of (left, right, top, bottom) pad sizes + mode (str): 'symmetric', 'wrap', 'constant, 'reflect', 'replicate', or + 'zero'. The padding technique. + """ + if mode == "symmetric": + # Vertical only + if pad[0] == 0 and pad[1] == 0: + m1, m2 = pad[2], pad[3] + l = x.shape[-2] + xe = reflect(np.arange(-m1, l + m2, dtype="int32"), -0.5, l - 0.5) + return x[:, :, xe] + # horizontal only + elif pad[2] == 0 and pad[3] == 0: + m1, m2 = pad[0], pad[1] + l = x.shape[-1] + xe = reflect(np.arange(-m1, l + m2, dtype="int32"), -0.5, l - 0.5) + return x[:, :, :, xe] + # Both + else: + m1, m2 = pad[0], pad[1] + l1 = x.shape[-1] + xe_row = reflect(np.arange(-m1, l1 + m2, dtype="int32"), -0.5, l1 - 0.5) + m1, m2 = pad[2], pad[3] + l2 = x.shape[-2] + xe_col = reflect(np.arange(-m1, l2 + m2, dtype="int32"), -0.5, l2 - 0.5) + i = np.outer(xe_col, np.ones(xe_row.shape[0])) + j = np.outer(np.ones(xe_col.shape[0]), xe_row) + return x[:, :, i, j] + elif mode == "periodic": + # Vertical only + if pad[0] == 0 and pad[1] == 0: + xe = np.arange(x.shape[-2]) + xe = np.pad(xe, (pad[2], pad[3]), mode="wrap") + return x[:, :, xe] + # Horizontal only + elif pad[2] == 0 and pad[3] == 0: + xe = np.arange(x.shape[-1]) + xe = np.pad(xe, (pad[0], pad[1]), mode="wrap") + return x[:, :, :, xe] + # Both + else: + xe_col = np.arange(x.shape[-2]) + xe_col = np.pad(xe_col, (pad[2], pad[3]), mode="wrap") + xe_row = np.arange(x.shape[-1]) + xe_row = np.pad(xe_row, (pad[0], pad[1]), mode="wrap") + i = np.outer(xe_col, np.ones(xe_row.shape[0])) + j = np.outer(np.ones(xe_col.shape[0]), xe_row) + return x[:, :, i, j] + + elif mode == "constant" or mode == "reflect" or mode == "replicate": + return F.pad(x, pad, mode, value) + elif mode == "zero": + return F.pad(x, pad) + else: + raise ValueError("Unkown pad type: {}".format(mode)) + + +def afb1d(x, h0, h1, mode="zero", dim=-1): + """1D analysis filter bank (along one dimension only) of an image + + Inputs: + x (tensor): 4D input with the last two dimensions the spatial input + h0 (tensor): 4D input for the lowpass filter. Should have shape (1, 1, + h, 1) or (1, 1, 1, w) + h1 (tensor): 4D input for the highpass filter. Should have shape (1, 1, + h, 1) or (1, 1, 1, w) + mode (str): padding method + dim (int) - dimension of filtering. d=2 is for a vertical filter (called + column filtering but filters across the rows). d=3 is for a + horizontal filter, (called row filtering but filters across the + columns). + + Returns: + lohi: lowpass and highpass subbands concatenated along the channel + dimension + """ + C = x.shape[1] + # Convert the dim to positive + d = dim % 4 + s = (2, 1) if d == 2 else (1, 2) + N = x.shape[d] + # If h0, h1 are not tensors, make them. If they are, then assume that they + # are in the right order + if not isinstance(h0, torch.Tensor): + h0 = torch.tensor( + np.copy(np.array(h0).ravel()[::-1]), dtype=torch.float, device=x.device + ) + if not isinstance(h1, torch.Tensor): + h1 = torch.tensor( + np.copy(np.array(h1).ravel()[::-1]), dtype=torch.float, device=x.device + ) + L = h0.numel() + L2 = L // 2 + shape = [1, 1, 1, 1] + shape[d] = L + # If h aren't in the right shape, make them so + if h0.shape != tuple(shape): + h0 = h0.reshape(*shape) + if h1.shape != tuple(shape): + h1 = h1.reshape(*shape) + h = torch.cat([h0, h1] * C, dim=0) + + if mode == "per" or mode == "periodization": + if x.shape[dim] % 2 == 1: + if d == 2: + x = torch.cat((x, x[:, :, -1:]), dim=2) + else: + x = torch.cat((x, x[:, :, :, -1:]), dim=3) + N += 1 + x = roll(x, -L2, dim=d) + pad = (L - 1, 0) if d == 2 else (0, L - 1) + lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C) + N2 = N // 2 + if d == 2: + lohi[:, :, :L2] = lohi[:, :, :L2] + lohi[:, :, N2 : N2 + L2] + lohi = lohi[:, :, :N2] + else: + lohi[:, :, :, :L2] = lohi[:, :, :, :L2] + lohi[:, :, :, N2 : N2 + L2] + lohi = lohi[:, :, :, :N2] + else: + # Calculate the pad size + outsize = pywt.dwt_coeff_len(N, L, mode=mode) + p = 2 * (outsize - 1) - N + L + if mode == "zero": + # Sadly, pytorch only allows for same padding before and after, if + # we need to do more padding after for odd length signals, have to + # prepad + if p % 2 == 1: + pad = (0, 0, 0, 1) if d == 2 else (0, 1, 0, 0) + x = F.pad(x, pad) + pad = (p // 2, 0) if d == 2 else (0, p // 2) + # Calculate the high and lowpass + lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C) + elif mode == "symmetric" or mode == "reflect" or mode == "periodic": + pad = ( + (0, 0, p // 2, (p + 1) // 2) if d == 2 else (p // 2, (p + 1) // 2, 0, 0) + ) + x = mypad(x, pad=pad, mode=mode) + lohi = F.conv2d(x, h, stride=s, groups=C) + else: + raise ValueError("Unkown pad type: {}".format(mode)) + + return lohi + + +def afb1d_atrous(x, h0, h1, mode="periodic", dim=-1, dilation=1): + """1D analysis filter bank (along one dimension only) of an image without + downsampling. Does the a trous algorithm. + + Inputs: + x (tensor): 4D input with the last two dimensions the spatial input + h0 (tensor): 4D input for the lowpass filter. Should have shape (1, 1, + h, 1) or (1, 1, 1, w) + h1 (tensor): 4D input for the highpass filter. Should have shape (1, 1, + h, 1) or (1, 1, 1, w) + mode (str): padding method + dim (int) - dimension of filtering. d=2 is for a vertical filter (called + column filtering but filters across the rows). d=3 is for a + horizontal filter, (called row filtering but filters across the + columns). + dilation (int): dilation factor. Should be a power of 2. + + Returns: + lohi: lowpass and highpass subbands concatenated along the channel + dimension + """ + C = x.shape[1] + # Convert the dim to positive + d = dim % 4 + # If h0, h1 are not tensors, make them. If they are, then assume that they + # are in the right order + if not isinstance(h0, torch.Tensor): + h0 = torch.tensor( + np.copy(np.array(h0).ravel()[::-1]), dtype=torch.float, device=x.device + ) + if not isinstance(h1, torch.Tensor): + h1 = torch.tensor( + np.copy(np.array(h1).ravel()[::-1]), dtype=torch.float, device=x.device + ) + L = h0.numel() + shape = [1, 1, 1, 1] + shape[d] = L + # If h aren't in the right shape, make them so + if h0.shape != tuple(shape): + h0 = h0.reshape(*shape) + if h1.shape != tuple(shape): + h1 = h1.reshape(*shape) + h = torch.cat([h0, h1] * C, dim=0) + + # Calculate the pad size + L2 = (L * dilation) // 2 + pad = (0, 0, L2 - dilation, L2) if d == 2 else (L2 - dilation, L2, 0, 0) + x = mypad(x, pad=pad, mode=mode) + lohi = F.conv2d(x, h, groups=C, dilation=dilation) + + return lohi + + +def sfb1d(lo, hi, g0, g1, mode="zero", dim=-1): + """1D synthesis filter bank of an image tensor""" + C = lo.shape[1] + d = dim % 4 + # If g0, g1 are not tensors, make them. If they are, then assume that they + # are in the right order + if not isinstance(g0, torch.Tensor): + g0 = torch.tensor( + np.copy(np.array(g0).ravel()), dtype=torch.float, device=lo.device + ) + if not isinstance(g1, torch.Tensor): + g1 = torch.tensor( + np.copy(np.array(g1).ravel()), dtype=torch.float, device=lo.device + ) + L = g0.numel() + shape = [1, 1, 1, 1] + shape[d] = L + N = 2 * lo.shape[d] + # If g aren't in the right shape, make them so + if g0.shape != tuple(shape): + g0 = g0.reshape(*shape) + if g1.shape != tuple(shape): + g1 = g1.reshape(*shape) + + s = (2, 1) if d == 2 else (1, 2) + g0 = torch.cat([g0] * C, dim=0) + g1 = torch.cat([g1] * C, dim=0) + if mode == "per" or mode == "periodization": + y = F.conv_transpose2d(lo, g0, stride=s, groups=C) + F.conv_transpose2d( + hi, g1, stride=s, groups=C + ) + if d == 2: + y[:, :, : L - 2] = y[:, :, : L - 2] + y[:, :, N : N + L - 2] + y = y[:, :, :N] + else: + y[:, :, :, : L - 2] = y[:, :, :, : L - 2] + y[:, :, :, N : N + L - 2] + y = y[:, :, :, :N] + y = roll(y, 1 - L // 2, dim=dim) + else: + if ( + mode == "zero" + or mode == "symmetric" + or mode == "reflect" + or mode == "periodic" + ): + pad = (L - 2, 0) if d == 2 else (0, L - 2) + y = F.conv_transpose2d( + lo, g0, stride=s, padding=pad, groups=C + ) + F.conv_transpose2d(hi, g1, stride=s, padding=pad, groups=C) + else: + raise ValueError("Unkown pad type: {}".format(mode)) + + return y + + +def mode_to_int(mode): + if mode == "zero": + return 0 + elif mode == "symmetric": + return 1 + elif mode == "per" or mode == "periodization": + return 2 + elif mode == "constant": + return 3 + elif mode == "reflect": + return 4 + elif mode == "replicate": + return 5 + elif mode == "periodic": + return 6 + else: + raise ValueError("Unkown pad type: {}".format(mode)) + + +def int_to_mode(mode): + if mode == 0: + return "zero" + elif mode == 1: + return "symmetric" + elif mode == 2: + return "periodization" + elif mode == 3: + return "constant" + elif mode == 4: + return "reflect" + elif mode == 5: + return "replicate" + elif mode == 6: + return "periodic" + else: + raise ValueError("Unkown pad type: {}".format(mode)) + + +class AFB2D(Function): + """Does a single level 2d wavelet decomposition of an input. Does separate + row and column filtering by two calls to + :py:func:`pytorch_wavelets.dwt.lowlevel.afb1d` + + Needs to have the tensors in the right form. Because this function defines + its own backward pass, saves on memory by not having to save the input + tensors. + + Inputs: + x (torch.Tensor): Input to decompose + h0_row: row lowpass + h1_row: row highpass + h0_col: col lowpass + h1_col: col highpass + mode (int): use mode_to_int to get the int code here + + We encode the mode as an integer rather than a string as gradcheck causes an + error when a string is provided. + + Returns: + y: Tensor of shape (N, C*4, H, W) + """ + + @staticmethod + def forward(ctx, x, h0_row, h1_row, h0_col, h1_col, mode): + ctx.save_for_backward(h0_row, h1_row, h0_col, h1_col) + ctx.shape = x.shape[-2:] + mode = int_to_mode(mode) + ctx.mode = mode + lohi = afb1d(x, h0_row, h1_row, mode=mode, dim=3) + y = afb1d(lohi, h0_col, h1_col, mode=mode, dim=2) + s = y.shape + y = y.reshape(s[0], -1, 4, s[-2], s[-1]) + low = y[:, :, 0].contiguous() + highs = y[:, :, 1:].contiguous() + return low, highs + + @staticmethod + def backward(ctx, low, highs): + dx = None + if ctx.needs_input_grad[0]: + mode = ctx.mode + h0_row, h1_row, h0_col, h1_col = ctx.saved_tensors + lh, hl, hh = torch.unbind(highs, dim=2) + lo = sfb1d(low, lh, h0_col, h1_col, mode=mode, dim=2) + hi = sfb1d(hl, hh, h0_col, h1_col, mode=mode, dim=2) + dx = sfb1d(lo, hi, h0_row, h1_row, mode=mode, dim=3) + if dx.shape[-2] > ctx.shape[-2] and dx.shape[-1] > ctx.shape[-1]: + dx = dx[:, :, : ctx.shape[-2], : ctx.shape[-1]] + elif dx.shape[-2] > ctx.shape[-2]: + dx = dx[:, :, : ctx.shape[-2]] + elif dx.shape[-1] > ctx.shape[-1]: + dx = dx[:, :, :, : ctx.shape[-1]] + return dx, None, None, None, None, None + + +class AFB1D(Function): + """Does a single level 1d wavelet decomposition of an input. + + Needs to have the tensors in the right form. Because this function defines + its own backward pass, saves on memory by not having to save the input + tensors. + + Inputs: + x (torch.Tensor): Input to decompose + h0: lowpass + h1: highpass + mode (int): use mode_to_int to get the int code here + + We encode the mode as an integer rather than a string as gradcheck causes an + error when a string is provided. + + Returns: + x0: Tensor of shape (N, C, L') - lowpass + x1: Tensor of shape (N, C, L') - highpass + """ + + @staticmethod + def forward(ctx, x, h0, h1, mode): + mode = int_to_mode(mode) + + # Make inputs 4d + x = x[:, :, None, :] + h0 = h0[:, :, None, :] + h1 = h1[:, :, None, :] + + # Save for backwards + ctx.save_for_backward(h0, h1) + ctx.shape = x.shape[3] + ctx.mode = mode + + lohi = afb1d(x, h0, h1, mode=mode, dim=3) + x0 = lohi[:, ::2, 0].contiguous() + x1 = lohi[:, 1::2, 0].contiguous() + return x0, x1 + + @staticmethod + def backward(ctx, dx0, dx1): + dx = None + if ctx.needs_input_grad[0]: + mode = ctx.mode + h0, h1 = ctx.saved_tensors + + # Make grads 4d + dx0 = dx0[:, :, None, :] + dx1 = dx1[:, :, None, :] + + dx = sfb1d(dx0, dx1, h0, h1, mode=mode, dim=3)[:, :, 0] + + # Check for odd input + if dx.shape[2] > ctx.shape: + dx = dx[:, :, : ctx.shape] + + return dx, None, None, None, None, None + + +def afb2d(x, filts, mode="zero"): + """Does a single level 2d wavelet decomposition of an input. Does separate + row and column filtering by two calls to + :py:func:`pytorch_wavelets.dwt.lowlevel.afb1d` + + Inputs: + x (torch.Tensor): Input to decompose + filts (list of ndarray or torch.Tensor): If a list of tensors has been + given, this function assumes they are in the right form (the form + returned by + :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d`). + Otherwise, this function will prepare the filters to be of the right + form by calling + :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d`. + mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which + padding to use. If periodization, the output size will be half the + input size. Otherwise, the output size will be slightly larger than + half. + + Returns: + y: Tensor of shape (N, C*4, H, W) + """ + tensorize = [not isinstance(f, torch.Tensor) for f in filts] + if len(filts) == 2: + h0, h1 = filts + if True in tensorize: + h0_col, h1_col, h0_row, h1_row = prep_filt_afb2d(h0, h1, device=x.device) + else: + h0_col = h0 + h0_row = h0.transpose(2, 3) + h1_col = h1 + h1_row = h1.transpose(2, 3) + elif len(filts) == 4: + if True in tensorize: + h0_col, h1_col, h0_row, h1_row = prep_filt_afb2d(*filts, device=x.device) + else: + h0_col, h1_col, h0_row, h1_row = filts + else: + raise ValueError("Unknown form for input filts") + + lohi = afb1d(x, h0_row, h1_row, mode=mode, dim=3) + y = afb1d(lohi, h0_col, h1_col, mode=mode, dim=2) + + return y + + +def afb2d_atrous(x, filts, mode="periodization", dilation=1): + """Does a single level 2d wavelet decomposition of an input. Does separate + row and column filtering by two calls to + :py:func:`pytorch_wavelets.dwt.lowlevel.afb1d` + + Inputs: + x (torch.Tensor): Input to decompose + filts (list of ndarray or torch.Tensor): If a list of tensors has been + given, this function assumes they are in the right form (the form + returned by + :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d`). + Otherwise, this function will prepare the filters to be of the right + form by calling + :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d`. + mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which + padding to use. If periodization, the output size will be half the + input size. Otherwise, the output size will be slightly larger than + half. + dilation (int): dilation factor for the filters. Should be 2**level + + Returns: + y: Tensor of shape (N, C, 4, H, W) + """ + tensorize = [not isinstance(f, torch.Tensor) for f in filts] + if len(filts) == 2: + h0, h1 = filts + if True in tensorize: + h0_col, h1_col, h0_row, h1_row = prep_filt_afb2d(h0, h1, device=x.device) + else: + h0_col = h0 + h0_row = h0.transpose(2, 3) + h1_col = h1 + h1_row = h1.transpose(2, 3) + elif len(filts) == 4: + if True in tensorize: + h0_col, h1_col, h0_row, h1_row = prep_filt_afb2d(*filts, device=x.device) + else: + h0_col, h1_col, h0_row, h1_row = filts + else: + raise ValueError("Unknown form for input filts") + + lohi = afb1d_atrous(x, h0_row, h1_row, mode=mode, dim=3, dilation=dilation) + y = afb1d_atrous(lohi, h0_col, h1_col, mode=mode, dim=2, dilation=dilation) + + return y + + +def afb2d_nonsep(x, filts, mode="zero"): + """Does a 1 level 2d wavelet decomposition of an input. Doesn't do separate + row and column filtering. + + Inputs: + x (torch.Tensor): Input to decompose + filts (list or torch.Tensor): If a list is given, should be the low and + highpass filter banks. If a tensor is given, it should be of the + form created by + :py:func:`pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d_nonsep` + mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which + padding to use. If periodization, the output size will be half the + input size. Otherwise, the output size will be slightly larger than + half. + + Returns: + y: Tensor of shape (N, C, 4, H, W) + """ + C = x.shape[1] + Ny = x.shape[2] + Nx = x.shape[3] + + # Check the filter inputs + if isinstance(filts, (tuple, list)): + if len(filts) == 2: + filts = prep_filt_afb2d_nonsep(filts[0], filts[1], device=x.device) + else: + filts = prep_filt_afb2d_nonsep( + filts[0], filts[1], filts[2], filts[3], device=x.device + ) + f = torch.cat([filts] * C, dim=0) + Ly = f.shape[2] + Lx = f.shape[3] + + if mode == "periodization" or mode == "per": + if x.shape[2] % 2 == 1: + x = torch.cat((x, x[:, :, -1:]), dim=2) + Ny += 1 + if x.shape[3] % 2 == 1: + x = torch.cat((x, x[:, :, :, -1:]), dim=3) + Nx += 1 + pad = (Ly - 1, Lx - 1) + stride = (2, 2) + x = roll(roll(x, -Ly // 2, dim=2), -Lx // 2, dim=3) + y = F.conv2d(x, f, padding=pad, stride=stride, groups=C) + y[:, :, : Ly // 2] += y[:, :, Ny // 2 : Ny // 2 + Ly // 2] + y[:, :, :, : Lx // 2] += y[:, :, :, Nx // 2 : Nx // 2 + Lx // 2] + y = y[:, :, : Ny // 2, : Nx // 2] + elif mode == "zero" or mode == "symmetric" or mode == "reflect": + # Calculate the pad size + out1 = pywt.dwt_coeff_len(Ny, Ly, mode=mode) + out2 = pywt.dwt_coeff_len(Nx, Lx, mode=mode) + p1 = 2 * (out1 - 1) - Ny + Ly + p2 = 2 * (out2 - 1) - Nx + Lx + if mode == "zero": + # Sadly, pytorch only allows for same padding before and after, if + # we need to do more padding after for odd length signals, have to + # prepad + if p1 % 2 == 1 and p2 % 2 == 1: + x = F.pad(x, (0, 1, 0, 1)) + elif p1 % 2 == 1: + x = F.pad(x, (0, 0, 0, 1)) + elif p2 % 2 == 1: + x = F.pad(x, (0, 1, 0, 0)) + # Calculate the high and lowpass + y = F.conv2d(x, f, padding=(p1 // 2, p2 // 2), stride=2, groups=C) + elif mode == "symmetric" or mode == "reflect" or mode == "periodic": + pad = (p2 // 2, (p2 + 1) // 2, p1 // 2, (p1 + 1) // 2) + x = mypad(x, pad=pad, mode=mode) + y = F.conv2d(x, f, stride=2, groups=C) + else: + raise ValueError("Unkown pad type: {}".format(mode)) + + return y + + +def sfb2d(ll, lh, hl, hh, filts, mode="zero"): + """Does a single level 2d wavelet reconstruction of wavelet coefficients. + Does separate row and column filtering by two calls to + :py:func:`pytorch_wavelets.dwt.lowlevel.sfb1d` + + Inputs: + ll (torch.Tensor): lowpass coefficients + lh (torch.Tensor): horizontal coefficients + hl (torch.Tensor): vertical coefficients + hh (torch.Tensor): diagonal coefficients + filts (list of ndarray or torch.Tensor): If a list of tensors has been + given, this function assumes they are in the right form (the form + returned by + :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_sfb2d`). + Otherwise, this function will prepare the filters to be of the right + form by calling + :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_sfb2d`. + mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which + padding to use. If periodization, the output size will be half the + input size. Otherwise, the output size will be slightly larger than + half. + """ + tensorize = [not isinstance(x, torch.Tensor) for x in filts] + if len(filts) == 2: + g0, g1 = filts + if True in tensorize: + g0_col, g1_col, g0_row, g1_row = prep_filt_sfb2d(g0, g1) + else: + g0_col = g0 + g0_row = g0.transpose(2, 3) + g1_col = g1 + g1_row = g1.transpose(2, 3) + elif len(filts) == 4: + if True in tensorize: + g0_col, g1_col, g0_row, g1_row = prep_filt_sfb2d(*filts) + else: + g0_col, g1_col, g0_row, g1_row = filts + else: + raise ValueError("Unknown form for input filts") + + lo = sfb1d(ll, lh, g0_col, g1_col, mode=mode, dim=2) + hi = sfb1d(hl, hh, g0_col, g1_col, mode=mode, dim=2) + y = sfb1d(lo, hi, g0_row, g1_row, mode=mode, dim=3) + + return y + + +class SFB2D(Function): + """Does a single level 2d wavelet decomposition of an input. Does separate + row and column filtering by two calls to + :py:func:`pytorch_wavelets.dwt.lowlevel.afb1d` + + Needs to have the tensors in the right form. Because this function defines + its own backward pass, saves on memory by not having to save the input + tensors. + + Inputs: + x (torch.Tensor): Input to decompose + h0_row: row lowpass + h1_row: row highpass + h0_col: col lowpass + h1_col: col highpass + mode (int): use mode_to_int to get the int code here + + We encode the mode as an integer rather than a string as gradcheck causes an + error when a string is provided. + + Returns: + y: Tensor of shape (N, C*4, H, W) + """ + + @staticmethod + def forward(ctx, low, highs, g0_row, g1_row, g0_col, g1_col, mode): + mode = int_to_mode(mode) + ctx.mode = mode + ctx.save_for_backward(g0_row, g1_row, g0_col, g1_col) + + lh, hl, hh = torch.unbind(highs, dim=2) + lo = sfb1d(low, lh, g0_col, g1_col, mode=mode, dim=2) + hi = sfb1d(hl, hh, g0_col, g1_col, mode=mode, dim=2) + y = sfb1d(lo, hi, g0_row, g1_row, mode=mode, dim=3) + return y + + @staticmethod + def backward(ctx, dy): + dlow, dhigh = None, None + if ctx.needs_input_grad[0]: + mode = ctx.mode + g0_row, g1_row, g0_col, g1_col = ctx.saved_tensors + dx = afb1d(dy, g0_row, g1_row, mode=mode, dim=3) + dx = afb1d(dx, g0_col, g1_col, mode=mode, dim=2) + s = dx.shape + dx = dx.reshape(s[0], -1, 4, s[-2], s[-1]) + dlow = dx[:, :, 0].contiguous() + dhigh = dx[:, :, 1:].contiguous() + return dlow, dhigh, None, None, None, None, None + + +class SFB1D(Function): + """Does a single level 1d wavelet decomposition of an input. + + Needs to have the tensors in the right form. Because this function defines + its own backward pass, saves on memory by not having to save the input + tensors. + + Inputs: + low (torch.Tensor): Lowpass to reconstruct of shape (N, C, L) + high (torch.Tensor): Highpass to reconstruct of shape (N, C, L) + g0: lowpass + g1: highpass + mode (int): use mode_to_int to get the int code here + + We encode the mode as an integer rather than a string as gradcheck causes an + error when a string is provided. + + Returns: + y: Tensor of shape (N, C*2, L') + """ + + @staticmethod + def forward(ctx, low, high, g0, g1, mode): + mode = int_to_mode(mode) + # Make into a 2d tensor with 1 row + low = low[:, :, None, :] + high = high[:, :, None, :] + g0 = g0[:, :, None, :] + g1 = g1[:, :, None, :] + + ctx.mode = mode + ctx.save_for_backward(g0, g1) + + return sfb1d(low, high, g0, g1, mode=mode, dim=3)[:, :, 0] + + @staticmethod + def backward(ctx, dy): + dlow, dhigh = None, None + if ctx.needs_input_grad[0]: + mode = ctx.mode + ( + g0, + g1, + ) = ctx.saved_tensors + dy = dy[:, :, None, :] + + dx = afb1d(dy, g0, g1, mode=mode, dim=3) + + dlow = dx[:, ::2, 0].contiguous() + dhigh = dx[:, 1::2, 0].contiguous() + return dlow, dhigh, None, None, None, None, None + + +def sfb2d_nonsep(coeffs, filts, mode="zero"): + """Does a single level 2d wavelet reconstruction of wavelet coefficients. + Does not do separable filtering. + + Inputs: + coeffs (torch.Tensor): tensor of coefficients of shape (N, C, 4, H, W) + where the third dimension indexes across the (ll, lh, hl, hh) bands. + filts (list of ndarray or torch.Tensor): If a list of tensors has been + given, this function assumes they are in the right form (the form + returned by + :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_sfb2d_nonsep`). + Otherwise, this function will prepare the filters to be of the right + form by calling + :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_sfb2d_nonsep`. + mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which + padding to use. If periodization, the output size will be half the + input size. Otherwise, the output size will be slightly larger than + half. + """ + C = coeffs.shape[1] + Ny = coeffs.shape[-2] + Nx = coeffs.shape[-1] + + # Check the filter inputs - should be in the form of a torch tensor, but if + # not, tensorize it here. + if isinstance(filts, (tuple, list)): + if len(filts) == 2: + filts = prep_filt_sfb2d_nonsep(filts[0], filts[1], device=coeffs.device) + elif len(filts) == 4: + filts = prep_filt_sfb2d_nonsep( + filts[0], filts[1], filts[2], filts[3], device=coeffs.device + ) + else: + raise ValueError("Unkown form for input filts") + f = torch.cat([filts] * C, dim=0) + Ly = f.shape[2] + Lx = f.shape[3] + + x = coeffs.reshape(coeffs.shape[0], -1, coeffs.shape[-2], coeffs.shape[-1]) + if mode == "periodization" or mode == "per": + ll = F.conv_transpose2d(x, f, groups=C, stride=2) + ll[:, :, : Ly - 2] += ll[:, :, 2 * Ny : 2 * Ny + Ly - 2] + ll[:, :, :, : Lx - 2] += ll[:, :, :, 2 * Nx : 2 * Nx + Lx - 2] + ll = ll[:, :, : 2 * Ny, : 2 * Nx] + ll = roll(roll(ll, 1 - Ly // 2, dim=2), 1 - Lx // 2, dim=3) + elif ( + mode == "symmetric" or mode == "zero" or mode == "reflect" or mode == "periodic" + ): + pad = (Ly - 2, Lx - 2) + ll = F.conv_transpose2d(x, f, padding=pad, groups=C, stride=2) + else: + raise ValueError("Unkown pad type: {}".format(mode)) + + return ll.contiguous() + + +def prep_filt_afb2d_nonsep(h0_col, h1_col, h0_row=None, h1_row=None, device=None): + """ + Prepares the filters to be of the right form for the afb2d_nonsep function. + In particular, makes 2d point spread functions, and mirror images them in + preparation to do torch.conv2d. + + Inputs: + h0_col (array-like): low pass column filter bank + h1_col (array-like): high pass column filter bank + h0_row (array-like): low pass row filter bank. If none, will assume the + same as column filter + h1_row (array-like): high pass row filter bank. If none, will assume the + same as column filter + device: which device to put the tensors on to + + Returns: + filts: (4, 1, h, w) tensor ready to get the four subbands + """ + h0_col = np.array(h0_col).ravel() + h1_col = np.array(h1_col).ravel() + if h0_row is None: + h0_row = h0_col + if h1_row is None: + h1_row = h1_col + ll = np.outer(h0_col, h0_row) + lh = np.outer(h1_col, h0_row) + hl = np.outer(h0_col, h1_row) + hh = np.outer(h1_col, h1_row) + filts = np.stack( + [ + ll[None, ::-1, ::-1], + lh[None, ::-1, ::-1], + hl[None, ::-1, ::-1], + hh[None, ::-1, ::-1], + ], + axis=0, + ) + filts = torch.tensor(filts, dtype=torch.get_default_dtype(), device=device) + return filts + + +def prep_filt_sfb2d_nonsep(g0_col, g1_col, g0_row=None, g1_row=None, device=None): + """ + Prepares the filters to be of the right form for the sfb2d_nonsep function. + In particular, makes 2d point spread functions. Does not mirror image them + as sfb2d_nonsep uses conv2d_transpose which acts like normal convolution. + + Inputs: + g0_col (array-like): low pass column filter bank + g1_col (array-like): high pass column filter bank + g0_row (array-like): low pass row filter bank. If none, will assume the + same as column filter + g1_row (array-like): high pass row filter bank. If none, will assume the + same as column filter + device: which device to put the tensors on to + + Returns: + filts: (4, 1, h, w) tensor ready to combine the four subbands + """ + g0_col = np.array(g0_col).ravel() + g1_col = np.array(g1_col).ravel() + if g0_row is None: + g0_row = g0_col + if g1_row is None: + g1_row = g1_col + ll = np.outer(g0_col, g0_row) + lh = np.outer(g1_col, g0_row) + hl = np.outer(g0_col, g1_row) + hh = np.outer(g1_col, g1_row) + filts = np.stack([ll[None], lh[None], hl[None], hh[None]], axis=0) + filts = torch.tensor(filts, dtype=torch.get_default_dtype(), device=device) + return filts + + +def prep_filt_sfb2d(g0_col, g1_col, g0_row=None, g1_row=None, device=None): + """ + Prepares the filters to be of the right form for the sfb2d function. In + particular, makes the tensors the right shape. It does not mirror image them + as as sfb2d uses conv2d_transpose which acts like normal convolution. + + Inputs: + g0_col (array-like): low pass column filter bank + g1_col (array-like): high pass column filter bank + g0_row (array-like): low pass row filter bank. If none, will assume the + same as column filter + g1_row (array-like): high pass row filter bank. If none, will assume the + same as column filter + device: which device to put the tensors on to + + Returns: + (g0_col, g1_col, g0_row, g1_row) + """ + g0_col, g1_col = prep_filt_sfb1d(g0_col, g1_col, device) + if g0_row is None: + g0_row, g1_row = g0_col, g1_col + else: + g0_row, g1_row = prep_filt_sfb1d(g0_row, g1_row, device) + + g0_col = g0_col.reshape((1, 1, -1, 1)) + g1_col = g1_col.reshape((1, 1, -1, 1)) + g0_row = g0_row.reshape((1, 1, 1, -1)) + g1_row = g1_row.reshape((1, 1, 1, -1)) + + return g0_col, g1_col, g0_row, g1_row + + +def prep_filt_sfb1d(g0, g1, device=None): + """ + Prepares the filters to be of the right form for the sfb1d function. In + particular, makes the tensors the right shape. It does not mirror image them + as as sfb2d uses conv2d_transpose which acts like normal convolution. + + Inputs: + g0 (array-like): low pass filter bank + g1 (array-like): high pass filter bank + device: which device to put the tensors on to + + Returns: + (g0, g1) + """ + g0 = np.array(g0).ravel() + g1 = np.array(g1).ravel() + t = torch.get_default_dtype() + g0 = torch.tensor(g0, device=device, dtype=t).reshape((1, 1, -1)) + g1 = torch.tensor(g1, device=device, dtype=t).reshape((1, 1, -1)) + + return g0, g1 + + +def prep_filt_afb2d(h0_col, h1_col, h0_row=None, h1_row=None, device=None): + """ + Prepares the filters to be of the right form for the afb2d function. In + particular, makes the tensors the right shape. It takes mirror images of + them as as afb2d uses conv2d which acts like normal correlation. + + Inputs: + h0_col (array-like): low pass column filter bank + h1_col (array-like): high pass column filter bank + h0_row (array-like): low pass row filter bank. If none, will assume the + same as column filter + h1_row (array-like): high pass row filter bank. If none, will assume the + same as column filter + device: which device to put the tensors on to + + Returns: + (h0_col, h1_col, h0_row, h1_row) + """ + h0_col, h1_col = prep_filt_afb1d(h0_col, h1_col, device) + if h0_row is None: + h0_row, h1_col = h0_col, h1_col + else: + h0_row, h1_row = prep_filt_afb1d(h0_row, h1_row, device) + + h0_col = h0_col.reshape((1, 1, -1, 1)) + h1_col = h1_col.reshape((1, 1, -1, 1)) + h0_row = h0_row.reshape((1, 1, 1, -1)) + h1_row = h1_row.reshape((1, 1, 1, -1)) + return h0_col, h1_col, h0_row, h1_row + + +def prep_filt_afb1d(h0, h1, device=None): + """ + Prepares the filters to be of the right form for the afb2d function. In + particular, makes the tensors the right shape. It takes mirror images of + them as as afb2d uses conv2d which acts like normal correlation. + + Inputs: + h0 (array-like): low pass column filter bank + h1 (array-like): high pass column filter bank + device: which device to put the tensors on to + + Returns: + (h0, h1) + """ + h0 = np.array(h0[::-1]).ravel() + h1 = np.array(h1[::-1]).ravel() + t = torch.get_default_dtype() + h0 = torch.tensor(h0, device=device, dtype=t).reshape((1, 1, -1)) + h1 = torch.tensor(h1, device=device, dtype=t).reshape((1, 1, -1)) + return h0, h1 diff --git a/mirtorch/vendors/pytorch_wavelets/dwt/swt_inverse.py b/mirtorch/vendors/pytorch_wavelets/dwt/swt_inverse.py new file mode 100644 index 0000000..810ac74 --- /dev/null +++ b/mirtorch/vendors/pytorch_wavelets/dwt/swt_inverse.py @@ -0,0 +1,213 @@ +def sfb1d_atrous( + lo, hi, g0, g1, mode="periodization", dim=-1, dilation=1, pad1=None, pad=None +): + """1D synthesis filter bank of an image tensor with no upsampling. Used for + the stationary wavelet transform. + """ + C = lo.shape[1] + d = dim % 4 + # If g0, g1 are not tensors, make them. If they are, then assume that they + # are in the right order + if not isinstance(g0, torch.Tensor): + g0 = torch.tensor( + np.copy(np.array(g0).ravel()), dtype=torch.float, device=lo.device + ) + if not isinstance(g1, torch.Tensor): + g1 = torch.tensor( + np.copy(np.array(g1).ravel()), dtype=torch.float, device=lo.device + ) + L = g0.numel() + shape = [1, 1, 1, 1] + shape[d] = L + # If g aren't in the right shape, make them so + if g0.shape != tuple(shape): + g0 = g0.reshape(*shape) + if g1.shape != tuple(shape): + g1 = g1.reshape(*shape) + g0 = torch.cat([g0] * C, dim=0) + g1 = torch.cat([g1] * C, dim=0) + + # Calculate the padding size. + # With dilation, zeros are inserted between the filter taps but not after. + # that means a filter that is [a b c d] becomes [a 0 b 0 c 0 d]. + centre = L / 2 + fsz = (L - 1) * dilation + 1 + newcentre = fsz / 2 + before = newcentre - dilation * centre + + # When conv_transpose2d is done, a filter with k taps expands an input with + # N samples to be N + k - 1 samples. The 'padding' is really the opposite of + # that, and is how many samples on the edges you want to cut out. + # In addition to this, we want the input to be extended before convolving. + # This means the final output size without the padding option will be + # N + k - 1 + k - 1 + # The final thing to worry about is making sure that the output is centred. + short_offset = dilation - 1 + centre_offset = fsz % 2 + a = fsz // 2 + b = fsz // 2 + (fsz + 1) % 2 + # a = 0 + # b = 0 + pad = (0, 0, a, b) if d == 2 else (a, b, 0, 0) + lo = mypad(lo, pad=pad, mode=mode) + hi = mypad(hi, pad=pad, mode=mode) + unpad = (fsz - 1, 0) if d == 2 else (0, fsz - 1) + unpad = (0, 0) + y = F.conv_transpose2d( + lo, g0, padding=unpad, groups=C, dilation=dilation + ) + F.conv_transpose2d(hi, g1, padding=unpad, groups=C, dilation=dilation) + # pad = (L-1, 0) if d == 2 else (0, L-1) + # y = F.conv_transpose2d(lo, g0, padding=pad, groups=C, dilation=dilation) + \ + # F.conv_transpose2d(hi, g1, padding=pad, groups=C, dilation=dilation) + # + # + # Calculate the pad size + # L2 = (L * dilation)//2 + # # pad = (0, 0, L2, L2+dilation) if d == 2 else (L2, L2+dilation, 0, 0) + # a = dilation*2 + # b = dilation*(L-2) + # if pad1 is None: + # pad1 = (0, 0, a, b) if d == 2 else (a, b, 0, 0) + # print(pad1) + # lo = mypad(lo, pad=pad1, mode=mode) + # hi = mypad(hi, pad=pad1, mode=mode) + # if pad is None: + # p = (a + b + (L - 1)*dilation)//2 + # pad = (p, 0) if d == 2 else (0, p) + # print(pad) + + return y / (2 * dilation) + + +def sfb2d_atrous(ll, lh, hl, hh, filts, mode="zero"): + """Does a single level 2d wavelet reconstruction of wavelet coefficients. + Does separate row and column filtering by two calls to + :py:func:`pytorch_wavelets.dwt.lowlevel.sfb1d` + + Inputs: + ll (torch.Tensor): lowpass coefficients + lh (torch.Tensor): horizontal coefficients + hl (torch.Tensor): vertical coefficients + hh (torch.Tensor): diagonal coefficients + filts (list of ndarray or torch.Tensor): If a list of tensors has been + given, this function assumes they are in the right form (the form + returned by + :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_sfb2d`). + Otherwise, this function will prepare the filters to be of the right + form by calling + :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_sfb2d`. + mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which + padding to use. If periodization, the output size will be half the + input size. Otherwise, the output size will be slightly larger than + half. + """ + tensorize = [not isinstance(x, torch.Tensor) for x in filts] + if len(filts) == 2: + g0, g1 = filts + if True in tensorize: + g0_col, g1_col, g0_row, g1_row = prep_filt_sfb2d(g0, g1) + else: + g0_col = g0 + g0_row = g0.transpose(2, 3) + g1_col = g1 + g1_row = g1.transpose(2, 3) + elif len(filts) == 4: + if True in tensorize: + g0_col, g1_col, g0_row, g1_row = prep_filt_sfb2d(*filts) + else: + g0_col, g1_col, g0_row, g1_row = filts + else: + raise ValueError("Unknown form for input filts") + + lo = sfb1d_atrous(ll, lh, g0_col, g1_col, mode=mode, dim=2) + hi = sfb1d_atrous(hl, hh, g0_col, g1_col, mode=mode, dim=2) + y = sfb1d_atrous(lo, hi, g0_row, g1_row, mode=mode, dim=3) + + return y + + +class SWTInverse(nn.Module): + """Performs a 2d DWT Inverse reconstruction of an image + + Args: + wave (str or pywt.Wavelet): Which wavelet to use + C: deprecated, will be removed in future + """ + + def __init__(self, wave="db1", mode="zero", separable=True): + super().__init__() + if isinstance(wave, str): + wave = pywt.Wavelet(wave) + if isinstance(wave, pywt.Wavelet): + g0_col, g1_col = wave.rec_lo, wave.rec_hi + g0_row, g1_row = g0_col, g1_col + else: + if len(wave) == 2: + g0_col, g1_col = wave[0], wave[1] + g0_row, g1_row = g0_col, g1_col + elif len(wave) == 4: + g0_col, g1_col = wave[0], wave[1] + g0_row, g1_row = wave[2], wave[3] + # Prepare the filters + if separable: + filts = lowlevel.prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row) + self.register_buffer("g0_col", filts[0]) + self.register_buffer("g1_col", filts[1]) + self.register_buffer("g0_row", filts[2]) + self.register_buffer("g1_row", filts[3]) + else: + filts = lowlevel.prep_filt_sfb2d_nonsep(g0_col, g1_col, g0_row, g1_row) + self.register_buffer("h", filts) + self.mode = mode + self.separable = separable + + def forward(self, coeffs): + """ + Args: + coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where: + yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}', + W_{in}')` and yh is a list of bandpass tensors of shape + :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match + the format returned by DWTForward + + Returns: + Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})` + + Note: + :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly + downsampled shapes of the DWT pyramid. + + Note: + Can have None for any of the highpass scales and will treat the + values as zeros (not in an efficient way though). + """ + yl, yh = coeffs + ll = yl + + # Do a multilevel inverse transform + for h in yh[::-1]: + if h is None: + h = torch.zeros( + ll.shape[0], + ll.shape[1], + 3, + ll.shape[-2], + ll.shape[-1], + device=ll.device, + ) + + # 'Unpad' added dimensions + if ll.shape[-2] > h.shape[-2]: + ll = ll[..., :-1, :] + if ll.shape[-1] > h.shape[-1]: + ll = ll[..., :-1] + + # Do the synthesis filter banks + if self.separable: + lh, hl, hh = torch.unbind(h, dim=2) + filts = (self.g0_col, self.g1_col, self.g0_row, self.g1_row) + ll = lowlevel.sfb2d(ll, lh, hl, hh, filts, mode=self.mode) + else: + c = torch.cat((ll[:, :, None], h), dim=2) + ll = lowlevel.sfb2d_nonsep(c, self.h, mode=self.mode) + return ll diff --git a/mirtorch/vendors/pytorch_wavelets/dwt/transform1d.py b/mirtorch/vendors/pytorch_wavelets/dwt/transform1d.py new file mode 100644 index 0000000..f077e54 --- /dev/null +++ b/mirtorch/vendors/pytorch_wavelets/dwt/transform1d.py @@ -0,0 +1,117 @@ +import torch.nn as nn +import pywt +from . import lowlevel +import torch + + +class DWT1DForward(nn.Module): + """Performs a 1d DWT Forward decomposition of an image + + Args: + J (int): Number of levels of decomposition + wave (str or pywt.Wavelet or tuple(ndarray)): Which wavelet to use. + Can be: + 1) a string to pass to pywt.Wavelet constructor + 2) a pywt.Wavelet class + 3) a tuple of numpy arrays (h0, h1) + mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The + padding scheme + """ + + def __init__(self, J=1, wave="db1", mode="zero"): + super().__init__() + if isinstance(wave, str): + wave = pywt.Wavelet(wave) + if isinstance(wave, pywt.Wavelet): + h0, h1 = wave.dec_lo, wave.dec_hi + else: + assert len(wave) == 2 + h0, h1 = wave[0], wave[1] + + # Prepare the filters - this makes them into column filters + filts = lowlevel.prep_filt_afb1d(h0, h1) + self.register_buffer("h0", filts[0]) + self.register_buffer("h1", filts[1]) + self.J = J + self.mode = mode + + def forward(self, x): + """Forward pass of the DWT. + + Args: + x (tensor): Input of shape :math:`(N, C_{in}, L_{in})` + + Returns: + (yl, yh) + tuple of lowpass (yl) and bandpass (yh) coefficients. + yh is a list of length J with the first entry + being the finest scale coefficients. + """ + assert x.ndim == 3, "Can only handle 3d inputs (N, C, L)" + highs = [] + x0 = x + mode = lowlevel.mode_to_int(self.mode) + + # Do a multilevel transform + for j in range(self.J): + x0, x1 = lowlevel.AFB1D.apply(x0, self.h0, self.h1, mode) + highs.append(x1) + + return x0, highs + + +class DWT1DInverse(nn.Module): + """Performs a 1d DWT Inverse reconstruction of an image + + Args: + wave (str or pywt.Wavelet or tuple(ndarray)): Which wavelet to use. + Can be: + 1) a string to pass to pywt.Wavelet constructor + 2) a pywt.Wavelet class + 3) a tuple of numpy arrays (h0, h1) + mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The + padding scheme + """ + + def __init__(self, wave="db1", mode="zero"): + super().__init__() + if isinstance(wave, str): + wave = pywt.Wavelet(wave) + if isinstance(wave, pywt.Wavelet): + g0, g1 = wave.rec_lo, wave.rec_hi + else: + assert len(wave) == 2 + g0, g1 = wave[0], wave[1] + + # Prepare the filters + filts = lowlevel.prep_filt_sfb1d(g0, g1) + self.register_buffer("g0", filts[0]) + self.register_buffer("g1", filts[1]) + self.mode = mode + + def forward(self, coeffs): + """ + Args: + coeffs (yl, yh): tuple of lowpass and bandpass coefficients, should + match the format returned by DWT1DForward. + + Returns: + Reconstructed input of shape :math:`(N, C_{in}, L_{in})` + + Note: + Can have None for any of the highpass scales and will treat the + values as zeros (not in an efficient way though). + """ + x0, highs = coeffs + assert x0.ndim == 3, "Can only handle 3d inputs (N, C, L)" + mode = lowlevel.mode_to_int(self.mode) + # Do a multilevel inverse transform + for x1 in highs[::-1]: + if x1 is None: + x1 = torch.zeros_like(x0) + + # 'Unpad' added signal + if x0.shape[-1] > x1.shape[-1]: + x0 = x0[..., :-1] + x0 = lowlevel.SFB1D.apply(x0, x1, self.g0, self.g1, mode) + return x0 diff --git a/mirtorch/vendors/pytorch_wavelets/dwt/transform2d.py b/mirtorch/vendors/pytorch_wavelets/dwt/transform2d.py new file mode 100644 index 0000000..400d516 --- /dev/null +++ b/mirtorch/vendors/pytorch_wavelets/dwt/transform2d.py @@ -0,0 +1,223 @@ +import torch.nn as nn +import pywt +from . import lowlevel +import torch + + +class DWTForward(nn.Module): + """Performs a 2d DWT Forward decomposition of an image + + Args: + J (int): Number of levels of decomposition + wave (str or pywt.Wavelet or tuple(ndarray)): Which wavelet to use. + Can be: + 1) a string to pass to pywt.Wavelet constructor + 2) a pywt.Wavelet class + 3) a tuple of numpy arrays, either (h0, h1) or (h0_col, h1_col, h0_row, h1_row) + mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The + padding scheme + """ + + def __init__(self, J=1, wave="db1", mode="zero"): + super().__init__() + if isinstance(wave, str): + wave = pywt.Wavelet(wave) + if isinstance(wave, pywt.Wavelet): + h0_col, h1_col = wave.dec_lo, wave.dec_hi + h0_row, h1_row = h0_col, h1_col + else: + if len(wave) == 2: + h0_col, h1_col = wave[0], wave[1] + h0_row, h1_row = h0_col, h1_col + elif len(wave) == 4: + h0_col, h1_col = wave[0], wave[1] + h0_row, h1_row = wave[2], wave[3] + + # Prepare the filters + filts = lowlevel.prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row) + self.register_buffer("h0_col", filts[0]) + self.register_buffer("h1_col", filts[1]) + self.register_buffer("h0_row", filts[2]) + self.register_buffer("h1_row", filts[3]) + self.J = J + self.mode = mode + + def forward(self, x): + """Forward pass of the DWT. + + Args: + x (tensor): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})` + + Returns: + (yl, yh) + tuple of lowpass (yl) and bandpass (yh) coefficients. + yh is a list of length J with the first entry + being the finest scale coefficients. yl has shape + :math:`(N, C_{in}, H_{in}', W_{in}')` and yh has shape + :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. The new + dimension in yh iterates over the LH, HL and HH coefficients. + + Note: + :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly + downsampled shapes of the DWT pyramid. + """ + yh = [] + ll = x + mode = lowlevel.mode_to_int(self.mode) + + # Do a multilevel transform + for j in range(self.J): + # Do 1 level of the transform + ll, high = lowlevel.AFB2D.apply( + ll, self.h0_col, self.h1_col, self.h0_row, self.h1_row, mode + ) + yh.append(high) + + return ll, yh + + +class DWTInverse(nn.Module): + """Performs a 2d DWT Inverse reconstruction of an image + + Args: + wave (str or pywt.Wavelet or tuple(ndarray)): Which wavelet to use. + Can be: + 1) a string to pass to pywt.Wavelet constructor + 2) a pywt.Wavelet class + 3) a tuple of numpy arrays, either (h0, h1) or (h0_col, h1_col, h0_row, h1_row) + mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The + padding scheme + """ + + def __init__(self, wave="db1", mode="zero"): + super().__init__() + if isinstance(wave, str): + wave = pywt.Wavelet(wave) + if isinstance(wave, pywt.Wavelet): + g0_col, g1_col = wave.rec_lo, wave.rec_hi + g0_row, g1_row = g0_col, g1_col + else: + if len(wave) == 2: + g0_col, g1_col = wave[0], wave[1] + g0_row, g1_row = g0_col, g1_col + elif len(wave) == 4: + g0_col, g1_col = wave[0], wave[1] + g0_row, g1_row = wave[2], wave[3] + # Prepare the filters + filts = lowlevel.prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row) + self.register_buffer("g0_col", filts[0]) + self.register_buffer("g1_col", filts[1]) + self.register_buffer("g0_row", filts[2]) + self.register_buffer("g1_row", filts[3]) + self.mode = mode + + def forward(self, coeffs): + """ + Args: + coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where: + yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}', + W_{in}')` and yh is a list of bandpass tensors of shape + :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match + the format returned by DWTForward + + Returns: + Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})` + + Note: + :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly + downsampled shapes of the DWT pyramid. + + Note: + Can have None for any of the highpass scales and will treat the + values as zeros (not in an efficient way though). + """ + yl, yh = coeffs + ll = yl + mode = lowlevel.mode_to_int(self.mode) + + # Do a multilevel inverse transform + for h in yh[::-1]: + if h is None: + h = torch.zeros( + ll.shape[0], + ll.shape[1], + 3, + ll.shape[-2], + ll.shape[-1], + device=ll.device, + ) + + # 'Unpad' added dimensions + if ll.shape[-2] > h.shape[-2]: + ll = ll[..., :-1, :] + if ll.shape[-1] > h.shape[-1]: + ll = ll[..., :-1] + ll = lowlevel.SFB2D.apply( + ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode + ) + return ll + + +class SWTForward(nn.Module): + """Performs a 2d Stationary wavelet transform (or undecimated wavelet + transform) of an image + + Args: + J (int): Number of levels of decomposition + wave (str or pywt.Wavelet): Which wavelet to use. Can be a string to + pass to pywt.Wavelet constructor, can also be a pywt.Wavelet class, + or can be a two tuple of array-like objects for the analysis low and + high pass filters. + mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The + padding scheme. PyWavelets uses only periodization so we use this + as our default scheme. + """ + + def __init__(self, J=1, wave="db1", mode="periodization"): + super().__init__() + if isinstance(wave, str): + wave = pywt.Wavelet(wave) + if isinstance(wave, pywt.Wavelet): + h0_col, h1_col = wave.dec_lo, wave.dec_hi + h0_row, h1_row = h0_col, h1_col + else: + if len(wave) == 2: + h0_col, h1_col = wave[0], wave[1] + h0_row, h1_row = h0_col, h1_col + elif len(wave) == 4: + h0_col, h1_col = wave[0], wave[1] + h0_row, h1_row = wave[2], wave[3] + + # Prepare the filters + filts = lowlevel.prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row) + self.register_buffer("h0_col", filts[0]) + self.register_buffer("h1_col", filts[1]) + self.register_buffer("h0_row", filts[2]) + self.register_buffer("h1_row", filts[3]) + + self.J = J + self.mode = mode + + def forward(self, x): + """Forward pass of the SWT. + + Args: + x (tensor): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})` + + Returns: + List of coefficients for each scale. Each coefficient has + shape :math:`(N, C_{in}, 4, H_{in}, W_{in})` where the extra + dimension stores the 4 subbands for each scale. The ordering in + these 4 coefficients is: (A, H, V, D) or (ll, lh, hl, hh). + """ + ll = x + coeffs = [] + # Do a multilevel transform + filts = (self.h0_col, self.h1_col, self.h0_row, self.h1_row) + for j in range(self.J): + # Do 1 level of the transform + y = lowlevel.afb2d_atrous(ll, filts, self.mode, 2**j) + coeffs.append(y) + ll = y[:, :, 0] + + return coeffs diff --git a/mirtorch/vendors/pytorch_wavelets/scatternet/__init__.py b/mirtorch/vendors/pytorch_wavelets/scatternet/__init__.py new file mode 100644 index 0000000..48614ca --- /dev/null +++ b/mirtorch/vendors/pytorch_wavelets/scatternet/__init__.py @@ -0,0 +1,3 @@ +from .layers import ScatLayer, ScatLayerj2 + +__all__ = ["ScatLayer", "ScatLayerj2"] diff --git a/mirtorch/vendors/pytorch_wavelets/scatternet/layers.py b/mirtorch/vendors/pytorch_wavelets/scatternet/layers.py new file mode 100644 index 0000000..59fd0d7 --- /dev/null +++ b/mirtorch/vendors/pytorch_wavelets/scatternet/layers.py @@ -0,0 +1,209 @@ +import torch +import torch.nn as nn +from ..dtcwt.coeffs import biort as _biort, qshift as _qshift +from ..dtcwt.lowlevel import prep_filt + +from .lowlevel import mode_to_int +from .lowlevel import ScatLayerj1_f, ScatLayerj1_rot_f +from .lowlevel import ScatLayerj2_f, ScatLayerj2_rot_f + + +class ScatLayer(nn.Module): + """Does one order of scattering at a single scale. Can be made into a + second order scatternet by stacking two of these layers. + Inputs: + biort (str): the biorthogonal filters to use. if 'near_sym_b_bp' will + use the rotationally symmetric filters. These have 13 and 19 taps + so are quite long. They also require 7 1D convolutions instead of 6. + x (torch.tensor): Input of shape (N, C, H, W) + mode (str): padding mode. Can be 'symmetric' or 'zero' + magbias (float): the magnitude bias to use for smoothing + combine_colour (bool): if true, will only have colour lowpass and have + greyscale bandpass + Returns: + y (torch.tensor): y has the lowpass and invariant U terms stacked along + the channel dimension, and so has shape (N, 7*C, H/2, W/2). Where + the first C channels are the lowpass outputs, and the next 6C are + the magnitude highpass outputs. + """ + + def __init__( + self, biort="near_sym_a", mode="symmetric", magbias=1e-2, combine_colour=False + ): + super().__init__() + self.biort = biort + # Have to convert the string to an int as the grad checks don't work + # with string inputs + self.mode_str = mode + self.mode = mode_to_int(mode) + self.magbias = magbias + self.combine_colour = combine_colour + if biort == "near_sym_b_bp": + self.bandpass_diag = True + h0o, _, h1o, _, h2o, _ = _biort(biort) + self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False) + self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False) + self.h2o = torch.nn.Parameter(prep_filt(h2o, 1), False) + else: + self.bandpass_diag = False + h0o, _, h1o, _ = _biort(biort) + self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False) + self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False) + + def forward(self, x): + # Do the single scale DTCWT + # If the row/col count of X is not divisible by 2 then we need to + # extend X + _, ch, r, c = x.shape + if r % 2 != 0: + x = torch.cat((x, x[:, :, -1:]), dim=2) + if c % 2 != 0: + x = torch.cat((x, x[:, :, :, -1:]), dim=3) + + if self.combine_colour: + assert ch == 3 + + if self.bandpass_diag: + Z = ScatLayerj1_rot_f.apply( + x, + self.h0o, + self.h1o, + self.h2o, + self.mode, + self.magbias, + self.combine_colour, + ) + else: + Z = ScatLayerj1_f.apply( + x, self.h0o, self.h1o, self.mode, self.magbias, self.combine_colour + ) + if not self.combine_colour: + b, _, c, h, w = Z.shape + Z = Z.view(b, 7 * c, h, w) + return Z + + def extra_repr(self): + return "biort='{}', mode='{}', magbias={}".format( + self.biort, self.mode_str, self.magbias + ) + + +class ScatLayerj2(nn.Module): + """Does second order scattering for two scales. Uses correct dtcwt first + and second level filters compared to ScatLayer which only uses biorthogonal + filters. + + Inputs: + biort (str): the biorthogonal filters to use. if 'near_sym_b_bp' will + use the rotationally symmetric filters. These have 13 and 19 taps + so are quite long. They also require 7 1D convolutions instead of 6. + x (torch.tensor): Input of shape (N, C, H, W) + mode (str): padding mode. Can be 'symmetric' or 'zero' + Returns: + y (torch.tensor): y has the lowpass and invariant U terms stacked along + the channel dimension, and so has shape (N, 7*C, H/2, W/2). Where + the first C channels are the lowpass outputs, and the next 6C are + the magnitude highpass outputs. + """ + + def __init__( + self, + biort="near_sym_a", + qshift="qshift_a", + mode="symmetric", + magbias=1e-2, + combine_colour=False, + ): + super().__init__() + self.biort = biort + self.qshift = biort + # Have to convert the string to an int as the grad checks don't work + # with string inputs + self.mode_str = mode + self.mode = mode_to_int(mode) + self.magbias = magbias + self.combine_colour = combine_colour + if biort == "near_sym_b_bp": + assert qshift == "qshift_b_bp" + self.bandpass_diag = True + h0o, _, h1o, _, h2o, _ = _biort(biort) + self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False) + self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False) + self.h2o = torch.nn.Parameter(prep_filt(h2o, 1), False) + h0a, h0b, _, _, h1a, h1b, _, _, h2a, h2b, _, _ = _qshift("qshift_b_bp") + self.h0a = torch.nn.Parameter(prep_filt(h0a, 1), False) + self.h0b = torch.nn.Parameter(prep_filt(h0b, 1), False) + self.h1a = torch.nn.Parameter(prep_filt(h1a, 1), False) + self.h1b = torch.nn.Parameter(prep_filt(h1b, 1), False) + self.h2a = torch.nn.Parameter(prep_filt(h2a, 1), False) + self.h2b = torch.nn.Parameter(prep_filt(h2b, 1), False) + else: + self.bandpass_diag = False + h0o, _, h1o, _ = _biort(biort) + self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False) + self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False) + h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift) + self.h0a = torch.nn.Parameter(prep_filt(h0a, 1), False) + self.h0b = torch.nn.Parameter(prep_filt(h0b, 1), False) + self.h1a = torch.nn.Parameter(prep_filt(h1a, 1), False) + self.h1b = torch.nn.Parameter(prep_filt(h1b, 1), False) + + def forward(self, x): + # Ensure the input size is divisible by 8 + ch, r, c = x.shape[1:] + rem = r % 8 + if rem != 0: + rows_after = (9 - rem) // 2 + rows_before = (8 - rem) // 2 + x = torch.cat((x[:, :, :rows_before], x, x[:, :, -rows_after:]), dim=2) + rem = c % 8 + if rem != 0: + cols_after = (9 - rem) // 2 + cols_before = (8 - rem) // 2 + x = torch.cat( + (x[:, :, :, :cols_before], x, x[:, :, :, -cols_after:]), dim=3 + ) + + if self.combine_colour: + assert ch == 3 + + if self.bandpass_diag: + pass + Z = ScatLayerj2_rot_f.apply( + x, + self.h0o, + self.h1o, + self.h2o, + self.h0a, + self.h0b, + self.h1a, + self.h1b, + self.h2a, + self.h2b, + self.mode, + self.magbias, + self.combine_colour, + ) + else: + Z = ScatLayerj2_f.apply( + x, + self.h0o, + self.h1o, + self.h0a, + self.h0b, + self.h1a, + self.h1b, + self.mode, + self.magbias, + self.combine_colour, + ) + + if not self.combine_colour: + b, _, c, h, w = Z.shape + Z = Z.view(b, 49 * c, h, w) + return Z + + def extra_repr(self): + return "biort='{}', mode='{}', magbias={}".format( + self.biort, self.mode_str, self.magbias + ) diff --git a/mirtorch/vendors/pytorch_wavelets/scatternet/lowlevel.py b/mirtorch/vendors/pytorch_wavelets/scatternet/lowlevel.py new file mode 100644 index 0000000..f7d5348 --- /dev/null +++ b/mirtorch/vendors/pytorch_wavelets/scatternet/lowlevel.py @@ -0,0 +1,779 @@ +from __future__ import absolute_import +import torch +import torch.nn.functional as F + +from ..dtcwt.transform_funcs import fwd_j1, inv_j1 +from ..dtcwt.transform_funcs import fwd_j1_rot, inv_j1_rot +from ..dtcwt.transform_funcs import fwd_j2plus, inv_j2plus +from ..dtcwt.transform_funcs import fwd_j2plus_rot, inv_j2plus_rot + + +def mode_to_int(mode): + if mode == "zero": + return 0 + elif mode == "symmetric": + return 1 + elif mode == "per" or mode == "periodization": + return 2 + elif mode == "constant": + return 3 + elif mode == "reflect": + return 4 + elif mode == "replicate": + return 5 + elif mode == "periodic": + return 6 + else: + raise ValueError("Unkown pad type: {}".format(mode)) + + +def int_to_mode(mode): + if mode == 0: + return "zero" + elif mode == 1: + return "symmetric" + elif mode == 2: + return "periodization" + elif mode == 3: + return "constant" + elif mode == 4: + return "reflect" + elif mode == 5: + return "replicate" + elif mode == 6: + return "periodic" + else: + raise ValueError("Unkown pad type: {}".format(mode)) + + +class SmoothMagFn(torch.autograd.Function): + """Class to do complex magnitude""" + + @staticmethod + def forward(ctx, x, y, b): + r = torch.sqrt(x**2 + y**2 + b**2) + if x.requires_grad: + dx = x / r + dy = y / r + ctx.save_for_backward(dx, dy) + + return r - b + + @staticmethod + def backward(ctx, dr): + dx = None + if ctx.needs_input_grad[0]: + drdx, drdy = ctx.saved_tensors + dx = drdx * dr + dy = drdy * dr + return dx, dy, None + + +class ScatLayerj1_f(torch.autograd.Function): + """Function to do forward and backward passes of a single scattering + layer with the DTCWT biorthogonal filters.""" + + @staticmethod + def forward(ctx, x, h0o, h1o, mode, bias, combine_colour): + # bias = 1e-2 + # bias = 0 + ctx.in_shape = x.shape + batch, ch, r, c = x.shape + assert r % 2 == c % 2 == 0 + mode = int_to_mode(mode) + ctx.mode = mode + ctx.combine_colour = combine_colour + + ll, reals, imags = fwd_j1(x, h0o, h1o, False, 1, mode) + ll = F.avg_pool2d(ll, 2) + if combine_colour: + r = torch.sqrt( + reals[:, :, 0] ** 2 + + imags[:, :, 0] ** 2 + + reals[:, :, 1] ** 2 + + imags[:, :, 1] ** 2 + + reals[:, :, 2] ** 2 + + imags[:, :, 2] ** 2 + + bias**2 + ) + r = r[:, :, None] + else: + r = torch.sqrt(reals**2 + imags**2 + bias**2) + + if x.requires_grad: + drdx = reals / r + drdy = imags / r + ctx.save_for_backward(h0o, h1o, drdx, drdy) + else: + z = x.new_zeros(1) + ctx.save_for_backward(h0o, h1o, z, z) + + r = r - bias + del reals, imags + if combine_colour: + Z = torch.cat((ll, r[:, :, 0]), dim=1) + else: + Z = torch.cat((ll[:, None], r), dim=1) + + return Z + + @staticmethod + def backward(ctx, dZ): + dX = None + mode = ctx.mode + + if ctx.needs_input_grad[0]: + # h0o, h1o, θ = ctx.saved_tensors + h0o, h1o, drdx, drdy = ctx.saved_tensors + # Use the special properties of the filters to get the time reverse + h0o_t = h0o + h1o_t = h1o + + # Level 1 backward (time reversed biorthogonal analysis filters) + if ctx.combine_colour: + dYl, dr = dZ[:, :3], dZ[:, 3:] + dr = dr[:, :, None] + else: + dYl, dr = dZ[:, 0], dZ[:, 1:] + ll = 1 / 4 * F.interpolate(dYl, scale_factor=2, mode="nearest") + reals = dr * drdx + imags = dr * drdy + + dX = inv_j1(ll, reals, imags, h0o_t, h1o_t, 1, 3, 4, mode) + + return (dX,) + (None,) * 5 + + +class ScatLayerj1_rot_f(torch.autograd.Function): + """Function to do forward and backward passes of a single scattering + layer with the DTCWT biorthogonal filters. Uses the rotationally symmetric + filters, i.e. a slightly more expensive operation.""" + + @staticmethod + def forward(ctx, x, h0o, h1o, h2o, mode, bias, combine_colour): + mode = int_to_mode(mode) + ctx.mode = mode + # bias = 0 + ctx.in_shape = x.shape + ctx.combine_colour = combine_colour + batch, ch, r, c = x.shape + assert r % 2 == c % 2 == 0 + + # Level 1 forward (biorthogonal analysis filters) + ll, reals, imags = fwd_j1_rot(x, h0o, h1o, h2o, False, 1, mode) + ll = F.avg_pool2d(ll, 2) + if combine_colour: + r = torch.sqrt( + reals[:, :, 0] ** 2 + + imags[:, :, 0] ** 2 + + reals[:, :, 1] ** 2 + + imags[:, :, 1] ** 2 + + reals[:, :, 2] ** 2 + + imags[:, :, 2] ** 2 + + bias**2 + ) + r = r[:, :, None] + else: + r = torch.sqrt(reals**2 + imags**2 + bias**2) + if x.requires_grad: + drdx = reals / r + drdy = imags / r + ctx.save_for_backward(h0o, h1o, h2o, drdx, drdy) + else: + z = x.new_zeros(1) + ctx.save_for_backward(h0o, h1o, h2o, z, z) + r = r - bias + del reals, imags + if combine_colour: + Z = torch.cat((ll, r[:, :, 0]), dim=1) + else: + Z = torch.cat((ll[:, None], r), dim=1) + + return Z + + @staticmethod + def backward(ctx, dZ): + dX = None + mode = ctx.mode + + if ctx.needs_input_grad[0]: + # Don't need to do time reverse as these filters are symmetric + # h0o, h1o, h2o, θ = ctx.saved_tensors + h0o, h1o, h2o, drdx, drdy = ctx.saved_tensors + + # Level 1 backward (time reversed biorthogonal analysis filters) + if ctx.combine_colour: + dYl, dr = dZ[:, :3], dZ[:, 3:] + dr = dr[:, :, None] + else: + dYl, dr = dZ[:, 0], dZ[:, 1:] + ll = 1 / 4 * F.interpolate(dYl, scale_factor=2, mode="nearest") + + reals = dr * drdx + imags = dr * drdy + dX = inv_j1_rot(ll, reals, imags, h0o, h1o, h2o, 1, 3, 4, mode) + + return (dX,) + (None,) * 6 + + +class ScatLayerj2_f(torch.autograd.Function): + """Function to do forward and backward passes of a single scattering + layer with the DTCWT biorthogonal filters.""" + + @staticmethod + def forward(ctx, x, h0o, h1o, h0a, h0b, h1a, h1b, mode, bias, combine_colour): + # bias = 1e-2 + # bias = 0 + ctx.in_shape = x.shape + batch, ch, r, c = x.shape + assert r % 8 == c % 8 == 0 + mode = int_to_mode(mode) + ctx.mode = mode + ctx.combine_colour = combine_colour + + # First order scattering + s0, reals, imags = fwd_j1(x, h0o, h1o, False, 1, mode) + if combine_colour: + s1_j1 = torch.sqrt( + reals[:, :, 0] ** 2 + + imags[:, :, 0] ** 2 + + reals[:, :, 1] ** 2 + + imags[:, :, 1] ** 2 + + reals[:, :, 2] ** 2 + + imags[:, :, 2] ** 2 + + bias**2 + ) + s1_j1 = s1_j1[:, :, None] + if x.requires_grad: + dsdx1 = reals / s1_j1 + dsdy1 = imags / s1_j1 + s1_j1 = s1_j1 - bias + + s0, reals, imags = fwd_j2plus(s0, h0a, h1a, h0b, h1b, False, 1, mode) + s1_j2 = torch.sqrt( + reals[:, :, 0] ** 2 + + imags[:, :, 0] ** 2 + + reals[:, :, 1] ** 2 + + imags[:, :, 1] ** 2 + + reals[:, :, 2] ** 2 + + imags[:, :, 2] ** 2 + + bias**2 + ) + s1_j2 = s1_j2[:, :, None] + if x.requires_grad: + dsdx2 = reals / s1_j2 + dsdy2 = imags / s1_j2 + s1_j2 = s1_j2 - bias + s0 = F.avg_pool2d(s0, 2) + + # Second order scattering + s1_j1 = s1_j1[:, :, 0] + s1_j1, reals, imags = fwd_j1(s1_j1, h0o, h1o, False, 1, mode) + s2_j1 = torch.sqrt(reals**2 + imags**2 + bias**2) + if x.requires_grad: + dsdx2_1 = reals / s2_j1 + dsdy2_1 = imags / s2_j1 + q = s2_j1.shape + s2_j1 = s2_j1.view(q[0], 36, q[3], q[4]) + s2_j1 = s2_j1 - bias + s1_j1 = F.avg_pool2d(s1_j1, 2) + if x.requires_grad: + ctx.save_for_backward( + h0o, + h1o, + h0a, + h0b, + h1a, + h1b, + dsdx1, + dsdy1, + dsdx2, + dsdy2, + dsdx2_1, + dsdy2_1, + ) + else: + z = x.new_zeros(1) + ctx.save_for_backward(h0o, h1o, h0a, h0b, h1a, h1b, z, z, z, z, z, z) + + del reals, imags + Z = torch.cat((s0, s1_j1, s1_j2[:, :, 0], s2_j1), dim=1) + + else: + s1_j1 = torch.sqrt(reals**2 + imags**2 + bias**2) + if x.requires_grad: + dsdx1 = reals / s1_j1 + dsdy1 = imags / s1_j1 + s1_j1 = s1_j1 - bias + + s0, reals, imags = fwd_j2plus(s0, h0a, h1a, h0b, h1b, False, 1, mode) + s1_j2 = torch.sqrt(reals**2 + imags**2 + bias**2) + if x.requires_grad: + dsdx2 = reals / s1_j2 + dsdy2 = imags / s1_j2 + s1_j2 = s1_j2 - bias + s0 = F.avg_pool2d(s0, 2) + + # Second order scattering + p = s1_j1.shape + s1_j1 = s1_j1.view(p[0], 6 * p[2], p[3], p[4]) + + s1_j1, reals, imags = fwd_j1(s1_j1, h0o, h1o, False, 1, mode) + s2_j1 = torch.sqrt(reals**2 + imags**2 + bias**2) + if x.requires_grad: + dsdx2_1 = reals / s2_j1 + dsdy2_1 = imags / s2_j1 + q = s2_j1.shape + s2_j1 = s2_j1.view(q[0], 36, q[2] // 6, q[3], q[4]) + s2_j1 = s2_j1 - bias + s1_j1 = F.avg_pool2d(s1_j1, 2) + s1_j1 = s1_j1.view(p[0], 6, p[2], p[3] // 2, p[4] // 2) + + if x.requires_grad: + ctx.save_for_backward( + h0o, + h1o, + h0a, + h0b, + h1a, + h1b, + dsdx1, + dsdy1, + dsdx2, + dsdy2, + dsdx2_1, + dsdy2_1, + ) + else: + z = x.new_zeros(1) + ctx.save_for_backward(h0o, h1o, h0a, h0b, h1a, h1b, z, z, z, z, z, z) + + del reals, imags + Z = torch.cat((s0[:, None], s1_j1, s1_j2, s2_j1), dim=1) + + return Z + + @staticmethod + def backward(ctx, dZ): + dX = None + mode = ctx.mode + + if ctx.needs_input_grad[0]: + # Input has shape N, L, C, H, W + o_dim = 1 + h_dim = 3 + w_dim = 4 + + # Retrieve phase info + ( + h0o, + h1o, + h0a, + h0b, + h1a, + h1b, + dsdx1, + dsdy1, + dsdx2, + dsdy2, + dsdx2_1, + dsdy2_1, + ) = ctx.saved_tensors + + # Use the special properties of the filters to get the time reverse + h0o_t = h0o + h1o_t = h1o + h0a_t = h0b + h0b_t = h0a + h1a_t = h1b + h1b_t = h1a + + # Level 1 backward (time reversed biorthogonal analysis filters) + if ctx.combine_colour: + ds0, ds1_j1, ds1_j2, ds2_j1 = ( + dZ[:, :3], + dZ[:, 3:9], + dZ[:, 9:15], + dZ[:, 15:], + ) + ds1_j2 = ds1_j2[:, :, None] + + ds1_j1 = 1 / 4 * F.interpolate(ds1_j1, scale_factor=2, mode="nearest") + q = ds2_j1.shape + ds2_j1 = ds2_j1.view(q[0], 6, 6, q[2], q[3]) + + # Inverse second order scattering + reals = ds2_j1 * dsdx2_1 + imags = ds2_j1 * dsdy2_1 + ds1_j1 = inv_j1( + ds1_j1, reals, imags, h0o_t, h1o_t, o_dim, h_dim, w_dim, mode + ) + ds1_j1 = ds1_j1[:, :, None] + + # Inverse first order scattering j=2 + ds0 = 1 / 4 * F.interpolate(ds0, scale_factor=2, mode="nearest") + # s = ds1_j2.shape + # ds1_j2 = ds1_j2.view(s[0], 6, s[1]//6, s[2], s[3]) + reals = ds1_j2 * dsdx2 + imags = ds1_j2 * dsdy2 + ds0 = inv_j2plus( + ds0, + reals, + imags, + h0a_t, + h1a_t, + h0b_t, + h1b_t, + o_dim, + h_dim, + w_dim, + mode, + ) + + # Inverse first order scattering j=1 + reals = ds1_j1 * dsdx1 + imags = ds1_j1 * dsdy1 + dX = inv_j1(ds0, reals, imags, h0o_t, h1o_t, o_dim, h_dim, w_dim, mode) + else: + ds0, ds1_j1, ds1_j2, ds2_j1 = ( + dZ[:, 0], + dZ[:, 1:7], + dZ[:, 7:13], + dZ[:, 13:], + ) + p = ds1_j1.shape + ds1_j1 = ds1_j1.view(p[0], p[2] * 6, p[3], p[4]) + ds1_j1 = 1 / 4 * F.interpolate(ds1_j1, scale_factor=2, mode="nearest") + q = ds2_j1.shape + ds2_j1 = ds2_j1.view(q[0], 6, q[2] * 6, q[3], q[4]) + + # Inverse second order scattering + reals = ds2_j1 * dsdx2_1 + imags = ds2_j1 * dsdy2_1 + ds1_j1 = inv_j1( + ds1_j1, reals, imags, h0o_t, h1o_t, o_dim, h_dim, w_dim, mode + ) + ds1_j1 = ds1_j1.view(p[0], 6, p[2], p[3] * 2, p[4] * 2) + + # Inverse first order scattering j=2 + ds0 = 1 / 4 * F.interpolate(ds0, scale_factor=2, mode="nearest") + # s = ds1_j2.shape + # ds1_j2 = ds1_j2.view(s[0], 6, s[1]//6, s[2], s[3]) + reals = ds1_j2 * dsdx2 + imags = ds1_j2 * dsdy2 + ds0 = inv_j2plus( + ds0, + reals, + imags, + h0a_t, + h1a_t, + h0b_t, + h1b_t, + o_dim, + h_dim, + w_dim, + mode, + ) + + # Inverse first order scattering j=1 + reals = ds1_j1 * dsdx1 + imags = ds1_j1 * dsdy1 + dX = inv_j1(ds0, reals, imags, h0o_t, h1o_t, o_dim, h_dim, w_dim, mode) + + return (dX,) + (None,) * 9 + + +class ScatLayerj2_rot_f(torch.autograd.Function): + """Function to do forward and backward passes of a single scattering + layer with the DTCWT bandpass biorthogonal and qshift filters .""" + + @staticmethod + def forward( + ctx, x, h0o, h1o, h2o, h0a, h0b, h1a, h1b, h2a, h2b, mode, bias, combine_colour + ): + # bias = 1e-2 + # bias = 0 + ctx.in_shape = x.shape + batch, ch, r, c = x.shape + assert r % 8 == c % 8 == 0 + mode = int_to_mode(mode) + ctx.mode = mode + ctx.combine_colour = combine_colour + + # First order scattering + s0, reals, imags = fwd_j1_rot(x, h0o, h1o, h2o, False, 1, mode) + if combine_colour: + s1_j1 = torch.sqrt( + reals[:, :, 0] ** 2 + + imags[:, :, 0] ** 2 + + reals[:, :, 1] ** 2 + + imags[:, :, 1] ** 2 + + reals[:, :, 2] ** 2 + + imags[:, :, 2] ** 2 + + bias**2 + ) + s1_j1 = s1_j1[:, :, None] + if x.requires_grad: + dsdx1 = reals / s1_j1 + dsdy1 = imags / s1_j1 + s1_j1 = s1_j1 - bias + + s0, reals, imags = fwd_j2plus_rot( + s0, h0a, h1a, h0b, h1b, h2a, h2b, False, 1, mode + ) + s1_j2 = torch.sqrt( + reals[:, :, 0] ** 2 + + imags[:, :, 0] ** 2 + + reals[:, :, 1] ** 2 + + imags[:, :, 1] ** 2 + + reals[:, :, 2] ** 2 + + imags[:, :, 2] ** 2 + + bias**2 + ) + s1_j2 = s1_j2[:, :, None] + if x.requires_grad: + dsdx2 = reals / s1_j2 + dsdy2 = imags / s1_j2 + s1_j2 = s1_j2 - bias + s0 = F.avg_pool2d(s0, 2) + + # Second order scattering + s1_j1 = s1_j1[:, :, 0] + s1_j1, reals, imags = fwd_j1_rot(s1_j1, h0o, h1o, h2o, False, 1, mode) + s2_j1 = torch.sqrt(reals**2 + imags**2 + bias**2) + if x.requires_grad: + dsdx2_1 = reals / s2_j1 + dsdy2_1 = imags / s2_j1 + q = s2_j1.shape + s2_j1 = s2_j1.view(q[0], 36, q[3], q[4]) + s2_j1 = s2_j1 - bias + s1_j1 = F.avg_pool2d(s1_j1, 2) + if x.requires_grad: + ctx.save_for_backward( + h0o, + h1o, + h2o, + h0a, + h0b, + h1a, + h1b, + h2a, + h2b, + dsdx1, + dsdy1, + dsdx2, + dsdy2, + dsdx2_1, + dsdy2_1, + ) + else: + z = x.new_zeros(1) + ctx.save_for_backward( + h0o, h1o, h2o, h0a, h0b, h1a, h1b, h2a, h2b, z, z, z, z, z, z + ) + + del reals, imags + Z = torch.cat((s0, s1_j1, s1_j2[:, :, 0], s2_j1), dim=1) + else: + s1_j1 = torch.sqrt(reals**2 + imags**2 + bias**2) + if x.requires_grad: + dsdx1 = reals / s1_j1 + dsdy1 = imags / s1_j1 + s1_j1 = s1_j1 - bias + + s0, reals, imags = fwd_j2plus_rot( + s0, h0a, h1a, h0b, h1b, h2a, h2b, False, 1, mode + ) + s1_j2 = torch.sqrt(reals**2 + imags**2 + bias**2) + if x.requires_grad: + dsdx2 = reals / s1_j2 + dsdy2 = imags / s1_j2 + s1_j2 = s1_j2 - bias + s0 = F.avg_pool2d(s0, 2) + + # Second order scattering + p = s1_j1.shape + s1_j1 = s1_j1.view(p[0], 6 * p[2], p[3], p[4]) + s1_j1, reals, imags = fwd_j1_rot(s1_j1, h0o, h1o, h2o, False, 1, mode) + s2_j1 = torch.sqrt(reals**2 + imags**2 + bias**2) + if x.requires_grad: + dsdx2_1 = reals / s2_j1 + dsdy2_1 = imags / s2_j1 + q = s2_j1.shape + s2_j1 = s2_j1.view(q[0], 36, q[2] // 6, q[3], q[4]) + s2_j1 = s2_j1 - bias + s1_j1 = F.avg_pool2d(s1_j1, 2) + s1_j1 = s1_j1.view(p[0], 6, p[2], p[3] // 2, p[4] // 2) + + if x.requires_grad: + ctx.save_for_backward( + h0o, + h1o, + h2o, + h0a, + h0b, + h1a, + h1b, + h2a, + h2b, + dsdx1, + dsdy1, + dsdx2, + dsdy2, + dsdx2_1, + dsdy2_1, + ) + else: + z = x.new_zeros(1) + ctx.save_for_backward( + h0o, h1o, h2o, h0a, h0b, h1a, h1b, h2a, h2b, z, z, z, z, z, z + ) + + del reals, imags + Z = torch.cat((s0[:, None], s1_j1, s1_j2, s2_j1), dim=1) + + return Z + + @staticmethod + def backward(ctx, dZ): + dX = None + mode = ctx.mode + + if ctx.needs_input_grad[0]: + # Input has shape N, L, C, H, W + o_dim = 1 + h_dim = 3 + w_dim = 4 + + # Retrieve phase info + ( + h0o, + h1o, + h2o, + h0a, + h0b, + h1a, + h1b, + h2a, + h2b, + dsdx1, + dsdy1, + dsdx2, + dsdy2, + dsdx2_1, + dsdy2_1, + ) = ctx.saved_tensors + + # Use the special properties of the filters to get the time reverse + h0o_t = h0o + h1o_t = h1o + h2o_t = h2o + h0a_t = h0b + h0b_t = h0a + h1a_t = h1b + h1b_t = h1a + h2a_t = h2b + h2b_t = h2a + + # Level 1 backward (time reversed biorthogonal analysis filters) + if ctx.combine_colour: + ds0, ds1_j1, ds1_j2, ds2_j1 = ( + dZ[:, :3], + dZ[:, 3:9], + dZ[:, 9:15], + dZ[:, 15:], + ) + ds1_j2 = ds1_j2[:, :, None] + + # Inverse second order scattering + ds1_j1 = 1 / 4 * F.interpolate(ds1_j1, scale_factor=2, mode="nearest") + q = ds2_j1.shape + ds2_j1 = ds2_j1.view(q[0], 6, 6, q[2], q[3]) + + # Inverse second order scattering + reals = ds2_j1 * dsdx2_1 + imags = ds2_j1 * dsdy2_1 + ds1_j1 = inv_j1_rot( + ds1_j1, reals, imags, h0o_t, h1o_t, h2o_t, o_dim, h_dim, w_dim, mode + ) + ds1_j1 = ds1_j1[:, :, None] + + # Inverse first order scattering j=2 + ds0 = 1 / 4 * F.interpolate(ds0, scale_factor=2, mode="nearest") + # s = ds1_j2.shape + # ds1_j2 = ds1_j2.view(s[0], 6, s[1]//6, s[2], s[3]) + reals = ds1_j2 * dsdx2 + imags = ds1_j2 * dsdy2 + ds0 = inv_j2plus_rot( + ds0, + reals, + imags, + h0a_t, + h1a_t, + h0b_t, + h1b_t, + h2a_t, + h2b_t, + o_dim, + h_dim, + w_dim, + mode, + ) + + # Inverse first order scattering j=1 + reals = ds1_j1 * dsdx1 + imags = ds1_j1 * dsdy1 + dX = inv_j1_rot( + ds0, reals, imags, h0o_t, h1o_t, h2o_t, o_dim, h_dim, w_dim, mode + ) + else: + ds0, ds1_j1, ds1_j2, ds2_j1 = ( + dZ[:, 0], + dZ[:, 1:7], + dZ[:, 7:13], + dZ[:, 13:], + ) + + # Inverse second order scattering + p = ds1_j1.shape + ds1_j1 = ds1_j1.view(p[0], p[2] * 6, p[3], p[4]) + ds1_j1 = 1 / 4 * F.interpolate(ds1_j1, scale_factor=2, mode="nearest") + q = ds2_j1.shape + ds2_j1 = ds2_j1.view(q[0], 6, q[2] * 6, q[3], q[4]) + reals = ds2_j1 * dsdx2_1 + imags = ds2_j1 * dsdy2_1 + ds1_j1 = inv_j1_rot( + ds1_j1, reals, imags, h0o_t, h1o_t, h2o_t, o_dim, h_dim, w_dim, mode + ) + ds1_j1 = ds1_j1.view(p[0], 6, p[2], p[3] * 2, p[4] * 2) + + # Inverse first order scattering j=2 + ds0 = 1 / 4 * F.interpolate(ds0, scale_factor=2, mode="nearest") + # s = ds1_j2.shape + # ds1_j2 = ds1_j2.view(s[0], 6, s[1]//6, s[2], s[3]) + reals = ds1_j2 * dsdx2 + imags = ds1_j2 * dsdy2 + ds0 = inv_j2plus_rot( + ds0, + reals, + imags, + h0a_t, + h1a_t, + h0b_t, + h1b_t, + h2a_t, + h2b_t, + o_dim, + h_dim, + w_dim, + mode, + ) + + # Inverse first order scattering j=1 + reals = ds1_j1 * dsdx1 + imags = ds1_j1 * dsdy1 + dX = inv_j1_rot( + ds0, reals, imags, h0o_t, h1o_t, h2o_t, o_dim, h_dim, w_dim, mode + ) + + return (dX,) + (None,) * 12 diff --git a/mirtorch/vendors/pytorch_wavelets/utils.py b/mirtorch/vendors/pytorch_wavelets/utils.py new file mode 100644 index 0000000..1c8a491 --- /dev/null +++ b/mirtorch/vendors/pytorch_wavelets/utils.py @@ -0,0 +1,243 @@ +""" Useful utilities for testing the 2-D DTCWT with synthetic images""" + +from __future__ import absolute_import + +import functools +import numpy as np + + +def unpack(pyramid, backend="numpy"): + """Unpacks a pyramid give back the constituent parts. + + :param pyramid: The Pyramid of DTCWT transforms you wish to unpack + :param str backend: A string from 'numpy', 'opencl', or 'tf' indicating + which attributes you want to unpack from the pyramid. + + :returns: returns a generator which can be unpacked into the Yl, Yh and + Yscale components of the pyramid. The generator will only return 2 + values if the pyramid was created with the include_scale parameter set + to false. + + .. note:: + + You can still unpack a tf or opencl pyramid as if it were created by a + numpy. In this case it will return a numpy array, rather than the + backend specific array type. + """ + backend = backend.lower() + if backend == "numpy": + yield pyramid.lowpass + yield pyramid.highpasses + if pyramid.scales is not None: + yield pyramid.scales + elif backend == "opencl": + yield pyramid.cl_lowpass + yield pyramid.cl_highpasses + if pyramid.cl_scales is not None: + yield pyramid.cl_scales + elif backend == "tf": + yield pyramid.lowpass_op + yield pyramid.highpasses_ops + if pyramid.scales_ops is not None: + yield pyramid.scales_ops + + +def drawedge(theta, r, w, N): + """Generate an image of size N * N pels, of an edge going from 0 to 1 in + height at theta degrees to the horizontal (top of image = 1 if angle = 0). + r is a two-element vector, it is a coordinate in ij coords through which the + step should pass. + The shape of the intensity step is half a raised cosine w pels wide (w>=1). + + T. E . Gale's enhancement to drawedge() for MATLAB, transliterated + to Python by S. C. Forshaw, Nov. 2013.""" + + # convert theta from degrees to radians + thetar = np.array(theta * np.pi / 180) + + # Calculate image centre from given width + imCentre = (np.array([N, N]).T - 1) / 2 + 1 + + # Calculate values to subtract from the plane + r = np.array([np.cos(thetar), np.sin(thetar)]) * (-1) * (r - imCentre) + + # check width of raised cosine section + w = np.maximum(1, w) + + ramp = np.arange(0, N) - (N + 1) / 2 + hgrad = np.sin(thetar) * (-1) * np.ones([N, 1]) + vgrad = np.cos(thetar) * (-1) * np.ones([1, N]) + plane = ((hgrad * ramp) - r[0]) + ((ramp * vgrad).T - r[1]) + x = 0.5 + 0.5 * np.sin( + np.minimum(np.maximum(plane * (np.pi / w), np.pi / (-2)), np.pi / 2) + ) + + return x + + +def drawcirc(r, w, du, dv, N): + + """Generate an image of size N*N pels, containing a circle + radius r pels and centred at du,dv relative + to the centre of the image. The edge of the circle is a cosine shaped + edge of width w (from 10 to 90% points). + + Python implementation by S. C. Forshaw, November 2013.""" + + # check value of w to avoid dividing by zero + w = np.maximum(w, 1) + + # x plane + x = np.ones([N, 1]) * ((np.arange(0, N, 1, dtype="float") - (N + 1) / 2 - dv) / r) + + # y vector + y = ( + ((np.arange(0, N, 1, dtype="float") - (N + 1) / 2 - du) / r) * np.ones([1, N]) + ).T + + # Final circle image plane + p = 0.5 + 0.5 * np.sin( + np.minimum( + np.maximum( + (np.exp(np.array([-0.5]) * (x**2 + y**2)).T - np.exp((-0.5))) + * (r * 3 / w), # noqa + np.pi / (-2), + ), + np.pi / 2, + ) + ) + return p + + +def asfarray(X): + """Similar to :py:func:`numpy.asfarray` except that this function tries to + preserve the original datatype of X if it is already a floating point type + and will pass floating point arrays through directly without copying. + + """ + X = np.asanyarray(X) + return np.asfarray(X, dtype=X.dtype) + + +def appropriate_complex_type_for(X): + """Return an appropriate complex data type depending on the type of X. If X + is already complex, return that, if it is floating point return a complex + type of the appropriate size and if it is integer, choose an complex + floating point type depending on the result of :py:func:`numpy.asfarray`. + + """ + X = asfarray(X) + + if np.issubsctype(X.dtype, np.complex64) or np.issubsctype(X.dtype, np.complex128): + return X.dtype + elif np.issubsctype(X.dtype, np.float32): + return np.complex64 + elif np.issubsctype(X.dtype, np.float64): + return np.complex128 + + # God knows, err on the side of caution + return np.complex128 + + +def as_column_vector(v): + """Return *v* as a column vector with shape (N,1).""" + v = np.atleast_2d(v) + if v.shape[0] == 1: + return v.T + else: + return v + + +def reflect(x, minx, maxx): + """Reflect the values in matrix *x* about the scalar values *minx* and + *maxx*. Hence a vector *x* containing a long linearly increasing series is + converted into a waveform which ramps linearly up and down between *minx* + and *maxx*. If *x* contains integers and *minx* and *maxx* are (integers + + 0.5), the ramps will have repeated max and min samples. + + .. codeauthor:: Rich Wareham , Aug 2013 + .. codeauthor:: Nick Kingsbury, Cambridge University, January 1999. + + """ + x = np.asanyarray(x) + rng = maxx - minx + rng_by_2 = 2 * rng + mod = np.fmod(x - minx, rng_by_2) + normed_mod = np.where(mod < 0, mod + rng_by_2, mod) + out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx + return np.array(out, dtype=x.dtype) + + +def symm_pad_1d(l, m): + """Creates indices for symmetric padding. Works for 1-D. + + Inptus: + l (int): size of input + m (int): size of filter + """ + xe = reflect(np.arange(-m, l + m, dtype="int32"), -0.5, l - 0.5) + return xe + + +# note that this decorator ignores **kwargs +# From https://wiki.python.org/moin/PythonDecoratorLibrary#Alternate_memoize_as_nested_functions # noqa +def memoize(obj): + cache = obj.cache = {} + + @functools.wraps(obj) + def memoizer(*args, **kwargs): + if args not in cache: + cache[args] = obj(*args, **kwargs) + return cache[args] + + return memoizer + + +def stacked_2d_matrix_vector_prod(mats, vecs): + """ + Interpret *mats* and *vecs* as arrays of 2D matrices and vectors. I.e. + *mats* has shape PxQxNxM and *vecs* has shape PxQxM. The result + is a PxQxN array equivalent to: + + .. code:: + + result[i,j,:] = mats[i,j,:,:].dot(vecs[i,j,:]) + + for all valid row and column indices *i* and *j*. + """ + return np.einsum("...ij,...j->...i", mats, vecs) + + +def stacked_2d_vector_matrix_prod(vecs, mats): + """ + Interpret *mats* and *vecs* as arrays of 2D matrices and vectors. I.e. + *mats* has shape PxQxNxM and *vecs* has shape PxQxN. The result + is a PxQxM array equivalent to: + + .. code:: + + result[i,j,:] = mats[i,j,:,:].T.dot(vecs[i,j,:]) + + for all valid row and column indices *i* and *j*. + """ + vecshape = np.array(vecs.shape + (1,)) + vecshape[-1:-3:-1] = vecshape[-2:] + outshape = mats.shape[:-2] + (mats.shape[-1],) + return stacked_2d_matrix_matrix_prod(vecs.reshape(vecshape), mats).reshape( + outshape + ) # noqa + + +def stacked_2d_matrix_matrix_prod(mats1, mats2): + """ + Interpret *mats1* and *mats2* as arrays of 2D matrices. I.e. + *mats1* has shape PxQxNxM and *mats2* has shape PxQxMxR. The result + is a PxQxNxR array equivalent to: + + .. code:: + + result[i,j,:,:] = mats1[i,j,:,:].dot(mats2[i,j,:,:]) + + for all valid row and column indices *i* and *j*. + """ + return np.einsum("...ij,...jk->...ik", mats1, mats2) diff --git a/pyproject.toml b/pyproject.toml index 330a920..a1fda2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,6 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "MIRTorch" +version = "0.1.0" authors = [{ name = "Guanhua Wang", email = "guanhuaw@umich.edu" }] description = "a PyTorch-based image reconstruction toolbox" readme = "README.md" @@ -15,7 +16,7 @@ classifiers = [ "License :: OSI Approved :: BSD License", "Operating System :: OS Independent", 'Development Status :: 2 - Pre-Alpha', - 'Topic :: Signal Processing', + 'Topic :: Scientific/Engineering :: Image Processing', ] dependencies = [ "torch>=1.13", @@ -29,10 +30,8 @@ dependencies = [ 'importlib-metadata; python_version<"3.8"', "einops", "matplotlib", - "pytorch_wavelets@git+https://github.com/fbcotter/pytorch_wavelets.git@8d2e3b4289beaea9aa89f7b1dbb290e448331197#egg=pytorch_wavelets", ] -dynamic = ["version"] [project.urls] # Optional "repository" = "https://github.com/guanhuaw/MIRTorch" @@ -43,3 +42,35 @@ test = ["coverage", "pytest"] [project.scripts] my-script = "my_package.module:function" + +[tool.ruff] +exclude = [ + "__pycache__", + ".github", + ".idea", + ".vscode", + "build", + "dist", + "docs", + "examples", + "tests", + "mirtorch/vendors/pytorch_wavelets", + "venv", + ".git", +] + +[tool.pyright] +exclude = [ + "__pycache__", + ".github", + ".idea", + ".vscode", + "build", + "dist", + "docs", + "examples", + "tests", + "mirtorch/vendors/pytorch_wavelets", + "venv", + ".git", +] diff --git a/tests/basics_tests.py b/tests/basics_tests.py index 10c2b45..4e3b58b 100644 --- a/tests/basics_tests.py +++ b/tests/basics_tests.py @@ -1,236 +1,154 @@ -import unittest -import sys -import os - -path = os.path.dirname(os.path.abspath(__file__)) -path = path[: path.rfind("/")] -sys.path.insert(0, path) -from mirtorch.linear import basics -from mirtorch.linear import wavelets +import pytest import torch import torch.nn.functional as F -from .utils import conv1D, conv2D - - -class TestBasic(unittest.TestCase): - def test_diag(self): - x = torch.randn(5, 5) - P = torch.randn(5, 5) - - diag = basics.Diag(P) - out = diag.apply(x) - - exp = P * x - assert torch.allclose(out, exp, rtol=1e-3) - - def test_conv1d_apply_simple(self): - x = torch.randn(1, 16, 50) - weight = torch.randn(33, 16, 3) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - x, weight = x.to(device), weight.to(device) - - conv = basics.Convolve1d(x.shape, weight) - out = conv.apply(x) - - # exp = F.conv1d(x, weight) - x = x.permute(0, 2, 1).detach().cpu().numpy() - weight = weight.permute(2, 1, 0).detach().cpu().numpy() - exp = conv1D(x, weight, stride=1, pad=0, dilation=0) - exp = torch.from_numpy(exp).to(device).permute(0, 2, 1) - assert torch.allclose(out, exp, rtol=1.5e-3) - - def test_conv2d_apply_simple(self): - x = torch.randn(1, 4, 5, 5) - weight = torch.randn(8, 4, 3, 3) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - x, weight = x.to(device), weight.to(device) - - conv = basics.Convolve2d(x.shape, weight) - out = conv.apply(x) - - # exp = F.conv2d(x, weight) - x = x.permute(0, 2, 3, 1).detach().cpu().numpy() - weight = weight.permute(2, 3, 1, 0).detach().cpu().numpy() - exp = conv2D(x, weight, stride=1, pad=0, dilation=0) - exp = torch.from_numpy(exp).to(device).permute(0, 3, 1, 2) - assert torch.allclose(out, exp, rtol=1e-3) - - def test_conv3d_apply_simple(self): - x = torch.randn(20, 16, 50, 10, 20) - weight = torch.randn(33, 16, 3, 3, 3) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - x, weight = x.to(device), weight.to(device) - - conv = basics.Convolve3d(x.shape, weight) - out = conv.apply(x) - - exp = F.conv3d(x, weight) - assert torch.allclose(out, exp, rtol=1e-3) - - def test_conv1d_apply_hard(self): - x = torch.randn(20, 16, 50) - weight = torch.randn(33, 16, 3) - bias = torch.randn(33) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - x, weight, bias = x.to(device), weight.to(device), bias.to(device) - - conv = basics.Convolve1d( - x.shape, weight, bias=bias, stride=2, padding=1, dilation=2 - ) - out = conv.apply(x) - - exp = F.conv1d(x, weight, bias=bias, stride=2, padding=1, dilation=2) - assert torch.allclose(out, exp, rtol=1e-3) - - def test_conv2d_apply_hard(self): - x = torch.randn(1, 4, 5, 5) - weight = torch.randn(8, 4, 3, 3) - bias = torch.randn(8) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - x, weight, bias = x.to(device), weight.to(device), bias.to(device) - - conv = basics.Convolve2d( - x.shape, weight, bias=bias, stride=3, padding=2, dilation=2 - ) - out = conv.apply(x) - - exp = F.conv2d(x, weight, bias=bias, stride=3, padding=2, dilation=2) - assert torch.allclose(out, exp, rtol=1e-3) - - def test_conv3d_apply_hard(self): - x = torch.randn(20, 16, 50, 10, 20) - weight = torch.randn(33, 16, 3, 3, 3) - bias = torch.randn(33) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - x, weight, bias = x.to(device), weight.to(device), bias.to(device) - - conv = basics.Convolve3d( - x.shape, weight, bias=bias, stride=3, padding=3, dilation=4 - ) - out = conv.apply(x) - - exp = F.conv3d(x, weight, bias=bias, stride=3, padding=3, dilation=4) - assert torch.allclose(out, exp, rtol=1e-3) - - def test_conv1d_adjoint_simple(self): - x = torch.randn(20, 16, 50) - weight = torch.randn(33, 16, 3) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - x, weight = x.to(device), weight.to(device) - - Ax = F.conv1d(x, weight) - conv = basics.Convolve1d(x.shape, weight) - out = conv.adjoint(Ax) - - exp = F.conv_transpose1d(Ax, weight) - assert torch.allclose(out, exp, rtol=1e-3) - - def test_conv2d_adjoint_simple(self): - x = torch.randn(1, 4, 5, 5) - weight = torch.randn(8, 4, 3, 3) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - x, weight = x.to(device), weight.to(device) - - Ax = F.conv2d(x, weight) - conv = basics.Convolve2d(x.shape, weight) - out = conv.adjoint(Ax) - - exp = F.conv_transpose2d(Ax, weight) - assert torch.allclose(out, exp, rtol=1e-3) - - def test_conv3d_adjoint_simple(self): - x = torch.randn(20, 16, 50, 10, 20) - weight = torch.randn(33, 16, 3, 3, 3) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - x, weight = x.to(device), weight.to(device) - - Ax = F.conv3d(x, weight) - conv = basics.Convolve3d(x.shape, weight) - out = conv.adjoint(Ax) - - exp = F.conv_transpose3d(Ax, weight) - assert torch.allclose(out, exp, rtol=1e-3) - - def test_conv1d_adjoint_hard(self): - x = torch.randn(20, 16, 50) - weight = torch.randn(33, 16, 3) - bias = torch.randn(16) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - x, weight, bias = x.to(device), weight.to(device), bias.to(device) - - Ax = F.conv1d(x, weight, stride=2, padding=1, dilation=2) - conv = basics.Convolve1d( - x.shape, weight, bias=bias, stride=2, padding=1, dilation=2 - ) - out = conv.adjoint(Ax) - - exp = F.conv_transpose1d(Ax, weight, bias=bias, stride=2, padding=1, dilation=2) - assert torch.allclose(out, exp, rtol=1e-3) - - def test_patch2d_forward(self): - x = torch.randn(2, 3, 10, 10) - kernel_size = 2 - stride = 1 - exp = torch.zeros(2, 3, 9, 9, 2, 2) - for ix in range(9): - for iy in range(9): - exp[:, :, ix, iy, :, :] = x[:, :, ix : ix + 2, iy : iy + 2] - P = basics.Patch2D(x.shape, kernel_size, stride) - out = P * x - assert torch.allclose(out, exp, rtol=1e-3) - - def test_patch2d_adjoint(self): - x = torch.randn(2, 3, 9, 9, 2, 2) - kernel_size = 2 - stride = 1 - exp = torch.zeros(2, 3, 10, 10) - for ix in range(9): - for iy in range(9): - exp[:, :, ix : ix + 2, iy : iy + 2] = ( - exp[:, :, ix : ix + 2, iy : iy + 2] + x[:, :, ix, iy, :, :] - ) - P = basics.Patch2D(exp.shape, kernel_size, stride) - out = P.H * x - - def test_patch3d_forward(self): - x = torch.randn(2, 3, 10, 10, 10) - kernel_size = 2 - stride = 1 - exp = torch.zeros(2, 3, 9, 9, 9, 2, 2, 2) - for ix in range(9): - for iy in range(9): - for iz in range(9): - exp[:, :, ix, iy, iz, :, :, :] = x[ - :, :, ix : ix + 2, iy : iy + 2, iz : iz + 2 - ] - P = basics.Patch3D(x.shape, kernel_size, stride) - out = P * x - assert torch.allclose(out, exp, rtol=1e-3) - - def test_patch3d_adjoint(self): - x = torch.randn(2, 3, 9, 9, 9, 2, 2, 2) - kernel_size = 2 - stride = 1 - exp = torch.zeros(2, 3, 10, 10, 10) - for ix in range(9): - for iy in range(9): - for iz in range(9): - exp[:, :, ix : ix + 2, iy : iy + 2, iz : iz + 2] = ( - exp[:, :, ix : ix + 2, iy : iy + 2, iz : iz + 2] - + x[:, :, ix, iy, iz, :, :, :] - ) - P = basics.Patch3D(exp.shape, kernel_size, stride) - out = P.H * x - - def test_wavelet2D(self): - x = torch.randn(1, 2, 101, 167) * (1 + 1j) - W = wavelets.Wavelet2D([1, 2, 101, 167], padding="periodization") - x = torch.randn(101, 167) * (1 + 1j) - W = wavelets.Wavelet2D([101, 167]) - assert torch.allclose(W.H * W * x, x, rtol=1e-3) - - -if __name__ == "__main__": - # if torch.cuda.is_available(): - # print(torch.cuda.get_device_name(0)) - unittest.main() +from mirtorch.linear import basics +from mirtorch.linear import wavelets + +@pytest.fixture +def device(): + """Fixture to handle the allocation of tensors to devices.""" + return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +# Here we use a fixture to initialize data used across multiple tests +@pytest.fixture +def setup_conv_data(device): + """Setup data for convolution tests.""" + x = torch.randn(20, 16, 50, device=device) + weight = torch.randn(33, 16, 3, device=device) + bias = torch.randn(33, device=device) + return x, weight, bias + +# Individual tests using the fixture for device +def test_diag(): + x = torch.randn(5, 5) + P = torch.randn(5, 5) + diag = basics.Diag(P) + out = diag.apply(x) + exp = P * x + assert torch.allclose(out, exp, rtol=1e-3) + +def test_conv1d_apply_simple(device): + x = torch.randn(1, 16, 50, device=device) + weight = torch.randn(33, 16, 3, device=device) + conv = basics.Convolve1d(x.shape, weight) + out = conv.apply(x) + exp = F.conv1d(x, weight) + assert torch.allclose(out, exp, rtol=1.5e-3) + +def test_conv2d_apply_simple(device): + x = torch.randn(1, 4, 5, 5, device=device) + weight = torch.randn(8, 4, 3, 3, device=device) + conv = basics.Convolve2d(x.shape, weight) + out = conv.apply(x) + exp = F.conv2d(x, weight) + assert torch.allclose(out, exp, rtol=1e-3) + +def test_conv3d_apply_simple(device): + x = torch.randn(20, 16, 50, 10, 20, device=device) + weight = torch.randn(33, 16, 3, 3, 3, device=device) + conv = basics.Convolve3d(x.shape, weight) + out = conv.apply(x) + exp = F.conv3d(x, weight) + assert torch.allclose(out, exp, rtol=1e-3) + +def test_conv1d_apply_hard(setup_conv_data): + x, weight, bias = setup_conv_data + conv = basics.Convolve1d(x.shape, weight, bias=bias, stride=2, padding=1, dilation=2) + out = conv.apply(x) + exp = F.conv1d(x, weight, bias, stride=2, padding=1, dilation=2) + assert torch.allclose(out, exp, rtol=1e-3) + +def test_conv2d_apply_hard(device): + x = torch.randn(1, 4, 5, 5, device=device) + weight = torch.randn(8, 4, 3, 3, device=device) + bias = torch.randn(8, device=device) + conv = basics.Convolve2d(x.shape, weight, bias=bias, stride=3, padding=2, dilation=2) + out = conv.apply(x) + exp = F.conv2d(x, weight, bias, stride=3, padding=2, dilation=2) + assert torch.allclose(out, exp, rtol=1e-3) + +def test_conv3d_apply_hard(device): + x = torch.randn(20, 16, 50, 10, 20, device=device) + weight = torch.randn(33, 16, 3, 3, 3, device=device) + bias = torch.randn(33, device=device) + conv = basics.Convolve3d(x.shape, weight, bias=bias, stride=3, padding=3, dilation=4) + out = conv.apply(x) + exp = F.conv3d(x, weight, bias, stride=3, padding=3, dilation=4) + assert torch.allclose(out, exp, rtol=1e-3) + +def test_conv1d_adjoint_simple(device): + x = torch.randn(20, 16, 50, device=device) + weight = torch.randn(33, 16, 3, device=device) + Ax = F.conv1d(x, weight) + conv = basics.Convolve1d(x.shape, weight) + out = conv.adjoint(Ax) + exp = F.conv_transpose1d(Ax, weight) + assert torch.allclose(out, exp, rtol=1e-3) + +def test_conv2d_adjoint_simple(device): + x = torch.randn(1, 4, 5, 5, device=device) + weight = torch.randn(8, 4, 3, 3, device=device) + Ax = F.conv2d(x, weight) + conv = basics.Convolve2d(x.shape, weight) + out = conv.adjoint(Ax) + exp = F.conv_transpose2d(Ax, weight) + assert torch.allclose(out, exp, rtol=1e-3) + +def test_conv3d_adjoint_simple(device): + x = torch.randn(20, 16, 50, 10, 20, device=device) + weight = torch.randn(33, 16, 3, 3, 3, device=device) + Ax = F.conv3d(x, weight) + conv = basics.Convolve3d(x.shape, weight) + out = conv.adjoint(Ax) + exp = F.conv_transpose3d(Ax, weight) + assert torch.allclose(out, exp, rtol=1e-3) + +def test_patch2d_forward(device): + x = torch.randn(2, 3, 10, 10, device=device) + kernel_size = 2 + stride = 1 + exp = torch.zeros(2, 3, 9, 9, 2, 2, device=device) + for ix in range(9): + for iy in range(9): + exp[:, :, ix, iy, :, :] = x[:, :, ix:ix+2, iy:iy+2] + P = basics.Patch2D(x.shape, kernel_size, stride) + out = P * x + assert torch.allclose(out, exp, rtol=1e-3) + +def test_patch2d_adjoint(device): + x = torch.randn(2, 3, 9, 9, 2, 2, device=device) + kernel_size = 2 + stride = 1 + exp = torch.zeros(2, 3, 10, 10, device=device) + for ix in range(9): + for iy in range(9): + exp[:, :, ix:ix+2, iy:iy+2] += x[:, :, ix, iy, :, :] + P = basics.Patch2D(exp.shape, kernel_size, stride) + out = P.H * x + assert torch.allclose(out, exp, rtol=1e-3) + +def test_patch3d_forward(device): + x = torch.randn(2, 3, 10, 10, 10, device=device) + kernel_size = 2 + stride = 1 + exp = torch.zeros(2, 3, 9, 9, 9, 2, 2, 2, device=device) + for ix in range(9): + for iy in range(9): + for iz in range(9): + exp[:, :, ix, iy, iz, :, :, :] = x[:, :, ix:ix+2, iy:iy+2, iz:iz+2] + P = basics.Patch3D(x.shape, kernel_size, stride) + out = P * x + assert torch.allclose(out, exp, rtol=1e-3) + +def test_patch3d_adjoint(device): + x = torch.randn(2, 3, 9, 9, 9, 2, 2, 2, device=device) + kernel_size = 2 + stride = 1 + exp = torch.zeros(2, 3, 10, 10, 10, device=device) + for ix in range(9): + for iy in range(9): + for iz in range(9): + exp[:, :, ix:ix+2, iy:iy+2, iz:iz+2] += x[:, :, ix, iy, iz, :, :, :] + P = basics.Patch3D(exp.shape, kernel_size, stride) + out = P.H * x + assert torch.allclose(out, exp, rtol=1e-3) diff --git a/tests/linops_tests.py b/tests/linops_tests.py index e69de29..68bc00e 100644 --- a/tests/linops_tests.py +++ b/tests/linops_tests.py @@ -0,0 +1,109 @@ +import pytest +import torch +from typing import List +from torch import Tensor +from mirtorch.linear import LinearMap, Add, Multiply, Matmul, ConjTranspose, BlockDiagonal, Kron, Vstack, Hstack + + +# Define a mock linear operator for testing purposes +class MockLinearOperator(LinearMap): + def _apply(self, x: Tensor) -> Tensor: + return 2 * x + + def _apply_adjoint(self, x: Tensor) -> Tensor: + return 0.5 * x + + +@pytest.fixture +def tensor1(): + return torch.tensor([1.0, 2.0, 3.0]) + + +@pytest.fixture +def tensor2(): + return torch.tensor([4.0, 5.0, 6.0]) + + +@pytest.fixture +def linear_operator(): + return MockLinearOperator([3], [3]) + + +def test_linear_map_initialization(): + lm = LinearMap([3], [3]) + assert lm.size_in == [3] + assert lm.size_out == [3] + + +def test_add_operator(tensor1, linear_operator): + op = Add(linear_operator, linear_operator) + result = op.apply(tensor1) + expected = 4 * tensor1 + assert torch.allclose(result, expected) + + +def test_multiply_operator(tensor1, linear_operator): + op = Multiply(linear_operator, 3) + result = op.apply(tensor1) + expected = linear_operator.apply(3 * tensor1) + assert torch.allclose(result, expected) + + +def test_matmul_operator(tensor1, linear_operator): + op = Matmul(linear_operator, linear_operator) + result = op.apply(tensor1) + expected = linear_operator.apply(linear_operator.apply(tensor1)) + assert torch.allclose(result, expected) + + +def test_conj_transpose_operator(tensor1, linear_operator): + op = ConjTranspose(linear_operator) + result = op.apply(tensor1) + expected = linear_operator.adjoint(tensor1) + assert torch.allclose(result, expected) + + +def test_block_diagonal_operator(tensor1, linear_operator): + op = BlockDiagonal([linear_operator, linear_operator]) + x = torch.stack([tensor1, tensor1], dim=-1) + result = op.apply(x) + expected = torch.stack([linear_operator.apply(tensor1), linear_operator.apply(tensor1)], dim=-1) + assert torch.allclose(result, expected) + + +def test_kron_operator(tensor1, linear_operator): + op = Kron(linear_operator, 2) + x = torch.stack([tensor1, tensor1], dim=-1) + result = op.apply(x) + expected = torch.stack([linear_operator.apply(tensor1), linear_operator.apply(tensor1)], dim=-1) + assert torch.allclose(result, expected) + + +def test_vstack_operator(tensor1, linear_operator): + op = Vstack([linear_operator, linear_operator]) + result = op.apply(tensor1) + expected = torch.cat([linear_operator.apply(tensor1), linear_operator.apply(tensor1)]) + assert torch.allclose(result, expected) + assert result.shape == (6,) # 3 + 3 + + # Test adjoint + adjoint_input = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + adjoint_result = op.adjoint(adjoint_input) + expected_adjoint = linear_operator.adjoint(adjoint_input[:3]) + linear_operator.adjoint(adjoint_input[3:]) + assert torch.allclose(adjoint_result, expected_adjoint) + assert adjoint_result.shape == (3,) + +def test_hstack_operator(tensor1, tensor2, linear_operator): + op = Hstack([linear_operator, linear_operator]) + input_tensor = torch.cat([tensor1, tensor2]) + result = op.apply(input_tensor) + expected = linear_operator.apply(tensor1) + linear_operator.apply(tensor2) + assert torch.allclose(result, expected) + assert result.shape == (3,) + + # Test adjoint + adjoint_input = torch.tensor([1.0, 2.0, 3.0]) + adjoint_result = op.adjoint(adjoint_input) + expected_adjoint = torch.cat([linear_operator.adjoint(adjoint_input), linear_operator.adjoint(adjoint_input)]) + assert torch.allclose(adjoint_result, expected_adjoint) + assert adjoint_result.shape == (6,) # 3 diff --git a/tests/mri_tests.py b/tests/mri_tests.py index e69de29..3e9f642 100644 --- a/tests/mri_tests.py +++ b/tests/mri_tests.py @@ -0,0 +1,76 @@ +import pytest +import torch +import numpy as np +from mirtorch.linear import FFTCn, Sense, NuSense, NuSenseGram + +@pytest.fixture +def complex_tensor(): + return torch.complex(torch.randn(2, 1, 16, 16), torch.randn(2, 1, 16, 16)) + +@pytest.fixture +def smaps(): + return torch.complex(torch.randn(2, 4, 16, 16), torch.randn(2, 4, 16, 16)) + +@pytest.fixture +def masks(): + return torch.randint(0, 2, (2, 16, 16)).float() + +@pytest.fixture +def traj(): + return torch.rand(2, 2, 1000) * 2 - 1 + +def test_fftcn_forward_backward(complex_tensor): + fftcn = FFTCn([2, 1, 16, 16], [2, 1, 16, 16], dims=(2, 3)) + k_space = fftcn(complex_tensor) + image = fftcn.H(k_space) + assert torch.allclose(complex_tensor, image, atol=1e-6) + +def test_fftcn_adjoint_property(complex_tensor): + fftcn = FFTCn([2, 1, 16, 16], [2, 1, 16, 16], dims=(2, 3)) + k_space = torch.randn_like(complex_tensor) + lhs = torch.sum(fftcn(complex_tensor).conj() * k_space) + rhs = torch.sum(complex_tensor.conj() * fftcn.H(k_space)) + assert torch.allclose(lhs, rhs, atol=1e-6) + +def test_sense_forward_backward(complex_tensor, smaps, masks): + sense = Sense(smaps, masks) + k_space = sense(complex_tensor) + image = sense.H(k_space) + assert k_space.shape == (2, 4, 16, 16) + assert image.shape == (2, 1, 16, 16) + assert not torch.allclose(complex_tensor, image, atol=1e-6) # Due to undersampling + +def test_sense_adjoint_property(complex_tensor, smaps, masks): + sense = Sense(smaps, masks) + k_space = torch.randn(2, 4, 16, 16, dtype=torch.complex64) + lhs = torch.sum(sense(complex_tensor).conj() * k_space) + rhs = torch.sum(complex_tensor.conj() * sense.H(k_space)) + assert torch.allclose(lhs, rhs, atol=1e-6) + +def test_nusense_forward_backward(complex_tensor, smaps, traj): + nusense = NuSense(smaps, traj) + k_space = nusense(complex_tensor) + image = nusense.H(k_space) + assert k_space.shape == (2, 4, 1000) + assert image.shape == (2, 1, 16, 16) + assert not torch.allclose(complex_tensor, image, atol=1e-6) # Due to non-Cartesian sampling + +def test_nusense_adjoint_property(complex_tensor, smaps, traj): + nusense = NuSense(smaps, traj) + k_space = torch.randn(2, 4, 1000, dtype=torch.complex64) + lhs = torch.sum(nusense(complex_tensor).conj() * k_space) + rhs = torch.sum(complex_tensor.conj() * nusense.H(k_space)) + assert torch.allclose(lhs, rhs, atol=1e-6) + +def test_nusense_gram_forward(complex_tensor, smaps, traj): + nusense_gram = NuSenseGram(smaps, traj) + output = nusense_gram(complex_tensor) + assert output.shape == complex_tensor.shape + assert not torch.allclose(complex_tensor, output, atol=1e-6) + +def test_nusense_gram_adjoint_property(complex_tensor, smaps, traj): + nusense_gram = NuSenseGram(smaps, traj) + y = torch.randn_like(complex_tensor) + lhs = torch.sum(nusense_gram(complex_tensor).conj() * y) + rhs = torch.sum(complex_tensor.conj() * nusense_gram.H(y)) + assert torch.allclose(lhs, rhs, atol=1e-6) diff --git a/tests/prox_tests.py b/tests/prox_tests.py index 50fd810..cd50331 100644 --- a/tests/prox_tests.py +++ b/tests/prox_tests.py @@ -1,131 +1,132 @@ -import unittest -import sys -import os - -path = os.path.dirname(os.path.abspath(__file__)) -path = path[: path.rfind("/")] -sys.path.insert(0, path) +import pytest import torch import numpy as np import numpy.testing as npt -import mirtorch - - -class TestProx(unittest.TestCase): - def test_l1(self): - lambd = np.random.random() - prox = mirtorch.prox.L1Regularizer(lambd) - a = torch.rand((5, 4, 8), dtype=torch.float) - exp = np.zeros((5, 4, 8)).flatten() - out = prox(a, 0.1) - - Lambd = 0.1 * lambd - a = a.numpy().flatten() - for i in range(a.shape[0]): - if a[i] > Lambd: - exp[i] = a[i] - Lambd - elif a[i] < -Lambd: - exp[i] = a[i] + Lambd - else: - exp[i] = 0 - - exp = exp.reshape((5, 4, 8)) - - npt.assert_allclose(out, exp, rtol=1e-3) - - def test_l2(self): - a = torch.rand((3, 4, 2, 1), dtype=torch.float) - lambd = np.random.random() - prox = mirtorch.prox.L2Regularizer(lambd) - exp = 1.0 - lambd * 0.1 / max(np.linalg.norm(a.numpy()), lambd * 0.1) - npt.assert_allclose(prox(a, 0.1), exp * a.numpy(), rtol=1e-3) - - def test_squaredl2(self): - a = torch.rand((3, 4, 2, 1), dtype=torch.float) - lambd = np.random.random() - prox = mirtorch.prox.SquaredL2Regularizer(lambd) - exp = a.numpy() / (1.0 + 2 * lambd * 0.1) - npt.assert_allclose(prox(a, 0.1), exp, rtol=1e-3) - - def test_boxconstraint(self): - lambd = np.random.random() - lower, upper = np.random.randint(0, 10), np.random.randint(10, 20) - prox = mirtorch.prox.BoxConstraint(lambd, lower, upper) - a = 100 * torch.rand((5, 4, 8), dtype=torch.float) - out = prox(a, 0.1) - exp = np.clip(a.numpy(), lower, upper) - npt.assert_allclose(out, exp, rtol=1e-3) - - def test_l0_complex(self): - lambd = np.random.random() - prox = mirtorch.prox.L0Regularizer(lambd) - a = torch.rand(2, 2, dtype=torch.cfloat, requires_grad=True) - out = prox(a, 0.1) - torch.sum(out).backward() - a.requires_grad = False - an = a.numpy() - exp = torch.from_numpy(an * (np.abs(an) > (lambd * 0.1))).to(out) - npt.assert_allclose(out.detach(), exp, rtol=1e-3) - - def test_l1_complex(self): - lambd = np.random.random() - prox = mirtorch.prox.L1Regularizer(lambd) - a = torch.rand(2, 2, dtype=torch.cfloat, requires_grad=True) - out = prox(a, 0.1) - torch.sum(out).backward() - a.requires_grad = False - exp = torch.exp(1j * a.angle()) * prox(a.abs(), 0.1) - npt.assert_allclose(out.detach(), exp, rtol=1e-3) - - def test_l2_complex(self): - lambd = np.random.random() - prox = mirtorch.prox.L2Regularizer(lambd) - a = torch.rand(2, 2, dtype=torch.cfloat, requires_grad=True) - out = prox(a, 0.1) - torch.sum(out).backward() - a.requires_grad = False - exp = torch.exp(1j * a.angle()) * prox(a.abs(), 0.1) - npt.assert_allclose(out.detach(), exp, rtol=1e-3) - - def test_squaredl2_complex(self): - lambd = np.random.random() - prox = mirtorch.prox.SquaredL2Regularizer(lambd) - a = torch.rand(2, 2, dtype=torch.cfloat, requires_grad=True) - out = prox(a, 0.1) - torch.sum(out).backward() - a.requires_grad = False - exp = torch.exp(1j * a.angle()) * prox(a.abs(), 0.1) - npt.assert_allclose(out.detach(), exp, rtol=1e-3) - - def test_angle(self): - a = torch.complex(torch.Tensor([1]), torch.Tensor([-1])) - npt.assert_allclose(a.angle(), torch.atan2(a.imag, a.real)) - - def test_boxconstraint_complex(self): - lambd = np.random.random() - lower, upper = np.random.randint(0, 10), np.random.randint(10, 20) - prox = mirtorch.prox.BoxConstraint(lambd, lower, upper) - a = torch.rand(2, 2, dtype=torch.cfloat, requires_grad=True) - out = prox(a, 0.1) - torch.sum(out).backward() - a.requires_grad = False - exp = torch.exp(1j * a.angle()) * prox(a.abs(), 0.1) - npt.assert_allclose(out.detach(), exp, rtol=1e-3) - - def test_complex_edge_cases(self): - a = torch.complex(torch.Tensor([1]), torch.Tensor([0])) - npt.assert_allclose(a.angle(), torch.atan2(a.imag, a.real)) - - def test_complex_edge_cases2(self): - a = torch.complex(torch.Tensor([0]), torch.Tensor([1])) - npt.assert_allclose(a.angle(), torch.atan2(a.imag, a.real)) - - def test_complex_edge_cases3(self): - # Should we ever need to worry about this issue? - a = torch.complex(torch.Tensor([0]), torch.Tensor([0])) - npt.assert_allclose(a.angle(), torch.atan2(a.imag, a.real)) - - -if __name__ == "__main__": - print(f"PyTorch version: {torch.__version__}") - unittest.main() +from mirtorch.prox import ( + Prox, + L1Regularizer, + L2Regularizer, + SquaredL2Regularizer, + BoxConstraint, + L0Regularizer, + Conj, + Const, +) + +# Fixtures for common test data +@pytest.fixture +def random_tensor(): + return torch.rand((5, 4, 8), dtype=torch.float) + +@pytest.fixture +def random_lambda(): + return np.abs(np.random.random()) + +@pytest.fixture +def random_tensor_complex(): + return torch.rand(2, 2, dtype=torch.float, requires_grad=True) + +# Test cases +def test_l1_regularizer(random_tensor, random_lambda): + prox = L1Regularizer(random_lambda) + out = prox(random_tensor, 0.1) + + lambd = 0.1 * random_lambda + a = random_tensor.numpy().flatten() + exp = np.zeros_like(a) + + for i in range(a.shape[0]): + if a[i] > lambd: + exp[i] = a[i] - lambd + elif a[i] < -lambd: + exp[i] = a[i] + lambd + else: + exp[i] = 0 + + exp = exp.reshape(random_tensor.shape) + npt.assert_allclose(out, exp, rtol=1e-3) + +def test_l2_regularizer(random_tensor, random_lambda): + prox = L2Regularizer(random_lambda) + out = prox(random_tensor, 0.1) + + exp = 1.0 - random_lambda * 0.1 / max(np.linalg.norm(random_tensor.numpy()), random_lambda * 0.1) + npt.assert_allclose(out, exp * random_tensor.numpy(), rtol=1e-3) + +def test_squaredl2_regularizer(random_tensor, random_lambda): + prox = SquaredL2Regularizer(random_lambda) + out = prox(random_tensor, 0.1) + + exp = random_tensor.numpy() / (1.0 + 2 * random_lambda * 0.1) + npt.assert_allclose(out, exp, rtol=1e-3) + +def test_boxconstraint(random_tensor, random_lambda): + lower, upper = np.random.randint(0, 10), np.random.randint(10, 20) + prox = BoxConstraint(random_lambda, lower, upper) + out = prox(random_tensor, 0.1) + + exp = np.clip(random_tensor.numpy(), lower, upper) + npt.assert_allclose(out, exp, rtol=1e-3) + +def test_l0_regularizer_complex(random_tensor_complex, random_lambda): + prox = L0Regularizer(random_lambda) + out = prox(random_tensor_complex, 0.1) + torch.sum(out).backward() + + random_tensor_complex.requires_grad = False + an = random_tensor_complex.numpy() + exp = torch.from_numpy(an * (np.abs(an) > (random_lambda * 0.1))).to(out) + npt.assert_allclose(out.detach(), exp, rtol=1e-3) + +def test_l1_regularizer_complex(random_tensor_complex, random_lambda): + prox = L1Regularizer(random_lambda) + out = prox(random_tensor_complex, 0.1) + torch.sum(out).backward() + + random_tensor_complex.requires_grad = False + exp = torch.exp(1j * random_tensor_complex.angle()) * prox(random_tensor_complex.abs(), 0.1) + npt.assert_allclose(out.detach(), exp, rtol=1e-3) + +def test_l2_regularizer_complex(random_tensor_complex, random_lambda): + prox = L2Regularizer(random_lambda) + out = prox(random_tensor_complex, 0.1) + torch.sum(out).backward() + + random_tensor_complex.requires_grad = False + exp = torch.exp(1j * random_tensor_complex.angle()) * prox(random_tensor_complex.abs(), 0.1) + npt.assert_allclose(out.detach(), exp, rtol=1e-3) + +def test_squaredl2_regularizer_complex(random_tensor_complex, random_lambda): + prox = SquaredL2Regularizer(random_lambda) + out = prox(random_tensor_complex, 0.1) + torch.sum(out).backward() + + random_tensor_complex.requires_grad = False + exp = torch.exp(1j * random_tensor_complex.angle()) * prox(random_tensor_complex.abs(), 0.1) + npt.assert_allclose(out.detach(), exp, rtol=1e-3) + +def test_angle(): + a = torch.complex(torch.Tensor([1]), torch.Tensor([-1])) + npt.assert_allclose(a.angle(), torch.atan2(a.imag, a.real)) + +def test_boxconstraint_complex(random_tensor_complex, random_lambda): + lower, upper = np.random.randint(0, 10), np.random.randint(10, 20) + prox = BoxConstraint(random_lambda, lower, upper) + out = prox(random_tensor_complex, 0.1) + torch.sum(out).backward() + + random_tensor_complex.requires_grad = False + exp = torch.exp(1j * random_tensor_complex.angle()) * prox(random_tensor_complex.abs(), 0.1) + npt.assert_allclose(out.detach(), exp, rtol=1e-3) + +def test_complex_edge_cases(): + a = torch.complex(torch.Tensor([1]), torch.Tensor([0])) + npt.assert_allclose(a.angle(), torch.atan2(a.imag, a.real)) + +def test_complex_edge_cases2(): + a = torch.complex(torch.Tensor([0]), torch.Tensor([1])) + npt.assert_allclose(a.angle(), torch.atan2(a.imag, a.real)) + +def test_complex_edge_cases3(): + a = torch.complex(torch.Tensor([0]), torch.Tensor([0])) + npt.assert_allclose(a.angle(), torch.atan2(a.imag, a.real)) diff --git a/tests/spect_tests.py b/tests/spect_tests.py index e2d20f8..e909402 100644 --- a/tests/spect_tests.py +++ b/tests/spect_tests.py @@ -1,58 +1,112 @@ -# spect_tests.py -""" -Adjoint tests for SPECT forward-backward projector -Author: Zongyu Li, zonyul@umich.edu -""" - -import unittest -import sys -import os - -path = os.path.dirname(os.path.abspath(__file__)) -path = path[: path.rfind("/")] -sys.path.insert(0, path) -from mirtorch.linear.spect import SPECT +# # spect_tests.py +# """ +# Adjoint tests for SPECT forward-backward projector +# Author: Zongyu Li, zonyul@umich.edu +# """ + +# import unittest +# import sys +# import os + +# path = os.path.dirname(os.path.abspath(__file__)) +# path = path[: path.rfind("/")] +# sys.path.insert(0, path) +# from mirtorch.linear.spect import SPECT +# import torch + + +# def gen_data(nx, ny, nz, nview, px, pz): +# img = torch.zeros(nx, ny, nz) +# img[1:-1, 1:-1, 1:-1] = torch.rand(nx - 2, ny - 2, nz - 2) +# view = torch.zeros(nx, nz, nview) +# view[1:-1, 1:-1, 1:-1] = torch.rand(nx - 2, nz - 2, nview - 2) +# mumap = torch.zeros(nx, ny, nz) +# mumap[1:-1, 1:-1, 1:-1] = torch.rand(nx - 2, ny - 2, nz - 2) +# psfs = torch.ones(px, pz, ny, nview) / (px * pz) +# return img, view, mumap, psfs + + +# class TestSPECT(unittest.TestCase): +# def test_adjoint(self): +# torch.manual_seed(42) +# nx = 8 +# ny = 8 +# nz = 6 +# nview = 9 +# px = 3 +# pz = 3 +# dy = 4.8 + +# img, view, mumap, psfs = gen_data(nx, ny, nz, nview, px, pz) +# SPECT_sys = SPECT( +# size_in=(nx, ny, nz), +# size_out=(nx, nz, nview), +# mumap=mumap, +# psfs=psfs, +# dy=dy, +# ) +# out1 = SPECT_sys * img +# out2 = SPECT_sys.H * view + +# test1 = torch.dot(out1.reshape(-1), view.reshape(-1)) +# test2 = torch.dot(out2.reshape(-1), img.reshape(-1)) +# assert torch.allclose(test1, test2, rtol=5e-3) + + +# if __name__ == "__main__": +# t = TestSPECT() +# t.test_adjoint() + +import pytest import torch +from mirtorch.linear.spect import SPECT, project, backproject + +@pytest.fixture +def mumap(): + return torch.rand((32, 32, 32), dtype=torch.float32) + +@pytest.fixture +def psfs(): + return torch.rand((16, 16, 32, 60), dtype=torch.float32) + +@pytest.fixture +def dy(): + return 1.0 + +@pytest.fixture +def input_tensor(mumap): + return torch.rand(mumap.shape, dtype=torch.float32) + +@pytest.fixture +def view_tensor(psfs): + return torch.rand((32, 32, 60), dtype=torch.float32) + +def test_spect_init(mumap, psfs, dy): + size_in = mumap.shape + size_out = [mumap.shape[0], mumap.shape[2], psfs.shape[-1]] + spect = SPECT(size_in, size_out, mumap, psfs, dy) + assert spect.mumap.shape == mumap.shape + assert spect.psfs.shape == psfs.shape + assert spect.dy == dy + +def test_project(input_tensor, mumap, psfs, dy): + views = project(input_tensor, mumap, psfs, dy) + assert views.shape == (32, 32, 60) + +def test_backproject(view_tensor, mumap, psfs, dy): + image = backproject(view_tensor, mumap, psfs, dy) + assert image.shape == mumap.shape +def test_spect_apply(input_tensor, mumap, psfs, dy): + size_in = mumap.shape + size_out = [mumap.shape[0], mumap.shape[2], psfs.shape[-1]] + spect = SPECT(size_in, size_out, mumap, psfs, dy) + result = spect._apply(input_tensor) + assert result.shape == (32, 32, 60) -def gen_data(nx, ny, nz, nview, px, pz): - img = torch.zeros(nx, ny, nz) - img[1:-1, 1:-1, 1:-1] = torch.rand(nx - 2, ny - 2, nz - 2) - view = torch.zeros(nx, nz, nview) - view[1:-1, 1:-1, 1:-1] = torch.rand(nx - 2, nz - 2, nview - 2) - mumap = torch.zeros(nx, ny, nz) - mumap[1:-1, 1:-1, 1:-1] = torch.rand(nx - 2, ny - 2, nz - 2) - psfs = torch.ones(px, pz, ny, nview) / (px * pz) - return img, view, mumap, psfs - - -class TestSPECT(unittest.TestCase): - def test_adjoint(self): - torch.manual_seed(42) - nx = 8 - ny = 8 - nz = 6 - nview = 9 - px = 3 - pz = 3 - dy = 4.8 - - img, view, mumap, psfs = gen_data(nx, ny, nz, nview, px, pz) - SPECT_sys = SPECT( - size_in=(nx, ny, nz), - size_out=(nx, nz, nview), - mumap=mumap, - psfs=psfs, - dy=dy, - ) - out1 = SPECT_sys * img - out2 = SPECT_sys.H * view - - test1 = torch.dot(out1.reshape(-1), view.reshape(-1)) - test2 = torch.dot(out2.reshape(-1), img.reshape(-1)) - assert torch.allclose(test1, test2, rtol=5e-3) - - -if __name__ == "__main__": - t = TestSPECT() - t.test_adjoint() +def test_spect_apply_adjoint(view_tensor, mumap, psfs, dy): + size_in = mumap.shape + size_out = [mumap.shape[0], mumap.shape[2], psfs.shape[-1]] + spect = SPECT(size_in, size_out, mumap, psfs, dy) + result = spect._apply_adjoint(view_tensor) + assert result.shape == mumap.shape diff --git a/tests/util_tests.py b/tests/util_tests.py new file mode 100644 index 0000000..73a553b --- /dev/null +++ b/tests/util_tests.py @@ -0,0 +1,109 @@ +import pytest +import torch +import torchvision.transforms.functional as F + +from mirtorch.linear.util import ( + finitediff, + finitediff_adj, + fftshift, + ifftshift, + dim_conv, + imrotate, + fft2, + ifft2, + pad2sizezero, + fft_conv, + fft_conv_adj, + map2x, + map2y, + integrate1D +) + +@pytest.fixture +def tensor_2d(): + return torch.rand((4, 4)) + +@pytest.fixture +def tensor_3d(): + return torch.rand((3, 4, 4)) + +@pytest.fixture +def tensor_4d(): + return torch.rand((2, 3, 4, 4)) + +def test_finitediff(tensor_2d): + result = finitediff(tensor_2d, dim=1, mode='reflexive') + assert result.shape == (4, 3) + +def test_finitediff_periodic(tensor_2d): + result = finitediff(tensor_2d, dim=1, mode="periodic") + assert result.shape == (4, 4) + +def test_finitediff_adj(tensor_2d): + result = finitediff_adj(tensor_2d, dim=1, mode='reflexive') + assert result.shape == (4, 5) + +def test_finitediff_adj_periodic(tensor_2d): + result = finitediff_adj(tensor_2d, dim=1, mode="periodic") + assert result.shape == (4, 4) + +def test_fftshift(tensor_2d): + result = fftshift(tensor_2d) + assert result.shape == tensor_2d.shape + +def test_ifftshift(tensor_2d): + result = ifftshift(tensor_2d) + assert result.shape == tensor_2d.shape + +def test_dim_conv(): + result = dim_conv(32, 3, dim_stride=2, dim_padding=1) + assert result == 16 + +def test_imrotate(tensor_4d): + angle = 45 + result = imrotate(tensor_4d, angle) + assert result.shape == tensor_4d.shape + +def test_fft2(tensor_2d): + result = fft2(tensor_2d) + assert result.shape == tensor_2d.shape + +def test_ifft2(tensor_2d): + result = ifft2(tensor_2d) + assert result.shape == tensor_2d.shape + +def test_pad2sizezero(tensor_2d): + result = pad2sizezero(tensor_2d, 6, 6) + assert result.shape == (6, 6) + +def test_fft_conv(tensor_2d): + ker = torch.rand((3, 3)) + result = fft_conv(tensor_2d, ker) + assert result.shape == tensor_2d.shape + +def test_fft_conv_adj(tensor_2d): + ker = torch.rand((3, 3)) + result = fft_conv_adj(tensor_2d, ker) + assert result.shape == tensor_2d.shape + +def test_map2x(): + x1 = torch.tensor(1.0) + y1 = torch.rand((4, 4)) + x2 = torch.tensor(0.0) + y2 = torch.rand((4, 4)) + result = map2x(x1, y1, x2, y2) + assert result.shape == y1.shape + +def test_map2y(): + x1 = torch.tensor(1.0) + y1 = torch.rand((4, 4)) + x2 = torch.tensor(0.0) + y2 = torch.rand((4, 4)) + result = map2y(x1, y1, x2, y2) + assert result.shape == y1.shape + +def test_integrate1D(): + p_v = torch.rand((4,)) + pixelSize = torch.tensor([1.0, 1.0, 1.0, 1.0]) + result = integrate1D(p_v, pixelSize) + assert result.shape == (5,) From f055e51ad25b887bd43a5cad8f5f0174b2237e36 Mon Sep 17 00:00:00 2001 From: guanhuaw Date: Sat, 27 Jul 2024 22:11:02 -0700 Subject: [PATCH 5/6] updated ci --- .github/workflows/python-ci.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index e8a646f..6d162bb 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -17,11 +17,12 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" # Specify the Python version you need + cache: 'pip' # caching pip dependencies - name: Install dependencies run: | @@ -31,7 +32,7 @@ jobs: - name: Lint with Ruff run: | - ruff ./mirtorch + ruff check ./mirtorch - name: Test with pytest run: | From 9f24708bb3fb67ff28c6ef351627355f2bcb7664 Mon Sep 17 00:00:00 2001 From: guanhuaw Date: Sat, 27 Jul 2024 22:24:36 -0700 Subject: [PATCH 6/6] update ci --- .github/workflows/python-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index 6d162bb..0420046 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -36,7 +36,7 @@ jobs: - name: Test with pytest run: | - pytest ./tests + pytest ./tests/* - name: Automated Version Bump if: github.ref == 'refs/heads/master'