This is the official implementation of "Energy-based Hopfield Boosting for Out-of-Distribution Detection". The paper is available here.
hopfield-boosting-vid.mp4
-
Hopfield Boosting works best with Anaconda (download here). To install Hopfield Boosting and all dependencies, run the following commands:
conda env create -f environment.yml conda activate hopfield-boosting pip install -e .
- Hopfield Boosting supports logging with Weights and Biases (W&B). By default, W&B will log all metrics in anonymous mode. Note that runs logged in anonymous mode will be deleted after 7 days. To keep the logs, you need to create a W&B account. When done, login to your account using the command line.
To run, you need the following data sets. We follow the established benchmark, which is also used by e.g. Lui et al. (2020) and Ming et al. (2022).
- CIFAR: Automatically downloaded by PyTorch
- ImageNet-RC: We use ImageNet64x64, which can be downloaded from the ImageNet Website.
- MNIST: Automatically downloaded by PyTorch
- FashionMNIST: Automatically downloaded by PyTorch
The OOD test data is comprised of a selection of vision data sets:
- SVHN: Street View House Numbers
- Places 365: Scene recognition data set
- LSUN-Resize: A resized version of the Large-scale Scene UNderstanding Challenge
- LSUN-Crop: A cropped version of the Large-scale Scene UNderstanding Challenge
- iSUN: Contains a large number of different scenes
- Textures: A collection of textural images in the wild
We have included a Python script that conveniently downloads all OOD Test data sets. To execute it, simply run
python -m hopfield_boosting.download_data
The downloaded data sets will be placed in the currently active directory under downloaded_datasets/
.
-
Set the paths to the data sets: Copy the
.env.examples
file located in the root directory of the repository. Name the newly created file.env
. Customize the new file to contain the paths to the data sets on your machine. You can also set aproject_root
, which is where Hopfield Boosting will store your model checkpoints. -
To run Hopfield Boosting on CIFAR-10, run the command
python -m hopfield_boosting -cn resnet-18-cifar-10-aux-from-scratch
-
For CIFAR-100, use the command
python -m hopfield_boosting -cn resnet-18-cifar-100-aux-from-scratch
-
The performance on the OOD validation data sets will be logged to W&B; the performance on the OOD test sets will be logged to a file located in
test_logs
named according to therun.id
from W&B.
We have provided a demo notebook here where we demonstrate the capability of Hopfeild Boosting to detect OOD inputs. We provide a pre-trained model trained on CIFAR-10 for running the notebook, which is available for download here.
To run, first set the paths to the data sets and the model in hopfield_boosting_notebook_config.yaml
. The notebook uses additional data sets. You can find the link to download these data sets in the notebook itself.
If you found this repository helpful, consider giving it a ⭐ and cite our paper:
@article{hofmann2024energybased,
title={Energy-based Hopfield Boosting for Out-of-Distribution Detection},
author={Claus Hofmann and Simon Schmid and Bernhard Lehner and Daniel Klotz and Sepp Hochreiter},
year={2024},
journal={arXiv preprint arXiv:2405.08766}
}