-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix type bug in to() method #55
Fix type bug in to() method #55
Conversation
Thank you for your contribution. We require contributors to sign our Contributor License Agreement (CLA). We do not have a signed CLA on file for you. In order for us to review and merge your code, please sign our CLA here. After you signed, you can comment on this PR with |
@cla-bot check |
Thanks for tagging me. I looked for a signed form under your signature again, and updated the status on this PR. If the check was successful, no further action is needed. If the check was unsuccessful, please see the instructions in my first comment. |
Hi @sofiagilardini, thanks a lot for flagging (and fixing!!) this. A good addition to this PR would be a unit test that fails in the old CEBRA version, and works after applying your fix. Would you be interested in contributing this test yourself, or should we add it to the PR based on your description? |
Hi @stes, I'd be happy to add it. Would you like the test to go in CEBRA/tests/test_sklearn.py ? Is there a specific pytest decorator you suggest? |
Hi @sofiagilardini , thanks a lot! Adding to
is probably the quickest way for testing on your end. If you want to run more of the test suite, I would suggest to use the If you can think of multiple test cases, using a Thanks again! |
Hi @stes, I've added the test, it can be run by doing: The code in this PR passes the test. The same test fails if you checkout the previous version of the sklearn integration: Let me know if there are any issues :) Sofia |
@gonlairo please code review ASAP. |
lgtm, ready to merge imo |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This requires one confirmation from @gonlairo that tests pass on GPU and Apple silicon devices, then good to go from my end.
Thanks @sofiagilardini for the contribution! |
Hi all, thank you very much for this great work.
I have been using the latest stable release of Cebra for my work on a summer research project. I recently switched to a new MacBook, so today I tried using main branch as it has MPS compatibility. I came across an issue with the recently added
to()
method within the CEBRA class for sklearn integration.From my understanding:
The recently implemented
to()
method within the CEBRA class in sklearn integration has a bug in the type of the class attributesdevice
anddevice_
. The two bugs are:device.startswith()
, which assumes device is of type str, yet the input type hint isdevice: Union[str, torch.device]
self.device
(andself.device_
if existent) to a torch.device object. This causes downstream errors when calling other class methods such aspartial_fit()
after having calledto()
, because the method_prepare_fit()
callssklearn_utils.check_device(self.device)
which fails whenself.device
is not of type str.The proposed fix modifies the logic for type checking and sets class attribute
self.device
(andself.device_
if existent) to a str object always. In this way, calling theto()
method does not break existing working code.