Skip to content

hadasah/btm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Branch-Train-Merge: Embarrassingly Parallel Training of Expert Language Models

Below are instructions to access the code and models for the paper "Branch-Train-Merge: Embarrassingly Parallel Training of Expert Language Models".

Code

This code is based on Fairseq, and includes a hard fork of Fairseq in the fairseq folder.

Setup

For basic setup of our code:

git clone https://github.com/hadasah/btm.git
cd btm/fairseq
pip install -e .

Note that this will uninstall any existing Fairseq install in your environment. For additional install options, see the Fairseq repository README.

Data

Most of the experiments are conducted with data from DeMIX. You can easily train on your own data by following the Fairseq instructions for data preprocessing.

Model Training

To train a Transformer-LM baseline with domain data balancing, or to conduct seed LM training:

# TODO @margaretli remove this upper command chunk
FOLDER_NAME=mod_os_test;
NUM_GPUS=16;
ARCHITECTURE=transformer_lm_gpt3_small;
DATA_FOLDER=/private/home/suching/raw_data/demix_scale/data-bin/;
DATA_DOMAIN_NAME=1b_demix_paper;
SAVE_MODEL_FOLDER=/checkpoint/margaretli/;
NUM_UPDATES=80000;
UPDATE_FREQ=32;
LR=5e-4;
SAVE_INTERVAL_UPDATES=2000;
PORT=55555;
WANDB_PROJECT=mod_os_test;
BTM_CODE_PATH=/private/home/margaretli/gitfiles/btm;
RANDOM_SEED=1;
UNIQUE_RUN_ID=unique_name;
conda activate mod_os;
cd $BTM_CODE_PATH;
bash btm_shell_scripts/btm_train.sh $FOLDER_NAME $NUM_GPUS \
$ARCHITECTURE dense $DATA_FOLDER $DATA_DOMAIN_NAME None \
$SAVE_MODEL_FOLDER None None None False \
$NUM_UPDATES $UPDATE_FREQ $LR $SAVE_INTERVAL_UPDATES $PORT \
$WANDB_PROJECT $BTM_CODE_PATH $RANDOM_SEED $UNIQUE_RUN_ID ;

MODEL_FOLDER_NAME=project_name;
NUM_GPUS=16;
ARCHITECTURE=transformer_lm_gpt3_small;
DATA_FOLDER=/path/to/data;
DATA_DOMAIN_NAME=data_domain_name;
SAVE_MODEL_FOLDER=/path/to/new/model/checkpointing;
NUM_UPDATES=80000;
UPDATE_FREQ=32;
LR=5e-4;
SAVE_INTERVAL_UPDATES=2000;
PORT=55555;
WANDB_PROJECT=project_name;
BTM_CODE_PATH=/path/to/this/repo;
RANDOM_SEED=1;
UNIQUE_RUN_ID=unique_run_name;

cd $BTM_CODE_PATH;
bash btm_shell_scripts/btm_train.sh $MODEL_FOLDER_NAME $NUM_GPUS \
$ARCHITECTURE dense $DATA_FOLDER $DATA_DOMAIN_NAME None \
$SAVE_MODEL_FOLDER None None None False \
$NUM_UPDATES $UPDATE_FREQ $LR $SAVE_INTERVAL_UPDATES $PORT \
$WANDB_PROJECT $BTM_CODE_PATH $RANDOM_SEED $UNIQUE_RUN_ID ;

To branch and train from an existing checkpoint:

MODEL_FOLDER_NAME=project_name;
NUM_GPUS=2;
ARCHITECTURE=transformer_lm_gpt3_small;
DATA_FOLDER=/path/to/data;
DATA_DOMAIN_NAME=data_domain;
INIT_CHECKPOINT_FOLDER=/path/to/seed/model/folder;
SAVE_MODEL_FOLDER=/path/to/new/model/checkpointing;
SEED_PHASE_COMPUTE_SHARE=None;
SEED_PHASE_UPDATE_NUM=model_update_number;
NUM_UPDATES=80000;
UPDATE_FREQ=32;
LR=5e-4;
SAVE_INTERVAL_UPDATES=2000;
PORT=55555;
WANDB_PROJECT=project_name;
BTM_CODE_PATH=/path/to/this/repo;
RANDOM_SEED=1;
UNIQUE_RUN_ID=unique_run_name2;

bash btm_shell_scripts/btm_train.sh $MODEL_FOLDER_NAME $NUM_GPUS \
$ARCHITECTURE branch $DATA_FOLDER $DATA_DOMAIN_NAME $INIT_CHECKPOINT_FOLDER \
$SAVE_MODEL_FOLDER . $SEED_PHASE_COMPUTE_SHARE $SEED_PHASE_UPDATE_NUM True \
$NUM_UPDATES $UPDATE_FREQ $LR $SAVE_INTERVAL_UPDATES $PORT \
$WANDB_PROJECT $BTM_CODE_PATH $RANDOM_SEED $UNIQUE_RUN_ID ;

Model Evaluation

To evaluate a single LM:

DATA_FOLDER=/path/to/data;
MODEL_FOLDER=/path/to/new/model/checkpointing/project_name/unique_run_name;
CHECKPOINT_FILE_NAME=checkpoint_last.pt;
DATA_SPLIT=test;
DATA_DOMAIN_NAME=data_domain;

bash btm_shell_scripts/eval_pipeline.sh $DATA_FOLDER $MODEL_FOLDER \
$DATA_SPLIT $DATA_DOMAIN_NAME ;

To evaluate an ensemble of LMs, where the ensemble is weighted by the domain posterior (this requires jq):

NUM_EXPERTS=8;
DATA_FOLDER=/path/to/data;
MODEL_PATHS=/path/to/expert1:/path/to/expert2:etc;
DATA_DOMAIN_NAME=data_domain;
ENSEMBLE_TYPE=cached_prior;
RESULTS_OUTPUT_FOLDER=/path/to/output/folder;

bash btm_shell_scripts/ensemble_eval.sh $NUM_EXPERTS $DATA_FOLDER \
$MODEL_PATHS $DATA_DOMAIN_NAME $ENSEMBLE_TYPE $RESULTS_OUTPUT_FOLDER ;

To parameter average LMs:

RESULTING_MODEL_OUTPUT_FOLDER=/path/to/output/folder;
WEIGHTS=0.1,0.9;
MODEL_PATHS=/path/to/expert1:/path/to/expert2;
python btm_utils/average.py --output-dir RESULTING_MODEL_OUTPUT_FOLDER \
--weights WEIGHTS --model-files MODEL_PATHS ;

To evaluate the parameter-averaged LM, use the single-LM evaluation command above.

Models

All trained ELMs, Transformer-LM and DeMIX baselines across the 125M, 350M, 750M, and 1.3B parameter scales, as well as the 350M parameter ELMs trained on our 64-domain corpus are made available.

To download one of the Transformer_LMs:

# 125M, 350M, 750M or 1.3B
model_scale=125M;
model_architecture=transformer_lm;

mkdir -p btm_models/models/${model_scale}/${model_architecture}/
cd btm_models/models/${model_scale}/${model_architecture}/
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last.pt

To download the DeMIX models:

# 125M, 350M, 750M or 1.3B
model_scale=125M;
model_architecture=demix;

mkdir -p btm_models/models/${model_scale}/${model_architecture}/
cd btm_models/models/${model_scale}/${model_architecture}/
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-0.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-1.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-2.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-3.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-4.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-5.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-6.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-7.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-shared.pt;

To download one of the ELMs:

# 125M, 350M, 750M or 1.3B
model_scale=125M;
model_architecture=elmforest;
# one of the domains specified in btm_utils/constants.py
domain=1b;

mkdir -p btm_models/models/${model_scale}/${model_architecture}/${domain}/
cd btm_models/models/${model_scale}/${model_architecture}/${domain}/
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/${domain}/checkpoint_last.pt

To download one of the ELMs from the 64 domain experiments:

model_scale=64_domain_curriculum;
# one of the domains specified in btm_utils/constants.py
domain=2021newscrawl;

mkdir -p btm_models/models/${model_scale}/${domain}/
cd btm_models/models/${model_scale}/${domain}/
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${domain}/checkpoint_last.pt

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published