From 1d3a7c64f497359e06982b66f26beb5262e4a658 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Sat, 30 Sep 2023 19:11:33 -0400 Subject: [PATCH] ENH: Add Forward.save and hdf5 support --- doc/changes/devel.rst | 1 + mne/_fiff/meas_info.py | 3 ++ mne/_fiff/tests/test_meas_info.py | 8 ++++++ mne/forward/forward.py | 47 +++++++++++++++++++++++++------ mne/forward/tests/test_forward.py | 9 ++++++ mne/utils/docs.py | 9 ++++++ 6 files changed, 69 insertions(+), 8 deletions(-) diff --git a/doc/changes/devel.rst b/doc/changes/devel.rst index a47f500bc5f..f33389d0cb4 100644 --- a/doc/changes/devel.rst +++ b/doc/changes/devel.rst @@ -33,6 +33,7 @@ Enhancements - Add helpful error messages when using methods on empty :class:`mne.Epochs`-objects (:gh:`11306` by `Martin Schulz`_) - Add inferring EEGLAB files' montage unit automatically based on estimated head radius using :func:`read_raw_eeglab(..., montage_units="auto") ` (:gh:`11925` by `Jack Zhang`_, :gh:`11951` by `Eric Larson`_) - Add :class:`~mne.time_frequency.EpochsSpectrumArray` and :class:`~mne.time_frequency.SpectrumArray` to support creating power spectra from :class:`NumPy array ` data (:gh:`11803` by `Alex Rockhill`_) +- Add support for writing forward solutions to HDF5 and convenience function :meth:`mne.forward.Forward.save` (:gh:`12036` by `Eric Larson`_) - Refactored internals of :func:`mne.read_annotations` (:gh:`11964` by `Paul Roujansky`_) - Enhance :func:`~mne.viz.plot_evoked_field` with a GUI that has controls for time, colormap, and contour lines (:gh:`11942` by `Marijn van Vliet`_) diff --git a/mne/_fiff/meas_info.py b/mne/_fiff/meas_info.py index fe5c9d0d881..672f805c1b8 100644 --- a/mne/_fiff/meas_info.py +++ b/mne/_fiff/meas_info.py @@ -80,6 +80,7 @@ _check_on_missing, fill_doc, _check_fname, + check_fname, repr_html, ) from ._digitization import ( @@ -2006,6 +2007,8 @@ def read_info(fname, verbose=None): ------- %(info_not_none)s """ + check_fname(fname, "Info", (".fif", ".fif.gz")) + fname = _check_fname(fname, must_exist=True, overwrite="read") f, tree, _ = fiff_open(fname) with f as fid: info = read_meas_info(fid, tree)[0] diff --git a/mne/_fiff/tests/test_meas_info.py b/mne/_fiff/tests/test_meas_info.py index 6cee0c94d76..844d04fc624 100644 --- a/mne/_fiff/tests/test_meas_info.py +++ b/mne/_fiff/tests/test_meas_info.py @@ -345,6 +345,14 @@ def test_read_write_info(tmp_path): write_info(fname, info) +@testing.requires_testing_data +def test_dir_warning(): + """Test that trying to read a bad filename emits a warning before an error.""" + with pytest.raises(OSError, match="directory"): + with pytest.warns(RuntimeWarning, match="foo"): + read_info(ctf_fname) + + def test_io_dig_points(tmp_path): """Test Writing for dig files.""" dest = tmp_path / "test.txt" diff --git a/mne/forward/forward.py b/mne/forward/forward.py index 0fcb821ab2d..07ea99d59ce 100644 --- a/mne/forward/forward.py +++ b/mne/forward/forward.py @@ -81,6 +81,7 @@ _stamp_to_dt, _on_missing, repr_html, + _import_h5io_funcs, ) from ..label import Label @@ -165,6 +166,18 @@ def copy(self): """Copy the Forward instance.""" return Forward(deepcopy(self)) + @verbose + def save(self, fname, *, overwrite=False, verbose=None): + """Save the forward solution. + + Parameters + ---------- + %(fname_fwd)s + %(overwrite)s + %(verbose)s + """ + write_forward_solution(fname, self, overwrite=overwrite) + def _get_src_type_and_ori_for_repr(self): src_types = np.array([src["type"] for src in self["src"]]) @@ -520,7 +533,8 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, verbos Parameters ---------- fname : path-like - The file name, which should end with ``-fwd.fif`` or ``-fwd.fif.gz``. + The file name, which should end with ``-fwd.fif``, ``-fwd.fif.gz``, + ``_fwd.fif``, ``_fwd.fif.gz``, ``-fwd.h5``, or ``_fwd.h5``. include : list, optional List of names of channels to include. If empty all channels are included. @@ -554,11 +568,15 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, verbos forward solution with :func:`read_forward_solution`. """ check_fname( - fname, "forward", ("-fwd.fif", "-fwd.fif.gz", "_fwd.fif", "_fwd.fif.gz") + fname, + "forward", + ("-fwd.fif", "-fwd.fif.gz", "_fwd.fif", "_fwd.fif.gz", "-fwd.h5", "_fwd.h5"), ) fname = _check_fname(fname=fname, must_exist=True, overwrite="read") # Open the file, create directory logger.info("Reading forward solution from %s..." % fname) + if fname.suffix == ".h5": + return _read_forward_hdf5(fname) f, tree, _ = fiff_open(fname) with f as fid: # Find all forward solutions @@ -861,9 +879,7 @@ def write_forward_solution(fname, fwd, overwrite=False, verbose=None): Parameters ---------- - fname : path-like - File name to save the forward solution to. It should end with - ``-fwd.fif`` or ``-fwd.fif.gz``. + %(fname_fwd)s fwd : Forward Forward solution. %(overwrite)s @@ -889,13 +905,28 @@ def write_forward_solution(fname, fwd, overwrite=False, verbose=None): forward solution with :func:`read_forward_solution`. """ check_fname( - fname, "forward", ("-fwd.fif", "-fwd.fif.gz", "_fwd.fif", "_fwd.fif.gz") + fname, + "forward", + ("-fwd.fif", "-fwd.fif.gz", "_fwd.fif", "_fwd.fif.gz", "-fwd.h5", "_fwd.h5"), ) # check for file existence and expand `~` if present fname = _check_fname(fname, overwrite) - with start_and_end_file(fname) as fid: - _write_forward_solution(fid, fwd) + if fname.suffix == ".h5": + _write_forward_hdf5(fname, fwd) + else: + with start_and_end_file(fname) as fid: + _write_forward_solution(fid, fwd) + + +def _write_forward_hdf5(fname, fwd): + _, write_hdf5 = _import_h5io_funcs() + write_hdf5(fname, dict(fwd=fwd), overwrite=True) + + +def _read_forward_hdf5(fname): + read_hdf5, _ = _import_h5io_funcs() + return Forward(read_hdf5(fname)["fwd"]) def _write_forward_solution(fid, fwd): diff --git a/mne/forward/tests/test_forward.py b/mne/forward/tests/test_forward.py index d6981945ac6..ee37f11676c 100644 --- a/mne/forward/tests/test_forward.py +++ b/mne/forward/tests/test_forward.py @@ -197,6 +197,15 @@ def test_io_forward(tmp_path): fwd_read = read_forward_solution(fname_temp) assert_forward_allclose(fwd, fwd_read) + h5py = pytest.importorskip("h5py") + pytest.importorskip("h5io") + fname_h5 = fname_temp.with_suffix(".h5") + fwd.save(fname_h5) + with h5py.File(fname_h5, "r"): + pass # just checks for hdf5-ness + fwd_read = read_forward_solution(fname_h5) + assert_forward_allclose(fwd, fwd_read) + @testing.requires_testing_data def test_apply_forward(): diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 53a394022a3..7ec2dbc4534 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1694,6 +1694,15 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Name of the output file. """ +docdict[ + "fname_fwd" +] = """ +fname : path-like + File name to save the forward solution to. It should end with + ``-fwd.fif`` or ``-fwd.fif.gz`` to save to FIF, or ``-fwd.h5`` to save to + HDF5. +""" + docdict[ "fnirs" ] = """