diff --git a/mne/morph.py b/mne/morph.py index a1e9f031c95..80a373cdafe 100644 --- a/mne/morph.py +++ b/mne/morph.py @@ -1142,7 +1142,7 @@ def _hemi_morph(tris, vertices_to, vertices_from, smooth, maps, warn): e = mesh_edges(tris) e.data[e.data == 2] = 1 n_vertices = e.shape[0] - e = e + sparse.eye(n_vertices) + e += sparse.eye(n_vertices, format='csr') if isinstance(smooth, str): _check_option('smooth', smooth, ('nearest',), extra=' when used as a string.') diff --git a/mne/source_space.py b/mne/source_space.py index 0e1186e798b..b62ac4e432b 100644 --- a/mne/source_space.py +++ b/mne/source_space.py @@ -33,7 +33,7 @@ _CheckInside) from .utils import (get_subjects_dir, check_fname, logger, verbose, fill_doc, _ensure_int, check_version, _get_call_line, warn, - _check_fname, _check_path_like, has_nibabel, _check_sphere, + _check_fname, _check_path_like, _check_sphere, _validate_type, _check_option, _is_numeric, _pl, _suggest, object_size, sizeof_fmt) from .parallel import parallel_func, check_n_jobs @@ -1357,17 +1357,17 @@ def read_talxfm(subject, subjects_dir=None, verbose=None): return mri_mni_t -def _read_mri_info(path, units='m', return_img=False): - if has_nibabel(): +def _read_mri_info(path, units='m', return_img=False, use_nibabel=False): + # This is equivalent but 100x slower, so only use nibabel if we need to + # (later): + if use_nibabel: import nibabel - mgz = nibabel.load(path) - hdr = mgz.header + hdr = nibabel.load(path).header n_orig = hdr.get_vox2ras() t_orig = hdr.get_vox2ras_tkr() dims = hdr.get_data_shape() zooms = hdr.get_zooms()[:3] else: - mgz = None hdr = _get_mgz_header(path) n_orig = hdr['vox2ras'] t_orig = hdr['vox2ras_tkr'] @@ -1395,7 +1395,8 @@ def _read_mri_info(path, units='m', return_img=False): out = (vox_ras_t, vox_mri_t, mri_ras_t, dims, zooms) if return_img: - out += (mgz,) + nibabel = _import_nibabel() + out += (nibabel.load(path),) return out @@ -2250,6 +2251,8 @@ def _vol_vertex(width, height, jj, kk, pp): def _get_mgz_header(fname): """Adapted from nibabel to quickly extract header info.""" + fname = _check_fname(fname, overwrite='read', must_exist=True, + name='MRI image') if not fname.endswith('.mgz'): raise IOError('Filename must end with .mgz') header_dtd = [('version', '>i4'), ('dims', '>i4', (4,)), diff --git a/mne/surface.py b/mne/surface.py index 9dda2cf2080..7609c4009d4 100644 --- a/mne/surface.py +++ b/mne/surface.py @@ -10,7 +10,7 @@ from copy import deepcopy from distutils.version import LooseVersion -from functools import partial +from functools import partial, lru_cache from glob import glob from os import path as op from struct import pack @@ -28,7 +28,7 @@ _get_trans, apply_trans, Transform) from .utils import (logger, verbose, get_subjects_dir, warn, _check_fname, _check_option, _ensure_int, _TempDir, run_subprocess, - _check_freesurfer_home) + _check_freesurfer_home, _hashable_ndarray) ############################################################################### @@ -343,7 +343,7 @@ def _normal_orth(nn): @verbose def complete_surface_info(surf, do_neighbor_vert=False, copy=True, - verbose=None): + do_neighbor_tri=True, *, verbose=None): """Complete surface information. Parameters @@ -351,9 +351,11 @@ def complete_surface_info(surf, do_neighbor_vert=False, copy=True, surf : dict The surface. do_neighbor_vert : bool - If True, add neighbor vertex information. + If True (default False), add neighbor vertex information. copy : bool If True (default), make a copy. If False, operate in-place. + do_neighbor_tri : bool + If True (default), compute triangle neighbors. %(verbose)s Returns @@ -383,27 +385,28 @@ def complete_surface_info(surf, do_neighbor_vert=False, copy=True, # Find neighboring triangles, accumulate vertex normals, normalize logger.info(' Triangle neighbors and vertex normals...') - surf['neighbor_tri'] = _triangle_neighbors(surf['tris'], surf['np']) surf['nn'] = _accumulate_normals(surf['tris'].astype(int), surf['tri_nn'], surf['np']) _normalize_vectors(surf['nn']) # Check for topological defects - zero, fewer = list(), list() - for ni, n in enumerate(surf['neighbor_tri']): - if len(n) < 3: - if len(n) == 0: - zero.append(ni) - else: - fewer.append(ni) - surf['neighbor_tri'][ni] = np.array([], int) - if len(zero) > 0: - logger.info(' Vertices do not have any neighboring ' - 'triangles: [%s]' % ', '.join(str(z) for z in zero)) - if len(fewer) > 0: - logger.info(' Vertices have fewer than three neighboring ' - 'triangles, removing neighbors: [%s]' - % ', '.join(str(f) for f in fewer)) + if do_neighbor_tri: + surf['neighbor_tri'] = _triangle_neighbors(surf['tris'], surf['np']) + zero, fewer = list(), list() + for ni, n in enumerate(surf['neighbor_tri']): + if len(n) < 3: + if len(n) == 0: + zero.append(ni) + else: + fewer.append(ni) + surf['neighbor_tri'][ni] = np.array([], int) + if len(zero) > 0: + logger.info(' Vertices do not have any neighboring ' + 'triangles: [%s]' % ', '.join(str(z) for z in zero)) + if len(fewer) > 0: + logger.info(' Vertices have fewer than three neighboring ' + 'triangles, removing neighbors: [%s]' + % ', '.join(str(f) for f in fewer)) # Determine the neighboring vertices and fix errors if do_neighbor_vert is True: @@ -1451,6 +1454,12 @@ def mesh_edges(tris): edges : sparse matrix The adjacency matrix. """ + tris = _hashable_ndarray(tris) + return _mesh_edges(tris=tris) + + +@lru_cache(maxsize=10) +def _mesh_edges(tris=None): from scipy.sparse import coo_matrix if np.max(tris) > len(np.unique(tris)): raise ValueError( diff --git a/mne/tests/test_coreg.py b/mne/tests/test_coreg.py index a2e1485a36c..843df90de3a 100644 --- a/mne/tests/test_coreg.py +++ b/mne/tests/test_coreg.py @@ -56,6 +56,7 @@ def make_dig(coords, cf): assert_array_almost_equal(trans_est['trans'], trans['trans']) +@requires_nibabel() @pytest.mark.slowtest # can take forever on OSX Travis @testing.requires_testing_data @pytest.mark.parametrize('scale', (.9, [1, .2, .8])) diff --git a/mne/tests/test_source_space.py b/mne/tests/test_source_space.py index cbf66a30485..568570eb9f1 100644 --- a/mne/tests/test_source_space.py +++ b/mne/tests/test_source_space.py @@ -293,6 +293,7 @@ def test_discrete_source_space(tmpdir): assert _get_src_type(src_new, None) == 'discrete' +@requires_nibabel() @pytest.mark.slowtest @testing.requires_testing_data def test_volume_source_space(tmpdir): @@ -593,6 +594,7 @@ def test_head_to_mni(): assert_allclose(coords_MNI, coords_MNI_2, atol=10.0) +@requires_nibabel() @testing.requires_testing_data def test_vertex_to_mni_fs_nibabel(monkeypatch): """Test equivalence of vert_to_mni for nibabel and freesurfer.""" @@ -601,7 +603,10 @@ def test_vertex_to_mni_fs_nibabel(monkeypatch): vertices = rng.randint(0, 100000, n_check) hemis = rng.randint(0, 1, n_check) coords = vertex_to_mni(vertices, hemis, subject, subjects_dir) - monkeypatch.setattr(mne.source_space, 'has_nibabel', lambda: False) + read_mri = mne.source_space._read_mri_info + monkeypatch.setattr( + mne.source_space, '_read_mri_info', + lambda *args, **kwargs: read_mri(*args, use_nibabel=True, **kwargs)) coords_2 = vertex_to_mni(vertices, hemis, subject, subjects_dir) # less than 0.1 mm error assert_allclose(coords, coords_2, atol=0.1) diff --git a/mne/utils/__init__.py b/mne/utils/__init__.py index bb41ad9fdcf..628cdee714f 100644 --- a/mne/utils/__init__.py +++ b/mne/utils/__init__.py @@ -60,7 +60,7 @@ _mask_to_onsets_offsets, _array_equal_nan, _julian_to_cal, _cal_to_julian, _dt_to_julian, _julian_to_dt, _dt_to_stamp, _stamp_to_dt, - _check_dt, _ReuseCycle, _arange_div) + _check_dt, _ReuseCycle, _arange_div, _hashable_ndarray) from .mixin import (SizeMixin, GetEpochsMixin, _prepare_read_metadata, _prepare_write_metadata, _FakeNoPandas, ShiftTimeMixin) from .linalg import (_svd_lwork, _repeated_svd, _sym_mat_pow, sqrtm_sym, eigh, diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index 7c24daf44be..7f181ffcce5 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -605,6 +605,18 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True): return grand_average +class _HashableNdarray(np.ndarray): + def __hash__(self): + return object_hash(self) + + def __eq__(self, other): + return NotImplementedError # defer to hash + + +def _hashable_ndarray(x): + return x.view(_HashableNdarray) + + def object_hash(x, h=None): """Hash a reasonable python object. diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index f9002a59b13..edb1901f228 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -105,7 +105,8 @@ def __init__(self, renderer, vertices, triangles, normals): self._actor = None self._is_mapped = False - self._cache = None + self._current_colors = None + self._cached_colors = None self._overlays = OrderedDict() self._default_scalars = np.ones(vertices.shape) @@ -142,14 +143,15 @@ def _compute_over(self, B, A): return np.clip(C, 0, 1, out=C) def _compose_overlays(self): - B = None + B = cache = None for overlay in self._overlays.values(): A = overlay.to_colors() if B is None: B = A else: - B = self._compute_over(B, A) - return B + cache = B + B = self._compute_over(cache, A) + return B, cache def add_overlay(self, scalars, colormap, rng, opacity, name): overlay = _Overlay( @@ -161,36 +163,45 @@ def add_overlay(self, scalars, colormap, rng, opacity, name): ) self._overlays[name] = overlay colors = overlay.to_colors() - - # save colors in cache - if self._cache is None: - self._cache = colors + if self._current_colors is None: + self._current_colors = colors else: - self._cache = self._compute_over(self._cache, colors) + # save previous colors to cache + self._cached_colors = self._current_colors + self._current_colors = self._compute_over( + self._cached_colors, colors) - # update the texture - self._update() + # apply the texture + self._apply() def remove_overlay(self, names): + to_update = False if not isinstance(names, list): names = [names] for name in names: if name in self._overlays: del self._overlays[name] - self.update() + to_update = True + if to_update: + self.update() - def _update(self): - if self._cache is None or self._renderer is None: + def _apply(self): + if self._current_colors is None or self._renderer is None: return self._renderer._set_mesh_scalars( mesh=self._polydata, - scalars=self._cache, + scalars=self._current_colors, name=self._default_scalars_name, ) - def update(self): - self._cache = self._compose_overlays() - self._update() + def update(self, colors=None): + if colors is not None and self._cached_colors is not None: + self._current_colors = self._compute_over( + self._cached_colors, colors) + else: + self._current_colors, self._cached_colors = \ + self._compose_overlays() + self._apply() def _clean(self): mapper = self._actor.GetMapper() @@ -213,7 +224,11 @@ def update_overlay(self, name, scalars=None, colormap=None, overlay._opacity = opacity if rng is not None: overlay._rng = rng - self.update() + # partial update: use cache if possible + if name == list(self._overlays.keys())[-1]: + self.update(colors=overlay.to_colors()) + else: # full update + self.update() @fill_doc @@ -525,7 +540,8 @@ def __init__(self, subject_id, hemi, surf, title=None, alpha=self._silhouette["alpha"], decimate=self._silhouette["decimate"], ) - self._renderer.set_camera(**views_dicts[h][v]) + self._renderer.set_camera(update=False, reset_camera=False, + **views_dicts[h][v]) self.interaction = interaction self._closed = False @@ -1177,7 +1193,7 @@ def _configure_vertex_time_course(self): alpha=0.5, ls=':') # now plot the time line - self.plot_time_line() + self.plot_time_line(update=False) # then the picked points for idx, hemi in enumerate(['lh', 'rh', 'vol']): @@ -1206,7 +1222,7 @@ def _configure_vertex_time_course(self): else: mesh = self._layered_meshes[hemi]._polydata vertex_id = vertices[ind[0]] - self._add_vertex_glyph(hemi, mesh, vertex_id) + self._add_vertex_glyph(hemi, mesh, vertex_id, update=False) def _configure_picking(self): # get data for each hemi @@ -1476,7 +1492,7 @@ def _remove_label_glyph(self, hemi, label_id): self._layered_meshes[hemi].remove_overlay(label.name) self.picked_patches[hemi].remove(label_id) - def _add_vertex_glyph(self, hemi, mesh, vertex_id): + def _add_vertex_glyph(self, hemi, mesh, vertex_id, update=True): if vertex_id in self.picked_points[hemi]: return @@ -1484,7 +1500,7 @@ def _add_vertex_glyph(self, hemi, mesh, vertex_id): if self.act_data_smooth[hemi][0] is None: return color = next(self.color_cycle) - line = self.plot_time_course(hemi, vertex_id, color) + line = self.plot_time_course(hemi, vertex_id, color, update=update) if hemi == 'vol': ijk = np.unravel_index( vertex_id, np.array(mesh.GetDimensions()) - 1, order='F') @@ -1584,7 +1600,7 @@ def clear_glyphs(self): self.rms = None self._renderer._update() - def plot_time_course(self, hemi, vertex_id, color): + def plot_time_course(self, hemi, vertex_id, color, update=True): """Plot the vertex time course. Parameters @@ -1595,6 +1611,8 @@ def plot_time_course(self, hemi, vertex_id, color): The vertex identifier in the mesh. color : matplotlib color The color of the time course. + update : bool + Force an update of the plot. Defaults to True. Returns ------- @@ -1643,11 +1661,18 @@ def plot_time_course(self, hemi, vertex_id, color): lw=1., color=color, zorder=4, + update=update, ) return line - def plot_time_line(self): - """Add the time line to the MPL widget.""" + def plot_time_line(self, update=True): + """Add the time line to the MPL widget. + + Parameters + ---------- + update : bool + Force an update of the plot. Defaults to True. + """ if self.mpl_canvas is None: return if isinstance(self.show_traces, bool) and self.show_traces: @@ -1659,9 +1684,11 @@ def plot_time_line(self): label='time', color=self._fg_color, lw=1, + update=update, ) self.time_line.set_xdata(current_time) - self.mpl_canvas.update_plot() + if update: + self.mpl_canvas.update_plot() def _configure_help(self): pairs = [ @@ -1983,7 +2010,8 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None, bgcolor=self._brain_color[:3]) kwargs.update(colorbar_kwargs or {}) self._scalar_bar = self._renderer.scalarbar(**kwargs) - self._renderer.set_camera(**views_dicts[hemi][v]) + self._renderer.set_camera( + update=False, reset_camera=False, **views_dicts[hemi][v]) # 4) update the scalar bar and opacity self.update_lut(alpha=alpha) @@ -2280,7 +2308,7 @@ def add_label(self, label, color=None, alpha=1, scalar_thresh=None, name=label_name, ) if reset_camera: - self._renderer.set_camera(**views_dicts[hemi][v]) + self._renderer.set_camera(update=False, **views_dicts[hemi][v]) if self.time_viewer and self.show_traces \ and self.traces_mode == 'label': label._color = orig_color @@ -2394,7 +2422,7 @@ def _configure_label_time_course(self): self.add_annotation(self.annot, color="w", alpha=0.75) # now plot the time line - self.plot_time_line() + self.plot_time_line(update=False) self.mpl_canvas.update_plot() for hemi in self._hemis: @@ -2736,7 +2764,6 @@ def set_data_smoothing(self, n_steps): n_steps : int Number of smoothing steps. """ - from scipy import sparse from ...morph import _hemi_morph for hemi in ['lh', 'rh']: hemi_data = self._data.get(hemi) @@ -2750,12 +2777,11 @@ def set_data_smoothing(self, n_steps): 'parameter must not be None' % (len(hemi_data), self.geo[hemi].x.shape[0])) morph_n_steps = 'nearest' if n_steps == -1 else n_steps - maps = sparse.eye(len(self.geo[hemi].coords), format='csr') with use_log_level(False): smooth_mat = _hemi_morph( self.geo[hemi].orig_faces, np.arange(len(self.geo[hemi].coords)), - vertices, morph_n_steps, maps, warn=False) + vertices, morph_n_steps, maps=None, warn=False) self._data[hemi]['smooth_mat'] = smooth_mat self.set_time_point(self._data['time_idx']) self._data['smoothing_steps'] = n_steps diff --git a/mne/viz/_brain/surface.py b/mne/viz/_brain/surface.py index e2e1e512692..dc7ebea19ab 100644 --- a/mne/viz/_brain/surface.py +++ b/mne/viz/_brain/surface.py @@ -131,7 +131,8 @@ def load_geometry(self): else: coords -= (np.min(x_) + self.offset) * self.x_dir surf = dict(rr=coords, tris=faces) - complete_surface_info(surf, copy=False, verbose=False) + complete_surface_info( + surf, copy=False, verbose=False, do_neighbor_tri=False) nn = surf['nn'] self.coords = coords self.faces = faces diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index 8dc280baa60..b20da49c8c0 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -115,7 +115,8 @@ def test_layered_mesh(renderer_interactive_pyvista): assert not mesh._is_mapped mesh.map() assert mesh._is_mapped - assert mesh._cache is None + assert mesh._current_colors is None + assert mesh._cached_colors is None mesh.update() assert len(mesh._overlays) == 0 mesh.add_overlay( @@ -123,13 +124,27 @@ def test_layered_mesh(renderer_interactive_pyvista): colormap=np.array([(1, 1, 1, 1), (0, 0, 0, 0)]), rng=[0, 1], opacity=None, - name='test', + name='test1', ) - assert mesh._cache is not None + assert mesh._current_colors is not None + assert mesh._cached_colors is None + assert len(mesh._overlays) == 1 + assert 'test1' in mesh._overlays + mesh.add_overlay( + scalars=np.array([1, 0, 0, 1]), + colormap=np.array([(1, 1, 1, 1), (0, 0, 0, 0)]), + rng=[0, 1], + opacity=None, + name='test2', + ) + assert mesh._current_colors is not None + assert mesh._cached_colors is not None + assert len(mesh._overlays) == 2 + assert 'test2' in mesh._overlays + mesh.remove_overlay('test2') + assert 'test2' not in mesh._overlays + mesh.update() assert len(mesh._overlays) == 1 - assert 'test' in mesh._overlays - mesh.remove_overlay('test') - assert len(mesh._overlays) == 0 mesh._clean() diff --git a/mne/viz/backends/_abstract.py b/mne/viz/backends/_abstract.py index a101f50d62f..11f4c79ecf1 100644 --- a/mne/viz/backends/_abstract.py +++ b/mne/viz/backends/_abstract.py @@ -655,17 +655,19 @@ def _connect(self): self.canvas.mpl_connect( event + '_event', getattr(self, 'on_' + event)) - def plot(self, x, y, label, **kwargs): + def plot(self, x, y, label, update=True, **kwargs): """Plot a curve.""" line, = self.axes.plot( x, y, label=label, **kwargs) - self.update_plot() + if update: + self.update_plot() return line - def plot_time_line(self, x, label, **kwargs): + def plot_time_line(self, x, label, update=True, **kwargs): """Plot the vertical line.""" line = self.axes.axvline(x, label=label, **kwargs) - self.update_plot() + if update: + self.update_plot() return line def update_plot(self): diff --git a/mne/viz/backends/_pysurfer_mayavi.py b/mne/viz/backends/_pysurfer_mayavi.py index 24aa8475da9..00dd0bbb7f2 100644 --- a/mne/viz/backends/_pysurfer_mayavi.py +++ b/mne/viz/backends/_pysurfer_mayavi.py @@ -316,10 +316,10 @@ def close(self): def set_camera(self, azimuth=None, elevation=None, distance=None, focalpoint=None, roll=None, reset_camera=None, - rigid=None): + rigid=None, update=True): _set_3d_view(figure=self.fig, azimuth=azimuth, elevation=elevation, distance=distance, - focalpoint=focalpoint, roll=roll) + focalpoint=focalpoint, roll=roll, update=update) def reset_camera(self): renderer = getattr(self.fig.scene, 'renderer', None) @@ -454,13 +454,14 @@ def _close_all(): def _set_3d_view(figure, azimuth, elevation, focalpoint, distance, roll=None, - reset_camera=True): + reset_camera=True, update=True): from mayavi import mlab with warnings.catch_warnings(record=True): # traits with SilenceStdout(): mlab.view(azimuth, elevation, distance, focalpoint=focalpoint, figure=figure, roll=roll) - mlab.draw(figure) + if update: + mlab.draw(figure) def _set_3d_title(figure, title, size=40): diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index cf942c48a0c..92e2b352a42 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -598,10 +598,10 @@ def close(self): def set_camera(self, azimuth=None, elevation=None, distance=None, focalpoint='auto', roll=None, reset_camera=True, - rigid=None): + rigid=None, update=True): _set_3d_view(self.figure, azimuth=azimuth, elevation=elevation, distance=distance, focalpoint=focalpoint, roll=roll, - reset_camera=reset_camera, rigid=rigid) + reset_camera=reset_camera, rigid=rigid, update=update) def reset_camera(self): self.plotter.reset_camera() @@ -930,7 +930,8 @@ def _get_camera_direction(focalpoint, position): def _set_3d_view(figure, azimuth=None, elevation=None, focalpoint='auto', - distance=None, roll=None, reset_camera=True, rigid=None): + distance=None, roll=None, reset_camera=True, rigid=None, + update=True): rigid = np.eye(4) if rigid is None else rigid position = np.array(figure.plotter.camera_position[0]) bounds = np.array(figure.plotter.renderer.ComputeVisiblePropBounds()) @@ -991,8 +992,9 @@ def _set_3d_view(figure, azimuth=None, elevation=None, focalpoint='auto', if roll is not None: figure.plotter.camera.SetRoll(figure.plotter.camera.GetRoll() + roll) - figure.plotter.update() - _process_events(figure.plotter) + if update: + figure.plotter.update() + _process_events(figure.plotter) def _set_3d_title(figure, title, size=16):