Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ONNX support #129

Merged
merged 8 commits into from
Jun 26, 2023
Merged

Added ONNX support #129

merged 8 commits into from
Jun 26, 2023

Conversation

jacobf18
Copy link
Collaborator

@jacobf18 jacobf18 commented Jun 19, 2023

PR Description

Added ONNX support to iclabel. I changed the function run_iclabel to include a library parameter. This parameter's default value is set to use PyTorch. I added unit tests to check the network outputs against MATLAB. ONNX is faster to load and more lightweight than PyTorch.

The options for the library parameter are pytorch and onnx so far. It throws a ValueError for any other input.

Closes #125.

Merge checklist

Maintainer, please confirm the following before merging:

  • All comments resolved
  • This is not your own PR
  • All CIs are happy
  • PR title starts with [MRG]
  • whats_new.rst is updated
  • New contributors have been added to CITATION.cff
  • PR description includes phrase "closes <#issue-number>"

@jacobf18 jacobf18 requested a review from adam2392 June 20, 2023 04:57
Copy link
Member

@mscheltienne mscheltienne left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jacobf18 ! That looks good to me.

As you said, we are going to need to add some logic to load the libraries. I'm also +1 to make Qt an optional dependency. Overall, +1 to merge as is and clean-up, add library selection logic, and so on.. in a second PR.

mne_icalabel/iclabel/tests/test_network.py Outdated Show resolved Hide resolved
mne_icalabel/iclabel/tests/test_network.py Outdated Show resolved Hide resolved
Comment on lines -59 to +60
'PyQt5',
'PyQt6',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to change this?

Copy link
Member

@adam2392 adam2392 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can make the dependencies optional, then we should nest the import statements for torch and onnx inside the functions that use them.

Perhaps we should add a helper function that does the import and checks that at least one of them is installed?

Comment on lines 44 to +45
'torch',
'onnxruntime',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way we can make both optional dependencies, since the user only needs to install one?

@mscheltienne
Copy link
Member

Yes, we need to add some selection logic and get rid of the import at the top-level. Probably requires also 2 separate files, one for pytorch and one for onnx since the classes defined inherit from torch.nn directly. Maybe some logic like psychopy has here in an __init__.
Anyway, I think that could be done in a second PR which restructure the iclabel module before a 0.5 release.

@adam2392
Copy link
Member

Alright. Sounds good to me. LGTM with the exception of CI. Does it just need a restart possibly?

Thanks @jacobf18!

@mscheltienne
Copy link
Member

No the ubuntu ones have been broken for a while because of some xvfb stuff, it will be a good occasion to fix them as well..

@mscheltienne
Copy link
Member

@jacobf18 Do you want to give a shot to a logic to set pytorch and onnx as optional?
Ideally:

  • import mne_icalabel and from mne_icalabel import label_components would work regardless of the installation status of the backends.
  • import mne_icalabel.iclabel and from mne_icalabel.iclabel import label_components would fail if neither backend is available.

A couple of ideas: we should move the network part to a mne_icalabel.iclabel.network module, with 3 files:

  • __init__.py which attempts to load both in a similar fashion to psychopy with the drivers for the parallel port
  • pytorch.py with the pytorch backend
  • onnx.py with the onnx backend

Or maybe an empty __init__.py and some helper function which handles the loading of the backends.
The tests might also require a decorator @requires_onnx or @requires_torch.


Let me know if you want to give it a shot, else I'll do it in the coming days. Once it's sorted out, we can release a 0.5 version with this addition, and maybe get this correctly packaged on conda-forge for windows.

@mscheltienne
Copy link
Member

Merging and picking this up today with another PR to add the selection of the backend. Let's not stall this :)

@mscheltienne mscheltienne merged commit b7294e5 into mne-tools:main Jun 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ONNX/Tensorflow Port
3 participants