Skip to content

Commit

Permalink
add encoded queires integrate tests (#490)
Browse files Browse the repository at this point in the history
* add integrate tests with encoded queries
  • Loading branch information
MXueguang authored Apr 21, 2021
1 parent de3e715 commit f29307a
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 57 deletions.
49 changes: 36 additions & 13 deletions integrations/test_ance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import socket
import unittest
from integrations.utils import clean_files, run_command, parse_score

from pyserini.search import get_topics
from pyserini.dsearch import QueryEncoder

class TestSearchIntegration(unittest.TestCase):
def setUp(self):
Expand All @@ -33,8 +34,8 @@ def setUp(self):
self.threads = 36
self.batch_size = 144

def test_msmarco_passage_ance_bf(self):
output_file = 'test_run.msmarco-passage.ance.bf.tsv'
def test_msmarco_passage_ance_bf_otf(self):
output_file = 'test_run.msmarco-passage.ance.bf.otf.tsv'
self.temp_files.append(output_file)
cmd1 = f'python -m pyserini.dsearch --topics msmarco-passage-dev-subset \
--index msmarco-passage-ance-bf \
Expand All @@ -48,11 +49,16 @@ def test_msmarco_passage_ance_bf(self):
stdout, stderr = run_command(cmd2)
score = parse_score(stdout, "MRR @10")
self.assertEqual(status, 0)
self.assertEqual(stderr, '')
self.assertAlmostEqual(score, 0.3302, delta=0.0001)

def test_msmarco_doc_ance_bf(self):
output_file = 'test_run.msmarco-doc.passage.ance-maxp.txt '
def test_msmarco_passage_ance_encoded_queries(self):
encoder = QueryEncoder.load_encoded_queries('ance-msmarco-passage-dev-subset')
topics = get_topics('msmarco-passage-dev-subset')
for t in topics:
self.assertTrue(topics[t]['title'] in encoder.embedding)

def test_msmarco_doc_ance_bf_otf(self):
output_file = 'test_run.msmarco-doc.passage.ance-maxp.otf.txt '
self.temp_files.append(output_file)
cmd1 = f'python -m pyserini.dsearch --topics msmarco-doc-dev \
--index msmarco-doc-ance-maxp-bf \
Expand All @@ -69,13 +75,18 @@ def test_msmarco_doc_ance_bf(self):
stdout, stderr = run_command(cmd2)
score = parse_score(stdout, "MRR @100")
self.assertEqual(status, 0)
self.assertEqual(stderr, '')
# We get a small difference, 0.3794 on macOS.
self.assertAlmostEqual(score, 0.3797, delta=0.0003)

def test_nq_test_ance_bf(self):
output_file = 'test_run.ance.nq-test.multi.bf.trec'
retrieval_file = 'test_run.ance.nq-test.multi.bf.json'
def test_msmarco_doc_ance_bf_encoded_queries(self):
encoder = QueryEncoder.load_encoded_queries('ance_maxp-msmarco-doc-dev')
topics = get_topics('maxp-msmarco-doc-dev')
for t in topics:
self.assertTrue(topics[t]['title'] in encoder.embedding)

def test_nq_test_ance_bf_otf(self):
output_file = 'test_run.ance.nq-test.multi.bf.otf.trec'
retrieval_file = 'test_run.ance.nq-test.multi.bf.otf.json'
self.temp_files.extend([output_file, retrieval_file])
cmd1 = f'python -m pyserini.dsearch --topics dpr-nq-test \
--index wikipedia-ance-multi-bf \
Expand All @@ -95,9 +106,15 @@ def test_nq_test_ance_bf(self):
self.assertEqual(status2, 0)
self.assertAlmostEqual(score, 0.8224, places=4)

def test_trivia_test_ance_bf(self):
output_file = 'test_run.ance.trivia-test.multi.bf.trec'
retrieval_file = 'test_run.ance.trivia-test.multi.bf.json'
def test_nq_test_ance_encoded_queries(self):
encoder = QueryEncoder.load_encoded_queries('dpr_multi-nq-test')
topics = get_topics('dpr-nq-test')
for t in topics:
self.assertTrue(topics[t]['title'] in encoder.embedding)

def test_trivia_test_ance_bf_otf(self):
output_file = 'test_run.ance.trivia-test.multi.bf.otf.trec'
retrieval_file = 'test_run.ance.trivia-test.multi.bf.otf.json'
self.temp_files.extend([output_file, retrieval_file])
cmd1 = f'python -m pyserini.dsearch --topics dpr-trivia-test \
--index wikipedia-ance-multi-bf \
Expand All @@ -117,6 +134,12 @@ def test_trivia_test_ance_bf(self):
self.assertEqual(status2, 0)
self.assertAlmostEqual(score, 0.8010, places=4)

def test_trivia_test_ance_encoded_queries(self):
encoder = QueryEncoder.load_encoded_queries('dpr_multi-trivia-test')
topics = get_topics('dpr-trivia-test')
for t in topics:
self.assertTrue(topics[t]['title'] in encoder.embedding)

def tearDown(self):
clean_files(self.temp_files)

Expand Down
10 changes: 9 additions & 1 deletion integrations/test_distilbert_kd.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import socket
import unittest
from integrations.utils import clean_files, run_command, parse_score
from pyserini.search import get_topics
from pyserini.dsearch import QueryEncoder


class TestSearchIntegration(unittest.TestCase):
Expand All @@ -33,7 +35,7 @@ def setUp(self):
self.threads = 36
self.batch_size = 144

def test_msmarco_passage_distilbert_kd_bf(self):
def test_msmarco_passage_distilbert_kd_bf_otf(self):
output_file = 'test_run.msmarco-passage.distilbert-dot-margin_mse-T2.bf.tsv'
self.temp_files.append(output_file)
cmd1 = f'python -m pyserini.dsearch --topics msmarco-passage-dev-subset \
Expand All @@ -51,6 +53,12 @@ def test_msmarco_passage_distilbert_kd_bf(self):
self.assertEqual(stderr, '')
self.assertAlmostEqual(score, 0.3251, delta=0.0001)

def test_msmarco_passage_distilbert_kd_encoded_queries(self):
encoder = QueryEncoder.load_encoded_queries('distilbert_kd-msmarco-passage-dev-subset')
topics = get_topics('msmarco-passage-dev-subset')
for t in topics:
self.assertTrue(topics[t]['title'] in encoder.embedding)

def tearDown(self):
clean_files(self.temp_files)

Expand Down
92 changes: 62 additions & 30 deletions integrations/test_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import socket
import unittest
from integrations.utils import clean_files, run_command, parse_score
from pyserini.search import get_topics
from pyserini.dsearch import QueryEncoder


class TestSearchIntegration(unittest.TestCase):
Expand All @@ -33,9 +35,9 @@ def setUp(self):
self.threads = 36
self.batch_size = 144

def test_dpr_nq_test_bf(self):
output_file = 'test_run.dpr.nq-test.multi.bf.trec'
retrieval_file = 'test_run.dpr.nq-test.multi.bf.json'
def test_dpr_nq_test_bf_otf(self):
output_file = 'test_run.dpr.nq-test.multi.bf.otf.trec'
retrieval_file = 'test_run.dpr.nq-test.multi.bf.otf.json'
self.temp_files.extend([output_file, retrieval_file])
cmd1 = f'python -m pyserini.dsearch --topics dpr-nq-test \
--index wikipedia-dpr-multi-bf \
Expand All @@ -55,9 +57,9 @@ def test_dpr_nq_test_bf(self):
self.assertEqual(status2, 0)
self.assertAlmostEqual(score, 0.7947, places=4)

def test_dpr_nq_test_bf_bm25_hybrid(self):
output_file = 'test_run.dpr.nq-test.multi.bf.bm25.trec'
retrieval_file = 'test_run.dpr.nq-test.multi.bf.bm25.json'
def test_dpr_nq_test_bf_bm25_hybrid_otf(self):
output_file = 'test_run.dpr.nq-test.multi.bf.otf.bm25.trec'
retrieval_file = 'test_run.dpr.nq-test.multi.bf.otf.bm25.json'
self.temp_files.extend([output_file, retrieval_file])
cmd1 = f'python -m pyserini.hsearch dense --index wikipedia-dpr-multi-bf \
--encoder facebook/dpr-question_encoder-multiset-base \
Expand All @@ -79,9 +81,15 @@ def test_dpr_nq_test_bf_bm25_hybrid(self):
self.assertEqual(status2, 0)
self.assertAlmostEqual(score, 0.8260, places=4)

def test_dpr_trivia_test_bf(self):
output_file = 'test_run.dpr.trivia-test.multi.bf.trec'
retrieval_file = 'test_run.dpr.trivia-test.multi.bf.json'
def test_dpr_nq_test_encoded_queries(self):
encoder = QueryEncoder.load_encoded_queries('dpr_multi-nq-test')
topics = get_topics('dpr-nq-test')
for t in topics:
self.assertTrue(topics[t]['title'] in encoder.embedding)

def test_dpr_trivia_test_bf_otf(self):
output_file = 'test_run.dpr.trivia-test.multi.bf.otf.trec'
retrieval_file = 'test_run.dpr.trivia-test.multi.bf.otf.json'
self.temp_files.extend([output_file, retrieval_file])
cmd1 = f'python -m pyserini.dsearch --topics dpr-trivia-test \
--encoder facebook/dpr-question_encoder-multiset-base \
Expand All @@ -101,9 +109,9 @@ def test_dpr_trivia_test_bf(self):
self.assertEqual(status2, 0)
self.assertAlmostEqual(score, 0.7887, places=4)

def test_dpr_trivia_test_bf_bm25_hybrid(self):
output_file = 'test_run.dpr.trivia-test.multi.bf.bm25.trec'
retrieval_file = 'test_run.dpr.trivia-test.multi.bf.bm25.json'
def test_dpr_trivia_test_bf_bm25_hybrid_otf(self):
output_file = 'test_run.dpr.trivia-test.multi.bf.otf.bm25.trec'
retrieval_file = 'test_run.dpr.trivia-test.multi.bf.otf.bm25.json'
self.temp_files.extend([output_file, retrieval_file])
cmd1 = f'python -m pyserini.hsearch dense --index wikipedia-dpr-multi-bf \
--encoder facebook/dpr-question_encoder-multiset-base \
Expand All @@ -125,9 +133,15 @@ def test_dpr_trivia_test_bf_bm25_hybrid(self):
self.assertEqual(status2, 0)
self.assertAlmostEqual(score, 0.8264, places=4)

def test_dpr_wq_test_bf(self):
output_file = 'test_run.dpr.wq-test.multi.bf.trec'
retrieval_file = 'test_run.dpr.wq-test.multi.bf.json'
def test_dpr_trivia_test_encoded_queries(self):
encoder = QueryEncoder.load_encoded_queries('dpr_multi-trivia-test')
topics = get_topics('dpr-trivia-test')
for t in topics:
self.assertTrue(topics[t]['title'] in encoder.embedding)

def test_dpr_wq_test_bf_otf(self):
output_file = 'test_run.dpr.wq-test.multi.bf.otf.trec'
retrieval_file = 'test_run.dpr.wq-test.multi.bf.otf.json'
self.temp_files.extend([output_file, retrieval_file])
cmd1 = f'python -m pyserini.dsearch --topics dpr-wq-test \
--index wikipedia-dpr-multi-bf \
Expand All @@ -147,9 +161,9 @@ def test_dpr_wq_test_bf(self):
self.assertEqual(status2, 0)
self.assertAlmostEqual(score, 0.7505, places=4)

def test_dpr_wq_test_bf_bm25_hybrid(self):
output_file = 'test_run.dpr.wq-test.multi.bf.bm25.trec'
retrieval_file = 'test_run.dpr.wq-test.multi.bf.bm25.json'
def test_dpr_wq_test_bf_bm25_hybrid_otf(self):
output_file = 'test_run.dpr.wq-test.multi.bf.otf.bm25.trec'
retrieval_file = 'test_run.dpr.wq-test.multi.bf.otf.bm25.json'
self.temp_files.extend([output_file, retrieval_file])
cmd1 = f'python -m pyserini.hsearch dense --index wikipedia-dpr-multi-bf \
--encoder facebook/dpr-question_encoder-multiset-base \
Expand All @@ -171,9 +185,15 @@ def test_dpr_wq_test_bf_bm25_hybrid(self):
self.assertEqual(status2, 0)
self.assertAlmostEqual(score, 0.7712, places=4)

def test_dpr_curated_test_bf(self):
output_file = 'test_run.dpr.curated-test.multi.bf.trec'
retrieval_file = 'test_run.dpr.curated-test.multi.bf.json'
def test_dpr_wq_test_encoded_queries(self):
encoder = QueryEncoder.load_encoded_queries('dpr_multi-wq-test')
topics = get_topics('dpr-wq-test')
for t in topics:
self.assertTrue(topics[t]['title'] in encoder.embedding)

def test_dpr_curated_test_bf_otf(self):
output_file = 'test_run.dpr.curated-test.multi.bf.otf.trec'
retrieval_file = 'test_run.dpr.curated-test.multi.bf.otf.json'
self.temp_files.extend([output_file, retrieval_file])
cmd1 = f'python -m pyserini.dsearch --topics dpr-curated-test \
--index wikipedia-dpr-multi-bf \
Expand All @@ -193,9 +213,9 @@ def test_dpr_curated_test_bf(self):
self.assertEqual(status2, 0)
self.assertAlmostEqual(score, 0.8876, places=4)

def test_dpr_curated_test_bf_bm25_hybrid(self):
output_file = 'test_run.dpr.curated-test.multi.bf.bm25.trec'
retrieval_file = 'test_run.dpr.curated-test.multi.bf.bm25.json'
def test_dpr_curated_test_bf_bm25_hybrid_otf(self):
output_file = 'test_run.dpr.curated-test.multi.bf.otf.bm25.trec'
retrieval_file = 'test_run.dpr.curated-test.multi.bf.otf.bm25.json'
self.temp_files.extend([output_file, retrieval_file])
cmd1 = f'python -m pyserini.hsearch dense --index wikipedia-dpr-multi-bf \
--encoder facebook/dpr-question_encoder-multiset-base \
Expand All @@ -217,9 +237,15 @@ def test_dpr_curated_test_bf_bm25_hybrid(self):
self.assertEqual(status2, 0)
self.assertAlmostEqual(score, 0.9006, places=4)

def test_dpr_squad_test_bf(self):
output_file = 'test_run.dpr.squad-test.multi.bf.trec'
retrieval_file = 'test_run.dpr.squad-test.multi.bf.json'
def test_dpr_curated_test_encoded_queries(self):
encoder = QueryEncoder.load_encoded_queries('dpr_multi-curated-test')
topics = get_topics('dpr-curated-test')
for t in topics:
self.assertTrue(topics[t]['title'] in encoder.embedding)

def test_dpr_squad_test_bf_otf(self):
output_file = 'test_run.dpr.squad-test.multi.bf.otf.trec'
retrieval_file = 'test_run.dpr.squad-test.multi.bf.otf.json'
self.temp_files.extend([output_file, retrieval_file])
cmd1 = f'python -m pyserini.dsearch --topics dpr-squad-test \
--index wikipedia-dpr-multi-bf \
Expand All @@ -239,9 +265,9 @@ def test_dpr_squad_test_bf(self):
self.assertEqual(status2, 0)
self.assertAlmostEqual(score, 0.5199, places=4)

def test_dpr_squad_test_bf_bm25_hybrid(self):
output_file = 'test_run.dpr.squad-test.multi.bf.bm25.trec'
retrieval_file = 'test_run.dpr.squad-test.multi.bf.bm25.json'
def test_dpr_squad_test_bf_bm25_hybrid_otf(self):
output_file = 'test_run.dpr.squad-test.multi.bf.otf.bm25.trec'
retrieval_file = 'test_run.dpr.squad-test.multi.bf.otf.bm25.json'
self.temp_files.extend([output_file, retrieval_file])
cmd1 = f'python -m pyserini.hsearch dense --index wikipedia-dpr-multi-bf \
--encoder facebook/dpr-question_encoder-multiset-base \
Expand All @@ -263,6 +289,12 @@ def test_dpr_squad_test_bf_bm25_hybrid(self):
self.assertEqual(status2, 0)
self.assertAlmostEqual(score, 0.7511, places=4)

def test_dpr_squad_test_encoded_queries(self):
encoder = QueryEncoder.load_encoded_queries('dpr_multi-squad-test')
topics = get_topics('dpr-squad-test')
for t in topics:
self.assertTrue(topics[t]['title'] in encoder.embedding)

def tearDown(self):
clean_files(self.temp_files)

Expand Down
12 changes: 10 additions & 2 deletions integrations/test_sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import socket
import unittest
from integrations.utils import clean_files, run_command, parse_score
from pyserini.search import get_topics
from pyserini.dsearch import QueryEncoder


class TestSearchIntegration(unittest.TestCase):
Expand All @@ -33,8 +35,8 @@ def setUp(self):
self.threads = 36
self.batch_size = 144

def test_msmarco_passage_sbert_bf(self):
output_file = 'test_run.msmarco-passage.sbert.bf.tsv'
def test_msmarco_passage_sbert_bf_otf(self):
output_file = 'test_run.msmarco-passage.sbert.bf.otf.tsv'
self.temp_files.append(output_file)
cmd1 = f'python -m pyserini.dsearch --topics msmarco-passage-dev-subset \
--index msmarco-passage-sbert-bf \
Expand All @@ -51,6 +53,12 @@ def test_msmarco_passage_sbert_bf(self):
self.assertEqual(stderr, '')
self.assertAlmostEqual(score, 0.3314, delta=0.0001)

def test_msmarco_passage_sbert_encoded_queries(self):
encoder = QueryEncoder.load_encoded_queries('sbert-msmarco-passage-dev-subset')
topics = get_topics('msmarco-passage-dev-subset')
for t in topics:
self.assertTrue(topics[t]['title'] in encoder.embedding)

def tearDown(self):
clean_files(self.temp_files)

Expand Down
Loading

0 comments on commit f29307a

Please sign in to comment.