diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index d6706b9482c..675b48ed422 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -115,8 +115,6 @@ class _Brain(object): +---------------------------+--------------+-----------------------+ | foci | ✓ | | +---------------------------+--------------+-----------------------+ - | index_for_time | ✓ | ✓ | - +---------------------------+--------------+-----------------------+ | labels | ✓ | | +---------------------------+--------------+-----------------------+ | labels_dict | ✓ | | @@ -371,6 +369,7 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None, raise ValueError('time has shape %s, but need shape %s ' '(array.shape[-1])' % (time.shape, (array.shape[-1],))) + self._data["time"] = time if self._n_times is None: self._n_times = len(time) @@ -386,7 +385,8 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None, if initial_time is None: time_idx = 0 else: - time_idx = self.index_for_time(initial_time) + time_idx = self._to_time_index(initial_time) + self._data["time_idx"] = time_idx # time label if isinstance(time_label, str): @@ -395,13 +395,12 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None, def time_label(x): return time_label_fmt % x self._data["time_label"] = time_label - self._data["time"] = time - self._data["time_idx"] = 0 y_txt = 0.05 + 0.1 * bool(colorbar) if time is not None and len(array.shape) == 2: # we have scalar_data with time dimension - act_data = array[:, time_idx] + act_data, act_time = self._interpolate_data(array, time_idx) + self._current_time = act_time else: # we have scalar data without time act_data = array @@ -479,7 +478,7 @@ def time_label(x): time_actor = self._renderer.text2d( x_window=0.95, y_window=y_txt, size=time_label_size, - text=time_label(time[time_idx]), + text=time_label(self._current_time), justification='right' ) self._data['time_actor'] = time_actor @@ -724,46 +723,6 @@ def remove_labels(self, labels=None): """ pass - def index_for_time(self, time, rounding='closest'): - """Find the data time index closest to a specific time point. - - Parameters - ---------- - time : scalar - Time. - rounding : 'closest' | 'up' | 'down' - How to round if the exact time point is not an index. - - Returns - ------- - index : int - Data time index closest to time. - """ - if self._n_times is None: - raise RuntimeError("Brain has no time axis") - times = self._times - - # Check that time is in range - tmin = np.min(times) - tmax = np.max(times) - max_diff = (tmax - tmin) / (len(times) - 1) / 2 - if time < tmin - max_diff or time > tmax + max_diff: - err = ("time = %s lies outside of the time axis " - "[%s, %s]" % (time, tmin, tmax)) - raise ValueError(err) - - if rounding == 'closest': - idx = np.argmin(np.abs(times - time)) - elif rounding == 'up': - idx = np.nonzero(times >= time)[0][0] - elif rounding == 'down': - idx = np.nonzero(times <= time)[0][-1] - else: - err = "Invalid rounding parameter: %s" % repr(rounding) - raise ValueError(err) - - return idx - def close(self): """Close all figures and cleanup data structure.""" self._renderer.close() @@ -861,10 +820,19 @@ def set_data_smoothing(self, n_steps): self._data[hemi]['smooth_mat'] = smooth_mat self.set_time_point(self._data['time_idx']) + def _interpolate_data(self, array, time_idx): + from scipy.interpolate import interp1d + times = np.arange(self._n_times) + act_data = interp1d( + times, array, self.interp_kind, axis=1, + assume_sorted=True)(time_idx) + ifunc = interp1d(times, self._data['time']) + act_time = ifunc(time_idx) + return act_data, act_time + def set_time_point(self, time_idx): """Set the time point shown.""" from ..backends._pyvista import _set_mesh_scalars - from scipy.interpolate import interp1d time = self._data['time'] for hemi in ['lh', 'rh']: hemi_data = self._data.get(hemi) @@ -877,12 +845,9 @@ def set_time_point(self, time_idx): act_data = array[:, time_idx] self._current_time = time[time_idx] else: - times = np.arange(self._n_times) - act_data = interp1d( - times, array, self.interp_kind, axis=1, - assume_sorted=True)(time_idx) - ifunc = interp1d(times, self._data['time']) - self._current_time = ifunc(time_idx) + act_data, act_time = self._interpolate_data( + array, time_idx) + self._current_time = act_time smooth_mat = hemi_data['smooth_mat'] if smooth_mat is not None: @@ -1030,6 +995,12 @@ def update_auto_scaling(self, restore=False): _set_colormap_range(actor, ctable, scalar_bar, rng) self._data['ctable'] = ctable + def _to_time_index(self, value): + """Return the interpolated time index of the given time value.""" + time = self._data['time'] + value = np.interp(value, time, np.arange(len(time))) + return value + @property def data(self): u"""Data used by time viewer and color bar widgets.""" diff --git a/mne/viz/_brain/_timeviewer.py b/mne/viz/_brain/_timeviewer.py index 98fed987d9a..342a0a760b3 100644 --- a/mne/viz/_brain/_timeviewer.py +++ b/mne/viz/_brain/_timeviewer.py @@ -7,14 +7,14 @@ from itertools import cycle import time import numpy as np -from ..utils import _show_help, _get_color_list, tight_layout +from ..utils import _check_option, _show_help, _get_color_list, tight_layout from ...source_space import vertex_to_mni class MplCanvas(object): """Ultimately, this is a QWidget (as well as a FigureCanvasAgg, etc.).""" - def __init__(self, parent, width, height, dpi): + def __init__(self, timeviewer, parent, width, height, dpi): from PyQt5 import QtWidgets from matplotlib.figure import Figure from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg @@ -31,6 +31,10 @@ def __init__(self, parent, width, height, dpi): FigureCanvasQTAgg.updateGeometry(self.canvas) # XXX eventually this should be called in the window resize callback tight_layout(fig=self.axes.figure) + self.timeviewer = timeviewer + for event in ('button_press', 'motion_notify'): + self.canvas.mpl_connect( + event + '_event', getattr(self, 'on_' + event)) def plot(self, x, y, label, **kwargs): """Plot a curve.""" @@ -39,6 +43,12 @@ def plot(self, x, y, label, **kwargs): self.update_plot() return line + def plot_time_line(self, x, label, **kwargs): + """Plot the vertical line.""" + line = self.axes.axvline(x, label=label, **kwargs) + self.update_plot() + return line + def update_plot(self): """Update the plot.""" self.axes.legend(prop={'family': 'monospace', 'size': 'small'}, @@ -53,6 +63,17 @@ def close(self): """Close the canvas.""" self.canvas.close() + def on_button_press(self, event): + """Handle button presses.""" + # left click (and maybe drag) in progress in axes + if (event.inaxes != self.axes or + event.button != 1): + return + self.timeviewer.time_call( + event.xdata, update_widget=True, time_as_index=False) + + on_motion_notify = on_button_press # for now they can be the same + class IntSlider(object): """Class to set a integer slider.""" @@ -79,9 +100,10 @@ def __call__(self, value): class TimeSlider(object): """Class to update the time slider.""" - def __init__(self, plotter=None, brain=None): + def __init__(self, plotter=None, brain=None, callback=None): self.plotter = plotter self.brain = brain + self.callback = callback self.slider_rep = None if brain is None: self.time_label = None @@ -89,9 +111,14 @@ def __init__(self, plotter=None, brain=None): if callable(self.brain._data['time_label']): self.time_label = self.brain._data['time_label'] - def __call__(self, value, update_widget=False): + def __call__(self, value, update_widget=False, time_as_index=True): """Update the time slider.""" + value = float(value) + if not time_as_index: + value = self.brain._to_time_index(value) self.brain.set_time_point(value) + if self.callback is not None: + self.callback() current_time = self.brain._current_time if self.slider_rep is None: for slider in self.plotter.slider_widgets: @@ -252,14 +279,20 @@ class _TimeViewer(object): """Class to interact with _Brain.""" def __init__(self, brain, show_traces=False): - self.brain = brain - self.brain.time_viewer = self - self.plotter = brain._renderer.plotter - self.interactor = self.plotter - self.interactor.keyPressEvent = self.keyPressEvent - - # orientation slider - orientation = [ + # Default configuration + self.playback = False + self.visibility = True + self.refresh_rate_ms = max(int(round(1000. / 60.)), 1) + self.default_scaling_range = [0.2, 2.0] + self.default_smoothing_range = [0, 15] + self.default_smoothing_value = 7 + self.default_playback_speed_range = [0.01, 1] + self.default_playback_speed_value = 0.05 + self.act_data = {'lh': None, 'rh': None} + self.color_cycle = None + self.picked_points = {'lh': list(), 'rh': list()} + self._mouse_no_mvt = -1 + self.orientation = [ 'lateral', 'medial', 'rostral', @@ -269,7 +302,139 @@ def __init__(self, brain, show_traces=False): 'frontal', 'parietal' ] + self.key_bindings = { + '?': self.help, + 'i': self.toggle_interface, + 's': self.apply_auto_scaling, + 'r': self.restore_user_scaling, + 'c': self.clear_points, + ' ': self.toggle_playback, + } + + # Direct access parameters: + self.brain = brain + self.brain.time_viewer = self + self.plotter = brain._renderer.plotter + self.main_menu = self.plotter.main_menu + self.interactor = self.plotter + self.interactor.keyPressEvent = self.keyPressEvent + + # Derived parameters: + self.playback_speed = self.default_playback_speed_value + _check_option('show_traces', type(show_traces), [bool, str]) + if isinstance(show_traces, str) and show_traces == "separate": + self.show_traces = True + self.separate_canvas = True + else: + self.show_traces = show_traces + self.separate_canvas = False + + self.configure_time_label() + self.configure_sliders() + self.configure_scalar_bar() + self.configure_playback() + self.configure_point_picking() + self.configure_menu() + + def keyPressEvent(self, event): + callback = self.key_bindings.get(event.text()) + if callback is not None: + callback() + + def toggle_interface(self): + self.visibility = not self.visibility + + # manage sliders + for slider in self.plotter.slider_widgets: + slider_rep = slider.GetRepresentation() + if self.visibility: + slider_rep.VisibilityOn() + else: + slider_rep.VisibilityOff() + + # manage time label + time_label = self.brain._data['time_label'] + if callable(time_label) and self.time_actor is not None: + if self.visibility: + self.time_actor.VisibilityOff() + else: + self.time_actor.SetInput(time_label(self.brain._current_time)) + self.time_actor.VisibilityOn() + + def apply_auto_scaling(self): + self.brain.update_auto_scaling() + self.fmin_slider_rep.SetValue(self.brain._data['fmin']) + self.fmid_slider_rep.SetValue(self.brain._data['fmid']) + self.fmax_slider_rep.SetValue(self.brain._data['fmax']) + + def restore_user_scaling(self): + self.brain.update_auto_scaling(restore=True) + self.fmin_slider_rep.SetValue(self.brain._data['fmin']) + self.fmid_slider_rep.SetValue(self.brain._data['fmid']) + self.fmax_slider_rep.SetValue(self.brain._data['fmax']) + + def toggle_playback(self): + self.playback = not self.playback + if self.playback: + time_data = self.brain._data['time'] + max_time = np.max(time_data) + if self.brain._current_time == max_time: # start over + self.brain.set_time_point(np.min(time_data)) + self._last_tick = time.time() + + def set_playback_speed(self, speed): + self.playback_speed = speed + + def play(self): + if self.playback: + this_time = time.time() + delta = this_time - self._last_tick + self._last_tick = time.time() + time_data = self.brain._data['time'] + times = np.arange(self.brain._n_times) + time_shift = delta * self.playback_speed + max_time = np.max(time_data) + time_point = min(self.brain._current_time + time_shift, max_time) + # always use linear here -- this does not determine the data + # interpolation mode, it just finds where we are (in time) in + # terms of the time indices + idx = np.interp(time_point, time_data, times) + self.time_call(idx, update_widget=True) + if time_point == max_time: + self.playback = False + self.plotter.update() # critical for smooth animation + def set_slider_style(self, slider, show_label=True): + if slider is not None: + slider_rep = slider.GetRepresentation() + slider_rep.SetSliderLength(0.02) + slider_rep.SetSliderWidth(0.04) + slider_rep.SetTubeWidth(0.005) + slider_rep.SetEndCapLength(0.01) + slider_rep.SetEndCapWidth(0.02) + slider_rep.GetSliderProperty().SetColor((0.5, 0.5, 0.5)) + if not show_label: + slider_rep.ShowSliderLabelOff() + + def configure_time_label(self): + self.time_actor = self.brain._data.get('time_actor') + if self.time_actor is not None: + self.time_actor.SetPosition(0.5, 0.03) + self.time_actor.GetTextProperty().SetJustificationToCentered() + self.time_actor.GetTextProperty().BoldOn() + self.time_actor.VisibilityOff() + + def configure_scalar_bar(self): + if self.brain._colorbar_added: + scalar_bar = self.plotter.scalar_bar + scalar_bar.SetOrientationToVertical() + scalar_bar.SetHeight(0.6) + scalar_bar.SetWidth(0.05) + scalar_bar.SetPosition(0.02, 0.2) + + def configure_sliders(self): + rng = _get_range(self.brain) + # Orientation slider # default: put orientation slider on the first view if self.brain._hemi in ('split', 'both'): self.plotter.subplot(0, 0) @@ -290,7 +455,7 @@ def __init__(self, brain, show_traces=False): self.orientation_call = ShowView( plotter=self.plotter, brain=self.brain, - orientation=orientation, + orientation=self.orientation, hemi=hemi, row=ri, col=ci, @@ -299,7 +464,7 @@ def __init__(self, brain, show_traces=False): orientation_slider = self.plotter.add_text_slider_widget( self.orientation_call, value=0, - data=orientation, + data=self.orientation, pointa=(0.82, 0.74), pointb=(0.98, 0.74), event_type='always' @@ -312,44 +477,28 @@ def __init__(self, brain, show_traces=False): if self.brain._hemi in ('split', 'both'): self.plotter.subplot(0, 0) - # scalar bar - if brain._colorbar_added: - scalar_bar = self.plotter.scalar_bar - scalar_bar.SetOrientationToVertical() - scalar_bar.SetHeight(0.6) - scalar_bar.SetWidth(0.05) - scalar_bar.SetPosition(0.02, 0.2) - - # smoothing slider - default_smoothing_value = 7 + # Smoothing slider self.smoothing_call = IntSlider( plotter=self.plotter, - callback=brain.set_data_smoothing, + callback=self.brain.set_data_smoothing, name="smoothing" ) smoothing_slider = self.plotter.add_slider_widget( self.smoothing_call, - value=default_smoothing_value, - rng=[0, 15], title="smoothing", + value=self.default_smoothing_value, + rng=self.default_smoothing_range, title="smoothing", pointa=(0.82, 0.90), pointb=(0.98, 0.90) ) smoothing_slider.name = 'smoothing' - self.smoothing_call(default_smoothing_value) + self.smoothing_call(self.default_smoothing_value) - # time label - self.time_actor = brain._data.get('time_actor') - if self.time_actor is not None: - self.time_actor.SetPosition(0.5, 0.03) - self.time_actor.GetTextProperty().SetJustificationToCentered() - self.time_actor.GetTextProperty().BoldOn() - self.time_actor.VisibilityOff() - - # time slider - max_time = len(brain._data['time']) - 1 + # Time slider + max_time = len(self.brain._data['time']) - 1 self.time_call = TimeSlider( plotter=self.plotter, - brain=self.brain + brain=self.brain, + callback=self.plot_time_line, ) time_slider = self.plotter.add_slider_widget( self.time_call, @@ -360,13 +509,11 @@ def __init__(self, brain, show_traces=False): event_type='always' ) time_slider.GetRepresentation().SetLabelFormat('idx=%0.1f') - time_slider.name = "time" # set the default value - self.time_call(value=brain._data['time_idx']) + self.time_call(value=self.brain._data['time_idx']) - # playback speed - default_playback_speed = 0.05 + # Playback speed slider self.playback_speed_call = SmartSlider( plotter=self.plotter, callback=self.set_playback_speed, @@ -374,61 +521,60 @@ def __init__(self, brain, show_traces=False): ) playback_speed_slider = self.plotter.add_slider_widget( self.playback_speed_call, - value=default_playback_speed, - rng=[0.01, 1], title="speed", + value=self.default_playback_speed_value, + rng=self.default_playback_speed_range, title="speed", pointa=(0.02, 0.1), pointb=(0.18, 0.1), event_type='always' ) playback_speed_slider.name = "playback_speed" - # colormap slider - scaling_limits = [0.2, 2.0] + # Colormap slider pointa = np.array((0.82, 0.26)) pointb = np.array((0.98, 0.26)) shift = np.array([0, 0.08]) - fmin = brain._data["fmin"] + fmin = self.brain._data["fmin"] self.fmin_call = BumpColorbarPoints( plotter=self.plotter, - brain=brain, + brain=self.brain, name="fmin" ) fmin_slider = self.plotter.add_slider_widget( self.fmin_call, value=fmin, - rng=_get_range(brain), title="clim", + rng=rng, title="clim", pointa=pointa, pointb=pointb, event_type="always", ) fmin_slider.name = "fmin" self.fmin_slider_rep = fmin_slider.GetRepresentation() - fmid = brain._data["fmid"] + fmid = self.brain._data["fmid"] self.fmid_call = BumpColorbarPoints( plotter=self.plotter, - brain=brain, + brain=self.brain, name="fmid", ) fmid_slider = self.plotter.add_slider_widget( self.fmid_call, value=fmid, - rng=_get_range(brain), title="", + rng=rng, title="", pointa=pointa + shift, pointb=pointb + shift, event_type="always", ) fmid_slider.name = "fmid" self.fmid_slider_rep = fmid_slider.GetRepresentation() - fmax = brain._data["fmax"] + fmax = self.brain._data["fmax"] self.fmax_call = BumpColorbarPoints( plotter=self.plotter, - brain=brain, + brain=self.brain, name="fmax", ) fmax_slider = self.plotter.add_slider_widget( self.fmax_call, value=fmax, - rng=_get_range(brain), title="", + rng=rng, title="", pointa=pointa + 2 * shift, pointb=pointb + 2 * shift, event_type="always", @@ -437,26 +583,17 @@ def __init__(self, brain, show_traces=False): self.fmax_slider_rep = fmax_slider.GetRepresentation() self.fscale_call = UpdateColorbarScale( plotter=self.plotter, - brain=brain, + brain=self.brain, ) fscale_slider = self.plotter.add_slider_widget( self.fscale_call, value=1.0, - rng=scaling_limits, title="fscale", + rng=self.default_scaling_range, title="fscale", pointa=(0.82, 0.10), pointb=(0.98, 0.10) ) fscale_slider.name = "fscale" - # add toggle to start/pause playback - self.playback = False - self.playback_speed = default_playback_speed - self.refresh_rate_ms = max(int(round(1000. / 60.)), 1) - self.plotter.add_callback(self.play, self.refresh_rate_ms) - - # add toggle to show/hide interface - self.visibility = True - # set the slider style self.set_slider_style(smoothing_slider) self.set_slider_style(fmin_slider) @@ -466,179 +603,86 @@ def __init__(self, brain, show_traces=False): self.set_slider_style(playback_speed_slider) self.set_slider_style(time_slider) - # Point Picking and MplCanvas plotting - if isinstance(show_traces, str) and show_traces == "separate": - show_traces = True - self.separate_canvas = True - else: - self.separate_canvas = False - if isinstance(show_traces, bool) and show_traces: - self.act_data = {'lh': None, 'rh': None} - self.color_cycle = None - self.picked_points = {'lh': list(), 'rh': list()} - self._mouse_no_mvt = -1 - self.enable_point_picking() + def configure_playback(self): + self.plotter.add_callback(self.play, self.refresh_rate_ms) + + def configure_point_picking(self): + from ..backends._pyvista import _update_picking_callback + if self.show_traces: + # use a matplotlib canvas + self.color_cycle = cycle(_get_color_list()) + win = self.plotter.app_window + dpi = win.windowHandle().screen().logicalDotsPerInch() + w, h = win.geometry().width() / dpi, win.geometry().height() / dpi + h /= 3 # one third of the window + if self.separate_canvas: + parent = None + else: + parent = win + self.mpl_canvas = MplCanvas(self, parent, w, h, dpi) + xlim = [np.min(self.brain._data['time']), + np.max(self.brain._data['time'])] + self.mpl_canvas.axes.set(xlim=xlim) + vlayout = self.plotter.frame.layout() + if self.separate_canvas: + self.plotter.app_window.signal_close.connect( + self.mpl_canvas.close) + self.mpl_canvas.show() + else: + vlayout.addWidget(self.mpl_canvas.canvas) + vlayout.setStretch(0, 2) + vlayout.setStretch(1, 1) + + # get brain data + for idx, hemi in enumerate(['lh', 'rh']): + hemi_data = self.brain._data.get(hemi) + if hemi_data is not None: + self.act_data[hemi] = hemi_data['array'] + smooth_mat = hemi_data['smooth_mat'] + if smooth_mat is not None: + self.act_data[hemi] = smooth_mat.dot( + self.act_data[hemi]) + + # simulate a picked renderer + if self.brain._hemi == 'split': + self.picked_renderer = self.plotter.renderers[idx] + else: + self.picked_renderer = self.plotter.renderers[0] + + # initialize the default point + color = next(self.color_cycle) + ind = np.unravel_index( + np.argmax(self.act_data[hemi], axis=None), + self.act_data[hemi].shape + ) + vertex_id = ind[0] + mesh = hemi_data['mesh'][-1] + line = self.plot_time_course(hemi, vertex_id, color) + self.add_point(hemi, mesh, vertex_id, line, color) + + self.plot_time_line() + + _update_picking_callback( + self.plotter, + self.on_mouse_move, + self.on_button_press, + self.on_button_release, + self.on_pick + ) + def configure_menu(self): # remove default picking menu - main_menu = self.plotter.main_menu to_remove = list() - for action in main_menu.actions(): + for action in self.main_menu.actions(): if action.text() == "Tools": to_remove.append(action) for action in to_remove: - main_menu.removeAction(action) + self.main_menu.removeAction(action) - # setup key bindings - self.key_bindings = { - '?': self.help, - 'i': self.toggle_interface, - 's': self.apply_auto_scaling, - 'r': self.restore_user_scaling, - 'c': self.clear_points, - ' ': self.toggle_playback, - } - menu = self.plotter.main_menu.addMenu('Help') + # add help menu + menu = self.main_menu.addMenu('Help') menu.addAction('Show MNE key bindings\t?', self.help) - def keyPressEvent(self, event): - callback = self.key_bindings.get(event.text()) - if callback is not None: - callback() - - def toggle_interface(self): - self.visibility = not self.visibility - - # manage sliders - for slider in self.plotter.slider_widgets: - slider_rep = slider.GetRepresentation() - if self.visibility: - slider_rep.VisibilityOn() - else: - slider_rep.VisibilityOff() - - # manage time label - time_label = self.brain._data['time_label'] - if callable(time_label) and self.time_actor is not None: - if self.visibility: - self.time_actor.VisibilityOff() - else: - self.time_actor.SetInput(time_label(self.brain._current_time)) - self.time_actor.VisibilityOn() - - def apply_auto_scaling(self): - self.brain.update_auto_scaling() - self.fmin_slider_rep.SetValue(self.brain._data['fmin']) - self.fmid_slider_rep.SetValue(self.brain._data['fmid']) - self.fmax_slider_rep.SetValue(self.brain._data['fmax']) - - def restore_user_scaling(self): - self.brain.update_auto_scaling(restore=True) - self.fmin_slider_rep.SetValue(self.brain._data['fmin']) - self.fmid_slider_rep.SetValue(self.brain._data['fmid']) - self.fmax_slider_rep.SetValue(self.brain._data['fmax']) - - def toggle_playback(self): - self.playback = not self.playback - if self.playback: - time_data = self.brain._data['time'] - max_time = np.max(time_data) - if self.brain._current_time == max_time: # start over - self.brain.set_time_point(np.min(time_data)) - self._last_tick = time.time() - - def set_playback_speed(self, speed): - self.playback_speed = speed - - def play(self): - if self.playback: - this_time = time.time() - delta = this_time - self._last_tick - self._last_tick = time.time() - time_data = self.brain._data['time'] - times = np.arange(self.brain._n_times) - time_shift = delta * self.playback_speed - max_time = np.max(time_data) - time_point = min(self.brain._current_time + time_shift, max_time) - # always use linear here -- this does not determine the data - # interpolation mode, it just finds where we are (in time) in - # terms of the time indices - idx = np.interp(time_point, time_data, times) - self.time_call(idx, update_widget=True) - if time_point == max_time: - self.playback = False - self.plotter.update() # critical for smooth animation - - def set_slider_style(self, slider, show_label=True): - if slider is not None: - slider_rep = slider.GetRepresentation() - slider_rep.SetSliderLength(0.02) - slider_rep.SetSliderWidth(0.04) - slider_rep.SetTubeWidth(0.005) - slider_rep.SetEndCapLength(0.01) - slider_rep.SetEndCapWidth(0.02) - slider_rep.GetSliderProperty().SetColor((0.5, 0.5, 0.5)) - if not show_label: - slider_rep.ShowSliderLabelOff() - - def enable_point_picking(self): - from ..backends._pyvista import _update_picking_callback - # use a matplotlib canvas - self.color_cycle = cycle(_get_color_list()) - win = self.plotter.app_window - dpi = win.windowHandle().screen().logicalDotsPerInch() - w, h = win.geometry().width() / dpi, win.geometry().height() / dpi - h /= 3 # one third of the window - if self.separate_canvas: - parent = None - else: - parent = win - self.mpl_canvas = MplCanvas(parent, w, h, dpi) - xlim = [np.min(self.brain._data['time']), - np.max(self.brain._data['time'])] - self.mpl_canvas.axes.set(xlim=xlim) - vlayout = self.plotter.frame.layout() - if self.separate_canvas: - self.plotter.app_window.signal_close.connect(self.mpl_canvas.close) - self.mpl_canvas.show() - else: - vlayout.addWidget(self.mpl_canvas.canvas) - vlayout.setStretch(0, 2) - vlayout.setStretch(1, 1) - - # get brain data - for idx, hemi in enumerate(['lh', 'rh']): - hemi_data = self.brain._data.get(hemi) - if hemi_data is not None: - self.act_data[hemi] = hemi_data['array'] - smooth_mat = hemi_data['smooth_mat'] - if smooth_mat is not None: - self.act_data[hemi] = smooth_mat.dot(self.act_data[hemi]) - - # simulate a picked renderer - if self.brain._hemi == 'split': - self.picked_renderer = self.plotter.renderers[idx] - else: - self.picked_renderer = self.plotter.renderers[0] - - # initialize the default point - color = next(self.color_cycle) - ind = np.unravel_index( - np.argmax(self.act_data[hemi], axis=None), - self.act_data[hemi].shape - ) - vertex_id = ind[0] - mesh = hemi_data['mesh'][-1] - line = self.plot_time_course(hemi, vertex_id, color) - self.add_point(hemi, mesh, vertex_id, line, color) - - _update_picking_callback( - self.plotter, - self.on_mouse_move, - self.on_button_press, - self.on_button_release, - self.on_pick - ) - def on_mouse_move(self, vtk_picker, event): if self._mouse_no_mvt: self._mouse_no_mvt -= 1 @@ -732,6 +776,8 @@ def clear_points(self): self._spheres.clear() def plot_time_course(self, hemi, vertex_id, color): + if not hasattr(self, "mpl_canvas"): + return time = self.brain._data['time'] hemi_str = 'L' if hemi == 'lh' else 'R' hemi_int = 0 if hemi == 'lh' else 1 @@ -753,6 +799,23 @@ def plot_time_course(self, hemi, vertex_id, color): ) return line + def plot_time_line(self): + if not hasattr(self, "mpl_canvas"): + return + if isinstance(self.show_traces, bool) and self.show_traces: + # add time information + current_time = self.brain._current_time + if not hasattr(self, "time_line"): + self.time_line = self.mpl_canvas.plot_time_line( + x=current_time, + label='time', + color='black', + lw=1, + ) + else: + self.time_line.set_xdata(current_time) + self.mpl_canvas.update_plot() + def help(self): pairs = [ ('?', 'Display help window'),