Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix imports in tests #75

Merged
merged 2 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/_mackey_glass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Generate artificial time-series data for debugging purposes."""

from typing import Optional, Union

import torch
Expand Down
6 changes: 3 additions & 3 deletions tests/test_convolution_fwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import torch
from scipy import datasets

from src.ptwt._util import _outer
from src.ptwt.conv_transform import (
from ptwt._util import _outer
from ptwt.conv_transform import (
_flatten_2d_coeff_lst,
_translate_boundary_strings,
wavedec,
waverec,
)
from src.ptwt.conv_transform_2 import wavedec2, waverec2
from ptwt.conv_transform_2 import wavedec2, waverec2
from src.ptwt.wavelets_learnable import SoftOrthogonalWavelet
from tests._mackey_glass import MackeyGenerator

Expand Down
2 changes: 1 addition & 1 deletion tests/test_convolution_fwt_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pywt
import torch

import src.ptwt as ptwt
import ptwt


def _expand_dims(batch_list: List) -> List:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_cwt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test the continuous transformation code."""

from typing import Union

import numpy as np
Expand All @@ -7,7 +8,7 @@
import torch
from scipy import signal

from src.ptwt.continuous_transform import _ShannonWavelet, cwt
from ptwt.continuous_transform import _ShannonWavelet, cwt

continuous_wavelets = [
"cgau1",
Expand Down
8 changes: 4 additions & 4 deletions tests/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch
from scipy import signal

import src.ptwt as ptwt
from src.ptwt.continuous_transform import _ShannonWavelet
import ptwt
from ptwt.continuous_transform import _ShannonWavelet
from tests._mackey_glass import MackeyGenerator


Expand All @@ -31,7 +31,7 @@ def _set_up_wavelet_tuple(wavelet, dtype):


def _to_jit_wavedec_fun(data, wavelet, level):
return ptwt.wavedec(data, wavelet, "reflect", level)
return ptwt.wavedec(data, wavelet, mode="reflect", level=level)


@pytest.mark.slow
Expand Down Expand Up @@ -117,7 +117,7 @@ def _to_jit_wavedec_3(data, wavelet):
means we have to stack the lists in the output.
"""
assert data.shape == (10, 20, 20, 20), "Changing the chape requires re-tracing."
coeff = ptwt.wavedec3(data, wavelet, "reflect", 2)
coeff = ptwt.wavedec3(data, wavelet, mode="reflect", level=2)
coeff2 = []
keys = ("aad", "ada", "add", "daa", "dad", "dda", "ddd")
for c in coeff:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_matrix_fwt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Test the fwt and ifwt matrices."""
# Written by moritz ( @ wolter.tech ) in 2021

from typing import List

import numpy as np
import pytest
import pywt
import torch

from src.ptwt.matmul_transform import (
from ptwt.matmul_transform import (
MatrixWavedec,
MatrixWaverec,
_construct_a,
Expand Down
7 changes: 4 additions & 3 deletions tests/test_matrix_fwt_2.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Test code for the 2d boundary wavelets."""
# Created by moritz ( wolter@cs.uni-bonn.de ), 08.09.21

import numpy as np
import pytest
import pywt
import scipy.signal
import torch

from src.ptwt.conv_transform import _flatten_2d_coeff_lst
from src.ptwt.matmul_transform import MatrixWavedec, MatrixWaverec
from src.ptwt.matmul_transform_2 import (
from ptwt.conv_transform import _flatten_2d_coeff_lst
from ptwt.matmul_transform import MatrixWavedec, MatrixWaverec
from ptwt.matmul_transform_2 import (
MatrixWavedec2,
MatrixWaverec2,
construct_boundary_a2,
Expand Down
7 changes: 4 additions & 3 deletions tests/test_matrix_fwt_3.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Test the 3d matrix-fwt code."""

from typing import List

import numpy as np
import pytest
import pywt
import torch

from src.ptwt.matmul_transform import construct_boundary_a
from src.ptwt.matmul_transform_3 import MatrixWavedec3, MatrixWaverec3
from src.ptwt.sparse_math import _batch_dim_mm
from ptwt.matmul_transform import construct_boundary_a
from ptwt.matmul_transform_3 import MatrixWavedec3, MatrixWaverec3
from ptwt.sparse_math import _batch_dim_mm


@pytest.mark.parametrize("axis", [1, 2, 3])
Expand Down
3 changes: 2 additions & 1 deletion tests/test_packets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test the wavelet packet code."""
# Created on Fri Apr 6 2021 by moritz (wolter@cs.uni-bonn.de)

from itertools import product

import numpy as np
Expand All @@ -8,7 +9,7 @@
import torch
from scipy import datasets

from src.ptwt.packets import WaveletPacket, WaveletPacket2D, get_freq_order
from ptwt.packets import WaveletPacket, WaveletPacket2D, get_freq_order


def _compare_trees1(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_separable_conv_fwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import pywt
import torch

from src.ptwt.matmul_transform_2 import MatrixWavedec2
from src.ptwt.matmul_transform_3 import MatrixWavedec3
from src.ptwt.separable_conv_transform import (
from ptwt.matmul_transform_2 import MatrixWavedec2
from ptwt.matmul_transform_3 import MatrixWavedec3
from ptwt.separable_conv_transform import (
_separable_conv_wavedecn,
_separable_conv_waverecn,
fswavedec2,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_sparse_math.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Test the sparse math code from ptwt.sparse_math."""
# Written by moritz ( @ wolter.tech ) in 2021

import numpy as np
import pytest
import scipy.signal
import torch
from scipy import datasets

from src.ptwt.sparse_math import (
from ptwt.sparse_math import (
batch_mm,
construct_conv2d_matrix,
construct_conv_matrix,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_swt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pywt
import torch

from src.ptwt._stationary_transform import _swt
from ptwt._stationary_transform import _swt


@pytest.mark.slow
Expand Down
3 changes: 2 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Test the util methods."""

from typing import Tuple

import numpy as np
import pytest
import pywt
import torch

from src.ptwt._util import (
from ptwt._util import (
_as_wavelet,
_fold_axes,
_pad_symmetric,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_wavelet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Test the adaptive wavelet cost functions."""

import pytest
import pywt
import torch

from src.ptwt.wavelets_learnable import SoftOrthogonalWavelet
from ptwt.wavelets_learnable import SoftOrthogonalWavelet


@pytest.mark.parametrize(
Expand Down
Loading