forked from JunHao-Zhu/FusionQuery
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
91 lines (79 loc) · 3.59 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import logging
import argparse
import json
from datetime import datetime
from sentence_transformers import SentenceTransformer
from FusionQuery.framework import FusionQuery
from utils.utility import *
def set_logger(name):
log_file = os.path.join('./logger', name)
logging.basicConfig(
format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S',
filename=log_file,
filemode='w'
)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
def parse_args(args=None):
parser = argparse.ArgumentParser(description='fusion query')
parser.add_argument('--data_root', type=str, default="./data/movie",
help='dataset name')
parser.add_argument("--data_name", type=str, default="movie",
help="saved root")
parser.add_argument("--lm_path", type=str, default="../../.cache/huggingface/transformers/sentence-bert")
parser.add_argument("--fusion_model", type=str, default="FusionQuery",
help="select from (FusionQuery, CASE, DART, LTM, TruthFinder, MajorityVoter)")
parser.add_argument("--types", nargs='+', required=True)
parser.add_argument("--iters", type=int, default=20, help="iteration for EM algorithm")
parser.add_argument("--thres_for_query", type=float, default=0.9, help="threshold the for query stage")
parser.add_argument("--thres_for_fusion", type=float, default=0.5, help="threshold for value veracity")
parser.add_argument("--gpu", type=int, default=0, help="gpu device")
parser.add_argument("--seed", type=int, default=2021, help='random seed')
return parser.parse_args()
def main():
args = parse_args()
data_root = args.data_root
data_name = args.data_name
qry_num = 210 if data_name == "movie" else 100
types = args.types
lm_path = args.lm_path
device = "cuda:{}".format(args.gpu) if args.gpu != -1 else "cpu"
max_iters = args.iters
fusion_model = args.fusion_model
query_thres = args.thres_for_query
veracity_thres = args.thres_for_fusion
with open("config.json", 'r') as config_file:
config = json.load(config_file)
if max_iters > 0:
config["max_iters"] = max_iters
set_random_seed(args.seed)
cur_time = ' ' + datetime.now().strftime("%F %T")
type_name = '-'.join(types)
logger_name = args.data_name + '-' + fusion_model + '-time-' + type_name + '-' + cur_time.replace(':', '-')
set_logger(logger_name)
logging.info(args)
logging.info(config[fusion_model])
logging.info("seed: {}".format(args.seed))
logging.info("{:*^50}".format("[Data Preparation]"))
lang_model = SentenceTransformer(lm_path).to(device)
src_g = prepare_graph(data_root + "/data2kg", data_name, types, lm=lang_model, line_transform=True)
src_num = len(src_g)
logging.info("{:*^50}".format("[Query Preparation]"))
qry_g = prepare_query(data_root, qry_num=qry_num, lm=lang_model)
logging.info("{:*^50}".format("[END]"))
fusion_model = load_fusion_model(fusion_model, src_num, config)
pipeline = FusionQuery(aggregator=fusion_model,
src_graphs=src_g,
threshold=[query_thres] * src_num,
veracity_thres=veracity_thres,
device=device)
pipeline.evaluate(qry_g, timing=True)
pipeline.statistic.print_stat_info()
if __name__ == '__main__':
main()