Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DPO] use ref model logprobs if it exists in the data #885

Merged
merged 69 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
0110d87
use logprobs if it exists in the batch
kashif Oct 17, 2023
25d726b
add features to tokenized batch if in data
kashif Oct 17, 2023
920a605
make get_batch_logps a static method
kashif Oct 18, 2023
ab09bc0
add tokenize_batch_element dataset mapper
kashif Oct 18, 2023
3471d89
Remove tokenize_batch method from DPODataCollator
pablovicente Oct 19, 2023
4e6ccb9
Initial sketch to precompute reference_logps
pablovicente Oct 19, 2023
ba35bca
run ref model via pytorch dataloader
kashif Oct 19, 2023
4f967d1
add a padding helper
kashif Oct 19, 2023
2168cae
clean up the helper
kashif Oct 19, 2023
6631c9f
use logprob item()
kashif Oct 19, 2023
f06b9b0
default behaviour
kashif Oct 19, 2023
16cfc87
clean up collator
kashif Oct 19, 2023
0bd2058
add docstring
kashif Oct 19, 2023
e1acfb3
copy data back to cpu if needed
kashif Oct 27, 2023
1e50685
use get_train_dataloader methods
kashif Oct 27, 2023
5096ff2
fix tests
kashif Oct 27, 2023
1d0145c
rename: more explicit variable name precompute_ref_log_probs
sabman Nov 1, 2023
d3bc976
improve comment
sabman Nov 1, 2023
64e06d9
update comment
sabman Nov 1, 2023
5281f3c
Update trl/trainer/dpo_trainer.py
sabman Nov 1, 2023
14c3ce9
refactor models into setup parameters
sabman Nov 1, 2023
a1fa8d1
parametrize precompute_ref_log_probs flag
kashif Nov 2, 2023
102a841
remove useless test
kashif Nov 3, 2023
1c03583
Update trl/trainer/dpo_trainer.py
kashif Nov 3, 2023
dba4f27
Update tests/test_dpo_trainer.py
kashif Nov 3, 2023
d58075f
Update tests/test_dpo_trainer.py
kashif Nov 3, 2023
f56710a
Update trl/trainer/dpo_trainer.py
kashif Nov 3, 2023
7f0bd70
Update trl/trainer/dpo_trainer.py
kashif Nov 3, 2023
03f3dda
update function arg name
kashif Nov 3, 2023
0086a02
distinguish between pad token_id and mask values
kashif Nov 3, 2023
5464b12
fix tokenization #932 by @nrailg
kashif Nov 3, 2023
ef7c5ee
Merge branch 'main' into reference-logprobs
kashif Nov 6, 2023
4b32a8e
fix test
kashif Nov 6, 2023
b7c0255
Merge branch 'main' into reference-logprobs
kashif Nov 6, 2023
3c08299
undo test refactor
kashif Nov 7, 2023
15447b7
new line
kashif Nov 7, 2023
0182c95
undo breaking change
kashif Nov 7, 2023
cc9ec00
Update token counter condition to allow Llama tokenizer
pablovicente Nov 7, 2023
e5533bf
Acount for merged tokens on certain tokenizers such Llama-2 tokenizer
pablovicente Nov 7, 2023
9fb7868
Update variable name to match list value when truncating response
pablovicente Nov 7, 2023
78cc8f2
map function on multi-gpu and gather
kashif Nov 8, 2023
d0644c9
Add test cases for DPOTrainer tokenization step
pablovicente Nov 8, 2023
3d175bc
revert since we need the prepeared model
kashif Nov 9, 2023
352bde3
Use gather_with_metrics on ref_logps precomputation to keep original …
pablovicente Nov 9, 2023
b82b868
Add flag to keep track of when ref_logps are precomputed
pablovicente Nov 9, 2023
411cf79
make variable names private
kashif Nov 9, 2023
5ed4f10
formatting
kashif Nov 9, 2023
058e6e8
if precompute_ref_log_probs is true one can use non-peft to populate …
kashif Nov 9, 2023
2639a19
Use tokenizer padding token unless padding_value is set
pablovicente Nov 15, 2023
2dc9dc1
Move dataset.map(tokenize_batch) outside dataloader to avoid serializ…
pablovicente Nov 16, 2023
619a170
eval can be none
kashif Nov 17, 2023
1575da8
move to cpu to avoid gpu oom
kashif Nov 17, 2023
1c9d770
remove unneeded cast to float32
kashif Nov 20, 2023
206248c
remove unneeded
kashif Nov 20, 2023
190c2f0
Merge branch 'main' into reference-logprobs
kashif Nov 24, 2023
db2eec4
fix merge
kashif Nov 24, 2023
7fb8846
Merge remote-tracking branch 'upstream/main' into reference-logprobs
kashif Dec 1, 2023
cbd13c0
fix merge
kashif Dec 1, 2023
158ecfd
Merge remote-tracking branch 'upstream/main' into reference-logprobs
kashif Dec 5, 2023
1aa5c38
fix merge
kashif Dec 5, 2023
9c52b21
add precompute log-prob status via tqdm
kashif Dec 6, 2023
36b80b0
Truncate answer if too longer once prompt has been truncated
pablovicente Dec 8, 2023
7f0bf14
Add prompt_input_ids to batch to enable generation
pablovicente Dec 8, 2023
5d1dd2d
formatting and add lora example
kashif Dec 8, 2023
9a65cc3
Merge branch 'main' into reference-logprobs
kashif Dec 11, 2023
b40c264
fix formatting
kashif Dec 11, 2023
dd07a10
Tokenize row now expects sample to have space on chosen/rejected for …
pablovicente Dec 11, 2023
94207ba
Revert "Tokenize row now expects sample to have space on chosen/rejec…
kashif Dec 12, 2023
0d9a4c0
raise error when using zero-3 with precompute_ref_log_probs
kashif Dec 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 43 additions & 31 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,24 @@
class DPOTrainerTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
cls.model = AutoModelForCausalLM.from_pretrained(cls.model_id)
cls.ref_model = AutoModelForCausalLM.from_pretrained(cls.model_id)
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id)
cls.tokenizer.pad_token = cls.tokenizer.eos_token

# get t5 as seq2seq example:
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab"
cls.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
cls.t5_ref_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
cls.models = {
"gpt2": {
"model_id": "trl-internal-testing/dummy-GPT2-correct-vocab",
"model_type": AutoModelForCausalLM,
"tokenizer_type": AutoTokenizer,
},
"t5": {
"model_id": "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab",
"model_type": AutoModelForSeq2SeqLM,
"tokenizer_type": AutoTokenizer,
},
# add more models here if needed
}
for key, model_info in cls.models.items():
model_info["model"] = model_info["model_type"].from_pretrained(model_info["model_id"])
model_info["ref_model"] = model_info["model_type"].from_pretrained(model_info["model_id"])
model_info["tokenizer"] = model_info["tokenizer_type"].from_pretrained(model_info["model_id"])
cls.models["gpt2"]["tokenizer"].pad_token = cls.models["gpt2"]["tokenizer"].eos_token

def _init_dummy_dataset(self):
# fmt: off
Expand Down Expand Up @@ -74,8 +81,12 @@ def _init_dummy_dataset(self):
# fmt: on
return Dataset.from_dict(dummy_dataset_dict)

@parameterized.expand([["gpt2", "sigmoid"], ["t5", "hinge"]])
def test_dpo_trainer(self, name, loss_type):
def _get_models_by_name(self, name):
return self.models[name]["model"], self.models[name]["ref_model"], self.models[name]["tokenizer"]

@parameterized.expand([["gpt2", "sigmoid", True], ["t5", "hinge", False]])
def test_dpo_trainer(self, name, loss_type, precompute_ref_log_probs):
model, ref_model, tokenizer = self._get_models_by_name(name)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
Expand All @@ -89,15 +100,6 @@ def test_dpo_trainer(self, name, loss_type):

dummy_dataset = self._init_dummy_dataset()

if name == "gpt2":
model = self.model
ref_model = self.ref_model
tokenizer = self.tokenizer
elif name == "t5":
model = self.t5_model
ref_model = self.t5_ref_model
tokenizer = self.t5_tokenizer

trainer = DPOTrainer(
model=model,
ref_model=ref_model,
Expand All @@ -107,6 +109,7 @@ def test_dpo_trainer(self, name, loss_type):
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
precompute_ref_log_probs=precompute_ref_log_probs,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
Expand All @@ -122,7 +125,9 @@ def test_dpo_trainer(self, name, loss_type):
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

def test_dpo_trainer_without_providing_ref_model(self):
@parameterized.expand([["gpt2", False], ["t5", True]])
kashif marked this conversation as resolved.
Show resolved Hide resolved
def test_dpo_trainer_without_providing_ref_model(self, name, precompute_ref_log_probs):
model, _, tokenizer = self._get_models_by_name(name)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
Expand All @@ -137,13 +142,14 @@ def test_dpo_trainer_without_providing_ref_model(self):
dummy_dataset = self._init_dummy_dataset()

trainer = DPOTrainer(
model=self.model,
model=model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=self.tokenizer,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
precompute_ref_log_probs=precompute_ref_log_probs,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
Expand All @@ -161,7 +167,9 @@ def test_dpo_trainer_without_providing_ref_model(self):

@require_peft
@mark.peft_test
def test_dpo_trainer_without_providing_ref_model_with_lora(self):
@parameterized.expand([["gpt2", True], ["t5", False]])
kashif marked this conversation as resolved.
Show resolved Hide resolved
def test_dpo_trainer_without_providing_ref_model_with_lora(self, name, precompute_ref_log_probs):
model, _, tokenizer = self._get_models_by_name(name)
from peft import LoraConfig

lora_config = LoraConfig(
Expand All @@ -186,14 +194,15 @@ def test_dpo_trainer_without_providing_ref_model_with_lora(self):
dummy_dataset = self._init_dummy_dataset()

trainer = DPOTrainer(
model=self.model,
model=model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=self.tokenizer,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
precompute_ref_log_probs=precompute_ref_log_probs,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
Expand All @@ -210,8 +219,10 @@ def test_dpo_trainer_without_providing_ref_model_with_lora(self):
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

@parameterized.expand([["gpt2", False], ["t5", True]])
@require_no_wandb
def test_dpo_trainer_generate_during_eval_no_wandb(self):
def test_dpo_trainer_generate_during_eval_no_wandb(self, name, precompute_ref_log_probs):
model, _, tokenizer = self._get_models_by_name(name)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
Expand All @@ -231,12 +242,13 @@ def test_dpo_trainer_generate_during_eval_no_wandb(self):
" Please install `wandb` to resolve.",
):
DPOTrainer(
model=self.model,
model=model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=self.tokenizer,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
generate_during_eval=True,
precompute_ref_log_probs=precompute_ref_log_probs,
)
Loading
Loading