diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 21c84feafbc..87ed238e996 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -17,7 +17,7 @@ Current (0.23.dev0) Enhancements ~~~~~~~~~~~~ -- None yet +- Update the ``notebook`` 3d backend to use ``ipyvtk_simple`` for a better integration within ``Jupyter`` (:gh:`8503` by `Guillaume Favelier`_) Bugs ~~~~ diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index 931527245ed..142c0320192 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -429,8 +429,10 @@ def __init__(self, subject_id, hemi, surf, title=None, shape=shape, fig=figure) - if _get_3d_backend() == "pyvista": - self.plotter = self._renderer.plotter + self.plotter = self._renderer.plotter + if self.notebook: + self.window = None + else: self.window = self.plotter.app_window self.window.signal_close.connect(self._clean) @@ -501,11 +503,6 @@ def setup_time_viewer(self, time_viewer=True, show_traces=True): self.orientation = list(_lh_views_dict.keys()) self.default_smoothing_range = [0, 15] - # setup notebook - if self.notebook: - self._configure_notebook() - return - # Default configuration self.playback = False self.visibility = False @@ -550,10 +547,15 @@ def setup_time_viewer(self, time_viewer=True, show_traces=True): # Direct access parameters: self._iren = self._renderer.plotter.iren - self.main_menu = self.plotter.main_menu - self.tool_bar = self.window.addToolBar("toolbar") - self.status_bar = self.window.statusBar() - self.interactor = self.plotter.interactor + self.tool_bar = None + if self.notebook: + self.main_menu = None + self.status_bar = None + self.interactor = None + else: + self.main_menu = self.plotter.main_menu + self.status_bar = self.window.statusBar() + self.interactor = self.plotter.interactor # Derived parameters: self.playback_speed = self.default_playback_speed_value @@ -584,21 +586,24 @@ def setup_time_viewer(self, time_viewer=True, show_traces=True): self.separate_canvas = False del show_traces - self._load_icons() self._configure_time_label() self._configure_sliders() self._configure_scalar_bar() - self._configure_playback() - self._configure_menu() - self._configure_tool_bar() - self._configure_status_bar() + self._configure_shortcuts() self._configure_picking() + self._configure_tool_bar() + if self.notebook: + self.show() self._configure_trace_mode() - - # show everything at the end self.toggle_interface() - with self.ensure_minimum_sizes(): - self.show() + if not self.notebook: + self._configure_playback() + self._configure_menu() + self._configure_status_bar() + + # show everything at the end + with self.ensure_minimum_sizes(): + self.show() @safe_event def _clean(self): @@ -677,10 +682,13 @@ def toggle_interface(self, value=None): self.visibility = value # update tool bar icon - if self.visibility: - self.actions["visibility"].setIcon(self.icons["visibility_on"]) - else: - self.actions["visibility"].setIcon(self.icons["visibility_off"]) + if not self.notebook: + if self.visibility: + self.actions["visibility"].setIcon( + self.icons["visibility_on"]) + else: + self.actions["visibility"].setIcon( + self.icons["visibility_off"]) # manage sliders for slider in self.plotter.slider_widgets: @@ -810,10 +818,6 @@ def _set_slider_style(self): ) slider_rep.GetCapProperty().SetOpacity(0) - def _configure_notebook(self): - from ._notebook import _NotebookInteractor - self._renderer.figure.display = _NotebookInteractor(self) - def _configure_time_label(self): self.time_actor = self._data.get('time_actor') if self.time_actor is not None: @@ -998,19 +1002,24 @@ def _configure_playback(self): self.plotter.add_callback(self._play, self.refresh_rate_ms) def _configure_mplcanvas(self): - win = self.plotter.app_window - dpi = win.windowHandle().screen().logicalDotsPerInch() ratio = (1 - self.interactor_fraction) / self.interactor_fraction - w = self.interactor.geometry().width() - h = self.interactor.geometry().height() / ratio + if self.notebook: + dpi = 96 + w, h = self.plotter.window_size + else: + dpi = self.window.windowHandle().screen().logicalDotsPerInch() + w = self.interactor.geometry().width() + h = self.interactor.geometry().height() + h /= ratio # Get the fractional components for the brain and mpl - self.mpl_canvas = MplCanvas(self, w / dpi, h / dpi, dpi) + self.mpl_canvas = MplCanvas(self, w / dpi, h / dpi, dpi, + self.notebook) xlim = [np.min(self._data['time']), np.max(self._data['time'])] with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) self.mpl_canvas.axes.set(xlim=xlim) - if not self.separate_canvas: + if not self.notebook and not self.separate_canvas: from PyQt5.QtWidgets import QSplitter from PyQt5.QtCore import Qt canvas = self.mpl_canvas.canvas @@ -1106,10 +1115,13 @@ def _configure_picking(self): def _configure_trace_mode(self): from ...source_estimate import _get_allowed_label_modes from ...label import _read_annot_cands - from PyQt5.QtWidgets import QComboBox, QLabel if not self.show_traces: return + if self.notebook: + self._configure_vertex_time_course() + return + # do not show trace mode for volumes if (self._data.get('src', None) is not None and self._data['src'].kind == 'volume'): @@ -1131,6 +1143,7 @@ def _set_annot(annot): self._configure_label_time_course() self._update() + from PyQt5.QtWidgets import QComboBox, QLabel dir_name = op.join(self._subjects_dir, self._subject_id, 'label') cands = _read_annot_cands(dir_name) self.tool_bar.addSeparator() @@ -1201,60 +1214,109 @@ def _load_icons(self): def _save_movie_noname(self): return self.save_movie(None) + def _screenshot(self): + if not self.notebook: + self.plotter._qt_screenshot() + + def _initialize_actions(self): + if not self.notebook: + self._load_icons() + self.tool_bar = self.window.addToolBar("toolbar") + + def _add_action(self, name, desc, func, icon_name, qt_icon_name=None, + notebook=True): + if self.notebook: + if not notebook: + return + from ipywidgets import Button + self.actions[name] = Button(description=desc, icon=icon_name) + self.actions[name].on_click(lambda x: func()) + else: + qt_icon_name = name if qt_icon_name is None else qt_icon_name + self.actions[name] = self.tool_bar.addAction( + self.icons[qt_icon_name], + desc, + func, + ) + def _configure_tool_bar(self): - self.actions["screenshot"] = self.tool_bar.addAction( - self.icons["screenshot"], - "Take a screenshot", - self.plotter._qt_screenshot + self._initialize_actions() + self._add_action( + name="screenshot", + desc="Take a screenshot", + func=self._screenshot, + icon_name=None, + notebook=False, ) - self.actions["movie"] = self.tool_bar.addAction( - self.icons["movie"], - "Save movie...", - self._save_movie_noname, + self._add_action( + name="movie", + desc="Save movie...", + func=self._save_movie_noname, + icon_name=None, + notebook=False, ) - self.actions["visibility"] = self.tool_bar.addAction( - self.icons["visibility_on"], - "Toggle Visibility", - self.toggle_interface + self._add_action( + name="visibility", + desc="Toggle Visibility", + func=self.toggle_interface, + icon_name="eye", + qt_icon_name="visibility_on", ) - self.actions["play"] = self.tool_bar.addAction( - self.icons["play"], - "Play/Pause", - self.toggle_playback + self._add_action( + name="play", + desc="Play/Pause", + func=self.toggle_playback, + icon_name=None, + notebook=False, ) - self.actions["reset"] = self.tool_bar.addAction( - self.icons["reset"], - "Reset", - self.reset + self._add_action( + name="reset", + desc="Reset", + func=self.reset, + icon_name="history", ) - self.actions["scale"] = self.tool_bar.addAction( - self.icons["scale"], - "Auto-Scale", - self.apply_auto_scaling + self._add_action( + name="scale", + desc="Auto-Scale", + func=self.apply_auto_scaling, + icon_name="magic", ) - self.actions["restore"] = self.tool_bar.addAction( - self.icons["restore"], - "Restore scaling", - self.restore_user_scaling + self._add_action( + name="restore", + desc="Restore scaling", + func=self.restore_user_scaling, + icon_name="reply", ) - self.actions["clear"] = self.tool_bar.addAction( - self.icons["clear"], - "Clear traces", - self.clear_glyphs + self._add_action( + name="clear", + desc="Clear traces", + func=self.clear_glyphs, + icon_name="trash", ) - self.actions["help"] = self.tool_bar.addAction( - self.icons["help"], - "Help", - self.help + self._add_action( + name="help", + desc="Help", + func=self.help, + icon_name=None, + notebook=False, ) - self.actions["movie"].setShortcut("ctrl+shift+s") - self.actions["visibility"].setShortcut("i") - self.actions["play"].setShortcut(" ") - self.actions["scale"].setShortcut("s") - self.actions["restore"].setShortcut("r") - self.actions["clear"].setShortcut("c") - self.actions["help"].setShortcut("?") + if self.notebook: + from IPython import display + from ipywidgets import HBox + self.tool_bar = HBox(tuple(self.actions.values())) + display.display(self.tool_bar) + else: + # Qt shortcuts + self.actions["movie"].setShortcut("ctrl+shift+s") + self.actions["play"].setShortcut(" ") + self.actions["help"].setShortcut("?") + + def _configure_shortcuts(self): + self.plotter.add_key_event("i", self.toggle_interface) + self.plotter.add_key_event("s", self.apply_auto_scaling) + self.plotter.add_key_event("r", self.restore_user_scaling) + self.plotter.add_key_event("c", self.clear_glyphs) def _configure_menu(self): # remove default picking menu @@ -3123,13 +3185,6 @@ def _iter_time(self, time_idx, callback): # Restore original time index func(current_time_idx) - def _show(self): - """Request rendering of the window.""" - try: - return self._renderer.show() - except RuntimeError: - logger.info("No active/running renderer available.") - def _check_stc(self, hemi, array, vertices): from ...source_estimate import ( _BaseSourceEstimate, _BaseSurfaceSourceEstimate, @@ -3220,7 +3275,7 @@ def _update(self): from ..backends import renderer if renderer.get_3d_backend() in ['pyvista', 'notebook']: if self.notebook and self._renderer.figure.display is not None: - self._renderer.figure.display.update() + self._renderer.figure.display.update_canvas() else: self._renderer.plotter.update() diff --git a/mne/viz/_brain/_notebook.py b/mne/viz/_brain/_notebook.py deleted file mode 100644 index 801ba240c07..00000000000 --- a/mne/viz/_brain/_notebook.py +++ /dev/null @@ -1,67 +0,0 @@ -# Authors: Guillaume Favelier -# -# License: Simplified BSD - -from ..backends._notebook \ - import _NotebookInteractor as _PyVistaNotebookInteractor - - -class _NotebookInteractor(_PyVistaNotebookInteractor): - def __init__(self, brain): - self.brain = brain - super().__init__(self.brain._renderer) - - def configure_controllers(self): - from ipywidgets import (IntSlider, interactive, Play, VBox, - HBox, Label, jslink) - super().configure_controllers() - # orientation - self.controllers["orientation"] = interactive( - self.set_orientation, - orientation=self.brain.orientation, - ) - # smoothing - self.sliders["smoothing"] = IntSlider( - value=self.brain._data['smoothing_steps'], - min=self.brain.default_smoothing_range[0], - max=self.brain.default_smoothing_range[1], - continuous_update=False - ) - self.controllers["smoothing"] = VBox([ - Label(value='Smoothing steps'), - interactive( - self.brain.set_data_smoothing, - n_steps=self.sliders["smoothing"] - ) - ]) - # time slider - max_time = len(self.brain._data['time']) - 1 - if max_time >= 1: - time_player = Play( - value=self.brain._data['time_idx'], - min=0, - max=max_time, - continuous_update=False - ) - time_slider = IntSlider( - min=0, - max=max_time, - ) - jslink((time_player, 'value'), (time_slider, 'value')) - time_slider.observe(self.set_time_point, 'value') - self.controllers["time"] = VBox([ - HBox([ - Label(value='Select time point'), - time_player, - ]), - time_slider, - ]) - self.sliders["time"] = time_slider - - def set_orientation(self, orientation): - row, col = self.plotter.index_to_loc( - self.plotter._active_renderer_index) - self.brain.show_view(orientation, row=row, col=col) - - def set_time_point(self, data): - self.brain.set_time_point(data['new']) diff --git a/mne/viz/_brain/mplcanvas.py b/mne/viz/_brain/mplcanvas.py index 23b9f4d7295..6870f6ada12 100644 --- a/mne/viz/_brain/mplcanvas.py +++ b/mne/viz/_brain/mplcanvas.py @@ -11,11 +11,10 @@ class MplCanvas(object): """Ultimately, this is a QWidget (as well as a FigureCanvasAgg, etc.).""" - def __init__(self, brain, width, height, dpi): - from PyQt5 import QtWidgets + def __init__(self, brain, width, height, dpi, notebook=False): from matplotlib import rc_context from matplotlib.figure import Figure - from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg + self.notebook = notebook if brain.separate_canvas: parent = None else: @@ -30,16 +29,25 @@ def __init__(self, brain, width, height, dpi): pass with context: self.fig = Figure(figsize=(width, height), dpi=dpi) - self.canvas = FigureCanvasQTAgg(self.fig) + if self.notebook: + from matplotlib.backends.backend_nbagg import (FigureCanvasNbAgg, + FigureManager) + self.canvas = FigureCanvasNbAgg(self.fig) + self.manager = FigureManager(self.canvas, 0) + else: + from PyQt5 import QtWidgets + from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg + self.canvas = FigureCanvasQTAgg(self.fig) + self.canvas.setParent(parent) + FigureCanvasQTAgg.setSizePolicy( + self.canvas, + QtWidgets.QSizePolicy.Expanding, + QtWidgets.QSizePolicy.Expanding + ) + FigureCanvasQTAgg.updateGeometry(self.canvas) + self.manager = None self.axes = self.fig.add_subplot(111) self.axes.set(xlabel='Time (sec)', ylabel='Activation (AU)') - self.canvas.setParent(parent) - FigureCanvasQTAgg.setSizePolicy( - self.canvas, - QtWidgets.QSizePolicy.Expanding, - QtWidgets.QSizePolicy.Expanding - ) - FigureCanvasQTAgg.updateGeometry(self.canvas) self.brain = brain self.time_func = brain.callbacks["time"] for event in ('button_press', 'motion_notify') + extra_events: @@ -86,7 +94,10 @@ def set_color(self, bg_color, fg_color): def show(self): """Show the canvas.""" - self.canvas.show() + if self.notebook: + self.manager.show() + else: + self.canvas.show() def close(self): """Close the canvas.""" @@ -108,6 +119,7 @@ def clear(self): self.fig.clear() self.brain = None self.canvas = None + self.manager = None on_motion_notify = on_button_press # for now they can be the same diff --git a/mne/viz/_brain/tests/test.ipynb b/mne/viz/_brain/tests/test.ipynb index 0e7dbe7ffbd..80a8bec809e 100644 --- a/mne/viz/_brain/tests/test.ipynb +++ b/mne/viz/_brain/tests/test.ipynb @@ -15,9 +15,10 @@ "trans = data_path + '/MEG/sample/sample_audvis_trunc-trans.fif'\n", "info = mne.io.read_info(raw_fname)\n", "mne.viz.set_3d_backend('notebook')\n", - "mne.viz.plot_alignment(info, trans, subject=subject, dig=True,\n", - " meg=['helmet', 'sensors'], subjects_dir=subjects_dir,\n", - " surfaces=['head-dense'])" + "fig = mne.viz.plot_alignment(info, trans, subject=subject, dig=True,\n", + " meg=['helmet', 'sensors'], subjects_dir=subjects_dir,\n", + " surfaces=['head-dense'])\n", + "assert fig.display is not None" ] }, { @@ -46,8 +47,8 @@ " hemi='split')\n", " assert isinstance(brain, brain_class)\n", " assert brain.notebook\n", - " interactor = brain._renderer.figure.display\n", - " interactor.set_time_point({'new': 0})\n", + " assert brain._renderer.figure.display is not None\n", + " brain._update()\n", " brain.close()" ] }, @@ -59,9 +60,13 @@ "source": [ "import mne\n", "mne.viz.set_3d_backend('notebook')\n", - "fig = mne.viz.create_3d_figure(size=(100, 100))\n", + "rend = mne.viz.create_3d_figure(size=(100, 100), scene=False)\n", + "fig = rend.scene()\n", "mne.viz.set_3d_title(fig, 'Notebook testing')\n", - "mne.viz.set_3d_view(fig, 200, 70, focalpoint=[0, 0, 0])" + "mne.viz.set_3d_view(fig, 200, 70, focalpoint=[0, 0, 0])\n", + "assert fig.display is None\n", + "rend.show()\n", + "assert fig.display is not None" ] } ], diff --git a/mne/viz/backends/_notebook.py b/mne/viz/backends/_notebook.py index ec211f229c7..761f0b8a60f 100644 --- a/mne/viz/backends/_notebook.py +++ b/mne/viz/backends/_notebook.py @@ -2,8 +2,6 @@ # # License: Simplified BSD -import matplotlib.pyplot as plt -from contextlib import contextmanager from ...fixes import nullcontext from ._pyvista import _Renderer as _PyVistaRenderer from ._pyvista import \ @@ -12,157 +10,16 @@ class _Renderer(_PyVistaRenderer): def __init__(self, *args, **kwargs): - from IPython import get_ipython - ipython = get_ipython() - ipython.magic('matplotlib widget') kwargs["notebook"] = True super().__init__(*args, **kwargs) def show(self): - self.figure.display = _NotebookInteractor(self) + from IPython.display import display + self.figure.display = self.plotter.show(use_ipyvtk=True, + return_viewer=True) + self.figure.display.layout.width = None # unlock the fixed layout + display(self.figure.display) return self.scene() -class _NotebookInteractor(object): - def __init__(self, renderer): - from IPython import display - from ipywidgets import HBox, VBox - self.dpi = 90 - self.sliders = dict() - self.controllers = dict() - self.renderer = renderer - self.plotter = self.renderer.plotter - with self.disabled_interactivity(): - self.fig, self.dh = self.screenshot() - self.configure_controllers() - controllers = VBox(list(self.controllers.values())) - layout = HBox([self.fig.canvas, controllers]) - display.display(layout) - - @contextmanager - def disabled_interactivity(self): - state = plt.isinteractive() - plt.ioff() - try: - yield - finally: - if state: - plt.ion() - else: - plt.ioff() - - def screenshot(self): - width, height = self.renderer.figure.store['window_size'] - - fig = plt.figure() - fig.figsize = (width / self.dpi, height / self.dpi) - fig.dpi = self.dpi - fig.canvas.toolbar_visible = False - fig.canvas.header_visible = False - fig.canvas.resizable = False - fig.canvas.callbacks.callbacks.clear() - ax = plt.Axes(fig, [0., 0., 1., 1.]) - ax.set_axis_off() - fig.add_axes(ax) - - dh = ax.imshow(self.plotter.screenshot()) - return fig, dh - - def update(self): - self.plotter.render() - self.dh.set_data(self.plotter.screenshot()) - self.fig.canvas.draw() - - def configure_controllers(self): - from ipywidgets import (interactive, Label, VBox, FloatSlider, - IntSlider, Checkbox) - # continuous update - self.continuous_update_button = Checkbox( - value=False, - description='Continuous update', - disabled=False, - indent=False, - ) - self.controllers["continuous_update"] = interactive( - self.set_continuous_update, - value=self.continuous_update_button - ) - # subplot - number_of_plots = len(self.plotter.renderers) - if number_of_plots > 1: - self.sliders["subplot"] = IntSlider( - value=number_of_plots - 1, - min=0, - max=number_of_plots - 1, - step=1, - continuous_update=False - ) - self.controllers["subplot"] = VBox([ - Label(value='Select the subplot'), - interactive( - self.set_subplot, - index=self.sliders["subplot"], - ) - ]) - # azimuth - default_azimuth = self.plotter.renderer._azimuth - self.sliders["azimuth"] = FloatSlider( - value=default_azimuth, - min=-180., - max=180., - step=10., - continuous_update=False - ) - # elevation - default_elevation = self.plotter.renderer._elevation - self.sliders["elevation"] = FloatSlider( - value=default_elevation, - min=-180., - max=180., - step=10., - continuous_update=False - ) - # distance - eps = 1e-5 - default_distance = self.plotter.renderer._distance - self.sliders["distance"] = FloatSlider( - value=default_distance, - min=eps, - max=2. * default_distance - eps, - step=default_distance / 10., - continuous_update=False - ) - # camera - self.controllers["camera"] = VBox([ - Label(value='Camera settings'), - interactive( - self.set_camera, - azimuth=self.sliders["azimuth"], - elevation=self.sliders["elevation"], - distance=self.sliders["distance"], - ) - ]) - - def set_camera(self, azimuth, elevation, distance): - focalpoint = self.plotter.camera.GetFocalPoint() - self.renderer.set_camera(azimuth, elevation, - distance, focalpoint) - self.update() - - def set_subplot(self, index): - row, col = self.plotter.index_to_loc(index) - self.renderer.subplot(row, col) - figure = self.renderer.figure - default_azimuth = figure.plotter.renderer._azimuth - default_elevation = figure.plotter.renderer._elevation - default_distance = figure.plotter.renderer._distance - self.sliders["azimuth"].value = default_azimuth - self.sliders["elevation"].value = default_elevation - self.sliders["distance"].value = default_distance - - def set_continuous_update(self, value): - for slider in self.sliders.values(): - slider.continuous_update = value - - _testing_context = nullcontext diff --git a/server_environment.yml b/server_environment.yml index 8530c5dd767..86c3fda05b6 100644 --- a/server_environment.yml +++ b/server_environment.yml @@ -23,4 +23,5 @@ dependencies: - jupyter - ipympl - ipywidgets + - ipyvtk_simple - jupyter_client!=6.1.5