-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Reward Backpropogation Support #1585
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @mihirp1998 for your hardwork ! In principle this looks good !
I just have few questions with respect to the differences between this method and DDPO, could you clearly highlight either on the documentation or in this PR what are the major differences between DDPO and this algorithm ? 🙏
I would also like to have a review from @sayakpaul if possible, what do you think of Reward Backpropagation ?
Thanks !
docs/source/alignprop_trainer.mdx
Outdated
| Before | After finetuning | | ||
| --- | --- | | ||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_squirrel.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_squirrel.png"/></div> | | ||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_crab.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_crab.png"/></div> | | ||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_starfish.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_starfish.png"/></div> | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These images are the ones generated from DDPO no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, i wanted to update them, although I wasn't sure how to do it, as they linked to a huggingface internal webpage https://huggingface.co/datasets/trl-internal-testing/
If you can guide me on how to do it, i can update them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can open a PR to https://huggingface.co/datasets/trl-internal-testing/ repository adding the resultant images you want.
docs/source/alignprop_trainer.mdx
Outdated
library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers. | ||
Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made. | ||
|
||
There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here it references DDPO trainer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this. I have fixed this.
I have added the differences in the pull request, let me know if u have some doubts or think something is missing. |
@@ -0,0 +1,117 @@ | |||
# Aligning Text-to-Image Diffusion Models with Reward Backpropagation | |||
|
|||
## The why |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think as a reader I understand if the following table justifies the name of this section. Would you mind elaborating?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I added a better why statement.
docs/source/alignprop_trainer.mdx
Outdated
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_starfish.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_starfish.png"/></div> | | ||
|
||
|
||
## Getting started with Stable Diffusion finetuning with reinforcement learning |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is needed. We should strive to keep the API documentation lean and precise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes i removed it.
docs/source/alignprop_trainer.mdx
Outdated
```python | ||
|
||
import torch | ||
from trl import DefaultDDPOStableDiffusionPipeline |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have to use a non-diffusers pipeline here? Does DiffusionPipeline
from diffusers
not work here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes indeed, i changed it to StableDiffusionPipeline from diffusers
docs/source/alignprop_trainer.mdx
Outdated
pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/alignprop-finetuned-sd-model") | ||
|
||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | ||
|
||
# memory optimization | ||
pipeline.vae.to(device, torch.float16) | ||
pipeline.text_encoder.to(device, torch.float16) | ||
pipeline.unet.to(device, torch.float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These LoCs could be reduce if we do:
pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/alignprop-finetuned-sd-model", torch_dtype=torch.float16)
pipeline = pipeline.to("cuda")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additionally, https://huggingface.co/metric-space/alignprop-finetuned-sd-model is not available. Let's make sure we're using the right checkpoint ids here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes i reduced it and fixed the checkpoint ids.
examples/scripts/alignprop.py
Outdated
class MLP(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.layers = nn.Sequential( | ||
nn.Linear(768, 1024), | ||
nn.Dropout(0.2), | ||
nn.Linear(1024, 128), | ||
nn.Dropout(0.2), | ||
nn.Linear(128, 64), | ||
nn.Dropout(0.1), | ||
nn.Linear(64, 16), | ||
nn.Linear(16, 1), | ||
) | ||
|
||
def forward(self, embed): | ||
return self.layers(embed) | ||
|
||
|
||
class AestheticScorer(torch.nn.Module): | ||
""" | ||
This model attempts to predict the aesthetic score of an image. The aesthetic score | ||
is a numerical approximation of how much a specific image is liked by humans on average. | ||
This is from https://github.com/christophschuhmann/improved-aesthetic-predictor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we copy-pasting these modules from the DDPO script?
@younesbelkada would it make sense to have a separate module for these (auxiliary_modules
, perhaps)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are not exactly copy pasted, as DDPO had clamp and no_grad operations within them, which were preventing gradients from backpropagating.
Anyhow I still transfered the above reward function code from alignprop.py to trl/models/auxiliary_modules.py, as u suggested.
trl/models/modeling_sd_base.py
Outdated
if truncated_backprop: | ||
if truncated_backprop_rand: | ||
rand_timestep = random.randint(truncated_rand_backprop_minmax[0],truncated_rand_backprop_minmax[1]) | ||
if i < rand_timestep: | ||
noise_pred = noise_pred.detach() | ||
else: | ||
if i < truncated_backprop_timestep: | ||
noise_pred = noise_pred.detach() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would want to supplement this code block with comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes added comments.
trl/models/modeling_sd_base.py
Outdated
@@ -527,6 +527,243 @@ def pipeline_step( | |||
|
|||
return DDPOPipelineOutput(image, all_latents, all_log_probs) | |||
|
|||
def pipeline_step_with_grad( | |||
self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self
, could be replaced with pipeline
as that is what we're passing down the line, IIUC?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes i changed it to pipeline.
trl/trainer/alignprop_trainer.py
Outdated
|
||
# {model_name} | ||
|
||
This is a pipeline that finetunes a diffusion model with reward gradients. The model can be used for image generation conditioned with text. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what is the norm is within the library but I think it could be nice to also include a link to the AlignProp paper here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes i added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contributions. I left a couple of comments.
I would love to see some concrete comparisons to DDPO (training time, reward dynamics, convergence of the validation samples, etc.).
I have made concrete comparisions with DDPO here. I ran the DDPO default code in TRL with batch size 128, while AlignProp also uses the same batch size. As can be seen AlignProp is significantly more sample efficient, here i train both the models for a few hours. Here x-axis is the epochs and y-axis is the reward achieved. Although i ran the above experiments for a few hours, AlignProp only takes about 30 minutes to converge to a good solution. So i early stopped at the 8th epoch in training. Below is the comparision with DDPO after training both models for 30 minutes. Here x-axis is the training time and y-axis is the reward achieved. Both the curves are similar to the curves in the AlignProp paper. The above curves were with the same set of prompts during training/testing. In the curve below i show AlignProp results on unseen prompts. As can be seen there is not much gap in results between seen/unseen prompts. Here dotted lines is the unseen prompts while solid line is the seen prompts. Finally here are some generated images from AlignProp, for seen/unseen animals after training. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks clean on my end ! Just one comment about an import that is not needed and we should be good to go !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice documentation page !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Yes the import wasn't needed, i committed it.
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice integration thanks for adding this !
cc @sayakpaul could you give a final look ? |
@mihirp1998 can you run the styling checks? |
Will do. Allow me today as today is the SD3 release :) |
Sure, no rush |
@mihirp1998 can you also add |
Yes done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work. I really like the sample efficiency gains :)
Thanks! Is something left from my side? As it still says: 2 workflows awaiting approval |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Oh nvm, seems like everything is approved! Thanks! |
cc @vwxyzjn for a final look, if all looks great to you, feel free to merge! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thanks for the work!
commit 9e9dc96 Author: Maxim Kopecki <kopecki.maxim@gmail.com> Date: Wed Jul 10 19:11:13 2024 +0200 Added missing token kwarg in Peft model loading (#1825) commit 7ddef5c Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Wed Jul 10 18:26:11 2024 +0200 Make use of `trust_remote_code` consistent (#1806) Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> commit a9cddf8 Author: Adnan Khan <AdnaneKhan@users.noreply.github.com> Date: Wed Jul 10 11:25:07 2024 -0400 Delete unused benchmark.yml workflow. (#1822) commit 2860ce5 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Tue Jul 9 09:22:52 2024 +0200 DPO Llava 1.5 and PaliGemma support (#1797) * llava support dpo * add_special_tokens=False only when possible * format * pali gemma * refactor size * remove image resize --------- Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> commit 30e33bd Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Tue Jul 9 05:37:12 2024 +0200 upgrade gh actions (#1818) Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> commit d5a0d2d Author: Costa Huang <costa.huang@outlook.com> Date: Mon Jul 8 11:12:41 2024 -0400 Set dev version (#1817) commit 314e8eb Author: Puneet Singh Bhooi <puneetb@iiitd.ac.in> Date: Mon Jul 8 19:11:36 2024 +0530 fix broken url in `docs\source\index.mdx` (#1813) commit e107920 Author: Costa Huang <costa.huang@outlook.com> Date: Mon Jul 8 09:38:09 2024 -0400 0.9.6 release (#1816) commit 78045de Author: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Mon Jul 8 01:59:26 2024 +0200 Fix `TRL_USE_RICH` environment variable handling (#1808) * Add `strtobool` custom implementation from `distutils` * Fix `TRL_USE_RICH` handling via `strtobool` * Run `make precommit` commit 747612f Author: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Fri Jul 5 16:28:59 2024 +0200 Fix `torch_dtype` handling in `{DPO,SFT}Trainer` when provided via CLI (#1807) * Fix `torch_dtype` handling through CLI The `torch_dtype` is not properly handled when provided via the TRL CLI since it's provided initially as a string, but is then casted to `torch.dtype` before providing it to the `{DPO,SFT}Trainer`, which means that those trainers should handle the scenario where `torch_dtype` is a `torch.dtype` too. * Add `torch_dtype` tests in `test_{dpo,sft}_trainer.py` * Forward contribution credits * Run `make precommit` --------- Co-authored-by: Tash Srivastava <yash-srivastava19@users.noreply.github.com> commit 9e3a35b Author: Michael <mnoukhov@gmail.com> Date: Fri Jul 5 07:29:48 2024 -0400 Remove extra print in reward_trainer.py (#1799) `print_rich_table` is called twice and the first call doesn't restrict to `num_print_samples`. Remove the first, extra call commit 4402b36 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Thu Jul 4 14:29:25 2024 +0200 clean examples (#1791) Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> commit 78f8228 Author: Noah Tye <hi@noahtye.com> Date: Wed Jul 3 11:10:50 2024 -0700 Bugfix: Preserve token fields when converting TrainingArguments to SFTConfig (#1794) * Preserve token fields when converting TrainingArguments to SFTConfig TrainingArguments.to_dict() redacts token fields, so we have to individually copy them over when converting to SFTConfig to avoid breaking push_to_hub functionality. Also adds a test. * run precommit * one-line args_as_dict definition per suggestion from kashif * generalize token copying to match TrainingArguments behavior * unwrap |= on dict, to support python 3.8 * use .update instead of |= or for-loop commit b6af2ed Author: Kashif Rasul <kashif.rasul@gmail.com> Date: Wed Jul 3 08:29:16 2024 +0200 add model_init_kwargs to training_args (#1787) commit cd85b14 Author: Tommaso Buonocore <buonocore.tms@gmail.com> Date: Sat Jun 29 15:35:48 2024 +0200 Fixed typo in SFT trainer docs (#1788) 'STFConfig' instead of 'SFTConfig' appears multiple times in the doc, causing error when running the code snippets. commit a57544f Author: Kashif Rasul <kashif.rasul@gmail.com> Date: Thu Jun 27 15:47:58 2024 +0200 fix docs and examples (#1780) commit b68ff96 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Wed Jun 26 16:26:37 2024 +0200 Visual DPO (#1647) * Remove extra whitespaces * idefics * vdpo * sft idefics * pad with test * use prompt instead of tokenizer * rm name main * support vlm in tokenize row * temp fix for regex in lora_target_module * format * vdpo * tmp float16 hard code * concatenated_forward support for vision * style and new command line * all-linear * format * delete old examples * get image * upcast * new test * modified test * new strat for tokenizer * rm token transfer * integrate vision in dpo example * format * add FDivergenceType back * precommit * pillow test dep * optional prompt * `evaluation_strategy` to `eval_strategy` * revert vsft change (oos) * update test * test * comment and support more in process * update process * update doc for vdpo * caution about limited support * Update docs/source/dpo_trainer.mdx Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * revert DPO example changes * cleaner way to check if a model is vision * comment * update vdpo example * rename --------- Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> commit c8c01cc Author: Mubin Manasia <48038715+Mubin17@users.noreply.github.com> Date: Wed Jun 26 03:23:36 2024 -0600 Fix Documentation Overflow Issues for Long URLs in SFTConfig (#1774) * Update sft_config.py * Update sft_config.py commit 3479606 Author: Costa Huang <costa.huang@outlook.com> Date: Wed Jun 26 03:18:22 2024 -0400 Remove the leading space in the tldr preference dataset (#1773) commit 7965b78 Author: Haozhe Ji <jihaozhe@gmail.com> Date: Tue Jun 25 22:47:32 2024 +0800 add Efficient Exact Optimization (EXO) (#1735) * add exo * fix a detail * Update trl/trainer/dpo_trainer.py * Update trl/trainer/dpo_trainer.py * Update trl/trainer/dpo_trainer.py --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> commit 56bd1bb Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Tue Jun 25 16:14:26 2024 +0200 `evaluation_strategy` to `eval_strategy` (#1771) Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> commit 94d53e6 Author: Clara Pohland <54847419+claralp@users.noreply.github.com> Date: Mon Jun 24 21:27:00 2024 +0200 MoE Models: option to add load balancing loss (#1765) * KTO: add aux loss * use router_aux_loss_coef in KtoTrainer when aux_loss enabled * align optional aux_loss in DPO, KTO, CPO, ORPO * precommit changes * fix KL forward kwargs * add aux_loss doku entry * apply docs suggestions --------- Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de> commit b5be100 Author: Mihir Prabhudesai <mihirp1998.mp@gmail.com> Date: Mon Jun 24 12:05:44 2024 -0400 Added Reward Backpropogation Support (#1585) * added alignprop template * added alignprop support * Update alignprop_trainer.mdx * Update alignprop_trainer.mdx * added better why statement * fixed inference code * changed self to pipeline * removed aesthetic classifier * added aesthetic to auxiliary models * added unseen prompt logging * removed unseen prompt log * fixed minor * remove not needed import in trl/__init__.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * fixed styling * updated _toctree --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> commit 6e1652b Author: Haoran Xu <45837851+fe1ixxu@users.noreply.github.com> Date: Sun Jun 23 09:54:30 2024 -0700 Add CPO-SimPO method (#1760) * enable cpo-simpo * highlight SimPO and CPO-SimPO * add test for cpo_alpha * formatting * Update docs/source/cpo_trainer.mdx --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> commit 65374c6 Author: Costa Huang <costa.huang@outlook.com> Date: Fri Jun 21 11:20:54 2024 -0400 New sentiment and descriptiveness dataset (#1757) * push changes * handle edge cases where the chosen and the rejected are the same commit 9956091 Author: Juyoung Suk <scottsuk0306@gmail.com> Date: Fri Jun 21 18:01:08 2024 +0900 Add dataset_text_field in examples/scripts/sft.py (#1758) commit 34d273f Author: Costa Huang <costa.huang@outlook.com> Date: Thu Jun 20 13:16:43 2024 -0400 Support num_train_epochs (#1743) * add a test case for num_train_epochs * fix ci * quick change * disable push to hub * debug windows ci * try another fix * skip subprocess tests on windows commit 3bf9449 Author: Mert Sayar <mert.sayar@gmail.com> Date: Thu Jun 20 18:22:20 2024 +0300 Fix masking of response tokens (#1718) Current handling of `response_masks` inside `batch_forward_pass` function does not take padding into consideration which results with shape unmatch during masking. Since response mask is a mask tensor of response tokens, response tokens should not be concatenated with a `torch.zeros(query_length)` and masking operation should be done without slicing. Remove the concatenation of the response mask, remove the slicing from the response mask since response mask already has the length of `end - start + 1`, which is equal to length of `masks[j, start:end]`. commit ba6abee Author: idanshen <49375140+idanshen@users.noreply.github.com> Date: Thu Jun 20 09:14:16 2024 -0400 Support for returning past_key_values from the model (#1742) * add support for returning past_key_values from the model * change order of keys commit a57e759 Author: 1485840691 <110707330+1485840691@users.noreply.github.com> Date: Wed Jun 19 18:02:51 2024 +0800 Integrate f-divergence to DPO (Follow up) (#1610) * Step 1: update ppo_trainer and hello_world example * Step 2: Refine comments and add parameter type * Step 2: Add missing parameter comments * Step 1: Organize ptx loss into a function and add ptx_loss to train_stats * Step 1 updates: add comment to ptx_loss function, fix a bug and add warning message * Step 2: 1) Add ppo_ptx trainig example as ppo; 2) separate pretrain data fetch and iterate * Step 2: Remove loss from columns_to_log in ppo_ptx example * Remove data set revision in load imbd dataset * Run pre-commit and fix format issues * Initial draft of f-divergence fn * Update f-divergence to avoid overflow * fix test errors and comments * Add Unit tests for dpo loss with alpha and js div f * Adjust format * Fix test error * Reverse this update * Add test cases * Reverse un-needed updates * Update code style * Try to fix code fmt error * remove extra end line --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> commit ae23d40 Author: Shihyueh Hsu <66808901+AIR-hl@users.noreply.github.com> Date: Tue Jun 18 22:07:24 2024 +0800 change the `process` function in the example of DPO (#1753) * change the `process` function in the example of DPO * fix commit 83b367b Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue Jun 18 11:31:17 2024 +0200 CI / `KTOTrainer`: Remove old tests (#1750) * remove old tests * remove datasets * Update test_dpo_trainer.py * Update test_dpo_trainer.py commit d1ed730 Author: Michael <mnoukhov@gmail.com> Date: Mon Jun 17 10:50:21 2024 -0400 prepare deepspeed accomodate fp16 and bf16 (#1728) * prepare deepspeed accomodate fp16 and bf16 * precommit commit 8f8e95e Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon Jun 17 16:49:00 2024 +0200 CPO / DPO: Fix red CI (#1749) * fix red CI * precommit commit 4e23d95 Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon Jun 17 16:41:36 2024 +0200 fix red CI commit 50c4620 Author: Kawin <kawin.ethayarajh@gmail.com> Date: Mon Jun 17 07:14:44 2024 -0700 small KTO fixes (#1734) * add warning for imbalanced data * update documentation * update script commands to be same as in dpo * use batch_size KL examples and batch_size target examples to calculate batch_size losses * fix deepspeed issue * speed up forward with no_grad for KL * add some removed metrics * Update trl/trainer/kto_trainer.py * Update trl/trainer/kto_trainer.py * Update trl/trainer/kto_trainer.py add reference to paper Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * add more detailed comments * convert assert to ValueError * Update kto_trainer.py * precommit formatting * remove nans in metrics by gathering across machines * fix formatting * fix choice of mismatched examples for KL term * describe weights * fix hanging issue in distributed training * linting * move metrics to cpu * Update trl/trainer/kto_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update trl/trainer/kto_trainer.py * Update trl/trainer/kto_trainer.py * remove kto_pair * speed up data processing * move bco code inside * raise error for kto_pair argument * fix formatting --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> Co-authored-by: Winnie Xu <winnie.xu97@gmail.com> commit 6105d03 Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon Jun 17 16:01:06 2024 +0200 `TrlParser`: Add ignore extra args option (#1748) * add ignore extra args option * Update trl/commands/cli_utils.py commit e247bbd Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon Jun 17 15:16:07 2024 +0200 CI / core: Pin `numpy` to `!=2.0.0` for CI and to users (#1747) * Update setup.py * Update setup.py * Update setup.py * Update test_best_of_n_sampler.py dummy commit * pin numpy * Update tests/test_best_of_n_sampler.py * Update setup.py commit 3d04496 Author: Michael <mnoukhov@gmail.com> Date: Mon Jun 17 08:43:33 2024 -0400 better trl parser with yaml config (#1739) * working trl parser with config correctly overrides yaml config with command line arguments adds return_remaining_strings when return_remaining_strings is False, raises error if yaml contains extra args that are not in the dataclasses simpler and cleaner than previous yaml parsing and merging addresses #1733 * lowercase trlparser commit 2d244f8 Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon Jun 17 11:56:13 2024 +0200 Workflow: Notify tests results on slack channel (#1744) * Update tests-main.yml * Update docker-build.yml commit f5168fd Author: Igor Melnyk <igoraries@gmail.com> Date: Wed Jun 12 05:54:54 2024 -0400 adds AOT (#1701) * adds AOT * Applied format changes * added docs and tests --------- Co-authored-by: Igor Melnyk <igor.melnyk@ibm.com> commit 79686e1 Author: jetlime <paul.houssel@yahoo.de> Date: Wed Jun 12 00:35:31 2024 +1000 ktotrainer: Refuse datasets which contain only one class of labels (#1724) * ktotrainer: refuse dataset which contain only one class of labels * ktotrainer: document new dataset constraint commit 34ebc4c Author: Luc Georges <McPatate@users.noreply.github.com> Date: Mon Jun 10 11:17:54 2024 +0200 feat(ci): add trufflehog secrets detection (#1721) * feat(ci): add trufflehog secrets detection * fix(ci): remove unnecessary permissions commit 1d84e2b Author: Michael <mnoukhov@gmail.com> Date: Fri Jun 7 11:42:08 2024 +0200 Fix default padding_value in dpo_config.py (#1692) dpo_config default padding value should be None, not 0, otherwise it by default overrides the padding value of any tokenizer to 0 commit 2f71b8b Author: Michael <mnoukhov@gmail.com> Date: Fri Jun 7 10:37:27 2024 +0200 fix yaml parser for derived config classes (#1713) fixes #1712 reformatted cli_utils with ruff commit 5bcb8ad Author: Kashif Rasul <kashif.rasul@gmail.com> Date: Fri Jun 7 08:48:17 2024 +0100 RDPO fix nll loss (#1705) commit b8b972f Author: Haoran Xu <45837851+fe1ixxu@users.noreply.github.com> Date: Thu Jun 6 14:06:47 2024 -0700 Add a variant of CPO, SimPO (#1703) * add a variant of cpo: simpo * correct cpo-simpo loss * avoid 0 int error in logging * add simpo description * Update trl/trainer/cpo_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * fix formatting * add test for simpo * Update docs/source/cpo_trainer.mdx Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * add a docstring for simpogamma * move simpo description to the above docstring * change simpo description in the doc * formatting --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> commit 3eb9ccb Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu Jun 6 19:33:20 2024 +0200 set dev version (#1710) * Update setup.py * Update __init__.py commit 974b0d3 Author: Costa Huang <costa.huang@outlook.com> Date: Thu Jun 6 10:13:00 2024 -0400 0.9.4 release (#1708) commit 39a7d1c Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu Jun 6 15:50:17 2024 +0200 SFTTrainer: Fix backward Compatibility issue with `TrainingArguments` (#1707) * fix BC * fixup commit 0bdc638 Author: Guilherme Freire <guilhermebfreire@gmail.com> Date: Thu Jun 6 14:42:58 2024 +0100 Fixed doc string and docs for the SFTConfig update (#1706) commit 275d33b Author: Costa Huang <costa.huang@outlook.com> Date: Wed Jun 5 14:34:59 2024 -0400 0.9.3 release (#1699) commit c0819ee Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed Jun 5 17:29:03 2024 +0200 Update sft_trainer.py (#1698) commit a03e7cc Author: Costa Huang <costa.huang@outlook.com> Date: Wed Jun 5 11:00:19 2024 -0400 Release 0.9.2 (#1697) * Release: 0.9.0 * Release commit a13cb89 Author: Costa Huang <costa.huang@outlook.com> Date: Wed Jun 5 10:20:54 2024 -0400 Quick fix on GPT4-eval (#1696) * quick fix * precommit commit 84156f1 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Mon Jun 3 20:09:05 2024 +0200 Fix typo in DPOTrainer's warnings (#1688) commit 4eb0b90 Author: Alex Brooks <alex.brooks@ibm.com> Date: Mon Jun 3 10:24:32 2024 -0600 Skip packing validation (#1673) * Add test for skipping preproc if packing=True Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Allow skipping of validation for packing=True Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Use dummy dataset in no packing preproc test Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> --------- Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> commit 6c203f9 Author: Alexey Rozhkov <alexisrozhkov@gmail.com> Date: Mon Jun 3 10:16:22 2024 +0100 Fix overriding optimize_device_cache with optimize_cuda_cache in PPOConfig (#1690) * Don't override optimize_device_cache when optimize_cuda_cache is not provided Raise an exception when both optimize_cuda_cache and optimize_device_cache are set * Minor fix commit f18253b Author: Kashif Rasul <kashif.rasul@gmail.com> Date: Mon Jun 3 09:43:02 2024 +0100 intial RPO loss (#1686) * intial RPO loss * fix sign * clean up commit 151a452 Author: Samuel <s.kiegeland@gmx.de> Date: Wed May 29 20:29:38 2024 +0200 Fix max completion length (#1588) commit 488b502 Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed May 29 20:19:26 2024 +0200 fix (#1678) commit 3c0a10b Author: Wang, Yi <yi.a.wang@intel.com> Date: Mon May 27 20:52:20 2024 +0800 fix dataset load error (#1670) Signed-off-by: Wang, Yi <yi.a.wang@intel.com> commit b031adf Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri May 24 15:20:16 2024 +0200 FIX / PPO: Fix `enable_input_require_grads` issues with PPO models (#1664) * Update modeling_base.py * Update ppo_config.py * Update ppo_trainer.py * style commit e7cb597 Author: Costa Huang <costa.huang@outlook.com> Date: Thu May 23 11:37:16 2024 -0400 Fix ppov2 test case (#1661) * Fix PPOv2 / RLOO refactor's stuff * update terminology to use stop token commit bc8dfbf Author: Kashif Rasul <kashif.rasul@gmail.com> Date: Thu May 23 15:28:04 2024 +0200 update eval_strategy (#1662) commit e4ed7a3 Author: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Thu May 23 18:34:22 2024 +0530 do not upcast adapters when using FSDP+QLoRA (#1654) commit 9a7efbd Author: syrn1k <85796210+syrn1k@users.noreply.github.com> Date: Thu May 23 15:58:49 2024 +0300 🤫 TR-DPO implementation (#1593) * 🤫 TR-DPO implementation baseline * fix comments * docs * fix linters * test added * move configs to DPOConfig * fix typo * add docs * fix import * use state.global_step * fix order of arguments * make sure plugins are not none * Update trl/trainer/utils.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * Update trl/trainer/utils.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * checking that reference model weights have changed * sync_target_model as staticmethod * set reference model --------- Co-authored-by: Nikita Surnachev <n.surnachev@tinkoff.ru> Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> commit b344bce Author: Anush Kini <33577829+Abilityguy@users.noreply.github.com> Date: Thu May 23 18:27:25 2024 +0530 [DPO] Add 'robust' loss_type (#1653) * Initial commit * pre-commit fix * Minor change to comments * Added some documentation on how to use Robust DPO commit 35e12dc Author: Nicolinho <Nicolinho@users.noreply.github.com> Date: Thu May 23 14:36:15 2024 +0200 Fix inheritance order in PPOv2Config (#1659) * fix inheritance order in PPOv2Config * fix inheritance order in rloo_config commit 1da6be1 Author: Ali Bakly <anbakly@gmail.com> Date: Thu May 23 14:10:29 2024 +0200 docs: correct cDPO usage in DPOTrainer (#1655) commit e249cd8 Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu May 23 14:10:05 2024 +0200 add support for training collator (#1658) commit a02513c Author: Zach Mueller <muellerzr@gmail.com> Date: Thu May 23 06:48:00 2024 -0400 Apply deprecated `evaluation_strategy` (#1559) * Deprecate * Update tests/test_dpo_trainer.py --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> commit 13454d2 Author: Costa Huang <costa.huang@outlook.com> Date: Wed May 22 08:31:10 2024 -0400 PPO / Reinforce Trainers (#1540) * Add ppov2 trainer * make eos trick optional, remove unused args * quick fix * precommit * update debugging script * fix out of bound `drop_last=True`; use built-in scheduler * Add PPO examples * push changes * quick change * quick change * various bug fixes * remove unnecessary grad accumulation setting * push new changes * fix DS3 model saving * update ppo.py * refactor * quick change * refactor * update ppo trainer * refactor * quick test * add ds2 /ds3 7 processes config * add vllm trainer * quick change * experiment with reward normalization * push changes * quick push * push changes * push various changes * refactor to use ModelConfig * quick change * refactor * refactor * Simplify DS logic * quick update * remove unnecessary files * precommit * deepspeed fix; handle edge case when eos_token_id = 0 * add PPO tldr example * add TL;DR example * fix undefined var * utilize all samples in rloo * quick setting * remove the unnecessary `value_model` * use exact_div * allow saving the deepspeed model * refactor * remove dead code * Use some shared utilities * add some end-to-end test cases * add PPOv2 docs and RLOO docs / tests * update docs * quikc push * fix ci * fix type annotation for ci * quick update * update trainer docs commit 99f2c94 Author: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed May 15 19:55:46 2024 +0530 don't cast the trainable lora layers to half precision (#1644) * don't cast the trainable lora layers to half precision * quality commit 6401d08 Author: Wing Lian <wing.lian@gmail.com> Date: Tue May 14 09:41:07 2024 -0400 Pairwise Noise Contrastive Alignment (#1632) * add NCA paired preference loss * chore: lint * set more lenient tolerance for integration tests * Update tests/test_dpo_trainer.py * skip test * fix --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: younesbelkada <younesbelkada@gmail.com> commit d632a5b Author: bartoszzuk <57541034+bartoszzuk@users.noreply.github.com> Date: Tue May 14 12:25:54 2024 +0200 Fixed wrong logs prefixes in KTOTrainer (#1641) * Fixed wrong logs prefixes in KTOTrainer * Pre-commit formating commit 5aeb752 Author: Tiezhen WANG <38108242+xianbaoqian@users.noreply.github.com> Date: Fri May 10 23:19:15 2024 +0800 Update sft_llama2.py to work with the latest API (#1637) * Update sft_llama2.py to work with the latest API SFTTrainer now takes a STFConfig argument * Update dpo_llama2.py * precommit commit b8b8978 Author: Ilya Gusev <phoenixilya@gmail.com> Date: Fri May 10 15:43:13 2024 +0200 [ORPO] Correct label mask for pad tokens (#1625) * [ORPO] Correct label mask for pad tokens Recent [fix](57aebe9) for calculating NLL loss for a whole sequence introduced a bug. When input_ids are copied to labels, pad tokens are not masked. This PR aims to path this by masking labels based on the attention mask. * -100 -> label_pad_token_id Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> commit 8799952 Author: Costa Huang <costa.huang@outlook.com> Date: Fri May 10 09:32:20 2024 -0400 visualize rm prediction (#1636) * visualize rm prediction * quick update * quick check * quick fix * update eval steps commit 3b4c249 Author: Xiao Yu <39458711+jasonyux@users.noreply.github.com> Date: Fri May 3 18:19:35 2024 -0400 fixed adding bos and eos token unconditionally (#1591) * fixed adding bos and eos token unconditionally * fixed typo of tokenizer -> self.tokenizer. Also added update to ORPO * fixed code quality, and added BOS/EOS fix to KTO * code reformatting with pre-commit run --all-files * bug fix: check input id length before checking for EOS/BOS commit 0347f58 Author: lewtun <lewis.c.tunstall@gmail.com> Date: Fri May 3 15:59:59 2024 +0200 Fix ZeRO-3 generation context manager (#1617)
* Add WinRateCallback * Enable PairRM * Refactor * Streamline * Add HF judge * Add base judge * Use better prompt * Clean * Add max tokens * Use logging * Add batched inference * Squashed commit of the following: commit 9e9dc96 Author: Maxim Kopecki <kopecki.maxim@gmail.com> Date: Wed Jul 10 19:11:13 2024 +0200 Added missing token kwarg in Peft model loading (#1825) commit 7ddef5c Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Wed Jul 10 18:26:11 2024 +0200 Make use of `trust_remote_code` consistent (#1806) Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> commit a9cddf8 Author: Adnan Khan <AdnaneKhan@users.noreply.github.com> Date: Wed Jul 10 11:25:07 2024 -0400 Delete unused benchmark.yml workflow. (#1822) commit 2860ce5 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Tue Jul 9 09:22:52 2024 +0200 DPO Llava 1.5 and PaliGemma support (#1797) * llava support dpo * add_special_tokens=False only when possible * format * pali gemma * refactor size * remove image resize --------- Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> commit 30e33bd Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Tue Jul 9 05:37:12 2024 +0200 upgrade gh actions (#1818) Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> commit d5a0d2d Author: Costa Huang <costa.huang@outlook.com> Date: Mon Jul 8 11:12:41 2024 -0400 Set dev version (#1817) commit 314e8eb Author: Puneet Singh Bhooi <puneetb@iiitd.ac.in> Date: Mon Jul 8 19:11:36 2024 +0530 fix broken url in `docs\source\index.mdx` (#1813) commit e107920 Author: Costa Huang <costa.huang@outlook.com> Date: Mon Jul 8 09:38:09 2024 -0400 0.9.6 release (#1816) commit 78045de Author: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Mon Jul 8 01:59:26 2024 +0200 Fix `TRL_USE_RICH` environment variable handling (#1808) * Add `strtobool` custom implementation from `distutils` * Fix `TRL_USE_RICH` handling via `strtobool` * Run `make precommit` commit 747612f Author: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Fri Jul 5 16:28:59 2024 +0200 Fix `torch_dtype` handling in `{DPO,SFT}Trainer` when provided via CLI (#1807) * Fix `torch_dtype` handling through CLI The `torch_dtype` is not properly handled when provided via the TRL CLI since it's provided initially as a string, but is then casted to `torch.dtype` before providing it to the `{DPO,SFT}Trainer`, which means that those trainers should handle the scenario where `torch_dtype` is a `torch.dtype` too. * Add `torch_dtype` tests in `test_{dpo,sft}_trainer.py` * Forward contribution credits * Run `make precommit` --------- Co-authored-by: Tash Srivastava <yash-srivastava19@users.noreply.github.com> commit 9e3a35b Author: Michael <mnoukhov@gmail.com> Date: Fri Jul 5 07:29:48 2024 -0400 Remove extra print in reward_trainer.py (#1799) `print_rich_table` is called twice and the first call doesn't restrict to `num_print_samples`. Remove the first, extra call commit 4402b36 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Thu Jul 4 14:29:25 2024 +0200 clean examples (#1791) Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> commit 78f8228 Author: Noah Tye <hi@noahtye.com> Date: Wed Jul 3 11:10:50 2024 -0700 Bugfix: Preserve token fields when converting TrainingArguments to SFTConfig (#1794) * Preserve token fields when converting TrainingArguments to SFTConfig TrainingArguments.to_dict() redacts token fields, so we have to individually copy them over when converting to SFTConfig to avoid breaking push_to_hub functionality. Also adds a test. * run precommit * one-line args_as_dict definition per suggestion from kashif * generalize token copying to match TrainingArguments behavior * unwrap |= on dict, to support python 3.8 * use .update instead of |= or for-loop commit b6af2ed Author: Kashif Rasul <kashif.rasul@gmail.com> Date: Wed Jul 3 08:29:16 2024 +0200 add model_init_kwargs to training_args (#1787) commit cd85b14 Author: Tommaso Buonocore <buonocore.tms@gmail.com> Date: Sat Jun 29 15:35:48 2024 +0200 Fixed typo in SFT trainer docs (#1788) 'STFConfig' instead of 'SFTConfig' appears multiple times in the doc, causing error when running the code snippets. commit a57544f Author: Kashif Rasul <kashif.rasul@gmail.com> Date: Thu Jun 27 15:47:58 2024 +0200 fix docs and examples (#1780) commit b68ff96 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Wed Jun 26 16:26:37 2024 +0200 Visual DPO (#1647) * Remove extra whitespaces * idefics * vdpo * sft idefics * pad with test * use prompt instead of tokenizer * rm name main * support vlm in tokenize row * temp fix for regex in lora_target_module * format * vdpo * tmp float16 hard code * concatenated_forward support for vision * style and new command line * all-linear * format * delete old examples * get image * upcast * new test * modified test * new strat for tokenizer * rm token transfer * integrate vision in dpo example * format * add FDivergenceType back * precommit * pillow test dep * optional prompt * `evaluation_strategy` to `eval_strategy` * revert vsft change (oos) * update test * test * comment and support more in process * update process * update doc for vdpo * caution about limited support * Update docs/source/dpo_trainer.mdx Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * revert DPO example changes * cleaner way to check if a model is vision * comment * update vdpo example * rename --------- Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> commit c8c01cc Author: Mubin Manasia <48038715+Mubin17@users.noreply.github.com> Date: Wed Jun 26 03:23:36 2024 -0600 Fix Documentation Overflow Issues for Long URLs in SFTConfig (#1774) * Update sft_config.py * Update sft_config.py commit 3479606 Author: Costa Huang <costa.huang@outlook.com> Date: Wed Jun 26 03:18:22 2024 -0400 Remove the leading space in the tldr preference dataset (#1773) commit 7965b78 Author: Haozhe Ji <jihaozhe@gmail.com> Date: Tue Jun 25 22:47:32 2024 +0800 add Efficient Exact Optimization (EXO) (#1735) * add exo * fix a detail * Update trl/trainer/dpo_trainer.py * Update trl/trainer/dpo_trainer.py * Update trl/trainer/dpo_trainer.py --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> commit 56bd1bb Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Tue Jun 25 16:14:26 2024 +0200 `evaluation_strategy` to `eval_strategy` (#1771) Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> commit 94d53e6 Author: Clara Pohland <54847419+claralp@users.noreply.github.com> Date: Mon Jun 24 21:27:00 2024 +0200 MoE Models: option to add load balancing loss (#1765) * KTO: add aux loss * use router_aux_loss_coef in KtoTrainer when aux_loss enabled * align optional aux_loss in DPO, KTO, CPO, ORPO * precommit changes * fix KL forward kwargs * add aux_loss doku entry * apply docs suggestions --------- Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de> commit b5be100 Author: Mihir Prabhudesai <mihirp1998.mp@gmail.com> Date: Mon Jun 24 12:05:44 2024 -0400 Added Reward Backpropogation Support (#1585) * added alignprop template * added alignprop support * Update alignprop_trainer.mdx * Update alignprop_trainer.mdx * added better why statement * fixed inference code * changed self to pipeline * removed aesthetic classifier * added aesthetic to auxiliary models * added unseen prompt logging * removed unseen prompt log * fixed minor * remove not needed import in trl/__init__.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * fixed styling * updated _toctree --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> commit 6e1652b Author: Haoran Xu <45837851+fe1ixxu@users.noreply.github.com> Date: Sun Jun 23 09:54:30 2024 -0700 Add CPO-SimPO method (#1760) * enable cpo-simpo * highlight SimPO and CPO-SimPO * add test for cpo_alpha * formatting * Update docs/source/cpo_trainer.mdx --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> commit 65374c6 Author: Costa Huang <costa.huang@outlook.com> Date: Fri Jun 21 11:20:54 2024 -0400 New sentiment and descriptiveness dataset (#1757) * push changes * handle edge cases where the chosen and the rejected are the same commit 9956091 Author: Juyoung Suk <scottsuk0306@gmail.com> Date: Fri Jun 21 18:01:08 2024 +0900 Add dataset_text_field in examples/scripts/sft.py (#1758) commit 34d273f Author: Costa Huang <costa.huang@outlook.com> Date: Thu Jun 20 13:16:43 2024 -0400 Support num_train_epochs (#1743) * add a test case for num_train_epochs * fix ci * quick change * disable push to hub * debug windows ci * try another fix * skip subprocess tests on windows commit 3bf9449 Author: Mert Sayar <mert.sayar@gmail.com> Date: Thu Jun 20 18:22:20 2024 +0300 Fix masking of response tokens (#1718) Current handling of `response_masks` inside `batch_forward_pass` function does not take padding into consideration which results with shape unmatch during masking. Since response mask is a mask tensor of response tokens, response tokens should not be concatenated with a `torch.zeros(query_length)` and masking operation should be done without slicing. Remove the concatenation of the response mask, remove the slicing from the response mask since response mask already has the length of `end - start + 1`, which is equal to length of `masks[j, start:end]`. commit ba6abee Author: idanshen <49375140+idanshen@users.noreply.github.com> Date: Thu Jun 20 09:14:16 2024 -0400 Support for returning past_key_values from the model (#1742) * add support for returning past_key_values from the model * change order of keys commit a57e759 Author: 1485840691 <110707330+1485840691@users.noreply.github.com> Date: Wed Jun 19 18:02:51 2024 +0800 Integrate f-divergence to DPO (Follow up) (#1610) * Step 1: update ppo_trainer and hello_world example * Step 2: Refine comments and add parameter type * Step 2: Add missing parameter comments * Step 1: Organize ptx loss into a function and add ptx_loss to train_stats * Step 1 updates: add comment to ptx_loss function, fix a bug and add warning message * Step 2: 1) Add ppo_ptx trainig example as ppo; 2) separate pretrain data fetch and iterate * Step 2: Remove loss from columns_to_log in ppo_ptx example * Remove data set revision in load imbd dataset * Run pre-commit and fix format issues * Initial draft of f-divergence fn * Update f-divergence to avoid overflow * fix test errors and comments * Add Unit tests for dpo loss with alpha and js div f * Adjust format * Fix test error * Reverse this update * Add test cases * Reverse un-needed updates * Update code style * Try to fix code fmt error * remove extra end line --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> commit ae23d40 Author: Shihyueh Hsu <66808901+AIR-hl@users.noreply.github.com> Date: Tue Jun 18 22:07:24 2024 +0800 change the `process` function in the example of DPO (#1753) * change the `process` function in the example of DPO * fix commit 83b367b Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue Jun 18 11:31:17 2024 +0200 CI / `KTOTrainer`: Remove old tests (#1750) * remove old tests * remove datasets * Update test_dpo_trainer.py * Update test_dpo_trainer.py commit d1ed730 Author: Michael <mnoukhov@gmail.com> Date: Mon Jun 17 10:50:21 2024 -0400 prepare deepspeed accomodate fp16 and bf16 (#1728) * prepare deepspeed accomodate fp16 and bf16 * precommit commit 8f8e95e Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon Jun 17 16:49:00 2024 +0200 CPO / DPO: Fix red CI (#1749) * fix red CI * precommit commit 4e23d95 Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon Jun 17 16:41:36 2024 +0200 fix red CI commit 50c4620 Author: Kawin <kawin.ethayarajh@gmail.com> Date: Mon Jun 17 07:14:44 2024 -0700 small KTO fixes (#1734) * add warning for imbalanced data * update documentation * update script commands to be same as in dpo * use batch_size KL examples and batch_size target examples to calculate batch_size losses * fix deepspeed issue * speed up forward with no_grad for KL * add some removed metrics * Update trl/trainer/kto_trainer.py * Update trl/trainer/kto_trainer.py * Update trl/trainer/kto_trainer.py add reference to paper Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * add more detailed comments * convert assert to ValueError * Update kto_trainer.py * precommit formatting * remove nans in metrics by gathering across machines * fix formatting * fix choice of mismatched examples for KL term * describe weights * fix hanging issue in distributed training * linting * move metrics to cpu * Update trl/trainer/kto_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update trl/trainer/kto_trainer.py * Update trl/trainer/kto_trainer.py * remove kto_pair * speed up data processing * move bco code inside * raise error for kto_pair argument * fix formatting --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> Co-authored-by: Winnie Xu <winnie.xu97@gmail.com> commit 6105d03 Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon Jun 17 16:01:06 2024 +0200 `TrlParser`: Add ignore extra args option (#1748) * add ignore extra args option * Update trl/commands/cli_utils.py commit e247bbd Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon Jun 17 15:16:07 2024 +0200 CI / core: Pin `numpy` to `!=2.0.0` for CI and to users (#1747) * Update setup.py * Update setup.py * Update setup.py * Update test_best_of_n_sampler.py dummy commit * pin numpy * Update tests/test_best_of_n_sampler.py * Update setup.py commit 3d04496 Author: Michael <mnoukhov@gmail.com> Date: Mon Jun 17 08:43:33 2024 -0400 better trl parser with yaml config (#1739) * working trl parser with config correctly overrides yaml config with command line arguments adds return_remaining_strings when return_remaining_strings is False, raises error if yaml contains extra args that are not in the dataclasses simpler and cleaner than previous yaml parsing and merging addresses #1733 * lowercase trlparser commit 2d244f8 Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon Jun 17 11:56:13 2024 +0200 Workflow: Notify tests results on slack channel (#1744) * Update tests-main.yml * Update docker-build.yml commit f5168fd Author: Igor Melnyk <igoraries@gmail.com> Date: Wed Jun 12 05:54:54 2024 -0400 adds AOT (#1701) * adds AOT * Applied format changes * added docs and tests --------- Co-authored-by: Igor Melnyk <igor.melnyk@ibm.com> commit 79686e1 Author: jetlime <paul.houssel@yahoo.de> Date: Wed Jun 12 00:35:31 2024 +1000 ktotrainer: Refuse datasets which contain only one class of labels (#1724) * ktotrainer: refuse dataset which contain only one class of labels * ktotrainer: document new dataset constraint commit 34ebc4c Author: Luc Georges <McPatate@users.noreply.github.com> Date: Mon Jun 10 11:17:54 2024 +0200 feat(ci): add trufflehog secrets detection (#1721) * feat(ci): add trufflehog secrets detection * fix(ci): remove unnecessary permissions commit 1d84e2b Author: Michael <mnoukhov@gmail.com> Date: Fri Jun 7 11:42:08 2024 +0200 Fix default padding_value in dpo_config.py (#1692) dpo_config default padding value should be None, not 0, otherwise it by default overrides the padding value of any tokenizer to 0 commit 2f71b8b Author: Michael <mnoukhov@gmail.com> Date: Fri Jun 7 10:37:27 2024 +0200 fix yaml parser for derived config classes (#1713) fixes #1712 reformatted cli_utils with ruff commit 5bcb8ad Author: Kashif Rasul <kashif.rasul@gmail.com> Date: Fri Jun 7 08:48:17 2024 +0100 RDPO fix nll loss (#1705) commit b8b972f Author: Haoran Xu <45837851+fe1ixxu@users.noreply.github.com> Date: Thu Jun 6 14:06:47 2024 -0700 Add a variant of CPO, SimPO (#1703) * add a variant of cpo: simpo * correct cpo-simpo loss * avoid 0 int error in logging * add simpo description * Update trl/trainer/cpo_trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * fix formatting * add test for simpo * Update docs/source/cpo_trainer.mdx Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * add a docstring for simpogamma * move simpo description to the above docstring * change simpo description in the doc * formatting --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> commit 3eb9ccb Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu Jun 6 19:33:20 2024 +0200 set dev version (#1710) * Update setup.py * Update __init__.py commit 974b0d3 Author: Costa Huang <costa.huang@outlook.com> Date: Thu Jun 6 10:13:00 2024 -0400 0.9.4 release (#1708) commit 39a7d1c Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu Jun 6 15:50:17 2024 +0200 SFTTrainer: Fix backward Compatibility issue with `TrainingArguments` (#1707) * fix BC * fixup commit 0bdc638 Author: Guilherme Freire <guilhermebfreire@gmail.com> Date: Thu Jun 6 14:42:58 2024 +0100 Fixed doc string and docs for the SFTConfig update (#1706) commit 275d33b Author: Costa Huang <costa.huang@outlook.com> Date: Wed Jun 5 14:34:59 2024 -0400 0.9.3 release (#1699) commit c0819ee Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed Jun 5 17:29:03 2024 +0200 Update sft_trainer.py (#1698) commit a03e7cc Author: Costa Huang <costa.huang@outlook.com> Date: Wed Jun 5 11:00:19 2024 -0400 Release 0.9.2 (#1697) * Release: 0.9.0 * Release commit a13cb89 Author: Costa Huang <costa.huang@outlook.com> Date: Wed Jun 5 10:20:54 2024 -0400 Quick fix on GPT4-eval (#1696) * quick fix * precommit commit 84156f1 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Mon Jun 3 20:09:05 2024 +0200 Fix typo in DPOTrainer's warnings (#1688) commit 4eb0b90 Author: Alex Brooks <alex.brooks@ibm.com> Date: Mon Jun 3 10:24:32 2024 -0600 Skip packing validation (#1673) * Add test for skipping preproc if packing=True Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Allow skipping of validation for packing=True Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Use dummy dataset in no packing preproc test Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> --------- Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> commit 6c203f9 Author: Alexey Rozhkov <alexisrozhkov@gmail.com> Date: Mon Jun 3 10:16:22 2024 +0100 Fix overriding optimize_device_cache with optimize_cuda_cache in PPOConfig (#1690) * Don't override optimize_device_cache when optimize_cuda_cache is not provided Raise an exception when both optimize_cuda_cache and optimize_device_cache are set * Minor fix commit f18253b Author: Kashif Rasul <kashif.rasul@gmail.com> Date: Mon Jun 3 09:43:02 2024 +0100 intial RPO loss (#1686) * intial RPO loss * fix sign * clean up commit 151a452 Author: Samuel <s.kiegeland@gmx.de> Date: Wed May 29 20:29:38 2024 +0200 Fix max completion length (#1588) commit 488b502 Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed May 29 20:19:26 2024 +0200 fix (#1678) commit 3c0a10b Author: Wang, Yi <yi.a.wang@intel.com> Date: Mon May 27 20:52:20 2024 +0800 fix dataset load error (#1670) Signed-off-by: Wang, Yi <yi.a.wang@intel.com> commit b031adf Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri May 24 15:20:16 2024 +0200 FIX / PPO: Fix `enable_input_require_grads` issues with PPO models (#1664) * Update modeling_base.py * Update ppo_config.py * Update ppo_trainer.py * style commit e7cb597 Author: Costa Huang <costa.huang@outlook.com> Date: Thu May 23 11:37:16 2024 -0400 Fix ppov2 test case (#1661) * Fix PPOv2 / RLOO refactor's stuff * update terminology to use stop token commit bc8dfbf Author: Kashif Rasul <kashif.rasul@gmail.com> Date: Thu May 23 15:28:04 2024 +0200 update eval_strategy (#1662) commit e4ed7a3 Author: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Thu May 23 18:34:22 2024 +0530 do not upcast adapters when using FSDP+QLoRA (#1654) commit 9a7efbd Author: syrn1k <85796210+syrn1k@users.noreply.github.com> Date: Thu May 23 15:58:49 2024 +0300 🤫 TR-DPO implementation (#1593) * 🤫 TR-DPO implementation baseline * fix comments * docs * fix linters * test added * move configs to DPOConfig * fix typo * add docs * fix import * use state.global_step * fix order of arguments * make sure plugins are not none * Update trl/trainer/utils.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * Update trl/trainer/utils.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * checking that reference model weights have changed * sync_target_model as staticmethod * set reference model --------- Co-authored-by: Nikita Surnachev <n.surnachev@tinkoff.ru> Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> commit b344bce Author: Anush Kini <33577829+Abilityguy@users.noreply.github.com> Date: Thu May 23 18:27:25 2024 +0530 [DPO] Add 'robust' loss_type (#1653) * Initial commit * pre-commit fix * Minor change to comments * Added some documentation on how to use Robust DPO commit 35e12dc Author: Nicolinho <Nicolinho@users.noreply.github.com> Date: Thu May 23 14:36:15 2024 +0200 Fix inheritance order in PPOv2Config (#1659) * fix inheritance order in PPOv2Config * fix inheritance order in rloo_config commit 1da6be1 Author: Ali Bakly <anbakly@gmail.com> Date: Thu May 23 14:10:29 2024 +0200 docs: correct cDPO usage in DPOTrainer (#1655) commit e249cd8 Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu May 23 14:10:05 2024 +0200 add support for training collator (#1658) commit a02513c Author: Zach Mueller <muellerzr@gmail.com> Date: Thu May 23 06:48:00 2024 -0400 Apply deprecated `evaluation_strategy` (#1559) * Deprecate * Update tests/test_dpo_trainer.py --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> commit 13454d2 Author: Costa Huang <costa.huang@outlook.com> Date: Wed May 22 08:31:10 2024 -0400 PPO / Reinforce Trainers (#1540) * Add ppov2 trainer * make eos trick optional, remove unused args * quick fix * precommit * update debugging script * fix out of bound `drop_last=True`; use built-in scheduler * Add PPO examples * push changes * quick change * quick change * various bug fixes * remove unnecessary grad accumulation setting * push new changes * fix DS3 model saving * update ppo.py * refactor * quick change * refactor * update ppo trainer * refactor * quick test * add ds2 /ds3 7 processes config * add vllm trainer * quick change * experiment with reward normalization * push changes * quick push * push changes * push various changes * refactor to use ModelConfig * quick change * refactor * refactor * Simplify DS logic * quick update * remove unnecessary files * precommit * deepspeed fix; handle edge case when eos_token_id = 0 * add PPO tldr example * add TL;DR example * fix undefined var * utilize all samples in rloo * quick setting * remove the unnecessary `value_model` * use exact_div * allow saving the deepspeed model * refactor * remove dead code * Use some shared utilities * add some end-to-end test cases * add PPOv2 docs and RLOO docs / tests * update docs * quikc push * fix ci * fix type annotation for ci * quick update * update trainer docs commit 99f2c94 Author: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed May 15 19:55:46 2024 +0530 don't cast the trainable lora layers to half precision (#1644) * don't cast the trainable lora layers to half precision * quality commit 6401d08 Author: Wing Lian <wing.lian@gmail.com> Date: Tue May 14 09:41:07 2024 -0400 Pairwise Noise Contrastive Alignment (#1632) * add NCA paired preference loss * chore: lint * set more lenient tolerance for integration tests * Update tests/test_dpo_trainer.py * skip test * fix --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: younesbelkada <younesbelkada@gmail.com> commit d632a5b Author: bartoszzuk <57541034+bartoszzuk@users.noreply.github.com> Date: Tue May 14 12:25:54 2024 +0200 Fixed wrong logs prefixes in KTOTrainer (#1641) * Fixed wrong logs prefixes in KTOTrainer * Pre-commit formating commit 5aeb752 Author: Tiezhen WANG <38108242+xianbaoqian@users.noreply.github.com> Date: Fri May 10 23:19:15 2024 +0800 Update sft_llama2.py to work with the latest API (#1637) * Update sft_llama2.py to work with the latest API SFTTrainer now takes a STFConfig argument * Update dpo_llama2.py * precommit commit b8b8978 Author: Ilya Gusev <phoenixilya@gmail.com> Date: Fri May 10 15:43:13 2024 +0200 [ORPO] Correct label mask for pad tokens (#1625) * [ORPO] Correct label mask for pad tokens Recent [fix](57aebe9) for calculating NLL loss for a whole sequence introduced a bug. When input_ids are copied to labels, pad tokens are not masked. This PR aims to path this by masking labels based on the attention mask. * -100 -> label_pad_token_id Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> commit 8799952 Author: Costa Huang <costa.huang@outlook.com> Date: Fri May 10 09:32:20 2024 -0400 visualize rm prediction (#1636) * visualize rm prediction * quick update * quick check * quick fix * update eval steps commit 3b4c249 Author: Xiao Yu <39458711+jasonyux@users.noreply.github.com> Date: Fri May 3 18:19:35 2024 -0400 fixed adding bos and eos token unconditionally (#1591) * fixed adding bos and eos token unconditionally * fixed typo of tokenizer -> self.tokenizer. Also added update to ORPO * fixed code quality, and added BOS/EOS fix to KTO * code reformatting with pre-commit run --all-files * bug fix: check input id length before checking for EOS/BOS commit 0347f58 Author: lewtun <lewis.c.tunstall@gmail.com> Date: Fri May 3 15:59:59 2024 +0200 Fix ZeRO-3 generation context manager (#1617) * judge refactoring and unittest * format * init * doc * format * improve doc * basejudge * improve doc and add BaseAPIJudge * Doc * style * refactor callback * remove openai and pairrm judge from test * doc * rm dpo online example * new prompts and completions * skip hf judge and add hf token --------- Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Hi,
I have added support for AlignProp (https://align-prop.github.io/) for finetuning Stable Diffusion model using reward gradients.
AlignProp directly backpropagate gradients from the reward model to the diffusion weights. Thus is about 25x more sample and compute efficient than policy gradient based methods like DDPO.
The current implementation seems to train effectively, almost within an hour on a single A100 while using Aesthetic reward model. Please find the attached loss and reward curves + some qualitative results after training.
huggingface/diffusers#7312
Difference between DDPO and AlignProp:
DDPO uses PPO, which is a policy gradient method for aligning diffusion models. AlignProp doesn't use policy gradients instead it directly backpropagates gradients from the reward function to diffusion denoising process, to maximize reward.
AlignProp can only work when the reward function is differentiable, DDPO on other hand can handle non-differentiable reward functions, as it never backpropagates gradients from the reward function weights.
As AlignProp takes benefit of the differentiability of the reward function as it backpropagates gradient. It is significantly more sample efficient than DDPO.
The loss function in AlignProp is simply the negative of the reward value outputed by the reward function, while in DDPO it's the PPO loss function.
As the reward function is sitting on the RGB images. AlignProp requires to do the full denoising chain from Noise to RGB during training, while DDPO can instead sample random denoising timesteps, similar to diffusion training.
DDPO and AlignProp both use LoRA and gradient checkpointing.
CC: @parthos86 @sayakpaul @lvwerra @younesbelkada
Image Generations post training: