diff --git a/mne/coreg.py b/mne/coreg.py index 2ce15dcb966..3c38d55c3ba 100644 --- a/mne/coreg.py +++ b/mne/coreg.py @@ -15,6 +15,7 @@ import re import shutil from functools import reduce +from collections import OrderedDict import numpy as np @@ -1632,6 +1633,11 @@ def _nearest_transformed_high_res_mri_idx_hsp(self): return self._nearest_calc.query( apply_trans(self._head_mri_t, self._filtered_extra_points))[1] + @property + def _has_hsp_data(self): + return (self._has_mri_data and + len(self._nearest_transformed_high_res_mri_idx_hsp) > 0) + @property def _has_hpi_data(self): return (self._has_mri_data and @@ -1648,7 +1654,7 @@ def _has_lpa_data(self): @property def _has_nasion_data(self): - return (np.any(self._nasion) and np.any(self._dig_dict.nasion)) + return (np.any(self._nasion) and np.any(self._dig_dict['nasion'])) @property def _has_rpa_data(self): @@ -1963,3 +1969,57 @@ def reset(self): self._extra_points_filter = None self._update_nearest_calc() return self + + def _get_fiducials_distance(self): + distance = OrderedDict() + for key in ('lpa', 'nasion', 'rpa'): + fid = getattr(self, f"_{key}") + transformed_mri = apply_trans(self._mri_trans, fid) + transformed_hsp = apply_trans( + self._head_mri_t, self._dig_dict[key]) + distance[key] = np.linalg.norm( + np.ravel(transformed_mri - transformed_hsp)) + return np.array(list(distance.values())) * 1e3 + + def _get_fiducials_distance_str(self): + dists = self._get_fiducials_distance() + return f"Fiducials: {dists[0]:.1f}, {dists[1]:.1f}, {dists[2]:.1f} mm" + + def _get_point_distance(self): + mri_points = list() + hsp_points = list() + if self._hsp_weight > 0 and self._has_hsp_data: + mri_points.append(self._transformed_high_res_mri_points[ + self._nearest_transformed_high_res_mri_idx_hsp]) + hsp_points.append(self._transformed_dig_extra) + assert len(mri_points[-1]) == len(hsp_points[-1]) + if self._eeg_weight > 0 and self._has_eeg_data: + mri_points.append(self._transformed_high_res_mri_points[ + self._nearest_transformed_high_res_mri_idx_eeg]) + hsp_points.append(self._transformed_dig_eeg) + assert len(mri_points[-1]) == len(hsp_points[-1]) + if self._hpi_weight > 0 and self._has_hpi_data: + mri_points.append(self._transformed_high_res_mri_points[ + self._nearest_transformed_high_res_mri_idx_hpi]) + hsp_points.append(self._transformed_dig_hpi) + assert len(mri_points[-1]) == len(hsp_points[-1]) + if all(len(h) == 0 for h in hsp_points): + return None + mri_points = np.concatenate(mri_points) + hsp_points = np.concatenate(hsp_points) + return np.linalg.norm(mri_points - hsp_points, axis=-1) + + def _get_point_distance_str(self): + point_distance = self._get_point_distance() + if point_distance is None: + return "" + dists = 1e3 * point_distance + av_dist = np.mean(dists) + std_dist = np.std(dists) + kinds = [kind for kind, check in + (('HSP', self._hsp_weight > 0 and self._has_hsp_data), + ('EEG', self._eeg_weight > 0 and self._has_eeg_data), + ('HPI', self._hpi_weight > 0 and self._has_hpi_data)) + if check] + kinds = '+'.join(kinds) + return f"{len(dists)} {kinds}: {av_dist:.1f} ± {std_dist:.1f} mm" diff --git a/mne/gui/_coreg.py b/mne/gui/_coreg.py index 52fe66440c6..d7ac1b04dfd 100644 --- a/mne/gui/_coreg.py +++ b/mne/gui/_coreg.py @@ -219,6 +219,7 @@ def _get_default(var, val): # configure UI self._reset_fitting_parameters() + self._configure_status_bar() self._configure_dock() self._configure_picking() @@ -242,9 +243,11 @@ def _get_default(var, val): False: dict(azimuth=180, elevation=90)} # left self._renderer.set_camera(distance=None, **views[self._lock_fids]) self._redraw() + # XXX: internal plotter/renderer should not be exposed if not self._immediate_redraw: self._renderer.plotter.add_callback( self._redraw, self._refresh_rate_ms) + self._renderer.plotter.show_axes() if standalone: _qt_app_exec(self._renderer.figure.store["app"]) @@ -373,6 +376,8 @@ def _set_point_weight(self, weight, point): if point in funcs.keys(): getattr(self, funcs[point])(weight > 0) setattr(self, f"_{point}_weight", weight) + setattr(self._coreg, f"_{point}_weight", weight) + self._update_distance_estimation() @observe("_subjects_dir") def _subjects_dir_changed(self, change=None): @@ -402,6 +407,7 @@ def _lock_fids_changed(self, change=None): if self._lock_fids: self._forward_widget_command(view_widgets, "set_enabled", True) self._display_message() + self._update_distance_estimation() else: self._forward_widget_command(view_widgets, "set_enabled", False) self._display_message("Picking fiducials - " @@ -415,6 +421,7 @@ def _lock_fids_changed(self, change=None): def _fiducials_file_changed(self, change=None): fids, _ = read_fiducials(self._fiducials_file) self._coreg._setup_fiducials(fids) + self._update_distance_estimation() self._reset() self._set_lock_fids(True) @@ -530,6 +537,7 @@ def _redraw(self, verbose=None): draw_map[key]() self._redraws_pending.clear() self._renderer._update() + self._renderer._process_events() # necessary for MacOS? def _on_mouse_move(self, vtk_picker, event): if self._mouse_no_mvt: @@ -541,7 +549,7 @@ def _on_button_press(self, vtk_picker, event): def _on_button_release(self, vtk_picker, event): if self._mouse_no_mvt > 0: x, y = vtk_picker.GetEventPosition() - # XXX: plotter/renderer should not be exposed if possible + # XXX: internal plotter/renderer should not be exposed plotter = self._renderer.figure.plotter picked_renderer = self._renderer.figure.plotter.renderer # trigger the pick @@ -589,11 +597,22 @@ def _reset_fiducials(self): def _omit_hsp(self): self._coreg.omit_head_shape_points(self._omit_hsp_distance / 1e3) + n_omitted = np.sum(~self._coreg._extra_points_filter) + n_remaining = len(self._coreg._dig_dict['hsp']) - n_omitted self._update_plot("hsp") + self._update_distance_estimation() + self._display_message( + f"{n_omitted} head shape points omitted, " + f"{n_remaining} remaining.") def _reset_omit_hsp_filter(self): self._coreg._extra_points_filter = None + self._coreg._update_params(force_update_omitted=True) self._update_plot("hsp") + self._update_distance_estimation() + n_total = len(self._coreg._dig_dict['hsp']) + self._display_message( + f"No head shape point is omitted, the total is {n_total}.") def _update_plot(self, changes="all"): # Update list of things that need to be updated/plotted (and maybe @@ -638,11 +657,9 @@ def _lock_plot(self): self._plot_locked = old_plot_locked def _display_message(self, msg=""): - if "msg" not in self._actors: - self._actors["msg"] = self._renderer.text2d(0, 0, msg) - else: - self._actors["msg"].SetInput(msg) - self._renderer._update() + self._status_msg.set_value(msg) + self._status_msg.show() + self._status_msg.update() def _follow_fiducial_view(self): fid = self._current_fiducial.lower() @@ -659,6 +676,16 @@ def _update_fiducials(self): self._forward_widget_command( ["fid_X", "fid_Y", "fid_Z"], "set_value", val) + def _update_distance_estimation(self): + value = self._coreg._get_fiducials_distance_str() + '\n' + \ + self._coreg._get_point_distance_str() + dists = self._coreg.compute_dig_mri_distances() * 1e3 + if self._hsp_weight > 0: + value += "\nHSP <-> MRI (mean/min/max): "\ + f"{np.mean(dists):.2f} "\ + f"/ {np.min(dists):.2f} / {np.max(dists):.2f} mm" + self._forward_widget_command("fit_label", "set_value", value) + def _update_parameters(self): with self._lock_plot(): # rotation @@ -676,6 +703,7 @@ def _reset(self): self._coreg.reset() self._update_plot() self._update_parameters() + self._update_distance_estimation() def _forward_widget_command(self, names, command, value): names = [names] if not isinstance(names, list) else names @@ -697,6 +725,7 @@ def _set_sensors_visibility(self, state): self._renderer._update() def _update_actor(self, actor_name, actor): + # XXX: internal plotter/renderer should not be exposed self._renderer.plotter.remove_actor(self._actors.get(actor_name)) self._actors[actor_name] = actor @@ -786,6 +815,10 @@ def _add_head_hair(self): self._coreg._get_processed_mri_points(res) def _fit_fiducials(self): + if not self._lock_fids: + self._display_message( + "Fitting is disabled, lock the fiducials first.") + return start = time.time() self._coreg.fit_fiducials( lpa_weight=self._lpa_weight, @@ -794,18 +827,25 @@ def _fit_fiducials(self): verbose=self._verbose, ) end = time.time() - self._renderer._status_bar_show_message( + self._display_message( f"Fitting fiducials finished in {end - start:.2f} seconds.") self._update_plot("sensors") self._update_parameters() + self._update_distance_estimation() def _fit_icp(self): + if not self._lock_fids: + self._display_message( + "Fitting is disabled, lock the fiducials first.") + return self._current_icp_iterations = 0 def callback(iteration, n_iterations): - self._display_message(f"Fitting ICP - iteration {iteration + 1}") + self._display_message( + f"Fitting ICP - iteration {iteration + 1}") self._update_plot("sensors") - self._current_icp_iterations = iteration + self._current_icp_iterations = n_iterations + self._update_distance_estimation() self._renderer._process_events() # allow a draw or cancel start = time.time() @@ -819,7 +859,7 @@ def callback(iteration, n_iterations): ) end = time.time() self._display_message() - self._renderer._status_bar_show_message( + self._display_message( f"Fitting ICP finished in {end - start:.2f} seconds and " f"{self._current_icp_iterations} iterations.") self._update_parameters() @@ -827,6 +867,8 @@ def callback(iteration, n_iterations): def _save_trans(self, fname): write_trans(fname, self._coreg.trans) + self._display_message( + "{fname} transform file is saved.") def _load_trans(self, fname): mri_head_t = _ensure_trans(read_trans(fname, return_all=True), @@ -838,6 +880,9 @@ def _load_trans(self, fname): tra=np.array([x, y, z]), ) self._update_parameters() + self._update_distance_estimation() + self._display_message( + f"{fname} transform file is loaded.") def _get_subjects(self, sdir=None): # XXX: would be nice to move this function to util @@ -931,7 +976,7 @@ def _configure_dock(self): layout=layout, ) self._widgets["grow_hair"] = self._renderer._dock_add_spin_box( - name="Grow Hair", + name="Grow Hair (mm)", value=self._grow_hair, rng=[0.0, 10.0], callback=self._set_grow_hair, @@ -939,7 +984,7 @@ def _configure_dock(self): ) hlayout = self._renderer._dock_add_layout(vertical=False) self._widgets["omit_distance"] = self._renderer._dock_add_spin_box( - name="Omit Distance", + name="Omit Distance (mm)", value=self._omit_hsp_distance, rng=[0.0, 100.0], callback=self._set_omit_hsp_distance, @@ -1032,6 +1077,10 @@ def _configure_dock(self): layout=hlayout, ) self._renderer._layout_add_widget(layout, hlayout) + self._widgets["fit_label"] = self._renderer._dock_add_label( + value="", + layout=layout, + ) self._widgets["icp_n_iterations"] = self._renderer._dock_add_spin_box( name="Number Of ICP Iterations", value=self._defaults["icp_n_iterations"], @@ -1110,6 +1159,10 @@ def _configure_dock(self): self._renderer._layout_add_widget(layout, hlayout) self._renderer._dock_add_stretch() + def _configure_status_bar(self): + self._status_msg = self._renderer._status_bar_add_label("", stretch=1) + self._status_msg.hide() + def _clean(self): self._renderer = None self._coreg = None @@ -1118,6 +1171,8 @@ def _clean(self): self._surfaces.clear() self._defaults.clear() self._head_geo = None + self._redraw = None + self._status_msg = None def close(self): """Close interface and cleanup data structure.""" diff --git a/mne/viz/backends/_abstract.py b/mne/viz/backends/_abstract.py index e1585f71204..fafe4546eb8 100644 --- a/mne/viz/backends/_abstract.py +++ b/mne/viz/backends/_abstract.py @@ -598,10 +598,6 @@ class _AbstractStatusBar(ABC): def _status_bar_initialize(self, window=None): pass - @abstractmethod - def _status_bar_show_message(self, value, timeout=5000): - pass - @abstractmethod def _status_bar_add_label(self, value, stretch=0): pass diff --git a/mne/viz/backends/_notebook.py b/mne/viz/backends/_notebook.py index 037423f6ac9..5bb7390363d 100644 --- a/mne/viz/backends/_notebook.py +++ b/mne/viz/backends/_notebook.py @@ -279,9 +279,6 @@ def _status_bar_initialize(self, window=None): self._status_bar = self._status_bar_layout = HBox() self._layout_initialize(None) - def _status_bar_show_message(self, value, timeout=5000): - pass - def _status_bar_add_label(self, value, stretch=0): widget = Text(value=value, disabled=True) self._layout_add_widget(self._status_bar_layout, widget) diff --git a/mne/viz/backends/_qt.py b/mne/viz/backends/_qt.py index e766a6a5105..55f924027b2 100644 --- a/mne/viz/backends/_qt.py +++ b/mne/viz/backends/_qt.py @@ -400,9 +400,6 @@ def _status_bar_initialize(self, window=None): self._status_bar = window.statusBar() self._status_bar_layout = self._status_bar.layout() - def _status_bar_show_message(self, value, timeout=5000): - self._status_bar.showMessage(value, timeout) - def _status_bar_add_label(self, value, stretch=0): widget = QLabel(value) self._layout_add_widget(self._status_bar_layout, widget, stretch)