From 376a154b060c32c9a3933af65ce16c09d7b1f8a8 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Mon, 3 Jul 2023 17:09:25 +0200 Subject: [PATCH 01/11] add device support to DatasetCollection --- cebra/data/datasets.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index f6aa1904..45a90672 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -161,14 +161,35 @@ def _unpack_dataset_arguments( else: return datasets + def _assert_datasets_same_device( + self, datasets: List[cebra_data.SingleSessionDataset]) -> str: + + """Checks if the list of datasets are all on the same device. + + Args: + datasets: List of datasets. + + Returns: + str: The device name if all datasets are on the same device. + + Raises: + ValueError: If datasets are not all on the same device. + """ + devices = set([dataset.device for dataset in datasets]) + if len(devices) != 1: + raise ValueError("Datasets are not all on the same device") + return devices.pop() + def __init__( self, *datasets: cebra_data.SingleSessionDataset, ): - super().__init__() self._datasets: List[ cebra_data.SingleSessionDataset] = self._unpack_dataset_arguments( datasets) + + device = self._assert_datasets_same_device(self._datasets) + super().__init__(device = device) continuous = all( self._has_not_none_attribute(session, "continuous_index") From 136213977413fb8ade6dfa1e39cf47f9c53c52bc Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Mon, 3 Jul 2023 17:29:50 +0200 Subject: [PATCH 02/11] add tests --- tests/test_sklearn.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index c260a30b..342be7c3 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -757,12 +757,18 @@ def _iterate_actions(): def do_nothing(model): return model - def fit_model(model): + def fit_singlession_model(model): X = np.linspace(-1, 1, 1000)[:, None] model.fit(X) return model + + def fit_multisession_model(model): + X = np.linspace(-1, 1, 1000)[:, None] + model.fit([X, X], [X, X]) + return model - return [do_nothing, fit_model] + + return [do_nothing, fit_singlession_model, fit_multisession_model] def _assert_same_state_dict(first, second): @@ -810,3 +816,27 @@ def test_save_and_load(action): original_model.save(savefile.name) loaded_model = cebra_sklearn_cebra.CEBRA.load(savefile.name) _assert_equal(original_model, loaded_model) + +#let's contribute a test to test_sklearn that checks devices for training for all solvers, and link it in the issue here. + +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize("action", _iterate_actions()) +def test_check_devices(action, device): + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture="offset1-model", + time_offsets=10, + learning_rate=3e-4, + max_iterations=5, + device=device, + output_dimension=4, + batch_size=42, + verbose=True, + ) + cebra_model = action(cebra_model) + assert cebra_model.device == device + + if action.__name__ != "do_nothing": + if device == "cuda": + #TODO(rodrigo): remove once https://github.com/AdaptiveMotorControlLab/CEBRA/pull/34 is merged. + device = torch.device(device, index = 0) + assert next(cebra_model.model_.parameters()).device == torch.device(device) \ No newline at end of file From f91d4ee4524c5de76b80d953b108f2541d327faa Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Mon, 3 Jul 2023 17:09:25 +0200 Subject: [PATCH 03/11] add device support to DatasetCollection --- cebra/data/datasets.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index f6aa1904..45a90672 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -161,14 +161,35 @@ def _unpack_dataset_arguments( else: return datasets + def _assert_datasets_same_device( + self, datasets: List[cebra_data.SingleSessionDataset]) -> str: + + """Checks if the list of datasets are all on the same device. + + Args: + datasets: List of datasets. + + Returns: + str: The device name if all datasets are on the same device. + + Raises: + ValueError: If datasets are not all on the same device. + """ + devices = set([dataset.device for dataset in datasets]) + if len(devices) != 1: + raise ValueError("Datasets are not all on the same device") + return devices.pop() + def __init__( self, *datasets: cebra_data.SingleSessionDataset, ): - super().__init__() self._datasets: List[ cebra_data.SingleSessionDataset] = self._unpack_dataset_arguments( datasets) + + device = self._assert_datasets_same_device(self._datasets) + super().__init__(device = device) continuous = all( self._has_not_none_attribute(session, "continuous_index") From 9ea8ba89b0bbc6f1e2916428a23c6b18947ff853 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Mon, 3 Jul 2023 17:29:50 +0200 Subject: [PATCH 04/11] add tests --- tests/test_sklearn.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index c260a30b..342be7c3 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -757,12 +757,18 @@ def _iterate_actions(): def do_nothing(model): return model - def fit_model(model): + def fit_singlession_model(model): X = np.linspace(-1, 1, 1000)[:, None] model.fit(X) return model + + def fit_multisession_model(model): + X = np.linspace(-1, 1, 1000)[:, None] + model.fit([X, X], [X, X]) + return model - return [do_nothing, fit_model] + + return [do_nothing, fit_singlession_model, fit_multisession_model] def _assert_same_state_dict(first, second): @@ -810,3 +816,27 @@ def test_save_and_load(action): original_model.save(savefile.name) loaded_model = cebra_sklearn_cebra.CEBRA.load(savefile.name) _assert_equal(original_model, loaded_model) + +#let's contribute a test to test_sklearn that checks devices for training for all solvers, and link it in the issue here. + +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize("action", _iterate_actions()) +def test_check_devices(action, device): + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture="offset1-model", + time_offsets=10, + learning_rate=3e-4, + max_iterations=5, + device=device, + output_dimension=4, + batch_size=42, + verbose=True, + ) + cebra_model = action(cebra_model) + assert cebra_model.device == device + + if action.__name__ != "do_nothing": + if device == "cuda": + #TODO(rodrigo): remove once https://github.com/AdaptiveMotorControlLab/CEBRA/pull/34 is merged. + device = torch.device(device, index = 0) + assert next(cebra_model.model_.parameters()).device == torch.device(device) \ No newline at end of file From a1c11768ca8ee242b64e224e30b5d89db7092d77 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Fri, 7 Jul 2023 10:32:04 +0200 Subject: [PATCH 05/11] move helper function out of the class --- cebra/data/datasets.py | 39 +++++++++++++++++++-------------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index 45a90672..ab201d0f 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -107,6 +107,24 @@ def __getitem__(self, index): index = self.expand_index(index) return self.neural[index].transpose(2, 1) +def _assert_datasets_same_device( + datasets: List[cebra_data.SingleSessionDataset]) -> str: + + """Checks if the list of datasets are all on the same device. + + Args: + datasets: List of datasets. + + Returns: + str: The device name if all datasets are on the same device. + + Raises: + ValueError: If datasets are not all on the same device. + """ + devices = set([dataset.device for dataset in datasets]) + if len(devices) != 1: + raise ValueError("Datasets are not all on the same device") + return devices.pop() class DatasetCollection(cebra_data.MultiSessionDataset): """Multi session dataset made up of a list of datasets. @@ -161,25 +179,6 @@ def _unpack_dataset_arguments( else: return datasets - def _assert_datasets_same_device( - self, datasets: List[cebra_data.SingleSessionDataset]) -> str: - - """Checks if the list of datasets are all on the same device. - - Args: - datasets: List of datasets. - - Returns: - str: The device name if all datasets are on the same device. - - Raises: - ValueError: If datasets are not all on the same device. - """ - devices = set([dataset.device for dataset in datasets]) - if len(devices) != 1: - raise ValueError("Datasets are not all on the same device") - return devices.pop() - def __init__( self, *datasets: cebra_data.SingleSessionDataset, @@ -188,7 +187,7 @@ def __init__( cebra_data.SingleSessionDataset] = self._unpack_dataset_arguments( datasets) - device = self._assert_datasets_same_device(self._datasets) + device = _assert_datasets_same_device(self._datasets) super().__init__(device = device) continuous = all( From 286f752189710a0553ae0a9ca970db168ea21686 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Fri, 7 Jul 2023 10:33:07 +0200 Subject: [PATCH 06/11] fix typo in test --- tests/test_sklearn.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 342be7c3..21a58938 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -757,7 +757,7 @@ def _iterate_actions(): def do_nothing(model): return model - def fit_singlession_model(model): + def fit_singlesession_model(model): X = np.linspace(-1, 1, 1000)[:, None] model.fit(X) return model @@ -768,7 +768,7 @@ def fit_multisession_model(model): return model - return [do_nothing, fit_singlession_model, fit_multisession_model] + return [do_nothing, fit_singlesession_model, fit_multisession_model] def _assert_same_state_dict(first, second): @@ -824,13 +824,9 @@ def test_save_and_load(action): def test_check_devices(action, device): cebra_model = cebra_sklearn_cebra.CEBRA( model_architecture="offset1-model", - time_offsets=10, - learning_rate=3e-4, max_iterations=5, device=device, - output_dimension=4, batch_size=42, - verbose=True, ) cebra_model = action(cebra_model) assert cebra_model.device == device From 248a6a6adaef33c19c0aca4e40e143832ea761e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Gonz=C3=A1lez=20Laiz?= <31796689+gonlairo@users.noreply.github.com> Date: Wed, 12 Jul 2023 22:59:34 +0200 Subject: [PATCH 07/11] fix test --- tests/test_sklearn.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 21a58938..fe700c32 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -802,15 +802,20 @@ def check_fitted(model): _assert_same_state_dict(original_model.state_dict_, loaded_model.state_dict_) X = np.random.normal(0, 1, (100, 1)) - assert np.allclose(loaded_model.transform(X), - original_model.transform(X)) + + if loaded_model.num_sessions is not None: + assert np.allclose(loaded_model.transform(X, session_id=0), + original_model.transform(X, session_id=0)) + else: + assert np.allclose(loaded_model.transform(X), + original_model.transform(X)) @pytest.mark.parametrize("action", _iterate_actions()) def test_save_and_load(action): model_architecture = "offset10-model" original_model = cebra_sklearn_cebra.CEBRA( - model_architecture=model_architecture, max_iterations=5) + model_architecture=model_architecture, max_iterations=5, batch_size=42) original_model = action(original_model) with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile: original_model.save(savefile.name) @@ -835,4 +840,4 @@ def test_check_devices(action, device): if device == "cuda": #TODO(rodrigo): remove once https://github.com/AdaptiveMotorControlLab/CEBRA/pull/34 is merged. device = torch.device(device, index = 0) - assert next(cebra_model.model_.parameters()).device == torch.device(device) \ No newline at end of file + assert next(cebra_model.model_.parameters()).device == torch.device(device) From c3993f3c0d30382d30d0a677018acdb138eaba02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Gonz=C3=A1lez=20Laiz?= <31796689+gonlairo@users.noreply.github.com> Date: Wed, 12 Jul 2023 23:09:36 +0200 Subject: [PATCH 08/11] fix test_check_devices when cpu only available --- tests/test_sklearn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index fe700c32..154ba4b0 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -824,7 +824,7 @@ def test_save_and_load(action): #let's contribute a test to test_sklearn that checks devices for training for all solvers, and link it in the issue here. -@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize("device", ["cpu"] + ["cuda"] if torch.cuda.is_available() else []) @pytest.mark.parametrize("action", _iterate_actions()) def test_check_devices(action, device): cebra_model = cebra_sklearn_cebra.CEBRA( From fd67a36a1838d7a5b4985fed00a44e6e5376c9ae Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Thu, 13 Jul 2023 15:34:23 +0200 Subject: [PATCH 09/11] Update docstring --- cebra/data/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index ab201d0f..72205c82 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -116,7 +116,7 @@ def _assert_datasets_same_device( datasets: List of datasets. Returns: - str: The device name if all datasets are on the same device. + The device name if all datasets are on the same device. Raises: ValueError: If datasets are not all on the same device. From e9ca9fd9c7a1b18f055e1169c9ec49d7fdcf70d5 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Thu, 13 Jul 2023 15:34:35 +0200 Subject: [PATCH 10/11] Remove comment --- tests/test_sklearn.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 1f4ade6e..25f6a199 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -823,8 +823,6 @@ def test_save_and_load(action): loaded_model = cebra_sklearn_cebra.CEBRA.load(savefile.name) _assert_equal(original_model, loaded_model) -#let's contribute a test to test_sklearn that checks devices for training for all solvers, and link it in the issue here. - @pytest.mark.parametrize("device", ["cpu"] + ["cuda"] if torch.cuda.is_available() else []) @pytest.mark.parametrize("action", _iterate_actions()) def test_check_devices(action, device): From 10dc68b013476a32661fa9de8b1f6966aeb20ccf Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Thu, 13 Jul 2023 17:45:38 +0200 Subject: [PATCH 11/11] Run pre-commit formatting --- cebra/data/datasets.py | 7 ++++--- tests/test_sklearn.py | 16 +++++++++------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index 72205c82..aa0ad3db 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -107,9 +107,9 @@ def __getitem__(self, index): index = self.expand_index(index) return self.neural[index].transpose(2, 1) + def _assert_datasets_same_device( datasets: List[cebra_data.SingleSessionDataset]) -> str: - """Checks if the list of datasets are all on the same device. Args: @@ -126,6 +126,7 @@ def _assert_datasets_same_device( raise ValueError("Datasets are not all on the same device") return devices.pop() + class DatasetCollection(cebra_data.MultiSessionDataset): """Multi session dataset made up of a list of datasets. @@ -186,9 +187,9 @@ def __init__( self._datasets: List[ cebra_data.SingleSessionDataset] = self._unpack_dataset_arguments( datasets) - + device = _assert_datasets_same_device(self._datasets) - super().__init__(device = device) + super().__init__(device=device) continuous = all( self._has_not_none_attribute(session, "continuous_index") diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 25f6a199..d5099b66 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -762,13 +762,12 @@ def fit_singlesession_model(model): X = np.linspace(-1, 1, 1000)[:, None] model.fit(X) return model - + def fit_multisession_model(model): X = np.linspace(-1, 1, 1000)[:, None] model.fit([X, X], [X, X]) return model - return [do_nothing, fit_singlesession_model, fit_multisession_model] @@ -805,8 +804,8 @@ def check_fitted(model): X = np.random.normal(0, 1, (100, 1)) if loaded_model.num_sessions is not None: - assert np.allclose(loaded_model.transform(X, session_id=0), - original_model.transform(X, session_id=0)) + assert np.allclose(loaded_model.transform(X, session_id=0), + original_model.transform(X, session_id=0)) else: assert np.allclose(loaded_model.transform(X), original_model.transform(X)) @@ -823,7 +822,9 @@ def test_save_and_load(action): loaded_model = cebra_sklearn_cebra.CEBRA.load(savefile.name) _assert_equal(original_model, loaded_model) -@pytest.mark.parametrize("device", ["cpu"] + ["cuda"] if torch.cuda.is_available() else []) + +@pytest.mark.parametrize("device", ["cpu"] + + ["cuda"] if torch.cuda.is_available() else []) @pytest.mark.parametrize("action", _iterate_actions()) def test_check_devices(action, device): cebra_model = cebra_sklearn_cebra.CEBRA( @@ -838,5 +839,6 @@ def test_check_devices(action, device): if action.__name__ != "do_nothing": if device == "cuda": #TODO(rodrigo): remove once https://github.com/AdaptiveMotorControlLab/CEBRA/pull/34 is merged. - device = torch.device(device, index = 0) - assert next(cebra_model.model_.parameters()).device == torch.device(device) + device = torch.device(device, index=0) + assert next( + cebra_model.model_.parameters()).device == torch.device(device)