Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add option to store and return TFR taper weights #12910

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

tsbinns
Copy link
Contributor

@tsbinns tsbinns commented Oct 22, 2024

Reference issue (if any)

PR for #12851

What does this implement/fix?

Adds an option to return taper weights for complex and phase outputs of the multitaper method in tfr_array_multitaper(), and also ensures taper weights are stored in TFR objects.

Additional information

When working on this, I discovered a couple of other issues with the per-taper TFR implementations (#12851 (comment)), including the fact that the TFR object plotting methods and to_data_frame methods do not account for a taper dimension, leading to errors. Wasn't sure if people want me to also address these here or in a separate PR.

@@ -302,12 +306,15 @@ def _make_dpss(
real_offset = Wk.mean()
Wk -= real_offset
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel())
Ck = np.sqrt(conc[m])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This I am somewhat unsure on. The existing implementation is to just use conc as-is, however in the MNE-Connectivity implementation that sqrt is taken: https://github.com/mne-tools/mne-connectivity/blob/97147a57eefb36a5c9680e539fdc6343a1183f20/mne_connectivity/spectral/time.py#L825

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also unsure on this point. We should ask @ruuskas (who wrote the implementation in MNE-Connectivity) and @larsoner (who wrote the SciPy DPSS implementation) to weigh in.

@tsbinns
Copy link
Contributor Author

tsbinns commented Oct 22, 2024

I'm also somewhat confused about the design of the _make_dpss function:

for m in range(n_taps):
Wm = list()
Cm = list()
for k, f in enumerate(freqs):
if len(n_cycles) != 1:
this_n_cycles = n_cycles[k]
else:
this_n_cycles = n_cycles[0]
t_win = this_n_cycles / float(f)
t = np.arange(0.0, t_win, 1.0 / sfreq)
# Making sure wavelets are centered before tapering
oscillation = np.exp(2.0 * 1j * np.pi * f * (t - t_win / 2.0))
# Get dpss tapers
tapers, conc = dpss_windows(
t.shape[0], time_bandwidth / 2.0, n_taps, sym=False
)
Wk = oscillation * tapers[m]
if zero_mean: # to make it zero mean
real_offset = Wk.mean()
Wk -= real_offset
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel())
Ck = np.sqrt(conc[m])
Wm.append(Wk)
Cm.append(Ck)
Ws.append(Wm)
Cs.append(Cm)

It is looping over tapers, and then over frequencies. However, the dpss_windows function it calls internally provides the tapers and their weights for all tapers of a given frequency.

Would it not be more efficient to only loop over frequencies and take advantage of the fact that this will also return information for each taper?

Comment on lines -1190 to -1192
# shim for tfr_array_morlet deprecation warning (TODO: remove after 1.7 release)
if method == "morlet":
method_kw.setdefault("zero_mean", True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR, but it can be removed.

@tsbinns
Copy link
Contributor Author

tsbinns commented Oct 22, 2024

I also have a question regarding testing: for the I/O tests, we're reading TFR objects that do not have a weights property (just gets assigned to None) when loaded. Do I need to create new TFR objects that actually have some weights, or is the current test sufficient?

Apart from this there are still some tests I need to expand.

mne/time_frequency/multitaper.py Outdated Show resolved Hide resolved
Comment on lines +1230 to +1232
# always store weights for per-taper outputs
if method == "multitaper" and method_kw.get("output") in ["complex", "phase"]:
method_kw["return_weights"] = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hesitate to blindly overwrite what the user might have put into their method_kw dict, so I was going to suggest using .setdefault here. But then I wondered, is there ever a case where the user would sensibly want to pass method_kw=dict(return_weights=False, ...)? I'm guessing not, since when instantiating the TFR class object, the user isn't getting direct access to the return value of the method anyway. WDYT @tsbinns ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this was my line of thought as well. Also, allowing the user to control this would mean extra logic needs to be put in place when unpacking the tfr values (i.e., whether we need to separate the tfr from the weights). I think just forcing this to True simplifies things and would not affect the user at all.

@@ -302,12 +306,15 @@ def _make_dpss(
real_offset = Wk.mean()
Wk -= real_offset
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel())
Ck = np.sqrt(conc[m])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also unsure on this point. We should ask @ruuskas (who wrote the implementation in MNE-Connectivity) and @larsoner (who wrote the SciPy DPSS implementation) to weigh in.

mne/time_frequency/tfr.py Show resolved Hide resolved
@tsbinns
Copy link
Contributor Author

tsbinns commented Oct 29, 2024

Thanks for the review @drammock! I will sort out those remaining tests, although I'm in the process of moving at the moment so it might not be for some days.

Regarding those issues I came across with TFR multitapers and converting to dataframes / plotting: would you like me to incorporate that into this PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants