Skip to content

Commit

Permalink
On the road to type annotating the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt committed Jan 28, 2024
1 parent 5826fa7 commit f2d943e
Show file tree
Hide file tree
Showing 17 changed files with 119 additions and 74 deletions.
1 change: 0 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def mypy(session):
"--explicit-package-bases",
"src",
"tests",
"examples",
)


Expand Down
6 changes: 5 additions & 1 deletion src/ptwt/matmul_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,11 @@ def orthogonalize(
raise ValueError(f"Invalid orthogonalization method: {method}")


class MatrixWavedec(object):
class BaseMatrixWaveDec:
"""A base class for matrix wavedec."""


class MatrixWavedec(BaseMatrixWaveDec):
"""Compute the sparse matrix fast wavelet transform.
Intermediate scale results must be divisible
Expand Down
9 changes: 7 additions & 2 deletions src/ptwt/matmul_transform_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
_preprocess_tensor_dec2d,
_waverec2d_fold_channels_2d_list,
)
from .matmul_transform import construct_boundary_a, construct_boundary_s, orthogonalize
from .matmul_transform import (
BaseMatrixWaveDec,
construct_boundary_a,
construct_boundary_s,
orthogonalize,
)
from .sparse_math import (
batch_mm,
cat_sparse_identity_matrix,
Expand Down Expand Up @@ -217,7 +222,7 @@ def _matrix_pad_2(height: int, width: int) -> Tuple[int, int, Tuple[bool, bool]]
return height, width, pad_tuple


class MatrixWavedec2(object):
class MatrixWavedec2(BaseMatrixWaveDec):
"""Experimental sparse matrix 2d wavelet transform.
For a completely pad-free transform,
Expand Down
2 changes: 1 addition & 1 deletion tests/_mackey_glass.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _mackey(
return x[:, discard:]


class MackeyGenerator(object):
class MackeyGenerator:
"""Generates lorenz attractor data in 1 or 3d on the GPU."""

def __init__(
Expand Down
28 changes: 19 additions & 9 deletions tests/test_convolution_fwt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test the conv-fwt code."""

from typing import List, Optional, Sequence

# Written by moritz ( @ wolter.tech ) in 2021
import numpy as np
import pytest
Expand All @@ -8,14 +10,15 @@
from scipy import datasets

from ptwt._util import _outer
from ptwt.constants import BoundaryMode
from ptwt.conv_transform import (
_flatten_2d_coeff_lst,
_translate_boundary_strings,
wavedec,
waverec,
)
from ptwt.conv_transform_2 import wavedec2, waverec2
from src.ptwt.wavelets_learnable import SoftOrthogonalWavelet
from ptwt.wavelets_learnable import SoftOrthogonalWavelet
from tests._mackey_glass import MackeyGenerator


Expand All @@ -28,7 +31,14 @@
"mode", ["reflect", "zero", "constant", "periodic", "symmetric"]
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_conv_fwt1d(wavelet_string, level, mode, length, batch_size, dtype):
def test_conv_fwt1d(
wavelet_string: str,
level: Optional[int],
mode: BoundaryMode,
length: int,
batch_size: int,
dtype: torch.dtype,
) -> None:
"""Test multiple convolution fwt, for various levels and padding options."""
generator = MackeyGenerator(
batch_size=batch_size, tmax=length, delta_t=1, device="cpu"
Expand Down Expand Up @@ -57,7 +67,7 @@ def test_conv_fwt1d(wavelet_string, level, mode, length, batch_size, dtype):

@pytest.mark.parametrize("size", [[5, 10, 64], [1, 1, 32]])
@pytest.mark.parametrize("wavelet", ["haar", "db2"])
def test_conv_fwt1d_channel(size, wavelet):
def test_conv_fwt1d_channel(size: List[int], wavelet: str) -> None:
"""Test channel dimension support."""
data = torch.randn(*size).type(torch.float64)
ptwt_coeff = wavedec(data, wavelet)
Expand All @@ -72,10 +82,10 @@ def test_conv_fwt1d_channel(size, wavelet):
assert np.allclose(data.numpy(), rec.numpy())


def test_ripples_haar_lvl3():
def test_ripples_haar_lvl3() -> None:
"""Compute example from page 7 of Ripples in Mathematics, Jensen, la Cour-Harbo."""

class _MyHaarFilterBank(object):
class _MyHaarFilterBank:
@property
def filter_bank(self):
"""Unscaled Haar wavelet filters."""
Expand All @@ -95,7 +105,7 @@ def filter_bank(self):
assert (torch.squeeze(coeffs[3]).numpy() == [8.0, -8.0, 0.0, 12.0]).all()


def test_orth_wavelet():
def test_orth_wavelet() -> None:
"""Test an orthogonal wavelet fwt."""
generator = MackeyGenerator(batch_size=2, tmax=64, delta_t=1, device="cpu")

Expand All @@ -114,7 +124,7 @@ def test_orth_wavelet():

@pytest.mark.parametrize("level", [1, 2, 3, None])
@pytest.mark.parametrize("shape", [(64,), (1, 64), (3, 2, 64), (4, 3, 2, 64)])
def test_1d_multibatch(level, shape):
def test_1d_multibatch(level: Optional[int], shape: Sequence[int]) -> None:
"""Test 1D conv support for multiple inert batch dimensions."""
data = torch.randn(*shape, dtype=torch.float64)
ptwt_coeff = wavedec(data, "haar", level=level)
Expand Down Expand Up @@ -144,7 +154,7 @@ def test_1d_axis_arg(axis):
assert torch.allclose(rec, data)


def test_2d_haar_lvl1():
def test_2d_haar_lvl1() -> None:
"""Test a 2d-Haar wavelet conv-fwt."""
# ------------------------- 2d haar wavelet tests -----------------------
face = np.transpose(
Expand All @@ -161,7 +171,7 @@ def test_2d_haar_lvl1():
assert np.allclose(rec, face)


def test_2d_db2_lvl1():
def test_2d_db2_lvl1() -> None:
"""Test a 2d-db2 wavelet conv-fwt."""
# single level db2 - 2d
face = np.transpose(
Expand Down
12 changes: 7 additions & 5 deletions tests/test_convolution_fwt_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def _expand_dims(batch_list: List) -> List:
for pos, bel in enumerate(batch_list):
if type(bel) is np.ndarray:
if isinstance(bel, np.ndarray):
batch_list[pos] = np.expand_dims(bel, 0)
else:
for key, item in batch_list[pos].items():
Expand Down Expand Up @@ -57,7 +57,9 @@ def _cat_batch_list(batch_lists: List) -> List:
@pytest.mark.parametrize("wavelet", ["haar", "db2", "db4"])
@pytest.mark.parametrize("level", [1, 2, None])
@pytest.mark.parametrize("mode", typing.get_args(BoundaryMode))
def test_waverec3(shape: list, wavelet: str, level: int, mode: BoundaryMode) -> None:
def test_waverec3(
shape: List[int], wavelet: str, level: int, mode: BoundaryMode
) -> None:
"""Ensure the 3d analysis transform is invertible."""
data = np.random.randn(*shape)
data = torch.from_numpy(data)
Expand Down Expand Up @@ -92,7 +94,7 @@ def test_waverec3(shape: list, wavelet: str, level: int, mode: BoundaryMode) ->
@pytest.mark.parametrize("level", [1, 2, None])
@pytest.mark.parametrize("wavelet", ["haar", "sym3", "db3"])
@pytest.mark.parametrize("mode", ["zero", "symmetric", "reflect"])
def test_multidim_input(size: List[int], level: int, wavelet: str, mode: str):
def test_multidim_input(size: List[int], level: int, wavelet: str, mode: str) -> None:
"""Ensure correct folding of multidimensional inputs."""
data = torch.randn(size, dtype=torch.float64)
ptwc = ptwt.wavedec3(data, wavelet, level=level, mode=mode)
Expand Down Expand Up @@ -146,14 +148,14 @@ def test_axes_arg_3d(
assert np.allclose(data, rec)


def test_2d_dimerror():
def test_2d_dimerror() -> None:
"""Check the error for too many axes."""
with pytest.raises(ValueError):
data = torch.randn([32, 32], dtype=torch.float64)
ptwt.wavedec3(data, "haar")


def test_1d_dimerror():
def test_1d_dimerror() -> None:
"""Check the error for too many axes."""
with pytest.raises(ValueError):
data = torch.randn([32], dtype=torch.float64)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_cwt_cuda(cuda: bool, wavelet: str = "cgau6") -> None:


@pytest.mark.parametrize("wavelet", continuous_wavelets)
def test_cwt_batched(wavelet):
def test_cwt_batched(wavelet: str) -> None:
"""Test batched transforms."""
sig = np.random.randn(10, 200)
widths = np.arange(1, 30)
Expand All @@ -81,7 +81,7 @@ def test_cwt_batched(wavelet):

@pytest.mark.parametrize("type", ["shan1-1"])
@pytest.mark.parametrize("grid_size", [8, 9, 10])
def test_nn_schannon_wavefun(type: str, grid_size: int):
def test_nn_schannon_wavefun(type: str, grid_size: int) -> None:
"""Test the wavelet sampling for the differentiable shannon example."""
pywt_shannon = pywt.ContinuousWavelet(type)
ptwt_shannon = _ShannonWavelet(type)
Expand Down
10 changes: 6 additions & 4 deletions tests/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def _to_jit_wavedec_fun(data, wavelet, level):
@pytest.mark.parametrize("length", [64, 65])
@pytest.mark.parametrize("batch_size", [1, 3])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_conv_fwt_jit(wavelet_string, level, length, batch_size, dtype):
def test_conv_fwt_jit(
wavelet_string: str, level: int, length: int, batch_size: int, dtype: torch.dtype
) -> None:
"""Test jitting a convolution fwt, for various levels and padding options."""
generator = MackeyGenerator(
batch_size=batch_size, tmax=length, delta_t=1, device="cpu"
Expand Down Expand Up @@ -89,7 +91,7 @@ def _to_jit_waverec_2(data, wavelet):
return rec


def test_conv_fwt_jit_2d():
def test_conv_fwt_jit_2d() -> None:
"""Test the jit compilation feature for the wavedec2 function."""
data = torch.randn(10, 20, 20).type(torch.float64)
wavelet = pywt.Wavelet("db4")
Expand Down Expand Up @@ -141,7 +143,7 @@ def _to_jit_waverec_3(data, wavelet):
return rec


def test_conv_fwt_jit_3d():
def test_conv_fwt_jit_3d() -> None:
"""Test the jit compilation feature for the wavedec3 function."""
data = torch.randn(10, 20, 20, 20).type(torch.float64)
wavelet = pywt.Wavelet("db4")
Expand Down Expand Up @@ -171,7 +173,7 @@ def _to_jit_cwt(sig):
return cwtmatr


def test_cwt_jit():
def test_cwt_jit() -> None:
"""Test cwt jitting."""
t = np.linspace(-2, 2, 800, endpoint=False)
sig = torch.from_numpy(signal.chirp(t, f0=1, f1=12, t1=2, method="linear"))
Expand Down
4 changes: 2 additions & 2 deletions tests/test_matrix_fwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def test_4d_invalid_axis_error():


@pytest.mark.parametrize("size", [[2, 3, 32], [5, 32], [32], [1, 1, 64]])
def test_matrix1d_batch_channel(size):
def test_matrix1d_batch_channel(size: List[int]):
"""Test if batch and channel support works as expected."""
data = torch.randn(*size).type(torch.float64)
matrix_wavedec_1d = MatrixWavedec("haar", 3)
Expand All @@ -214,7 +214,7 @@ def test_matrix1d_batch_channel(size):


@pytest.mark.parametrize("axis", (0, 1, 2, 3, 4))
def test_axis_1d(axis):
def test_axis_1d(axis: int):
"""Ensure the axis argument is supported correctly."""
data = torch.randn(24, 24, 24, 24, 24).type(torch.float64)
matrix_wavedec = MatrixWavedec(wavelet="haar", level=3, axis=axis)
Expand Down
20 changes: 11 additions & 9 deletions tests/test_matrix_fwt_2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test code for the 2d boundary wavelets."""

# Created by moritz ( wolter@cs.uni-bonn.de ), 08.09.21
from typing import List, Type

import numpy as np
import pytest
Expand All @@ -9,7 +9,7 @@
import torch

from ptwt.conv_transform import _flatten_2d_coeff_lst
from ptwt.matmul_transform import MatrixWavedec, MatrixWaverec
from ptwt.matmul_transform import BaseMatrixWaveDec, MatrixWavedec, MatrixWaverec
from ptwt.matmul_transform_2 import (
MatrixWavedec2,
MatrixWaverec2,
Expand All @@ -18,6 +18,8 @@
)
from tests.test_convolution_fwt import _compare_coeffs

# Created by moritz ( wolter@cs.uni-bonn.de ), 08.09.21


@pytest.mark.parametrize("size", [(16, 16), (16, 8), (8, 16)])
@pytest.mark.parametrize("wavelet_str", ["db1", "db2", "db3", "db4", "db5"])
Expand Down Expand Up @@ -190,7 +192,7 @@ def test_separable_haar_2d():


@pytest.mark.parametrize("size", [[3, 2, 32, 32], [4, 32, 32], [1, 1, 32, 32]])
def test_batch_channel_2d_haar(size):
def test_batch_channel_2d_haar(size: List[int]):
"""Test matrix fwt-2d leading channel and batch dimension code."""
signal = torch.randn(*size).type(torch.float64)
ptwt_coeff = MatrixWavedec2("haar", 2, separable=False)(signal)
Expand All @@ -210,23 +212,23 @@ def test_batch_channel_2d_haar(size):


@pytest.mark.parametrize("operator", [MatrixWavedec2, MatrixWavedec])
def test_empty_operators(operator) -> None:
def test_empty_operators(operator: Type[BaseMatrixWaveDec]) -> None:
"""Check if the error is thrown properly if no matrix was ever built."""
if operator is MatrixWavedec2:
matrixfwt = operator("haar", separable=False)
matrixfwt = operator(wavelet="haar", separable=False)
else:
matrixfwt = operator("haar")
matrixfwt = operator(wavelet="haar")
with pytest.raises(ValueError):
_ = matrixfwt.sparse_fwt_operator


@pytest.mark.parametrize("operator", [MatrixWaverec2, MatrixWaverec])
def test_empty_inverse_operators(operator) -> None:
def test_empty_inverse_operators(operator: Type[BaseMatrixWaveDec]) -> None:
"""Check if the error is thrown properly if no matrix was ever built."""
if operator is MatrixWaverec2:
matrixifwt = operator("haar", separable=False)
matrixifwt = operator(wavelet="haar", separable=False)
else:
matrixifwt = operator("haar")
matrixifwt = operator(wavelet="haar")
with pytest.raises(ValueError):
_ = matrixifwt.sparse_ifwt_operator

Expand Down
16 changes: 9 additions & 7 deletions tests/test_matrix_fwt_3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test the 3d matrix-fwt code."""

from typing import List
from typing import List, Optional, Tuple

import numpy as np
import pytest
Expand All @@ -14,9 +14,9 @@

@pytest.mark.parametrize("axis", [1, 2, 3])
@pytest.mark.parametrize(
"shape", [[32, 32, 32], [64, 32, 32], [32, 64, 32], [32, 32, 64]]
"shape", [(32, 32, 32), (64, 32, 32), (32, 64, 32), (32, 32, 64)]
)
def test_single_dim_mm(axis: int, shape: tuple):
def test_single_dim_mm(axis: int, shape: Tuple[int, int, int]) -> None:
"""Test the transposed matrix multiplication approach."""
test_tensor = torch.rand(4, shape[0], shape[1], shape[2]).type(torch.float64)
pywt_dec_lo, pywt_dec_hi = pywt.wavedec(
Expand All @@ -30,9 +30,9 @@ def test_single_dim_mm(axis: int, shape: tuple):


@pytest.mark.parametrize(
"shape", [[32, 32, 32], [64, 32, 32], [32, 64, 32], [32, 32, 64]]
"shape", [(32, 32, 32), (64, 32, 32), (32, 64, 32), (32, 32, 64)]
)
def test_boundary_wavedec3_level1_haar(shape):
def test_boundary_wavedec3_level1_haar(shape: Tuple[int, int, int]) -> None:
"""Test a separable boundary 3d-transform."""
batch_size = 1
test_data = torch.rand(batch_size, shape[0], shape[1], shape[2]).type(torch.float64)
Expand Down Expand Up @@ -73,9 +73,11 @@ def test_boundary_wavedec3_level1_haar(shape):
@pytest.mark.slow
@pytest.mark.parametrize("level", [1, 2, 3, None])
@pytest.mark.parametrize(
"shape", [[31, 32, 33], [63, 35, 32], [32, 62, 31], [32, 32, 64]]
"shape", [(31, 32, 33), (63, 35, 32), (32, 62, 31), (32, 32, 64)]
)
def test_boundary_wavedec3_inverse(level, shape):
def test_boundary_wavedec3_inverse(
level: Optional[int], shape: Tuple[int, int, int]
) -> None:
"""Test the 3d matrix wavedec and the padding for odd axes."""
batch_size = 1
test_data = torch.rand(batch_size, shape[0], shape[1], shape[2]).type(torch.float64)
Expand Down
Loading

0 comments on commit f2d943e

Please sign in to comment.