This folder contains the implementation of Agent Attention based on DeiT, PVT, Swin and CSwin models for image classification.
- Python 3.9
- PyTorch == 1.11.0
- torchvision == 0.12.0
- numpy
- timm == 0.4.12
- einops
- yacs
The ImageNet dataset should be prepared as follows:
$ tree data
imagenet
├── train
│ ├── class1
│ │ ├── img1.jpeg
│ │ ├── img2.jpeg
│ │ └── ...
│ ├── class2
│ │ ├── img3.jpeg
│ │ └── ...
│ └── ...
└── val
├── class1
│ ├── img4.jpeg
│ ├── img5.jpeg
│ └── ...
├── class2
│ ├── img6.jpeg
│ └── ...
└── ...
model | Reso | Params | FLOPs | acc@1 | config | pretrained weights |
---|---|---|---|---|---|---|
Agent-DeiT-T | 6.0M | 1.2G | 74.9 | config | TsinghuaCloud | |
Agent-DeiT-S | 22.7M | 4.4G | 80.5 | config | TsinghuaCloud | |
Agent-DeiT-S | 23.1M | 17.7G | 83.1 | config | TsinghuaCloud | |
Agent-DeiT-B | 87.2M | 17.6G | 82.0 | config | TsinghuaCloud | |
Agent-PVT-T | 11.6M | 2.0G | 78.4 | config | TsinghuaCloud | |
Agent-PVT-S | 20.6M | 4.0G | 82.2 | config | TsinghuaCloud | |
Agent-PVT-M | 35.9M | 7.0G | 83.4 | config | TsinghuaCloud | |
Agent-PVT-M | 36.1M | 9.2G | 83.8 | config | TsinghuaCloud | |
Agent-PVT-L | 48.7M | 10.4G | 83.7 | config | TsinghuaCloud | |
Agent-Swin-T | 29M | 4.5G | 82.6 | config | TsinghuaCloud | |
Agent-Swin-S | 50M | 8.7G | 83.7 | config | TsinghuaCloud | |
Agent-Swin-S | 50M | 14.6G | 84.1 | config | TsinghuaCloud | |
Agent-Swin-B | 88M | 15.4G | 84.0 | config | TsinghuaCloud | |
Agent-Swin-B | 88M | 46.3G | 84.9 | config | TsinghuaCloud | |
Agent-CSwin-T | 21M | 4.3G | 83.1 | config | TsinghuaCloud | |
Agent-CSwin-S | 33M | 6.8G | 83.9 | config | TsinghuaCloud | |
Agent-CSwin-B | 73M | 14.9G | 84.7 | config | TsinghuaCloud | |
Agent-CSwin-B | 73M | 46.3G | 85.8 | config | TsinghuaCloud |
Evaluate Agent-DeiT/Agent-PVT/Agent-Swin
on ImageNet:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg <path-to-config-file> --data-path <imagenet-path> --output <output-path> --eval --resume <path-to-pretrained-weights>
Evaluate Agent-CSwin
on ImageNet:
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg <path-to-config-file> --data-path <imagenet-path> --output <output-path> --eval --resume <path-to-pretrained-weights>
- To train
Agent-DeiT/Agent-PVT/Agent-Swin
on ImageNet from scratch, run:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg <path-to-config-file> --data-path <imagenet-path> --output <output-path>
- To train
Agent-CSwin-T/S/B
on ImageNet from scratch, run:
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg <path-to-config-file> --data-path <imagenet-path> --output <output-path> --model-ema --model-ema-decay 0.99984/0.99984/0.99992
- Fine-tune a
Agent-Swin-B
model pre-trained on 224x224 resolution to 384x384 resolution:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/agent_swin_b_384.yaml --data-path <imagenet-path> --output <output-path> --pretrained <path-to-224x224-pretrained-weights>
- Fine-tune a
Agent-CSwin-B
model pre-trained on 224x224 resolution to 384x384 resolution:
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/agent_cswin_b_384.yaml --data-path <imagenet-path> --output <output-path> --pretrained <path-to-224x224-pretrained-weights> --model-ema --model-ema-decay 0.9998
If you find this repo helpful, please consider citing us.
@inproceedings{han2024agent,
title={Agent attention: On the integration of softmax and linear attention},
author={Han, Dongchen and Ye, Tianzhu and Han, Yizeng and Xia, Zhuofan and Pan, Siyuan and Wan, Pengfei and Song, Shiji and Huang, Gao},
booktitle={European Conference on Computer Vision},
year={2024},
}