Skip to content

Commit

Permalink
factor out epochs creation
Browse files Browse the repository at this point in the history
  • Loading branch information
dmalt committed Aug 14, 2023
1 parent 296a861 commit 860b8ab
Showing 1 changed file with 57 additions and 35 deletions.
92 changes: 57 additions & 35 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1476,44 +1476,66 @@ def test_epochs_io_preload(tmp_path, preload):
assert_equal(epochs.get_data().shape[-1], 1)


@pytest.mark.parametrize(
"split_size, n_epochs, n_files",
[
("1.5MB", 9, 6),
("3MB", 18, 3),
],
)
@pytest.mark.parametrize("metadata", [False, True])
@pytest.mark.parametrize("concat", (False, True))
def test_split_saving(tmp_path, split_size, n_epochs, n_files, metadata, concat):
"""Test saving split epochs."""
if metadata:
pytest.importorskip("pandas")
# See gh-5102
fs = 1000.0
n_times = int(round(fs * (n_epochs + 1)))
raw = mne.io.RawArray(
np.random.RandomState(0).randn(100, n_times), mne.create_info(100, 1000.0)
)
events = mne.make_fixed_length_events(raw, 1)
epochs = mne.Epochs(raw, events)
if metadata:
from pandas import DataFrame

junk = ["*" * 10000 for _ in range(len(events))]
metadata = DataFrame(
{
"event_time": events[:, 0] / raw.info["sfreq"],
"trial_number": range(len(events)),
"junk": junk,
}
@pytest.fixture(scope="session")
def epochs_factory():
"""Function to create fake Epochs object."""
def factory(n_epochs, metadata=False, concat=False):
if metadata:
pytest.importorskip("pandas")
# See gh-5102
fs = 1000.0
n_times = int(round(fs * (n_epochs + 1)))
raw = mne.io.RawArray(
np.random.RandomState(0).randn(100, n_times), mne.create_info(100, 1000.0)
)
epochs.metadata = metadata
if concat:
events = mne.make_fixed_length_events(raw, 1)
epochs = mne.Epochs(raw, events)
if metadata:
from pandas import DataFrame

junk = ["*" * 10000 for _ in range(len(events))]
metadata = DataFrame(
{
"event_time": events[:, 0] / raw.info["sfreq"],
"trial_number": range(len(events)),
"junk": junk,
}
)
epochs.metadata = metadata
epochs.drop_bad()
epochs = concatenate_epochs([epochs[ii] for ii in range(len(epochs))])
if concat:
epochs = concatenate_epochs([epochs[ii] for ii in range(len(epochs))])
assert len(epochs) == n_epochs
return epochs
return factory


@pytest.fixture(params=[
("1.5MB", 9, True, True, 6),
("1.5MB", 9, True, False, 6),
("1.5MB", 9, False, True, 6),
("1.5MB", 9, False, False, 6),
("3MB", 18, True, True, 3),
("3MB", 18, True, False, 3),
("3MB", 18, False, True, 3),
("3MB", 18, False, False, 3),
])
def epochs_to_split(request, epochs_factory):
"""Epochs tailored to produce specific number of splits when saving.
We're specifically interested in boundary cases, when a small size
excess triggers creation of a new split: gh-7897
"""
split_size, n_epochs, metadata, concat, n_files = request.param
epochs = epochs_factory(n_epochs, metadata, concat)
return epochs, split_size, n_files


def test_split_saving(tmp_path, epochs_to_split):
"""Test saving split epochs."""
epochs, split_size, n_files = epochs_to_split
epochs_data = epochs.get_data()
assert len(epochs) == n_epochs
fname = tmp_path / "test-epo.fif"
epochs.save(fname, split_size=split_size, overwrite=True)
got_size = _get_split_size(split_size)
Expand Down

0 comments on commit 860b8ab

Please sign in to comment.