diff --git a/scripts/kilt/anserini_retriever.py b/scripts/kilt/anserini_retriever.py index 0406ce331..e1a1b0b45 100644 --- a/scripts/kilt/anserini_retriever.py +++ b/scripts/kilt/anserini_retriever.py @@ -63,7 +63,7 @@ def _get_predictions_thread(arguments): doc_scores = [] if use_bigrams: - tokens = filter(lambda word: word not in STOPWORDS, word_tokenize(query)) + tokens = filter(lambda word: word.lower() not in STOPWORDS, word_tokenize(query)) if stem_bigrams: tokens = map(stemmer.stem, tokens) bigram_query = bigrams(tokens) diff --git a/scripts/kilt/convert_kilt_dpr_to_pyserini_format.py b/scripts/kilt/convert_kilt_dpr_to_pyserini_format.py new file mode 100644 index 000000000..f71eb9cd8 --- /dev/null +++ b/scripts/kilt/convert_kilt_dpr_to_pyserini_format.py @@ -0,0 +1,30 @@ +import argparse +import pickle +import csv +from tqdm import tqdm + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Convert KILT-dpr corpus into the docids file read by pyserini-dpr') + parser.add_argument('--input', required=True, help='Path to the kilt_w100_title.tsv file') + parser.add_argument('--mapping', required=True, help='Path to the mapping_KILT_title.p file') + parser.add_argument('--output', required=True, help='Name and path of the output file') + + args = parser.parse_args() + + KILT_mapping = pickle.load(open(args.mapping, "rb")) + + not_found = set() + with open(args.input, 'r') as f, open(args.output, 'w') as outp: + tsv = csv.reader(f, delimiter='\t') + next(tsv) # skip headers + for row in tqdm(tsv, mininterval=10.0, maxinterval=20.0): + title = row[2] + if title not in KILT_mapping: + not_found.add(title) + _ = outp.write('N/A\n') + else: + _ = outp.write(f'{KILT_mapping[title]}\n') + + print('Done!') + print(f'Not found: {not_found}') diff --git a/scripts/kilt/convert_kilt_to_document_jsonl.py b/scripts/kilt/convert_kilt_to_document_jsonl.py index 045ed5225..8acb5297e 100644 --- a/scripts/kilt/convert_kilt_to_document_jsonl.py +++ b/scripts/kilt/convert_kilt_to_document_jsonl.py @@ -26,7 +26,7 @@ doc["id"] = raw["_id"] doc["contents"] = "".join(raw["text"]) if args.bigrams: - tokens = filter(lambda word: word not in STOPWORDS, word_tokenize(doc["contents"])) + tokens = filter(lambda word: word.lower() not in STOPWORDS, word_tokenize(doc["contents"])) if args.stem: tokens = map(stemmer.stem, tokens) bigram_doc = bigrams(tokens) diff --git a/scripts/kilt/convert_kilt_to_passage_jsonl.py b/scripts/kilt/convert_kilt_to_passage_jsonl.py index 80205fcde..fe482f3d1 100644 --- a/scripts/kilt/convert_kilt_to_passage_jsonl.py +++ b/scripts/kilt/convert_kilt_to_passage_jsonl.py @@ -42,7 +42,7 @@ doc["id"] = f"{raw['_id']}-{i}" p = texts[i] if args.bigrams: - tokens = filter(lambda word: word not in STOPWORDS, word_tokenize(p)) + tokens = filter(lambda word: word.lower() not in STOPWORDS, word_tokenize(p)) if args.stem: tokens = map(stemmer.stem, tokens) bigram_doc = bigrams(tokens)