-
Notifications
You must be signed in to change notification settings - Fork 1
/
document_retriever.py
61 lines (48 loc) · 1.86 KB
/
document_retriever.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from pyserini.search.lucene import LuceneSearcher
import re
import json
from tqdm import tqdm
from vncorenlp import VnCoreNLP
ROOT_DIR = '/code'
vncorenlp_model = None
paragraphs = None
uni_searcher = None
ngram_searcher = None
def load_model():
print('Loading retrieval models...')
global vncorenlp_model, paragraphs, uni_searcher, ngram_searcher
vncorenlp_model = VnCoreNLP(f'{ROOT_DIR}/VnCoreNLP-1.1.1.jar', annotators="wseg,pos,parse")
paragraphs = []
with open(f'{ROOT_DIR}/data/paragraphs2.jsonl', 'r') as f:
for line in tqdm(f):
paragraphs.append(json.loads(line))
print(len(paragraphs))
uni_searcher = LuceneSearcher(f'{ROOT_DIR}/indexes/paragraphs2')
uni_searcher.set_language('vi')
ngram_searcher = LuceneSearcher(f'{ROOT_DIR}/indexes/paragraphs_tokenized2')
ngram_searcher.set_language('vi')
return vncorenlp_model
def clean(text: str):
text = text.lower()
text = re.sub('\W+', ' ', text)
text = re.sub('\s+', ' ', text)
return text.strip()
def retrieve_documents_unigram(question, k):
question = clean(question)
return uni_searcher.search(question, k=k)
def retrieve_documents_ngram(question, k):
question = ' '.join([tok for sent in vncorenlp_model.tokenize(question) for tok in sent])
question = clean(question)
return ngram_searcher.search(question, k=k)
def retrieve_documents(question, k=20):
uni_res = retrieve_documents_unigram(question, k=int(k / 2))
ngram_res = retrieve_documents_ngram(question, k=int(k / 2))
res = {}
for item in [*uni_res, *ngram_res]:
if item.docid in res:
res[item.docid] = max(item.score, res[item.docid])
else:
res[item.docid] = item.score
res = list(res.items())
res = sorted(res, key=lambda i: i[1], reverse=True)
return [paragraphs[int(item[0])]['contents'] for item in res]