Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🤝 Mixture of judges #2159

Merged
merged 46 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
013aae4
base judge
gaetanlop Oct 3, 2024
0ea5a48
adding mixture of judges
gaetanlop Oct 3, 2024
517cfb0
update doc
gaetanlop Oct 3, 2024
9e5ed12
update doc
gaetanlop Oct 3, 2024
3406e53
formatting
gaetanlop Oct 3, 2024
568d2b9
fix small typo in doc
gaetanlop Oct 3, 2024
466292e
fix randomcontraintjudge
gaetanlop Oct 3, 2024
a3d90df
Merge branch 'main' into cgpo_mixture_of_judges
qgallouedec Oct 4, 2024
3f0b8b0
replace arxiv by hf papers
gaetanlop Oct 4, 2024
8995ab4
formatting
gaetanlop Oct 4, 2024
896259e
Merge branch 'main' into cgpo_mixture_of_judges
gaetanlop Oct 4, 2024
ef1feb0
fix naming in __init__
gaetanlop Oct 4, 2024
3da4a06
run precommi
gaetanlop Oct 4, 2024
765768b
adding gold answers to judges
gaetanlop Oct 7, 2024
8aaaaa1
cgpo llm judges
gaetanlop Oct 7, 2024
cfc84ed
fix init
gaetanlop Oct 7, 2024
a1e8eeb
Merge branch 'main' into cgpo_mixture_of_judges
gaetanlop Oct 7, 2024
6898285
output type
gaetanlop Oct 7, 2024
f5639a1
adjust booleans in test
gaetanlop Oct 7, 2024
289b855
adapt moj doc
gaetanlop Oct 7, 2024
308e743
Merge branch 'main' into cgpo_mixture_of_judges
gaetanlop Oct 9, 2024
2c6de87
Merge branch 'main' into cgpo_mixture_of_judges
gaetanlop Oct 11, 2024
dedc859
renaming and removing factuality and safety judges
gaetanlop Oct 11, 2024
ba0fffb
fix typo in import
gaetanlop Oct 11, 2024
226de82
fix small typo in naming
gaetanlop Oct 11, 2024
5626cd4
Merge branch 'main' into cgpo_mixture_of_judges
gaetanlop Oct 14, 2024
567b798
formatting
gaetanlop Oct 14, 2024
1c33494
Merge branch 'main' into cgpo_mixture_of_judges
gaetanlop Oct 21, 2024
64c9de8
Merge branch 'main' into cgpo_mixture_of_judges
gaetanlop Oct 24, 2024
559cd1b
Update trl/trainer/judges.py
gaetanlop Oct 24, 2024
2c29ef5
update parameter name
gaetanlop Oct 25, 2024
bd1bed8
update tests
gaetanlop Oct 25, 2024
21e3ccd
update doc
gaetanlop Oct 25, 2024
43d6cca
Merge branch 'main' into cgpo_mixture_of_judges
gaetanlop Oct 29, 2024
9eca0f8
Update trl/trainer/judges.py
gaetanlop Oct 29, 2024
d5b32f0
Update doc
gaetanlop Oct 29, 2024
ac88c63
fix alltruejudge type
gaetanlop Oct 29, 2024
999154b
Merge branch 'main' into cgpo_mixture_of_judges
gaetanlop Nov 4, 2024
6d7c8bc
Merge branch 'main' into cgpo_mixture_of_judges
qgallouedec Nov 15, 2024
8b36bc7
Refactor judge variable names and update test names
qgallouedec Nov 15, 2024
30f3ca8
Clarify judgment logic
qgallouedec Nov 15, 2024
60af3c4
Merge branch 'cgpo_mixture_of_judges' of https://github.com/gaetanlop…
qgallouedec Nov 15, 2024
fcfcd14
Fix invalid binary judgment check in AllTrueJudge class
qgallouedec Nov 15, 2024
7e7f9e7
Fix invalid binary judgment check in AllTrueJudge class
qgallouedec Nov 15, 2024
294588f
Merge branch 'main' into cgpo_mixture_of_judges
qgallouedec Nov 18, 2024
f049351
Merge branch 'main' into cgpo_mixture_of_judges
qgallouedec Nov 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/judges.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,18 @@ judge.judge(
) # Outputs: [0, 1]
```

## AllTrueJudge

[[autodoc]] AllTrueJudge

## BaseJudge

[[autodoc]] BaseJudge

## BaseBinaryJudge

[[autodoc]] BaseBinaryJudge

## BaseRankJudge

[[autodoc]] BaseRankJudge
Expand All @@ -64,6 +72,10 @@ judge.judge(

[[autodoc]] BasePairwiseJudge

## RandomBinaryJudge

[[autodoc]] RandomBinaryJudge

## RandomRankJudge

[[autodoc]] RandomRankJudge
Expand Down
40 changes: 33 additions & 7 deletions tests/test_judges.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,53 @@
import time
import unittest

from trl import HfPairwiseJudge, PairRMJudge, RandomPairwiseJudge, RandomRankJudge
from trl import (
AllTrueJudge,
HfPairwiseJudge,
PairRMJudge,
RandomBinaryJudge,
RandomPairwiseJudge,
RandomRankJudge,
)

from .testing_utils import require_llm_blender


class TestJudges(unittest.TestCase):
def _get_prompts_and_completions(self):
def _get_prompts_and_pairwise_completions(self):
prompts = ["The capital of France is", "The biggest planet in the solar system is"]
completions = [["Paris", "Marseille"], ["Saturn", "Jupiter"]]
return prompts, completions

def _get_prompts_and_single_completions(self):
prompts = ["What's the capital of France?", "What's the color of the sky?"]
completions = ["Marseille", "blue"]
return prompts, completions

def test_all_true_judge(self):
judge = AllTrueJudge(judges=[RandomBinaryJudge(), RandomBinaryJudge()])
prompts, completions = self._get_prompts_and_single_completions()
judgements = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(judgements), 2)
self.assertTrue(all(judgement in {0, 1, -1} for judgement in judgements))

def test_random_binary_judge(self):
judge = RandomBinaryJudge()
prompts, completions = self._get_prompts_and_single_completions()
judgements = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(judgements), 2)
self.assertTrue(all(judgement in {0, 1, -1} for judgement in judgements))

def test_random_pairwise_judge(self):
judge = RandomPairwiseJudge()
prompts, completions = self._get_prompts_and_completions()
prompts, completions = self._get_prompts_and_pairwise_completions()
ranks = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(ranks), 2)
self.assertTrue(all(isinstance(rank, int) for rank in ranks))

def test_random_rank_judge(self):
judge = RandomRankJudge()
prompts, completions = self._get_prompts_and_completions()
prompts, completions = self._get_prompts_and_pairwise_completions()
ranks = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(ranks), 2)
self.assertTrue(all(isinstance(rank, list) for rank in ranks))
Expand All @@ -44,7 +70,7 @@ def test_random_rank_judge(self):
@unittest.skip("This test needs to be run manually since it requires a valid Hugging Face API key.")
def test_hugging_face_judge(self):
judge = HfPairwiseJudge()
prompts, completions = self._get_prompts_and_completions()
prompts, completions = self._get_prompts_and_pairwise_completions()
ranks = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(ranks), 2)
self.assertTrue(all(isinstance(rank, int) for rank in ranks))
Expand All @@ -62,7 +88,7 @@ def load_pair_rm_judge(self):
@require_llm_blender
def test_pair_rm_judge(self):
judge = self.load_pair_rm_judge()
prompts, completions = self._get_prompts_and_completions()
prompts, completions = self._get_prompts_and_pairwise_completions()
ranks = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(ranks), 2)
self.assertTrue(all(isinstance(rank, int) for rank in ranks))
Expand All @@ -71,7 +97,7 @@ def test_pair_rm_judge(self):
@require_llm_blender
def test_pair_rm_judge_return_scores(self):
judge = self.load_pair_rm_judge()
prompts, completions = self._get_prompts_and_completions()
prompts, completions = self._get_prompts_and_pairwise_completions()
probs = judge.judge(prompts=prompts, completions=completions, return_scores=True)
self.assertEqual(len(probs), 2)
self.assertTrue(all(isinstance(prob, float) for prob in probs))
Expand Down
6 changes: 6 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@
"trainer": [
"AlignPropConfig",
"AlignPropTrainer",
"AllTrueJudge",
"BaseJudge",
"BaseBinaryJudge",
"BasePairwiseJudge",
"BaseRankJudge",
"BCOConfig",
Expand Down Expand Up @@ -79,6 +81,7 @@
"PairRMJudge",
"PPOConfig",
"PPOTrainer",
"RandomBinaryJudge",
"RandomPairwiseJudge",
"RandomRankJudge",
"RewardConfig",
Expand Down Expand Up @@ -138,6 +141,8 @@
from .trainer import (
AlignPropConfig,
AlignPropTrainer,
AllTrueJudge,
BaseBinaryJudge,
BaseJudge,
BasePairwiseJudge,
BaseRankJudge,
Expand Down Expand Up @@ -168,6 +173,7 @@
PairRMJudge,
PPOConfig,
PPOTrainer,
RandomBinaryJudge,
RandomPairwiseJudge,
RandomRankJudge,
RewardConfig,
Expand Down
6 changes: 6 additions & 0 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@
"gkd_trainer": ["GKDTrainer"],
"iterative_sft_trainer": ["IterativeSFTTrainer"],
"judges": [
"AllTrueJudge",
"BaseJudge",
"BaseBinaryJudge",
"BasePairwiseJudge",
"BaseRankJudge",
"HfPairwiseJudge",
"OpenAIPairwiseJudge",
"PairRMJudge",
"RandomBinaryJudge",
"RandomPairwiseJudge",
"RandomRankJudge",
],
Expand Down Expand Up @@ -98,12 +101,15 @@
from .gkd_trainer import GKDTrainer
from .iterative_sft_trainer import IterativeSFTTrainer
from .judges import (
AllTrueJudge,
BaseBinaryJudge,
BaseJudge,
BasePairwiseJudge,
BaseRankJudge,
HfPairwiseJudge,
OpenAIPairwiseJudge,
PairRMJudge,
RandomBinaryJudge,
RandomPairwiseJudge,
RandomRankJudge,
)
Expand Down
92 changes: 92 additions & 0 deletions trl/trainer/judges.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,54 @@ def judge(self, prompts: List[str], completions: List[List[str]], shuffle_order:
raise NotImplementedError("Judge subclasses must implement the `judge` method.")


class BaseBinaryJudge(BaseJudge):
"""
Base class for binary judges.
"""

@abstractmethod
def judge(
self,
prompts: List[str],
completions: List[str],
gold_completions: Optional[List[str]] = None,
shuffle_order: bool = True,
) -> List[int]:
"""
Judge the completion for a given prompt. Used to assess if a completion satisfies a constraint.

This base class should be used to implement binary evaluations as done in section 4.1.4 of the
[CGPO paper](https://huggingface.co/papers/2409.20370).
It is relevant for assessing whether or not a prompt completion pair satisfies a specific contraint.

Args:
prompts (`List[str]`): List of prompts.
completions (`List[str]`): List of completions.
gold_completions (`List[str]`, `optional`): List of gold completions if it exists.
shuffle_order (`bool`): Whether to shuffle the order of the completions to avoid positional bias.

Returns:
List[int]: A list of binary labels:
- 1 indicates that the completion satisfies the evaluated constraint.
- 0 indicates that the completion does not satisfy the evaluated constraint.

Note:
If the judge returns -1 for any prompt, it indicates that the inner process used to compute the preference has failed.
For instance, this could occur if the underlying language model or rule based contraint returned an invalid answer.
In such cases, the caller should handle these invalid indices appropriately, possibly by implementing fallback logic or error handling.
"""
raise NotImplementedError("Judge subclasses must implement the `judge` method.")


class RandomBinaryJudge(BaseBinaryJudge):
"""
Random binary judge, for testing purposes.
"""

def judge(self, prompts, completions, gold_completions=None, shuffle_order=True):
return [random.choice([0, 1, -1]) for _ in range(len(prompts))]


class RandomRankJudge(BaseRankJudge):
"""
Random rank, for testing purposes.
Expand Down Expand Up @@ -392,3 +440,47 @@ def get_rank(prompt, candidates):

# Return the ranks
return ranks


class AllTrueJudge(BaseBinaryJudge):
"""
Unify the decision of multiple [`BaseBinaryJudge`] instances.

Returns `1` only if all inner binary judges return `1`. If any judge returns `0`, it returns `0`.
If any judge returns `-1`, indicating a failure in its process, this judge will also return `-1`.

Implements the Mixture of Judges as described in the [CGPO paper](https://huggingface.co/papers/2409.20370).

Args:
judges (`List[BaseBinaryJudge]`): A list of [`BaseBinaryJudge`] instances whose decisions will be unified.
"""

def __init__(self, judges: List[BaseBinaryJudge]):
self.judges = judges

def judge(
self,
prompts: List[str],
completions: List[str],
gold_completions: Optional[List[str]] = None,
shuffle_order: bool = True,
) -> List[int]:
all_binary_judgments = [
judge.judge(prompts, completions, gold_completions, shuffle_order) for judge in self.judges
]
output = []
for binary_judgments in zip(*all_binary_judgments):
# Check that all values are in {0, 1, -1}
if any(binary_judgment not in {0, 1, -1} for binary_judgment in binary_judgments):
raise ValueError(
f"Invalid binary judgment: {binary_judgments}, expected list of values in {{0, 1, -1}}."
)

# Unify the decision
if -1 in binary_judgments:
output.append(-1)
elif all(binary_judgment == 1 for binary_judgment in binary_judgments):
output.append(1)
else:
output.append(0)
return output
Loading