Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ktotrainer: Refuse datasets which contain only one class of labels #1724

Merged
merged 2 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/kto_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ kto_dataset_dict = {
```

where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`).
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. It is required that the dataset contains at least one desirable and one undesirable completion.


## Expected model format
The KTO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
Expand Down
122 changes: 122 additions & 0 deletions tests/test_kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,74 @@ def _init_dummy_dataset(self):
# fmt: on
return Dataset.from_dict(dummy_dataset_dict)

def _init_dummy_dataset_only_desirable(self):
# fmt: off
dummy_dataset_unbalanced_dict = {
"prompt": [
"Hey, hello",
"How are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"completion": [
"hi nice to meet you",
"leave me alone",
"I don't have a name",
"My name is Mary",
"Python",
"C++",
"Java",
],
"label": [
True,
True,
True,
True,
True,
True,
True,
],
}
# fmt: on
return Dataset.from_dict(dummy_dataset_unbalanced_dict)

def _init_dummy_dataset_no_desirable(self):
# fmt: off
dummy_dataset_unbalanced_dict = {
"prompt": [
"Hey, hello",
"How are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"completion": [
"hi nice to meet you",
"leave me alone",
"I don't have a name",
"My name is Mary",
"Python",
"C++",
"Java",
],
"label": [
False,
False,
False,
False,
False,
False,
False,
],
}
# fmt: on
return Dataset.from_dict(dummy_dataset_unbalanced_dict)

@parameterized.expand(
[
["gpt2", "kto", True, True],
Expand Down Expand Up @@ -144,6 +212,60 @@ def test_kto_trainer(self, name, loss_type, pre_compute, eval_dataset):
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

@require_no_wandb
def test_kto_trainer_no_desirable_input(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
remove_unused_columns=False,
)

dummy_dataset = self._init_dummy_dataset_no_desirable()

model = self.model
ref_model = self.ref_model
tokenizer = self.tokenizer

with self.assertRaises(
ValueError,
msg="The set of desirable completions cannot be empty.",
):
_ = KTOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=None,
)

@require_no_wandb
def test_kto_trainer_only_desirable_input(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
remove_unused_columns=False,
)

dummy_dataset = self._init_dummy_dataset_only_desirable()

model = self.model
ref_model = self.ref_model
tokenizer = self.tokenizer

with self.assertRaises(
ValueError,
msg="The set of undesirable completions cannot be empty.",
):
_ = KTOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=None,
)

def test_tokenize_and_process_tokens(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
Expand Down
5 changes: 5 additions & 0 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,11 @@ def make_inputs_require_grad(module, input, output):
lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
)

if len(desirable) == 0:
raise ValueError("The set of desirable completions cannot be empty.")
elif len(undesirable) == 0:
raise ValueError("The set of undesirable completions cannot be empty.")

if len(desirable) != len(undesirable):
# The lower and upper bounds come from Eq. (8) of https://arxiv.org/abs/2402.01306
des_weight_lower_bound = round((len(undesirable) * self.undesirable_weight / len(desirable)) * 1, 2)
Expand Down
Loading