-
Notifications
You must be signed in to change notification settings - Fork 0
/
logs.py
67 lines (55 loc) · 1.82 KB
/
logs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import pickle
RGB_PREFIX = 'model_rgb_state_dict'
FLOW_PREFIX = 'model_flow_state_dict'
JOINT_PREFIX = 'model_twoStream_state_dict'
LOG_PREFIX = 'log_stage'
VAL_LOG_PREFIX = 'val_log_stage'
class Logger():
def __init__(self, **params):
self.params = params
self.data = []
self.step_data = []
def add_epoch_data(self, epoch, acc, loss, dual_loss=False):
if dual_loss is not False:
self.data.append({epoch:(acc, loss, dual_loss)})
else:
self.data.append({epoch:(acc, loss)})
def add_step_data(self, step, acc, loss):
self.step_data.append({step:(acc, loss)})
def save(self, path):
with open(path, 'wb') as logfile:
pickle.dump(self, logfile)
@classmethod
def load(cls, path):
with open(path, 'rb') as logfile:
new_instance = pickle.load(logfile)
return new_instance
def generate_model_checkpoint_name(stage, n_frames, ms_block=False, loss=None, optional=''):
name = ""
if stage < 3:
name += RGB_PREFIX
if stage == 2:
name += '_stage2'
elif stage == 3:
name += FLOW_PREFIX
else:
name += JOINT_PREFIX
name += '_'+str(n_frames)+'frames'
if loss is not None:
name += '_'+loss
if ms_block:
name += '_msblock'
name += optional+".pth"
return name
def generate_log_filenames(stage, n_frames, ms_block=False, loss=None, optional=''):
train = LOG_PREFIX + str(stage) + '_'+str(n_frames)+'frames'
val = VAL_LOG_PREFIX + str(stage) + '_'+str(n_frames)+'frames'
if loss is not None:
train += '_'+loss
val += '_'+loss
if ms_block:
train += '_msblock'
val += '_msblock'
train += optional+".obj"
val += optional+".obj"
return train, val