diff --git a/clouddrift/wavelet.py b/clouddrift/wavelet.py index fd3d999e..f198b7bf 100644 --- a/clouddrift/wavelet.py +++ b/clouddrift/wavelet.py @@ -40,7 +40,8 @@ def morse_wavelet_transform( Parameters ---------- x : np.ndarray - Real- or complex-valued signals. + Real- or complex-valued signals. The time axis is assumed to be the last. If not, specify optional + argument `time_axis`. gamma: float Gamma parameter of the Morse wavelets. beta: float @@ -69,22 +70,24 @@ def morse_wavelet_transform( The boundary condition to be imposed at the edges of the input signal ``x``. Allowed values are ``"mirror"``, ``"zeros"``, and ``"periodic"``. Default is ``"mirror"``. order: int, optional - Order of wavelets, default is 1. + Order of Morse wavelets, default is 1. Returns ------- If the input signal is real as specificied by ``complex=False``: wtx : np.ndarray - Time-domain wavelet transform of input ``x``. The axes of ``wtx`` will be organized as (x axes), orders, frequencies, time - unless ``time_axis`` is different from last (-1) in which case it will be moved back to its original position within the axes of ``x``. + Time-domain wavelet transform of input ``x`` with shape ((x shape without time_axis), orders, frequencies, time_axis) + but with dimensions of length 1 removed (squeezed). - If the input signal is complex as specificied by ``complex=True``: + If the input signal is complex as specificied by ``complex=True``, a tuple is returned: wtx_p: np.array - Time-domain positive wavelet transform of input ``x``. + Time-domain positive wavelet transform of input ``x`` with shape ((x shape without time_axis), frequencies, orders), + but with dimensions of length 1 removed (squeezed). wtx_n: np.array - Time-domain negative wavelet transform of input ``x``. + Time-domain negative wavelet transform of input ``x`` with shape ((x shape without time_axis), frequencies, orders), + but with dimensions of length 1 removed (squeezed). Examples -------- @@ -125,7 +128,7 @@ def morse_wavelet_transform( >>> x = np.random.random((10,15,1024)) >>> wtx = morse_wavelet_transform(x, 3, 4, np.array([2*np.pi*0.2]), boundary="periodic") - This function can be used to complete a time-frequency analysis of the input signal by specifying + This function can be used to conduct a time-frequency analysis of the input signal by specifying a range of randian frequencies using the ``morse_logspace_freq`` function as an example: >>> x = np.random.random(1024) @@ -134,6 +137,7 @@ def morse_wavelet_transform( >>> radian_frequency = morse_logspace_freq(gamma, beta, np.shape(x)[0]) >>> wtx = morse_wavelet_transform(x, gamma, beta, radian_frequency) + Raises ------ ValueError @@ -181,10 +185,15 @@ def morse_wavelet_transform( ) wtx = wtx_p, wtx_n - else: + elif ~complex: # real case wtx = wavelet_transform(x, wavelet, boundary=boundary, time_axis=time_axis) + else: + raise ValueError( + "`complex` optional argument must be boolean 'True' or 'False'" + ) + return wtx @@ -207,9 +216,9 @@ def wavelet_transform( wavelet : np.ndarray A suite of time-domain wavelets, typically returned by the function ``morse_wavelet``. The length of the time axis of the wavelets must be the last one and matches the - length of the time axis of x. The other dimensions (axes) of the wavelets (orders and frequencies) are + length of the time axis of x. The other dimensions (axes) of the wavelets (such as orders and frequencies) are typically organized as orders, frequencies, and time, unless specified by optional arguments freq_axis and order_axis. - The normalization of the wavelets is assumed to be "bandpass", if not use kwarg normalization="energy", see ``morse_wavelet``. + The normalization of the wavelets is assumed to be "bandpass", if not, use kwarg normalization="energy", see ``morse_wavelet``. boundary : str, optional The boundary condition to be imposed at the edges of the input signal ``x``. Allowed values are ``"mirror"``, ``"zeros"``, and ``"periodic"``. Default is ``"mirror"``. @@ -224,8 +233,8 @@ def wavelet_transform( Returns ------- wtx : np.ndarray - Time-domain wavelet transform of input ``x``. The axes of ``wtx`` will be organized as (x axes), orders, frequencies, time - unless ``time_axis`` is different from last (-1) in which case it will be moved back to its original position within the axes of ``x``. + Time-domain wavelet transform of ``x`` with shape ((x shape without time_axis), orders, frequencies, time_axis) + but with dimensions of length 1 removed (squeezed). Examples -------- @@ -261,7 +270,7 @@ def wavelet_transform( ) # Positions and time arrays must have the same shape. if x.shape[time_axis] != wavelet.shape[-1]: - raise ValueError("x and wave time axes must have the same length.") + raise ValueError("x and wavelet time axes must have the same length.") wavelet_ = np.moveaxis(wavelet, [freq_axis, order_axis], [-2, -3]) @@ -319,10 +328,13 @@ def wavelet_transform( complex_dtype = np.cdouble if x.dtype == np.single else np.csingle wtx = np.fft.ifft(X_ * np.conj(_wavelet_fft)).astype(complex_dtype) wtx = wtx[..., index] - # remove extra dimensions + + # reposition the time axis if needed from axis -1 + if time_axis != -1: + wtx = np.moveaxis(wtx, -1, time_axis) + + # remove extra dimensions if needed wtx = np.squeeze(wtx) - # reposition the time axis: should I add a condition to do so only if time_axis!=-1? works anyway - wtx = np.moveaxis(wtx, -1, time_axis) return wtx @@ -362,9 +374,9 @@ def morse_wavelet( Returns ------- wavelet : np.ndarray - Time-domain wavelets. ``wavelet`` will be of shape (length,np.size(radian_frequency),order). + Time-domain wavelets with shape (order, radian_frequency, length). wavelet_fft: np.ndarray - Frequency-domain wavelets. ``wavelet_fft`` will be of shape (length,np.size(radian_frequency),order). + Frequency-domain wavelets with shape (order, radian_frequency, length). Examples -------- @@ -380,10 +392,10 @@ def morse_wavelet( >>> wavelet, wavelet_fft = morse_wavelet(1024, 3, 4, np.array([2*np.pi*0.2, 2*np.pi*0.3]), order=3) >>> np.shape(wavelet) - (3, 3, 1024) + (3, 2, 1024) Compute a Morse wavelet specifying an energy normalization : - >>> wavelet, wavelet_fft = morse_wavelet(1024, 3, 4, np.array([2*np.pi*0.2]), normalization=energy) + >>> wavelet, wavelet_fft = morse_wavelet(1024, 3, 4, np.array([2*np.pi*0.2]), normalization="energy") Raises ------ @@ -468,6 +480,7 @@ def morse_wavelet( # enforce length 1 for first axis if order=1 (no squeezing) wavelet = np.moveaxis(wavelet, [0, 1, 2], [2, 0, 1]) waveletfft = np.moveaxis(waveletfft, [0, 1, 2], [2, 0, 1]) + return wavelet, waveletfft @@ -690,10 +703,11 @@ def _morsehigh( for i in range(0, len(gamma)): fm, _, _ = morse_freq(gamma[i], beta[i]) - om = fm * np.pi / omhigh - lnwave1 = beta[i] / gamma[i] * np.log(np.exp(1) * gamma[i] / beta[i]) - lnwave2 = beta[i] * np.log(om) - om ** gamma[i] - lnwave = lnwave1 + lnwave2 + with np.errstate(all="ignore"): + om = fm * np.pi / omhigh + lnwave1 = beta[i] / gamma[i] * np.log(np.exp(1) * gamma[i] / beta[i]) + lnwave2 = beta[i] * np.log(om) - om ** gamma[i] + lnwave = lnwave1 + lnwave2 index = np.nonzero(np.log(eta) - lnwave < 0)[0][0] f[i] = omhigh[index] diff --git a/pyproject.toml b/pyproject.toml index cbc693d9..a197b052 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "clouddrift" -version = "0.21.3" +version = "0.21.4" authors = [ { name="Shane Elipot", email="selipot@miami.edu" }, { name="Philippe Miron", email="philippemiron@gmail.com" }, diff --git a/tests/wavelet_tests.py b/tests/wavelet_tests.py index 7fc3c9f6..df0d91b2 100644 --- a/tests/wavelet_tests.py +++ b/tests/wavelet_tests.py @@ -143,9 +143,23 @@ def test_wavelet_transform_size(self): gamma = 3 beta = 4 x = np.random.random((m, m * 2, length)) - wave, _ = morse_wavelet(length, gamma, beta, radian_frequency, order=order) - w = wavelet_transform(x, wave) - self.assertTrue(np.shape(w) == (m, m * 2, order, len(radian_frequency), length)) + wavelet, _ = morse_wavelet(length, gamma, beta, radian_frequency, order=order) + wtx = wavelet_transform(x, wavelet) + self.assertTrue( + np.shape(wtx) == (m, m * 2, order, len(radian_frequency), length) + ) + x = np.random.random((length, m, m * 2)) + wavelet, _ = morse_wavelet(length, gamma, beta, radian_frequency, order=order) + wtx = wavelet_transform(x, wavelet, time_axis=0) + self.assertTrue( + np.shape(wtx) == (length, m, m * 2, order, len(radian_frequency)) + ) + x = np.random.random((m, length, m * 2)) + wavelet, _ = morse_wavelet(length, gamma, beta, radian_frequency, order=order) + wtx = wavelet_transform(x, wavelet, time_axis=1) + self.assertTrue( + np.shape(wtx) == (m, length, m * 2, order, len(radian_frequency)) + ) def test_wavelet_transform_size_axis(self): length = 1024 @@ -163,9 +177,9 @@ def test_wavelet_transform_centered(self): J = 10 ao = np.logspace(np.log10(5), np.log10(40), J) / 100 x = np.zeros(2**10) - wave, _ = morse_wavelet(len(x), 2, 4, ao, order=1) + wavelet, _ = morse_wavelet(len(x), 2, 4, ao, order=1) x[2**9] = 1 - y = wavelet_transform(x, wave) + y = wavelet_transform(x, wavelet) m = np.argmax(np.abs(y), axis=-1) self.assertTrue(np.allclose(m, 2**9)) @@ -182,9 +196,7 @@ def test_wavelet_transform_data_real(self): waveletb, _ = morse_wavelet( np.shape(t)[0], gamma, beta, omega, normalization="bandpass" ) - # wavelete, _ = morse_wavelet(np.shape(t)[0],gamma,beta,omega,normalization="energy") wtxb = wavelet_transform(x, waveletb, boundary="mirror") - # wtxe = wavelet_transform(x,wavelete,boundary="mirror") self.assertTrue(np.isclose(np.var(wtxb), 2 * np.var(x), rtol=1e-1)) def test_wavelet_transform_data_complex(self):