Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MRG, ENH: Scrape traces when available #7927

Merged
merged 2 commits into from
Jul 16, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ jobs:
echo "set -e" >> $BASH_ENV
echo "export DISPLAY=:99" >> $BASH_ENV
echo "export OPENBLAS_NUM_THREADS=4" >> $BASH_ENV
echo "export XDG_RUNTIME_DIR=/tmp/runtime-circleci" >> $BASH_ENV
source tools/get_minimal_commands.sh
echo "source ${PWD}/tools/get_minimal_commands.sh" >> $BASH_ENV
echo "export MNE_3D_BACKEND=pyvista" >> $BASH_ENV
echo "export _MNE_BRAIN_TRACES_AUTO=false" >> $BASH_ENV
echo "export PATH=~/.local/bin/:${MNE_ROOT}/bin:$PATH" >> $BASH_ENV
echo "BASH_ENV:"
cat $BASH_ENV
Expand Down
2 changes: 1 addition & 1 deletion doc/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ html-noplot:
@echo "Build finished. The HTML pages are in _build/html_stable."

html_dev-front:
@PATTERN="\(plot_mne_dspm_source_localization.py\|plot_receptive_field.py\|plot_mne_inverse_label_connectivity.py\|plot_sensors_decoding.py\|plot_stats_cluster_spatio_temporal.py\|plot_visualize_evoked.py\)" make html_dev-pattern;
@PATTERN="\(plot_mne_dspm_source_localization.py\|plot_receptive_field.py\|plot_mne_inverse_label_connectivity.py\|plot_sensors_decoding.py\|plot_stats_cluster_spatio_temporal.py\|plot_20_visualize_evoked.py\)" make html_dev-pattern;

dirhtml:
$(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) _build/dirhtml
Expand Down
5 changes: 5 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,11 @@
scrapers += (report_scraper,)
else:
report_scraper = None
if 'pyvista' in scrapers:
brain_scraper = mne.viz._3d._BrainScraper()
scrapers = list(scrapers)
scrapers.insert(scrapers.index('pyvista'), brain_scraper)
scrapers = tuple(scrapers)


def append_attr_meth_examples(app, what, name, obj, options, lines):
Expand Down
1 change: 1 addition & 0 deletions examples/inverse/plot_source_space_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
===============================

This example shows how to compute and plot source space SNR as in [1]_.

"""
# Author: Padma Sundaram <tottochan@gmail.com>
# Kaisu Lankinen <klankinen@mgh.harvard.edu>
Expand Down
67 changes: 67 additions & 0 deletions mne/viz/_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# License: Simplified BSD

from distutils.version import LooseVersion
import gc
from itertools import cycle
import os
import os.path as op
Expand Down Expand Up @@ -3188,3 +3189,69 @@ def _get_3d_option(key):
opt = opt.lower()
_check_option(f'3D option {key}', opt, ('true', 'false'))
return opt == 'true'


class _BrainScraper(object):
"""Scrape Brain objects."""

def __repr__(self):
return '<BrainScraper>'

def __call__(self, block, block_vars, gallery_conf):
rst = ''
try:
from ._brain import _Brain
except ImportError: # in case we haven't fired up the 3D plotting yet
return rst
for brain in block_vars['example_globals'].values():
# Only need to process if it's a brain with a time_viewer
# with traces on and shown in the same window, otherwise
# PyVista and matplotlib scrapers can just do the work
if (not isinstance(brain, _Brain)) or brain._closed:
continue
from sphinx_gallery.scrapers import figure_rst
from matplotlib.image import imsave
from .backends._pyvista import _process_events
img_fname = next(block_vars['image_path_iterator'])
img = brain.screenshot()
assert img.size > 0
if getattr(brain, 'time_viewer', None) is not None and \
brain.time_viewer.show_traces and \
not brain.time_viewer.separate_canvas:
canvas = brain.time_viewer.mpl_canvas.fig.canvas
canvas.draw_idle()
# In theory, one of these should work:
#
# trace_img = np.frombuffer(
# canvas.tostring_rgb(), dtype=np.uint8)
# trace_img.shape = canvas.get_width_height()[::-1] + (3,)
#
# or
#
# trace_img = np.frombuffer(
# canvas.tostring_rgb(), dtype=np.uint8)
# size = time_viewer.mpl_canvas.getSize()
# trace_img.shape = (size.height(), size.width(), 3)
#
# But in practice, sometimes the sizes does not match the
# renderer tostring_rgb() size. So let's directly use what
# matplotlib does in lib/matplotlib/backends/backend_agg.py
# before calling tobytes():
trace_img = np.asarray(
canvas.renderer._renderer).take([0, 1, 2], axis=2)
# need to slice into trace_img because generally it's a bit
# smaller
delta = trace_img.shape[1] - img.shape[1]
if delta > 0:
start = delta // 2
trace_img = trace_img[:, start:start + img.shape[1]]
img = np.concatenate([img, trace_img], axis=0)
imsave(img_fname, img)
assert op.isfile(img_fname)
rst += figure_rst(
[img_fname], gallery_conf['src_dir'], brain._title)
brain.close()
_process_events(brain._renderer.plotter)
_process_events(brain._renderer.plotter)
gc.collect()
return rst
2 changes: 2 additions & 0 deletions mne/viz/_brain/_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def __init__(self, subject_id, hemi, surf, title=None,
self._renderer.set_camera(azimuth=views_dict[v].azim,
elevation=views_dict[v].elev)

self._closed = False
if show:
self._renderer.show()

Expand Down Expand Up @@ -925,6 +926,7 @@ def resolve_coincident_topology(self, actor):

def close(self):
"""Close all figures and cleanup data structure."""
self._closed = True
self._renderer.close()

def show(self):
Expand Down
31 changes: 19 additions & 12 deletions mne/viz/_brain/_timeviewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def __init__(self, brain, show_traces=False):
self.default_playback_speed_range = [0.01, 1]
self.default_playback_speed_value = 0.05
self.default_status_bar_msg = "Press ? for help"
self.act_data = {'lh': None, 'rh': None}
self.act_data_smooth = {'lh': None, 'rh': None}
self.color_cycle = None
self.picked_points = {'lh': list(), 'rh': list()}
self._mouse_no_mvt = -1
Expand Down Expand Up @@ -819,6 +819,8 @@ def configure_playback(self):

def configure_point_picking(self):
from ..backends._pyvista import _update_picking_callback
# XXX we should change this to be if not self.show_traces: return
# and dedent...
if self.show_traces:
# use a matplotlib canvas
self.color_cycle = _ReuseCycle(_get_color_list())
Expand Down Expand Up @@ -851,9 +853,7 @@ def configure_point_picking(self):
if act_data.ndim == 3:
act_data = np.linalg.norm(act_data, axis=1)
smooth_mat = hemi_data['smooth_mat']
if smooth_mat is not None:
act_data = smooth_mat.dot(act_data)
self.act_data[hemi] = act_data
self.act_data_smooth[hemi] = (act_data, smooth_mat)

# simulate a picked renderer
if self.brain._hemi == 'split':
Expand All @@ -864,10 +864,10 @@ def configure_point_picking(self):
# initialize the default point
color = next(self.color_cycle)
ind = np.unravel_index(
np.argmax(self.act_data[hemi], axis=None),
self.act_data[hemi].shape
np.argmax(self.act_data_smooth[hemi][0], axis=None),
self.act_data_smooth[hemi][0].shape
)
vertex_id = ind[0]
vertex_id = hemi_data['vertices'][ind[0]]
mesh = hemi_data['mesh'][-1]
line = self.plot_time_course(hemi, vertex_id, color)
self.add_point(hemi, mesh, vertex_id, line, color)
Expand Down Expand Up @@ -1079,7 +1079,7 @@ def clear_points(self):
def plot_time_course(self, hemi, vertex_id, color):
if not hasattr(self, "mpl_canvas"):
return
time = self.brain._data['time']
time = self.brain._data['time'].copy() # avoid circular ref
hemi_str = 'L' if hemi == 'lh' else 'R'
hemi_int = 0 if hemi == 'lh' else 1
mni = vertex_to_mni(
Expand All @@ -1091,9 +1091,14 @@ def plot_time_course(self, hemi, vertex_id, color):
label = "{}:{} MNI: {}".format(
hemi_str, str(vertex_id).ljust(6),
', '.join('%5.1f' % m for m in mni))
act_data, smooth = self.act_data_smooth[hemi]
if smooth is not None:
act_data = smooth[vertex_id].dot(act_data)[0]
else:
act_data = act_data[vertex_id].copy()
line = self.mpl_canvas.plot(
time,
self.act_data[hemi][vertex_id, :],
act_data,
label=label,
lw=1.,
color=color
Expand Down Expand Up @@ -1177,14 +1182,16 @@ def clean(self):
self.interactor = None
if hasattr(self, "mpl_canvas"):
self.mpl_canvas.close()
self.mpl_canvas.axes.clear()
self.mpl_canvas.fig.clear()
self.mpl_canvas.time_viewer = None
self.mpl_canvas.canvas = None
self.mpl_canvas = None
self.time_actor = None
self.picked_renderer = None
self.act_data["lh"] = None
self.act_data["rh"] = None
self.act_data = None
self.act_data_smooth["lh"] = None
self.act_data_smooth["rh"] = None
self.act_data_smooth = None


class _LinkViewer(object):
Expand Down
22 changes: 20 additions & 2 deletions mne/viz/_brain/tests/test_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
from mne import SourceEstimate, read_source_estimate
from mne.source_space import read_source_spaces, vertex_to_mni
from mne.datasets import testing
from mne.utils import check_version
from mne.viz._brain import _Brain, _TimeViewer, _LinkViewer
from mne.viz._brain.colormap import calculate_lut
from mne.viz._3d import _BrainScraper

from matplotlib import cm
from matplotlib import cm, image

data_path = testing.data_path(download=False)
subject_id = 'sample'
Expand Down Expand Up @@ -198,7 +200,7 @@ def test_brain_timeviewer(renderer_interactive):
pytest.param('split', marks=pytest.mark.slowtest),
pytest.param('both', marks=pytest.mark.slowtest),
])
def test_brain_timeviewer_traces(renderer_interactive, hemi):
def test_brain_timeviewer_traces(renderer_interactive, hemi, tmpdir):
"""Test _TimeViewer traces."""
if renderer_interactive._get_3d_backend() != 'pyvista':
pytest.skip('Only PyVista supports traces')
Expand Down Expand Up @@ -251,6 +253,22 @@ def test_brain_timeviewer_traces(renderer_interactive, hemi):
assert line.get_label() == label
assert len(spheres) == len(hemi_str)

# and the scraper for it (will close the instance)
if not check_version('sphinx_gallery'):
return
screenshot = brain_data.screenshot()
fnames = [str(tmpdir.join('temp.png'))]
block_vars = dict(image_path_iterator=iter(fnames),
example_globals=dict(brain=brain_data))
gallery_conf = dict(src_dir=str(tmpdir))
scraper = _BrainScraper()
rst = scraper(None, block_vars, gallery_conf)
assert 'temp.png' in rst
assert path.isfile(fnames[0])
img = image.imread(fnames[0])
assert img.shape[1] == screenshot.shape[1] # same width
assert img.shape[0] > screenshot.shape[0] # larger height


@testing.requires_testing_data
def test_brain_linkviewer(renderer_interactive, travis_macos):
Expand Down
1 change: 1 addition & 0 deletions requirements_testing.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pytest-timeout
pytest-xdist
flake8
https://github.com/larsoner/flake8-array-spacing/archive/master.zip
https://github.com/sphinx-gallery/sphinx-gallery/archive/master.zip
https://github.com/numpy/numpydoc/archive/master.zip
https://github.com/codespell-project/codespell/archive/master.zip
pydocstyle
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@
###############################################################################
# Morph data to average brain
# ---------------------------
# Next we morph data to ``fsaverage``.

# setup source morph
morph = mne.compute_source_morph(
Expand Down
3 changes: 1 addition & 2 deletions tutorials/source-modeling/plot_visualize_stc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@
# and ``pysurfer`` installed on your machine.
initial_time = 0.1
brain = stc.plot(subjects_dir=subjects_dir, initial_time=initial_time,
clim=dict(kind='value', pos_lims=[3, 6, 9]),
time_viewer=True)
clim=dict(kind='value', pos_lims=[3, 6, 9]))

###############################################################################
#
Expand Down