Skip to content

Latest commit

 

History

History
130 lines (92 loc) · 4.31 KB

README.md

File metadata and controls

130 lines (92 loc) · 4.31 KB

Resolving Interference When Merging Models (NeurIPS 2023)

teaser image

Setup

  1. Create a virtual environment and activate it.
python3 -m venv env
source env/bin/activate
  1. Install dependencies
python -m pip install -r requirements.txt -f https://download.pytorch.org/whl/cu113/torch_stable.html
  1. Download Story Cloze Dataset and update its path in data/dataset_readers.py StoryClozeReader class.

  2. Set the path to where finetuned models are stored in utils/merge_utils.py

We have released the IA3 checkpoints here!

Train

Train T5 Models

python src/training.py -c configs/t5_base.json -k train_batch_size=8 gradient_accumulation_factor=1 project_name=training experiment_name=test train_dataset=rte train_dataset_mixture=None num_batches=2

Evaluation

Evaluate IA3 across multiple prompts and report median.

$path_to_checkpoint = # path to your checkpoint
$eval_split = validation
$dataset = rte

python ./src/inference.py -c configs/ia3_base.json --multiple_prompts -i ${dataset} --kwargs checkpoint_to_directly_load_model=${path_to_checkpoint} split=${eval_split} project_name=ia3 experiment_name=${dataset}

Evaluate T5-Large.

$path_to_checkpoint = # path to your checkpoint
$eval_split = validation
$dataset = rte

python ./src/inference.py -c configs/t5_large.json -i ${dataset} --kwargs checkpoint_to_directly_load_model=${path_to_checkpoint} split=${eval_split} project_name=t5-large experiment_name=${dataset}

Merging Models

T5-Large

Basic Averaging

$eval_split = validation

python ./src/ties_merging.py -c configs/t5_large.json -i t5_mixture -m t5_mixture -f basic_mean --kwargs split=${eval_split} project_name=t5-large experiment_name=mean

Task Vectors

$eval_split = validation
$eval_function = task-vector_linear+0.1+1.01+0.1

python ./src/ties_merging.py -c configs/t5_large.json -i t5_mixture -m t5_mixture -f ${eval_function} --kwargs split=${eval_split} project_name=t5-large experiment_name=task_vectors

Performs merging for different values of lambda. will try out all lambda values between 0 and 1 in incrementso of 0.1.

TIES MERGING

$eval_split = validation
$redundant = topk20
$elect = mass
$agg = dis-mean
$scale = linear+0.8+2.51+0.1

python ./src/ties_merging.py -c configs/t5_large.json -i t5_mixture -m t5_mixture -f ${redundant}_${elect}_${agg}_${scale} --kwargs split=${eval_split} project_name=t5-large experiment_name=ties

IA3

Basic Averaging

$eval_split = validation

python ./src/ties_merging.py -c configs/ia3_base.json -i T0_held_out -m T0_held_out -f basic_mean --multiple_prompts --kwargs pretrained_model=bigscience/T0_3B split=${eval_split} project_name=ia3 experiment_name=mean

Task Vectors

$eval_split = validation
$eval_function = task-vector_linear+0.1+1.01+0.1

python ./src/ties_merging.py -c configs/ia3_base.json -i T0_held_out -m T0_held_out -f ${eval_function} --multiple_prompts --kwargs pretrained_model=bigscience/T0_3B split=${eval_split} project_name=ia3 experiment_name=task_vectors

TIES MERGING

$eval_split = validation
$redundant = topk20
$elect = mass
$agg = dis-mean
$scale = linear+0.8+2.51+0.1

python ./src/ties_merging.py -c configs/ia3_base.json -i T0_held_out -m T0_held_out -f ${redundant}_${elect}_${agg}_${scale} --multiple_prompts --kwargs pretrained_model=bigscience/T0_3B split=${eval_split} project_name=ia3 experiment_name=ties

Reference

Please cite our paper if you use our models in your works:

@inproceedings{
      yadav2023tiesmerging,
      title={{TIES}-Merging: Resolving Interference When Merging Models},
      author={Prateek Yadav and Derek Tam and Leshem Choshen and Colin Raffel and Mohit Bansal},
      booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
      year={2023},
      url={https://openreview.net/forum?id=xtaX3WyCj1}
}