diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index 2f5ae6dd792..b6b731743ae 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -475,12 +475,10 @@ def __init__(self, subject_id, hemi, surf, title=None, shape=shape, fig=figure) - self.plotter = self._renderer.plotter - self.viewer = self._renderer.viewer if self.notebook: self.window = None else: - self.window = self.plotter._window + self.window = self._renderer.figure.plotter._window self.window.signal_close.connect(self._clean) self._setup_canonical_rotation() @@ -514,7 +512,7 @@ def __init__(self, subject_id, hemi, surf, title=None, mesh._polydata._hemi = h else: actor = self._layered_meshes[h]._actor - self._renderer.plotter.add_actor(actor) + self._renderer.figure.viewer.add_actor(actor) if self.silhouette: mesh = self._layered_meshes[h] self._renderer._silhouette( @@ -664,6 +662,7 @@ def setup_time_viewer(self, time_viewer=True, show_traces=True): del show_traces self._configure_time_label() + # XXX:WIP self._configure_scalar_bar() self._configure_shortcuts() self._configure_picking() @@ -696,7 +695,7 @@ def _clean(self): self._layered_meshes[hemi]._clean() self._clear_callbacks() self._clear_widgets() - self.viewer._key_press_event_callbacks.clear() + self._renderer.figure.viewer._key_press_event_callbacks.clear() if getattr(self, 'mpl_canvas', None) is not None: self.mpl_canvas.clear() if getattr(self, 'act_data_smooth', None) is not None: @@ -707,14 +706,14 @@ def _clean(self): renderer.RemoveAllLights() # app_window cannot be set to None because it is used in __del__ for key in ('lighting', 'interactor', '_RenderWindow'): - setattr(self.plotter, key, None) + setattr(self._renderer.figure.plotter, key, None) # Qt LeaveEvent requires _Iren so we use _FakeIren instead of None # to resolve the ref to vtkGenericRenderWindowInteractor - self.plotter._Iren = _FakeIren() - if getattr(self.plotter, 'scalar_bar', None) is not None: - self.plotter.scalar_bar = None - if getattr(self.plotter, 'picker', None) is not None: - self.plotter.picker = None + self._renderer.figure.plotter._Iren = _FakeIren() + if getattr(self._renderer.figure.plotter, 'scalar_bar', None) is not None: + self._renderer.figure.plotter.scalar_bar = None + if getattr(self._renderer.figure.plotter, 'picker', None) is not None: + self._renderer.figure.plotter.picker = None # XXX end PyVista for key in ('plotter', 'main_menu', 'window', 'tool_bar', 'status_bar', 'interactor', 'mpl_canvas', 'time_actor', @@ -885,7 +884,7 @@ def _configure_time_label(self): def _configure_scalar_bar(self): if self._colorbar_added: - scalar_bar = self.viewer.scalar_bar + scalar_bar = self._renderer.figure.viewer.scalar_bar scalar_bar.SetOrientationToVertical() scalar_bar.SetHeight(0.6) scalar_bar.SetWidth(0.05) @@ -998,8 +997,8 @@ def _configure_dock_orientation_widget(self, name): if len(rends) > 1: def select_renderer(idx): idx = int(idx) - loc = self.viewer.index_to_loc(idx) - self.plotter.subplot(*loc) + loc = self._renderer._index_to_loc(idx) + self._renderer.subplot(*loc) self.callbacks["renderer"] = SmartCallBack( callback=select_renderer, @@ -1025,7 +1024,7 @@ def select_renderer(idx): orientation_data = [None] * len(rends) for hemi in hemis_ref: for ri, ci, view in self._iter_views(hemi): - idx = self.viewer.loc_to_index((ri, ci)) + idx = self._renderer._loc_to_index((ri, ci)) if view == 'flat': _data = None else: @@ -1309,7 +1308,7 @@ def _configure_vertex_time_course(self): # simulate a picked renderer if self._hemi in ('both', 'rh') or hemi == 'vol': idx = 0 - self.picked_renderer = self.plotter.renderers[idx] + self.picked_renderer = self._renderer.figure.plotter.renderers[idx] # initialize the default point if self._data['initial_time'] is not None: @@ -1372,7 +1371,7 @@ def _save_image(fname, img): except ImportError: from pyvistaqt.plotting import FileDialog FileDialog( - self.plotter.app_window, + self.window, callback=partial(_save_image, img=img) ) @@ -1454,21 +1453,21 @@ def _rotate_elevation(self, value): def _configure_shortcuts(self): # First, we remove the default bindings: - self.viewer._key_press_event_callbacks.clear() + self._renderer.figure.viewer._key_press_event_callbacks.clear() # Then, we add our own: - self.viewer.add_key_event("i", self.toggle_interface) - self.viewer.add_key_event("s", self.apply_auto_scaling) - self.viewer.add_key_event("r", self.restore_user_scaling) - self.viewer.add_key_event("c", self.clear_glyphs) - self.viewer.add_key_event("n", partial(self._shift_time, + self._renderer.figure.viewer.add_key_event("i", self.toggle_interface) + self._renderer.figure.viewer.add_key_event("s", self.apply_auto_scaling) + self._renderer.figure.viewer.add_key_event("r", self.restore_user_scaling) + self._renderer.figure.viewer.add_key_event("c", self.clear_glyphs) + self._renderer.figure.viewer.add_key_event("n", partial(self._shift_time, op=lambda x, y: x + y)) - self.viewer.add_key_event("b", partial(self._shift_time, + self._renderer.figure.viewer.add_key_event("b", partial(self._shift_time, op=lambda x, y: x - y)) for key, func, sign in (("Left", self._rotate_azimuth, 1), ("Right", self._rotate_azimuth, -1), ("Up", self._rotate_elevation, 1), ("Down", self._rotate_elevation, -1)): - self.viewer.add_key_event(key, partial(func, sign * _ARROW_MOVE)) + self._renderer.figure.viewer.add_key_event(key, partial(func, sign * _ARROW_MOVE)) # XXX:WIP # def _configure_menu(self): @@ -1503,9 +1502,9 @@ def _on_button_release(self, vtk_picker, event): if self._mouse_no_mvt > 0: x, y = vtk_picker.GetEventPosition() # programmatically detect the picked renderer - self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y) + self.picked_renderer = self._renderer.figure.plotter.iren.FindPokedRenderer(x, y) # trigger the pick - self.plotter.picker.Pick(x, y, 0, self.picked_renderer) + self._renderer.figure.plotter.picker.Pick(x, y, 0, self.picked_renderer) self._mouse_no_mvt = 0 def _on_pick(self, vtk_picker, event): @@ -1647,13 +1646,13 @@ def _add_vertex_glyph(self, hemi, mesh, vertex_id): del mesh # from the picked renderer to the subplot coords - rindex = self.plotter.renderers.index(self.picked_renderer) - row, col = self.plotter.index_to_loc(rindex) + rindex = self._renderer.figure.plotter.renderers.index(self.picked_renderer) + row, col = self._renderer._index_to_loc(rindex) actors = list() spheres = list() for ri, ci, _ in self._iter_views(hemi): - self.plotter.subplot(ri, ci) + self._renderer.subplot(ri, ci) # Using _sphere() instead of renderer.sphere() for 2 reasons: # 1) renderer.sphere() fails on Windows in a scenario where a lot # of picking requests are done in a short span of time (could be @@ -1701,7 +1700,7 @@ def _remove_vertex_glyph(self, mesh, render=True): self.color_cycle.restore(color) for sphere in spheres: # remove all actors - self.plotter.remove_actor(sphere._actors, render=render) + self._renderer.figure.plotter.remove_actor(sphere._actors, render=render) sphere._actors = None self._spheres.pop(self._spheres.index(sphere)) self.pick_table.pop(vertex_id) @@ -2238,16 +2237,16 @@ def _add_volume_data(self, hemi, src, volume_options): self._data[hemi]['grid_shape'] = dimensions self._data[hemi]['grid_volume_pos'] = volume_pos self._data[hemi]['grid_volume_neg'] = volume_neg - actor_pos, _ = self._renderer.plotter.add_actor( + actor_pos, _ = self._renderer.figure.viewer.add_actor( volume_pos, reset_camera=False, name=None, culling=False) if volume_neg is not None: - actor_neg, _ = self._renderer.plotter.add_actor( + actor_neg, _ = self._renderer.figure.viewer.add_actor( volume_neg, reset_camera=False, name=None, culling=False) else: actor_neg = None grid_mesh = self._data[hemi]['grid_mesh'] if grid_mesh is not None: - _, prop = self._renderer.plotter.add_actor( + _, prop = self._renderer.figure.viewer.add_actor( grid_mesh, reset_camera=False, name=None, culling=False, pickable=False) prop.SetColor(*self._brain_color[:3]) @@ -3047,7 +3046,7 @@ def _update_glyphs(self, hemi, vectors): prop = glyph_actor.GetProperty() prop.SetLineWidth(2.) prop.SetOpacity(vector_alpha) - self._renderer.plotter.add_actor(glyph_actor) + self._renderer.figure.viewer.add_actor(glyph_actor) hemi_data['glyph_actor'].append(glyph_actor) else: glyph_actor = hemi_data['glyph_actor'][count] diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index dce8df4b13a..90346e7b3a7 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -79,6 +79,11 @@ def build(self): plotter_class = Plotter else: plotter_class = MultiPlotter + self.store["menu_bar"] = False + self.store["toolbar"] = False + self.store["nrows"] = self.store["shape"][0] + self.store["ncols"] = self.store["shape"][1] + self.store.pop('shape', None) self.store.pop('show', None) self.store.pop('title', None) self.store.pop('auto_update', None) @@ -105,7 +110,6 @@ def build(self): _process_events(self.plotter) _process_events(self.plotter) - return self.plotter, self.viewer def is_active(self): if self.viewer is None: @@ -199,14 +203,10 @@ def __init__(self, fig=None, size=(600, 600), bgcolor='black', # smooth_shading=True fails on MacOS CIs self.figure.smooth_shading = False with _disabled_depth_peeling(): - self.plotter, self.viewer = self.figure.build() - self.viewer.hide_axes() - if hasattr(self.viewer, "default_camera_tool_bar"): - self.viewer.default_camera_tool_bar.close() - if hasattr(self.viewer, "saved_cameras_tool_bar"): - self.viewer.saved_cameras_tool_bar.close() + self.figure.build() + self.figure.viewer.hide_axes() if self.antialias: - _enable_aa(self.figure, self.viewer) + _enable_aa(self.figure, self.figure.viewer) self.update_lighting() @@ -216,11 +216,11 @@ def _get_screenshot_filename(self): return "MNE" + dt_string + ".png" def _get_all_renderers(self): - return _get_all_renderers(self.plotter) + return _get_all_renderers(self.figure.plotter) def _update(self): if self.figure.notebook: - self.plotter.update() + self.figure.plotter.update() # XXX:WIP # @contextmanager @@ -252,11 +252,27 @@ def _update(self): # _process_events(self.plotter) # _process_events(self.plotter) + def _index_to_loc(self, idx): + if self.figure.notebook: + return self.figure.plotter.index_to_loc(idx) + else: + _ncols = self.figure.store["ncols"] + row = idx // _ncols + col = idx % _ncols + return (row, col) + + def _loc_to_index(self, loc): + if self.figure.notebook: + return self.figure.plotter.loc_to_index(loc) + else: + _ncols = self.figure.store["ncols"] + return loc[0] * _ncols + loc[1] + def _subplot(self, x, y): if self.figure.notebook: - self.plotter.subplot(x, y) + self.figure.plotter.subplot(x, y) else: - self.viewer = self.plotter[x, y] + self.figure.viewer = self.figure.plotter[x, y] def subplot(self, x, y): x = np.max([0, np.min([x, self.shape[0] - 1])]) @@ -265,7 +281,7 @@ def subplot(self, x, y): warnings.filterwarnings("ignore", category=FutureWarning) self._subplot(x, y) if self.antialias: - _enable_aa(self.figure, self.viewer) + _enable_aa(self.figure, self.figure.viewer) def scene(self): return self.figure @@ -338,7 +354,7 @@ def polydata(self, mesh, color=None, opacity=1.0, normals=None, rgba = kwargs["rgba"] kwargs.pop('rgba') actor = _add_mesh( - plotter=self.viewer, + plotter=self.figure.viewer, mesh=mesh, color=color, scalars=scalars, rgba=rgba, opacity=opacity, cmap=colormap, backface_culling=backface_culling, @@ -403,7 +419,7 @@ def contour(self, surface, scalars, contours, width=1.0, opacity=1.0, contour = contour.tube(radius=width, n_sides=self.tube_n_sides) line_width = 1.0 actor = _add_mesh( - plotter=self.viewer, + plotter=self.figure.viewer, mesh=contour, show_scalar_bar=False, line_width=line_width, @@ -459,7 +475,7 @@ def sphere(self, center, color, scale, opacity=1.0, glyph = mesh.glyph(orient=False, scale=False, factor=factor, geom=geom) actor = _add_mesh( - plotter=self.viewer, + plotter=self.figure.viewer, mesh=glyph, color=color, opacity=opacity, backface_culling=backface_culling, smooth_shading=self.figure.smooth_shading @@ -482,7 +498,7 @@ def tube(self, origin, destination, radius=0.001, color='white', scalars = None tube = line.tube(radius, n_sides=self.tube_n_sides) _add_mesh( - plotter=self.viewer, + plotter=self.figure.viewer, mesh=tube, scalars=scalars, flip_scalars=reverse_lut, @@ -571,7 +587,7 @@ def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8, mesh = grid.glyph(orient='vec', scale=scale, factor=factor, geom=geom) _add_mesh( - plotter=self.viewer, + plotter=self.figure.viewer, mesh=mesh, color=color, opacity=opacity, @@ -584,7 +600,7 @@ def text2d(self, x_window, y_window, text, size=14, color='white', position = (x_window, y_window) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) - actor = self.viewer.add_text( + actor = self.figure.viewer.add_text( text, position=position, font_size=size, font=self.font_family, color=color, viewport=True) if isinstance(justification, str): @@ -612,9 +628,9 @@ def text3d(self, x, y, z, text, scale, color='white'): name=text, shape_opacity=0, ) - if 'always_visible' in _get_args(self.viewer.add_point_labels): + if 'always_visible' in _get_args(self.figure.viewer.add_point_labels): kwargs['always_visible'] = True - self.viewer.add_point_labels(**kwargs) + self.figure.viewer.add_point_labels(**kwargs) def scalarbar(self, source, color="white", title=None, n_labels=4, bgcolor=None, **extra_kwargs): @@ -626,10 +642,10 @@ def scalarbar(self, source, color="white", title=None, n_labels=4, label_font_size=22, font_family=self.font_family, background_color=bgcolor) kwargs.update(extra_kwargs) - self.viewer.add_scalar_bar(**kwargs) + self.figure.viewer.add_scalar_bar(**kwargs) def show(self): - self.figure.display = self.plotter.show() + self.figure.display = self.figure.plotter.show() # XXX:WIP # if hasattr(self.plotter, "app_window"): # with _qt_disable_paint(self.plotter): @@ -649,17 +665,17 @@ def set_camera(self, azimuth=None, elevation=None, distance=None, reset_camera=reset_camera, rigid=rigid) def reset_camera(self): - self.viewer.reset_camera() + self.figure.viewer.reset_camera() def screenshot(self, mode='rgb', filename=None): return _take_3d_screenshot(figure=self.figure, mode=mode, filename=filename) def project(self, xyz, ch_names): - xy = _3d_to_2d(self.viewer, xyz) + xy = _3d_to_2d(self.figure.viewer, xyz) xy = dict(zip(ch_names, xy)) # pts = self.fig.children[-1] - pts = self.viewer.renderer.GetActors().GetLastItem() + pts = self.figure.viewer.renderer.GetActors().GetLastItem() return _Projection(xy=xy, pts=pts) @@ -670,18 +686,18 @@ def enable_depth_peeling(self): def remove_mesh(self, mesh_data): actor, _ = mesh_data - self.viewer.remove_actor(actor) + self.figure.viewer.remove_actor(actor) @contextmanager def _disabled_interaction(self): - if not self.viewer.renderer.GetInteractive(): + if not self.figure.viewer.renderer.GetInteractive(): yield else: - self.viewer.disable() + self.figure.viewer.disable() try: yield finally: - self.viewer.enable() + self.figure.viewer.enable() def _actor(self, mapper=None): actor = vtk.vtkActor() @@ -690,31 +706,31 @@ def _actor(self, mapper=None): return actor def _process_events(self): - _process_events(self.plotter) + _process_events(self.figure.plotter) def _update_picking_callback(self, on_mouse_move, on_button_press, on_button_release, on_pick): - self.viewer.iren.AddObserver( + self.figure.viewer.iren.AddObserver( vtk.vtkCommand.RenderEvent, on_mouse_move ) - self.viewer.iren.AddObserver( + self.figure.viewer.iren.AddObserver( vtk.vtkCommand.LeftButtonPressEvent, on_button_press ) - self.viewer.iren.AddObserver( + self.figure.viewer.iren.AddObserver( vtk.vtkCommand.EndInteractionEvent, on_button_release ) - self.viewer.picker = vtk.vtkCellPicker() - self.viewer.picker.AddObserver( + self.figure.viewer.picker = vtk.vtkCellPicker() + self.figure.viewer.picker.AddObserver( vtk.vtkCommand.EndPickEvent, on_pick ) - self.viewer.picker.SetVolumeOpacityIsovalue(0.) + self.figure.viewer.picker.SetVolumeOpacityIsovalue(0.) def _set_mesh_scalars(self, mesh, scalars, name): # Catch: FutureWarning: Conversion of the second argument of @@ -767,7 +783,7 @@ def _sphere(self, center, color, radius): sphere.Update() mesh = pyvista.wrap(sphere.GetOutput()) actor = _add_mesh( - plotter=self.viewer, + plotter=self.figure.viewer, mesh=mesh, color=color ) @@ -839,12 +855,12 @@ def _silhouette(self, mesh, color=None, line_width=None, alpha=None, mesh = mesh.decimate(decimate) if decimate is not None else mesh silhouette_filter = vtk.vtkPolyDataSilhouette() silhouette_filter.SetInputData(mesh) - silhouette_filter.SetCamera(self.viewer.renderer.GetActiveCamera()) + silhouette_filter.SetCamera(self.figure.viewer.renderer.GetActiveCamera()) silhouette_filter.SetEnableFeatureAngle(0) silhouette_mapper = vtk.vtkPolyDataMapper() silhouette_mapper.SetInputConnection( silhouette_filter.GetOutputPort()) - _, prop = self.viewer.add_actor( + _, prop = self.figure.viewer.add_actor( silhouette_mapper, reset_camera=False, name=None, culling=False, pickable=False) if color is not None: diff --git a/mne/viz/backends/_qt.py b/mne/viz/backends/_qt.py index 25567ad1591..af77c02ce9e 100644 --- a/mne/viz/backends/_qt.py +++ b/mne/viz/backends/_qt.py @@ -34,7 +34,7 @@ def _dock_initialize(self): self.scroll.setWidgetResizable(True) self.dock.setAllowedAreas(Qt.LeftDockWidgetArea) self.dock.setFeatures(QDockWidget.NoDockWidgetFeatures) - self.plotter._window.addDockWidget(Qt.LeftDockWidgetArea, self.dock) + self.figure.plotter._window.addDockWidget(Qt.LeftDockWidgetArea, self.dock) self.dock_layout = QVBoxLayout() widget.setLayout(self.dock_layout) @@ -226,7 +226,7 @@ def _tool_bar_load_icons(self): def _tool_bar_initialize(self, name="default"): self.actions = dict() - self.tool_bar = self.plotter._window.addToolBar(name) + self.tool_bar = self.figure.plotter._window.addToolBar(name) def _tool_bar_finalize(self): pass