-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DPO] use ref model logprobs if it exists in the data #885
Conversation
kashif
commented
Oct 17, 2023
•
edited
Loading
edited
- Refactor the trainer so that it can use the logprobs from the data rather than a reference model if it exists in the dataset
- Added a flag that adds reference model logprobs to the dataset before training in the dataloader creation phase
- fix confusion about padding_value and label_pad_token_id
- fix from fix DPO data collator #932
- fix for DPO models generate multiple / corrupted responses #1025
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
90c470a
to
e1acfb3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand this only affects internals except the new flag? Looks good to me, maybe @lewtun wants to have a look too since he's been using this quite a bit recently.
Hey @kashif , a couple of questions:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for optimising the DPO trainer @kashif 🔥 !
Overall the PR looks great, but I have a few questions about whether the log probs are precomputed on CPU vs GPU.
It would also be great to see some docs on how to use this feature, e.g. showing the 3 main cases:
- Log probs are already precomputed and stored in a dataset
- Log probs are precomputed at the start of training
- Log probs are computed on the fly (previous behaviour)
Would it also make sense if I run this PR through a DPO training run to check we don't have any major regressions?
trl/trainer/dpo_trainer.py
Outdated
# tokenize the dataset and compute reference logps for training datasets | ||
self.train_dataset = self.train_dataset.map(self.tokenize_batch_element) | ||
if self.precompute_ref_logps: | ||
self.train_dataset = self.train_dataset.map( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the map()
run on CPU or GPU? I seem to recall that batched maps like this were best done with a torch
dataloader, but perhaps this is no longer true. I'm mostly worried that running Llama 70B on CPU will blow up the RAM :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you happen to know what the answer to the above question is (ie do we run inference on CPU or GPU)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should run wherever the ref_model
is at this point, which is already prepared via accelerate
no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes however if there is explicitly a ref_model then (since its no being passed to the accelerate we have to move it to the accelerate device) which is what i am doing
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, what do you think @lvwerra @lewtun @edbeeching ?
…ted for llama" This reverts commit dd07a10.
This looks good to go - let merge it 🔥 ! |
* use logprobs if it exists in the batch * add features to tokenized batch if in data * make get_batch_logps a static method * add tokenize_batch_element dataset mapper * Remove tokenize_batch method from DPODataCollator * Initial sketch to precompute reference_logps * run ref model via pytorch dataloader * add a padding helper * clean up the helper * use logprob item() * default behaviour * clean up collator * add docstring * copy data back to cpu if needed * use get_train_dataloader methods * fix tests * rename: more explicit variable name precompute_ref_log_probs * improve comment * update comment * Update trl/trainer/dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * refactor models into setup parameters * parametrize precompute_ref_log_probs flag * remove useless test * Update trl/trainer/dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update tests/test_dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update tests/test_dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update trl/trainer/dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update trl/trainer/dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * update function arg name * distinguish between pad token_id and mask values * fix tokenization huggingface#932 by @nrailg * fix test * undo test refactor * new line * undo breaking change * Update token counter condition to allow Llama tokenizer * Acount for merged tokens on certain tokenizers such Llama-2 tokenizer * Update variable name to match list value when truncating response * map function on multi-gpu and gather * Add test cases for DPOTrainer tokenization step * revert since we need the prepeared model * Use gather_with_metrics on ref_logps precomputation to keep original dataset size * Add flag to keep track of when ref_logps are precomputed * make variable names private * formatting * if precompute_ref_log_probs is true one can use non-peft to populate log-probs * Use tokenizer padding token unless padding_value is set * Move dataset.map(tokenize_batch) outside dataloader to avoid serialization errors * eval can be none * move to cpu to avoid gpu oom * remove unneeded cast to float32 * remove unneeded * fix merge * fix merge * fix merge * add precompute log-prob status via tqdm * Truncate answer if too longer once prompt has been truncated * Add prompt_input_ids to batch to enable generation * formatting and add lora example * fix formatting * Tokenize row now expects sample to have space on chosen/rejected for llama * Revert "Tokenize row now expects sample to have space on chosen/rejected for llama" This reverts commit dd07a10. * raise error when using zero-3 with precompute_ref_log_probs --------- Co-authored-by: Pablo Vicente Juan <p.vicente.juan@gmail.com> Co-authored-by: Shoaib Burq <saburq@gmail.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>