diff --git a/tests/_mackey_glass.py b/tests/_mackey_glass.py index fb8e9d8c..63f17ee3 100644 --- a/tests/_mackey_glass.py +++ b/tests/_mackey_glass.py @@ -1,4 +1,5 @@ """Generate artificial time-series data for debugging purposes.""" + from typing import Optional, Union import torch diff --git a/tests/test_convolution_fwt.py b/tests/test_convolution_fwt.py index d7663c28..baaed98d 100644 --- a/tests/test_convolution_fwt.py +++ b/tests/test_convolution_fwt.py @@ -6,14 +6,14 @@ import torch from scipy import datasets -from src.ptwt._util import _outer -from src.ptwt.conv_transform import ( +from ptwt._util import _outer +from ptwt.conv_transform import ( _flatten_2d_coeff_lst, _translate_boundary_strings, wavedec, waverec, ) -from src.ptwt.conv_transform_2 import wavedec2, waverec2 +from ptwt.conv_transform_2 import wavedec2, waverec2 from src.ptwt.wavelets_learnable import SoftOrthogonalWavelet from tests._mackey_glass import MackeyGenerator diff --git a/tests/test_convolution_fwt_3.py b/tests/test_convolution_fwt_3.py index 13b1ebbd..e6148fad 100644 --- a/tests/test_convolution_fwt_3.py +++ b/tests/test_convolution_fwt_3.py @@ -7,7 +7,7 @@ import pywt import torch -import src.ptwt as ptwt +import ptwt def _expand_dims(batch_list: List) -> List: diff --git a/tests/test_cwt.py b/tests/test_cwt.py index 01979d36..fa03830b 100644 --- a/tests/test_cwt.py +++ b/tests/test_cwt.py @@ -1,4 +1,5 @@ """Test the continuous transformation code.""" + from typing import Union import numpy as np @@ -7,7 +8,7 @@ import torch from scipy import signal -from src.ptwt.continuous_transform import _ShannonWavelet, cwt +from ptwt.continuous_transform import _ShannonWavelet, cwt continuous_wavelets = [ "cgau1", diff --git a/tests/test_jit.py b/tests/test_jit.py index 1f34e482..a821574f 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -7,8 +7,8 @@ import torch from scipy import signal -import src.ptwt as ptwt -from src.ptwt.continuous_transform import _ShannonWavelet +import ptwt +from ptwt.continuous_transform import _ShannonWavelet from tests._mackey_glass import MackeyGenerator @@ -31,7 +31,7 @@ def _set_up_wavelet_tuple(wavelet, dtype): def _to_jit_wavedec_fun(data, wavelet, level): - return ptwt.wavedec(data, wavelet, "reflect", level) + return ptwt.wavedec(data, wavelet, mode="reflect", level=level) @pytest.mark.slow @@ -117,7 +117,7 @@ def _to_jit_wavedec_3(data, wavelet): means we have to stack the lists in the output. """ assert data.shape == (10, 20, 20, 20), "Changing the chape requires re-tracing." - coeff = ptwt.wavedec3(data, wavelet, "reflect", 2) + coeff = ptwt.wavedec3(data, wavelet, mode="reflect", level=2) coeff2 = [] keys = ("aad", "ada", "add", "daa", "dad", "dda", "ddd") for c in coeff: diff --git a/tests/test_matrix_fwt.py b/tests/test_matrix_fwt.py index 0cc82484..913ba36f 100644 --- a/tests/test_matrix_fwt.py +++ b/tests/test_matrix_fwt.py @@ -1,5 +1,6 @@ """Test the fwt and ifwt matrices.""" # Written by moritz ( @ wolter.tech ) in 2021 + from typing import List import numpy as np @@ -7,7 +8,7 @@ import pywt import torch -from src.ptwt.matmul_transform import ( +from ptwt.matmul_transform import ( MatrixWavedec, MatrixWaverec, _construct_a, diff --git a/tests/test_matrix_fwt_2.py b/tests/test_matrix_fwt_2.py index d975a32e..b128154f 100644 --- a/tests/test_matrix_fwt_2.py +++ b/tests/test_matrix_fwt_2.py @@ -1,14 +1,15 @@ """Test code for the 2d boundary wavelets.""" # Created by moritz ( wolter@cs.uni-bonn.de ), 08.09.21 + import numpy as np import pytest import pywt import scipy.signal import torch -from src.ptwt.conv_transform import _flatten_2d_coeff_lst -from src.ptwt.matmul_transform import MatrixWavedec, MatrixWaverec -from src.ptwt.matmul_transform_2 import ( +from ptwt.conv_transform import _flatten_2d_coeff_lst +from ptwt.matmul_transform import MatrixWavedec, MatrixWaverec +from ptwt.matmul_transform_2 import ( MatrixWavedec2, MatrixWaverec2, construct_boundary_a2, diff --git a/tests/test_matrix_fwt_3.py b/tests/test_matrix_fwt_3.py index 09790810..868f5c89 100644 --- a/tests/test_matrix_fwt_3.py +++ b/tests/test_matrix_fwt_3.py @@ -1,4 +1,5 @@ """Test the 3d matrix-fwt code.""" + from typing import List import numpy as np @@ -6,9 +7,9 @@ import pywt import torch -from src.ptwt.matmul_transform import construct_boundary_a -from src.ptwt.matmul_transform_3 import MatrixWavedec3, MatrixWaverec3 -from src.ptwt.sparse_math import _batch_dim_mm +from ptwt.matmul_transform import construct_boundary_a +from ptwt.matmul_transform_3 import MatrixWavedec3, MatrixWaverec3 +from ptwt.sparse_math import _batch_dim_mm @pytest.mark.parametrize("axis", [1, 2, 3]) diff --git a/tests/test_packets.py b/tests/test_packets.py index 488b8035..8e54d496 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -1,5 +1,6 @@ """Test the wavelet packet code.""" # Created on Fri Apr 6 2021 by moritz (wolter@cs.uni-bonn.de) + from itertools import product import numpy as np @@ -8,7 +9,7 @@ import torch from scipy import datasets -from src.ptwt.packets import WaveletPacket, WaveletPacket2D, get_freq_order +from ptwt.packets import WaveletPacket, WaveletPacket2D, get_freq_order def _compare_trees1( diff --git a/tests/test_separable_conv_fwt.py b/tests/test_separable_conv_fwt.py index 07afac4b..1ad11b38 100644 --- a/tests/test_separable_conv_fwt.py +++ b/tests/test_separable_conv_fwt.py @@ -5,9 +5,9 @@ import pywt import torch -from src.ptwt.matmul_transform_2 import MatrixWavedec2 -from src.ptwt.matmul_transform_3 import MatrixWavedec3 -from src.ptwt.separable_conv_transform import ( +from ptwt.matmul_transform_2 import MatrixWavedec2 +from ptwt.matmul_transform_3 import MatrixWavedec3 +from ptwt.separable_conv_transform import ( _separable_conv_wavedecn, _separable_conv_waverecn, fswavedec2, diff --git a/tests/test_sparse_math.py b/tests/test_sparse_math.py index 3cc45679..648db4cb 100644 --- a/tests/test_sparse_math.py +++ b/tests/test_sparse_math.py @@ -1,12 +1,13 @@ """Test the sparse math code from ptwt.sparse_math.""" # Written by moritz ( @ wolter.tech ) in 2021 + import numpy as np import pytest import scipy.signal import torch from scipy import datasets -from src.ptwt.sparse_math import ( +from ptwt.sparse_math import ( batch_mm, construct_conv2d_matrix, construct_conv_matrix, diff --git a/tests/test_swt.py b/tests/test_swt.py index 1266f781..0dc4f494 100644 --- a/tests/test_swt.py +++ b/tests/test_swt.py @@ -5,7 +5,7 @@ import pywt import torch -from src.ptwt._stationary_transform import _swt +from ptwt._stationary_transform import _swt @pytest.mark.slow diff --git a/tests/test_util.py b/tests/test_util.py index a6b6016c..6f756ac0 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,4 +1,5 @@ """Test the util methods.""" + from typing import Tuple import numpy as np @@ -6,7 +7,7 @@ import pywt import torch -from src.ptwt._util import ( +from ptwt._util import ( _as_wavelet, _fold_axes, _pad_symmetric, diff --git a/tests/test_wavelet.py b/tests/test_wavelet.py index eea5baf2..36b1a236 100644 --- a/tests/test_wavelet.py +++ b/tests/test_wavelet.py @@ -1,9 +1,10 @@ """Test the adaptive wavelet cost functions.""" + import pytest import pywt import torch -from src.ptwt.wavelets_learnable import SoftOrthogonalWavelet +from ptwt.wavelets_learnable import SoftOrthogonalWavelet @pytest.mark.parametrize(