The PyTorch implementation of Att-DARTS: Differentiable Neural Architecture Search for Attention.
The codes are based on https://github.com/dragen1860/DARTS-PyTorch.
- Python == 3.7
- PyTorch == 1.0.1
- torchvision == 0.2.2
- pillow == 6.2.1
- numpy
- graphviz
- requests
- tqdm
We recommend downloading PyTorch from here.
- CIFAR-10/100: automatically downloaded by torchvision to
data
folder. - ImageNet (ILSVRC2012 version): manually downloaded following the instructions here.
CIFAR-10 | CIFAR-100 | Params(M) | |
---|---|---|---|
DARTS | 2.76 ± 0.09 | 16.69 ± 0.28 | 3.3 |
Att-DARTS | 2.54 ± 0.10 | 16.54 ± 0.40 | 3.2 |
top-1 | top-5 | Params(M) | |
---|---|---|---|
DARTS | 26.7 | 8.7 | 4.7 |
Att-DARTS | 26.0 | 8.5 | 4.6 |
Our script occupies all available GPUs. Please set environment CUDA_VISIBLE_DEVICES
.
To carry out architecture search using 2nd-order approximation, run:
python train_search.py --unrolled
The found cell will be saved in genotype.json
.
Our resultant Att_DARTS
is written in genotypes.py.
Inserting an attention at other locations is supported through the --location
flag.
The locations are specified at AttLocation
in model_search.py.
To evaluate our best cells by training from scratch, run:
python train_CIFAR10.py --auxiliary --cutout --arch Att_DARTS # CIFAR-10
python train_CIFAR100.py --auxiliary --cutout --arch Att_DARTS # CIFAR-100
python train_ImageNet.py --auxiliary --arch Att_DARTS # ImageNet
Customized architectures are supported through the --arch
flag once specified in genotypes.py.
Also, you can designate the search result in .json
through the --arch_path
flag:
python train_CIFAR10.py --auxiliary --cutout --arch_path ${PATH} # CIFAR-10
python train_CIFAR100.py --auxiliary --cutout --arch_path ${PATH} # CIFAR-100
python train_ImageNet.py --auxiliary --arch_path ${PATH} # ImageNet
where ${PATH}
should be replaced by the path to the .json
.
The trained model is saved in trained.pt
.
After training, the test script automatically runs.
Also, you can always test the trained.pt
as indicated below.
To test a pretrained model saved in .pt
, run:
python test_CIFAR10.py --auxiliary --model_path ${PATH} --arch Att_DARTS # CIFAR-10
python test_CIFAR100.py --auxiliary --model_path ${PATH} --arch Att_DARTS # CIFAR-100
python test_imagenet.py --auxiliary --model_path ${PATH} --arch Att_DARTS # ImageNet
where ${PATH}
should be replaced by the path to .pt
.
You can designate our pretrained models (cifar10_att.pt, cifar100_att.pt, imagenet_att.pt) or the saved trained.pt
in Architecture Evaluation.
Also, we support customized architectures specified in genotypes.py through the --arch
flag, or architectures specified in .json
through the --arch_path
flag.
You can visualize the found cells in genotypes.py.
For example, you can visualize Att-DARTS
running:
python visualize.py Att_DARTS
Also, you can visualize the saved cell in .json
:
python visualize.py genotype.json
This repository includes the following attentions:
- Squeeze-and-Excitation (paper / code (unofficial))
- Gather-Excite (paper / code (unofficial))
- BAM (paper / code)
- CBAM (paper / code)
- A2-Nets (paper / code (unofficial))
@inproceedings{att-darts2020IJCNN,
author = {Nakai, Kohei and Matsubara, Takashi and Uehara, Kuniaki},
booktitle = {The International Joint Conference on Neural Networks (IJCNN)},
title = {{Att-DARTS: Differentiable Neural Architecture Search for Attention}},
year = {2020}
}