Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DPO] use ref model logprobs if it exists in the data #885

Merged
merged 69 commits into from
Dec 12, 2023

Conversation

kashif
Copy link
Collaborator

@kashif kashif commented Oct 17, 2023

@kashif kashif marked this pull request as draft October 17, 2023 18:40
@kashif kashif changed the title [DPO] use logprobs if it exists in the data [DPO] use ref model logprobs if it exists in the data Oct 17, 2023
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@kashif kashif force-pushed the reference-logprobs branch from 90c470a to e1acfb3 Compare October 27, 2023 09:33
@kashif kashif marked this pull request as ready for review October 27, 2023 10:38
Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand this only affects internals except the new flag? Looks good to me, maybe @lewtun wants to have a look too since he's been using this quite a bit recently.

@lvwerra lvwerra requested a review from lewtun October 31, 2023 09:22
@lewtun lewtun requested a review from edbeeching October 31, 2023 09:31
@edbeeching
Copy link
Collaborator

edbeeching commented Oct 31, 2023

Hey @kashif , a couple of questions:

  1. Is the user expected to pass ref_model=None and precompute_logprobs=True to the DPOTrainer? As in this case isn't another instance of the model instantiated through a deepcopy?:
    self.ref_model = create_reference_model(model)
  2. Will the pre-computation of logprobs work correctly in a distributed setting, as if we have 8 GPUs and DDP=8, I think the dataset.map will run on all of the dataset on all GPUs?
    self.train_dataset = self.train_dataset.map(

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for optimising the DPO trainer @kashif 🔥 !

Overall the PR looks great, but I have a few questions about whether the log probs are precomputed on CPU vs GPU.

It would also be great to see some docs on how to use this feature, e.g. showing the 3 main cases:

  • Log probs are already precomputed and stored in a dataset
  • Log probs are precomputed at the start of training
  • Log probs are computed on the fly (previous behaviour)

Would it also make sense if I run this PR through a DPO training run to check we don't have any major regressions?

trl/trainer/dpo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/dpo_trainer.py Outdated Show resolved Hide resolved
# tokenize the dataset and compute reference logps for training datasets
self.train_dataset = self.train_dataset.map(self.tokenize_batch_element)
if self.precompute_ref_logps:
self.train_dataset = self.train_dataset.map(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the map() run on CPU or GPU? I seem to recall that batched maps like this were best done with a torch dataloader, but perhaps this is no longer true. I'm mostly worried that running Llama 70B on CPU will blow up the RAM :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you happen to know what the answer to the above question is (ie do we run inference on CPU or GPU)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should run wherever the ref_model is at this point, which is already prepared via accelerate no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes however if there is explicitly a ref_model then (since its no being passed to the accelerate we have to move it to the accelerate device) which is what i am doing

trl/trainer/dpo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/dpo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/dpo_trainer.py Show resolved Hide resolved
trl/trainer/dpo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/dpo_trainer.py Outdated Show resolved Hide resolved
@kashif kashif added the 🏋 DPO Related to DPO label Dec 2, 2023
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, what do you think @lvwerra @lewtun @edbeeching ?

@lewtun
Copy link
Member

lewtun commented Dec 12, 2023

This looks good to go - let merge it 🔥 !

@kashif kashif merged commit 48b3ef0 into huggingface:main Dec 12, 2023
9 checks passed
@kashif kashif deleted the reference-logprobs branch December 12, 2023 18:12
lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
* use logprobs if it exists in the batch

* add features to tokenized batch if in data

* make get_batch_logps a static method

* add tokenize_batch_element dataset mapper

* Remove tokenize_batch method from DPODataCollator

* Initial sketch to precompute reference_logps

* run ref model via pytorch dataloader

* add a padding helper

* clean up the helper

* use logprob item()

* default behaviour

* clean up collator

* add docstring

* copy data back to cpu if needed

* use get_train_dataloader methods

* fix tests

* rename: more explicit variable name precompute_ref_log_probs

* improve comment

* update comment

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* refactor models into setup parameters

* parametrize precompute_ref_log_probs flag

* remove useless test

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update tests/test_dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update tests/test_dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* update function arg name

* distinguish between pad token_id and mask values

* fix tokenization huggingface#932 by @nrailg

* fix test

* undo test refactor

* new line

* undo breaking change

* Update token counter condition to allow Llama tokenizer

* Acount for merged tokens on certain tokenizers such Llama-2 tokenizer

* Update variable name to match list value when truncating response

* map function on multi-gpu and gather

* Add test cases for DPOTrainer tokenization step

* revert since we need the prepeared model

* Use gather_with_metrics on ref_logps precomputation to keep original dataset size

* Add flag to keep track of when ref_logps are precomputed

* make variable names private

* formatting

* if precompute_ref_log_probs is true one can use non-peft to populate log-probs

* Use tokenizer padding token unless padding_value is set

* Move dataset.map(tokenize_batch) outside dataloader to avoid serialization errors

* eval can be none

* move to cpu to avoid gpu oom

* remove unneeded cast to float32

* remove unneeded

* fix merge

* fix merge

* fix merge

* add precompute log-prob status via tqdm

* Truncate answer if too longer once prompt has been truncated

* Add prompt_input_ids to batch to enable generation

* formatting and add lora example

* fix formatting

* Tokenize row now expects sample to have space on chosen/rejected for llama

* Revert "Tokenize row now expects sample to have space on chosen/rejected for llama"

This reverts commit dd07a10.

* raise error when using zero-3 with precompute_ref_log_probs

---------

Co-authored-by: Pablo Vicente Juan <p.vicente.juan@gmail.com>
Co-authored-by: Shoaib Burq <saburq@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🏋 DPO Related to DPO
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants