diff --git a/README.md b/README.md index 22e0532..f1b1bd8 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ -# **H**um**a**n-Centered **Lo**ss Functions (HALOs) :innocent: +# Human-Aware Loss Functions (HALOs) :innocent: -This repo allows you to design new **HumAn-centered LOss functions (HALOs)** for aligning LLMs with offline human feedback at scale [(read more in our technical report)](assets/report.pdf). +This repo allows you to design new **Human-Aware Loss Functions (HALOs)** for aligning LLMs with offline human feedback at scale (read more in our [technical report](assets/report.pdf) or our [full paper](assets/full_paper.pdf)). It was used to create Archangel, the largest-ever suite of human-feedback-aligned LLMs, and has been tested at scales from 1B to 30B. This repo draws from the excellently written [DPO repo](https://github.com/eric-mitchell/direct-preference-optimization) and has preserved many design choices from the original. @@ -45,7 +45,7 @@ What should we do? 5. Write a trainer in `trainers.py`. This should subclass either `UnpairedPreferenceTrainer` or `PairedPreferenceTrainer` depending on whether it uses pairs of preferences or not. If you need highly custom behavior that is not in either, then you can subclass `BasicTrainer` directly. - We can implement a simple version of KTO as follows (not that this is different from the proper version of KTO in `KTOTrainer`, which does not assume the existence of both chosen and rejected examples in each batch). + We can implement a simple version of KTO as follows (note that this is different from the proper version of KTO in `KTOTrainer`, which does not assume the existence of both chosen and rejected examples in each batch). To make SimpleKTOTrainer, we just subclass `trainers.UnpairedPreferenceTrainer` as `trainers.SimpleKTOTrainer` and overwrite the loss function definition. KTO has one hyperparameter, beta, which we can access via `self.config.loss.beta`: @@ -65,17 +65,7 @@ What should we do? If generation y ~ p_rejected, , where x' ~ are the examples with chosen generations, we have the 'rejected' loss: L(x, y) := 1 - sigmoid(beta * KL(p_policy(y_chosen|x') || p_reference(y_chosen|x')) - [log p_policy(y|x) - log p_reference(y|x)]) """ - chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0) - rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0) - - chosen_logratios = (policy_chosen_logps - reference_chosen_logps) - rejected_logratios = (policy_rejected_logps - reference_rejected_logps) - - losses = torch.cat((1 - F.sigmoid(self.config.loss.beta * (chosen_logratios - rejected_KL)), 1 - F.sigmoid(self.config.loss.beta * (chosen_KL - rejected_logratios))), 0) - - chosen_rewards = self.config.loss.beta * (policy_chosen_logps - reference_chosen_logps).detach() - rejected_rewards = self.config.loss.beta * (policy_rejected_logps - reference_rejected_logps).detach() - + # your implementation goes here return losses, chosen_rewards, rejected_rewards ``` @@ -145,7 +135,7 @@ If you find this repo or the technical paper useful in your research, please fee ``` @techreport{ethayarajh2023halos, author = {Ethayarajh, Kawin and Xu, Winnie, and Jurafsky, Dan and Kiela, Douwe}, - title = {Human-Centered Loss Functions (HALOs)}, + title = {Human-Aware Loss Functions (HALOs)}, institution = {Contextual AI}, note = {https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf}, year = {2023}, diff --git a/assets/full_paper.pdf b/assets/full_paper.pdf new file mode 100644 index 0000000..401c407 Binary files /dev/null and b/assets/full_paper.pdf differ diff --git a/assets/report.pdf b/assets/report.pdf index 7554a75..2ee06c2 100644 Binary files a/assets/report.pdf and b/assets/report.pdf differ diff --git a/config/model/mistral7b_sft_beta.yaml b/config/model/mistral7b_sft_beta.yaml new file mode 100644 index 0000000..9149924 --- /dev/null +++ b/config/model/mistral7b_sft_beta.yaml @@ -0,0 +1,6 @@ +defaults: + - base_model + +name_or_path: HuggingFaceH4/mistral-7b-sft-beta +block_name: MistralDecoderLayer +use_flash_attention: true diff --git a/scripts/eval.sh b/scripts/eval.sh new file mode 100644 index 0000000..38c9e9b --- /dev/null +++ b/scripts/eval.sh @@ -0,0 +1,72 @@ +#!/bin/bash +CKPT=$1 +# Make out be the last string after / +OUT=results/${CKPT##*/} +mkdir -p $OUT + +DATADIR=evaldata/eval +FORMAT=eval.templates.create_prompt_with_halo_chat_format + +pwd=$(pwd) +cd open-instruct + +python -m eval.gsm.run_eval \ +--data_dir $DATADIR/gsm/ \ +--max_num_examples 200 \ +--save_dir $OUT \ +--model $CKPT \ +--tokenizer_name_or_path $CKPT \ +--n_shot 8 \ +--use_chat_format \ +--use_vllm \ +--chat_formatting_function $FORMAT + +python -m eval.mmlu.run_eval \ +--ntrain 0 \ +--data_dir $DATADIR/mmlu/ \ +--save_dir $OUT \ +--model_name_or_path $CKPT \ +--tokenizer_name_or_path $CKPT \ +--eval_batch_size 4 \ +--use_chat_format \ +--chat_formatting_function $FORMAT + +python -m eval.bbh.run_eval \ +--data_dir $DATADIR/bbh \ +--save_dir $OUT \ +--model $CKPT \ +--tokenizer_name_or_path $CKPT \ +--max_num_examples_per_task 40 \ +--use_chat_format \ +--use_vllm \ +--chat_formatting_function $FORMAT + +python -m eval.tydiqa.run_eval \ +--data_dir $DATADIR/tydiqa \ +--n_shot 1 \ +--max_num_examples_per_lang 100 \ +--max_context_length 512 \ +--save_dir $OUT \ +--model $CKPT \ +--tokenizer_name_or_path $CKPT \ +--use_chat_format \ +--use_vllm \ +--chat_formatting_function $FORMAT + +cd ../bigcode-evaluation-harness + +accelerate launch --config_file /home/niklas/sgpt2/scripts/configs/config_1gpu_vanilla.yml main.py \ +--model $CKPT \ +--tasks humanevalsynthesize-python \ +--do_sample True \ +--temperature 0.2 \ +--n_samples 20 \ +--batch_size 5 \ +--allow_code_execution \ +--save_generations \ +--trust_remote_code \ +--prompt halo \ +--save_generations_path $OUT/generations_humanevalsynthesizepython.json \ +--metric_output_path $OUT/evaluation_humanevalsynthesizepython.json \ +--max_length_generation 2048 \ +--precision bf16 diff --git a/scripts/reformat_statedict.py b/scripts/reformat_statedict.py new file mode 100644 index 0000000..48e200e --- /dev/null +++ b/scripts/reformat_statedict.py @@ -0,0 +1,19 @@ +import os +import sys +import torch + +path = sys.argv[1] + +sd_path = os.path.join(path, "policy.pt") + +sd = torch.load(sd_path) +# Check if already reformatted by checking if first key has model. prefix +if not "state" in list(sd.keys()): + print('SD seems already reformatted: ', sd.keys()) + sys.exit(0) +torch.save(sd["state"], sd_path) + +# Copy in tokenizer etc +os.system(f"mv {sd_path} {os.path.join(path, 'pytorch_model.bin')}") +os.system(f"cp -r /data/niklas/kto_mistralsft/*tok* {path}/") +os.system(f"cp -r /data/niklas/kto_mistralsft/*json* {path}/")