Skip to content

Commit

Permalink
feat: Add filter_policy init parameter to in memory retrievers (#7795)
Browse files Browse the repository at this point in the history
* Add filter_policy init parameter to in-memory retrievers
  • Loading branch information
vblagoje committed Jul 3, 2024
1 parent d3cdbcc commit ac9b02f
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 9 deletions.
20 changes: 16 additions & 4 deletions haystack/components/retrievers/in_memory/bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.document_stores.types import FilterPolicy


@component
Expand Down Expand Up @@ -40,6 +41,7 @@ def __init__(
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
scale_score: bool = False,
filter_policy: FilterPolicy = FilterPolicy.REPLACE,
):
"""
Create the InMemoryBM25Retriever component.
Expand All @@ -52,7 +54,7 @@ def __init__(
The maximum number of documents to retrieve.
:param scale_score:
Scales the BM25 score to a unit interval in the range of 0 to 1, where 1 means extremely relevant. If set to `False`, uses raw similarity scores.
:param filter_policy: The filter policy to apply during retrieval.
:raises ValueError:
If the specified `top_k` is not > 0.
"""
Expand All @@ -67,6 +69,7 @@ def __init__(
self.filters = filters
self.top_k = top_k
self.scale_score = scale_score
self.filter_policy = filter_policy

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand All @@ -83,7 +86,12 @@ def to_dict(self) -> Dict[str, Any]:
"""
docstore = self.document_store.to_dict()
return default_to_dict(
self, document_store=docstore, filters=self.filters, top_k=self.top_k, scale_score=self.scale_score
self,
document_store=docstore,
filters=self.filters,
top_k=self.top_k,
scale_score=self.scale_score,
filter_policy=self.filter_policy.value,
)

@classmethod
Expand All @@ -101,6 +109,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "InMemoryBM25Retriever":
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
if "filter_policy" in init_params:
init_params["filter_policy"] = FilterPolicy.from_str(init_params["filter_policy"])
data["init_parameters"]["document_store"] = InMemoryDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
Expand Down Expand Up @@ -132,8 +142,10 @@ def run(
:raises ValueError:
If the specified DocumentStore is not found or is not a InMemoryDocumentStore instance.
"""
if filters is None:
filters = self.filters
if self.filter_policy == FilterPolicy.MERGE and filters:
filters = {**(self.filters or {}), **filters}
else:
filters = filters or self.filters
if top_k is None:
top_k = self.top_k
if scale_score is None:
Expand Down
14 changes: 11 additions & 3 deletions haystack/components/retrievers/in_memory/embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.document_stores.types import FilterPolicy


@component
Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(
top_k: int = 10,
scale_score: bool = False,
return_embedding: bool = False,
filter_policy: FilterPolicy = FilterPolicy.REPLACE,
):
"""
Create the InMemoryEmbeddingRetriever component.
Expand All @@ -64,7 +66,7 @@ def __init__(
Scales the BM25 score to a unit interval in the range of 0 to 1, where 1 means extremely relevant. If set to `False`, uses raw similarity scores.
:param return_embedding:
Whether to return the embedding of the retrieved Documents.
:param filter_policy: The filter policy to apply during retrieval.
:raises ValueError:
If the specified top_k is not > 0.
"""
Expand All @@ -80,6 +82,7 @@ def __init__(
self.top_k = top_k
self.scale_score = scale_score
self.return_embedding = return_embedding
self.filter_policy = filter_policy

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand All @@ -102,6 +105,7 @@ def to_dict(self) -> Dict[str, Any]:
top_k=self.top_k,
scale_score=self.scale_score,
return_embedding=self.return_embedding,
filter_policy=self.filter_policy.value,
)

@classmethod
Expand All @@ -119,6 +123,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "InMemoryEmbeddingRetriever":
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
if "filter_policy" in init_params:
init_params["filter_policy"] = FilterPolicy.from_str(init_params["filter_policy"])
data["init_parameters"]["document_store"] = InMemoryDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
Expand Down Expand Up @@ -153,8 +159,10 @@ def run(
:raises ValueError:
If the specified DocumentStore is not found or is not an InMemoryDocumentStore instance.
"""
if filters is None:
filters = self.filters
if self.filter_policy == FilterPolicy.MERGE and filters:
filters = {**(self.filters or {}), **filters}
else:
filters = filters or self.filters
if top_k is None:
top_k = self.top_k
if scale_score is None:
Expand Down
3 changes: 2 additions & 1 deletion haystack/document_stores/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0

from .filter_policy import FilterPolicy
from .policy import DuplicatePolicy
from .protocol import DocumentStore

__all__ = ["DocumentStore", "DuplicatePolicy"]
__all__ = ["DocumentStore", "DuplicatePolicy", "FilterPolicy"]
35 changes: 35 additions & 0 deletions haystack/document_stores/types/filter_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from enum import Enum


class FilterPolicy(Enum):
"""
Policy to determine how filters are applied in retrievers interacting with document stores.
"""

# Runtime filters replace init filters during retriever run invocation.
REPLACE = "replace"

# Runtime filters are merged with init filters, with runtime filters overwriting init values.
MERGE = "merge"

def __str__(self):
return self.value

@staticmethod
def from_str(filter_policy: str) -> "FilterPolicy":
"""
Convert a string to a FilterPolicy enum.
:param filter_policy: The string to convert.
:return: The corresponding FilterPolicy enum.
"""
enum_map = {e.value: e for e in FilterPolicy}
policy = enum_map.get(filter_policy)
if policy is None:
msg = f"Unknown FilterPolicy type '{filter_policy}'. Supported types are: {list(enum_map.keys())}"
raise ValueError(msg)
return policy
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
---
enhancements:
- |
Provides users the ability to customize text extraction from PDF files. It is particularly useful for PDFs with unusual layouts, such as those containing multiple text columns. For instance, users can configure the object to retain the reading order.
Provides users the ability to customize text extraction from PDF files. It is particularly useful for PDFs with unusual layouts, such as those containing multiple text columns. For instance, users can configure the object to retain the reading order.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Introduced a 'filter_policy' init parameter for both InMemoryBM25Retriever and InMemoryEmbeddingRetriever, allowing users to define how runtime filters should be applied with options to either 'replace' the initial filters or 'merge' them, providing greater flexibility in filtering query results.
4 changes: 4 additions & 0 deletions test/components/retrievers/test_in_memory_bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from haystack import Pipeline, DeserializationError
from haystack.document_stores.types import FilterPolicy
from haystack.testing.factory import document_store_class
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.dataclasses import Document
Expand Down Expand Up @@ -56,6 +57,7 @@ def test_to_dict(self):
"filters": None,
"top_k": 10,
"scale_score": False,
"filter_policy": "replace",
},
}

Expand All @@ -74,6 +76,7 @@ def test_to_dict_with_custom_init_parameters(self):
"filters": {"name": "test.txt"},
"top_k": 5,
"scale_score": True,
"filter_policy": "replace",
},
}

Expand All @@ -96,6 +99,7 @@ def test_from_dict(self):
assert component.filters == {"name": "test.txt"}
assert component.top_k == 5
assert component.scale_score is False
assert component.filter_policy == FilterPolicy.REPLACE

def test_from_dict_without_docstore(self):
data = {"type": "InMemoryBM25Retriever", "init_parameters": {}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np

from haystack import Pipeline, DeserializationError
from haystack.document_stores.types import FilterPolicy
from haystack.testing.factory import document_store_class
from haystack.components.retrievers.in_memory.embedding_retriever import InMemoryEmbeddingRetriever
from haystack.dataclasses import Document
Expand Down Expand Up @@ -47,6 +48,7 @@ def test_to_dict(self):
"top_k": 10,
"scale_score": False,
"return_embedding": False,
"filter_policy": "replace",
},
}

Expand All @@ -70,6 +72,7 @@ def test_to_dict_with_custom_init_parameters(self):
"top_k": 5,
"scale_score": True,
"return_embedding": True,
"filter_policy": "replace",
},
}

Expand All @@ -83,13 +86,15 @@ def test_from_dict(self):
},
"filters": {"name": "test.txt"},
"top_k": 5,
"filter_policy": "merge",
},
}
component = InMemoryEmbeddingRetriever.from_dict(data)
assert isinstance(component.document_store, InMemoryDocumentStore)
assert component.filters == {"name": "test.txt"}
assert component.top_k == 5
assert component.scale_score is False
assert component.filter_policy == FilterPolicy.MERGE

def test_from_dict_without_docstore(self):
data = {
Expand Down

0 comments on commit ac9b02f

Please sign in to comment.