-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
83 lines (66 loc) · 2.91 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import sys
sys.path.append('./src')
from dgld.utils.evaluation import split_auc, score_other, curve_obtain, split_ap
from dgld.utils.common import seed_everything
from dgld.utils.argparser import parse_all_args
from dgld.utils.load_data import load_data, load_custom_data, load_truth_data
from dgld.utils.inject_anomalies import inject_contextual_anomalies, inject_structural_anomalies
from dgld.utils.common_params import Q_MAP, K, P
from dgld.utils.log import Dgldlog
from dgld.models import *
import numpy as np
import os
truth_list = ['weibo', 'reddit2']
if __name__ == "__main__":
args_dict, args = parse_all_args()
print(args_dict, flush=True)
data_name = args_dict['dataset']
save_path = args.save_path
exp_name = args.exp_name
log = Dgldlog(save_path, exp_name, args)
res_list_final = []
res_list_attrb = []
res_list_struct = []
res_list_ap = []
res_list_attrb_ap = []
res_list_struct_ap = []
seed_list = [i for i in range(1, args.runs + 1)]
for runs in range(args.runs):
log.update_runs()
seed = seed_list[runs]
seed_everything(seed)
args_dict['seed'] = seed
if data_name in truth_list:
graph = load_truth_data(data_path=args.data_path, dataset_name=data_name)
else:
graph = load_data(data_name)
graph = inject_contextual_anomalies(graph=graph, k=K, p=P, q=Q_MAP[data_name], seed=seed)
graph = inject_structural_anomalies(graph=graph, p=P, q=Q_MAP[data_name], seed=seed)
label = graph.ndata['label']
if args.model in ['RAND']:
model = eval(f'{args.model}(**args_dict["model"])')
else:
raise ValueError(f"{args.model} is not implemented!")
model.fit(graph, **args_dict["fit"])
result = model.predict(graph, **args_dict["predict"])
test_fpr, test_tpr = curve_obtain(label, result)
final_score, _, _ = split_auc(label, result)
the_ap_score = score_other(label, result)
res_list_final.append(final_score)
res_list_ap.append(the_ap_score)
print(args_dict, flush=True)
mean_final_auc = np.mean(np.array(res_list_final))
std_final_auc = np.std(np.array(res_list_final))
mean_final_ap = np.mean(np.array(res_list_ap))
std_final_ap = np.std(np.array(res_list_ap))
print("########### OVERALL AUC #############", flush=True)
print("final auc list: ", res_list_final, flush=True)
print("final auc mean: ", mean_final_auc, flush=True)
print("final auc std: ", std_final_auc, flush=True)
print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&", flush=True)
print("########### OTHER OVERALL PERFORMANCE #############", flush=True)
print("final ap list: ", res_list_ap, flush=True)
print("final ap mean: ", mean_final_ap, flush=True)
print("final ap std: ", std_final_ap, flush=True)
print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&", flush=True)
os._exit(0)