diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index f6aa1904..aa0ad3db 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -108,6 +108,25 @@ def __getitem__(self, index): return self.neural[index].transpose(2, 1) +def _assert_datasets_same_device( + datasets: List[cebra_data.SingleSessionDataset]) -> str: + """Checks if the list of datasets are all on the same device. + + Args: + datasets: List of datasets. + + Returns: + The device name if all datasets are on the same device. + + Raises: + ValueError: If datasets are not all on the same device. + """ + devices = set([dataset.device for dataset in datasets]) + if len(devices) != 1: + raise ValueError("Datasets are not all on the same device") + return devices.pop() + + class DatasetCollection(cebra_data.MultiSessionDataset): """Multi session dataset made up of a list of datasets. @@ -165,11 +184,13 @@ def __init__( self, *datasets: cebra_data.SingleSessionDataset, ): - super().__init__() self._datasets: List[ cebra_data.SingleSessionDataset] = self._unpack_dataset_arguments( datasets) + device = _assert_datasets_same_device(self._datasets) + super().__init__(device=device) + continuous = all( self._has_not_none_attribute(session, "continuous_index") for session in self.iter_sessions()) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 2a8026b9..d5099b66 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -758,12 +758,17 @@ def _iterate_actions(): def do_nothing(model): return model - def fit_model(model): + def fit_singlesession_model(model): X = np.linspace(-1, 1, 1000)[:, None] model.fit(X) return model - return [do_nothing, fit_model] + def fit_multisession_model(model): + X = np.linspace(-1, 1, 1000)[:, None] + model.fit([X, X], [X, X]) + return model + + return [do_nothing, fit_singlesession_model, fit_multisession_model] def _assert_same_state_dict(first, second): @@ -797,17 +802,43 @@ def check_fitted(model): _assert_same_state_dict(original_model.state_dict_, loaded_model.state_dict_) X = np.random.normal(0, 1, (100, 1)) - assert np.allclose(loaded_model.transform(X), - original_model.transform(X)) + + if loaded_model.num_sessions is not None: + assert np.allclose(loaded_model.transform(X, session_id=0), + original_model.transform(X, session_id=0)) + else: + assert np.allclose(loaded_model.transform(X), + original_model.transform(X)) @pytest.mark.parametrize("action", _iterate_actions()) def test_save_and_load(action): model_architecture = "offset10-model" original_model = cebra_sklearn_cebra.CEBRA( - model_architecture=model_architecture, max_iterations=5) + model_architecture=model_architecture, max_iterations=5, batch_size=42) original_model = action(original_model) with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile: original_model.save(savefile.name) loaded_model = cebra_sklearn_cebra.CEBRA.load(savefile.name) _assert_equal(original_model, loaded_model) + + +@pytest.mark.parametrize("device", ["cpu"] + + ["cuda"] if torch.cuda.is_available() else []) +@pytest.mark.parametrize("action", _iterate_actions()) +def test_check_devices(action, device): + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture="offset1-model", + max_iterations=5, + device=device, + batch_size=42, + ) + cebra_model = action(cebra_model) + assert cebra_model.device == device + + if action.__name__ != "do_nothing": + if device == "cuda": + #TODO(rodrigo): remove once https://github.com/AdaptiveMotorControlLab/CEBRA/pull/34 is merged. + device = torch.device(device, index=0) + assert next( + cebra_model.model_.parameters()).device == torch.device(device)