diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a31d1df4..1cb89b6f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,13 +8,18 @@ on: branches: - main - development - types: [opened, synchronize, reopened, ready_for_review] + types: + - opened + - reopened + - synchronize + - ready_for_review push: branches: - main jobs: tox: + if: ${{ !github.event.pull_request.draft }} strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] @@ -45,7 +50,7 @@ jobs: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} check: - if: always() + if: ${{ !github.event.pull_request.draft }} needs: tox runs-on: ubuntu-latest steps: diff --git a/.github/workflows/connect.yml b/.github/workflows/connect.yml index 63dac019..06b6e053 100644 --- a/.github/workflows/connect.yml +++ b/.github/workflows/connect.yml @@ -10,12 +10,12 @@ jobs: name: build # Name of the job strategy: matrix: - os: [ubuntu-latest] #[ubuntu-latest, macos-latest, windows-latest] - python-version: ['3.8', '3.9', '3.10'] + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ['3.9'] runs-on: ${{ matrix.os }} # Operating system for the job steps: - name: Checkout # Step to checkout the repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python # Step to set up Python uses: actions/setup-python@v4 # Use v4 for compatibility with pyproject.toml diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 25694f90..3bcc3e63 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,9 +1,221 @@ # Contributing -The NeMoS package is designed to provide a robust set of statistical analysis tools for neuroscience research. While the repository is managed by a core team of data scientists at the Center for Computational Neuroscience of the Flatiron Institute, we warmly welcome contributions from external collaborators. +The NeMoS package is designed to provide a robust set of statistical analysis tools for neuroscience research. While the repository is managed by a core team of data scientists at the Center for Computational Neuroscience of the Flatiron Institute, we warmly welcome contributions from external collaborators. +This guide explains how to contribute: if you have questions about the process, please feel free to reach out on [Github Discussions](https://github.com/flatironinstitute/nemos/discussions). ## General Guidelines Developers are encouraged to contribute to various areas of development. This could include the creation of concrete classes, such as those for new basis function types, or the addition of further checks at evaluation. Enhancements to documentation and the overall readability of the code are also greatly appreciated. -Feel free to work on any section of code that you believe you can improve. More importantly, remember to thoroughly test all your classes and functions, and to provide clear, detailed comments within your code. This not only aids others in using your library, but also facilitates future maintenance and further development. +Feel free to work on any section of code that you believe you can improve. More importantly, remember to thoroughly test all your classes and functions, and to provide clear, detailed comments within your code. This not only aids others in using the library, but also facilitates future maintenance and further development. + +For more detailed information about NeMoS modules, including design choices and implementation details, visit the [`For Developers`](https://nemos.readthedocs.io/en/latest/developers_notes/) section of the package documentation. + +## Contributing to the code + +### Contribution workflow cycle + +In order to contribute, you will need to do the following: + +1) Create your own branch +2) Make sure that tests pass and code coverage is maintained +3) Open a Pull Request + +The NeMoS package follows the [Git Flow](https://www.atlassian.com/git/tutorials/comparing-workflows/gitflow-workflow) workflow. In essence, there are two primary branches, `main` and `development`, to which no one is allowed to +push directly. All development happens in separate feature branches that are then merged into `development` once we have determined they are ready. When enough changes have accumulated, `developemnt` is merged into `main`, and a new release is +generated. This process includes adding a new tag to increment the version number and uploading the new release to PyPI. + + +#### Creating a development environment + +You will need a local installation of `nemos` which keeps up-to-date with any changes you make. To do so, you will need to fork and clone `nemos` before checking out +a new branch: + +1) Go to the [nemos repo](https://github.com/flatironinstitute/nemos) and click on the `Fork` button at the top right of the page. This will create a copy +of `nemos` in your GitHub account. You should then clone *your fork* to your local machine. + +```bash +git clone https://github.com//nemos.git +cd nemos +``` + +2) Install `nemos` in editable mode with developer dependencies + +```bash +pip install -e .[dev] +``` + +> [!NOTE] +> In order to install `nemos` in editable mode you will need a Python virtual environment. Please see our documentation [here](https://nemos.readthedocs.io/en/latest/installation/) that provides guidance on how to create and activate a virtual environment. + +3) Add the upstream branch: + +```bash +git remote add upstream https://github.com/flatironinstitute/nemos +``` + +At this point you have two remotes: `origin` (your fork) and `upstream` (the canonical version). You won't have permission to push to upstream (only `origin`), but +this make it easy to keep your `nemos` up-to-date with the canonical version by pulling from upstream: `git pull upstream`. + +#### Creating a new branch + +As mentioned previously, each feature in `nemos` is worked on in a separate branch. This allows multiple people developing multiple features simultaneously, without interfering with each other's work. To create +your own branch, run the following from within your `nemos` directory: + +> [!NOTE] +> Below we are checking out the `development` branch. In terms of the `nemos` contribution workflow cycle, the `development` branch accumulates a series of changes from different feature branches that are then all merged into the `main` branch at one time (normally at the time of a release). + +```bash +# switch to the development branch on your local copy +git checkout development +# update your local copy from your fork +git pull origin development +# sync your local copy with upstream development +git pull upstream development +# update your fork's development branch with any changs from upstream +git push origin development +# create and switch to a new branch, where you'll work on your new feature +git checkout -b my_feature_branch +``` + +After you have made changes on this branch, add and commit them when you are ready: + +```bash +# stage the changes +git add src/nemos/the_changed_file.py +# commit your changes +git commit -m "A one-line message explaining the changes made" +# push to the remote origin +git push origin my_feature_branch +``` + +#### Contributing your change back to NeMoS + +You can make any number of changes on your branch. Once you are happy with your changes, add tests to check that they run correctly and add documentation to properly note your changes. +See below for details on how to [add tests](#adding-tests) and properly [document](#adding-documentation) your code. + +Lastly, you should make sure that the existing tests all run successfully and that the codebase is formatted properly: + +```bash +# run tests and make sure they all pass +pytest tests/ +# format the code base +black src/ +isort src +flake8 --config=tox.ini src +``` + +> [!IMPORTANT] +> [`black`](https://black.readthedocs.io/en/stable/) and [`isort`](https://pycqa.github.io/isort/) automatically reformat your code and organize your imports, respectively. [`flake8`](https://flake8.pycqa.org/en/stable/#) does not modify your code directly; instead, it identifies syntax errors and code complexity issues that need to be addressed manually. + +> [!NOTE] +> If some files were reformatted after running `black`, make sure to commit those changes and push them to your feature branch as well. + +Now you are ready to make a Pull Request (PR). You can open a pull request by clicking on the big `Compare & pull request` button that appears at the top of the `nemos` repo +after pushing to your branch (see [here](https://intersect-training.org/collaborative-git/03-pr/index.html) for a tutorial). + +Your pull request should include the following: +- A summary including information on what you changed and why +- References to relevant issues or discussions +- Special notice to any portion of your changes where you have lingering questions (e.g., "was this the right way to implement this?") or want reviewers to pay special attention to + +Next, we will be notified of the pull request and will read it over. We will try to give an initial response quickly, and then do a longer in-depth review, at which point +you will probably need to respond to our comments, making changes as appropriate. We'll then respond again, and proceed in an iterative fashion until everyone is happy with the proposed +changes. + +Additionally, every PR to `main` or `development` will automatically run linters and tests through a [GitHub action](https://docs.github.com/en/actions). Merges can happen only when all check passes. + +> [!NOTE] +> The [NeMoS GitHub action](.github/workflows/ci.yml) runs tests in an isolated environment using [`tox`](https://tox.wiki/en/). `tox` is not included in our optional dependencies, so if you want to replicate the action workflow locally, you need to install `tox` via pip and then run it. From the package directory: +> ```sh +> pip install tox +> tox -e py +> ``` +> This will execute `tox` with a Python version that matches your local environment. `tox` configurations can be found in the [`tox.ini`](tox.ini) file. + +Once your changes are integrated, you will be added as a GitHub contributor and as one of the authors of the package. Thank you for being part of `nemos`! + +### Style Guide + +The next section will talk about the style of your code and specific requirements for certain feature development in `nemos`. + +- Longer, descriptive names are preferred (e.g., x is not an appropriate name for a variable), especially for anything user-facing, such as methods, attributes, or arguments. +- Any public method or function must have a complete type-annotated docstring (see below for details). Hidden ones do not need to have complete docstrings, but they probably should. + +### Releases + +We create releases on Github, deploy on / distribute via [pypi](https://pypi.org/), and try to follow [semantic versioning](https://semver.org/): + +> Given a version number MAJOR.MINOR.PATCH, increment the: +> 1. MAJOR version when you make incompatible API changes +> 2. MINOR version when you add functionality in a backward compatible manner +> 3. PATCH version when you make backward compatible bug fixes + +ro release a new version, we [create a Github release](https://docs.github.com/en/repositories/releasing-projects-on-github/managing-releases-in-a-repository) with a new tag incrementing the version as described above. Creating the Github release will trigger the deployment to pypi, via our `deploy` action (found in `.github/workflows/deploy-pure-python.yml`). The built version will grab the version tag from the Github release, using [setuptools_scm](https://github.com/pypa/setuptools_scm). + +### Testing + +To run all tests, run `pytest` from within the main `nemos` repository. This may take a while as there are many tests, broken into several categories. +There are several options for how to run a subset of tests: +- Run tests from one file: `pytest tests/test_glm.py` +- Run a specific test within a specific module: `pytests tests/test_glm.py::test_func` +- Another example specifying a test method via the command line: `pytest tests/test_glm.py::GLMClass::test_func` + +#### Adding tests + +New tests can be added in any of the existing `tests/test_*.py` scripts. Tests should be functions, contained within classes. The class contains a bunch of related tests +(e.g., regularizers, bases), and each test should ideally be a unit test, only testing one thing. The classes should be named `TestSomething`, while test functions should be named +`test_something` in snakecase. + +If you're adding a substantial bunch of tests that are separate from the existing ones, you can create a new test script. Its name must begin with `test_`, +it must have an `.py` extension, and it must be contained within the `tests` directory. Assuming you do that, our github actions will automatically find it and +add it to the tests-to-run. + +> [!NOTE] +> If you have many variants on a test you wish to run, you should make use of pytest's `parameterize` mark. See the official documentation [here](https://docs.pytest.org/en/stable/how-to/parametrize.html) and NeMoS [`test_error_invalid_entry`](https://github.com/flatironinstitute/nemos/blob/main/tests/test_vallidation.py#L27) for a concrete implementation. + +> [!NOTE] +> If you are using an object that gets used in multiple tests (such as a model with certain data, regularizer, or solver), you should use pytest's `fixtures` to avoid having to load or instantiate the object multiple times. Look at our `conftest.py` to see already available fixtures for your tests. See the official documentation [here](https://docs.pytest.org/en/stable/how-to/fixtures.html). + +### Documentation + +Documentation is a crucial part of open-source software and greatly influences the ability to use a codebase. As such, it is imperative that any new changes are +properly documented as outlined below. + +#### Adding documentation + +1) **Docstrings** + +All public-facing functions and classes should have complete docstrings, which start with a one-line short summary of the function, +a medium-length description of the function / class and what it does, and a complete description of all arguments and return values. +Math should be included in a `Notes` section when necessary to explain what the function is doing, and references to primary literature +should be included in a `References` section when appropriate. Docstrings should be relatively short, providing the information necessary +for a user to use the code. + +Private functions and classes should have sufficient explanation that other developers know what the function / class does and how to use it, +but do not need to be as extensive. + +We follow the [numpydoc](https://numpydoc.readthedocs.io/en/latest/) conventions for docstring structure. + +2) **Examples/Tutorials** + +If your changes are significant (add a new functionality or drastically change the current codebase), then the current examples may need to be updated or +a new example may need to be added. + +All examples live within the `docs/` subfolder of `nemos`. These are written as `.py` files but are converted to +notebooks by [`mkdocs-gallery`](https://smarie.github.io/mkdocs-gallery/), and have a special syntax, as demonstrated in this [example +gallery](https://smarie.github.io/mkdocs-gallery/generated/gallery/). + +We avoid using `.ipynb` notebooks directly because their JSON-based format makes them difficult to read, interpret, and resolve merge conflicts in version control. + +To see if changes you have made break the current documentation, you can build the documentation locally. + +```bash +# Clear the cached documentation pages +# This step is only necessary if your changes affected the src/ directory +rm -r docs/generated +# build the docs within the nemos repo +mkdocs build +``` + +If the build fails, you will see line-specific errors that prompted the failure. diff --git a/docs/developers_notes/01-basis_module.md b/docs/developers_notes/01-basis_module.md index c0bdf6f8..f6180c5b 100644 --- a/docs/developers_notes/01-basis_module.md +++ b/docs/developers_notes/01-basis_module.md @@ -26,18 +26,27 @@ Abstract Class Basis └─ Concrete Subclass OrthExponentialBasis ``` -The super-class `Basis` provides two public methods, [`evaluate`](#the-public-method-evaluate) and [`evaluate_on_grid`](#the-public-method-evaluate_on_grid). These methods perform checks on both the input provided by the user and the output of the evaluation to ensure correctness, and are thus considered "safe". They both make use of the private abstract method `_evaluate` that is specific for each concrete class. See below for more details. +The super-class `Basis` provides two public methods, [`compute_features`](#the-public-method-evaluate) and [`evaluate_on_grid`](#the-public-method-evaluate_on_grid). These methods perform checks on both the input provided by the user and the output of the evaluation to ensure correctness, and are thus considered "safe". They both make use of the abstract method `__call__` that is specific for each concrete class. See below for more details. ## The Class `nemos.basis.Basis` -### The Public Method `evaluate` +### The Public Method `compute_features` -The `evaluate` method checks input consistency and evaluates the basis function at some sample points. It accepts one or more numpy arrays as input, which represent the sample points at which the basis will be evaluated, and performs the following steps: +The `compute_features` method checks input consistency and applies the basis function to the inputs. +`Basis` can operate in two modes defined at initialization: `"eval"` and `"conv"`. When a basis is in mode `"eval"`, +`compute_features` evaluates the basis at the given input samples. When in mode `"conv"`, it will convolve the samples +with a bank of kernels, one per basis function. + +It accepts one or more NumPy array or pynapple `Tsd` object as input, and performs the following steps: 1. Checks that the inputs all have the same sample size `M`, and raises a `ValueError` if this is not the case. 2. Checks that the number of inputs matches what the basis being evaluated expects (e.g., one input for a 1-D basis, N inputs for an N-D basis, or the sum of N 1-D bases), and raises a `ValueError` if this is not the case. -3. Calls the `_evaluate` method on the input, which is the subclass-specific implementation of the basis set evaluation. -4. Returns a numpy array of shape `(M, n_basis_funcs)`, with each basis element evaluated at the samples. +3. In `"eval"` mode, calls the `__call__` method on the input, which is the subclass-specific implementation of the basis set evaluation. In `"conv"` mode, generates a filter bank using `evaluate_on_grid` and then applies the convolution to the input with `nemos.convolve.create_convolutional_predictor`. +4. Returns a NumPy array or pynapple `TsdFrame` of shape `(M, n_basis_funcs)`, with each basis element evaluated at the samples. + +!!! note "Multiple epochs" + Note that the convolution works gracefully with multiple disjoint epochs, when a pynapple time series is used as + input. ### The Public Method `evaluate_on_grid` @@ -47,14 +56,14 @@ This method performs the following steps: 1. Checks that the number of inputs matches what the basis being evaluated expects (e.g., one input for a 1-D basis, N inputs for an N-D basis, or the sum of N 1-D bases), and raises a `ValueError` if this is not the case. 2. Calls `_get_samples` method, which returns equidistant samples over the domain of the basis function. The domain may depend on the type of basis. -3. Calls the `evaluate` method. +3. Calls the `__call__` method. 4. Returns both the sample grid points of shape `(m1, ..., mN)`, and the evaluation output at each grid point of shape `(m1, ..., mN, n_basis_funcs)`, where `mi` is the number of sample points for the i-th axis of the grid. ### Abstract Methods The `nemos.basis.Basis` class has the following abstract methods, which every concrete subclass must implement: -1. `_evaluate`: Evaluates a basis over some specified samples. +1. `__call__`: Evaluates a basis over some specified samples. 2. `_check_n_basis_min`: Checks the minimum number of basis functions required. This requirement can be specific to the type of basis. ## Contributors Guidelines @@ -63,8 +72,7 @@ The `nemos.basis.Basis` class has the following abstract methods, which every co To write a usable (i.e., concrete, non-abstract) basis object, you - **Must** inherit the abstract superclass `Basis` -- **Must** define the `_evaluate` and `_check_n_basis_min` methods with the expected input/output format, see [Code References](../../reference/nemos/basis/) for the specifics. -- **Should not** overwrite the `evaluate` and `evaluate_on_grid` methods inherited from `Basis`. +- **Must** define the `__call__` and `_check_n_basis_min` methods with the expected input/output format, see [Code References](../../reference/nemos/basis/) for the specifics. +- **Should not** overwrite the `compute_features` and `evaluate_on_grid` methods inherited from `Basis`. - **May** inherit any number of abstract intermediate classes (e.g., `SplineBasis`). -- **May** reimplement the `_get_samples` method if your basis domain differs from `[0,1]`. However, we recommend mapping the specific basis domain to `[0,1]` whenever possible. diff --git a/docs/developers_notes/02-base_class.md b/docs/developers_notes/02-base_class.md index 18a3b8e7..f22d5c87 100644 --- a/docs/developers_notes/02-base_class.md +++ b/docs/developers_notes/02-base_class.md @@ -2,9 +2,9 @@ ## Introduction -The `base_class` module introduces the `Base` class and abstract classes defining broad model categories. These abstract classes **must** inherit from `Base`. +The `base_class` module introduces the `Base` class and abstract classes defining broad model categories and feature constructors. These abstract classes **must** inherit from `Base`. -The `Base` class is envisioned as the foundational component for any object type (e.g., regression, dimensionality reduction, clustering, observation models, regularizers etc.). In contrast, abstract classes derived from `Base` define overarching object categories (e.g., `base_class.BaseRegressor` is building block for GLMs, GAMS, etc. while `observation_models.Observations` is the building block for the Poisson observations, Gamma observations, ... etc.). +The `Base` class is envisioned as the foundational component for any object type (e.g., basis, regression, dimensionality reduction, clustering, observation models, regularizers etc.). In contrast, abstract classes derived from `Base` define overarching object categories (e.g., `base_regressor.BaseRegressor` is building block for GLMs, GAMS, etc. while `observation_models.Observations` is the building block for the Poisson observations, Gamma observations, ... etc.). Designed to be compatible with the `scikit-learn` API, the class structure aims to facilitate access to `scikit-learn`'s robust pipeline and cross-validation modules. This is achieved while leveraging the accelerated computational capabilities of `jax` and `jaxopt` in the backend, which is essential for analyzing extensive neural recordings and fitting large models. @@ -34,20 +34,16 @@ Abstract Class Base │ │ │ ├─ Concrete Subclass PoissonObservations │ │ -│ ├─ Concrete Subclass GammaObservations *(not implemented yet) +│ ├─ Concrete Subclass GammaObservations │ ... │ ... ``` -!!! Example - The current package version includes a concrete class named `nemos.glm.GLM`. This class inherits from `BaseRegressor`, which in turn inherits `Base`, since it falls under the " GLM regression" category. - As any `BaseRegressor`, it **must** implement the `fit`, `score`, `predict`, and `simulate` methods. - ## The Class `model_base.Base` -The `Base` class aligns with the `scikit-learn` API for `base.BaseEstimator`. This alignment is achieved by implementing the `get_params` and `set_params` methods, essential for `scikit-learn` compatibility and foundational for all model implementations. Additionally, the class provides auxiliary helper methods to identify available computational devices (such as GPUs and TPUs) and to facilitate data transfer to these devices. +The `Base` class aligns with the `scikit-learn` API for `base.BaseEstimator`. This alignment is achieved by implementing the `get_params` and `set_params` methods, essential for `scikit-learn` compatibility and foundational for all model implementations. For a detailed understanding, consult the [`scikit-learn` API Reference](https://scikit-learn.org/stable/modules/classes.html) and [`BaseEstimator`](https://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html). @@ -59,42 +55,3 @@ For a detailed understanding, consult the [`scikit-learn` API Reference](https:/ - **`get_params`**: The `get_params` method retrieves parameters set during model instance initialization. Opting for a deep inspection allows the method to assess nested object parameters, resulting in a comprehensive parameter dictionary. - **`set_params`**: The `set_params` method offers a mechanism to adjust or set an estimator's parameters. It's versatile, accommodating both individual estimators and more complex nested structures like pipelines. Feeding an unrecognized parameter will raise a `ValueError`. -## The Abstract Class `model_base.BaseRegressor` - -`BaseRegressor` is an abstract class that inherits from `Base`, stipulating the implementation of abstract methods: `fit`, `predict`, `score`, and `simulate`. This ensures seamless assimilation with `scikit-learn` pipelines and cross-validation procedures. - -### Abstract Methods - -For subclasses derived from `BaseRegressor` to function correctly, they must implement the following: - -1. `fit`: Adapt the model using input data `X` and corresponding observations `y`. -2. `predict`: Provide predictions based on the trained model and input data `X`. -3. `score`: Score the accuracy of model predictions using input data `X` against the actual observations `y`. -4. `simulate`: Simulate data based on the trained regression model. - -### Public Methods - -To ensure the consistency and conformity of input data, the `BaseRegressor` introduces two public preprocessing methods: - -1. `preprocess_fit`: Assesses and converts the input for the `fit` method into the desired `jax.ndarray` format. If necessary, this method can initialize model parameters using default values. -2. `preprocess_simulate`: Validates and converts inputs for the `simulate` method. This method confirms the integrity of the feedforward input and, when provided, the initial values for feedback. - -### Auxiliary Methods - -Moreover, `BaseRegressor` incorporates auxiliary methods such as `_convert_to_jnp_ndarray`, `_has_invalid_entry` -and a number of other methods for checking input consistency. - -!!! Tip - Deciding between concrete and abstract methods in a superclass can be nuanced. As a general guideline: any method that's expected in all subclasses and isn't subclass-specific should be concretely implemented in the superclass. Conversely, methods essential for a subclass's expected behavior, but vary based on the subclass, should be abstract in the superclass. For instance, compatibility with the `sklearn.cross_validation` module demands `score`, `fit`, `get_params`, and `set_params` methods. Given their specificity to individual models, `score` and `fit` are abstract in `BaseRegressor`. Conversely, as `get_params` and `set_params` are consistent across model classes, they're inherited from `Base`. This approach typifies our general implementation strategy. However, it's important to note that while these are sound guidelines, exceptions exist based on various factors like future extensibility, clarity, and maintainability. - - -## Contributor Guidelines - -### Implementing Model Subclasses - -When devising a new model subclass based on the `BaseRegressor` abstract class, adhere to the subsequent guidelines: - -- **Must** inherit the `BaseRegressor` abstract superclass. -- **Must** realize the abstract methods: `fit`, `predict`, `score`, and `simulate`. -- **Should not** overwrite the `get_params` and `set_params` methods, inherited from `Base`. -- **May** introduce auxiliary methods such as `_convert_to_jnp_ndarray` for added utility. diff --git a/docs/developers_notes/03-base_regressor.md b/docs/developers_notes/03-base_regressor.md new file mode 100644 index 00000000..2e03e751 --- /dev/null +++ b/docs/developers_notes/03-base_regressor.md @@ -0,0 +1,65 @@ +# The Abstract Class `BaseRegressor` + +`BaseRegressor` is an abstract class that inherits from `Base`, stipulating the implementation of number of abstract methods such as `fit`, `predict`, `score`. This ensures seamless assimilation with `scikit-learn` pipelines and cross-validation procedures. + + +!!! Example + The current package version includes a concrete class named `nemos.glm.GLM`. This class inherits from `BaseRegressor`, which in turn inherits `Base`, since it falls under the "GLM regression" category. + As a `BaseRegressor`, it **must** implement the `fit`, `score`, `predict` and the other abstract methods of this class, see below. + +### Abstract Methods + +For subclasses derived from `BaseRegressor` to function correctly, they must implement the following: + +1. `fit`: Adapt the model using input data `X` and corresponding observations `y`. +2. `predict`: Provide predictions based on the trained model and input data `X`. +3. `score`: Score the accuracy of model predictions using input data `X` against the actual observations `y`. +4. `simulate`: Simulate data based on the trained regression model. +5. `update`: Run a single optimization step, and stores the updated parameter and solver state. Used by stochastic optimization schemes. +6. `_predict_and_compute_loss`: Compute prediction and evaluates the loss function prvided the parameters and `X` and `y`. This is used by the `instantiate_solver` method which sets up the solver. +7. `_check_params`: Check the parameter structure. +8. `_check_input_dimensionality`: Check the input dimensionality matches model expectation. +9. `_check_input_and_params_consistency`: Checks that the input and the parameters are consistent. +10. `_get_coef_and_intercept` and `_set_coef_and_intercept`: set and get model coefficient and intercept term. + +All the `_check_` methods are called by the `_validate` method which checks that the provided +input and parameters conform with the model requirements. + +### Attributes + +Public attributes are stored as properties: + +- `regularizer`: An instance of the `nemos.regularizer.Regularizer` class. The setter for this property accepts either the instance directly or a string that is used to instantiate the appropriate regularizer. +- `regularizer_strength`: A float quantifying the amount of regularization. +- `solver_name`: One of the `jaxopt` solver supported solvers, currently "GradientDescent", "BFGS", "LBFGS", "ProximalGradient" and, "NonlinearCG". +- `solver_kwargs`: Extra keyword arguments to be passed at solver initialization. +- `solver_init_state`, `solver_update`, `solver_run`: Read-only property with a partially evaluated `solver.init_state`, `solver.update` and, `solver.run` methods. The partial evaluation guarantees a consistent API for all solvers. + +When implementing a new subclass of `BaseRegressor`, the only attributes you must interact directly with are those that operate on the solver, i.e. `solver_init_state`, `solver_update`, `solver_run`. + +Typically, in `YourRegressor` you will call `self.solver_init_state` at the parameter initialization step, `self.sovler_run` in `fit`, and `self.solver_update` in `update`. + +!!! note "Solvers" + Solvers are typically optimizers from the `jaxopt` package, but in principle they could be custom optimization routines as long as they respect the `jaxopt` api (i.e., have a `run`, `init_state`, and `update` method with the appropriate input/output types). + We rely on `jaxopt` because it provides a comprehensive set of robust, GPU accelerated, batchable and differentiable optimizers in JAX, that are highly customizable. In the future we may provide a number of custom solvers optimized for convex stochastic optimization. + +## Contributor Guidelines + +### Implementing Model Subclasses + +When devising a new model subclass based on the `BaseRegressor` abstract class, adhere to the subsequent guidelines: + +- **Must** inherit the `BaseRegressor` abstract superclass. +- **Must** realize the abstract methods, see above. +- **Should not** overwrite the `get_params` and `set_params` methods, inherited from `Base`. +- **May** introduce auxiliary methods for added utility. + + +## Glossary + +| Term | Description | +|--------------------| ----------- | +| **Regularization** | Regularization is a technique used to prevent overfitting by adding a penalty to the loss function, which discourages complex models. Common regularization techniques include L1 (Lasso) and L2 (Ridge) regularization. | +| **Optimization** | Optimization refers to the process of minimizing (or maximizing) a function by systematically choosing the values of the variables within an allowable set. In machine learning, optimization aims to minimize the loss function to train models. | +| **Solver** | A solver is an algorithm or a set of algorithms used for solving optimization problems. In the given module, solvers are used to find the parameters that minimize the loss function, potentially subject to some constraints. | +| **Runner** | A runner in this context refers to a callable function configured to execute the solver with the specified parameters and data. It acts as an interface to the solver, simplifying the process of running optimization tasks. | diff --git a/docs/developers_notes/04-regularizer.md b/docs/developers_notes/04-regularizer.md index 3026c162..aecd6faa 100644 --- a/docs/developers_notes/04-regularizer.md +++ b/docs/developers_notes/04-regularizer.md @@ -4,12 +4,9 @@ The `regularizer` module introduces an archetype class `Regularizer` which provides the structural components for each concrete sub-class. -Objects of type `Regularizer` provide methods to define a regularized optimization objective, and instantiate a solver for it. These objects serve as attribute of the [`nemos.glm.GLM`](../05-glm/#the-concrete-class-glm), equipping the glm with a solver for learning model parameters. +Objects of type `Regularizer` provide methods to define a regularized optimization objective. These objects serve as attribute of the [`nemos.glm.GLM`](../05-glm/#the-concrete-class-glm), equipping the glm with an appropriate regularization scheme. -Solvers are typically optimizers from the `jaxopt` package, but in principle they could be custom optimization routines as long as they respect the `jaxopt` api (i.e., have a `run` and `update` method with the appropriate input/output types). -We choose to rely on `jaxopt` because it provides a comprehensive set of robust, GPU accelerated, batchable and differentiable optimizers in JAX, that are highly customizable. - -Each `Regularizer` object defines a set of allowed optimizers, which in turn depends on the loss function characteristics (smooth vs non-smooth) and/or the optimization type (constrained, un-constrained, batched, etc.). +Each `Regularizer` object defines a default solver, and a set of allowed solvers, which depends on the loss function characteristics (smooth vs non-smooth). ``` Abstract Class Regularizer @@ -18,127 +15,36 @@ Abstract Class Regularizer | ├─ Concrete Class Ridge | -└─ Abstract Class ProximalGradientRegularizer - | - ├─ Concrete Class Lasso - | - └─ Concrete Class GroupLasso +├─ Concrete Class Lasso +| +└─ Concrete Class GroupLasso ``` !!! note - If we need advanced adaptive optimizers (e.g., Adam, LAMB etc.) in the future, we should consider adding [`Optax`](https://optax.readthedocs.io/en/latest/) as a dependency, which is compatible with `jaxopt`, see [here](https://jaxopt.github.io/stable/_autosummary/jaxopt.OptaxSolver.html#jaxopt.OptaxSolver). + If we need advanced adaptive solvers (e.g., Adam, LAMB etc.) in the future, we should consider adding [`Optax`](https://optax.readthedocs.io/en/latest/) as a dependency, which is compatible with `jaxopt`, see [here](https://jaxopt.github.io/stable/_autosummary/jaxopt.OptaxSolver.html#jaxopt.OptaxSolver). ## The Abstract Class `Regularizer` -The abstract class `Regularizer` enforces the implementation of the `instantiate_solver` method on any concrete realization of a `Regularizer` object. `Regularizer` objects are equipped with a method for instantiating a solver runner with the appropriately regularized loss function, i.e., a function that receives as input the initial parameters, the endogenous and the exogenous variables, and outputs the optimization results. - -Additionally, the class provides auxiliary methods for checking that the solver and loss function specifications are valid. +The abstract class `Regularizer` enforces the implementation of the `penalized_loss` and `get_proximal_operator` methods. -### Public Methods - -- **`instantiate_solver`**: Instantiate a solver runner for a provided loss function, configure and return a `solver_run` callable. The loss function must be of type `Callable`. +### Attributes -### Auxiliary Methods +The attributes of `Regularizer` consist of the `default_solver` and `allowed_solvers`, which are stored as read-only properties of type string and tuple of strings respectively. -- **`_check_solver`**: This method ensures that the provided solver name is in the list of allowed solvers for the specific `Regularizer` object. This is crucial for maintaining consistency and correctness in the solver's operation. +### Abstract Methods -- **`_check_solver_kwargs`**: This method checks if the provided keyword arguments are valid for the specified solver. This helps in catching and preventing potential errors in solver configuration. +- **`penalized_loss`**: Returns a penalized version of the input loss function which is uniquely defined by the regularization scheme and the regularizer strength parameter. +- **`get_proximal_operator`**: Returns the proximal projection operator which is uniquely defined by the regularization scheme. ## The `UnRegularized` Class The `UnRegularized` class extends the base `Regularizer` class and is designed specifically for optimizing unregularized models. This means that the solver instantiated by this class does not add any regularization penalty to the loss function during the optimization process. -### Attributes - -- **`allowed_solvers`**: A list of string identifiers for the optimization solvers that can be used with this regularizer class. The optimization methods listed here are specifically suitable for unregularized optimization problems. - -### Methods - -- **`__init__`**: The constructor method for this class which initializes a new `UnRegularized` object. It accepts the name of the solver algorithm to use (`solver_name`) and an optional dictionary of additional keyword arguments (`solver_kwargs`) for the solver. - -- **`instantiate_solver`**: A method which prepares and returns a runner function for the specified loss function. This method ensures that the loss function is callable and prepares the necessary keyword arguments for calling the `get_runner` method from the base `Regularizer` class. - -### Example Usage - -```python -unregularized = UnRegularized(solver_name="GradientDescent") -runner = unregularized.instantiate_solver(loss_function) -optim_results = runner(init_params, exog_vars, endog_vars) -``` - -## The `Ridge` Class - -The `Ridge` class extends the `Regularizer` class to handle optimization problems with Ridge regularization. Ridge regularization adds a penalty to the loss function, proportional to the sum of squares of the model parameters, to prevent overfitting and stabilize the optimization. - -### Attributes - -- **`allowed_solvers`**: A list containing string identifiers of optimization solvers compatible with Ridge regularization. - -- **`regularizer_strength`**: A floating-point value determining the strength of the Ridge regularization. Higher values correspond to stronger regularization which tends to drive the model parameters towards zero. - -### Methods - -- **`__init__`**: The constructor method for the `Ridge` class. It accepts the name of the solver algorithm (`solver_name`), an optional dictionary of additional keyword arguments (`solver_kwargs`) for the solver, and the regularization strength (`regularizer_strength`). -- **`penalization`**: A method to compute the Ridge regularization penalty for a given set of model parameters. +### Concrete Methods Specifics +- **`penalized_loss`**: Returns the original loss without any changes. +- **`get_proximal_operator`**: Returns the identity operator. -- **`instantiate_solver`**: A method that prepares and returns a runner function with a penalized loss function for Ridge regularization. This method modifies the original loss function to include the Ridge penalty, ensures the loss function is callable, and prepares the necessary keyword arguments for calling the `get_runner` method from the base `Regularizer` class. - -### Example Usage - -```python -ridge = Ridge(solver_name="LBFGS", regularizer_strength=1.0) -runner = ridge.instantiate_solver(loss_function) -optim_results = runner(init_params, exog_vars, endog_vars) -``` - -## `ProxGradientRegularizer` Class - -`ProxGradientRegularizer` class extends the `Regularizer` class to utilize the Proximal Gradient method for optimization. It leverages the `jaxopt` library's Proximal Gradient optimizer, introducing the functionality of a proximal operator. - -### Attributes: -- **`allowed_solvers`**: A list containing string identifiers of optimization solvers compatible with this solver, specifically the "ProximalGradient". - -### Methods: -- **`__init__`**: The constructor method for the `ProxGradientRegularizer` class. It accepts the name of the solver algorithm (`solver_name`), an optional dictionary of additional keyword arguments (`solver_kwargs`) for the solver, the regularization strength (`regularizer_strength`), and an optional mask array (`mask`). - -- **`get_prox_operator`**: Abstract method to retrieve the proximal operator for this solver. - -- **`instantiate_solver`**: Method to prepare and return a runner function for optimization with a provided loss function and proximal operator. - -## `Lasso` Class - -`Lasso` class extends `ProxGradientRegularizer` to specialize in optimization using the Lasso (L1 regularization) method with Proximal Gradient. - -### Methods: -- **`__init__`**: Constructor method similar to `ProxGradientRegularizer` but defaults `solver_name` to "ProximalGradient". - -- **`get_prox_operator`**: Method to retrieve the proximal operator for Lasso regularization (L1 penalty). - -## `GroupLasso` Class - -`GroupLasso` class extends `ProxGradientRegularizer` to specialize in optimization using the Group Lasso regularization method with Proximal Gradient. It induces sparsity on groups of features rather than individual features. - -### Attributes: -- **`mask`**: A mask array indicating groups of features for regularization. - -### Methods: -- **`__init__`**: Constructor method similar to `ProxGradientRegularizer`, but additionally requires a `mask` array to identify groups of features. - -- **`get_prox_operator`**: Method to retrieve the proximal operator for Group Lasso regularization. - -- **`_check_mask`**: Static method to check that the provided mask is a float `jax.numpy.ndarray` of 0s and 1s. The mask must be in floats to be applied correctly through the linear algebra operations of the `nemos.proimal_operator.prox_group_lasso` function. - -### Example Usage -```python -lasso = Lasso(regularizer_strength=1.0) -runner = lasso.instantiate_solver(loss_function) -optim_results = runner(init_params, exog_vars, endog_vars) - -group_lasso = GroupLasso(solver_name="ProximalGradient", mask=group_mask, regularizer_strength=1.0) -runner = group_lasso.instantiate_solver(loss_function) -optim_results = runner(init_params, exog_vars, endog_vars) -``` ## Contributor Guidelines @@ -147,20 +53,17 @@ optim_results = runner(init_params, exog_vars, endog_vars) When developing a functional (i.e., concrete) `Regularizer` class: - **Must** inherit from `Regularizer` or one of its derivatives. -- **Must** implement the `instantiate_solver` method to tailor the solver instantiation based on the provided loss function. -- For any Proximal Gradient method, **must** include a `get_prox_operator` method to define the proximal operator. -- **Must** possess an `allowed_solvers` attribute to list the solver names that are permissible to be used with this regularizer. -- **May** embed additional attributes and methods such as `mask` and `_check_mask` if required by the specific Solver subclass for handling special optimization scenarios. -- **May** include a `regularizer_strength` attribute to control the strength of the regularization in scenarios where regularization is applicable. -- **May** rely on a custom solver implementation for specific optimization problems, but the implementation **must** adhere to the `jaxopt` API. - -These guidelines ensure that each Solver subclass adheres to a consistent structure and behavior, facilitating ease of extension and maintenance. - -## Glossary - -| Term | Description | -|--------------------| ----------- | -| **Regularization** | Regularization is a technique used to prevent overfitting by adding a penalty to the loss function, which discourages complex models. Common regularization techniques include L1 (Lasso) and L2 (Ridge) regularization. | -| **Optimization** | Optimization refers to the process of minimizing (or maximizing) a function by systematically choosing the values of the variables within an allowable set. In machine learning, optimization aims to minimize the loss function to train models. | -| **Solver** | A solver is an algorithm or a set of algorithms used for solving optimization problems. In the given module, solvers are used to find the parameters that minimize the loss function, potentially subject to some constraints. | -| **Runner** | A runner in this context refers to a callable function configured to execute the solver with the specified parameters and data. It acts as an interface to the solver, simplifying the process of running optimization tasks. | +- **Must** implement the `penalized_loss` and `get_proximal_operator` methods. +- **Must** define a default solver and a tuple of allowed solvers. +- **May** require extra initialization parameters, like the `mask` argument of `GroupLasso`. + +!!! info + When adding a new regularizer, you must include a convergence test, which verifies that + the model parameters the regularizer finds for a convex problem such as the GLM are identical + whether one minimizes the penalized loss directly and uses the proximal operator (i.e., when + using `ProximalGradient`). In practice, this means you should test the result of the `ProximalGradient` + optimization against that of either `GradientDescent` (if your regularization is differentiable) or + `Nelder-Mead` from [`scipy.optimize.minimize`](https://docs.scipy.org/doc/scipy/reference/optimize.minimize-neldermead.html) + (or another non-gradient based method, if your regularization is non-differentiable). You can refer to NeMoS `test_lasso_convergence` + from `tests/test_convergence.py` for a concrete example. + diff --git a/docs/developers_notes/05-glm.md b/docs/developers_notes/05-glm.md deleted file mode 100644 index dc01329d..00000000 --- a/docs/developers_notes/05-glm.md +++ /dev/null @@ -1,76 +0,0 @@ -# The `glm` Module - -## Introduction - - - -Generalized Linear Models (GLM) provide a flexible framework for modeling a variety of data types while establishing a relationship between multiple predictors and a response variable. A GLM extends the traditional linear regression by allowing for response variables that have error distribution models other than a normal distribution, such as binomial or Poisson distributions. - -The `nemos.glm` module currently offers implementations of two GLM classes: - -1. **`GLM`:** A direct implementation of a feedforward GLM. -2. **`RecurrentGLM`:** An implementation of a recurrent GLM. This class inherits from `GLM` and redefines the `simulate` method to generate spikes akin to a recurrent neural network. - -Our design aligns with the `scikit-learn` API, facilitating seamless integration of our GLM classes with the well-established `scikit-learn` pipeline and its cross-validation tools. - -The classes provided here are modular by design offering a standard foundation for any GLM variant. - -Instantiating a specific GLM simply requires providing an observation model (Gamma, Poisson, etc.) and a regularization strategies (Ridge, Lasso, etc.) during initialization. This is done using the [`nemos.observation_models.Observations`](../03-observation_models/#the-abstract-class-observations) and [`nemos.regularizer.Regularizer`](../04-regularizer/#the-abstract-class-regularizer) objects, respectively. - - -
- -
Schematic of the module interactions.
-
- - - -## The Concrete Class `GLM` - -The `GLM` class provides a direct implementation of the GLM model and is designed with `scikit-learn` compatibility in mind. - -### Inheritance - -`GLM` inherits from [`BaseRegressor`](../02-base_class/#the-abstract-class-baseregressor). This inheritance mandates the direct implementation of methods like `predict`, `fit`, `score`, and `simulate`. - -### Attributes - -- **`regularizer`**: Refers to the optimization regularizer - an object of the [`nemos.regularizer.regularizer`](../04-regularizer/#the-abstract-class-regularizer) type. It uses the `jaxopt` solver to minimize the (penalized) negative log-likelihood of the GLM. -- **`observation_models`**: Represents the GLM observation model, which is an object of the [`nemos.observation_models.Observations`](../03-observation_models/#the-abstract-class-observations) type. This model determines the log-likelihood and the emission probability mechanism for the `GLM`. -- **`coef_`**: Stores the solution for spike basis coefficients as `jax.ndarray` after the fitting process. It is initialized as `None` during class instantiation. -- **`intercept_`**: Stores the bias terms' solutions as `jax.ndarray` after the fitting process. It is initialized as `None` during class instantiation. -- **`solver_state`**: Indicates the solver's state. For specific solver states, refer to the [`jaxopt` documentation](https://jaxopt.github.io/stable/index.html#). - -### Public Methods - -- **`predict`**: Validates input and computes the mean rates of the `GLM` by invoking the inverse-link function of the `observation_models` attribute. -- **`score`**: Validates input and assesses the Poisson GLM using either log-likelihood or pseudo-$R^2$. This method uses the `observation_models` to determine log-likelihood or pseudo-$R^2$. -- **`fit`**: Validates input and aligns the Poisson GLM with spike train data. It leverages the `observation_models` and `regularizer` to define the model's loss function and instantiate the regularizer. -- **`simulate`**: Simulates spike trains using the GLM as a feedforward network, invoking the `observation_models.sample_generator` method for emission probability. - -### Private Methods - -- **`_predict`**: Forecasts rates based on current model parameters and the inverse-link function of the `observation_models`. -- **`_score`**: Determines the Poisson negative log-likelihood, excluding normalization constants. -- **`_check_is_fit`**: Validates whether the model has been appropriately fit by ensuring model parameters are set. If not, a `NotFittedError` is raised. - - -## The Concrete Class `RecurrentGLM` - -The `RecurrentGLM` class is an extension of the `GLM`, designed to simulate models with recurrent connections. It inherits the `predict`, `fit`, and `score` methods from `GLM`, but provides its own implementation for the `simulate` method. - -### Overridden Methods - -- **`simulate`**: This method simulates spike trains, treating the GLM as a recurrent neural network. It utilizes the `observation_models.sample_generator` method to determine the emission probability. - -## Contributor Guidelines - -### Implementing GLM Subclasses - -When crafting a functional (i.e., concrete) GLM class: - -- **Must** inherit from `BaseRegressor` or one of its derivatives. -- **Must** realize the `predict`, `fit`, `score`, and `simulate` methods, either directly or through inheritance. -- **Should** incorporate a `observation_models` attribute of type `nemos.observation_models.Observations` to specify the link-function, emission probability, and likelihood. -- **Should** include a `regularizer` attribute of type `nemos.regularizer.Regularizer` to instantiate the solver based on regularization type. -- **May** embed additional parameter and input checks if required by the specific GLM subclass. diff --git a/docs/developers_notes/03-observation_models.md b/docs/developers_notes/05-observation_models.md similarity index 62% rename from docs/developers_notes/03-observation_models.md rename to docs/developers_notes/05-observation_models.md index b7429bcd..e0c3fe32 100644 --- a/docs/developers_notes/03-observation_models.md +++ b/docs/developers_notes/05-observation_models.md @@ -8,42 +8,26 @@ The abstract class `Observations` defines the structure of the subclasses which ## The Abstract class `Observations` -The abstract class `Observations` is the backbone of any observation model. Any class inheriting `Observations` must reimplement the `negative_log_likelihood`, `sample_generator`, `residual_deviance`, and `estimate_scale` methods. +The abstract class `Observations` is the backbone of any observation model. Any class inheriting `Observations` must reimplement the `_negative_log_likelihood`, `log_likelihood`, `sample_generator`, `deviance`, and `estimate_scale` methods. ### Abstract Methods For subclasses derived from `Observations` to function correctly, they must implement the following: -- **negative_log_likelihood**: Computes the negative-log likelihood of the model up to a normalization constant. This method is usually part of the objective function used to learn GLM parameters. +- **_negative_log_likelihood**: Computes the negative-log likelihood of the model up to a normalization constant. This method is usually part of the objective function used to learn GLM parameters. + +- **log_likelihood**: Computes the full log-likelihood including the normalization constant. - **sample_generator**: Returns the random emission probability function. This typically invokes `jax.random` emission probability, provided some sufficient statistics[^1]. For distributions in the exponential family, the sufficient statistics are the canonical parameter and the scale. In GLMs, the canonical parameter is entirely specified by the model's weights, while the scale is either fixed (i.e., Poisson) or needs to be estimated (i.e., Gamma). -- **residual_deviance**: Computes the residual deviance based on the model's estimated rates and observations. +- **deviance**: Computes the deviance based on the model's estimated rates and observations. -- **estimate_scale**: A method for estimating the scale parameter of the model. +- **estimate_scale**: A method for estimating the scale parameter of the model. Rate and scale are sufficient to fully characterize distributions from the exponential family. ### Public Methods - **pseudo_r2**: Method for computing the pseudo-$R^2$ of the model based on the residual deviance. There is no consensus definition for the pseudo-$R^2$, what we used here is the definition by Cohen at al. 2003[^2]. - - -### Auxiliary Methods - -- **_check_inverse_link_function**: Check that the provided link function is a `Callable` of the `jax` namespace. - -## Concrete `PoissonObservations` class - -The `PoissonObservations` class extends the abstract `Observations` class to provide functionalities specific to the Poisson observation model. It is designed for modeling observed spike counts based on a Poisson distribution with a given rate. - -### Overridden Methods - -- **negative_log_likelihood**: This method computes the Poisson negative log-likelihood of the predicted rates for the observed spike counts. - -- **sample_generator**: Generates random numbers from a Poisson distribution based on the given `predicted_rate`. - -- **residual_deviance**: Calculates the residual deviance for a Poisson model. - -- **estimate_scale**: Assigns a fixed value of 1 to the scale parameter of the Poisson model since Poisson distribution has a fixed scale. +- **check_inverse_link_function**: Check that the link function is a auto-differentiable, vectorized function form $\mathbb{R} \longrightarrow \mathbb{R}$. ## Contributor Guidelines @@ -51,9 +35,9 @@ To implement an observation model class you - **Must** inherit from `Observations` -- **Must** provide a concrete implementation of `negative_log_likelihood`, `sample_generator`, `residual_deviance`, and `estimate_scale`. +- **Must** provide a concrete implementation of the abstract methods, see above. -- **Should not** reimplement the `pseudo_r2` method as well as the `_check_inverse_link_function` auxiliary method. +- **Should not** reimplement the `pseudo_r2` method as well as the `check_inverse_link_function` auxiliary method. [^1]: In statistics, a statistic is sufficient with respect to a statistical model and its associated unknown parameters if "no other statistic that can be calculated from the same sample provides any additional information as to the value of the parameters", adapted from Fisher R. A. diff --git a/docs/developers_notes/06-glm.md b/docs/developers_notes/06-glm.md new file mode 100644 index 00000000..c5104022 --- /dev/null +++ b/docs/developers_notes/06-glm.md @@ -0,0 +1,93 @@ +# The `glm` Module + +## Introduction + + + +Generalized Linear Models (GLM) provide a flexible framework for modeling a variety of data types while establishing a relationship between multiple predictors and a response variable. A GLM extends the traditional linear regression by allowing for response variables that have error distribution models other than a normal distribution, such as binomial or Poisson distributions. + +The `nemos.glm` module currently offers implementations of two GLM classes: + +1. **`GLM`:** A direct implementation of a feedforward GLM. +2. **`PopulationGLM`:** An implementation of a GLM for fitting a populaiton of neuron in a vectorized manner. This class inherits from `GLM` and redefines the `fit` and `_predict` to fit the model and predict the firing rate. + +Our design aligns with the `scikit-learn` API, facilitating seamless integration of our GLM classes with the well-established `scikit-learn` pipeline and its cross-validation tools. + +The classes provided here are modular by design offering a standard foundation for any GLM variant. + +Instantiating a specific GLM simply requires providing an observation model (Gamma, Poisson, etc.), a regularization strategies (Ridge, Lasso, etc.) and an optimization scheme during initialization. This is done using the [`nemos.observation_models.Observations`](../03-observation_models/#the-abstract-class-observations), [`nemos.regularizer.Regularizer`](../04-regularizer/#the-abstract-class-regularizer) objects as well as the compatible `jaxopt` solvers, respectively. + + +
+ +
Schematic of the module interactions.
+
+ + + +## The Concrete Class `GLM` + +The `GLM` class provides a direct implementation of the GLM model and is designed with `scikit-learn` compatibility in mind. + +### Inheritance + +`GLM` inherits from [`BaseRegressor`](../02-base_class/#the-abstract-class-baseregressor). This inheritance mandates the direct implementation of methods like `predict`, `fit`, `score` `update`, and `simulate`, plus a number of validation methods. + +### Attributes + +- **`observation_model`**: Property that represents the GLM observation model, which is an object of the [`nemos.observation_models.Observations`](../03-observation_models/#the-abstract-class-observations) type. This model determines the log-likelihood and the emission probability mechanism for the `GLM`. +- **`coef_`**: Stores the solution for spike basis coefficients as `jax.ndarray` after the fitting process. It is initialized as `None` during class instantiation. +- **`intercept_`**: Stores the bias terms' solutions as `jax.ndarray` after the fitting process. It is initialized as `None` during class instantiation. +- **`dof_resid_`**: The degrees of freedom of the model's residual. this quantity is used to estimate the scale parameter, see below, and compute frequentist confidence intervals. +- **`scale_`**: The scale parameter of the observation distribution, which together with the rate, uniquely specifies a distribution of the exponential family. Example: a 1D Gaussian is specified by the mean which is the rate, and the standard deviation, which is the scale. +- **`solver_state_`**: Indicates the solver's state. For specific solver states, refer to the [`jaxopt` documentation](https://jaxopt.github.io/stable/index.html#). + +Additionally, the `GLM` class inherits the attributes of `BaseRegressor`, see the [relative note](03-base_regressor.md) for more information. + +### Public Methods + +- **`predict`**: Validates input and computes the mean rates of the `GLM` by invoking the inverse-link function of the `observation_models` attribute. +- **`score`**: Validates input and assesses the Poisson GLM using either log-likelihood or pseudo-$R^2$. This method uses the `observation_models` to determine log-likelihood or pseudo-$R^2$. +- **`fit`**: Validates input and aligns the Poisson GLM with spike train data. It leverages the `observation_models` and `regularizer` to define the model's loss function and instantiate the regularizer. +- **`simulate`**: Simulates spike trains using the GLM as a feedforward network, invoking the `observation_models.sample_generator` method for emission probability. +- **`initialize_params`**: Initialize model parameters, setting to zero the coefficients, and setting the intercept by matching the firing rate. +- **`initialize_state`**: Initialize the state of the solver. +- **`update`**: Run a step of optimization and update the parameter and solver step. + +### Private Methods + +Here we list the private method related to the model computations: + +- **`_predict`**: Forecasts rates based on current model parameters and the inverse-link function of the `observation_models`. +- **`_predict_and_compute_loss`**: Predicts the rate and calculates the mean Poisson negative log-likelihood, excluding normalization constants. + +A number of `GLM` specific private methods are used for checking parameters and inputs, while the methods related for checking the solver-regularizer configurations/instantiation are inherited from `BaseRergessor`. + + +## The Concrete Class `PopulationGLM` + +The `PopulationGLM` class is an extension of the `GLM`, designed to fit multiple neurons jointly. This involves vectorized fitting processes that efficiently handle multiple neurons simultaneously, leveraging the inherent parallelism. + +### `PopulationGLM` Specific Attributes + +- **`feature_mask`**: A mask that determines which features are used as predictors for each neuron. It can be a matrix of shape `(num_features, num_neurons)` or a `FeaturePytree` of binary values, where 1 indicates that a feature is used for a particular neuron and 0 indicates it is not. + +### Overridden Methods + +- **`fit`**: Overridden to handle fitting of the model to a neural population. This method validates input including the mask and fits the model parameters (coefficients and intercepts) to the data. +- **`_predict`**: Computes the predicted firing rates using the model parameters and the feature mask. + + + +## Contributor Guidelines + +### Implementing a `BaseRegressor` Subclasses + +When crafting a functional (i.e., concrete) GLM class: + +- You **must** inherit from `GLM` or one of its derivatives. +- If you inherit directly from `BaseRegressor`, you **must** implement all the abstract methods, see the [`BaseRegressor` page](03-base_regressor.md) for more details. +- If you inherit `GLM` or any of the other concrete classes directly, there won't be any abstract methods. +- You **may** embed additional parameter and input checks if required by the specific GLM subclass. +- You **may** override some of the computations if needed by the model specifications. + diff --git a/docs/developers_notes/GLM_scheme.jpg b/docs/developers_notes/GLM_scheme.jpg deleted file mode 100644 index 712a98ba..00000000 Binary files a/docs/developers_notes/GLM_scheme.jpg and /dev/null differ diff --git a/docs/developers_notes/README.md b/docs/developers_notes/README.md index c0d90ac5..250ef834 100644 --- a/docs/developers_notes/README.md +++ b/docs/developers_notes/README.md @@ -1,7 +1,5 @@ # Introduction -!!! warning - This note is out-of-sync with the current API. Please, do not rely on this yet for contributing to NeMoS. Welcome to the Developer Notes of the NeMoS project. These notes aim to provide detailed technical information about the various modules, classes, and functions that make up this library, as well as guidelines on how to write code that integrates nicely with our package. They are intended to help current and future developers understand the design decisions, structure, and functioning of the library, and to provide guidance on how to modify, extend, and maintain the codebase. diff --git a/docs/developers_notes/classes_nemos.png b/docs/developers_notes/classes_nemos.png new file mode 100644 index 00000000..187ab3b1 Binary files /dev/null and b/docs/developers_notes/classes_nemos.png differ diff --git a/docs/developers_notes/classes_nemos.svg b/docs/developers_notes/classes_nemos.svg new file mode 100644 index 00000000..fae37888 --- /dev/null +++ b/docs/developers_notes/classes_nemos.svg @@ -0,0 +1,739 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + BaseRegressor + + Base + + GLM + + GammaObservations + + Observations + + PoissonObservations + + + + + + + + + Regularizer + + GroupLasso + + Lasso + + Ridge + + UnRegularized + + + + + + + + A is an attribute of B + A is a subclass if B + A + A + B: + B: + Legend + diff --git a/src/nemos/base_regressor.py b/src/nemos/base_regressor.py index 5ebd3b7b..0e7581a7 100644 --- a/src/nemos/base_regressor.py +++ b/src/nemos/base_regressor.py @@ -428,7 +428,7 @@ def _check_params( data_type: Optional[jnp.dtype] = None, ) -> Tuple[DESIGN_INPUT_TYPE, jnp.ndarray]: """ - Validate the dimensions and consistency of parameters and data. + Validate the dimensions and consistency of parameters. This function checks the consistency of shapes and dimensions for model parameters. diff --git a/src/nemos/regularizer.py b/src/nemos/regularizer.py index e983287f..af5bf224 100644 --- a/src/nemos/regularizer.py +++ b/src/nemos/regularizer.py @@ -52,13 +52,14 @@ def __init__( super().__init__(**kwargs) @property - def allowed_solvers(self): + def allowed_solvers(self) -> Tuple[str]: return self._allowed_solvers @property - def default_solver(self): + def default_solver(self) -> str: return self._default_solver + @abc.abstractmethod def penalized_loss(self, loss: Callable, regularizer_strength: float) -> Callable: """ Abstract method to penalize loss functions. @@ -78,6 +79,7 @@ def penalized_loss(self, loss: Callable, regularizer_strength: float) -> Callabl """ pass + @abc.abstractmethod def get_proximal_operator( self, ) -> ProximalOperator: diff --git a/tests/test_regularizer.py b/tests/test_regularizer.py index 00049e8a..821eca1f 100644 --- a/tests/test_regularizer.py +++ b/tests/test_regularizer.py @@ -16,7 +16,7 @@ @pytest.mark.parametrize( "reg_str, reg_type", [ - ("UnRegularized", nmo.regularizer.Regularizer), + ("UnRegularized", nmo.regularizer.UnRegularized), ("Ridge", nmo.regularizer.Ridge), ("Lasso", nmo.regularizer.Lasso), ("GroupLasso", nmo.regularizer.GroupLasso), @@ -1281,5 +1281,6 @@ def test_solver_combination(self, solver_name, poissonGLM_model_instantiation): def test_available_regularizer_match(): """Test matching of the two regularizer lists.""" - assert set(nmo._regularizer_builder.AVAILABLE_REGULARIZERS) == set(nmo.regularizer.__dir__()) - + assert set(nmo._regularizer_builder.AVAILABLE_REGULARIZERS) == set( + nmo.regularizer.__dir__() + ) diff --git a/tox.ini b/tox.ini index 6b9cb384..7295d866 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,6 @@ [tox] isolated_build = True -envlist = py38, py39, py310 +envlist = py39, py310, py311 [testenv] # means we'll run the equivalent of `pip install .[dev]`, also installing pytest @@ -23,9 +23,9 @@ commands = [gh-actions] python = - 3.8: py38 3.9: py39 3.10: py310 + 3.11: py311 [flake8]