This repository contains code for the NeurIPS Spotlight paper "Benchmarking Uncertainty Disentanglement: Specialized Uncertainties for Specialized Tasks" and also serves as a standalone benchmark suite for future methods.
The untangle
repository is a comprehensive uncertainty quantification and uncertainty
disentanglement benchmark suite that comes with
- implementations of 19 uncertainty quantification methods as convenient wrapper classes ... (
untangle.wrappers
) - ... and corresponding loss functions (
untangle.losses
) - an efficient training script that supports these methods (
train.py
) - an extensive evaluation suite for uncertainty quantification methods (
validate.py
) - support for CIFAR-10 ResNet models, including pre-activation and Fixup variants of Wide ResNets (
untangle.models
) - ImageNet-C and CIFAR-10C support (
untangle.transforms
) - ImageNet-ReaL and CIFAR-10H support (
untangle.datasets
) - out-of-the-box support for PyTorch Image Models (
timm
) models and configs - plotting utilities to recreate the plots of the paper
If you found the paper or the code useful in your research, please cite our work as
@article{mucsanyi2024benchmarking,
title={Benchmarking Uncertainty Disentanglement: Specialized Uncertainties for Specialized Tasks},
author={Mucs{\'a}nyi, B{\'a}lint and Kirchhof, Michael and Oh, Seong Joon},
journal={arXiv preprint arXiv:2402.19460},
year={2024}
}
If you use the benchmark, please also cite the datasets it uses.
The package supports Python 3.11 and 3.12.
The package requirements are all listed in the requirements.txt
file.
A local copy of the ImageNet-1k dataset is needed to run the ImageNet experiments.
CIFAR-10 is available in torchvision.datasets
and is downloaded automatically.
The ImageNet-ReaL labels are available in this GitHub repository. The needed files are raters.npz
and real.json
.
The CIFAR-10H test dataset can be downloaded from this link.
The ImageNet-C and CIFAR-10C perturbations use Wand, a Python binding of ImageMagick. Follow these instructions to install ImageMagick. Wand is installed below.
Create a virtual environment for untangle
by running python -m venv
(or uv venv
)
in the root folder.
Activate the virtual environment with source .venv/bin/activate
and run one of the
following commands based on your use case:
- Work with the existing code:
python -m pip install .
(oruv pip install .
) - Extend the code base:
python -m pip install -e '.[dev]'
(oruv pip install -e '.[dev]'
)
The test metrics and hyperparameter sweeps used for all methods on both ImageNet and CIFAR-10 (including the chosen hyperparameter ranges and logs) are available on Weights & Biases. Below, we provide direct links to per-method hyperparameter sweeps and final runs used in the paper for ImageNet, CIFAR-10, CIFAR-10 50%, and CIFAR-10 10%.
- CE Baseline: Hyperparameter Sweep, Final Results
- Correctness Prediction: Hyperparameter Sweep, Final Results
- HetClassNN: Hyperparameter Sweep, Final Results
- HET-XL: Hyperparameter Sweep, Final Results
- HET: Hyperparameter Sweep, Final Results
- Loss Prediction: Hyperparameter Sweep, Final Results
- Shallow Ensemble: Hyperparameter Sweep, Final Results
- DDU: Hyperparameter Sweep, Final Results
- MC Dropout: Hyperparameter Sweep, Final Results
- GP: Hyperparameter Sweep, Final Results
- SNGP: Hyperparameter Sweep, Final Results
- PostNet: Hyperparameter Sweep, Final Results
- EDL: Hyperparameter Sweep, Final Results
- Deep Ensemble: Final Results
- Laplace: Final Results
- Mahalanobis: Final Results
- Temperature Scaling: Final Results
- SWAG: Final Results
- The repository contains an improved version of Automatic Mixed Precision inference, for which results are available here.
- CE Baseline: Hyperparameter Sweep, Final Results
- Correctness Prediction: Hyperparameter Sweep, Final Results
- HetClassNN: Hyperparameter Sweep, Final Results
- HET-XL: Hyperparameter Sweep, Final Results
- HET: Hyperparameter Sweep, Final Results
- Loss Prediction: Hyperparameter Sweep, Final Results
- Shallow Ensemble: Hyperparameter Sweep, Final Results
- DDU: Hyperparameter Sweep, Final Results
- MC Dropout: Hyperparameter Sweep, Final Results
- GP: Hyperparameter Sweep, Final Results
- SNGP: Hyperparameter Sweep, Final Results
- PostNet: Hyperparameter Sweep, Final Results
- EDL: Hyperparameter Sweep, Final Results
- Deep Ensemble: Final Results
- Laplace: Final Results
- Mahalanobis: Final Results
- Temperature Scaling: Final Results
- SWAG: Final Results
- CE Baseline: Final Results
- Correctness Prediction: Final Results
- HetClassNN: Final Results
- HET-XL: Final Results
- HET: Final Results
- Loss Prediction: Final Results
- Shallow Ensemble: Final Results
- DDU: Final Results
- MC Dropout: Final Results
- GP: Final Results
- SNGP: Final Results
- PostNet: Final Results
- EDL: Final Results
- Deep Ensemble: Final Results
- Laplace: Final Results
- Mahalanobis: Final Results
- Temperature Scaling: Final Results
- SWAG: Final Results
- CE Baseline: Final Results
- Correctness Prediction: Final Results
- HetClassNN: Final Results
- HET-XL: Final Results
- HET: Final Results
- Loss Prediction: Final Results
- Shallow Ensemble: Final Results
- DDU: Final Results
- MC Dropout: Final Results
- GP: Final Results
- SNGP: Final Results
- PostNet: Final Results
- EDL: Final Results
- Deep Ensemble: Final Results
- Laplace: Final Results
- Mahalanobis: Final Results
- Temperature Scaling: Final Results
- SWAG: Final Results
We also provide access to the exact Singularity container
we used in our experiments. The Singularity
file was used to create this container by
running singularity build --fakeroot untangle.simg Singularity
.
To recreate the plots used in the paper, please refer to the Plotting Utilities section below.
The results of our experiments are available in full on Weights & Biases and in our paper. The Weights & Biases results can either be queried online or through our plotting utilities.
To access our results from the Weights & Biases console, click any of the "Test" links above and type any metric from the Metrics section.
As our paper contains more than 50 plots, we provide
general plotting utilities that allow to visualize results on any metric as opposed to
providing scripts to reproduce a particular plot. These utilities are found in the
plots
folder.
plot_ranking.py
: Our main plotting script that generates bar plots of method rankings on a particular metric. Needs a dataset (imagenet
orcifar10
), the label of the y axis, and the metric (see Metrics). The script has several optional arguments to have fine-grained control over the plot, as documented in the source code.plot_ranking_it.py
: Variant ofplot_ranking.py
which uses the information-theoretical decomposition's estimators for all methods. Can only be used with theauroc_oodness
andrank_correlation_bregman_au
metrics.plot_ranking_shallow.py
: Variant ofplot_ranking.py
which uses the information-theoretical decomposition's estimators for the Shallow Ensemble method and the best-performing one for the others. The Shallow Ensemble gives the most decorrelated aleatoric and epistemic estimates withthe information-theoretical decomposition, and the plots checks how practically relevant the resulting estimates are. Can only be used with theauroc_oodness
andrank_correlation_bregman_au
metrics.plot_full_correlation_matrix.py
: Calculates the correlation of method rankings across different metrics. Needs only a dataset (imagenet
orcifar10
).plot_correlation_matrix.py
: Variant ofplot_full_correlation_matrix.py
that calculates the correlations between a smaller set of metrics.plot_estimator_correlation_matrix.py
: Variant ofplot_full_correlation_matrix.py
that only calculates the correlations w.r.t. one estimator. This estimator can be eitherone_minus_max_probs_of_dual_bma
,one_minus_max_probs_of_bma
, orone_minus_expected_max_probs
.plot_correlation_datasets.py
: Prints correlation statistics of rankings on different metrics across different datasets (CIFAR-10 and ImageNet).plot_correctness_robustness.py
: Plots the performance of estimators and methods on in-distribution and increasingly out-of-distribution datasets w.r.t. the accuracy, correctness AUROC, and AUAC metrics. Requires only a dataset (imagenet
orcifar10
).plot_calibration_robustness.py
: Plots the robustness of the Laplace, Shallow Ensemble, EDL, and CE Baseline methods on the ECE metrics when going from in-distribution data to out-of-distribution data of severity level one. Requires only a dataset (imagenet
orcifar10
).
In this section, we provide a (non-exhaustive) list of descriptions of several metrics we consider in our benchmarks. Their codes can be either used to search for results on the online Weights & Biases console or in our plotting utilities.
auroc_hard_bma_correctness_original
: AUROC of uncertainty estimators w.r.t. the Bayesian Model Average's correctness and the original hard (i.e., one-hot) labels.auroc_soft_bma_correctness
: AUROC of uncertainty estimators w.r.t. the Bayesian Model Average's binary correctness indicators and the soft labels (either ImageNet-ReaL or CIFAR-10H).auroc_oodness
AUROC of uncertainty estimators w.r.t. the binary OOD indicators on a balanced mixture of ID and OOD data.hard_bma_accuracy_original
: Accuracy of the Bayesian Model Average w.r.t. the original hard (i.e., one-hot) labels.cumulative_hard_bma_abstinence_auc
: AUAC value of uncertainty estimators w.r.t. the Bayesian Model Average.log_prob_score_hard_bma_correctness_original
: The log probability proper scoring rule of uncertainty estimators w.r.t. the Bayesian Model Average's correctness on the original hard (i.e., one-hot) labels.brier_score_hard_bma_correctness
: The Brier score of uncertainty estimators w.r.t. the Bayesian Model Average's correctness on the original hard (i.e., one-hot) labels.log_prob_score_hard_bma_aleatoric_original
The log probability proper scoring rule of the Bayesian Model Average's predicted probability vector w.r.t. the ground-truth original hard (i.e., one-hot) labels. A.k.a. the log-likelihood of the labels under the model.brier_score_hard_bma_aleatoric_original
The Brier score of the Bayesian Model Average's predicted probability vector w.r.t. the ground-truth original hard (i.e., one-hot) labels. A.k.a. the negative L2 loss of the model's predictions.rank_correlation_bregman_au
: The rank correlation of uncertainty estimators with the groud-truth aleatoric uncertainty values from the Bregman decomposition.rank_correlation_bregman_b_dual_bma
: The rank correlation of uncertainty estimators with the bias values from the Bregman decomposition (which uses the Dual Bayesian Model Average instead of the primal one).rank_correlation_it_au_eu
: The rank correlation of the information-theoretical aleatoric and epistemic estimates.rank_correlation_bregman_eu_au_hat
: The rank correlation of the Bregman decomposition's epistemic estimates with the aleatoric estimates predicted by the model.rank_correlation_bregman_au_b_dual_bma
: The rank correlation of the Bregman decomposition's aleatoric and bias ground-truth values.ece_hard_bma_correctness_original
: ECE of uncertainty estimators w.r.t. the Bayesian Model Average's correctness and the original hard (i.e., one-hot) labels.
For more metrics, please refer to validate.py
.
Contributions are very welcome. Before contributing, please make sure to run
pre-commit install
. Feel free to open a pull request with new methods or fixes.