Skip to content

Latest commit

 

History

History

agent_transformer

Agent Transformer

This folder contains the implementation of Agent Attention based on DeiT, PVT, Swin and CSwin models for image classification.

Dependencies

  • Python 3.9
  • PyTorch == 1.11.0
  • torchvision == 0.12.0
  • numpy
  • timm == 0.4.12
  • einops
  • yacs

Data preparation

The ImageNet dataset should be prepared as follows:

$ tree data
imagenet
├── train
│   ├── class1
│   │   ├── img1.jpeg
│   │   ├── img2.jpeg
│   │   └── ...
│   ├── class2
│   │   ├── img3.jpeg
│   │   └── ...
│   └── ...
└── val
    ├── class1
    │   ├── img4.jpeg
    │   ├── img5.jpeg
    │   └── ...
    ├── class2
    │   ├── img6.jpeg
    │   └── ...
    └── ...

Pretrained Models

model Reso Params FLOPs acc@1 config pretrained weights
Agent-DeiT-T $224^2$ 6.0M 1.2G 74.9 config TsinghuaCloud
Agent-DeiT-S $224^2$ 22.7M 4.4G 80.5 config TsinghuaCloud
Agent-DeiT-S $448^2$ 23.1M 17.7G 83.1 config TsinghuaCloud
Agent-DeiT-B $224^2$ 87.2M 17.6G 82.0 config TsinghuaCloud
Agent-PVT-T $224^2$ 11.6M 2.0G 78.4 config TsinghuaCloud
Agent-PVT-S $224^2$ 20.6M 4.0G 82.2 config TsinghuaCloud
Agent-PVT-M $224^2$ 35.9M 7.0G 83.4 config TsinghuaCloud
Agent-PVT-M $256^2$ 36.1M 9.2G 83.8 config TsinghuaCloud
Agent-PVT-L $224^2$ 48.7M 10.4G 83.7 config TsinghuaCloud
Agent-Swin-T $224^2$ 29M 4.5G 82.6 config TsinghuaCloud
Agent-Swin-S $224^2$ 50M 8.7G 83.7 config TsinghuaCloud
Agent-Swin-S $288^2$ 50M 14.6G 84.1 config TsinghuaCloud
Agent-Swin-B $224^2$ 88M 15.4G 84.0 config TsinghuaCloud
Agent-Swin-B $384^2$ 88M 46.3G 84.9 config TsinghuaCloud
Agent-CSwin-T $224^2$ 21M 4.3G 83.1 config TsinghuaCloud
Agent-CSwin-S $224^2$ 33M 6.8G 83.9 config TsinghuaCloud
Agent-CSwin-B $224^2$ 73M 14.9G 84.7 config TsinghuaCloud
Agent-CSwin-B $384^2$ 73M 46.3G 85.8 config TsinghuaCloud

Evaluate Agent-DeiT/Agent-PVT/Agent-Swin on ImageNet:

python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg <path-to-config-file> --data-path <imagenet-path> --output <output-path> --eval --resume <path-to-pretrained-weights>

Evaluate Agent-CSwin on ImageNet:

python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg <path-to-config-file> --data-path <imagenet-path> --output <output-path> --eval --resume <path-to-pretrained-weights>

Train Models from Scratch

  • To train Agent-DeiT/Agent-PVT/Agent-Swin on ImageNet from scratch, run:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg <path-to-config-file> --data-path <imagenet-path> --output <output-path>
  • To train Agent-CSwin-T/S/B on ImageNet from scratch, run:
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg <path-to-config-file> --data-path <imagenet-path> --output <output-path> --model-ema --model-ema-decay 0.99984/0.99984/0.99992

Fine-tuning on higher resolution

  • Fine-tune a Agent-Swin-B model pre-trained on 224x224 resolution to 384x384 resolution:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/agent_swin_b_384.yaml --data-path <imagenet-path> --output <output-path> --pretrained <path-to-224x224-pretrained-weights>
  • Fine-tune a Agent-CSwin-B model pre-trained on 224x224 resolution to 384x384 resolution:
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/agent_cswin_b_384.yaml --data-path <imagenet-path> --output <output-path> --pretrained <path-to-224x224-pretrained-weights> --model-ema --model-ema-decay 0.9998

Citation

If you find this repo helpful, please consider citing us.

@inproceedings{han2024agent,
  title={Agent attention: On the integration of softmax and linear attention},
  author={Han, Dongchen and Ye, Tianzhu and Han, Yizeng and Xia, Zhuofan and Pan, Siyuan and Wan, Pengfei and Song, Shiji and Huang, Gao},
  booktitle={European Conference on Computer Vision},
  year={2024},
}