Skip to content

Commit

Permalink
Merge pull request #75 from v0lta/fix-tests
Browse files Browse the repository at this point in the history
Fix imports in tests
  • Loading branch information
v0lta authored Jan 26, 2024
2 parents c2bc884 + 0a2b467 commit f46f7bc
Show file tree
Hide file tree
Showing 14 changed files with 33 additions and 24 deletions.
1 change: 1 addition & 0 deletions tests/_mackey_glass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Generate artificial time-series data for debugging purposes."""

from typing import Optional, Union

import torch
Expand Down
6 changes: 3 additions & 3 deletions tests/test_convolution_fwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import torch
from scipy import datasets

from src.ptwt._util import _outer
from src.ptwt.conv_transform import (
from ptwt._util import _outer
from ptwt.conv_transform import (
_flatten_2d_coeff_lst,
_translate_boundary_strings,
wavedec,
waverec,
)
from src.ptwt.conv_transform_2 import wavedec2, waverec2
from ptwt.conv_transform_2 import wavedec2, waverec2
from src.ptwt.wavelets_learnable import SoftOrthogonalWavelet
from tests._mackey_glass import MackeyGenerator

Expand Down
2 changes: 1 addition & 1 deletion tests/test_convolution_fwt_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pywt
import torch

import src.ptwt as ptwt
import ptwt


def _expand_dims(batch_list: List) -> List:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_cwt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test the continuous transformation code."""

from typing import Union

import numpy as np
Expand All @@ -7,7 +8,7 @@
import torch
from scipy import signal

from src.ptwt.continuous_transform import _ShannonWavelet, cwt
from ptwt.continuous_transform import _ShannonWavelet, cwt

continuous_wavelets = [
"cgau1",
Expand Down
8 changes: 4 additions & 4 deletions tests/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch
from scipy import signal

import src.ptwt as ptwt
from src.ptwt.continuous_transform import _ShannonWavelet
import ptwt
from ptwt.continuous_transform import _ShannonWavelet
from tests._mackey_glass import MackeyGenerator


Expand All @@ -31,7 +31,7 @@ def _set_up_wavelet_tuple(wavelet, dtype):


def _to_jit_wavedec_fun(data, wavelet, level):
return ptwt.wavedec(data, wavelet, "reflect", level)
return ptwt.wavedec(data, wavelet, mode="reflect", level=level)


@pytest.mark.slow
Expand Down Expand Up @@ -117,7 +117,7 @@ def _to_jit_wavedec_3(data, wavelet):
means we have to stack the lists in the output.
"""
assert data.shape == (10, 20, 20, 20), "Changing the chape requires re-tracing."
coeff = ptwt.wavedec3(data, wavelet, "reflect", 2)
coeff = ptwt.wavedec3(data, wavelet, mode="reflect", level=2)
coeff2 = []
keys = ("aad", "ada", "add", "daa", "dad", "dda", "ddd")
for c in coeff:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_matrix_fwt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Test the fwt and ifwt matrices."""
# Written by moritz ( @ wolter.tech ) in 2021

from typing import List

import numpy as np
import pytest
import pywt
import torch

from src.ptwt.matmul_transform import (
from ptwt.matmul_transform import (
MatrixWavedec,
MatrixWaverec,
_construct_a,
Expand Down
7 changes: 4 additions & 3 deletions tests/test_matrix_fwt_2.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Test code for the 2d boundary wavelets."""
# Created by moritz ( wolter@cs.uni-bonn.de ), 08.09.21

import numpy as np
import pytest
import pywt
import scipy.signal
import torch

from src.ptwt.conv_transform import _flatten_2d_coeff_lst
from src.ptwt.matmul_transform import MatrixWavedec, MatrixWaverec
from src.ptwt.matmul_transform_2 import (
from ptwt.conv_transform import _flatten_2d_coeff_lst
from ptwt.matmul_transform import MatrixWavedec, MatrixWaverec
from ptwt.matmul_transform_2 import (
MatrixWavedec2,
MatrixWaverec2,
construct_boundary_a2,
Expand Down
7 changes: 4 additions & 3 deletions tests/test_matrix_fwt_3.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Test the 3d matrix-fwt code."""

from typing import List

import numpy as np
import pytest
import pywt
import torch

from src.ptwt.matmul_transform import construct_boundary_a
from src.ptwt.matmul_transform_3 import MatrixWavedec3, MatrixWaverec3
from src.ptwt.sparse_math import _batch_dim_mm
from ptwt.matmul_transform import construct_boundary_a
from ptwt.matmul_transform_3 import MatrixWavedec3, MatrixWaverec3
from ptwt.sparse_math import _batch_dim_mm


@pytest.mark.parametrize("axis", [1, 2, 3])
Expand Down
3 changes: 2 additions & 1 deletion tests/test_packets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test the wavelet packet code."""
# Created on Fri Apr 6 2021 by moritz (wolter@cs.uni-bonn.de)

from itertools import product

import numpy as np
Expand All @@ -8,7 +9,7 @@
import torch
from scipy import datasets

from src.ptwt.packets import WaveletPacket, WaveletPacket2D, get_freq_order
from ptwt.packets import WaveletPacket, WaveletPacket2D, get_freq_order


def _compare_trees1(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_separable_conv_fwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import pywt
import torch

from src.ptwt.matmul_transform_2 import MatrixWavedec2
from src.ptwt.matmul_transform_3 import MatrixWavedec3
from src.ptwt.separable_conv_transform import (
from ptwt.matmul_transform_2 import MatrixWavedec2
from ptwt.matmul_transform_3 import MatrixWavedec3
from ptwt.separable_conv_transform import (
_separable_conv_wavedecn,
_separable_conv_waverecn,
fswavedec2,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_sparse_math.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Test the sparse math code from ptwt.sparse_math."""
# Written by moritz ( @ wolter.tech ) in 2021

import numpy as np
import pytest
import scipy.signal
import torch
from scipy import datasets

from src.ptwt.sparse_math import (
from ptwt.sparse_math import (
batch_mm,
construct_conv2d_matrix,
construct_conv_matrix,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_swt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pywt
import torch

from src.ptwt._stationary_transform import _swt
from ptwt._stationary_transform import _swt


@pytest.mark.slow
Expand Down
3 changes: 2 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Test the util methods."""

from typing import Tuple

import numpy as np
import pytest
import pywt
import torch

from src.ptwt._util import (
from ptwt._util import (
_as_wavelet,
_fold_axes,
_pad_symmetric,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_wavelet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Test the adaptive wavelet cost functions."""

import pytest
import pywt
import torch

from src.ptwt.wavelets_learnable import SoftOrthogonalWavelet
from ptwt.wavelets_learnable import SoftOrthogonalWavelet


@pytest.mark.parametrize(
Expand Down

0 comments on commit f46f7bc

Please sign in to comment.