Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend type checking to all float datatypes #166

Merged
merged 10 commits into from
Oct 26, 2024
24 changes: 18 additions & 6 deletions cebra/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,35 @@ def __init__(
device: str = "cpu"
):
super().__init__(device=device)
self.neural = self._to_tensor(neural, torch.FloatTensor).float()
self.continuous = self._to_tensor(continuous, torch.FloatTensor)
self.discrete = self._to_tensor(discrete, torch.LongTensor)
self.neural = self._to_tensor(neural, check_dtype="float").float()
self.continuous = self._to_tensor(continuous, check_dtype="float").float()
self.discrete = self._to_tensor(discrete, check_dtype="integer")
if self.continuous is None and self.discrete is None:
raise ValueError(
"You have to pass at least one of the arguments 'continuous' or 'discrete'."
)
self.offset = offset

def _to_tensor(self, array, check_dtype=None):
def _to_tensor(self, array, check_dtype: str = None):
"""Convert :py:func:`numpy.array` to :py:class:`torch.Tensor` if necessary and check the dtype.

Args:
array: Array to check.
check_dtype (list, optional): If not `None`, list of dtypes to which the values in `array`
stes marked this conversation as resolved.
Show resolved Hide resolved
must belong to. Defaults to None.

Returns:
The `array` as a :py:class:`torch.Tensor`.
"""
if array is None:
return None
if isinstance(array, np.ndarray):
array = torch.from_numpy(array)
if check_dtype is not None:
if not isinstance(array, check_dtype):
raise TypeError(f"{type(array)} instead of {check_dtype}.")
if (check_dtype == "integer" and not cebra.helper._is_integer(array)
) or (check_dtype == "float" and not cebra.helper._is_floating(array)
) or (check_dtype == "float_integer" and not cebra.helper._is_floating_or_integer(array)):
CeliaBenquet marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError(f"{array.dtype} instead of {check_dtype}.")
stes marked this conversation as resolved.
Show resolved Hide resolved
return array

@property
Expand Down
15 changes: 14 additions & 1 deletion cebra/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _is_integer(y: Union[npt.NDArray, torch.Tensor]) -> bool:


def _is_floating(y: Union[npt.NDArray, torch.Tensor]) -> bool:
"""Check if the values in ``y`` are :py:class:`int`.
"""Check if the values in ``y`` are :py:class:`float`.

Note:
There is no ``torch`` method to check that the ``dtype`` of a :py:class:`torch.Tensor`
Expand All @@ -118,6 +118,19 @@ def _is_floating(y: Union[npt.NDArray, torch.Tensor]) -> bool:
y, torch.Tensor) and torch.is_floating_point(y))


def _is_floating_or_integer(y: Union[npt.NDArray, torch.Tensor]) -> bool:
"""Check if the values in ``y`` are :py:class:`int` or :py:class:`float`.

Args:
y: An array, either as a :py:func:`numpy.array` or a :py:class:`torch.Tensor`.

Returns:
``True`` if ``y`` contains :py:class:`float` or :py:class:`int`.
"""

return _is_floating(y) or _is_integer(y)


def get_loader_options(dataset: "cebra.data.Dataset") -> List[str]:
"""Return all possible dataloaders for the given dataset.

Expand Down
Loading