Skip to content

Commit

Permalink
refactor vector prf and add negative prf passages for rocchio (castor…
Browse files Browse the repository at this point in the history
…ini#1108)

- add negative prf passages
  • Loading branch information
ArvinZhuang authored Apr 7, 2022
1 parent 65ad9cb commit cc06d0b
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 22 deletions.
7 changes: 4 additions & 3 deletions docs/experiments-vector-prf.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ one can use `--encoded-queries ance-msmarco-passage-dev-subset`. For ADORE model

With these parameters, one can easily reproduce the results above, for example, to reproduce `TREC DL 2019 Passage with ANCE Average Vector PRF 3` the command will be:
```
$ python -m pyserini.dsearch --topics dl19-passage \
$ python -m pyserini.search.faiss --topics dl19-passage \
--index msmarco-passage-ance-bf \
--encoder castorini/ance-msmarco-passage \
--batch-size 64 \
Expand All @@ -171,14 +171,15 @@ $ python -m pyserini.dsearch --topics dl19-passage \

To reproduce `TREC DL 2019 Passage with ANCE Rocchio Vector PRF 5 Alpha 0.4 Beta 0.6`, the command will be:
```
$ python -m pyserini.dsearch --topics dl19-passage \
$ python -m pyserini.search.faiss --topics dl19-passage \
--index msmarco-passage-ance-bf \
--encoder castorini/ance-msmarco-passage \
--batch-size 64 \
--threads 12 \
--output runs/run.ance.dl19-passage.rocchio_prf5_a0.4_b0.6.trec \
--prf-depth 5 \
--prf-method rocchio \
--prf-depth 5 \
--rocchio-topk 5 \
--rocchio-alpha 0.4 \
--rocchio-beta 0.6
```
Expand Down
14 changes: 11 additions & 3 deletions pyserini/search/faiss/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,13 @@ def define_dsearch_args(parser):
default=0.9,
help="The alpha parameter to control the contribution from the query vector")
parser.add_argument('--rocchio-beta', type=float, metavar='beta parameter for rocchio', required=False, default=0.1,
help="The beta parameter to control the contribution from the average vector of the PRF passages")
help="The beta parameter to control the contribution from the average vector of the positive PRF passages")
parser.add_argument('--rocchio-gamma', type=float, metavar='gamma parameter for rocchio', required=False, default=0.1,
help="The gamma parameter to control the contribution from the average vector of the negative PRF passages")
parser.add_argument('--rocchio-topk', type=int, metavar='topk passages as positive for rocchio', required=False, default=3,
help="Set topk passages as positive PRF passages for rocchio")
parser.add_argument('--rocchio-bottomk', type=int, metavar='bottomk passages as negative for rocchio', required=False, default=0,
help="Set bottomk passages as negative PRF passages for rocchio, 0: do not use negatives prf passages.")
parser.add_argument('--sparse-index', type=str, metavar='sparse lucene index containing contents', required=False,
help='The path to sparse index containing the passage contents')
parser.add_argument('--ance-prf-encoder', type=str, metavar='query encoder path for ANCE-PRF', required=False,
Expand Down Expand Up @@ -172,7 +178,8 @@ def init_query_encoder(encoder, tokenizer_name, topics_name, encoded_queries, de
if args.prf_method.lower() == 'avg':
prfRule = DenseVectorAveragePrf()
elif args.prf_method.lower() == 'rocchio':
prfRule = DenseVectorRocchioPrf(args.rocchio_alpha, args.rocchio_beta)
prfRule = DenseVectorRocchioPrf(args.rocchio_alpha, args.rocchio_beta, args.rocchio_gamma,
args.rocchio_topk, args.rocchio_bottomk)
# ANCE-PRF is using a new query encoder, so the input to DenseVectorAncePrf is different
elif args.prf_method.lower() == 'ance-prf' and type(query_encoder) == AnceQueryEncoder:
if os.path.exists(args.sparse_index):
Expand Down Expand Up @@ -209,7 +216,8 @@ def init_query_encoder(encoder, tokenizer_name, topics_name, encoded_queries, de
if args.prf_method.lower() == 'ance-prf':
prf_emb_q = prfRule.get_prf_q_emb(text, prf_candidates)
else:
prf_emb_q = prfRule.get_prf_q_emb(emb_q, prf_candidates)
prf_emb_q = prfRule.get_prf_q_emb(emb_q[0], prf_candidates)
prf_emb_q = np.expand_dims(prf_emb_q, axis=0).astype('float32')
hits = searcher.search(prf_emb_q, k=args.hits, **kwargs)
else:
hits = searcher.search(text, args.hits, **kwargs)
Expand Down
36 changes: 20 additions & 16 deletions pyserini/search/faiss/_prf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDense
return new query embeddings
"""
all_candidate_embs = [item.vectors for item in prf_candidates]
new_emb_qs = np.mean(np.vstack((emb_qs[0], all_candidate_embs)), axis=0)
new_emb_qs = np.array([new_emb_qs]).astype('float32')
new_emb_qs = np.mean(np.vstack((emb_qs, all_candidate_embs)), axis=0)
return new_emb_qs

def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = None,
Expand All @@ -61,26 +60,33 @@ def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray =
new_emb_qs = list()
for index, topic_id in enumerate(topic_ids):
qids.append(topic_id)
all_candidate_embs = [item.vectors for item in prf_candidates[topic_id]]
new_emb_q = np.mean(np.vstack((emb_qs[index], all_candidate_embs)), axis=0)
new_emb_qs.append(new_emb_q)
new_emb_qs.append(self.get_prf_q_emb(emb_qs[index], prf_candidates[topic_id]))
new_emb_qs = np.array(new_emb_qs).astype('float32')
return new_emb_qs


class DenseVectorRocchioPrf(DenseVectorPrf):
def __init__(self, alpha: float, beta: float):
def __init__(self, alpha: float, beta: float, gamma: float, topk: int, bottomk: int):
"""
Parameters
----------
alpha : float
Rocchio parameter, controls the weight assigned to the original query embedding.
beta : float
Rocchio parameter, controls the weight assigned to the document embeddings.
Rocchio parameter, controls the weight assigned to the positive document embeddings.
gamma : float
Rocchio parameter, controls the weight assigned to the negative document embeddings.
topk : int
Rocchio parameter, set topk documents as positive document feedbacks.
bottomk : int
Rocchio parameter, set bottomk documents as negative document feedbacks.
"""
DenseVectorPrf.__init__(self)
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.topk = topk
self.bottomk = bottomk

def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDenseSearchResult] = None):
"""Perform Rocchio PRF with Dense Vectors
Expand All @@ -99,10 +105,12 @@ def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDense
"""

all_candidate_embs = [item.vectors for item in prf_candidates]
weighted_mean_doc_embs = self.beta * np.mean(all_candidate_embs, axis=0)
weighted_query_embs = self.alpha * emb_qs[0]
new_emb_q = np.sum(np.vstack((weighted_query_embs, weighted_mean_doc_embs)), axis=0)
new_emb_q = np.array([new_emb_q]).astype('float32')
weighted_query_embs = self.alpha * emb_qs
weighted_mean_pos_doc_embs = self.beta * np.mean(all_candidate_embs[:self.topk], axis=0)
new_emb_q = weighted_query_embs + weighted_mean_pos_doc_embs
if self.bottomk > 0:
weighted_mean_neg_doc_embs = self.gamma * np.mean(all_candidate_embs[-self.bottomk:], axis=0)
new_emb_q -= weighted_mean_neg_doc_embs
return new_emb_q

def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = None,
Expand All @@ -127,11 +135,7 @@ def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray =
new_emb_qs = list()
for index, topic_id in enumerate(topic_ids):
qids.append(topic_id)
all_candidate_embs = [item.vectors for item in prf_candidates[topic_id]]
weighted_mean_doc_embs = self.beta * np.mean(all_candidate_embs, axis=0)
weighted_query_embs = self.alpha * emb_qs[index]
new_emb_q = np.sum(np.vstack((weighted_query_embs, weighted_mean_doc_embs)), axis=0)
new_emb_qs.append(new_emb_q)
new_emb_qs.append(self.get_prf_q_emb(emb_qs[index], prf_candidates[topic_id]))
new_emb_qs = np.array(new_emb_qs).astype('float32')
return new_emb_qs

Expand Down

0 comments on commit cc06d0b

Please sign in to comment.