Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support KILT for Pyserini's h/d/search #405

Merged
merged 52 commits into from
Apr 29, 2021
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
0c981e8
Support KILT for pyserini h/d/search
yuxuan-ji Mar 8, 2021
fe47a9b
Lazy import
yuxuan-ji Mar 8, 2021
829bc9b
Support KILT output
yuxuan-ji Mar 8, 2021
bdd9973
Delete write_result
yuxuan-ji Mar 14, 2021
5414bca
Create dir
yuxuan-ji Mar 14, 2021
dcb528c
Convert kilt dpr corpus
yuxuan-ji Mar 21, 2021
e9cafdb
Rename param for clarity and only passage delim when max_passage is on
yuxuan-ji Mar 21, 2021
32fc284
Only need one kilt format for now
yuxuan-ji Mar 21, 2021
5fb2ddb
Support explicit tokenizer
yuxuan-ji Mar 21, 2021
ef0f3df
Update script with index writer
yuxuan-ji Mar 21, 2021
35e7f86
Add some comments
yuxuan-ji Mar 21, 2021
3ceb6c0
Raise exception if file not fully iterate through
yuxuan-ji Mar 21, 2021
f1ae599
Add script to precompute embeddings
yuxuan-ji Mar 28, 2021
507c948
Load precompute embeddings from file instead of hardcoded dir filename
yuxuan-ji Mar 29, 2021
9a6ce9f
Merge & fix conflicts
yuxuan-ji Apr 22, 2021
11d8228
Add and fix inegration tests for KILT support
yuxuan-ji Apr 22, 2021
29f0dc6
Add test and prebuilt index
yuxuan-ji Apr 22, 2021
03bf13f
Rename to json
yuxuan-ji Apr 22, 2021
3532e9a
Merge branch 'support-kilt' into kilt-integration-tests
yuxuan-ji Apr 22, 2021
da28e7b
Add eval script
yuxuan-ji Apr 24, 2021
3d55884
Merge branch 'support-kilt' into kilt-integration-tests
yuxuan-ji Apr 24, 2021
fc63d58
Add eval to integration test
yuxuan-ji Apr 25, 2021
8adcb46
Add score check
yuxuan-ji Apr 25, 2021
be86d67
Merge branch 'master' into support-kilt
yuxuan-ji Apr 26, 2021
6b0f844
Fixes to merge conflict
yuxuan-ji Apr 26, 2021
9d20757
Merge branch 'support-kilt' into kilt-integration-tests
yuxuan-ji Apr 26, 2021
afb0746
Normalize path
yuxuan-ji Apr 26, 2021
c5c09b8
Merge branch 'support-kilt' into kilt-integration-tests
yuxuan-ji Apr 26, 2021
5d8d4f0
Download nq queries
yuxuan-ji Apr 26, 2021
5b29aa1
Merge branch 'support-kilt' into kilt-integration-tests
yuxuan-ji Apr 26, 2021
2241cba
Download topics file in eval script
yuxuan-ji Apr 26, 2021
370af57
Merge branch 'support-kilt' into kilt-integration-tests
yuxuan-ji Apr 26, 2021
c05434c
Download kilt topics in test
yuxuan-ji Apr 26, 2021
4aef4b1
Add threads and batch size
yuxuan-ji Apr 26, 2021
ebb5de8
Simplify getting output writer
yuxuan-ji Apr 26, 2021
cd6df52
Merge branch 'support-kilt' into kilt-integration-tests
yuxuan-ji Apr 26, 2021
71946c4
Add newline
yuxuan-ji Apr 26, 2021
b1586fa
Merge branch 'support-kilt' into kilt-integration-tests
yuxuan-ji Apr 26, 2021
f6f4618
Add license header
yuxuan-ji Apr 26, 2021
5c8b3e7
Merge branch 'support-kilt' into kilt-integration-tests
yuxuan-ji Apr 26, 2021
de2e221
Add triviaQA
yuxuan-ji Apr 27, 2021
fe44775
Merge branch 'support-kilt' into kilt-integration-tests
yuxuan-ji Apr 27, 2021
be574e8
Typo
yuxuan-ji Apr 27, 2021
1bfbe3c
Merge branch 'master' into support-kilt
yuxuan-ji Apr 27, 2021
c62ed7c
Use tsvinttopicreader
yuxuan-ji Apr 27, 2021
da472a5
Merge branch 'support-kilt' of github.com:yuxuan-ji/pyserini into sup…
yuxuan-ji Apr 27, 2021
8011295
Merge pull request #1 from yuxuan-ji/kilt-integration-tests
yuxuan-ji Apr 27, 2021
83a11d2
Fix unittest
yuxuan-ji Apr 27, 2021
ec2baf3
Script fix
yuxuan-ji Apr 28, 2021
5240b14
Script fix
yuxuan-ji Apr 28, 2021
bebbc8a
Add script to convert 100w tsv into jsonl
yuxuan-ji Apr 28, 2021
51457f1
Merge branch 'support-kilt' of https://github.com/yuxuan-ji/pyserini …
yuxuan-ji Apr 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pyserini/resources/jars/*.jar
collections/*
indexes/*
.vscode/
venv/
# build directories from `python3 setup.py sdist bdist_wheel`
build/
dist/
Expand Down
70 changes: 35 additions & 35 deletions pyserini/dsearch/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
import argparse
import os

import json
from tqdm import tqdm

from pyserini.dsearch import SimpleDenseSearcher, TctColBertQueryEncoder, \
QueryEncoder, DprQueryEncoder, AnceQueryEncoder, AutoQueryEncoder
from pyserini.query_iterator import query_iterator
from pyserini.search import get_topics
from pyserini.search.__main__ import write_result, write_result_max_passage
from pyserini.dsearch import SimpleDenseSearcher, TctColBertQueryEncoder, QueryEncoder, DprQueryEncoder,\
AnceQueryEncoder, AutoQueryEncoder
from pyserini.query_iterator import get_query_iterator, TopicsFormat
from pyserini.output_writer import get_output_writer, OutputFormat


# Fixes this error: "OMP: Error #15: Initializing libomp.a, but found libomp.dylib already initialized."
# https://stackoverflow.com/questions/53014306/error-15-initializing-libiomp5-dylib-but-found-libiomp5-dylib-already-initial
Expand All @@ -37,14 +36,17 @@ def define_dsearch_args(parser):
parser.add_argument('--encoder', type=str, metavar='path to query encoder checkpoint or encoder name',
required=False,
help="Path to query encoder pytorch checkpoint or hgf encoder model name")
parser.add_argument('--tokenizer', type=str, metavar='name or path',
required=False,
help="Path to a hgf tokenizer name or path")
parser.add_argument('--encoded-queries', type=str, metavar='path to query encoded queries dir or queries name',
required=False,
help="Path to query encoder pytorch checkpoint or hgf encoder model name")
parser.add_argument('--device', type=str, metavar='device to run query encoder', required=False, default='cpu',
help="Device to run query encoder, cpu or [cuda:0, cuda:1, ...]")


def init_query_encoder(encoder, topics_name, encoded_queries, device):
def init_query_encoder(encoder, tokenizer_name, topics_name, encoded_queries, device):
encoded_queries_map = {
'msmarco-passage-dev-subset': 'tct_colbert-msmarco-passage-dev-subset',
'dpr-nq-dev': 'dpr_multi-nq-dev',
Expand All @@ -57,30 +59,35 @@ def init_query_encoder(encoder, topics_name, encoded_queries, device):
}
if encoder:
if 'dpr' in encoder:
return DprQueryEncoder(encoder_dir=encoder, device=device)
return DprQueryEncoder(encoder_dir=encoder, tokenizer_name=tokenizer_name, device=device)
elif 'tct_colbert' in encoder:
return TctColBertQueryEncoder(encoder_dir=encoder, device=device)
return TctColBertQueryEncoder(encoder_dir=encoder, tokenizer_name=tokenizer_name, device=device)
elif 'ance' in encoder:
return AnceQueryEncoder(encoder_dir=encoder, device=device)
return AnceQueryEncoder(encoder_dir=encoder, tokenizer_name=tokenizer_name, device=device)
elif 'sentence' in encoder:
return AutoQueryEncoder(encoder_dir=encoder, device=device, pooling='mean', l2_norm=True)
return AutoQueryEncoder(encoder_dir=encoder, tokenizer_name=tokenizer_name, device=device,
pooling='mean', l2_norm=True)
else:
return AutoQueryEncoder(encoder_dir=encoder, device=device)
return AutoQueryEncoder(encoder_dir=encoder, tokenizer_name=tokenizer_name, device=device)

if encoded_queries:
if os.path.exists(encoded_queries):
return QueryEncoder(encoded_queries)
return QueryEncoder.load_encoded_queries(encoded_queries)
if topics_name in encoded_queries_map:
return QueryEncoder.load_encoded_queries(encoded_queries_map[topics_name])
return None
raise ValueError(f'No encoded queries for topic {topics_name}')


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Search a Faiss index.')
parser.add_argument('--topics', type=str, metavar='topic_name', required=True,
help="Name of topics. Available: msmarco-passage-dev-subset.")
parser.add_argument('--hits', type=int, metavar='num', required=False, default=1000, help="Number of hits.")
parser.add_argument('--msmarco', action='store_true', default=False, help="Output in MS MARCO format.")
parser.add_argument('--topics-format', type=str, metavar='format', default=TopicsFormat.DEFAULT.value,
help=f"Format of topics. Available: {[x.value for x in list(TopicsFormat)]}")
parser.add_argument('--output-format', type=str, metavar='format', default=OutputFormat.TREC.value,
help=f"Format of output. Available: {[x.value for x in list(OutputFormat)]}")
parser.add_argument('--output', type=str, metavar='path', required=True, help="Path to output file.")
parser.add_argument('--max-passage', action='store_true',
default=False, help="Select only max passage from document.")
Expand All @@ -95,20 +102,10 @@ def init_query_encoder(encoder, topics_name, encoded_queries, device):
define_dsearch_args(parser)
args = parser.parse_args()

if os.path.exists(args.topics) and args.topics.endswith('.json'):
topics = json.load(open(args.topics))
else:
topics = get_topics(args.topics)

# invalid topics name
if topics == {}:
print(f'Topic {args.topics} Not Found')
exit()
query_iterator = get_query_iterator(args.topics, TopicsFormat(args.topics_format))
topics = query_iterator.topics

query_encoder = init_query_encoder(args.encoder, args.topics, args.encoded_queries, args.device)
if not query_encoder:
print(f'No encoded queries for topic {args.topics}')
exit()
query_encoder = init_query_encoder(args.encoder, args.tokenizer, args.topics, args.encoded_queries, args.device)

if os.path.exists(args.index):
# create searcher from index directory
Expand All @@ -126,10 +123,16 @@ def init_query_encoder(encoder, topics_name, encoded_queries, device):
print(f'Running {args.topics} topics, saving to {output_path}...')
tag = 'Faiss'

with open(output_path, 'w') as target_file:
output_writer = get_output_writer(output_path, OutputFormat(args.output_format), 'w',
max_hits=args.hits, tag=tag, topics=topics,
use_max_passage=args.max_passage,
max_passage_delimiter=args.max_passage_delimiter,
max_passage_hits=args.max_passage_hits)

with output_writer:
batch_topics = list()
batch_topic_ids = list()
for index, (topic_id, text) in enumerate(tqdm(list(query_iterator(topics, args.topics)))):
for index, (topic_id, text) in enumerate(tqdm(query_iterator, total=len(topics.keys()))):
if args.batch_size <= 1 and args.threads <= 1:
hits = searcher.search(text, args.hits)
results = [(topic_id, hits)]
Expand All @@ -146,10 +149,7 @@ def init_query_encoder(encoder, topics_name, encoded_queries, device):
else:
continue

for result in results:
if args.max_passage:
write_result_max_passage(target_file, result, args.max_passage_delimiter,
args.max_passage_hits, args.msmarco, tag)
else:
write_result(target_file, result, args.hits, args.msmarco, tag)
for topic, hits in results:
output_writer.write(topic, hits)

results.clear()
20 changes: 12 additions & 8 deletions pyserini/dsearch/_dsearcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,14 @@ def _load_embeddings(encoded_query_dir):

class TctColBertQueryEncoder(QueryEncoder):

def __init__(self, encoder_dir: str = None, encoded_query_dir: str = None, device: str = 'cpu'):
def __init__(self, encoder_dir: str = None, tokenizer_name: str = None,
encoded_query_dir: str = None, device: str = 'cpu'):
super().__init__(encoded_query_dir)
if encoder_dir:
self.device = device
self.model = BertModel.from_pretrained(encoder_dir)
self.model.to(self.device)
self.tokenizer = BertTokenizer.from_pretrained(encoder_dir)
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name or encoder_dir)
self.has_model = True
if (not self.has_model) and (not self.has_encoded_query):
raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one')
Expand All @@ -111,13 +112,14 @@ def encode(self, query: str):

class DprQueryEncoder(QueryEncoder):

def __init__(self, encoder_dir: str = None, encoded_query_dir: str = None, device: str = 'cpu'):
def __init__(self, encoder_dir: str = None, tokenizer_name: str = None,
encoded_query_dir: str = None, device: str = 'cpu'):
super().__init__(encoded_query_dir)
if encoder_dir:
self.device = device
self.model = DPRQuestionEncoder.from_pretrained(encoder_dir)
self.model.to(self.device)
self.tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(encoder_dir)
self.tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(tokenizer_name or encoder_dir)
self.has_model = True
if (not self.has_model) and (not self.has_encoded_query):
raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one')
Expand All @@ -134,13 +136,14 @@ def encode(self, query: str):

class AnceQueryEncoder(QueryEncoder):

def __init__(self, encoder_dir: str = None, encoded_query_dir: str = None, device: str = 'cpu'):
def __init__(self, encoder_dir: str = None, tokenizer_name: str = None,
encoded_query_dir: str = None, device: str = 'cpu'):
super().__init__(encoded_query_dir)
if encoder_dir:
self.device = device
self.model = AnceEncoder.from_pretrained(encoder_dir)
self.model.to(self.device)
self.tokenizer = RobertaTokenizer.from_pretrained(encoder_dir)
self.tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name or encoder_dir)
self.has_model = True
if (not self.has_model) and (not self.has_encoded_query):
raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one')
Expand All @@ -164,14 +167,15 @@ def encode(self, query: str):

class AutoQueryEncoder(QueryEncoder):

def __init__(self, encoder_dir: str = None, encoded_query_dir: str = None, device: str = 'cpu',
def __init__(self, encoder_dir: str = None, tokenizer_name: str = None,
encoded_query_dir: str = None, device: str = 'cpu',
pooling: str = 'cls', l2_norm: bool = False):
super().__init__(encoded_query_dir)
if encoder_dir:
self.device = device
self.model = AutoModel.from_pretrained(encoder_dir)
self.model.to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(encoder_dir)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or encoder_dir)
self.has_model = True
self.pooling = pooling
self.l2_norm = l2_norm
Expand Down
Loading