This repository contains the original implementation for "Residual Prompt Tuning: Improving Prompt Tuning with Residual Reparameterization" (ACL 2023) by Anastasia Razdaibiedina, Yuning Mao, Rui Hou, Madian Khabsa, Mike Lewis, Jimmy Ba and Amjad Almahairi.
🎊 Our work is accepted to ACL Findings 2023!
We introduce Residual Prompt Tuning – a simple and efficient method that significantly improves the performance and stability of prompt tuning. We propose to reparameterize soft prompt embedings using a shallow network with a residual connection.
This reparameterization gives the model more flexibility to decide between using a separate embedding for each prompt token versus the representation obtained from the shared reparameterization network. After training is completed, the reparameterization network can be discarded and original prompt embeddings can be replaced with their projections.
Our codebase includes pytorch implementation of:
- original prompt tuning (following Lester et al.)
- residual prompt tuning (our modification)
- full model tuning
Clone this repo as follows:
git clone https://github.com/arazd/ResidualPrompts
cd ResidualPrompts
conda env create -f environment.yaml
conda activate nlp
An example of training a 10-token soft prompt on WSC task using T5-base model and residual reparametrization with MLP1 type:
python train.py --task wsc --prefix_MLP MLP1 \
--lr 0.3 --freeze_weights 1 --freeze_except xxxx \
--model_name t5-base --early_stopping 1 \
--test_eval_after_every_task 1 --select_k_per_class -1 \
--batch_size 8 --num_epochs 20 --prefix_len 10 \
--save_dir /home/%u/my_dir/ --save_name my_model_folder
If you use the code for your work, please consider citing our paper:
@inproceedings{razdaibiedina2023residual,
title={Residual Prompt Tuning: Improving Prompt Tuning with Residual Reparameterization},
author={Razdaibiedina, Anastasia and Mao, Yuning and Hou, Rui and Khabsa, Madian and Lewis, Mike and Ba, Jimmy and Almahairi, Amjad},
booktitle={Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics},
year={2023}
}