Skip to content

Commit f546278

Browse files
committed
Add train metrics calculation and tb logging
1 parent 235d6ea commit f546278

File tree

1 file changed

+95
-24
lines changed

1 file changed

+95
-24
lines changed

imagenet/main.py

Lines changed: 95 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@
3939
# Model evaluation
4040
# ================
4141

42+
class TrainMetrics(NamedTuple):
43+
class_labels: List[int]
44+
acc_balanced: float
45+
f1_micro: float
46+
f1_macro: float
47+
prec_micro: float
48+
prec_macro: float
49+
rec_micro: float
50+
rec_macro: float
51+
52+
4253
class ValidationMetrics(NamedTuple):
4354
class_labels: List[int]
4455
acc_balanced: float
@@ -63,7 +74,7 @@ class EarlyStopping:
6374
- https://github.com/Bjarten/early-stopping-pytorch
6475
"""
6576

66-
def __init__(self, patience=3, min_delta=1, min_epochs=50):
77+
def __init__(self, patience: int = 3, min_delta: float = 1, min_epochs: int = 50):
6778
self.patience = patience
6879
self.min_delta = min_delta
6980
self.counter = 0
@@ -457,24 +468,25 @@ def get_target_class(cl: int) -> str:
457468
except Exception as e:
458469
log(f"Failed to add graph to tensorboard.")
459470

460-
early_stopping = EarlyStopping(patience=5, min_delta=5, min_epochs=50)
471+
early_stopping = EarlyStopping(patience=5, min_delta=0.5, min_epochs=50)
461472
try:
462473
for epoch in range(args.start_epoch, args.epochs):
463474
if args.distributed:
464475
train_sampler.set_epoch(epoch)
465476

466477
# train for one epoch
467-
train_loss = train(train_loader, model, criterion, optimizer, epoch, device, args)
478+
train_acc1, train_loss, train_metrics = train(train_loader, model, criterion, optimizer, epoch, device,
479+
args)
468480

469481
# evaluate on validation set
470-
acc1, val_loss, metrics = validate(val_loader, model, criterion, args)
482+
val_acc1, val_loss, val_metrics = validate(val_loader, model, criterion, args)
471483
scheduler.step()
472484
early_stopping(val_loss, epoch)
473485

474486
# remember best acc@1 and save checkpoint
475-
is_best = acc1 > best_acc1
476-
best_acc1 = max(acc1, best_acc1)
477-
best_metrics = metrics if metrics.f1_micro > best_metrics.f1_micro else best_metrics
487+
is_best = val_acc1 > best_acc1
488+
best_acc1 = max(val_acc1, best_acc1)
489+
best_metrics = val_metrics if val_metrics.f1_micro > best_metrics.f1_micro else best_metrics
478490

479491
if not args.multiprocessing_distributed or \
480492
(args.multiprocessing_distributed and args.rank % ngpus_per_node == 0) or \
@@ -491,31 +503,49 @@ def get_target_class(cl: int) -> str:
491503
if tensorboard_writer:
492504
tensorboard_writer.add_scalars('Loss', dict(train=train_loss, val=val_loss), epoch + 1)
493505
tensorboard_writer.add_scalars('Metrics/Accuracy',
494-
dict(acc=acc1 / 100.0, balanced_acc=metrics.acc_balanced), epoch + 1)
495-
tensorboard_writer.add_scalars('Metrics/F1', dict(micro=metrics.f1_micro, macro=metrics.f1_macro),
506+
dict(val_acc=val_acc1 / 100.0,
507+
val_bacc=val_metrics.acc_balanced,
508+
train_acc=train_acc1 / 100.0,
509+
train_bacc=train_metrics.acc_balanced),
510+
epoch + 1)
511+
tensorboard_writer.add_scalars('Metrics/F1',
512+
dict(val_micro=val_metrics.f1_micro,
513+
val_macro=val_metrics.f1_macro,
514+
train_micro=train_metrics.f1_micro,
515+
train_macro=train_metrics.f1_macro),
496516
epoch + 1)
497517
tensorboard_writer.add_scalars('Metrics/Precision',
498-
dict(micro=metrics.prec_micro, macro=metrics.prec_macro), epoch + 1)
499-
tensorboard_writer.add_scalars('Metrics/Recall', dict(micro=metrics.rec_micro, macro=metrics.rec_macro),
518+
dict(val_micro=val_metrics.prec_micro,
519+
val_macro=val_metrics.prec_macro,
520+
train_micro=train_metrics.prec_micro,
521+
train_macro=train_metrics.prec_macro),
522+
epoch + 1)
523+
tensorboard_writer.add_scalars('Metrics/Recall',
524+
dict(val_micro=val_metrics.rec_micro,
525+
val_macro=val_metrics.rec_macro,
526+
train_micro=train_metrics.rec_micro,
527+
train_macro=train_metrics.rec_macro),
500528
epoch + 1)
501529
tensorboard_writer.add_scalars('Metrics/F1/class',
502-
{get_target_class(cl): f1 for cl, f1 in metrics.f1_per_class}, epoch + 1)
530+
{get_target_class(cl): f1 for cl, f1 in val_metrics.f1_per_class},
531+
epoch + 1)
503532

504533
if epoch < 10 or epoch % 5 == 0 or epoch == args.epochs - 1:
505-
class_names = [get_target_class(cl) for cl in list({l for l in metrics.class_labels})]
506-
fig_abs, _ = plot_confusion_matrix(metrics.conf_matrix, class_names=class_names, normalize=False)
507-
fig_rel, _ = plot_confusion_matrix(metrics.conf_matrix, class_names=class_names, normalize=True)
534+
class_names = [get_target_class(cl) for cl in list({l for l in val_metrics.class_labels})]
535+
fig_abs, _ = plot_confusion_matrix(val_metrics.conf_matrix, class_names=class_names,
536+
normalize=False)
537+
fig_rel, _ = plot_confusion_matrix(val_metrics.conf_matrix, class_names=class_names, normalize=True)
508538
tensorboard_writer.add_figure('Confusion matrix', fig_abs, epoch + 1)
509539
tensorboard_writer.add_figure('Confusion matrix/normalized', fig_rel, epoch + 1)
510540

511-
for cl in metrics.class_labels:
541+
for cl in val_metrics.class_labels:
512542
class_index = int(cl)
513-
labels_true = metrics.labels_true == class_index
514-
pred_probs = metrics.labels_probs[:, class_index]
543+
labels_true = val_metrics.labels_true == class_index
544+
pred_probs = val_metrics.labels_probs[:, class_index]
515545
tensorboard_writer.add_pr_curve(f'PR curve/{get_target_class(class_index)}',
516546
labels_true, pred_probs, epoch + 1)
517547

518-
tensorboard_writer.add_figure('PR curve', metrics.fig_pr_curve_micro, epoch + 1)
548+
tensorboard_writer.add_figure('PR curve', val_metrics.fig_pr_curve_micro, epoch + 1)
519549

520550
if early_stopping.should_stop:
521551
log(f"Early stopping at epoch {epoch + 1}")
@@ -540,7 +570,7 @@ def get_target_class(cl: int) -> str:
540570
})
541571

542572

543-
def train(train_loader, model, criterion, optimizer, epoch, device, args) -> float:
573+
def train(train_loader, model, criterion, optimizer, epoch, device, args) -> Tuple[float, float, TrainMetrics]:
544574
batch_time = AverageMeter('Time', ':6.3f')
545575
data_time = AverageMeter('Data', ':6.3f')
546576
losses = AverageMeter('Loss', ':.4e')
@@ -555,6 +585,11 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args) -> flo
555585
# switch to train mode
556586
model.train()
557587

588+
# for train metrics
589+
labels_true = np.array([], dtype=np.int64)
590+
labels_pred = np.array([], dtype=np.int64)
591+
labels_probs = []
592+
558593
end = time.time()
559594
for i, (images, target) in enumerate(train_loader):
560595
# measure data loading time
@@ -579,14 +614,29 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args) -> flo
579614
loss.backward()
580615
optimizer.step()
581616

617+
with torch.no_grad():
618+
predicted_values, predicted_indices = torch.max(output.data, 1)
619+
labels_true = np.append(labels_true, target.cpu().numpy())
620+
labels_pred = np.append(labels_pred, predicted_indices.cpu().numpy())
621+
622+
class_probs_batch = [F.softmax(el, dim=0) for el in output]
623+
labels_probs.append(class_probs_batch)
624+
582625
# measure elapsed time
583626
batch_time.update(time.time() - end)
584627
end = time.time()
585628

586629
if i % args.print_freq == 0:
587630
progress.display(i + 1)
588631

589-
return loss.item()
632+
if args.distributed:
633+
acc_top1.all_reduce()
634+
acc_top5.all_reduce()
635+
636+
labels_probs = torch.cat([torch.stack(batch) for batch in labels_probs]).cpu()
637+
metrics = calculate_train_metrics(labels_true, labels_pred, labels_probs)
638+
639+
return acc_top1.avg, loss.item(), metrics
590640

591641

592642
def validate(val_loader, model, criterion, args) -> Tuple[float, float, "ValidationMetrics"]:
@@ -635,7 +685,7 @@ def run_validate(loader, base_progress=0) -> ValidationMetrics:
635685

636686
labels_probs = torch.cat([torch.stack(batch) for batch in labels_probs]).cpu()
637687

638-
return metrics_labels_true_pred(labels_true, labels_pred, labels_probs)
688+
return calculate_validation_metrics(labels_true, labels_pred, labels_probs)
639689

640690
batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
641691
losses = AverageMeter('Loss', ':.4e', Summary.NONE)
@@ -786,8 +836,29 @@ def accuracy(output, target, topk=(1,)):
786836
return res
787837

788838

789-
def metrics_labels_true_pred(labels_true: np.array, labels_pred: np.array,
790-
labels_probs: torch.Tensor) -> ValidationMetrics:
839+
def calculate_train_metrics(labels_true: np.array, labels_pred: np.array,
840+
labels_probs: torch.Tensor) -> TrainMetrics:
841+
unique_labels = list({l for l in labels_true})
842+
f1_micro = f1_score(labels_true, labels_pred, average="micro")
843+
f1_macro = f1_score(labels_true, labels_pred, average="macro")
844+
845+
acc_balanced = balanced_accuracy_score(labels_true, labels_pred)
846+
prec_micro = precision_score(labels_true, labels_pred, average="micro")
847+
prec_macro = precision_score(labels_true, labels_pred, average="macro")
848+
rec_micro = recall_score(labels_true, labels_pred, average="micro")
849+
rec_macro = recall_score(labels_true, labels_pred, average="macro")
850+
851+
return TrainMetrics(
852+
unique_labels,
853+
acc_balanced,
854+
f1_micro, f1_macro,
855+
prec_micro, prec_macro,
856+
rec_micro, rec_macro
857+
)
858+
859+
860+
def calculate_validation_metrics(labels_true: np.array, labels_pred: np.array,
861+
labels_probs: torch.Tensor) -> ValidationMetrics:
791862
unique_labels = list({l for l in labels_true})
792863
f1_per_class = f1_score(labels_true, labels_pred, average=None, labels=unique_labels)
793864
f1_micro = f1_score(labels_true, labels_pred, average="micro")

0 commit comments

Comments
 (0)