This repository contains the official implementation of experiments conducted in
- EVA: Efficient attention via control variates (ICLR 2023)
- LARA: Linear complexity randomized self-attention mechanism (ICML 2022)
🌲 Repo structure:
efficient-attention
: a small self-contained codebase that implements various efficient attention mechanisms. Please see the usage for more details.vit
: codebase for image classification experiments, which is adapted fromfairseq
: a modified fork of fairseq for language tasks, including machine translation and autoregressive language modeling.main.sh
: a bash script for launching all experiments.- See the script for the list of arguments.
- Note that arguments after
-e True
are directly passed to the training command. You can pass custom arguments to the training command by appending them after-e True
.
To setup the environment, run the following commands to install the required dependencies (recommended in a virtual environment):
# install packages
pip install -r requirements.txt
# install efficient-attention library
pip install -e efficient-attention
# OPTIONAL: install fairseq library for running language tasks
cd fairseq
python3 setup.py build develop
cd ..
The environment is tested with Python 3.8.10, PyTorch 1.12.0, and CUDA 11.3. Also note our fork of fairseq modifies several files in the original codebase; using more recent versions of fairseq might lead to unexpected dependency conflicts.
efficient-attention
is a small self-contained codebase that collects several efficient attention mechanisms.
- For arguments specific to each attention mechanism, please check the
add_attn_specific_args()
class method in the corresponding python file. - To pass these arguments to the
argparse
parser, follow the following code snippet:
import argparse
from efficient_attention import AttentionFactory
# ...
parser = argparse.ArgumentParser()
parser.add_argument('--attn-name', default='softmax', type=str, metavar='ATTN',
help='Name of attention model to use')
# ...
temp_args, _ = parser.parse_known_args()
# add attention-specific arguments to the parser
# struct_name: name of the inner namespace to store all attention-specific arguments
# prefix: prefix to prepend to all argument names
# for example, if prefix = encoder-attn, then for the argument --window-size
# we need to pass --encoder-attn-window-size
# this is useful to avoid argument name conflicts.
efficient_attention.AttentionFactory.add_attn_specific_args(parser, temp_args.attn_name, struct_name="attn_args", prefix="")
# parse arguments to a namespace that supports nested attributes
args = parser.parse_args(namespace=efficient_attention.NestedNamespace())
# now we can access the attention-specific arguments via args.attn_args
print(args.attn_args.window_size)
In a torch.nn.Module
class, you can create an efficient attention module as follows:
# we might want to pass attention-specific arguments to the attention module
# along with other related arguments
attn_args = {
**vars(args.attn_args),
**{
'dim': args.embed_dim,
'num_heads': args.num_heads,
'qkv_bias': args.qkv_bias,
'attn_drop': args.attn_drop_rate,
'proj_drop': args.drop_rate,
}
}
self.attn = AttentionFactory.build_attention(attn_name = attn_name, attn_args = attn_args)
# the module can then be used as a normal function as
x = self.attn(x)
We follow the setup similar to DeiT to pre-process the ImageNet dataset. Download ImageNet train and val images and place them in the following directory structure so that it can be compatible with the torchvision datasets.ImageFolder
/path/to/imagenet/
train/
class1/
img1.jpeg
class2/
img2.jpeg
val/
class1/
img3.jpeg
class2/
img4.jpeg
The following commands are used for training and evaluating various vision transformers with LARA/EVA
. The training is assumed to be conducted with 8 GPUs.
To use LARA/EVA
in different DeiT architectures:
# LARA: DeiT-tiny-p8
bash main.sh -m evit_tiny_p8 -p <dir-of-imagenet-data> -g 8 -d imagenet -e TRUE --dist-eval --num-workers 16 --clip-grad 5.0 --warmup-epochs 10 --seed 1 --attn-name lara --mis-type mis-opt --proposal-gen pool-mixed --alpha-coeff 2.0 --num-landmarks 49
# LARA: DeiT-tiny-p16
bash main.sh -m evit_tiny_p16 -p <dir-of-imagenet-data> -g 8 -d imagenet -e TRUE --dist-eval --num-workers 16 --clip-grad 5.0 --warmup-epochs 10 --seed 1 --attn-name lara --mis-type mis-opt --proposal-gen pool-mixed --alpha-coeff 2.0 --num-landmarks 49
# LARA: DeiT-small-p16
bash main.sh -m evit_small_p16 -p <dir-of-imagenet-data> -g 8 -d imagenet -e TRUE --dist-eval --num-workers 16 --clip-grad 5.0 --warmup-epochs 10 --seed 1 --attn-name lara --mis-type mis-opt --proposal-gen pool-mixed --alpha-coeff 2.0 --num-landmarks 49
# EVA: DeiT-tiny-p8
bash main.sh -m evit_tiny_p8 -p <dir-of-imagenet-data> -g 8 -d imagenet -e TRUE --dist-eval --num-workers 16 --clip-grad 5.0 --warmup-epochs 10 --seed 1 --attn-name eva --num-landmarks 49 --adaptive-proj default --window-size 7 --attn-2d --use-rpe
# EVA: DeiT-tiny-p16
bash main.sh -m evit_tiny_p16 -p <dir-of-imagenet-data> -g 8 -d imagenet -e TRUE --dist-eval --num-workers 16 --clip-grad 5.0 --warmup-epochs 10 --seed 1 --attn-name eva --num-landmarks 49 --adaptive-proj default --window-size 7 --attn-2d --use-rpe
# EVA: DeiT-small-p16
bash main.sh -m evit_small_p16 -p <dir-of-imagenet-data> -g 8 -d imagenet -e TRUE --dist-eval --num-workers 16 --clip-grad 5.0 --warmup-epochs 10 --seed 1 --attn-name eva --num-landmarks 49 --adaptive-proj default --window-size 7 --attn-2d --use-rpe
To adapt LARA/EVA
in PvTv2 architectures:
# LARA Attention
bash main.sh -m pvt_medium2 -p <dir-of-imagenet-data> -g 8 -d imagenet -e TRUE --dist-eval --num-workers 16 --clip-grad 1.0 --drop-path-rate 0.3 --warmup-epochs 10 --seed 1 --attn-name lara --pool-module-type dense --mis-type mis-opt --proposal-gen pool-mixed --num-landmarks 49 --alpha-coeff 2.0 --repeated-aug
# EVA Attention
bash main.sh -m pvt_medium2 -p <dir-of-imagenet-data> -g 8 -d imagenet -e TRUE --dist-eval --num-workers 16 --clip-grad 5.0 --drop-path-rate 0.3 --warmup-epochs 10 --seed 1 --attn-name eva --num-landmarks 49 --adaptive-proj default --window-size 7 --attn-2d --use-rpe --repeated-aug
Alternatively, you may want to try out other attention mechanisms:
# Softmax Attention
bash main.sh -m evit_tiny_p8 -d imagenet -e TRUE --dist-eval --num-workers 16 --clip-grad 5.0 --warmup-epochs 10 --seed 1 --attn-name softmax
# RFA/Performer
bash main.sh -m evit_tiny_p8 -d imagenet -e TRUE --dist-eval --num-workers 16 --clip-grad 5.0 --warmup-epochs 10 --seed 1 --attn-name performer --proj-method favorp --approx-attn-dim 64
# Local Attention
bash main.sh -m evit_tiny_p8 -d imagenet -e TRUE --dist-eval --num-workers 16 --clip-grad 5.0 --warmup-epochs 10 --seed 1 --attn-name local --window-size 7 --attn-2d --use-rpe
We use the standard pre-processing fairseq to prepare the data for language tasks.
- For machine translation, please follow here to prepare for the binarized
WMT'14 EN-DE
data; - For autoregressive language modeling, follow here to process the
Wikitext-103
dataset.
-r <resume-ckpt-DIR>
specifies the DIRECTORY that stores your checkpoints during training and can be used to resume training.- Note that all attention-specific arguments need to be associated with prefix
--encoder-attn-
(for encoder-side) /--decoder-attn-
(for decoder-side). See the examples below.
## LARA
CUDA_VISIBLE_DEVICES=0,1,2,3 bash main.sh -p <dir-of-your-bin-data> -d wmt -s lara_8 -g 4 -e TRUE --attn-name-encoder lara --encoder-attn-num-landmarks 8 --encoder-attn-proposal-gen adaptive-1d --encoder-attn-mis-type mis-opt
CUDA_VISIBLE_DEVICES=0,1,2,3 bash main.sh -p <dir-of-your-bin-data> -d wmt -s lara_16 -g 4 -e TRUE --attn-name-encoder lara --encoder-attn-num-landmarks 16 --encoder-attn-proposal-gen adaptive-1d --encoder-attn-mis-type mis-opt
CUDA_VISIBLE_DEVICES=0,1,2,3 bash main.sh -p <dir-of-your-bin-data> -d wmt -s lara_32 -g 4 -e TRUE --attn-name-encoder lara --encoder-attn-num-landmarks 32 --encoder-attn-proposal-gen adaptive-1d --encoder-attn-mis-type mis-opt
## EVA
CUDA_VISIBLE_DEVICES=0,1,2,3 bash main.sh -p <dir-of-your-bin-data> -d wmt -s eva_8_8 -g 4 -e TRUE --attn-name-encoder eva --encoder-attn-window-size 8 --encoder-attn-num-landmarks 8 --encoder-attn-adaptive-proj no-ln --encoder-attn-use-t5-rpe --encoder-attn-overlap-window
CUDA_VISIBLE_DEVICES=0,1,2,3 bash main.sh -p <dir-of-your-bin-data> -d wmt -s eva_16_8 -g 4 -e TRUE --attn-name-encoder eva --encoder-attn-window-size 16 --encoder-attn-num-landmarks 8 --encoder-attn-adaptive-proj no-ln --encoder-attn-use-t5-rpe --encoder-attn-overlap-window
CUDA_VISIBLE_DEVICES=0,1,2,3 bash main.sh -p <dir-of-your-bin-data> -d wmt -s eva_32_8 -g 4 -e TRUE --attn-name-encoder eva --encoder-attn-window-size 32 --encoder-attn-num-landmarks 8 --encoder-attn-adaptive-proj no-ln --encoder-attn-use-t5-rpe --encoder-attn-overlap-window
# Currently, LARA does not support causal masking yet.
# EVA on a 16-layer Transformer LM
CUDA_VISIBLE_DEVICES=0,1,2,3 bash main.sh -p <dir-of-your-bin-data> -m 16layers -d wikitext103 -s eva_128_8_16layers -g 4 -e TRUE --attn-name-decoder causal_eva --decoder-attn-window-size 128 --decoder-attn-causal --decoder-attn-adaptive-proj qk --decoder-attn-chunk-size 8 --decoder-attn-use-t5-rpe
# EVA on a 32-layer Transformer LM
CUDA_VISIBLE_DEVICES=0,1,2,3 bash main.sh -p <dir-of-your-bin-data> -m 32layers -d wikitext103 -s eva_128_8_32layers -g 4 -e TRUE --attn-name-decoder causal_eva --decoder-attn-window-size 128 --decoder-attn-causal --decoder-attn-adaptive-proj qk --decoder-attn-chunk-size 8 --decoder-attn-use-t5-rpe
For generation & evaluation, simply pass argument -i true
when calling main.sh
to perform the inference procedure only. The checkpoint path can be specified as -c <your-ckpt-path>
. For example,
# Machine Translation
CUDA_VISIBLE_DEVICES=0 bash main.sh -i true -c <your-possibly-avg-checkpoint.pt> -p <dir-of-your-bin-data> -d wmt -g 1
# Autoregressive Language Modeling
CUDA_VISIBLE_DEVICES=0 bash main.sh -i true -c <your-checkpoint_last.pt> -p <dir-of-your-bin-data> -d wikitext103 -g 1
We also provide trained EVA
model checkpoints in OneDrive for machine translation and language modeling tasks:
- wikitext103-eva-16layers-lm
- wikitext103-eva-32layers-lm
- wmt14ende-eva-e32_c8-mt
- wmt14ende-eva-e8_c8-mt
@inproceedings{zheng2023efficient,
title={Efficient Attention via Control Variates},
author={Lin Zheng and Jianbo Yuan and Chong Wang and Lingpeng Kong},
booktitle={International Conference on Learning Representations},
year={2023},
url={https://openreview.net/forum?id=G-uNfHKrj46}
}
@inproceedings{zheng2022linear,
title={Linear complexity randomized self-attention mechanism},
author={Lin Zheng and Chong Wang and Lingpeng Kong},
booktitle={International Conference on Machine Learning},
pages={27011--27041},
year={2022},
organization={PMLR}
}