Chinmay Savadikar1, Xi Song2, Tianfu Wu1
1North Carolina State University, 2An Independent Researcher
[Paper] | [Website]
The code in this repository has been tested with Python 3.9, but should be compatible with Python >= 3.8. We have tested the code with Pytorch 2.1.0. GIFT can be installed as a standalone package without the dependencies required for the experiments. To install GIFT, first install PyTorch
conda create -n gift python=3.9.19
conda activate gift
conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=12.1 -c pytorch -c nvidia
# or, if you have CUDA 11.1
# conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=11.8 -c pytorch -c nvidia
Once PyTorch is installed, install GIFT by running the following commands:
git clone https://github.com/savadikarc/gift.git
cd gift
pip install -e .
The directory gift_experiment_utils/visual_classification_utils contains the utility functions required for the visual classification experiments. The scripts to run the our experiments do not require an installation of the utility code as a package (but does need the user to install additional dependencies). To replicate the exact environment used in our experiments, please see the instructions in experiment_setup/SETUP.md. To run the experiments, please jump here.
Causal Language Modeling using transformers (example using Llama 3)
from transformers import AutoModelForCausalLM
from gift.gift import GIFTConfig, GIFTWrapperForCausalLM
# Define the pretrained backbone
dtype = torch.bfloat16 # or torch.float32, if bfloat16 is not supported
backbone = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B",
torch_dtype=dtype,
device_map=<DEVICE>
)
# Define the GIFT configuration
gift_config = GIFTConfig(
rank=64, # Rank of GIFT
dtype=dtype, # dtype for GIFT parameters and residual generation
gift_paramters=dict( # GIFT schema. By default, GIFT uses two simple linear projections
block_type='simple_block', # denoted by 'simple_block'
act_layer="identity" # with no non-linerarity in between
)
in_projection_bias=False, # and no biases in the first projection (phi)
out_projection_bias=False, # and no biases in the second projection (psi)
target_modules=["q_proj", "v_proj"] # target modules for GIFT
)
# Wrap the backbone with GIFT
model = GIFTWrapperForCausalLM(gift_config, backbone)
# Now, you can train the model as you would normally do
# ...
# ...
Vision Transformer (ViT) backbone using timm for image classification
import timm
from gift.gift import GIFTWrapperForImageClassification, GIFTConfig
# Define the pretrained backbone
backbone = timm.create_model('vit_base_patch16_224.augreg_in21k', pretrained=True, num_classes=<NUM_CLASSES>)
# Define the GIFT configuration
gift_config = GIFTConfig(
rank=16,
dtype='float32',
gift_paramters=dict(
block_type='simple_block',
act_layer="identity"
)
in_projection_bias=False,
out_projection_bias=False,
target_modules=["attn:proj"] # target modules for GIFT. By default, we use the final linear projection in the MHSA layer for ViTs
)
# The notation 'attn:proj' is chosen in order to disambiguate between any other modules named 'proj' in the model.
# The target modules should match the modules in the backbone model.
# Since GIFT stores it's modules in a ModuleDictionary, we use ':' to separate the module name from the submodule name instead
# # of the standard '.'. This is because keys in a ModuleDictionary cannot contain '.'.
# Wrap the backbone with GIFT
model = GIFTWrapperForImageClassification(gift_config, backbone)
# Now, you can train the model as you would normally do
# ...
# ...
RoBERTa backbone using transformers for Sequence Classification
from transformers import AutoConfig, AutoModelForSequenceClassification
from gift.gift import GIFTConfig, GIFTWrapperForSeqClassification
# Define the pretrained backbone
config = AutoConfig.from_pretrained(
"FacebookAI/roberta-base",
num_labels=<NUM_LABELS,
)
backbone = AutoModelForSequenceClassification.from_pretrained(
"FacebookAI/roberta-base",
config=config,
)
# Define the GIFT configuration
gift_config = GIFTConfig(
rank=32,
dtype='float32',
gift_paramters=dict(
block_type='simple_block',
act_layer="identity"
)
in_projection_bias=False,
out_projection_bias=False,
target_modules=["query", "value"]
)
# Wrap the backbone with GIFT
model = GIFTWrapperForSeqClassification(gift_config, backbone)
# Now, you can train the model as you would normally do
# ...
# ...
Please follow the instructions in experiment_setup/SETUP.md to setup the environment.
Please refer to scripts/language_modeling/README.md for details on data preparation and training for commonsense reasoning, arithmetic reasoning and instruction tuning experiments.
Please refer to scripts/visual_classification/fgvc/README.md for details on data preparation and training.
Please refer to scripts/visual_classification/vtab/README.md for details on data preparation and training.
Please refer to scripts/language_understanding/README.md for details on data preparation and training.
This code is based on code from timm, TOAST, and pyreft. We thank the authors for their amazing work.
@misc{savadikar2024gift,
title={GIFT: Generative Interpretable Fine-Tuning},
author={Chinmay Savadikar and Xi Song and Tianfu Wu},
year={2024},
eprint={2312.00700},
archivePrefix={arXiv},
primaryClass={cs.CV}
}