Skip to content

Commit

Permalink
ENH: Improve plotting and reporting
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Oct 17, 2024
1 parent 922a780 commit 1dd212e
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 58 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:class:`mne.Report` HDF5 files are now written in ``mode='a`` (append) to allow users to store other data in the HDF5 files, by `Eric Larson`_.
6 changes: 6 additions & 0 deletions doc/changes/devel/newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Improved reporting and plotting options:

- :meth:`mne.Report.add_projs` can now plot with :func:`mne.viz.plot_projs_joint` rather than :func:`mne.viz.plot_projs_topomap`
- :func:`mne.viz.plot_head_positions` now has a ``totals=True`` option to show the total distance and angle of the head.

Changes by `Eric Larson`_.
96 changes: 63 additions & 33 deletions mne/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
plot_compare_evokeds,
plot_cov,
plot_events,
plot_projs_joint,
plot_projs_topomap,
set_3d_view,
use_browser_backend,
Expand Down Expand Up @@ -1613,27 +1614,43 @@ def add_projs(
self,
*,
info,
projs=None,
title,
projs=None,
topomap_kwargs=None,
tags=("ssp",),
joint=False,
picks_trace=None,
section=None,
replace=False,
):
"""Render (SSP) projection vectors.
Parameters
----------
info : instance of Info | path-like
An `~mne.Info` structure or the path of a file containing one. This
is required to create the topographic plots.
info : instance of Info | instance of Evoked | path-like
An `~mne.Info` structure or the path of a file containing one.
title : str
The title corresponding to the :class:`~mne.Projection` object.
projs : iterable of mne.Projection | path-like | None
The projection vectors to add to the report. Can be the path to a
file that will be loaded via `mne.read_proj`. If ``None``, the
projectors are taken from ``info['projs']``.
title : str
The title corresponding to the `~mne.Projection` object.
%(topomap_kwargs)s
%(tags_report)s
joint : bool
If True (default False), plot the projectors using
:func:`mne.viz.plot_projs_joint`, otherwise use
:func:`mne.viz.plot_projs_topomap`. If True, then ``info`` must be an
instance of :class:`mne.Evoked`.
.. versionadded:: 1.9
%(picks_plot_projs_joint_trace)s
Only used when ``joint=True``.
.. versionadded:: 1.9
%(section_report)s
.. versionadded:: 1.9
%(replace_report)s
Notes
Expand All @@ -1646,10 +1663,11 @@ def add_projs(
projs=projs,
title=title,
image_format=self.image_format,
section=None,
section=section,
tags=tags,
topomap_kwargs=topomap_kwargs,
replace=replace,
joint=joint,
)

def _add_ica_overlay(self, *, ica, inst, image_format, section, tags, replace):
Expand Down Expand Up @@ -2970,9 +2988,10 @@ def save(

if is_hdf5:
_, write_hdf5 = _import_h5io_funcs()
write_hdf5(
fname, self.__getstate__(), overwrite=overwrite, title="mnepython"
)
import h5py

with h5py.File(fname, "a") as f:
write_hdf5(f, self.__getstate__(), title="mnepython")
else:
# Add header, TOC, and footer.
header_html = _html_header_element(
Expand Down Expand Up @@ -3262,26 +3281,27 @@ def _add_projs(
section,
topomap_kwargs,
replace,
picks_trace=None,
joint=False,
):
evoked = None
if isinstance(info, Info): # no-op
pass
elif hasattr(info, "info"): # try to get the file name
if isinstance(info, BaseRaw):
fname = info.filenames[0]
# elif isinstance(info, (Evoked, BaseEpochs)):
# fname = info.filename
else:
fname = ""
if isinstance(info, Evoked):
evoked = info
info = info.info
else: # read from a file
fname = info
info = read_info(fname, verbose=False)
info = read_info(info, verbose=False)
if joint and evoked is None:
raise ValueError(
"joint=True requires an evoked instance to be passed as the info"
)

if projs is None:
projs = info["projs"]
elif not isinstance(projs, list):
fname = projs
projs = read_proj(fname)
projs = read_proj(projs)

if not projs:
raise ValueError("No SSP projectors found")
Expand All @@ -3294,19 +3314,29 @@ def _add_projs(
)

topomap_kwargs = self._validate_topomap_kwargs(topomap_kwargs)
fig = plot_projs_topomap(
projs=projs,
info=info,
colorbar=True,
vlim="joint",
show=False,
**topomap_kwargs,
)
# TODO This seems like a bad idea, better to provide a way to set a
# desired size in plot_projs_topomap, but that uses prepare_trellis...
# hard to see how (6, 4) could work in all number-of-projs by
# number-of-channel-types conditions...
fig.set_size_inches((6, 4))
if evoked is None:
fig = plot_projs_topomap(
projs=projs,
info=info,
colorbar=True,
vlim="joint",
show=False,
**topomap_kwargs,
)
# TODO This seems like a bad idea, better to provide a way to set a
# desired size in plot_projs_topomap, but that uses prepare_trellis...
# hard to see how (6, 4) could work in all number-of-projs by
# number-of-channel-types conditions...
fig.set_size_inches((6, 4))
else:
fig = plot_projs_joint(
projs,
evoked=evoked,
picks_trace=picks_trace,
topomap_kwargs=topomap_kwargs,
show=False,
)

_constrain_fig_resolution(fig, max_width=MAX_IMG_WIDTH, max_res=MAX_IMG_RES)
self._add_figure(
fig=fig,
Expand Down
22 changes: 14 additions & 8 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,14 +1118,20 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
"""

docdict["destination_maxwell_dest"] = """
destination : path-like | array-like, shape (3,) | None
The destination location for the head. Can be ``None``, which
will not change the head position, or a path to a FIF file
containing a MEG device<->head transformation, or a 3-element array
giving the coordinates to translate to (with no rotations).
For example, ``destination=(0, 0, 0.04)`` would translate the bases
as ``--trans default`` would in MaxFilter™ (i.e., to the default
head location).
destination : path-like | array-like, shape (3,) | instance of Transform | None
The destination location for the head. Can be:
``None``
Will not change the head position.
:class:`mne.Transform`
A MEG device<->head transformation, e.g. ``info["dev_head_t"]``.
:class:`numpy.ndarray`
A 3-element array giving the coordinates to translate to (with no rotations).
For example, ``destination=(0, 0, 0.04)`` would translate the bases
as ``--trans default`` would in MaxFilter™ (i.e., to the default
head location).
``path-like``
A path to a FIF file containing the destination MEG device<->head transformation.
"""

docdict["detrend_epochs"] = """
Expand Down
60 changes: 43 additions & 17 deletions mne/viz/_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
)
from ..transforms import (
Transform,
_angle_between_quats,
_angle_dist_between_rigid,
_ensure_trans,
_find_trans,
_frame_to_str,
Expand Down Expand Up @@ -131,6 +133,7 @@ def plot_head_positions(
info=None,
color="k",
axes=None,
totals=False,
):
"""Plot head positions.
Expand All @@ -149,10 +152,9 @@ def plot_head_positions(
directional axes in "field" mode.
show : bool
Show figure if True. Defaults to True.
destination : str | array-like, shape (3,) | None
The destination location for the head, assumed to be in head
coordinates. See :func:`mne.preprocessing.maxwell_filter` for
details.
destination : path-like | array-like, shape (3,) | instance of Transform | None
The destination location for the head. See
:func:`mne.preprocessing.maxwell_filter` for details.
.. versionadded:: 0.16
%(info)s If provided, will be used to show the destination position when
Expand All @@ -164,12 +166,16 @@ def plot_head_positions(
arrows in ``mode == 'field'``.
.. versionadded:: 0.16
axes : array-like, shape (3, 2)
axes : array-like, shape (3, 2) or (4, 2)
The matplotlib axes to use.
.. versionadded:: 0.16
.. versionchanged:: 1.8
Added support for making use of this argument when ``mode="field"``.
totals : bool
If True and in traces mode, show the total distance and angle in a fourth row.
.. versionadded:: 1.9
Returns
-------
Expand All @@ -182,12 +188,12 @@ def plot_head_positions(
from ..preprocessing.maxwell import _check_destination

_check_option("mode", mode, ["traces", "field"])
_validate_type(totals, bool, "totals")
dest_info = dict(dev_head_t=None) if info is None else info
destination = _check_destination(destination, dest_info, head_frame=True)
if destination is not None:
destination = _ensure_trans(destination, "head", "meg") # probably inv
destination = destination["trans"][:3].copy()
destination[:, 3] *= 1000
destination = destination["trans"]

if not isinstance(pos, list | tuple):
pos = [pos]
Expand Down Expand Up @@ -228,15 +234,17 @@ def plot_head_positions(
surf["rr"] *= 1000.0
helmet_color = DEFAULTS["coreg"]["helmet_color"]
if mode == "traces":
want_shape = (3 + int(totals), 2)
if axes is None:
axes = plt.subplots(3, 2, sharex=True)[1]
_, axes = plt.subplots(*want_shape, sharex=True, layout="constrained")
else:
axes = np.array(axes)
if axes.shape != (3, 2):
raise ValueError(f"axes must have shape (3, 2), got {axes.shape}")
_check_option("axes.shape", axes.shape, (want_shape,))
fig = axes[0, 0].figure

labels = ["xyz", ("$q_1$", "$q_2$", "$q_3$")]
labels = [["x (mm)", "y (mm)", "z (mm)"], ["$q_1$", "$q_2$", "$q_3$"]]
if totals:
labels[0].append("dist (mm)")
labels[1].append("angle (°)")
for ii, (quat, coord) in enumerate(zip(use_quats.T, use_trans.T)):
axes[ii, 0].plot(t, coord, color, lw=1.0, zorder=3)
axes[ii, 0].set(ylabel=labels[0][ii], xlim=t[[0, -1]])
Expand All @@ -245,9 +253,19 @@ def plot_head_positions(
for b in borders[:-1]:
for jj in range(2):
axes[ii, jj].axvline(t[b], color="r")
for ii, title in enumerate(("Position (mm)", "Rotation (quat)")):
axes[0, ii].set(title=title)
axes[-1, ii].set(xlabel="Time (s)")
if totals:
vals = [
np.linalg.norm(use_trans, axis=-1),
np.rad2deg(_angle_between_quats(use_quats)),
]
ii = -1
for ci, val in enumerate(vals):
axes[ii, ci].plot(t, val, color, lw=1.0, zorder=3)
axes[ii, ci].set(ylabel=labels[ci][ii], xlim=t[[0, -1]])
titles = ["Position", "Rotation"]
for ci, title in enumerate(titles):
axes[0, ci].set(title=title)
axes[-1, ci].set(xlabel="Time (s)")
if rrs is not None:
pos_bads = np.any(
[
Expand Down Expand Up @@ -305,10 +323,18 @@ def plot_head_positions(

if destination is not None:
vals = np.array(
[destination[:, 3], rot_to_quat(destination[:, :3])]
[1000 * destination[:3, 3], rot_to_quat(destination[:3, :3])]
).T.ravel()
for ax, val in zip(fig.axes, vals):
for ax, val in zip(axes[:3].ravel(), vals):
ax.axhline(val, color="r", ls=":", zorder=2, lw=1.0)
if totals:
dest_ang, dest_dist = _angle_dist_between_rigid(
destination,
angle_units="deg",
distance_units="mm",
)
axes[-1, 0].axhline(dest_dist, color="r", ls=":", zorder=2, lw=1.0)
axes[-1, 1].axhline(dest_ang, color="r", ls=":", zorder=2, lw=1.0)

else: # mode == 'field':
from matplotlib.colors import Normalize
Expand Down

0 comments on commit 1dd212e

Please sign in to comment.