diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index f3d39ba445..d62cbd5372 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -24,6 +24,8 @@ title: AlignProp - local: bco_trainer title: BCO + - local: cgpo_trainer + title: CGPO - local: cpo_trainer title: CPO - local: ddpo_trainer diff --git a/docs/source/cgpo_trainer.md b/docs/source/cgpo_trainer.md new file mode 100644 index 0000000000..cf0d3c4249 --- /dev/null +++ b/docs/source/cgpo_trainer.md @@ -0,0 +1,108 @@ +# Constrained Generative Policy Optimization Trainer + +## Overview + +Constrained Generative Policy Optimization (CGPO) was proposed in [The Perfect Blend: Redefining RLHF with Mixture of Judges](https://huggingface.co/papers/2306.13649) by Tengyu Xu, Eryk Helenowski, Karthik Abinav Sankararaman, Di Jin, Kaiyan Peng, Eric Han, Shaoliang Nie, Chen Zhu, Hejia Zhang, Wenxuan Zhou, Zhouhao Zeng, Yun He,Karishma Mandyam, Arya Talabzadeh, Madian Khabsa, Gabriel Cohen, Yuandong Tian, Hao Ma, Sinong Wang and Han Fang. + +The abstract from the paper is the following: + +> Reinforcement learning from human feedback (RLHF) has become the leading approach for fine-tuning large language models (LLM). However, RLHF has limitations in multi-task learning (MTL) due to challenges of reward hacking and extreme multi-objective optimization (i.e., trade-off of multiple and/or sometimes conflicting objectives). Applying RLHF for MTL currently requires careful tuning of the weights for reward model and data combinations. This is often done via human intuition and does not generalize. In this work, we introduce a novel post-training paradigm which we called Constrained Generative Policy Optimization (CGPO). The core of CGPO is Mixture of Judges (MoJ) with cost-efficient constrained policy optimization with stratification, which can identify the perfect blend in RLHF in a principled manner. It shows strong empirical results with theoretical guarantees, does not require extensive hyper-parameter tuning, and is plug-and-play in common post-training pipelines. Together, this can detect and mitigate reward hacking behaviors while reaching a pareto-optimal point across an extremely large number of objectives. +Our results show that CGPO consistently outperforms other commonly used SoTA RLHF algorithms (such as PPO and DPO) on a wide range of tasks – general chat, STEM questions, instruction following, math, coding and knowledge. In particular, CGPO improves over PPO by 7.4% in AlpacaEval-2 (general chat), 12.5% in Arena-Hard (STEM & reasoning), 2% in IFEval (Instrcution Following), 2% in both MATH and GSM8K (Math & reasoning), 5% in HumanEval (Coding), and 2% in the ARC challenge (Knowledge). We also observe that PPO is susceptible to severe reward hacking behaviors (it exhibits severe regression in popular coding benchmarks) which can be addressed by CGPO. CGPO represents a breakthrough in RLHF, simultaneously addressing reward-hacking and extreme multi-objective optimization, and thereby advancing the state-of-the-art in aligning general-purpose LLMs. + + +CGPO is designed to address the challenges of reward hacking and the complexities of multi-task learning in RLHF. It introduces three key innovations: +1. A 'Mixture of Judges' (MoJs) combining rule-based and LLM-based judges to collaboratively detect reward hacking and ensure adherence to task-specific constraints. +2. Task-specific optimization strategies (independent MoJs, optimizers and reward models). +3. Three new constrained RLHF optimizers: Calibrated-Regularized Policy Gradient (CRPG), Constrained Online Direct Preference Optimization (CODPO), and Calibrated-Regularized Reward Ranking Finetuning (CRRAFT) + +This post-training method was contributed by [Gaetan Lopez](https://github.com/gaetanlop) + Add the names of the future PR reviewers (kashif, lewton, qgallouedec?) + +> [!WARNING] +> The `CGPOTrainer` currently only supports the single task with single objective setting. CGPO in multi-tasks with multi-objectives will be added in a future release. + +## Usage tips + +The `CGPOTrainer` is a wrapper around the transformers [`Trainer`] class that takes in a reward model and a mixture of judges. It mostly requires three parameters to be set via the [`CGPOConfig`] namely: +* `rlhf_optimizer`: specifies the optimizer to use for policy optimization, with three possible options: `crpg`, `codpo` and `crraft`. +* `k`: defines the number of generations per prompt. +* `kl_threshold`: sets the maximum allowable KL divergence between the model and the reference model for each generated completion. + +Based on the paper findings: For tasks requiring precise judges and extensive exploration, such as instruction following, math, and coding, use higher values for `k` and a more lenient KL threshold. Conversely, for tasks with less precise judges and where exploration is less critical, such as "general chat", use lower values of `k` and a stricter KL threshold. + +The basic API is as follows: + +```python +from datasets import Dataset +from trl import CGPOConfig, CGPOTrainer, MixtureOfConstraintJudges +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, +) + +NUM_DUMMY_SAMPLES = 100 + +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1) +mixture_of_judges = MixtureOfConstraintJudges([CustomJudge1, CustomJudge2]) + +train_dataset = Dataset.from_dict( + { + "messages": [ + [ + {"role": "user", "content": "Hi, how are you?"}, + {"role": "assistant", "content": "I'm great thanks"}, + ] + ] + * NUM_DUMMY_SAMPLES + } +) +eval_dataset = Dataset.from_dict( + { + "messages": [ + [ + {"role": "user", "content": "What colour is the sky?"}, + {"role": "assistant", "content": "The sky is blue"}, + ] + ] + * NUM_DUMMY_SAMPLES + } +) + +training_args = CGPOConfig( + output_dir="cgpo-model", + per_device_train_batch_size=2, + k=4, + rlhf_optimizer="crpg", + kl_threshold=10., + ) +trainer = CGPOTrainer( + model=model, + reward_model=teacher_model, + mixture_of_judges=mixture_of_judges, + args=training_args, + processing_class=tokenizer, + train_dataset=train_dataset, + eval_dataset=eval_dataset, +) +trainer.train() +``` + +### ⚠️ Use the same chat template + +Make sure that the SFT model and reward model use the _same_ chat template. Otherwise, you may find the model completions are scored incorrectly during training. + +### Expected dataset format + +The dataset should be formatted as a list of "messages" where each message is a list of dictionaries with the following keys: +* `role`: either `system`, `assistant` or `user` +* `content`: the message content + + +## CGPOTrainer + +[[autodoc]] CGPOTrainer + +## CGPOConfig + +[[autodoc]] CGPOConfig diff --git a/tests/test_cgpo_trainer.py b/tests/test_cgpo_trainer.py new file mode 100644 index 0000000000..0178638d29 --- /dev/null +++ b/tests/test_cgpo_trainer.py @@ -0,0 +1,590 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import torch +from datasets import load_dataset +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer +from transformers.testing_utils import require_peft + +from trl import CGPOConfig, CGPOTrainer +from trl.trainer.cgpo_trainer import MixtureOfConstraintJudges +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +class CGPOTrainerTester(unittest.TestCase): + def setUp(self): + self.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + self.reward_model = AutoModelForSequenceClassification.from_pretrained(self.model_id, num_labels=1) + # to replace one the Mixture of containtPR is merged + self.moj = MixtureOfConstraintJudges() + + # Ensure the tokenizer has a chat template + if not hasattr(self.tokenizer, "chat_template") or self.tokenizer.chat_template is None: + self.tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + + @parameterized.expand(["crraft", "crpg", "codpo"]) + def test_cgpo_trainer(self, rlhf_optimizer): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = CGPOConfig( + output_dir=tmp_dir, + rlhf_optimizer=rlhf_optimizer, + k=4, + kl_threshold=5.0, + temperature=0.9, + max_new_tokens=4, + per_device_train_batch_size=4, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") + + trainer = CGPOTrainer( + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + mixture_of_judges=self.moj, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) + + @parameterized.expand(["crraft", "crpg", "codpo"]) + def test_cgpo_trainer_no_satisfied_constraints(self, rlhf_optimizer): + with tempfile.TemporaryDirectory() as tmp_dir: + moj = MixtureOfConstraintJudges(method="all_violated") + training_args = CGPOConfig( + output_dir=tmp_dir, + rlhf_optimizer=rlhf_optimizer, + k=4, + kl_threshold=5.0, + temperature=0.9, + max_new_tokens=4, + per_device_train_batch_size=4, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") + + trainer = CGPOTrainer( + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + mixture_of_judges=moj, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have not changed if no constraints are satisfied + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: + assert torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) + + @parameterized.expand(["crraft", "crpg", "codpo"]) + def test_cgpo_trainer_all_satisfied_constraints(self, rlhf_optimizer): + with tempfile.TemporaryDirectory() as tmp_dir: + moj = MixtureOfConstraintJudges(method="all_satisfied") + training_args = CGPOConfig( + output_dir=tmp_dir, + rlhf_optimizer=rlhf_optimizer, + k=4, + kl_threshold=5.0, + temperature=0.9, + max_new_tokens=4, + per_device_train_batch_size=4, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") + + trainer = CGPOTrainer( + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + mixture_of_judges=moj, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) + + def test_cgpo_trainer_no_moj(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = CGPOConfig( + output_dir=tmp_dir, + rlhf_optimizer="crraft", + k=4, + kl_threshold=5.0, + temperature=0.9, + max_new_tokens=4, + per_device_train_batch_size=4, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") + + with self.assertRaisesRegex( + ValueError, + expected_regex="`mixture_of_judges` must be provided.", + ): + CGPOTrainer( + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + mixture_of_judges=None, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + def test_cgpo_trainer_no_reward_model(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = CGPOConfig( + output_dir=tmp_dir, + rlhf_optimizer="crraft", + k=4, + kl_threshold=5.0, + temperature=0.9, + max_new_tokens=4, + per_device_train_batch_size=4, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") + + with self.assertRaisesRegex( + ValueError, + expected_regex="`reward_model` must be provided.", + ): + CGPOTrainer( + model=self.model, + ref_model=self.ref_model, + reward_model=None, + mixture_of_judges=self.moj, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + def test_cgpo_trainer_wrong_rlhf_optimizer(self): + with tempfile.TemporaryDirectory() as tmp_dir: + wrong_rlhf_optimizer = "crraftss" + with self.assertRaisesRegex( + ValueError, + expected_regex=f"Invalid value for rlhf_optimizer: {wrong_rlhf_optimizer}. Must be one of 'crraft', 'codpo', or 'crpg'.", + ): + CGPOConfig( + output_dir=tmp_dir, + rlhf_optimizer=wrong_rlhf_optimizer, + k=4, + kl_threshold=5.0, + temperature=0.9, + max_new_tokens=4, + per_device_train_batch_size=4, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + + def test_cgpo_trainer_no_kl_threshold(self): + with tempfile.TemporaryDirectory() as tmp_dir: + with self.assertRaisesRegex( + ValueError, + expected_regex="Training without setting the KL divergence threshold is not supported.", + ): + CGPOConfig( + output_dir=tmp_dir, + rlhf_optimizer="crraft", + k=4, + kl_threshold=None, + temperature=0.9, + max_new_tokens=4, + per_device_train_batch_size=4, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + + @parameterized.expand(["crraft", "crpg", "codpo"]) + def test_cgpo_trainer_with_mini_batch(self, rlhf_optimizer): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = CGPOConfig( + output_dir=tmp_dir, + rlhf_optimizer=rlhf_optimizer, + k=4, + local_genscore_mini_batch_size=8, + kl_threshold=5.0, + temperature=0.9, + max_new_tokens=4, + per_device_train_batch_size=4, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") + + trainer = CGPOTrainer( + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + mixture_of_judges=self.moj, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) + + @parameterized.expand(["crraft", "crpg", "codpo"]) + def test_cgpo_trainer_with_missing_eos_penalty(self, rlhf_optimizer): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = CGPOConfig( + output_dir=tmp_dir, + rlhf_optimizer=rlhf_optimizer, + k=4, + missing_eos_penalty=1.0, + kl_threshold=5.0, + temperature=0.9, + max_new_tokens=4, + per_device_train_batch_size=4, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") + + trainer = CGPOTrainer( + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + mixture_of_judges=self.moj, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) + + def test_cgpo_trainer_without_providing_ref_model(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = CGPOConfig( + output_dir=tmp_dir, + rlhf_optimizer="crraft", + k=4, + kl_threshold=5.0, + temperature=0.9, + max_new_tokens=4, + per_device_train_batch_size=4, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") + + trainer = CGPOTrainer( + model=self.model, + reward_model=self.reward_model, + mixture_of_judges=self.moj, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) + + def test_cgpo_trainer_with_ref_model_is_model(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = CGPOConfig( + output_dir=tmp_dir, + rlhf_optimizer="crraft", + k=4, + kl_threshold=5.0, + temperature=0.9, + max_new_tokens=4, + per_device_train_batch_size=4, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") + + with self.assertRaises(ValueError): + CGPOTrainer( + model=self.model, + ref_model=self.model, + reward_model=self.reward_model, + mixture_of_judges=self.moj, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + @require_peft + def test_cgpo_trainer_without_providing_ref_model_with_lora(self): + from peft import LoraConfig + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = CGPOConfig( + output_dir=tmp_dir, + rlhf_optimizer="crraft", + k=4, + kl_threshold=5.0, + temperature=0.9, + max_new_tokens=4, + per_device_train_batch_size=4, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") + + trainer = CGPOTrainer( + model=self.model, + reward_model=self.reward_model, + mixture_of_judges=self.moj, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + if "lora" in n: + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) + + def test_cgpo_trainer_padding_token_id_is_none(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = CGPOConfig( + output_dir=tmp_dir, + rlhf_optimizer="crraft", + k=4, + kl_threshold=5.0, + temperature=0.9, + max_new_tokens=4, + per_device_train_batch_size=4, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") + + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + tokenizer.pad_token = None + + with self.assertRaisesRegex( + ValueError, + expected_regex="The tokenizer does not have a pad token. Please set `pad_token_id` in the tokenizer.", + ): + trainer = CGPOTrainer( + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + mixture_of_judges=self.moj, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + trainer.train() + + @parameterized.expand(["crraft", "crpg", "codpo"]) + def test_cgpo_tags(self, optimizer_name): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = CGPOConfig( + output_dir=tmp_dir, + rlhf_optimizer=optimizer_name, + k=4, + kl_threshold=5.0, + temperature=0.9, + max_new_tokens=4, + per_device_train_batch_size=4, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") + + trainer = CGPOTrainer( + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + mixture_of_judges=self.moj, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + for tag in ["cgpo", "trl", optimizer_name]: + self.assertIn(tag, trainer.model.model_tags) diff --git a/tests/test_utils.py b/tests/test_utils.py index 226861d96f..e18c974f66 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -250,3 +250,16 @@ def test_data_collator_for_chatml(self): # Verify that EOS token is at the end of labels self.assertEqual(labels[-1], self.eos_token_id, "The last token of labels should be EOS token.") + + def test_gold_answer(self): + data_no_gold = self.collator(self.examples) + + for example in self.examples: + example["gold_answer"] = "This is a gold answer" + + data_with_gold = self.collator(self.examples) + + # Verify that the batch do not contain the gold answers + self.assertNotIn("gold_answer", data_no_gold, "Batch should not contain gold answers.") + # Verify that the batch contain the gold answers + self.assertIn("gold_answer", data_with_gold, "Batch should contain gold answers.") diff --git a/trl/__init__.py b/trl/__init__.py index 405991e652..d0cc88cd2a 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -55,6 +55,8 @@ "BaseRankJudge", "BCOConfig", "BCOTrainer", + "CGPOConfig", + "CGPOTrainer", "CPOConfig", "CPOTrainer", "DataCollatorForCompletionOnlyLM", @@ -151,6 +153,8 @@ BaseRankJudge, BCOConfig, BCOTrainer, + CGPOConfig, + CGPOTrainer, CPOConfig, CPOTrainer, DataCollatorForCompletionOnlyLM, diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index f0eba412c6..d202fc3ba2 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -25,6 +25,8 @@ "bco_config": ["BCOConfig"], "bco_trainer": ["BCOTrainer"], "callbacks": ["LogCompletionsCallback", "RichProgressCallback", "SyncRefModelCallback", "WinRateCallback"], + "cgpo_config": ["CGPOConfig"], + "cgpo_trainer": ["CGPOTrainer"], "cpo_config": ["CPOConfig"], "cpo_trainer": ["CPOTrainer"], "ddpo_config": ["DDPOConfig"], @@ -89,6 +91,8 @@ from .bco_config import BCOConfig from .bco_trainer import BCOTrainer from .callbacks import LogCompletionsCallback, RichProgressCallback, SyncRefModelCallback, WinRateCallback + from .cgpo_config import CGPOConfig + from .cgpo_trainer import CGPOTrainer from .cpo_config import CPOConfig from .cpo_trainer import CPOTrainer from .ddpo_config import DDPOConfig diff --git a/trl/trainer/cgpo_config.py b/trl/trainer/cgpo_config.py new file mode 100644 index 0000000000..017af430ce --- /dev/null +++ b/trl/trainer/cgpo_config.py @@ -0,0 +1,83 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Literal, Optional + +from transformers import TrainingArguments + + +@dataclass +class CGPOConfig(TrainingArguments): + r""" + Configuration class for the [`CGPOTrainer`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + k (`int`, *optional*, defaults to `4`): + The number of responses generated by the policy for each prompt at each iteration. + rlhf_optimizer (`Literal["crraft", "codpo", "crpg"]`, *optional*, defaults to `crraft`): + The RLHF optimizer to update the policy during CGPO training. Possible options: + - CRRAFT (Calibrated-Regularized Reward Ranking Fine-tuning) + - CODPO (Constrained Online Direct Preference Optimization) + - CRPG (Calibrated-Regularized Policy Gradient) + kl_threshold (`float`, *optional*, defaults to `None`): + Maximum allowable KL-divergence during policy updates, used to prevent significant deviation from the reference model + beta (`float`, *optional*, defaults to `0.1`): + Used when `rlhf_optimizer` is set to `codpo`. + It controls the deviation from the reference model. Higher beta means less deviation from the reference model. + lamb (`float`, *optional*, defaults to `5.0`): + Used when `rlhf_optimizer` is set to `codpo`. + It controls the strength of the regularization term added to the DPO loss. + local_genscore_mini_batch_size (`int`, *optional*, defaults to `None`): + The size of the local mini-batch used to generate with the policy, score with the reward model, and get the logits from the reference model. + max_new_tokens (`int`, *optional*, defaults to `64`): + Maximum number of tokens to generate per completion. + max_length (`int`, *optional*, defaults to `None`): + The maximum sequence length for the prompt and the baseline completions. + temperature (`float`, *optional*, defaults to `0.9`): + Temperature for sampling. The higher the temperature, the more random the completions. + missing_eos_penalty (`Optional[float]`, *optional*, defaults to `None`): + Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage + to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive + value. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model. + """ + + k: int = 4 + rlhf_optimizer: Literal["crraft", "codpo", "crpg"] = "crraft" + kl_threshold: float = None + beta: float = 0.1 + lamb: float = 5.0 + local_genscore_mini_batch_size: int = None + max_new_tokens: int = 64 + max_length: int = None + temperature: float = 0.9 + missing_eos_penalty: Optional[float] = None + disable_dropout: bool = True + + def __post_init__(self): + super().__post_init__() + + if self.rlhf_optimizer not in {"crraft", "codpo", "crpg"}: + raise ValueError( + f"Invalid value for rlhf_optimizer: {self.rlhf_optimizer}. Must be one of 'crraft', 'codpo', or 'crpg'." + ) + + if self.kl_threshold is None: + raise ValueError("Training without setting the KL divergence threshold is not supported.") diff --git a/trl/trainer/cgpo_trainer.py b/trl/trainer/cgpo_trainer.py new file mode 100644 index 0000000000..049933646a --- /dev/null +++ b/trl/trainer/cgpo_trainer.py @@ -0,0 +1,744 @@ +# CGPO Authors: Tengyu Xu, Eryk Helenowski, Karthik Abinav Sankararaman, Di Jin, Kaiyan Peng, Eric Han, Shaoliang Nie, Chen Zhu, Hejia Zhang, Wenxuan Zhou, Zhouhao Zeng, Yun He,Karishma Mandyam, Arya Talabzadeh, Madian Khabsa, Gabriel Cohen, Yuandong Tian, Hao Ma, Sinong Wang, Han Fang +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import textwrap +import warnings +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import datasets +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data +from datasets import Dataset +from torch.utils.data import IterableDataset +from transformers import ( + BaseImageProcessor, + DataCollator, + FeatureExtractionMixin, + GenerationConfig, + PreTrainedTokenizerBase, + ProcessorMixin, + Trainer, + TrainerCallback, + is_apex_available, + is_wandb_available, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.trainer_utils import EvalPrediction +from transformers.utils import is_peft_available, logging + +from ..models import create_reference_model +from ..models.utils import unwrap_model_for_generation +from .cgpo_config import CGPOConfig +from .utils import ( + DataCollatorForChatML, + batch_generation, + disable_dropout_in_model, + generate_model_card, + get_reward, + prepare_deepspeed, + truncate_right, +) + + +if is_peft_available(): + from peft import PeftModel, get_peft_model + +if is_apex_available(): + from apex import amp + +if is_wandb_available(): + import wandb + + +logger = logging.get_logger(__name__) + + +# should be removed when the mixture of judges PR is merged +class MixtureOfConstraintJudges: + "Placeholder waiting for https://github.com/huggingface/trl/pull/2159 to be merged" + + def __init__(self, method: Literal["all_violated", "all_satisfied", "all_random"] = "all_random"): + self.method = method + + def judge(self, prompts, completions=None, gold_answers=None, shuffle_order=None): + if self.method == "all_violated": + return [0 for _ in range(len(prompts))] + elif self.method == "all_satisfied": + return [1 for _ in range(len(prompts))] + else: + return [random.choice([0, 1]) for _ in range(len(prompts))] + + +class CGPOTrainer(Trainer): + r""" + Initialize the CGPOTrainer. + + Args: + model (`transformers.PreTrainedModel` or `torch.nn.Module`): + The model to train, preferably an `AutoModelForCausalLM`. + ref_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`): + The reference model to use for training. If None is specified, the reference model will be created from + the model. + reward_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`): + The reward model to score completions with, preferably an `AutoModelForSequenceClassification`. + mixture_of_judges (`MixtureOfConstraintJudges`): + The mixtures of judges to check if completions satisfy a set of contraints. + args (`CGPOConfig`): + The CGPO config arguments to use for training. + data_collator (`transformers.DataCollator`): + Ignored and replaced by an instance of `DataCollatorForChatML`. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + processing_class (`PreTrainedprocessing_classBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + peft_config (`Dict`): + The peft config to use for training. + compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return + a dictionary string to metric values. + callbacks (`List[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + """ + + _tag_names = ["trl", "cgpo"] + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module], + ref_model: Union[PreTrainedModel, nn.Module, None] = None, + reward_model: Union[PreTrainedModel, nn.Module, None] = None, + mixture_of_judges: Optional[MixtureOfConstraintJudges] = None, + args: Optional[CGPOConfig] = None, + data_collator: Optional[DataCollator] = None, # type: ignore + train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, + peft_config: Optional[Dict] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + ) -> None: + if ref_model is model: + raise ValueError( + "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " + "same as `model`, either omit the `ref_model` argument or pass `None`." + ) + + self.ref_model = ref_model + + if reward_model is None: + raise ValueError("`reward_model` must be provided.") + else: + self.reward_model = reward_model + + if mixture_of_judges is None: + raise ValueError("`mixture_of_judges` must be provided.") + else: + self.moj = mixture_of_judges + + if args is None: + raise ValueError("`args` must be provided.") + + # Check that the processing_class is provided + if processing_class is None: + raise ValueError("`processing_class` must be provided.") + + # Convert to PEFT model if peft_config is provided + if peft_config is not None: + # Check if PEFT is available + if not is_peft_available(): + raise ImportError( + "PEFT is not available and passed `peft_config`. Please install PEFT with " + "`pip install peft` to use it." + ) + + # If the model is already a PeftModel, we need to merge and unload it. + # Further information here: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft + if isinstance(model, PeftModel): + model = model.merge_and_unload() + + # Get peft model with the given config + model = get_peft_model(model, peft_config) + + # Disable dropout in the model if specified + if args.disable_dropout: + disable_dropout_in_model(model) + + # Handle the ref_model + # Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to + # get the ref model, as it's just the model with a disabled adapter. When not using PEFT, we need to create + # the ref model from the model by copying it and disable the gradients and set it in evaluation mode. + if self.ref_model is None: # No ref model provided, the most common case + if peft_config is None: + self.ref_model = create_reference_model(model) # copy, disable gradients, set eval mode + else: + self.ref_model = None # we don't need a ref model here, we can just disable the adapter. + else: # rare case, the user provided a ref model + self.ref_model = ref_model + self.ref_model.eval() + + # Set the reward model in eval mode + if self.reward_model is not None: + self.reward_model.eval() + + if data_collator is not None: + warnings.warn( + "`CGPOTrainer` only supports training with a custom `DataCollatorForChatML. " + "The data collator will be replaced by a custom `DataCollatorForChatML`. " + "The specified data collator will not be used." + ) + + args.remove_unused_columns = False + data_collator = DataCollatorForChatML(processing_class, max_length=args.max_length) + + self.generation_config = GenerationConfig( + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + top_k=0, + do_sample=True, + use_cache=False if args.gradient_checkpointing else True, + pad_token_id=processing_class.pad_token_id, + ) + + # Set custom EOS tokens if they are specified by the model's generation + # config. This is important for models with the Llama 3 chat template, + # which use special tokens <|eot_id|> and <|eom_id|> to mark the end of + # turns or messages. + if hasattr(model.generation_config, "eos_token_id") and model.generation_config.eos_token_id is not None: + self.generation_config.eos_token_id = model.generation_config.eos_token_id + + self.k = args.k + self.rlhf_optimizer = args.rlhf_optimizer + self.beta = args.beta + self.kl_threshold = args.kl_threshold if args.kl_threshold is not None else torch.inf + self.lamb = args.lamb + self.local_genscore_mini_batch_size = ( + args.local_genscore_mini_batch_size + if args.local_genscore_mini_batch_size + else args.per_device_train_batch_size + ) + + if ( + self.local_genscore_mini_batch_size < args.per_device_train_batch_size + or self.local_genscore_mini_batch_size > (args.per_device_train_batch_size * self.k) + ): + raise ValueError( + "`local_genscore_mini_batch_size` should be higher than `per_device_train_batch_size` and smaller than `per_device_train_batch_size * k`." + ) + + # to avoid divisions by 0 + self.epsilon = 1e-9 + self._tag_names.append(args.rlhf_optimizer) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + # Placed after the super().__init__ because we need self.is_deepspeed_enabled and self.accelerator + if self.is_deepspeed_enabled: + self.reward_model = prepare_deepspeed( + self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 + ) + self.ref_model = prepare_deepspeed(self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16) + else: + self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True) + if self.ref_model is not None: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + self.stats = { + "constraints/judgements": [], + "constraints/rewards": [], + "objective/regularized_rewards": [], + } + + def _get_batch_logprobs( + self, input_ids: torch.Tensor, attention_mask: torch.Tensor, context_length: int, use_ref_model: bool = False + ) -> torch.Tensor: + if not use_ref_model: + logits = self.model(input_ids, attention_mask=attention_mask).logits + + else: + if self.ref_model is not None: + logits = self.ref_model(input_ids, attention_mask=attention_mask).logits + else: + with self.model.disable_adapter(): + logits = self.model(input_ids, attention_mask=attention_mask).logits + + completion_logprobs = F.log_softmax(logits, dim=-1)[:, context_length - 1 : -1] + completion_ids = input_ids[:, context_length:] + completion_mask = attention_mask[:, context_length:] + + logprobs = torch.take_along_dim(completion_logprobs, completion_ids.unsqueeze(-1), dim=-1).squeeze(-1) + logprobs = torch.masked_fill(logprobs, ~completion_mask.bool(), 0.0) + return logprobs.sum(1) + + def crpg_optimization(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + context_length = inputs["context_length"] + prompt_completion_ids = inputs["prompt_completion_ids"] + prompt_completion_mask = inputs["prompt_completion_mask"] + judgements = inputs["judgements"] + rewards = inputs["rewards"] + completion_logprobs = inputs["completion_logprobs"] + bs = inputs["bs"] + mini_bs = self.local_genscore_mini_batch_size + full_bs = prompt_completion_ids.shape[0] + + with torch.no_grad(): + _, baseline_rewards, _ = get_reward( + self.reward_model, + inputs["input_ids"], + self.processing_class.pad_token_id, + context_length, + ) + baseline_rewards = baseline_rewards.repeat_interleave(repeats=self.k, dim=0) + + # compute the reference logprobs on generated samples + ref_logprobss = [] + for i in range(0, full_bs, mini_bs): + mini_batch_ids = prompt_completion_ids[i : i + mini_bs] + mini_batch_mask = prompt_completion_mask[i : i + mini_bs] + with torch.no_grad(): + mini_batch_ref_logprobs = self._get_batch_logprobs( + mini_batch_ids, mini_batch_mask, context_length, use_ref_model=True + ) + ref_logprobss.append(mini_batch_ref_logprobs) + + ref_logprobs = torch.cat(ref_logprobss) + + # kl_div regularizer + kl_div = completion_logprobs - ref_logprobs + kl_div_regularizer = torch.clamp(1 - kl_div / self.kl_threshold, min=0) + + # compute the constrained calibrated regularize rewards + regularized_rewards = judgements * torch.sigmoid(rewards - baseline_rewards) * kl_div_regularizer + mean_regularized_rewards = regularized_rewards.mean() + + total_loss = torch.tensor(0.0, device=self.model.device) + total_num_tokens = prompt_completion_mask[:, context_length:].sum() + for i in range(0, full_bs, bs): + # Compute loss on a batch of size `bs`, instead of the full batch (`bs` * self.k) to avoid OOM. + mini_batch_prompt_completion_ids = prompt_completion_ids[i : i + bs] + mini_batch_prompt_completion_mask = prompt_completion_mask[i : i + bs] + mini_batch_regularized_rewards = regularized_rewards[i : i + bs] + + # compute kl_divergence + logprobs = self._get_batch_logprobs( + mini_batch_prompt_completion_ids, mini_batch_prompt_completion_mask, context_length + ) + + losses = -logprobs * (mini_batch_regularized_rewards - mean_regularized_rewards) + loss = losses.sum() / total_num_tokens + loss = loss / self.k + + if self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + self.accelerator.backward(loss) + + total_loss += loss + + self.stats["objective/regularized_rewards"].append(self.accelerator.gather(regularized_rewards).mean().item()) + + return total_loss + + def codpo_optimization(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + bs = inputs["bs"] + context_length = inputs["context_length"] + + # reshaping per prompt + judgements = inputs["judgements"].view(bs, self.k).bool() + rewards = inputs["rewards"].view(bs, self.k) + prompt_completion_ids = inputs["prompt_completion_ids"].view(bs, self.k, -1) + prompt_completion_mask = inputs["prompt_completion_mask"].view(bs, self.k, -1) + + positive_masked_rewards = judgements * rewards + negative_masked_rewards = torch.where(judgements, torch.inf, 1) * rewards + + # get chosen and rejected completions + chosen_idx = torch.argmax(positive_masked_rewards, dim=1) + # handle cases where all generations satisfy constraints + rejected_idx = torch.where( + judgements.sum(dim=1) == self.k, + torch.argmin(rewards, dim=1), + torch.argmin(negative_masked_rewards, dim=1), + ) + batch_indices = torch.arange(bs) + chosen_prompt_completion_ids = prompt_completion_ids[batch_indices, chosen_idx] + chosen_prompt_completion_mask = prompt_completion_mask[batch_indices, chosen_idx] + rejected_prompt_completion_ids = prompt_completion_ids[batch_indices, rejected_idx] + rejected_prompt_completion_mask = prompt_completion_mask[batch_indices, rejected_idx] + + # get the batch log probabilities from policy and ref + chosen_logprobs = self._get_batch_logprobs( + chosen_prompt_completion_ids, chosen_prompt_completion_mask, context_length + ) + rejected_logprobs = self._get_batch_logprobs( + rejected_prompt_completion_ids, rejected_prompt_completion_mask, context_length + ) + with torch.no_grad(): + chosen_ref_logprobs = self._get_batch_logprobs( + chosen_prompt_completion_ids, chosen_prompt_completion_mask, context_length, use_ref_model=True + ) + rejected_ref_logprobs = self._get_batch_logprobs( + rejected_prompt_completion_ids, rejected_prompt_completion_mask, context_length, use_ref_model=True + ) + + # computes the regularized dpo loss + pi_logratios = chosen_logprobs - rejected_logprobs + ref_logratios = chosen_ref_logprobs - rejected_ref_logprobs + logits = pi_logratios - ref_logratios + chosen_length = chosen_prompt_completion_mask[:, context_length:].sum(-1) + # eqn (14) in the paper + losses = -(F.logsigmoid(self.beta * logits) + self.lamb / chosen_length * chosen_logprobs) + + # Skip samples where no generations satisfy all constraints + positive_completion_mask = judgements.sum(dim=1) != 0 + loss = (losses * positive_completion_mask).sum() / (positive_completion_mask.sum() + self.epsilon) + + if self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + self.accelerator.backward(loss) + + return loss + + def crraft_optimization(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + prompt_baseline_ids = inputs["input_ids"] + prompt_baseline_mask = inputs["attention_mask"] + rewards = inputs["rewards"] + judgements = inputs["judgements"] + prompt_completion_ids = inputs["prompt_completion_ids"] + prompt_completion_mask = inputs["prompt_completion_mask"] + completion_logprobs = inputs["completion_logprobs"] + + bs = inputs["bs"] + mini_bs = self.local_genscore_mini_batch_size + full_bs = prompt_completion_ids.shape[0] + context_length = inputs["context_length"] + + # get baseline rewards + with torch.no_grad(): + _, baseline_rewards, _ = get_reward( + self.reward_model, + prompt_baseline_ids, + self.processing_class.pad_token_id, + inputs["context_length"], + ) + + if self.args.missing_eos_penalty is not None: + baseline_ids = prompt_baseline_ids[:, context_length:] + contain_eos_token = torch.any(baseline_ids == self.generation_config.eos_token_id, dim=-1) + baseline_rewards[~contain_eos_token] -= self.args.missing_eos_penalty + + # compute kl div regularizer + ref_logprobss = [] + for i in range(0, full_bs, mini_bs): + mini_batch_ids = prompt_completion_ids[i : i + mini_bs] + mini_batch_mask = prompt_completion_mask[i : i + mini_bs] + with torch.no_grad(): + mini_batch_ref_logprobs = self._get_batch_logprobs( + mini_batch_ids, mini_batch_mask, context_length, use_ref_model=True + ) + ref_logprobss.append(mini_batch_ref_logprobs) + + ref_logprobs = torch.cat(ref_logprobss) + kl_div = completion_logprobs - ref_logprobs + kl_div_regularizer = kl_div <= self.kl_threshold + + baseline_rewards = baseline_rewards.repeat_interleave(repeats=self.k, dim=0) + # compute the contrainted regularized calibrated rewards + regularized_rewards = (kl_div_regularizer * judgements * torch.sigmoid(rewards - baseline_rewards)).view( + bs, self.k + ) + + # get the best completions and rewards per prompt + best_completion_indices = torch.argmax(regularized_rewards, dim=-1) + batch_indices = torch.arange(bs) + best_prompt_completion_ids = prompt_completion_ids.view(bs, self.k, -1)[batch_indices, best_completion_indices] + best_prompt_completion_mask = prompt_completion_mask.view(bs, self.k, -1)[ + batch_indices, best_completion_indices + ] + best_rewards = regularized_rewards[batch_indices, best_completion_indices] + + best_completion_logprobs = self._get_batch_logprobs( + best_prompt_completion_ids, best_prompt_completion_mask, context_length + ) + + no_positive_completion_mask = best_rewards == 0 + if no_positive_completion_mask.any(): + # only compute the baseline logprobs if we need to (ie some prompts do not have any positive samples) + # check that the baseline satisfy all constraints: judgements and kl + gold_answers = inputs.get("gold_answer", None) + baseline_judgements = torch.zeros_like(best_rewards, dtype=torch.bool) + prompts_text = [text for i, text in enumerate(inputs["prompts_text"]) if no_positive_completion_mask[i]] + if gold_answers is not None: + gold_answers = [text for i, text in enumerate(gold_answers) if no_positive_completion_mask[i]] + + baseline_completions_text = [ + text for i, text in enumerate(inputs["baseline_completions_text"]) if no_positive_completion_mask[i] + ] + no_positive_judgements = self.moj.judge(prompts_text, baseline_completions_text, gold_answers) + + baseline_judgements[no_positive_completion_mask] = torch.tensor( + no_positive_judgements, device=self.model.device, dtype=torch.bool + ) + + baseline_logprobs = torch.zeros_like(best_completion_logprobs) + baseline_logprobs[no_positive_completion_mask] = self._get_batch_logprobs( + prompt_baseline_ids[no_positive_completion_mask], + prompt_baseline_mask[no_positive_completion_mask], + context_length, + ) + with torch.no_grad(): + baseline_ref_logprobs = torch.zeros_like(best_completion_logprobs) + baseline_ref_logprobs[no_positive_completion_mask] = self._get_batch_logprobs( + prompt_baseline_ids[no_positive_completion_mask], + prompt_baseline_mask[no_positive_completion_mask], + context_length, + use_ref_model=True, + ) + + baseline_kl_div = baseline_logprobs - baseline_ref_logprobs + + baseline_kl_div_regularizer = baseline_kl_div <= self.kl_threshold + replace_by_baseline_mask = no_positive_completion_mask * baseline_kl_div_regularizer * baseline_judgements + + best_completion_logprobs = torch.where( + replace_by_baseline_mask, baseline_logprobs, best_completion_logprobs + ) + + baseline_mask = prompt_baseline_mask[:, context_length:].sum(1) + best_completion_mask = best_prompt_completion_mask[:, context_length:].sum(1) + best_completion_mask = torch.where( + replace_by_baseline_mask, + baseline_mask, + best_completion_mask, + ) + # calibrated reward for the baseline is always 0.5 ie sigmoid(0) + best_rewards = torch.where(replace_by_baseline_mask, 0.5, best_rewards) + + else: + best_completion_mask = best_prompt_completion_mask[:, context_length:].sum(1) + + best_rewards_mask = best_rewards != 0 + total_num_tokens = (best_completion_mask * best_rewards_mask).sum() + # compute loss as done in eqn (18) of the CGPO paper: https://huggingface.co/papers/2409.20370 + losses = -best_completion_logprobs * best_rewards + loss = losses.sum() / (total_num_tokens + self.epsilon) + + if self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + self.accelerator.backward(loss) + + self.stats["objective/regularized_rewards"].append(self.accelerator.gather(regularized_rewards).mean().item()) + + return loss + + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + bs, context_length = inputs["prompts"].shape + inputs["prompts_text"] = self.processing_class.batch_decode(inputs["prompts"]) + inputs["baseline_completions_text"] = self.processing_class.batch_decode( + inputs["input_ids"][:, context_length:] + ) + gold_answers = inputs.get("gold_answer", None) + + # step 4 of algorithm 1 of the CGPO paper: https://huggingface.co/papers/2409.20370 + prompt_ids = inputs["prompts"].repeat_interleave(repeats=self.k, dim=0) + prompt_mask = inputs["prompt_attention_mask"].repeat_interleave(repeats=self.k, dim=0) + with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: + query_responses_ids, completion_logits = batch_generation( + unwrapped_model, + prompt_ids, + self.local_genscore_mini_batch_size, + self.processing_class.pad_token_id, + self.generation_config, + ) + + completion_ids = query_responses_ids[:, context_length:] + + query_responses = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + # step 5 of algorithm 1 of the CGPO paper: https://huggingface.co/papers/2409.20370 + with torch.no_grad(): + prompt_repeated = [item for item in inputs["prompts_text"] for _ in range(self.k)] + if gold_answers is not None: + gold_answers = [item for item in gold_answers for _ in range(self.k)] + judgements = self.moj.judge(prompt_repeated, query_responses, gold_answers) + + completion_ids, completion_mask = truncate_right( + completion_ids, self.generation_config.eos_token_id, self.processing_class.pad_token_id + ) + completion_logprobs = F.log_softmax(completion_logits, dim=-1) + completion_logprobs = torch.take_along_dim(completion_logprobs, completion_ids.unsqueeze(-1), dim=-1).squeeze( + -1 + ) + completion_logprobs = torch.masked_fill(completion_logprobs, ~completion_mask.bool(), 0.0).sum(1) + + prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1) + prompt_completion_mask = torch.cat((prompt_mask, completion_mask), dim=1) + + rewards = [] + for i in range(0, prompt_completion_ids.shape[0], self.local_genscore_mini_batch_size): + mini_batch_prompt_completion_ids = prompt_completion_ids[i : i + self.local_genscore_mini_batch_size] + with torch.no_grad(): + _, mini_batch_rewards, _ = get_reward( + self.reward_model, + mini_batch_prompt_completion_ids, + self.processing_class.pad_token_id, + context_length, + ) + + rewards.append(mini_batch_rewards) + + rewards = torch.cat(rewards, dim=0) + # Completions that do not contain an eos token id are penalized. + if self.args.missing_eos_penalty is not None: + contain_eos_token = torch.any(completion_ids == self.generation_config.eos_token_id, dim=-1) + rewards[~contain_eos_token] -= self.args.missing_eos_penalty + + inputs["rewards"] = rewards + inputs["judgements"] = torch.tensor(judgements, device=self.model.device, dtype=torch.float) + inputs["bs"] = bs + inputs["context_length"] = context_length + inputs["prompt_completion_ids"] = prompt_completion_ids + inputs["prompt_completion_mask"] = prompt_completion_mask + inputs["completion_logprobs"] = completion_logprobs + + if self.rlhf_optimizer == "crraft": + loss = self.crraft_optimization(inputs) + elif self.rlhf_optimizer == "codpo": + loss = self.codpo_optimization(inputs) + elif self.rlhf_optimizer == "crpg": + loss = self.crpg_optimization(inputs) + else: + raise ValueError(f"{self.rlhf_optimizer} not supported.", "Choose between `codpo`, `crraft` and `crpg`.") + + self.stats["constraints/judgements"].append(self.accelerator.gather(inputs["judgements"]).mean().item()) + self.stats["constraints/rewards"].append(self.accelerator.gather(inputs["rewards"]).mean().item()) + + return loss.detach() / self.args.gradient_accumulation_steps + + # Same as Trainer.evaluate but log our metrics + def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval): + if self.control.should_log and self.state.global_step > self._globalstep_last_logged: + logs: Dict[str, float] = {} + + # all_gather + mean() to get average loss over all processes + tr_loss_scalar = self._nested_gather(tr_loss).mean().item() + + # reset tr_loss to zero + tr_loss -= tr_loss + + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + if grad_norm is not None: + logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm + logs["learning_rate"] = self._get_learning_rate() + + # Add our metrics + for key, val in self.stats.items(): + if len(val) != 0: + # CODPO do not update kl div + logs[key] = sum(val) / len(val) + self.stats = {key: [] for key in self.stats} # reset stats + + self._total_loss_scalar += tr_loss_scalar + self._globalstep_last_logged = self.state.global_step + self.store_flos() + + self.log(logs) + + metrics = None + if self.control.should_evaluate: + metrics = self._evaluate(trial, ignore_keys_for_eval) + + if self.control.should_save: + self._save_checkpoint(model, trial, metrics=metrics) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + citation = textwrap.dedent( + """\ + TO ADD + """ + ) + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="CGPO", + trainer_citation=citation, + paper_title="The Perfect Blend: Redefining RLHF with Mixture of Judges", + paper_id="2409.20370", + ) + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 81e826f971..a2bd14e3c1 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -246,6 +246,7 @@ class DataCollatorForChatML: max_length: int = None prompt_key: str = "prompt" messages_key: str = "messages" + gold_key: str = "gold_answer" def __post_init__(self): if self.tokenizer.pad_token_id is None: @@ -260,6 +261,7 @@ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: prompts_input_ids = [] prompt_attention_mask = [] labels = [] + golds = [] if examples[0].get(self.gold_key, None) is not None else None for example in examples: formatted_prompt = example.get(self.prompt_key, None) @@ -306,6 +308,9 @@ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: label[completion_start_idx:] = input_ids[-1][completion_start_idx:] labels.append(label) + if golds is not None: + golds.append(example[self.gold_key]) + # convert to list of tensors and pad input_ids = [torch.tensor(ids, dtype=torch.long) for ids in input_ids] attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in attention_mask] @@ -319,7 +324,7 @@ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: prompts_input_ids = pad(prompts_input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id) prompt_attention_mask = pad(prompt_attention_mask, padding_side="left", padding_value=0) - return { + batch = { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, @@ -327,6 +332,11 @@ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: "prompt_attention_mask": prompt_attention_mask, } + if golds is not None: + batch["gold_answer"] = golds + + return batch + @dataclass class RewardDataCollatorWithPadding: