39
39
# Model evaluation
40
40
# ================
41
41
42
+ class TrainMetrics (NamedTuple ):
43
+ class_labels : List [int ]
44
+ acc_balanced : float
45
+ f1_micro : float
46
+ f1_macro : float
47
+ prec_micro : float
48
+ prec_macro : float
49
+ rec_micro : float
50
+ rec_macro : float
51
+
52
+
42
53
class ValidationMetrics (NamedTuple ):
43
54
class_labels : List [int ]
44
55
acc_balanced : float
@@ -63,7 +74,7 @@ class EarlyStopping:
63
74
- https://github.com/Bjarten/early-stopping-pytorch
64
75
"""
65
76
66
- def __init__ (self , patience = 3 , min_delta = 1 , min_epochs = 50 ):
77
+ def __init__ (self , patience : int = 3 , min_delta : float = 1 , min_epochs : int = 50 ):
67
78
self .patience = patience
68
79
self .min_delta = min_delta
69
80
self .counter = 0
@@ -457,24 +468,25 @@ def get_target_class(cl: int) -> str:
457
468
except Exception as e :
458
469
log (f"Failed to add graph to tensorboard." )
459
470
460
- early_stopping = EarlyStopping (patience = 5 , min_delta = 5 , min_epochs = 50 )
471
+ early_stopping = EarlyStopping (patience = 5 , min_delta = 0. 5 , min_epochs = 50 )
461
472
try :
462
473
for epoch in range (args .start_epoch , args .epochs ):
463
474
if args .distributed :
464
475
train_sampler .set_epoch (epoch )
465
476
466
477
# train for one epoch
467
- train_loss = train (train_loader , model , criterion , optimizer , epoch , device , args )
478
+ train_acc1 , train_loss , train_metrics = train (train_loader , model , criterion , optimizer , epoch , device ,
479
+ args )
468
480
469
481
# evaluate on validation set
470
- acc1 , val_loss , metrics = validate (val_loader , model , criterion , args )
482
+ val_acc1 , val_loss , val_metrics = validate (val_loader , model , criterion , args )
471
483
scheduler .step ()
472
484
early_stopping (val_loss , epoch )
473
485
474
486
# remember best acc@1 and save checkpoint
475
- is_best = acc1 > best_acc1
476
- best_acc1 = max (acc1 , best_acc1 )
477
- best_metrics = metrics if metrics .f1_micro > best_metrics .f1_micro else best_metrics
487
+ is_best = val_acc1 > best_acc1
488
+ best_acc1 = max (val_acc1 , best_acc1 )
489
+ best_metrics = val_metrics if val_metrics .f1_micro > best_metrics .f1_micro else best_metrics
478
490
479
491
if not args .multiprocessing_distributed or \
480
492
(args .multiprocessing_distributed and args .rank % ngpus_per_node == 0 ) or \
@@ -491,31 +503,49 @@ def get_target_class(cl: int) -> str:
491
503
if tensorboard_writer :
492
504
tensorboard_writer .add_scalars ('Loss' , dict (train = train_loss , val = val_loss ), epoch + 1 )
493
505
tensorboard_writer .add_scalars ('Metrics/Accuracy' ,
494
- dict (acc = acc1 / 100.0 , balanced_acc = metrics .acc_balanced ), epoch + 1 )
495
- tensorboard_writer .add_scalars ('Metrics/F1' , dict (micro = metrics .f1_micro , macro = metrics .f1_macro ),
506
+ dict (val_acc = val_acc1 / 100.0 ,
507
+ val_bacc = val_metrics .acc_balanced ,
508
+ train_acc = train_acc1 / 100.0 ,
509
+ train_bacc = train_metrics .acc_balanced ),
510
+ epoch + 1 )
511
+ tensorboard_writer .add_scalars ('Metrics/F1' ,
512
+ dict (val_micro = val_metrics .f1_micro ,
513
+ val_macro = val_metrics .f1_macro ,
514
+ train_micro = train_metrics .f1_micro ,
515
+ train_macro = train_metrics .f1_macro ),
496
516
epoch + 1 )
497
517
tensorboard_writer .add_scalars ('Metrics/Precision' ,
498
- dict (micro = metrics .prec_micro , macro = metrics .prec_macro ), epoch + 1 )
499
- tensorboard_writer .add_scalars ('Metrics/Recall' , dict (micro = metrics .rec_micro , macro = metrics .rec_macro ),
518
+ dict (val_micro = val_metrics .prec_micro ,
519
+ val_macro = val_metrics .prec_macro ,
520
+ train_micro = train_metrics .prec_micro ,
521
+ train_macro = train_metrics .prec_macro ),
522
+ epoch + 1 )
523
+ tensorboard_writer .add_scalars ('Metrics/Recall' ,
524
+ dict (val_micro = val_metrics .rec_micro ,
525
+ val_macro = val_metrics .rec_macro ,
526
+ train_micro = train_metrics .rec_micro ,
527
+ train_macro = train_metrics .rec_macro ),
500
528
epoch + 1 )
501
529
tensorboard_writer .add_scalars ('Metrics/F1/class' ,
502
- {get_target_class (cl ): f1 for cl , f1 in metrics .f1_per_class }, epoch + 1 )
530
+ {get_target_class (cl ): f1 for cl , f1 in val_metrics .f1_per_class },
531
+ epoch + 1 )
503
532
504
533
if epoch < 10 or epoch % 5 == 0 or epoch == args .epochs - 1 :
505
- class_names = [get_target_class (cl ) for cl in list ({l for l in metrics .class_labels })]
506
- fig_abs , _ = plot_confusion_matrix (metrics .conf_matrix , class_names = class_names , normalize = False )
507
- fig_rel , _ = plot_confusion_matrix (metrics .conf_matrix , class_names = class_names , normalize = True )
534
+ class_names = [get_target_class (cl ) for cl in list ({l for l in val_metrics .class_labels })]
535
+ fig_abs , _ = plot_confusion_matrix (val_metrics .conf_matrix , class_names = class_names ,
536
+ normalize = False )
537
+ fig_rel , _ = plot_confusion_matrix (val_metrics .conf_matrix , class_names = class_names , normalize = True )
508
538
tensorboard_writer .add_figure ('Confusion matrix' , fig_abs , epoch + 1 )
509
539
tensorboard_writer .add_figure ('Confusion matrix/normalized' , fig_rel , epoch + 1 )
510
540
511
- for cl in metrics .class_labels :
541
+ for cl in val_metrics .class_labels :
512
542
class_index = int (cl )
513
- labels_true = metrics .labels_true == class_index
514
- pred_probs = metrics .labels_probs [:, class_index ]
543
+ labels_true = val_metrics .labels_true == class_index
544
+ pred_probs = val_metrics .labels_probs [:, class_index ]
515
545
tensorboard_writer .add_pr_curve (f'PR curve/{ get_target_class (class_index )} ' ,
516
546
labels_true , pred_probs , epoch + 1 )
517
547
518
- tensorboard_writer .add_figure ('PR curve' , metrics .fig_pr_curve_micro , epoch + 1 )
548
+ tensorboard_writer .add_figure ('PR curve' , val_metrics .fig_pr_curve_micro , epoch + 1 )
519
549
520
550
if early_stopping .should_stop :
521
551
log (f"Early stopping at epoch { epoch + 1 } " )
@@ -540,7 +570,7 @@ def get_target_class(cl: int) -> str:
540
570
})
541
571
542
572
543
- def train (train_loader , model , criterion , optimizer , epoch , device , args ) -> float :
573
+ def train (train_loader , model , criterion , optimizer , epoch , device , args ) -> Tuple [ float , float , TrainMetrics ] :
544
574
batch_time = AverageMeter ('Time' , ':6.3f' )
545
575
data_time = AverageMeter ('Data' , ':6.3f' )
546
576
losses = AverageMeter ('Loss' , ':.4e' )
@@ -555,6 +585,11 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args) -> flo
555
585
# switch to train mode
556
586
model .train ()
557
587
588
+ # for train metrics
589
+ labels_true = np .array ([], dtype = np .int64 )
590
+ labels_pred = np .array ([], dtype = np .int64 )
591
+ labels_probs = []
592
+
558
593
end = time .time ()
559
594
for i , (images , target ) in enumerate (train_loader ):
560
595
# measure data loading time
@@ -579,14 +614,29 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args) -> flo
579
614
loss .backward ()
580
615
optimizer .step ()
581
616
617
+ with torch .no_grad ():
618
+ predicted_values , predicted_indices = torch .max (output .data , 1 )
619
+ labels_true = np .append (labels_true , target .cpu ().numpy ())
620
+ labels_pred = np .append (labels_pred , predicted_indices .cpu ().numpy ())
621
+
622
+ class_probs_batch = [F .softmax (el , dim = 0 ) for el in output ]
623
+ labels_probs .append (class_probs_batch )
624
+
582
625
# measure elapsed time
583
626
batch_time .update (time .time () - end )
584
627
end = time .time ()
585
628
586
629
if i % args .print_freq == 0 :
587
630
progress .display (i + 1 )
588
631
589
- return loss .item ()
632
+ if args .distributed :
633
+ acc_top1 .all_reduce ()
634
+ acc_top5 .all_reduce ()
635
+
636
+ labels_probs = torch .cat ([torch .stack (batch ) for batch in labels_probs ]).cpu ()
637
+ metrics = calculate_train_metrics (labels_true , labels_pred , labels_probs )
638
+
639
+ return acc_top1 .avg , loss .item (), metrics
590
640
591
641
592
642
def validate (val_loader , model , criterion , args ) -> Tuple [float , float , "ValidationMetrics" ]:
@@ -635,7 +685,7 @@ def run_validate(loader, base_progress=0) -> ValidationMetrics:
635
685
636
686
labels_probs = torch .cat ([torch .stack (batch ) for batch in labels_probs ]).cpu ()
637
687
638
- return metrics_labels_true_pred (labels_true , labels_pred , labels_probs )
688
+ return calculate_validation_metrics (labels_true , labels_pred , labels_probs )
639
689
640
690
batch_time = AverageMeter ('Time' , ':6.3f' , Summary .NONE )
641
691
losses = AverageMeter ('Loss' , ':.4e' , Summary .NONE )
@@ -786,8 +836,29 @@ def accuracy(output, target, topk=(1,)):
786
836
return res
787
837
788
838
789
- def metrics_labels_true_pred (labels_true : np .array , labels_pred : np .array ,
790
- labels_probs : torch .Tensor ) -> ValidationMetrics :
839
+ def calculate_train_metrics (labels_true : np .array , labels_pred : np .array ,
840
+ labels_probs : torch .Tensor ) -> TrainMetrics :
841
+ unique_labels = list ({l for l in labels_true })
842
+ f1_micro = f1_score (labels_true , labels_pred , average = "micro" )
843
+ f1_macro = f1_score (labels_true , labels_pred , average = "macro" )
844
+
845
+ acc_balanced = balanced_accuracy_score (labels_true , labels_pred )
846
+ prec_micro = precision_score (labels_true , labels_pred , average = "micro" )
847
+ prec_macro = precision_score (labels_true , labels_pred , average = "macro" )
848
+ rec_micro = recall_score (labels_true , labels_pred , average = "micro" )
849
+ rec_macro = recall_score (labels_true , labels_pred , average = "macro" )
850
+
851
+ return TrainMetrics (
852
+ unique_labels ,
853
+ acc_balanced ,
854
+ f1_micro , f1_macro ,
855
+ prec_micro , prec_macro ,
856
+ rec_micro , rec_macro
857
+ )
858
+
859
+
860
+ def calculate_validation_metrics (labels_true : np .array , labels_pred : np .array ,
861
+ labels_probs : torch .Tensor ) -> ValidationMetrics :
791
862
unique_labels = list ({l for l in labels_true })
792
863
f1_per_class = f1_score (labels_true , labels_pred , average = None , labels = unique_labels )
793
864
f1_micro = f1_score (labels_true , labels_pred , average = "micro" )
0 commit comments