Skip to content

Commit

Permalink
Update test_matrix_fwt.py
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt committed Jan 29, 2024
1 parent a889cf5 commit 108dd7a
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tests/test_matrix_fwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pywt
import torch

from ptwt.constants import OrthogonalizeMethod
from ptwt.matmul_transform import (
MatrixWavedec,
MatrixWaverec,
Expand Down Expand Up @@ -128,7 +129,7 @@ def test_boundary_filter_analysis_and_synthethis_matrices(
@pytest.mark.parametrize("level", [2, 1])
@pytest.mark.parametrize("boundary", ["gramschmidt", "qr"])
def test_boundary_transform_1d(
wavelet_str: str, data: np.ndarray, level: int, boundary: str
wavelet_str: str, data: np.ndarray, level: int, boundary: OrthogonalizeMethod
) -> None:
"""Ensure matrix fwt reconstructions are pywt compatible."""
data_torch = torch.from_numpy(data.astype(np.float64))
Expand Down Expand Up @@ -159,7 +160,7 @@ def test_boundary_transform_1d(

@pytest.mark.parametrize("wavelet_str", ["db2", "db3", "haar"])
@pytest.mark.parametrize("boundary", ["qr", "gramschmidt"])
def test_matrix_transform_1d_rebuild(wavelet_str: str, boundary: str) -> None:
def test_matrix_transform_1d_rebuild(wavelet_str: str, boundary: OrthogonalizeMethod) -> None:
"""Ensure matrix fwt reconstructions are pywt compatible."""
data_list = [np.random.randn(18), np.random.randn(21)]
wavelet = pywt.Wavelet(wavelet_str)
Expand Down Expand Up @@ -188,9 +189,7 @@ def test_matrix_transform_1d_rebuild(wavelet_str: str, boundary: str) -> None:
def test_4d_invalid_axis_error() -> None:
"""Test the error for 1d axis arguments."""
with pytest.raises(ValueError):
data = torch.randn(50, 50, 50, 50)
matrix_wavedec_1d = MatrixWavedec("haar", axis=(1, 2))
matrix_wavedec_1d(data)
MatrixWavedec("haar", axis=(1, 2))


@pytest.mark.parametrize("size", [[2, 3, 32], [5, 32], [32], [1, 1, 64]])
Expand Down

0 comments on commit 108dd7a

Please sign in to comment.