diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 7cf8f200bcc..d0f5f51df9c 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -27,6 +27,8 @@ Changelog - Add :class:`mne.MixedVectorSourceEstimate` for vector source estimates for mixed source spaces, by `Eric Larson`_ +- Add mixed and volumetric source estimate plotting using volumetric ray-casting to :meth:`mne.MixedSourceEstimate.plot` and :meth:`mne.VolSourceEstimate.plot_3d` by `Eric Larson`_ + - Add :meth:`mne.MixedSourceEstimate.surface` and :meth:`mne.MixedSourceEstimate.volume` methods to allow surface and volume extraction by `Eric Larson`_ - Add :meth:`mne.VectorSourceEstimate.project` to project vector source estimates onto the direction of maximum source power by `Eric Larson`_ diff --git a/doc/conf.py b/doc/conf.py index f54ef1d34bf..00fb25a15f1 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -596,6 +596,7 @@ def reset_warnings(gallery_conf, fname): 'VolSourceEstimate': 'mne.VolSourceEstimate', 'VolVectorSourceEstimate': 'mne.VolVectorSourceEstimate', 'MixedSourceEstimate': 'mne.MixedSourceEstimate', + 'MixedVectorSourceEstimate': 'mne.MixedVectorSourceEstimate', 'SourceEstimate': 'mne.SourceEstimate', 'Projection': 'mne.Projection', 'ConductorModel': 'mne.bem.ConductorModel', 'Dipole': 'mne.Dipole', 'DipoleFixed': 'mne.DipoleFixed', diff --git a/examples/inverse/plot_mixed_source_space_inverse.py b/examples/inverse/plot_mixed_source_space_inverse.py index 81a5867697b..a66e4a74f7d 100644 --- a/examples/inverse/plot_mixed_source_space_inverse.py +++ b/examples/inverse/plot_mixed_source_space_inverse.py @@ -128,17 +128,27 @@ pick_ori=None) src = inverse_operator['src'] +############################################################################### +# Plot the mixed source estimate +# ------------------------------ + +# sphinx_gallery_thumbnail_number = 3 +initial_time = 0.1 +stc_vec = apply_inverse(evoked, inverse_operator, lambda2, inv_method, + pick_ori='vector') +brain = stc_vec.plot( + hemi='both', src=inverse_operator['src'], views='coronal', + initial_time=initial_time, subjects_dir=subjects_dir) + ############################################################################### # Plot the surface # ---------------- -initial_time = 0.1 brain = stc.surface().plot(initial_time=initial_time, subjects_dir=subjects_dir) ############################################################################### # Plot the volume # ---------------- -# sphinx_gallery_thumbnail_number = 4 fig = stc.volume().plot(initial_time=initial_time, src=src, subjects_dir=subjects_dir) diff --git a/mne/conftest.py b/mne/conftest.py index 980ced1b0d8..56ea405c3d3 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -98,6 +98,7 @@ def pytest_configure(config): ignore:.*pandas\.util\.testing is deprecated.*: ignore:.*tostring.*is deprecated.*:DeprecationWarning ignore:.*QDesktopWidget\.availableGeometry.*:DeprecationWarning + ignore:Unable to enable faulthandler.*:UserWarning always:.*get_data.* is deprecated in favor of.*:DeprecationWarning always::ResourceWarning """ # noqa: E501 @@ -320,6 +321,21 @@ def renderer_notebook(): yield +@pytest.fixture(scope='session') +def pixel_ratio(): + """Get the pixel ratio.""" + from mne.viz.backends.tests._utils import (has_mayavi, has_pyvista, + has_pyqt5) + if not (has_mayavi() or has_pyvista()) or not has_pyqt5(): + return 1. + from PyQt5.QtWidgets import QApplication, QMainWindow + _ = QApplication.instance() or QApplication([]) + window = QMainWindow() + ratio = float(window.devicePixelRatio()) + window.close() + return ratio + + @pytest.fixture(scope='function', params=[testing._pytest_param()]) def subjects_dir_tmp(tmpdir): """Copy MNE-testing-data subjects_dir to a temp dir for manipulation.""" diff --git a/mne/defaults.py b/mne/defaults.py index 7cf7848200f..41d8c38e783 100644 --- a/mne/defaults.py +++ b/mne/defaults.py @@ -82,6 +82,8 @@ depth_sparse=dict(exp=0.8, limit=None, limit_depth_chs='whiten', combine_xyz='fro', allow_fixed_depth=True), interpolation_method=dict(eeg='spline', meg='MNE', fnirs='nearest'), + volume_options=dict( + alpha=None, resolution=1., surface_alpha=None, blending='mip'), ) diff --git a/mne/source_estimate.py b/mne/source_estimate.py index 2874e5b7f96..a5925aa847a 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -611,6 +611,30 @@ def save(self, fname, ftype='h5', verbose=None): src_type=self._src_type), title='mnepython', overwrite=True) + @copy_function_doc_to_method_doc(plot_source_estimates) + def plot(self, subject=None, surface='inflated', hemi='lh', + colormap='auto', time_label='auto', smoothing_steps=10, + transparent=True, alpha=1.0, time_viewer='auto', + subjects_dir=None, + figure=None, views='lat', colorbar=True, clim='auto', + cortex="classic", size=800, background="black", + foreground=None, initial_time=None, time_unit='s', + backend='auto', spacing='oct6', title=None, show_traces='auto', + src=None, volume_options=1., view_layout='vertical', + verbose=None): + brain = plot_source_estimates( + self, subject, surface=surface, hemi=hemi, colormap=colormap, + time_label=time_label, smoothing_steps=smoothing_steps, + transparent=transparent, alpha=alpha, time_viewer=time_viewer, + subjects_dir=subjects_dir, figure=figure, views=views, + colorbar=colorbar, clim=clim, cortex=cortex, size=size, + background=background, foreground=foreground, + initial_time=initial_time, time_unit=time_unit, backend=backend, + spacing=spacing, title=title, show_traces=show_traces, + src=src, volume_options=volume_options, view_layout=view_layout, + verbose=verbose) + return brain + @property def sfreq(self): """Sample rate of the data.""" @@ -1576,28 +1600,6 @@ def save(self, fname, ftype='stc', verbose=None): super().save(fname) logger.info('[done]') - @copy_function_doc_to_method_doc(plot_source_estimates) - def plot(self, subject=None, surface='inflated', hemi='lh', - colormap='auto', time_label='auto', smoothing_steps=10, - transparent=True, alpha=1.0, time_viewer='auto', - subjects_dir=None, - figure=None, views='lat', colorbar=True, clim='auto', - cortex="classic", size=800, background="black", - foreground=None, initial_time=None, time_unit='s', - backend='auto', spacing='oct6', title=None, - show_traces='auto', verbose=None): - brain = plot_source_estimates( - self, subject, surface=surface, hemi=hemi, colormap=colormap, - time_label=time_label, smoothing_steps=smoothing_steps, - transparent=transparent, alpha=alpha, time_viewer=time_viewer, - subjects_dir=subjects_dir, figure=figure, views=views, - colorbar=colorbar, clim=clim, cortex=cortex, size=size, - background=background, foreground=foreground, - initial_time=initial_time, time_unit=time_unit, backend=backend, - spacing=spacing, title=title, show_traces=show_traces, - verbose=verbose) - return brain - @verbose def estimate_snr(self, info, fwd, cov, verbose=None): r"""Compute time-varying SNR in the source space. @@ -1899,12 +1901,57 @@ def project(self, directions, src=None, use_cps=True): self.verbose) return stc, directions + @copy_function_doc_to_method_doc(plot_vector_source_estimates) + def plot(self, subject=None, hemi='lh', colormap='hot', time_label='auto', + smoothing_steps=10, transparent=True, brain_alpha=0.4, + overlay_alpha=None, vector_alpha=1.0, scale_factor=None, + time_viewer='auto', subjects_dir=None, figure=None, views='lat', + colorbar=True, clim='auto', cortex='classic', size=800, + background='black', foreground=None, initial_time=None, + time_unit='s', show_traces='auto', src=None, volume_options=1., + view_layout='vertical', verbose=None): # noqa: D102 + return plot_vector_source_estimates( + self, subject=subject, hemi=hemi, colormap=colormap, + time_label=time_label, smoothing_steps=smoothing_steps, + transparent=transparent, brain_alpha=brain_alpha, + overlay_alpha=overlay_alpha, vector_alpha=vector_alpha, + scale_factor=scale_factor, time_viewer=time_viewer, + subjects_dir=subjects_dir, figure=figure, views=views, + colorbar=colorbar, clim=clim, cortex=cortex, size=size, + background=background, foreground=foreground, + initial_time=initial_time, time_unit=time_unit, + show_traces=show_traces, src=src, volume_options=volume_options, + view_layout=view_layout, verbose=verbose) + class _BaseVolSourceEstimate(_BaseSourceEstimate): _src_type = 'volume' _src_count = None + @copy_function_doc_to_method_doc(plot_source_estimates) + def plot_3d(self, subject=None, surface='white', hemi='both', + colormap='auto', time_label='auto', smoothing_steps=10, + transparent=True, alpha=0.2, time_viewer='auto', + subjects_dir=None, + figure=None, views='axial', colorbar=True, clim='auto', + cortex="classic", size=800, background="black", + foreground=None, initial_time=None, time_unit='s', + backend='auto', spacing='oct6', title=None, show_traces='auto', + src=None, volume_options=1., view_layout='vertical', + verbose=None): + return super().plot( + subject=subject, surface=surface, hemi=hemi, colormap=colormap, + time_label=time_label, smoothing_steps=smoothing_steps, + transparent=transparent, alpha=alpha, time_viewer=time_viewer, + subjects_dir=subjects_dir, + figure=figure, views=views, colorbar=colorbar, clim=clim, + cortex=cortex, size=size, background=background, + foreground=foreground, initial_time=initial_time, + time_unit=time_unit, backend=backend, spacing=spacing, title=title, + show_traces=show_traces, src=src, volume_options=volume_options, + view_layout=view_layout, verbose=verbose) + @copy_function_doc_to_method_doc(plot_volume_source_estimates) def plot(self, src, subject=None, subjects_dir=None, mode='stat_map', bg_img='T1.mgz', colorbar=True, colormap='auto', clim='auto', @@ -2162,8 +2209,8 @@ def save(self, fname, ftype='stc', verbose=None): @fill_doc -class VolVectorSourceEstimate(_BaseVectorSourceEstimate, - _BaseVolSourceEstimate): +class VolVectorSourceEstimate(_BaseVolSourceEstimate, + _BaseVectorSourceEstimate): """Container for volume source estimates. Parameters @@ -2209,6 +2256,32 @@ class VolVectorSourceEstimate(_BaseVectorSourceEstimate, _scalar_class = VolSourceEstimate + # defaults differ: hemi='both', views='axial' + @copy_function_doc_to_method_doc(plot_vector_source_estimates) + def plot_3d(self, subject=None, hemi='both', colormap='hot', + time_label='auto', + smoothing_steps=10, transparent=True, brain_alpha=0.4, + overlay_alpha=None, vector_alpha=1.0, scale_factor=None, + time_viewer='auto', subjects_dir=None, figure=None, + views='axial', + colorbar=True, clim='auto', cortex='classic', size=800, + background='black', foreground=None, initial_time=None, + time_unit='s', show_traces='auto', src=None, + volume_options=1., view_layout='vertical', + verbose=None): # noqa: D102 + return _BaseVectorSourceEstimate.plot( + self, subject=subject, hemi=hemi, colormap=colormap, + time_label=time_label, smoothing_steps=smoothing_steps, + transparent=transparent, brain_alpha=brain_alpha, + overlay_alpha=overlay_alpha, vector_alpha=vector_alpha, + scale_factor=scale_factor, time_viewer=time_viewer, + subjects_dir=subjects_dir, figure=figure, views=views, + colorbar=colorbar, clim=clim, cortex=cortex, size=size, + background=background, foreground=foreground, + initial_time=initial_time, time_unit=time_unit, + show_traces=show_traces, src=src, volume_options=volume_options, + view_layout=view_layout, verbose=verbose) + @fill_doc class VectorSourceEstimate(_BaseVectorSourceEstimate, @@ -2259,27 +2332,6 @@ class VectorSourceEstimate(_BaseVectorSourceEstimate, _scalar_class = SourceEstimate - @copy_function_doc_to_method_doc(plot_vector_source_estimates) - def plot(self, subject=None, hemi='lh', colormap='hot', time_label='auto', - smoothing_steps=10, transparent=True, brain_alpha=0.4, - overlay_alpha=None, vector_alpha=1.0, scale_factor=None, - time_viewer='auto', subjects_dir=None, figure=None, views='lat', - colorbar=True, clim='auto', cortex='classic', size=800, - background='black', foreground=None, initial_time=None, - time_unit='s', show_traces='auto', verbose=None): # noqa: D102 - return plot_vector_source_estimates( - self, subject=subject, hemi=hemi, colormap=colormap, - time_label=time_label, smoothing_steps=smoothing_steps, - transparent=transparent, brain_alpha=brain_alpha, - overlay_alpha=overlay_alpha, vector_alpha=vector_alpha, - scale_factor=scale_factor, time_viewer=time_viewer, - subjects_dir=subjects_dir, figure=figure, views=views, - colorbar=colorbar, clim=clim, cortex=cortex, size=size, - background=background, foreground=foreground, - initial_time=initial_time, time_unit=time_unit, - show_traces=show_traces, verbose=verbose, - ) - ############################################################################### # Mixed source estimate (two cortical surfs plus other stuff) diff --git a/mne/utils/docs.py b/mne/utils/docs.py index b65205b0ac1..cf74e805d46 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -990,13 +990,15 @@ or 'cubic'. """ docdict["show_traces"] = """ -show_traces : bool | str +show_traces : bool | str | float If True, enable interactive picking of a point on the surface of the - brain and plot it's time course using the bottom 1/3 of the figure. + brain and plot its time course. This feature is only available with the PyVista 3d backend, and requires ``time_viewer=True``. Defaults to 'auto', which will use True if and only if ``time_viewer=True``, the backend is PyVista, and there is more - than one time point. + than one time point. If float (between zero and one), it specifies what + proportion of the total window should be devoted to traces (True is + equivalent to 0.25, i.e., it will occupy the bottom 1/4 of the figure). .. versionadded:: 0.20.0 """ @@ -1007,7 +1009,34 @@ default is ``'auto'``, which will use ``time=%0.2f ms`` if there is more than one time point. """ - +docdict["src_volume_options_layout"] = """ +src : instance of SourceSpaces | None + The source space corresponding to the source estimate. Only necessary + if the STC is a volume or mixed source estimate. +volume_options : float | dict | None + Options for volumetric source estimate plotting, with key/value pairs: + + - ``'resolution'`` : float | None + Resolution (in mm) of volume rendering. Smaller (e.g., 1.) looks + better at the cost of speed. None (default) uses the volume source + space resolution, which is often something like 7 or 5 mm, + without resampling. + - ``'blending'`` : str + Can be "mip" (default) for maximum intensity projection or + "composite" for composite blending. + - ``'alpha'`` : float | None + Alpha for the volumetric rendering. Uses 0.4 for vector source + estimates and 1.0 for scalar source estimates. + - ``'surface_alpha'`` : float | None + Alpha for the surface enclosing the volume(s). None will use + half the volume alpha. Set to zero to avoid plotting the surface. + + A float input (default 1.) or None will be used for the ``'resolution'`` + entry. +view_layout : str + Can be "vertical" (default) or "horizontal". When using "horizontal" mode, + the PyVista backend must be used and hemi cannot be "split". +""" # STC label time course docdict['eltc_labels'] = """ diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index b709ab2f995..4c89132019d 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -27,7 +27,7 @@ from ..io.constants import FIFF from ..io.meas_info import read_fiducials, create_info from ..source_space import (_ensure_src, _create_surf_spacing, _check_spacing, - _read_mri_info) + _read_mri_info, SourceSpaces) from ..surface import (get_meg_helmet_surf, read_surface, _DistanceQuery, transform_surface_to, _project_onto_surface, @@ -953,7 +953,7 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None, # initialize figure renderer = _get_renderer(fig, bgcolor=(0.5, 0.5, 0.5), size=(800, 800)) if interaction == 'terrain': - renderer.set_interactive() + renderer.set_interaction('terrain') # plot surfaces alphas = dict(head=head_alpha, helmet=0.25, lh=hemi_val, rh=hemi_val) @@ -1549,6 +1549,34 @@ def link_brains(brains, time=True, camera=False): _LinkViewer(brains, time, camera) +def _triage_stc(stc, src, surface, backend_name, kind='scalar'): + from ..source_estimate import ( + _BaseSurfaceSourceEstimate, _BaseMixedSourceEstimate) + if isinstance(stc, _BaseSurfaceSourceEstimate): + stc_vol = src_vol = None + else: + if backend_name == 'mayavi': + raise RuntimeError( + 'Must use the PyVista 3D backend to plot a mixed or volume ' + 'source estimate') + _validate_type(src, SourceSpaces, 'src', + 'src when stc is a mixed or volume source estimate') + if isinstance(stc, _BaseMixedSourceEstimate): + stc_vol = stc.volume() + stc = stc.surface() + # When showing subvolumes, surfaces that preserve geometry must + # be used (i.e., no inflated) + _check_option( + 'surface', surface, ('white', 'pial'), + extra='when plotting a mixed source estimate') + else: + stc_vol = stc + stc = None + src_vol = src + src_vol = src[2:] if src.kind == 'mixed' else src + return stc, stc_vol, src_vol + + @verbose def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh', colormap='auto', time_label='auto', @@ -1558,12 +1586,10 @@ def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh', cortex="classic", size=800, background="black", foreground=None, initial_time=None, time_unit='s', backend='auto', spacing='oct6', - title=None, show_traces='auto', verbose=None): - """Plot SourceEstimate with PySurfer. - - By default this function uses :mod:`mayavi.mlab` to plot the source - estimates. If Mayavi is not installed, the plotting is done with - :mod:`matplotlib.pyplot` (much slower, decimated source space by default). + title=None, show_traces='auto', + src=None, volume_options=1., view_layout='vertical', + verbose=None): + """Plot SourceEstimate. Parameters ---------- @@ -1649,6 +1675,7 @@ def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh', .. versionadded:: 0.17.0 %(show_traces)s + %(src_volume_options_layout)s %(verbose)s Returns @@ -1658,9 +1685,8 @@ def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh', matplotlib figure. """ # noqa: E501 from .backends.renderer import _get_3d_backend, set_3d_backend - # import here to avoid circular import problem - from ..source_estimate import SourceEstimate - _validate_type(stc, SourceEstimate, "stc", "Surface Source Estimate") + from ..source_estimate import _BaseSourceEstimate + _validate_type(stc, _BaseSourceEstimate, 'stc', 'source estimate') subjects_dir = get_subjects_dir(subjects_dir=subjects_dir, raise_error=True) subject = _check_subject(stc.subject, subject, True) @@ -1675,6 +1701,8 @@ def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh', plot_mpl = True else: # 'mayavi' raise + else: + backend = _get_3d_backend() kwargs = dict( subject=subject, surface=surface, hemi=hemi, colormap=colormap, time_label=time_label, smoothing_steps=smoothing_steps, @@ -1687,24 +1715,18 @@ def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh', return _plot_stc( stc, overlay_alpha=alpha, brain_alpha=alpha, vector_alpha=alpha, cortex=cortex, foreground=foreground, size=size, scale_factor=None, - show_traces=show_traces, **kwargs) + show_traces=show_traces, src=src, volume_options=volume_options, + view_layout=view_layout, **kwargs) def _plot_stc(stc, subject, surface, hemi, colormap, time_label, smoothing_steps, subjects_dir, views, clim, figure, initial_time, time_unit, background, time_viewer, colorbar, transparent, brain_alpha, overlay_alpha, vector_alpha, cortex, foreground, - size, scale_factor, show_traces): + size, scale_factor, show_traces, src, volume_options, + view_layout): from .backends.renderer import _get_3d_backend - from ..source_estimate import ( - _BaseSourceEstimate, SourceEstimate, VectorSourceEstimate) - _validate_type(stc, _BaseSourceEstimate) vec = stc._data_ndim == 3 - if vec: - allowed = VectorSourceEstimate - else: - allowed = SourceEstimate - _validate_type(stc, allowed, 'stc') subjects_dir = get_subjects_dir(subjects_dir=subjects_dir, raise_error=True) subject = _check_subject(stc.subject, subject, True) @@ -1719,12 +1741,17 @@ def _plot_stc(stc, subject, surface, hemi, colormap, time_label, from ._brain import _Brain as Brain _check_option('hemi', hemi, ['lh', 'rh', 'split', 'both']) + _check_option('view_layout', view_layout, ('vertical', 'horizontal')) time_label, times = _handle_time(time_label, time_unit, stc.times) # convert control points to locations in colormap mapdata = _process_clim(clim, colormap, transparent, stc.data, allow_pos_lims=not vec) + stc_surf, stc_vol, src_vol = _triage_stc( + stc, src, surface, backend, 'scalar') + del src, stc + # XXX we should only need to do this for PySurfer/Mayavi, the PyVista # plotter should be smart enough to do this separation in the cmap-to-ctab # conversion. But this will need to be another refactoring that will @@ -1742,6 +1769,8 @@ def _plot_stc(stc, subject, surface, hemi, colormap, time_label, hemis = ['lh', 'rh'] else: hemis = [hemi] + if stc_vol is not None: + hemis.append('vol') if overlay_alpha is None: overlay_alpha = brain_alpha @@ -1758,8 +1787,12 @@ def _plot_stc(stc, subject, surface, hemi, colormap, time_label, } if backend in ['pyvista', 'notebook']: kwargs["show"] = not time_viewer + kwargs["view_layout"] = view_layout else: kwargs.update(_check_pysurfer_antialias(Brain)) + if view_layout != 'vertical': + raise ValueError('view_layout must be "vertical" when using the ' + 'mayavi backend') with warnings.catch_warnings(record=True): # traits warnings brain = Brain(**kwargs) del kwargs @@ -1774,17 +1807,22 @@ def _plot_stc(stc, subject, surface, hemi, colormap, time_label, sd_kwargs = dict(transparent=transparent, verbose=False) center = 0. if diverging else None for hemi in hemis: - data = getattr(stc, hemi + '_data') - vertices = stc.vertices[0 if hemi == 'lh' else 1] - alpha = overlay_alpha - if len(data) == 0: - continue + if hemi == 'vol': + data = stc_vol.data + vertices = np.concatenate(stc_vol.vertices) + else: + if stc_surf is None: + continue + data = getattr(stc_surf, hemi + '_data') + vertices = stc_surf.vertices[0 if hemi == 'lh' else 1] + if len(data) == 0: + continue kwargs = { "array": data, "colormap": colormap, "vertices": vertices, "smoothing_steps": smoothing_steps, "time": times, "time_label": time_label, - "alpha": alpha, "hemi": hemi, + "alpha": overlay_alpha, "hemi": hemi, "colorbar": colorbar, "vector_alpha": vector_alpha, "scale_factor": scale_factor, @@ -1802,6 +1840,8 @@ def _plot_stc(stc, subject, surface, hemi, colormap, time_label, kwargs["fmid"] = scale_pts[1] kwargs["fmax"] = scale_pts[2] kwargs["clim"] = clim + kwargs["volume_options"] = volume_options + kwargs["src"] = src_vol with warnings.catch_warnings(record=True): # traits warnings brain.add_data(**kwargs) brain.scale_data_colormap(fmin=scale_pts[0], fmid=scale_pts[1], @@ -1832,8 +1872,10 @@ def _plot_stc(stc, subject, surface, hemi, colormap, time_label, # time_viewer and show_traces _check_option('time_viewer', time_viewer, (True, False, 'auto')) - _check_option('show_traces', show_traces, - (True, False, 'auto', 'separate')) + _validate_type(show_traces, (str, bool, 'numeric'), 'show_traces') + if isinstance(show_traces, str): + _check_option('show_traces', show_traces, ('auto', 'separate'), + extra='when a string') if time_viewer == 'auto': time_viewer = not using_mayavi if show_traces == 'auto': @@ -2315,7 +2357,8 @@ def plot_vector_source_estimates(stc, subject=None, hemi='lh', colormap='hot', size=800, background='black', foreground=None, initial_time=None, time_unit='s', show_traces='auto', - verbose=None): + src=None, volume_options=1., + view_layout='vertical', verbose=None): """Plot VectorSourceEstimate with PySurfer. A "glass brain" is drawn and all dipoles defined in the source estimate @@ -2325,7 +2368,7 @@ def plot_vector_source_estimates(stc, subject=None, hemi='lh', colormap='hot', Parameters ---------- - stc : VectorSourceEstimate + stc : VectorSourceEstimate | MixedVectorSourceEstimate The vector source estimate to plot. subject : str | None The subject name corresponding to FreeSurfer environment @@ -2389,6 +2432,7 @@ def plot_vector_source_estimates(stc, subject=None, hemi='lh', colormap='hot', Whether time is represented in seconds ("s", default) or milliseconds ("ms"). %(show_traces)s + %(src_volume_options_layout)s %(verbose)s Returns @@ -2403,6 +2447,9 @@ def plot_vector_source_estimates(stc, subject=None, hemi='lh', colormap='hot', If the current magnitude overlay is not desired, set ``overlay_alpha=0`` and ``smoothing_steps=1``. """ + from ..source_estimate import _BaseVectorSourceEstimate + _validate_type( + stc, _BaseVectorSourceEstimate, 'stc', 'vector source estimate') return _plot_stc( stc, subject=subject, surface='white', hemi=hemi, colormap=colormap, time_label=time_label, smoothing_steps=smoothing_steps, @@ -2411,7 +2458,8 @@ def plot_vector_source_estimates(stc, subject=None, hemi='lh', colormap='hot', time_viewer=time_viewer, colorbar=colorbar, transparent=transparent, brain_alpha=brain_alpha, overlay_alpha=overlay_alpha, vector_alpha=vector_alpha, cortex=cortex, foreground=foreground, - size=size, scale_factor=scale_factor, show_traces=show_traces) + size=size, scale_factor=scale_factor, show_traces=show_traces, + src=src, volume_options=volume_options, view_layout=view_layout) @verbose diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index e7ac2e840b8..c0fd1fd061d 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -15,13 +15,16 @@ from .colormap import calculate_lut from .surface import Surface -from .view import lh_views_dict, rh_views_dict, View +from .view import views_dicts from .._3d import _process_clim, _handle_time +from ...defaults import _handle_default from ...surface import mesh_edges +from ...source_space import SourceSpaces +from ...transforms import apply_trans from ...utils import (_check_option, logger, verbose, fill_doc, _validate_type, - use_log_level) + use_log_level, Bunch) class _Brain(object): @@ -82,7 +85,6 @@ class _Brain(object): mostly for generating images or screenshots, but can be buggy. Use at your own risk. interaction : str - Not supported yet. Can be "trackball" (default) or "terrain", i.e. a turntable-style camera. units : str @@ -148,6 +150,10 @@ class _Brain(object): +---------------------------+--------------+-----------------------+ | get_picked_points | | ✓ | +---------------------------+--------------+-----------------------+ + | add_data(volume) | | ✓ | + +---------------------------+--------------+-----------------------+ + | view_layout | | ✓ | + +---------------------------+--------------+-----------------------+ """ @@ -155,13 +161,11 @@ def __init__(self, subject_id, hemi, surf, title=None, cortex="classic", alpha=1.0, size=800, background="black", foreground=None, figure=None, subjects_dir=None, views=['lateral'], offset=True, show_toolbar=False, - offscreen=False, interaction=None, units='mm', - show=True): + offscreen=False, interaction='trackball', units='mm', + view_layout='vertical', show=True): from ..backends.renderer import backend, _get_renderer, _get_3d_backend from matplotlib.colors import colorConverter - - if interaction is not None: - raise ValueError('"interaction" parameter is not supported.') + from matplotlib.cm import get_cmap if hemi in ('both', 'split'): self._hemis = ('lh', 'rh') @@ -170,6 +174,8 @@ def __init__(self, subject_id, hemi, surf, title=None, else: raise KeyError('hemi has to be either "lh", "rh", "split", ' 'or "both"') + _check_option('view_layout', view_layout, ('vertical', 'horizontal')) + self._view_layout = view_layout if figure is not None and not isinstance(figure, int): backend._check_3d_figure(figure) @@ -177,6 +183,7 @@ def __init__(self, subject_id, hemi, surf, title=None, self._title = subject_id else: self._title = title + self._interaction = 'trackball' if isinstance(background, str): background = colorConverter.to_rgb(background) @@ -186,11 +193,14 @@ def __init__(self, subject_id, hemi, surf, title=None, if isinstance(foreground, str): foreground = colorConverter.to_rgb(foreground) self._fg_color = foreground + if isinstance(views, str): views = [views] - n_row = len(views) col_dict = dict(lh=1, rh=1, both=1, split=2) - n_col = col_dict[hemi] + shape = (len(views), col_dict[hemi]) + if self._view_layout == 'horizontal': + shape = shape[::-1] + self._subplot_shape = shape size = tuple(np.atleast_1d(size).round(0).astype(int).flat) if len(size) not in (1, 2): @@ -201,6 +211,7 @@ def __init__(self, subject_id, hemi, surf, title=None, self._notebook = (_get_3d_backend() == "notebook") self._hemi = hemi self._units = units + self._alpha = float(alpha) self._subject_id = subject_id self._subjects_dir = subjects_dir self._views = views @@ -220,13 +231,16 @@ def __init__(self, subject_id, hemi, surf, title=None, self.set_time_interpolation('nearest') geo_kwargs = self.cortex_colormap(cortex) + # evaluate at the midpoint of the used colormap + val = -geo_kwargs['vmin'] / (geo_kwargs['vmax'] - geo_kwargs['vmin']) + self._brain_color = get_cmap(geo_kwargs['colormap'])(val) # load geometry for one or both hemispheres as necessary offset = None if (not offset or hemi != 'both') else 0.0 self._renderer = _get_renderer(name=self._title, size=self._size, bgcolor=background, - shape=(n_row, n_col), + shape=shape, fig=figure) for h in self._hemis: @@ -237,51 +251,61 @@ def __init__(self, subject_id, hemi, surf, title=None, geo.load_geometry() geo.load_curvature() self.geo[h] = geo - - for ri, v in enumerate(views): - for hi, h in enumerate(['lh', 'rh']): - views_dict = lh_views_dict if hemi == 'lh' else rh_views_dict - if not (hemi in ['lh', 'rh'] and h != hemi): - ci = hi if hemi == 'split' else 0 - self._renderer.subplot(ri, ci) - kwargs = { - "color": None, - "scalars": self.geo[h].bin_curv, - "vmin": geo_kwargs["vmin"], - "vmax": geo_kwargs["vmax"], - "colormap": geo_kwargs["colormap"], - "opacity": alpha, - } - if self._hemi_meshes.get(h) is None: - mesh_data = self._renderer.mesh( - x=self.geo[h].coords[:, 0], - y=self.geo[h].coords[:, 1], - z=self.geo[h].coords[:, 2], - triangles=self.geo[h].faces, - normals=self.geo[h].nn, - **kwargs, - ) - if isinstance(mesh_data, tuple): - actor, mesh = mesh_data - # add metadata to the mesh for picking - mesh._hemi = h - else: - actor, mesh = mesh_data.actor, mesh_data - self._hemi_meshes[h] = mesh - self._hemi_actors[h] = actor + for ri, ci, v in self._iter_views(h): + self._renderer.subplot(ri, ci) + kwargs = { + "color": None, + "scalars": self.geo[h].bin_curv, + "vmin": geo_kwargs["vmin"], + "vmax": geo_kwargs["vmax"], + "colormap": geo_kwargs["colormap"], + "opacity": alpha, + "pickable": False, + } + if self._hemi_meshes.get(h) is None: + mesh_data = self._renderer.mesh( + x=self.geo[h].coords[:, 0], + y=self.geo[h].coords[:, 1], + z=self.geo[h].coords[:, 2], + triangles=self.geo[h].faces, + normals=self.geo[h].nn, + **kwargs, + ) + if isinstance(mesh_data, tuple): + actor, mesh = mesh_data + # add metadata to the mesh for picking + mesh._hemi = h else: - self._renderer.polydata( - self._hemi_meshes[h], - **kwargs, - ) - del kwargs - self._renderer.set_camera(azimuth=views_dict[v].azim, - elevation=views_dict[v].elev) - + actor, mesh = mesh_data.actor, mesh_data + self._hemi_meshes[h] = mesh + self._hemi_actors[h] = actor + else: + self._renderer.polydata( + self._hemi_meshes[h], + **kwargs, + ) + del kwargs + self._renderer.set_camera(**views_dicts[h][v]) + + self.interaction = interaction self._closed = False if show: self._renderer.show() + @property + def interaction(self): + """The interaction style.""" + return self._interaction + + @interaction.setter + def interaction(self, interaction): + """Set the interaction style.""" + _validate_type(interaction, str, 'interaction') + _check_option('interaction', interaction, ('trackball', 'terrain')) + for ri, ci, _ in self._iter_views('vol'): # will traverse all + self._renderer.subplot(ri, ci) + self._renderer.set_interaction(interaction) + def cortex_colormap(self, cortex): """Return the colormap corresponding to the cortex.""" colormap_map = dict(classic=dict(colormap="Greys", @@ -302,8 +326,8 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None, time_label="auto", colorbar=True, hemi=None, remove_existing=None, time_label_size=None, initial_time=None, scale_factor=None, vector_alpha=None, - clim=None, verbose=None): - """Display data from a numpy array on the surface. + clim=None, src=None, volume_options=0.4, verbose=None): + """Display data from a numpy array on the surface or volume. This provides a similar interface to :meth:`surfer.Brain.add_overlay`, but it displays @@ -382,6 +406,9 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None, Not supported yet. alpha level to control opacity of the arrows. Only used for vector-valued data. If None (default), ``alpha`` is used. + clim : dict + Original clim arguments. + %(src_volume_options_layout)s %(verbose)s Notes @@ -405,7 +432,7 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None, _check_option('remove_existing', remove_existing, [None]) _check_option('time_label_size', time_label_size, [None]) - hemi = self._check_hemi(hemi) + hemi = self._check_hemi(hemi, extras=['vol']) array = np.asarray(array) vector_alpha = alpha if vector_alpha is None else vector_alpha self._data['vector_alpha'] = vector_alpha @@ -476,7 +503,7 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None, self._data['transparent'] = transparent # data specific for a hemi self._data[hemi] = dict() - self._data[hemi]['actor'] = None + self._data[hemi]['actors'] = None self._data[hemi]['mesh'] = None self._data[hemi]['glyph_actor'] = None self._data[hemi]['glyph_mesh'] = None @@ -488,66 +515,34 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None, self._data['fmin'] = fmin self._data['fmid'] = fmid self._data['fmax'] = fmax - - dt_max = fmax - dt_min = fmin if center is None else -1 * fmax self.update_lut() # 1) add the surfaces first - for ri, v in enumerate(self._views): - views_dict = lh_views_dict if hemi == 'lh' else rh_views_dict - if self._hemi != 'split': - ci = 0 - else: - ci = 0 if hemi == 'lh' else 1 + actor = None + for ri, ci, _ in self._iter_views(hemi): self._renderer.subplot(ri, ci) - kwargs = { - "color": None, - "colormap": self._data['ctable'], - "vmin": dt_min, - "vmax": dt_max, - "opacity": alpha, - "scalars": np.zeros(len(self.geo[hemi].coords)), - } - if self._data[hemi]['mesh'] is None: - mesh_data = self._renderer.mesh( - x=self.geo[hemi].coords[:, 0], - y=self.geo[hemi].coords[:, 1], - z=self.geo[hemi].coords[:, 2], - triangles=self.geo[hemi].faces, - normals=self.geo[hemi].nn, - **kwargs, - ) - if isinstance(mesh_data, tuple): - actor, mesh = mesh_data - # add metadata to the mesh for picking - mesh._hemi = hemi - self.resolve_coincident_topology(actor) - else: - actor, mesh = mesh_data, None - self._data[hemi]['actor'] = actor + if hemi in ('lh', 'rh'): + if self._data[hemi]['actors'] is None: + self._data[hemi]['actors'] = list() + actor, mesh = self._add_surface_data(hemi) + self._data[hemi]['actors'].append(actor) self._data[hemi]['mesh'] = mesh else: - self._renderer.polydata( - self._data[hemi]['mesh'], - **kwargs, - ) - del kwargs + actor, _ = self._add_volume_data(hemi, src, volume_options) + assert actor is not None # should have added one # 2) update time and smoothing properties # set_data_smoothing calls "set_time_point" for us, which will set # _current_time self.set_time_interpolation(self.time_interpolation) - self.set_data_smoothing(smoothing_steps) + self.set_data_smoothing(self._data['smoothing_steps']) # 3) add the other actors - for ri, v in enumerate(self._views): - views_dict = lh_views_dict if hemi == 'lh' else rh_views_dict - if self._hemi != 'split': - ci = 0 - else: - ci = 0 if hemi == 'lh' else 1 - if not self._time_label_added and time_label is not None: + for ri, ci, v in self._iter_views(hemi): + self._renderer.subplot(ri, ci) + # Add the time label to the bottommost view + do = (ri == self._subplot_shape[0] - 1) + if not self._time_label_added and time_label is not None and do: time_actor = self._renderer.text2d( x_window=0.95, y_window=y_txt, color=self._fg_color, @@ -557,16 +552,67 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None, ) self._data['time_actor'] = time_actor self._time_label_added = True - if colorbar and not self._colorbar_added: + if colorbar and not self._colorbar_added and do: self._renderer.scalarbar(source=actor, n_labels=8, color=self._fg_color, - bgcolor=(0.5, 0.5, 0.5)) + bgcolor=self._brain_color[:3]) self._colorbar_added = True - self._renderer.set_camera(azimuth=views_dict[v].azim, - elevation=views_dict[v].elev) - + self._renderer.set_camera(**views_dicts[hemi][v]) self._update() + def _iter_views(self, hemi): + # which rows and columns each type of visual needs to be added to + if self._hemi == 'split': + hemi_dict = dict(lh=[0], rh=[1], vol=[0, 1]) + else: + hemi_dict = dict(lh=[0], rh=[0], vol=[0]) + for vi, view in enumerate(self._views): + if self._hemi == 'split': + view_dict = dict(lh=[vi], rh=[vi], vol=[vi, vi]) + else: + view_dict = dict(lh=[vi], rh=[vi], vol=[vi]) + if self._view_layout == 'vertical': + rows = view_dict # views are rows + cols = hemi_dict # hemis are columns + else: + rows = hemi_dict # hemis are rows + cols = view_dict # views are columns + for ri, ci in zip(rows[hemi], cols[hemi]): + yield ri, ci, view + + def _add_surface_data(self, hemi): + rng = self._cmap_range + kwargs = { + "color": None, + "colormap": self._data['ctable'], + "vmin": rng[0], + "vmax": rng[1], + "opacity": self._data['alpha'], + "scalars": np.zeros(len(self.geo[hemi].coords)), + } + if self._data[hemi]['mesh'] is not None: + actor, mesh = self._renderer.polydata( + self._data[hemi]['mesh'], + **kwargs, + ) + return actor, mesh + mesh_data = self._renderer.mesh( + x=self.geo[hemi].coords[:, 0], + y=self.geo[hemi].coords[:, 1], + z=self.geo[hemi].coords[:, 2], + triangles=self.geo[hemi].faces, + normals=self.geo[hemi].nn, + **kwargs, + ) + if isinstance(mesh_data, tuple): + actor, mesh = mesh_data + # add metadata to the mesh for picking + mesh._hemi = hemi + self.resolve_coincident_topology(actor) + else: + actor, mesh = mesh_data, None + return actor, mesh + def remove_labels(self): """Remove all the ROI labels from the image.""" for data in self._label_data: @@ -574,6 +620,157 @@ def remove_labels(self): self._label_data.clear() self._update() + def _add_volume_data(self, hemi, src, volume_options): + _validate_type(src, SourceSpaces, 'src') + _check_option('src.kind', src.kind, ('volume',)) + _validate_type( + volume_options, (dict, 'numeric', None), 'volume_options') + assert hemi == 'vol' + if not isinstance(volume_options, dict): + volume_options = dict( + resolution=float(volume_options) if volume_options is not None + else None) + volume_options = _handle_default('volume_options', volume_options) + allowed_types = ( + ['resolution', (None, 'numeric')], + ['blending', (str,)], + ['alpha', ('numeric', None)], + ['surface_alpha', (None, 'numeric')]) + for key, types in allowed_types: + _validate_type(volume_options[key], types, + f'volume_options[{repr(key)}]') + extra_keys = set(volume_options) - set(a[0] for a in allowed_types) + if len(extra_keys): + raise ValueError( + f'volume_options got unknown keys {sorted(extra_keys)}') + _check_option('volume_options["blending"]', volume_options['blending'], + ('composite', 'mip')) + blending = volume_options['blending'] + alpha = volume_options['alpha'] + if alpha is None: + alpha = 0.4 if self._data[hemi]['array'].ndim == 3 else 1. + alpha = np.clip(float(alpha), 0., 1.) + surface_alpha = volume_options['surface_alpha'] + resolution = volume_options['resolution'] + if surface_alpha is None: + surface_alpha = min(alpha / 2., 0.4) + del volume_options + volume_pos = self._data[hemi].get('grid_volume_pos') + volume_neg = self._data[hemi].get('grid_volume_neg') + if volume_pos is None: + xyz = np.meshgrid( + *[np.arange(s) for s in src[0]['shape']], indexing='ij') + dimensions = np.array(src[0]['shape'], int) + mult = 1000 if self._units == 'mm' else 1 + src_mri_t = src[0]['src_mri_t']['trans'].copy() + src_mri_t[:3] *= mult + if resolution is not None: + resolution = resolution * mult / 1000. # to mm + del src, mult + coords = np.array([c.ravel(order='F') for c in xyz]).T + coords = apply_trans(src_mri_t, coords) + self.geo[hemi] = Bunch(coords=coords) + vertices = self._data[hemi]['vertices'] + assert self._data[hemi]['array'].shape[0] == len(vertices) + # MNE constructs the source space on a uniform grid in MRI space, + # but let's make sure + assert np.allclose( + src_mri_t[:3, :3], np.diag([src_mri_t[0, 0]] * 3)) + spacing = np.diag(src_mri_t)[:3] + origin = src_mri_t[:3, 3] - spacing / 2. + scalars = np.zeros(np.prod(dimensions)) + scalars[vertices] = 1. # for the outer mesh + grid, grid_mesh, volume_pos, volume_neg = \ + self._add_volume_object( + dimensions, origin, spacing, scalars, surface_alpha, + resolution, blending) + self._data[hemi]['alpha'] = alpha # incorrectly set earlier + self._data[hemi]['grid'] = grid + self._data[hemi]['grid_mesh'] = grid_mesh + self._data[hemi]['grid_coords'] = coords + self._data[hemi]['grid_src_mri_t'] = src_mri_t + self._data[hemi]['grid_shape'] = dimensions + self._data[hemi]['grid_volume_pos'] = volume_pos + self._data[hemi]['grid_volume_neg'] = volume_neg + actor_pos, _ = self._renderer.plotter.add_actor( + volume_pos, reset_camera=False, name=None, culling=False) + if volume_neg is not None: + actor_neg, _ = self._renderer.plotter.add_actor( + volume_neg, reset_camera=False, name=None, culling=False) + else: + actor_neg = None + grid_mesh = self._data[hemi]['grid_mesh'] + if grid_mesh is not None: + _, prop = self._renderer.plotter.add_actor( + grid_mesh, reset_camera=False, name=None, culling=False, + pickable=False) + prop.SetColor(*self._brain_color[:3]) + prop.SetOpacity(surface_alpha) + return actor_pos, actor_neg + + def _add_volume_object(self, dimensions, origin, spacing, scalars, + surface_alpha, resolution, blending): + # Now we can actually construct the visualization + import vtk + import pyvista as pv + grid = pv.UniformGrid() + grid.dimensions = dimensions + 1 # inject data on the cells + grid.origin = origin + grid.spacing = spacing + grid.cell_arrays['values'] = scalars + + # Add contour of enclosed volume (use GetOutput instead of + # GetOutputPort below to avoid updating) + grid_alg = vtk.vtkCellDataToPointData() + grid_alg.SetInputDataObject(grid) + grid_alg.SetPassCellData(False) + grid_alg.Update() + + if surface_alpha > 0: + grid_surface = vtk.vtkMarchingContourFilter() + grid_surface.ComputeNormalsOn() + grid_surface.ComputeScalarsOff() + grid_surface.SetInputData(grid_alg.GetOutput()) + grid_surface.SetValue(0, 0.1) + grid_surface.Update() + grid_mesh = vtk.vtkPolyDataMapper() + grid_mesh.SetInputData(grid_surface.GetOutput()) + else: + grid_mesh = None + + mapper = vtk.vtkSmartVolumeMapper() + if resolution is None: # native + mapper.SetScalarModeToUseCellData() + mapper.SetInputDataObject(grid) + else: + upsampler = vtk.vtkImageResample() + upsampler.SetInterpolationModeToNearestNeighbor() + upsampler.SetOutputSpacing(*([resolution] * 3)) + upsampler.SetInputConnection(grid_alg.GetOutputPort()) + mapper.SetInputConnection(upsampler.GetOutputPort()) + # Additive, AverageIntensity, and Composite might also be reasonable + remap = dict(composite='Composite', mip='MaximumIntensity') + getattr(mapper, f'SetBlendModeTo{remap[blending]}')() + volume_pos = vtk.vtkVolume() + volume_pos.SetMapper(mapper) + dist = grid.length / (np.mean(grid.dimensions) - 1) + volume_pos.GetProperty().SetScalarOpacityUnitDistance(dist) + if self._data['center'] is not None and blending == 'mip': + # We need to create a minimum intensity projection for the neg half + mapper_neg = vtk.vtkSmartVolumeMapper() + if resolution is None: # native + mapper_neg.SetScalarModeToUseCellData() + mapper_neg.SetInputDataObject(grid) + else: + mapper_neg.SetInputConnection(upsampler.GetOutputPort()) + mapper_neg.SetBlendModeToMinimumIntensity() + volume_neg = vtk.vtkVolume() + volume_neg.SetMapper(mapper_neg) + volume_neg.GetProperty().SetScalarOpacityUnitDistance(dist) + else: + volume_neg = None + return grid, grid_mesh, volume_pos, volume_neg + def add_label(self, label, color=None, alpha=1, scalar_thresh=None, borders=False, hemi=None, subdir=None): """Add an ROI label to the image. @@ -676,12 +873,7 @@ def add_label(self, label, color=None, alpha=1, scalar_thresh=None, cmap = np.array([(0, 0, 0, 0,), color]) ctable = np.round(cmap * 255).astype(np.uint8) - for ri, v in enumerate(self._views): - if self._hemi != 'split': - ci = 0 - else: - ci = 0 if hemi == 'lh' else 1 - views_dict = lh_views_dict if hemi == 'lh' else rh_views_dict + for ri, ci, v in self._iter_views(hemi): self._renderer.subplot(ri, ci) if borders: surface = { @@ -710,8 +902,7 @@ def add_label(self, label, color=None, alpha=1, scalar_thresh=None, actor, _ = mesh_data self.resolve_coincident_topology(actor) self._label_data.append(mesh_data) - self._renderer.set_camera(azimuth=views_dict[v].azim, - elevation=views_dict[v].elev) + self._renderer.set_camera(**views_dicts[hemi][v]) self._update() @@ -764,18 +955,12 @@ def add_foci(self, coords, coords_as_verts=False, map_surface=None, if self._units == 'm': scale_factor = scale_factor / 1000. - for ri, v in enumerate(self._views): - views_dict = lh_views_dict if hemi == 'lh' else rh_views_dict - if self._hemi != 'split': - ci = 0 - else: - ci = 0 if hemi == 'lh' else 1 + for ri, ci, v in self._iter_views(hemi): self._renderer.subplot(ri, ci) self._renderer.sphere(center=coords, color=color, scale=(10. * scale_factor), opacity=alpha) - self._renderer.set_camera(azimuth=views_dict[v].azim, - elevation=views_dict[v].elev) + self._renderer.set_camera(**views_dicts[hemi][v]) def add_text(self, x, y, text, name=None, color=None, opacity=1.0, row=-1, col=-1, font_size=None, justification=None): @@ -868,7 +1053,6 @@ def add_annotation(self, annot, borders=True, alpha=1, hemi=None, self._subject_id, 'label', ".".join([hemi, annot, 'annot'])) - print(filepath) if not os.path.exists(filepath): raise ValueError('Annotation file %s does not exist' % filepath) @@ -889,8 +1073,7 @@ def add_annotation(self, annot, borders=True, alpha=1, hemi=None, # Handle null labels properly cmap[:, 3] = 255 - # bgcolor = self._brain_color - bgcolor = [144, 144, 144, 255] + bgcolor = np.round(np.array(self._brain_color) * 255).astype(int) bgcolor[-1] = 0 cmap[cmap[:, 4] < 0, 4] += 2 ** 24 # wrap to positive cmap[cmap[:, 4] <= 0, :4] = bgcolor @@ -958,15 +1141,21 @@ def show_view(self, view=None, roll=None, distance=None, row=0, col=0, hemi=None): """Orient camera to display view.""" hemi = self._hemi if hemi is None else hemi - views_dict = lh_views_dict if hemi == 'lh' else rh_views_dict + if hemi == 'split': + if (self._view_layout == 'vertical' and col == 1 or + self._view_layout == 'horizontal' and row == 1): + hemi = 'rh' + else: + hemi = 'lh' if isinstance(view, str): - view = views_dict.get(view) - elif isinstance(view, dict): - view = View(azim=view['azimuth'], - elev=view['elevation']) + view = views_dicts[hemi].get(view) + view = view.copy() + if roll is not None: + view.update(roll=roll) + if distance is not None: + view.update(distance=distance) self._renderer.subplot(row, col) - self._renderer.set_camera(azimuth=view.azim, - elevation=view.elev) + self._renderer.set_camera(**view) self._renderer.reset_camera() self._update() @@ -1010,7 +1199,7 @@ def update_lut(self, fmin=None, fmid=None, fmax=None): fmax : float | None Maximum value in colormap. """ - from ..backends._pyvista import _set_colormap_range + from ..backends._pyvista import _set_colormap_range, _set_volume_range center = self._data['center'] colormap = self._data['colormap'] transparent = self._data['transparent'] @@ -1030,19 +1219,33 @@ def update_lut(self, fmin=None, fmid=None, fmax=None): # update our values rng = self._cmap_range ctable = self._data['ctable'] - for hemi in ['lh', 'rh']: + if self._colorbar_added: + scalar_bar = self._renderer.plotter.scalar_bar + else: + scalar_bar = None + for hemi in ['lh', 'rh', 'vol']: hemi_data = self._data.get(hemi) if hemi_data is not None: - if hemi_data['actor'] is not None: - actor = hemi_data['actor'] - if self._colorbar_added: - scalar_bar = self._renderer.plotter.scalar_bar - else: + if hemi_data.get('actors') is not None: + for actor in hemi_data['actors']: + _set_colormap_range(actor, ctable, scalar_bar, rng) + scalar_bar = None + + grid_volume_pos = hemi_data.get('grid_volume_pos') + grid_volume_neg = hemi_data.get('grid_volume_neg') + for grid_volume in (grid_volume_pos, grid_volume_neg): + if grid_volume is not None: + _set_volume_range( + grid_volume, ctable, hemi_data['alpha'], + scalar_bar, rng) scalar_bar = None - _set_colormap_range(actor, ctable, scalar_bar, rng) - glyph_actor = hemi_data['glyph_actor'] + + glyph_actor = hemi_data.get('glyph_actor') if glyph_actor is not None: - _set_colormap_range(glyph_actor, ctable, None, rng) + for glyph_actor_ in glyph_actor: + _set_colormap_range( + glyph_actor_, ctable, scalar_bar, rng) + scalar_bar = None def set_data_smoothing(self, n_steps): """Set the number of smoothing steps. @@ -1101,7 +1304,7 @@ def set_time_interpolation(self, interpolation): self._time_interp_inv = None if self._times is not None: idx = np.arange(self._n_times) - for hemi in ['lh', 'rh']: + for hemi in ['lh', 'rh', 'vol']: hemi_data = self._data.get(hemi) if hemi_data is not None: array = hemi_data['array'] @@ -1113,14 +1316,15 @@ def set_time_interpolation(self, interpolation): def set_time_point(self, time_idx): """Set the time point shown (can be a float to interpolate).""" from ..backends._pyvista import _set_mesh_scalars - current_act_data = list() + self._current_act_data = dict() time_actor = self._data.get('time_actor', None) time_label = self._data.get('time_label', None) - for hemi in ['lh', 'rh']: + for hemi in ['lh', 'rh', 'vol']: hemi_data = self._data.get(hemi) if hemi_data is not None: array = hemi_data['array'] # interpolate in time + vectors = None if array.ndim == 1: act_data = array self._current_time = 0 @@ -1131,39 +1335,60 @@ def set_time_point(self, time_idx): vectors = act_data act_data = np.linalg.norm(act_data, axis=1) self._current_time = self._time_interp_inv(time_idx) - current_act_data.append(act_data) + self._current_act_data[hemi] = act_data if time_actor is not None and time_label is not None: time_actor.SetInput(time_label(self._current_time)) + # update the volume interpolation + grid = hemi_data.get('grid') + if grid is not None: + vertices = self._data['vol']['vertices'] + values = self._current_act_data['vol'] + rng = self._cmap_range + fill = 0 if self._data['center'] is not None else rng[0] + grid.cell_arrays['values'].fill(fill) + # XXX for sided data, we probably actually need two + # volumes as composite/MIP needs to look at two + # extremes... for now just use abs. Eventually we can add + # two volumes if we want. + grid.cell_arrays['values'][vertices] = values + # interpolate in space smooth_mat = hemi_data.get('smooth_mat') if smooth_mat is not None: act_data = smooth_mat.dot(act_data) # update the mesh scalar values - if hemi_data['mesh'] is not None: + if hemi_data.get('mesh') is not None: mesh = hemi_data['mesh'] _set_mesh_scalars(mesh, act_data, 'Data') # update the glyphs - if array.ndim == 3: - self.update_glyphs(hemi, vectors) - self._current_act_data = np.concatenate(current_act_data) + if vectors is not None: + self._update_glyphs(hemi, vectors) + self._data['time_idx'] = time_idx self._update() - def update_glyphs(self, hemi, vectors): + def _update_glyphs(self, hemi, vectors): from ..backends._pyvista import _set_colormap_range hemi_data = self._data.get(hemi) - if hemi_data is not None: - vertices = hemi_data['vertices'] - ctable = self._data['ctable'] - vector_alpha = self._data['vector_alpha'] - scale_factor = self._data['scale_factor'] - rng = self._cmap_range - vertices = slice(None) if vertices is None else vertices - x, y, z = np.array(self.geo[hemi].coords)[vertices].T - + assert hemi_data is not None + vertices = hemi_data['vertices'] + vector_alpha = self._data['vector_alpha'] + scale_factor = self._data['scale_factor'] + vertices = slice(None) if vertices is None else vertices + x, y, z = np.array(self.geo[hemi].coords)[vertices].T + + if hemi_data['glyph_mesh'] is None: + add = True + hemi_data['glyph_mesh'] = list() + hemi_data['glyph_actor'] = list() + else: + add = False + count = 0 + for ri, ci, _ in self._iter_views(hemi): + self._renderer.subplot(ri, ci) polydata = self._renderer.quiver3d( x, y, z, vectors[:, 0], vectors[:, 1], vectors[:, 2], @@ -1175,16 +1400,20 @@ def update_glyphs(self, hemi, vectors): name=str(hemi) + "_glyph" ) if polydata is not None: - if hemi_data['glyph_mesh'] is None: + if not add: + glyph_actor = hemi_data['glyph_actor'][count] + glyph_mesh = hemi_data['glyph_mesh'][count] + glyph_mesh.shallow_copy(polydata) + else: glyph_actor, _ = self._renderer.polydata(polydata) + assert not isinstance(glyph_actor, list) glyph_actor.VisibilityOff() glyph_actor.GetProperty().SetLineWidth(2.) - hemi_data['glyph_actor'] = glyph_actor - hemi_data['glyph_mesh'] = polydata - else: - glyph_actor = hemi_data['glyph_actor'] - glyph_mesh = hemi_data['glyph_mesh'] - glyph_mesh.shallow_copy(polydata) + hemi_data['glyph_mesh'].append(polydata) + hemi_data['glyph_actor'].append(glyph_actor) + count += 1 + ctable = self._data['ctable'] + rng = self._cmap_range _set_colormap_range(glyph_actor, ctable, None, rng) # the glyphs are now ready to be displayed glyph_actor.VisibilityOn() @@ -1219,7 +1448,8 @@ def update_auto_scaling(self, restore=False): colormap = self._data['colormap'] transparent = self._data['transparent'] mapdata = _process_clim( - clim, colormap, transparent, self._current_act_data, + clim, colormap, transparent, + np.concatenate(list(self._current_act_data.values())), allow_pos_lims) diverging = 'pos_lims' in mapdata['clim'] colormap = mapdata['colormap'] @@ -1381,7 +1611,7 @@ def _show(self): except RuntimeError: logger.info("No active/running renderer available.") - def _check_hemi(self, hemi): + def _check_hemi(self, hemi, extras=()): """Check for safe single-hemi input, returns str.""" if hemi is None: if self._hemi not in ['lh', 'rh']: @@ -1389,7 +1619,7 @@ def _check_hemi(self, hemi): 'hemispheres are displayed') else: hemi = self._hemi - elif hemi not in ['lh', 'rh']: + elif hemi not in ['lh', 'rh'] + list(extras): extra = ' or None' if self._hemi in ['lh', 'rh'] else '' raise ValueError('hemi must be either "lh" or "rh"' + extra + ", got " + str(hemi)) @@ -1441,15 +1671,16 @@ def scale_data_colormap(self, fmin, fmid, fmax, transparent, for hemi in ['lh', 'rh']: hemi_data = self._data.get(hemi) if hemi_data is not None: - if hemi_data['actor'] is not None: - actor = hemi_data['actor'] - vtk_lut = actor.GetMapper().GetLookupTable() - vtk_lut.SetNumberOfColors(n_col) - vtk_lut.SetRange([fmin, fmax]) - vtk_lut.Build() - for i in range(0, n_col): - lt = lut_lst[i] - vtk_lut.SetTableValue(i, lt[0], lt[1], lt[2], alpha) + if hemi_data['actors'] is not None: + for actor in hemi_data['actors']: + vtk_lut = actor.GetMapper().GetLookupTable() + vtk_lut.SetNumberOfColors(n_col) + vtk_lut.SetRange([fmin, fmax]) + vtk_lut.Build() + for i in range(0, n_col): + lt = lut_lst[i] + vtk_lut.SetTableValue( + i, lt[0], lt[1], lt[2], alpha) self._update_fscale(1.0) def enable_depth_peeling(self): diff --git a/mne/viz/_brain/_timeviewer.py b/mne/viz/_brain/_timeviewer.py index 11959a1d937..eb0e690e4b1 100644 --- a/mne/viz/_brain/_timeviewer.py +++ b/mne/viz/_brain/_timeviewer.py @@ -13,12 +13,17 @@ import warnings import numpy as np +from scipy import sparse from . import _Brain -from ..utils import _check_option, _show_help, _get_color_list, tight_layout +from .view import _lh_views_dict + +from ..utils import _show_help, _get_color_list, tight_layout from ...externals.decorator import decorator -from ...source_space import vertex_to_mni -from ...utils import _ReuseCycle, warn, copy_doc +from ...source_space import vertex_to_mni, _read_talxfm +from ...transforms import apply_trans +from ...utils import _ReuseCycle, warn, copy_doc, _validate_type +from ...fixes import nullcontext @decorator @@ -43,13 +48,23 @@ class MplCanvas(object): def __init__(self, time_viewer, width, height, dpi): from PyQt5 import QtWidgets + from matplotlib import rc_context from matplotlib.figure import Figure from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg if time_viewer.separate_canvas: parent = None else: parent = time_viewer.window - self.fig = Figure(figsize=(width, height), dpi=dpi) + # prefer constrained layout here but live with tight_layout otherwise + context = nullcontext + extra_events = ('resize',) + try: + context = rc_context({'figure.constrained_layout.use': True}) + extra_events = () + except KeyError: + pass + with context: + self.fig = Figure(figsize=(width, height), dpi=dpi) self.canvas = FigureCanvasQTAgg(self.fig) self.axes = self.fig.add_subplot(111) self.axes.set(xlabel='Time (sec)', ylabel='Activation (AU)') @@ -60,11 +75,9 @@ def __init__(self, time_viewer, width, height, dpi): QtWidgets.QSizePolicy.Expanding ) FigureCanvasQTAgg.updateGeometry(self.canvas) - # XXX eventually this should be called in the window resize callback - tight_layout(fig=self.axes.figure) self.time_viewer = time_viewer self.time_func = time_viewer.time_call - for event in ('button_press', 'motion_notify'): + for event in ('button_press', 'motion_notify') + extra_events: self.canvas.mpl_connect( event + '_event', getattr(self, 'on_' + event)) @@ -89,7 +102,9 @@ def update_plot(self): facecolor=self.time_viewer.brain._bg_color) for text in leg.get_texts(): text.set_color(self.time_viewer.brain._fg_color) - self.canvas.draw() + with warnings.catch_warnings(record=True): + warnings.filterwarnings('ignore', 'constrained_layout') + self.canvas.draw() def set_color(self, bg_color, fg_color): """Set the widget colors.""" @@ -123,6 +138,10 @@ def on_button_press(self, event): on_motion_notify = on_button_press # for now they can be the same + def on_resize(self, event): + """Handle resize events.""" + tight_layout(fig=self.axes.figure) + class IntSlider(object): """Class to set a integer slider.""" @@ -311,17 +330,10 @@ def __init__(self, brain, show_traces=False): _require_minimum_version('0.24') # shared configuration + if hasattr(brain, 'time_viewer'): + raise RuntimeError('brain already has a TimeViewer') self.brain = brain - self.orientation = [ - 'lateral', - 'medial', - 'rostral', - 'caudal', - 'dorsal', - 'ventral', - 'frontal', - 'parietal' - ] + self.orientation = list(_lh_views_dict.keys()) self.default_smoothing_range = [0, 15] # detect notebook @@ -340,9 +352,10 @@ def __init__(self, brain, show_traces=False): self.default_playback_speed_range = [0.01, 1] self.default_playback_speed_value = 0.05 self.default_status_bar_msg = "Press ? for help" - self.act_data_smooth = {'lh': None, 'rh': None} + all_keys = ('lh', 'rh', 'vol') + self.act_data_smooth = {key: (None, None) for key in all_keys} self.color_cycle = None - self.picked_points = {'lh': list(), 'rh': list()} + self.picked_points = {key: list() for key in all_keys} self._mouse_no_mvt = -1 self.icons = dict() self.actions = dict() @@ -365,16 +378,28 @@ def __init__(self, brain, show_traces=False): # Derived parameters: self.playback_speed = self.default_playback_speed_value - _check_option('show_traces', type(show_traces), [bool, str]) - if isinstance(show_traces, str) and show_traces == "separate": + _validate_type(show_traces, (bool, str, 'numeric'), 'show_traces') + self.interactor_fraction = 0.25 + if isinstance(show_traces, str): + assert 'show_traces' == 'separate' # should be guaranteed earlier self.show_traces = True self.separate_canvas = True else: - self.show_traces = show_traces + if isinstance(show_traces, bool): + self.show_traces = show_traces + else: + show_traces = float(show_traces) + if not 0 < show_traces < 1: + raise ValueError( + 'show traces, if numeric, must be between 0 and 1, ' + f'got {show_traces}') + self.show_traces = True + self.interactor_fraction = show_traces self.separate_canvas = False + del show_traces + self._spheres = list() self.load_icons() - self.interactor_stretch = 3 self.configure_time_label() self.configure_sliders() self.configure_scalar_bar() @@ -391,17 +416,26 @@ def __init__(self, brain, show_traces=False): @contextlib.contextmanager def ensure_minimum_sizes(self): + from ..backends._pyvista import _process_events sz = self.brain._size adjust_mpl = self.show_traces and not self.separate_canvas if not adjust_mpl: yield else: - self.mpl_canvas.canvas.setMinimumSize( - sz[0], int(round(sz[1] / self.interactor_stretch))) + mpl_h = int(round((sz[1] * self.interactor_fraction) / + (1 - self.interactor_fraction))) + self.mpl_canvas.canvas.setMinimumSize(sz[0], mpl_h) try: yield finally: + self.splitter.setSizes([sz[1], mpl_h]) self.mpl_canvas.canvas.setMinimumSize(0, 0) + _process_events(self.plotter) + for hemi in ('lh', 'rh'): + if hemi == 'rh' and self.brain._hemi == 'split': + continue + for ri, ci, v in self.brain._iter_views(hemi): + self.brain.show_view(view=v, row=ri, col=ci) def toggle_interface(self, value=None): if value is None: @@ -618,21 +652,13 @@ def configure_scalar_bar(self): def configure_sliders(self): rng = _get_range(self.brain) # Orientation slider - # default: put orientation slider on the first view - if self.brain._hemi in ('split', 'both'): - self.plotter.subplot(0, 0) - # Use 'lh' as a reference for orientation for 'both' if self.brain._hemi == 'both': hemis_ref = ['lh'] else: hemis_ref = self.brain._hemis for hemi in hemis_ref: - if self.brain._hemi == 'split': - ci = 0 if hemi == 'lh' else 1 - else: - ci = 0 - for ri, view in enumerate(self.brain._views): + for ri, ci, view in self.brain._iter_views(hemi): self.plotter.subplot(ri, ci) self.orientation_call = ShowView( plotter=self.plotter, @@ -655,9 +681,9 @@ def configure_sliders(self): self.set_slider_style(orientation_slider, show_label=False) self.orientation_call(view, update_widget=True) - # necessary because show_view modified subplot - if self.brain._hemi in ('split', 'both'): - self.plotter.subplot(0, 0) + # Put other sliders on the bottom right view + ri, ci = np.array(self.brain._subplot_shape) - 1 + self.plotter.subplot(ri, ci) # Smoothing slider self.smoothing_call = IntSlider( @@ -829,54 +855,90 @@ def configure_point_picking(self): self.color_cycle = _ReuseCycle(_get_color_list()) win = self.plotter.app_window dpi = win.windowHandle().screen().logicalDotsPerInch() - w, h = win.geometry().width() / dpi, win.geometry().height() / dpi - h /= 3 # one third of the window - self.mpl_canvas = MplCanvas(self, w, h, dpi) + ratio = (1 - self.interactor_fraction) / self.interactor_fraction + w = self.interactor.geometry().width() + h = self.interactor.geometry().height() / ratio + # Get the fractional components for the brain and mpl + self.mpl_canvas = MplCanvas(self, w / dpi, h / dpi, dpi) xlim = [np.min(self.brain._data['time']), np.max(self.brain._data['time'])] with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) self.mpl_canvas.axes.set(xlim=xlim) - vlayout = self.plotter.frame.layout() if not self.separate_canvas: - vlayout.addWidget(self.mpl_canvas.canvas) - vlayout.setStretch(0, self.interactor_stretch) - vlayout.setStretch(1, 1) + from PyQt5.QtWidgets import QSplitter + from PyQt5.QtCore import Qt + canvas = self.mpl_canvas.canvas + vlayout = self.plotter.frame.layout() + vlayout.removeWidget(self.interactor) + self.splitter = splitter = QSplitter( + orientation=Qt.Vertical, parent=self.plotter.frame) + vlayout.addWidget(splitter) + splitter.addWidget(self.interactor) + splitter.addWidget(canvas) self.mpl_canvas.set_color( bg_color=self.brain._bg_color, fg_color=self.brain._fg_color, ) self.mpl_canvas.show() - # get brain data - for idx, hemi in enumerate(['lh', 'rh']): + # get data for each hemi + for idx, hemi in enumerate(['vol', 'lh', 'rh']): hemi_data = self.brain._data.get(hemi) if hemi_data is not None: act_data = hemi_data['array'] if act_data.ndim == 3: act_data = np.linalg.norm(act_data, axis=1) smooth_mat = hemi_data.get('smooth_mat') + vertices = hemi_data['vertices'] + if hemi == 'vol': + assert smooth_mat is None + smooth_mat = sparse.csr_matrix( + (np.ones(len(vertices)), + (vertices, np.arange(len(vertices))))) self.act_data_smooth[hemi] = (act_data, smooth_mat) - # simulate a picked renderer - if self.brain._hemi == 'split': - self.picked_renderer = self.plotter.renderers[idx] - else: - self.picked_renderer = self.plotter.renderers[0] - - # initialize the default point - color = next(self.color_cycle) - ind = np.unravel_index( - np.argmax(self.act_data_smooth[hemi][0], axis=None), - self.act_data_smooth[hemi][0].shape - ) - vertex_id = hemi_data['vertices'][ind[0]] - mesh = hemi_data['mesh'] - line = self.plot_time_course(hemi, vertex_id, color) - self.add_point(hemi, mesh, vertex_id, line, color) + # plot the GFP + y = np.concatenate(list(v[0] for v in self.act_data_smooth.values() + if v[0] is not None)) + y = np.linalg.norm(y, axis=0) / np.sqrt(len(y)) + self.mpl_canvas.axes.plot( + self.brain._data['time'], y, + lw=3, label='GFP', zorder=3, color=self.brain._fg_color, + alpha=0.5, ls=':') + # now plot the time line self.plot_time_line() + # then the picked points + for idx, hemi in enumerate(['lh', 'rh', 'vol']): + act_data = self.act_data_smooth.get(hemi, [None])[0] + if act_data is None: + continue + hemi_data = self.brain._data[hemi] + vertices = hemi_data['vertices'] + + # simulate a picked renderer + if self.brain._hemi in ('both', 'rh') or hemi == 'vol': + idx = 0 + self.picked_renderer = self.plotter.renderers[idx] + + # initialize the default point + if self.brain._data['initial_time'] is not None: + # pick at that time + use_data = act_data[ + :, [np.round(self.brain._data['time_idx']).astype(int)]] + else: + use_data = act_data + ind = np.unravel_index(np.argmax(np.abs(use_data), axis=None), + use_data.shape) + if hemi == 'vol': + mesh = hemi_data['grid'] + else: + mesh = hemi_data['mesh'] + vertex_id = vertices[ind[0]] + self.add_point(hemi, mesh, vertex_id) + _update_picking_callback( self.plotter, self.on_mouse_move, @@ -987,17 +1049,79 @@ def on_button_release(self, vtk_picker, event): self._mouse_no_mvt = 0 def on_pick(self, vtk_picker, event): + # vtk_picker is a vtkCellPicker cell_id = vtk_picker.GetCellId() mesh = vtk_picker.GetDataSet() - if mesh is None or cell_id == -1: - return - + if mesh is None or cell_id == -1 or not self._mouse_no_mvt: + return # don't pick + + # 1) Check to see if there are any spheres along the ray + if len(self._spheres): + collection = vtk_picker.GetProp3Ds() + found_sphere = None + for ii in range(collection.GetNumberOfItems()): + actor = collection.GetItemAsObject(ii) + for sphere in self._spheres: + if any(a is actor for a in sphere._actors): + found_sphere = sphere + break + if found_sphere is not None: + break + if found_sphere is not None: + assert found_sphere._is_point + mesh = found_sphere + + # 2) Remove sphere if it's what we have if hasattr(mesh, "_is_point"): self.remove_point(mesh) - elif self._mouse_no_mvt: + return + + # 3) Otherwise, pick the objects in the scene + try: hemi = mesh._hemi - pos = vtk_picker.GetPickPosition() + except AttributeError: # volume + hemi = 'vol' + else: + assert hemi in ('lh', 'rh') + if self.act_data_smooth[hemi][0] is None: # no data to add for hemi + return + pos = np.array(vtk_picker.GetPickPosition()) + if hemi == 'vol': + # VTK will give us the point closest to the viewer in the vol. + # We want to pick the point with the maximum value along the + # camera-to-click array, which fortunately we can get "just" + # by inspecting the points that are sufficiently close to the + # ray. + grid = mesh = self.brain._data[hemi]['grid'] + vertices = self.brain._data[hemi]['vertices'] + coords = self.brain._data[hemi]['grid_coords'][vertices] + scalars = grid.cell_arrays['values'][vertices] + spacing = np.array(grid.GetSpacing()) + max_dist = np.linalg.norm(spacing) / 2. + origin = vtk_picker.GetRenderer().GetActiveCamera().GetPosition() + ori = pos - origin + ori /= np.linalg.norm(ori) + # the magic formula: distance from a ray to a given point + dists = np.linalg.norm(np.cross(ori, coords - pos), axis=1) + assert dists.shape == (len(coords),) + mask = dists <= max_dist + idx = np.where(mask)[0] + if len(idx) == 0: + return # weird point on edge of volume? + # useful for debugging the ray by mapping it into the volume: + # dists = dists - dists.min() + # dists = (1. - dists / dists.max()) * self.brain._cmap_range[1] + # grid.cell_arrays['values'][vertices] = dists * mask + idx = idx[np.argmax(np.abs(scalars[idx]))] + vertex_id = vertices[idx] + # Naive way: convert pos directly to idx; i.e., apply mri_src_t + # shape = self.brain._data[hemi]['grid_shape'] + # taking into account the cell vs point difference (spacing/2) + # shift = np.array(grid.GetOrigin()) + spacing / 2. + # ijk = np.round((pos - shift) / spacing).astype(int) + # vertex_id = np.ravel_multi_index(ijk, shape, order='F') + else: vtk_cell = mesh.GetCell(cell_id) cell = [vtk_cell.GetPointId(point_id) for point_id in range(vtk_cell.GetNumberOfPoints())] @@ -1005,18 +1129,29 @@ def on_pick(self, vtk_picker, event): idx = np.argmin(abs(vertices - pos), axis=0) vertex_id = cell[idx[0]] - if vertex_id not in self.picked_points[hemi]: - color = next(self.color_cycle) - - # update associated time course - line = self.plot_time_course(hemi, vertex_id, color) + if vertex_id not in self.picked_points[hemi]: + self.add_point(hemi, mesh, vertex_id) - # add glyph at picked point - self.add_point(hemi, mesh, vertex_id, line, color) - - def add_point(self, hemi, mesh, vertex_id, line, color): + def add_point(self, hemi, mesh, vertex_id): from ..backends._pyvista import _sphere - center = mesh.GetPoints().GetPoint(vertex_id) + color = next(self.color_cycle) + line = self.plot_time_course(hemi, vertex_id, color) + if hemi == 'vol': + ijk = np.unravel_index( + vertex_id, np.array(mesh.GetDimensions()) - 1, order='F') + # should just be GetCentroid(center), but apparently it's VTK9+: + # center = np.empty(3) + # voxel.GetCentroid(center) + voxel = mesh.GetCell(*ijk) + pts = voxel.GetPoints() + n_pts = pts.GetNumberOfPoints() + center = np.empty((n_pts, 3)) + for ii in range(pts.GetNumberOfPoints()): + pts.GetPoint(ii, center[ii]) + center = np.mean(center, axis=0) + else: + center = mesh.GetPoints().GetPoint(vertex_id) + del mesh # from the picked renderer to the subplot coords rindex = self.plotter.renderers.index(self.picked_renderer) @@ -1024,8 +1159,8 @@ def add_point(self, hemi, mesh, vertex_id, line, color): actors = list() spheres = list() - for ri, view in enumerate(self.brain._views): - self.plotter.subplot(ri, col) + for ri, ci, _ in self.brain._iter_views(hemi): + self.plotter.subplot(ri, ci) # Using _sphere() instead of renderer.sphere() for 2 reasons: # 1) renderer.sphere() fails on Windows in a scenario where a lot # of picking requests are done in a short span of time (could be @@ -1049,16 +1184,14 @@ def add_point(self, hemi, mesh, vertex_id, line, color): sphere._actors = actors sphere._color = color sphere._vertex_id = vertex_id + sphere._spheres = spheres self.picked_points[hemi].append(vertex_id) - - # this is used for testing only - if hasattr(self, "_spheres"): - self._spheres += spheres - else: - self._spheres = spheres + self._spheres.extend(spheres) def remove_point(self, mesh): + if mesh._spheres is None: + return # already removed mesh._line.remove() self.mpl_canvas.update_plot() self.picked_points[mesh._hemi].remove(mesh._vertex_id) @@ -1067,30 +1200,42 @@ def remove_point(self, mesh): # entire color cycle warnings.simplefilter('ignore') self.color_cycle.restore(mesh._color) + # remove all actors self.plotter.remove_actor(mesh._actors) mesh._actors = None + # remove all meshes from sphere list + for sphere in list(mesh._spheres): # includes itself, so copy + self._spheres.pop(self._spheres.index(sphere)) + sphere._spheres = sphere._actors = None def clear_points(self): - if hasattr(self, "_spheres"): - for sphere in self._spheres: - vertex_id = sphere._vertex_id - hemi = sphere._hemi - if vertex_id in self.picked_points[hemi]: - self.remove_point(sphere) - self._spheres.clear() + for sphere in list(self._spheres): # will remove itself, so copy + self.remove_point(sphere) + assert sum(len(v) for v in self.picked_points.values()) == 0 + assert len(self._spheres) == 0 def plot_time_course(self, hemi, vertex_id, color): if not hasattr(self, "mpl_canvas"): return time = self.brain._data['time'].copy() # avoid circular ref - hemi_str = 'L' if hemi == 'lh' else 'R' - hemi_int = 0 if hemi == 'lh' else 1 - mni = vertex_to_mni( - vertices=vertex_id, - hemis=hemi_int, - subject=self.brain._subject_id, - subjects_dir=self.brain._subjects_dir - ) + if hemi == 'vol': + hemi_str = 'V' + xfm = _read_talxfm( + self.brain._subject_id, self.brain._subjects_dir) + if self.brain._units == 'm': + xfm['trans'][:3, 3] /= 1000. + ijk = np.unravel_index( + vertex_id, self.brain._data[hemi]['grid_shape'], order='F') + src_mri_t = self.brain._data[hemi]['grid_src_mri_t'] + mni = apply_trans(np.dot(xfm['trans'], src_mri_t), ijk) + else: + hemi_str = 'L' if hemi == 'lh' else 'R' + mni = vertex_to_mni( + vertices=vertex_id, + hemis=0 if hemi == 'lh' else 1, + subject=self.brain._subject_id, + subjects_dir=self.brain._subjects_dir + ) label = "{}:{} MNI: {}".format( hemi_str, str(vertex_id).ljust(6), ', '.join('%5.1f' % m for m in mni)) @@ -1104,7 +1249,8 @@ def plot_time_course(self, hemi, vertex_id, color): act_data, label=label, lw=1., - color=color + color=color, + zorder=4, ) return line @@ -1121,9 +1267,8 @@ def plot_time_line(self): color=self.brain._fg_color, lw=1, ) - else: - self.time_line.set_xdata(current_time) - self.mpl_canvas.update_plot() + self.time_line.set_xdata(current_time) + self.mpl_canvas.update_plot() def help(self): pairs = [ @@ -1192,9 +1337,8 @@ def clean(self): self.mpl_canvas = None self.time_actor = None self.picked_renderer = None - self.act_data_smooth["lh"] = None - self.act_data_smooth["rh"] = None - self.act_data_smooth = None + for key in list(self.act_data_smooth.keys()): + self.act_data_smooth[key] = None class _LinkViewer(object): @@ -1286,7 +1430,7 @@ def _update_camera(vtk_picker, event): def _get_range(brain): - val = np.abs(brain._current_act_data) + val = np.abs(np.concatenate(list(brain._current_act_data.values()))) return [np.min(val), np.max(val)] diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index 746a9211ff9..2ecbc1dfd1f 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -14,14 +14,17 @@ import numpy as np from numpy.testing import assert_allclose -from mne import SourceEstimate, read_source_estimate -from mne.source_space import read_source_spaces, vertex_to_mni +from mne import (read_source_estimate, SourceEstimate, MixedSourceEstimate, + VolSourceEstimate) +from mne.source_space import (read_source_spaces, vertex_to_mni, + setup_volume_source_space) from mne.datasets import testing from mne.utils import check_version from mne.viz._brain import _Brain, _TimeViewer, _LinkViewer, _BrainScraper from mne.viz._brain.colormap import calculate_lut from matplotlib import cm, image +import matplotlib.pyplot as plt data_path = testing.data_path(download=False) subject_id = 'sample' @@ -32,13 +35,27 @@ 'sample-oct-6-src.fif') +class _Collection(object): + def __init__(self, actors): + self._actors = actors + + def GetNumberOfItems(self): + return len(self._actors) + + def GetItemAsObject(self, ii): + return self._actors[ii] + + class TstVTKPicker(object): """Class to test cell picking.""" - def __init__(self, mesh, cell_id): + def __init__(self, mesh, cell_id, hemi, brain): self.mesh = mesh self.cell_id = cell_id self.point_id = None + self.hemi = hemi + self.brain = brain + self._actors = () def GetCellId(self): """Return the picked cell.""" @@ -50,15 +67,33 @@ def GetDataSet(self): def GetPickPosition(self): """Return the picked position.""" - vtk_cell = self.mesh.GetCell(self.cell_id) - cell = [vtk_cell.GetPointId(point_id) for point_id - in range(vtk_cell.GetNumberOfPoints())] - self.point_id = cell[0] - return self.mesh.points[self.point_id] + if self.hemi == 'vol': + self.point_id = self.cell_id + return self.brain._data['vol']['grid_coords'][self.cell_id] + else: + vtk_cell = self.mesh.GetCell(self.cell_id) + cell = [vtk_cell.GetPointId(point_id) for point_id + in range(vtk_cell.GetNumberOfPoints())] + self.point_id = cell[0] + return self.mesh.points[self.point_id] + + def GetProp3Ds(self): + """Return all picked Prop3Ds.""" + return _Collection(self._actors) + + def GetRenderer(self): + """Return the "renderer".""" + return self # set this to also be the renderer and active camera + + GetActiveCamera = GetRenderer + + def GetPosition(self): + """Return the position.""" + return np.array(self.GetPickPosition()) - (0, 0, 100) @testing.requires_testing_data -def test_brain(renderer): +def test_brain_init(renderer, tmpdir, pixel_ratio): """Test initialization of the _Brain instance.""" from mne.label import read_label hemi = 'lh' @@ -67,18 +102,21 @@ def test_brain(renderer): title = 'test' size = (300, 300) + kwargs = dict(subject_id=subject_id, subjects_dir=subjects_dir) with pytest.raises(ValueError, match='"size" parameter must be'): - _Brain(subject_id=subject_id, hemi=hemi, surf=surf, size=[1, 2, 3]) + _Brain(hemi=hemi, surf=surf, size=[1, 2, 3], **kwargs) with pytest.raises(TypeError, match='figure'): - _Brain(subject_id=subject_id, hemi=hemi, surf=surf, figure='foo') + _Brain(hemi=hemi, surf=surf, figure='foo', **kwargs) + with pytest.raises(TypeError, match='interaction'): + _Brain(hemi=hemi, surf=surf, interaction=0, **kwargs) with pytest.raises(ValueError, match='interaction'): - _Brain(subject_id=subject_id, hemi=hemi, surf=surf, interaction=0) + _Brain(hemi=hemi, surf=surf, interaction='foo', **kwargs) with pytest.raises(KeyError): - _Brain(subject_id=subject_id, hemi='foo', surf=surf) + _Brain(hemi='foo', surf=surf, **kwargs) - brain = _Brain(subject_id, hemi=hemi, surf=surf, size=size, - subjects_dir=subjects_dir, title=title, - cortex=cortex) + brain = _Brain(hemi=hemi, surf=surf, size=size, title=title, + cortex=cortex, units='m', **kwargs) + assert brain.interaction == 'trackball' # add_data stc = read_source_estimate(fname_stc) fmin = stc.data.min() @@ -110,19 +148,47 @@ def test_brain(renderer): with pytest.raises(ValueError): brain.add_data(hemi_data, fmin=fmin, hemi=hemi, fmax=fmax, vertices=None) + with pytest.raises(ValueError, match='has shape'): + brain.add_data(hemi_data[:, np.newaxis], fmin=fmin, hemi=hemi, + fmax=fmax, vertices=None, time=[0, 1]) brain.add_data(hemi_data, fmin=fmin, hemi=h, fmax=fmax, colormap='hot', vertices=hemi_vertices, smoothing_steps='nearest', colorbar=False, time=None) - brain.add_data(hemi_data, fmin=fmin, hemi=h, fmax=fmax, + assert brain.data['lh']['array'] is hemi_data + assert brain.views == ['lateral'] + assert brain.hemis == ('lh',) + brain.add_data(hemi_data[:, np.newaxis], fmin=fmin, hemi=h, fmax=fmax, colormap='hot', vertices=hemi_vertices, smoothing_steps=1, initial_time=0., colorbar=False, - time=None) - + time=[0]) + brain.set_time_point(0) # should hit _safe_interp1d + + with pytest.raises(ValueError, match='consistent with'): + brain.add_data(hemi_data[:, np.newaxis], fmin=fmin, hemi=h, + fmax=fmax, colormap='hot', vertices=hemi_vertices, + smoothing_steps='nearest', colorbar=False, + time=[1]) + with pytest.raises(ValueError, match='different from'): + brain.add_data(hemi_data[:, np.newaxis][:, [0, 0]], + fmin=fmin, hemi=h, fmax=fmax, colormap='hot', + vertices=hemi_vertices) + with pytest.raises(ValueError, match='need shape'): + brain.add_data(hemi_data[:, np.newaxis], time=[0, 1], + fmin=fmin, hemi=h, fmax=fmax, colormap='hot', + vertices=hemi_vertices) + with pytest.raises(ValueError, match='If array has 3'): + brain.add_data(hemi_data[:, np.newaxis, np.newaxis], + fmin=fmin, hemi=h, fmax=fmax, colormap='hot', + vertices=hemi_vertices) # add label label = read_label(fname_label) brain.add_label(label, scalar_thresh=0.) brain.remove_labels() + brain.add_label(fname_label) + brain.add_label('V1', borders=True) + brain.remove_labels() + brain.remove_labels() # add foci brain.add_foci([0], coords_as_verts=True, @@ -131,21 +197,31 @@ def test_brain(renderer): # add text brain.add_text(x=0, y=0, text='foo') - # screenshot - brain.show_view(view=dict(azimuth=180., elevation=90.)) - img = brain.screenshot(mode='rgb') - assert_allclose(img.shape, (size[0], size[1], 3), - atol=70) # XXX undo once size is fixed - # add annotation - annots = ['aparc', 'PALS_B12_Lobes'] + annots = ['aparc', path.join(subjects_dir, 'fsaverage', 'label', + 'lh.PALS_B12_Lobes.annot')] borders = [True, 2] alphas = [1, 0.5] + colors = [None, 'r'] brain = _Brain(subject_id='fsaverage', hemi=hemi, size=size, surf='inflated', subjects_dir=subjects_dir) - for a, b, p in zip(annots, borders, alphas): - brain.add_annotation(a, b, p) + for a, b, p, color in zip(annots, borders, alphas, colors): + brain.add_annotation(a, b, p, color=color) + brain.show_view(dict(focalpoint=(1e-5, 1e-5, 1e-5)), roll=1, distance=500) + + # image and screenshot + fname = path.join(str(tmpdir), 'test.png') + assert not path.isfile(fname) + brain.save_image(fname) + assert path.isfile(fname) + brain.show_view(view=dict(azimuth=180., elevation=90.)) + img = brain.screenshot(mode='rgb') + if renderer._get_3d_backend() == 'mayavi': + pixel_ratio = 1. # no HiDPI when using the testing backend + want_size = np.array([size[0] * pixel_ratio, size[1] * pixel_ratio, 3]) + assert_allclose(img.shape, want_size, + atol=70 * pixel_ratio) # XXX undo once size is fixed brain.close() @@ -155,7 +231,7 @@ def test_brain_save_movie(tmpdir, renderer): """Test saving a movie of a _Brain instance.""" if renderer._get_3d_backend() == "mayavi": pytest.skip('Save movie only supported on PyVista') - brain_data = _create_testing_brain(hemi='lh') + brain_data = _create_testing_brain(hemi='lh', time_viewer=False) filename = str(path.join(tmpdir, "brain_test.mov")) brain_data.save_movie(filename, time_dilation=1, interpolation='nearest') @@ -164,13 +240,15 @@ def test_brain_save_movie(tmpdir, renderer): @testing.requires_testing_data -def test_brain_timeviewer(renderer_interactive): +def test_brain_timeviewer(renderer_interactive, pixel_ratio): """Test _TimeViewer primitives.""" if renderer_interactive._get_3d_backend() != 'pyvista': pytest.skip('TimeViewer tests only supported on PyVista') - brain_data = _create_testing_brain(hemi='both') + brain_data = _create_testing_brain(hemi='both', show_traces=False) - time_viewer = _TimeViewer(brain_data) + with pytest.raises(RuntimeError, match='already'): + _TimeViewer(brain_data) + time_viewer = brain_data.time_viewer time_viewer.time_call(value=0) time_viewer.orientation_call(value='lat', update_widget=True) time_viewer.orientation_call(value='medial', update_widget=True) @@ -185,11 +263,16 @@ def test_brain_timeviewer(renderer_interactive): time_viewer.toggle_playback() time_viewer.apply_auto_scaling() time_viewer.restore_user_scaling() + plt.close('all') + time_viewer.help() + assert len(plt.get_fignums()) == 1 + plt.close('all') # screenshot brain_data.show_view(view=dict(azimuth=180., elevation=90.)) img = brain_data.screenshot(mode='rgb') - assert(img.shape == (300, 300, 3)) + want_shape = np.array([300 * pixel_ratio, 300 * pixel_ratio, 3]) + assert_allclose(img.shape, want_shape) @testing.requires_testing_data @@ -199,45 +282,80 @@ def test_brain_timeviewer(renderer_interactive): pytest.param('split', marks=pytest.mark.slowtest), pytest.param('both', marks=pytest.mark.slowtest), ]) -def test_brain_timeviewer_traces(renderer_interactive, hemi, tmpdir): +@pytest.mark.parametrize('src', [ + 'surface', + pytest.param('volume', marks=pytest.mark.slowtest), + pytest.param('mixed', marks=pytest.mark.slowtest), +]) +def test_brain_timeviewer_traces(renderer_interactive, hemi, src, tmpdir): """Test _TimeViewer traces.""" if renderer_interactive._get_3d_backend() != 'pyvista': pytest.skip('Only PyVista supports traces') - brain_data = _create_testing_brain(hemi=hemi) - time_viewer = _TimeViewer(brain_data, show_traces=True) + brain_data = _create_testing_brain( + hemi=hemi, surf='white', src=src, show_traces=0.5, initial_time=0, + volume_options=None, # for speed, don't upsample + ) + with pytest.raises(RuntimeError, match='already'): + _TimeViewer(brain_data) + time_viewer = brain_data.time_viewer + assert time_viewer.show_traces assert hasattr(time_viewer, "picked_points") assert hasattr(time_viewer, "_spheres") # test points picked by default picked_points = brain_data.get_picked_points() spheres = time_viewer._spheres - hemi_str = [hemi] if hemi in ('lh', 'rh') else ['lh', 'rh'] + hemi_str = list() + if src in ('surface', 'mixed'): + hemi_str.extend([hemi] if hemi in ('lh', 'rh') else ['lh', 'rh']) + if src in ('mixed', 'volume'): + hemi_str.extend(['vol']) for current_hemi in hemi_str: assert len(picked_points[current_hemi]) == 1 - assert len(spheres) == len(hemi_str) + n_spheres = len(hemi_str) + if hemi == 'split' and src in ('mixed', 'volume'): + n_spheres += 1 + assert len(spheres) == n_spheres # test removing points time_viewer.clear_points() - assert len(picked_points['lh']) == 0 - assert len(picked_points['rh']) == 0 + assert len(spheres) == 0 + for key in ('lh', 'rh', 'vol'): + assert len(picked_points[key]) == 0 # test picking a cell at random + rng = np.random.RandomState(0) for idx, current_hemi in enumerate(hemi_str): - current_mesh = brain_data._hemi_meshes[current_hemi] - cell_id = np.random.randint(0, current_mesh.n_cells) - test_picker = TstVTKPicker(current_mesh, cell_id) + assert len(spheres) == 0 + if current_hemi == 'vol': + current_mesh = brain_data._data['vol']['grid'] + vertices = brain_data._data['vol']['vertices'] + values = current_mesh.cell_arrays['values'][vertices] + cell_id = vertices[np.argmax(np.abs(values))] + else: + current_mesh = brain_data._hemi_meshes[current_hemi] + cell_id = rng.randint(0, current_mesh.n_cells) + test_picker = TstVTKPicker(None, None, current_hemi, brain_data) + assert time_viewer.on_pick(test_picker, None) is None + test_picker = TstVTKPicker( + current_mesh, cell_id, current_hemi, brain_data) assert cell_id == test_picker.cell_id assert test_picker.point_id is None time_viewer.on_pick(test_picker, None) assert test_picker.point_id is not None assert len(picked_points[current_hemi]) == 1 assert picked_points[current_hemi][0] == test_picker.point_id - sphere = spheres[idx] + assert len(spheres) > 0 + sphere = spheres[-1] vertex_id = sphere._vertex_id assert vertex_id == test_picker.point_id line = sphere._line - hemi_prefix = 'L' if current_hemi == 'lh' else 'R' + hemi_prefix = current_hemi[0].upper() + if current_hemi == 'vol': + assert hemi_prefix + ':' in line.get_label() + assert 'MNI' in line.get_label() + continue # the MNI conversion is more complex hemi_int = 0 if current_hemi == 'lh' else 1 mni = vertex_to_mni( vertices=vertex_id, @@ -250,7 +368,12 @@ def test_brain_timeviewer_traces(renderer_interactive, hemi, tmpdir): ', '.join('%5.1f' % m for m in mni)) assert line.get_label() == label - assert len(spheres) == len(hemi_str) + + # remove the sphere by clicking in its vicinity + old_len = len(spheres) + test_picker._actors = sum((s._actors for s in spheres), []) + time_viewer.on_pick(test_picker, None) + assert len(spheres) < old_len # and the scraper for it (will close the instance) if not check_version('sphinx_gallery'): @@ -276,8 +399,7 @@ def test_brain_linkviewer(renderer_interactive, travis_macos): pytest.skip('Linkviewer only supported on PyVista') if travis_macos: pytest.skip('Linkviewer tests unstable on Travis macOS') - brain_data = _create_testing_brain(hemi='split') - _TimeViewer(brain_data) + brain_data = _create_testing_brain(hemi='split', show_traces=False) link_viewer = _LinkViewer( [brain_data], @@ -394,30 +516,46 @@ def test_brain_colormap(): calculate_lut(colormap, alpha, 1, 0, 2) -def _create_testing_brain(hemi, surf='inflated'): - sample_src = read_source_spaces(src_fname) +def _create_testing_brain(hemi, surf='inflated', src='surface', size=300, + **kwargs): + assert src in ('surface', 'mixed', 'volume') + meth = 'plot' + if src in ('surface', 'mixed'): + sample_src = read_source_spaces(src_fname) + klass = MixedSourceEstimate if src == 'mixed' else SourceEstimate + if src in ('volume', 'mixed'): + vol_src = setup_volume_source_space( + subject_id, 7., mri='aseg.mgz', + volume_label='Left-Cerebellum-Cortex', + subjects_dir=subjects_dir, add_interpolator=False) + assert len(vol_src) == 1 + assert vol_src[0]['nuse'] == 150 + if src == 'mixed': + sample_src = sample_src + vol_src + else: + sample_src = vol_src + klass = VolSourceEstimate + meth = 'plot_3d' + assert sample_src.kind == src # dense version + rng = np.random.RandomState(0) vertices = [s['vertno'] for s in sample_src] n_time = 5 n_verts = sum(len(v) for v in vertices) stc_data = np.zeros((n_verts * n_time)) stc_size = stc_data.size - stc_data[(np.random.rand(stc_size // 20) * stc_size).astype(int)] = \ - np.random.RandomState(0).rand(stc_data.size // 20) + stc_data[(rng.rand(stc_size // 20) * stc_size).astype(int)] = \ + rng.rand(stc_data.size // 20) stc_data.shape = (n_verts, n_time) - stc = SourceEstimate(stc_data, vertices, 1, 1) + stc = klass(stc_data, vertices, 1, 1) fmin = stc.data.min() fmax = stc.data.max() - brain_data = _Brain(subject_id, hemi, surf, size=300, - subjects_dir=subjects_dir) - hemi_list = ['lh', 'rh'] if hemi in ['both', 'split'] else [hemi] - for hemi_str in hemi_list: - hemi_idx = 0 if hemi_str == 'lh' else 1 - data = getattr(stc, hemi_str + '_data') - vertices = stc.vertices[hemi_idx] - brain_data.add_data(data, fmin=fmin, hemi=hemi_str, fmax=fmax, - colormap='hot', vertices=vertices, - colorbar=True) + fmid = (fmin + fmax) / 2. + brain_data = getattr(stc, meth)( + subject=subject_id, hemi=hemi, surface=surf, size=size, + subjects_dir=subjects_dir, colormap='hot', + clim=dict(kind='value', lims=(fmin, fmid, fmax)), src=sample_src, + **kwargs) return brain_data diff --git a/mne/viz/_brain/view.py b/mne/viz/_brain/view.py index 1be4e9ed76b..acba52030a3 100644 --- a/mne/viz/_brain/view.py +++ b/mne/viz/_brain/view.py @@ -6,34 +6,39 @@ # # License: Simplified BSD -from collections import namedtuple - - -View = namedtuple('View', 'elev azim') - -lh_views_dict = {'lateral': View(azim=180., elev=90.), - 'medial': View(azim=0., elev=90.0), - 'rostral': View(azim=90., elev=90.), - 'caudal': View(azim=270., elev=90.), - 'dorsal': View(azim=180., elev=0.), - 'ventral': View(azim=180., elev=180.), - 'frontal': View(azim=120., elev=80.), - 'parietal': View(azim=-120., elev=60.)} -rh_views_dict = {'lateral': View(azim=180., elev=-90.), - 'medial': View(azim=0., elev=-90.0), - 'rostral': View(azim=-90., elev=-90.), - 'caudal': View(azim=90., elev=-90.), - 'dorsal': View(azim=180., elev=0.), - 'ventral': View(azim=180., elev=180.), - 'frontal': View(azim=60., elev=80.), - 'parietal': View(azim=-60., elev=60.)} +_lh_views_dict = { + 'lateral': dict(azimuth=180., elevation=90.), + 'medial': dict(azimuth=0., elevation=90.0), + 'rostral': dict(azimuth=90., elevation=90.), + 'caudal': dict(azimuth=270., elevation=90.), + 'dorsal': dict(azimuth=180., elevation=0.), + 'ventral': dict(azimuth=180., elevation=180.), + 'frontal': dict(azimuth=120., elevation=80.), + 'parietal': dict(azimuth=-120., elevation=60.), + 'sagittal': dict(azimuth=180., elevation=-90.), + 'coronal': dict(azimuth=90., elevation=-90.), + 'axial': dict(azimuth=180., elevation=0., roll=180), +} +_rh_views_dict = { + 'lateral': dict(azimuth=180., elevation=-90.), + 'medial': dict(azimuth=0., elevation=-90.0), + 'rostral': dict(azimuth=-90., elevation=-90.), + 'caudal': dict(azimuth=90., elevation=-90.), + 'dorsal': dict(azimuth=180., elevation=0.), + 'ventral': dict(azimuth=180., elevation=180.), + 'frontal': dict(azimuth=60., elevation=80.), + 'parietal': dict(azimuth=-60., elevation=60.), + 'sagittal': dict(azimuth=180., elevation=-90.), + 'coronal': dict(azimuth=90., elevation=-90.), + 'axial': dict(azimuth=180., elevation=0., roll=180), +} # add short-size version entries into the dict -_lh_views_dict = dict() -for k, v in lh_views_dict.items(): - _lh_views_dict[k[:3]] = v -lh_views_dict.update(_lh_views_dict) +lh_views_dict = _lh_views_dict.copy() +for k, v in _lh_views_dict.items(): + lh_views_dict[k[:3]] = v -_rh_views_dict = dict() -for k, v in rh_views_dict.items(): - _rh_views_dict[k[:3]] = v -rh_views_dict.update(_rh_views_dict) +rh_views_dict = _rh_views_dict.copy() +for k, v in _rh_views_dict.items(): + rh_views_dict[k[:3]] = v +views_dicts = dict(lh=lh_views_dict, vol=lh_views_dict, both=lh_views_dict, + rh=rh_views_dict) diff --git a/mne/viz/backends/_pysurfer_mayavi.py b/mne/viz/backends/_pysurfer_mayavi.py index 5ad63ee45dc..11a81cb47b0 100644 --- a/mne/viz/backends/_pysurfer_mayavi.py +++ b/mne/viz/backends/_pysurfer_mayavi.py @@ -89,16 +89,18 @@ def subplot(self, x, y): def scene(self): return self.fig - def set_interactive(self): + def set_interaction(self, interaction): from tvtk.api import tvtk if self.fig.scene is not None: self.fig.scene.interactor.interactor_style = \ - tvtk.InteractorStyleTerrain() + getattr(tvtk, f'InteractorStyle{interaction.capitalize()}')() def mesh(self, x, y, z, triangles, color, opacity=1.0, shading=False, backface_culling=False, scalars=None, colormap=None, vmin=None, vmax=None, interpolate_before_map=True, - representation='surface', line_width=1., normals=None, **kwargs): + representation='surface', line_width=1., normals=None, + pickable=None, **kwargs): + # normals and pickable are unused if color is not None: color = _check_color(color) if color is not None and isinstance(color, np.ndarray) \ @@ -299,10 +301,10 @@ def close(self): _close_3d_figure(figure=self.fig) def set_camera(self, azimuth=None, elevation=None, distance=None, - focalpoint=None): + focalpoint=None, roll=None): _set_3d_view(figure=self.fig, azimuth=azimuth, elevation=elevation, distance=distance, - focalpoint=focalpoint) + focalpoint=focalpoint, roll=roll) def reset_camera(self): renderer = getattr(self.fig.scene, 'renderer', None) @@ -436,12 +438,12 @@ def _close_all(): mlab.close(all=True) -def _set_3d_view(figure, azimuth, elevation, focalpoint, distance): +def _set_3d_view(figure, azimuth, elevation, focalpoint, distance, roll=None): from mayavi import mlab with warnings.catch_warnings(record=True): # traits with SilenceStdout(): mlab.view(azimuth, elevation, distance, - focalpoint=focalpoint, figure=figure) + focalpoint=focalpoint, figure=figure, roll=roll) mlab.draw(figure) @@ -488,16 +490,16 @@ def _take_3d_screenshot(figure, mode='rgb', filename=None): figure_size = figure._window_size else: figure_size = figure.scene._renwin.size - return np.zeros(tuple(figure_size) + (ndim,), np.uint8) + img = np.zeros(tuple(figure_size) + (ndim,), np.uint8) else: from pyface.api import GUI gui = GUI() gui.process_events() with warnings.catch_warnings(record=True): # traits img = mlab.screenshot(figure, mode=mode) - if isinstance(filename, str): - _save_figure(img, filename) - return img + if isinstance(filename, str): + _save_figure(img, filename) + return img @contextmanager diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index 4405d1b5d08..c103307dc0f 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -67,6 +67,8 @@ def __init__(self, plotter=None, self.store['off_screen'] = off_screen self.store['border'] = False self.store['auto_update'] = False + # multi_samples > 1 is broken on macOS + Intel Iris + volume rendering + self.store['multi_samples'] = 1 if sys.platform == 'darwin' else 4 def build(self): if self.plotter_class is None: @@ -199,7 +201,7 @@ def ensure_minimum_sizes(self): yield finally: for _ in range(2): - self.plotter.app.processEvents() + _process_events(self.plotter) self.plotter.interactor.setMinimumSize(0, 0) def subplot(self, x, y): @@ -214,8 +216,8 @@ def subplot(self, x, y): def scene(self): return self.figure - def set_interactive(self): - self.plotter.enable_terrain_style() + def set_interaction(self, interaction): + getattr(self.plotter, f'enable_{interaction}_style')() def polydata(self, mesh, color=None, opacity=1.0, normals=None, backface_culling=False, scalars=None, colormap=None, @@ -533,9 +535,9 @@ def close(self): _close_3d_figure(figure=self.figure) def set_camera(self, azimuth=None, elevation=None, distance=None, - focalpoint=None): + focalpoint=None, roll=None): _set_3d_view(self.figure, azimuth=azimuth, elevation=elevation, - distance=distance, focalpoint=focalpoint) + distance=distance, focalpoint=focalpoint, roll=roll) def reset_camera(self): self.plotter.reset_camera() @@ -659,7 +661,7 @@ def _get_camera_direction(focalpoint, position): return r, theta, phi, focalpoint -def _set_3d_view(figure, azimuth, elevation, focalpoint, distance): +def _set_3d_view(figure, azimuth, elevation, focalpoint, distance, roll=None): position = np.array(figure.plotter.camera_position[0]) focalpoint = np.array(figure.plotter.camera_position[1]) r, theta, phi, fp = _get_camera_direction(focalpoint, position) @@ -668,20 +670,18 @@ def _set_3d_view(figure, azimuth, elevation, focalpoint, distance): phi = _deg2rad(azimuth) if elevation is not None: theta = _deg2rad(elevation) + if roll is not None: + roll = _deg2rad(roll) renderer = figure.plotter.renderer bounds = np.array(renderer.ComputeVisiblePropBounds()) - if distance is not None: - r = distance - else: - r = max(bounds[1::2] - bounds[::2]) * 2.0 - distance = r + if distance is None: + distance = max(bounds[1::2] - bounds[::2]) * 2.0 if focalpoint is not None: - cen = np.asarray(focalpoint) + focalpoint = np.asarray(focalpoint) else: - cen = (bounds[1::2] + bounds[::2]) * 0.5 - focalpoint = cen + focalpoint = (bounds[1::2] + bounds[::2]) * 0.5 # Now calculate the view_up vector of the camera. If the view up is # close to the 'z' axis, the view plane normal is parallel to the @@ -692,15 +692,19 @@ def _set_3d_view(figure, azimuth, elevation, focalpoint, distance): view_up = [np.sin(phi), np.cos(phi), 0] position = [ - r * np.cos(phi) * np.sin(theta), - r * np.sin(phi) * np.sin(theta), - r * np.cos(theta)] + distance * np.cos(phi) * np.sin(theta), + distance * np.sin(phi) * np.sin(theta), + distance * np.cos(theta)] figure.plotter.camera_position = [ - position, cen, view_up] + position, focalpoint, view_up] + if roll is not None: + figure.plotter.camera.SetRoll(roll) + # set the distance figure.plotter.renderer._azimuth = azimuth figure.plotter.renderer._elevation = elevation figure.plotter.renderer._distance = distance + figure.plotter.renderer._roll = roll def _set_3d_title(figure, title, size=16): @@ -737,7 +741,9 @@ def _take_3d_screenshot(figure, mode='rgb', filename=None): def _process_events(plotter, show=False): if hasattr(plotter, 'app'): - plotter.app.processEvents() + with warnings.catch_warnings(record=True): + warnings.filterwarnings('ignore', 'constrained_layout') + plotter.app.processEvents() if show: plotter.app_window.show() @@ -758,6 +764,25 @@ def _set_colormap_range(actor, ctable, scalar_bar, rng=None): scalar_bar.SetLookupTable(actor.GetMapper().GetLookupTable()) +def _set_volume_range(volume, ctable, alpha, scalar_bar, rng): + import vtk + from vtk.util.numpy_support import numpy_to_vtk + color_tf = vtk.vtkColorTransferFunction() + opacity_tf = vtk.vtkPiecewiseFunction() + for loc, color in zip(np.linspace(*rng, num=len(ctable)), ctable): + color_tf.AddRGBPoint(loc, *color[:-1]) + opacity_tf.AddPoint(loc, color[-1] * alpha / 255. / (len(ctable) - 1)) + color_tf.ClampingOn() + opacity_tf.ClampingOn() + volume.GetProperty().SetColor(color_tf) + volume.GetProperty().SetScalarOpacity(opacity_tf) + if scalar_bar is not None: + lut = vtk.vtkLookupTable() + lut.SetRange(*rng) + lut.SetTable(numpy_to_vtk(ctable)) + scalar_bar.SetLookupTable(lut) + + def _set_mesh_scalars(mesh, scalars, name): # Catch: FutureWarning: Conversion of the second argument of # issubdtype from `complex` to `np.complexfloating` is deprecated. @@ -812,6 +837,7 @@ def _update_picking_callback(plotter, vtk.vtkCommand.EndPickEvent, on_pick ) + picker.SetVolumeOpacityIsovalue(0.) plotter.picker = picker diff --git a/mne/viz/backends/base_renderer.py b/mne/viz/backends/base_renderer.py index 48b91acec71..94ab73ff5df 100644 --- a/mne/viz/backends/base_renderer.py +++ b/mne/viz/backends/base_renderer.py @@ -28,8 +28,8 @@ def scene(self): pass @abstractclassmethod - def set_interactive(self): - """Enable interactive mode.""" + def set_interaction(self, interaction): + """Set interaction mode.""" pass @abstractclassmethod diff --git a/mne/viz/backends/tests/test_renderer.py b/mne/viz/backends/tests/test_renderer.py index 87e4075002c..52bc92a9157 100644 --- a/mne/viz/backends/tests/test_renderer.py +++ b/mne/viz/backends/tests/test_renderer.py @@ -107,7 +107,8 @@ def test_3d_backend(renderer): bgcolor=win_color, smooth_shading=True, ) - rend.set_interactive() + for interaction in ('terrain', 'trackball'): + rend.set_interaction(interaction) # use mesh mesh_data = rend.mesh( diff --git a/mne/viz/tests/test_3d.py b/mne/viz/tests/test_3d.py index 26ed45ed6d8..06ebaa5290a 100644 --- a/mne/viz/tests/test_3d.py +++ b/mne/viz/tests/test_3d.py @@ -7,7 +7,7 @@ # # License: Simplified BSD -import os +from mne.minimum_norm.inverse import apply_inverse import os.path as op from pathlib import Path import sys @@ -19,11 +19,12 @@ from matplotlib.colors import Colormap from mne import (make_field_map, pick_channels_evoked, read_evokeds, - read_trans, read_dipole, SourceEstimate, VectorSourceEstimate, + read_trans, read_dipole, SourceEstimate, VolSourceEstimate, make_sphere_model, use_coil_def, setup_volume_source_space, read_forward_solution, VolVectorSourceEstimate, convert_forward_solution, compute_source_morph, MixedSourceEstimate) +from mne.source_estimate import _BaseVolSourceEstimate from mne.io import (read_raw_ctf, read_raw_bti, read_raw_kit, read_info, read_raw_nirx) from mne.io._digitization import write_dig @@ -689,39 +690,78 @@ def test_plot_volume_source_estimates_morph(): clim=dict(lims=[-1, 2, 3], kind='value')) -bad_azure_3d = pytest.mark.skipif( - os.getenv('AZURE_CI_WINDOWS', 'false') == 'true' and - sys.version_info[:2] == (3, 8), - reason='Crashes workers on Azure') - - @pytest.mark.slowtest # can be slow on OSX @testing.requires_testing_data @requires_pysurfer @traits_test -@bad_azure_3d -def test_plot_vector_source_estimates(renderer_interactive): - """Test plotting of vector source estimates.""" - sample_src = read_source_spaces(src_fname) - - vertices = [s['vertno'] for s in sample_src] - n_verts = sum(len(v) for v in vertices) - n_time = 5 - data = np.random.RandomState(0).rand(n_verts, 3, n_time) - stc = VectorSourceEstimate(data, vertices, 1, 1) - - brain = stc.plot('sample', subjects_dir=subjects_dir, hemi='both', - smoothing_steps=1, verbose='error') +@pytest.mark.parametrize('pick_ori', ('vector', None)) +@pytest.mark.parametrize('kind', ('surface', 'volume', 'mixed')) +def test_plot_source_estimates(renderer_interactive, all_src_types_inv_evoked, + pick_ori, kind): + """Test plotting of scalar and vector source estimates.""" + invs, evoked = all_src_types_inv_evoked + inv = invs[kind] + is_pyvista = renderer_interactive._get_3d_backend() == 'pyvista' + with pytest.warns(None): # PCA mag + stc = apply_inverse(evoked, inv, pick_ori=pick_ori) + stc.data[1] *= -1 # make it signed + meth = 'plot_3d' if isinstance(stc, _BaseVolSourceEstimate) else 'plot' + meth = getattr(stc, meth) + kwargs = dict(subject='sample', subjects_dir=subjects_dir, + time_viewer=False, show_traces=False, # for speed + smoothing_steps=1, verbose='error', src=inv['src'], + volume_options=dict(resolution=None), # for speed + ) + if pick_ori != 'vector': + kwargs['surface'] = 'white' + # Mayavi can't handle non-surface + if kind != 'surface' and not is_pyvista: + with pytest.raises(RuntimeError, match='PyVista'): + meth(**kwargs) + return + brain = meth(**kwargs) brain.close() del brain - with pytest.raises(ValueError, match='use "pos_lims"'): - stc.plot('sample', subjects_dir=subjects_dir, - clim=dict(pos_lims=[1, 2, 3])) + if pick_ori == 'vector': + with pytest.raises(ValueError, match='use "pos_lims"'): + meth(**kwargs, clim=dict(pos_lims=[1, 2, 3])) + if kind in ('volume', 'mixed'): + with pytest.raises(TypeError, match='when stc is a mixed or vol'): + these_kwargs = kwargs.copy() + these_kwargs.pop('src') + meth(**these_kwargs) with pytest.raises(ValueError, match='cannot be used'): - stc.plot('sample', subjects_dir=subjects_dir, - show_traces=True, time_viewer=False) + these_kwargs = kwargs.copy() + these_kwargs.update(show_traces=True, time_viewer=False) + meth(**these_kwargs) + if not is_pyvista: + with pytest.raises(ValueError, match='view_layout must be'): + meth(view_layout='horizontal', **kwargs) + + # just test one for speed + if kind != 'mixed': + return + assert is_pyvista + brain = meth( + views=['lat', 'med', 'ven'], hemi='lh', + view_layout='horizontal', **kwargs) + brain.close() + assert brain._subplot_shape == (1, 3) + del brain + these_kwargs = kwargs.copy() + these_kwargs['volume_options'] = dict(blending='foo') + with pytest.raises(ValueError, match='mip'): + meth(**these_kwargs) + these_kwargs['volume_options'] = dict(badkey='foo') + with pytest.raises(ValueError, match='unknown'): + meth(**these_kwargs) + # with resampling (actually downsampling but it's okay) + these_kwargs['volume_options'] = dict(resolution=20., surface_alpha=0.) + brain = meth(**these_kwargs) + brain.close() + del brain @pytest.mark.slowtest @@ -786,7 +826,6 @@ def test_brain_colorbar(orientation, diverging, lims): @requires_pysurfer @testing.requires_testing_data @traits_test -@bad_azure_3d def test_mixed_sources_plot_surface(renderer_interactive): """Test plot_surface() for mixed source space.""" src = read_source_spaces(fwd_fname2) diff --git a/setup.cfg b/setup.cfg index 2429a580712..f0e828267f5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,10 +22,11 @@ select = A,E,F,W,C [tool:pytest] addopts = - --durations=20 --doctest-modules -ra --cov-report= + --durations=20 --doctest-modules -ra --cov-report= --tb=short --doctest-ignore-import-errors --junit-xml=junit-results.xml --ignore=doc --ignore=logo --ignore=examples --ignore=tutorials --ignore=mne/gui/_*.py --ignore=mne/externals --ignore=mne/icons + --capture=sys junit_family = xunit2 [pydocstyle] diff --git a/tutorials/source-modeling/plot_beamformer_lcmv.py b/tutorials/source-modeling/plot_beamformer_lcmv.py index 692e03a1232..b739ebdaf7b 100644 --- a/tutorials/source-modeling/plot_beamformer_lcmv.py +++ b/tutorials/source-modeling/plot_beamformer_lcmv.py @@ -195,20 +195,34 @@ ############################################################################### # Visualize the reconstructed source activity # ------------------------------------------- -# We can visualize the source estimate in different ways, e.g. as an overlay -# onto the MRI or as a glass brain. +# We can visualize the source estimate in different ways, e.g. as a volume +# rendering, an overlay onto the MRI, or as an overlay onto a glass brain. +# # The plots for the scalar beamformer show brain activity in the right temporal # lobe around 100 ms post stimulus. This is expected given the left-ear # auditory stimulation of the experiment. +# +# Volumetric rendering (3D) +# ~~~~~~~~~~~~~~~~~~~~~~~~~ lims = [0.3, 0.45, 0.6] - kwargs = dict(src=forward['src'], subject='sample', subjects_dir=subjects_dir, initial_time=0.087, verbose=True) +stc.plot_3d(clim=dict(kind='value', pos_lims=lims), hemi='both', + views=['sagittal', 'coronal', 'axial'], size=(800, 300), + view_layout='horizontal', show_traces=0.4, **kwargs) + +############################################################################### +# On MRI slices (orthoview; 2D) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + stc.plot(mode='stat_map', clim=dict(kind='value', pos_lims=lims), **kwargs) ############################################################################### +# On MNI glass brain (orthoview; 2D) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + stc.plot(mode='glass_brain', clim=dict(kind='value', lims=lims), **kwargs) ############################################################################### @@ -219,6 +233,11 @@ # sphinx_gallery_thumbnail_number = 5 +stc_vec.plot_3d(clim=dict(kind='value', lims=lims), hemi='both', + views=['sagittal', 'coronal', 'axial'], size=(800, 300), + view_layout='horizontal', show_traces=0.4, **kwargs) + +############################################################################### stc_vec.plot(mode='stat_map', clim=dict(kind='value', pos_lims=lims), **kwargs) ###############################################################################