-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
101 lines (79 loc) · 3.7 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score
def calc_multi_cls_measures(probs, label):
"""Calculate multi-class classification measures (Accuracy, precision,
Recall, AUC.
:probs: NxC numpy array storing probabilities for each case
:label: ground truth label
:returns: a dictionary of accuracy, precision and recall
"""
n_classes = probs.shape[1]
preds = np.argmax(probs, axis=1)
accuracy = accuracy_score(label, preds)
precisions = precision_score(label, preds, average=None,
labels=range(n_classes), zero_division=0.)
recalls = recall_score(label, preds, average=None, labels=range(n_classes),
zero_division=0.)
metric_collects = {'accuracy': accuracy, 'precisions': precisions,
'recalls': recalls}
return metric_collects
def print_progress(epoch=None, n_epoch=None, n_iter=None, iters_one_batch=None,
mean_loss=None, cur_lr=None, metric_collects=None,
prefix=None):
"""Print the training progress.
:epoch: epoch number
:n_epoch: total number of epochs
:n_iter: current iteration number
:mean_loss: mean loss of current batch
:iters_one_batch: number of iterations per batch
:cur_lr: current learning rate
:metric_collects: dictionary returned by function calc_multi_cls_measures
:returns: None
"""
accuracy = metric_collects['accuracy']
precisions = metric_collects['precisions']
recalls = metric_collects['recalls']
n_classes = len(precisions)
log_str = ''
if epoch is not None:
log_str += 'Ep: {0}/{1}|'.format(epoch, n_epoch)
if n_iter is not None:
log_str += 'It: {0}/{1}|'.format(n_iter, iters_one_batch)
if mean_loss is not None:
log_str += 'Loss: {0:.4f}|'.format(mean_loss)
log_str += 'Acc: {:.4f}|'.format(accuracy)
templ = 'Pr: ' + ', '.join(['{:.4f}'] * (n_classes-1)) + '|'
log_str += templ.format(*(precisions[1:].tolist()))
templ = 'Re: ' + ', '.join(['{:.4f}'] * (n_classes-1)) + '|'
log_str += templ.format(*(recalls[1:].tolist()))
if cur_lr is not None:
log_str += 'lr: {0}'.format(cur_lr)
log_str = log_str if prefix is None else prefix + log_str
print(log_str)
def print_epoch_progress(train_loss, val_loss, time_duration, train_metric,
val_metric):
"""Print all the information after each epoch.
:train_loss: average training loss
:val_loss: average validation loss
:time_duration: time duration for current epoch
:train_metric_collects: a performance dictionary for training
:val_metric_collects: a performance dictionary for validation
:returns: None
"""
train_acc, val_acc = train_metric['accuracy'], val_metric['accuracy']
train_prec, val_prec = train_metric['precisions'], val_metric['precisions']
train_recalls, val_recalls = train_metric['recalls'], val_metric['recalls']
log_str = 'Train/Val| Loss: {:.4f}/{:.4f}|'.format(train_loss, val_loss)
log_str += 'Acc: {:.4f}/{:.4f}|'.format(train_acc, val_acc)
n_classes = len(train_prec)
templ = 'Pr: ' + ', '.join(['{:.4f}'] * (n_classes-1)) + '/'
log_str += templ.format(*(train_prec[1:].tolist()))
templ = ', '.join(['{:.4f}'] * (n_classes-1)) + '|'
log_str += templ.format(*(val_prec[1:].tolist()))
templ = 'Re: ' + ', '.join(['{:.4f}'] * (n_classes - 1)) + '/'
log_str += templ.format(*(train_recalls[1:].tolist()))
templ = ', '.join(['{:.4f}'] * (n_classes - 1)) + '|'
log_str += templ.format(*(val_recalls[1:].tolist()))
log_str += 'T(s) {:.2f}'.format(time_duration)
print(log_str)