From cc06d0bf061404e318155e7510871dea50329247 Mon Sep 17 00:00:00 2001 From: Shengyao Zhuang <46237844+ArvinZhuang@users.noreply.github.com> Date: Thu, 7 Apr 2022 10:24:18 +1000 Subject: [PATCH] refactor vector prf and add negative prf passages for rocchio (#1108) - add negative prf passages --- docs/experiments-vector-prf.md | 7 +++--- pyserini/search/faiss/__main__.py | 14 +++++++++--- pyserini/search/faiss/_prf.py | 36 +++++++++++++++++-------------- 3 files changed, 35 insertions(+), 22 deletions(-) diff --git a/docs/experiments-vector-prf.md b/docs/experiments-vector-prf.md index aca5b4e61..94213fa5c 100644 --- a/docs/experiments-vector-prf.md +++ b/docs/experiments-vector-prf.md @@ -159,7 +159,7 @@ one can use `--encoded-queries ance-msmarco-passage-dev-subset`. For ADORE model With these parameters, one can easily reproduce the results above, for example, to reproduce `TREC DL 2019 Passage with ANCE Average Vector PRF 3` the command will be: ``` -$ python -m pyserini.dsearch --topics dl19-passage \ +$ python -m pyserini.search.faiss --topics dl19-passage \ --index msmarco-passage-ance-bf \ --encoder castorini/ance-msmarco-passage \ --batch-size 64 \ @@ -171,14 +171,15 @@ $ python -m pyserini.dsearch --topics dl19-passage \ To reproduce `TREC DL 2019 Passage with ANCE Rocchio Vector PRF 5 Alpha 0.4 Beta 0.6`, the command will be: ``` -$ python -m pyserini.dsearch --topics dl19-passage \ +$ python -m pyserini.search.faiss --topics dl19-passage \ --index msmarco-passage-ance-bf \ --encoder castorini/ance-msmarco-passage \ --batch-size 64 \ --threads 12 \ --output runs/run.ance.dl19-passage.rocchio_prf5_a0.4_b0.6.trec \ - --prf-depth 5 \ --prf-method rocchio \ + --prf-depth 5 \ + --rocchio-topk 5 \ --rocchio-alpha 0.4 \ --rocchio-beta 0.6 ``` diff --git a/pyserini/search/faiss/__main__.py b/pyserini/search/faiss/__main__.py index 4bda6904d..976e7c29a 100644 --- a/pyserini/search/faiss/__main__.py +++ b/pyserini/search/faiss/__main__.py @@ -62,7 +62,13 @@ def define_dsearch_args(parser): default=0.9, help="The alpha parameter to control the contribution from the query vector") parser.add_argument('--rocchio-beta', type=float, metavar='beta parameter for rocchio', required=False, default=0.1, - help="The beta parameter to control the contribution from the average vector of the PRF passages") + help="The beta parameter to control the contribution from the average vector of the positive PRF passages") + parser.add_argument('--rocchio-gamma', type=float, metavar='gamma parameter for rocchio', required=False, default=0.1, + help="The gamma parameter to control the contribution from the average vector of the negative PRF passages") + parser.add_argument('--rocchio-topk', type=int, metavar='topk passages as positive for rocchio', required=False, default=3, + help="Set topk passages as positive PRF passages for rocchio") + parser.add_argument('--rocchio-bottomk', type=int, metavar='bottomk passages as negative for rocchio', required=False, default=0, + help="Set bottomk passages as negative PRF passages for rocchio, 0: do not use negatives prf passages.") parser.add_argument('--sparse-index', type=str, metavar='sparse lucene index containing contents', required=False, help='The path to sparse index containing the passage contents') parser.add_argument('--ance-prf-encoder', type=str, metavar='query encoder path for ANCE-PRF', required=False, @@ -172,7 +178,8 @@ def init_query_encoder(encoder, tokenizer_name, topics_name, encoded_queries, de if args.prf_method.lower() == 'avg': prfRule = DenseVectorAveragePrf() elif args.prf_method.lower() == 'rocchio': - prfRule = DenseVectorRocchioPrf(args.rocchio_alpha, args.rocchio_beta) + prfRule = DenseVectorRocchioPrf(args.rocchio_alpha, args.rocchio_beta, args.rocchio_gamma, + args.rocchio_topk, args.rocchio_bottomk) # ANCE-PRF is using a new query encoder, so the input to DenseVectorAncePrf is different elif args.prf_method.lower() == 'ance-prf' and type(query_encoder) == AnceQueryEncoder: if os.path.exists(args.sparse_index): @@ -209,7 +216,8 @@ def init_query_encoder(encoder, tokenizer_name, topics_name, encoded_queries, de if args.prf_method.lower() == 'ance-prf': prf_emb_q = prfRule.get_prf_q_emb(text, prf_candidates) else: - prf_emb_q = prfRule.get_prf_q_emb(emb_q, prf_candidates) + prf_emb_q = prfRule.get_prf_q_emb(emb_q[0], prf_candidates) + prf_emb_q = np.expand_dims(prf_emb_q, axis=0).astype('float32') hits = searcher.search(prf_emb_q, k=args.hits, **kwargs) else: hits = searcher.search(text, args.hits, **kwargs) diff --git a/pyserini/search/faiss/_prf.py b/pyserini/search/faiss/_prf.py index 79d6cbaa4..68167318d 100644 --- a/pyserini/search/faiss/_prf.py +++ b/pyserini/search/faiss/_prf.py @@ -34,8 +34,7 @@ def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDense return new query embeddings """ all_candidate_embs = [item.vectors for item in prf_candidates] - new_emb_qs = np.mean(np.vstack((emb_qs[0], all_candidate_embs)), axis=0) - new_emb_qs = np.array([new_emb_qs]).astype('float32') + new_emb_qs = np.mean(np.vstack((emb_qs, all_candidate_embs)), axis=0) return new_emb_qs def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = None, @@ -61,26 +60,33 @@ def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = new_emb_qs = list() for index, topic_id in enumerate(topic_ids): qids.append(topic_id) - all_candidate_embs = [item.vectors for item in prf_candidates[topic_id]] - new_emb_q = np.mean(np.vstack((emb_qs[index], all_candidate_embs)), axis=0) - new_emb_qs.append(new_emb_q) + new_emb_qs.append(self.get_prf_q_emb(emb_qs[index], prf_candidates[topic_id])) new_emb_qs = np.array(new_emb_qs).astype('float32') return new_emb_qs class DenseVectorRocchioPrf(DenseVectorPrf): - def __init__(self, alpha: float, beta: float): + def __init__(self, alpha: float, beta: float, gamma: float, topk: int, bottomk: int): """ Parameters ---------- alpha : float Rocchio parameter, controls the weight assigned to the original query embedding. beta : float - Rocchio parameter, controls the weight assigned to the document embeddings. + Rocchio parameter, controls the weight assigned to the positive document embeddings. + gamma : float + Rocchio parameter, controls the weight assigned to the negative document embeddings. + topk : int + Rocchio parameter, set topk documents as positive document feedbacks. + bottomk : int + Rocchio parameter, set bottomk documents as negative document feedbacks. """ DenseVectorPrf.__init__(self) self.alpha = alpha self.beta = beta + self.gamma = gamma + self.topk = topk + self.bottomk = bottomk def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDenseSearchResult] = None): """Perform Rocchio PRF with Dense Vectors @@ -99,10 +105,12 @@ def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDense """ all_candidate_embs = [item.vectors for item in prf_candidates] - weighted_mean_doc_embs = self.beta * np.mean(all_candidate_embs, axis=0) - weighted_query_embs = self.alpha * emb_qs[0] - new_emb_q = np.sum(np.vstack((weighted_query_embs, weighted_mean_doc_embs)), axis=0) - new_emb_q = np.array([new_emb_q]).astype('float32') + weighted_query_embs = self.alpha * emb_qs + weighted_mean_pos_doc_embs = self.beta * np.mean(all_candidate_embs[:self.topk], axis=0) + new_emb_q = weighted_query_embs + weighted_mean_pos_doc_embs + if self.bottomk > 0: + weighted_mean_neg_doc_embs = self.gamma * np.mean(all_candidate_embs[-self.bottomk:], axis=0) + new_emb_q -= weighted_mean_neg_doc_embs return new_emb_q def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = None, @@ -127,11 +135,7 @@ def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = new_emb_qs = list() for index, topic_id in enumerate(topic_ids): qids.append(topic_id) - all_candidate_embs = [item.vectors for item in prf_candidates[topic_id]] - weighted_mean_doc_embs = self.beta * np.mean(all_candidate_embs, axis=0) - weighted_query_embs = self.alpha * emb_qs[index] - new_emb_q = np.sum(np.vstack((weighted_query_embs, weighted_mean_doc_embs)), axis=0) - new_emb_qs.append(new_emb_q) + new_emb_qs.append(self.get_prf_q_emb(emb_qs[index], prf_candidates[topic_id])) new_emb_qs = np.array(new_emb_qs).astype('float32') return new_emb_qs