Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds BERT reranker example #59

Merged
merged 2 commits into from
Jul 16, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Currently, this repo contains implementations of the rerankers for [CovidQA](htt
* If you prefer Anaconda, use `conda env create -f environment.yml && conda activate pygaggle`.


# A simple reranking example
# A simple reranking example - T5
The code below exemplifies how to score two documents for a given query using a T5 reranker from [Document Ranking with a Pretrained
Sequence-to-Sequence Model](https://arxiv.org/pdf/2003.06713.pdf).
```python
Expand Down Expand Up @@ -65,3 +65,35 @@ scores = [result.score for result in reranker.rerank(query, documents)]
# scores = [-0.1782158613204956, -0.36637523770332336]
```

# A simple reranking example - BERT
You can also try the code below, which uses a BERT reranker from [Passage Re-ranking with BERT](https://arxiv.org/pdf/1901.04085.pdf).
Note that the T5 reranker produces slightly better scores than the BERT reranker.
```python
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from pygaggle.model import BatchTokenizer
from pygaggle.rerank.base import Query, Text
from pygaggle.rerank.transformer import SequenceClassificationTransformerReranker

model_name = 'castorini/monobert-large-msmarco'
tokenizer_name = 'bert-large-uncased'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = AutoModelForSequenceClassification.from_pretrained(model_name)
model = model.to(device).eval()

tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
reranker = SequenceClassificationTransformerReranker(model, tokenizer)

query = Query('what causes low liver enzymes')

correct_doc = Text('Reduced production of liver enzymes may indicate dysfunction of the liver. This article explains the causes and symptoms of low liver enzymes. Scroll down to know how the production of the enzymes can be accelerated.')

wrong_doc = Text('Elevated liver enzymes often indicate inflammation or damage to cells in the liver. Inflamed or injured liver cells leak higher than normal amounts of certain chemicals, including liver enzymes, into the bloodstream, elevating liver enzymes on blood tests.')

documents = [correct_doc, wrong_doc]

scores = [result.score for result in reranker.rerank(query, documents)]
# scores = [-3.077077865600586, -5.45782470703125]
```