Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow eval in Online DPO #2476

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 21 additions & 41 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,10 @@
ProcessorMixin,
Trainer,
TrainerCallback,
is_apex_available,
is_wandb_available,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, seed_worker
from transformers.training_args import OptimizerNames
from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging
from transformers.utils import is_peft_available, logging
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
Expand All @@ -56,7 +54,6 @@
SIMPLE_CHAT_TEMPLATE,
DPODataCollatorWithPadding,
disable_dropout_in_model,
empty_cache,
generate_model_card,
get_reward,
prepare_deepspeed,
Expand All @@ -67,17 +64,6 @@
if is_peft_available():
from peft import PeftModel, get_peft_model

if is_apex_available():
from apex import amp


if is_sagemaker_mp_enabled():
from smdistributed.modelparallel import __version__ as SMP_VERSION

IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")

else:
IS_SAGEMAKER_MP_POST_1_10 = False

if is_wandb_available():
import wandb
Expand Down Expand Up @@ -391,11 +377,13 @@ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None

return self.accelerator.prepare(eval_dataloader)

def training_step(
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
def compute_loss(
self,
model: nn.Module,
inputs: dict[str, Union[torch.Tensor, Any]],
return_outputs: bool = False,
num_items_in_batch: Optional[int] = None,
) -> torch.Tensor:
model.train()

# Apply chat template and tokenize the input.
# We do this on-the-fly to enable the use of reward models and policies with different tokenizers / chat templates.
batch_size = len(next(iter(inputs.values())))
Expand Down Expand Up @@ -579,28 +567,7 @@ def training_step(
self.stats["rewards/accuracies"].append(accuracy.float().mean().item())
self.stats["beta"].append(self.beta)

if (
self.args.torch_empty_cache_steps is not None
and self.state.global_step % self.args.torch_empty_cache_steps == 0
):
empty_cache()

kwargs = {}

# For LOMO optimizers you need to explicitly use the learnign rate
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
kwargs["learning_rate"] = self._get_learning_rate()

if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training

if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss, **kwargs)

return loss.detach() / self.args.gradient_accumulation_steps
return (loss, None) if return_outputs else loss

# Same as Trainer._maybe_log_save_evaluate but log our metrics
# start_time defaults to None to allow compatibility with transformers<=4.46
Expand Down Expand Up @@ -645,6 +612,19 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno
self._save_checkpoint(model, trial)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)

def prediction_step(
self,
model: nn.Module,
inputs: dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[list[str]] = None,
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
with torch.no_grad():
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
loss = loss.mean().detach()
return (loss, None, None)

# Copy-pasted from transformers.Trainer to maintain compatibility with earlier versions.
# This can be removed once the minimum transformers version is updated to 4.47.
# Refer to https://github.com/huggingface/trl/pull/2288 for more details.
Expand Down
Loading