B-cos Alignment for Inherently Interpretable CNNs and Vision Transformers
Moritz Böhle, Navdeeppal Singh, Mario Fritz, Bernt Schiele. TPAMI, 2024.
This repository contains the code for the B-cos v2 models.
These models are more efficient and easier to train than the original v1 B-cos models. Furthermore, we make a large number of pretrained B-cos models available for use.
If you want to take a quick look at the explanations the models generate, you can try out the Gradio web demo on .
If you prefer a more hands-on approach,
you can take a look at the demo notebook on Colab
or load the models directly via torch.hub
as explained below.
If you simply want to copy the model definitions, we provide a minimal, single-file reference implementation including
explanation mode in extra/minimal_bcos_resnet.py
!
UPDATE: We have also released our ViT models! See Model Zoo.
You only need to make sure you have torch
and torchvision
installed.
Then, loading the models via torch.hub
is as easy as:
import torch
# list all available models
torch.hub.list('B-cos/B-cos-v2')
# load a pretrained model
model = torch.hub.load('B-cos/B-cos-v2', 'resnet50', pretrained=True)
Inference and explanation visualization is as simple as:
from PIL import Image
import matplotlib.pyplot as plt
# load image
img = model.transform(Image.open('cat.jpg'))
img = img[None].requires_grad_()
# predict and explain
model.eval()
expl_out = model.explain(img)
print("Prediction:", expl_out["prediction"]) # predicted class idx
plt.imshow(expl_out["explanation"])
plt.show()
Each of the models has its inference transform attached to it, accessible via model.transform
.
Furthermore, each model has a .explain()
method that takes an image tensor and returns a dictionary
containing the prediction and the explanation, and some extras.
See the demo notebook
for more details on the .explain()
method.
Furthermore, each model has a get_classifier
and get_feature_extractor
method that return the
classifier and feature extractor modules respectively. These can useful for fine-tuning the models!
Depending on your use case, you can either install the bcos
package
or set up the development environment for training the models (for your custom models or for reproducing the results).
If you are simply interested in using the models (pretrained or otherwise),
then we provide a bcos
package that can be installed via pip
:
pip install bcos
This contains the models, their modules, transforms, and other utilities making it easy to use and build B-cos models. Take a look at the public API here. (I'll add a proper docs site if I have time or there's enough interest. Nonetheless, I have tried to keep the code well-documented, so it should be easy to follow.)
If you want to train your own B-cos models using this repository or are interested in reproducing the results, you can set up the development environment as follows:
Using conda
(recommended, especially if you want to reproduce the results):
conda env create -f environment.yml
conda activate bcos
Using pip
pip install -r requirements-train.txt
You can either set the paths in bcos/settings.py
or set the environment variables
DATA_ROOT
IMAGENET_PATH
to the paths of the data directories.
The DATA_ROOT
environment variable should point to the data root directory for CIFAR-10
(will be automatically downloaded).
For ImageNet, the IMAGENET_PATH
environment variable should point to the directory containing
the train
and val
directories.
For the bcos
package, as mentioned earlier, take a look at the public API here.
For evaluating or training the models, you can use the evaluate.py
and train.py
scripts, as follows:
You can use evaluate the accuracy of the models on the ImageNet validation set using:
python evaluate.py --dataset ImageNet --hubconf resnet18
This will download the model from torch.hub
and evaluate it on the ImageNet validation set.
The default batch size is 1, but you can change it using the --batch-size
argument.
Replace resnet18
with any of the other models listed in Model Zoo that you wish to evaluate.
Short version:
python train.py \
--dataset ImageNet \
--base_network bcos_final \
--experiment_name resnet18
Long version: See TRAINING.md for more details on how the setup works and how to train your own models.
Here are the ImageNet pre-trained models available in the model zoo.
You can find the links to the model weights below
(uploaded to the Weights
GitHub release).
Model/Entrypoint | Top-1 Accuracy | Top-5 Accuracy | #Params | Download |
---|---|---|---|---|
resnet18 |
68.736% | 87.430% | 11.69M | link |
resnet34 |
72.284% | 90.052% | 21.80M | link |
resnet50 |
75.882% | 92.528% | 25.52M | link |
resnet101 |
76.532% | 92.538% | 44.50M | link |
resnet152 |
76.484% | 92.398% | 60.13M | link |
resnext50_32x4d |
75.820% | 91.810% | 25.00M | link |
densenet121 |
73.612% | 91.106% | 7.95M | link |
densenet161 |
76.622% | 92.554% | 28.58M | link |
densenet169 |
75.186% | 91.786% | 14.08M | link |
densenet201 |
75.480% | 91.992% | 19.91M | link |
vgg11_bnu |
69.310% | 88.388% | 132.86M | link |
convnext_tiny |
77.488% | 93.192% | 28.54M | link |
convnext_base |
79.650% | 94.614% | 88.47M | link |
convnext_tiny_bnu |
76.826% | 93.090% | 28.54M | link |
convnext_base_bnu |
80.142% | 94.834% | 88.47M | link |
densenet121_long |
77.302% | 93.234% | 7.95M | link |
resnet50_long |
79.468% | 94.452% | 25.52M | link |
resnet152_long |
80.144% | 94.116% | 60.13M | link |
simple_vit_ti_patch16_224 |
59.960% | 81.838% | 5.80M | link |
simple_vit_s_patch16_224 |
69.246% | 88.096% | 22.28M | link |
simple_vit_b_patch16_224 |
74.408% | 91.156% | 86.90M | link |
simple_vit_l_patch16_224 |
75.060% | 91.378% | 178.79M | link |
vitc_ti_patch1_14 |
67.260% | 86.774% | 5.32M | link |
vitc_s_patch1_14 |
74.504% | 91.288% | 20.88M | link |
vitc_b_patch1_14 |
77.152% | 92.926% | 81.37M | link |
vitc_l_patch1_14 |
77.782% | 92.966% | 167.44M | link |
standard_simple_vit_ti_patch16_224 |
70.230% | 89.380% | 5.67M | link |
standard_simple_vit_s_patch16_224 |
74.470% | 91.226% | 21.96M | link |
standard_simple_vit_b_patch16_224 |
75.300% | 91.026% | 86.38M | link |
standard_simple_vit_l_patch16_224 |
75.710% | 90.050% | 178.10M | link |
standard_vitc_ti_patch1_14 |
72.590% | 90.788% | 5.33M | link |
standard_vitc_s_patch1_14 |
75.756% | 91.994% | 20.91M | link |
standard_vitc_b_patch1_14 |
76.790% | 92.024% | 81.39M | link |
standard_vitc_l_patch1_14 |
77.866% | 92.298% | 167.54M | link |
You can find these entrypoints in bcos/models/pretrained.py
.
This repository's code is licensed under the Apache License 2.0 which you can find in the LICENSE file.
The pre-trained models are trained on ImageNet (and are hence derived from it), which is licensed under the ImageNet Terms of access, which among others things, only allows non-commercial use of the dataset. It is therefore your responsibility to check whether you have permission to use the pre-trained models for your use case.
@article{Boehle2024TPAMI,
author={Böhle, Moritz and Singh, Navdeeppal and Fritz, Mario and Schiele, Bernt},
title = {B-cos Alignment for Inherently Interpretable CNNs and Vision Transformers},
journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence},
year = {2024},
pages = {1-15},
doi = {10.1109/TPAMI.2024.3355155},
}