Skip to content

Commit

Permalink
Add integration test for index partitioning for Faiss (#1074)
Browse files Browse the repository at this point in the history
  • Loading branch information
vivianliu0 authored Mar 8, 2022
1 parent 5773500 commit 7796685
Showing 1 changed file with 54 additions and 8 deletions.
62 changes: 54 additions & 8 deletions integrations/dense/test_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

"""Integration tests for create dense index """

import faiss
import os
import shutil
import unittest
from pyserini.search.faiss import FaissSearcher
from pyserini.search.lucene import LuceneImpactSearcher
from urllib.request import urlretrieve


class TestSearchIntegration(unittest.TestCase):
def setUp(self):
curdir = os.getcwd()
Expand All @@ -33,41 +33,87 @@ def setUp(self):
self.pyserini_root = '.'
self.temp_folders = []
self.corpus_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/corpus/jsonl/cacm.json'
self.corpus_path = f"{self.pyserini_root}/integrations/dense/temp_cacm/"
self.corpus_path = f'{self.pyserini_root}/integrations/dense/temp_cacm/'
os.makedirs(self.corpus_path, exist_ok=True)
self.temp_folders.append(self.corpus_path)
urlretrieve(self.corpus_url, os.path.join(self.corpus_path, 'cacm.json'))

def test_dpr_encode_as_faiss(self):
index_dir = f'{self.pyserini_root}/temp_index'
self.temp_folders.append(index_dir)
cmd1 = f"python -m pyserini.encode input --corpus {self.corpus_path} \
cmd1 = f'python -m pyserini.encode input --corpus {self.corpus_path} \
--fields text \
output --embeddings {index_dir} --to-faiss \
encoder --encoder facebook/dpr-ctx_encoder-multiset-base \
--fields text \
--batch 4 \
--device cpu"
--device cpu'
_ = os.system(cmd1)
searcher = FaissSearcher(
index_dir,
'facebook/dpr-question_encoder-multiset-base'
)
q_emb, hit = searcher.search("What is the solution of separable closed queueing networks?", k=1, return_vector=True)
q_emb, hit = searcher.search('What is the solution of separable closed queueing networks?', k=1, return_vector=True)
self.assertEqual(hit[0].docid, 'CACM-2445')
self.assertAlmostEqual(hit[0].vectors[0], -6.88267112e-01, places=4)
self.assertEqual(searcher.num_docs, 3204)

def test_dpr_encode_as_faiss_search_with_partitions(self):
# Create two partitions of the CACM index, search them individually, and merge results to compute top hit
index_dir = f'{self.pyserini_root}/temp_index'
os.makedirs(os.path.join(index_dir, 'partition1'), exist_ok=True)
os.makedirs(os.path.join(index_dir, 'partition2'), exist_ok=True)
self.temp_folders.append(index_dir)
cmd1 = f'python -m pyserini.encode input --corpus {self.corpus_path} \
--fields text \
output --embeddings {index_dir} --to-faiss \
encoder --encoder facebook/dpr-ctx_encoder-multiset-base \
--fields text \
--batch 4 \
--device cpu'
_ = os.system(cmd1)
index = faiss.read_index(os.path.join(index_dir, 'index'))
new_index_partition1 = faiss.IndexFlatIP(index.d)
new_index_partition2 = faiss.IndexFlatIP(index.d)
vectors_partition1 = index.reconstruct_n(0, index.ntotal // 2)
vectors_partition2 = index.reconstruct_n(index.ntotal // 2, index.ntotal - index.ntotal // 2)
new_index_partition1.add(vectors_partition1)
new_index_partition2.add(vectors_partition2)

faiss.write_index(new_index_partition1, os.path.join(index_dir, 'partition1/index'))
faiss.write_index(new_index_partition2, os.path.join(index_dir, 'partition2/index'))

with open(os.path.join(index_dir, 'partition1/docid'), 'w') as docid1, open(os.path.join(index_dir, 'partition2/docid'), 'w') as docid2:
with open(os.path.join(index_dir, 'docid'), 'r') as file:
for i in range(index.ntotal):
line = next(file)
if i < (index.ntotal // 2):
docid1.write(line)
else:
docid2.write(line)

searcher_partition1 = FaissSearcher(index_dir + '/partition1','facebook/dpr-question_encoder-multiset-base')
searcher_partition2 = FaissSearcher(index_dir + '/partition2','facebook/dpr-question_encoder-multiset-base')
q_emb, hit1 = searcher_partition1.search('What is the solution of separable closed queueing networks?', k=2, return_vector=True)
q_emb, hit2 = searcher_partition2.search('What is the solution of separable closed queueing networks?', k=2, return_vector=True)
merged_hits = hit1 + hit2
merged_hits.sort(key=lambda x: x.score, reverse=True)

self.assertEqual(merged_hits[0].docid, 'CACM-2445')
self.assertAlmostEqual(merged_hits[0].vectors[0], -6.88267112e-01, places=4)
self.assertEqual(searcher_partition1.num_docs, 1602)
self.assertEqual(searcher_partition2.num_docs, 1602)

def test_unicoil_encode_as_jsonl(self):
embedding_dir = f'{self.pyserini_root}/temp_embeddings'
self.temp_folders.append(embedding_dir)
cmd1 = f"python -m pyserini.encode input --corpus {self.corpus_path} \
cmd1 = f'python -m pyserini.encode input --corpus {self.corpus_path} \
--fields text \
output --embeddings {embedding_dir} \
encoder --encoder castorini/unicoil-msmarco-passage \
--fields text \
--batch 4 \
--device cpu"
--device cpu'
_ = os.system(cmd1)
index_dir = f'{self.pyserini_root}/temp_lucene'
self.temp_folders.append(index_dir)
Expand All @@ -78,7 +124,7 @@ def test_unicoil_encode_as_jsonl(self):
-impact -pretokenized -threads 12 -storeRaw'
_ = os.system(cmd2)
searcher = LuceneImpactSearcher(index_dir, query_encoder='castorini/unicoil-msmarco-passage')
hits = searcher.search("What is the solution of separable closed queueing networks?", k=1)
hits = searcher.search('What is the solution of separable closed queueing networks?', k=1)
hit = hits[0]
self.assertEqual(hit.docid, 'CACM-2712')
self.assertAlmostEqual(hit.score, 18.401899337768555, places=4)
Expand Down

0 comments on commit 7796685

Please sign in to comment.