Skip to content

Commit

Permalink
support for nd arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanhammonds committed Sep 28, 2023
1 parent cf053ea commit 0350005
Show file tree
Hide file tree
Showing 12 changed files with 66 additions and 42 deletions.
4 changes: 2 additions & 2 deletions neurodsp/burst/dualthresh.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def detect_bursts_dual_threshold(sig, fs, dual_thresh, f_range=None,
Parameters
----------
sig : 1d or 2d array
sig : nd array
Time series.
fs : float
Sampling rate, in Hz.
Expand All @@ -44,7 +44,7 @@ def detect_bursts_dual_threshold(sig, fs, dual_thresh, f_range=None,
Returns
-------
is_burst : 1d array
is_burst : nd array
Boolean indication of where bursts are present in the input signal.
True indicates that a burst was detected at that sample, otherwise False.
Expand Down
4 changes: 2 additions & 2 deletions neurodsp/filt/fir.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,14 @@ def apply_fir_filter(sig, filter_coefs):
Parameters
----------
sig : array
sig : nd array
Time series to be filtered.
filter_coefs : 1d array
Filter coefficients of the FIR filter.
Returns
-------
array
nd array
Filtered time series.
Examples
Expand Down
4 changes: 2 additions & 2 deletions neurodsp/filt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,14 @@ def remove_filter_edges(sig, filt_len):
Parameters
----------
sig : 1d or 2d array
sig : nd array
Filtered signal to have edge artifacts removed from.
filt_len : int
Length of the filter that was applied.
Returns
-------
sig : 1d array
sig : nd array
Filter signal with edge artifacts switched to NaNs.
Examples
Expand Down
4 changes: 2 additions & 2 deletions neurodsp/rhythm/lc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def compute_lagged_coherence(sig, fs, freqs, n_cycles=3, return_spectrum=False):
Parameters
----------
sig : 1d or 2d array
sig : nd array
Time series.
fs : float
Sampling rate, in Hz.
Expand All @@ -37,7 +37,7 @@ def compute_lagged_coherence(sig, fs, freqs, n_cycles=3, return_spectrum=False):
Returns
-------
lcs : float or 1d array
lcs : float or nd array
If `return_spectrum` is False: mean lagged coherence value across the frequency range.
If `return_spectrum` is True: lagged coherence values for all frequencies.
freqs : 1d array
Expand Down
6 changes: 3 additions & 3 deletions neurodsp/rhythm/swm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def sliding_window_matching(sig, fs, win_len, win_spacing, max_iterations=100,
Parameters
----------
sig : 1d or 2d array
sig : nd array
Time series.
fs : float
Sampling rate, in Hz.
Expand All @@ -32,9 +32,9 @@ def sliding_window_matching(sig, fs, win_len, win_spacing, max_iterations=100,
Returns
-------
windows : 2d array
windows : n x 2d array
Putative patterns discovered in the input signal.
window_starts : 1d array
window_starts : n x 1d array
Indices at which each window begins for the final set of windows.
Notes
Expand Down
16 changes: 8 additions & 8 deletions neurodsp/spectral/power.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def compute_spectrum(sig, fs, method='welch', avg_type='mean', **kwargs):
Parameters
----------
sig : 1d or 2d array
sig : nd array
Time series.
fs : float
Sampling rate, in Hz.
Expand All @@ -41,7 +41,7 @@ def compute_spectrum(sig, fs, method='welch', avg_type='mean', **kwargs):
-------
freqs : 1d array
Frequencies at which the measure was calculated.
spectrum : 1d or 2d array
spectrum : nd array
Power spectral density.
Examples
Expand Down Expand Up @@ -72,7 +72,7 @@ def compute_spectrum_wavelet(sig, fs, freqs, avg_type='mean', **kwargs):
Parameters
----------
sig : 1d or 2d array
sig : nd array
Time series.
fs : float
Sampling rate, in Hz.
Expand All @@ -89,7 +89,7 @@ def compute_spectrum_wavelet(sig, fs, freqs, avg_type='mean', **kwargs):
-------
freqs : 1d array
Frequencies at which the measure was calculated.
spectrum : 1d or 2d array
spectrum : nd array
Power spectral density.
Examples
Expand Down Expand Up @@ -124,7 +124,7 @@ def compute_spectrum_welch(sig, fs, avg_type='mean', window='hann',
Parameters
----------
sig : 1d or 2d array
sig : nd array
Time series.
fs : float
Sampling rate, in Hz.
Expand Down Expand Up @@ -152,7 +152,7 @@ def compute_spectrum_welch(sig, fs, avg_type='mean', window='hann',
-------
freqs : 1d array
Frequencies at which the measure was calculated.
spectrum : 1d or 2d array
spectrum : nd array
Power spectral density.
Notes
Expand Down Expand Up @@ -200,7 +200,7 @@ def compute_spectrum_medfilt(sig, fs, filt_len=1., f_range=None):
Parameters
----------
sig : 1d or 2d array
sig : nd array
Time series.
fs : float
Sampling rate, in Hz.
Expand All @@ -213,7 +213,7 @@ def compute_spectrum_medfilt(sig, fs, filt_len=1., f_range=None):
-------
freqs : 1d array
Frequencies at which the measure was calculated.
spectrum : 1d or 2d array
spectrum : nd array
Power spectral density.
Examples
Expand Down
12 changes: 6 additions & 6 deletions neurodsp/spectral/variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def compute_scv(sig, fs, window='hann', nperseg=None, noverlap=0, outlier_pct=No
Parameters
----------
sig : 1d or 2d array
sig : nd array
Time series of measurement values.
fs : float
Sampling rate, in Hz.
Expand All @@ -39,7 +39,7 @@ def compute_scv(sig, fs, window='hann', nperseg=None, noverlap=0, outlier_pct=No
-------
freqs : 1d array
Frequencies at which the measure was calculated.
scv : 1d array
scv : nd array
Spectral coefficient of variation.
Notes
Expand Down Expand Up @@ -75,7 +75,7 @@ def compute_scv_rs(sig, fs, window='hann', nperseg=None, noverlap=0,
Parameters
----------
sig : 1d or 2d array
sig : nd array
Time series of measurement values.
fs : float
Sampling rate, in Hz.
Expand Down Expand Up @@ -112,7 +112,7 @@ def compute_scv_rs(sig, fs, window='hann', nperseg=None, noverlap=0,
t_inds : 1d array or None
Time indices at which the measure was calculated.
This is only returned for 'rolling' resampling. If 'bootstrap', t_inds = None.
scv_rs : 2d array
scv_rs : nd array
Resampled spectral coefficient of variation.
Notes
Expand Down Expand Up @@ -185,7 +185,7 @@ def compute_spectral_hist(sig, fs, window='hann', nperseg=None, noverlap=None,
Parameters
----------
sig : 1d or 2d array
sig : nd array
Time series of measurement values.
fs : float
Sampling rate, in Hz.
Expand All @@ -211,7 +211,7 @@ def compute_spectral_hist(sig, fs, window='hann', nperseg=None, noverlap=None,
Frequencies at which the measure was calculated.
power_bins : 1d array
Histogram bins used to compute the distribution.
spectral_hist : 2d array
spectral_hist : nd array
Power distribution at every frequency, as [n_bins, freqs].
Notes
Expand Down
19 changes: 16 additions & 3 deletions neurodsp/tests/utils/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,20 @@ def func(sig):
arr2d = np.array([[1, 2], [1, 2]])
assert np.array_equal(func(arr2d), np.array([3, 3]))

# Check error for input of unsupported dimension
# 3d input
# note: func(arr3d) will return (dima, dimb, 1), so add the last dim to .sum() with a reshape
arr3d = np.array([[[1, 2], [3, 4]], [[1, 2], [3, 4]]])
with raises(ValueError):
func(arr3d)
assert np.array_equal(func(arr3d), arr3d.sum(axis=-1).reshape((*arr3d.shape[:-1], 1)))

# 4d input
arr4d = np.random.rand(2, 3, 4, 5)
assert np.array_equal(func(arr4d), arr4d.sum(axis=-1).reshape((*arr4d.shape[:-1], 1)))

# 2d return shape (e.g. compute_spectrum)
@multidim(select=[0])
def func(sig):
return np.arange(3), np.random.rand(3)

freqs, powers = func(arr3d)
assert np.array_equal(freqs, np.arange(3))
assert powers.shape == (*arr3d.shape[:-1], 3)
16 changes: 8 additions & 8 deletions neurodsp/timefrequency/hilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ def robust_hilbert(sig):
Parameters
----------
sig : 1d or 2d array
sig : nd array
Time series.
Returns
-------
sig_hilb : 1d array
sig_hilb : nd array
The analytic signal, of which the imaginary part is the Hilbert transform of the input.
Examples
Expand Down Expand Up @@ -53,7 +53,7 @@ def phase_by_time(sig, fs, f_range=None, remove_edges=True, **filter_kwargs):
Parameters
----------
sig : 1d or 2d array
sig : nd array
Time series.
fs : float
Sampling rate, in Hz.
Expand All @@ -67,7 +67,7 @@ def phase_by_time(sig, fs, f_range=None, remove_edges=True, **filter_kwargs):
Returns
-------
pha : 1d array
pha : nd array
Instantaneous phase time series.
Examples
Expand Down Expand Up @@ -99,7 +99,7 @@ def amp_by_time(sig, fs, f_range=None, remove_edges=True, **filter_kwargs):
Parameters
----------
sig : 1d or 2d array
sig : nd array
Time series.
fs : float
Sampling rate, in Hz.
Expand All @@ -113,7 +113,7 @@ def amp_by_time(sig, fs, f_range=None, remove_edges=True, **filter_kwargs):
Returns
-------
amp : 1d array
amp : nd array
Instantaneous amplitude time series.
Examples
Expand Down Expand Up @@ -144,7 +144,7 @@ def freq_by_time(sig, fs, f_range=None, remove_edges=True, **filter_kwargs):
Parameters
----------
sig : 1d or 2d array
sig : nd array
Time series.
fs : float
Sampling rate, in Hz.
Expand All @@ -158,7 +158,7 @@ def freq_by_time(sig, fs, f_range=None, remove_edges=True, **filter_kwargs):
Returns
-------
i_f : 1d array
i_f : nd array
Instantaneous frequency time series.
Notes
Expand Down
8 changes: 4 additions & 4 deletions neurodsp/timefrequency/wavelets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=7, scaling=0.5, norm='amp
Parameters
----------
sig : 1d or 2d array
sig : nd array
Time series.
fs : float
Sampling rate, in Hz.
Expand All @@ -37,7 +37,7 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=7, scaling=0.5, norm='amp
Returns
-------
mwt : 2d array
mwt : nd array
Time frequency representation of the input signal.
Notes
Expand Down Expand Up @@ -71,7 +71,7 @@ def convolve_wavelet(sig, fs, freq, n_cycles=7, scaling=0.5, wavelet_len=None, n
Parameters
----------
sig : 1d or 2d array
sig : nd array
Time series.
fs : float
Sampling rate, in Hz.
Expand All @@ -91,7 +91,7 @@ def convolve_wavelet(sig, fs, freq, n_cycles=7, scaling=0.5, wavelet_len=None, n
Returns
-------
array
nd array
Complex time series.
Notes
Expand Down
13 changes: 12 additions & 1 deletion neurodsp/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,18 @@ def wrapper(sig, *args, **kwargs):
out = np.array(outs)

else:
raise ValueError('Arrays of 3 or more dimensions are not supported.')
# Reshape to 2d and run func
shape = sig.shape
sig_2d = sig.reshape(-1, shape[-1])
out = wrapper(sig_2d, *args, **kwargs)

# Reshape back to original shape
if isinstance(out, (tuple, list)):
for i in range(len(out)):
if i not in select:
out[i] = out[i].reshape((*shape[:-1], -1))
else:
out = out.reshape((*shape[:-1], -1))

return out

Expand Down
2 changes: 1 addition & 1 deletion neurodsp/utils/outliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def restore_nans(sig, sig_nans, dtype=float):
Parameters
----------
sig : 1d or 2d array
sig : nd array
Signal that has had NaN edges removed.
sig_nans : 1d array
Boolean array indicating where NaNs were in the original array.
Expand Down

0 comments on commit 0350005

Please sign in to comment.