diff --git a/doc/changes/devel.rst b/doc/changes/devel.rst index 6669474fd29..846593b3182 100644 --- a/doc/changes/devel.rst +++ b/doc/changes/devel.rst @@ -29,7 +29,7 @@ Enhancements Bugs ~~~~ -- None yet +- Fix bugs with saving splits for :class:`~mne.Epochs` (:gh:`11876` by `Dmitrii Altukhov`_) API changes ~~~~~~~~~~~ diff --git a/mne/epochs.py b/mne/epochs.py index 069f88ddb42..87241add7ec 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -18,7 +18,7 @@ import numpy as np -from .io.utils import _construct_bids_filename +from .io.utils import _make_split_fnames from .io.write import ( start_and_end_file, start_block, @@ -119,36 +119,22 @@ def _pack_reject_params(epochs): return reject_params -def _save_split(epochs, fname, part_idx, n_parts, fmt, split_naming, overwrite): +def _save_split(epochs, split_fnames, part_idx, n_parts, fmt, overwrite): """Split epochs. Anything new added to this function also needs to be added to BaseEpochs.save to account for new file sizes. """ # insert index in filename - base, ext = op.splitext(fname) - if part_idx > 0: - if split_naming == "neuromag": - fname = "%s-%d%s" % (base, part_idx, ext) - else: - assert split_naming == "bids" - fname = _construct_bids_filename(base, ext, part_idx, validate=False) - _check_fname(fname, overwrite=overwrite) + this_fname = split_fnames[part_idx] + _check_fname(this_fname, overwrite=overwrite) - next_fname = None + next_fname, next_idx = None, None if part_idx < n_parts - 1: - if split_naming == "neuromag": - next_fname = "%s-%d%s" % (base, part_idx + 1, ext) - else: - assert split_naming == "bids" - next_fname = _construct_bids_filename( - base, ext, part_idx + 1, validate=False - ) next_idx = part_idx + 1 - else: - next_idx = None + next_fname = split_fnames[next_idx] - with start_and_end_file(fname) as fid: + with start_and_end_file(this_fname) as fid: _save_part(fid, epochs, fmt, n_parts, next_fname, next_idx) @@ -2149,13 +2135,14 @@ def save( epoch_idxs = np.array_split(np.arange(n_epochs), n_parts) + _check_option("split_naming", split_naming, ("neuromag", "bids")) + split_fnames = _make_split_fnames(fname, n_parts, split_naming) for part_idx, epoch_idx in enumerate(epoch_idxs): this_epochs = self[epoch_idx] if n_parts > 1 else self # avoid missing event_ids in splits this_epochs.event_id = self.event_id - _save_split( - this_epochs, fname, part_idx, n_parts, fmt, split_naming, overwrite - ) + + _save_split(this_epochs, split_fnames, part_idx, n_parts, fmt, overwrite) @verbose def export(self, fname, fmt="auto", *, overwrite=False, verbose=None): diff --git a/mne/io/utils.py b/mne/io/utils.py index 95b7ffe49ec..a6fa77e9aa6 100644 --- a/mne/io/utils.py +++ b/mne/io/utils.py @@ -342,3 +342,18 @@ def _construct_bids_filename(base, ext, part_idx, validate=True): if dirname: use_fname = op.join(dirname, use_fname) return use_fname + + +def _make_split_fnames(fname, n_splits, split_naming): + """Make a list of split filenames.""" + if n_splits == 1: + return [fname] + res = [] + base, ext = op.splitext(fname) + for i in range(n_splits): + if split_naming == "neuromag": + res.append(f"{base}-{i:d}{ext}" if i else fname) + else: + assert split_naming == "bids" + res.append(_construct_bids_filename(base, ext, i + 1)) + return res diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 26563c6a4af..2334a268ce4 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -1574,13 +1574,13 @@ def test_split_saving_and_loading_back(tmp_path, epochs_to_split, preload): ( "bids", "test_epo.fif", - lambda i: f"test_split-{i:02d}_epo.fif" if i else "test_epo.fif", + lambda i: f"test_split-{i + 1:02d}_epo.fif", ), ( "bids", - "test-epo.fif", + "a_b-epo.fif", # Merely stating the fact: - lambda i: f"_split-{i:02d}_test-epo.fif" if i else "test-epo.fif", + lambda i: f"a_split-{i + 1:02d}_b-epo.fif", ), ], ids=["neuromag", "bids", "mix"], @@ -1629,6 +1629,42 @@ def test_saved_fname_no_splitting( assert not split_1_fpath.is_file() +@pytest.mark.parametrize( + "epochs_to_split", + [ + ("3MB", 18, False, False, 3), + pytest.param( + ("2GB", 18, False, False, 1), + marks=pytest.mark.xfail(reason="No check when not splitting"), + ), + ], + indirect=True, +) +@pytest.mark.parametrize( + "dst_fname", + [ + "test-epo.fif", + pytest.param( + "a_b_c-epo.fif", + marks=pytest.mark.xfail(reason="No check for several bids clauses"), + ), + ], +) +def test_bids_splits_fail_for_bad_fname_ending(epochs_to_split, dst_fname, tmp_path): + """Make sure split_naming=bids is only used with bids endings. + + Non-bids endings can cause surprising split names, e.g. test-epo.fif + producing splits _split-01_test-epo.fif. + + """ + epochs, split_size, _ = epochs_to_split + dst_fpath = tmp_path / dst_fname + save_kwargs = {"split_naming": "bids", "split_size": split_size} + + with pytest.raises(ValueError, match=".* must end with an underscore"): + epochs.save(dst_fpath, verbose=True, **save_kwargs) + + @pytest.mark.parametrize( "epochs_to_split", [("3MB", 18, False, False, 3)], indirect=True ) @@ -1636,12 +1672,7 @@ def test_saved_fname_no_splitting( "split_naming, dst_fname, existing_fname", [ ("neuromag", "test-epo.fif", "test-epo.fif"), - pytest.param( - "neuromag", - "test-epo.fif", - "test-epo-1.fif", - marks=pytest.mark.xfail(reason="bug"), - ), + ("neuromag", "test-epo.fif", "test-epo-1.fif"), ("bids", "test_epo.fif", "test_epo.fif"), ("bids", "test_epo.fif", "test_split-01_epo.fif"), ("bids", "test_epo.fif", "test_split-02_epo.fif"),