Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Fix bugs with corrmap computation #11858

Merged
merged 1 commit into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Bugs
- Fix bug with :class:`mne.io.Raw`, :class:`mne.SourceEstimate`, and related classes where the ``decimate`` and ``shift_time`` methods were errantly added (:gh:`11853` by `Eric Larson`_)
- Fix hanging interpreter with matplotlib figures using ``mne/viz/_mpl_figure.py`` in spyder console and jupyter notebooks (:gh:`11696` by `Mathieu Scheltienne`_)
- Fix bug with overlapping text for :meth:`mne.Evoked.plot` (:gh:`11698` by `Alex Rockhill`_)
- Fix bug with :func:`mne.preprocessing.corrmap` where the template iteration had non-standard map averaging (:gh:`11857` by `Eric Larson`_)
- For :func:`mne.io.read_raw_eyelink`, the default value of the ``gap_description`` parameter is now ``'BAD_ACQ_SKIP'``, following MNE convention (:gh:`11719` by `Scott Huberty`_)
- Fix bug with :func:`mne.io.read_raw_fil` where datasets without sensor positions would not import (:gh:`11733` by `George O'Neill`_)
- Fix bug with :func:`mne.chpi.compute_chpi_snr` where cHPI being off for part of the recording or bad channels being defined led to an error or incorrect behavior (:gh:`11754`, :gh:`11755` by `Eric Larson`_)
Expand Down
57 changes: 38 additions & 19 deletions mne/preprocessing/ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -3153,38 +3153,57 @@ def _band_pass_filter(inst, sources, target, l_freq, h_freq, verbose=None):

def _find_max_corrs(all_maps, target, threshold):
"""Compute correlations between template and target components."""
all_corrs = [compute_corr(target, subj.T) for subj in all_maps]
# Following Fig.2 from:
# https://www.sciencedirect.com/science/article/abs/pii/S1388245709002338

# > ... inverse weights (i.e., IC maps) from a selected template IC are
# > correlated with all ICs from all datasets ...
all_corrs = [compute_corr(target, subj_maps.T) for subj_maps in all_maps]
abs_corrs = [np.abs(a) for a in all_corrs]
larsoner marked this conversation as resolved.
Show resolved Hide resolved
corr_polarities = [np.sign(a) for a in all_corrs]
del all_corrs

# > selection of X ICs from each dataset with highest absolute
# > correlation >= TH
#
# subj_idxs is a list of indices for each subject that exceeded the threshold:
if threshold <= 1:
max_corrs = [list(np.nonzero(s_corr > threshold)[0]) for s_corr in abs_corrs]
subj_idxs = [list(np.nonzero(s_corr > threshold)[0]) for s_corr in abs_corrs]
else:
max_corrs = [
subj_idxs = [
list(_find_outliers(s_corr, threshold=threshold)) for s_corr in abs_corrs
]

am = [l_[i] for l_, i_s in zip(abs_corrs, max_corrs) for i in i_s]
median_corr_with_target = np.median(am) if len(am) > 0 else 0

polarities = [l_[i] for l_, i_s in zip(corr_polarities, max_corrs) for i in i_s]

maxmaps = [l_[i] for l_, i_s in zip(all_maps, max_corrs) for i in i_s]

if len(maxmaps) == 0:
# > The mean correlation of a resulting cluster is then computed via
# > Fisher’s z transform, to account for the non-normal distribution of
# > correlation values.
#
# Here we just use the median rather than the (transformed-back) mean of
# the (Fisher z-transformed) correlations:
am = np.concatenate(
[abs_corr[subj_idx] for abs_corr, subj_idx in zip(abs_corrs, subj_idxs)]
)
if len(am) == 0:
return [], 0, 0, []
newtarget = np.zeros(maxmaps[0].size)
std_of_maps = np.std(np.asarray(maxmaps))
mean_of_maps = np.std(np.asarray(maxmaps))
for maxmap, polarity in zip(maxmaps, polarities):
newtarget += (maxmap / std_of_maps - mean_of_maps) * polarity
median_corr_with_target = np.median(am)

newtarget /= len(maxmaps)
newtarget *= std_of_maps
# > Next, an average cluster map is calculated, after inversion of those
# > ICs showing a negative correlation (sign ambiguity problem) and root
# > mean square (RMS) normalization of each individual IC.
#
# Which is this (rms=Frobenius norm=np.linalg.norm):
newtarget = sum(
subj_maps[idx] * (pols[idx] / np.linalg.norm(subj_maps[idx]))
for subj_maps, pols, subj_idx in zip(all_maps, corr_polarities, subj_idxs)
for idx in subj_idx
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't this be written in vectorized form that would also be more readable? Like

(np.array(maxmaps) * np.array(polarities) / np.linalg.norm(np.array(maxmaps))).mean()

or something like that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that might work if you pass the right axis (and maybe keepdims=True) to np.linalg.norm

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe with the right axis manipulations... I'll try it and you can see if it's better

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... actually I'm not sure this will work unless we're guaranteed non-raggedness of the np.array, which we might not be if there are different numbers of components (which I think could happen if the number of PCA vectors differs across subjects for example). But I can unify the iteration step at least to a single one, which I'll push.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we cannot assume non-raggedness (IIRC I think raggedness even happens in our tutorial docs on corrmap)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case I think what's here is about as readable as you can get with as few temporaries in memory (just one for the current sum plus one that is being added)

newtarget /= len(am)

# And we also compute the similarity between this new map and our original
# target map
sim_i_o = np.abs(np.corrcoef(target, newtarget)[1, 0])

return newtarget, median_corr_with_target, sim_i_o, max_corrs
return newtarget, median_corr_with_target, sim_i_o, subj_idxs


@verbose
Expand Down
5 changes: 4 additions & 1 deletion mne/preprocessing/tests/test_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,10 @@ def test_ica_additional(method, tmp_path, short_raw_epochs):
plot=False,
show=False,
)
corrmap([ica, ica2], (0, 0), threshold=0.5, plot=False, show=False)
with catch_logging(True) as log:
corrmap([ica, ica2], (0, 0), threshold=0.5, plot=False, show=False)
log = log.getvalue()
assert "Median correlation with constructed map: 1.0" in log
assert ica.labels_["blinks"] == ica2.labels_["blinks"]
assert 0 in ica.labels_["blinks"]
# test retrieval of component maps as arrays
Expand Down
1 change: 0 additions & 1 deletion tutorials/preprocessing/40_artifact_correction_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
# is usually possible to separate the sources using ICA, and then re-construct
# the sensor signals after excluding the sources that are unwanted.
#
#
# ICA in MNE-Python
# ~~~~~~~~~~~~~~~~~
#
Expand Down