diff --git a/doc/changes/devel/12071.newfeature.rst b/doc/changes/devel/12071.newfeature.rst new file mode 100644 index 00000000000..4e7995e3beb --- /dev/null +++ b/doc/changes/devel/12071.newfeature.rst @@ -0,0 +1 @@ +Add new ``select`` parameter to :func:`mne.viz.plot_evoked_topo` and :meth:`mne.Evoked.plot_topo` to toggle lasso selection of sensors, by `Marijn van Vliet`_. diff --git a/mne/epochs.py b/mne/epochs.py index 515bbb69a72..ba9d78df1e1 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -1353,6 +1353,7 @@ def plot_topo_image( fig_facecolor="k", fig_background=None, font_color="w", + select=False, show=True, ): return plot_topo_image_epochs( @@ -1371,6 +1372,7 @@ def plot_topo_image( fig_facecolor=fig_facecolor, fig_background=fig_background, font_color=font_color, + select=select, show=show, ) diff --git a/mne/evoked.py b/mne/evoked.py index a985fc30ad7..597633c5f9e 100644 --- a/mne/evoked.py +++ b/mne/evoked.py @@ -613,6 +613,7 @@ def plot_topo( background_color="w", noise_cov=None, exclude="bads", + select=False, show=True, ): """ @@ -638,6 +639,7 @@ def plot_topo( background_color=background_color, noise_cov=noise_cov, exclude=exclude, + select=select, show=show, ) diff --git a/mne/viz/_figure.py b/mne/viz/_figure.py index ffb1c57dafd..4eebf273772 100644 --- a/mne/viz/_figure.py +++ b/mne/viz/_figure.py @@ -500,11 +500,11 @@ def _create_ch_location_fig(self, pick): show=False, ) # highlight desired channel & disable interactivity - inds = np.isin(fig.lasso.ch_names, [ch_name]) + fig.lasso.selection_inds = np.isin(fig.lasso.names, [ch_name]) fig.lasso.disconnect() - fig.lasso.alpha_other = 0.3 + fig.lasso.alpha_nonselected = 0.3 fig.lasso.linewidth_selected = 3 - fig.lasso.style_sensors(inds) + fig.lasso.style_objects() return fig diff --git a/mne/viz/_mpl_figure.py b/mne/viz/_mpl_figure.py index 2e552bd4012..3987b641dff 100644 --- a/mne/viz/_mpl_figure.py +++ b/mne/viz/_mpl_figure.py @@ -1536,7 +1536,7 @@ def _update_selection(self): def _update_highlighted_sensors(self): """Update the sensor plot to show what is selected.""" inds = np.isin( - self.mne.fig_selection.lasso.ch_names, self.mne.ch_names[self.mne.picks] + self.mne.fig_selection.lasso.names, self.mne.ch_names[self.mne.picks] ).nonzero()[0] self.mne.fig_selection.lasso.select_many(inds) diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index 0343a5b7d62..6fd99b0e3ba 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -1151,6 +1151,7 @@ def plot_evoked_topo( background_color="w", noise_cov=None, exclude="bads", + select=False, show=True, ): """Plot 2D topography of evoked responses. @@ -1216,6 +1217,15 @@ def plot_evoked_topo( exclude : list of str | ``'bads'`` Channels names to exclude from the plot. If ``'bads'``, the bad channels are excluded. By default, exclude is set to ``'bads'``. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. + + .. versionadded:: 1.9.0 + exclude : list of str | ``'bads'`` + Channels names to exclude from the plot. If ``'bads'``, the + bad channels are excluded. By default, exclude is set to ``'bads'``. show : bool Show figure if True. @@ -1272,10 +1282,11 @@ def plot_evoked_topo( font_color=font_color, merge_channels=merge_grads, legend=legend, + noise_cov=noise_cov, axes=axes, exclude=exclude, + select=select, show=show, - noise_cov=noise_cov, ) diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index a2565927feb..8a3282b55b7 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -1094,32 +1094,48 @@ def test_plot_sensors(raw): ax = fig.axes[0] # Click with no sensors - _fake_click(fig, ax, (0.0, 0.0), xform="data") - _fake_click(fig, ax, (0, 0.0), xform="data", kind="release") + _fake_click(fig, ax, (-0.14, 0.14), xform="data") + _fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.14), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.14), xform="data", kind="release") assert fig.lasso.selection == [] # Lasso with 1 sensor (upper left) - _fake_click(fig, ax, (0, 1), xform="ax") - fig.canvas.draw() - assert fig.lasso.selection == [] - _fake_click(fig, ax, (0.65, 1), xform="ax", kind="motion") - _fake_click(fig, ax, (0.65, 0.7), xform="ax", kind="motion") - _fake_keypress(fig, "control") - _fake_click(fig, ax, (0, 0.7), xform="ax", kind="release", key="control") + _fake_click(fig, ax, (-0.13, 0.13), xform="data") + _fake_click(fig, ax, (-0.11, 0.13), xform="data", kind="motion") + _fake_click(fig, ax, (-0.11, 0.06), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.06), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="release") assert fig.lasso.selection == ["MEG 0121"] - # check that point appearance changes + # Use SHIFT key to lasso an additional sensor. + _fake_keypress(fig, "shift") + _fake_click(fig, ax, (-0.17, 0.07), xform="data") + _fake_click(fig, ax, (-0.17, 0.05), xform="data", kind="motion") + _fake_click(fig, ax, (-0.15, 0.05), xform="data", kind="motion") + _fake_click(fig, ax, (-0.15, 0.07), xform="data", kind="motion") + _fake_click(fig, ax, (-0.15, 0.07), xform="data", kind="release") + _fake_keypress(fig, "shift", kind="release") + assert fig.lasso.selection == ["MEG 0111", "MEG 0121"] + + # Check that the two selected sensors have a different appearance. fc = fig.lasso.collection.get_facecolors() ec = fig.lasso.collection.get_edgecolors() - assert (fc[:, -1] == [0.5, 1.0, 0.5]).all() - assert (ec[:, -1] == [0.25, 1.0, 0.25]).all() - - _fake_click(fig, ax, (0.7, 1), xform="ax", kind="motion", key="control") - xy = ax.collections[0].get_offsets() - _fake_click(fig, ax, xy[2], xform="data", key="control") # single sel - assert fig.lasso.selection == ["MEG 0121", "MEG 0131"] - _fake_click(fig, ax, xy[2], xform="data", key="control") # deselect - assert fig.lasso.selection == ["MEG 0121"] + assert (fc[2:, -1] == 0.5).all() + assert (ec[2:, -1] == 0.25).all() + assert (fc[:2, -1] == 1.0).all() + assert (ec[:2:, -1] == 1.0).all() + + # Use ALT key to remove a sensor from the lasso. + _fake_keypress(fig, "alt") + _fake_click(fig, ax, (-0.17, 0.07), xform="data") + _fake_click(fig, ax, (-0.17, 0.05), xform="data", kind="motion") + _fake_click(fig, ax, (-0.15, 0.05), xform="data", kind="motion") + _fake_click(fig, ax, (-0.15, 0.07), xform="data", kind="motion") + _fake_click(fig, ax, (-0.15, 0.07), xform="data", kind="release") + _fake_keypress(fig, "alt", kind="release") + plt.close("all") raw.info["dev_head_t"] = None # like empty room diff --git a/mne/viz/topo.py b/mne/viz/topo.py index 3364a455aed..4db685a459b 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -13,8 +13,10 @@ from .._fiff.pick import _picks_to_idx, channel_type, pick_types from ..defaults import _handle_default from ..utils import Bunch, _check_option, _clean_names, _is_numeric, _to_rgb, fill_doc +from .ui_events import ChannelsSelect, publish, subscribe from .utils import ( DraggableColorbar, + SelectFromCollection, _check_cov, _check_delayed_ssp, _draw_proj_checkbox, @@ -37,6 +39,7 @@ def iter_topography( axis_spinecolor="k", layout_scale=None, legend=False, + select=False, ): """Create iterator over channel positions. @@ -72,6 +75,10 @@ def iter_topography( If True, an additional axis is created in the bottom right corner that can be used to, e.g., construct a legend. The index of this axis will be -1. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. Returns ------- @@ -93,6 +100,7 @@ def iter_topography( axis_spinecolor, layout_scale, legend=legend, + select=select, ) @@ -128,6 +136,7 @@ def _iter_topography( img=False, axes=None, legend=False, + select=False, ): """Iterate over topography. @@ -193,8 +202,11 @@ def format_coord_multiaxis(x, y, ch_name=None): under_ax.set(xlim=[0, 1], ylim=[0, 1]) axs = list() + + shown_ch_names = [] for idx, name in iter_ch: ch_idx = ch_names.index(name) + shown_ch_names.append(name) if not unified: # old, slow way ax = plt.axes(pos[idx]) ax.patch.set_facecolor(axis_facecolor) @@ -226,24 +238,47 @@ def format_coord_multiaxis(x, y, ch_name=None): if unified: under_ax._mne_axs = axs # Create a PolyCollection for the axis backgrounds + sel_pos = pos[[i[0] for i in iter_ch]] verts = np.transpose( [ - pos[:, :2], - pos[:, :2] + pos[:, 2:] * [1, 0], - pos[:, :2] + pos[:, 2:], - pos[:, :2] + pos[:, 2:] * [0, 1], + sel_pos[:, :2], + sel_pos[:, :2] + sel_pos[:, 2:] * [1, 0], + sel_pos[:, :2] + sel_pos[:, 2:], + sel_pos[:, :2] + sel_pos[:, 2:] * [0, 1], ], [1, 0, 2], ) - if not img: - under_ax.add_collection( - collections.PolyCollection( - verts, - facecolor=axis_facecolor, - edgecolor=axis_spinecolor, - linewidth=1.0, + if not img: # Not needed for image plots. + collection = collections.PolyCollection( + verts, + facecolor=axis_facecolor, + edgecolor=axis_spinecolor, + linewidth=1.0, + ) + under_ax.add_collection(collection) + + if select: + # Configure the lasso-selection tool + fig.lasso = SelectFromCollection( + ax=under_ax, + collection=collection, + names=shown_ch_names, + alpha_nonselected=0, + alpha_selected=1, + linewidth_nonselected=0, + linewidth_selected=0.7, ) - ) # Not needed for image plots. + + def on_select(): + publish(fig, ChannelsSelect(ch_names=fig.lasso.selection)) + + def on_channels_select(event): + ch_inds = {name: i for i, name in enumerate(ch_names)} + selection_inds = [ch_inds[name] for name in event.ch_names] + fig.lasso.select_many(selection_inds) + + fig.lasso.callbacks.append(on_select) + subscribe(fig, "channels_select", on_channels_select) for ax in axs: yield ax, ax._mne_ch_idx @@ -270,6 +305,7 @@ def _plot_topo( unified=False, img=False, axes=None, + select=False, ): """Plot on sensor layout.""" import matplotlib.pyplot as plt @@ -322,6 +358,7 @@ def _plot_topo( unified=unified, img=img, axes=axes, + select=select, ) for ax, ch_idx in my_topo_plot: @@ -342,6 +379,9 @@ def _plot_topo_onpick(event, show_func): """Onpick callback that shows a single channel in a new figure.""" # make sure that the swipe gesture in OS-X doesn't open many figures orig_ax = event.inaxes + if orig_ax.figure.canvas._key in ["shift", "alt"]: + return + import matplotlib.pyplot as plt try: @@ -838,9 +878,10 @@ def _plot_evoked_topo( merge_channels=False, legend=True, axes=None, + noise_cov=None, exclude="bads", + select=False, show=True, - noise_cov=None, ): """Plot 2D topography of evoked responses. @@ -912,6 +953,10 @@ def _plot_evoked_topo( exclude : list of str | 'bads' Channels names to exclude from being shown. If 'bads', the bad channels are excluded. By default, exclude is set to 'bads'. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. show : bool Show figure if True. @@ -1091,6 +1136,7 @@ def _plot_evoked_topo( y_label=y_label, unified=True, axes=axes, + select=select, ) add_background_image(fig, fig_background) @@ -1098,7 +1144,10 @@ def _plot_evoked_topo( if legend is not False: legend_loc = 0 if legend is True else legend labels = [e.comment if e.comment else "Unknown" for e in evoked] - handles = fig.axes[0].lines[: len(evoked)] + if select: + handles = fig.axes[0].lines[1 : len(evoked) + 1] + else: + handles = fig.axes[0].lines[: len(evoked)] legend = plt.legend( labels=labels, handles=handles, loc=legend_loc, prop={"size": 10} ) @@ -1157,6 +1206,7 @@ def plot_topo_image_epochs( fig_facecolor="k", fig_background=None, font_color="w", + select=False, show=True, ): """Plot Event Related Potential / Fields image on topographies. @@ -1204,6 +1254,10 @@ def plot_topo_image_epochs( :func:`matplotlib.pyplot.imshow`. Defaults to ``None``. font_color : color The color of tick labels in the colorbar. Defaults to white. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. show : bool Whether to show the figure. Defaults to ``True``. @@ -1293,6 +1347,7 @@ def plot_topo_image_epochs( y_label="Epoch", unified=True, img=True, + select=select, ) add_background_image(fig, fig_background) plt_show(show) diff --git a/mne/viz/ui_events.py b/mne/viz/ui_events.py index 256d5741ad3..b8b3fe29a4d 100644 --- a/mne/viz/ui_events.py +++ b/mne/viz/ui_events.py @@ -212,6 +212,26 @@ class Contours(UIEvent): contours: list[str] +@dataclass +@fill_doc +class ChannelsSelect(UIEvent): + """Indicates that the user has selected one or more channels. + + Parameters + ---------- + ch_names : list of str + The names of the channels that were selected. + + Attributes + ---------- + %(ui_event_name_source)s + ch_names : list of str + The names of the channels that were selected. + """ + + ch_names: list[str] + + def _get_event_channel(fig): """Get the event channel associated with a figure. diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 675b89b2852..e3f26224fd5 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -58,7 +58,7 @@ warn, ) from ..utils.misc import _identity_function -from .ui_events import ColormapRange, publish, subscribe +from .ui_events import ChannelsSelect, ColormapRange, publish, subscribe _channel_type_prettyprint = { "eeg": "EEG channel", @@ -807,12 +807,12 @@ def _fake_click(fig, ax, point, xform="ax", button=1, kind="press", key=None): ) -def _fake_keypress(fig, key): +def _fake_keypress(fig, key, kind="press"): from matplotlib import backend_bases fig.canvas.callbacks.process( - "key_press_event", - backend_bases.KeyEvent(name="key_press_event", canvas=fig.canvas, key=key), + f"key_{kind}_event", + backend_bases.KeyEvent(name=f"key_{kind}_event", canvas=fig.canvas, key=key), ) @@ -952,7 +952,7 @@ def plot_sensors( Whether to plot the sensors as 3d, topomap or as an interactive sensor selection dialog. Available options ``'topomap'``, ``'3d'``, ``'select'``. If ``'select'``, a set of channels can be selected - interactively by using lasso selector or clicking while holding control + interactively by using lasso selector or clicking while holding the shift key. The selected channels are returned along with the figure instance. Defaults to ``'topomap'``. ch_type : None | str @@ -1163,7 +1163,7 @@ def _onpick_sensor(event, fig, ax, pos, ch_names, show_names): if event.mouseevent.inaxes != ax: return - if event.mouseevent.key == "control" and fig.lasso is not None: + if event.mouseevent.key in ["shift", "alt"] and fig.lasso is not None: for ind in event.ind: fig.lasso.select_one(ind) @@ -1272,7 +1272,18 @@ def _plot_sensors_2d( lw=linewidth, ) if kind == "select": - fig.lasso = SelectFromCollection(ax, pts, ch_names) + fig.lasso = SelectFromCollection(ax, pts, names=ch_names) + + def on_select(): + publish(fig, ChannelsSelect(ch_names=fig.lasso.selection)) + + def on_channels_select(event): + ch_inds = {name: i for i, name in enumerate(ch_names)} + selection_inds = [ch_inds[name] for name in event.ch_names] + fig.lasso.select_many(selection_inds) + + fig.lasso.callbacks.append(on_select) + subscribe(fig, "channels_select", on_channels_select) else: fig.lasso = None @@ -1595,11 +1606,11 @@ def _update(self): class SelectFromCollection: - """Select channels from a matplotlib collection using ``LassoSelector``. + """Select objects from a matplotlib collection using ``LassoSelector``. - Selected channels are saved in the ``selection`` attribute. This tool - highlights selected points by fading other points out (i.e., reducing their - alpha values). + The names of the selected objects are saved in the ``selection`` attribute. + This tool highlights selected objects by fading other objects out (i.e., + reducing their alpha values). Parameters ---------- @@ -1607,62 +1618,93 @@ class SelectFromCollection: Axes to interact with. collection : instance of matplotlib collection Collection you want to select from. - alpha_other : 0 <= float <= 1 - To highlight a selection, this tool sets all selected points to an - alpha value of 1 and non-selected points to ``alpha_other``. - Defaults to 0.3. - linewidth_other : float - Linewidth to use for non-selected sensors. Default is 1. + names : list of str + The names of the object. The selection is returned as a subset of these names. + alpha_selected : float + Alpha for selected objects (0=tranparant, 1=opaque). + alpha_nonselected : float + Alpha for non-selected objects (0=tranparant, 1=opaque). + linewidth_selected : float + Linewidth for the borders of selected objects. + linewidth_nonselected : float + Linewidth for the borders of non-selected objects. Notes ----- - This tool selects collection objects based on their *origins* - (i.e., ``offsets``). Calls all callbacks in self.callbacks when selection - is ready. + This tool selects collection objects which bounding boxes intersect with a lasso + path. Calls all callbacks in self.callbacks when selection is ready. """ def __init__( self, ax, collection, - ch_names, - alpha_other=0.5, - linewidth_other=0.5, + *, + names, alpha_selected=1, + alpha_nonselected=0.5, linewidth_selected=1, + linewidth_nonselected=0.5, ): from matplotlib.widgets import LassoSelector + self.fig = ax.figure self.canvas = ax.figure.canvas self.collection = collection - self.ch_names = ch_names - self.alpha_other = alpha_other - self.linewidth_other = linewidth_other + self.names = names self.alpha_selected = alpha_selected + self.alpha_nonselected = alpha_nonselected self.linewidth_selected = linewidth_selected + self.linewidth_nonselected = linewidth_nonselected - self.xys = collection.get_offsets() - self.Npts = len(self.xys) + from matplotlib.collections import PolyCollection + from matplotlib.path import Path - # Ensure that we have separate colors for each object + if isinstance(collection, PolyCollection): + self.paths = collection.get_paths() + else: + self.paths = [Path([point]) for point in collection.get_offsets()] + self.Npts = len(self.paths) + if self.Npts != len(names): + raise ValueError( + f"Number of names ({len(names)}) does not match the number of objects " + f"in the collection ({self.Npts})." + ) + + # Ensure that we have colors for each object. self.fc = collection.get_facecolors() self.ec = collection.get_edgecolors() - self.lw = collection.get_linewidths() if len(self.fc) == 0: raise ValueError("Collection must have a facecolor") elif len(self.fc) == 1: self.fc = np.tile(self.fc, self.Npts).reshape(self.Npts, -1) + if len(self.ec) == 0: + self.ec = np.zeros((self.Npts, 4)) # all black + elif len(self.ec) == 1: self.ec = np.tile(self.ec, self.Npts).reshape(self.Npts, -1) - self.fc[:, -1] = self.alpha_other # deselect in the beginning - self.ec[:, -1] = self.alpha_other - self.lw = np.full(self.Npts, self.linewidth_other) + self.lw = np.full(self.Npts, float(self.linewidth_nonselected)) + # Initialize the lasso selector self.lasso = LassoSelector( ax, onselect=self.on_select, props=dict(color="red", linewidth=0.5) ) self.selection = list() + self.selection_inds = np.array([], dtype="int") self.callbacks = list() + # Deselect everything in the beginning. + self.style_objects() + + # For backwards compatibility + @property + def ch_names(self): + return self.names + + def notify(self): + """Notify listeners that a selection has been made.""" + for callback in self.callbacks: + callback() + def on_select(self, verts): """Select a subset from the collection.""" from matplotlib.path import Path @@ -1671,48 +1713,45 @@ def on_select(self, verts): return path = Path(verts) - inds = np.nonzero([path.contains_point(xy) for xy in self.xys])[0] - if self.canvas._key == "control": # Appending selection. - sels = [np.where(self.ch_names == c)[0][0] for c in self.selection] - inters = set(inds) - set(sels) - inds = list(inters.union(set(sels) - set(inds))) - - self.selection[:] = np.array(self.ch_names)[inds].tolist() - self.style_sensors(inds) + inds = np.nonzero([path.intersects_path(p) for p in self.paths])[0] + if self.canvas._key == "shift": # Appending selection. + self.selection_inds = np.union1d(self.selection_inds, inds).astype("int") + elif self.canvas._key == "alt": # Removing selection. + self.selection_inds = np.setdiff1d(self.selection_inds, inds).astype("int") + else: + self.selection_inds = inds + self.selection = [self.names[i] for i in self.selection_inds] + self.style_objects() self.notify() def select_one(self, ind): """Select or deselect one sensor.""" - ch_name = self.ch_names[ind] - if ch_name in self.selection: - sel_ind = self.selection.index(ch_name) - self.selection.pop(sel_ind) + if self.canvas._key == "shift": + self.selection_inds = np.union1d(self.selection_inds, [ind]) + elif self.canvas._key == "alt": + self.selection_inds = np.setdiff1d(self.selection_inds, [ind]) else: - self.selection.append(ch_name) - inds = np.isin(self.ch_names, self.selection).nonzero()[0] - self.style_sensors(inds) + return # don't notify() + self.selection = [self.names[i] for i in self.selection_inds] + self.style_objects() self.notify() - def notify(self): - """Notify listeners that a selection has been made.""" - for callback in self.callbacks: - callback() - def select_many(self, inds): """Select many sensors using indices (for predefined selections).""" - self.selection[:] = np.array(self.ch_names)[inds].tolist() - self.style_sensors(inds) + self.selection_inds = inds + self.selection = [self.names[i] for i in self.selection_inds] + self.style_objects() - def style_sensors(self, inds): + def style_objects(self): """Style selected sensors as "active".""" # reset - self.fc[:, -1] = self.alpha_other - self.ec[:, -1] = self.alpha_other / 2 - self.lw[:] = self.linewidth_other + self.fc[:, -1] = self.alpha_nonselected + self.ec[:, -1] = self.alpha_nonselected / 2 + self.lw[:] = self.linewidth_nonselected # style sensors at `inds` - self.fc[inds, -1] = self.alpha_selected - self.ec[inds, -1] = self.alpha_selected - self.lw[inds] = self.linewidth_selected + self.fc[self.selection_inds, -1] = self.alpha_selected + self.ec[self.selection_inds, -1] = self.alpha_selected + self.lw[self.selection_inds] = self.linewidth_selected self.collection.set_facecolors(self.fc) self.collection.set_edgecolors(self.ec) self.collection.set_linewidths(self.lw)