INTR: A Simple Interpretable Transformer for Fine-grained Image Classification and Analysis (ICLR 2024)
This repo is the official implementation of INTR: A Simple Interpretable Transformer for Fine-grained Image Classification and Analysis. It currently includes code and models for the interpretation of fine-grained data. We will provide a link to the upcoming ICLR 2024 proceedings for this paper when it becomes available online.
INTR is a novel usage of Transformers to make image classification interpretable. In INTR, we investigate a proactive approach to classification, asking each class to look for itself in an image. We learn class-specific queries (one for each class) as input to the decoder, allowing them to look for their presence in an image via cross-attention. We show that INTR intrinsically encourages each class to attend distinctly; the cross-attention weights thus provide a meaningful interpretation of the model's prediction. Interestingly, via multi-head cross-attention, INTR could learn to localize different attributes of a class, making it particularly suitable for fine-grained classification and analysis.
In the INTR model, each query in the decoder is responsible for the prediction of a class. So, a query looks at itself to find class-specific features from the feature map. First, we visualize the feature map i.e., the value matrix of the transformer architecture to see the important parts of the object in the image. To find the specific features, where the model pays attention in the value matrix, we show the heatmap of the attention of the model. To avoid external interference in the classification, we use a shared weight vector for classification so therefore the attention weight explains the model's prediction.
INTR on DETR-R50 backbone, classification performance, and fine-tuned models on different datasets.
Dataset | acc@1 | acc@5 | Model |
---|---|---|---|
CUB | 71.8 | 89.3 | checkpoint download |
Bird | 97.4 | 99.2 | checkpoint download |
Butterfly | 95.0 | 98.3 | checkpoint download |
Create python environment (optional)
conda create -n intr python=3.8 -y
conda activate intr
Clone the repository
git clone https://github.com/dipanjyoti/INTR.git
cd INTR
Install python dependencies
pip install -r requirements.txt
Follow the below format for data.
datasets
├── dataset_name
│ ├── train
│ │ ├── class1
│ │ │ ├── img1.jpeg
│ │ │ ├── img2.jpeg
│ │ │ └── ...
│ │ ├── class2
│ │ │ ├── img3.jpeg
│ │ │ └── ...
│ │ └── ...
│ └── val
│ ├── class1
│ │ ├── img4.jpeg
│ │ ├── img5.jpeg
│ │ └── ...
│ ├── class2
│ │ ├── img6.jpeg
│ │ └── ...
│ └── ...
To evaluate the performance of INTR on the CUB dataset, on a multi-GPU (e.g., 4 GPUs) settings, execute the below command. INTR checkpoints are available at Fine-tune model and results.
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port 12345 --use_env main.py --eval --resume <path/to/intr_checkpoint_cub_detr_r50.pth> --dataset_path <path/to/datasets> --dataset_name <dataset_name>
To generate visual representations of the INTR's interpretations, execute the provided command below. This command will present the interpretation for a specific class with the index <class_number>. By default, it will display interpretations from all attention heads. To focus on interpretations associated with the top queries labeled as top_q as well, set the parameter sim_query_heads to 1. Use a batch size of 1 for the visualization.
python -m tools.visualization --eval --resume <path/to/intr_checkpoint_cub_detr_r50.pth> --dataset_path <path/to/datasets> --dataset_name <dataset_name> --class_index <class_number>
Inference time single-image prediction and visualization: We've also provided a Jupyter Notebook, demo.ipynb, designed for single-image prediction and visualization during the inference process. Please note that the demo is focused on the CUB dataset.
To prepare INTR for training, use the pretrained model DETR-R50. To train for a particular dataset, modify '--num_queries' by setting it to the number of classes in the dataset. Within the INTR architecture, each query in the decoder is assigned the task of capturing class-specific features, which means that every query can be adapted through the learning process. Consequently, the total number of model parameters will grow in proportion to the number of classes in the dataset. To train INTR on a multi-GPU system, (e.g., 4 GPUs), execute the command below.
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port 12345 --use_env main.py --finetune <path/to/detr-r50-e632da11.pth> --dataset_path <path/to/datasets> --dataset_name <dataset_name> --num_queries <num_of_classes>
Our model is inspired by the DEtection TRansformer (DETR) method.
We thank the authors of DETR for doing such great work.
If you find our work helpful for your research, please consider citing the BibTeX entry.
@inproceedings{paul2024simple,
title={A Simple Interpretable Transformer for Fine-Grained Image Classification and Analysis},
author={Paul, Dipanjyoti and Chowdhury, Arpita and Xiong, Xinqi and Chang, Feng-Ju and Carlyn, David and Stevens, Samuel and Provost, Kaiya and Karpatne, Anuj and Carstens, Bryan and Rubenstein, Daniel and Stewart, Charles and Berger-Wolf, Tanya and Su, Yu and Chao, Wei-Lun},
booktitle={International Conference on Learning Representations},
year={2024}
}