Skip to content

Commit

Permalink
Remove duplicate code in pyserini.encode and pyserini.search.faiss (#…
Browse files Browse the repository at this point in the history
…2008)

+ Reconciled the two different versions of QueryEncoder and AutoQueryEncoder;
  retaining only the version in pyserini.encode.
+ De-dupped encoder implementations in pyserini.search.faiss, folding into pyserini.encode as appropriate.

(Only remaining encoders in pyserini.search.faiss._searcher.py are ClipQueryEncoder and OpenAIQueryEncoder,
saved for a subsequent pass.)
  • Loading branch information
lintool authored Oct 12, 2024
1 parent 2733088 commit e68d544
Show file tree
Hide file tree
Showing 38 changed files with 815 additions and 818 deletions.
2 changes: 1 addition & 1 deletion integrations-optional/dense/test_ance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import unittest

from integrations.utils import clean_files, run_command, parse_score_qa, parse_score_msmarco
from pyserini.encode import QueryEncoder
from pyserini.search import get_topics
from pyserini.search.faiss._searcher import QueryEncoder


class TestAnce(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion integrations-optional/dense/test_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import unittest

from integrations.utils import clean_files, run_command, parse_score_qa
from pyserini.encode import QueryEncoder
from pyserini.search import get_topics
from pyserini.search.faiss._searcher import QueryEncoder


class TestDpr(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion integrations-optional/dense/test_tct_colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import unittest

from integrations.utils import clean_files, run_command, parse_score
from pyserini.encode import QueryEncoder
from pyserini.search import get_topics
from pyserini.search.faiss._searcher import QueryEncoder


class TestTctColBert(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class TestLtrMsmarcoDocument(unittest.TestCase):
def test_reranking(self):
if(os.path.isdir('ltr_test')):
if os.path.isdir('ltr_test'):
rmtree('ltr_test')
os.mkdir('ltr_test')
inp = 'run.msmarco-pass-doc.bm25.txt'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@
import unittest
from shutil import rmtree

from pyserini.search.lucene import LuceneSearcher


class TestLtrMsmarcoPassage(unittest.TestCase):
def test_reranking(self):
if(os.path.isdir('ltr_test')):
if os.path.isdir('ltr_test'):
rmtree('ltr_test')
os.mkdir('ltr_test')
inp = 'run.msmarco-passage.bm25tuned.txt'
Expand Down
5 changes: 3 additions & 2 deletions integrations/clprf/test_trec_covid_r5.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
# limitations under the License.
#

import gzip
import json
import os
import re
import shutil
import unittest
import json
import gzip
from random import randint

from pyserini.util import download_url, download_prebuilt_index


Expand Down
2 changes: 1 addition & 1 deletion integrations/lucenesearcher_anserini_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import os
from typing import List

from pyserini.util import get_cache_home
from pyserini.prebuilt_index_info import TF_INDEX_INFO
from pyserini.util import get_cache_home


class LuceneSearcherAnseriniMatchChecker:
Expand Down
2 changes: 1 addition & 1 deletion integrations/sparse/test_lucenesearcher_check_irst.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

import os
import unittest
from shutil import rmtree
from random import randint
from shutil import rmtree

from integrations.utils import run_command, parse_score

Expand Down
2 changes: 1 addition & 1 deletion integrations/sparse/test_search_pretokenized.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import os
import shutil
import unittest

from random import randint

from integrations.lucenesearcher_score_checker import LuceneSearcherScoreChecker


Expand Down
2 changes: 1 addition & 1 deletion integrations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
#

import os
import subprocess
import shutil
import subprocess


def clean_files(files):
Expand Down
1 change: 0 additions & 1 deletion pyserini/analysis/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@

# Wrappers around Anserini classes
JAnalyzerUtils = autoclass('io.anserini.analysis.AnalyzerUtils')
JDefaultEnglishAnalyzer = autoclass('io.anserini.analysis.DefaultEnglishAnalyzer')
JTweetAnalyzer = autoclass('io.anserini.analysis.TweetAnalyzer')
JHuggingFaceTokenizerAnalyzer = autoclass('io.anserini.analysis.HuggingFaceTokenizerAnalyzer')

Expand Down
3 changes: 2 additions & 1 deletion pyserini/demo/dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import json
import random

from pyserini.encode import DprQueryEncoder
from pyserini.search import get_topics
from pyserini.search.faiss import FaissSearcher, DprQueryEncoder
from pyserini.search.faiss import FaissSearcher
from pyserini.search.hybrid import HybridSearcher
from pyserini.search.lucene import LuceneSearcher

Expand Down
3 changes: 2 additions & 1 deletion pyserini/demo/msmarco.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import json
import random

from pyserini.encode import AnceQueryEncoder, TctColBertQueryEncoder
from pyserini.search import get_topics
from pyserini.search.faiss import FaissSearcher, TctColBertQueryEncoder, AnceQueryEncoder
from pyserini.search.faiss import FaissSearcher
from pyserini.search.hybrid import HybridSearcher
from pyserini.search.lucene import LuceneSearcher

Expand Down
8 changes: 3 additions & 5 deletions pyserini/dsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@
import os
import sys

from pyserini.search.faiss import FaissSearcher
from pyserini.search.faiss._searcher import TctColBertQueryEncoder, BinaryDenseSearcher

__all__ = ['SimpleDenseSearcher', 'BinaryDenseSearcher', 'TctColBertQueryEncoder']
from pyserini.encode import TctColBertQueryEncoder
from pyserini.search.faiss import FaissSearcher, BinaryDenseFaissSearcher


class SimpleDenseSearcher(FaissSearcher):
Expand All @@ -33,7 +31,7 @@ def __new__(cls, *args, **kwargs):
return super().__new__(cls)


class BinaryDenseSearcher(BinaryDenseSearcher):
class BinaryDenseSearcher(BinaryDenseFaissSearcher):
def __new__(cls, *args, **kwargs):
print('pyserini.dsearch.BinaryDenseSearcher class has been deprecated, '
'please use BinaryDenseSearcher from pyserini.search.faiss instead')
Expand Down
2 changes: 2 additions & 0 deletions pyserini/encode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
from ._aggretriever import AggretrieverDocumentEncoder, AggretrieverQueryEncoder
from ._ance import AnceEncoder, AnceDocumentEncoder, AnceQueryEncoder
from ._auto import AutoQueryEncoder, AutoDocumentEncoder
from ._bpr import BprQueryEncoder
from ._cached_data import CachedDataQueryEncoder
from ._cosdpr import CosDprEncoder, CosDprDocumentEncoder, CosDprQueryEncoder
from ._dkrr import DkrrDprQueryEncoder
from ._dpr import DprDocumentEncoder, DprQueryEncoder
from ._openai import OpenAIDocumentEncoder, OpenAIQueryEncoder, OPENAI_API_RETRY_DELAY
from ._slim import SlimQueryEncoder
Expand Down
64 changes: 34 additions & 30 deletions pyserini/encode/_aggretriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from transformers import AutoModelForMaskedLM, AutoTokenizer, PreTrainedModel
from pyserini.encode import DocumentEncoder, QueryEncoder

class BERTAggretrieverEncoder(PreTrainedModel):

class BertAggretrieverEncoder(PreTrainedModel):
config_class = BertConfig
base_model_prefix = 'encoder'
load_tf_weights = None
Expand Down Expand Up @@ -120,7 +121,7 @@ def forward(
return torch.cat((semantic_reps, lexical_reps), -1)


class DistlBERTAggretrieverEncoder(BERTAggretrieverEncoder):
class DistlBertAggretrieverEncoder(BertAggretrieverEncoder):
config_class = DistilBertConfig
base_model_prefix = 'encoder'
load_tf_weights = None
Expand All @@ -130,9 +131,9 @@ class AggretrieverDocumentEncoder(DocumentEncoder):
def __init__(self, model_name: str, tokenizer_name=None, device='cuda:0'):
self.device = device
if 'distilbert' in model_name.lower():
self.model = DistlBERTAggretrieverEncoder.from_pretrained(model_name)
self.model = DistlBertAggretrieverEncoder.from_pretrained(model_name)
else:
self.model = BERTAggretrieverEncoder.from_pretrained(model_name)
self.model = BertAggretrieverEncoder.from_pretrained(model_name)
self.model.to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name)

Expand Down Expand Up @@ -160,30 +161,33 @@ def encode(self, texts, titles=None, fp16=False, max_length=512, **kwargs):


class AggretrieverQueryEncoder(QueryEncoder):
def __init__(self, model_name: str, tokenizer_name=None, device='cuda:0'):
self.device = device
if 'distilbert' in model_name.lower():
self.model = DistlBERTAggretrieverEncoder.from_pretrained(model_name)
else:
self.model = BERTAggretrieverEncoder.from_pretrained(model_name)
self.model.to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name)

def encode(self, texts, fp16=False, max_length=32, **kwargs):
texts = [text for text in texts]
inputs = self.tokenizer(
texts,
max_length=max_length,
padding="longest",
truncation=True,
add_special_tokens=True,
return_tensors='pt'
)
inputs.to(self.device)
if fp16:
with autocast():
with torch.no_grad():
outputs = self.model(**inputs)
else:
def __init__(self, encoder_dir: str = None, tokenizer_name: str = None,
encoded_query_dir: str = None, device: str = 'cpu', **kwargs):
if encoder_dir:
self.device = device
if 'distilbert' in encoder_dir.lower():
self.model = DistlBertAggretrieverEncoder.from_pretrained(encoder_dir)
else:
self.model = BertAggretrieverEncoder.from_pretrained(encoder_dir)
self.model.to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or encoder_dir)
self.has_model = True
if (not self.has_model) and (not self.has_encoded_query):
raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one')

def encode(self, query: str, max_length: int=32):
if self.has_model:
inputs = self.tokenizer(
query,
max_length=max_length,
padding="longest",
truncation=True,
add_special_tokens=True,
return_tensors='pt'
)
inputs.to(self.device)
outputs = self.model(**inputs)
return outputs.detach().cpu().numpy()
embeddings = outputs.detach().cpu().numpy()
return embeddings.flatten()
else:
return super().encode(query)
71 changes: 53 additions & 18 deletions pyserini/encode/_ance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# limitations under the License.
#

from typing import Optional
from typing import Optional, List

import torch
from transformers import PreTrainedModel, RobertaConfig, RobertaModel, RobertaTokenizer
from transformers import PreTrainedModel, RobertaConfig, RobertaModel, RobertaTokenizer, requires_backends

from pyserini.encode import DocumentEncoder, QueryEncoder

Expand All @@ -30,6 +30,7 @@ class AnceEncoder(PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r'pooler', r'classifier']

def __init__(self, config: RobertaConfig):
requires_backends(self, 'torch')
super().__init__(config)
self.config = config
self.roberta = RobertaModel(config)
Expand All @@ -55,11 +56,7 @@ def init_weights(self):
self.embeddingHead.apply(self._init_weights)
self.norm.apply(self._init_weights)

def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
input_shape = input_ids.size()
device = input_ids.device
if attention_mask is None:
Expand Down Expand Up @@ -98,22 +95,60 @@ def encode(self, texts, titles=None, max_length=256, **kwargs):


class AnceQueryEncoder(QueryEncoder):
def __init__(self, encoder_dir: str = None, tokenizer_name: str = None,
encoded_query_dir: str = None, device: str = 'cpu', **kwargs):
super().__init__(encoded_query_dir)
if encoder_dir:
self.device = device
self.model = AnceEncoder.from_pretrained(encoder_dir)
self.model.to(self.device)
self.tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name or encoder_dir)
self.has_model = True
self.tokenizer.do_lower_case = True
if (not self.has_model) and (not self.has_encoded_query):
raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one')

def encode(self, query: str):
if self.has_model:
inputs = self.tokenizer(
[query],
max_length=64,
padding='longest',
truncation=True,
add_special_tokens=True,
return_tensors='pt'
)
inputs.to(self.device)
embeddings = self.model(inputs["input_ids"]).detach().cpu().numpy()
return embeddings.flatten()
else:
return super().encode(query)

def prf_encode(self, query: str):
if self.has_model:
inputs = self.tokenizer(
[query],
max_length=512,
padding='longest',
truncation=True,
add_special_tokens=False,
return_tensors='pt'
)
inputs.to(self.device)
embeddings = self.model(inputs["input_ids"]).detach().cpu().numpy()
return embeddings.flatten()
else:
return super().encode(query)

def __init__(self, model_name: str, tokenizer_name: str = None, device: str = 'cpu'):
self.device = device
self.model = AnceEncoder.from_pretrained(model_name)
self.model.to(self.device)
self.tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name or model_name)

def encode(self, query: str, **kwargs):
def prf_batch_encode(self, query: List[str]):
inputs = self.tokenizer(
[query],
max_length=64,
query,
max_length=512,
padding='longest',
truncation=True,
add_special_tokens=True,
add_special_tokens=False,
return_tensors='pt'
)
inputs.to(self.device)
embeddings = self.model(inputs["input_ids"]).detach().cpu().numpy()
return embeddings.flatten()
return embeddings
Loading

0 comments on commit e68d544

Please sign in to comment.