From 1dd212eb0dcb34bbbb480e440382480ceb565143 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Thu, 17 Oct 2024 16:18:33 -0400 Subject: [PATCH] ENH: Improve plotting and reporting --- doc/changes/devel/bugfix.rst | 1 + doc/changes/devel/newfeature.rst | 6 ++ mne/report/report.py | 96 +++++++++++++++++++++----------- mne/utils/docs.py | 22 +++++--- mne/viz/_3d.py | 60 ++++++++++++++------ 5 files changed, 127 insertions(+), 58 deletions(-) create mode 100644 doc/changes/devel/bugfix.rst create mode 100644 doc/changes/devel/newfeature.rst diff --git a/doc/changes/devel/bugfix.rst b/doc/changes/devel/bugfix.rst new file mode 100644 index 00000000000..88db1504a06 --- /dev/null +++ b/doc/changes/devel/bugfix.rst @@ -0,0 +1 @@ +:class:`mne.Report` HDF5 files are now written in ``mode='a`` (append) to allow users to store other data in the HDF5 files, by `Eric Larson`_. diff --git a/doc/changes/devel/newfeature.rst b/doc/changes/devel/newfeature.rst new file mode 100644 index 00000000000..a92fa989cd7 --- /dev/null +++ b/doc/changes/devel/newfeature.rst @@ -0,0 +1,6 @@ +Improved reporting and plotting options: + +- :meth:`mne.Report.add_projs` can now plot with :func:`mne.viz.plot_projs_joint` rather than :func:`mne.viz.plot_projs_topomap` +- :func:`mne.viz.plot_head_positions` now has a ``totals=True`` option to show the total distance and angle of the head. + +Changes by `Eric Larson`_. diff --git a/mne/report/report.py b/mne/report/report.py index 0ca781378f5..05473dca819 100644 --- a/mne/report/report.py +++ b/mne/report/report.py @@ -75,6 +75,7 @@ plot_compare_evokeds, plot_cov, plot_events, + plot_projs_joint, plot_projs_topomap, set_3d_view, use_browser_backend, @@ -1613,27 +1614,43 @@ def add_projs( self, *, info, - projs=None, title, + projs=None, topomap_kwargs=None, tags=("ssp",), + joint=False, + picks_trace=None, + section=None, replace=False, ): """Render (SSP) projection vectors. Parameters ---------- - info : instance of Info | path-like - An `~mne.Info` structure or the path of a file containing one. This - is required to create the topographic plots. + info : instance of Info | instance of Evoked | path-like + An `~mne.Info` structure or the path of a file containing one. + title : str + The title corresponding to the :class:`~mne.Projection` object. projs : iterable of mne.Projection | path-like | None The projection vectors to add to the report. Can be the path to a file that will be loaded via `mne.read_proj`. If ``None``, the projectors are taken from ``info['projs']``. - title : str - The title corresponding to the `~mne.Projection` object. %(topomap_kwargs)s %(tags_report)s + joint : bool + If True (default False), plot the projectors using + :func:`mne.viz.plot_projs_joint`, otherwise use + :func:`mne.viz.plot_projs_topomap`. If True, then ``info`` must be an + instance of :class:`mne.Evoked`. + + .. versionadded:: 1.9 + %(picks_plot_projs_joint_trace)s + Only used when ``joint=True``. + + .. versionadded:: 1.9 + %(section_report)s + + .. versionadded:: 1.9 %(replace_report)s Notes @@ -1646,10 +1663,11 @@ def add_projs( projs=projs, title=title, image_format=self.image_format, - section=None, + section=section, tags=tags, topomap_kwargs=topomap_kwargs, replace=replace, + joint=joint, ) def _add_ica_overlay(self, *, ica, inst, image_format, section, tags, replace): @@ -2970,9 +2988,10 @@ def save( if is_hdf5: _, write_hdf5 = _import_h5io_funcs() - write_hdf5( - fname, self.__getstate__(), overwrite=overwrite, title="mnepython" - ) + import h5py + + with h5py.File(fname, "a") as f: + write_hdf5(f, self.__getstate__(), title="mnepython") else: # Add header, TOC, and footer. header_html = _html_header_element( @@ -3262,26 +3281,27 @@ def _add_projs( section, topomap_kwargs, replace, + picks_trace=None, + joint=False, ): + evoked = None if isinstance(info, Info): # no-op pass elif hasattr(info, "info"): # try to get the file name - if isinstance(info, BaseRaw): - fname = info.filenames[0] - # elif isinstance(info, (Evoked, BaseEpochs)): - # fname = info.filename - else: - fname = "" + if isinstance(info, Evoked): + evoked = info info = info.info else: # read from a file - fname = info - info = read_info(fname, verbose=False) + info = read_info(info, verbose=False) + if joint and evoked is None: + raise ValueError( + "joint=True requires an evoked instance to be passed as the info" + ) if projs is None: projs = info["projs"] elif not isinstance(projs, list): - fname = projs - projs = read_proj(fname) + projs = read_proj(projs) if not projs: raise ValueError("No SSP projectors found") @@ -3294,19 +3314,29 @@ def _add_projs( ) topomap_kwargs = self._validate_topomap_kwargs(topomap_kwargs) - fig = plot_projs_topomap( - projs=projs, - info=info, - colorbar=True, - vlim="joint", - show=False, - **topomap_kwargs, - ) - # TODO This seems like a bad idea, better to provide a way to set a - # desired size in plot_projs_topomap, but that uses prepare_trellis... - # hard to see how (6, 4) could work in all number-of-projs by - # number-of-channel-types conditions... - fig.set_size_inches((6, 4)) + if evoked is None: + fig = plot_projs_topomap( + projs=projs, + info=info, + colorbar=True, + vlim="joint", + show=False, + **topomap_kwargs, + ) + # TODO This seems like a bad idea, better to provide a way to set a + # desired size in plot_projs_topomap, but that uses prepare_trellis... + # hard to see how (6, 4) could work in all number-of-projs by + # number-of-channel-types conditions... + fig.set_size_inches((6, 4)) + else: + fig = plot_projs_joint( + projs, + evoked=evoked, + picks_trace=picks_trace, + topomap_kwargs=topomap_kwargs, + show=False, + ) + _constrain_fig_resolution(fig, max_width=MAX_IMG_WIDTH, max_res=MAX_IMG_RES) self._add_figure( fig=fig, diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 4dec335df1f..03bf9f10cf1 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1118,14 +1118,20 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ docdict["destination_maxwell_dest"] = """ -destination : path-like | array-like, shape (3,) | None - The destination location for the head. Can be ``None``, which - will not change the head position, or a path to a FIF file - containing a MEG device<->head transformation, or a 3-element array - giving the coordinates to translate to (with no rotations). - For example, ``destination=(0, 0, 0.04)`` would translate the bases - as ``--trans default`` would in MaxFilter™ (i.e., to the default - head location). +destination : path-like | array-like, shape (3,) | instance of Transform | None + The destination location for the head. Can be: + + ``None`` + Will not change the head position. + :class:`mne.Transform` + A MEG device<->head transformation, e.g. ``info["dev_head_t"]``. + :class:`numpy.ndarray` + A 3-element array giving the coordinates to translate to (with no rotations). + For example, ``destination=(0, 0, 0.04)`` would translate the bases + as ``--trans default`` would in MaxFilter™ (i.e., to the default + head location). + ``path-like`` + A path to a FIF file containing the destination MEG device<->head transformation. """ docdict["detrend_epochs"] = """ diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index 2fb94134830..7f8fb98b07a 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -50,6 +50,8 @@ ) from ..transforms import ( Transform, + _angle_between_quats, + _angle_dist_between_rigid, _ensure_trans, _find_trans, _frame_to_str, @@ -131,6 +133,7 @@ def plot_head_positions( info=None, color="k", axes=None, + totals=False, ): """Plot head positions. @@ -149,10 +152,9 @@ def plot_head_positions( directional axes in "field" mode. show : bool Show figure if True. Defaults to True. - destination : str | array-like, shape (3,) | None - The destination location for the head, assumed to be in head - coordinates. See :func:`mne.preprocessing.maxwell_filter` for - details. + destination : path-like | array-like, shape (3,) | instance of Transform | None + The destination location for the head. See + :func:`mne.preprocessing.maxwell_filter` for details. .. versionadded:: 0.16 %(info)s If provided, will be used to show the destination position when @@ -164,12 +166,16 @@ def plot_head_positions( arrows in ``mode == 'field'``. .. versionadded:: 0.16 - axes : array-like, shape (3, 2) + axes : array-like, shape (3, 2) or (4, 2) The matplotlib axes to use. .. versionadded:: 0.16 .. versionchanged:: 1.8 Added support for making use of this argument when ``mode="field"``. + totals : bool + If True and in traces mode, show the total distance and angle in a fourth row. + + .. versionadded:: 1.9 Returns ------- @@ -182,12 +188,12 @@ def plot_head_positions( from ..preprocessing.maxwell import _check_destination _check_option("mode", mode, ["traces", "field"]) + _validate_type(totals, bool, "totals") dest_info = dict(dev_head_t=None) if info is None else info destination = _check_destination(destination, dest_info, head_frame=True) if destination is not None: destination = _ensure_trans(destination, "head", "meg") # probably inv - destination = destination["trans"][:3].copy() - destination[:, 3] *= 1000 + destination = destination["trans"] if not isinstance(pos, list | tuple): pos = [pos] @@ -228,15 +234,17 @@ def plot_head_positions( surf["rr"] *= 1000.0 helmet_color = DEFAULTS["coreg"]["helmet_color"] if mode == "traces": + want_shape = (3 + int(totals), 2) if axes is None: - axes = plt.subplots(3, 2, sharex=True)[1] + _, axes = plt.subplots(*want_shape, sharex=True, layout="constrained") else: axes = np.array(axes) - if axes.shape != (3, 2): - raise ValueError(f"axes must have shape (3, 2), got {axes.shape}") + _check_option("axes.shape", axes.shape, (want_shape,)) fig = axes[0, 0].figure - - labels = ["xyz", ("$q_1$", "$q_2$", "$q_3$")] + labels = [["x (mm)", "y (mm)", "z (mm)"], ["$q_1$", "$q_2$", "$q_3$"]] + if totals: + labels[0].append("dist (mm)") + labels[1].append("angle (°)") for ii, (quat, coord) in enumerate(zip(use_quats.T, use_trans.T)): axes[ii, 0].plot(t, coord, color, lw=1.0, zorder=3) axes[ii, 0].set(ylabel=labels[0][ii], xlim=t[[0, -1]]) @@ -245,9 +253,19 @@ def plot_head_positions( for b in borders[:-1]: for jj in range(2): axes[ii, jj].axvline(t[b], color="r") - for ii, title in enumerate(("Position (mm)", "Rotation (quat)")): - axes[0, ii].set(title=title) - axes[-1, ii].set(xlabel="Time (s)") + if totals: + vals = [ + np.linalg.norm(use_trans, axis=-1), + np.rad2deg(_angle_between_quats(use_quats)), + ] + ii = -1 + for ci, val in enumerate(vals): + axes[ii, ci].plot(t, val, color, lw=1.0, zorder=3) + axes[ii, ci].set(ylabel=labels[ci][ii], xlim=t[[0, -1]]) + titles = ["Position", "Rotation"] + for ci, title in enumerate(titles): + axes[0, ci].set(title=title) + axes[-1, ci].set(xlabel="Time (s)") if rrs is not None: pos_bads = np.any( [ @@ -305,10 +323,18 @@ def plot_head_positions( if destination is not None: vals = np.array( - [destination[:, 3], rot_to_quat(destination[:, :3])] + [1000 * destination[:3, 3], rot_to_quat(destination[:3, :3])] ).T.ravel() - for ax, val in zip(fig.axes, vals): + for ax, val in zip(axes[:3].ravel(), vals): ax.axhline(val, color="r", ls=":", zorder=2, lw=1.0) + if totals: + dest_ang, dest_dist = _angle_dist_between_rigid( + destination, + angle_units="deg", + distance_units="mm", + ) + axes[-1, 0].axhline(dest_dist, color="r", ls=":", zorder=2, lw=1.0) + axes[-1, 1].axhline(dest_ang, color="r", ls=":", zorder=2, lw=1.0) else: # mode == 'field': from matplotlib.colors import Normalize