From 6b02c293d4363a593e31091da1b61a6c1229754b Mon Sep 17 00:00:00 2001 From: Charles Gaydon Date: Thu, 25 Apr 2024 15:48:45 +0200 Subject: [PATCH] Move logged items to gpu to prevent error in ddp --- myria3d/callbacks/metric_callbacks.py | 7 +++++-- myria3d/models/model.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/myria3d/callbacks/metric_callbacks.py b/myria3d/callbacks/metric_callbacks.py index f73d3edb..2f8e33d5 100644 --- a/myria3d/callbacks/metric_callbacks.py +++ b/myria3d/callbacks/metric_callbacks.py @@ -56,13 +56,16 @@ def _end_of_batch(self, phase: str, outputs): def _end_of_epoch(self, phase: str, pl_module): for metric_name, metric in self.metrics[phase].items(): metric_name_for_log = f"{phase}/{metric_name}" + value = metric.to(pl_module.device).compute() self.log( metric_name_for_log, - metric, + value, on_epoch=True, on_step=False, metric_attribute=metric_name_for_log, ) + metric.reset() # always reset state when using compute(). + class_names = pl_module.hparams.classification_dict.values() for metric_name, metric in self.metrics_by_class[phase].items(): values = metric.to(pl_module.device).compute() @@ -75,7 +78,7 @@ def _end_of_epoch(self, phase: str, pl_module): on_epoch=True, metric_attribute=metric_name_for_log, ) - metric.reset() # always reset when using compute(). + metric.reset() # always reset state when using compute(). def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self._end_of_batch("train", outputs) diff --git a/myria3d/models/model.py b/myria3d/models/model.py index 3dbd5fc5..1d9df54e 100755 --- a/myria3d/models/model.py +++ b/myria3d/models/model.py @@ -78,7 +78,7 @@ def on_test_start(self) -> None: self.test_iou = MulticlassJaccardIndex(self.hparams.num_classes).to(self.device) def log_all_class_ious(self, confmat, phase: str): - ious = iou(confmat) + ious = iou(confmat).to(self.device) for class_iou, class_name in zip(ious, self.hparams.classification_dict.values()): metric_name = f"{phase}/iou_CLASS_{class_name}" self.log(