-
Notifications
You must be signed in to change notification settings - Fork 3
/
generate_configs.py
56 lines (44 loc) · 2.04 KB
/
generate_configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import glob
import yaml
from cfgnode import CfgNode
def same_netconfig_for_diligent_datasets(template_file, new_exp_name, output_yml, data_path, objects):
with open(template_file, "r") as f:
cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
cfg = CfgNode(cfg_dict)
output_yml_list = [os.path.join(output_yml, s + '.yml') for s in objects]
data_path_list = [os.path.join(data_path, s + 'PNG') for s in objects]
log_path_list = [os.path.join('./runs', new_exp_name, s) for s in objects]
for i in range(len(objects)):
cfg.experiment.log_path = log_path_list[i]
cfg.dataset.data_path = data_path_list[i]
with open(output_yml_list[i], 'w') as file:
documents = cfg.dump(stream=file)
def same_netconfig_for_other_datasets(template_file, new_exp_name, output_yml, data_path, objects):
with open(template_file, "r") as f:
cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
cfg = CfgNode(cfg_dict)
output_yml_list = [os.path.join(output_yml, s + '.yml') for s in objects]
data_path_list = [os.path.join(data_path, s) for s in objects]
log_path_list = [os.path.join('./runs', new_exp_name, s) for s in objects]
for i in range(len(objects)):
cfg.experiment.log_path = log_path_list[i]
cfg.dataset.data_path = data_path_list[i]
with open(output_yml_list[i], 'w') as file:
documents = cfg.dump(stream=file)
if __name__ == '__main__':
same_netconfig_for_diligent_datasets(
template_file='configs/template.yml',
new_exp_name='paper_config/diligent',
output_yml='./configs/diligent',
data_path='./data/DiLiGenT/pmsData',
objects=['ball', 'bear', 'buddha', 'cat', 'cow', 'goblet', 'harvest', 'pot1', 'pot2', 'reading'],
)
same_netconfig_for_other_datasets(
template_file='configs/template.yml',
new_exp_name='paper_config/apple',
output_yml='./configs/apple',
data_path='./data/Apple_Dataset',
objects=['apple', 'gourd1', 'gourd2'],
)
print('done')