-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathretrieve_bm25.py
127 lines (92 loc) · 4.26 KB
/
retrieve_bm25.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from pyserini.search import LuceneSearcher, get_topics, get_qrels
import tempfile
from src.trec_eval import run_retriever, EvalFunction
import json
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--eval_dataset", required=True, type=str, help="eval dataset")
parser.add_argument("-q", "--query_augment_file_path", required=True, type=str, help="query_augment_file_path")
args = parser.parse_args()
args.query_augment_file_path
# parameter setup
if args.eval_dataset == 'trec-dl-19':
data_num = 19
elif args.eval_dataset == 'trec-dl-20':
data_num = 20
elif args.eval_dataset == 'ms-marco-dev':
data_num = 21
# load query augmentation file
if args.query_augment_file_path is not None:
query_augment_results = []
with open(args.query_augment_file_path, "r") as f:
for line in f:
per_result = line.strip().split('\t')
try:
query_augment_results.append(per_result[1])
except:
query_augment_results.append('')
# Retrieve passages using pyserini BM25.
searcher = LuceneSearcher.from_prebuilt_index('msmarco-v1-passage')
if data_num == 19:
topics = get_topics('dl19-passage')
elif data_num == 20:
topics = get_topics('dl20')
else:
topics = get_topics('msmarco-passage-dev-subset')
#### for Query expansion
k = 5
for i, key in enumerate(list(topics.keys())):
if args.query_augment_file_path is not None:
aug_num = len(query_augment_results[i]) // len(topics[key]['title'])
if aug_num < k:
aug_num = k
topics[key]['title'] = (topics[key]['title']+ ' ') * aug_num + ' ' + query_augment_results[i]
else:
topics[key]['title'] = topics[key]['title']
qrels = get_qrels(f'dl{data_num}-passage')
if data_num == 21:
rank_results = run_retriever(topics, searcher, qrels, k=1000)
else:
rank_results = run_retriever(topics, searcher, qrels, k=100)
# # # #######################################################
# # # save retrieval result
# ret_results = []
# for result in rank_results:
# question = result['query']
# answers = ['']
# ctxs = []
# for hit in result['hits']:
# # ctxs.append({'id': hit['rank'], 'title': '', 'text': f"[{hit['rank']}] " + hit['content']})
# ctxs.append({'id': hit['rank'], 'title': '', 'text': f"Passage: " + hit['content']})
# ret_results.append({'question': question, 'answers': answers, 'ctxs': ctxs})
# with open(f'', 'w') as outfile:
# json.dump(ret_results, outfile, indent=4)
# exit()
# # # #######################################################
def write_eval_file(rank_results, file):
with open(file, 'w') as f:
for i in range(len(rank_results)):
rank = 1
hits = rank_results[i]['hits']
for hit in hits:
f.write(f"{hit['qid']} Q0 {hit['docid']} {rank} {hit['score']} rank\n")
rank += 1
return True
# Evaluate nDCG@(num)
temp_file = tempfile.NamedTemporaryFile(delete=False).name
write_eval_file(rank_results, temp_file)
if data_num == 19 or data_num==20:
EvalFunction.eval(['-c', '-m', f'ndcg_cut.1', f'dl{data_num}-passage', temp_file])
EvalFunction.eval(['-c', '-m', f'ndcg_cut.5', f'dl{data_num}-passage', temp_file])
EvalFunction.eval(['-c', '-m', f'ndcg_cut.10', f'dl{data_num}-passage', temp_file]) # dl19-passage dl20-passage
# ###
# EvalFunction.eval(['-c', '-m', f'recall.10', f'dl{data_num}-passage', temp_file])
# EvalFunction.eval(['-c', '-m', f'recall.50', f'dl{data_num}-passage', temp_file])
# EvalFunction.eval(['-c', '-m', f'recall.100', f'dl{data_num}-passage', temp_file])
# EvalFunction.eval(['-c', '-M', '1', '-m', f'recip_rank', f'dl{data_num}-passage', temp_file])
# EvalFunction.eval(['-c', '-M', '5', '-m', f'recip_rank', f'dl{data_num}-passage', temp_file])
# EvalFunction.eval(['-c', '-M', '10', '-m', f'recip_rank', f'dl{data_num}-passage', temp_file])
else:
EvalFunction.eval(['-c', '-M', '10', '-m', f'recip_rank', f'msmarco-passage-dev-subset', temp_file])
EvalFunction.eval(['-c', '-m', f'recall.50', f'msmarco-passage-dev-subset', temp_file])
EvalFunction.eval(['-c', '-m', f'recall.1000', f'msmarco-passage-dev-subset', temp_file])