-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathmain_test.py
69 lines (56 loc) · 1.92 KB
/
main_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from test import *
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
### Initialize
opt = Args()
opt.set_graphtype('pivotMds_grid')
#opt.scale = 400
print(opt.__dict__)
### Modify the epoch number to the suitable value
epoch=1500
# model_execute = "ModelName_DatasetName_TrialID"
# model_execute = "GraphLSTM_pyg-grid_v2-demo1"
model_execute = "GraphLSTM_pyg-pivotMds_grid-demo1"
model_name = 'model_'+model_execute+'_'+str(epoch)+'.pkl'
model_file = opt.main_data_folder + 'model_save/' + model_name
opt.DGL_input = False # Set this value as True, if the trained model is GraphLSTM_dgl; Otherwise, set it as False.
opt.PYG_input = True # Set this value as False, if the trained model is GraphLSTM_pyg; Otherwise, set it as False.
folder_prefix = opt.main_data_folder + 'testing_results/'
bfs_order = False
scale_constant = 1
pred_scale_constant = 1
PA_corrected = True
cpu_mode = False
Scale_corrected = True
### Clean Folder
if os.path.exists(folder_prefix):
shutil.rmtree(folder_prefix)
if not os.path.exists(folder_prefix):
os.mkdir(folder_prefix)
### Read Model
if cpu_mode == True:
model = torch.load(model_file, map_location=lambda storage, loc: storage)
else:
model = torch.load(model_file)
# ## Copy model in testing_results
# shutil.copyfile(model_file,folder_prefix+model_name)
## Initialize Dataset
graph_dataset,valid_graph_dataset,test_graph_dataset = getdataset(opt)
## Begin test samples from test dataset
model_testdataset_inference_params = {
"max_samples":24,
"dataset":test_graph_dataset,
"test_params":{
"folder":folder_prefix+'test_random/',
"model":model,
"opt":opt,
"cpu_mode":cpu_mode,
"bfs_order":bfs_order,
"Scale_corrected":Scale_corrected,
"PA_corrected":PA_corrected,
"scale_constant":scale_constant,
"pred_scale_constant":pred_scale_constant
}
}
model_inference(model_testdataset_inference_params)