Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Padding free dpo #2437

Open
wants to merge 55 commits into
base: main
Choose a base branch
from
Open

Conversation

dame-cell
Copy link

@dame-cell dame-cell commented Dec 4, 2024

What does this PR do?

New feature #2422

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

For now this is just a draft will be continuing to work on it

@osanseviero

@dame-cell
Copy link
Author

dame-cell commented Dec 10, 2024

not really done yet but for now
everything seems to be working
if padding_free is set to True the trainer will not pad and also when padding_free =True attention_mask will not be used

for now here are some task to be done :

  • Ensure when padding_Free =True the trainer will not pad
  • Ensure that when padding_free = True the trainer will not use or return attention_mask
  • Ensure that when padding_free = True we use positon_ids
  • make tests

@dame-cell
Copy link
Author

most of the stuff is done just some small stuff left like dealing with list and converting to tensor

@dame-cell dame-cell marked this pull request as ready for review December 11, 2024 14:43
@dame-cell
Copy link
Author

dame-cell commented Dec 11, 2024

Hey @osanseviero,

The main idea for using padding_free is mostly in place now, but there are still a few things that need to be done. It would be awesome if you could take a look at the code and let me know if there's anything else I should address or add.

I've made it so the user can directly do this

trainer = DPOTrainer(
                model=self.model,
                ref_model=None,
                args=training_args,
                tokenizer=self.tokenizer,
                padding_free=True, # when true it will not use any padding 
                train_dataset=dummy_dataset["train"],
                eval_dataset=dummy_dataset["test"],
            )

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to have tests for this with gradient accumulation too. perhaps using pytest.mark.parameterize?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All right will do so thanks for reviewing 😎

@@ -53,6 +53,9 @@ class PPOConfig(OnPolicyConfig):
Discount factor.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This modif shouldn't be here, right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh my bad i'll fix it right now

Copy link
Member

@qgallouedec qgallouedec Dec 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You still have modifications in ppo files

@qgallouedec
Copy link
Member

qgallouedec commented Dec 14, 2024

Here

trl/trl/trainer/dpo_trainer.py

Lines 1115 to 1123 in 6d4ed07

# Flush left to reduce the memory usage
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
for i in range(attention_mask.size(0)):
first_one_idx = torch.nonzero(attention_mask[i])[0].item()
input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx)
attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx)
loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx)

After the flushing left, we could remove pad tokens, and add position ids:

# Flush left to reduce the memory usage 
# [[0, 0, x, x, x, x],  ->  [[x, x, x, x], 
#  [0, x, x, x, 0, 0]]       [x, x, x, 0]] 
for i in range(attention_mask.size(0)): 
    first_one_idx = torch.nonzero(attention_mask[i])[0].item() 
    input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) 
    attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) 
    loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx) 

if self.padding_free: 
    # input =             input =            pos_ids =           input =                    pod_ids =
    # [[x, x, x, x],  ->  [[x, x, x, x], and [[0, 1, 2, 3],  ->  [x, x, x, x, x, x, x]  and [0, 1, 2, 3, 0, 1, 2]
    #  [x, x, x, 0]]       [x, x, x]]         [0, 1, 2]] 

    ... # code here

@dame-cell
Copy link
Author

all right awesome actually this make more sense 😭

@dame-cell
Copy link
Author

dame-cell commented Dec 14, 2024

before I push my code again I want to benchmark this with padding and padding_free just to show the performance

@qgallouedec
Copy link
Member

You can push it, no worry we can still refine after

@dame-cell
Copy link
Author

dame-cell commented Dec 15, 2024

Thank you for your understanding! I wanted to let you know that I’m a bit tied up today and tomorrow, so I might not be able to push the code right away. I’ll try to get to it as soon as possible, but please feel free to let me know if there’s a hard deadline I should prioritize.

Thanks for your patience!
I'll keep working on it so I'll try to push it by Tommorow if i can

@qgallouedec
Copy link
Member

No rush on our side :)

@dame-cell
Copy link
Author

dame-cell commented Dec 17, 2024

all right so I think this does it I did check if we can train this on a single T4 gpu colab notebook
now using the examples scripts provided the file trl/scripts/dpo.py with a bit of update
I was able to train a model using the padding_Free =True

python trl/examples/scripts/dpo.py \
    --dataset_name trl-lib/ultrafeedback_binarized \
    --model_name_or_path Qwen/Qwen2-0.5B-Instruct \
    --learning_rate 5.0e-6 \
    --num_train_epochs 1 \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --gradient_checkpointing \
    --logging_steps 1 \
    --output_dir Qwen2-0.5B-DPO \
    --no_remove_unused_columns \
    --use_peft \
    --lora_r 32 \
    --lora_alpha 16

without padding_free it kept saying OOM is this normal or what ?
I have not updated the docs yet because I'm not 100 % sure this one works or is correct until after a review

@dame-cell dame-cell marked this pull request as ready for review December 17, 2024 14:58
@dame-cell
Copy link
Author

@osanseviero Just wanted to follow up on this PR and see if there’s any feedback so far. I’m happy to clarify anything or make updates if needed. Let me know whenever you get a chance—thanks so much for your time! 🙌

@qgallouedec
Copy link
Member

You still need to revert the changes applied to PPO files. And apply pre-commits

@dame-cell
Copy link
Author

dame-cell commented Dec 21, 2024

The new push changes some code due to these problems
this is training without padding_free

{'loss': 0.6933, 'grad_norm': 28.669958114624023, 'learning_rate': 4.880000000000001e-06, 'rewards/chosen':

 -0.0004301070875953883, 'rewards/rejected': -0.00019319055718369782, 'rewards/accuracies': 0.5, 'rewards/margins': 

-0.0002369165886193514, 'logps/chosen': -359.64996337890625, 'logps/rejected': -218.63882446289062, 'logits/chosen': 

-3.0950357913970947, 'logits/rejected': -2.8118934631347656, 'epoch': 0.02}

and training with padding_free

{'loss': 0.6931, 'grad_norm': 197.54409790039062, 'learning_rate': 4.960000000000001e-06, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/accuracies': 0.0, 'rewards/margins': 0.0, 'logps/chosen': -972.31005859375, 'logps/rejected': -1276.14697265625, 'logits/chosen': -2.9294216632843018, 'logits/rejected': -2.5271308422088623, 'epoch': 0.01}
{'loss': 0.6791, 'grad_norm': 124.33696746826172, 'learning_rate': 4.92e-06, 'rewards/chosen': 0.008736038580536842, 'rewards/rejected': -0.019646836444735527, 'rewards/accuracies': 1.0, 'rewards/margins': 0.02838287316262722, 'logps/chosen': -656.6194458007812, 'logps/rejected': -656.0517578125, 'logits/chosen': -2.9904067516326904, 'logits/rejected': -2.4395949840545654, 'epoch': 0.02}
{'loss': 0.6568, 'grad_norm': 147.94912719726562, 'learning_rate': 4.880000000000001e-06, 'rewards/chosen': 0.020291520282626152, 'rewards/rejected': -0.05392017588019371, 'rewards/accuracies': 1.0, 'rewards/margins': 0.07421170175075531, 'logps/chosen': -653.7568359375, 'logps/rejected': -713.071044921875, 'logits/chosen': -3.067491292953491, 'logits/rejected': -2.60148286819458, 'epoch': 0.02}
{'loss': 0.6229, 'grad_norm': 213.63063049316406, 'learning_rate': 4.84e-06, 'rewards/chosen': 0.02659912034869194, 'rewards/rejected': -0.12024041265249252, 'rewards/accuracies': 1.0, 'rewards/margins': 0.14683951437473297, 'logps/chosen': -1233.0953369140625, 'logps/rejected': -761.375732421875, 'logits/chosen': -2.791830062866211, 'logits/rejected': -2.6060781478881836, 'epoch': 0.03}

the grad_norm is so high for padding_free and the reward/accuracies is always 1.0 which is not correct

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Let DPOTrainer Support padding_free
3 participants