Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use any reward model for online methods #2276

Merged
merged 18 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions docs/source/online_dpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,17 @@ Instead of a judge, you can chose to use a reward model -- see [Reward Bench](ht

- judge = PairRMJudge()
+ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
+ reward_tokenizer = AutoTokenizer.from_pretrained("trl-lib/Qwen2-0.5B-Reward")

trainer = OnlineDPOTrainer(
...
- judge=judge,
+ reward_model=reward_model,
+ reward_processing_class=reward_tokenizer,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the reason to use a processing class in case we want to support other modalities beyond text?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly. And since the tokenizer is now called processing_class within trainers, I'd recommend always aligning with it (even if only the textual modality is supported). Unless you have a good reason not to.

...
)
```

<Tip warning={true}>

Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training.

</Tip>

### Encourage EOS token generation

When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`OnlineDPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`OnlineDPOConfig`]:
Expand Down
6 changes: 6 additions & 0 deletions examples/scripts/dpo_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,13 @@
trust_remote_code=model_config.trust_remote_code,
**model_kwargs,
)
reward_tokenizer = AutoTokenizer.from_pretrained(
training_args.reward_model_path,
trust_remote_code=model_config.trust_remote_code,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also set truncation=True along with truncation_side="left" to ensure the labels aren't lost on long inputs? We might also need to allow people to set max_length since the context window of the tokenizer might be different from the policy - would that be best stored in the ScriptArguments?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for truncation, and truncation side. Not sure what's the best way to let the user set the max_length. Ok for doing this in a follow-up PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to follow up in separate PR - it should generally be safe for RMs that define the max length implicitly in their config anyway

)
else:
reward_model = None
reward_tokenizer = None

if training_args.judge is not None:
judge_cls = JUDGES[training_args.judge]
Expand Down Expand Up @@ -123,6 +128,7 @@
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
reward_processing_class=reward_tokenizer,
peft_config=get_peft_config(model_config),
)
generation_config = GenerationConfig(
Expand Down
6 changes: 6 additions & 0 deletions tests/test_judges.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@


class TestJudges(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Initialize once to download the model. This ensures it’s downloaded before running tests, preventing issues
# where concurrent tests attempt to load the model while it’s still downloading.
PairRMJudge()

def _get_prompts_and_completions(self):
prompts = ["The capital of France is", "The biggest planet in the solar system is"]
completions = [["Paris", "Marseille"], ["Saturn", "Jupiter"]]
Expand Down
34 changes: 23 additions & 11 deletions tests/test_online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available

from trl import OnlineDPOConfig, OnlineDPOTrainer, PairRMJudge, is_llmblender_available
from trl import OnlineDPOConfig, OnlineDPOTrainer, RandomPairwiseJudge, is_llmblender_available
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


if is_peft_available():
Expand All @@ -33,6 +34,9 @@ def setUp(self):
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1)
self.reward_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
self.reward_tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
self.reward_tokenizer.pad_token = self.reward_tokenizer.eos_token
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.tokenizer.pad_token = self.tokenizer.eos_token

Expand All @@ -53,9 +57,10 @@ def test_training(self, config_name):
model=self.model,
reward_model=self.reward_model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
processing_class=self.tokenizer,
reward_processing_class=self.reward_tokenizer,
)
trainer.train()

Expand All @@ -79,9 +84,10 @@ def test_training_with_ref_model(self):
ref_model=self.ref_model,
reward_model=self.reward_model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
processing_class=self.tokenizer,
reward_processing_class=self.reward_tokenizer,
)
trainer.train()

Expand All @@ -103,9 +109,11 @@ def test_ref_model_is_model(self):
OnlineDPOTrainer(
model=self.model,
ref_model=self.model, # ref_model can't be the same as model
reward_model=self.reward_model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
processing_class=self.tokenizer,
reward_processing_class=self.reward_tokenizer,
)

@require_peft
Expand All @@ -126,9 +134,10 @@ def test_training_with_peft(self):
model=self.model,
reward_model=self.reward_model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
processing_class=self.tokenizer,
reward_processing_class=self.reward_tokenizer,
peft_config=lora_config,
)

Expand Down Expand Up @@ -156,9 +165,10 @@ def test_training_with_peft_and_ref_model(self):
ref_model=self.ref_model,
reward_model=self.reward_model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
processing_class=self.tokenizer,
reward_processing_class=self.reward_tokenizer,
peft_config=lora_config,
)

Expand Down Expand Up @@ -188,9 +198,10 @@ def test_training_with_peft_model_and_peft_config(self):
model=model,
reward_model=self.reward_model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
processing_class=self.tokenizer,
reward_processing_class=self.reward_tokenizer,
peft_config=lora_train_config,
)

Expand All @@ -200,7 +211,8 @@ def test_training_with_peft_model_and_peft_config(self):
self.assertIn("train_loss", trainer.state.log_history[-1])

@unittest.skipIf(not is_llmblender_available(), "llm-blender is not available")
def test_training_with_judge(self):
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
def test_training_with_judge(self, config_name):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = OnlineDPOConfig(
output_dir=tmp_dir,
Expand All @@ -210,15 +222,15 @@ def test_training_with_judge(self):
eval_strategy="steps",
report_to="none",
)
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")
dummy_dataset = load_dataset("trl-internal-testing/zen", config_name)

trainer = OnlineDPOTrainer(
model=self.model,
judge=PairRMJudge(),
judge=RandomPairwiseJudge(),
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
processing_class=self.tokenizer,
)
trainer.train()

Expand Down
5 changes: 3 additions & 2 deletions tests/test_trainers_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,13 @@ def test_online_dpo(self, beta_list):
reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
trainer = OnlineDPOTrainer(
args=training_args,
processing_class=tokenizer,
model=model,
ref_model=ref_model,
reward_model=reward_model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
reward_processing_class=tokenizer,
)
self.assertEqual(trainer.args.max_new_tokens, 42)
self.assertEqual(trainer.args.temperature, 0.5)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_xpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available

from trl import PairRMJudge, XPOConfig, XPOTrainer, is_llmblender_available
from trl import RandomPairwiseJudge, XPOConfig, XPOTrainer, is_llmblender_available


if is_peft_available():
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_xpo_trainer_judge_training(self, config_name):
report_to="none",
)
dummy_dataset = load_dataset("trl-internal-testing/zen", config_name)
judge = PairRMJudge()
judge = RandomPairwiseJudge()

trainer = XPOTrainer(
model=self.model,
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/nash_md_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
reward_processing_class=processing_class, # for now, NashMDTrainer can't use any reward model
peft_config=peft_config,
compute_metrics=compute_metrics,
callbacks=callbacks,
Expand Down
66 changes: 48 additions & 18 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from transformers.training_args import OptimizerNames
from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging

from ..data_utils import is_conversational, maybe_apply_chat_template
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..models import create_reference_model
from ..models.utils import unwrap_model_for_generation
from .judges import BasePairwiseJudge
Expand Down Expand Up @@ -137,6 +137,7 @@ def __init__(
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
reward_processing_class: Optional[PreTrainedTokenizerBase] = None,
peft_config: Optional[Dict] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
Expand All @@ -161,6 +162,7 @@ def __init__(
raise ValueError("Either `reward_model` or `judge` must be provided.")

self.reward_model = reward_model
self.reward_processing_class = reward_processing_class
self.judge = judge

if args.missing_eos_penalty is not None and judge is not None:
Expand Down Expand Up @@ -428,18 +430,23 @@ def training_step(
ref_logprobs = torch.take_along_dim(ref_all_logprobs, completion_ids.unsqueeze(-1), dim=2).squeeze(-1)
del ref_output, ref_logits, ref_all_logprobs # free memory

# Get the reward from the reward model or judge:
if self.judge is not None:
completions = self.processing_class.batch_decode(
prompt_completion_ids[:, context_length:], skip_special_tokens=True
)
completions = [completion.strip() for completion in completions] # remove the leading space
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need strip

# Decode the completions, and format them if the input is conversational
device = prompt_completion_ids.device
completions_ids = prompt_completion_ids[:, context_length:]
completions = self.processing_class.batch_decode(completions_ids, skip_special_tokens=True)
if is_conversational({"prompt": prompts[0]}):
completions = [[{"role": "assistant", "content": completion}] for completion in completions]

# Get the reward from the reward model or judge
if self.judge is not None:
# Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not
lewtun marked this conversation as resolved.
Show resolved Hide resolved
# directly understandable by the judge and could alter its judgment. To avoid this and make the judge
# independent of the model's chat template, we use the raw conversation data, and apply our own chat
# template to it.
if is_conversational({"prompt": prompts[0]}):
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
environment = jinja2.Environment()
template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
prompts = [template.render(messages=message) for message in prompts]
prompts = [template.render(messages=prompt) for prompt in prompts]
completions = [template.render(messages=completion) for completion in completions]

ranks_of_first_completion = self.judge.judge(
Expand All @@ -449,24 +456,47 @@ def training_step(
# convert ranks to a True/False mask:
# when rank == 0, it means the first completion is the best
# when rank == 1, it means the second completion is the best
mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=prompt_completion_ids.device)
mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device)
else:
_, scores, _ = get_reward(
self.reward_model, prompt_completion_ids, self.processing_class.pad_token_id, context_length
)
# The reward model may not have the same chat template or tokenizer as the model, so we need to use the
# raw data (string), apply the chat template (if needed), and tokenize it with the reward processing class.
prompts = 2 * prompts # repeat the prompt: [prompt0, prompt1] -> [prompt0, prompt1, prompt0, prompt1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we do this? Is it to align a prompt with chosen/rejected?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, we have:

prompts = ["What color is the sky?", "What's the capital of France?"]
completions = ["Blue", "Lyon", "Green", "Paris"]

and later, we need to concat the prompts and the completions to compute the reward

prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1)
_, scores, _ = get_reward(
    self.reward_model, prompt_completion_ids, self.reward_processing_class.pad_token_id, context_length
)

so we need to repeat the prompt.

if is_conversational({"prompt": prompts[0]}):
examples = [{"prompt": p, "completion": c} for p, c in zip(prompts, completions)]
examples = [apply_chat_template(example, self.reward_processing_class) for example in examples]
prompts = [example["prompt"] for example in examples]
completions = [example["completion"] for example in examples]

# Tokenize the prompts
prompts_ids = self.reward_processing_class(
prompts, padding=True, return_tensors="pt", padding_side="left"
)["input_ids"].to(device)
context_length = prompts_ids.shape[1]

# Tokenize the completions
completions_ids = self.reward_processing_class(
completions, padding=True, return_tensors="pt", padding_side="right"
)["input_ids"].to(device)

# Concatenate the prompts and completions and get the reward
prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1)
with torch.inference_mode():
_, scores, _ = get_reward(
self.reward_model, prompt_completion_ids, self.reward_processing_class.pad_token_id, context_length
)

# Filter completion. Ensure that the sample contains stop_token_id
# Completions not passing that filter will receive a lower score.
if self.args.missing_eos_penalty is not None:
scores[~contain_eos_token] -= self.args.missing_eos_penalty
# Filter completion. Ensure that the sample contains stop_token_id
# Completions not passing that filter will receive a lower score.
if self.args.missing_eos_penalty is not None:
scores[~contain_eos_token] -= self.args.missing_eos_penalty

# Split the scores in 2 (the prompts of the first half are the same as the second half)
first_half, second_half = scores.split(num_examples)

# Get the indices of the chosen and rejected examples
mask = first_half >= second_half

num_examples_range = torch.arange(num_examples, device=prompt_completion_ids.device)
num_examples_range = torch.arange(num_examples, device=device)
chosen_indices = num_examples_range + (~mask * num_examples)
rejected_indices = num_examples_range + (mask * num_examples)

Expand Down
1 change: 1 addition & 0 deletions trl/trainer/xpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
reward_processing_class=processing_class, # for now, XPOTrainer can't use any reward model
peft_config=peft_config,
compute_metrics=compute_metrics,
callbacks=callbacks,
Expand Down
Loading