Skip to content

Commit

Permalink
wavelet transform fixes (Cloud-Drift#259)
Browse files Browse the repository at this point in the history
* wavelet transform output change

* remove comment

* more shape tests

* docstring updates

* ValueError for morse transform

* Bump minor version

* consistent docstrings

* suppress warning

* reverse+docstring

* pyproject version

---------

Co-authored-by: milancurcic <caomaco@gmail.com>
  • Loading branch information
2 people authored and Philippe Miron committed Nov 16, 2023
1 parent 7a7f169 commit d2c3861
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 33 deletions.
64 changes: 39 additions & 25 deletions clouddrift/wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def morse_wavelet_transform(
Parameters
----------
x : np.ndarray
Real- or complex-valued signals.
Real- or complex-valued signals. The time axis is assumed to be the last. If not, specify optional
argument `time_axis`.
gamma: float
Gamma parameter of the Morse wavelets.
beta: float
Expand Down Expand Up @@ -69,22 +70,24 @@ def morse_wavelet_transform(
The boundary condition to be imposed at the edges of the input signal ``x``.
Allowed values are ``"mirror"``, ``"zeros"``, and ``"periodic"``. Default is ``"mirror"``.
order: int, optional
Order of wavelets, default is 1.
Order of Morse wavelets, default is 1.
Returns
-------
If the input signal is real as specificied by ``complex=False``:
wtx : np.ndarray
Time-domain wavelet transform of input ``x``. The axes of ``wtx`` will be organized as (x axes), orders, frequencies, time
unless ``time_axis`` is different from last (-1) in which case it will be moved back to its original position within the axes of ``x``.
Time-domain wavelet transform of input ``x`` with shape ((x shape without time_axis), orders, frequencies, time_axis)
but with dimensions of length 1 removed (squeezed).
If the input signal is complex as specificied by ``complex=True``:
If the input signal is complex as specificied by ``complex=True``, a tuple is returned:
wtx_p: np.array
Time-domain positive wavelet transform of input ``x``.
Time-domain positive wavelet transform of input ``x`` with shape ((x shape without time_axis), frequencies, orders),
but with dimensions of length 1 removed (squeezed).
wtx_n: np.array
Time-domain negative wavelet transform of input ``x``.
Time-domain negative wavelet transform of input ``x`` with shape ((x shape without time_axis), frequencies, orders),
but with dimensions of length 1 removed (squeezed).
Examples
--------
Expand Down Expand Up @@ -125,7 +128,7 @@ def morse_wavelet_transform(
>>> x = np.random.random((10,15,1024))
>>> wtx = morse_wavelet_transform(x, 3, 4, np.array([2*np.pi*0.2]), boundary="periodic")
This function can be used to complete a time-frequency analysis of the input signal by specifying
This function can be used to conduct a time-frequency analysis of the input signal by specifying
a range of randian frequencies using the ``morse_logspace_freq`` function as an example:
>>> x = np.random.random(1024)
Expand All @@ -134,6 +137,7 @@ def morse_wavelet_transform(
>>> radian_frequency = morse_logspace_freq(gamma, beta, np.shape(x)[0])
>>> wtx = morse_wavelet_transform(x, gamma, beta, radian_frequency)
Raises
------
ValueError
Expand Down Expand Up @@ -181,10 +185,15 @@ def morse_wavelet_transform(
)
wtx = wtx_p, wtx_n

else:
elif ~complex:
# real case
wtx = wavelet_transform(x, wavelet, boundary=boundary, time_axis=time_axis)

else:
raise ValueError(
"`complex` optional argument must be boolean 'True' or 'False'"
)

return wtx


Expand All @@ -207,9 +216,9 @@ def wavelet_transform(
wavelet : np.ndarray
A suite of time-domain wavelets, typically returned by the function ``morse_wavelet``.
The length of the time axis of the wavelets must be the last one and matches the
length of the time axis of x. The other dimensions (axes) of the wavelets (orders and frequencies) are
length of the time axis of x. The other dimensions (axes) of the wavelets (such as orders and frequencies) are
typically organized as orders, frequencies, and time, unless specified by optional arguments freq_axis and order_axis.
The normalization of the wavelets is assumed to be "bandpass", if not use kwarg normalization="energy", see ``morse_wavelet``.
The normalization of the wavelets is assumed to be "bandpass", if not, use kwarg normalization="energy", see ``morse_wavelet``.
boundary : str, optional
The boundary condition to be imposed at the edges of the input signal ``x``.
Allowed values are ``"mirror"``, ``"zeros"``, and ``"periodic"``. Default is ``"mirror"``.
Expand All @@ -224,8 +233,8 @@ def wavelet_transform(
Returns
-------
wtx : np.ndarray
Time-domain wavelet transform of input ``x``. The axes of ``wtx`` will be organized as (x axes), orders, frequencies, time
unless ``time_axis`` is different from last (-1) in which case it will be moved back to its original position within the axes of ``x``.
Time-domain wavelet transform of ``x`` with shape ((x shape without time_axis), orders, frequencies, time_axis)
but with dimensions of length 1 removed (squeezed).
Examples
--------
Expand Down Expand Up @@ -261,7 +270,7 @@ def wavelet_transform(
)
# Positions and time arrays must have the same shape.
if x.shape[time_axis] != wavelet.shape[-1]:
raise ValueError("x and wave time axes must have the same length.")
raise ValueError("x and wavelet time axes must have the same length.")

wavelet_ = np.moveaxis(wavelet, [freq_axis, order_axis], [-2, -3])

Expand Down Expand Up @@ -319,10 +328,13 @@ def wavelet_transform(
complex_dtype = np.cdouble if x.dtype == np.single else np.csingle
wtx = np.fft.ifft(X_ * np.conj(_wavelet_fft)).astype(complex_dtype)
wtx = wtx[..., index]
# remove extra dimensions

# reposition the time axis if needed from axis -1
if time_axis != -1:
wtx = np.moveaxis(wtx, -1, time_axis)

# remove extra dimensions if needed
wtx = np.squeeze(wtx)
# reposition the time axis: should I add a condition to do so only if time_axis!=-1? works anyway
wtx = np.moveaxis(wtx, -1, time_axis)

return wtx

Expand Down Expand Up @@ -362,9 +374,9 @@ def morse_wavelet(
Returns
-------
wavelet : np.ndarray
Time-domain wavelets. ``wavelet`` will be of shape (length,np.size(radian_frequency),order).
Time-domain wavelets with shape (order, radian_frequency, length).
wavelet_fft: np.ndarray
Frequency-domain wavelets. ``wavelet_fft`` will be of shape (length,np.size(radian_frequency),order).
Frequency-domain wavelets with shape (order, radian_frequency, length).
Examples
--------
Expand All @@ -380,10 +392,10 @@ def morse_wavelet(
>>> wavelet, wavelet_fft = morse_wavelet(1024, 3, 4, np.array([2*np.pi*0.2, 2*np.pi*0.3]), order=3)
>>> np.shape(wavelet)
(3, 3, 1024)
(3, 2, 1024)
Compute a Morse wavelet specifying an energy normalization :
>>> wavelet, wavelet_fft = morse_wavelet(1024, 3, 4, np.array([2*np.pi*0.2]), normalization=energy)
>>> wavelet, wavelet_fft = morse_wavelet(1024, 3, 4, np.array([2*np.pi*0.2]), normalization="energy")
Raises
------
Expand Down Expand Up @@ -468,6 +480,7 @@ def morse_wavelet(
# enforce length 1 for first axis if order=1 (no squeezing)
wavelet = np.moveaxis(wavelet, [0, 1, 2], [2, 0, 1])
waveletfft = np.moveaxis(waveletfft, [0, 1, 2], [2, 0, 1])

return wavelet, waveletfft


Expand Down Expand Up @@ -690,10 +703,11 @@ def _morsehigh(

for i in range(0, len(gamma)):
fm, _, _ = morse_freq(gamma[i], beta[i])
om = fm * np.pi / omhigh
lnwave1 = beta[i] / gamma[i] * np.log(np.exp(1) * gamma[i] / beta[i])
lnwave2 = beta[i] * np.log(om) - om ** gamma[i]
lnwave = lnwave1 + lnwave2
with np.errstate(all="ignore"):
om = fm * np.pi / omhigh
lnwave1 = beta[i] / gamma[i] * np.log(np.exp(1) * gamma[i] / beta[i])
lnwave2 = beta[i] * np.log(om) - om ** gamma[i]
lnwave = lnwave1 + lnwave2
index = np.nonzero(np.log(eta) - lnwave < 0)[0][0]
f[i] = omhigh[index]

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "clouddrift"
version = "0.21.3"
version = "0.21.4"
authors = [
{ name="Shane Elipot", email="selipot@miami.edu" },
{ name="Philippe Miron", email="philippemiron@gmail.com" },
Expand Down
26 changes: 19 additions & 7 deletions tests/wavelet_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,23 @@ def test_wavelet_transform_size(self):
gamma = 3
beta = 4
x = np.random.random((m, m * 2, length))
wave, _ = morse_wavelet(length, gamma, beta, radian_frequency, order=order)
w = wavelet_transform(x, wave)
self.assertTrue(np.shape(w) == (m, m * 2, order, len(radian_frequency), length))
wavelet, _ = morse_wavelet(length, gamma, beta, radian_frequency, order=order)
wtx = wavelet_transform(x, wavelet)
self.assertTrue(
np.shape(wtx) == (m, m * 2, order, len(radian_frequency), length)
)
x = np.random.random((length, m, m * 2))
wavelet, _ = morse_wavelet(length, gamma, beta, radian_frequency, order=order)
wtx = wavelet_transform(x, wavelet, time_axis=0)
self.assertTrue(
np.shape(wtx) == (length, m, m * 2, order, len(radian_frequency))
)
x = np.random.random((m, length, m * 2))
wavelet, _ = morse_wavelet(length, gamma, beta, radian_frequency, order=order)
wtx = wavelet_transform(x, wavelet, time_axis=1)
self.assertTrue(
np.shape(wtx) == (m, length, m * 2, order, len(radian_frequency))
)

def test_wavelet_transform_size_axis(self):
length = 1024
Expand All @@ -163,9 +177,9 @@ def test_wavelet_transform_centered(self):
J = 10
ao = np.logspace(np.log10(5), np.log10(40), J) / 100
x = np.zeros(2**10)
wave, _ = morse_wavelet(len(x), 2, 4, ao, order=1)
wavelet, _ = morse_wavelet(len(x), 2, 4, ao, order=1)
x[2**9] = 1
y = wavelet_transform(x, wave)
y = wavelet_transform(x, wavelet)
m = np.argmax(np.abs(y), axis=-1)
self.assertTrue(np.allclose(m, 2**9))

Expand All @@ -182,9 +196,7 @@ def test_wavelet_transform_data_real(self):
waveletb, _ = morse_wavelet(
np.shape(t)[0], gamma, beta, omega, normalization="bandpass"
)
# wavelete, _ = morse_wavelet(np.shape(t)[0],gamma,beta,omega,normalization="energy")
wtxb = wavelet_transform(x, waveletb, boundary="mirror")
# wtxe = wavelet_transform(x,wavelete,boundary="mirror")
self.assertTrue(np.isclose(np.var(wtxb), 2 * np.var(x), rtol=1e-1))

def test_wavelet_transform_data_complex(self):
Expand Down

0 comments on commit d2c3861

Please sign in to comment.