Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sklearn-API: support of all the integer types as discrete labels? #39

Closed
2 tasks done
gwappa opened this issue Jul 5, 2023 · 2 comments · Fixed by #43
Closed
2 tasks done

Sklearn-API: support of all the integer types as discrete labels? #39

gwappa opened this issue Jul 5, 2023 · 2 comments · Fixed by #43
Assignees
Labels
bug Something isn't working

Comments

@gwappa
Copy link

gwappa commented Jul 5, 2023

Is there an existing issue for this?

  • I have searched the existing issues

Bug description

Hi,

thanks a lot for sharing this nice work!

I happened to use numpy.uint8 as the discrete label type, and the resulting CEBRA model ended up performing time-contrastive learning (instead of supervised learning using discrete labels).

It seems that it requires the labels to have numpy.int32 or numpy.int64 data types in order for them to be considered to be discrete: https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/cebra/integrations/sklearn/dataset.py#L142

Is it possible that you guys further support the other integer types for discrete labels?

I found out this StackOverflow post: https://stackoverflow.com/questions/37726830/how-to-determine-if-a-number-is-any-type-of-int-core-or-numpy-signed-or-not

So substituting something like below should theoretically work:

-    elif y.dtype in (np.int32, np.int64):
+    elif np.issubdtype(y.dtype, np.integer):

Thank you very much in advance!

Operating System

  • Windows 10
  • Python 3.11.3 (Anaconda)

CEBRA version

cebra version 0.2.0

Device type

Core i9 / RTX 3090

Steps To Reproduce

Something like below should reproduce this issue (unfortunately I write this on another computer: please forgive me for any potential typos)

import numpy as np
import cebra

N   = 1000
rng = np.random.default_rng()

X = np.concatenate([rng.multivariate_normal(mean=[0, 0], cov=[[0.2, 0], [0, 0.2]], size=N),
                    rng.multivariate_normal(mean=[1, 0], cov=[[0.2, 0.1], [0.1, 0.2]], size=N)],
                   axis=0)
y = np.concatenate([np.zeros(N), np.ones(N)])

# casting the label to uint8 (the use of int32 instead should work)
y = y.astype(np.uint8)

_, _, loader, _ = cebra.CEBRA(batch_size = N // 2)._prepare_fit(X, y)

assert loader.dataset.discrete_index is not None

Relevant log output

No response

Anything else?

No response

Code of Conduct

@MMathisLab MMathisLab added the bug Something isn't working label Jul 5, 2023
@stes
Copy link
Member

stes commented Jul 5, 2023

Hi @gwappa , great catch, thanks a lot for your bug report.

This seems related to #30 (CC @CeliaBenquet ) , where we started fixing a similar issue for the decoder.

@gonlairo , let's try to convert this issue into a test, move the helper functions in #30 out of cebra/integrations/sklearn/decoder.py, and then fix the selection of the distribution when another data type is given?

@gwappa , I assume that this issue is not blocking you currently? We will probably address it in an upcoming version of CEBRA.

@gonlairo
Copy link
Contributor

gonlairo commented Jul 6, 2023

Thank you very much @gwappa. I'm on it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants