Skip to content

Commit

Permalink
BUG: Fix bug with pickling MNEBadsList (#12063)
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner authored Oct 4, 2023
1 parent 9d20815 commit c65345a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
9 changes: 8 additions & 1 deletion mne/_fiff/meas_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,13 @@ def __init__(self, *, bads, info):
def extend(self, iterable):
if not isinstance(iterable, list):
iterable = list(iterable)
_check_bads_info_compat(iterable, self._mne_info)
# can happen during pickling
try:
info = self._mne_info
except AttributeError:
pass # can happen during pickling
else:
_check_bads_info_compat(iterable, info)
return super().extend(iterable)

def append(self, x):
Expand Down Expand Up @@ -1551,6 +1557,7 @@ def __getstate__(self):
def __setstate__(self, state):
"""Set state (for pickling)."""
self._unlocked = state["_unlocked"]
self["bads"] = MNEBadsList(bads=self["bads"], info=self)

def __setitem__(self, key, val):
"""Attribute setter."""
Expand Down
10 changes: 8 additions & 2 deletions mne/_fiff/tests/test_meas_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,21 +1070,27 @@ def test_channel_name_limit(tmp_path, monkeypatch, fname):
apply_inverse(evoked, inv) # smoke test


@pytest.mark.parametrize("protocol", ("highest", "default"))
@pytest.mark.parametrize("fname_info", (raw_fname, "create_info"))
@pytest.mark.parametrize("unlocked", (True, False))
def test_pickle(fname_info, unlocked):
def test_pickle(fname_info, unlocked, protocol):
"""Test that Info can be (un)pickled."""
if fname_info == "create_info":
info = create_info(3, 1000.0, "eeg")
else:
info = read_info(fname_info)
protocol = getattr(pickle, f"{protocol.upper()}_PROTOCOL")
assert isinstance(info["bads"], MNEBadsList)
info["bads"] = info["ch_names"][:1]
assert not info._unlocked
info._unlocked = unlocked
data = pickle.dumps(info)
data = pickle.dumps(info, protocol=protocol)
info_un = pickle.loads(data) # nosec B301
assert isinstance(info_un, Info)
assert_object_equal(info, info_un)
assert info_un._unlocked == unlocked
assert isinstance(info_un["bads"], MNEBadsList)
assert info_un["bads"]._mne_info is info_un


def test_info_bad():
Expand Down

0 comments on commit c65345a

Please sign in to comment.