-
Notifications
You must be signed in to change notification settings - Fork 55
/
eval_tools.py
149 lines (122 loc) · 5.28 KB
/
eval_tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import csv
import json
import time
from collections import OrderedDict, defaultdict
import torch
from mivolo.data.misc import cumulative_error, cumulative_score
from timm.utils import AverageMeter, accuracy
def time_sync():
# pytorch-accurate time
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.time()
def write_results(results_file, results, format="csv"):
with open(results_file, mode="w") as cf:
if format == "json":
json.dump(results, cf, indent=4)
else:
if not isinstance(results, (list, tuple)):
results = [results]
if not results:
return
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
dw.writeheader()
for r in results:
dw.writerow(r)
cf.flush()
class Metrics:
def __init__(self, l_for_cs, draw_hist, age_classes=None):
self.batch_time = AverageMeter()
self.preproc_batch_time = AverageMeter()
self.seen = 0
self.losses = AverageMeter()
self.top1_m_gender = AverageMeter()
self.top1_m_age = AverageMeter()
if age_classes is None:
self.is_regression = True
self.av_csl_age = AverageMeter()
self.max_error = AverageMeter()
self.per_age_error = defaultdict(list)
self.l_for_cs = l_for_cs
else:
self.is_regression = False
self.draw_hist = draw_hist
def update_regression_age_metrics(self, age_out, age_target):
batch_size = age_out.size(0)
age_abs_err = torch.abs(age_out - age_target)
age_acc1 = torch.sum(age_abs_err) / age_out.shape[0]
age_csl = cumulative_score(age_out, age_target, self.l_for_cs)
me = cumulative_error(age_out, age_target, 20)
self.top1_m_age.update(age_acc1.item(), batch_size)
self.av_csl_age.update(age_csl.item(), batch_size)
self.max_error.update(me.item(), batch_size)
if self.draw_hist:
for i in range(age_out.shape[0]):
self.per_age_error[int(age_target[i].item())].append(age_abs_err[i].item())
def update_age_accuracy(self, age_out, age_target):
batch_size = age_out.size(0)
if batch_size == 0:
return
correct = torch.sum(age_out == age_target)
age_acc1 = correct * 100.0 / batch_size
self.top1_m_age.update(age_acc1.item(), batch_size)
def update_gender_accuracy(self, gender_out, gender_target):
if gender_out is None or gender_out.size(0) == 0:
return
batch_size = gender_out.size(0)
gender_acc1 = accuracy(gender_out, gender_target, topk=(1,))[0]
if gender_acc1 is not None:
self.top1_m_gender.update(gender_acc1.item(), batch_size)
def update_loss(self, loss, batch_size):
self.losses.update(loss.item(), batch_size)
def update_time(self, process_time, preprocess_time, batch_size):
self.seen += batch_size
self.batch_time.update(process_time)
self.preproc_batch_time.update(preprocess_time)
def get_info_str(self, batch_size):
avg_time = (self.preproc_batch_time.sum + self.batch_time.sum) / self.batch_time.count
cur_time = self.batch_time.val + self.preproc_batch_time.val
middle_info = (
"Time: {cur_time:.3f}s ({avg_time:.3f}s, {rate_avg:>7.2f}/s) "
"Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) "
"Gender Acc: {top1gender.val:>7.2f} ({top1gender.avg:>7.2f}) ".format(
cur_time=cur_time,
avg_time=avg_time,
rate_avg=batch_size / avg_time,
loss=self.losses,
top1gender=self.top1_m_gender,
)
)
if self.is_regression:
age_info = (
"Age CS@{l_for_cs}: {csl.val:>7.4f} ({csl.avg:>7.4f}) "
"Age CE@20: {max_error.val:>7.4f} ({max_error.avg:>7.4f}) "
"Age ME: {top1age.val:>7.2f} ({top1age.avg:>7.2f})".format(
top1age=self.top1_m_age, csl=self.av_csl_age, max_error=self.max_error, l_for_cs=self.l_for_cs
)
)
else:
age_info = "Age Acc: {top1age.val:>7.2f} ({top1age.avg:>7.2f})".format(top1age=self.top1_m_age)
return middle_info + age_info
def get_result(self):
age_top1a = self.top1_m_age.avg
gender_top1 = self.top1_m_gender.avg if self.top1_m_gender.count > 0 else None
mean_per_image_time = self.batch_time.sum / self.seen
mean_preprocessing_time = self.preproc_batch_time.sum / self.seen
results = OrderedDict(
mean_inference_time=mean_per_image_time * 1e3,
mean_preprocessing_time=mean_preprocessing_time * 1e3,
agetop1=round(age_top1a, 4),
agetop1_err=round(100 - age_top1a, 4),
)
if self.is_regression:
results.update(
dict(
max_error=self.max_error.avg,
csl=self.av_csl_age.avg,
per_age_error=self.per_age_error,
)
)
if gender_top1 is not None:
results.update(dict(gendertop1=round(gender_top1, 4), gendertop1_err=round(100 - gender_top1, 4)))
return results