This repository has been archived by the owner on May 23, 2024. It is now read-only.
forked from soap117/DeepRule
-
Notifications
You must be signed in to change notification settings - Fork 4
/
test_chart.py
106 lines (85 loc) · 3.6 KB
/
test_chart.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env python
import os
import json
import torch
import pprint
import argparse
import matplotlib
matplotlib.use("Agg")
from config import system_configs
from nnet.py_factory import NetworkFactory
from db.datasets import datasets
import importlib
torch.backends.cudnn.benchmark = False
def parse_args():
parser = argparse.ArgumentParser(description="Test CornerNet")
parser.add_argument("--cfg_file", dest="cfg_file", help="config file", default="CornerNetLine", type=str)
parser.add_argument("--testiter", dest="testiter",
help="test at iteration i",
default=50000, type=int)
parser.add_argument("--split", dest="split",
help="which split to use",
default="validation", type=str)
parser.add_argument('--cache_path', dest="cache_path", type=str)
parser.add_argument('--result_path', dest="result_path", type=str)
parser.add_argument('--tar_data_path', dest="tar_data_path", type=str)
parser.add_argument("--suffix", dest="suffix", default=None, type=str)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--data_dir", dest="data_dir", default="c:/work/linedata(1023)", type=str)
args = parser.parse_args()
return args
def make_dirs(directories):
for directory in directories:
if not os.path.exists(directory):
os.makedirs(directory)
def test(db, split, testiter, debug=False, suffix=None):
with torch.no_grad():
result_dir = system_configs.result_dir
result_dir = os.path.join(result_dir, str(testiter), split)
if suffix is not None:
result_dir = os.path.join(result_dir, suffix)
make_dirs([result_dir])
test_iter = system_configs.max_iter if testiter is None else testiter
print("loading parameters at iteration: {}".format(test_iter))
print("building neural network...")
nnet = NetworkFactory(db)
print("loading parameters...")
nnet.load_params(test_iter)
from testfile.test_line_cls_pure_real import testing
path = 'testfile.test_%s' % args.cfg_file
testing = importlib.import_module(path).testing
nnet.cuda()
nnet.eval_mode()
testing(db, nnet, result_dir, debug=debug)
if __name__ == "__main__":
args = parse_args()
if args.suffix is None:
cfg_file = os.path.join(system_configs.config_dir, args.cfg_file + ".json")
else:
cfg_file = os.path.join(system_configs.config_dir, args.cfg_file + "-{}.json".format(args.suffix))
print("cfg_file: {}".format(cfg_file))
with open(cfg_file, "r") as f:
configs = json.load(f)
configs["system"]["snapshot_name"] = args.cfg_file
configs["system"]["data_dir"] = args.data_dir
configs["system"]["cache_dir"] = args.cache_path
configs["system"]["result_dir"] = args.result_path
configs["system"]["tar_data_dir"] = args.tar_data_path
system_configs.update_config(configs["system"])
train_split = system_configs.train_split
val_split = system_configs.val_split
test_split = system_configs.test_split
split = {
"training": train_split,
"validation": val_split,
"testing": test_split
}[args.split]
print("loading all datasets...")
dataset = system_configs.dataset
print("split: {}".format(split))
testing_db = datasets[dataset](configs["db"], split)
print("system config...")
pprint.pprint(system_configs.full)
print("db config...")
pprint.pprint(testing_db.configs)
test(testing_db, args.split, args.testiter, args.debug, args.suffix)