This is the official implementation of our paper: Expanding Sparse Tuning for Low Memory Usage.
We propose a method called SNELL (Sparse tuning with kerNELized LoRA) to enable sparse tuning with low memory usage. SNELL decomposes the tunable matrix for sparsification into two learnable low-rank matrices, saving from the costly storage of the original full matrix. To maintain the effectiveness of sparse tuning with low-rank matrices, we extend the low-rank decomposition from a kernel perspective. Specifically, we apply nonlinear kernel functions to the full-matrix merging and gain an increase in the rank of the merged matrix. Employing higher ranks enhances the ability of SNELL to optimize the pre-trained model sparsely for downstream tasks. To further reduce the memory usage in sparse tuning, we introduce a competition-based sparsification mechanism, avoiding the storage of tunable weight indexes. Extensive experiments on multiple downstream tasks show that SNELL achieves state-of-the-art performance with low memory usage, extending effective PEFT with sparse tuning to large-scale models.
If you find this repository or our paper useful, please consider citing and staring us!
@InProceedings{Shen_2024_SNELL,
title={Expanding Sparse Tuning for Low Memory Usage},
author={Shen, Shufan and Sun, Junshu and Ji, Xiangyang and Huang, Qingming and Wang, Shuhui},
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
year={2024}
}
./train.py
: run this file for training../scripts
: scripts for adapting pre-trained models to downstream tasks with SNELL../lib
: helper functions for io, loggings, training, and data loading../model
: backbone architectures and methods for fine-tuning../engine.py
: main training and eval functions../data
: storing FGVC and VTAB-1k benchmarks.
-
Clone this repo:
git clone https://github.com/ssfgunner/SNELL.git cd SNELL
-
Create a conda virtual environment and activate it:
conda create -n snell python=3.8 -y conda activate snell
-
Install
torch==1.12.1
andtorchvision==0.13.1
withCUDA==11.3
:conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch
-
Install other dependencies:
pip install -r requirements.txt
-
FGVC: Please download the datasets following VPT.
-
VTAB-1k: Since the processing of some datasets in original VTAB benchmark is tricky, we recommend the extracted VTAB-1k datasets shared by SSF for convenience. (Note that the license is in VTAB benchmark).
-
The file structure should look like:
data ├── fgvc │ ├── cub │ ├── nabirds │ └── ... └── vtab-1k ├── caltech101 ├── cifar └── ...
mkdir checkpoints
cd checkpoints
# Supervisedly pre-trained ViT-B/16
wget https://console.cloud.google.com/storage/browser/_details/vit_models/imagenet21k/ViT-B_16.npz
# MAE pre-trained ViT-B/16
wget https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth
# MoCo V3 pre-trained ViT-B/16
wget https://dl.fbaipublicfiles.com/moco-v3/vit-b-300ep/linear-vit-b-300ep.pth.tar
# Supervisedly pre-trained Swin-Transformer
wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth
# Supervisedly pre-trained ConvNeXt
wget https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth
We have provided training scripts for adapting supervised pre-trained ViT to FGVC and VTAB-1K with SNELL-32, for example:
# Fine-tuning supervised pre-trained ViT-B/16 with SNELL-32 for CUB dataset of FGVC
bash scripts/fgvc/snell32/vit_cub_snell.sh
# Fine-tuning supervised pre-trained ViT-B/16 with SNELL-32 for CIFAR dataset of VTAB-1k
bash scripts/vtab/snell32/vit_cifar_snell.sh
For other models, we provide scripts to fine-tune them on FGVC for example:
- For ViT pre-trained with MAE:
python train.py --data-path=./data/fgvc/${DATASET} --init_thres=${init_thres} \
--data-set=${DATASET} --model_name=vit_base_patch16_224_in21k_snell --resume=checkpoints/mae_pretrain_vit_base.pth \
--output_dir=${save_dir} \
--batch-size=${batch_size} --lr=0.001 --epochs=100 --weight-decay=${WEIGHT_DECAY} --mixup=0 --cutmix=0 \
--smoothing=0 --launcher="none" --seed=0 --val_interval=10 --opt=adamw --low_rank_dim=32 \
--exp_name="ViT_MAE_${DATASET}" --seed=0 \
--test --block=BlockSNELLParallel --tuning_model=snell --freeze_stage
- For ViT pre-trained with MoCo v3:
python train.py --data-path=./data/fgvc/${DATASET} --init_thres=${init_thres} \
--data-set=${DATASET} --model_name=vit_base_patch16_224_in21k_snell --resume=checkpoints/linear-vit-b-300ep.pth.tar \
--output_dir=${save_dir} \
--batch-size=${batch_size} --lr=0.001 --epochs=100 --weight-decay=${WEIGHT_DECAY} --mixup=0 --cutmix=0 \
--smoothing=0 --launcher="none" --seed=0 --val_interval=10 --opt=adamw --low_rank_dim=32 \
--exp_name="ViT_MoCo_${DATASET}" --seed=0 \
--test --block=BlockSNELLParallel --tuning_model=snell --freeze_stage
- For supervised pre-trained Swin-Transformer:
python train.py --data-path=./data/fgvc/${DATASET} --init_thres=${init_thres} \
--data-set=${DATASET} --model_name=swin_base_patch4_window7_224_in22k --resume=./checkpoints/swin_base_patch4_window7_224_22k.pth \
--output_dir=${save_dir} \
--batch-size=${batch_size} --lr=0.001 --epochs=100 --weight-decay=${WEIGHT_DECAY} --mixup=0 --cutmix=0 \
--smoothing=0 --launcher="none" --seed=0 --val_interval=10 --opt=adamw --low_rank_dim=32 \
--exp_name="Swin_${DATASET}" --seed=0 \
--test --block=BlockSNELLParallel --tuning_model=snell --freeze_stage
- For supervised pre-trained ConvNeXt:
python train.py --data-path=./data/fgvc/${DATASET} --init_thres=${init_thres} \
--data-set=${DATASET} --model_name=convnext_base_in22k --resume=./checkpoints/convnext_base_22k_224.pth \
--output_dir=${save_dir} \
--batch-size=${batch_size} --lr=0.001 --epochs=100 --weight-decay=${WEIGHT_DECAY} --mixup=0 --cutmix=0 \
--smoothing=0 --launcher="none" --seed=0 --val_interval=10 --opt=adamw --low_rank_dim=32 \
--exp_name="ConvNeXt_${DATASET}" --seed=0 \
--test --block=BlockSNELLParallel --tuning_model=snell --freeze_stage
Our code is modified from VPT, SSF and SPT. We thank the authors for their open-sourced code.