diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index 7ecb4bec6c1..3de4042320d 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -861,8 +861,6 @@ def set_time_point(self, time_idx): from ..backends._pyvista import _set_mesh_scalars from scipy.interpolate import interp1d time = self._data['time'] - time_label = self._data['time_label'] - time_actor = self._data.get('time_actor') for hemi in ['lh', 'rh']: hemi_data = self._data.get(hemi) if hemi_data is not None: @@ -872,23 +870,18 @@ def set_time_point(self, time_idx): if array.ndim == 2: if isinstance(time_idx, int): act_data = array[:, time_idx] + self._current_time = time[time_idx] else: times = np.arange(self._n_times) act_data = interp1d(times, array, 'linear', axis=1, assume_sorted=True)(time_idx) + ifunc = interp1d(times, self._data['time']) + self._current_time = ifunc(time_idx) smooth_mat = hemi_data['smooth_mat'] if smooth_mat is not None: act_data = smooth_mat.dot(act_data) _set_mesh_scalars(mesh, act_data, 'Data') - if callable(time_label) and time_actor is not None: - if isinstance(time_idx, int): - self._current_time = time[time_idx] - time_actor.SetInput(time_label(self._current_time)) - else: - ifunc = interp1d(times, self._data['time']) - self._current_time = ifunc(time_idx) - time_actor.SetInput(time_label(self._current_time)) self._data['time_idx'] = time_idx def update_fmax(self, fmax): diff --git a/mne/viz/_brain/_timeviewer.py b/mne/viz/_brain/_timeviewer.py index b2142b98cd8..d512c0b3955 100644 --- a/mne/viz/_brain/_timeviewer.py +++ b/mne/viz/_brain/_timeviewer.py @@ -30,6 +30,36 @@ def __call__(self, value): self.callback(idx) +class TimeSlider(object): + """Class to update the time slider.""" + + def __init__(self, plotter=None, brain=None): + self.plotter = plotter + self.brain = brain + self.slider_rep = None + if brain is None: + self.time_label = None + else: + if callable(self.brain._data['time_label']): + self.time_label = self.brain._data['time_label'] + + def __call__(self, value, update_widget=False): + """Update the time slider.""" + self.brain.set_time_point(value) + current_time = self.brain._current_time + if self.slider_rep is None: + for slider in self.plotter.slider_widgets: + name = getattr(slider, "name", None) + if name == "time": + self.slider_rep = slider.GetRepresentation() + if self.slider_rep is not None: + if update_widget: + self.slider_rep.SetValue(value) + if self.time_label is not None: + current_time = self.time_label(current_time) + self.slider_rep.SetTitleText(current_time) + + class UpdateColorbarScale(object): """Class to update the values of the colorbar sliders.""" @@ -259,23 +289,25 @@ def __init__(self, brain): if self.time_actor is not None: self.time_actor.SetPosition(0.5, 0.03) self.time_actor.GetTextProperty().SetJustificationToCentered() + self.time_actor.GetTextProperty().BoldOn() + self.time_actor.VisibilityOff() # time slider max_time = len(brain._data['time']) - 1 - self.time_call = SmartSlider( + self.time_call = TimeSlider( plotter=self.plotter, - callback=self.brain.set_time_point, - name="time" + brain=self.brain ) time_slider = self.plotter.add_slider_widget( self.time_call, - value=brain._data['time_idx'], rng=[0, max_time], pointa=(0.23, 0.1), pointb=(0.77, 0.1), event_type='always' ) time_slider.name = "time" + # set the default value + self.time_call(value=brain._data['time_idx']) # playback speed default_playback_speed = 0.05 @@ -384,13 +416,11 @@ def __init__(self, brain): self.set_slider_style(fmax_slider) self.set_slider_style(fscale_slider) self.set_slider_style(playback_speed_slider) - self.set_slider_style(time_slider, show_label=False) - - # set the text style - _set_text_style(self.time_actor) + self.set_slider_style(time_slider) def toggle_interface(self): self.visibility = not self.visibility + # manage sliders for slider in self.plotter.slider_widgets: slider_rep = slider.GetRepresentation() if self.visibility: @@ -398,6 +428,15 @@ def toggle_interface(self): else: slider_rep.VisibilityOff() + # manage time label + time_label = self.brain._data['time_label'] + if callable(time_label) and self.time_actor is not None: + if self.visibility: + self.time_actor.VisibilityOff() + else: + self.time_actor.SetInput(time_label(self.brain._current_time)) + self.time_actor.VisibilityOn() + def apply_auto_scaling(self): self.brain.update_auto_scaling() self.fmin_slider_rep.SetValue(self.brain._data['fmin']) @@ -506,12 +545,6 @@ def link_sliders(self, name, callback, event_type): ) -def _set_text_style(text_actor): - if text_actor is not None: - prop = text_actor.GetTextProperty() - prop.BoldOn() - - def _get_range(brain): val = np.abs(brain._data['array']) return [np.min(val), np.max(val)]