Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Fix epoch splits naming #11876

Merged
merged 12 commits into from
Aug 17, 2023
35 changes: 11 additions & 24 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import numpy as np

from .io.utils import _construct_bids_filename
from .io.utils import _make_split_fnames
from .io.write import (
start_and_end_file,
start_block,
Expand Down Expand Up @@ -119,36 +119,22 @@ def _pack_reject_params(epochs):
return reject_params


def _save_split(epochs, fname, part_idx, n_parts, fmt, split_naming, overwrite):
def _save_split(epochs, split_fnames, part_idx, n_parts, fmt, overwrite):
"""Split epochs.

Anything new added to this function also needs to be added to
BaseEpochs.save to account for new file sizes.
"""
# insert index in filename
base, ext = op.splitext(fname)
if part_idx > 0:
if split_naming == "neuromag":
fname = "%s-%d%s" % (base, part_idx, ext)
else:
assert split_naming == "bids"
fname = _construct_bids_filename(base, ext, part_idx, validate=False)
_check_fname(fname, overwrite=overwrite)
this_fname = split_fnames[part_idx]
_check_fname(this_fname, overwrite=overwrite)

next_fname = None
next_fname, next_idx = None, None
if part_idx < n_parts - 1:
if split_naming == "neuromag":
next_fname = "%s-%d%s" % (base, part_idx + 1, ext)
else:
assert split_naming == "bids"
next_fname = _construct_bids_filename(
base, ext, part_idx + 1, validate=False
)
next_idx = part_idx + 1
else:
next_idx = None
next_fname = split_fnames[next_idx]

with start_and_end_file(fname) as fid:
with start_and_end_file(this_fname) as fid:
_save_part(fid, epochs, fmt, n_parts, next_fname, next_idx)


Expand Down Expand Up @@ -2149,13 +2135,14 @@ def save(

epoch_idxs = np.array_split(np.arange(n_epochs), n_parts)

_check_option("split_naming", split_naming, ("neuromag", "bids"))
split_fnames = _make_split_fnames(fname, n_parts, split_naming)
for part_idx, epoch_idx in enumerate(epoch_idxs):
this_epochs = self[epoch_idx] if n_parts > 1 else self
# avoid missing event_ids in splits
this_epochs.event_id = self.event_id
_save_split(
this_epochs, fname, part_idx, n_parts, fmt, split_naming, overwrite
)

_save_split(this_epochs, split_fnames, part_idx, n_parts, fmt, overwrite)

@verbose
def export(self, fname, fmt="auto", *, overwrite=False, verbose=None):
Expand Down
15 changes: 15 additions & 0 deletions mne/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,18 @@ def _construct_bids_filename(base, ext, part_idx, validate=True):
if dirname:
use_fname = op.join(dirname, use_fname)
return use_fname


def _make_split_fnames(fname, n_splits, split_naming):
"""Make a list of split filenames."""
if n_splits == 1:
return [fname]
res = []
base, ext = op.splitext(fname)
for i in range(n_splits):
if split_naming == "neuromag":
res.append(f"{base}-{i:d}{ext}" if i else fname)
else:
assert split_naming == "bids"
res.append(_construct_bids_filename(base, ext, i + 1))
return res
49 changes: 40 additions & 9 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,13 +1574,13 @@ def test_split_saving_and_loading_back(tmp_path, epochs_to_split, preload):
(
"bids",
"test_epo.fif",
lambda i: f"test_split-{i:02d}_epo.fif" if i else "test_epo.fif",
lambda i: f"test_split-{i + 1:02d}_epo.fif",
),
(
"bids",
"test-epo.fif",
"a_b-epo.fif",
# Merely stating the fact:
lambda i: f"_split-{i:02d}_test-epo.fif" if i else "test-epo.fif",
lambda i: f"a_split-{i + 1:02d}_b-epo.fif",
),
],
ids=["neuromag", "bids", "mix"],
Expand Down Expand Up @@ -1629,19 +1629,50 @@ def test_saved_fname_no_splitting(
assert not split_1_fpath.is_file()


@pytest.mark.parametrize(
"epochs_to_split",
[
("3MB", 18, False, False, 3),
pytest.param(
("2GB", 18, False, False, 1),
marks=pytest.mark.xfail(reason="No check when not splitting"),
),
],
indirect=True,
)
@pytest.mark.parametrize(
"dst_fname",
[
"test-epo.fif",
pytest.param(
"a_b_c-epo.fif",
marks=pytest.mark.xfail(reason="No check for several bids clauses"),
),
],
)
def test_bids_splits_fail_for_bad_fname_ending(epochs_to_split, dst_fname, tmp_path):
"""Make sure split_naming=bids is only used with bids endings.

Non-bids endings can cause surprising split names, e.g. test-epo.fif
producing splits _split-01_test-epo.fif.

"""
epochs, split_size, _ = epochs_to_split
dst_fpath = tmp_path / dst_fname
save_kwargs = {"split_naming": "bids", "split_size": split_size}

with pytest.raises(ValueError, match=".* must end with an underscore"):
epochs.save(dst_fpath, verbose=True, **save_kwargs)


@pytest.mark.parametrize(
"epochs_to_split", [("3MB", 18, False, False, 3)], indirect=True
)
@pytest.mark.parametrize(
"split_naming, dst_fname, existing_fname",
[
("neuromag", "test-epo.fif", "test-epo.fif"),
pytest.param(
"neuromag",
"test-epo.fif",
"test-epo-1.fif",
marks=pytest.mark.xfail(reason="bug"),
),
("neuromag", "test-epo.fif", "test-epo-1.fif"),
("bids", "test_epo.fif", "test_epo.fif"),
("bids", "test_epo.fif", "test_split-01_epo.fif"),
("bids", "test_epo.fif", "test_split-02_epo.fif"),
Expand Down