From 60fe6f4b4578d93133d0e50eb64d36f3ebf9105f Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Sun, 12 May 2024 20:42:18 -0700 Subject: [PATCH] Reranker implementation using Cross Encoders (HuggingFace / SentenceTransformers) --- docs/user_guide/rerankers_06.ipynb | 144 +++++++++++++++++----- redisvl/utils/rerank/__init__.py | 6 +- redisvl/utils/rerank/hf_cross_encoder.py | 129 +++++++++++++++++++ tests/integration/test_rerankers.py | 64 ++++++++-- tests/unit/test_cross_encoder_reranker.py | 53 ++++++++ 5 files changed, 353 insertions(+), 43 deletions(-) create mode 100644 redisvl/utils/rerank/hf_cross_encoder.py create mode 100644 tests/unit/test_cross_encoder_reranker.py diff --git a/docs/user_guide/rerankers_06.ipynb b/docs/user_guide/rerankers_06.ipynb index d3ef7a66..adfc2af8 100644 --- a/docs/user_guide/rerankers_06.ipynb +++ b/docs/user_guide/rerankers_06.ipynb @@ -9,7 +9,10 @@ "\n", "In this notebook, we will show how to use RedisVL to rerank search results\n", "(documents or chunks or records) based on the input query. Today RedisVL\n", - "supports reranking through the [Cohere /rerank API](https://docs.cohere.com/docs/rerank-2).\n", + "supports reranking through: \n", + "\n", + "- A re-ranker that uses pre-trained [Cross-Encoders](https://sbert.net/examples/applications/cross-encoder/README.html) which can use models from [Hugging Face cross encoder models](https://huggingface.co/cross-encoder) or Hugging Face models that implement a cross encoder function ([example: BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)).\n", + "- The [Cohere /rerank API](https://docs.cohere.com/docs/rerank-2).\n", "\n", "Before running this notebook, be sure to:\n", "1. Have installed ``redisvl`` and have that environment active for this notebook.\n", @@ -26,8 +29,10 @@ }, { "cell_type": "code", - "execution_count": 1, - "metadata": {}, + "execution_count": 27, + "metadata": { + "metadata": {} + }, "outputs": [], "source": [ "# import necessary modules\n", @@ -48,8 +53,10 @@ }, { "cell_type": "code", - "execution_count": 2, - "metadata": {}, + "execution_count": 28, + "metadata": { + "metadata": {} + }, "outputs": [], "source": [ "query = \"What is the capital of the United States?\"\n", @@ -75,24 +82,93 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Init the Reranker\n", + "### Using the Cross-Encoder Reranker\n", "\n", - "Initialize the reranker. Install the cohere library and provide the right Cohere API Key." + "To use the cross-encoder reranker we initialize an instance of `HFCrossEncoderReranker` passing a suitable model (if no model is provided, the `cross-encoder/ms-marco-MiniLM-L-6-v2` model is used): " ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 29, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "from redisvl.utils.rerank import HFCrossEncoderReranker\n", + "\n", + "cross_encoder_reranker = HFCrossEncoderReranker(\"BAAI/bge-reranker-base\")" + ] + }, + { + "cell_type": "markdown", "metadata": {}, + "source": [ + "### Rerank documents with HFCrossEncoderReranker\n", + "\n", + "With the obtained reranker instance we can rerank and truncate the list of\n", + "documents based on relevance to the initial query." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "metadata": {} + }, "outputs": [], "source": [ - "#!pip install cohere" + "results, scores = cross_encoder_reranker.rank(query=query, docs=docs)" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 31, + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.07461125403642654 -- {'content': 'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. The President of the USA and many major national government offices are in the territory. This makes it the political center of the United States of America.'}\n", + "0.05220315232872963 -- {'content': 'Charlotte Amalie is the capital and largest city of the United States Virgin Islands. It has about 20,000 people. The city is on the island of Saint Thomas.'}\n", + "0.3802368640899658 -- {'content': 'Carson City is the capital city of the American state of Nevada. At the 2010 United States Census, Carson City had a population of 55,274.'}\n" + ] + } + ], + "source": [ + "for result, score in zip(results, scores):\n", + " print(score, \" -- \", result)" + ] + }, + { + "cell_type": "markdown", "metadata": {}, + "source": [ + "### Using the Cohere Reranker\n", + "\n", + "To initialize the Cohere reranker you'll need to install the cohere library and provide the right Cohere API Key." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "#!pip install cohere" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "metadata": {} + }, "outputs": [], "source": [ "import getpass\n", @@ -103,38 +179,44 @@ }, { "cell_type": "code", - "execution_count": 4, - "metadata": {}, + "execution_count": 34, + "metadata": { + "metadata": {} + }, "outputs": [], "source": [ "from redisvl.utils.rerank import CohereReranker\n", "\n", - "reranker = CohereReranker(limit=3, api_config={\"api_key\": api_key})" + "cohere_reranker = CohereReranker(limit=3, api_config={\"api_key\": api_key})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Rerank documents\n", + "### Rerank documents with CohereReranker\n", "\n", - "Below we will use the `CohereReranker` to rerank and also truncate the list of\n", + "Below we will use the `CohereReranker` to rerank and truncate the list of\n", "documents above based on relevance to the initial query." ] }, { "cell_type": "code", - "execution_count": 5, - "metadata": {}, + "execution_count": 35, + "metadata": { + "metadata": {} + }, "outputs": [], "source": [ - "results, scores = reranker.rank(query=query, docs=docs)" + "results, scores = cohere_reranker.rank(query=query, docs=docs)" ] }, { "cell_type": "code", - "execution_count": 7, - "metadata": {}, + "execution_count": 36, + "metadata": { + "metadata": {} + }, "outputs": [ { "name": "stdout", @@ -162,8 +244,10 @@ }, { "cell_type": "code", - "execution_count": 8, - "metadata": {}, + "execution_count": 37, + "metadata": { + "metadata": {} + }, "outputs": [], "source": [ "docs = [\n", @@ -192,17 +276,21 @@ }, { "cell_type": "code", - "execution_count": 10, - "metadata": {}, + "execution_count": 38, + "metadata": { + "metadata": {} + }, "outputs": [], "source": [ - "results, scores = reranker.rank(query=query, docs=docs, rank_by=[\"passage\", \"source\"])" + "results, scores = cohere_reranker.rank(query=query, docs=docs, rank_by=[\"passage\", \"source\"])" ] }, { "cell_type": "code", - "execution_count": 11, - "metadata": {}, + "execution_count": 39, + "metadata": { + "metadata": {} + }, "outputs": [ { "name": "stdout", @@ -236,7 +324,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.11.9" }, "orig_nbformat": 4, "vscode": { diff --git a/redisvl/utils/rerank/__init__.py b/redisvl/utils/rerank/__init__.py index ef7fa9e4..faafd809 100644 --- a/redisvl/utils/rerank/__init__.py +++ b/redisvl/utils/rerank/__init__.py @@ -1,7 +1,5 @@ from redisvl.utils.rerank.base import BaseReranker from redisvl.utils.rerank.cohere import CohereReranker +from redisvl.utils.rerank.hf_cross_encoder import HFCrossEncoderReranker -__all__ = [ - "BaseReranker", - "CohereReranker", -] +__all__ = ["BaseReranker", "CohereReranker", "HFCrossEncoderReranker"] diff --git a/redisvl/utils/rerank/hf_cross_encoder.py b/redisvl/utils/rerank/hf_cross_encoder.py new file mode 100644 index 00000000..2fc0f908 --- /dev/null +++ b/redisvl/utils/rerank/hf_cross_encoder.py @@ -0,0 +1,129 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +from sentence_transformers import CrossEncoder + +from redisvl.utils.rerank.base import BaseReranker + + +class HFCrossEncoderReranker(BaseReranker): + """ + The HFCrossEncoderReranker class uses a cross-encoder models from Hugging Face + to rerank documents based on an input query. + + This reranker loads a cross-encoder model using the `CrossEncoder` class + from the `sentence_transformers` library. It requires the + `sentence_transformers` library to be installed. + + .. code-block:: python + + from redisvl.utils.rerank import HFCrossEncoderReranker + + # set up the HFCrossEncoderReranker with a specific model + reranker = HFCrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2", limit=3) + # rerank raw search results based on user input/query + results = reranker.rank( + query="your input query text here", + docs=[ + {"content": "document 1"}, + {"content": "document 2"}, + {"content": "document 3"} + ] + ) + """ + + def __init__( + self, + model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", + limit: int = 3, + return_score: bool = True, + ) -> None: + """ + Initialize the HFCrossEncoderReranker with a specified model and ranking criteria. + + Parameters: + model_name (str): The name or path of the cross-encoder model to use for reranking. + Defaults to 'cross-encoder/ms-marco-MiniLM-L-6-v2'. + limit (int): The maximum number of results to return after reranking. Must be a positive integer. + return_score (bool): Whether to return scores alongside the reranked results. + """ + super().__init__( + model=model_name, rank_by=None, limit=limit, return_score=return_score + ) + self.model: CrossEncoder = CrossEncoder(model_name) + + def rank( + self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs + ) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: + """ + Rerank documents based on the provided query using the loaded cross-encoder model. + + This method processes the user's query and the provided documents to rerank them + in a manner that is potentially more relevant to the query's context. + + Parameters: + query (str): The user's search query. + docs (Union[List[Dict[str, Any]], List[str]]): The list of documents to be ranked, + either as dictionaries or strings. + + Returns: + Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: + The reranked list of documents and optionally associated scores. + """ + limit = kwargs.get("limit", self.limit) + return_score = kwargs.get("return_score", self.return_score) + + if not query: + raise ValueError("query cannot be empty") + + if not isinstance(query, str): + raise TypeError("query must be a string") + + if not isinstance(docs, list): + raise TypeError("docs must be a list") + + if not docs: + return [] if not return_score else ([], []) + + if all(isinstance(doc, dict) for doc in docs): + texts = [ + str(doc["content"]) + for doc in docs + if isinstance(doc, dict) and "content" in doc + ] + doc_subset = [ + doc for doc in docs if isinstance(doc, dict) and "content" in doc + ] + else: + texts = [str(doc) for doc in docs] + doc_subset = [{"content": doc} for doc in docs] + + scores = self.model.predict([(query, text) for text in texts]) + scores = [float(score) for score in scores] + docs_with_scores = list(zip(doc_subset, scores)) + docs_with_scores.sort(key=lambda x: x[1], reverse=True) + reranked_docs = [doc for doc, _ in docs_with_scores[:limit]] + scores = scores[:limit] + + if return_score: + return reranked_docs, scores + return reranked_docs + + async def arank( + self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs + ) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: + """ + Asynchronously rerank documents based on the provided query using the loaded cross-encoder model. + + This method processes the user's query and the provided documents to rerank them + in a manner that is potentially more relevant to the query's context. + + Parameters: + query (str): The user's search query. + docs (Union[List[Dict[str, Any]], List[str]]): The list of documents to be ranked, + either as dictionaries or strings. + + Returns: + Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: + The reranked list of documents and optionally associated scores. + """ + return self.rank(query, docs, **kwargs) diff --git a/tests/integration/test_rerankers.py b/tests/integration/test_rerankers.py index 4866aa58..e7da8fa3 100644 --- a/tests/integration/test_rerankers.py +++ b/tests/integration/test_rerankers.py @@ -2,24 +2,34 @@ import pytest -from redisvl.utils.rerank import CohereReranker +from redisvl.utils.rerank import CohereReranker, HFCrossEncoderReranker # Fixture for the reranker instance @pytest.fixture -def reranker(): +def cohereReranker(): skip_reranker = os.getenv("SKIP_RERANKERS", "False").lower() == "true" if skip_reranker: pytest.skip("Skipping reranker instantiation...") return CohereReranker() +@pytest.fixture +def hfCrossEncoderReranker(): + return HFCrossEncoderReranker() + + +@pytest.fixture +def hfCrossEncoderRerankerWithCustomModel(): + return HFCrossEncoderReranker("cross-encoder/stsb-distilroberta-base") + + # Test for basic ranking functionality -def test_rank_documents(reranker): +def test_rank_documents_cohere(cohereReranker): docs = ["document one", "document two", "document three"] query = "search query" - reranked_docs, scores = reranker.rank(query, docs) + reranked_docs, scores = cohereReranker.rank(query, docs) assert isinstance(reranked_docs, list) assert len(reranked_docs) == len(docs) # Ensure we get back as many docs as we sent @@ -28,11 +38,11 @@ def test_rank_documents(reranker): # Test for asynchronous ranking functionality @pytest.mark.asyncio -async def test_async_rank_documents(reranker): +async def test_async_rank_documents_cohere(cohereReranker): docs = ["document one", "document two", "document three"] query = "search query" - reranked_docs, scores = await reranker.arank(query, docs) + reranked_docs, scores = await cohereReranker.arank(query, docs) assert isinstance(reranked_docs, list) assert len(reranked_docs) == len(docs) # Ensure we get back as many docs as we sent @@ -40,17 +50,49 @@ async def test_async_rank_documents(reranker): # Test handling of bad input -def test_bad_input(reranker): +def test_bad_input_cohere(cohereReranker): with pytest.raises(Exception): - reranker.rank("", []) # Empty query or documents + cohereReranker.rank("", []) # Empty query or documents with pytest.raises(Exception): - reranker.rank(123, ["valid document"]) # Invalid type for query + cohereReranker.rank(123, ["valid document"]) # Invalid type for query with pytest.raises(Exception): - reranker.rank("valid query", "not a list") # Invalid type for documents + cohereReranker.rank("valid query", "not a list") # Invalid type for documents with pytest.raises(Exception): - reranker.rank( + cohereReranker.rank( "valid query", [{"field": "valid document"}], rank_by=["invalid_field"] ) # Invalid rank_by field + + +def test_rank_documents_cross_encoder(hfCrossEncoderReranker): + query = "I love you" + texts = ["I love you", "I like you", "I don't like you", "I hate you"] + reranked_docs, scores = hfCrossEncoderReranker.rank(query, texts) + + for i in range(min(len(texts), hfCrossEncoderReranker.limit) - 1): + assert scores[i] > scores[i + 1] + + +def test_rank_documents_cross_encoder_custom_model( + hfCrossEncoderRerankerWithCustomModel, +): + query = "I love you" + texts = ["I love you", "I like you", "I don't like you", "I hate you"] + reranked_docs, scores = hfCrossEncoderRerankerWithCustomModel.rank(query, texts) + + for i in range(min(len(texts), hfCrossEncoderRerankerWithCustomModel.limit) - 1): + assert scores[i] > scores[i + 1] + + +@pytest.mark.asyncio +async def test_async_rank_cross_encoder(hfCrossEncoderReranker): + docs = ["document one", "document two", "document three"] + query = "search query" + + reranked_docs, scores = await hfCrossEncoderReranker.arank(query, docs) + + assert isinstance(reranked_docs, list) + assert len(reranked_docs) == len(docs) # Ensure we get back as many docs as we sent + assert all(isinstance(score, float) for score in scores) # Scores should be floats diff --git a/tests/unit/test_cross_encoder_reranker.py b/tests/unit/test_cross_encoder_reranker.py new file mode 100644 index 00000000..8db57bb8 --- /dev/null +++ b/tests/unit/test_cross_encoder_reranker.py @@ -0,0 +1,53 @@ +import pytest +from sentence_transformers import CrossEncoder + +from redisvl.utils.rerank.hf_cross_encoder import HFCrossEncoderReranker + + +@pytest.fixture +def reranker(): + return HFCrossEncoderReranker() + + +def test_rank_documents(reranker): + docs = ["document one", "document two", "document three"] + query = "search query" + + reranked_docs, scores = reranker.rank(query, docs) + + assert isinstance(reranked_docs, list) + assert len(reranked_docs) == reranker.limit + assert all(isinstance(score, float) for score in scores) + + +@pytest.mark.asyncio +async def test_async_rank_documents(reranker): + docs = ["document one", "document two", "document three"] + query = "search query" + + reranked_docs, scores = await reranker.arank(query, docs) + + assert isinstance(reranked_docs, list) + assert len(reranked_docs) == reranker.limit + assert all(isinstance(score, float) for score in scores) + + +def test_bad_input(reranker): + with pytest.raises(ValueError): + reranker.rank("", []) # Empty query + + with pytest.raises(TypeError): + reranker.rank(123, ["valid document"]) # Invalid type for query + + with pytest.raises(TypeError): + reranker.rank("valid query", "not a list") # Invalid type for documents + + +def test_rerank_empty(reranker): + docs = [] + query = "search query" + + reranked_docs = reranker.rank(query, docs, return_score=False) + + assert isinstance(reranked_docs, list) + assert len(reranked_docs) == 0