From 275ffaa63b497c32a923749f9d2ff10c8828b2aa Mon Sep 17 00:00:00 2001 From: mmagnuski Date: Tue, 1 Feb 2022 14:39:20 +0100 Subject: [PATCH 01/15] ENH: return complex output per taper --- mne/time_frequency/tfr.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index aeb23026e7c..3d26d200c27 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -370,6 +370,7 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet', # Initialize output n_freqs = len(freqs) + n_tapers = len(Ws) n_epochs, n_chans, n_times = epoch_data[:, :, decim].shape if output in ('power', 'phase', 'avg_power', 'itc'): dtype = np.float64 @@ -380,6 +381,8 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet', if ('avg_' in output) or ('itc' in output): out = np.empty((n_chans, n_freqs, n_times), dtype) + elif output == 'complex' and n_tapers > 1: + out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype) else: out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype) @@ -399,7 +402,10 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet', if ('avg_' not in output) and ('itc' not in output): # This is to enforce that the first dimension is for epochs - out = out.transpose(1, 0, 2, 3) + if output == 'complex' and n_tapers > 1: + out = out.transpose(2, 0, 1, 3, 4) + else: + out = out.transpose(1, 0, 2, 3) return out @@ -507,15 +513,19 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim): # Init outputs decim = _check_decim(decim) + n_tapers = len(Ws) n_epochs, n_times = X[:, decim].shape n_freqs = len(Ws[0]) if ('avg_' in output) or ('itc' in output): tfrs = np.zeros((n_freqs, n_times), dtype=dtype) + elif output == 'complex' and n_tapers > 1: + tfrs = np.zeros((n_tapers, n_epochs, n_freqs, n_times), + dtype=dtype) else: tfrs = np.zeros((n_epochs, n_freqs, n_times), dtype=dtype) # Loops across tapers. - for W in Ws: + for taper_idx, W in enumerate(Ws): # No need to check here, it's done earlier (outside parallel part) nfft = _get_nfft(W, X, use_fft, check=False) coefs = _cwt_gen( @@ -543,6 +553,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim): # Stack or add if ('avg_' in output) or ('itc' in output): tfrs += tfr + elif output == 'complex' and n_tapers > 1: + tfrs[taper_idx, epoch_idx] += tfr else: tfrs[epoch_idx] += tfr From 4c510e29b1216624170994423b7a6761196bfa6a Mon Sep 17 00:00:00 2001 From: mmagnuski Date: Tue, 1 Feb 2022 18:01:36 +0100 Subject: [PATCH 02/15] FIX: update test, so that per-taper shape is tested when complex output is requrested --- mne/time_frequency/tests/test_tfr.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index 168e6b0a513..436250b397d 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -731,7 +731,11 @@ def test_compute_tfr(): out = func(data, sfreq=sfreq, freqs=freqs, use_fft=use_fft, zero_mean=zero_mean, n_cycles=2., output=output) # Check shapes - shape = np.r_[data.shape[:2], len(freqs), data.shape[2]] + if func == tfr_array_multitaper and output == 'complex': + n_tapers = 3 + shape = np.r_[data.shape[:2], n_tapers, len(freqs), data.shape[2]] + else: + shape = np.r_[data.shape[:2], len(freqs), data.shape[2]] if ('avg' in output) or ('itc' in output): assert_array_equal(shape[1:], out.shape) else: From 19500f7b835b06b8b3889f4ec14a17735380a007 Mon Sep 17 00:00:00 2001 From: mmagnuski Date: Tue, 1 Feb 2022 18:39:58 +0100 Subject: [PATCH 03/15] FIX: change output to power in correctness test and have equal window length across freqs --- mne/time_frequency/tests/test_tfr.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index 436250b397d..743703d95b5 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -802,14 +802,17 @@ def test_compute_tfr_correct(method, decim): sfreq = 1000. t = np.arange(1000) / sfreq f = 50. - data = np.sin(2 * np.pi * 50. * t) + data = np.sin(2 * np.pi * f * t) data *= np.hanning(data.size) data = data[np.newaxis, np.newaxis] - freqs = np.arange(10, 111, 10) + freqs = np.arange(10, 111, 4) assert f in freqs + + # previous n_cycles=2 gives weird results for multitaper + n_cycles = freqs * 0.25 tfr = _compute_tfr(data, freqs, sfreq, method=method, decim=decim, - n_cycles=2)[0, 0] - assert freqs[np.argmax(np.abs(tfr).mean(-1))] == f + n_cycles=n_cycles, output='power')[0, 0] + assert freqs[np.argmax(tfr.mean(-1))] == f def test_averaging_epochsTFR(): From c08302c63cccd093f0d24e32f071b441f8310435 Mon Sep 17 00:00:00 2001 From: mmagnuski Date: Tue, 1 Feb 2022 18:40:45 +0100 Subject: [PATCH 04/15] FIX: output='power' in decim data shape tests --- mne/time_frequency/tests/test_tfr.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index 743703d95b5..51990b232c7 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -784,10 +784,11 @@ def test_compute_tfr(): _decim = slice(None, None, decim) if isinstance(decim, int) else decim n_time = len(np.arange(data.shape[2])[_decim]) shape = np.r_[data.shape[:2], len(freqs), n_time] + for method in ('multitaper', 'morlet'): # Single trials out = _compute_tfr(data, freqs, sfreq, method=method, decim=decim, - n_cycles=2.) + output='power', n_cycles=2.) assert_array_equal(shape, out.shape) # Averages out = _compute_tfr(data, freqs, sfreq, method=method, decim=decim, From 8e533d4c7e275cb206c4d54af9f63b3b3145fa81 Mon Sep 17 00:00:00 2001 From: mmagnuski Date: Wed, 2 Feb 2022 11:51:54 +0100 Subject: [PATCH 05/15] ENH: also return phase per taper --- mne/time_frequency/tfr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 3d26d200c27..18bfe9cfb5e 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -381,7 +381,7 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet', if ('avg_' in output) or ('itc' in output): out = np.empty((n_chans, n_freqs, n_times), dtype) - elif output == 'complex' and n_tapers > 1: + elif output in ['complex', 'phase'] and n_tapers > 1: out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype) else: out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype) @@ -518,7 +518,7 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim): n_freqs = len(Ws[0]) if ('avg_' in output) or ('itc' in output): tfrs = np.zeros((n_freqs, n_times), dtype=dtype) - elif output == 'complex' and n_tapers > 1: + elif output in ['complex', 'phase'] and n_tapers > 1: tfrs = np.zeros((n_tapers, n_epochs, n_freqs, n_times), dtype=dtype) else: @@ -553,7 +553,7 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim): # Stack or add if ('avg_' in output) or ('itc' in output): tfrs += tfr - elif output == 'complex' and n_tapers > 1: + elif output in ['complex', 'phase'] and n_tapers > 1: tfrs[taper_idx, epoch_idx] += tfr else: tfrs[epoch_idx] += tfr From fa3bdf1b27b2adb91ec6be27e267279f847c3f39 Mon Sep 17 00:00:00 2001 From: mmagnuski Date: Wed, 2 Feb 2022 12:56:52 +0100 Subject: [PATCH 06/15] TST, FIX: adapt tests for multitaper output='phase', fix transpose --- mne/time_frequency/tests/test_tfr.py | 10 +--------- mne/time_frequency/tfr.py | 7 +------ 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index 51990b232c7..88512024616 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -721,17 +721,12 @@ def test_compute_tfr(): (tfr_array_multitaper, tfr_array_morlet), (False, True), (False, True), ('complex', 'power', 'phase', 'avg_power_itc', 'avg_power', 'itc')): - # Check exception - if (func == tfr_array_multitaper) and (output == 'phase'): - pytest.raises(NotImplementedError, func, data, sfreq=sfreq, - freqs=freqs, output=output) - continue # Check runs out = func(data, sfreq=sfreq, freqs=freqs, use_fft=use_fft, zero_mean=zero_mean, n_cycles=2., output=output) # Check shapes - if func == tfr_array_multitaper and output == 'complex': + if func == tfr_array_multitaper and output in ['complex', 'phase']: n_tapers = 3 shape = np.r_[data.shape[:2], n_tapers, len(freqs), data.shape[2]] else: @@ -766,9 +761,6 @@ def test_compute_tfr(): # No time_bandwidth param in morlet pytest.raises(ValueError, _compute_tfr, data, freqs, sfreq, method='morlet', time_bandwidth=1) - # No phase in multitaper XXX Check ? - pytest.raises(NotImplementedError, _compute_tfr, data, freqs, sfreq, - method='multitaper', output='phase') # Inter-trial coherence tests out = _compute_tfr(data, freqs, sfreq, output='itc', n_cycles=2.) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 18bfe9cfb5e..34a8798f4e9 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -402,7 +402,7 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet', if ('avg_' not in output) and ('itc' not in output): # This is to enforce that the first dimension is for epochs - if output == 'complex' and n_tapers > 1: + if output in ['complex', 'phase'] and n_tapers > 1: out = out.transpose(2, 0, 1, 3, 4) else: out = out.transpose(1, 0, 2, 3) @@ -434,11 +434,6 @@ def _check_tfr_param(freqs, sfreq, method, zero_mean, n_cycles, % type(zero_mean)) freqs = np.asarray(freqs) - if (method == 'multitaper') and (output == 'phase'): - raise NotImplementedError( - 'This function is not optimized to compute the phase using the ' - 'multitaper method. Use np.angle of the complex output instead.') - # Check n_cycles if isinstance(n_cycles, (int, float)): n_cycles = float(n_cycles) From 99a0743f82b5d1d2e0f9d76ab3b1fafd879abed0 Mon Sep 17 00:00:00 2001 From: mmagnuski Date: Wed, 2 Feb 2022 13:07:41 +0100 Subject: [PATCH 07/15] TST: test that averaging power across tapers when output='complex' gives the same as output='power' --- mne/time_frequency/tests/test_tfr.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index 88512024616..23404a31757 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -131,6 +131,22 @@ def test_time_frequency(): # computed within the method. assert_allclose(epochs_amplitude_2.data**2, epochs_power_picks.data) + # complex test for multitaper case + epoch_data = epochs.get_data() + multitaper_power = tfr_array_multitaper( + epoch_data, epochs.info['sfreq'], freqs, n_cycles, + output="power") + multitaper_complex = tfr_array_multitaper( + epoch_data, epochs.info['sfreq'], freqs, n_cycles, + output="complex") + + print(multitaper_complex.shape) + print(multitaper_power.shape) + taper_dim = 2 + power_from_complex = (multitaper_complex * multitaper_complex.conj() + ).real.mean(axis=taper_dim) + assert_allclose(power_from_complex, multitaper_power) + print(itc) # test repr print(itc.ch_names) # test property itc += power # test add From 1531136b7374bf7791cc322109e80a0a7c812cac Mon Sep 17 00:00:00 2001 From: mmagnuski Date: Wed, 2 Feb 2022 13:08:09 +0100 Subject: [PATCH 08/15] FIX: fix normalization by n_tapers --- mne/time_frequency/tfr.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 34a8798f4e9..1c09537af78 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -564,7 +564,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim): tfrs /= n_epochs # Normalization by number of taper - tfrs /= len(Ws) + if n_tapers > 1 and output not in ['complex', 'phase']: + tfrs /= len(Ws) return tfrs From afb4b6b2df5993784d580fa2a16fd9d1b939f9bf Mon Sep 17 00:00:00 2001 From: mmagnuski Date: Wed, 2 Feb 2022 13:10:15 +0100 Subject: [PATCH 09/15] better --- mne/time_frequency/tests/test_tfr.py | 3 ++- mne/time_frequency/tfr.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index 23404a31757..f8ec04c5809 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -131,7 +131,8 @@ def test_time_frequency(): # computed within the method. assert_allclose(epochs_amplitude_2.data**2, epochs_power_picks.data) - # complex test for multitaper case + # test that averaging power across tapers when multitaper with + # output='complex' gives the same as output='power' epoch_data = epochs.get_data() multitaper_power = tfr_array_multitaper( epoch_data, epochs.info['sfreq'], freqs, n_cycles, diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 1c09537af78..e8509859778 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -565,7 +565,7 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim): # Normalization by number of taper if n_tapers > 1 and output not in ['complex', 'phase']: - tfrs /= len(Ws) + tfrs /= n_tapers return tfrs From 67ee3a3512950dbec079148842bc93966aec8d8e Mon Sep 17 00:00:00 2001 From: mmagnuski Date: Wed, 2 Feb 2022 13:10:34 +0100 Subject: [PATCH 10/15] DOC: update docstrings --- mne/time_frequency/multitaper.py | 14 ++++++++------ mne/time_frequency/tfr.py | 12 ++++++++---- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index b95ab496964..d44b8e27831 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -494,9 +494,9 @@ def tfr_array_multitaper(epoch_data, sfreq, freqs, n_cycles=7.0, is done after the convolutions. output : str, default 'complex' - * 'complex' : single trial complex. + * 'complex' : single trial per taper complex values. * 'power' : single trial power. - * 'phase' : single trial phase. + * 'phase' : single trial per taper phase. * 'avg_power' : average of single trial power. * 'itc' : inter-trial coherence. * 'avg_power_itc' : average of single trial power and inter-trial @@ -510,10 +510,12 @@ def tfr_array_multitaper(epoch_data, sfreq, freqs, n_cycles=7.0, ------- out : array Time frequency transform of epoch_data. If output is in ['complex', - 'phase', 'power'], then shape of out is (n_epochs, n_chans, n_freqs, - n_times), else it is (n_chans, n_freqs, n_times). If output is - 'avg_power_itc', the real values code for 'avg_power' and the - imaginary values code for the 'itc': out = avg_power + i * itc. + 'phase'], then the shape of ``out`` is ``(n_epochs, n_chans, n_tapers, + n_freqs, n_times)``; if output is 'power', the shape of ``out`` is + ``(n_epochs, n_chans, n_freqs, n_times)``, else it is ``(n_chans, + n_freqs, n_times)``. If output is 'avg_power_itc', the real values + in ``out`` contain the average power and the imaginary values contain + the ITC: ``out = avg_power + i * itc``. See Also -------- diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index e8509859778..78455d76d97 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -331,10 +331,14 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet', ------- out : array Time frequency transform of epoch_data. If output is in ['complex', - 'phase', 'power'], then shape of out is (n_epochs, n_chans, n_freqs, - n_times), else it is (n_chans, n_freqs, n_times). If output is - 'avg_power_itc', the real values code for 'avg_power' and the - imaginary values code for the 'itc': out = avg_power + i * itc + 'phase', 'power'], then shape of ``ou``t is ``(n_epochs, n_chans, + n_freqs, n_times)``, else it is ``(n_chans, n_freqs, n_times)``. + However, using multitaper method with at least 2 tapers and output + ``'complex'`` or ``'phase'`` results in shape of ``out`` being + ``(n_epochs, n_chans, n_tapers, n_freqs, n_times)``. If output is + ``'avg_power_itc'``, the real values in the ``output`` contain + average power' and the imaginary values contain the ITC: + ``out = avg_power + i * itc``. """ # Check data epoch_data = np.asarray(epoch_data) From 01de221612d5842ae1039c3fa78edad8fb68eb7c Mon Sep 17 00:00:00 2001 From: mmagnuski Date: Wed, 2 Feb 2022 21:57:34 +0100 Subject: [PATCH 11/15] DOC: update whats new --- doc/changes/latest.inc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index d6f295bb2f3..fd302b59f6b 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -24,6 +24,8 @@ Current (1.0.dev0) Enhancements ~~~~~~~~~~~~ +- :func:`mne.time_frequency.tfr_array_multitaper` can now return results for ``output='phase'`` instead of an error (:gh:`10281` by `Mikołaj Magnuski`) + - Speed up :func:`mne.preprocessing.annotate_muscle_zscore`, :func:`mne.preprocessing.annotate_movement`, and :func:`mne.preprocessing.annotate_nan` through better annotation creation (:gh:`10089` by :newcontrib:`Senwen Deng`) - Fix some unused variables in time_frequency_erds.py example (:gh:`10076` by :newcontrib:`Jan Zerfowski`) @@ -103,6 +105,8 @@ Enhancements Bugs ~~~~ +- :func:`mne.time_frequency.tfr_array_multitaper` now returns results per taper when ``output='complex'`` (:gh:`10281` by `Mikołaj Magnuski`) + - Teach :func:`mne.io.read_raw_bti` to use its ``eog_ch`` parameter (:gh:`10093` by :newcontrib:`Adina Wagner`) - Fix default of :func:`mne.io.Raw.plot` to be ``use_opengl=None``, which will act like False unless ``MNE_BROWSER_USE_OPENGL=true`` is set in the user configuration (:gh:`9957` by `Eric Larson`_) From 9b1a189e5ca2d3c50f26601cc1f0b8c61361b257 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 2 Feb 2022 16:25:10 -0500 Subject: [PATCH 12/15] DOC: Move and RST --- doc/changes/latest.inc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index fd302b59f6b..1f2cb10a40c 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -24,12 +24,12 @@ Current (1.0.dev0) Enhancements ~~~~~~~~~~~~ -- :func:`mne.time_frequency.tfr_array_multitaper` can now return results for ``output='phase'`` instead of an error (:gh:`10281` by `Mikołaj Magnuski`) - - Speed up :func:`mne.preprocessing.annotate_muscle_zscore`, :func:`mne.preprocessing.annotate_movement`, and :func:`mne.preprocessing.annotate_nan` through better annotation creation (:gh:`10089` by :newcontrib:`Senwen Deng`) - Fix some unused variables in time_frequency_erds.py example (:gh:`10076` by :newcontrib:`Jan Zerfowski`) +- :func:`mne.time_frequency.tfr_array_multitaper` can now return results for ``output='phase'`` instead of an error (:gh:`10281` by `Mikołaj Magnuski`_) + - Add show local maxima toggling button to :func:`mne.gui.locate_ieeg` (:gh:`9952` by `Alex Rockhill`_) - Improve docstring of :class:`mne.Info` and add attributes that were not covered (:gh:`9922` by `Mathieu Scheltienne`_) @@ -105,10 +105,10 @@ Enhancements Bugs ~~~~ -- :func:`mne.time_frequency.tfr_array_multitaper` now returns results per taper when ``output='complex'`` (:gh:`10281` by `Mikołaj Magnuski`) - - Teach :func:`mne.io.read_raw_bti` to use its ``eog_ch`` parameter (:gh:`10093` by :newcontrib:`Adina Wagner`) +- :func:`mne.time_frequency.tfr_array_multitaper` now returns results per taper when ``output='complex'`` (:gh:`10281` by `Mikołaj Magnuski`_) + - Fix default of :func:`mne.io.Raw.plot` to be ``use_opengl=None``, which will act like False unless ``MNE_BROWSER_USE_OPENGL=true`` is set in the user configuration (:gh:`9957` by `Eric Larson`_) - Fix bug with :class:`mne.Report` where figures were saved with ``bbox_inches='tight'``, which led to inconsistent sizes in sliders (:gh:`9966` by `Eric Larson`_) From 928d337fdbf5c5d03bf40004ace009362a8f8f6e Mon Sep 17 00:00:00 2001 From: mmagnuski Date: Thu, 3 Feb 2022 15:42:55 +0100 Subject: [PATCH 13/15] remove leftover prints --- mne/time_frequency/tests/test_tfr.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index f8ec04c5809..b0edd9a968e 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -141,8 +141,6 @@ def test_time_frequency(): epoch_data, epochs.info['sfreq'], freqs, n_cycles, output="complex") - print(multitaper_complex.shape) - print(multitaper_power.shape) taper_dim = 2 power_from_complex = (multitaper_complex * multitaper_complex.conj() ).real.mean(axis=taper_dim) From 4ec4d3d036b038abdc0565977340f68e877911cc Mon Sep 17 00:00:00 2001 From: mmagnuski Date: Thu, 3 Feb 2022 15:44:11 +0100 Subject: [PATCH 14/15] ENH: return tapers dimension for complex and phase multitaper --- mne/time_frequency/tfr.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 78455d76d97..a63ad52c9d2 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -385,7 +385,7 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet', if ('avg_' in output) or ('itc' in output): out = np.empty((n_chans, n_freqs, n_times), dtype) - elif output in ['complex', 'phase'] and n_tapers > 1: + elif output in ['complex', 'phase'] and method == 'multitaper': out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype) else: out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype) @@ -397,7 +397,7 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet', # Parallelization is applied across channels. tfrs = parallel( - my_cwt(channel, Ws, output, use_fft, 'same', decim) + my_cwt(channel, Ws, output, use_fft, 'same', decim, method) for channel in epoch_data.transpose(1, 0, 2)) # FIXME: to avoid overheads we should use np.array_split() @@ -406,7 +406,7 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet', if ('avg_' not in output) and ('itc' not in output): # This is to enforce that the first dimension is for epochs - if output in ['complex', 'phase'] and n_tapers > 1: + if output in ['complex', 'phase'] and method == 'multitaper': out = out.transpose(2, 0, 1, 3, 4) else: out = out.transpose(1, 0, 2, 3) @@ -477,7 +477,8 @@ def _check_tfr_param(freqs, sfreq, method, zero_mean, n_cycles, return freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim -def _time_frequency_loop(X, Ws, output, use_fft, mode, decim): +def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, + method=None): """Aux. function to _compute_tfr. Loops time-frequency transform across wavelets and epochs. @@ -504,6 +505,9 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim): See numpy.convolve. decim : slice The decimation slice: e.g. power[:, decim] + method : str | None + Used only for multitapering to create tapers dimension in the output + if ``output in ['complex', 'phase']``. """ # Set output type dtype = np.float64 @@ -517,7 +521,7 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim): n_freqs = len(Ws[0]) if ('avg_' in output) or ('itc' in output): tfrs = np.zeros((n_freqs, n_times), dtype=dtype) - elif output in ['complex', 'phase'] and n_tapers > 1: + elif output in ['complex', 'phase'] and method == 'multitaper': tfrs = np.zeros((n_tapers, n_epochs, n_freqs, n_times), dtype=dtype) else: @@ -552,7 +556,7 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim): # Stack or add if ('avg_' in output) or ('itc' in output): tfrs += tfr - elif output in ['complex', 'phase'] and n_tapers > 1: + elif output in ['complex', 'phase'] and method == 'multitaper': tfrs[taper_idx, epoch_idx] += tfr else: tfrs[epoch_idx] += tfr From f439206d97905142cce68886f6888778b5922d41 Mon Sep 17 00:00:00 2001 From: mmagnuski Date: Thu, 3 Feb 2022 15:52:19 +0100 Subject: [PATCH 15/15] DOC: update docs --- mne/time_frequency/multitaper.py | 14 +++++++------- mne/time_frequency/tfr.py | 13 ++++++------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index d44b8e27831..495451456f4 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -509,13 +509,13 @@ def tfr_array_multitaper(epoch_data, sfreq, freqs, n_cycles=7.0, Returns ------- out : array - Time frequency transform of epoch_data. If output is in ['complex', - 'phase'], then the shape of ``out`` is ``(n_epochs, n_chans, n_tapers, - n_freqs, n_times)``; if output is 'power', the shape of ``out`` is - ``(n_epochs, n_chans, n_freqs, n_times)``, else it is ``(n_chans, - n_freqs, n_times)``. If output is 'avg_power_itc', the real values - in ``out`` contain the average power and the imaginary values contain - the ITC: ``out = avg_power + i * itc``. + Time frequency transform of epoch_data. If ``output in ['complex', + 'phase']``, then the shape of ``out`` is ``(n_epochs, n_chans, + n_tapers, n_freqs, n_times)``; if output is 'power', the shape of + ``out`` is ``(n_epochs, n_chans, n_freqs, n_times)``, else it is + ``(n_chans, n_freqs, n_times)``. If output is 'avg_power_itc', the real + values in ``out`` contain the average power and the imaginary values + contain the ITC: ``out = avg_power + i * itc``. See Also -------- diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index a63ad52c9d2..fb8e1c6353b 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -331,14 +331,13 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet', ------- out : array Time frequency transform of epoch_data. If output is in ['complex', - 'phase', 'power'], then shape of ``ou``t is ``(n_epochs, n_chans, + 'phase', 'power'], then shape of ``out`` is ``(n_epochs, n_chans, n_freqs, n_times)``, else it is ``(n_chans, n_freqs, n_times)``. - However, using multitaper method with at least 2 tapers and output - ``'complex'`` or ``'phase'`` results in shape of ``out`` being - ``(n_epochs, n_chans, n_tapers, n_freqs, n_times)``. If output is - ``'avg_power_itc'``, the real values in the ``output`` contain - average power' and the imaginary values contain the ITC: - ``out = avg_power + i * itc``. + However, using multitaper method and output ``'complex'`` or + ``'phase'`` results in shape of ``out`` being ``(n_epochs, n_chans, + n_tapers, n_freqs, n_times)``. If output is ``'avg_power_itc'``, the + real values in the ``output`` contain average power' and the imaginary + values contain the ITC: ``out = avg_power + i * itc``. """ # Check data epoch_data = np.asarray(epoch_data)