Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add option to store and return TFR taper weights #12910

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
12 changes: 12 additions & 0 deletions mne/time_frequency/multitaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ def tfr_array_multitaper(
use_fft=True,
decim=1,
output="complex",
return_weights=False,
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
n_jobs=None,
*,
verbose=None,
Expand Down Expand Up @@ -502,6 +503,13 @@ def tfr_array_multitaper(
* ``'itc'`` : inter-trial coherence.
* ``'avg_power_itc'`` : average of single trial power and inter-trial
coherence across trials.

return_weights : bool, default False
If True, return the taper weights. Only applies if ``output='complex'`` or
``'phase'``.

.. versionadded:: 1.9.0

%(n_jobs)s
The parallelization is implemented across channels.
%(verbose)s
Expand All @@ -520,6 +528,9 @@ def tfr_array_multitaper(
If ``output`` is ``'avg_power_itc'``, the real values in ``out``
contain the average power and the imaginary values contain the
inter-trial coherence: :math:`out = power_{avg} + i * ITC`.
weights : array of shape (n_tapers, n_freqs)
The taper weights. Only returned if ``output='complex'`` or ``'phase'`` and
``return_weights=True``.

See Also
--------
Expand Down Expand Up @@ -550,6 +561,7 @@ def tfr_array_multitaper(
use_fft=use_fft,
decim=decim,
output=output,
return_weights=return_weights,
n_jobs=n_jobs,
verbose=verbose,
)
14 changes: 9 additions & 5 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,17 +432,21 @@ def test_tfr_morlet():
def test_dpsswavelet():
"""Test DPSS tapers."""
freqs = np.arange(5, 25, 3)
Ws = _make_dpss(
1000, freqs=freqs, n_cycles=freqs / 2.0, time_bandwidth=4.0, zero_mean=True
Ws, weights = _make_dpss(
1000,
freqs=freqs,
n_cycles=freqs / 2.0,
time_bandwidth=4.0,
zero_mean=True,
return_weights=True,
)

assert len(Ws) == 3 # 3 tapers expected
assert np.shape(Ws)[:2] == (3, len(freqs)) # 3 tapers expected
assert np.shape(Ws)[:2] == np.shape(weights) # weights of shape (tapers, freqs)

# Check that zero mean is true
assert np.abs(np.mean(np.real(Ws[0][0]))) < 1e-5

assert len(Ws[0]) == len(freqs) # As many wavelets as asked for


@pytest.mark.slowtest
def test_tfr_multitaper():
Expand Down
97 changes: 75 additions & 22 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,11 @@ def _make_dpss(
-------
Ws : list of array
The wavelets time series.
Cs : list of array
The concentration weights. Only returned if return_weights=True.
"""
Ws = list()
Cs = list()

freqs = np.array(freqs)
if np.any(freqs <= 0):
Expand All @@ -281,6 +284,7 @@ def _make_dpss(

for m in range(n_taps):
Wm = list()
Cm = list()
for k, f in enumerate(freqs):
if len(n_cycles) != 1:
this_n_cycles = n_cycles[k]
Expand All @@ -302,12 +306,15 @@ def _make_dpss(
real_offset = Wk.mean()
Wk -= real_offset
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel())
Ck = np.sqrt(conc[m])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This I am somewhat unsure on. The existing implementation is to just use conc as-is, however in the MNE-Connectivity implementation that sqrt is taken: https://github.com/mne-tools/mne-connectivity/blob/97147a57eefb36a5c9680e539fdc6343a1183f20/mne_connectivity/spectral/time.py#L825

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also unsure on this point. We should ask @ruuskas (who wrote the implementation in MNE-Connectivity) and @larsoner (who wrote the SciPy DPSS implementation) to weigh in.


Wm.append(Wk)
Cm.append(Ck)

Ws.append(Wm)
Cs.append(Cm)
if return_weights:
return Ws, conc
return Ws, Cs
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
return Ws


Expand Down Expand Up @@ -428,6 +435,7 @@ def _compute_tfr(
use_fft=True,
decim=1,
output="complex",
return_weights=False,
n_jobs=None,
*,
verbose=None,
Expand Down Expand Up @@ -479,6 +487,9 @@ def _compute_tfr(
* 'avg_power_itc' : average of single trial power and inter-trial
coherence across trials.

return_weights : bool, default False
Whether to return the taper weights. Only applies if method='multitaper' and
output='complex' or 'phase'.
%(n_jobs)s
The number of epochs to process at the same time. The parallelization
is implemented across channels.
Expand All @@ -495,6 +506,10 @@ def _compute_tfr(
n_tapers, n_freqs, n_times)``. If output is ``'avg_power_itc'``, the
real values in the ``output`` contain average power' and the imaginary
values contain the ITC: ``out = avg_power + i * itc``.

weights : array of shape (n_tapers, n_freqs)
The taper weights. Only returned if method='multitaper', output='complex' or
'phase', and return_weights=True.
"""
# Check data
epoch_data = np.asarray(epoch_data)
Expand All @@ -516,6 +531,9 @@ def _compute_tfr(
decim,
output,
)
return_weights = (
return_weights and method == "multitaper" and output in ["complex", "phase"]
)

decim = _ensure_slice(decim)
if (freqs > sfreq / 2.0).any():
Expand All @@ -531,13 +549,18 @@ def _compute_tfr(
Ws = [W] # to have same dimensionality as the 'multitaper' case

elif method == "multitaper":
Ws = _make_dpss(
out = _make_dpss(
sfreq,
freqs,
n_cycles=n_cycles,
time_bandwidth=time_bandwidth,
zero_mean=zero_mean,
return_weights=return_weights,
)
if return_weights:
Ws, weights = out
else:
Ws = out

# Check wavelets
if len(Ws[0][0]) > epoch_data.shape[2]:
Expand All @@ -561,6 +584,8 @@ def _compute_tfr(
out = np.empty((n_chans, n_freqs, n_times), dtype)
elif output in ["complex", "phase"] and method == "multitaper":
out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype)
if return_weights:
weights = np.array(weights)
else:
out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype)

Expand All @@ -585,6 +610,9 @@ def _compute_tfr(
out = out.transpose(2, 0, 1, 3, 4)
else:
out = out.transpose(1, 0, 2, 3)

if return_weights:
return out, weights
return out


Expand Down Expand Up @@ -1187,9 +1215,6 @@ def __init__(
f'{classname} got unsupported parameter value{_pl(problem)} '
f'{" and ".join(problem)}.'
)
# shim for tfr_array_morlet deprecation warning (TODO: remove after 1.7 release)
if method == "morlet":
method_kw.setdefault("zero_mean", True)
Comment on lines -1190 to -1192
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR, but it can be removed.

# check method
valid_methods = ["morlet", "multitaper"]
if isinstance(inst, BaseEpochs):
Expand All @@ -1203,6 +1228,9 @@ def __init__(
method_kw.setdefault("output", "power")
self._freqs = np.asarray(freqs, dtype=np.float64)
del freqs
# always store weights for per-taper outputs
if method == "multitaper" and method_kw.get("output") in ["complex", "phase"]:
method_kw["return_weights"] = True
Comment on lines +1230 to +1232
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hesitate to blindly overwrite what the user might have put into their method_kw dict, so I was going to suggest using .setdefault here. But then I wondered, is there ever a case where the user would sensibly want to pass method_kw=dict(return_weights=False, ...)? I'm guessing not, since when instantiating the TFR class object, the user isn't getting direct access to the return value of the method anyway. WDYT @tsbinns ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this was my line of thought as well. Also, allowing the user to control this would mean extra logic needs to be put in place when unpacking the tfr values (i.e., whether we need to separate the tfr from the weights). I think just forcing this to True simplifies things and would not affect the user at all.

# check validity of kwargs manually to save compute time if any are invalid
tfr_funcs = dict(
morlet=tfr_array_morlet,
Expand All @@ -1224,6 +1252,7 @@ def __init__(
self._method = method
self._inst_type = type(inst)
self._baseline = None
self._weights = None
self.preload = True # needed for __getitem__, never False for TFRs
# self._dims may also get updated by child classes
self._dims = ["channel", "freq", "time"]
Expand Down Expand Up @@ -1382,6 +1411,7 @@ def __getstate__(self):
info=self.info,
baseline=self._baseline,
decim=self._decim,
weights=self._weights,
)

def __setstate__(self, state):
Expand Down Expand Up @@ -1410,6 +1440,7 @@ def __setstate__(self, state):
self._decim = defaults["decim"]
self.preload = True
self._set_times(self._raw_times)
self._weights = state.get("weights") # objs saved before #XXX won't have
# Handle instance type. Prior to gh-11282, Raw was not a possibility so if
# `inst_type_str` is missing it must be Epochs or Evoked
unknown_class = Epochs if "epoch" in self._dims else Evoked
Expand Down Expand Up @@ -1516,6 +1547,10 @@ def _compute_tfr(self, data, n_jobs, verbose):
if self.method == "stockwell":
self._data, self._itc, freqs = result
assert np.array_equal(self._freqs, freqs)
elif self.method == "multitaper" and self._tfr_func.keywords.get(
"output", ""
) in ["complex", "phase"]:
self._data, self._weights = result
elif self._tfr_func.keywords.get("output", "").endswith("_itc"):
self._data, self._itc = result.real, result.imag
else:
Expand Down Expand Up @@ -1694,6 +1729,11 @@ def times(self):
"""The time points present in the data (in seconds)."""
return self._times_readonly

@property
def weights(self):
"""The weights used for each taper in the time-frequency estimates."""
return self._weights

@fill_doc
def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True):
"""Crop data to a given time interval in place.
Expand Down Expand Up @@ -2654,42 +2694,55 @@ def to_data_frame(
"""
# check pandas once here, instead of in each private utils function
pd = _check_pandas_installed() # noqa
# triage for Epoch-derived or unaggregated spectra
from_epo = isinstance(self, EpochsTFR)
unagg_mt = "taper" in self._dims
# arg checking
valid_index_args = ["time", "freq"]
if isinstance(self, EpochsTFR):
if from_epo:
valid_index_args.extend(["epoch", "condition"])
valid_time_formats = ["ms", "timedelta"]
index = _check_pandas_index_arguments(index, valid_index_args)
time_format = _check_time_format(time_format, valid_time_formats)
# get data
picks = _picks_to_idx(self.info, picks, "all", exclude=())
data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True)
axis = self._dims.index("channel")
if not isinstance(self, EpochsTFR):
ch_axis = self._dims.index("channel")
if not from_epo:
data = data[np.newaxis] # add singleton "epochs" axis
axis += 1
n_epochs, n_picks, n_freqs, n_times = data.shape
# reshape to (epochs*freqs*times) x signals
data = np.moveaxis(data, axis, -1)
data = data.reshape(n_epochs * n_freqs * n_times, n_picks)
ch_axis += 1
if not unagg_mt:
data = np.expand_dims(data, -3) # add singleton "tapers" axis
n_epochs, n_picks, n_tapers, n_freqs, n_times = data.shape
# reshape to (epochs*tapers*freqs*times) x signals
data = np.moveaxis(data, ch_axis, -1)
data = data.reshape(n_epochs * n_tapers * n_freqs * n_times, n_picks)
# prepare extra columns / multiindex
mindex = list()
default_index = list()
times = _convert_times(times, time_format, self.info["meas_date"])
times = np.tile(times, n_epochs * n_freqs)
freqs = np.tile(np.repeat(freqs, n_times), n_epochs)
times = np.tile(times, n_epochs * n_freqs * n_tapers)
freqs = np.tile(np.repeat(freqs, n_times * n_tapers), n_epochs)
mindex.append(("time", times))
mindex.append(("freq", freqs))
if isinstance(self, EpochsTFR):
mindex.append(("epoch", np.repeat(self.selection, n_times * n_freqs)))
if from_epo:
mindex.append(
("epoch", np.repeat(self.selection, n_times * n_freqs * n_tapers))
)
rev_event_id = {v: k for k, v in self.event_id.items()}
conditions = [rev_event_id[k] for k in self.events[:, 2]]
mindex.append(("condition", np.repeat(conditions, n_times * n_freqs)))
mindex.append(
("condition", np.repeat(conditions, n_times * n_freqs * n_tapers))
)
default_index.extend(["condition", "epoch"])
default_index.extend(["freq", "time"])
if unagg_mt:
name = "taper"
taper_nums = np.tile(np.arange(n_tapers), n_epochs * n_freqs * n_times)
mindex.append((name, taper_nums))
default_index.append(name)
assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:])
# build DataFrame
if isinstance(self, EpochsTFR):
default_index = ["condition", "epoch", "freq", "time"]
else:
default_index = ["freq", "time"]
df = _build_data_frame(
self, data, picks, long_format, mindex, index, default_index=default_index
)
Expand Down
Loading