This repository is the official implementation of "Towards Scaling Difference Target Propagation with Backprop Targets", accepted to ICML 2022 for a short presentation. The following code runs on Python > 3.7 with Pytorch >= 1.7.0.
pip install -e .
(Optional): We suggest you use a conda environment. The specs of our environment are stored in conda_env_specs.txt
.
Name in paper | Name in codebase |
---|---|
L-DRL | DTP |
Backpropagation | BaselineModel |
DRL | meulemans_dtp (Based on the original authors' repo) |
Target Propagation | TargetProp |
Difference Target Propagation | VanillaDTP |
"Parallel" L-DRL (not in the paper) | ParallelDTP |
The main logic of our method is in target_prop/models/dtp.py
An initial PyTorch implementation of our DTP model can be found under target_prop/legacy. This model was then re-implemented using PyTorch-Lightning.
Here is how the codebase is roughly structured:
├── main.py # training script (legacy)
├── meulemans_dtp # Codebase for DRL (Meulemans repo)
├── numerical_experiments # Initial scripts for creating the figures (used for fig. 4.2)
└── target_prop
├── datasets # Datasets
├── legacy # initial implementation
├── models # Code for all the models except DRL
└── networks # Networks (SimpleVGG, LetNet, ResNet)
-
Recreating figure 4.2:
$ python -m numerical_experiments figure_4_2
The figure save location will then be displayed on the console.
-
Recreating figure 4.3:
$ pytest -vv target_prop/theorem_test.py $ python target_prop/legacy/plot.py
To see a list of available command-line options, use the "--help" command.
python main.py --help
To run the pytorch-lightning re-implementation of DTP on CIFAR-10, use the following command:
python main.py model=dtp dataset=cifar10
To use the modified version of the above DTP model, with parallel feedback weight training on CIFAR-10, use the following command:
python main.py model=parallel_dtp dataset=cifar10
To run backprop baseline, do:
python main.py model=backprop dataset=cifar10
To train with DTP on downsampled ImageNet 32x32 dataset, do:
python main.py model=dtp dataset=imagenet32
To check training on CIFAR-10, type the following command in the terminal:
python main_legacy.py --batch-size 128 \
--C 128 128 256 256 512 \
--iter 20 30 35 55 20 \
--epochs 90 \
--lr_b 1e-4 3.5e-4 8e-3 8e-3 0.18 \
--noise 0.4 0.4 0.2 0.2 0.08 \
--lr_f 0.08 \
--beta 0.7 \
--path CIFAR-10 \
--scheduler --wdecay 1e-4