This is an unofficial implementation of Chain of Hindsight using PyTorch and Huggingface Trainer. The data loading script is directly taken from the original repo, and only the training part is re-written using PyTorch.
- For pip,
pip install -r environment/requirements.txt
- For conda,
conda env create -f environment/env.yml
A shell script for training can be found in train.sh
. It takes gpu device ids
as inputs and passes it to CUDA_VISIBLE_DEVICES
environment variable.
sh train.sh 0,1,2,3
To customize command line arguments, take a look at the arguments dataclasses used in the following files:
coh.coh_train.ExperimentArgs
coh.data.coh_data.CoHDataArgs
coh.trainer.CoHTrainArgs
(this inherits fromtransformers.TrainingArguments
)
Train script for LLaMA is also provided. The baseline script is:
sh llama_train.sh 0,1,2,3 ${LLAMA_PATH}
To use this script, you will need to have already downloaded LLaMa weights and converted it to pytorch weights using the convert script at huggingface transformers repo.
- Relevant PR
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
--input_dir /path/to/downloaded/llama/weights \
--model_size 7B \
--output_dir /output/path
To use DeepSpeed, you need nvcc with the correct version installed. Conda provides
cuda-nvcc
package, which is also included in env.yml
. However, to use this,
you need to set the CUDA_HOME
environment variable to point to the conda environment
(this is required for deepspeed JIT c++ compiler to point to the conda installed
nvcc
not the system-wide one). after creating the environment and activating it, set
export CUDA_HOME=/path/to/conda/envs/coh
Example deepspeed config files can be found in ds_config
. They are directly
taken from huggingface's deepspeed integration tutorial.
By default, train.sh
uses deepspeed. llama_train.sh
uses FSDP instead.
To further enhance efficiency of training, PEFT lora is applied. Pass --use_lora
into training arguments.
You can also use 8-bit training!
- This is compatible with PEFT.
- This is NOT compatible with DeepSpeed.
- Need to use
torchrun
launcher instead ofdeepspeed
launcher.
- Need to use
This repo diverges from the original repo's implementation in a few ways:
- The original repo does not have evaluation step.
- Here, no
bos_token
is prepended to theinput_ids
. This is because since the batching logic is chunk-wise, each sentence in a batch is not really a sentence. - No
weight_decay_mask
is used. - Forgetful Causal Masking is not applied.