From fe63c876e1587ea680e1a4d899640d1d6871504f Mon Sep 17 00:00:00 2001
From: CEBRA
Date: Sat, 24 Jun 2023 13:57:56 +0200
Subject: [PATCH] Standardize files with current pre-commit config
---
CHANGELOG.md | 10 +-
CITATION.cff | 2 +-
CLA.md | 2 +-
CODE_OF_CONDUCT.md | 2 +-
LICENSE.md | 36 +--
Makefile | 7 +-
NOTICE.yml | 7 +-
PKGBUILD | 2 +-
README.md | 12 +-
cebra/__init__.py | 2 +-
cebra/data/__init__.py | 2 +-
cebra/data/helper.py | 12 +-
cebra/data/load.py | 4 +-
cebra/datasets/allen/ca_movie_decoding.py | 10 +-
cebra/datasets/allen/single_session_ca.py | 10 +-
cebra/datasets/hippocampus.py | 6 +-
cebra/datasets/monkey_reaching.py | 50 ++--
cebra/distributions/__init__.py | 2 +-
cebra/distributions/base.py | 2 +-
cebra/grid_search.py | 10 +-
cebra/integrations/__init__.py | 2 +-
cebra/integrations/matplotlib.py | 6 +-
cebra/integrations/sklearn/cebra.py | 4 +-
cebra/integrations/sklearn/metrics.py | 4 +-
cebra/models/__init__.py | 2 +-
cebra/models/model.py | 4 +-
cebra/registry.py | 12 +-
cebra/solver/base.py | 2 +-
cebra/solver/single_session.py | 10 +-
conda/cebra_paper_m1.yml | 2 +-
conftest.py | 31 ++-
docs/root/index.html | 4 +-
docs/source/_static/css/custom.css | 2 +-
docs/source/api.rst | 6 +-
docs/source/api/pytorch/data.rst | 6 +-
docs/source/api/pytorch/distributions.rst | 1 -
docs/source/api/pytorch/helpers.rst | 2 +-
docs/source/api/pytorch/models.rst | 6 +-
docs/source/api/sklearn/decoder.rst | 1 -
docs/source/api/sklearn/metrics.rst | 1 -
docs/source/conf.py | 1 +
docs/source/contributing.rst | 26 +-
docs/source/figures.rst | 8 +-
docs/source/index.rst | 6 +-
docs/source/installation.rst | 34 +--
docs/source/usage.rst | 317 +++++++++++-----------
setup.cfg | 2 +-
tests/test_datasets.py | 2 +-
tests/test_sklearn.py | 5 +-
tests/test_usecases.py | 5 +-
50 files changed, 354 insertions(+), 350 deletions(-)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index d96513b3..a8338f47 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,6 +1,6 @@
# Changelog
-When contributing a PR, please add the title, link and a short 1-2 line description of the
+When contributing a PR, please add the title, link and a short 1-2 line description of the
PR to this document. If you are an external contributor, please also add your github handle.
You can use markdown formatting in this document.
@@ -19,8 +19,8 @@ in this file to the released code version using the name of the github tag (e.g.
Add ``ensemble_embeddings`` that aligns multiple embeddings and combine them into an averaged one.
- **Move `max_validation_iterations` from `cebra.CEBRA` to `cebra.metrics.infonce_loss` [#527](https://github.com/AdaptiveMotorControlLab/CEBRA-dev/pull/527)**:
- Move `max_validation_iterations` from `cebra.CEBRA` to `cebra.metrics.infonce_loss` and
- rename the variable to `num_batches`.
+ Move `max_validation_iterations` from `cebra.CEBRA` to `cebra.metrics.infonce_loss` and
+ rename the variable to `num_batches`.
- **Add `plot_consistency` and demo notebook [#502](https://github.com/AdaptiveMotorControlLab/CEBRA-dev/pull/502)**:
Add `plot_consistency` helper function and complete the corresponding notebook.
@@ -50,7 +50,7 @@ It is the official first release distributed along with the publication of the C
- **Add cebra.plot package [#385](https://github.com/stes/neural_cl/pull/385)**:
Simplify post-hoc analysis of model performance and embeddings by collecting plotting functions for the most common usecases.
- **Multisession API integration [#333](https://github.com/stes/neural_cl/pull/333)**:
- Add multisession implementation compatibility to the sklearn API.
+ Add multisession implementation compatibility to the sklearn API.
- v0.0.2rc1
- **Implementation for general dataloading [#305](https://github.com/stes/neural_cl/pull/305)**:
Implement `load`, a general function to convert any supported data file types to ``numpy.array``.
@@ -59,7 +59,7 @@ It is the official first release distributed along with the publication of the C
- **Add quick testing option [#318](https://github.com/stes/neural_cl/pull/318)**:
Add slow marker for longer tests and a quick testing option for pytest and in github workflow.
- **Add CITATION.cff file [#339](https://github.com/stes/neural_cl/pull/339)**:
- Add CITATION.cff file for easy-to-use citation of the pre-print paper.
+ Add CITATION.cff file for easy-to-use citation of the pre-print paper.
- **Update sklearn dependency [#317](https://github.com/stes/neural_cl/pull/317)**:
The sklearn dependency was updated to `scikit-learn` as discussed
[in the scikit-learn docs](https://github.com/scikit-learn/sklearn-pypi-package)
diff --git a/CITATION.cff b/CITATION.cff
index 4011d6f2..3e5c4e9b 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -37,4 +37,4 @@ preferred-citation:
month: 05
doi: 10.1038/s41586-023-06031-6
issn: 1476-4687
- url: https://doi.org/10.1038/s41586-023-06031-6
\ No newline at end of file
+ url: https://doi.org/10.1038/s41586-023-06031-6
diff --git a/CLA.md b/CLA.md
index 8988f91c..878fe562 100644
--- a/CLA.md
+++ b/CLA.md
@@ -9,7 +9,7 @@ For reference, or for printing and emailing or mailing the form it is reproduced
CLA Version as of March 17th, 2023.
-Thank you for your interest in software from The Mathis Laboratory of
+Thank you for your interest in software from The Mathis Laboratory of
Adaptive Motor Control, UPMWMATHIS ("Lab").
In order to clarify the intellectual property license
granted with Contributions from any person or entity, the Lab
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
index 8d26f5c3..ef0a8c30 100644
--- a/CODE_OF_CONDUCT.md
+++ b/CODE_OF_CONDUCT.md
@@ -55,7 +55,7 @@ further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
-reported by contacting the project team at steffen.schneider@epfl.ch or steffen@bethgelab.org.
+reported by contacting the project team at steffen.schneider@epfl.ch or steffen@bethgelab.org.
All complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
diff --git a/LICENSE.md b/LICENSE.md
index cd714bd0..97234104 100644
--- a/LICENSE.md
+++ b/LICENSE.md
@@ -7,15 +7,15 @@ Steffen Schneider, Jin H Lee, Mackenzie W Mathis. 2023.
**Introduction**
- This license agreement sets forth the terms and conditions under which ECOLE POLYTECHNIQUE FEDERALE DE LAUSANNE ( EPFL),
-CH-1015 Lausanne, Switzerland and Prof. Mackenzie W. Mathis and code authors (hereafter "LICENSOR") will grant you
+CH-1015 Lausanne, Switzerland and Prof. Mackenzie W. Mathis and code authors (hereafter "LICENSOR") will grant you
(hereafter "LICENSEE") a fully-paid, non-exclusive, and non-transferable license for academic, non-commercial purposes only
(hereafter “LICENSE”) to use the "CEBRA" computer software program (hereafter "PROGRAM").
- LICENSEE acknowledges that the PROGRAM is a research tool that is being supplied "as is", without any related services,
improvements or warranties from LICENSOR and that this license is entered into in order to enable others to utilize the
-PROGRAM in their academic activities.
+PROGRAM in their academic activities.
-- The ideas covered in this work is also patent pending (as of Jan 2023): US 63/302,670 “DIMENSIONALITY REDUCTION OF TIME-SERIES DATA,
+- The ideas covered in this work is also patent pending (as of Jan 2023): US 63/302,670 “DIMENSIONALITY REDUCTION OF TIME-SERIES DATA,
AND SYSTEMS AND DEVICES THAT USE THE RESULTANT EMBEDDINGS”
- If this license is not appropriate for your application, please contact Prof. Mackenzie W. Mathis (mackenzie@post.harvard.edu)
@@ -23,34 +23,34 @@ and/or the TTO office at EPFL (tto@epfl.ch) for a commercial use license.
**Terms and Conditions of the LICENSE**
1. LICENSOR grants to LICENSEE a fully-paid up, non-exclusive, and non-transferable license to use the PROGRAM for academic,
- non-commercial purposes, upon the terms and conditions hereinafter set out and until termination of this license as set
+ non-commercial purposes, upon the terms and conditions hereinafter set out and until termination of this license as set
forth below.
-2. LICENSEE acknowledges the PROGRAM is provided "as is", without any related services or improvements from LICENSOR and
+2. LICENSEE acknowledges the PROGRAM is provided "as is", without any related services or improvements from LICENSOR and
that the LICENSE is entered into in order to enable others to utilize the PROGRAM in their academic activities.
3. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY REPRESENTATIONS OR
- WARRANTIES OF MERCHANTABILITY OR FITNESS FOR PARTICULAR PURPOSE OR THAT THE USE OF THE PROGRAM WILL NOT INFRINGE ANY
+ WARRANTIES OF MERCHANTABILITY OR FITNESS FOR PARTICULAR PURPOSE OR THAT THE USE OF THE PROGRAM WILL NOT INFRINGE ANY
PATENTS, COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS. LICENSOR shall not be liable for any direct, indirect or consequential
damages with respect to any claim by LICENSEE or any third party arising from this Agreement or use of the PROGRAM.
-
-4. LICENSEE agrees that it will use the PROGRAM, and any modifications, improvements, or derivatives to PROGRAM that
- LICENSEE may create (collectively, "IMPROVEMENTS") solely for academic, non-commercial purposes and shall not
+
+4. LICENSEE agrees that it will use the PROGRAM, and any modifications, improvements, or derivatives to PROGRAM that
+ LICENSEE may create (collectively, "IMPROVEMENTS") solely for academic, non-commercial purposes and shall not
distribute or transfer the PROGRAM or any IMPROVEMENTS to any person without prior written permission from LICENSOR.
Any IMPROVEMENTS must remain open source with a copy of this license. The terms "academic, non-commercial", as used
- in this Agreement, mean academic or other scholarly research which (a) is not undertaken for profit, or (b) is not
+ in this Agreement, mean academic or other scholarly research which (a) is not undertaken for profit, or (b) is not
intended to produce works, services, or data for commercial use, or (c) is neither conducted, nor funded, by a person
or an entity engaged in the commercial use, application or exploitation of works similar to the PROGRAM.
-5. LICENSEE agrees that they shall credit the use of CEBRA with an appropriate citation:
- Steffen Schneider, Jin H. Lee, Mackenzie Weygandt Mathis. Learnable latent embeddings for joint behavioral
+5. LICENSEE agrees that they shall credit the use of CEBRA with an appropriate citation:
+ Steffen Schneider, Jin H. Lee, Mackenzie Weygandt Mathis. Learnable latent embeddings for joint behavioral
and neural analysis. Nature 2023 doi: https://doi.org/10.1038/s41586-023-06031-6.
6. Ownership of all rights, including copyright in the PROGRAM and in any material associated therewith, shall at all times
- remain with LICENSOR and LICENSEE agrees to preserve the same. LICENSEE agrees not to use any portion of the PROGRAM or
- of any IMPROVEMENTS in any machine-readable form outside the PROGRAM, nor to make any copies except for its internal use,
+ remain with LICENSOR and LICENSEE agrees to preserve the same. LICENSEE agrees not to use any portion of the PROGRAM or
+ of any IMPROVEMENTS in any machine-readable form outside the PROGRAM, nor to make any copies except for its internal use,
without prior written consent of LICENSOR. LICENSEE agrees to maintain this license file with the source code and place the
- following copyright notice on any such copies:
+ following copyright notice on any such copies:
© All rights reserved. ECOLE POLYTECHNIQUE FÉDÉRALE DE LAUSANNE, Switzerland, Laboratory of Prof. Mackenzie W. Mathis
(UPMWMATHIS) and original authors: Steffen Schneider, Jin H Lee, Mackenzie W Mathis. 2023.
@@ -58,9 +58,9 @@ and/or the TTO office at EPFL (tto@epfl.ch) for a commercial use license.
7. The LICENSE shall not be construed to confer any rights upon LICENSEE by implication or otherwise except as specifically
set forth herein.
-8. This Agreement shall be governed by the material laws of Switzerland and any dispute arising out of this Agreement or
- use of the PROGRAM shall be brought before the courts of Lausanne, Switzerland.
+8. This Agreement shall be governed by the material laws of Switzerland and any dispute arising out of this Agreement or
+ use of the PROGRAM shall be brought before the courts of Lausanne, Switzerland.
-9. This Agreement and the LICENSE shall remain effective until expiration of the copyrights of the PROGRAM except that,
+9. This Agreement and the LICENSE shall remain effective until expiration of the copyrights of the PROGRAM except that,
upon any breach of this Agreement by LICENSEE, LICENSOR shall have the right to terminate the LICENSE immediately upon
notice to LICENSEE.
diff --git a/Makefile b/Makefile
index be4cc2fb..0d309fda 100644
--- a/Makefile
+++ b/Makefile
@@ -9,8 +9,8 @@ build: dist
archlinux:
mkdir -p dist/arch
- cp PKGBUILD dist/arch
- cp dist/cebra-0.2.0.tar.gz dist/arch
+ cp PKGBUILD dist/arch
+ cp dist/cebra-0.2.0.tar.gz dist/arch
(cd dist/arch; makepkg --skipchecksums -f)
# NOTE(stes): Ensure that no old tempfiles are present. Ideally, move this into
@@ -83,7 +83,7 @@ format:
# https://github.com/PyCQA/docformatter/issues/119
# is resolved.
# docformatter --config pyproject.toml -i cebra
- # docformatter --config pyproject.toml -i tests
+ # docformatter --config pyproject.toml -i tests
isort cebra/
isort tests/
@@ -99,4 +99,3 @@ report: check_docker format .coverage .pylint
coverage report
.PHONY: dist build archlinux clean_test test doctest test_parallel test_parallel_debug test_all test_fast test_debug test_benchmark interrogate docs docs-touch docs-strict serve_docs serve_page format codespell check_for_binary
-
diff --git a/NOTICE.yml b/NOTICE.yml
index b661a316..36d172de 100644
--- a/NOTICE.yml
+++ b/NOTICE.yml
@@ -1,11 +1,11 @@
# This notice file contains license headers for all code files
# in the repository (currently all *.py files).
-#
+#
# When updating headers, lower entries take precedence over higher
# entries. For each header, include/exclude statements can be used
# to define files to apply them to.
#
-# When adding code from an external repo, make sure to cover the
+# When adding code from an external repo, make sure to cover the
# added code files with the correct license header.
# Main repository license
@@ -22,7 +22,6 @@
include:
- 'cebra/**/*.py'
- - 'tests/**/*.py'
+ - 'tests/**/*.py'
- 'docs/**/*.py'
- 'conda/**/*.yml'
-
diff --git a/PKGBUILD b/PKGBUILD
index 835549db..e107a05c 100644
--- a/PKGBUILD
+++ b/PKGBUILD
@@ -39,7 +39,7 @@ build() {
package() {
cd $srcdir/${_pkgname}-${pkgver}
- pip install --ignore-installed --no-deps --root="${pkgdir}" dist/${_pkgname}-${pkgver}-py2.py3-none-any.whl
+ pip install --ignore-installed --no-deps --root="${pkgdir}" dist/${_pkgname}-${pkgver}-py2.py3-none-any.whl
find ${pkgdir} -iname __pycache__ -exec rm -r {} \; 2>/dev/null || echo
install -Dm 644 LICENSE.md $pkgdir/usr/share/licenses/${pkgname}/LICENSE
}
diff --git a/README.md b/README.md
index 3811d9ce..7efd0d21 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
-
+
@@ -11,7 +11,7 @@
[🛠️ Installation](https://cebra.ai/docs/installation.html) |
[🌎 Home Page](https://www.cebra.ai) |
[🚨 News](https://cebra.ai/docs/index.html) |
-[🪲 Reporting Issues](https://github.com/AdaptiveMotorControlLab/CEBRA)
+[🪲 Reporting Issues](https://github.com/AdaptiveMotorControlLab/CEBRA)
[![Downloads](https://static.pepy.tech/badge/cebra)](https://pepy.tech/project/cebra)
@@ -30,19 +30,19 @@
To receive updates on code releases, please 👀 watch or ⭐️ star this repository!
-``cebra`` is a self-supervised method for non-linear clustering that allows for label-informed time series analysis.
+``cebra`` is a self-supervised method for non-linear clustering that allows for label-informed time series analysis.
It can jointly use behavioral and neural data in a hypothesis- or discovery-driven manner to produce consistent, high-performance latent spaces. While it is not specific to neural and behavioral data, this is the first domain we used the tool in. This application case is to obtain a consistent representation of latent variables driving activity and behavior, improving decoding accuracy of behavioral variables over standard supervised learning, and obtaining embeddings which are robust to domain shifts.
-# Reference
+# Reference
- 📄 **Publication May 2023**:
[Learnable latent embeddings for joint behavioural and neural analysis.](https://doi.org/10.1038/s41586-023-06031-6)
Steffen Schneider*, Jin Hwa Lee* and Mackenzie Weygandt Mathis. Nature 2023.
-
+
- 📄 **Preprint April 2022**:
[Learnable latent embeddings for joint behavioral and neural analysis.](https://arxiv.org/abs/2204.00673)
Steffen Schneider*, Jin Hwa Lee* and Mackenzie Weygandt Mathis
-
+
# License
- CEBRA is released for academic use only (please read the license file). If this license is not appropriate for your application, please contact Prof. Mackenzie W. Mathis (mackenzie@post.harvard.edu) for a commercial use license.
diff --git a/cebra/__init__.py b/cebra/__init__.py
index 7f0e551c..66d8b602 100644
--- a/cebra/__init__.py
+++ b/cebra/__init__.py
@@ -9,7 +9,7 @@
# Please see LICENSE.md for the full license document:
# https://github.com/AdaptiveMotorControlLab/CEBRA/LICENSE.md
#
-"""CEBRA is a library for estimating Consistent Embeddings of high-dimensional Recordings
+"""CEBRA is a library for estimating Consistent Embeddings of high-dimensional Recordings
using Auxiliary variables. It contains self-supervised learning algorithms implemented in
PyTorch, and has support for a variety of different datasets common in biology and neuroscience.
"""
diff --git a/cebra/data/__init__.py b/cebra/data/__init__.py
index d3a9b7c0..41d22334 100644
--- a/cebra/data/__init__.py
+++ b/cebra/data/__init__.py
@@ -16,7 +16,7 @@
It is non-specific to a particular dataset (see :py:mod:`cebra.datasets` for actual dataset
implementations). However, the base classes for all datasets are defined here, as well as helper
functions to interact with datasets.
-
+
CEBRA supports different dataset types out-of-the box:
- :py:class:`cebra.data.single_session.SingleSessionDataset` is the abstract base class for a single session dataset. Single session datasets
diff --git a/cebra/data/helper.py b/cebra/data/helper.py
index 98638a0c..b6bfcbad 100644
--- a/cebra/data/helper.py
+++ b/cebra/data/helper.py
@@ -41,7 +41,7 @@ class OrthogonalProcrustesAlignment:
In linear algebra, the orthogonal Procrustes problem is a matrix approximation
problem. Considering two matrices A and B, it consists in finding the orthogonal
matrix R which most closely maps A to B, so that it minimizes the Frobenius norm of
- ``(A @ R) - B`` subject to ``R.T @ R = I``.
+ ``(A @ R) - B`` subject to ``R.T @ R = I``.
See :py:func:`scipy.linalg.orthogonal_procrustes` for more information.
For each dataset, the data and labels to align the data on is provided.
@@ -92,9 +92,9 @@ def fit(
label: Optional[npt.NDArray] = None,
) -> "OrthogonalProcrustesAlignment":
"""Compute the matrix solution of the orthogonal Procrustes problem.
-
+
The obtained matrix is used to align a dataset to a reference dataset.
-
+
Args:
ref_data: Reference data matrix on which to align the data.
data: Data matrix to align on the reference dataset.
@@ -117,7 +117,7 @@ def fit(
... data=aux_embedding,
... ref_label=ref_label,
... label=aux_label)
-
+
"""
if len(ref_data.shape) == 1:
ref_data = np.expand_dims(ref_data, axis=1)
@@ -307,9 +307,9 @@ def ensemble_embeddings(
Args:
embeddings: List of embeddings to align and ensemble.
- labels: Optional list of indexes associated to the embeddings in ``embeddings`` to align the embeddings on.
+ labels: Optional list of indexes associated to the embeddings in ``embeddings`` to align the embeddings on.
To be ensembled, the embeddings should already be aligned on time, and consequently do not require extra
- labels for alignment.
+ labels for alignment.
post_norm: If True, the resulting joint embedding is normalized (divided by its norm across
the features - axis 1).
n_jobs: The maximum number of concurrently running jobs to compute embedding alignment in a parallel manner using
diff --git a/cebra/data/load.py b/cebra/data/load.py
index ee35aeec..b98dae55 100644
--- a/cebra/data/load.py
+++ b/cebra/data/load.py
@@ -508,8 +508,8 @@ def load(
loaded_array = loaded_data
elif type(loaded_data) is dict:
if key is not None:
- if (key in loaded_data.keys() and type(loaded_data[key]) is
- np.ndarray): # check that key is valid
+ if (key in loaded_data.keys() and type(loaded_data[key])
+ is np.ndarray): # check that key is valid
loaded_array = loaded_data[key]
else:
raise AttributeError(
diff --git a/cebra/datasets/allen/ca_movie_decoding.py b/cebra/datasets/allen/ca_movie_decoding.py
index e9ea9adc..dfc0526f 100644
--- a/cebra/datasets/allen/ca_movie_decoding.py
+++ b/cebra/datasets/allen/ca_movie_decoding.py
@@ -9,7 +9,7 @@
# Please see LICENSE.md for the full license document:
# https://github.com/AdaptiveMotorControlLab/CEBRA/LICENSE.md
#
-"""Allen pseudomouse Ca decoding dataset with train/test split.
+"""Allen pseudomouse Ca decoding dataset with train/test split.
References:
*Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339.
@@ -149,9 +149,11 @@ def _split(self, pseudo_mice, frame_feature):
)
self.index = frame_feature.repeat(9, 1)
elif self.split_flag == "test":
- neural = pseudo_mice[self.neurons_indices, (self.test_repeat - 1) *
- self.movie_len:self.test_repeat *
- self.movie_len,]
+ neural = pseudo_mice[
+ self.neurons_indices,
+ (self.test_repeat - 1) * self.movie_len:self.test_repeat *
+ self.movie_len,
+ ]
self.index = frame_feature.repeat(1, 1)
else:
raise ValueError("split_flag should be either train or test")
diff --git a/cebra/datasets/allen/single_session_ca.py b/cebra/datasets/allen/single_session_ca.py
index b03568d5..da748ee9 100644
--- a/cebra/datasets/allen/single_session_ca.py
+++ b/cebra/datasets/allen/single_session_ca.py
@@ -351,16 +351,16 @@ def __init__(self, repeat_no, split_flag):
class SingleSessionAllenCaDecoding(cebra.data.SingleSessionDataset):
"""A corrupted single mouse 30Hz calcium events dataset during the allen MOVIE1 stimulus with train/test splits.
- A dataset of a single mouse 30Hz calcium events from the excitatory neurons
- in the primary visual cortex during the 10 repeats of the MOVIE1 stimulus
+ A dataset of a single mouse 30Hz calcium events from the excitatory neurons
+ in the primary visual cortex during the 10 repeats of the MOVIE1 stimulus
in session type A. The preprocessed data from *Deitch et al. (2021) are used.
- The continuous labels corresponding to a DINO embedding of each stimulus frame,
+ The continuous labels corresponding to a DINO embedding of each stimulus frame,
but in randomly shuffled order.
- A neural recording during the chosen repeat is used as a test set and the
+ A neural recording during the chosen repeat is used as a test set and the
remaining 9 repeats are used as a train set.
Args:
- session_id: The integer value to pick a session among 4 sessions with the
+ session_id: The integer value to pick a session among 4 sessions with the
largest number of recorded neruons. Choose between 0-3.
repeat_no: The nth repeat to use as the test set. Choose between 0-9.
split_flag: The `train`/`test` split to load.
diff --git a/cebra/datasets/hippocampus.py b/cebra/datasets/hippocampus.py
index 701a9e6c..1ce789cf 100644
--- a/cebra/datasets/hippocampus.py
+++ b/cebra/datasets/hippocampus.py
@@ -12,11 +12,11 @@
"""Rat hippocampus dataset
References:
- * Grosmark, A.D., and Buzsáki, G. (2016). Diversity in neural firing dynamics supports both rigid and learned
+ * Grosmark, A.D., and Buzsáki, G. (2016). Diversity in neural firing dynamics supports both rigid and learned
hippocampal sequences. Science 351, 1440–1443.
- * Chen, Z., Grosmark, A.D., Penagos, H., and Wilson, M.A. (2016). Uncovering representations of sleep-associated
+ * Chen, Z., Grosmark, A.D., Penagos, H., and Wilson, M.A. (2016). Uncovering representations of sleep-associated
hippocampal ensemble spike activity. Sci. Rep. 6, 32193.
- * Grosmark, A.D., Long J. and Buzsáki, G. (2016); Recordings from hippocampal area CA1, PRE, during and POST
+ * Grosmark, A.D., Long J. and Buzsáki, G. (2016); Recordings from hippocampal area CA1, PRE, during and POST
novel spatial learning. CRCNS.org. http://dx.doi.org/10.6080/K0862DC5
"""
diff --git a/cebra/datasets/monkey_reaching.py b/cebra/datasets/monkey_reaching.py
index 001044df..d20e3d21 100644
--- a/cebra/datasets/monkey_reaching.py
+++ b/cebra/datasets/monkey_reaching.py
@@ -9,7 +9,7 @@
# Please see LICENSE.md for the full license document:
# https://github.com/AdaptiveMotorControlLab/CEBRA/LICENSE.md
#
-"""Ephys neural and behavior data used for the monkey reaching experiment.
+"""Ephys neural and behavior data used for the monkey reaching experiment.
References:
* Chowdhury, Raeed H., Joshua I. Glaser, and Lee E. Miller. "Area 2 of primary somatosensory cortex encodes kinematics of the whole arm." Elife 9 (2020).
@@ -136,17 +136,17 @@ def _get_info(trial_info, data):
class Area2BumpDataset(cebra.data.SingleSessionDataset):
"""Base dataclass to generate monkey reaching datasets.
- Ephys and behavior recording from -100ms and 500ms from the movement
+ Ephys and behavior recording from -100ms and 500ms from the movement
onset in 1ms bin size.
Neural recording is smoothened with Gaussian kernel with 40ms std.
- The behavior labels can include trial types, target directions and the
+ The behavior labels can include trial types, target directions and the
x,y hand positions.
- After initialization of the dataset, split method can splits the data
+ After initialization of the dataset, split method can splits the data
into 'train', 'valid' and 'test' split.
Args:
path: The path to the directory where the preloaded data is.
- session: The trial type. Choose between 'active', 'passive',
+ session: The trial type. Choose between 'active', 'passive',
'all', 'active-passive'.
"""
@@ -169,9 +169,9 @@ def __init__(
def split(self, split):
"""Split the dataset.
- The train trials are the same as one defined in Neural Latent
+ The train trials are the same as one defined in Neural Latent
Benchmark (NLB) Dataset.
- The half of the valid trials defined in NLBDataset is used as
+ The half of the valid trials defined in NLBDataset is used as
the valid set and the other half is used as the test set.
Args:
@@ -242,18 +242,18 @@ def __getitem__(self, index):
class Area2BumpShuffledDataset(Area2BumpDataset):
"""Base dataclass to generate shuffled monkey reaching datasets.
- Ephys and behavior recording from -100ms and 500ms from the movement
+ Ephys and behavior recording from -100ms and 500ms from the movement
onset in 1ms bin size.
Neural recording is smoothened with Gaussian kernel with 40ms std.
- The shuffled behavior labels can include trial types, target directions
+ The shuffled behavior labels can include trial types, target directions
and the x,y hand positions.
- After initialization of the dataset, split method can splits the data
+ After initialization of the dataset, split method can splits the data
into 'train', 'valid' and 'test' split.
Args:
path: The path to the directory where the preloaded data is.
- session: The trial type. Choose between 'active', 'passive', 'all',
+ session: The trial type. Choose between 'active', 'passive', 'all',
'active-passive'.
"""
@@ -297,7 +297,7 @@ def _create_area2_dataset():
"""Register the monkey reaching datasets of different trial types, behavior labels.
The trial types are 'active', 'passive', 'all' and 'active-passive'.
- The 'active-passive' type distinguishes movement direction between active, passive
+ The 'active-passive' type distinguishes movement direction between active, passive
(0-7 for active and 8-15 for passive) and 'all' does not (0-7).
"""
@@ -310,12 +310,12 @@ class Dataset(Area2BumpDataset):
"""Monkey reaching dataset with hand position labels.
The dataset loads continuous x,y hand position as behavior labels.
- For the 'active-passive' trial type, it additionally loads discrete binary
+ For the 'active-passive' trial type, it additionally loads discrete binary
label of active(0)/passive(1).
Args:
path: The path to the directory where the preloaded data is.
- session: The trial type. Choose between 'active', 'passive', 'all',
+ session: The trial type. Choose between 'active', 'passive', 'all',
'active-passive'.
"""
@@ -342,7 +342,7 @@ class Dataset(Area2BumpDataset):
Args:
path: The path to the directory where the preloaded data is.
- session: The trial type. Choose between 'active', 'passive', 'all',
+ session: The trial type. Choose between 'active', 'passive', 'all',
'active-passive'.
"""
@@ -365,14 +365,14 @@ def continuous_index(self):
class Dataset(Area2BumpDataset):
"""Monkey reaching dataset with hand position labels and discrete target labels.
- The dataset loads continuous x,y hand position and discrete target labels (0-7)
+ The dataset loads continuous x,y hand position and discrete target labels (0-7)
as behavior labels.
- For active-passive type, the discrete target labels 0-7 for active and 8-16 for
+ For active-passive type, the discrete target labels 0-7 for active and 8-16 for
passive are loaded.
Args:
path: The path to the directory where the preloaded data is.
- session: The trial type. Choose between 'active', 'passive', 'all',
+ session: The trial type. Choose between 'active', 'passive', 'all',
'active-passive'.
"""
@@ -396,7 +396,7 @@ def _create_area2_shuffled_dataset():
"""Register the shuffled monkey reaching datasets of different trial types, behavior labels.
The trial types are 'active' and 'active-passive'.
- The behavior labels are randomly shuffled and the trial types are shuffled
+ The behavior labels are randomly shuffled and the trial types are shuffled
in case of 'shuffled-trial' datasets.
"""
@@ -408,12 +408,12 @@ def _create_area2_shuffled_dataset():
class Dataset(Area2BumpShuffledDataset):
"""Monkey reaching dataset with the shuffled trial type.
- The dataset loads the discrete binary trial type label active(0)/passive(1)
+ The dataset loads the discrete binary trial type label active(0)/passive(1)
in randomly shuffled order.
Args:
path: The path to the directory where the preloaded data is.
- session: The trial type. Choose between 'active', 'passive', 'all',
+ session: The trial type. Choose between 'active', 'passive', 'all',
'active-passive'.
"""
@@ -437,12 +437,12 @@ class Dataset(Area2BumpShuffledDataset):
"""Monkey reaching dataset with the shuffled hand position.
The dataset loads continuous x,y hand position in randomly shuffled order.
- For the 'active-passive' trial type, it additionally loads discrete binary label
+ For the 'active-passive' trial type, it additionally loads discrete binary label
of active(0)/passive(1).
Args:
path: The path to the directory where the preloaded data is.
- session: The trial type. Choose between 'active', 'passive', 'all',
+ session: The trial type. Choose between 'active', 'passive', 'all',
'active-passive'.
"""
@@ -465,12 +465,12 @@ def continuous_index(self):
class Dataset(Area2BumpShuffledDataset):
"""Monkey reaching dataset with the shuffled hand position.
- The dataset loads discrete target direction (0-7 for active and 0-15 for active-passive)
+ The dataset loads discrete target direction (0-7 for active and 0-15 for active-passive)
in randomly shuffled order.
Args:
path: The path to the directory where the preloaded data is.
- session: The trial type. Choose between 'active', 'passive', 'all',
+ session: The trial type. Choose between 'active', 'passive', 'all',
'active-passive'.
"""
diff --git a/cebra/distributions/__init__.py b/cebra/distributions/__init__.py
index 4fe0f149..58edffbe 100644
--- a/cebra/distributions/__init__.py
+++ b/cebra/distributions/__init__.py
@@ -14,7 +14,7 @@
This package contains classes for sampling and indexing of datasets.
Typically, the functionality of
classes in this module is guided by the auxiliary variables of CEBRA. A dataset would pass auxiliary
-variables to a sampler, and within the sampler the *indices* of reference, negative and positive
+variables to a sampler, and within the sampler the *indices* of reference, negative and positive
samples will be sampled based on the auxiliary information. Custom ways of sampling should therefore
be implemented in this package. Functionality in this package is fully agnostic to the actual signal
to be analysed, and only considers the auxiliary information of a dataset (called "index").
diff --git a/cebra/distributions/base.py b/cebra/distributions/base.py
index 3e1179fe..94092e4e 100644
--- a/cebra/distributions/base.py
+++ b/cebra/distributions/base.py
@@ -15,7 +15,7 @@
Distributions are defined in terms of _indices_ that reference samples within
the dataset.
-The appropriate base classes are defined in this module: An :py:class:`Index` is the
+The appropriate base classes are defined in this module: An :py:class:`Index` is the
part of the dataset used to inform the prior and conditional distributions;
and could for example be time, or information about an experimental condition.
"""
diff --git a/cebra/grid_search.py b/cebra/grid_search.py
index b779f25c..cbc05c95 100644
--- a/cebra/grid_search.py
+++ b/cebra/grid_search.py
@@ -151,7 +151,7 @@ def fit_models(self,
... verbose = False)
>>> # 2. Fit the models generated from the list of parameters
>>> grid_search = cebra.grid_search.GridSearch()
- >>> grid_search = grid_search.fit_models(datasets={"neural_data": neural_data},
+ >>> grid_search = grid_search.fit_models(datasets={"neural_data": neural_data},
... params=params_grid,
... models_dir="grid_search_models")
@@ -232,7 +232,7 @@ def load(cls, dir: str) -> Tuple[cebra_sklearn_cebra.CEBRA, List[dict]]:
>>> import cebra.grid_search
>>> models, parameter_grid = cebra.grid_search.GridSearch().load(dir="grid_search_models")
-
+
"""
dir = pathlib.Path(dir)
if not pathlib.Path.exists(dir):
@@ -331,7 +331,7 @@ def get_best_model(
>>> # 2. Fit the models generated from the list of parameters
>>> grid_search = cebra.grid_search.GridSearch()
>>> grid_search = grid_search.fit_models(datasets={"neural_data": neural_data},
- ... params=params_grid,
+ ... params=params_grid,
... models_dir="grid_search_models")
>>> # 3. Get model with the best performances and use it as usual
>>> best_model, best_model_name = grid_search.get_best_model()
@@ -395,7 +395,7 @@ def get_df_results(self, models_dir: str = None) -> pd.DataFrame:
>>> # 2. Fit the models generated from the list of parameters
>>> grid_search = cebra.grid_search.GridSearch()
>>> grid_search = grid_search.fit_models(datasets={"neural_data": neural_data},
- ... params=params_grid,
+ ... params=params_grid,
... models_dir="grid_search_models")
>>> # 3. Get results for all models
>>> df_results = grid_search.get_df_results()
@@ -453,7 +453,7 @@ def plot_loss_comparison(self,
>>> # 2. Fit the models generated from the list of parameters
>>> grid_search = cebra.grid_search.GridSearch()
>>> grid_search = grid_search.fit_models(datasets={"neural_data": neural_data},
- ... params=params_grid,
+ ... params=params_grid,
... models_dir="grid_search_models")
>>> # 3. Plot losses for all models
>>> ax = grid_search.plot_loss_comparison()
diff --git a/cebra/integrations/__init__.py b/cebra/integrations/__init__.py
index 1b670308..72d13820 100644
--- a/cebra/integrations/__init__.py
+++ b/cebra/integrations/__init__.py
@@ -12,7 +12,7 @@
"""Integration of CEBRA into common machine learning libraries.
This package contains a growing collection of interfaces to other Python packages.
-There is no clear limit (yet) of what can go into it. The current examples include
+There is no clear limit (yet) of what can go into it. The current examples include
interfaces (implemented or planned) to `scikit-learn `_,
`streamlit `_, `deeplabcut `_,
`matplotlib `_ and `threejs `_.
diff --git a/cebra/integrations/matplotlib.py b/cebra/integrations/matplotlib.py
index b4eaf1e2..f5d46fcf 100644
--- a/cebra/integrations/matplotlib.py
+++ b/cebra/integrations/matplotlib.py
@@ -736,8 +736,10 @@ def plot_overview(
figsize: tuple = (15, 4),
dpi: int = 100,
**kwargs,
-) -> Tuple[matplotlib.figure.Figure, Tuple[
- matplotlib.axes.Axes, matplotlib.axes.Axes, matplotlib.axes.Axes],]:
+) -> Tuple[
+ matplotlib.figure.Figure,
+ Tuple[matplotlib.axes.Axes, matplotlib.axes.Axes, matplotlib.axes.Axes],
+]:
"""Plot an overview of a trained CEBRA model.
Args:
diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py
index 2276b84d..a54223d8 100644
--- a/cebra/integrations/sklearn/cebra.py
+++ b/cebra/integrations/sklearn/cebra.py
@@ -746,8 +746,8 @@ def _check_labels_types(self, y: tuple, session_id: Optional[int] = None):
else:
label_types_idx = self._label_types[i][session_id]
- if (len(label_types_idx[1]) > 1 and len(y[i].shape) >
- 1): # is there more than one feature in the index
+ if (len(label_types_idx[1]) > 1 and len(y[i].shape)
+ > 1): # is there more than one feature in the index
if label_types_idx[1][1] != y[i].shape[1]:
raise ValueError(
f"Labels invalid: must have the same number of features as the ones used for fitting,"
diff --git a/cebra/integrations/sklearn/metrics.py b/cebra/integrations/sklearn/metrics.py
index 4fbc871a..c6c96972 100644
--- a/cebra/integrations/sklearn/metrics.py
+++ b/cebra/integrations/sklearn/metrics.py
@@ -53,8 +53,8 @@ def infonce_loss(
>>> cebra_model = cebra.CEBRA(max_iterations=10)
>>> cebra_model.fit(neural_data)
CEBRA(max_iterations=10)
- >>> loss = cebra.sklearn.metrics.infonce_loss(cebra_model,
- ... neural_data,
+ >>> loss = cebra.sklearn.metrics.infonce_loss(cebra_model,
+ ... neural_data,
... num_batches=5)
"""
diff --git a/cebra/models/__init__.py b/cebra/models/__init__.py
index 90590dad..e80ed379 100644
--- a/cebra/models/__init__.py
+++ b/cebra/models/__init__.py
@@ -12,7 +12,7 @@
"""Pre-defined neural network model architectures
This package contains everything related to implementing data encoders and the loss functions
-applied to the feature spaces. :py:mod:`cebra.models.criterions` contains the implementations of
+applied to the feature spaces. :py:mod:`cebra.models.criterions` contains the implementations of
InfoNCE and other contrastive losses. All additions regarding how data is encoded and losses are
computed should be added to this package.
diff --git a/cebra/models/model.py b/cebra/models/model.py
index 79ffec5a..f9567605 100644
--- a/cebra/models/model.py
+++ b/cebra/models/model.py
@@ -697,7 +697,7 @@ def get_offset(self) -> cebra.data.datatypes.Offset:
@_register_conditionally("offset36-model-dropout")
class Offset36Dropout(_OffsetModel, ConvolutionalModelMixin):
"""CEBRA model with a 10 sample receptive field.
-
+
Note:
Requires ``torch>=1.12``.
"""
@@ -741,7 +741,7 @@ def get_offset(self) -> cebra.data.datatypes.Offset:
@_register_conditionally("offset36-model-more-dropout")
class Offset36Dropoutv2(_OffsetModel, ConvolutionalModelMixin):
"""CEBRA model with a 10 sample receptive field.
-
+
Note:
Requires ``torch>=1.12``.
"""
diff --git a/cebra/registry.py b/cebra/registry.py
index 54809b91..359d1def 100644
--- a/cebra/registry.py
+++ b/cebra/registry.py
@@ -321,15 +321,15 @@ def _wrap(text, indent: int):
>>> print({module.__name__}.get_options())
{_shorten(options)}
- To obtain an initialized instance, call ``{module.__name__}.init``,
- defined in :py:func:`cebra.registry.add_helper_functions`.
- The first parameter to provide is the {toplevel_name} name to use,
+ To obtain an initialized instance, call ``{module.__name__}.init``,
+ defined in :py:func:`cebra.registry.add_helper_functions`.
+ The first parameter to provide is the {toplevel_name} name to use,
which is one of the available options presented above.
- Then the required positional arguments specific to the module are provided, if
- needed.
+ Then the required positional arguments specific to the module are provided, if
+ needed.
You can register additional options by defining and registering
- classes with a name. To do that, you can add a decorator on top of it:
+ classes with a name. To do that, you can add a decorator on top of it:
``@{module.__name__}.register("my-{module.__name__.replace('.', '-')}")``.
Later, initialize your class similarly to the pre-defined options, using ``{module.__name__}.init``
diff --git a/cebra/solver/base.py b/cebra/solver/base.py
index d5fefb95..92103173 100644
--- a/cebra/solver/base.py
+++ b/cebra/solver/base.py
@@ -15,7 +15,7 @@
loops. When subclassing abstract solvers, in the simplest case only the
:py:meth:`Solver._inference` needs to be overridden.
-For more complex use cases, the :py:meth:`Solver.step` and
+For more complex use cases, the :py:meth:`Solver.step` and
:py:meth:`Solver.fit` method can be overridden to
implement larger changes to the training loop.
"""
diff --git a/cebra/solver/single_session.py b/cebra/solver/single_session.py
index 204daac3..4dc03dde 100644
--- a/cebra/solver/single_session.py
+++ b/cebra/solver/single_session.py
@@ -87,15 +87,15 @@ def get_embedding(self, data: torch.Tensor) -> torch.Tensor:
class SingleSessionAuxVariableSolver(abc_.Solver):
"""Single session training for reference and positive/negative samples.
- This solver processes reference samples with a model different from
+ This solver processes reference samples with a model different from
processing the positive and
negative samples. Requires that the ``reference_model`` is initialized
- to be different from the ``model`` used to process the positive and
+ to be different from the ``model`` used to process the positive and
negative samples.
- Besides using an asymmetric encoder for the same modality, this solver
- also allows for e.g. time-contrastive learning across modalities, by
- using a reference model on modality A, and a different model processing
+ Besides using an asymmetric encoder for the same modality, this solver
+ also allows for e.g. time-contrastive learning across modalities, by
+ using a reference model on modality A, and a different model processing
the signal from modality B.
"""
diff --git a/conda/cebra_paper_m1.yml b/conda/cebra_paper_m1.yml
index c020f393..223126f4 100644
--- a/conda/cebra_paper_m1.yml
+++ b/conda/cebra_paper_m1.yml
@@ -10,7 +10,7 @@
## https://github.com/AdaptiveMotorControlLab/CEBRA/LICENSE.md
##
-#NOTE: if you need to install piVAE, then you need to install tensorflow for apple. This is NOT required for cebra-only use.
+#NOTE: if you need to install piVAE, then you need to install tensorflow for apple. This is NOT required for cebra-only use.
#for M1/m2 chip, use miniconda3 (https://developer.apple.com/metal/tensorflow-plugin/)
# Get the miniconda M1/M2 bash installer, as explained in
# https://docs.conda.io/projects/conda/en/latest/user-guide/install/macos.html
diff --git a/conftest.py b/conftest.py
index 97665bf9..40aed8aa 100644
--- a/conftest.py
+++ b/conftest.py
@@ -1,7 +1,7 @@
"""Configuration options for pytest
See Also:
- * https://docs.pytest.org/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option
+ * https://docs.pytest.org/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option
"""
import pytest
@@ -9,16 +9,18 @@
def pytest_addoption(parser):
"""Define customized pytest flags.
-
+
Examples:
>>> pytest tests/test_sklearn.py --runfast
"""
- parser.addoption(
- "--runfast", action="store_true", default=False, help="don't run slow tests"
- )
- parser.addoption(
- "--runbenchmark", action="store_true", default=False, help="run benchmark test"
- )
+ parser.addoption("--runfast",
+ action="store_true",
+ default=False,
+ help="don't run slow tests")
+ parser.addoption("--runbenchmark",
+ action="store_true",
+ default=False,
+ help="run benchmark test")
def pytest_configure(config):
@@ -27,23 +29,26 @@ def pytest_configure(config):
config.addinivalue_line("markers", "fast: run tests with fast arguments")
config.addinivalue_line("markers", "benchmark: run benchmark tests")
+
def pytest_collection_modifyitems(config, items):
"""Select tests to skip based on current flag.
-
+
By default, slow arguments are used and fast ones are skipped.
If runfast flag is provided, tests are run with the arguments marked
- as fast, arguments marked as slow are be skipped.
-
+ as fast, arguments marked as slow are be skipped.
+
"""
if config.getoption("--runfast"):
# --runfast given in cli: skip slow tests
- skip_slow = pytest.mark.skip(reason="test marked as slowing a --runfast mode")
+ skip_slow = pytest.mark.skip(
+ reason="test marked as slowing a --runfast mode")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)
else:
- skip_fast = pytest.mark.skip(reason="test marked as fast, run only in --runfast mode")
+ skip_fast = pytest.mark.skip(
+ reason="test marked as fast, run only in --runfast mode")
for item in items:
if "fast" in item.keywords:
item.add_marker(skip_fast)
diff --git a/docs/root/index.html b/docs/root/index.html
index 4beb467d..86015297 100644
--- a/docs/root/index.html
+++ b/docs/root/index.html
@@ -186,7 +186,7 @@
CEBRA applied to mouse primary visual cortex, collected at the Allen Institute (de Vries et al. 2020, Siegle et al. 2021). 2-photon and Neuropixels recordings are embedded with CEBRA using DINO frame features as labels.
The embedding is used to decode the video frames using a kNN decoder on the CEBRA-Behavior embedding from the test set.
-
+
@@ -244,7 +244,7 @@
You can find our official implementation of the CEBRA algorithm on GitHub:
Watch and Star the repository to
be notified of future updates and releases.
- You can also follow us on Twitter or subscribe to our
+ You can also follow us on Twitter or subscribe to our
mailing list for updates on the project.
diff --git a/docs/source/_static/css/custom.css b/docs/source/_static/css/custom.css
index 4f84d46a..29760478 100644
--- a/docs/source/_static/css/custom.css
+++ b/docs/source/_static/css/custom.css
@@ -42,7 +42,7 @@ dt {
em.sig-param {
font-family: 'IBM Plex Mono', monospace;
font-style: normal;
- color: rgb(23, 23, 23);
+ color: rgb(23, 23, 23);
}
dt.sig.sig-object.py {
diff --git a/docs/source/api.rst b/docs/source/api.rst
index 83a1554d..a71b2d15 100644
--- a/docs/source/api.rst
+++ b/docs/source/api.rst
@@ -9,13 +9,13 @@ CEBRA has two main APIs:
- The low-level ``torch`` API exposes models, layers, loss functions and other components. The ``torch`` API exposes all low-level functions and classes used for training CEBRA models.
For **day-to-day use of CEBRA**, it is sufficient to know the high-level ``scikit-learn`` API, which
-is currently limited to a single estimator class, :py:class:`cebra.CEBRA`. CEBRA's main
+is currently limited to a single estimator class, :py:class:`cebra.CEBRA`. CEBRA's main
functionalities are covered by this class.
For machine learning researchers, and everybody with **custom data analysis needs**, we expose
all core functions of CEBRA via our ``torch`` API. This allows more fine-grained control over
the different components of the algorithm (models used for encoders, addition of custom
-sampling mechanisms, variations of the base loss function, etc.). It also allows to use
+sampling mechanisms, variations of the base loss function, etc.). It also allows to use
these components in other contexts and research code bases.
.. toctree::
@@ -25,7 +25,7 @@ these components in other contexts and research code bases.
api/sklearn/cebra
api/sklearn/metrics
api/sklearn/decoder
-
+
.. toctree::
diff --git a/docs/source/api/pytorch/data.rst b/docs/source/api/pytorch/data.rst
index 34557670..4f35932a 100644
--- a/docs/source/api/pytorch/data.rst
+++ b/docs/source/api/pytorch/data.rst
@@ -34,7 +34,7 @@ Pre-defined Datasets
:show-inheritance:
-Single Session Dataloaders
+Single Session Dataloaders
--------------------------------
.. automodule:: cebra.data.single_session
@@ -42,7 +42,7 @@ Single Session Dataloaders
:show-inheritance:
-Multi Session Dataloaders
+Multi Session Dataloaders
--------------------------------
.. automodule:: cebra.data.multi_session
@@ -56,5 +56,3 @@ Datatypes
.. automodule:: cebra.data.datatypes
:members:
:show-inheritance:
-
-
diff --git a/docs/source/api/pytorch/distributions.rst b/docs/source/api/pytorch/distributions.rst
index 5fb39466..fc3bedab 100644
--- a/docs/source/api/pytorch/distributions.rst
+++ b/docs/source/api/pytorch/distributions.rst
@@ -47,4 +47,3 @@ Multi-session
.. automodule:: cebra.distributions.multisession
:members:
:show-inheritance:
-
diff --git a/docs/source/api/pytorch/helpers.rst b/docs/source/api/pytorch/helpers.rst
index 5b6a89e0..2900b218 100644
--- a/docs/source/api/pytorch/helpers.rst
+++ b/docs/source/api/pytorch/helpers.rst
@@ -29,7 +29,7 @@ Plots
:members:
-Grid-Search
+Grid-Search
-----------
.. automodule:: cebra.grid_search
diff --git a/docs/source/api/pytorch/models.rst b/docs/source/api/pytorch/models.rst
index 7c20012d..ee3455bc 100644
--- a/docs/source/api/pytorch/models.rst
+++ b/docs/source/api/pytorch/models.rst
@@ -14,9 +14,9 @@ Registration and initialization
.. autofunction:: get_options
-.. autofunction:: register
+.. autofunction:: register
-.. autofunction:: parametrize
+.. autofunction:: parametrize
Models
@@ -50,5 +50,5 @@ Multi-objective models
:private-members:
:show-inheritance:
-..
+..
- projector
diff --git a/docs/source/api/sklearn/decoder.rst b/docs/source/api/sklearn/decoder.rst
index aeb96151..9d61aa5c 100644
--- a/docs/source/api/sklearn/decoder.rst
+++ b/docs/source/api/sklearn/decoder.rst
@@ -4,4 +4,3 @@ Decoders
.. automodule:: cebra.integrations.sklearn.decoder
:show-inheritance:
:members:
-
diff --git a/docs/source/api/sklearn/metrics.rst b/docs/source/api/sklearn/metrics.rst
index 3a20c794..b63def74 100644
--- a/docs/source/api/sklearn/metrics.rst
+++ b/docs/source/api/sklearn/metrics.rst
@@ -4,4 +4,3 @@ Metrics
.. automodule:: cebra.integrations.sklearn.metrics
:show-inheritance:
:members:
-
diff --git a/docs/source/conf.py b/docs/source/conf.py
index e329c078..1fe66c32 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -28,6 +28,7 @@
sys.path.insert(0, os.path.abspath("."))
import datetime
+
import cebra
diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst
index f2be152a..9743dfff 100644
--- a/docs/source/contributing.rst
+++ b/docs/source/contributing.rst
@@ -1,8 +1,8 @@
Contribution Guide
==================
-CEBRA is an actively developed package and we welcome community development
-and involvement. We are happy to receive code extensions, bug fixes, documentation
+CEBRA is an actively developed package and we welcome community development
+and involvement. We are happy to receive code extensions, bug fixes, documentation
updates, etc, but please sign the `Contributor License Agreement (CLA) `_
and note that it was signed in your pull request.
@@ -12,7 +12,7 @@ Development setup
Development should be done inside the provided docker environment.
All essential commands are included in the project ``Makefile``.
-To start an interactive console, run:
+To start an interactive console, run:
.. code:: bash
@@ -24,7 +24,7 @@ We use ``pytest`` for running tests. The full test suite can be run with:
$ make test
-A faster version of the test suite, only running one iteration of each longer tests, can be run with:
+A faster version of the test suite, only running one iteration of each longer tests, can be run with:
.. code:: bash
@@ -46,7 +46,7 @@ Code is formatted using `Google code style .ipynb>
-**Example:**
+**Example:**
.. code:: bash
diff --git a/docs/source/figures.rst b/docs/source/figures.rst
index ffd0c425..24b1987e 100644
--- a/docs/source/figures.rst
+++ b/docs/source/figures.rst
@@ -13,19 +13,19 @@ in two categories:
on a set of worked examples, this is the place to start.
* The collection of plotting code for all paper figures. The figures are generated from cached experimental
results. For data (re-) analysis and performance comparisons of CEBRA, this is the easiest way to get started.
- * The collection of experiments for obtaining results for the figures. Experiments should ideally be run on
- a GPU cluster with SLURM pre-installed for the best user experience. Alternatively, experiments can also be
+ * The collection of experiments for obtaining results for the figures. Experiments should ideally be run on
+ a GPU cluster with SLURM pre-installed for the best user experience. Alternatively, experiments can also be
manually scheduled (our submission system produces a stack of bash files which can be executed on any machine).
We recommend this route for follow-up research, when CEBRA (or any of our baselines) should be used for
comparisons against other methods.
-List of paper figures
+List of paper figures
---------------------
We provide reference code for plotting all paper figures here.
-Note that for the paper version, panels might have been post edited, and the figures might
+Note that for the paper version, panels might have been post edited, and the figures might
differ in minor typographic details.
.. toctree::
diff --git a/docs/source/index.rst b/docs/source/index.rst
index a2764bcb..73979414 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -38,7 +38,7 @@ Please see the dedicated :doc:`Installation Guide ` for informati
Have fun! 😁
-Usage
+Usage
-----
Please head over to the :doc:`Usage ` tab to find step-by-step instructions to use CEBRA on your data. For example use cases, see the :doc:`Demos ` tab.
@@ -54,8 +54,8 @@ possibility to compute CEBRA embeddings on DeepLabCut_ outputs directly.
Licensing
---------
-© All rights reserved. ECOLE POLYTECHNIQUE FÉDÉRALE DE LAUSANNE, Switzerland, Laboratory of Prof. Mackenzie W. Mathis (UPMWMATHIS) and original authors: Steffen Schneider, Jin H Lee, Mackenzie W Mathis. 2023.
-It is made available for non-commercial research use only. It comes without any warranty or guarantee.
+© All rights reserved. ECOLE POLYTECHNIQUE FÉDÉRALE DE LAUSANNE, Switzerland, Laboratory of Prof. Mackenzie W. Mathis (UPMWMATHIS) and original authors: Steffen Schneider, Jin H Lee, Mackenzie W Mathis. 2023.
+It is made available for non-commercial research use only. It comes without any warranty or guarantee.
Please see the full license file on Github_, and if it is not suitable to your project, please email_ Mackenzie Mathis for a commercial license.
diff --git a/docs/source/installation.rst b/docs/source/installation.rst
index 448b10ea..56ee556a 100644
--- a/docs/source/installation.rst
+++ b/docs/source/installation.rst
@@ -17,7 +17,7 @@ CEBRA is written in Python (3.8+) and PyTorch. CEBRA is most effective when used
Installation Guide
------------------
-We outline installation instructions for different systems.
+We outline installation instructions for different systems.
CEBRA will be installed via ``pip install cebra``.
Its dependencies can be installed using ``pip`` or ``conda`` and
@@ -38,7 +38,7 @@ Most users can only install the **minimal install**. 🚀 For more advanced user
.. tab:: Google Colab
CEBRA can also be installed and run on Google colaboratory.
- Please see the ``open in colab`` button at the top of each demo notebook for examples.
+ Please see the ``open in colab`` button at the top of each demo notebook for examples.
If you are starting with a new notebook, simply run
@@ -60,16 +60,16 @@ Most users can only install the **minimal install**. 🚀 For more advanced user
.. tab:: Supplied conda (paper reproduction)
-
- We provide a ``conda`` environment with the full requirements needed to reproduce the first CEBRA paper (although we
+
+ We provide a ``conda`` environment with the full requirements needed to reproduce the first CEBRA paper (although we
recommend using Docker). Namely, you can run CEBRA, piVAE, tSNE and UMAP within this conda env. It is *NOT* needed if you only want to use CEBRA.
-
+
* For all platforms except MacOS with M1/2 chipsets, create the full environment using ``cebra_paper.yml``, by running the following from the CEBRA repo root directory:
-
+
.. code:: bash
$ conda env create -f conda/cebra_paper.yml
-
+
* If you are a MacOS M1 or M2 user and want to reproduce the paper, use the ``cebra_paper_m1.yml`` instead. You'll need to install tensorflow. For that, use `miniconda3 `_ and follow the setup instructions for tensorflow listed in the `Apple developer docs `_. In the Terminal, run the following commands:
.. code:: bash
@@ -78,7 +78,7 @@ Most users can only install the **minimal install**. 🚀 For more advanced user
bash ~/miniconda.sh -b -p $HOME/miniconda
source ~/miniconda/bin/activate
conda init zsh
-
+
Then, you can build the full environment from the root directory:
.. code:: bash
@@ -87,7 +87,7 @@ Most users can only install the **minimal install**. 🚀 For more advanced user
.. tab:: conda
- Conda users should currently use ``pip`` for installation. The missing dependencies will be installed in the install process. A fresh conda environment can be created using
+ Conda users should currently use ``pip`` for installation. The missing dependencies will be installed in the install process. A fresh conda environment can be created using
.. code:: bash
@@ -119,12 +119,12 @@ Most users can only install the **minimal install**. 🚀 For more advanced user
.. rubric:: Install CEBRA using ``pip``
Once PyTorch is set up, the remaining dependencies can be installed via ``pip``. Select the correct feature
- set based on your usecase:
+ set based on your usecase:
* Regular usage
.. code:: bash
-
+
$ pip install cebra
* Inference and development tools only
@@ -148,18 +148,18 @@ Most users can only install the **minimal install**. 🚀 For more advanced user
.. note::
Consider using a `virtual environment`_ when installing the package via ``pip``.
-
- *(Optional)* Create the virtual environment by running
+
+ *(Optional)* Create the virtual environment by running
.. code:: bash
-
+
$ virtualenv .env && source .env/bin/activate
- We recommend that you install ``PyTorch`` before CEBRA by selecting the correct version in the `PyTorch Docs`_. Select your desired PyTorch build, operating
- system, select ``pip`` as your package manager and ``Python`` as the language. Select your compute platform (either a
+ We recommend that you install ``PyTorch`` before CEBRA by selecting the correct version in the `PyTorch Docs`_. Select your desired PyTorch build, operating
+ system, select ``pip`` as your package manager and ``Python`` as the language. Select your compute platform (either a
CUDA version or CPU only). Then, use the command to install the PyTorch package. See the ``conda`` tab for examples.
- Then you can install CEBRA, by running one of these lines, depending on your usage, in the root directory.
+ Then you can install CEBRA, by running one of these lines, depending on your usage, in the root directory.
* For **regular usage**, the PyPi package can be installed using
diff --git a/docs/source/usage.rst b/docs/source/usage.rst
index c1682d57..b2c6aa10 100644
--- a/docs/source/usage.rst
+++ b/docs/source/usage.rst
@@ -1,10 +1,10 @@
Using CEBRA
===========
-This page covers a standard CEBRA usage. We recommend checking out the :py:doc:`demos` for in-depth CEBRA usage examples as well. Here we present a quick overview on how to use CEBRA on various datasets. Note that we provide two ways to interact with the code:
+This page covers a standard CEBRA usage. We recommend checking out the :py:doc:`demos` for in-depth CEBRA usage examples as well. Here we present a quick overview on how to use CEBRA on various datasets. Note that we provide two ways to interact with the code:
-* For regular usage, we recommend leveraging the **high-level interface**, adhering to ``scikit-learn`` formatting.
-* Upon specific needs, advanced users might consider diving into the **low-level interface** that adheres to ``PyTorch`` formatting.
+* For regular usage, we recommend leveraging the **high-level interface**, adhering to ``scikit-learn`` formatting.
+* Upon specific needs, advanced users might consider diving into the **low-level interface** that adheres to ``PyTorch`` formatting.
Firstly, why use CEBRA?
@@ -32,14 +32,14 @@ Step-by-step CEBRA
------------------
For a quick start into applying CEBRA to your own datasets, we provide a `scikit-learn compatible `_ API, similar to methods such as tSNE, UMAP, etc.
-We assume you have CEBRA installed in the environment you are working in, if not go to the :doc:`installation`.
-Next, launch your conda env (e.g., ``conda activate cebra``).
+We assume you have CEBRA installed in the environment you are working in, if not go to the :doc:`installation`.
+Next, launch your conda env (e.g., ``conda activate cebra``).
Create a CEBRA workspace
^^^^^^^^^^^^^^^^^^^^^^^^
-Assuming you have your data recorded, you want to start using CEBRA on it.
+Assuming you have your data recorded, you want to start using CEBRA on it.
For instance you can create a new jupyter notebook.
For the sake of this usage guide, we create some example data:
@@ -52,7 +52,7 @@ For the sake of this usage guide, we create some example data:
X = np.random.normal(0,1,(100,3))
X_new = np.random.normal(0,1,(100,4))
np.savez("neural_data", neural = X, new_neural = X_new)
-
+
# Create a .h5 file, containing a pd.DataFrame
import pandas as pd
@@ -66,24 +66,24 @@ For the sake of this usage guide, we create some example data:
You can start by importing the CEBRA package, as well as the CEBRA model as a classical ``scikit-learn`` estimator.
.. testcode::
-
+
import cebra
from cebra import CEBRA
Data loading
^^^^^^^^^^^^
-
+
Get the data ready
""""""""""""""""""
-We acknowledge that your data can come in all formats.
+We acknowledge that your data can come in all formats.
That is why we developed a loading helper function to help you get your data ready to be used by CEBRA.
-The function :py:func:`cebra.load_data` supports various file formats to convert the data of interest to a :py:func:`numpy.array`.
+The function :py:func:`cebra.load_data` supports various file formats to convert the data of interest to a :py:func:`numpy.array`.
It handles three categories of data. Note that it will only read the data of interest and output the corresponding :py:func:`numpy.array`.
-It does not perform pre-processing so your data should be ready to be used for CEBRA.
+It does not perform pre-processing so your data should be ready to be used for CEBRA.
* Your data is a **2D array**. In that case, we handle Numpy, HDF5, PyTorch, csv, Excel, Joblib, Pickle and MAT-files. If your file only containsyour data then you can use the default :py:func:`cebra.load_data`. If your file contains more than one dataset, you will have to provide a ``key``, which corresponds to the data of interest in the file.
@@ -104,7 +104,7 @@ In the following example, ``neural_data.npz`` contains multiple :py:func:`numpy.
discrete_label = cebra.load_data(file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"]).flatten()
You can then use ``neural_data``, ``continuous_label`` or ``discrete_label`` directly as the input or index data of your CEBRA model. Note that we flattened ``discrete_label``
-in order to get a 1D :py:func:`numpy.array` as required for discrete index inputs.
+in order to get a 1D :py:func:`numpy.array` as required for discrete index inputs.
.. note::
@@ -133,7 +133,7 @@ CEBRA allows you to jointly use time-series data and (optionally) auxiliary vari
* **CEBRA-Time:** Discovery-driven: time contrastive learning. Set ``conditional='time'``. No assumption on the behaviors that are influencing neural activity. It can be used as a first step into the data analysis for instance, or as a comparison point to multiple hypothesis-driven analyses.
-To use auxiliary (behavioral) variables you can choose both continuous and discrete variables. The label information (none, discrete, continuous) determine the algorithm to use for data sampling. Using labels allows you to project future behavior onto past time-series activity, and explicitly use label-prior to shape the embedding. The conditional distribution can be chosen upon model initialization with the :py:attr:`cebra.CEBRA.conditional` parameter.
+To use auxiliary (behavioral) variables you can choose both continuous and discrete variables. The label information (none, discrete, continuous) determine the algorithm to use for data sampling. Using labels allows you to project future behavior onto past time-series activity, and explicitly use label-prior to shape the embedding. The conditional distribution can be chosen upon model initialization with the :py:attr:`cebra.CEBRA.conditional` parameter.
* **CEBRA-Behavior:** Hypothesis-driven: behavioral contrastive learning. Set ``conditional='time_delta'``. The user makes an hypothesis on the variables influencing neural activity (behavioral features such as position or head orientation, trial number, brain region, etc.). If the chosen auxiliary variables are in fact influencing the data to reduce, the resulting embedding should reflect that. Hence, it can easily be used to *compare hypotheses*. Auxiliary variables can be multiple, and both continuous and discrete. 👉 Examples on how to select them are presented in :py:doc:`demo_notebooks/Demo_primate_reaching`.
@@ -151,7 +151,7 @@ To use auxiliary (behavioral) variables you can choose both continuous and discr
:width: 500
:alt: CEBRA can be used in three modes: discovery-driven, hypothesis-driven, or in a hybrid mode, which allows for weaker priors on the latent embedding.
:align: center
-
+
*CEBRA sampling schemes: discovery-driven, hypothesis-driven, or in a hybrid mode. In the hypothesis-driven mode, the positive and negative samples are found based on the reference samples.*
👉 Examples on how to use each of the conditional distribution and how to compare them when analyzing data are presented in :doc:`demo_notebooks/Demo_hippocampus`.
@@ -160,9 +160,9 @@ To use auxiliary (behavioral) variables you can choose both continuous and discr
Model definition
^^^^^^^^^^^^^^^^
-CEBRA training is *modular*, and model fitting can serve different downstream applications and research questions. Here, we describe how you can adjust the parameters depending on your data type and the hypotheses you might have.
+CEBRA training is *modular*, and model fitting can serve different downstream applications and research questions. Here, we describe how you can adjust the parameters depending on your data type and the hypotheses you might have.
-.. _Model architecture:
+.. _Model architecture:
.. rubric:: Model architecture :py:attr:`~.CEBRA.model_architecture`
@@ -181,7 +181,7 @@ Then, you can choose the one that fits best with your needs and provide it to th
As an indication the table below presents the model architecture we used to train CEBRA on the datasets presented in our paper (Schneider, Lee, Mathis, 2022).
-.. list-table::
+.. list-table::
:widths: 25 25 20 30
:header-rows: 1
@@ -191,7 +191,7 @@ As an indication the table below presents the model architecture we used to trai
- Model architecture
* - Artificial spiking
- Synthetic
- -
+ -
- 'offset1-model-mse'
* - Rat hippocampus
- Electrophysiology
@@ -209,11 +209,11 @@ As an indication the table below presents the model architecture we used to trai
- Neuropixels
- Visual cortex
- 'offset40-model-4x-subsample'
-
+
.. dropdown:: 🚀 Optional: design your own model architectures
:color: light
-
+
It is possible to construct a personalized model and use the ``@cebra.models.register`` decorator on it. For example:
.. testcode::
@@ -238,7 +238,7 @@ As an indication the table below presents the model architecture we used to trai
normalize=normalize,
)
- # ... and you can also redefine the forward method,
+ # ... and you can also redefine the forward method,
# as you would for a typical pytorch model
def get_offset(self):
@@ -258,16 +258,16 @@ For standard usage we recommend the default values (i.e., ``InfoNCE`` and ``cosi
.. rubric:: Conditional distribution :py:attr:`~.CEBRA.conditional`
-👉 See the :ref:`previous section ` on how to choose the auxiliary variables and a conditional distribution.
+👉 See the :ref:`previous section ` on how to choose the auxiliary variables and a conditional distribution.
.. note::
- If the auxiliary variables types do not match with :py:attr:`~.CEBRA.conditional`, the model training will fall back to time contrastive learning.
+ If the auxiliary variables types do not match with :py:attr:`~.CEBRA.conditional`, the model training will fall back to time contrastive learning.
.. rubric:: Temperature :py:attr:`~.CEBRA.temperature`
-:py:attr:`~.CEBRA.temperature` has the largest effect on visualization of the embedding (see :py:doc:`cebra-figures/figures/ExtendedDataFigure2`). Hence, it is important that it is fitted to your specific data.
+:py:attr:`~.CEBRA.temperature` has the largest effect on visualization of the embedding (see :py:doc:`cebra-figures/figures/ExtendedDataFigure2`). Hence, it is important that it is fitted to your specific data.
-The simplest way to handle it is to use a *learnable temperature*. For that, set :py:attr:`~.CEBRA.temperature_mode` to ``auto``. :py:attr:`~.CEBRA.temperature` will be trained alongside the model.
+The simplest way to handle it is to use a *learnable temperature*. For that, set :py:attr:`~.CEBRA.temperature_mode` to ``auto``. :py:attr:`~.CEBRA.temperature` will be trained alongside the model.
🚀 For advance usage, you might need to find the optimal :py:attr:`~.CEBRA.temperature`. For that we recommend to perform a grid-search.
@@ -275,9 +275,9 @@ The simplest way to handle it is to use a *learnable temperature*. For that, set
.. rubric:: Time offsets :math:`\Delta` :py:attr:`~.CEBRA.time_offsets`
-This corresponds to the distance (in time) between positive pairs and informs the algorithm about the time-scale of interest.
+This corresponds to the distance (in time) between positive pairs and informs the algorithm about the time-scale of interest.
-The interpretation of this parameter depends on the chosen conditional distribution. A higher time offset typically will increase the difficulty of the learning task, and (within a range) improve the quality of the representation.
+The interpretation of this parameter depends on the chosen conditional distribution. A higher time offset typically will increase the difficulty of the learning task, and (within a range) improve the quality of the representation.
For time-contrastive learning, we generally recommend that the time offset should be larger than the specified receptive field of the model.
.. rubric:: Number of iterations :py:attr:`~.CEBRA.max_iterations`
@@ -289,7 +289,7 @@ We recommend to use at least 10,000 iterations to train the model. For prototypi
.. rubric:: Number of adaptation iterations :py:attr:`~.CEBRA.max_adapt_iterations`
-One feature of CEBRA is you can apply (adapt) your model to new data. If you are planning to adapt your trained model to a new set of data, we recommend to use around 500 steps to re-tuned the first layer of the model.
+One feature of CEBRA is you can apply (adapt) your model to new data. If you are planning to adapt your trained model to a new set of data, we recommend to use around 500 steps to re-tuned the first layer of the model.
In the paper, we show that fine-tuning the input embedding (first layer) on the novel data while using a pretrained model can be done with 500 steps in 3.5s only, and has better performance overall.
@@ -298,7 +298,7 @@ In the paper, we show that fine-tuning the input embedding (first layer) on the
CEBRA should be trained on the biggest batch size possible. Ideally, and depending on the size of your dataset, you should set :py:attr:`~.CEBRA.batch_size` to ``None`` (default value) which will train the model drawing samples from the full dataset at each iteration. As an indication, all the models used in the paper were trained with ``batch_size=512``. You should avoid having to set your batch size to a smaller value.
.. warning::
- Using the full dataset (``batch_size=None``) is only implemented for single-session training with continuous auxiliary variables.
+ Using the full dataset (``batch_size=None``) is only implemented for single-session training with continuous auxiliary variables.
Here is an example of a CEBRA model initialization:
@@ -343,7 +343,7 @@ Single-session versus multi-session training
We provide both single-sesison and multi-session training. The latest makes the resulting embeddings **invariant to the auxiliary variables** across all sessions.
.. note::
- For flexibility reasons, the multi-session training fits one model for each session and thus sessions don't necessarily have the same number of features (e.g., number of neurons).
+ For flexibility reasons, the multi-session training fits one model for each session and thus sessions don't necessarily have the same number of features (e.g., number of neurons).
Check out the following list to verify if the multi-session implementation is the right tool for your needs.
@@ -359,11 +359,11 @@ Check out the following list to verify if the multi-session implementation is th
|uncheck| I want to be able to use CEBRA for a new session that is fully unseen during training.
-.. warning::
+.. warning::
Using multi-session training limits the **influence of individual variations per session** on the embedding. Make sure that this session/animal-specific information won't be needed in your downstream analysis.
-👉 Have a look at :py:doc:`demo_notebooks/Demo_hippocampus_multisession` for more in-depth usage examples of the multi-session training.
+👉 Have a look at :py:doc:`demo_notebooks/Demo_hippocampus_multisession` for more in-depth usage examples of the multi-session training.
Training
""""""""
@@ -383,14 +383,14 @@ CEBRA is trained using :py:meth:`cebra.CEBRA.fit`, similarly to the examples bel
discrete_label = np.random.randint(0,10,(timesteps,))
single_cebra_model = cebra.CEBRA(batch_size=512,
- output_dimension=out_dim,
- max_iterations=10,
+ output_dimension=out_dim,
+ max_iterations=10,
max_adapt_iterations=10)
Note that the ``discrete_label`` array needs to be one dimensional, and needs to be of type :py:class:`int`.
-We can now fit the model in different modes.
+We can now fit the model in different modes.
* For **CEBRA-Time (time-contrastive training)** with the chosen ``time_offsets``, run:
@@ -398,13 +398,13 @@ We can now fit the model in different modes.
single_cebra_model.fit(neural_data)
-* For **CEBRA-Behavior (supervised constrastive learning)** using **discrete labels**, run:
+* For **CEBRA-Behavior (supervised constrastive learning)** using **discrete labels**, run:
.. testcode::
single_cebra_model.fit(neural_data, discrete_label)
-* For **CEBRA-Behavior (supervised constrastive learning)** using **continuous labels**, run:
+* For **CEBRA-Behavior (supervised constrastive learning)** using **continuous labels**, run:
.. testcode::
@@ -439,8 +439,8 @@ For multi-sesson training, lists of data are provided instead of a single datase
continuous_label2 = np.random.uniform(0,1,(timesteps2, 3))
multi_cebra_model = cebra.CEBRA(batch_size=512,
- output_dimension=out_dim,
- max_iterations=10,
+ output_dimension=out_dim,
+ max_iterations=10,
max_adapt_iterations=10)
Once you defined your CEBRA model, you can run:
@@ -452,23 +452,23 @@ Once you defined your CEBRA model, you can run:
.. admonition:: See API docs
:class: dropdown
-
+
.. autofunction:: cebra.CEBRA.fit
:noindex:
.. rubric:: Partial training
Consistently with the ``scikit-learn`` API, :py:meth:`cebra.CEBRA.partial_fit` can be used to perform incremental learning of your model on multiple data batches.
-That means by using :py:meth:`cebra.CEBRA.partial_fit`, you can fit your model on a set of data a first time and the model training will take on from the resulting
+That means by using :py:meth:`cebra.CEBRA.partial_fit`, you can fit your model on a set of data a first time and the model training will take on from the resulting
parameters to train at the next call of :py:meth:`cebra.CEBRA.partial_fit`, either on a new batch of data with the same number of features or on the same dataset.
-It can be used for both single-session or multi-session training, similarly to :py:meth:`cebra.CEBRA.fit`.
+It can be used for both single-session or multi-session training, similarly to :py:meth:`cebra.CEBRA.fit`.
.. testcode::
cebra_model = cebra.CEBRA(max_iterations=10)
# The model is fitted a first time ...
- cebra_model.partial_fit(neural_data)
+ cebra_model.partial_fit(neural_data)
# ... later on the model can be fitted again
cebra_model.partial_fit(neural_data)
@@ -480,7 +480,7 @@ It can be used for both single-session or multi-session training, similarly to :
.. admonition:: See API docs
:class: dropdown
-
+
.. autofunction:: cebra.CEBRA.partial_fit
:noindex:
@@ -507,10 +507,10 @@ The model will be saved as a ``.pt`` file.
.. admonition:: See API docs
:class: dropdown
-
+
.. autofunction:: cebra.CEBRA.save
:noindex:
-
+
.. autofunction:: cebra.CEBRA.load
:noindex:
@@ -521,18 +521,18 @@ Grid search
"""""""""""
.. tip::
-
+
A **grid-search** is the process of performing hyperparameter tuning in order to determine the optimal values of a given model. Practically, it consists in running a model on the data, by modifying the hyperparameters values at each iteration. Then, evaluating the performances of each model allows the user to select the best set of hyperparameters for its specific data.
In order to optimize a CEBRA model to the data, we recommend fine-tuning the parameters. For that, you can perform a grid-search over the hyperparameters you want to optimize.
-We provide a simple hyperparameters sweep to compare CEBRA models with different parameters over different datasets or combinations of data and auxiliary variables.
+We provide a simple hyperparameters sweep to compare CEBRA models with different parameters over different datasets or combinations of data and auxiliary variables.
-.. testcode::
+.. testcode::
import cebra
-
+
# 1. Define the parameters, either variable or fixed
params_grid = dict(
output_dimension = [3, 16],
@@ -541,7 +541,7 @@ We provide a simple hyperparameters sweep to compare CEBRA models with different
max_iterations = 5,
temperature_mode = "auto",
verbose = False)
-
+
# 2. Define the datasets to iterate over
datasets = {"dataset1": neural_session1, # time contrastive learning
"dataset2": (neural_session1, continuous_label1), # behavioral contrastive learning
@@ -581,7 +581,7 @@ Once the model is trained, embeddings can be computed using :py:meth:`cebra.CEBR
.. rubric:: Single-session training
-For a model trained on a single session, you just have to provide the input data on which to compte the embedding.
+For a model trained on a single session, you just have to provide the input data on which to compte the embedding.
.. testcode::
@@ -590,7 +590,7 @@ For a model trained on a single session, you just have to provide the input data
.. rubric:: Multi-session training
-For a model trained on multiple sessions, you will need to provide the ``session_id`` (between ``0`` and ``num_sessions-1``), to select the model corresponding to the accurate number of features.
+For a model trained on multiple sessions, you will need to provide the ``session_id`` (between ``0`` and ``num_sessions-1``), to select the model corresponding to the accurate number of features.
.. testcode::
@@ -602,17 +602,17 @@ In both case, the embedding will be of size ``time x`` :py:attr:`~.CEBRA.output_
.. admonition:: See API docs
:class: dropdown
-
+
.. autofunction:: cebra.CEBRA.transform
:noindex:
-Results visualization
+Results visualization
^^^^^^^^^^^^^^^^^^^^^
Here, we want to emphasize that if CEBRA is providing a low-dimensional representation of your data, i.e., the embedding, there are also plenty of elements that should be checked to assess the results. We provide a post-hoc package to easily visualize the crucial information.
-The visualization functions all have the same structure such that they are merely wrappers around :py:func:`matplotlib.pyplot.plot` and :py:func:`matplotlib.pyplot.scatter`. Consequently, you can provide the functions parameters to be used by those ``matplotlib.pyplot`` functions.
+The visualization functions all have the same structure such that they are merely wrappers around :py:func:`matplotlib.pyplot.plot` and :py:func:`matplotlib.pyplot.scatter`. Consequently, you can provide the functions parameters to be used by those ``matplotlib.pyplot`` functions.
*Note that all examples were computed on the rat hippocampus dataset (Grosmark & Buzsáki, 2016) with default parameters,* ``max_iterations=15000`` *,* ``batch_size=512`` *,* ``model_architecture=offset10-model`` *,* ``output_dimension=3`` *except if stated otherwise.*
@@ -623,10 +623,10 @@ Displaying the embedding
To get a 3D visualization of an embedding ``embedding``, obtained using :py:meth:`cebra.CEBRA.transform` (see above), you can use :py:func:`~.plot_embedding`.
-It takes a 2D matrix representing an embedding and returns a 3D scatter plot by taking the 3 first latents by default.
+It takes a 2D matrix representing an embedding and returns a 3D scatter plot by taking the 3 first latents by default.
.. note::
- If your embedding only has 2 dimensions, then the plot will automatically switch to a 2D mode. You can then use the function
+ If your embedding only has 2 dimensions, then the plot will automatically switch to a 2D mode. You can then use the function
similarly.
@@ -659,7 +659,7 @@ It takes a 2D matrix representing an embedding and returns a 3D scatter plot by
:width: 500
:alt: Matplotlib list of named colors
:align: center
-
+
.. testcode::
@@ -672,8 +672,8 @@ It takes a 2D matrix representing an embedding and returns a 3D scatter plot by
:align: center
- * By setting ``embedding_labels`` to ``time``. It will use the color map ``cmap`` to display the embedding based on temporality. By default, ``cmap=cool``. You can customize it by setting it to a valid :py:class:`matplotlib.colors.Colormap` (see `Choosing Colormaps in Matplotlib `_ for more information). You can also use our CEBRA-custom colormap by setting ``cmap="cebra"``.
-
+ * By setting ``embedding_labels`` to ``time``. It will use the color map ``cmap`` to display the embedding based on temporality. By default, ``cmap=cool``. You can customize it by setting it to a valid :py:class:`matplotlib.colors.Colormap` (see `Choosing Colormaps in Matplotlib `_ for more information). You can also use our CEBRA-custom colormap by setting ``cmap="cebra"``.
+
.. figure:: docs-imgs/cebra-colormap.png
:width: 1000
:alt: darkorchid embedding
@@ -681,10 +681,10 @@ It takes a 2D matrix representing an embedding and returns a 3D scatter plot by
*CEBRA-custom colormap. You can use it by calling* ``cmap="cebra"`` *.*
-
+
In the following example, you can also see how to change the size (``markersize``) or the transparency (``alpha``) of the markers.
- .. testcode::
+ .. testcode::
cebra.plot_embedding(embedding, embedding_labels="time", cmap="magma", markersize=5, alpha=0.5)
@@ -712,13 +712,13 @@ It takes a 2D matrix representing an embedding and returns a 3D scatter plot by
``embedding_labels`` must be uni-dimensional. Be sure to provide only one dimension of your auxiliary variables if you are using multi-dimensional continuous data for instance (e.g., only the x-coordinate of the position).
- You can specify the **latents to display** by setting ``idx_order=(latent_num_1, latent_num_2, latent_num_3)`` with ``latent_num_*`` the latent indices of your choice.
+ You can specify the **latents to display** by setting ``idx_order=(latent_num_1, latent_num_2, latent_num_3)`` with ``latent_num_*`` the latent indices of your choice.
In the following example we trained a model with ``output_dimension==10`` and we show embeddings when displaying latents (1, 2, 3) on the left and (4, 5, 6) on the right respectively. The code snippet also offers an example on how to combine multiple graphs and how to set a customized title (``title``). Note the parameter ``projection="3d"`` when adding a subplot to the figure.
-
+
.. testcode::
import matplotlib.pyplot as plt
- fig = plt.figure(figsize=(10,5))
+ fig = plt.figure(figsize=(10,5))
ax1 = fig.add_subplot(121, projection="3d")
ax2 = fig.add_subplot(122, projection="3d")
@@ -732,7 +732,7 @@ It takes a 2D matrix representing an embedding and returns a 3D scatter plot by
:align: center
If your embedding only has 2 dimensions or if you only want to display 2 dimensions from it, you can use the same function. The plot will automatically switch to 2D. Then you can use the function as usual.
-
+
The plot will be 2D if:
* If your embedding only has 2 dimensions and you don't specify the ``idx_order`` (then the default will be ``idx_order=(0,1)``)
@@ -753,11 +753,11 @@ It takes a 2D matrix representing an embedding and returns a 3D scatter plot by
.. admonition:: See API docs
:class: dropdown
-
+
.. autofunction:: cebra.plot_embedding
:noindex:
-.. _Visualize the training loss:
+.. _Visualize the training loss:
Displaying the training loss
""""""""""""""""""""""""""""
@@ -766,7 +766,7 @@ Observing the training loss is of great importance. It allows you to assess that
To visualize the loss evolution through training, you can use :py:func:`~.plot_loss`.
-It takes a CEBRA model and returns a 2D plot of the loss against the number of iterations. It can be used with default values as simply as this:
+It takes a CEBRA model and returns a 2D plot of the loss against the number of iterations. It can be used with default values as simply as this:
.. testcode::
@@ -784,7 +784,7 @@ It takes a CEBRA model and returns a 2D plot of the loss against the number of i
.. admonition:: See API docs
:class: dropdown
-
+
.. autofunction:: cebra.plot_loss
:noindex:
@@ -793,9 +793,9 @@ Displaying the temperature
:py:attr:`~.CEBRA.temperature` has the largest effect on the visualization of the embedding. Hence it might be interesting to check its evolution when ``temperature_mode=auto``.
-To that extend, you can use the function :py:func:`~.plot_temperature`.
+To that extend, you can use the function :py:func:`~.plot_temperature`.
-It takes a CEBRA model and returns a 2D plot of the value of :py:attr:`~.CEBRA.temperature` against the number of iterations. It can be used with default values as simply as this:
+It takes a CEBRA model and returns a 2D plot of the value of :py:attr:`~.CEBRA.temperature` against the number of iterations. It can be used with default values as simply as this:
.. testcode::
@@ -812,20 +812,20 @@ It takes a CEBRA model and returns a 2D plot of the value of :py:attr:`~.CEBRA.t
.. admonition:: See API docs
:class: dropdown
-
+
.. autofunction:: cebra.plot_temperature
:noindex:
-Comparing models
+Comparing models
""""""""""""""""
In order to select the most performant model, you might need to plot the training loss for a set of models on the same figure.
-First, we create a list of fitted models to compare. Here we suppose we have a dataset with neural data available, as well as the position and the direction of the animal. We will show differences in performance when training with any combination of these variables.
+First, we create a list of fitted models to compare. Here we suppose we have a dataset with neural data available, as well as the position and the direction of the animal. We will show differences in performance when training with any combination of these variables.
-.. testcode::
+.. testcode::
cebra_posdir_model = CEBRA(model_architecture='offset10-model',
batch_size=512,
@@ -849,9 +849,9 @@ First, we create a list of fitted models to compare. Here we suppose we have a d
cebra_dir_model.fit(neural_data, discrete_label)
-Then, you can compare their losses. To do that you can use :py:func:`~.compare_models`.
-It takes a list of CEBRA models and returns a 2D plot displaying their training losses.
-It can be used with default values as simply as this:
+Then, you can compare their losses. To do that you can use :py:func:`~.compare_models`.
+It takes a list of CEBRA models and returns a 2D plot displaying their training losses.
+It can be used with default values as simply as this:
.. testcode::
@@ -867,20 +867,20 @@ It can be used with default values as simply as this:
:alt: Default comparison
:align: center
-🚀 The function is a wrapper around :py:func:`matplotlib.pyplot.plot` and consequently accepts all the parameters of that function (e.g., ``alpha``, ``linewidth``, ``title``, ``color``, etc.) as parameters. Note that
-however, if you want to differentiate the traces with a set of colors, you need to provide a `colormap `_ to the ``cmap`` parameter. If you want a unique
-color for all traces, you can provide a `valid color `_ to the ``color`` parameter that will override the ``cmap`` parameter. By default, ``color=None`` and
-``cmap="cebra"`` our very special CEBRA-custom color map.
+🚀 The function is a wrapper around :py:func:`matplotlib.pyplot.plot` and consequently accepts all the parameters of that function (e.g., ``alpha``, ``linewidth``, ``title``, ``color``, etc.) as parameters. Note that
+however, if you want to differentiate the traces with a set of colors, you need to provide a `colormap `_ to the ``cmap`` parameter. If you want a unique
+color for all traces, you can provide a `valid color `_ to the ``color`` parameter that will override the ``cmap`` parameter. By default, ``color=None`` and
+``cmap="cebra"`` our very special CEBRA-custom color map.
.. admonition:: See API docs
:class: dropdown
-
+
.. autofunction:: cebra.compare_models
:noindex:
-What else do to with your CEBRA model
+What else do to with your CEBRA model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
As mentioned at the start of the guide, CEBRA is much more than a visualization tool. Here we present a (non-exhaustive) list of post-hoc analysis and investigations that we support with CEBRA. Happy hacking! 👩💻
@@ -888,18 +888,18 @@ As mentioned at the start of the guide, CEBRA is much more than a visualization
Consistency across features
"""""""""""""""""""""""""""
-One of the major strengths of CEBRA is measuring consistency across embeddings. We demonstrate in Schneider, Lee, Mathis 2023, that consistent latents can be derived across animals (i.e., across CA1 recordings in rats), and even across recording modalities (i.e., from calcium imaging to electrophysiology recordings).
+One of the major strengths of CEBRA is measuring consistency across embeddings. We demonstrate in Schneider, Lee, Mathis 2023, that consistent latents can be derived across animals (i.e., across CA1 recordings in rats), and even across recording modalities (i.e., from calcium imaging to electrophysiology recordings).
Thus, we provide the :py:func:`~.consistency_score` metrics to compute consistency across model runs or models computed on different datasets (i.e., subjects, sessions).
-To use it, you have to set the ``between`` parameter to either ``datasets`` or ``runs``. The main difference between the two modes is that for between-datasets comparisons you will provide
-labels to align the embeddings on. When using between-runs comparison, it supposes that the embeddings are already aligned. The simplest example being the model was run on the same dataset
+To use it, you have to set the ``between`` parameter to either ``datasets`` or ``runs``. The main difference between the two modes is that for between-datasets comparisons you will provide
+labels to align the embeddings on. When using between-runs comparison, it supposes that the embeddings are already aligned. The simplest example being the model was run on the same dataset
but it can also be for datasets that were recorded at the same time for example, i.e., neural activity in different brain regions, recorded during the same session.
.. note::
As consistency between CEBRA runs on the same dataset is demonstrated in Schneider, Lee, Mathis 2023 (consistent up to linear transformations), assessing consistency between different runs on the same dataset is a good way to reinsure you that you set your CEBRA model properly.
-We first create the embeddings to compare: we use two different datasets of data and fit a CEBRA model three times on each.
+We first create the embeddings to compare: we use two different datasets of data and fit a CEBRA model three times on each.
.. testcode::
@@ -910,7 +910,7 @@ We first create the embeddings to compare: we use two different datasets of data
output_dimension=32,
max_iterations=5,
time_offsets=10)
-
+
embeddings, dataset_ids, labels = [], [], []
for i in range(n_runs):
embeddings.append(cebra_model.fit_transform(neural_session1, continuous_label1))
@@ -928,17 +928,17 @@ To get the :py:func:`~.consistency_score` on the set of embeddings that we just
.. testcode::
# Between-runs, with dataset IDs (optional)
- scores_runs, pairs_runs, datasets_runs = cebra.sklearn.metrics.consistency_score(embeddings=embeddings,
- dataset_ids=dataset_ids,
+ scores_runs, pairs_runs, datasets_runs = cebra.sklearn.metrics.consistency_score(embeddings=embeddings,
+ dataset_ids=dataset_ids,
between="runs")
assert scores_runs.shape == (n_runs**2 - n_runs, )
assert pairs_runs.shape == (n_datasets, n_runs*n_datasets, 2)
assert datasets_runs.shape == (n_datasets, )
# Between-datasets, by aligning on the labels
- (scores_datasets,
- pairs_datasets,
- datasets_datasets) = cebra.sklearn.metrics.consistency_score(embeddings=embeddings,
+ (scores_datasets,
+ pairs_datasets,
+ datasets_datasets) = cebra.sklearn.metrics.consistency_score(embeddings=embeddings,
labels=labels,
dataset_ids=dataset_ids,
between="datasets")
@@ -948,7 +948,7 @@ To get the :py:func:`~.consistency_score` on the set of embeddings that we just
.. admonition:: See API docs
:class: dropdown
-
+
.. autofunction:: cebra.sklearn.metrics.consistency_score
:noindex:
@@ -963,24 +963,24 @@ You can then display the resulting scores using :py:func:`~.plot_consistency`.
ax1 = cebra.plot_consistency(scores_runs, pairs_runs, datasets_runs, vmin=0, vmax=100, ax=ax1, title="Between-runs consistencies")
ax2 = cebra.plot_consistency(scores_datasets, pairs_datasets, datasets_datasets, vmin=0, vmax=100, ax=ax2, title="Between-subjects consistencies")
-
+
.. figure:: docs-imgs/consistency-score.png
:width: 700
:alt: Consistency scores
:align: center
-🚀 This function is a wrapper around :py:func:`matplotlib.pyplot.imshow` and, similarly to the other plot functions we provide,
+🚀 This function is a wrapper around :py:func:`matplotlib.pyplot.imshow` and, similarly to the other plot functions we provide,
it accepts all the parameters of that function (e.g., cmap, vmax, vmin, etc.) as parameters. Check the full API for more details.
.. admonition:: See API docs
:class: dropdown
-
+
.. autofunction:: cebra.plot_consistency
:noindex:
-Embeddings comparison via the InfoNCE loss
+Embeddings comparison via the InfoNCE loss
""""""""""""""""""""""""""""""""""""""""""
.. rubric:: Usage case 👩🔬
@@ -989,12 +989,12 @@ You can also compare how a new dataset compares to prior models. This can be use
.. rubric:: How to use it
-The performances of a given model on a dataset can be evaluated by using the :py:func:`~.infonce_loss` function. That metric
+The performances of a given model on a dataset can be evaluated by using the :py:func:`~.infonce_loss` function. That metric
corresponds to the loss over the data, obtained using the criterion on which the model was trained (by default, ``infonce``). Hence, the
-smaller that metric is, the higher the model performances on a sample are, and so the better the fit to the positive samples is.
+smaller that metric is, the higher the model performances on a sample are, and so the better the fit to the positive samples is.
.. note::
- As an indication, you can consider that a good trained CEBRA model should get a value for the InfoNCE loss smaller than **~6.1**. If that is not the case,
+ As an indication, you can consider that a good trained CEBRA model should get a value for the InfoNCE loss smaller than **~6.1**. If that is not the case,
you might want to refer to the dedicated section `Improve your model`_.
Here are examples on how you can use :py:func:`~.infonce_loss` on your data for both single-session and multi-session trained models.
@@ -1002,23 +1002,23 @@ Here are examples on how you can use :py:func:`~.infonce_loss` on your data for
.. testcode::
# single-session
- single_score = cebra.sklearn.metrics.infonce_loss(single_cebra_model,
- neural_data,
- continuous_label,
+ single_score = cebra.sklearn.metrics.infonce_loss(single_cebra_model,
+ neural_data,
+ continuous_label,
discrete_label,
num_batches=5)
-
+
# multi-session
- multi_score = cebra.sklearn.metrics.infonce_loss(multi_cebra_model,
- neural_session1,
- continuous_label1,
+ multi_score = cebra.sklearn.metrics.infonce_loss(multi_cebra_model,
+ neural_session1,
+ continuous_label1,
session_id=0,
num_batches=5)
.. admonition:: See API docs
:class: dropdown
-
+
.. autofunction:: cebra.sklearn.metrics.infonce_loss
:noindex:
@@ -1026,10 +1026,10 @@ Here are examples on how you can use :py:func:`~.infonce_loss` on your data for
Adapt the model to new data
"""""""""""""""""""""""""""
-In some cases, it can be useful to adapt your CEBRA model to a novel dataset, with a different number of features.
+In some cases, it can be useful to adapt your CEBRA model to a novel dataset, with a different number of features.
For that, you can set ``adapt=True`` as a parameter of :py:meth:`cebra.CEBRA.fit`. It will reset the first layer of
-the model so that the input dimension corresponds to the new features dimensions and retrain it for
-:py:attr:`cebra.CEBRA.max_adapt_iterations`. You can set that parameter :py:attr:`cebra.CEBRA.max_adapt_iterations`
+the model so that the input dimension corresponds to the new features dimensions and retrain it for
+:py:attr:`cebra.CEBRA.max_adapt_iterations`. You can set that parameter :py:attr:`cebra.CEBRA.max_adapt_iterations`
when initializing your :py:class:`cebra.CEBRA` model.
.. note::
@@ -1041,21 +1041,21 @@ when initializing your :py:class:`cebra.CEBRA` model.
single_cebra_model.fit(neural_session1)
# ... do something with it (embedding, visualization, saving) ...
-
+
# ... and adapt the model
cebra_model.fit(neural_session2, adapt=True)
.. note::
- We recommend that you save your model, using :py:meth:`cebra.CEBRA.save`, before adapting it to a different dataset.
- The adapted model will replace the previous model in ``cebra_model.state_dict_`` so saving it beforehand allows you
- to keep the trained parameters for later. You can then load the model again, using :py:meth:`cebra.CEBRA.load` whenever
+ We recommend that you save your model, using :py:meth:`cebra.CEBRA.save`, before adapting it to a different dataset.
+ The adapted model will replace the previous model in ``cebra_model.state_dict_`` so saving it beforehand allows you
+ to keep the trained parameters for later. You can then load the model again, using :py:meth:`cebra.CEBRA.load` whenever
you need it.
.. admonition:: See API docs
:class: dropdown
-
+
.. autofunction:: cebra.CEBRA.fit
:noindex:
@@ -1071,20 +1071,20 @@ two decoders: :py:class:`~.KNNDecoder` and :py:class:`~.L1LinearRegressor`. Here
from sklearn.model_selection import train_test_split
- # 1. Train a CEBRA-Time model on the whole dataset
+ # 1. Train a CEBRA-Time model on the whole dataset
cebra_model = cebra.CEBRA(max_iterations=10)
cebra_model.fit(neural_data)
embedding = cebra_model.transform(neural_data)
# 2. Split the embedding and label to decode into train/validation sets
(
- train_embedding,
- valid_embedding,
- train_discrete_label,
+ train_embedding,
+ valid_embedding,
+ train_discrete_label,
valid_discrete_label,
- ) = train_test_split(embedding,
- discrete_label,
- test_size=0.3)
+ ) = train_test_split(embedding,
+ discrete_label,
+ test_size=0.3)
# 3. Train the decoder on the training set
decoder = cebra.KNNDecoder()
@@ -1093,16 +1093,16 @@ two decoders: :py:class:`~.KNNDecoder` and :py:class:`~.L1LinearRegressor`. Here
# 4. Get the score on the validation set
score = decoder.score(valid_embedding, valid_discrete_label)
- # 5. Get the discrete labels predictions
+ # 5. Get the discrete labels predictions
prediction = decoder.predict(valid_embedding)
``prediction`` contains the predictions of the decoder on the discrete labels.
.. warning::
- Be careful to avoid `double dipping `_ when using the decoder. The previous example uses time contrastive learning.
- If you are using CEBRA-Behavior or CEBRA-Hybrid and you consequently use labels, you will have to split your
+ Be careful to avoid `double dipping `_ when using the decoder. The previous example uses time contrastive learning.
+ If you are using CEBRA-Behavior or CEBRA-Hybrid and you consequently use labels, you will have to split your
original data from start as you don't want decode labels from an embedding that is itself trained on those labels.
-
+
.. dropdown:: 👉 Decoder example with CEBRA-Behavior
:color: light
@@ -1113,26 +1113,26 @@ two decoders: :py:class:`~.KNNDecoder` and :py:class:`~.L1LinearRegressor`. Here
# 1. Split your neural data and auxiliary variable
(
- train_data,
- valid_data,
- train_discrete_label,
+ train_data,
+ valid_data,
+ train_discrete_label,
valid_discrete_label,
- ) = train_test_split(neural_data,
- discrete_label,
+ ) = train_test_split(neural_data,
+ discrete_label,
test_size=0.2)
# 2. Train a CEBRA-Behavior model on training data only
cebra_model = cebra.CEBRA(max_iterations=10, batch_size=512)
cebra_model.fit(train_data, train_discrete_label)
-
+
# 3. Get embedding for training and validation data
train_embedding = cebra_model.transform(train_data)
valid_embedding = cebra_model.transform(valid_data)
-
+
# 4. Train the decoder on training embedding and labels
decoder = cebra.KNNDecoder()
decoder.fit(train_embedding, train_discrete_label)
-
+
# 5. Compute the score on validation embedding and labels
score = decoder.score(valid_embedding, valid_discrete_label)
@@ -1140,10 +1140,10 @@ two decoders: :py:class:`~.KNNDecoder` and :py:class:`~.L1LinearRegressor`. Here
.. admonition:: See API docs
:class: dropdown
-
+
.. autofunction:: cebra.KNNDecoder.fit
:noindex:
-
+
.. autofunction:: cebra.KNNDecoder.score
:noindex:
@@ -1153,7 +1153,7 @@ two decoders: :py:class:`~.KNNDecoder` and :py:class:`~.L1LinearRegressor`. Here
Improve model performance
^^^^^^^^^^^^^^^^^^^^^^^^^
-🧐 Below is a (non-exhaustive) list of actions you can try if your embedding looks different from what you were expecting.
+🧐 Below is a (non-exhaustive) list of actions you can try if your embedding looks different from what you were expecting.
#. Assess that your model `converged `_. For that, observe if the training loss stabilizes itself around the end of the training or still seems to be decreasing. Refer to `Visualize the training loss`_ for more details on how to display the training loss.
#. Increase the number of iterations. It should be at least 10,000.
@@ -1165,14 +1165,14 @@ Improve model performance
Quick Start: Scikit-learn API example
-------------------------------------
-Putting all previous snippet examples together, we obtain the following pipeline.
+Putting all previous snippet examples together, we obtain the following pipeline.
.. testcode::
import cebra
from numpy.random import uniform, randint
from sklearn.model_selection import train_test_split
-
+
# 1. Define a CEBRA model
cebra_model = cebra.CEBRA(
model_architecture = "offset10-model",
@@ -1190,7 +1190,7 @@ Putting all previous snippet examples together, we obtain the following pipeline
new_neural_data = cebra.load_data(file="neural_data.npz", key="new_neural")
continuous_label = cebra.load_data(file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["continuous1", "continuous2", "continuous3"])
discrete_label = cebra.load_data(file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"]).flatten()
-
+
assert neural_data.shape == (100, 3)
assert new_neural_data.shape == (100, 4)
assert discrete_label.shape == (100, )
@@ -1208,7 +1208,7 @@ Putting all previous snippet examples together, we obtain the following pipeline
discrete_label,
continuous_label,
test_size=0.3)
-
+
# 4. Fit the model
# time contrastive learning
cebra_model.fit(train_data)
@@ -1218,27 +1218,27 @@ Putting all previous snippet examples together, we obtain the following pipeline
cebra_model.fit(train_data, train_continuous_label)
# mixed behavior contrastive learning
cebra_model.fit(train_data, train_discrete_label, train_continuous_label)
-
+
# 5. Save the model
cebra_model.save('/tmp/foo.pt')
-
+
# 6. Load the model and compute an embedding
cebra_model = cebra.CEBRA.load('/tmp/foo.pt')
train_embedding = cebra_model.transform(train_data)
valid_embedding = cebra_model.transform(valid_data)
assert train_embedding.shape == (70, 8)
assert valid_embedding.shape == (30, 8)
-
+
# 7. Evaluate the model performances
- goodness_of_fit = cebra.sklearn.metrics.infonce_loss(cebra_model,
- valid_data,
- valid_discrete_label,
- valid_continuous_label,
+ goodness_of_fit = cebra.sklearn.metrics.infonce_loss(cebra_model,
+ valid_data,
+ valid_discrete_label,
+ valid_continuous_label,
num_batches=5)
-
+
# 8. Adapt the model to a new session
cebra_model.fit(new_neural_data, adapt = True)
-
+
# 9. Decode discrete labels behavior from the embedding
decoder = cebra.KNNDecoder()
decoder.fit(train_embedding, train_discrete_label)
@@ -1251,9 +1251,9 @@ Putting all previous snippet examples together, we obtain the following pipeline
Quick Start: Torch API example
------------------------------
-🚀 You have special custom data analysis needs? We invite you to use the ``torch``-API interface.
+🚀 You have special custom data analysis needs? We invite you to use the ``torch``-API interface.
-Refer to the ``examples/`` folder for a set of demo scripts.
+Refer to the ``examples/`` folder for a set of demo scripts.
Single- and multi-session training can be launched using the following ``bash`` command.
.. code:: bash
@@ -1315,4 +1315,3 @@ Below is the documentation on the available arguments.
--train-ratio 0.8 Ratio of train dataset. The remaining will be used for valid and test split.
--valid-ratio 0.1 Ratio of validation set after the train data split. The remaining will be test split
--share-model
-
diff --git a/setup.cfg b/setup.cfg
index 3d16b29e..ddd269ee 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -22,7 +22,7 @@ classifiers =
[options]
packages = find:
-where =
+where =
- .
- tests
python_requires = >=3.8
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 159dabc4..3b9dfa3c 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -181,7 +181,7 @@ def test_all_multisubject(dataset):
def test_compat_fix611(dataset):
"""Check that confirm the fix applied in internal PR #611
- The PR removed the explicit continuous and discrete args from the
+ The PR removed the explicit continuous and discrete args from the
datasets used to parametrize this function. We manually check that
the continuous index is available, and no discrete index is set.
diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py
index c260a30b..2a8026b9 100644
--- a/tests/test_sklearn.py
+++ b/tests/test_sklearn.py
@@ -439,8 +439,9 @@ def test_adapt_model():
assert before_adapt.keys() == after_adapt.keys()
for key in before_adapt.keys():
if key in adaptation_param_key:
- assert (before_adapt[key].shape != after_adapt[key].shape
- ) or not torch.allclose(before_adapt[key], after_adapt[key])
+ assert (before_adapt[key].shape
+ != after_adapt[key].shape) or not torch.allclose(
+ before_adapt[key], after_adapt[key])
else:
assert torch.allclose(before_adapt[key], after_adapt[key])
diff --git a/tests/test_usecases.py b/tests/test_usecases.py
index 035074fb..50ecb412 100644
--- a/tests/test_usecases.py
+++ b/tests/test_usecases.py
@@ -121,8 +121,9 @@ def test_hybrid():
"leave_out,args",
[(leave_out, args) for args in _args for leave_out in args.keys()])
def test_leave_arg_out(leave_out, args):
- model = cebra.CEBRA(**{k: v for k, v in args.items() if k != leave_out},
- **_default_kwargs())
+ model = cebra.CEBRA(**{
+ k: v for k, v in args.items() if k != leave_out
+ }, **_default_kwargs())
_run_test(model)