Skip to content

Commit

Permalink
ENH: return complex output per taper
Browse files Browse the repository at this point in the history
  • Loading branch information
mmagnuski committed Feb 1, 2022
1 parent 89a975d commit 98d4860
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet',

# Initialize output
n_freqs = len(freqs)
n_tapers = len(Ws)
n_epochs, n_chans, n_times = epoch_data[:, :, decim].shape
if output in ('power', 'phase', 'avg_power', 'itc'):
dtype = np.float64
Expand All @@ -381,6 +382,8 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet',

if ('avg_' in output) or ('itc' in output):
out = np.empty((n_chans, n_freqs, n_times), dtype)
elif output == 'complex' and n_tapers > 1:
out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype)
else:
out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype)

Expand All @@ -400,7 +403,10 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet',

if ('avg_' not in output) and ('itc' not in output):
# This is to enforce that the first dimension is for epochs
out = out.transpose(1, 0, 2, 3)
if output == 'complex' and n_tapers > 1:
out = out.transpose(2, 0, 1, 3, 4)
else:
out = out.transpose(1, 0, 2, 3)
return out


Expand Down Expand Up @@ -508,15 +514,19 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim):

# Init outputs
decim = _check_decim(decim)
n_tapers = len(Ws)
n_epochs, n_times = X[:, decim].shape
n_freqs = len(Ws[0])
if ('avg_' in output) or ('itc' in output):
tfrs = np.zeros((n_freqs, n_times), dtype=dtype)
elif output == 'complex' and n_tapers > 1:
tfrs = np.zeros((n_tapers, n_epochs, n_freqs, n_times),
dtype=dtype)
else:
tfrs = np.zeros((n_epochs, n_freqs, n_times), dtype=dtype)

# Loops across tapers.
for W in Ws:
for taper_idx, W in enumerate(Ws):
# No need to check here, it's done earlier (outside parallel part)
nfft = _get_nfft(W, X, use_fft, check=False)
coefs = _cwt_gen(
Expand Down Expand Up @@ -544,6 +554,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim):
# Stack or add
if ('avg_' in output) or ('itc' in output):
tfrs += tfr
elif output == 'complex' and n_tapers > 1:
tfrs[taper_idx, epoch_idx] += tfr
else:
tfrs[epoch_idx] += tfr

Expand Down

0 comments on commit 98d4860

Please sign in to comment.