Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[DPO] use ref model logprobs if it exists in the data (#885)
* use logprobs if it exists in the batch * add features to tokenized batch if in data * make get_batch_logps a static method * add tokenize_batch_element dataset mapper * Remove tokenize_batch method from DPODataCollator * Initial sketch to precompute reference_logps * run ref model via pytorch dataloader * add a padding helper * clean up the helper * use logprob item() * default behaviour * clean up collator * add docstring * copy data back to cpu if needed * use get_train_dataloader methods * fix tests * rename: more explicit variable name precompute_ref_log_probs * improve comment * update comment * Update trl/trainer/dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * refactor models into setup parameters * parametrize precompute_ref_log_probs flag * remove useless test * Update trl/trainer/dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update tests/test_dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update tests/test_dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update trl/trainer/dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update trl/trainer/dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * update function arg name * distinguish between pad token_id and mask values * fix tokenization #932 by @nrailg * fix test * undo test refactor * new line * undo breaking change * Update token counter condition to allow Llama tokenizer * Acount for merged tokens on certain tokenizers such Llama-2 tokenizer * Update variable name to match list value when truncating response * map function on multi-gpu and gather * Add test cases for DPOTrainer tokenization step * revert since we need the prepeared model * Use gather_with_metrics on ref_logps precomputation to keep original dataset size * Add flag to keep track of when ref_logps are precomputed * make variable names private * formatting * if precompute_ref_log_probs is true one can use non-peft to populate log-probs * Use tokenizer padding token unless padding_value is set * Move dataset.map(tokenize_batch) outside dataloader to avoid serialization errors * eval can be none * move to cpu to avoid gpu oom * remove unneeded cast to float32 * remove unneeded * fix merge * fix merge * fix merge * add precompute log-prob status via tqdm * Truncate answer if too longer once prompt has been truncated * Add prompt_input_ids to batch to enable generation * formatting and add lora example * fix formatting * Tokenize row now expects sample to have space on chosen/rejected for llama * Revert "Tokenize row now expects sample to have space on chosen/rejected for llama" This reverts commit dd07a10. * raise error when using zero-3 with precompute_ref_log_probs --------- Co-authored-by: Pablo Vicente Juan <p.vicente.juan@gmail.com> Co-authored-by: Shoaib Burq <saburq@gmail.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
- Loading branch information