Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def training(local_rank, config, logger=None):

# Setup evaluators
num_classes = config.num_classes
cm_metric = ConfusionMatrix(num_classes=num_classes)
cm_metric = ConfusionMatrix(num_classes=num_classes, average="recall")

val_metrics = {
"IoU": IoU(cm_metric),
Expand Down Expand Up @@ -241,16 +241,17 @@ def run_validation():
event_name=Events.ITERATION_COMPLETED(once=len(val_loader) // 2),
)

# Log confusion matrix to Trains:
if exp_tracking.has_trains:
from trains import Task
# Log confusion matrix to Trains:
if exp_tracking.has_trains:

@trainer.on(Events.COMPLETED)
def compute_and_log_cm():
cm = cm_metric.compute().cpu().numpy()

trains_logger = Task.current_task().get_logger()
if idist.get_rank() == 0:
from trains import Task

@trainer.on(Events.COMPLETED)
def log_cm():
cm = cm_metric.compute().numpy()
cm = cm / (cm.sum(axis=1)[:, None] + 1e-15)
trains_logger = Task.current_task().get_logger()
trains_logger.report_confusion_matrix(
title="Final Confusion Matrix",
series="cm-preds-gt",
Expand Down