Skip to content

Commit

Permalink
style: merge precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
dmalt committed Aug 12, 2023
1 parent e855ae3 commit 29b128b
Showing 1 changed file with 24 additions and 18 deletions.
42 changes: 24 additions & 18 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#
# License: BSD-3-Clause

import os
import pickle
from copy import deepcopy
from datetime import timedelta
Expand Down Expand Up @@ -1477,27 +1476,30 @@ def test_epochs_io_preload(tmp_path, preload):


@pytest.fixture(params=[True, False], ids=["metadata", "no_metadata"])
def metadata(request):
def is_add_metadata(request):
"""Weither to create metadata in epochs fixture."""
return request.param


@pytest.fixture(params=[True, False], ids=["concat", "no_concat"])
def concat(request):
def is_concat(request):
"""Weither to create concatenated version of epochs."""
return request.param


@pytest.fixture
def epochs(n_epochs, metadata, concat):
def fake_epochs(n_epochs, is_add_metadata, is_concat):
"""Epochs object to test split naming."""
n_ch, fs = 100, 1000.0
if metadata:
if is_add_metadata:
pytest.importorskip("pandas")
# See gh-5102
n_times = int(round(fs * (n_epochs + 1)))
raw_data = np.random.RandomState(0).randn(n_ch, n_times)
raw = mne.io.RawArray(raw_data, mne.create_info(n_ch, fs))
events = mne.make_fixed_length_events(raw, 1)
epochs = mne.Epochs(raw, events)
if metadata:
if is_add_metadata:
from pandas import DataFrame

junk = ["*" * 10000 for _ in range(len(events))]
Expand All @@ -1510,7 +1512,7 @@ def epochs(n_epochs, metadata, concat):
)
epochs.metadata = metadata
epochs.drop_bad()
if concat:
if is_concat:
epochs = concatenate_epochs([epochs[ii] for ii in range(len(epochs))])
assert len(epochs) == n_epochs
return epochs
Expand All @@ -1521,18 +1523,18 @@ def epochs(n_epochs, metadata, concat):
)
@pytest.mark.parametrize("preload", [True, False], ids=["preload", "no_preload"])
def test_split_saving_and_loading_back(
tmp_path, split_size, epochs, expected_n_splits, preload
tmp_path, split_size, fake_epochs, expected_n_splits, preload
):
"""Test saving split epochs."""
dst = tmp_path / "test-epo.fif"
split_size_bytes = _get_split_size(split_size)

epochs.save(dst, split_size=split_size)
fake_epochs.save(dst, split_size=split_size)
loaded_epochs = mne.read_epochs(dst, preload=preload)

_assert_splits(dst, expected_n_splits, split_size_bytes)
assert_allclose(loaded_epochs.get_data(), epochs.get_data())
assert_array_equal(epochs.events, loaded_epochs.events)
assert_allclose(loaded_epochs.get_data(), fake_epochs.get_data())
assert_array_equal(fake_epochs.events, loaded_epochs.events)


@pytest.mark.parametrize(
Expand All @@ -1552,16 +1554,19 @@ def test_split_saving_and_loading_back(
)
def test_split_naming(
tmp_path,
epochs,
fake_epochs,
split_size,
expected_n_splits,
dst_fname,
split_naming,
split_fname_fn,
):
"""Save epochs with small split_size and check the split filenames."""
dst = tmp_path / dst_fname

epochs.save(dst, split_size=split_size, split_naming=split_naming, verbose=True)
fake_epochs.save(
dst, split_size=split_size, split_naming=split_naming, verbose=True
)

assert expected_n_splits >= 2, "Need at least 2 splits for saved epochs."
assert len(list(tmp_path.iterdir())) == expected_n_splits
Expand All @@ -1579,26 +1584,27 @@ def test_split_naming(
],
)
def test_saved_fname_no_splitting(
tmp_path, epochs, dst_fname, split_naming, split_1_fname
tmp_path, fake_epochs, dst_fname, split_naming, split_1_fname
):
"""Test saved fname doesn't get split suffix when splitting not needed."""
dst_fpath = tmp_path / dst_fname
split_1_fpath = tmp_path / split_1_fname

epochs.save(dst_fpath, split_naming=split_naming, verbose=True, split_size="2GB")
fake_epochs.save(dst_fpath, split_naming=split_naming, verbose=True, split_size="2GB")

assert dst_fpath.is_file()
assert not split_1_fpath.is_file()


@pytest.mark.parametrize("n_epochs", [5])
@pytest.mark.parametrize("split_naming", ["neuromag", "bids"])
def test_saving_fails_with_not_permitted_overwrite(tmp_path, epochs, split_naming):
def test_saving_fails_with_not_permitted_overwrite(tmp_path, fake_epochs, split_naming):
"""Check exception is raised when overwriting without explicit flag."""
dst_fpath = tmp_path / "test-epo.fif"
epochs.save(dst_fpath, split_naming=split_naming, verbose=True)
fake_epochs.save(dst_fpath, split_naming=split_naming, verbose=True)

with pytest.raises(FileExistsError, match="Destination file"):
epochs.save(dst_fpath, split_naming=split_naming, verbose=True)
fake_epochs.save(dst_fpath, split_naming=split_naming, verbose=True)


@pytest.mark.slowtest
Expand Down

0 comments on commit 29b128b

Please sign in to comment.