Discrete Key-Value Bottleneck Frederik Träuble, Anirudh Goyal, Nasim Rahaman, Michael Mozer, Kenji Kawaguchi, Yoshua Bengio, Bernhard Schölkopf. ICML 2023.
To reproduce the results of the paper, you need to first create a conda environment and install the package:
conda create -n kvb python=3.10.6
conda activate kvb
git clone git@github.com:ftraeuble/experiments_discrete_key_value_bottleneck.git
cd experiments_discrete_key_value_bottleneck
pip install .
This will install the discrete_key_value_bottleneck
package as well as all required dependencies.
To reproduce the toy experiments from Fig. 2 in the paper, you can run the following notebook:
To reproduce the main experiments CIFAR10 you need to first log in to wandb in your machine, set your wandb
as well as the PROJECT_ROOT_DIR
environment variable.
To run the experiments it is advisable to precompute the relevant backbone embeddings across all required datasets. To reproduce the ConvMixer experiments, you will have to download the CIFAR10 and Imagenet32 Embeddings for the ConvMixer backbone from the SDMLP paper submission repository from Bricken et al. (2023).
To precompute the embeddings, you can use the following two notebooks:
Finally, store all created embeddings and label files in a folder named backbone_embeddings
A list of all sweeps can be found in the directory sweeps/icml2023
. Run the following command:
wandb sweep sweeps/icml2023/NAME_OF_SWEEP.yaml
wandb agent <SWEEP_ID>
All sweeps comprise 300+ trained models, which can be used to reproduce all results of the paper.
A single model can be trained by running the following command:
python scripts/train.py --backbone=resnet50_imagenet_v2 --dim_key=14 --dim_value=10 --init_epochs=10 --learning_rate=0.3 --num_books=256 --num_pairs=4096 --pretrain_data=CIFAR100 --seed=2
If you found this codebase useful, please cite our paper:
title={Discrete Key-Value Bottleneck},
author={Tr{\"a}uble, Frederik and Goyal, Anirudh and Rahaman, Nasim and Mozer, Michael and Kawaguchi, Kenji and Bengio, Yoshua and Sch{\"o}lkopf, Bernhard},
journal={International Conference on Machine Learning},