Skip to content

Commit

Permalink
Merge pull request #111 from IGNF/metrics-and-cm
Browse files Browse the repository at this point in the history
Log confusion matrices after each epoch
  • Loading branch information
CharlesGaydon authored Feb 8, 2024
2 parents f728d98 + ae377e6 commit d7b38bc
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 9 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# CHANGELOG

### 3.8.0
- dev: log confusion matrices to Comet after each epoch.
- fix: do not mix the two way to log IoUs to avoid known lightning [Common Pitfalls](https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html#common-pitfalls).

### 3.7.1
- fix: edge case when saving predictions under Classification channel, without saving entropy.

Expand Down
15 changes: 14 additions & 1 deletion myria3d/callbacks/comet_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_comet_logger(trainer: Trainer) -> Optional[CometLogger]:
return logger

warnings.warn(
"You are using comet related callback, but CometLogger was not found for some reason...",
"You are using comet related functions, but trainer has no CometLogger among its loggers.",
UserWarning,
)
return None
Expand Down Expand Up @@ -71,3 +71,16 @@ def setup(self, trainer, pl_module, stage):
log_path = os.getcwd()
log.info(f"----------------\n LOGS DIR is {log_path}\n ----------------")
logger.experiment.log_parameter("experiment_logs_dirpath", log_path)


def log_comet_cm(lightning_module, confmat, phase):
logger = get_comet_logger(trainer=lightning_module)
if logger:
labels = list(lightning_module.hparams.classification_dict.values())
logger.experiment.log_confusion_matrix(
matrix=confmat.cpu().numpy().tolist(),
labels=labels,
file_name=f"{phase}-confusion-matrix",
title="{phase} confusion matrix",
epoch=lightning_module.current_epoch,
)
20 changes: 13 additions & 7 deletions myria3d/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch_geometric.data import Batch
from torch_geometric.nn import knn_interpolate
from torchmetrics.classification import MulticlassJaccardIndex
from myria3d.callbacks.comet_callbacks import log_comet_cm

from myria3d.metrics.iou import iou
from myria3d.models.modules.pyg_randla_net import PyGRandLANet
Expand Down Expand Up @@ -143,12 +144,14 @@ def training_step(self, batch: Batch, batch_idx: int) -> dict:
with torch.no_grad():
preds = torch.argmax(logits.detach(), dim=1)
self.train_iou(preds, targets)
self.log("train/iou", self.train_iou, on_step=True, on_epoch=True, prog_bar=True)

return {"loss": loss, "logits": logits, "targets": targets}

def on_train_epoch_end(self) -> None:
self.train_iou.compute()
iou_epoch = self.train_iou.compute()
self.log("train/iou", iou_epoch, on_step=False, on_epoch=True, prog_bar=True)
self.log_all_class_ious(self.train_iou.confmat, "train")
log_comet_cm(self, self.train_iou.confmat, "train")
self.train_iou.reset()

def validation_step(self, batch: Batch, batch_idx: int) -> dict:
Expand All @@ -173,7 +176,7 @@ def validation_step(self, batch: Batch, batch_idx: int) -> dict:
preds = torch.argmax(logits.detach(), dim=1)
self.val_iou = self.val_iou.to(preds.device)
self.val_iou(preds, targets)
self.log("val/iou", self.val_iou, on_step=True, on_epoch=True, prog_bar=True)

return {"loss": loss, "logits": logits, "targets": targets}

def on_validation_epoch_end(self) -> None:
Expand All @@ -183,8 +186,10 @@ def on_validation_epoch_end(self) -> None:
outputs : output of validation_step
"""
self.val_iou.compute()
iou_epoch = self.val_iou.compute()
self.log("val/iou", iou_epoch, on_step=False, on_epoch=True, prog_bar=True)
self.log_all_class_ious(self.val_iou.confmat, "val")
log_comet_cm(self, self.val_iou.confmat, "val")
self.val_iou.reset()

def test_step(self, batch: Batch, batch_idx: int):
Expand All @@ -201,12 +206,11 @@ def test_step(self, batch: Batch, batch_idx: int):
targets, logits = self.forward(batch)
self.criterion = self.criterion.to(logits.device)
loss = self.criterion(logits, targets)
self.log("test/loss", loss, on_step=True, on_epoch=True)
self.log("test/loss", loss, on_step=False, on_epoch=True)

preds = torch.argmax(logits, dim=1)
self.test_iou = self.test_iou.to(preds.device)
self.test_iou(preds, targets)
self.log("test/iou", self.test_iou, on_step=False, on_epoch=True, prog_bar=True)

return {"loss": loss, "logits": logits, "targets": targets}

Expand All @@ -217,8 +221,10 @@ def on_test_epoch_end(self) -> None:
outputs : output of test
"""
self.test_iou.compute()
iou_epoch = self.test_iou.compute()
self.log("test/iou", iou_epoch, on_step=False, on_epoch=True, prog_bar=True)
self.log_all_class_ious(self.test_iou.confmat, "test")
log_comet_cm(self, self.test_iou.confmat, "test")
self.test_iou.reset()

def predict_step(self, batch: Batch) -> dict:
Expand Down
2 changes: 1 addition & 1 deletion package_metadata.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__: "3.7.1"
__version__: "3.8.0"
__name__: "myria3d"
__url__: "https://github.com/IGNF/myria3d"
__description__: "Deep Learning for the Semantic Segmentation of Aerial Lidar Point Clouds"
Expand Down

0 comments on commit d7b38bc

Please sign in to comment.