From 581539a7737a473739b1e3465df7d7e4e7203ede Mon Sep 17 00:00:00 2001 From: Ronak Date: Thu, 3 Feb 2022 06:34:22 -0500 Subject: [PATCH] MS MARCO v1 and v2 scripts rewrite (#66) --- convert_msmarco_doc_to_anserini.py | 1 - convert_msmarco_passages_doc_to_anserini.py | 22 ++--- msmarco-v1/augment_corpus.py | 76 +++++++++++++++ msmarco-v2/augment_corpus.py | 100 ++++++++++---------- setup.cfg | 2 + 5 files changed, 138 insertions(+), 63 deletions(-) create mode 100644 msmarco-v1/augment_corpus.py create mode 100644 setup.cfg diff --git a/convert_msmarco_doc_to_anserini.py b/convert_msmarco_doc_to_anserini.py index adb4aa7..5ef8049 100644 --- a/convert_msmarco_doc_to_anserini.py +++ b/convert_msmarco_doc_to_anserini.py @@ -45,7 +45,6 @@ def generate_output_dict(doc, predicted_queries): output_dict = generate_output_dict(doc, predicted_queries) f_out.write(json.dumps(output_dict) + '\n') - f_corpus.close() f_out.close() print('Done!') diff --git a/convert_msmarco_passages_doc_to_anserini.py b/convert_msmarco_passages_doc_to_anserini.py index b7a99f1..63760ce 100644 --- a/convert_msmarco_passages_doc_to_anserini.py +++ b/convert_msmarco_passages_doc_to_anserini.py @@ -1,6 +1,6 @@ ''' -Segment the documents and append their url, title, predicted queries to them. Then, they are saved into -json which can be used for indexing. +Segment the documents and append their url, title, predicted queries to them. +Then, they are saved into json which can be used for indexing. ''' import argparse @@ -9,14 +9,13 @@ import os import spacy from tqdm import tqdm -import re + def create_segments(doc_text, max_length, stride): doc_text = doc_text.strip() doc = nlp(doc_text[:10000]) sentences = [sent.string.strip() for sent in doc.sents] segments = [] - for i in range(0, len(sentences), stride): segment = " ".join(sentences[i:i+max_length]) segments.append(segment) @@ -24,6 +23,7 @@ def create_segments(doc_text, max_length, stride): break return segments + parser = argparse.ArgumentParser( description='Concatenate MS MARCO original docs with predicted queries') parser.add_argument('--original_docs_path', required=True, help='MS MARCO .tsv corpus file.') @@ -45,10 +45,11 @@ def create_segments(doc_text, max_length, stride): print('Spliting documents...') doc_id_ref = None -if args.predictions_path == None: + +if args.predictions_path is None: doc_ids_queries = zip(open(args.doc_ids_path)) else: - doc_ids_queries = zip(open(args.doc_ids_path),open(args.predictions_path)) + doc_ids_queries = zip(open(args.doc_ids_path), open(args.predictions_path)) for doc_id_query in tqdm(doc_ids_queries): doc_id = doc_id_query[0].strip() if doc_id != doc_id_ref: @@ -62,15 +63,14 @@ def create_segments(doc_text, max_length, stride): doc_seg = f'{doc_id}#{seg_id}' if seg_id < len(segments): segment = segments[seg_id] - if args.predictions_path == None: + if args.predictions_path is None: expanded_text = f'{doc_url} {doc_title} {segment}' else: predicted_queries_partial = doc_id_query[1] expanded_text = f'{doc_url} {doc_title} {segment} {predicted_queries_partial}' output_dict = {'id': doc_seg, 'contents': expanded_text} - f_out.write(json.dumps(output_dict) + '\n') - doc_id_ref = doc_id - + f_out.write(json.dumps(output_dict) + '\n') + doc_id_ref = doc_id f_corpus.close() f_out.close() -print('Done!') \ No newline at end of file +print('Done!') diff --git a/msmarco-v1/augment_corpus.py b/msmarco-v1/augment_corpus.py new file mode 100644 index 0000000..85b4496 --- /dev/null +++ b/msmarco-v1/augment_corpus.py @@ -0,0 +1,76 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +from datasets import load_dataset +import os +import json +from tqdm import tqdm +from pyserini.search import SimpleSearcher + + +def augment_corpus_with_doc2query_t5(dataset, searcher, f_out, num_queries, text_key="contents"): + print('Output docs...') + output = open(f_out, 'w') + counter = 0 + set_d2q_ids = set() + for i in tqdm(range(len(dataset))): + docid = dataset[i]["id"] + set_d2q_ids.add(docid) + output_dict = json.loads(searcher.doc(docid).raw()) + if num_queries == -1: + concatenated_queries = " ".join(dataset[i]["predicted_queries"]) + else: + concatenated_queries = " ".join(dataset[i]["predicted_queries"][:num_queries]) + output_dict[text_key] = f"{output_dict[text_key]}\n{concatenated_queries}" + counter += 1 + output.write(json.dumps(output_dict) + '\n') + counter_no_exp = 0 + for i in tqdm(range(searcher.num_docs)): + if searcher.doc(i).docid() not in set_d2q_ids: + output_dict = json.loads(searcher.doc(i).raw()) + counter_no_exp += 1 + output_dict[text_key] = f"{output_dict[text_key]}\n" + output.write(json.dumps(output_dict) + '\n') + output.close() + print(f'{counter + counter_no_exp} lines output. {counter_no_exp} lines with no expansions.') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Concatenate MS MARCO V1 corpus with predicted queries') + parser.add_argument('--hgf_d2q_dataset', required=True, + choices=['castorini/msmarco_v1_passage_doc2query-t5_expansions', + 'castorini/msmarco_v1_doc_segmented_doc2query-t5_expansions', + 'castorini/msmarco_v1_doc_doc2query-t5_expansions']) + parser.add_argument('--prebuilt_index', required=True, help='Prebuilt index name') + parser.add_argument('--output_psg_path', required=True, help='Output file for d2q-t5 augmented corpus.') + parser.add_argument('--num_queries', default=-1, type=int, help='Number of expansions used.') + parser.add_argument('--cache_dir', default=".", type=str, help='Path to cache the hgf dataset') + args = parser.parse_args() + + os.makedirs(args.output_psg_path, exist_ok=True) + + dataset = load_dataset(args.hgf_d2q_dataset, split="train", cache_dir=args.cache_dir) + if args.prebuilt_index in ['msmarco-v1-passage', 'msmarco-v1-doc-segmented', 'msmarco-v1-doc']: + searcher = SimpleSearcher.from_prebuilt_index(args.prebuilt_index) + else: + searcher = SimpleSearcher(args.prebuilt_index) + augment_corpus_with_doc2query_t5(dataset, + searcher, + os.path.join(args.output_psg_path, "docs.jsonl"), + args.num_queries) + print('Done!') diff --git a/msmarco-v2/augment_corpus.py b/msmarco-v2/augment_corpus.py index 002880c..d8a3f59 100644 --- a/msmarco-v2/augment_corpus.py +++ b/msmarco-v2/augment_corpus.py @@ -17,86 +17,84 @@ import argparse from datasets import load_dataset import os -import gzip import json from tqdm import tqdm -import glob from multiprocessing import Pool, Manager +from pyserini.search import SimpleSearcher -def load_docs(docid_to_doc, f_ins, text_key="passage"): - print("Loading docs") - counter = 0 - if text_key == "passage": - id_key = "pid" - else: - id_key = "docid" - for f_in in f_ins: - with gzip.open(f_in, 'rt', encoding='utf8') as in_fh: - for json_string in tqdm(in_fh): - input_dict = json.loads(json_string) - docid_to_doc[input_dict[id_key]] = input_dict - counter += 1 - print(f'{counter} docs loaded. Done!') - -def augment_corpus_with_doc2query_t5(dataset, f_out, start, end, num_queries, text_key="passage"): +def augment_corpus_with_doc2query_t5(dataset, expdocid_dict, f_out, start, end, num_queries, text_key="passage"): print('Output docs...') output = open(f_out, 'w') counter = 0 for i in tqdm(range(start, end)): docid = dataset[i]["id"] - output_dict = docid_to_doc[docid] - concatenated_queries = " ".join(dataset[i]["predicted_queries"][:num_queries]) - output_dict[text_key] = f"{output_dict[text_key]} {concatenated_queries}" + output_dict = docid2doc[docid] + expdocid_dict[docid] = 1 + if num_queries == -1: + concatenated_queries = " ".join(dataset[i]["predicted_queries"]) + else: + concatenated_queries = " ".join(dataset[i]["predicted_queries"][:num_queries]) + output_dict[text_key] = output_dict[text_key].replace("\n", " ") + output_dict[text_key] = f"{output_dict[text_key]}\n{concatenated_queries}" counter += 1 - output.write(json.dumps(output_dict) + '\n') + output.write(json.dumps(output_dict) + '\n') output.close() print(f'{counter} lines output. Done!') if __name__ == '__main__': parser = argparse.ArgumentParser( - description='Concatenate MS MARCO v2 corpus with predicted queries') - parser.add_argument('--hgf_d2q_dataset', required=True, + description='Concatenate MS MARCO V2 corpus with predicted queries') + parser.add_argument('--hgf_d2q_dataset', required=True, choices=['castorini/msmarco_v2_passage_doc2query-t5_expansions', - 'castorini/msmarco_v2_doc_segmented_doc2query-t5_expansions']) - parser.add_argument('--original_psg_path', required=True, help='Input corpus path') + 'castorini/msmarco_v2_doc_segmented_doc2query-t5_expansions', + 'castorini/msmarco_v2_doc_doc2query-t5_expansions']) + parser.add_argument('--index_path', required=True, help='Input index path') parser.add_argument('--output_psg_path', required=True, help='Output file for d2q-t5 augmented corpus.') parser.add_argument('--num_workers', default=1, type=int, help='Number of workers used.') - parser.add_argument('--num_queries', default=20, type=int, help='Number of expansions used.') - parser.add_argument('--task', default="passage", type=str, help='One of passage or document.') + parser.add_argument('--num_queries', default=-1, type=int, help='Number of expansions used.') parser.add_argument('--cache_dir', default=".", type=str, help='Path to cache the hgf dataset') + parser.add_argument('--task', default="passage", type=str, help='One of passage or document.') args = parser.parse_args() - psg_files = glob.glob(os.path.join(args.original_psg_path, '*.gz')) os.makedirs(args.output_psg_path, exist_ok=True) - - - manager = Manager() - docid_to_doc = manager.dict() - - dataset = load_dataset(args.hgf_d2q_dataset, split="train", cache_dir=args.cache_dir) + if args.index_path in ['msmarco-v2-passage', 'msmarco-v2-passage-augmented', + 'msmarco-v2-doc-segmented', 'msmarco-v2-doc']: + searcher = SimpleSearcher.from_prebuilt_index(args.index_path) + else: + searcher = SimpleSearcher(args.index_path) + if searcher.num_docs != len(dataset): + print("Total number of expanded queries: {}".format(len(dataset))) + print('Total passages loaded: {}'.format(searcher.num_docs)) + manager = Manager() + docid2doc = manager.dict() + for i in tqdm(range(searcher.num_docs)): + doc = searcher.doc(i) + docid2doc[doc.docid()] = json.loads(doc.raw()) pool = Pool(args.num_workers) - num_files_per_worker = (len(psg_files) // args.num_workers) - for i in range(args.num_workers): - pool.apply_async(load_docs, (docid_to_doc, psg_files[i*num_files_per_worker: min(len(dataset), (i+1)*num_files_per_worker)], args.task)) - pool.close() - pool.join() - assert len(docid_to_doc) == len(dataset) - print('Total passages loaded: {}'.format(len(docid_to_doc))) - - - pool = Pool(args.num_workers) - num_examples_per_worker = (len(docid_to_doc)//args.num_workers) + 1 + expdocid_dict = manager.dict() for i in range(args.num_workers): f_out = os.path.join(args.output_psg_path, 'dt5q_aug_psg' + str(i) + '.json') - pool.apply_async(augment_corpus_with_doc2query_t5 ,(dataset, f_out, - i*(num_examples_per_worker), - min(len(docid_to_doc), (i+1)*num_examples_per_worker), - args.num_queries, args.task)) - + print(f_out) + start = i * (searcher.num_docs // args.num_workers) + end = (i + 1) * (searcher.num_docs // args.num_workers) + if i == args.num_workers - 1: + end = searcher.num_docs + pool.apply_async(augment_corpus_with_doc2query_t5, + args=(dataset, expdocid_dict, f_out, start, end, args.num_queries, args.task, )) pool.close() pool.join() + if len(docid2doc) != len(expdocid_dict): + f_out = os.path.join(args.output_psg_path, 'dt5q_aug_psg' + str(args.num_workers - 1) + '.json') + with open(f_out, 'a') as output: + for id in tqdm(docid2doc.keys()): + if id not in expdocid_dict: + print(f"doc {id} not expanded") + output.write(json.dumps(docid2doc[id]) + '\n') + expdocid_dict[id] = 1 + assert len(docid2doc) == len(expdocid_dict) print('Done!') + print(f'{searcher.num_docs} documents and {len(dataset)} expanded documents.') diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..79a16af --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[flake8] +max-line-length = 120 \ No newline at end of file