Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dtypes handled in the labels of the decoders #30

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 43 additions & 14 deletions cebra/integrations/sklearn/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,40 @@
import torch


def _is_integer(y: Union[npt.NDArray, torch.Tensor]) -> bool:
"""Check if the values in ``y`` are :py:class:`int`.

Args:
y: An array, either as a :py:func:`numpy.array` or a :py:class:`torch.Tensor`.

Returns:
``True`` if ``y`` contains :py:class:`int`.
"""
return (isinstance(y, np.ndarray) and np.issubdtype(y.dtype, np.integer)
) or (isinstance(y, torch.Tensor) and
(not torch.is_floating_point(y) and not torch.is_complex(y)))


def _is_floating(y: Union[npt.NDArray, torch.Tensor]) -> bool:
"""Check if the values in ``y`` are :py:class:`int`.

Note:
There is no ``torch`` method to check that the ``dtype`` of a :py:class:`torch.Tensor`
is a :py:class:`float`, consequently, we check that it is not :py:class:`int` nor
:py:class:`complex`.

Args:
y: An array, either as a :py:func:`numpy.array` or a :py:class:`torch.Tensor`.

Returns:
``True`` if ``y`` contains :py:class:`float`.
"""

return (isinstance(y, np.ndarray) and
np.issubdtype(y.dtype, np.floating)) or (isinstance(
y, torch.Tensor) and torch.is_floating_point(y))


class Decoder(abc.ABC, sklearn.base.BaseEstimator):
"""Abstract base class for implementing a decoder."""

Expand Down Expand Up @@ -118,10 +152,10 @@ def fit(
)

# Use regression or classification, based on if the targets are continuous or discrete
if y.dtype in (np.float32, np.float64, torch.float32, torch.float64):
if _is_floating(y):
self.knn = sklearn.neighbors.KNeighborsRegressor(
n_neighbors=self.n_neighbors, metric=self.metric)
elif y.dtype in (np.int32, np.int64, torch.int32, torch.int64):
elif _is_integer(y):
self.knn = sklearn.neighbors.KNeighborsClassifier(
n_neighbors=self.n_neighbors, metric=self.metric)
else:
Expand All @@ -132,7 +166,9 @@ def fit(
self.knn.fit(X, y)
return self

def predict(self, X: Union[npt.NDArray, torch.Tensor]) -> npt.NDArray:
def predict(
self, X: Union[npt.NDArray,
torch.Tensor]) -> Union[npt.NDArray, torch.Tensor]:
"""Predict the targets for data ``X``.

Args:
Expand Down Expand Up @@ -201,24 +237,17 @@ def fit(
f"Invalid shape: y and X must have the same number of samples, got y:{len(y)} and X:{len(X)}."
)

if not y.dtype in (
np.float32,
np.float64,
torch.float32,
torch.float64,
np.int32,
np.int64,
torch.int32,
torch.int64,
):
if not (_is_integer(y) or _is_floating(y)):
raise NotImplementedError(
f"Invalid type: targets must be numeric, got y:{y.dtype}")

self.model = sklearn.linear_model.Lasso(alpha=self.alpha)
self.model.fit(X, y)
return self

def predict(self, X: Union[npt.NDArray, torch.Tensor]) -> npt.NDArray:
def predict(
self, X: Union[npt.NDArray,
torch.Tensor]) -> Union[npt.NDArray, torch.Tensor]:
"""Predict the targets for data ``X``.

Args:
Expand Down
30 changes: 26 additions & 4 deletions tests/test_sklearn_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#
import numpy as np
import pytest
import torch

import cebra.integrations.sklearn.decoder as cebra_sklearn_decoder

Expand Down Expand Up @@ -48,25 +49,39 @@ def test_sklearn_decoder(decoder):
decoder.fit(X, y_c)
pred = decoder.predict(X)
assert isinstance(pred, np.ndarray)
assert pred.dtype in (np.float32, np.float64)
assert np.issubdtype(pred.dtype, np.floating)

score = decoder.score(X, y_c)
assert isinstance(score, float)

# torch
decoder.fit(torch.Tensor(X), torch.Tensor(y_c))
pred = decoder.predict(torch.Tensor(X))
assert isinstance(pred, np.ndarray)
assert np.issubdtype(pred.dtype, np.floating)

# discrete target
decoder.fit(X, y_d)
pred = decoder.predict(X)
assert isinstance(pred, np.ndarray)
assert pred.dtype in (np.int32, np.int64, np.float32, np.float64)
assert np.issubdtype(pred.dtype, np.integer) or np.issubdtype(
pred.dtype, np.floating)

score = decoder.score(X, y_d)
assert isinstance(score, float)

# torch
decoder.fit(torch.Tensor(X), torch.Tensor(y_d))
pred = decoder.predict(torch.Tensor(X))
assert isinstance(pred, np.ndarray)
assert np.issubdtype(pred.dtype, np.integer) or np.issubdtype(
pred.dtype, np.floating)

# multi-dim continuous target
decoder.fit(X, y_c_dim)
pred = decoder.predict(X)
assert isinstance(pred, np.ndarray)
assert pred.dtype in (np.float32, np.float64)
assert np.issubdtype(pred.dtype, np.floating)

score = decoder.score(X, y_c_dim)
assert isinstance(score, float)
Expand All @@ -76,7 +91,7 @@ def test_sklearn_decoder(decoder):
decoder.fit(X, multi_y)
pred = decoder.predict(X)
assert isinstance(pred, np.ndarray)
assert pred.dtype in (np.float32, np.float64)
assert np.issubdtype(pred.dtype, np.floating)

score = decoder.score(X, multi_y)
assert isinstance(score, float)
Expand All @@ -86,3 +101,10 @@ def test_sklearn_decoder(decoder):
decoder.fit(X, y_str)
with pytest.raises(ValueError, match="Invalid.*shape"):
decoder.fit(X, y_d_short)


def test_dtype_checker():
assert cebra_sklearn_decoder._is_floating(torch.Tensor([4.5]))
assert cebra_sklearn_decoder._is_integer(torch.LongTensor([4]))
assert cebra_sklearn_decoder._is_floating(np.array([4.5]))
assert cebra_sklearn_decoder._is_integer(np.array([4]))
stes marked this conversation as resolved.
Show resolved Hide resolved