From 509bb5a0305f822abf9f3832e2f9643c1a51c36d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xueguang=20Ma=20=E9=A9=AC=E9=9B=AA=E5=85=89?= Date: Fri, 8 Oct 2021 09:30:20 -0400 Subject: [PATCH] add splade query encode (#815) --- docs/experiments-spladev2.md | 41 +++++++++++++++++++++++++++++ pyserini/encode/_splade.py | 36 +++++++++++++++++++++++++ pyserini/search/_impact_searcher.py | 3 +++ 3 files changed, 80 insertions(+) create mode 100644 pyserini/encode/_splade.py diff --git a/docs/experiments-spladev2.md b/docs/experiments-spladev2.md index 45a3000d24..38b63741bc 100644 --- a/docs/experiments-spladev2.md +++ b/docs/experiments-spladev2.md @@ -88,6 +88,47 @@ QueriesRanked: 6980 The final evaluation metric is very close to the one reported in the paper (0.368). +Alternatively, we can use one-the-fly query encoding. + +First, download the model checkpoint from NAVER's github [repo](https://github.com/naver/splade/tree/main/weights/splade_max): +```bash +mkdir splade-distil-max +cd splade-distil-max +wget https://github.com/naver/splade/raw/main/weights/distilsplade_max/pytorch_model.bin +wget https://github.com/naver/splade/raw/main/weights/distilsplade_max/config.json +wget https://github.com/naver/splade/raw/main/weights/distilsplade_max/special_tokens_map.json +wget https://github.com/naver/splade/raw/main/weights/distilsplade_max/tokenizer_config.json +wget https://github.com/naver/splade/raw/main/weights/distilsplade_max/vocab.txt +cd .. +``` + +Then run retrieval with `--encoder splade-distil-max` + +```bash +python -m pyserini.search --topics msmarco-passage-dev-subset \ + --index indexes/lucene-index.msmarco-passage-distill-splade-max \ + --encoder splade-distil-max \ + --output runs/run.msmarco-passage-distill-splade-max.tsv \ + --impact \ + --hits 1000 --batch 36 --threads 12 \ + --output-format msmarco +``` + +And then evaluate: + +```bash +python -m pyserini.eval.msmarco_passage_eval msmarco-passage-dev-subset runs/run.msmarco-passage-distill-splade-max.tsv +``` + +The results should be as follows: + +``` +##################### +MRR @10: 0.3684321417201083 +QueriesRanked: 6980 +##################### +``` + ## Reproduction Log[*](reproducibility.md) + Results reproduced by [@lintool](https://github.com/lintool) on 2021-10-05 (commit [`58d286c`](https://github.com/castorini/pyserini/commit/58d286c3f9fe845e261c271f2a0f514462844d97)) diff --git a/pyserini/encode/_splade.py b/pyserini/encode/_splade.py new file mode 100644 index 0000000000..da4f422656 --- /dev/null +++ b/pyserini/encode/_splade.py @@ -0,0 +1,36 @@ +import torch +from transformers import AutoModelForMaskedLM, AutoTokenizer +import numpy as np + +from pyserini.encode import QueryEncoder + + +class SpladeQueryEncoder(QueryEncoder): + def __init__(self, model_name_or_path, tokenizer_name=None, device='cpu'): + self.device = device + self.model = AutoModelForMaskedLM.from_pretrained(model_name_or_path) + self.model.to(self.device) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name_or_path) + self.reverse_voc = {v: k for k, v in self.tokenizer.vocab.items()} + + def encode(self, text, **kwargs): + max_length = 256 # hardcode for now + inputs = self.tokenizer([text], max_length=max_length, padding='longest', + truncation=True, add_special_tokens=True, + return_tensors='pt').to(self.device) + input_ids = inputs['input_ids'] + input_attention = inputs['attention_mask'] + batch_logits = self.model(input_ids)['logits'] + batch_aggregated_logits, _ = torch.max(torch.log(1 + torch.relu(batch_logits)) + * input_attention.unsqueeze(-1), dim=1) + batch_aggregated_logits = batch_aggregated_logits.cpu().detach().numpy() + return self._output_to_weight_dicts(batch_aggregated_logits)[0] + + def _output_to_weight_dicts(self, batch_aggregated_logits): + to_return = [] + for aggregated_logits in batch_aggregated_logits: + col = np.nonzero(aggregated_logits)[0] + weights = aggregated_logits[col] + d = {self.reverse_voc[k]: float(v) for k, v in zip(list(col), list(weights))} + to_return.append(d) + return to_return diff --git a/pyserini/search/_impact_searcher.py b/pyserini/search/_impact_searcher.py index 3713722a13..dc7827c2e7 100644 --- a/pyserini/search/_impact_searcher.py +++ b/pyserini/search/_impact_searcher.py @@ -26,6 +26,7 @@ from pyserini.pyclass import autoclass, JFloat, JArrayList, JHashMap, JString from pyserini.util import download_prebuilt_index from pyserini.encode import QueryEncoder, TokFreqQueryEncoder, UniCoilQueryEncoder, CachedDataQueryEncoder +from ..encode._splade import SpladeQueryEncoder logger = logging.getLogger(__name__) @@ -226,6 +227,8 @@ def _init_query_encoder_from_str(query_encoder): return CachedDataQueryEncoder(query_encoder) elif 'unicoil' in query_encoder.lower(): return UniCoilQueryEncoder(query_encoder) + elif 'splade' in query_encoder.lower(): + return SpladeQueryEncoder(query_encoder) @staticmethod def _compute_idf(index_path):