Skip to content

Commit

Permalink
Add corresponding code patches for Anserini #2122 (#1571)
Browse files Browse the repository at this point in the history
castorini/anserini#2122
Add ability to parse raw text into docvectors on-the-fly for impact indexes

castorini/anserini#2165
Misalignment in SearchCollection and SimpleImpactSearcher implementation - so some changes in 2cr
  • Loading branch information
AileenLin authored Aug 21, 2023
1 parent 34862fd commit b713a51
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 25 deletions.
23 changes: 21 additions & 2 deletions pyserini/encode/_slim.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def __init__(self, model_name_or_path, tokenizer_name=None, fusion_weight=.99, d
self.model.to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name_or_path)
self.reverse_vocab = {v: k for k, v in self.tokenizer.vocab.items()}
self.weight_range = 5
self.quant_range = 256

def encode(self, text, max_length=256, topk=20, return_sparse=False, **kwargs):
inputs = self.tokenizer(
Expand All @@ -31,8 +33,15 @@ def encode(self, text, max_length=256, topk=20, return_sparse=False, **kwargs):
full_router_repr = torch.log(1 + torch.relu(logits)) * attention_mask.unsqueeze(-1)
expert_weights, expert_ids = torch.topk(full_router_repr, dim=2, k=topk) # B x T x topk
min_expert_weight = torch.min(expert_weights, -1, True)[0]
sparse_expert_weights = torch.where(full_router_repr >= min_expert_weight, full_router_repr, 0)
return self._output_to_weight_dicts(expert_weights.cpu(), expert_ids.cpu(), sparse_expert_weights.cpu(), attention_mask.cpu(), return_sparse)[0]
sparse_expert_weights = torch.where(full_router_repr >= min_expert_weight, full_router_repr, torch.tensor(0, dtype=full_router_repr.dtype))
if return_sparse:
raw_weights, sparse_tok = self._output_to_weight_dicts(expert_weights.cpu(), expert_ids.cpu(), sparse_expert_weights.cpu(), attention_mask.cpu(), return_sparse)[0]
return self._get_encoded_query_token_wight_dicts([raw_weights])[0], sparse_tok
else:
raw_weights = self._output_to_weight_dicts(expert_weights.cpu(), expert_ids.cpu(), sparse_expert_weights.cpu(), attention_mask.cpu(), return_sparse)[0]
return self._get_encoded_query_token_wight_dicts([raw_weights])[0]
# return self._output_to_weight_dicts(expert_weights.cpu(), expert_ids.cpu(), sparse_expert_weights.cpu(), attention_mask.cpu(), return_sparse)[0]


def _output_to_weight_dicts(self, batch_expert_weights, batch_expert_ids, batch_sparse_expert_weights, batch_attention, return_sparse):
to_return = []
Expand All @@ -59,4 +68,14 @@ def _output_to_weight_dicts(self, batch_expert_weights, batch_expert_ids, batch_
to_return.append((fusion_vector, tok_vector))
else:
to_return.append(fusion_vector)
return to_return

def _get_encoded_query_token_wight_dicts(self, tok_weights):
to_return = []
for _tok_weight in tok_weights:
_weights = {}
for token, weight in _tok_weight.items():
weight_quanted = round(weight / self.weight_range * self.quant_range)
_weights[token] = weight_quanted
to_return.append(_weights)
return to_return
15 changes: 14 additions & 1 deletion pyserini/encode/_splade.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ def __init__(self, model_name_or_path, tokenizer_name=None, device='cpu'):
self.model.to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name_or_path)
self.reverse_voc = {v: k for k, v in self.tokenizer.vocab.items()}
self.weight_range = 5
self.quant_range = 256

def encode(self, text, max_length=256, **kwargs):
inputs = self.tokenizer([text], max_length=max_length, padding='longest',
Expand All @@ -23,7 +25,8 @@ def encode(self, text, max_length=256, **kwargs):
batch_aggregated_logits, _ = torch.max(torch.log(1 + torch.relu(batch_logits))
* input_attention.unsqueeze(-1), dim=1)
batch_aggregated_logits = batch_aggregated_logits.cpu().detach().numpy()
return self._output_to_weight_dicts(batch_aggregated_logits)[0]
raw_weights = self._output_to_weight_dicts(batch_token_ids, batch_weights)
return self._get_encoded_query_token_wight_dicts(raw_weights)[0]

def _output_to_weight_dicts(self, batch_aggregated_logits):
to_return = []
Expand All @@ -33,3 +36,13 @@ def _output_to_weight_dicts(self, batch_aggregated_logits):
d = {self.reverse_voc[k]: float(v) for k, v in zip(list(col), list(weights))}
to_return.append(d)
return to_return

def _get_encoded_query_token_wight_dicts(self, tok_weights):
to_return = []
for _tok_weight in tok_weights:
_weights = {}
for token, weight in _tok_weight.items():
weight_quanted = round(weight / self.weight_range * self.quant_range)
_weights[token] = weight_quanted
to_return.append(_weights)
return to_return
16 changes: 15 additions & 1 deletion pyserini/encode/_unicoil.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def __init__(self, model_name_or_path, tokenizer_name=None, device='cpu'):
self.model = UniCoilEncoder.from_pretrained(model_name_or_path)
self.model.to(self.device)
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name or model_name_or_path)
self.weight_range = 5
self.quant_range = 256

def encode(self, text, **kwargs):
max_length = 128 # hardcode for now
Expand All @@ -152,7 +154,8 @@ def encode(self, text, **kwargs):
return_tensors='pt').to(self.device)["input_ids"]
batch_weights = self.model(input_ids).cpu().detach().numpy()
batch_token_ids = input_ids.cpu().detach().numpy()
return self._output_to_weight_dicts(batch_token_ids, batch_weights)[0]
raw_weights = self._output_to_weight_dicts(batch_token_ids, batch_weights)
return self._get_encoded_query_token_wight_dicts(raw_weights)[0]

def _output_to_weight_dicts(self, batch_token_ids, batch_weights):
to_return = []
Expand All @@ -173,3 +176,14 @@ def _output_to_weight_dicts(self, batch_token_ids, batch_weights):
tok_weights[tok] += weight
to_return.append(tok_weights)
return to_return

def _get_encoded_query_token_wight_dicts(self, tok_weights):
to_return = []
for _tok_weight in tok_weights:
_weights = {}
for token, weight in _tok_weight.items():
weight_quanted = round(weight / self.weight_range * self.quant_range)
_weights[token] = weight_quanted
to_return.append(_weights)
return to_return

1 change: 1 addition & 0 deletions pyserini/pyclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
# Base Java classes
JString = autoclass('java.lang.String')
JFloat = autoclass('java.lang.Float')
JInt = autoclass('java.lang.Integer')
JPath = autoclass('java.nio.file.Path')
JPaths = autoclass('java.nio.file.Paths')
JList = autoclass('java.util.List')
Expand Down
143 changes: 126 additions & 17 deletions pyserini/search/lucene/_impact_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pyserini.encode import QueryEncoder, TokFreqQueryEncoder, UniCoilQueryEncoder, \
CachedDataQueryEncoder, SpladeQueryEncoder, SlimQueryEncoder
from pyserini.index import Document
from pyserini.pyclass import autoclass, JFloat, JArrayList, JHashMap
from pyserini.pyclass import autoclass, JFloat, JInt, JArrayList, JHashMap
from pyserini.util import download_prebuilt_index, download_encoded_corpus

logger = logging.getLogger(__name__)
Expand All @@ -53,16 +53,17 @@ class LuceneImpactSearcher:
QueryEncoder to encode query text
"""

def __init__(self, index_dir: str, query_encoder: Union[QueryEncoder, str], min_idf=0, encoder_type: str='pytorch'):
def __init__(self, index_dir: str, query_encoder: Union[QueryEncoder, str], min_idf=0, encoder_type: str='pytorch', prebuilt_index_name = None):
self.index_dir = index_dir
self.idf = self._compute_idf(index_dir)
self.min_idf = min_idf
self.object = JImpactSearcher(index_dir)
self.num_docs = self.object.get_total_num_docs()
self.encoder_type = encoder_type
self.query_encoder = query_encoder
self.prebuilt_index_name = prebuilt_index_name
if encoder_type == 'onnx':
if isinstance(query_encoder, str) or query_encoder is None:
if isinstance(query_encoder, str) and query_encoder is not None:
self.object.set_onnx_query_encoder(query_encoder)
else:
raise ValueError(f'Invalid query encoder type: {type(query_encoder)} for onnx encoder')
Expand Down Expand Up @@ -102,13 +103,12 @@ def from_prebuilt_index(cls, prebuilt_index_name: str, query_encoder: Union[Quer
return None

print(f'Initializing {prebuilt_index_name}...')
return cls(index_dir, query_encoder, min_idf, encoder_type)
return cls(index_dir, query_encoder, min_idf, encoder_type, prebuilt_index_name=prebuilt_index_name)

def encode(self, query):
if self.encoder_type == 'onnx':
encoded_query = self.object.encode_with_onnx(query)
else:
if self.encoder_type == 'pytorch':
encoded_query = self.query_encoder.encode(query)
else: raise ValueError(f'Invalid query encoder type: {type(query_encoder)} for encode')
return encoded_query

@staticmethod
Expand Down Expand Up @@ -140,13 +140,14 @@ def search(self, q: str, k: int = 10, fields=dict()) -> List[JImpactSearcherResu
for (field, boost) in fields.items():
jfields.put(field, JFloat(boost))

encoded_query = self.encode(q)

jquery = encoded_query
if self.encoder_type == 'pytorch':
encoded_query = self.encode(q)
jquery = encoded_query
for (token, weight) in encoded_query.items():
if token in self.idf and self.idf[token] > self.min_idf:
jquery.put(token, JFloat(weight))
jquery.put(token, JInt(weight))
else:
jquery = q

if not fields:
hits = self.object.search(jquery, k)
Expand Down Expand Up @@ -183,14 +184,14 @@ def batch_search(self, queries: List[str], qids: List[str],
query_lst = JArrayList()
qid_lst = JArrayList()
for q in queries:
encoded_query = self.encode(q)
jquery = JHashMap()
if self.encoder_type == 'pytorch':
encoded_query = self.encode(q)
for (token, weight) in encoded_query.items():
if token in self.idf and self.idf[token] > self.min_idf:
jquery.put(token, JFloat(weight))
jquery.put(token, JInt(weight))
else:
jquery = encoded_query
jquery = q
query_lst.add(jquery)

for qid in qids:
Expand All @@ -202,11 +203,28 @@ def batch_search(self, queries: List[str], qids: List[str],
jfields.put(field, JFloat(boost))

if not fields:
results = self.object.batch_search(query_lst, qid_lst, int(k), int(threads))
if self.encoder_type == 'onnx':
results = self.object.batch_search_queries(query_lst, qid_lst, int(k), int(threads))
else:
results = self.object.batch_search(query_lst, qid_lst, int(k), int(threads))
else:
results = self.object.batch_search_fields(query_lst, qid_lst, int(k), int(threads), jfields)
return {r.getKey(): r.getValue() for r in results.entrySet().toArray()}

def set_analyzer(self, analyzer):
"""Set the Java ``Analyzer`` to use.
Parameters
----------
analyzer : JAnalyzer
Java ``Analyzer`` object.
"""
self.object.set_analyzer(analyzer)

def set_language(self, language):
"""Set language of LuceneSearcher"""
self.object.set_language(language)

def doc(self, docid: Union[str, int]) -> Optional[Document]:
"""Return the :class:`Document` corresponding to ``docid``. The ``docid`` is overloaded: if it is of type
``str``, it is treated as an external collection ``docid``; if it is of type ``int``, it is treated as an
Expand All @@ -228,6 +246,97 @@ def doc(self, docid: Union[str, int]) -> Optional[Document]:
return None
return Document(lucene_document)

def set_rm3(self):
self.object.set_rm3()

def set_rm3(self, fb_terms=10, fb_docs=10, original_query_weight=float(0.5), debug=False, filter_terms=True):
"""Configure RM3 pseudo-relevance feedback.
Parameters
----------
fb_terms : int
RM3 parameter for number of expansion terms.
fb_docs : int
RM3 parameter for number of expansion documents.
original_query_weight : float
RM3 parameter for weight to assign to the original query.
debug : bool
Print the original and expanded queries as debug output.
filter_terms: bool
Whether to remove non-English terms.
"""
if self.object.reader.getTermVectors(0):
self.object.set_rm3(None, fb_terms, fb_docs, original_query_weight, debug, filter_terms)
elif self.object.reader.document(0).getField('raw'):
self.object.set_rm3('JsonVectorCollection', fb_terms, fb_docs, original_query_weight, debug, filter_terms)
elif self.prebuilt_index_name in ['msmarco-v1-passage', 'msmarco-v1-doc', 'msmarco-v1-doc-segmented']:
self.object.set_rm3('JsonCollection', fb_terms, fb_docs, original_query_weight, debug, filter_terms)
elif self.prebuilt_index_name in ['msmarco-v2-passage', 'msmarco-v2-passage-augmented']:
self.object.set_rm3('MsMarcoV2PassageCollection', fb_terms, fb_docs, original_query_weight, debug, filter_terms)
elif self.prebuilt_index_name in ['msmarco-v2-doc', 'msmarco-v2-doc-segmented']:
self.object.set_rm3('MsMarcoV2DocCollection', fb_terms, fb_docs, original_query_weight, debug, filter_terms)
else:
raise TypeError("RM3 is not supported for indexes without document vectors or raw texts.")

def unset_rm3(self):
"""Disable RM3 pseudo-relevance feedback."""
self.object.unset_rm3()

def is_using_rm3(self) -> bool:
"""Check if RM3 pseudo-relevance feedback is being performed."""
return self.object.use_rm3()

def set_rocchio(self):
self.object.set_rocchio()


def set_rocchio(self, top_fb_terms=10, top_fb_docs=10, bottom_fb_terms=10, bottom_fb_docs=10,
alpha=1, beta=0.75, gamma=0, debug=False, use_negative=False):
"""Configure Rocchio pseudo-relevance feedback.
Parameters
----------
top_fb_terms : int
Rocchio parameter for number of relevant expansion terms.
top_fb_docs : int
Rocchio parameter for number of relevant expansion documents.
bottom_fb_terms : int
Rocchio parameter for number of non-relevant expansion terms.
bottom_fb_docs : int
Rocchio parameter for number of non-relevant expansion documents.
alpha : float
Rocchio parameter for weight to assign to the original query.
beta: float
Rocchio parameter for weight to assign to the relevant document vector.
gamma: float
Rocchio parameter for weight to assign to the nonrelevant document vector.
debug : bool
Print the original and expanded queries as debug output.
use_negative : bool
Rocchio parameter to use negative labels.
"""
if self.object.reader.getTermVectors(0):
self.object.set_rocchio(None, top_fb_terms, top_fb_docs, bottom_fb_terms, bottom_fb_docs,
alpha, beta, gamma, debug, use_negative)
elif self.object.reader.document(0).getField('raw'):
self.object.set_rocchio('JsonVectorCollection', top_fb_terms, top_fb_docs, bottom_fb_terms, bottom_fb_docs,
alpha, beta, gamma, debug, use_negative)
elif self.prebuilt_index_name in ['msmarco-v1-passage', 'msmarco-v1-doc', 'msmarco-v1-doc-segmented']:
self.object.set_rocchio('JsonCollection', top_fb_terms, top_fb_docs, bottom_fb_terms, bottom_fb_docs,
alpha, beta, gamma, debug, use_negative)
# Note, we don't have any Pyserini 2CRs that use Rocchio for MS MARCO v2, so there's currently no
# corresponding code branch here. To avoid introducing bugs (without 2CR tests), we'll add when it's needed.
else:
raise TypeError("Rocchio is not supported for indexes without document vectors or raw texts.")

def unset_rocchio(self):
"""Disable Rocchio pseudo-relevance feedback."""
self.object.unset_rocchio()

def is_using_rocchio(self) -> bool:
"""Check if Rocchio pseudo-relevance feedback is being performed."""
return self.object.use_rocchio()

def doc_by_field(self, field: str, q: str) -> Optional[Document]:
"""Return the :class:`Document` based on a ``field`` with ``id``. For example, this method can be used to fetch
document based on alternative primary keys that have been indexed, such as an article's DOI. Method returns
Expand Down Expand Up @@ -327,7 +436,7 @@ def search(self, q: str, k: int = 10, fields=dict()) -> List[JImpactSearcherResu
jquery = JHashMap()
for (token, weight) in fusion_encoded_query.items():
if token in self.idf and self.idf[token] > self.min_idf:
jquery.put(token, JFloat(weight))
jquery.put(token, JInt(weight))

if self.sparse_vecs is not None:
search_k = k * (self.min_idf + 1)
Expand All @@ -348,7 +457,7 @@ def batch_search(self, queries: List[str], qids: List[str],
jquery = JHashMap()
for (token, weight) in fusion_encoded_query.items():
if token in self.idf and self.idf[token] > self.min_idf:
jquery.put(token, JFloat(weight))
jquery.put(token, JInt(weight))
query_lst.add(jquery)
sparse_encoded_queries[qid] = sparse_encoded_query

Expand Down
9 changes: 5 additions & 4 deletions tests/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,11 @@ def test_aggretriever_cocondenser_encoder_cmd(self):
def test_onnx_encode_unicoil(self):
temp_object = LuceneImpactSearcher(f'{self.index_dir}lucene9-index.cacm', 'SpladePlusPlusEnsembleDistil', encoder_type='onnx')

results = temp_object.encode("here is a test")
self.assertAlmostEqual(results.get("here"), 3.05345, delta=2e-4)
self.assertAlmostEqual(results.get("a"), 0.59636426, delta=2e-4)
self.assertAlmostEqual(results.get("test"), 2.9012794, delta=2e-4)
# this function will never be called in _impact_searcher, here to check quantization correctness
results = temp_object.object.encodeWithOnnx("here is a test")
self.assertAlmostEqual(results.get("here"), 156, delta=2e-4)
self.assertAlmostEqual(results.get("a"), 31, delta=2e-4)
self.assertAlmostEqual(results.get("test"), 149, delta=2e-4)

temp_object.close()
del temp_object
Expand Down

0 comments on commit b713a51

Please sign in to comment.