forked from castorini/anserini
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add splade query encode (castorini#815)
- Loading branch information
Showing
3 changed files
with
80 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import torch | ||
from transformers import AutoModelForMaskedLM, AutoTokenizer | ||
import numpy as np | ||
|
||
from pyserini.encode import QueryEncoder | ||
|
||
|
||
class SpladeQueryEncoder(QueryEncoder): | ||
def __init__(self, model_name_or_path, tokenizer_name=None, device='cpu'): | ||
self.device = device | ||
self.model = AutoModelForMaskedLM.from_pretrained(model_name_or_path) | ||
self.model.to(self.device) | ||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name_or_path) | ||
self.reverse_voc = {v: k for k, v in self.tokenizer.vocab.items()} | ||
|
||
def encode(self, text, **kwargs): | ||
max_length = 256 # hardcode for now | ||
inputs = self.tokenizer([text], max_length=max_length, padding='longest', | ||
truncation=True, add_special_tokens=True, | ||
return_tensors='pt').to(self.device) | ||
input_ids = inputs['input_ids'] | ||
input_attention = inputs['attention_mask'] | ||
batch_logits = self.model(input_ids)['logits'] | ||
batch_aggregated_logits, _ = torch.max(torch.log(1 + torch.relu(batch_logits)) | ||
* input_attention.unsqueeze(-1), dim=1) | ||
batch_aggregated_logits = batch_aggregated_logits.cpu().detach().numpy() | ||
return self._output_to_weight_dicts(batch_aggregated_logits)[0] | ||
|
||
def _output_to_weight_dicts(self, batch_aggregated_logits): | ||
to_return = [] | ||
for aggregated_logits in batch_aggregated_logits: | ||
col = np.nonzero(aggregated_logits)[0] | ||
weights = aggregated_logits[col] | ||
d = {self.reverse_voc[k]: float(v) for k, v in zip(list(col), list(weights))} | ||
to_return.append(d) | ||
return to_return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters