Skip to content

Commit

Permalink
ENH: Add Forward.save and hdf5 support
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Sep 30, 2023
1 parent d52a0d3 commit 1d3a7c6
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 8 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Enhancements
- Add helpful error messages when using methods on empty :class:`mne.Epochs`-objects (:gh:`11306` by `Martin Schulz`_)
- Add inferring EEGLAB files' montage unit automatically based on estimated head radius using :func:`read_raw_eeglab(..., montage_units="auto") <mne.io.read_raw_eeglab>` (:gh:`11925` by `Jack Zhang`_, :gh:`11951` by `Eric Larson`_)
- Add :class:`~mne.time_frequency.EpochsSpectrumArray` and :class:`~mne.time_frequency.SpectrumArray` to support creating power spectra from :class:`NumPy array <numpy.ndarray>` data (:gh:`11803` by `Alex Rockhill`_)
- Add support for writing forward solutions to HDF5 and convenience function :meth:`mne.forward.Forward.save` (:gh:`12036` by `Eric Larson`_)
- Refactored internals of :func:`mne.read_annotations` (:gh:`11964` by `Paul Roujansky`_)
- Enhance :func:`~mne.viz.plot_evoked_field` with a GUI that has controls for time, colormap, and contour lines (:gh:`11942` by `Marijn van Vliet`_)

Expand Down
3 changes: 3 additions & 0 deletions mne/_fiff/meas_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
_check_on_missing,
fill_doc,
_check_fname,
check_fname,
repr_html,
)
from ._digitization import (
Expand Down Expand Up @@ -2006,6 +2007,8 @@ def read_info(fname, verbose=None):
-------
%(info_not_none)s
"""
check_fname(fname, "Info", (".fif", ".fif.gz"))
fname = _check_fname(fname, must_exist=True, overwrite="read")
f, tree, _ = fiff_open(fname)
with f as fid:
info = read_meas_info(fid, tree)[0]
Expand Down
8 changes: 8 additions & 0 deletions mne/_fiff/tests/test_meas_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,14 @@ def test_read_write_info(tmp_path):
write_info(fname, info)


@testing.requires_testing_data
def test_dir_warning():
"""Test that trying to read a bad filename emits a warning before an error."""
with pytest.raises(OSError, match="directory"):
with pytest.warns(RuntimeWarning, match="foo"):
read_info(ctf_fname)


def test_io_dig_points(tmp_path):
"""Test Writing for dig files."""
dest = tmp_path / "test.txt"
Expand Down
47 changes: 39 additions & 8 deletions mne/forward/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
_stamp_to_dt,
_on_missing,
repr_html,
_import_h5io_funcs,
)
from ..label import Label

Expand Down Expand Up @@ -165,6 +166,18 @@ def copy(self):
"""Copy the Forward instance."""
return Forward(deepcopy(self))

@verbose
def save(self, fname, *, overwrite=False, verbose=None):
"""Save the forward solution.
Parameters
----------
%(fname_fwd)s
%(overwrite)s
%(verbose)s
"""
write_forward_solution(fname, self, overwrite=overwrite)

def _get_src_type_and_ori_for_repr(self):
src_types = np.array([src["type"] for src in self["src"]])

Expand Down Expand Up @@ -520,7 +533,8 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, verbos
Parameters
----------
fname : path-like
The file name, which should end with ``-fwd.fif`` or ``-fwd.fif.gz``.
The file name, which should end with ``-fwd.fif``, ``-fwd.fif.gz``,
``_fwd.fif``, ``_fwd.fif.gz``, ``-fwd.h5``, or ``_fwd.h5``.
include : list, optional
List of names of channels to include. If empty all channels
are included.
Expand Down Expand Up @@ -554,11 +568,15 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, verbos
forward solution with :func:`read_forward_solution`.
"""
check_fname(
fname, "forward", ("-fwd.fif", "-fwd.fif.gz", "_fwd.fif", "_fwd.fif.gz")
fname,
"forward",
("-fwd.fif", "-fwd.fif.gz", "_fwd.fif", "_fwd.fif.gz", "-fwd.h5", "_fwd.h5"),
)
fname = _check_fname(fname=fname, must_exist=True, overwrite="read")
# Open the file, create directory
logger.info("Reading forward solution from %s..." % fname)
if fname.suffix == ".h5":
return _read_forward_hdf5(fname)
f, tree, _ = fiff_open(fname)
with f as fid:
# Find all forward solutions
Expand Down Expand Up @@ -861,9 +879,7 @@ def write_forward_solution(fname, fwd, overwrite=False, verbose=None):
Parameters
----------
fname : path-like
File name to save the forward solution to. It should end with
``-fwd.fif`` or ``-fwd.fif.gz``.
%(fname_fwd)s
fwd : Forward
Forward solution.
%(overwrite)s
Expand All @@ -889,13 +905,28 @@ def write_forward_solution(fname, fwd, overwrite=False, verbose=None):
forward solution with :func:`read_forward_solution`.
"""
check_fname(
fname, "forward", ("-fwd.fif", "-fwd.fif.gz", "_fwd.fif", "_fwd.fif.gz")
fname,
"forward",
("-fwd.fif", "-fwd.fif.gz", "_fwd.fif", "_fwd.fif.gz", "-fwd.h5", "_fwd.h5"),
)

# check for file existence and expand `~` if present
fname = _check_fname(fname, overwrite)
with start_and_end_file(fname) as fid:
_write_forward_solution(fid, fwd)
if fname.suffix == ".h5":
_write_forward_hdf5(fname, fwd)
else:
with start_and_end_file(fname) as fid:
_write_forward_solution(fid, fwd)


def _write_forward_hdf5(fname, fwd):
_, write_hdf5 = _import_h5io_funcs()
write_hdf5(fname, dict(fwd=fwd), overwrite=True)


def _read_forward_hdf5(fname):
read_hdf5, _ = _import_h5io_funcs()
return Forward(read_hdf5(fname)["fwd"])


def _write_forward_solution(fid, fwd):
Expand Down
9 changes: 9 additions & 0 deletions mne/forward/tests/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,15 @@ def test_io_forward(tmp_path):
fwd_read = read_forward_solution(fname_temp)
assert_forward_allclose(fwd, fwd_read)

h5py = pytest.importorskip("h5py")
pytest.importorskip("h5io")
fname_h5 = fname_temp.with_suffix(".h5")
fwd.save(fname_h5)
with h5py.File(fname_h5, "r"):
pass # just checks for hdf5-ness
fwd_read = read_forward_solution(fname_h5)
assert_forward_allclose(fwd, fwd_read)


@testing.requires_testing_data
def test_apply_forward():
Expand Down
9 changes: 9 additions & 0 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1694,6 +1694,15 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
Name of the output file.
"""

docdict[
"fname_fwd"
] = """
fname : path-like
File name to save the forward solution to. It should end with
``-fwd.fif`` or ``-fwd.fif.gz`` to save to FIF, or ``-fwd.h5`` to save to
HDF5.
"""

docdict[
"fnirs"
] = """
Expand Down

0 comments on commit 1d3a7c6

Please sign in to comment.