Skip to content

Commit

Permalink
FIX: explicitly set weights_only to avoid FutureWarning (#193)
Browse files Browse the repository at this point in the history
* FIX: explicitly set weights_only to avoid FutureWarning

- If I'm understanding the Torch 2.4 changelog correctly, you just need to explicitly pass a value to weights_only. Since the default is alraedy False, I am just explicitly setting it here so this should be backward compatible

* FIX: fix other uses of torch.load in codebase

* DOC: Update changelog

* Update mne_icalabel/iclabel/network/tests/test_network.py

* fix one more

* Change weights_only to True

---------

Co-authored-by: Mathieu Scheltienne <mathieu.scheltienne@gmail.com>
Co-authored-by: Mathieu Scheltienne <mathieu.scheltienne@fcbg.ch>
  • Loading branch information
3 people authored Jul 25, 2024
1 parent 1c36710 commit 3cfc1ab
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 4 deletions.
3 changes: 2 additions & 1 deletion doc/changes/authors.inc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.. _Adam Li: https://github.com/adam2392
.. _Mathieu Scheltienne: https://github.com/mscheltienne
.. _Jacob Feitelberg: https://github.com/jacobf18
.. _Anand Saini: https://github.com/anandsaini024
.. _Anand Saini: https://github.com/anandsaini024
.. _Scott Huberty: https://github.com/scott-huberty
1 change: 1 addition & 0 deletions doc/changes/latest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ Version 0.7
===========

- Raise helpful error message when montage is incomplete (:pr:`181` by `Mathieu Scheltienne`_)
- Explicitly pass ``weights_only=True`` in all instances of ``torch.load`` used by mne-icalabel, both to suppress a warning in PyTorch 2.4 and to follow best security practices (:pr:`193` by `Scott Huberty`_)
4 changes: 2 additions & 2 deletions mne_icalabel/iclabel/network/tests/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

def test_weights_pytorch():
"""Compare the weights of pytorch model and matconvnet model."""
network_python = torch.load(torch_iclabel_path)
network_python = torch.load(torch_iclabel_path, weights_only=True)
network_matlab = loadmat(matconvnet_iclabel_path)

# load weights from matlab network
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_network_outputs_pytorch():

# run the forward pass on pytorch
iclabel_net = ICLabelNet()
iclabel_net.load_state_dict(torch.load(torch_iclabel_path))
iclabel_net.load_state_dict(torch.load(torch_iclabel_path, weights_only=True))
torch_labels = iclabel_net(images, psd, autocorr)
torch_labels = torch_labels.detach().numpy() # (30, 7)

Expand Down
2 changes: 1 addition & 1 deletion mne_icalabel/iclabel/network/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _run_iclabel(images: ArrayLike, psds: ArrayLike, autocorr: ArrayLike) -> NDA
# load weights
network_file = files("mne_icalabel.iclabel.network") / "assets" / "ICLabelNet.pt"
iclabel_net = ICLabelNet()
iclabel_net.load_state_dict(torch.load(network_file))
iclabel_net.load_state_dict(torch.load(network_file, weights_only=True))
# format inputs and run forward pass
labels = iclabel_net(
*_format_input_for_torch(*_format_input(images, psds, autocorr))
Expand Down

0 comments on commit 3cfc1ab

Please sign in to comment.