-
Notifications
You must be signed in to change notification settings - Fork 29
/
logger.py
67 lines (54 loc) · 1.89 KB
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
* Copyright (C) 2019 Zhonghui You
* If you are using this code in your research, please cite the paper:
* Gate Decorator: Global Filter Pruning Method for Accelerating Deep Convolutional Neural Networks, in NeurIPS 2019.
"""
import torch
from config import cfg
import os
import json
import numpy as np
class MetricsRecorder():
def __init__(self):
self.rec = {}
def add(self, pairs):
for key, val in pairs.items():
if key not in self.rec:
self.rec[key] = []
self.rec[key].append(val)
def mean(self):
r = {}
for key, val in self.rec.items():
r[key] = np.mean(val)
return r
class Logger():
def __init__(self):
self.base_path = './logs/' + cfg.base.task_name
self.logfile = self.base_path + '/log.json'
self.cfgfile = self.base_path + '/cfg.json'
if not os.path.isdir(self.base_path):
os.makedirs(self.base_path, exist_ok=True)
with open(self.logfile, 'w') as fp:
json.dump({}, fp)
with open(self.cfgfile, 'w') as fp:
json.dump(cfg.raw(), fp)
def save_record(self, epoch, record):
with open(self.logfile) as fp:
log = json.load(fp)
log[str(epoch)] = record
with open(self.logfile, 'w') as fp:
json.dump(log, fp)
def save_network(self, epoch, network):
saving_path = self.base_path + '/ckp.%d.torch' % epoch
print('saving model ...')
if type(network) is torch.nn.DataParallel:
torch.save(network.module.state_dict(), saving_path)
else:
torch.save(network.state_dict(), saving_path)
cfg.base.epoch = epoch
cfg.base.checkpoint_path = saving_path
with open(self.cfgfile, 'w') as fp:
json.dump(cfg.raw(), fp)
logger = None
if logger is None:
logger = Logger()