diff --git a/README.md b/README.md index 35a0dde4..74ea3da6 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ Currently, this repo contains implementations of the rerankers for [CovidQA](htt * If you prefer Anaconda, use `conda env create -f environment.yml && conda activate pygaggle`. -# A simple reranking example +# A simple reranking example - T5 The code below exemplifies how to score two documents for a given query using a T5 reranker from [Document Ranking with a Pretrained Sequence-to-Sequence Model](https://arxiv.org/pdf/2003.06713.pdf). ```python @@ -65,3 +65,35 @@ scores = [result.score for result in reranker.rerank(query, documents)] # scores = [-0.1782158613204956, -0.36637523770332336] ``` +# A simple reranking example - BERT +You can also try the code below, which uses a BERT reranker from [Passage Re-ranking with BERT](https://arxiv.org/pdf/1901.04085.pdf). +Note that the T5 reranker produces slightly better scores than the BERT reranker. +```python +import torch +from transformers import AutoTokenizer, AutoModelForSequenceClassification +from pygaggle.model import BatchTokenizer +from pygaggle.rerank.base import Query, Text +from pygaggle.rerank.transformer import SequenceClassificationTransformerReranker + +model_name = 'castorini/monobert-large-msmarco' +tokenizer_name = 'bert-large-uncased' + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +model = AutoModelForSequenceClassification.from_pretrained(model_name) +model = model.to(device).eval() + +tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) +reranker = SequenceClassificationTransformerReranker(model, tokenizer) + +query = Query('what causes low liver enzymes') + +correct_doc = Text('Reduced production of liver enzymes may indicate dysfunction of the liver. This article explains the causes and symptoms of low liver enzymes. Scroll down to know how the production of the enzymes can be accelerated.') + +wrong_doc = Text('Elevated liver enzymes often indicate inflammation or damage to cells in the liver. Inflamed or injured liver cells leak higher than normal amounts of certain chemicals, including liver enzymes, into the bloodstream, elevating liver enzymes on blood tests.') + +documents = [correct_doc, wrong_doc] + +scores = [result.score for result in reranker.rerank(query, documents)] +# scores = [-3.077077865600586, -5.45782470703125] +``` \ No newline at end of file