diff --git a/docs/source/judges.mdx b/docs/source/judges.mdx index 329cc6c917..818c5f1f0a 100644 --- a/docs/source/judges.mdx +++ b/docs/source/judges.mdx @@ -52,10 +52,18 @@ judge.judge( ) # Outputs: [0, 1] ``` +## AllTrueJudge + +[[autodoc]] AllTrueJudge + ## BaseJudge [[autodoc]] BaseJudge +## BaseBinaryJudge + +[[autodoc]] BaseBinaryJudge + ## BaseRankJudge [[autodoc]] BaseRankJudge @@ -64,6 +72,10 @@ judge.judge( [[autodoc]] BasePairwiseJudge +## RandomBinaryJudge + +[[autodoc]] RandomBinaryJudge + ## RandomRankJudge [[autodoc]] RandomRankJudge diff --git a/tests/test_judges.py b/tests/test_judges.py index 748cc85666..ce13d7bcca 100644 --- a/tests/test_judges.py +++ b/tests/test_judges.py @@ -15,27 +15,53 @@ import time import unittest -from trl import HfPairwiseJudge, PairRMJudge, RandomPairwiseJudge, RandomRankJudge +from trl import ( + AllTrueJudge, + HfPairwiseJudge, + PairRMJudge, + RandomBinaryJudge, + RandomPairwiseJudge, + RandomRankJudge, +) from .testing_utils import require_llm_blender class TestJudges(unittest.TestCase): - def _get_prompts_and_completions(self): + def _get_prompts_and_pairwise_completions(self): prompts = ["The capital of France is", "The biggest planet in the solar system is"] completions = [["Paris", "Marseille"], ["Saturn", "Jupiter"]] return prompts, completions + def _get_prompts_and_single_completions(self): + prompts = ["What's the capital of France?", "What's the color of the sky?"] + completions = ["Marseille", "blue"] + return prompts, completions + + def test_all_true_judge(self): + judge = AllTrueJudge(judges=[RandomBinaryJudge(), RandomBinaryJudge()]) + prompts, completions = self._get_prompts_and_single_completions() + judgements = judge.judge(prompts=prompts, completions=completions) + self.assertEqual(len(judgements), 2) + self.assertTrue(all(judgement in {0, 1, -1} for judgement in judgements)) + + def test_random_binary_judge(self): + judge = RandomBinaryJudge() + prompts, completions = self._get_prompts_and_single_completions() + judgements = judge.judge(prompts=prompts, completions=completions) + self.assertEqual(len(judgements), 2) + self.assertTrue(all(judgement in {0, 1, -1} for judgement in judgements)) + def test_random_pairwise_judge(self): judge = RandomPairwiseJudge() - prompts, completions = self._get_prompts_and_completions() + prompts, completions = self._get_prompts_and_pairwise_completions() ranks = judge.judge(prompts=prompts, completions=completions) self.assertEqual(len(ranks), 2) self.assertTrue(all(isinstance(rank, int) for rank in ranks)) def test_random_rank_judge(self): judge = RandomRankJudge() - prompts, completions = self._get_prompts_and_completions() + prompts, completions = self._get_prompts_and_pairwise_completions() ranks = judge.judge(prompts=prompts, completions=completions) self.assertEqual(len(ranks), 2) self.assertTrue(all(isinstance(rank, list) for rank in ranks)) @@ -44,7 +70,7 @@ def test_random_rank_judge(self): @unittest.skip("This test needs to be run manually since it requires a valid Hugging Face API key.") def test_hugging_face_judge(self): judge = HfPairwiseJudge() - prompts, completions = self._get_prompts_and_completions() + prompts, completions = self._get_prompts_and_pairwise_completions() ranks = judge.judge(prompts=prompts, completions=completions) self.assertEqual(len(ranks), 2) self.assertTrue(all(isinstance(rank, int) for rank in ranks)) @@ -62,7 +88,7 @@ def load_pair_rm_judge(self): @require_llm_blender def test_pair_rm_judge(self): judge = self.load_pair_rm_judge() - prompts, completions = self._get_prompts_and_completions() + prompts, completions = self._get_prompts_and_pairwise_completions() ranks = judge.judge(prompts=prompts, completions=completions) self.assertEqual(len(ranks), 2) self.assertTrue(all(isinstance(rank, int) for rank in ranks)) @@ -71,7 +97,7 @@ def test_pair_rm_judge(self): @require_llm_blender def test_pair_rm_judge_return_scores(self): judge = self.load_pair_rm_judge() - prompts, completions = self._get_prompts_and_completions() + prompts, completions = self._get_prompts_and_pairwise_completions() probs = judge.judge(prompts=prompts, completions=completions, return_scores=True) self.assertEqual(len(probs), 2) self.assertTrue(all(isinstance(prob, float) for prob in probs)) diff --git a/trl/__init__.py b/trl/__init__.py index 1c12c2ade9..7467667066 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -49,7 +49,9 @@ "trainer": [ "AlignPropConfig", "AlignPropTrainer", + "AllTrueJudge", "BaseJudge", + "BaseBinaryJudge", "BasePairwiseJudge", "BaseRankJudge", "BCOConfig", @@ -79,6 +81,7 @@ "PairRMJudge", "PPOConfig", "PPOTrainer", + "RandomBinaryJudge", "RandomPairwiseJudge", "RandomRankJudge", "RewardConfig", @@ -138,6 +141,8 @@ from .trainer import ( AlignPropConfig, AlignPropTrainer, + AllTrueJudge, + BaseBinaryJudge, BaseJudge, BasePairwiseJudge, BaseRankJudge, @@ -168,6 +173,7 @@ PairRMJudge, PPOConfig, PPOTrainer, + RandomBinaryJudge, RandomPairwiseJudge, RandomRankJudge, RewardConfig, diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index f0eba412c6..b6878ae11e 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -34,12 +34,15 @@ "gkd_trainer": ["GKDTrainer"], "iterative_sft_trainer": ["IterativeSFTTrainer"], "judges": [ + "AllTrueJudge", "BaseJudge", + "BaseBinaryJudge", "BasePairwiseJudge", "BaseRankJudge", "HfPairwiseJudge", "OpenAIPairwiseJudge", "PairRMJudge", + "RandomBinaryJudge", "RandomPairwiseJudge", "RandomRankJudge", ], @@ -98,12 +101,15 @@ from .gkd_trainer import GKDTrainer from .iterative_sft_trainer import IterativeSFTTrainer from .judges import ( + AllTrueJudge, + BaseBinaryJudge, BaseJudge, BasePairwiseJudge, BaseRankJudge, HfPairwiseJudge, OpenAIPairwiseJudge, PairRMJudge, + RandomBinaryJudge, RandomPairwiseJudge, RandomRankJudge, ) diff --git a/trl/trainer/judges.py b/trl/trainer/judges.py index af56ec3d9b..f79588b4a0 100644 --- a/trl/trainer/judges.py +++ b/trl/trainer/judges.py @@ -144,6 +144,54 @@ def judge(self, prompts: List[str], completions: List[List[str]], shuffle_order: raise NotImplementedError("Judge subclasses must implement the `judge` method.") +class BaseBinaryJudge(BaseJudge): + """ + Base class for binary judges. + """ + + @abstractmethod + def judge( + self, + prompts: List[str], + completions: List[str], + gold_completions: Optional[List[str]] = None, + shuffle_order: bool = True, + ) -> List[int]: + """ + Judge the completion for a given prompt. Used to assess if a completion satisfies a constraint. + + This base class should be used to implement binary evaluations as done in section 4.1.4 of the + [CGPO paper](https://huggingface.co/papers/2409.20370). + It is relevant for assessing whether or not a prompt completion pair satisfies a specific contraint. + + Args: + prompts (`List[str]`): List of prompts. + completions (`List[str]`): List of completions. + gold_completions (`List[str]`, `optional`): List of gold completions if it exists. + shuffle_order (`bool`): Whether to shuffle the order of the completions to avoid positional bias. + + Returns: + List[int]: A list of binary labels: + - 1 indicates that the completion satisfies the evaluated constraint. + - 0 indicates that the completion does not satisfy the evaluated constraint. + + Note: + If the judge returns -1 for any prompt, it indicates that the inner process used to compute the preference has failed. + For instance, this could occur if the underlying language model or rule based contraint returned an invalid answer. + In such cases, the caller should handle these invalid indices appropriately, possibly by implementing fallback logic or error handling. + """ + raise NotImplementedError("Judge subclasses must implement the `judge` method.") + + +class RandomBinaryJudge(BaseBinaryJudge): + """ + Random binary judge, for testing purposes. + """ + + def judge(self, prompts, completions, gold_completions=None, shuffle_order=True): + return [random.choice([0, 1, -1]) for _ in range(len(prompts))] + + class RandomRankJudge(BaseRankJudge): """ Random rank, for testing purposes. @@ -392,3 +440,47 @@ def get_rank(prompt, candidates): # Return the ranks return ranks + + +class AllTrueJudge(BaseBinaryJudge): + """ + Unify the decision of multiple [`BaseBinaryJudge`] instances. + + Returns `1` only if all inner binary judges return `1`. If any judge returns `0`, it returns `0`. + If any judge returns `-1`, indicating a failure in its process, this judge will also return `-1`. + + Implements the Mixture of Judges as described in the [CGPO paper](https://huggingface.co/papers/2409.20370). + + Args: + judges (`List[BaseBinaryJudge]`): A list of [`BaseBinaryJudge`] instances whose decisions will be unified. + """ + + def __init__(self, judges: List[BaseBinaryJudge]): + self.judges = judges + + def judge( + self, + prompts: List[str], + completions: List[str], + gold_completions: Optional[List[str]] = None, + shuffle_order: bool = True, + ) -> List[int]: + all_binary_judgments = [ + judge.judge(prompts, completions, gold_completions, shuffle_order) for judge in self.judges + ] + output = [] + for binary_judgments in zip(*all_binary_judgments): + # Check that all values are in {0, 1, -1} + if any(binary_judgment not in {0, 1, -1} for binary_judgment in binary_judgments): + raise ValueError( + f"Invalid binary judgment: {binary_judgments}, expected list of values in {{0, 1, -1}}." + ) + + # Unify the decision + if -1 in binary_judgments: + output.append(-1) + elif all(binary_judgment == 1 for binary_judgment in binary_judgments): + output.append(1) + else: + output.append(0) + return output