Skip to content

Commit

Permalink
Merge pull request #81 from v0lta/typing-everywhere
Browse files Browse the repository at this point in the history
Type annotate the tests
  • Loading branch information
v0lta authored Feb 1, 2024
2 parents f390042 + 6401034 commit 38901ce
Show file tree
Hide file tree
Showing 23 changed files with 283 additions and 183 deletions.
1 change: 1 addition & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def mypy(session):
"--no-warn-return-any",
"--explicit-package-bases",
"src",
"tests",
)


Expand Down
2 changes: 1 addition & 1 deletion src/ptwt/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pywt
import torch

from ptwt.constants import OrthogonalizeMethod
from .constants import OrthogonalizeMethod


class Wavelet(Protocol):
Expand Down
8 changes: 4 additions & 4 deletions src/ptwt/continuous_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,10 @@ def __call__(self, grid_values: torch.Tensor) -> torch.Tensor:
shannon = (
torch.sqrt(self.bandwidth)
* (
torch.sin(torch.pi * self.bandwidth * grid_values) # type: ignore
torch.sin(torch.pi * self.bandwidth * grid_values)
/ (torch.pi * self.bandwidth * grid_values)
)
* torch.exp(1j * 2 * torch.pi * self.center * grid_values) # type: ignore
* torch.exp(1j * 2 * torch.pi * self.center * grid_values)
)
return shannon

Expand All @@ -306,8 +306,8 @@ def __call__(self, grid_values: torch.Tensor) -> torch.Tensor:
"""Return numerical values for the wavelet on a grid."""
morlet = (
1.0
/ torch.sqrt(torch.pi * self.bandwidth) # type: ignore
/ torch.sqrt(torch.pi * self.bandwidth)
* torch.exp(-(grid_values**2) / self.bandwidth)
* torch.exp(1j * 2 * torch.pi * self.center * grid_values) # type: ignore
* torch.exp(1j * 2 * torch.pi * self.center * grid_values)
)
return morlet
6 changes: 5 additions & 1 deletion src/ptwt/matmul_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,11 @@ def orthogonalize(
raise ValueError(f"Invalid orthogonalization method: {method}")


class MatrixWavedec(object):
class BaseMatrixWaveDec:
"""A base class for matrix wavedec."""


class MatrixWavedec(BaseMatrixWaveDec):
"""Compute the sparse matrix fast wavelet transform.
Intermediate scale results must be divisible
Expand Down
9 changes: 7 additions & 2 deletions src/ptwt/matmul_transform_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
_preprocess_tensor_dec2d,
_waverec2d_fold_channels_2d_list,
)
from .matmul_transform import construct_boundary_a, construct_boundary_s, orthogonalize
from .matmul_transform import (
BaseMatrixWaveDec,
construct_boundary_a,
construct_boundary_s,
orthogonalize,
)
from .sparse_math import (
batch_mm,
cat_sparse_identity_matrix,
Expand Down Expand Up @@ -217,7 +222,7 @@ def _matrix_pad_2(height: int, width: int) -> Tuple[int, int, Tuple[bool, bool]]
return height, width, pad_tuple


class MatrixWavedec2(object):
class MatrixWavedec2(BaseMatrixWaveDec):
"""Experimental sparse matrix 2d wavelet transform.
For a completely pad-free transform,
Expand Down
4 changes: 2 additions & 2 deletions src/ptwt/matmul_transform_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
wavelet: Union[Wavelet, str],
level: Optional[int] = None,
axes: Tuple[int, int, int] = (-3, -2, -1),
boundary: Optional[str] = "qr",
boundary: OrthogonalizeMethod = "qr",
):
"""Create a *separable* three-dimensional fast boundary wavelet transform.
Expand All @@ -69,7 +69,7 @@ def __init__(
wavelet (Union[Wavelet, str]): The wavelet to use.
level (Optional[int]): The desired decomposition level.
Defaults to None.
boundary (Optional[str]): The matrix orthogonalization method.
boundary: The matrix orthogonalization method.
Defaults to "qr".
Raises:
Expand Down
2 changes: 1 addition & 1 deletion src/ptwt/packets.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(
if len(data.shape) == 1:
# add a batch dimension.
data = data.unsqueeze(0)
self.transform(data, maxlevel) # type: ignore
self.transform(data, maxlevel)
else:
self.data = {}

Expand Down
6 changes: 3 additions & 3 deletions src/ptwt/separable_conv_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def _separable_conv_waverecn(

approx: torch.Tensor = coeffs[0]
for level_dict in coeffs[1:]:
keys = list(level_dict.keys())
level_dict["a" * max(map(len, keys))] = approx
approx = _separable_conv_idwtn(level_dict, wavelet)
keys = list(level_dict.keys()) # type: ignore
level_dict["a" * max(map(len, keys))] = approx # type: ignore
approx = _separable_conv_idwtn(level_dict, wavelet) # type: ignore
return approx


Expand Down
2 changes: 1 addition & 1 deletion src/ptwt/sparse_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch

from ptwt.constants import PaddingMode
from .constants import PaddingMode


def _dense_kron(
Expand Down
2 changes: 1 addition & 1 deletion tests/_mackey_glass.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _mackey(
return x[:, discard:]


class MackeyGenerator(object):
class MackeyGenerator:
"""Generates lorenz attractor data in 1 or 3d on the GPU."""

def __init__(
Expand Down
80 changes: 50 additions & 30 deletions tests/test_convolution_fwt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test the conv-fwt code."""

from typing import Iterable, List, Optional, Sequence, Tuple, Union

# Written by moritz ( @ wolter.tech ) in 2021
import numpy as np
import pytest
Expand All @@ -8,14 +10,15 @@
from scipy import datasets

from ptwt._util import _outer
from ptwt.constants import BoundaryMode
from ptwt.conv_transform import (
_flatten_2d_coeff_lst,
_translate_boundary_strings,
wavedec,
waverec,
)
from ptwt.conv_transform_2 import wavedec2, waverec2
from src.ptwt.wavelets_learnable import SoftOrthogonalWavelet
from ptwt.wavelets_learnable import SoftOrthogonalWavelet
from tests._mackey_glass import MackeyGenerator


Expand All @@ -28,7 +31,14 @@
"mode", ["reflect", "zero", "constant", "periodic", "symmetric"]
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_conv_fwt1d(wavelet_string, level, mode, length, batch_size, dtype):
def test_conv_fwt1d(
wavelet_string: str,
level: Optional[int],
mode: BoundaryMode,
length: int,
batch_size: int,
dtype: torch.dtype,
) -> None:
"""Test multiple convolution fwt, for various levels and padding options."""
generator = MackeyGenerator(
batch_size=batch_size, tmax=length, delta_t=1, device="cpu"
Expand All @@ -49,15 +59,15 @@ def test_conv_fwt1d(wavelet_string, level, mode, length, batch_size, dtype):
)
py_coeff = np.stack(py_list)
assert np.allclose(
cptcoeff.numpy(), py_coeff, atol=np.finfo(py_coeff.dtype).resolution
cptcoeff.numpy(), py_coeff, atol=float(np.finfo(py_coeff.dtype).resolution)
)
res = waverec(ptcoeff, wavelet)
assert np.allclose(mackey_data_1.numpy(), res.numpy()[:, : mackey_data_1.shape[-1]])


@pytest.mark.parametrize("size", [[5, 10, 64], [1, 1, 32]])
@pytest.mark.parametrize("wavelet", ["haar", "db2"])
def test_conv_fwt1d_channel(size, wavelet):
def test_conv_fwt1d_channel(size: List[int], wavelet: str) -> None:
"""Test channel dimension support."""
data = torch.randn(*size).type(torch.float64)
ptwt_coeff = wavedec(data, wavelet)
Expand All @@ -72,12 +82,12 @@ def test_conv_fwt1d_channel(size, wavelet):
assert np.allclose(data.numpy(), rec.numpy())


def test_ripples_haar_lvl3():
def test_ripples_haar_lvl3() -> None:
"""Compute example from page 7 of Ripples in Mathematics, Jensen, la Cour-Harbo."""

class _MyHaarFilterBank(object):
class _MyHaarFilterBank:
@property
def filter_bank(self):
def filter_bank(self) -> Tuple[List[float], ...]:
"""Unscaled Haar wavelet filters."""
return (
[1 / 2, 1 / 2.0],
Expand All @@ -95,7 +105,7 @@ def filter_bank(self):
assert (torch.squeeze(coeffs[3]).numpy() == [8.0, -8.0, 0.0, 12.0]).all()


def test_orth_wavelet():
def test_orth_wavelet() -> None:
"""Test an orthogonal wavelet fwt."""
generator = MackeyGenerator(batch_size=2, tmax=64, delta_t=1, device="cpu")

Expand All @@ -114,7 +124,7 @@ def test_orth_wavelet():

@pytest.mark.parametrize("level", [1, 2, 3, None])
@pytest.mark.parametrize("shape", [(64,), (1, 64), (3, 2, 64), (4, 3, 2, 64)])
def test_1d_multibatch(level, shape):
def test_1d_multibatch(level: Optional[int], shape: Sequence[int]) -> None:
"""Test 1D conv support for multiple inert batch dimensions."""
data = torch.randn(*shape, dtype=torch.float64)
ptwt_coeff = wavedec(data, "haar", level=level)
Expand All @@ -130,7 +140,7 @@ def test_1d_multibatch(level, shape):


@pytest.mark.parametrize("axis", [-1, 0, 1, 2])
def test_1d_axis_arg(axis):
def test_1d_axis_arg(axis: int) -> None:
"""Ensure the axis argument works as expected."""
data = torch.randn([16, 16, 16], dtype=torch.float64)

Expand All @@ -144,7 +154,7 @@ def test_1d_axis_arg(axis):
assert torch.allclose(rec, data)


def test_2d_haar_lvl1():
def test_2d_haar_lvl1() -> None:
"""Test a 2d-Haar wavelet conv-fwt."""
# ------------------------- 2d haar wavelet tests -----------------------
face = np.transpose(
Expand All @@ -161,7 +171,7 @@ def test_2d_haar_lvl1():
assert np.allclose(rec, face)


def test_2d_db2_lvl1():
def test_2d_db2_lvl1() -> None:
"""Test a 2d-db2 wavelet conv-fwt."""
# single level db2 - 2d
face = np.transpose(
Expand All @@ -178,7 +188,7 @@ def test_2d_db2_lvl1():
assert np.allclose(rec.numpy().squeeze(), face)


def test_2d_haar_multi():
def test_2d_haar_multi() -> None:
"""Test a 2d-db2 wavelet level 5 conv-fwt."""
# multi level haar - 2d
face = np.transpose(
Expand All @@ -195,7 +205,7 @@ def test_2d_haar_multi():
assert np.allclose(rec, face)


def test_outer():
def test_outer() -> None:
"""Test the outer-product implementation."""
a = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
b = torch.tensor([6.0, 7.0, 8.0, 9.0, 10.0])
Expand All @@ -213,7 +223,9 @@ def test_outer():
@pytest.mark.parametrize(
"mode", ["reflect", "zero", "constant", "periodic", "symmetric"]
)
def test_2d_wavedec_rec(wavelet_str, level, size, mode):
def test_2d_wavedec_rec(
wavelet_str: str, level: Optional[int], size: Tuple[int, int], mode: BoundaryMode
) -> None:
"""Ensure pywt.wavedec2 and ptwt.wavedec2 produce the same coefficients.
wavedec2 and waverec2 must invert each other.
Expand Down Expand Up @@ -247,7 +259,9 @@ def test_2d_wavedec_rec(wavelet_str, level, size, mode):
)
@pytest.mark.parametrize("level", [1, None])
@pytest.mark.parametrize("wavelet", ["haar", "sym3"])
def test_input_4d(size, level, wavelet):
def test_input_4d(
size: Tuple[int, int, int, int], level: Optional[str], wavelet: str
) -> None:
"""Test the error for 4d inputs to wavedec2."""
data = torch.randn(*size).type(torch.float64)

Expand All @@ -274,20 +288,23 @@ def test_input_4d(size, level, wavelet):


@pytest.mark.parametrize("padding_str", ["invalid_padding_name"])
def test_incorrect_padding(padding_str):
def test_incorrect_padding(padding_str: BoundaryMode) -> None:
"""Test expected errors for an invalid padding name."""
with pytest.raises(ValueError):
_ = _translate_boundary_strings(padding_str)


def test_input_1d_dimension_error():
def test_input_1d_dimension_error() -> None:
"""Test the error for 1d inputs to wavedec2."""
with pytest.raises(ValueError):
data = torch.randn(50)
wavedec2(data, "haar", level=4)


def _compare_coeffs(ptwt_res, pywt_res):
def _compare_coeffs(
ptwt_res: Iterable[Union[torch.Tensor, Tuple[torch.Tensor, ...]]],
pywt_res: Iterable[Union[torch.Tensor, Tuple[torch.Tensor, ...]]],
) -> List[bool]:
"""Compare coefficient lists.
Args:
Expand All @@ -296,26 +313,29 @@ def _compare_coeffs(ptwt_res, pywt_res):
Returns:
A list with bools from allclose.
Raises:
TypeError: In case of a problem with the list structures.
"""
test_list = []
test_list: List[bool] = []
for ptwtcs, pywtcs in zip(ptwt_res, pywt_res):
if isinstance(ptwtcs, tuple):
if isinstance(ptwtcs, tuple) and isinstance(pywtcs, tuple):
test_list.extend(
tuple(
np.allclose(ptwtc.numpy(), pywtc)
for ptwtc, pywtc in zip(ptwtcs, pywtcs)
)
np.allclose(ptwtc.numpy(), pywtc)
for ptwtc, pywtc in zip(ptwtcs, pywtcs)
)
else:
elif isinstance(ptwtcs, torch.Tensor):
test_list.append(np.allclose(ptwtcs.numpy(), pywtcs))
else:
raise TypeError("Invalid coefficient typing.")
return test_list


@pytest.mark.slow
@pytest.mark.parametrize(
"size", [(50, 20, 128, 128), (8, 49, 21, 128, 128), (6, 4, 4, 5, 64, 64)]
)
def test_2d_multidim_input(size):
def test_2d_multidim_input(size: Tuple[int, ...]) -> None:
"""Test the error for multi-dimensional inputs to wavedec2."""
data = torch.randn(*size, dtype=torch.float64)
wavelet = "db2"
Expand All @@ -337,7 +357,7 @@ def test_2d_multidim_input(size):

@pytest.mark.slow
@pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (-3, -2), (0, 1), (1, 0)])
def test_2d_axis_argument(axes):
def test_2d_axis_argument(axes: Tuple[int, int]) -> None:
"""Ensure the axes argument works as expected."""
data = torch.randn([32, 32, 32, 32], dtype=torch.float64)

Expand All @@ -355,14 +375,14 @@ def test_2d_axis_argument(axes):
)


def test_2d_axis_error_axes_count():
def test_2d_axis_error_axes_count() -> None:
"""Check the error for too many axes."""
with pytest.raises(ValueError):
data = torch.randn([32, 32, 32, 32], dtype=torch.float64)
wavedec2(data, "haar", level=1, axes=(1, 2, 3))


def test_2d_axis_error_axes_repetition():
def test_2d_axis_error_axes_repetition() -> None:
"""Check the error for axes repetition."""
with pytest.raises(ValueError):
data = torch.randn([32, 32, 32, 32], dtype=torch.float64)
Expand Down
Loading

0 comments on commit 38901ce

Please sign in to comment.