Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: return complex multitaper output per taper #10281

Merged
merged 15 commits into from
Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ Enhancements

- Fix some unused variables in time_frequency_erds.py example (:gh:`10076` by :newcontrib:`Jan Zerfowski`)

- :func:`mne.time_frequency.tfr_array_multitaper` can now return results for ``output='phase'`` instead of an error (:gh:`10281` by `Mikołaj Magnuski`_)

- Add show local maxima toggling button to :func:`mne.gui.locate_ieeg` (:gh:`9952` by `Alex Rockhill`_)

- Improve docstring of :class:`mne.Info` and add attributes that were not covered (:gh:`9922` by `Mathieu Scheltienne`_)
Expand Down Expand Up @@ -105,6 +107,8 @@ Bugs

- Teach :func:`mne.io.read_raw_bti` to use its ``eog_ch`` parameter (:gh:`10093` by :newcontrib:`Adina Wagner`)

- :func:`mne.time_frequency.tfr_array_multitaper` now returns results per taper when ``output='complex'`` (:gh:`10281` by `Mikołaj Magnuski`_)

- Fix default of :func:`mne.io.Raw.plot` to be ``use_opengl=None``, which will act like False unless ``MNE_BROWSER_USE_OPENGL=true`` is set in the user configuration (:gh:`9957` by `Eric Larson`_)

- Fix bug with :class:`mne.Report` where figures were saved with ``bbox_inches='tight'``, which led to inconsistent sizes in sliders (:gh:`9966` by `Eric Larson`_)
Expand Down
16 changes: 9 additions & 7 deletions mne/time_frequency/multitaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,9 +494,9 @@ def tfr_array_multitaper(epoch_data, sfreq, freqs, n_cycles=7.0,
is done after the convolutions.
output : str, default 'complex'

* 'complex' : single trial complex.
* 'complex' : single trial per taper complex values.
* 'power' : single trial power.
* 'phase' : single trial phase.
* 'phase' : single trial per taper phase.
* 'avg_power' : average of single trial power.
* 'itc' : inter-trial coherence.
* 'avg_power_itc' : average of single trial power and inter-trial
Expand All @@ -509,11 +509,13 @@ def tfr_array_multitaper(epoch_data, sfreq, freqs, n_cycles=7.0,
Returns
-------
out : array
Time frequency transform of epoch_data. If output is in ['complex',
'phase', 'power'], then shape of out is (n_epochs, n_chans, n_freqs,
n_times), else it is (n_chans, n_freqs, n_times). If output is
'avg_power_itc', the real values code for 'avg_power' and the
imaginary values code for the 'itc': out = avg_power + i * itc.
Time frequency transform of epoch_data. If ``output in ['complex',
'phase']``, then the shape of ``out`` is ``(n_epochs, n_chans,
n_tapers, n_freqs, n_times)``; if output is 'power', the shape of
``out`` is ``(n_epochs, n_chans, n_freqs, n_times)``, else it is
``(n_chans, n_freqs, n_times)``. If output is 'avg_power_itc', the real
values in ``out`` contain the average power and the imaginary values
contain the ITC: ``out = avg_power + i * itc``.

See Also
--------
Expand Down
43 changes: 29 additions & 14 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,21 @@ def test_time_frequency():
# computed within the method.
assert_allclose(epochs_amplitude_2.data**2, epochs_power_picks.data)

# test that averaging power across tapers when multitaper with
# output='complex' gives the same as output='power'
epoch_data = epochs.get_data()
multitaper_power = tfr_array_multitaper(
epoch_data, epochs.info['sfreq'], freqs, n_cycles,
output="power")
multitaper_complex = tfr_array_multitaper(
epoch_data, epochs.info['sfreq'], freqs, n_cycles,
output="complex")

taper_dim = 2
power_from_complex = (multitaper_complex * multitaper_complex.conj()
).real.mean(axis=taper_dim)
assert_allclose(power_from_complex, multitaper_power)

print(itc) # test repr
print(itc.ch_names) # test property
itc += power # test add
Expand Down Expand Up @@ -721,17 +736,16 @@ def test_compute_tfr():
(tfr_array_multitaper, tfr_array_morlet), (False, True), (False, True),
('complex', 'power', 'phase',
'avg_power_itc', 'avg_power', 'itc')):
# Check exception
if (func == tfr_array_multitaper) and (output == 'phase'):
pytest.raises(NotImplementedError, func, data, sfreq=sfreq,
freqs=freqs, output=output)
continue

# Check runs
out = func(data, sfreq=sfreq, freqs=freqs, use_fft=use_fft,
zero_mean=zero_mean, n_cycles=2., output=output)
# Check shapes
shape = np.r_[data.shape[:2], len(freqs), data.shape[2]]
if func == tfr_array_multitaper and output in ['complex', 'phase']:
n_tapers = 3
shape = np.r_[data.shape[:2], n_tapers, len(freqs), data.shape[2]]
else:
shape = np.r_[data.shape[:2], len(freqs), data.shape[2]]
if ('avg' in output) or ('itc' in output):
assert_array_equal(shape[1:], out.shape)
else:
Expand Down Expand Up @@ -762,9 +776,6 @@ def test_compute_tfr():
# No time_bandwidth param in morlet
pytest.raises(ValueError, _compute_tfr, data, freqs, sfreq,
method='morlet', time_bandwidth=1)
# No phase in multitaper XXX Check ?
pytest.raises(NotImplementedError, _compute_tfr, data, freqs, sfreq,
method='multitaper', output='phase')

# Inter-trial coherence tests
out = _compute_tfr(data, freqs, sfreq, output='itc', n_cycles=2.)
Expand All @@ -780,10 +791,11 @@ def test_compute_tfr():
_decim = slice(None, None, decim) if isinstance(decim, int) else decim
n_time = len(np.arange(data.shape[2])[_decim])
shape = np.r_[data.shape[:2], len(freqs), n_time]

for method in ('multitaper', 'morlet'):
# Single trials
out = _compute_tfr(data, freqs, sfreq, method=method, decim=decim,
n_cycles=2.)
output='power', n_cycles=2.)
assert_array_equal(shape, out.shape)
# Averages
out = _compute_tfr(data, freqs, sfreq, method=method, decim=decim,
Expand All @@ -798,14 +810,17 @@ def test_compute_tfr_correct(method, decim):
sfreq = 1000.
t = np.arange(1000) / sfreq
f = 50.
data = np.sin(2 * np.pi * 50. * t)
data = np.sin(2 * np.pi * f * t)
data *= np.hanning(data.size)
data = data[np.newaxis, np.newaxis]
freqs = np.arange(10, 111, 10)
freqs = np.arange(10, 111, 4)
assert f in freqs

# previous n_cycles=2 gives weird results for multitaper
n_cycles = freqs * 0.25
tfr = _compute_tfr(data, freqs, sfreq, method=method, decim=decim,
n_cycles=2)[0, 0]
assert freqs[np.argmax(np.abs(tfr).mean(-1))] == f
n_cycles=n_cycles, output='power')[0, 0]
assert freqs[np.argmax(tfr.mean(-1))] == f


def test_averaging_epochsTFR():
Expand Down
43 changes: 29 additions & 14 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,10 +331,13 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet',
-------
out : array
Time frequency transform of epoch_data. If output is in ['complex',
'phase', 'power'], then shape of out is (n_epochs, n_chans, n_freqs,
n_times), else it is (n_chans, n_freqs, n_times). If output is
'avg_power_itc', the real values code for 'avg_power' and the
imaginary values code for the 'itc': out = avg_power + i * itc
'phase', 'power'], then shape of ``out`` is ``(n_epochs, n_chans,
n_freqs, n_times)``, else it is ``(n_chans, n_freqs, n_times)``.
However, using multitaper method and output ``'complex'`` or
``'phase'`` results in shape of ``out`` being ``(n_epochs, n_chans,
n_tapers, n_freqs, n_times)``. If output is ``'avg_power_itc'``, the
real values in the ``output`` contain average power' and the imaginary
values contain the ITC: ``out = avg_power + i * itc``.
"""
# Check data
epoch_data = np.asarray(epoch_data)
Expand Down Expand Up @@ -370,6 +373,7 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet',

# Initialize output
n_freqs = len(freqs)
n_tapers = len(Ws)
n_epochs, n_chans, n_times = epoch_data[:, :, decim].shape
if output in ('power', 'phase', 'avg_power', 'itc'):
dtype = np.float64
Expand All @@ -380,6 +384,8 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet',

if ('avg_' in output) or ('itc' in output):
out = np.empty((n_chans, n_freqs, n_times), dtype)
elif output in ['complex', 'phase'] and method == 'multitaper':
out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype)
else:
out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype)

Expand All @@ -390,7 +396,7 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet',

# Parallelization is applied across channels.
tfrs = parallel(
my_cwt(channel, Ws, output, use_fft, 'same', decim)
my_cwt(channel, Ws, output, use_fft, 'same', decim, method)
for channel in epoch_data.transpose(1, 0, 2))

# FIXME: to avoid overheads we should use np.array_split()
Expand All @@ -399,7 +405,10 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet',

if ('avg_' not in output) and ('itc' not in output):
# This is to enforce that the first dimension is for epochs
out = out.transpose(1, 0, 2, 3)
if output in ['complex', 'phase'] and method == 'multitaper':
out = out.transpose(2, 0, 1, 3, 4)
else:
out = out.transpose(1, 0, 2, 3)
return out


Expand Down Expand Up @@ -428,11 +437,6 @@ def _check_tfr_param(freqs, sfreq, method, zero_mean, n_cycles,
% type(zero_mean))
freqs = np.asarray(freqs)

if (method == 'multitaper') and (output == 'phase'):
raise NotImplementedError(
'This function is not optimized to compute the phase using the '
'multitaper method. Use np.angle of the complex output instead.')

# Check n_cycles
if isinstance(n_cycles, (int, float)):
n_cycles = float(n_cycles)
Expand Down Expand Up @@ -472,7 +476,8 @@ def _check_tfr_param(freqs, sfreq, method, zero_mean, n_cycles,
return freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim


def _time_frequency_loop(X, Ws, output, use_fft, mode, decim):
def _time_frequency_loop(X, Ws, output, use_fft, mode, decim,
method=None):
"""Aux. function to _compute_tfr.

Loops time-frequency transform across wavelets and epochs.
Expand All @@ -499,6 +504,9 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim):
See numpy.convolve.
decim : slice
The decimation slice: e.g. power[:, decim]
method : str | None
Used only for multitapering to create tapers dimension in the output
if ``output in ['complex', 'phase']``.
"""
# Set output type
dtype = np.float64
Expand All @@ -507,15 +515,19 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim):

# Init outputs
decim = _check_decim(decim)
n_tapers = len(Ws)
n_epochs, n_times = X[:, decim].shape
n_freqs = len(Ws[0])
if ('avg_' in output) or ('itc' in output):
tfrs = np.zeros((n_freqs, n_times), dtype=dtype)
elif output in ['complex', 'phase'] and method == 'multitaper':
tfrs = np.zeros((n_tapers, n_epochs, n_freqs, n_times),
dtype=dtype)
else:
tfrs = np.zeros((n_epochs, n_freqs, n_times), dtype=dtype)

# Loops across tapers.
for W in Ws:
for taper_idx, W in enumerate(Ws):
# No need to check here, it's done earlier (outside parallel part)
nfft = _get_nfft(W, X, use_fft, check=False)
coefs = _cwt_gen(
Expand Down Expand Up @@ -543,6 +555,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim):
# Stack or add
if ('avg_' in output) or ('itc' in output):
tfrs += tfr
elif output in ['complex', 'phase'] and method == 'multitaper':
tfrs[taper_idx, epoch_idx] += tfr
else:
tfrs[epoch_idx] += tfr

Expand All @@ -557,7 +571,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim):
tfrs /= n_epochs

# Normalization by number of taper
tfrs /= len(Ws)
if n_tapers > 1 and output not in ['complex', 'phase']:
tfrs /= n_tapers
return tfrs


Expand Down