diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index d3812822b7..cfd46c1891 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -808,6 +808,9 @@ def compute_loss( with compute_loss_context_manager: loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") + # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: + loss = loss.to(self.args.device) + # force log the metrics self.store_metrics(metrics, train_eval="train")