Skip to content

Commit

Permalink
[DPO] use ref model logprobs if it exists in the data (#885)
Browse files Browse the repository at this point in the history
* use logprobs if it exists in the batch

* add features to tokenized batch if in data

* make get_batch_logps a static method

* add tokenize_batch_element dataset mapper

* Remove tokenize_batch method from DPODataCollator

* Initial sketch to precompute reference_logps

* run ref model via pytorch dataloader

* add a padding helper

* clean up the helper

* use logprob item()

* default behaviour

* clean up collator

* add docstring

* copy data back to cpu if needed

* use get_train_dataloader methods

* fix tests

* rename: more explicit variable name precompute_ref_log_probs

* improve comment

* update comment

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* refactor models into setup parameters

* parametrize precompute_ref_log_probs flag

* remove useless test

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update tests/test_dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update tests/test_dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* update function arg name

* distinguish between pad token_id and mask values

* fix tokenization #932 by @nrailg

* fix test

* undo test refactor

* new line

* undo breaking change

* Update token counter condition to allow Llama tokenizer

* Acount for merged tokens on certain tokenizers such Llama-2 tokenizer

* Update variable name to match list value when truncating response

* map function on multi-gpu and gather

* Add test cases for DPOTrainer tokenization step

* revert since we need the prepeared model

* Use gather_with_metrics on ref_logps precomputation to keep original dataset size

* Add flag to keep track of when ref_logps are precomputed

* make variable names private

* formatting

* if precompute_ref_log_probs is true one can use non-peft to populate log-probs

* Use tokenizer padding token unless padding_value is set

* Move dataset.map(tokenize_batch) outside dataloader to avoid serialization errors

* eval can be none

* move to cpu to avoid gpu oom

* remove unneeded cast to float32

* remove unneeded

* fix merge

* fix merge

* fix merge

* add precompute log-prob status via tqdm

* Truncate answer if too longer once prompt has been truncated

* Add prompt_input_ids to batch to enable generation

* formatting and add lora example

* fix formatting

* Tokenize row now expects sample to have space on chosen/rejected for llama

* Revert "Tokenize row now expects sample to have space on chosen/rejected for llama"

This reverts commit dd07a10.

* raise error when using zero-3 with precompute_ref_log_probs

---------

Co-authored-by: Pablo Vicente Juan <p.vicente.juan@gmail.com>
Co-authored-by: Shoaib Burq <saburq@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
  • Loading branch information
4 people authored Dec 12, 2023
1 parent c0ce52a commit 48b3ef0
Show file tree
Hide file tree
Showing 4 changed files with 462 additions and 235 deletions.
16 changes: 16 additions & 0 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import torch
from datasets import Dataset, load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments

from trl import DPOTrainer
Expand Down Expand Up @@ -51,6 +52,10 @@ class ScriptArguments:
)
label_pad_token_id: Optional[int] = field(default=-100, metadata={"help": "label for non response tokens"})
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
# lora parameters
use_peft: Optional[bool] = field(default=True, metadata={"help": "Wether to use PEFT or not to train adapters"})
peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"})
peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"})
# instrumentation
sanity_check: Optional[bool] = field(default=True, metadata={"help": "only train on 1000 samples"})
report_to: Optional[str] = field(
Expand Down Expand Up @@ -163,6 +168,16 @@ def split_prompt_and_responses(sample) -> Dict[str, str]:
# gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs,
)

if script_args.use_peft:
peft_config = LoraConfig(
r=script_args.peft_lora_r,
lora_alpha=script_args.peft_lora_alpha,
bias="none",
task_type="CAUSAL_LM",
)
else:
peft_config = None

# 5. initialize the DPO trainer
dpo_trainer = DPOTrainer(
model,
Expand All @@ -176,6 +191,7 @@ def split_prompt_and_responses(sample) -> Dict[str, str]:
max_target_length=script_args.max_target_length,
max_prompt_length=script_args.max_prompt_length,
generate_during_eval=True,
peft_config=peft_config,
)

# 6. train
Expand Down
21 changes: 19 additions & 2 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def _init_dummy_dataset(self):
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
"[INST] How is the stock price? [/INST]",
"[INST] How is the stock price? [/INST] ",
],
"chosen": [
"hi nice to meet you",
Expand All @@ -60,6 +62,8 @@ def _init_dummy_dataset(self):
"Python",
"Python",
"Python",
"$46 as of 10am EST",
"46 as of 10am EST",
],
"rejected": [
"leave me alone",
Expand All @@ -69,15 +73,24 @@ def _init_dummy_dataset(self):
"Javascript",
"C++",
"Java",
" $46 as of 10am EST",
" 46 as of 10am EST",
],
}
# fmt: on
return Dataset.from_dict(dummy_dataset_dict)

@parameterized.expand(
[["gpt2", "sigmoid"], ["t5", "hinge"], ["gpt2", "ipo"], ["t5", "ipo"], ["gpt2", "kto"], ["t5", "kto"]]
[
["gpt2", "sigmoid", True],
["t5", "hinge", False],
["gpt2", "ipo", False],
["t5", "ipo", True],
["gpt2", "kto", True],
["t5", "kto", False],
]
)
def test_dpo_trainer(self, name, loss_type):
def test_dpo_trainer(self, name, loss_type, pre_compute):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
Expand Down Expand Up @@ -109,6 +122,7 @@ def test_dpo_trainer(self, name, loss_type):
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
precompute_ref_log_probs=pre_compute,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
Expand Down Expand Up @@ -146,6 +160,7 @@ def test_dpo_trainer_without_providing_ref_model(self):
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
precompute_ref_log_probs=True,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
Expand Down Expand Up @@ -196,6 +211,7 @@ def test_dpo_trainer_without_providing_ref_model_with_lora(self):
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
precompute_ref_log_probs=True,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
Expand Down Expand Up @@ -283,6 +299,7 @@ def test_dpo_lora_save(self):
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
precompute_ref_log_probs=True,
)

# train the model
Expand Down
Loading

0 comments on commit 48b3ef0

Please sign in to comment.