Skip to content

Commit

Permalink
BUG: Fix bug with plot_projs_topomap (#11792)
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner authored Oct 2, 2023
1 parent fd08b52 commit 8051f6d
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 30 deletions.
26 changes: 15 additions & 11 deletions mne/viz/tests/test_topomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
create_info,
read_cov,
EvokedArray,
compute_proj_raw,
Projection,
)
from mne._fiff.proj import make_eeg_average_ref_proj
Expand Down Expand Up @@ -71,6 +72,8 @@
layout = read_layout("Vectorview-all")
cov_fname = base_dir / "test-cov.fif"

fast_test = dict(res=8, contours=0, sensors=False)


@pytest.mark.parametrize("constrained_layout", (False, True))
def test_plot_topomap_interactive(constrained_layout):
Expand Down Expand Up @@ -135,32 +138,36 @@ def test_plot_projs_topomap():
"""Test plot_projs_topomap."""
projs = read_proj(ecg_fname)
info = read_info(raw_fname)
fast_test = {"res": 8, "contours": 0, "sensors": False}
plot_projs_topomap(projs, info=info, colorbar=True, **fast_test)
plt.close("all")
ax = plt.subplot(111)
_, ax = plt.subplots()
projs[3].plot_topomap(info)
plot_projs_topomap(projs[:1], info, axes=ax, **fast_test) # test axes
plt.close("all")
triux_info = read_info(triux_fname)
plot_projs_topomap(triux_info["projs"][-1:], triux_info, **fast_test)
plt.close("all")
plot_projs_topomap(triux_info["projs"][:1], triux_info, **fast_test)
plt.close("all")
eeg_avg = make_eeg_average_ref_proj(info)
eeg_avg.plot_topomap(info, **fast_test)
plt.close("all")
# test vlims
for vlim in ("joint", (-1, 1), (None, 0.5), (0.5, None), (None, None)):
plot_projs_topomap(projs[:-1], info, vlim=vlim, colorbar=True)
plt.close("all")

eeg_proj = make_eeg_average_ref_proj(info)
info_meg = pick_info(info, pick_types(info, meg=True, eeg=False))
with pytest.raises(ValueError, match="Missing channels"):
plot_projs_topomap([eeg_proj], info_meg)


@pytest.mark.parametrize("vlim", ("joint", None))
@pytest.mark.parametrize("meg", ("combined", "separate"))
def test_plot_projs_topomap_joint(meg, vlim, raw):
"""Test that plot_projs_topomap works with joint vlim."""
if vlim is None:
vlim = (None, None)
projs = compute_proj_raw(raw, meg=meg)
fig = plot_projs_topomap(projs, info=raw.info, vlim=vlim, **fast_test)
assert len(fig.axes) == 4 # 2 mag, 2 grad


def test_plot_topomap_animation(capsys):
"""Test topomap plotting."""
# evoked
Expand Down Expand Up @@ -322,7 +329,6 @@ def test_plot_topomap_basic():
"""Test basics of topomap plotting."""
evoked = read_evokeds(evoked_fname, "Left Auditory", baseline=(None, 0))
res = 8
fast_test = dict(res=res, contours=0, sensors=False, time_unit="s")
fast_test_noscale = dict(res=res, contours=0, sensors=False)
ev_bad = evoked.copy().pick(picks="eeg")
ev_bad.pick(ev_bad.ch_names[:2])
Expand Down Expand Up @@ -649,8 +655,6 @@ def test_plot_arrowmap(evoked):
@testing.requires_testing_data
def test_plot_topomap_neuromag122():
"""Test topomap plotting."""
res = 8
fast_test = dict(res=res, contours=0, sensors=False)
evoked = read_evokeds(evoked_fname, "Left Auditory", baseline=(None, 0))
evoked.pick(picks="grad")
evoked.pick(evoked.ch_names[:122])
Expand Down
50 changes: 31 additions & 19 deletions mne/viz/topomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,34 +474,46 @@ def _plot_projs_topomap(
projs = _check_type_projs(projs)
_validate_type(info, "info", "info")

types, datas, poss, spheres, outliness, ch_typess = [], [], [], [], [], []
# Preprocess projs to deal with joint MEG projectors. If we duplicate these and
# split into mag and grad, they should work as expected
info_names = _clean_names(info["ch_names"], remove_whitespace=True)
use_projs = list()
for proj in projs:
proj = _eliminate_zeros(proj) # gh 5641, makes a copy
proj["data"]["col_names"] = _clean_names(
proj["data"]["col_names"],
remove_whitespace=True,
)
picks = pick_channels(info_names, proj["data"]["col_names"], ordered=True)
proj_types = info.get_channel_types(picks)
unique_types = sorted(set(proj_types))
for type_ in unique_types:
proj_picks = np.where([proj_type == type_ for proj_type in proj_types])[0]
use_projs.append(copy.deepcopy(proj))
use_projs[-1]["data"]["data"] = proj["data"]["data"][:, proj_picks]
use_projs[-1]["data"]["col_names"] = [
proj["data"]["col_names"][pick] for pick in proj_picks
]
projs = use_projs

datas, poss, spheres, outliness, ch_typess = [], [], [], [], []
for proj in projs:
# get ch_names, ch_types, data
proj = _eliminate_zeros(proj) # gh 5641
ch_names = _clean_names(proj["data"]["col_names"], remove_whitespace=True)
if vlim == "joint":
ch_idxs = np.where(np.isin(info["ch_names"], proj["data"]["col_names"]))[0]
these_ch_types = info.get_channel_types(ch_idxs, unique=True)
# each projector should have only one channel type
assert len(these_ch_types) == 1
types.append(list(these_ch_types)[0])
data = proj["data"]["data"].ravel()
info_names = _clean_names(info["ch_names"], remove_whitespace=True)
picks = pick_channels(info_names, ch_names, ordered=True)
picks = pick_channels(info_names, proj["data"]["col_names"], ordered=True)
use_info = pick_info(info, picks)
these_ch_types = use_info.get_channel_types(unique=True)
assert len(these_ch_types) == 1 # should be guaranteed above
ch_type = these_ch_types[0]
(
data_picks,
pos,
merge_channels,
names,
ch_type,
_,
this_sphere,
clip_origin,
) = _prepare_topomap_plot(
use_info,
_get_plot_ch_type(use_info, None),
sphere=sphere,
)
) = _prepare_topomap_plot(use_info, ch_type, sphere=sphere)
these_outlines = _make_head_outlines(sphere, pos, outlines, clip_origin)
data = data[data_picks]
if merge_channels:
Expand Down Expand Up @@ -530,8 +542,8 @@ def _plot_projs_topomap(
# handle vmin/vmax
vlims = [None for _ in range(len(datas))]
if vlim == "joint":
for _ch_type in set(types):
idx = np.where(np.isin(types, _ch_type))[0]
for _ch_type in set(ch_typess):
idx = np.where(np.isin(ch_typess, _ch_type))[0]
these_data = np.concatenate(np.array(datas, dtype=object)[idx])
norm = all(these_data >= 0)
_vl = _setup_vmin_vmax(these_data, vmin=None, vmax=None, norm=norm)
Expand Down

0 comments on commit 8051f6d

Please sign in to comment.