Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type bug in to() method #55

Merged
merged 7 commits into from
Sep 12, 2023

Conversation

sofiagilardini
Copy link
Contributor

Hi all, thank you very much for this great work.

I have been using the latest stable release of Cebra for my work on a summer research project. I recently switched to a new MacBook, so today I tried using main branch as it has MPS compatibility. I came across an issue with the recently added to() method within the CEBRA class for sklearn integration.

From my understanding:

The recently implemented to() method within the CEBRA class in sklearn integration has a bug in the type of the class attributes device and device_. The two bugs are:

  • The logic of the if statements for type checking breaks when device is of type torch.device because of device.startswith(), which assumes device is of type str, yet the input type hint is device: Union[str, torch.device]
  • When the method is used with input device of type str, the original method sets self.device (and self.device_ if existent) to a torch.device object. This causes downstream errors when calling other class methods such as partial_fit() after having called to(), because the method _prepare_fit() calls sklearn_utils.check_device(self.device) which fails when self.device is not of type str.

The proposed fix modifies the logic for type checking and sets class attribute self.device (and self.device_ if existent) to a str object always. In this way, calling the to() method does not break existing working code.

@cla-bot
Copy link

cla-bot bot commented Aug 6, 2023

Thank you for your contribution. We require contributors to sign our Contributor License Agreement (CLA). We do not have a signed CLA on file for you. In order for us to review and merge your code, please sign our CLA here. After you signed, you can comment on this PR with @cla-bot check to trigger another check.

@stes stes self-requested a review August 6, 2023 19:05
@sofiagilardini
Copy link
Contributor Author

@cla-bot check

@cla-bot cla-bot bot added the CLA signed label Aug 6, 2023
@cla-bot
Copy link

cla-bot bot commented Aug 6, 2023

Thanks for tagging me. I looked for a signed form under your signature again, and updated the status on this PR. If the check was successful, no further action is needed. If the check was unsuccessful, please see the instructions in my first comment.

@stes
Copy link
Member

stes commented Aug 7, 2023

Hi @sofiagilardini, thanks a lot for flagging (and fixing!!) this.

A good addition to this PR would be a unit test that fails in the old CEBRA version, and works after applying your fix.

Would you be interested in contributing this test yourself, or should we add it to the PR based on your description?

@sofiagilardini
Copy link
Contributor Author

Hi @stes,

I'd be happy to add it. Would you like the test to go in CEBRA/tests/test_sklearn.py ?

Is there a specific pytest decorator you suggest?

@stes
Copy link
Member

stes commented Aug 9, 2023

Hi @sofiagilardini , thanks a lot!

Adding to CEBRA/tests/test_sklearn.py is perfect. In case it is not clear from the docs/Makefile, this here:

python -m pytest tests/test_sklearn.py::test_yours

is probably the quickest way for testing on your end. If you want to run more of the test suite, I would suggest to use the -m "not requires_dataset" for some speed gains.

If you can think of multiple test cases, using a pytest.mark.parametrize might be useful, otherwise no special requirements as long as the test reproduces your bug in the original version of the code, and this PRs makes the test pass.

Thanks again!

@sofiagilardini
Copy link
Contributor Author

Hi @stes,

I've added the test, it can be run by doing: python -m pytest tests/test_sklearn.py::test_fit_after_moving_to_device.

The code in this PR passes the test.

The same test fails if you checkout the previous version of the sklearn integration:
git checkout cda6e11eb89828ade70d3c342fff8bb955cb2b69 -- ./cebra/integrations/sklearn/cebra.py

Let me know if there are any issues :)

Sofia

@MMathisLab
Copy link
Member

@gonlairo please code review ASAP.

@gonlairo
Copy link
Contributor

gonlairo commented Sep 4, 2023

@gonlairo please code review ASAP.

lgtm, ready to merge imo

@MMathisLab MMathisLab requested a review from gonlairo September 9, 2023 12:10
Copy link
Member

@stes stes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires one confirmation from @gonlairo that tests pass on GPU and Apple silicon devices, then good to go from my end.

@stes stes changed the title fix type bug in to() method Fix type bug in to() method Sep 12, 2023
@stes stes added the bug Something isn't working label Sep 12, 2023
@stes
Copy link
Member

stes commented Sep 12, 2023

Thanks @sofiagilardini for the contribution!

@stes stes merged commit eda4aa7 into AdaptiveMotorControlLab:main Sep 12, 2023
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA signed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants