-
Notifications
You must be signed in to change notification settings - Fork 1
/
script_margin_study.py
104 lines (76 loc) · 2.89 KB
/
script_margin_study.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
from script_run_train import run_script
import yaml
from datetime import datetime
from utils.utils import dump_info
def statup_session(**arg):
session = arg['session']
model = arg['model']
margin = arg['margin']
results = arg['results']
cmd = arg['cmd']
print("[SCRIPT FILE] "+ model)
dest_network = 'network'
dest_session = 'session'
network_cfg(model,dest_network)
sess = set_margin(session,dest_session,margin)
run_script( cmd = cmd,
model = dest_network,
session = dest_session,
plot = plot,
results = results
)
def set_margin(src,dest,value):
dest_path = 'sessions/'+ dest+ '.yaml'
src_path = 'sessions/'+ src+ '.yaml'
if os.path.isfile(dest_path):
os.remove(dest_path)
session = yaml.load(open(src_path),Loader=yaml.FullLoader)
session['loss_function']['margin'] = value
session['train']['max_epochs'] = 21
session['train']['report_val'] = 3
session['train']['fraction'] = round(1/5,2)
session['test']['fraction'] = round(1/3,2)
with open(dest_path, 'w') as file:
documents = yaml.dump(session, file)
return(session)
def network_cfg(src_model,dest_model):
if os.path.isfile('model_cfg/'+ dest_model+ '.yaml'):
os.remove('model_cfg/'+ dest_model+ '.yaml')
network = yaml.load(open('model_cfg/'+ model + '.yaml'),Loader=yaml.FullLoader)
#network_cnf.yaml
network['backbone']['train'] = True
network['attention']['train'] = True
network['outlayer']['train'] = True
with open('model_cfg/' + dest_model + '.yaml', 'w') as file:
documents = yaml.dump(network, file)
return(network)
if __name__ == '__main__':
TYPE_ = 'cross_val'
root = "checkpoints/"
CMD = 'train_knn.py'
SEQUENCES = ['ex0','ex2','ex5','ex6','ex8']
models = ['1bb_1a_norm','2bb_1a_norm','3bb_1a_norm','4bb_1a_norm','5bb_1a_norm']
dest_model = 'network'
margin_array = [0.0,0.3,0.5,0.7,0.8,0.85,0.9,0.95]
plot = 0
results = 'margin_tunning.txt'
for ex in SEQUENCES:
dump_info(results,"",flag='a')
for model in models:
for margin in margin_array:
try:
# Build Argument
s = '%02d'%(int(ex[-1]))
session = TYPE_ + '_' + s
statup_session(
cmd = CMD,
model = model,
session = session,
results= results,
margin = margin,
plot = plot)
except KeyboardInterrupt:
print("[WRN] Exiting APP")
exit(0)
print("***********************************************")