-
Notifications
You must be signed in to change notification settings - Fork 15
/
logger.py
98 lines (89 loc) · 3.78 KB
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from collections import defaultdict
from collections.abc import Iterable
from numbers import Number
from torch.utils.tensorboard import SummaryWriter
from utils import ntuple
class Logger():
def __init__(self, log_path):
self.log_path = log_path
self.writer = None
self.tracker = defaultdict(int)
self.counter = defaultdict(int)
self.mean = defaultdict(int)
self.history = defaultdict(list)
self.iterator = defaultdict(int)
self.hist = defaultdict(list)
def safe(self, write):
if write:
self.writer = SummaryWriter(self.log_path)
else:
if self.writer is not None:
self.writer.close()
self.writer = None
for name in self.mean:
self.history[name].append(self.mean[name])
return
def reset(self):
self.tracker = defaultdict(int)
self.counter = defaultdict(int)
self.mean = defaultdict(int)
self.hist = defaultdict(list)
return
def append(self, result, tag, n=1, mean=True):
for k in result:
name = '{}/{}'.format(tag, k)
self.tracker[name] = result[k]
if mean:
if isinstance(result[k], Number):
self.counter[name] += n
if 'local' in name.lower():
self.hist[name].append(result[k])
self.mean[name] = ((self.counter[name] - n) * self.mean[name] + n * result[k]) / self.counter[name]
elif isinstance(result[k], Iterable):
if name not in self.mean:
self.counter[name] = [0 for _ in range(len(result[k]))]
self.mean[name] = [0 for _ in range(len(result[k]))]
_ntuple = ntuple(len(result[k]))
n = _ntuple(n)
for i in range(len(result[k])):
self.counter[name][i] += n[i]
if 'local' in name.lower():
self.hist[name].append(n[i])
self.mean[name][i] = ((self.counter[name][i] - n[i]) * self.mean[name][i] + n[i] *
result[k][i]) / self.counter[name][i]
else:
raise ValueError('Not valid data type')
return
def write(self, tag, metric_names):
names = ['{}/{}'.format(tag, k) for k in metric_names]
evaluation_info = []
for name in names:
tag, k = name.split('/')
if isinstance(self.mean[name], Number):
s = self.mean[name]
evaluation_info.append('{}: {:.4f}'.format(k, s))
if self.writer is not None:
self.iterator[name] += 1
self.writer.add_scalar(name, s, self.iterator[name])
elif isinstance(self.mean[name], Iterable):
s = tuple(self.mean[name])
evaluation_info.append('{}: {}'.format(k, s))
if self.writer is not None:
self.iterator[name] += 1
self.writer.add_scalar(name, s[0], self.iterator[name])
if 'local' in name.lower():
self.writer.add_histogram(f'{name}_hist', self.hist[name], self.iterator[name])
else:
raise ValueError('Not valid data type')
info_name = '{}/info'.format(tag)
info = self.tracker[info_name]
info[2:2] = evaluation_info
info = ' '.join(info)
print(info)
if self.writer is not None:
self.iterator[info_name] += 1
self.writer.add_text(info_name, info, self.iterator[info_name])
return
def flush(self):
self.writer.flush()
return