Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anserini: #2122 #1571

Merged
merged 8 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions pyserini/encode/_slim.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def __init__(self, model_name_or_path, tokenizer_name=None, fusion_weight=.99, d
self.model.to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name_or_path)
self.reverse_vocab = {v: k for k, v in self.tokenizer.vocab.items()}
self.weight_range = 5
self.quant_range = 256

def encode(self, text, max_length=256, topk=20, return_sparse=False, **kwargs):
inputs = self.tokenizer(
Expand All @@ -31,8 +33,15 @@ def encode(self, text, max_length=256, topk=20, return_sparse=False, **kwargs):
full_router_repr = torch.log(1 + torch.relu(logits)) * attention_mask.unsqueeze(-1)
expert_weights, expert_ids = torch.topk(full_router_repr, dim=2, k=topk) # B x T x topk
min_expert_weight = torch.min(expert_weights, -1, True)[0]
sparse_expert_weights = torch.where(full_router_repr >= min_expert_weight, full_router_repr, 0)
return self._output_to_weight_dicts(expert_weights.cpu(), expert_ids.cpu(), sparse_expert_weights.cpu(), attention_mask.cpu(), return_sparse)[0]
sparse_expert_weights = torch.where(full_router_repr >= min_expert_weight, full_router_repr, torch.tensor(0, dtype=full_router_repr.dtype))
if return_sparse:
raw_weights, sparse_tok = self._output_to_weight_dicts(expert_weights.cpu(), expert_ids.cpu(), sparse_expert_weights.cpu(), attention_mask.cpu(), return_sparse)[0]
return self._get_encoded_query_token_wight_dicts([raw_weights])[0], sparse_tok
else:
raw_weights = self._output_to_weight_dicts(expert_weights.cpu(), expert_ids.cpu(), sparse_expert_weights.cpu(), attention_mask.cpu(), return_sparse)[0]
return self._get_encoded_query_token_wight_dicts([raw_weights])[0]
# return self._output_to_weight_dicts(expert_weights.cpu(), expert_ids.cpu(), sparse_expert_weights.cpu(), attention_mask.cpu(), return_sparse)[0]


def _output_to_weight_dicts(self, batch_expert_weights, batch_expert_ids, batch_sparse_expert_weights, batch_attention, return_sparse):
to_return = []
Expand All @@ -59,4 +68,14 @@ def _output_to_weight_dicts(self, batch_expert_weights, batch_expert_ids, batch_
to_return.append((fusion_vector, tok_vector))
else:
to_return.append(fusion_vector)
return to_return

def _get_encoded_query_token_wight_dicts(self, tok_weights):
to_return = []
for _tok_weight in tok_weights:
_weights = {}
for token, weight in _tok_weight.items():
weight_quanted = round(weight / self.weight_range * self.quant_range)
_weights[token] = weight_quanted
to_return.append(_weights)
return to_return
15 changes: 14 additions & 1 deletion pyserini/encode/_splade.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ def __init__(self, model_name_or_path, tokenizer_name=None, device='cpu'):
self.model.to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name_or_path)
self.reverse_voc = {v: k for k, v in self.tokenizer.vocab.items()}
self.weight_range = 5
self.quant_range = 256

def encode(self, text, max_length=256, **kwargs):
inputs = self.tokenizer([text], max_length=max_length, padding='longest',
Expand All @@ -23,7 +25,8 @@ def encode(self, text, max_length=256, **kwargs):
batch_aggregated_logits, _ = torch.max(torch.log(1 + torch.relu(batch_logits))
* input_attention.unsqueeze(-1), dim=1)
batch_aggregated_logits = batch_aggregated_logits.cpu().detach().numpy()
return self._output_to_weight_dicts(batch_aggregated_logits)[0]
raw_weights = self._output_to_weight_dicts(batch_token_ids, batch_weights)
return self._get_encoded_query_token_wight_dicts(raw_weights)[0]

def _output_to_weight_dicts(self, batch_aggregated_logits):
to_return = []
Expand All @@ -33,3 +36,13 @@ def _output_to_weight_dicts(self, batch_aggregated_logits):
d = {self.reverse_voc[k]: float(v) for k, v in zip(list(col), list(weights))}
to_return.append(d)
return to_return

def _get_encoded_query_token_wight_dicts(self, tok_weights):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @AileenLin - if we implement the quantization and API changes we discussed on the Java end, we wouldn't need this on the Python end, right?

to_return = []
for _tok_weight in tok_weights:
_weights = {}
for token, weight in _tok_weight.items():
weight_quanted = round(weight / self.weight_range * self.quant_range)
_weights[token] = weight_quanted
to_return.append(_weights)
return to_return
16 changes: 15 additions & 1 deletion pyserini/encode/_unicoil.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def __init__(self, model_name_or_path, tokenizer_name=None, device='cpu'):
self.model = UniCoilEncoder.from_pretrained(model_name_or_path)
self.model.to(self.device)
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name or model_name_or_path)
self.weight_range = 5
self.quant_range = 256

def encode(self, text, **kwargs):
max_length = 128 # hardcode for now
Expand All @@ -152,7 +154,8 @@ def encode(self, text, **kwargs):
return_tensors='pt').to(self.device)["input_ids"]
batch_weights = self.model(input_ids).cpu().detach().numpy()
batch_token_ids = input_ids.cpu().detach().numpy()
return self._output_to_weight_dicts(batch_token_ids, batch_weights)[0]
raw_weights = self._output_to_weight_dicts(batch_token_ids, batch_weights)
return self._get_encoded_query_token_wight_dicts(raw_weights)[0]

def _output_to_weight_dicts(self, batch_token_ids, batch_weights):
to_return = []
Expand All @@ -173,3 +176,14 @@ def _output_to_weight_dicts(self, batch_token_ids, batch_weights):
tok_weights[tok] += weight
to_return.append(tok_weights)
return to_return

def _get_encoded_query_token_wight_dicts(self, tok_weights):
to_return = []
for _tok_weight in tok_weights:
_weights = {}
for token, weight in _tok_weight.items():
weight_quanted = round(weight / self.weight_range * self.quant_range)
_weights[token] = weight_quanted
to_return.append(_weights)
return to_return

1 change: 1 addition & 0 deletions pyserini/pyclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
# Base Java classes
JString = autoclass('java.lang.String')
JFloat = autoclass('java.lang.Float')
JInt = autoclass('java.lang.Integer')
JPath = autoclass('java.nio.file.Path')
JPaths = autoclass('java.nio.file.Paths')
JList = autoclass('java.util.List')
Expand Down
143 changes: 126 additions & 17 deletions pyserini/search/lucene/_impact_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pyserini.encode import QueryEncoder, TokFreqQueryEncoder, UniCoilQueryEncoder, \
CachedDataQueryEncoder, SpladeQueryEncoder, SlimQueryEncoder
from pyserini.index import Document
from pyserini.pyclass import autoclass, JFloat, JArrayList, JHashMap
from pyserini.pyclass import autoclass, JFloat, JInt, JArrayList, JHashMap
from pyserini.util import download_prebuilt_index, download_encoded_corpus

logger = logging.getLogger(__name__)
Expand All @@ -53,16 +53,17 @@ class LuceneImpactSearcher:
QueryEncoder to encode query text
"""

def __init__(self, index_dir: str, query_encoder: Union[QueryEncoder, str], min_idf=0, encoder_type: str='pytorch'):
def __init__(self, index_dir: str, query_encoder: Union[QueryEncoder, str], min_idf=0, encoder_type: str='pytorch', prebuilt_index_name = None):
self.index_dir = index_dir
self.idf = self._compute_idf(index_dir)
self.min_idf = min_idf
self.object = JImpactSearcher(index_dir)
self.num_docs = self.object.get_total_num_docs()
self.encoder_type = encoder_type
self.query_encoder = query_encoder
self.prebuilt_index_name = prebuilt_index_name
if encoder_type == 'onnx':
if isinstance(query_encoder, str) or query_encoder is None:
if isinstance(query_encoder, str) and query_encoder is not None:
self.object.set_onnx_query_encoder(query_encoder)
else:
raise ValueError(f'Invalid query encoder type: {type(query_encoder)} for onnx encoder')
Expand Down Expand Up @@ -102,13 +103,12 @@ def from_prebuilt_index(cls, prebuilt_index_name: str, query_encoder: Union[Quer
return None

print(f'Initializing {prebuilt_index_name}...')
return cls(index_dir, query_encoder, min_idf, encoder_type)
return cls(index_dir, query_encoder, min_idf, encoder_type, prebuilt_index_name=prebuilt_index_name)

def encode(self, query):
if self.encoder_type == 'onnx':
encoded_query = self.object.encode_with_onnx(query)
else:
if self.encoder_type == 'pytorch':
encoded_query = self.query_encoder.encode(query)
else: raise ValueError(f'Invalid query encoder type: {type(query_encoder)} for encode')
return encoded_query

@staticmethod
Expand Down Expand Up @@ -140,13 +140,14 @@ def search(self, q: str, k: int = 10, fields=dict()) -> List[JImpactSearcherResu
for (field, boost) in fields.items():
jfields.put(field, JFloat(boost))

encoded_query = self.encode(q)

jquery = encoded_query
if self.encoder_type == 'pytorch':
encoded_query = self.encode(q)
jquery = encoded_query
for (token, weight) in encoded_query.items():
if token in self.idf and self.idf[token] > self.min_idf:
jquery.put(token, JFloat(weight))
jquery.put(token, JInt(weight))
else:
jquery = q

if not fields:
hits = self.object.search(jquery, k)
Expand Down Expand Up @@ -183,14 +184,14 @@ def batch_search(self, queries: List[str], qids: List[str],
query_lst = JArrayList()
qid_lst = JArrayList()
for q in queries:
encoded_query = self.encode(q)
jquery = JHashMap()
if self.encoder_type == 'pytorch':
encoded_query = self.encode(q)
for (token, weight) in encoded_query.items():
if token in self.idf and self.idf[token] > self.min_idf:
jquery.put(token, JFloat(weight))
jquery.put(token, JInt(weight))
else:
jquery = encoded_query
jquery = q
query_lst.add(jquery)

for qid in qids:
Expand All @@ -202,11 +203,28 @@ def batch_search(self, queries: List[str], qids: List[str],
jfields.put(field, JFloat(boost))

if not fields:
results = self.object.batch_search(query_lst, qid_lst, int(k), int(threads))
if self.encoder_type == 'onnx':
results = self.object.batch_search_queries(query_lst, qid_lst, int(k), int(threads))
else:
results = self.object.batch_search(query_lst, qid_lst, int(k), int(threads))
else:
results = self.object.batch_search_fields(query_lst, qid_lst, int(k), int(threads), jfields)
return {r.getKey(): r.getValue() for r in results.entrySet().toArray()}

def set_analyzer(self, analyzer):
"""Set the Java ``Analyzer`` to use.

Parameters
----------
analyzer : JAnalyzer
Java ``Analyzer`` object.
"""
self.object.set_analyzer(analyzer)

def set_language(self, language):
"""Set language of LuceneSearcher"""
self.object.set_language(language)

def doc(self, docid: Union[str, int]) -> Optional[Document]:
"""Return the :class:`Document` corresponding to ``docid``. The ``docid`` is overloaded: if it is of type
``str``, it is treated as an external collection ``docid``; if it is of type ``int``, it is treated as an
Expand All @@ -228,6 +246,97 @@ def doc(self, docid: Union[str, int]) -> Optional[Document]:
return None
return Document(lucene_document)

def set_rm3(self):
self.object.set_rm3()

def set_rm3(self, fb_terms=10, fb_docs=10, original_query_weight=float(0.5), debug=False, filter_terms=True):
"""Configure RM3 pseudo-relevance feedback.

Parameters
----------
fb_terms : int
RM3 parameter for number of expansion terms.
fb_docs : int
RM3 parameter for number of expansion documents.
original_query_weight : float
RM3 parameter for weight to assign to the original query.
debug : bool
Print the original and expanded queries as debug output.
filter_terms: bool
Whether to remove non-English terms.
"""
if self.object.reader.getTermVectors(0):
self.object.set_rm3(None, fb_terms, fb_docs, original_query_weight, debug, filter_terms)
elif self.object.reader.document(0).getField('raw'):
self.object.set_rm3('JsonVectorCollection', fb_terms, fb_docs, original_query_weight, debug, filter_terms)
elif self.prebuilt_index_name in ['msmarco-v1-passage', 'msmarco-v1-doc', 'msmarco-v1-doc-segmented']:
self.object.set_rm3('JsonCollection', fb_terms, fb_docs, original_query_weight, debug, filter_terms)
elif self.prebuilt_index_name in ['msmarco-v2-passage', 'msmarco-v2-passage-augmented']:
self.object.set_rm3('MsMarcoV2PassageCollection', fb_terms, fb_docs, original_query_weight, debug, filter_terms)
elif self.prebuilt_index_name in ['msmarco-v2-doc', 'msmarco-v2-doc-segmented']:
self.object.set_rm3('MsMarcoV2DocCollection', fb_terms, fb_docs, original_query_weight, debug, filter_terms)
else:
raise TypeError("RM3 is not supported for indexes without document vectors or raw texts.")

def unset_rm3(self):
"""Disable RM3 pseudo-relevance feedback."""
self.object.unset_rm3()

def is_using_rm3(self) -> bool:
"""Check if RM3 pseudo-relevance feedback is being performed."""
return self.object.use_rm3()

def set_rocchio(self):
self.object.set_rocchio()


def set_rocchio(self, top_fb_terms=10, top_fb_docs=10, bottom_fb_terms=10, bottom_fb_docs=10,
alpha=1, beta=0.75, gamma=0, debug=False, use_negative=False):
"""Configure Rocchio pseudo-relevance feedback.

Parameters
----------
top_fb_terms : int
Rocchio parameter for number of relevant expansion terms.
top_fb_docs : int
Rocchio parameter for number of relevant expansion documents.
bottom_fb_terms : int
Rocchio parameter for number of non-relevant expansion terms.
bottom_fb_docs : int
Rocchio parameter for number of non-relevant expansion documents.
alpha : float
Rocchio parameter for weight to assign to the original query.
beta: float
Rocchio parameter for weight to assign to the relevant document vector.
gamma: float
Rocchio parameter for weight to assign to the nonrelevant document vector.
debug : bool
Print the original and expanded queries as debug output.
use_negative : bool
Rocchio parameter to use negative labels.
"""
if self.object.reader.getTermVectors(0):
self.object.set_rocchio(None, top_fb_terms, top_fb_docs, bottom_fb_terms, bottom_fb_docs,
alpha, beta, gamma, debug, use_negative)
elif self.object.reader.document(0).getField('raw'):
self.object.set_rocchio('JsonVectorCollection', top_fb_terms, top_fb_docs, bottom_fb_terms, bottom_fb_docs,
alpha, beta, gamma, debug, use_negative)
elif self.prebuilt_index_name in ['msmarco-v1-passage', 'msmarco-v1-doc', 'msmarco-v1-doc-segmented']:
self.object.set_rocchio('JsonCollection', top_fb_terms, top_fb_docs, bottom_fb_terms, bottom_fb_docs,
alpha, beta, gamma, debug, use_negative)
# Note, we don't have any Pyserini 2CRs that use Rocchio for MS MARCO v2, so there's currently no
# corresponding code branch here. To avoid introducing bugs (without 2CR tests), we'll add when it's needed.
else:
raise TypeError("Rocchio is not supported for indexes without document vectors or raw texts.")

def unset_rocchio(self):
"""Disable Rocchio pseudo-relevance feedback."""
self.object.unset_rocchio()

def is_using_rocchio(self) -> bool:
"""Check if Rocchio pseudo-relevance feedback is being performed."""
return self.object.use_rocchio()

def doc_by_field(self, field: str, q: str) -> Optional[Document]:
"""Return the :class:`Document` based on a ``field`` with ``id``. For example, this method can be used to fetch
document based on alternative primary keys that have been indexed, such as an article's DOI. Method returns
Expand Down Expand Up @@ -327,7 +436,7 @@ def search(self, q: str, k: int = 10, fields=dict()) -> List[JImpactSearcherResu
jquery = JHashMap()
for (token, weight) in fusion_encoded_query.items():
if token in self.idf and self.idf[token] > self.min_idf:
jquery.put(token, JFloat(weight))
jquery.put(token, JInt(weight))

if self.sparse_vecs is not None:
search_k = k * (self.min_idf + 1)
Expand All @@ -348,7 +457,7 @@ def batch_search(self, queries: List[str], qids: List[str],
jquery = JHashMap()
for (token, weight) in fusion_encoded_query.items():
if token in self.idf and self.idf[token] > self.min_idf:
jquery.put(token, JFloat(weight))
jquery.put(token, JInt(weight))
query_lst.add(jquery)
sparse_encoded_queries[qid] = sparse_encoded_query

Expand Down
9 changes: 5 additions & 4 deletions tests/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,11 @@ def test_aggretriever_cocondenser_encoder_cmd(self):
def test_onnx_encode_unicoil(self):
temp_object = LuceneImpactSearcher(f'{self.index_dir}lucene9-index.cacm', 'SpladePlusPlusEnsembleDistil', encoder_type='onnx')

results = temp_object.encode("here is a test")
self.assertAlmostEqual(results.get("here"), 3.05345, delta=2e-4)
self.assertAlmostEqual(results.get("a"), 0.59636426, delta=2e-4)
self.assertAlmostEqual(results.get("test"), 2.9012794, delta=2e-4)
# this function will never be called in _impact_searcher, here to check quantization correctness
results = temp_object.object.encodeWithOnnx("here is a test")
self.assertAlmostEqual(results.get("here"), 156, delta=2e-4)
self.assertAlmostEqual(results.get("a"), 31, delta=2e-4)
self.assertAlmostEqual(results.get("test"), 149, delta=2e-4)

temp_object.close()
del temp_object
Expand Down