Skip to content

Commit

Permalink
Merge pull request #84 from NiclasPi/fix-padding
Browse files Browse the repository at this point in the history
SWT: Make circular padding wrap more than once if needed
  • Loading branch information
v0lta authored Jun 13, 2024
2 parents 638a3c0 + 3326e2b commit 78abc5f
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 101 deletions.
10 changes: 10 additions & 0 deletions docs/ptwt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ ptwt.separable\_conv\_transform module
:undoc-members:
:show-inheritance:


ptwt.stationary\_transform module
---------------------------------

.. automodule:: ptwt.stationary_transform
:members:
:undoc-members:
:show-inheritance:


ptwt.matmul\_transform module
-----------------------------

Expand Down
48 changes: 30 additions & 18 deletions src/ptwt/conv_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _adjust_padding_at_reconstruction(

def _preprocess_tensor_dec1d(
data: torch.Tensor,
) -> Tuple[torch.Tensor, Union[List[int], None]]:
) -> Tuple[torch.Tensor, List[int]]:
"""Preprocess input tensor dimensions.
Args:
Expand All @@ -227,13 +227,13 @@ def _preprocess_tensor_dec1d(
Returns:
Tuple[torch.Tensor, Union[List[int], None]]:
A data tensor of shape [new_batch, 1, to_process]
and the original shape, if the shape has changed.
and the original shape.
"""
ds = None
if len(data.shape) == 1:
ds = list(data.shape)
if len(ds) == 1:
# assume time series
data = data.unsqueeze(0).unsqueeze(0)
elif len(data.shape) == 2:
elif len(ds) == 2:
# assume batched time series
data = data.unsqueeze(1)
else:
Expand All @@ -243,18 +243,33 @@ def _preprocess_tensor_dec1d(


def _postprocess_result_list_dec1d(
result_lst: List[torch.Tensor], ds: List[int]
result_list: List[torch.Tensor], ds: List[int], axis: int
) -> List[torch.Tensor]:
# Unfold axes for the wavelets
return [_unfold_axes(fres, ds, 1) for fres in result_lst]
if len(ds) == 1:
result_list = [r_el.squeeze(0) for r_el in result_list]
elif len(ds) > 2:
# Unfold axes for the wavelets
result_list = [_unfold_axes(fres, ds, 1) for fres in result_list]
else:
result_list = result_list

if axis != -1:
result_list = [coeff.swapaxes(axis, -1) for coeff in result_list]

return result_list


def _preprocess_result_list_rec1d(
result_lst: List[torch.Tensor],
) -> Tuple[List[torch.Tensor], List[int]]:
# Fold axes for the wavelets
ds = list(result_lst[0].shape)
fold_coeffs = [_fold_axes(uf_coeff, 1)[0] for uf_coeff in result_lst]
if len(ds) == 1:
fold_coeffs = [uf_coeff.unsqueeze(0) for uf_coeff in result_lst]
elif len(ds) > 2:
fold_coeffs = [_fold_axes(uf_coeff, 1)[0] for uf_coeff in result_lst]
else:
fold_coeffs = result_lst
return fold_coeffs, ds


Expand Down Expand Up @@ -350,11 +365,7 @@ def wavedec(
result_list.append(res_lo.squeeze(1))
result_list.reverse()

if ds:
result_list = _postprocess_result_list_dec1d(result_list, ds)

if axis != -1:
result_list = [coeff.swapaxes(axis, -1) for coeff in result_list]
result_list = _postprocess_result_list_dec1d(result_list, ds, axis)

return result_list

Expand Down Expand Up @@ -412,9 +423,8 @@ def waverec(
raise ValueError("waverec transforms a single axis only.")

# fold channels, if necessary.
ds = None
if coeffs[0].dim() >= 3:
coeffs, ds = _preprocess_result_list_rec1d(coeffs)
ds = list(coeffs[0].shape)
coeffs, ds = _preprocess_result_list_rec1d(coeffs)

_, _, rec_lo, rec_hi = _get_filter_tensors(
wavelet, flip=False, device=torch_device, dtype=torch_dtype
Expand All @@ -439,7 +449,9 @@ def waverec(
if padr > 0:
res_lo = res_lo[..., :-padr]

if ds:
if len(ds) == 1:
res_lo = res_lo.squeeze(0)
elif len(ds) > 2:
res_lo = _unfold_axes(res_lo, ds, 1)

if axis != -1:
Expand Down
20 changes: 6 additions & 14 deletions src/ptwt/matmul_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,15 +382,7 @@ def __call__(self, input_signal: torch.Tensor) -> List[torch.Tensor]:
result_list = [s.T for s in split_list[::-1]]

# unfold if necessary
if ds:
result_list = _postprocess_result_list_dec1d(result_list, ds)

if self.axis != -1:
swap = []
for coeff in result_list:
swap.append(coeff.swapaxes(self.axis, -1))
result_list = swap

result_list = _postprocess_result_list_dec1d(result_list, ds, self.axis)
return result_list


Expand Down Expand Up @@ -616,9 +608,7 @@ def __call__(self, coefficients: List[torch.Tensor]) -> torch.Tensor:
swap.append(coeff.swapaxes(self.axis, -1))
coefficients = swap

ds = None
if coefficients[0].ndim > 2:
coefficients, ds = _preprocess_result_list_rec1d(coefficients)
coefficients, ds = _preprocess_result_list_rec1d(coefficients)

level = len(coefficients) - 1
input_length = coefficients[-1].shape[-1] * 2
Expand Down Expand Up @@ -670,8 +660,10 @@ def __call__(self, coefficients: List[torch.Tensor]) -> torch.Tensor:

res_lo = lo.T

if ds:
res_lo = _unfold_axes(res_lo.unsqueeze(-2), list(ds), 1).squeeze(-2)
if len(ds) == 1:
res_lo = res_lo.squeeze(0)
elif len(ds) > 2:
res_lo = _unfold_axes(res_lo, ds, 1)

if self.axis != -1:
res_lo = res_lo.swapaxes(self.axis, -1)
Expand Down
102 changes: 51 additions & 51 deletions src/ptwt/_stationary_transform.py → src/ptwt/stationary_transform.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""This module implements stationary wavelet transforms."""

from typing import List, Optional, Union
from typing import List, Optional, Sequence, Union

import pywt
import torch
import torch.nn.functional as F # noqa:N812

from ._util import Wavelet, _as_wavelet, _unfold_axes
from .conv_transform import (
Expand All @@ -14,14 +15,48 @@
)


def _swt(
def _circular_pad(x: torch.Tensor, padding_dimensions: Sequence[int]) -> torch.Tensor:
"""Pad a tensor in circular mode, more than once if needed."""
trailing_dimension = x.shape[-1]

# if every padding dimension is smaller than or equal the trailing dimension,
# we do not need to manually wrap
if not any(
padding_dimension > trailing_dimension
for padding_dimension in padding_dimensions
):
return F.pad(x, padding_dimensions, mode="circular")

# repeat to pad at maximum trailing dimensions until all padding dimensions are zero
while any(padding_dimension > 0 for padding_dimension in padding_dimensions):
# reduce every padding dimension to at maximum trailing dimension width
reduced_padding_dimensions = [
min(trailing_dimension, padding_dimension)
for padding_dimension in padding_dimensions
]
# pad using reduced dimensions,
# which will never throw the circular wrap error
x = F.pad(x, reduced_padding_dimensions, mode="circular")
# remove the pad width that was just padded, and repeat
# if any pad width is greater than zero
padding_dimensions = [
max(padding_dimension - trailing_dimension, 0)
for padding_dimension in padding_dimensions
]

return x


def swt(
data: torch.Tensor,
wavelet: Union[Wavelet, str],
level: Optional[int] = None,
axis: Optional[int] = -1,
) -> List[torch.Tensor]:
"""Compute a multilevel 1d stationary wavelet transform.
This fuctions is equivalent to pywt's swt with `trim_approx=True` and `norm=False`.
Args:
data (torch.Tensor): The input data of shape [batch_size, time].
wavelet (Union[Wavelet, str]): The wavelet to use.
Expand Down Expand Up @@ -56,57 +91,20 @@ def _swt(
for current_level in range(level):
dilation = 2**current_level
padl, padr = dilation * (filt_len // 2 - 1), dilation * (filt_len // 2)
res_lo = torch.nn.functional.pad(res_lo, [padl, padr], mode="circular")
res_lo = _circular_pad(res_lo, [padl, padr])
res = torch.nn.functional.conv1d(res_lo, filt, stride=1, dilation=dilation)
res_lo, res_hi = torch.split(res, 1, 1)
# Trim_approx == False
# result_list.append((res_lo.squeeze(1), res_hi.squeeze(1)))
result_list.append(res_hi.squeeze(1))
result_list.append(res_lo.squeeze(1))

if ds:
result_list = _postprocess_result_list_dec1d(result_list, ds)

if axis != -1:
result_list = [coeff.swapaxes(axis, -1) for coeff in result_list]
result_list = _postprocess_result_list_dec1d(result_list, ds, axis)

return result_list[::-1]


def _conv_transpose_dedilate(
conv_res: torch.Tensor,
rec_filt: torch.Tensor,
dilation: int,
length: int,
) -> torch.Tensor:
"""Undo the forward dilated convolution from the analysis transform.
Args:
conv_res (torch.Tensor): The dilated coeffcients
of shape [batch, 2, length].
rec_filt (torch.Tensor): The reconstruction filter pair
of shape [1, 2, filter_length].
dilation (int): The dilation factor.
length (int): The signal length.
Returns:
torch.Tensor: The deconvolution result.
"""
to_conv_t_list = [
conv_res[..., fl : (fl + dilation * rec_filt.shape[-1]) : dilation]
for fl in range(length)
]
to_conv_t = torch.cat(to_conv_t_list, 0)
padding = rec_filt.shape[-1] - 1
rec = torch.nn.functional.conv_transpose1d(
to_conv_t, rec_filt, stride=1, padding=padding, output_padding=0
)
rec = rec / 2.0
splits = torch.split(rec, rec.shape[0] // len(to_conv_t_list))
return torch.cat(splits, -1)


def _iswt(
def iswt(
coeffs: List[torch.Tensor],
wavelet: Union[pywt.Wavelet, str],
axis: Optional[int] = -1,
Expand Down Expand Up @@ -134,10 +132,7 @@ def _iswt(
else:
raise ValueError("iswt transforms a single axis only.")

ds = None
length = coeffs[0].shape[-1]
if coeffs[0].ndim > 2:
coeffs, ds = _preprocess_result_list_rec1d(coeffs)
coeffs, ds = _preprocess_result_list_rec1d(coeffs)

wavelet = _as_wavelet(wavelet)
_, _, rec_lo, rec_hi = _get_filter_tensors(
Expand All @@ -151,13 +146,18 @@ def _iswt(
dilation = 2 ** (len(coeffs[1:]) - c_pos - 1)
res_lo = torch.stack([res_lo, res_hi], 1)
padl, padr = dilation * (filt_len // 2), dilation * (filt_len // 2 - 1)
res_lo = torch.nn.functional.pad(res_lo, (padl, padr), mode="circular")
res_lo = _conv_transpose_dedilate(
res_lo, rec_filt, dilation=dilation, length=length
# res_lo = torch.nn.functional.pad(res_lo, (padl, padr), mode="circular")
res_lo_pad = _circular_pad(res_lo, (padl, padr))
res_lo = torch.mean(
torch.nn.functional.conv_transpose1d(
res_lo_pad, rec_filt, dilation=dilation, groups=2, padding=(padl + padr)
),
1,
)
res_lo = res_lo.squeeze(1)

if ds:
if len(ds) == 1:
res_lo = res_lo.squeeze(0)
elif len(ds) > 2:
res_lo = _unfold_axes(res_lo, ds, 1)

if axis != -1:
Expand Down
17 changes: 17 additions & 0 deletions tests/test_convolution_fwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ def test_conv_fwt1d_channel(size: List[int], wavelet: str) -> None:
assert np.allclose(data.numpy(), rec.numpy())


@pytest.mark.parametrize("size", [[32], [64]])
@pytest.mark.parametrize("wavelet", ["haar", "db2"])
def test_conv_fwt1d_nobatch(size: List[int], wavelet: str) -> None:
"""1d conv for inputs without batch dim."""
data = torch.randn(*size).type(torch.float64)
ptwt_coeff = wavedec(data, wavelet)
pywt_coeff = pywt.wavedec(data.numpy(), wavelet, mode="reflect")
assert all(
[
np.allclose(ptwtc.numpy(), pywtc)
for ptwtc, pywtc in zip(ptwt_coeff, pywt_coeff)
]
)
rec = waverec(ptwt_coeff, wavelet)
assert np.allclose(data.numpy(), rec.numpy())


def test_ripples_haar_lvl3() -> None:
"""Compute example from page 7 of Ripples in Mathematics, Jensen, la Cour-Harbo."""

Expand Down
Loading

0 comments on commit 78abc5f

Please sign in to comment.