Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Padding free dpo #2437

Open
wants to merge 55 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
ca99954
added eos token for ppotrainer
dame-cell Nov 30, 2024
fe1d5f6
remove the unnecessary stuff
dame-cell Nov 30, 2024
b15c635
Update ppo_config.py
dame-cell Nov 30, 2024
1bcb3a4
remove redundant EOS token fallback
dame-cell Dec 1, 2024
2ef2b24
remove redundant EOS token fallback
dame-cell Dec 1, 2024
42a0f73
remove some unnecessary tests stuff
dame-cell Dec 1, 2024
6130a91
added tests and update concatenated_inputs
dame-cell Dec 4, 2024
6732ed2
return only list and also a lot to do
dame-cell Dec 4, 2024
ce67292
padding free not tested but getting closer
dame-cell Dec 10, 2024
91e40aa
rebase and also reevaluate my approach
dame-cell Dec 10, 2024
8a34cb5
merge main
dame-cell Dec 10, 2024
1d38632
fix identation
dame-cell Dec 10, 2024
814d69e
better tests
dame-cell Dec 10, 2024
7dae607
concatenated_forward now supports padding_free
dame-cell Dec 10, 2024
1a59d74
collator now does not return attention masks
dame-cell Dec 10, 2024
562c52e
postion ids and no attention mask works
dame-cell Dec 11, 2024
d194054
update concatenated forward to support padding_free
dame-cell Dec 11, 2024
3855851
update concatenated forward to support padding_free
dame-cell Dec 11, 2024
d9adbfb
Merge branch 'main' into padding_free_dpo
dame-cell Dec 11, 2024
1145006
grad accumalation tests
dame-cell Dec 13, 2024
24f73a4
Merge branch 'padding_free_dpo' of https://github.com/dame-cell/trl i…
dame-cell Dec 13, 2024
f6bd9e1
Resolved merge conflict in ppo_trainer.py
dame-cell Dec 13, 2024
bbd99cf
Resolved merge conflict in ppo_trainer.py
dame-cell Dec 13, 2024
ba4969d
Resolved merge conflict in ppo_trainer.py
dame-cell Dec 13, 2024
187b1e5
Resolved merge conflict in ppo_trainer.py
dame-cell Dec 13, 2024
1d9ce3e
fix identation
dame-cell Dec 13, 2024
8a974cc
comments update
dame-cell Dec 13, 2024
7f0298b
fix some small issue
dame-cell Dec 13, 2024
2900275
fix some small issue
dame-cell Dec 13, 2024
58e779a
fix some small issue
dame-cell Dec 13, 2024
6a1e251
fix some small issue
dame-cell Dec 13, 2024
f92e056
update concatenate_forward to support padding_fre
dame-cell Dec 13, 2024
457e3a1
fix some small issue
dame-cell Dec 13, 2024
f1789f4
fix some small issue
dame-cell Dec 13, 2024
0103728
So we need to make sure to correctlty handle the list
dame-cell Dec 13, 2024
5837faa
by correclty updatuing concatenated_forward it works now
dame-cell Dec 13, 2024
0321c1d
refactoring concatenated_forward and batched same length seq for padd…
dame-cell Dec 14, 2024
a6e2163
update
dame-cell Dec 14, 2024
1328fc3
padding_free in concatenated_forward and update_test
dame-cell Dec 17, 2024
51a2cc6
Merge branch 'main' into padding_free_dpo
dame-cell Dec 17, 2024
986ed71
padding_free in concatenated_forward and update_test
dame-cell Dec 17, 2024
570b79a
padding_free in concatenated_forward and update_test
dame-cell Dec 17, 2024
9dd9564
Merge branch 'padding_free_dpo' of https://github.com/dame-cell/trl i…
dame-cell Dec 17, 2024
b781876
Merge branch 'main' into padding_free_dpo
dame-cell Dec 18, 2024
525ecb2
Merge branch 'main' into padding_free_dpo
dame-cell Dec 19, 2024
c8ce9c8
Merge branch 'main' into padding_free_dpo
dame-cell Dec 19, 2024
b7fad73
Reverted PPO trainer to original version and updated DPO files
dame-cell Dec 19, 2024
55cd219
Merge branch 'padding_free_dpo' of https://github.com/dame-cell/trl i…
dame-cell Dec 19, 2024
5e8df69
Updated DPO files
dame-cell Dec 19, 2024
ba1ded1
Merge branch 'main' into padding_free_dpo
dame-cell Dec 20, 2024
0784202
Merge branch 'main' into padding_free_dpo
dame-cell Dec 21, 2024
ddfed7c
update test_dpo_trainer.py
dame-cell Dec 21, 2024
955c7e8
update dpo_trainer.py
dame-cell Dec 21, 2024
ba4356e
update dpo_trainer.py
dame-cell Dec 21, 2024
c61abb5
update dpo_trainer.py
dame-cell Dec 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,33 @@ def test_tr_dpo_trainer(self):
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

def test_dpo_trainer_padding_free_training(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
eval_strategy="steps",
beta=0.1,
report_to="none",
padding_free=True,
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
trainer = DPOTrainer(
model=self.model,
ref_model=None,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

trainer.train()

@require_no_wandb
def test_dpo_trainer_generate_during_eval_no_wandb(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class DPOConfig(TrainingArguments):
for saving memory and speeding up training by not computing the logits for all tokens, especially in scenarios
when working with very long prompts where labels are -ignored (-100).
[Read more](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaForCausalLM)
padding_free (`bool`, defaults to `False`):
Whether to use padding-free training. If set to `True`, the trainer will operate in a padding-free mode.
"""

learning_rate: float = 1e-6
Expand Down Expand Up @@ -192,3 +194,4 @@ class DPOConfig(TrainingArguments):
rpo_alpha: Optional[float] = None
discopop_tau: float = 0.05
use_num_logits_to_keep: bool = False
padding_free: bool = False
261 changes: 179 additions & 82 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class DPOTrainer(Trainer):
The function to use to preprocess the logits before computing the metrics.
peft_config (`dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.

"""

_tag_names = ["trl", "dpo"]
Expand Down Expand Up @@ -347,7 +348,7 @@ def make_inputs_require_grad(module, input, output):
self.model_adapter_name = args.model_adapter_name
self.ref_adapter_name = args.ref_adapter_name
self.reference_free = args.reference_free

self.padding_free = args.padding_free
if ref_model:
self.ref_model = ref_model
elif self.is_peft_model or args.precompute_ref_log_probs:
Expand Down Expand Up @@ -1068,8 +1069,10 @@ def dpo_loss(

def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]):
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

We do this to avoid doing two forward passes, because it's faster for FSDP.
When padding_free=True, sequences are concatenated without padding tokens and processed as a single
continuous sequence with reset position IDs, improving memory efficiency and computation speed.
When False, uses standard padded batch processing.
"""
num_examples = batch["prompt_input_ids"].shape[0]

Expand Down Expand Up @@ -1121,91 +1124,185 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx)
loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx)

# Get the first column idx that is all zeros and remove every column after that
empty_cols = torch.sum(attention_mask, dim=0) == 0
first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else attention_mask.size(1)
input_ids = input_ids[:, :first_empty_col]
attention_mask = attention_mask[:, :first_empty_col]
loss_mask = loss_mask[:, :first_empty_col]

# Truncate right
if self.args.max_length is not None:
input_ids = input_ids[:, : self.args.max_length]
attention_mask = attention_mask[:, : self.args.max_length]
loss_mask = loss_mask[:, : self.args.max_length]

if self.use_num_logits_to_keep:
# Compute num_logits_to_keep based on loss_mask pattern:
# [[0, 0, 0, x, x, x, x],
# [0, 0, 0, x, x, x, 0]]
# ^ start computing logits from here ([:, -(7-3+1):])
first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
num_logits_to_keep = loss_mask.shape[1] - first_compute_index
model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1 # +1 for the first label

outputs = model(input_ids=input_ids, attention_mask=attention_mask, **model_kwargs)

# Offset the logits by one to align with the labels
logits = outputs.logits[:, :-1, :]
labels = input_ids[:, 1:].clone()
loss_mask = loss_mask[:, 1:].bool()

if self.use_num_logits_to_keep:
# Align labels with logits
# logits: -, -, [x2, x3, x4, x5, x6]
# ^ --------- ^ after logits[:, :-1, :]
# labels: [y0, y1, y2, y3, y4, y5, y6]
# ^ --------- ^ with num_logits_to_keep=4, [:, -4:]
# loss_mask: [0, 0, 0, 1, 1, 1, 1]
labels = labels[:, -num_logits_to_keep:]
loss_mask = loss_mask[:, -num_logits_to_keep:]

if logits.shape[:2] != labels.shape[:2]:
# for llava, the returned logits include the image tokens (placed before the text tokens)
seq_len = labels.shape[1]
logits = logits[:, -seq_len:]

# Compute the log probabilities of the labels
labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps[~loss_mask] = 0
all_logps = per_token_logps.sum(-1)

output = {}

if self.use_weighting:
with torch.no_grad():
# Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827
logprobs = F.log_softmax(logits, dim=-1)
weights_adjustment_factor = torch.logsumexp(2 * logprobs, dim=-1) # same as sum(probs**2) in log space
per_token_logps_adjusted = per_token_logps - weights_adjustment_factor
all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1)
chosen_weights = all_weights[:num_examples]
rejected_weights = all_weights[num_examples:]
output["policy_weights"] = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1)

if self.args.rpo_alpha is not None:
# Only use the chosen logits for the RPO loss
chosen_logits = logits[:num_examples]
chosen_labels = labels[:num_examples]
if self.padding_free:
# Padding-free specific preprocessing
seq_lengths = attention_mask.sum(1)
if self.args.max_length is not None:
seq_lengths = torch.clamp(seq_lengths, max=self.args.max_length)
input_ids = input_ids[:, : self.args.max_length]

# Create concatenated sequence
total_length = seq_lengths.sum().item()
concatenated_input_ids = torch.zeros(total_length, dtype=input_ids.dtype, device=input_ids.device)
position_ids = torch.zeros(total_length, dtype=torch.long, device=input_ids.device)

# Fill tensors with sequence boundaries tracking
current_idx = 0
sequence_boundaries = []
for i in range(input_ids.size(0)):
length = seq_lengths[i].item()
valid_tokens = input_ids[i, :length]
concatenated_input_ids[current_idx : current_idx + length] = valid_tokens
position_ids[current_idx : current_idx + length] = torch.arange(length, device=input_ids.device)
sequence_boundaries.append((current_idx, current_idx + length))
current_idx += length

model_kwargs.pop("attention_mask", None)
model_inputs = {
"input_ids": concatenated_input_ids.unsqueeze(0),
"position_ids": position_ids.unsqueeze(0),
}
else:
# Get the first column idx that is all zeros and remove every column after that
empty_cols = torch.sum(attention_mask, dim=0) == 0
first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else attention_mask.size(1)
input_ids = input_ids[:, :first_empty_col]
attention_mask = attention_mask[:, :first_empty_col]
loss_mask = loss_mask[:, :first_empty_col]

# Truncate right
if self.args.max_length is not None:
input_ids = input_ids[:, : self.args.max_length]
attention_mask = attention_mask[:, : self.args.max_length]
loss_mask = loss_mask[:, : self.args.max_length]

if self.use_num_logits_to_keep:
# Compute num_logits_to_keep based on loss_mask pattern:
# [[0, 0, 0, x, x, x, x],
# [0, 0, 0, x, x, x, 0]]
# ^ start computing logits from here ([:, -(7-3+1):])
first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
num_logits_to_keep = loss_mask.shape[1] - first_compute_index
model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1 # +1 for the first label

model_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}

# Compute the log probabilities of the labels
output["nll_loss"] = F.cross_entropy(
torch.flatten(chosen_logits, end_dim=1), torch.flatten(chosen_labels, end_dim=1), ignore_index=0
)
outputs = model(**model_inputs, **model_kwargs)
# Process outputs
logits = outputs.logits
if self.padding_free:
# In padding_free mode:
# - We work with a single concatenated sequence
# - No padding tokens are used
# - Position IDs reset for each sequence
# - Use sequence boundaries to track where each sequence starts/ends
logits = logits[0, :-1, :] # Remove batch dim and last position
labels = concatenated_input_ids[1:].clone() # Shift right by 1 for next-token prediction

per_token_logps = torch.gather(logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)

# Use sequence boundaries to split back into individual sequences
batch_logps = []
for start, end in sequence_boundaries:
sequence_logps = per_token_logps[start : end - 1]
sequence_length = end - start - 1 # -1 because we're using end-1 above
normalized_logp = sequence_logps.sum() / sequence_length # normalize by sequence_length
batch_logps.append(normalized_logp)

all_logps = torch.stack(batch_logps)
else:
# In standard mode:
# - Work with padded batches
# - Use attention masks to handle padding
# - All sequences in batch processed in parallel
# - Need to handle padding tokens carefully
if self.use_num_logits_to_keep:
labels = labels[:, -num_logits_to_keep:]
loss_mask = loss_mask[:, -num_logits_to_keep:]

logits = logits[:, :-1, :] # Remove last position
labels = input_ids[:, 1:].clone() # Shift right by 1
loss_mask = loss_mask[:, 1:].bool() # Adjust mask for shifted positions

# Handle padding and alignment
if self.use_num_logits_to_keep:
labels = labels[:, -num_logits_to_keep:]
loss_mask = loss_mask[:, -num_logits_to_keep:]

if logits.shape[:2] != labels.shape[:2]:
seq_len = labels.shape[1]
logits = logits[:, -seq_len:]

# Mask out padding tokens
labels[~loss_mask] = 0
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps[~loss_mask] = 0
all_logps = per_token_logps.sum(-1)

# Prepare output dictionary
output = {}

# Add weighting if enabled
if self.use_weighting:
with torch.no_grad():
logprobs = F.log_softmax(logits, dim=-1)
weights_adjustment_factor = torch.logsumexp(2 * logprobs, dim=-1)
per_token_logps_adjusted = per_token_logps - weights_adjustment_factor

if self.padding_free:
# Use sequence boundaries to calculate weights
sequence_weights = [
per_token_logps_adjusted[start : end - 1].mean() for start, end in sequence_boundaries
]
all_weights = torch.stack(sequence_weights)
else:
# Use loss mask to handle padding in weight calculation
all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1)

chosen_weights = all_weights[:num_examples]
rejected_weights = all_weights[num_examples:]
output["policy_weights"] = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1)

# Add RPO loss if enabled
if self.args.rpo_alpha is not None:
if self.padding_free:
# Use sequence boundaries to get chosen sequences
chosen_end_idx = sequence_boundaries[num_examples - 1][1] - 1
chosen_logits = logits[:chosen_end_idx]
chosen_labels = labels[:chosen_end_idx]
else:
# Simply split batch in half for chosen sequences
chosen_logits = logits[:num_examples]
chosen_labels = labels[:num_examples]

if self.loss_type == "ipo":
all_logps = all_logps / loss_mask.sum(-1)
output["nll_loss"] = F.cross_entropy(
chosen_logits.reshape(-1, chosen_logits.size(-1)), chosen_labels.reshape(-1), ignore_index=0
)

output["chosen_logps"] = all_logps[:num_examples]
output["rejected_logps"] = all_logps[num_examples:]
output["mean_chosen_logits"] = logits[:num_examples][loss_mask[:num_examples]].mean()
output["mean_rejected_logits"] = logits[num_examples:][loss_mask[num_examples:]].mean()
# Handle IPO loss type
if self.loss_type == "ipo":
if self.padding_free:
# Normalize by actual sequence lengths from boundaries
all_logps = all_logps / torch.tensor(
[end - start - 1 for start, end in sequence_boundaries], device=all_logps.device
)
else:
# Normalize by non-padded length from loss mask
all_logps = all_logps / loss_mask.sum(-1)

# Add log probabilities to output
output["chosen_logps"] = all_logps[:num_examples]
output["rejected_logps"] = all_logps[num_examples:]

# Calculate mean logits
if self.padding_free:
# Use sequence boundaries to separate chosen/rejected
chosen_end_idx = sequence_boundaries[num_examples - 1][1] - 1
output["mean_chosen_logits"] = logits[:chosen_end_idx].mean()
output["mean_rejected_logits"] = logits[chosen_end_idx:].mean()
else:
# Use loss mask to ignore padding tokens in mean calculation
output["mean_chosen_logits"] = logits[:num_examples][loss_mask[:num_examples]].mean()
output["mean_rejected_logits"] = logits[num_examples:][loss_mask[num_examples:]].mean()

if self.aux_loss_enabled:
output["aux_loss"] = outputs.aux_loss
# Add auxiliary loss if enabled
if self.aux_loss_enabled:
output["aux_loss"] = outputs.aux_loss

return output
return output

def get_batch_loss_metrics(
self,
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,6 @@ class PPOConfig(OnPolicyConfig):
cliprange: float = 0.2
vf_coef: float = 0.1
cliprange_value: float = 0.2
"""Clip range for the value function."""
gamma: float = 1.0
lam: float = 0.95
3 changes: 1 addition & 2 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -142,7 +142,6 @@ def __init__(
None # disable `pad_token_id` and `eos_token_id` because we just want to
)
self.policy_model.generation_config.pad_token_id = None # generate tokens without truncation / padding

# peft support
if not is_peft_available() and peft_config is not None:
raise ImportError(
Expand Down