diff --git a/ludwig/models/llm.py b/ludwig/models/llm.py index 0ef2c9e3200..2acd6dc7178 100644 --- a/ludwig/models/llm.py +++ b/ludwig/models/llm.py @@ -467,7 +467,11 @@ def update_metrics(self, targets, predictions): continue of_obj.update_metrics(targets[of_name], predictions[of_name]) + # HACK (Tim): get the device of the targets to transfer self.eval_loss_metric to the same device + target_device = list(targets.values())[0].device + eval_loss, additional_losses = self.eval_loss(targets, predictions) + self.eval_loss_metric = self.eval_loss_metric.to(target_device) self.eval_loss_metric.update(eval_loss) self.eval_additional_losses_metrics.update(additional_losses)