From 860b8abe042e24eb13aa25bff4d87235441832a2 Mon Sep 17 00:00:00 2001 From: Dmitrii Altukhov Date: Mon, 14 Aug 2023 19:21:11 +0200 Subject: [PATCH] factor out epochs creation --- mne/tests/test_epochs.py | 92 +++++++++++++++++++++++++--------------- 1 file changed, 57 insertions(+), 35 deletions(-) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 9218997e4a6..42bad2e518e 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -1476,44 +1476,66 @@ def test_epochs_io_preload(tmp_path, preload): assert_equal(epochs.get_data().shape[-1], 1) -@pytest.mark.parametrize( - "split_size, n_epochs, n_files", - [ - ("1.5MB", 9, 6), - ("3MB", 18, 3), - ], -) -@pytest.mark.parametrize("metadata", [False, True]) -@pytest.mark.parametrize("concat", (False, True)) -def test_split_saving(tmp_path, split_size, n_epochs, n_files, metadata, concat): - """Test saving split epochs.""" - if metadata: - pytest.importorskip("pandas") - # See gh-5102 - fs = 1000.0 - n_times = int(round(fs * (n_epochs + 1))) - raw = mne.io.RawArray( - np.random.RandomState(0).randn(100, n_times), mne.create_info(100, 1000.0) - ) - events = mne.make_fixed_length_events(raw, 1) - epochs = mne.Epochs(raw, events) - if metadata: - from pandas import DataFrame - - junk = ["*" * 10000 for _ in range(len(events))] - metadata = DataFrame( - { - "event_time": events[:, 0] / raw.info["sfreq"], - "trial_number": range(len(events)), - "junk": junk, - } +@pytest.fixture(scope="session") +def epochs_factory(): + """Function to create fake Epochs object.""" + def factory(n_epochs, metadata=False, concat=False): + if metadata: + pytest.importorskip("pandas") + # See gh-5102 + fs = 1000.0 + n_times = int(round(fs * (n_epochs + 1))) + raw = mne.io.RawArray( + np.random.RandomState(0).randn(100, n_times), mne.create_info(100, 1000.0) ) - epochs.metadata = metadata - if concat: + events = mne.make_fixed_length_events(raw, 1) + epochs = mne.Epochs(raw, events) + if metadata: + from pandas import DataFrame + + junk = ["*" * 10000 for _ in range(len(events))] + metadata = DataFrame( + { + "event_time": events[:, 0] / raw.info["sfreq"], + "trial_number": range(len(events)), + "junk": junk, + } + ) + epochs.metadata = metadata epochs.drop_bad() - epochs = concatenate_epochs([epochs[ii] for ii in range(len(epochs))]) + if concat: + epochs = concatenate_epochs([epochs[ii] for ii in range(len(epochs))]) + assert len(epochs) == n_epochs + return epochs + return factory + + +@pytest.fixture(params=[ + ("1.5MB", 9, True, True, 6), + ("1.5MB", 9, True, False, 6), + ("1.5MB", 9, False, True, 6), + ("1.5MB", 9, False, False, 6), + ("3MB", 18, True, True, 3), + ("3MB", 18, True, False, 3), + ("3MB", 18, False, True, 3), + ("3MB", 18, False, False, 3), +]) +def epochs_to_split(request, epochs_factory): + """Epochs tailored to produce specific number of splits when saving. + + We're specifically interested in boundary cases, when a small size + excess triggers creation of a new split: gh-7897 + + """ + split_size, n_epochs, metadata, concat, n_files = request.param + epochs = epochs_factory(n_epochs, metadata, concat) + return epochs, split_size, n_files + + +def test_split_saving(tmp_path, epochs_to_split): + """Test saving split epochs.""" + epochs, split_size, n_files = epochs_to_split epochs_data = epochs.get_data() - assert len(epochs) == n_epochs fname = tmp_path / "test-epo.fif" epochs.save(fname, split_size=split_size, overwrite=True) got_size = _get_split_size(split_size)