Skip to content

Commit

Permalink
Add raw stc (#12001)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Marijn van Vliet <w.m.vanvliet@gmail.com>
Co-authored-by: Daniel McCloy <dan@mccloy.info>
  • Loading branch information
4 people authored Oct 6, 2023
1 parent 647fdd3 commit 37ae7e3
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 19 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Enhancements
- Added public :func:`mne.io.write_info` to complement :func:`mne.io.read_info` (:gh:`11918` by `Eric Larson`_)
- Added option ``remove_dc`` to to :meth:`Raw.compute_psd() <mne.io.Raw.compute_psd>`, :meth:`Epochs.compute_psd() <mne.Epochs.compute_psd>`, and :meth:`Evoked.compute_psd() <mne.Evoked.compute_psd>`, to allow skipping DC removal when computing Welch or multitaper spectra (:gh:`11769` by `Nikolai Chapochnikov`_)
- Add the possibility to provide a float between 0 and 1 as ``n_grad``, ``n_mag`` and ``n_eeg`` in `~mne.compute_proj_raw`, `~mne.compute_proj_epochs` and `~mne.compute_proj_evoked` to select the number of vectors based on the cumulative explained variance (:gh:`11919` by `Mathieu Scheltienne`_)
- Add extracting all time courses in a label using :func:`mne.extract_label_time_course` without applying an aggregation function (like ``mean``) (:gh:`12001` by `Hamza Abdelhedi`_)
- Added support for Artinis fNIRS data files to :func:`mne.io.read_raw_snirf` (:gh:`11926` by `Robert Luke`_)
- Add helpful error messages when using methods on empty :class:`mne.Epochs`-objects (:gh:`11306` by `Martin Schulz`_)
- Add support for passing a :class:`python:dict` as ``sensor_color`` to specify per-channel-type colors in :func:`mne.viz.plot_alignment` (:gh:`12067` by `Eric Larson`_)
Expand Down
26 changes: 15 additions & 11 deletions mne/source_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3240,6 +3240,7 @@ def _pca_flip(flip, data):
"mean_flip": lambda flip, data: np.mean(flip * data, axis=0),
"max": lambda flip, data: np.max(np.abs(data), axis=0),
"pca_flip": _pca_flip,
None: lambda flip, data: data, # Return Identity: Preserves all vertices.
}


Expand Down Expand Up @@ -3494,7 +3495,7 @@ def _volume_labels(src, labels, mri_resolution):


def _get_default_label_modes():
return sorted(_label_funcs.keys()) + ["auto"]
return sorted(_label_funcs.keys(), key=lambda x: (x is None, x)) + ["auto"]


def _get_allowed_label_modes(stc):
Expand Down Expand Up @@ -3572,7 +3573,12 @@ def _gen_extract_label_time_course(
)

# do the extraction
label_tc = np.zeros((n_labels,) + stc.data.shape[1:], dtype=stc.data.dtype)
if mode is None:
# prepopulate an empty list for easy array-like index-based assignment
label_tc = [None] * max(len(label_vertidx), len(src_flip))
else:
# For other modes, initialize the label_tc array
label_tc = np.zeros((n_labels,) + stc.data.shape[1:], dtype=stc.data.dtype)
for i, (vertidx, flip) in enumerate(zip(label_vertidx, src_flip)):
if vertidx is not None:
if isinstance(vertidx, sparse.csr_matrix):
Expand All @@ -3585,15 +3591,13 @@ def _gen_extract_label_time_course(
this_data = stc.data[vertidx]
label_tc[i] = func(flip, this_data)

# extract label time series for the vol src space (only mean supported)
offset = nvert[:-n_mean].sum() # effectively :2 or :0
for i, nv in enumerate(nvert[2:]):
if nv != 0:
v2 = offset + nv
label_tc[n_mode + i] = np.mean(stc.data[offset:v2], axis=0)
offset = v2

# this is a generator!
if mode is not None:
offset = nvert[:-n_mean].sum() # effectively :2 or :0
for i, nv in enumerate(nvert[2:]):
if nv != 0:
v2 = offset + nv
label_tc[n_mode + i] = np.mean(stc.data[offset:v2], axis=0)
offset = v2
yield label_tc


Expand Down
46 changes: 38 additions & 8 deletions mne/tests/test_source_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,12 +678,24 @@ def test_extract_label_time_course(kind, vector):

label_tcs = dict(mean=np.arange(n_labels)[:, None] * np.ones((n_labels, n_times)))
label_tcs["max"] = label_tcs["mean"]
label_tcs[None] = label_tcs["mean"]

# compute the mean with sign flip
label_tcs["mean_flip"] = np.zeros_like(label_tcs["mean"])
for i, label in enumerate(labels):
label_tcs["mean_flip"][i] = i * np.mean(label_sign_flip(label, src[:2]))

# compute pca_flip
label_flip = []
for i, label in enumerate(labels):
this_flip = i * label_sign_flip(label, src[:2])
label_flip.append(this_flip)
# compute pca_flip
label_tcs["pca_flip"] = np.zeros_like(label_tcs["mean"])
for i, (label, flip) in enumerate(zip(labels, label_flip)):
sign = np.sign(np.dot(np.full((flip.shape[0]), i), flip))
label_tcs["pca_flip"][i] = sign * label_tcs["mean"][i]

# generate some stc's with known data
stcs = list()
pad = (((0, 0), (2, 0), (0, 0)), "constant")
Expand Down Expand Up @@ -734,7 +746,7 @@ def test_extract_label_time_course(kind, vector):
assert_array_equal(arr[1:], vol_means_t)

# test the different modes
modes = ["mean", "mean_flip", "pca_flip", "max", "auto"]
modes = ["mean", "mean_flip", "pca_flip", "max", "auto", None]

for mode in modes:
if vector and mode not in ("mean", "max", "auto"):
Expand All @@ -748,18 +760,36 @@ def test_extract_label_time_course(kind, vector):
]
assert len(label_tc) == n_stcs
assert len(label_tc_method) == n_stcs
for tc1, tc2 in zip(label_tc, label_tc_method):
assert tc1.shape == (n_labels + len(vol_means),) + end_shape
assert tc2.shape == (n_labels + len(vol_means),) + end_shape
assert_allclose(tc1, tc2, rtol=1e-8, atol=1e-16)
for j, (tc1, tc2) in enumerate(zip(label_tc, label_tc_method)):
if mode is None:
assert all(arr.shape[1] == tc1[0].shape[1] for arr in tc1)
assert all(arr.shape[1] == tc2[0].shape[1] for arr in tc2)
assert (len(tc1), tc1[0].shape[1]) == (n_labels,) + end_shape
assert (len(tc2), tc2[0].shape[1]) == (n_labels,) + end_shape
for arr1, arr2 in zip(tc1, tc2): # list of arrays
assert_allclose(arr1, arr2, rtol=1e-8, atol=1e-16)
else:
assert tc1.shape == (n_labels + len(vol_means),) + end_shape
assert tc2.shape == (n_labels + len(vol_means),) + end_shape
assert_allclose(tc1, tc2, rtol=1e-8, atol=1e-16)
if mode == "auto":
use_mode = "mean" if vector else "mean_flip"
else:
use_mode = mode
# XXX we don't check pca_flip, probably should someday...
if use_mode in ("mean", "max", "mean_flip"):
if mode == "pca_flip":
for arr1, arr2 in zip(tc1, label_tcs[use_mode]):
assert_array_almost_equal(arr1, arr2)
elif use_mode is None:
for arr1, arr2 in zip(
tc1[:n_labels], label_tcs[use_mode]
): # list of arrays
assert_allclose(
arr1, np.tile(arr2, (arr1.shape[0], 1)), rtol=1e-8, atol=1e-16
)
elif use_mode in ("mean", "max", "mean_flip"):
assert_array_almost_equal(tc1[:n_labels], label_tcs[use_mode])
assert_array_almost_equal(tc1[n_labels:], vol_means_t)
if mode is not None:
assert_array_almost_equal(tc1[n_labels:], vol_means_t)

# test label with very few vertices (check SVD conditionals)
label = Label(vertices=src[0]["vertno"][:2], hemi="lh")
Expand Down
3 changes: 3 additions & 0 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,9 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
- ``'auto'`` (default)
Uses ``'mean_flip'`` when a standard source estimate is applied, and
``'mean'`` when a vector source estimate is supplied.
- ``None``
No aggregation is performed, and an array of shape ``(n_vertices, n_times)`` is
returned.
.. versionadded:: 0.21
Support for ``'auto'``, vector, and volume source estimates.
Expand Down

0 comments on commit 37ae7e3

Please sign in to comment.