Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: return complex multitaper output per taper #10281

Merged
merged 15 commits into from
Feb 4, 2022

Conversation

mmagnuski
Copy link
Member

@mmagnuski mmagnuski commented Feb 1, 2022

Reference issue

Fixes #8722.

What does this implement/fix?

Changes multitaper time-frequency to return per-taper Fourier coefficients when output == 'complex'. So far the coefficients were averaged across tapers which is not correct. This change should also allow users to calculate connectivity measures per taper and average across tapers (this is what we do correctly now for ITC).

Additional information

Currently the output has the following shape: (n_epochs, n_channels, n_tapers, n_freqs, n_times)
This could be changed if you think a different output structure will work better.

To see how this PR works, you can try the script below (based on code provided by @chapochn in #8722; I will reshape this code into a test later this week):

import numpy as np
import mne.time_frequency as tfr

signal = np.zeros(1000)
signal[500:] = 1

# %% compute power and complex output with 2 tapers
tfr_out1 = tfr.tfr_array_multitaper([[signal]], sfreq=100, freqs=[20],
                                    time_bandwidth=3, output='complex')
tfr_out2 = tfr.tfr_array_multitaper([[signal]], sfreq=100, freqs=[20],
                                    time_bandwidth=3, output='power')

# %% confirm that we can get the power from complex output
taper_dim = 2
tfr1_pwr = (np.abs(tfr_out1) ** 2).mean(axis=taper_dim)
print(np.allclose(tfr1_pwr, tfr_out2))
print(tfr_out1.shape, tfr_out2.shape)

TODOs:

  • update failing tests
  • add one more test (based on the code snippet)
  • return results per taper also for output='phase'
  • whats new
  • disable output='complex' for normal (non-array) tfr multitaper? complex or phase results seem to not be available through tfr_multitaper

@larsoner
Copy link
Member

larsoner commented Feb 1, 2022

FYI the plan in the next year or two is to overhaul these functions so that there are proper containers to hold different variants of spectral information, probably including complex tapers

https://mne.tools/stable/overview/roadmap.html#time-frequency-classes
https://chanzuckerberg.com/eoss/proposals/building-pediatric-and-clinical-data-pipelines-for-mne-python/
#10184

IIRC this will include both pure frequency (spectral) and time-resolved (time-frequency) estimates. @drammock is going to work on this mostly, but he's out for a bit.

That being said, given how long it might take to get all the above work done, this seems like a good intermediate solution. I agree we should not allow output='complex' in a way where we average across tapers

@larsoner larsoner mentioned this pull request Feb 1, 2022
@mmagnuski
Copy link
Member Author

mmagnuski commented Feb 2, 2022

@larsoner do you think it would be useful to provide per taper results for output='phase' too?

edit: yes, I think so - currently asking for output='phase' raises an error.

@larsoner
Copy link
Member

larsoner commented Feb 2, 2022

do you think it would be useful to provide per taper results for output='phase' too?

edit: yes, I think so - currently asking for output='phase' raises an error.

Either way is fine by me. It can be added here, or assume YAGNI and we can add it later if someone wants it

@mmagnuski mmagnuski marked this pull request as ready for review February 2, 2022 20:58
@mmagnuski
Copy link
Member Author

@larsoner ready for review. Previous tests were green but the build failed on docs timeout.

doc/changes/latest.inc Outdated Show resolved Hide resolved
doc/changes/latest.inc Outdated Show resolved Hide resolved
Comment on lines 144 to 145
print(multitaper_complex.shape)
print(multitaper_power.shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cruft? But probably okay since pytest will capture it anyway

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha, yes :) I will remove it.

@@ -380,6 +385,8 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet',

if ('avg_' in output) or ('itc' in output):
out = np.empty((n_chans, n_freqs, n_times), dtype)
elif output in ['complex', 'phase'] and n_tapers > 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me it makes more sense to always output the tapers dim, even if it's a singleton. It makes code and .ndim work the same regardless of if the time-bandwidth product / half-bandwidth (or whatever we call it) changes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I agree, I'll change that

@mmagnuski
Copy link
Member Author

@larsoner All done and green. :)

@larsoner larsoner merged commit 6422d6c into mne-tools:main Feb 4, 2022
@larsoner
Copy link
Member

larsoner commented Feb 4, 2022

Thanks @mmagnuski !

@mmagnuski
Copy link
Member Author

Thanks @larsoner 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

tfr_array_multitaper with output="complex" is misleading
2 participants