Skip to content

Latest commit

 

History

History
104 lines (78 loc) · 4.28 KB

README.md

File metadata and controls

104 lines (78 loc) · 4.28 KB

Towards Scaling Difference Target Propagation with Backprop Targets

This repository is the official implementation of "Towards Scaling Difference Target Propagation with Backprop Targets", accepted to ICML 2022 for a short presentation. The following code runs on Python > 3.7 with Pytorch >= 1.7.0.

Installation

pip install -e .

(Optional): We suggest you use a conda environment. The specs of our environment are stored in conda_env_specs.txt.

Naming of methods:

Name in paper Name in codebase
L-DRL DTP
Backpropagation BaselineModel
DRL meulemans_dtp (Based on the original authors' repo)
Target Propagation TargetProp
Difference Target Propagation VanillaDTP
"Parallel" L-DRL (not in the paper) ParallelDTP

Codebase structure

The main logic of our method is in target_prop/models/dtp.py

An initial PyTorch implementation of our DTP model can be found under target_prop/legacy. This model was then re-implemented using PyTorch-Lightning.

Here is how the codebase is roughly structured:

├── main.py                # training script (legacy)
├── meulemans_dtp          # Codebase for DRL (Meulemans repo)
├── numerical_experiments  # Initial scripts for creating the figures (used for fig. 4.2)
└── target_prop
    ├── datasets  # Datasets
    ├── legacy    # initial implementation
    ├── models    # Code for all the models except DRL
    └── networks  # Networks (SimpleVGG, LetNet, ResNet)

Running the code

  • Recreating figure 4.2:

    $ python -m numerical_experiments figure_4_2

    The figure save location will then be displayed on the console.

  • Recreating figure 4.3:

    $ pytest -vv target_prop/theorem_test.py
    $ python target_prop/legacy/plot.py

To see a list of available command-line options, use the "--help" command.

python main.py --help

To run the pytorch-lightning re-implementation of DTP on CIFAR-10, use the following command:

python main.py model=dtp dataset=cifar10

To use the modified version of the above DTP model, with parallel feedback weight training on CIFAR-10, use the following command:

python main.py model=parallel_dtp dataset=cifar10

To run backprop baseline, do:

python main.py model=backprop dataset=cifar10

ImageNet

To train with DTP on downsampled ImageNet 32x32 dataset, do:

python main.py model=dtp dataset=imagenet32

Legacy Implementation

To check training on CIFAR-10, type the following command in the terminal:

python main_legacy.py --batch-size 128 \
    --C 128 128 256 256 512 \
    --iter 20 30 35 55 20 \
    --epochs 90 \
    --lr_b 1e-4 3.5e-4 8e-3 8e-3 0.18 \
    --noise 0.4 0.4 0.2 0.2 0.08 \
    --lr_f 0.08 \
    --beta 0.7 \
    --path CIFAR-10 \
    --scheduler --wdecay 1e-4