-
Notifications
You must be signed in to change notification settings - Fork 35
/
add_doc2query_pass.py
executable file
·90 lines (67 loc) · 3.25 KB
/
add_doc2query_pass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python
#
# Copyright 2014+ Carnegie Mellon University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Adding predicted query fields for the MS MARCO *PASSAGE* collection.
https://github.com/castorini/docTTTTTquery
It reads all the predictions into memory
"""
import argparse
import json
from tqdm import tqdm
from flexneuart.text_proc.parse import SpacyTextParser
from flexneuart.io.stopwords import read_stop_words, STOPWORD_FILE
from flexneuart.io import jsonl_gen, FileWrapper
from flexneuart.config import SPACY_MODEL
from flexneuart.config import DOCID_FIELD, TEXT_FIELD_NAME
DOC2QUERY_FIELD_TEXT = 'doc2query_text'
DOC2QUERY_FIELD_TEXT_UNLEMM = 'doc2query_text_unlemm'
parser = argparse.ArgumentParser(description='Add doc2query fields to the existing JSONL data entries')
parser.add_argument('--input', metavar='input JSONL file', help='input JSONL file (can be compressed)',
type=str, required=True)
parser.add_argument('--output', metavar='output JSONL file', help='output JSONL file (can be compressed)',
type=str, required=True)
parser.add_argument('--target_fusion_field', metavar='target fusion field',
help='the name of the target field that will store concatenation of the lemmatized doc2query text and the original lemmatized text',
type=str, required=True)
parser.add_argument('--predictions_path',
required=True, metavar='doc2query predictions',
help='File containing predicted queries for passage data: one per each passage.')
args = parser.parse_args()
print(args)
stop_words = read_stop_words(STOPWORD_FILE, lower_case=True)
print(stop_words)
nlp = SpacyTextParser(SPACY_MODEL, stop_words, keep_only_alpha_num=True, lower_case=True)
doc_id_prev = None
predicted_queries = []
target_fusion_field = args.target_fusion_field
for line in tqdm(FileWrapper(args.predictions_path), desc='reading predictions'):
line = line.strip()
if line:
predicted_queries.append(line)
print(f'Read predictions for {len(predicted_queries)} passages')
pass_qty = 0
with FileWrapper(args.output, 'w') as outf:
for doce in tqdm(jsonl_gen(args.input), desc='adding doc2query fields'):
doc_id = doce[DOCID_FIELD]
text, text_unlemm = nlp.proc_text(predicted_queries[pass_qty])
doce[target_fusion_field] = doce[TEXT_FIELD_NAME] + ' ' + text
doce[DOC2QUERY_FIELD_TEXT] = text
doce[DOC2QUERY_FIELD_TEXT_UNLEMM] = text_unlemm
pass_qty += 1
outf.write(json.dumps(doce) + '\n')
if pass_qty != len(predicted_queries):
raise Exception(f'Mismatch in the number of predicted queries: {len(predicted_queries)} ' +
f' and the total number of passages: {pass_qty}')