From 5bf39ff3454d15404304917193ea36841acbac55 Mon Sep 17 00:00:00 2001 From: Guillaume Favelier Date: Wed, 19 May 2021 11:53:45 +0200 Subject: [PATCH 01/12] Reduce number of calls to _update --- mne/viz/_brain/_brain.py | 47 ++++++++++++++++++------------------ mne/viz/backends/_pyvista.py | 18 +++++++++++--- 2 files changed, 39 insertions(+), 26 deletions(-) diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index f9002a59b13..d76e9ad0cba 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -529,16 +529,19 @@ def __init__(self, subject_id, hemi, surf, title=None, self.interaction = interaction self._closed = False - if show: - self.show() + # update the views once the geometry is all set - for h in self._hemis: - for ri, ci, v in self._iter_views(h): - self.show_view(v, row=ri, col=ci, hemi=h) + with self._renderer._disabled_update(): + for h in self._hemis: + for ri, ci, v in self._iter_views(h): + self.show_view(v, row=ri, col=ci, hemi=h) if surf == 'flat': self._renderer.set_interaction("rubber_band_2d") + if show: + self._renderer.show() + def _setup_canonical_rotation(self): from ...coreg import fit_matched_points, _trans_from_params self._rigid = np.eye(4) @@ -652,27 +655,25 @@ def setup_time_viewer(self, time_viewer=True, show_traces=True): self.separate_canvas = False del show_traces - self._configure_time_label() - self._configure_scalar_bar() - self._configure_shortcuts() - self._configure_picking() - self._configure_tool_bar() - self._configure_dock() - self._configure_menu() - self._configure_status_bar() - self._configure_playback() - self._configure_help() + with self._renderer._disabled_update(): + self._configure_time_label() + self._configure_scalar_bar() + self._configure_shortcuts() + self._configure_picking() + self._configure_tool_bar() + self._configure_dock() + self._configure_menu() + self._configure_status_bar() + self._configure_playback() + self._configure_help() + # sizes could change, update views + for hemi in ('lh', 'rh'): + for ri, ci, v in self._iter_views(hemi): + self.show_view(view=v, row=ri, col=ci) + self._renderer._process_events() # show everything at the end - self.toggle_interface() self._renderer.show() - # sizes could change, update views - for hemi in ('lh', 'rh'): - for ri, ci, v in self._iter_views(hemi): - self.show_view(view=v, row=ri, col=ci) - self._renderer._process_events() - - self._renderer._update() # finally, show the MplCanvas if self.show_traces: self.mpl_canvas.show() diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index cf942c48a0c..52bd24d9ac7 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -26,7 +26,7 @@ ALLOWED_QUIVER_MODES, _init_qt_resources) from ...fixes import _get_args from ...transforms import apply_trans -from ...utils import copy_base_doc_to_subclass_doc, _check_option +from ...utils import copy_base_doc_to_subclass_doc, _check_option, logger with warnings.catch_warnings(): @@ -152,6 +152,7 @@ def __init__(self, fig=None, size=(600, 600), bgcolor='black', figure = _Figure(show=show, title=name, size=size, shape=shape, background_color=bgcolor, notebook=notebook, smooth_shading=smooth_shading) + self._to_update = True self.font_family = "arial" self.tube_n_sides = 20 antialias = _get_3d_option('antialias') @@ -204,8 +205,10 @@ def _hide_axes(self): renderer.hide_axes() def _update(self): - for plotter in self._all_plotters: - plotter.update() + if self._to_update: + logger.debug('Update renderer') + for plotter in self._all_plotters: + plotter.update() def _index_to_loc(self, idx): _ncols = self.figure._ncols @@ -646,6 +649,15 @@ def remove_mesh(self, mesh_data): actor, _ = mesh_data self.plotter.remove_actor(actor) + @contextmanager + def _disabled_update(self): + to_update = self._to_update + self._to_update = False + try: + yield + finally: + self._to_update = to_update + @contextmanager def _disabled_interaction(self): if not self.plotter.renderer.GetInteractive(): From cb128401fe003d736d3925a222d0debfcf23d223 Mon Sep 17 00:00:00 2001 From: Guillaume Favelier Date: Wed, 19 May 2021 17:04:10 +0200 Subject: [PATCH 02/12] Revert "Reduce number of calls to _update" This reverts commit 5bf39ff3454d15404304917193ea36841acbac55. --- mne/viz/_brain/_brain.py | 47 ++++++++++++++++++------------------ mne/viz/backends/_pyvista.py | 18 +++----------- 2 files changed, 26 insertions(+), 39 deletions(-) diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index d76e9ad0cba..f9002a59b13 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -529,19 +529,16 @@ def __init__(self, subject_id, hemi, surf, title=None, self.interaction = interaction self._closed = False - + if show: + self.show() # update the views once the geometry is all set - with self._renderer._disabled_update(): - for h in self._hemis: - for ri, ci, v in self._iter_views(h): - self.show_view(v, row=ri, col=ci, hemi=h) + for h in self._hemis: + for ri, ci, v in self._iter_views(h): + self.show_view(v, row=ri, col=ci, hemi=h) if surf == 'flat': self._renderer.set_interaction("rubber_band_2d") - if show: - self._renderer.show() - def _setup_canonical_rotation(self): from ...coreg import fit_matched_points, _trans_from_params self._rigid = np.eye(4) @@ -655,25 +652,27 @@ def setup_time_viewer(self, time_viewer=True, show_traces=True): self.separate_canvas = False del show_traces - with self._renderer._disabled_update(): - self._configure_time_label() - self._configure_scalar_bar() - self._configure_shortcuts() - self._configure_picking() - self._configure_tool_bar() - self._configure_dock() - self._configure_menu() - self._configure_status_bar() - self._configure_playback() - self._configure_help() - # sizes could change, update views - for hemi in ('lh', 'rh'): - for ri, ci, v in self._iter_views(hemi): - self.show_view(view=v, row=ri, col=ci) - self._renderer._process_events() + self._configure_time_label() + self._configure_scalar_bar() + self._configure_shortcuts() + self._configure_picking() + self._configure_tool_bar() + self._configure_dock() + self._configure_menu() + self._configure_status_bar() + self._configure_playback() + self._configure_help() # show everything at the end + self.toggle_interface() self._renderer.show() + # sizes could change, update views + for hemi in ('lh', 'rh'): + for ri, ci, v in self._iter_views(hemi): + self.show_view(view=v, row=ri, col=ci) + self._renderer._process_events() + + self._renderer._update() # finally, show the MplCanvas if self.show_traces: self.mpl_canvas.show() diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index 52bd24d9ac7..cf942c48a0c 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -26,7 +26,7 @@ ALLOWED_QUIVER_MODES, _init_qt_resources) from ...fixes import _get_args from ...transforms import apply_trans -from ...utils import copy_base_doc_to_subclass_doc, _check_option, logger +from ...utils import copy_base_doc_to_subclass_doc, _check_option with warnings.catch_warnings(): @@ -152,7 +152,6 @@ def __init__(self, fig=None, size=(600, 600), bgcolor='black', figure = _Figure(show=show, title=name, size=size, shape=shape, background_color=bgcolor, notebook=notebook, smooth_shading=smooth_shading) - self._to_update = True self.font_family = "arial" self.tube_n_sides = 20 antialias = _get_3d_option('antialias') @@ -205,10 +204,8 @@ def _hide_axes(self): renderer.hide_axes() def _update(self): - if self._to_update: - logger.debug('Update renderer') - for plotter in self._all_plotters: - plotter.update() + for plotter in self._all_plotters: + plotter.update() def _index_to_loc(self, idx): _ncols = self.figure._ncols @@ -649,15 +646,6 @@ def remove_mesh(self, mesh_data): actor, _ = mesh_data self.plotter.remove_actor(actor) - @contextmanager - def _disabled_update(self): - to_update = self._to_update - self._to_update = False - try: - yield - finally: - self._to_update = to_update - @contextmanager def _disabled_interaction(self): if not self.plotter.renderer.GetInteractive(): From 8067b94c929184c4e450bdcf282493df501784c9 Mon Sep 17 00:00:00 2001 From: Guillaume Favelier Date: Thu, 20 May 2021 12:53:12 +0200 Subject: [PATCH 03/12] Improve overlay caching --- mne/viz/_brain/_brain.py | 53 +++++++++++++++++++----------- mne/viz/_brain/tests/test_brain.py | 27 +++++++++++---- 2 files changed, 55 insertions(+), 25 deletions(-) diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index f9002a59b13..9623a0a0da4 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -105,7 +105,8 @@ def __init__(self, renderer, vertices, triangles, normals): self._actor = None self._is_mapped = False - self._cache = None + self._current_colors = None + self._cached_colors = None self._overlays = OrderedDict() self._default_scalars = np.ones(vertices.shape) @@ -142,14 +143,15 @@ def _compute_over(self, B, A): return np.clip(C, 0, 1, out=C) def _compose_overlays(self): - B = None + B = cache = None for overlay in self._overlays.values(): A = overlay.to_colors() if B is None: B = A else: - B = self._compute_over(B, A) - return B + cache = B + B = self._compute_over(cache, A) + return B, cache def add_overlay(self, scalars, colormap, rng, opacity, name): overlay = _Overlay( @@ -161,36 +163,45 @@ def add_overlay(self, scalars, colormap, rng, opacity, name): ) self._overlays[name] = overlay colors = overlay.to_colors() - - # save colors in cache - if self._cache is None: - self._cache = colors + if self._current_colors is None: + self._current_colors = colors else: - self._cache = self._compute_over(self._cache, colors) + # save previous colors to cache + self._cached_colors = self._current_colors + self._current_colors = self._compute_over( + self._cached_colors, colors) - # update the texture - self._update() + # apply the texture + self._apply() def remove_overlay(self, names): + to_update = False if not isinstance(names, list): names = [names] for name in names: if name in self._overlays: del self._overlays[name] - self.update() + to_update = True + if to_update: + self.update() - def _update(self): - if self._cache is None or self._renderer is None: + def _apply(self): + if self._current_colors is None or self._renderer is None: return self._renderer._set_mesh_scalars( mesh=self._polydata, - scalars=self._cache, + scalars=self._current_colors, name=self._default_scalars_name, ) - def update(self): - self._cache = self._compose_overlays() - self._update() + def update(self, colors=None): + if colors is not None and self._cached_colors is not None: + self._current_colors = self._compute_over( + self._cached_colors, colors) + else: + self._current_colors, self._cached_colors = \ + self._compose_overlays() + self._apply() def _clean(self): mapper = self._actor.GetMapper() @@ -213,7 +224,11 @@ def update_overlay(self, name, scalars=None, colormap=None, overlay._opacity = opacity if rng is not None: overlay._rng = rng - self.update() + # partial update: use cache if possible + if name == list(self._overlays.keys())[-1]: + self.update(colors=overlay.to_colors()) + else: # full update + self.update() @fill_doc diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index 8dc280baa60..b20da49c8c0 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -115,7 +115,8 @@ def test_layered_mesh(renderer_interactive_pyvista): assert not mesh._is_mapped mesh.map() assert mesh._is_mapped - assert mesh._cache is None + assert mesh._current_colors is None + assert mesh._cached_colors is None mesh.update() assert len(mesh._overlays) == 0 mesh.add_overlay( @@ -123,13 +124,27 @@ def test_layered_mesh(renderer_interactive_pyvista): colormap=np.array([(1, 1, 1, 1), (0, 0, 0, 0)]), rng=[0, 1], opacity=None, - name='test', + name='test1', ) - assert mesh._cache is not None + assert mesh._current_colors is not None + assert mesh._cached_colors is None + assert len(mesh._overlays) == 1 + assert 'test1' in mesh._overlays + mesh.add_overlay( + scalars=np.array([1, 0, 0, 1]), + colormap=np.array([(1, 1, 1, 1), (0, 0, 0, 0)]), + rng=[0, 1], + opacity=None, + name='test2', + ) + assert mesh._current_colors is not None + assert mesh._cached_colors is not None + assert len(mesh._overlays) == 2 + assert 'test2' in mesh._overlays + mesh.remove_overlay('test2') + assert 'test2' not in mesh._overlays + mesh.update() assert len(mesh._overlays) == 1 - assert 'test' in mesh._overlays - mesh.remove_overlay('test') - assert len(mesh._overlays) == 0 mesh._clean() From c2f69cebf9c53efea556e1821004cdd5d6b3b846 Mon Sep 17 00:00:00 2001 From: Guillaume Favelier Date: Fri, 21 May 2021 16:46:07 +0200 Subject: [PATCH 04/12] Disable plot updates at init --- mne/viz/_brain/_brain.py | 18 +++++++++++------- mne/viz/backends/_abstract.py | 10 ++++++---- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index 9623a0a0da4..90a41b3b965 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -1192,7 +1192,7 @@ def _configure_vertex_time_course(self): alpha=0.5, ls=':') # now plot the time line - self.plot_time_line() + self.plot_time_line(update=False) # then the picked points for idx, hemi in enumerate(['lh', 'rh', 'vol']): @@ -1221,7 +1221,7 @@ def _configure_vertex_time_course(self): else: mesh = self._layered_meshes[hemi]._polydata vertex_id = vertices[ind[0]] - self._add_vertex_glyph(hemi, mesh, vertex_id) + self._add_vertex_glyph(hemi, mesh, vertex_id, update=False) def _configure_picking(self): # get data for each hemi @@ -1491,7 +1491,7 @@ def _remove_label_glyph(self, hemi, label_id): self._layered_meshes[hemi].remove_overlay(label.name) self.picked_patches[hemi].remove(label_id) - def _add_vertex_glyph(self, hemi, mesh, vertex_id): + def _add_vertex_glyph(self, hemi, mesh, vertex_id, update=True): if vertex_id in self.picked_points[hemi]: return @@ -1499,7 +1499,7 @@ def _add_vertex_glyph(self, hemi, mesh, vertex_id): if self.act_data_smooth[hemi][0] is None: return color = next(self.color_cycle) - line = self.plot_time_course(hemi, vertex_id, color) + line = self.plot_time_course(hemi, vertex_id, color, update=update) if hemi == 'vol': ijk = np.unravel_index( vertex_id, np.array(mesh.GetDimensions()) - 1, order='F') @@ -1599,7 +1599,7 @@ def clear_glyphs(self): self.rms = None self._renderer._update() - def plot_time_course(self, hemi, vertex_id, color): + def plot_time_course(self, hemi, vertex_id, color, update=True): """Plot the vertex time course. Parameters @@ -1610,6 +1610,8 @@ def plot_time_course(self, hemi, vertex_id, color): The vertex identifier in the mesh. color : matplotlib color The color of the time course. + update : bool + Force an update of the plot. Defaults to True. Returns ------- @@ -1658,10 +1660,11 @@ def plot_time_course(self, hemi, vertex_id, color): lw=1., color=color, zorder=4, + update=update, ) return line - def plot_time_line(self): + def plot_time_line(self, update=True): """Add the time line to the MPL widget.""" if self.mpl_canvas is None: return @@ -1676,7 +1679,8 @@ def plot_time_line(self): lw=1, ) self.time_line.set_xdata(current_time) - self.mpl_canvas.update_plot() + if update: + self.mpl_canvas.update_plot() def _configure_help(self): pairs = [ diff --git a/mne/viz/backends/_abstract.py b/mne/viz/backends/_abstract.py index a101f50d62f..11f4c79ecf1 100644 --- a/mne/viz/backends/_abstract.py +++ b/mne/viz/backends/_abstract.py @@ -655,17 +655,19 @@ def _connect(self): self.canvas.mpl_connect( event + '_event', getattr(self, 'on_' + event)) - def plot(self, x, y, label, **kwargs): + def plot(self, x, y, label, update=True, **kwargs): """Plot a curve.""" line, = self.axes.plot( x, y, label=label, **kwargs) - self.update_plot() + if update: + self.update_plot() return line - def plot_time_line(self, x, label, **kwargs): + def plot_time_line(self, x, label, update=True, **kwargs): """Plot the vertical line.""" line = self.axes.axvline(x, label=label, **kwargs) - self.update_plot() + if update: + self.update_plot() return line def update_plot(self): From 0a31936b1305f730739a6882ddaa4f06ae33e3d9 Mon Sep 17 00:00:00 2001 From: Guillaume Favelier Date: Fri, 21 May 2021 17:03:01 +0200 Subject: [PATCH 05/12] Disable more --- mne/viz/_brain/_brain.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index 90a41b3b965..36b9a02ac34 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -1665,7 +1665,13 @@ def plot_time_course(self, hemi, vertex_id, color, update=True): return line def plot_time_line(self, update=True): - """Add the time line to the MPL widget.""" + """Add the time line to the MPL widget. + + Parameters + ---------- + update : bool + Force an update of the plot. Defaults to True. + """ if self.mpl_canvas is None: return if isinstance(self.show_traces, bool) and self.show_traces: @@ -1677,6 +1683,7 @@ def plot_time_line(self, update=True): label='time', color=self._fg_color, lw=1, + update=update, ) self.time_line.set_xdata(current_time) if update: @@ -2413,7 +2420,7 @@ def _configure_label_time_course(self): self.add_annotation(self.annot, color="w", alpha=0.75) # now plot the time line - self.plot_time_line() + self.plot_time_line(update=False) self.mpl_canvas.update_plot() for hemi in self._hemis: From 412d45d66332fff73082d48251569261347d0a3e Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 21 May 2021 14:33:09 -0400 Subject: [PATCH 06/12] ENH: Speed up vertex_to_mni --- mne/source_space.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/mne/source_space.py b/mne/source_space.py index 0e1186e798b..df079af1e62 100644 --- a/mne/source_space.py +++ b/mne/source_space.py @@ -1358,21 +1358,21 @@ def read_talxfm(subject, subjects_dir=None, verbose=None): def _read_mri_info(path, units='m', return_img=False): - if has_nibabel(): - import nibabel - mgz = nibabel.load(path) - hdr = mgz.header - n_orig = hdr.get_vox2ras() - t_orig = hdr.get_vox2ras_tkr() - dims = hdr.get_data_shape() - zooms = hdr.get_zooms()[:3] - else: - mgz = None - hdr = _get_mgz_header(path) - n_orig = hdr['vox2ras'] - t_orig = hdr['vox2ras_tkr'] - dims = hdr['dims'] - zooms = hdr['zooms'] + # This is equivalent but 100x slower, so only use nibabel if we need to + # (later): + # + # import nibabel + # hdr = nibabel.load(path).header + # hdr = mgz.header + # n_orig = hdr.get_vox2ras() + # t_orig = hdr.get_vox2ras_tkr() + # dims = hdr.get_data_shape() + # zooms = hdr.get_zooms()[:3] + hdr = _get_mgz_header(path) + n_orig = hdr['vox2ras'] + t_orig = hdr['vox2ras_tkr'] + dims = hdr['dims'] + zooms = hdr['zooms'] # extract the MRI_VOXEL to RAS (non-zero origin) transform vox_ras_t = Transform('mri_voxel', 'ras', n_orig) @@ -1395,7 +1395,8 @@ def _read_mri_info(path, units='m', return_img=False): out = (vox_ras_t, vox_mri_t, mri_ras_t, dims, zooms) if return_img: - out += (mgz,) + nibabel = _import_nibabel() + out += (nibabel.load(path),) return out From ae61bc2b00f6f7726976ce2a53cacb06f2ac8b1b Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 21 May 2021 14:43:05 -0400 Subject: [PATCH 07/12] FIX: Flake --- mne/source_space.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/source_space.py b/mne/source_space.py index df079af1e62..733cd320519 100644 --- a/mne/source_space.py +++ b/mne/source_space.py @@ -33,7 +33,7 @@ _CheckInside) from .utils import (get_subjects_dir, check_fname, logger, verbose, fill_doc, _ensure_int, check_version, _get_call_line, warn, - _check_fname, _check_path_like, has_nibabel, _check_sphere, + _check_fname, _check_path_like, _check_sphere, _validate_type, _check_option, _is_numeric, _pl, _suggest, object_size, sizeof_fmt) from .parallel import parallel_func, check_n_jobs From 9416de284a975b1807cf26a39ae42f341deba703 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 21 May 2021 15:32:29 -0400 Subject: [PATCH 08/12] FIX: Fix path --- mne/source_space.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mne/source_space.py b/mne/source_space.py index 733cd320519..d2cb29ba5c1 100644 --- a/mne/source_space.py +++ b/mne/source_space.py @@ -2251,6 +2251,8 @@ def _vol_vertex(width, height, jj, kk, pp): def _get_mgz_header(fname): """Adapted from nibabel to quickly extract header info.""" + fname = _check_fname(fname, overwrite='read', must_exist=True, + name='MRI image') if not fname.endswith('.mgz'): raise IOError('Filename must end with .mgz') header_dtd = [('version', '>i4'), ('dims', '>i4', (4,)), From 04893caa0d6a042647b4b8daf4087caa485b1a35 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 21 May 2021 15:43:25 -0400 Subject: [PATCH 09/12] ENH: Cache mesh_tris --- mne/morph.py | 2 +- mne/surface.py | 10 ++++++++-- mne/utils/__init__.py | 2 +- mne/utils/numerics.py | 12 ++++++++++++ mne/viz/_brain/_brain.py | 4 +--- 5 files changed, 23 insertions(+), 7 deletions(-) diff --git a/mne/morph.py b/mne/morph.py index a1e9f031c95..80a373cdafe 100644 --- a/mne/morph.py +++ b/mne/morph.py @@ -1142,7 +1142,7 @@ def _hemi_morph(tris, vertices_to, vertices_from, smooth, maps, warn): e = mesh_edges(tris) e.data[e.data == 2] = 1 n_vertices = e.shape[0] - e = e + sparse.eye(n_vertices) + e += sparse.eye(n_vertices, format='csr') if isinstance(smooth, str): _check_option('smooth', smooth, ('nearest',), extra=' when used as a string.') diff --git a/mne/surface.py b/mne/surface.py index 9dda2cf2080..58ea6302c65 100644 --- a/mne/surface.py +++ b/mne/surface.py @@ -10,7 +10,7 @@ from copy import deepcopy from distutils.version import LooseVersion -from functools import partial +from functools import partial, lru_cache from glob import glob from os import path as op from struct import pack @@ -28,7 +28,7 @@ _get_trans, apply_trans, Transform) from .utils import (logger, verbose, get_subjects_dir, warn, _check_fname, _check_option, _ensure_int, _TempDir, run_subprocess, - _check_freesurfer_home) + _check_freesurfer_home, _hashable_ndarray) ############################################################################### @@ -1451,6 +1451,12 @@ def mesh_edges(tris): edges : sparse matrix The adjacency matrix. """ + tris = _hashable_ndarray(tris) + return _mesh_edges(tris=tris) + + +@lru_cache(maxsize=10) +def _mesh_edges(tris=None): from scipy.sparse import coo_matrix if np.max(tris) > len(np.unique(tris)): raise ValueError( diff --git a/mne/utils/__init__.py b/mne/utils/__init__.py index bb41ad9fdcf..628cdee714f 100644 --- a/mne/utils/__init__.py +++ b/mne/utils/__init__.py @@ -60,7 +60,7 @@ _mask_to_onsets_offsets, _array_equal_nan, _julian_to_cal, _cal_to_julian, _dt_to_julian, _julian_to_dt, _dt_to_stamp, _stamp_to_dt, - _check_dt, _ReuseCycle, _arange_div) + _check_dt, _ReuseCycle, _arange_div, _hashable_ndarray) from .mixin import (SizeMixin, GetEpochsMixin, _prepare_read_metadata, _prepare_write_metadata, _FakeNoPandas, ShiftTimeMixin) from .linalg import (_svd_lwork, _repeated_svd, _sym_mat_pow, sqrtm_sym, eigh, diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index 7c24daf44be..7f181ffcce5 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -605,6 +605,18 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True): return grand_average +class _HashableNdarray(np.ndarray): + def __hash__(self): + return object_hash(self) + + def __eq__(self, other): + return NotImplementedError # defer to hash + + +def _hashable_ndarray(x): + return x.view(_HashableNdarray) + + def object_hash(x, h=None): """Hash a reasonable python object. diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index 36b9a02ac34..ae3c1212200 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -2762,7 +2762,6 @@ def set_data_smoothing(self, n_steps): n_steps : int Number of smoothing steps. """ - from scipy import sparse from ...morph import _hemi_morph for hemi in ['lh', 'rh']: hemi_data = self._data.get(hemi) @@ -2776,12 +2775,11 @@ def set_data_smoothing(self, n_steps): 'parameter must not be None' % (len(hemi_data), self.geo[hemi].x.shape[0])) morph_n_steps = 'nearest' if n_steps == -1 else n_steps - maps = sparse.eye(len(self.geo[hemi].coords), format='csr') with use_log_level(False): smooth_mat = _hemi_morph( self.geo[hemi].orig_faces, np.arange(len(self.geo[hemi].coords)), - vertices, morph_n_steps, maps, warn=False) + vertices, morph_n_steps, maps=None, warn=False) self._data[hemi]['smooth_mat'] = smooth_mat self.set_time_point(self._data['time_idx']) self._data['smoothing_steps'] = n_steps From b2d187d9931895859087fda13df50cef963a080f Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 21 May 2021 16:21:57 -0400 Subject: [PATCH 10/12] ENH: Defer updates and avoid reset --- mne/surface.py | 39 +++++++++++++++------------- mne/viz/_brain/_brain.py | 8 +++--- mne/viz/_brain/surface.py | 3 ++- mne/viz/backends/_pysurfer_mayavi.py | 9 ++++--- mne/viz/backends/_pyvista.py | 12 +++++---- 5 files changed, 40 insertions(+), 31 deletions(-) diff --git a/mne/surface.py b/mne/surface.py index 58ea6302c65..7609c4009d4 100644 --- a/mne/surface.py +++ b/mne/surface.py @@ -343,7 +343,7 @@ def _normal_orth(nn): @verbose def complete_surface_info(surf, do_neighbor_vert=False, copy=True, - verbose=None): + do_neighbor_tri=True, *, verbose=None): """Complete surface information. Parameters @@ -351,9 +351,11 @@ def complete_surface_info(surf, do_neighbor_vert=False, copy=True, surf : dict The surface. do_neighbor_vert : bool - If True, add neighbor vertex information. + If True (default False), add neighbor vertex information. copy : bool If True (default), make a copy. If False, operate in-place. + do_neighbor_tri : bool + If True (default), compute triangle neighbors. %(verbose)s Returns @@ -383,27 +385,28 @@ def complete_surface_info(surf, do_neighbor_vert=False, copy=True, # Find neighboring triangles, accumulate vertex normals, normalize logger.info(' Triangle neighbors and vertex normals...') - surf['neighbor_tri'] = _triangle_neighbors(surf['tris'], surf['np']) surf['nn'] = _accumulate_normals(surf['tris'].astype(int), surf['tri_nn'], surf['np']) _normalize_vectors(surf['nn']) # Check for topological defects - zero, fewer = list(), list() - for ni, n in enumerate(surf['neighbor_tri']): - if len(n) < 3: - if len(n) == 0: - zero.append(ni) - else: - fewer.append(ni) - surf['neighbor_tri'][ni] = np.array([], int) - if len(zero) > 0: - logger.info(' Vertices do not have any neighboring ' - 'triangles: [%s]' % ', '.join(str(z) for z in zero)) - if len(fewer) > 0: - logger.info(' Vertices have fewer than three neighboring ' - 'triangles, removing neighbors: [%s]' - % ', '.join(str(f) for f in fewer)) + if do_neighbor_tri: + surf['neighbor_tri'] = _triangle_neighbors(surf['tris'], surf['np']) + zero, fewer = list(), list() + for ni, n in enumerate(surf['neighbor_tri']): + if len(n) < 3: + if len(n) == 0: + zero.append(ni) + else: + fewer.append(ni) + surf['neighbor_tri'][ni] = np.array([], int) + if len(zero) > 0: + logger.info(' Vertices do not have any neighboring ' + 'triangles: [%s]' % ', '.join(str(z) for z in zero)) + if len(fewer) > 0: + logger.info(' Vertices have fewer than three neighboring ' + 'triangles, removing neighbors: [%s]' + % ', '.join(str(f) for f in fewer)) # Determine the neighboring vertices and fix errors if do_neighbor_vert is True: diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index ae3c1212200..edb1901f228 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -540,7 +540,8 @@ def __init__(self, subject_id, hemi, surf, title=None, alpha=self._silhouette["alpha"], decimate=self._silhouette["decimate"], ) - self._renderer.set_camera(**views_dicts[h][v]) + self._renderer.set_camera(update=False, reset_camera=False, + **views_dicts[h][v]) self.interaction = interaction self._closed = False @@ -2009,7 +2010,8 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None, bgcolor=self._brain_color[:3]) kwargs.update(colorbar_kwargs or {}) self._scalar_bar = self._renderer.scalarbar(**kwargs) - self._renderer.set_camera(**views_dicts[hemi][v]) + self._renderer.set_camera( + update=False, reset_camera=False, **views_dicts[hemi][v]) # 4) update the scalar bar and opacity self.update_lut(alpha=alpha) @@ -2306,7 +2308,7 @@ def add_label(self, label, color=None, alpha=1, scalar_thresh=None, name=label_name, ) if reset_camera: - self._renderer.set_camera(**views_dicts[hemi][v]) + self._renderer.set_camera(update=False, **views_dicts[hemi][v]) if self.time_viewer and self.show_traces \ and self.traces_mode == 'label': label._color = orig_color diff --git a/mne/viz/_brain/surface.py b/mne/viz/_brain/surface.py index e2e1e512692..dc7ebea19ab 100644 --- a/mne/viz/_brain/surface.py +++ b/mne/viz/_brain/surface.py @@ -131,7 +131,8 @@ def load_geometry(self): else: coords -= (np.min(x_) + self.offset) * self.x_dir surf = dict(rr=coords, tris=faces) - complete_surface_info(surf, copy=False, verbose=False) + complete_surface_info( + surf, copy=False, verbose=False, do_neighbor_tri=False) nn = surf['nn'] self.coords = coords self.faces = faces diff --git a/mne/viz/backends/_pysurfer_mayavi.py b/mne/viz/backends/_pysurfer_mayavi.py index 24aa8475da9..00dd0bbb7f2 100644 --- a/mne/viz/backends/_pysurfer_mayavi.py +++ b/mne/viz/backends/_pysurfer_mayavi.py @@ -316,10 +316,10 @@ def close(self): def set_camera(self, azimuth=None, elevation=None, distance=None, focalpoint=None, roll=None, reset_camera=None, - rigid=None): + rigid=None, update=True): _set_3d_view(figure=self.fig, azimuth=azimuth, elevation=elevation, distance=distance, - focalpoint=focalpoint, roll=roll) + focalpoint=focalpoint, roll=roll, update=update) def reset_camera(self): renderer = getattr(self.fig.scene, 'renderer', None) @@ -454,13 +454,14 @@ def _close_all(): def _set_3d_view(figure, azimuth, elevation, focalpoint, distance, roll=None, - reset_camera=True): + reset_camera=True, update=True): from mayavi import mlab with warnings.catch_warnings(record=True): # traits with SilenceStdout(): mlab.view(azimuth, elevation, distance, focalpoint=focalpoint, figure=figure, roll=roll) - mlab.draw(figure) + if update: + mlab.draw(figure) def _set_3d_title(figure, title, size=40): diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index cf942c48a0c..92e2b352a42 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -598,10 +598,10 @@ def close(self): def set_camera(self, azimuth=None, elevation=None, distance=None, focalpoint='auto', roll=None, reset_camera=True, - rigid=None): + rigid=None, update=True): _set_3d_view(self.figure, azimuth=azimuth, elevation=elevation, distance=distance, focalpoint=focalpoint, roll=roll, - reset_camera=reset_camera, rigid=rigid) + reset_camera=reset_camera, rigid=rigid, update=update) def reset_camera(self): self.plotter.reset_camera() @@ -930,7 +930,8 @@ def _get_camera_direction(focalpoint, position): def _set_3d_view(figure, azimuth=None, elevation=None, focalpoint='auto', - distance=None, roll=None, reset_camera=True, rigid=None): + distance=None, roll=None, reset_camera=True, rigid=None, + update=True): rigid = np.eye(4) if rigid is None else rigid position = np.array(figure.plotter.camera_position[0]) bounds = np.array(figure.plotter.renderer.ComputeVisiblePropBounds()) @@ -991,8 +992,9 @@ def _set_3d_view(figure, azimuth=None, elevation=None, focalpoint='auto', if roll is not None: figure.plotter.camera.SetRoll(figure.plotter.camera.GetRoll() + roll) - figure.plotter.update() - _process_events(figure.plotter) + if update: + figure.plotter.update() + _process_events(figure.plotter) def _set_3d_title(figure, title, size=16): From 9931f3891525dde9355d0c6e0f84aa6bf979b058 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 21 May 2021 17:16:58 -0400 Subject: [PATCH 11/12] TST: Fix test --- mne/source_space.py | 28 ++++++++++++++-------------- mne/tests/test_source_space.py | 6 +++++- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/mne/source_space.py b/mne/source_space.py index d2cb29ba5c1..b62ac4e432b 100644 --- a/mne/source_space.py +++ b/mne/source_space.py @@ -1357,22 +1357,22 @@ def read_talxfm(subject, subjects_dir=None, verbose=None): return mri_mni_t -def _read_mri_info(path, units='m', return_img=False): +def _read_mri_info(path, units='m', return_img=False, use_nibabel=False): # This is equivalent but 100x slower, so only use nibabel if we need to # (later): - # - # import nibabel - # hdr = nibabel.load(path).header - # hdr = mgz.header - # n_orig = hdr.get_vox2ras() - # t_orig = hdr.get_vox2ras_tkr() - # dims = hdr.get_data_shape() - # zooms = hdr.get_zooms()[:3] - hdr = _get_mgz_header(path) - n_orig = hdr['vox2ras'] - t_orig = hdr['vox2ras_tkr'] - dims = hdr['dims'] - zooms = hdr['zooms'] + if use_nibabel: + import nibabel + hdr = nibabel.load(path).header + n_orig = hdr.get_vox2ras() + t_orig = hdr.get_vox2ras_tkr() + dims = hdr.get_data_shape() + zooms = hdr.get_zooms()[:3] + else: + hdr = _get_mgz_header(path) + n_orig = hdr['vox2ras'] + t_orig = hdr['vox2ras_tkr'] + dims = hdr['dims'] + zooms = hdr['zooms'] # extract the MRI_VOXEL to RAS (non-zero origin) transform vox_ras_t = Transform('mri_voxel', 'ras', n_orig) diff --git a/mne/tests/test_source_space.py b/mne/tests/test_source_space.py index cbf66a30485..589e5aa44f0 100644 --- a/mne/tests/test_source_space.py +++ b/mne/tests/test_source_space.py @@ -593,6 +593,7 @@ def test_head_to_mni(): assert_allclose(coords_MNI, coords_MNI_2, atol=10.0) +@requires_nibabel() @testing.requires_testing_data def test_vertex_to_mni_fs_nibabel(monkeypatch): """Test equivalence of vert_to_mni for nibabel and freesurfer.""" @@ -601,7 +602,10 @@ def test_vertex_to_mni_fs_nibabel(monkeypatch): vertices = rng.randint(0, 100000, n_check) hemis = rng.randint(0, 1, n_check) coords = vertex_to_mni(vertices, hemis, subject, subjects_dir) - monkeypatch.setattr(mne.source_space, 'has_nibabel', lambda: False) + read_mri = mne.source_space._read_mri_info + monkeypatch.setattr( + mne.source_space, '_read_mri_info', + lambda *args, **kwargs: read_mri(*args, use_nibabel=True, **kwargs)) coords_2 = vertex_to_mni(vertices, hemis, subject, subjects_dir) # less than 0.1 mm error assert_allclose(coords, coords_2, atol=0.1) From 74ebeb8f8b9f587a5d3694d10f34c5ad75fc814f Mon Sep 17 00:00:00 2001 From: Guillaume Favelier Date: Sat, 22 May 2021 10:12:38 +0200 Subject: [PATCH 12/12] Fix tests --- mne/tests/test_coreg.py | 1 + mne/tests/test_source_space.py | 1 + 2 files changed, 2 insertions(+) diff --git a/mne/tests/test_coreg.py b/mne/tests/test_coreg.py index a2e1485a36c..843df90de3a 100644 --- a/mne/tests/test_coreg.py +++ b/mne/tests/test_coreg.py @@ -56,6 +56,7 @@ def make_dig(coords, cf): assert_array_almost_equal(trans_est['trans'], trans['trans']) +@requires_nibabel() @pytest.mark.slowtest # can take forever on OSX Travis @testing.requires_testing_data @pytest.mark.parametrize('scale', (.9, [1, .2, .8])) diff --git a/mne/tests/test_source_space.py b/mne/tests/test_source_space.py index 589e5aa44f0..568570eb9f1 100644 --- a/mne/tests/test_source_space.py +++ b/mne/tests/test_source_space.py @@ -293,6 +293,7 @@ def test_discrete_source_space(tmpdir): assert _get_src_type(src_new, None) == 'discrete' +@requires_nibabel() @pytest.mark.slowtest @testing.requires_testing_data def test_volume_source_space(tmpdir):