Skip to content

Public code for Illing, Ventura, Bellec & Gerstner 2021: Local plasticity rules can learn deep representations using self-supervised contrastive predictions

License

Notifications You must be signed in to change notification settings

EPFL-LCN/pub-illing2021-neurips

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

22 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DOI

CLAPP code

This is the code for the publication:

B. Illing, J. Ventura, G. Bellec & W. Gerstner Local plasticity rules can learn deep representations using self-supervised contrastive predictions, accepted at NeurIPS 2021.

Contact: bernd.illing@epfl.ch

Implementation of CLAPP in pytorch

We implement CLAPP (and its variants) using the auto-differentiation provided by pytorch. That means that we do not implement the learning rule, Equations (6) - (8), explicitely. Instead, we apply the CLAPP loss, Equation (3), at every layer and block gradients (pytorch .detach()), such that the automatically calculated gradients (.backward()) match the CLAPP learning rules. We summarize for a single layer in python/pytorch pseudocode:

""" 
require:
layer (encoder layer to train)
clapp_hinge_loss (CLAPP hinge loss as in Equation (3); contains the prediction weights)
opt (optimiser, e.g. ADAM, containing all trainable parameters of this layer)
x_past (previous input)
x (current input)
"""

c = layer(x_past.detach()) # context: encoding of previous input
z = layer(x.detach()) # future activity

loss = clapp_hinge_loss(c, z)
loss.backward() # autodiff calculates gradients

opt.step() # update parameters of layer and prediction weights

We verified numerically that the obtained updates are equivalent to evaluating the CLAPP learning rules Equations (6) - (8). The code for this can be found in ./vision/CLAPPVision/vision/compare_updates.py, see Vision section for more details.

Note that for Hinge Loss CPC, the end-to-end version of CLAPP, we only use a single CLAPP loss at the final layer. Furthermore, we don't use the .detach() function to allow gradient flow through the whole network.

Variants of CLAPP mainly differ in the exact implementation of the CLAPP loss clapp_hinge_loss. E.g. for the synchronous version CLAPP-s, the CLAPP loss adds the contribution of negative and positive sample at every step, instead of sampling with 50/50 probability as in CLAPP.

Structure of the code

The code is divided into three independent sections, corresponding to the three domains we apply CLAPP to:

  • vision
  • video
  • audio

Each section comes with its own dependencies handled by conda environments, as explained in the respective sections below.

Vision

The implementation of the CLAPP vision experiments is based on Sindy Löwe's code of the Greedy InfoMax model.

New Setup (Updated Apr.19th, 2024):

This setup is compatible with the new PyTorch versions.

  1. To setup the conda environment, first create a conda environment with python 3.9:
    conda create -n clappvision python=3.9
  1. Install PyTorch v2.0.1 and Torchvision v0.15.2 following the guide on the PyTorch website

  2. Install required files by pip:

    cd vision
    pip3 install -r requirements.txt
  1. To activate and deactive the created conda environment, run
    conda activate clappvision
    conda deactivate

Original Setup

This setup is originally used for producing numbers reported in the paper.

To setup the conda environment, simply run

    cd vision
    bash ./setup_dependencies.sh

To activate and deactive the created conda environment, run

    conda activate clappvision
    conda deactivate

respectively.

Usage

We included three sample scripts to run CLAPP, CLAPP-s (synchronous pos. and neg. updates; version with symmetric pre- and retrodiction weights) and Hinge Loss CPC (end-to-end version of CLAPP). To run the, e.g. the Hinge Loss CPC simulations (model training + evaluation), run:

    cd vision
    bash ./scripts/vision_traineval_HingeLossCPC.sh

The code includes many (experimental) versions of CLAPP as command line options that are not used and mentioned in the paper. To view all command-line options of model training, run:

    cd vision
    python -m CLAPPVision.vision.main_vision --help

We also added code to run the above mentioned numerical check that the updates obtained with auto-differentiation are equivalent to evaluating the CLAPP learning rules. To check this, e.g. for a randomly initialised network at the first epoch of training, run:

    mkdir ./logs/CLAPP_init/
    python -m CLAPPVision.vision.compare_updates --download_dataset --save_dir CLAPP_init --encoder_type 'vgg_like' --model_splits 6 --train_module 6 --contrast_mode 'hinge' --num_epochs 1 --negative_samples 1 --sample_negs_locally --sample_negs_locally_same_everywhere --start_epoch 0 --model_path ./logs/CLAPP_init/ --save_vars_for_update_calc 3 --batch_size 4

The equivalence was found to also hold later during training. For this, the respective simulations first need to be run (see comments in ./vision/CLAPPVision/vision/compare_updates.py).

Checkpoint (Updated April.19th, 2024)

The checkpoint of CLAPP-s trained on STL10 vision task is provided. This Checkpoint was trained with an updated code and environment setup (git commit b31bd3d) which supports PyTorch 2.0.1. The classification accuracy using the representations of the 5-th layer is 74.9%. This is slightly lower than the 75.0% accuracy reported in Table 1 of the paper.

The trained checkpoints of CLAPP-s are stored in vision/checkpoints/CLAPP_s. Trained models are stored with name model_i_k.ckpt where i is the number of layers and k is the epoch number. Example code of using the 5-th layer for downstream classification is:

    cd vision
    python -m CLAPPVision.vision.downstream_classification --model_path ./checkpoints/CLAPP_s --model_num 299 --encoder_type 'vgg_like' --model_splits 6 --train_module 6 --module_num 5

Video

The implementation of the CLAPP video experiments was inspired by Tengda Han's code for Dense Predictive Coding

Setup

The setup of the conda environment is described in ./video/env_setup.txt. To activate and deactive the created conda environment pdm, run

    conda activate pdm
    conda deactivate

respectively.

Usage

The basic simulations described in the paper can be replicated using the commands listed in ./video/commands.txt.

Audio

The implementation of the CLAPP audio experiments is based on Sindy Löwe's code of the Greedy InfoMax model.

Setup

Usage

Cite

Please cite our paper if you use this code in your own work:

@inproceedings{illing2021local,
  title={Local plasticity rules can learn deep representations using self-supervised contrastive predictions},
  author={Illing, Bernd and Ventura, Jean and Bellec, Guillaume and Gerstner, Wulfram},
  booktitle = {Advances in Neural Information Processing Systems},
  volume = {34},
  year={2021}
}

About

Public code for Illing, Ventura, Bellec & Gerstner 2021: Local plasticity rules can learn deep representations using self-supervised contrastive predictions

Resources

License

Stars

Watchers

Forks

Packages

No packages published