Skip to content

Commit

Permalink
init custom eval loop for further DPO evals (#766)
Browse files Browse the repository at this point in the history
* init

* run

* Update custom eval loop to aid DPO debugging (#770)

* sample_during_eval -> generate_during_eval

* Remove unused return_tokens

* Add import utils for W&B, prevent test fails

* Optimize dataloader random batch selection

* Separate prompt and response in logs

Makes it much easier to quickly read the starts of the generations

* Simplify logging

* reset eval steps

* manual merge fixes

* revert merge

* remove self.max_length

* style

* fix max_length

---------

Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
  • Loading branch information
Nathan Lambert and tomaarsen authored Sep 26, 2023
1 parent d608fea commit ad8d50e
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 19 deletions.
1 change: 1 addition & 0 deletions examples/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def split_prompt_and_responses(sample) -> Dict[str, str]:
max_length=script_args.max_length,
max_target_length=script_args.max_target_length,
max_prompt_length=script_args.max_prompt_length,
generate_during_eval=True,
)

# 6. train
Expand Down
33 changes: 32 additions & 1 deletion tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from trl import DPOTrainer

from .testing_utils import require_peft
from .testing_utils import require_no_wandb, require_peft


class DPOTrainerTester(unittest.TestCase):
Expand Down Expand Up @@ -213,3 +213,34 @@ def test_dpo_trainer_without_providing_ref_model_with_lora(self):
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

@require_no_wandb
def test_dpo_trainer_generate_during_eval_no_wandb(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
evaluation_strategy="steps",
)

dummy_dataset = self._init_dummy_dataset()

with self.assertRaisesRegex(
ValueError,
expected_regex="`generate_during_eval=True` requires Weights and Biases to be installed."
" Please install `wandb` to resolve.",
):
DPOTrainer(
model=self.model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
generate_during_eval=True,
)
21 changes: 20 additions & 1 deletion tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch

from trl import is_peft_available
from trl import is_peft_available, is_wandb_available


def require_peft(test_case):
Expand All @@ -27,6 +27,25 @@ def require_peft(test_case):
return test_case


def require_wandb(test_case, required: bool = True):
"""
Decorator marking a test that requires wandb. Skips the test if wandb is not available.
"""
# XOR, i.e.:
# skip if available and required = False and
# skip if not available and required = True
if is_wandb_available() ^ required:
test_case = unittest.skip("test requires wandb")(test_case)
return test_case


def require_no_wandb(test_case):
"""
Decorator marking a test that requires no wandb. Skips the test if wandb is available.
"""
return require_wandb(test_case, required=False)


def require_bitsandbytes(test_case):
"""
Decorator marking a test that requires bitsandbytes. Skips the test if bitsandbytes is not available.
Expand Down
2 changes: 1 addition & 1 deletion trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .core import set_seed
from .environment import TextEnvironment, TextHistory
from .extras import BestOfNSampler
from .import_utils import is_diffusers_available, is_peft_available
from .import_utils import is_diffusers_available, is_peft_available, is_wandb_available
from .models import (
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
Expand Down
16 changes: 10 additions & 6 deletions trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
_is_python_greater_3_8 = True


def is_peft_available():
def is_peft_available() -> bool:
return importlib.util.find_spec("peft") is not None


def is_torch_greater_2_0():
def is_torch_greater_2_0() -> bool:
if _is_python_greater_3_8:
from importlib.metadata import version

Expand All @@ -37,17 +37,21 @@ def is_torch_greater_2_0():
return torch_version >= "2.0"


def is_diffusers_available():
def is_diffusers_available() -> bool:
return importlib.util.find_spec("diffusers") is not None


def is_bitsandbytes_available():
def is_bitsandbytes_available() -> bool:
return importlib.util.find_spec("bitsandbytes") is not None


def is_torchvision_available():
def is_torchvision_available() -> bool:
return importlib.util.find_spec("torchvision") is not None


def is_rich_available():
def is_rich_available() -> bool:
return importlib.util.find_spec("rich") is not None


def is_wandb_available() -> bool:
return importlib.util.find_spec("wandb") is not None
79 changes: 73 additions & 6 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import warnings
from collections import defaultdict
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
Expand All @@ -20,10 +21,12 @@
import torch.nn as nn
import torch.nn.functional as F
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput

from ..import_utils import is_peft_available
from ..import_utils import is_peft_available, is_wandb_available
from ..models import create_reference_model
from .utils import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length

Expand All @@ -32,6 +35,10 @@
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training


if is_wandb_available():
import wandb


class DPOTrainer(Trainer):
r"""
Initialize DPOTrainer.
Expand Down Expand Up @@ -81,6 +88,8 @@ class DPOTrainer(Trainer):
If no model is provided, we need to know if the model_init returns an encoder-decoder.
disable_dropout (`bool`, defaults to `True`):
Whether or not to disable dropouts in `model` and `ref_model`.
generate_during_eval (`bool`, defaults to `False`):
Whether to sample and log generations during evaluation step.
"""

def __init__(
Expand Down Expand Up @@ -109,6 +118,7 @@ def __init__(
peft_config: Optional[Dict] = None,
is_encoder_decoder: Optional[bool] = None,
disable_dropout: bool = True,
generate_during_eval: bool = False,
):
if not is_peft_available() and peft_config is not None:
raise ValueError(
Expand All @@ -119,6 +129,12 @@ def __init__(
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
model = get_peft_model(model, peft_config)

if generate_during_eval and not is_wandb_available():
raise ValueError(
"`generate_during_eval=True` requires Weights and Biases to be installed."
" Please install `wandb` to resolve."
)

if model is not None:
self.is_encoder_decoder = model.config.is_encoder_decoder
elif is_encoder_decoder is None:
Expand Down Expand Up @@ -193,6 +209,8 @@ def __init__(
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)

self.max_length = max_length
self.generate_during_eval = generate_during_eval
self.label_pad_token_id = label_pad_token_id
self.padding_value = padding_value

Expand Down Expand Up @@ -459,7 +477,7 @@ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[
policy_output = model.generate(
batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.config.max_length,
max_length=self.max_length,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)
Expand All @@ -469,23 +487,23 @@ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[
reference_output = self.model.generate(
batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.config.max_length,
max_length=self.max_length,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)
else:
reference_output = self.ref_model.generate(
batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.config.max_length,
max_length=self.max_length,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)

policy_output = pad_to_length(policy_output, self.config.max_length, self.tokenizer.pad_token_id)
policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id)
policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)

reference_output = pad_to_length(reference_output, self.config.max_length, self.tokenizer.pad_token_id)
reference_output = pad_to_length(reference_output, self.max_length, self.tokenizer.pad_token_id)
reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True)

return policy_output_decoded, reference_output_decoded
Expand Down Expand Up @@ -533,6 +551,55 @@ def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train",
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)

def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
Overriding built-in evaluation loop to store metrics for each batch.
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Works both with or without labels.
"""

# Sample and save to game log if requested (for one batch to save time)
if self.generate_during_eval:
# Generate random indices within the range of the total number of samples
num_samples = len(dataloader.dataset)
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)

# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
random_batch_dataset = dataloader.dataset.select(random_indices)
random_batch = self.data_collator(random_batch_dataset)
random_batch = self._prepare_inputs(random_batch)

policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, random_batch)

self.log(
{
"game_log": wandb.Table(
columns=["Prompt", "Policy", "Ref Model"],
rows=[
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
for prompt, pol, ref in zip(
random_batch["prompt"], policy_output_decoded, ref_output_decoded
)
],
)
}
)

# Base evaluation
initial_output = super().evaluation_loop(
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
)

return initial_output

def log(self, logs: Dict[str, float]) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Expand Down
6 changes: 2 additions & 4 deletions trl/trainer/ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from trl.trainer.utils import exact_div

from ..core import flatten_dict
from ..import_utils import is_wandb_available


@dataclass
Expand Down Expand Up @@ -143,10 +144,7 @@ def __post_init__(self):
# check if wandb is installed
if self.log_with == "wandb":
# raise error if wandb is not installed
try:
import wandb # noqa: F401

except ImportError:
if not is_wandb_available():
raise ImportError(
"Please install wandb to use wandb logging. You can do this by running `pip install wandb`."
)
Expand Down

0 comments on commit ad8d50e

Please sign in to comment.