-
Notifications
You must be signed in to change notification settings - Fork 12
/
Recorder.py
104 lines (88 loc) · 3.99 KB
/
Recorder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import shutil
import time
import sys
class Recorder(object):
def __init__(self, snapshot_pref, exclude_dirs=None, max_file_size=10):
"""
:param snapshot_pref: The dir you want to save the backups
:param exclude_dirs: The dir name you want to exclude; eg ["results", "data"]
:param max_file_size: The minimum size of backups file; unit is MB
"""
date = time.strftime('%Y%m%d%H%M%S', time.localtime(time.time()))
if not os.path.isdir(snapshot_pref):
os.makedirs(snapshot_pref, exist_ok=True)
self.save_path = snapshot_pref
self.log_file = self.save_path + "log.txt"
self.readme = self.save_path + "README.md"
self.opt_file = self.save_path + "opt.log"
self.code_path = os.path.join(self.save_path, "code_{}/".format(date))
self.exclude_dirs = exclude_dirs
self.max_file_size = max_file_size
if os.path.isfile(self.readme):
os.remove(self.readme)
if not os.path.isdir(self.code_path):
os.mkdir(self.code_path)
self.copy_code(dst=self.code_path)
self.tee_stdout(os.path.join(snapshot_pref, "log.txt"))
print ("|===>Backups will be saved at", self.save_path)
def copy_code(self, src="./", dst="./code/"):
start_time = time.time()
file_abs_list = []
src_abs = os.path.abspath(src)
for root, dirs, files in os.walk(src_abs):
exclude_flag = True in [root.find(exclude_dir)>=0 for exclude_dir in self.exclude_dirs]
if not exclude_flag:
for name in files:
file_abs_list.append(root + "/" + name)
for file_abs in file_abs_list:
file_split = file_abs.split("/")[-1].split('.')
# if len(file_split) >= 2 and file_split[1] == "py":
if os.path.getsize(file_abs) / 1024 / 1024 < self.max_file_size and not file_split[-1] == "pyc":
src_file = file_abs
dst_file = dst + file_abs.replace(src_abs, "")
if not os.path.exists(os.path.dirname(dst_file)):
os.makedirs(os.path.dirname(dst_file))
# shutil.copyfile(src=src_file, dst=dst_file)
try:
shutil.copy2(src=src_file, dst=dst_file)
except:
print("copy file error")
print("|===>Backups using time: %.3f s"%(time.time() - start_time))
def tee_stdout(self, log_path):
log_file = open(log_path, 'a', 1)
stdout = sys.stdout
class Tee:
def write(self, string):
log_file.write(string)
stdout.write(string)
def flush(self):
log_file.flush()
stdout.flush()
sys.stdout = Tee()
def writeopt(self, opt):
with open(self.opt_file, "w") as f:
for k, v in opt.__dict__.items():
f.write(str(k)+": "+str(v)+"\n")
def writelog(self, input_data):
txt_file = open(self.log_file, 'a+')
txt_file.write(str(input_data) + "\n")
txt_file.close()
def writereadme(self, input_data):
txt_file = open(self.readme, 'a+')
txt_file.write(str(input_data) + "\n")
txt_file.close()
def gennetwork(self, var):
self.graph.draw(var=var)
def savenetwork(self):
self.graph.save(file_name=self.save_path+"network.svg")
"""def writeweights(self, input_data, block_id, layer_id, epoch_id):
txt_path = self.weight_folder + "conv_weight_" + str(epoch_id) + ".log"
txt_file = open(txt_path, 'a+')
write_str = "%d\t%d\t%d\t" % (epoch_id, block_id, layer_id)
for x in input_data:
write_str += str(x) + "\t"
txt_file.write(write_str+"\n")
def drawhist(self):
drawer = DrawHistogram(txt_folder=self.weight_folder, fig_folder=self.weight_fig_folder)
drawer.draw()"""