diff --git a/haystack/document_stores/types/__init__.py b/haystack/document_stores/types/__init__.py index df2032f79c..ed6becf8b4 100644 --- a/haystack/document_stores/types/__init__.py +++ b/haystack/document_stores/types/__init__.py @@ -2,8 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 -from .filter_policy import FilterPolicy +from .filter_policy import FilterPolicy, apply_filter_policy from .policy import DuplicatePolicy from .protocol import DocumentStore -__all__ = ["DocumentStore", "DuplicatePolicy", "FilterPolicy"] +__all__ = ["apply_filter_policy", "DocumentStore", "DuplicatePolicy", "FilterPolicy"] diff --git a/haystack/document_stores/types/filter_policy.py b/haystack/document_stores/types/filter_policy.py index a2be576d20..b0dc58d895 100644 --- a/haystack/document_stores/types/filter_policy.py +++ b/haystack/document_stores/types/filter_policy.py @@ -3,7 +3,11 @@ # SPDX-License-Identifier: Apache-2.0 from enum import Enum -from typing import Any, Dict, Optional +from typing import Any, Dict, Literal, Optional + +from haystack import logging + +logger = logging.getLogger(__name__) class FilterPolicy(Enum): @@ -28,18 +32,259 @@ def from_str(filter_policy: str) -> "FilterPolicy": :param filter_policy: The string to convert. :return: The corresponding FilterPolicy enum. """ - enum_map = {e.value: e for e in FilterPolicy} - policy = enum_map.get(filter_policy) + enum_map = {e.value.lower(): e for e in FilterPolicy} + policy = enum_map.get(filter_policy.lower() if filter_policy else "") if policy is None: msg = f"Unknown FilterPolicy type '{filter_policy}'. Supported types are: {list(enum_map.keys())}" raise ValueError(msg) return policy +def is_comparison_filter(filter_item: Dict[str, Any]) -> bool: + """ + Check if the given filter is a comparison filter. + + :param filter_item: The filter to check. + :returns: True if the filter is a comparison filter, False otherwise. + """ + return all(key in filter_item for key in ["field", "operator", "value"]) + + +def is_logical_filter(filter_item: Dict[str, Any]) -> bool: + """ + Check if the given filter is a logical filter. + + :param filter_item: The filter to check. + :returns: True if the filter is a logical filter, False otherwise. + """ + return "operator" in filter_item and "conditions" in filter_item + + +def combine_two_logical_filters( + init_logical_filter: Dict[str, Any], runtime_logical_filter: Dict[str, Any] +) -> Dict[str, Any]: + """ + Combine two logical filters, they must have the same operator. + + If `init_logical_filter["operator"]` and `runtime_logical_filter["operator"]` are the same, the conditions + of both filters are combined. Otherwise, the `init_logical_filter` is ignored and ` + runtime_logical_filter` is returned. + + __Example__: + + ```python + init_logical_filter = { + "operator": "AND", + "conditions": [ + {"field": "meta.type", "operator": "==", "value": "article"}, + {"field": "meta.rating", "operator": ">=", "value": 3}, + ] + } + runtime_logical_filter = { + "operator": "AND", + "conditions": [ + {"field": "meta.genre", "operator": "IN", "value": ["economy", "politics"]}, + {"field": "meta.publisher", "operator": "==", "value": "nytimes"}, + ] + } + new_filters = combine_two_logical_filters( + init_logical_filter, runtime_logical_filter, "AND" + ) + # Output: + { + "operator": "AND", + "conditions": [ + {"field": "meta.type", "operator": "==", "value": "article"}, + {"field": "meta.rating", "operator": ">=", "value": 3}, + {"field": "meta.genre", "operator": "IN", "value": ["economy", "politics"]}, + {"field": "meta.publisher", "operator": "==", "value": "nytimes"}, + ] + } + ``` + """ + if init_logical_filter["operator"] == runtime_logical_filter["operator"]: + return { + "operator": str(init_logical_filter["operator"]), + "conditions": init_logical_filter["conditions"] + runtime_logical_filter["conditions"], + } + + logger.warning( + "The provided logical operators, {parsed_operator} and {operator}, do not match so the parsed logical " + "filter, {init_logical_filter}, will be ignored and only the provided logical filter,{runtime_logical_filter}, " + "will be used. Update the logical operators to match to include the parsed filter.", + parsed_operator=init_logical_filter["operator"], + operator=runtime_logical_filter["operator"], + init_logical_filter=init_logical_filter, + runtime_logical_filter=runtime_logical_filter, + ) + runtime_logical_filter["operator"] = str(runtime_logical_filter["operator"]) + return runtime_logical_filter + + +def combine_init_comparison_and_runtime_logical_filters( + init_comparison_filter: Dict[str, Any], + runtime_logical_filter: Dict[str, Any], + logical_operator: Literal["AND", "OR", "NOT"], +) -> Dict[str, Any]: + """ + Combine a runtime logical filter with the init comparison filter using the provided logical_operator. + + We only add the init_comparison_filter if logical_operator matches the existing + runtime_logical_filter["operator"]. Otherwise, we return the runtime_logical_filter unchanged. + + __Example__: + + ```python + runtime_logical_filter = { + "operator": "AND", + "conditions": [ + {"field": "meta.type", "operator": "==", "value": "article"}, + {"field": "meta.rating", "operator": ">=", "value": 3}, + ] + } + init_comparison_filter = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"} + new_filters = combine_init_comparison_and_runtime_logical_filters( + init_comparison_filter, runtime_logical_filter, "AND" + ) + # Output: + { + "operator": "AND", + "conditions": [ + {"field": "meta.type", "operator": "==", "value": "article"}, + {"field": "meta.rating", "operator": ">=", "value": 3}, + {"field": "meta.date", "operator": ">=", "value": "2015-01-01"}, + ] + } + ``` + """ + if runtime_logical_filter["operator"] == logical_operator: + conditions = runtime_logical_filter["conditions"] + fields = {c.get("field") for c in conditions} + if init_comparison_filter["field"] not in fields: + conditions.append(init_comparison_filter) + else: + logger.warning( + "The init filter, {init_filter}, is ignored as the field is already present in the existing " + "filters, {filters}.", + init_filter=init_comparison_filter, + filters=runtime_logical_filter, + ) + return {"operator": str(runtime_logical_filter["operator"]), "conditions": conditions} + + logger.warning( + "The provided logical_operator, {logical_operator}, does not match the logical operator found in " + "the runtime filters, {filters_logical_operator}, so the init filter will be ignored.", + logical_operator=logical_operator, + filters_logical_operator=runtime_logical_filter["operator"], + ) + runtime_logical_filter["operator"] = str(runtime_logical_filter["operator"]) + return runtime_logical_filter + + +def combine_runtime_comparison_and_init_logical_filters( + runtime_comparison_filter: Dict[str, Any], + init_logical_filter: Dict[str, Any], + logical_operator: Literal["AND", "OR", "NOT"], +) -> Dict[str, Any]: + """ + Combine an init logical filter with the runtime comparison filter using the provided logical_operator. + + We only add the runtime_comparison_filter if logical_operator matches the existing + init_logical_filter["operator"]. Otherwise, we return the runtime_comparison_filter unchanged. + + __Example__: + + ```python + init_logical_filter = { + "operator": "AND", + "conditions": [ + {"field": "meta.type", "operator": "==", "value": "article"}, + {"field": "meta.rating", "operator": ">=", "value": 3}, + ] + } + runtime_comparison_filter = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"} + new_filters = combine_runtime_comparison_and_init_logical_filters( + runtime_comparison_filter, init_logical_filter, "AND" + ) + # Output: + { + "operator": "AND", + "conditions": [ + {"field": "meta.type", "operator": "==", "value": "article"}, + {"field": "meta.rating", "operator": ">=", "value": 3}, + {"field": "meta.date", "operator": ">=", "value": "2015-01-01"}, + ] + } + ``` + """ + if init_logical_filter["operator"] == logical_operator: + conditions = init_logical_filter["conditions"] + fields = {c.get("field") for c in conditions} + if runtime_comparison_filter["field"] in fields: + logger.warning( + "The runtime filter, {runtime_filter}, will overwrite the existing filter with the same " + "field in the init logical filter.", + runtime_filter=runtime_comparison_filter, + ) + conditions = [c for c in conditions if c.get("field") != runtime_comparison_filter["field"]] + conditions.append(runtime_comparison_filter) + return {"operator": str(init_logical_filter["operator"]), "conditions": conditions} + + logger.warning( + "The provided logical_operator, {logical_operator}, does not match the logical operator found in " + "the init logical filter, {filters_logical_operator}, so the init logical filter will be ignored.", + logical_operator=logical_operator, + filters_logical_operator=init_logical_filter["operator"], + ) + return runtime_comparison_filter + + +def combine_two_comparison_filters( + init_comparison_filter: Dict[str, Any], + runtime_comparison_filter: Dict[str, Any], + logical_operator: Literal["AND", "OR", "NOT"], +) -> Dict[str, Any]: + """ + Combine a comparison filter with the `init_comparison_filter` using the provided `logical_operator`. + + If `runtime_comparison_filter` and `init_comparison_filter` target the same field, `init_comparison_filter` + is ignored and `runtime_comparison_filter` is returned unchanged. + + __Example__: + + ```python + runtime_comparison_filter = {"field": "meta.type", "operator": "==", "value": "article"}, + init_comparison_filter = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"}, + new_filters = combine_two_comparison_filters( + init_comparison_filter, runtime_comparison_filter, "AND" + ) + # Output: + { + "operator": "AND", + "conditions": [ + {"field": "meta.type", "operator": "==", "value": "article"}, + {"field": "meta.date", "operator": ">=", "value": "2015-01-01"}, + ] + } + ``` + """ + if runtime_comparison_filter["field"] == init_comparison_filter["field"]: + logger.warning( + "The parsed filter, {parsed_filter}, is ignored as the field is already present in the existing " + "filters, {filters}.", + parsed_filter=init_comparison_filter, + filters=runtime_comparison_filter, + ) + return runtime_comparison_filter + + return {"operator": str(logical_operator), "conditions": [init_comparison_filter, runtime_comparison_filter]} + + def apply_filter_policy( filter_policy: FilterPolicy, init_filters: Optional[Dict[str, Any]] = None, runtime_filters: Optional[Dict[str, Any]] = None, + default_logical_operator: Literal["AND", "OR", "NOT"] = "AND", ) -> Optional[Dict[str, Any]]: """ Apply the filter policy to the given initial and runtime filters to determine the final set of filters used. @@ -52,10 +297,23 @@ def apply_filter_policy( values from the runtime filters will overwrite those from the initial filters. :param init_filters: The initial filters set during the initialization of the relevant retriever. :param runtime_filters: The filters provided at runtime, usually during a query operation execution. These filters - can change for each query/retreiver run invocation. + can change for each query/retriever run invocation. + :param default_logical_operator: The default logical operator to use when merging filters (non-legacy filters only). :returns: A dictionary containing the resulting filters based on the provided policy. """ - if filter_policy == FilterPolicy.MERGE and runtime_filters: - return {**(init_filters or {}), **runtime_filters} - else: - return runtime_filters or init_filters + if filter_policy == FilterPolicy.MERGE and runtime_filters and init_filters: + # now we merge filters + if is_comparison_filter(init_filters) and is_comparison_filter(runtime_filters): + return combine_two_comparison_filters(init_filters, runtime_filters, default_logical_operator) + elif is_comparison_filter(init_filters) and is_logical_filter(runtime_filters): + return combine_init_comparison_and_runtime_logical_filters( + init_filters, runtime_filters, default_logical_operator + ) + elif is_logical_filter(init_filters) and is_comparison_filter(runtime_filters): + return combine_runtime_comparison_and_init_logical_filters( + runtime_filters, init_filters, default_logical_operator + ) + elif is_logical_filter(init_filters) and is_logical_filter(runtime_filters): + return combine_two_logical_filters(init_filters, runtime_filters) + + return runtime_filters or init_filters diff --git a/releasenotes/notes/implement-merge-filter-logic-99e6785a78f80ae9.yaml b/releasenotes/notes/implement-merge-filter-logic-99e6785a78f80ae9.yaml new file mode 100644 index 0000000000..c90479c2c6 --- /dev/null +++ b/releasenotes/notes/implement-merge-filter-logic-99e6785a78f80ae9.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Enhanced filter application logic to support merging of filters. It facilitates more precise retrieval filtering, allowing for both init and runtime complex filter combinations with logical operators. For more details see https://docs.haystack.deepset.ai/docs/metadata-filtering diff --git a/test/document_stores/test_filter_policy.py b/test/document_stores/test_filter_policy.py index b7efcd0672..d775ee356e 100644 --- a/test/document_stores/test_filter_policy.py +++ b/test/document_stores/test_filter_policy.py @@ -3,43 +3,178 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -from typing import Any, Dict, Optional -from enum import Enum -from haystack.document_stores.types import FilterPolicy -from haystack.document_stores.types.filter_policy import apply_filter_policy +from haystack.document_stores.types import apply_filter_policy, FilterPolicy -def test_replace_policy_with_both_filters(): - init_filters = {"status": "active", "category": "news"} - runtime_filters = {"author": "John Doe"} - result = apply_filter_policy(FilterPolicy.REPLACE, init_filters, runtime_filters) - assert result == runtime_filters +def test_merge_two_comparison_filters(): + """ + Merging two comparison filters + Result: AND operator with both filters + """ + init_filters = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"} + runtime_filters = {"field": "meta.type", "operator": "==", "value": "article"} + result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters) + assert result == { + "operator": "AND", + "conditions": [ + {"field": "meta.date", "operator": ">=", "value": "2015-01-01"}, + {"field": "meta.type", "operator": "==", "value": "article"}, + ], + } -def test_merge_policy_with_both_filters(): - init_filters = {"status": "active", "category": "news"} - runtime_filters = {"author": "John Doe"} +def test_merge_init_comparison_and_runtime_logical_filters(): + """ + Merging init comparison and runtime logical filters + Result: AND operator with both filters + """ + init_filters = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"} + runtime_filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.type", "operator": "==", "value": "article"}, + {"field": "meta.rating", "operator": ">=", "value": 3}, + ], + } result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters) - assert result == {"status": "active", "category": "news", "author": "John Doe"} + assert result == { + "operator": "AND", + "conditions": [ + {"field": "meta.type", "operator": "==", "value": "article"}, + {"field": "meta.rating", "operator": ">=", "value": 3}, + {"field": "meta.date", "operator": ">=", "value": "2015-01-01"}, + ], + } -def test_replace_policy_with_only_init_filters(): - init_filters = {"status": "active", "category": "news"} - runtime_filters = None - result = apply_filter_policy(FilterPolicy.REPLACE, init_filters, runtime_filters) - assert result == init_filters +def test_merge_runtime_comparison_and_init_logical_filters_with_string_operators(): + """ + Merging a runtime comparison filter with an init logical filter, but with string-based logical operators + Result: AND operator with both filters + """ + # Test with string-based logical operators + init_filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.type", "operator": "==", "value": "article"}, + {"field": "meta.rating", "operator": ">=", "value": 3}, + ], + } + runtime_filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.genre", "operator": "IN", "value": ["economy", "politics"]}, + {"field": "meta.publisher", "operator": "==", "value": "nytimes"}, + ], + } + result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters) + assert result == { + "operator": "AND", + "conditions": [ + {"field": "meta.type", "operator": "==", "value": "article"}, + {"field": "meta.rating", "operator": ">=", "value": 3}, + {"field": "meta.genre", "operator": "IN", "value": ["economy", "politics"]}, + {"field": "meta.publisher", "operator": "==", "value": "nytimes"}, + ], + } -def test_merge_policy_with_only_init_filters(): - init_filters = {"status": "active", "category": "news"} - runtime_filters = None +def test_merge_runtime_comparison_and_init_logical_filters(): + """ + Merging a runtime comparison filter with an init logical filter + Result: AND operator with both filters + """ + init_filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.type", "operator": "==", "value": "article"}, + {"field": "meta.rating", "operator": ">=", "value": 3}, + ], + } + runtime_filters = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"} result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters) - assert result == init_filters + assert result == { + "operator": "AND", + "conditions": [ + {"field": "meta.type", "operator": "==", "value": "article"}, + {"field": "meta.rating", "operator": ">=", "value": 3}, + {"field": "meta.date", "operator": ">=", "value": "2015-01-01"}, + ], + } -def test_merge_policy_with_overlapping_keys(): - init_filters = {"status": "active", "category": "news"} - runtime_filters = {"category": "science", "author": "John Doe"} +def test_merge_two_logical_filters(): + """ + Merging two logical filters + Result: AND operator with both filters + """ + init_filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.type", "operator": "==", "value": "article"}, + {"field": "meta.rating", "operator": ">=", "value": 3}, + ], + } + runtime_filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.genre", "operator": "IN", "value": ["economy", "politics"]}, + {"field": "meta.publisher", "operator": "==", "value": "nytimes"}, + ], + } result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters) - assert result == {"status": "active", "category": "science", "author": "John Doe"} + assert result == { + "operator": "AND", + "conditions": [ + {"field": "meta.type", "operator": "==", "value": "article"}, + {"field": "meta.rating", "operator": ">=", "value": 3}, + {"field": "meta.genre", "operator": "IN", "value": ["economy", "politics"]}, + {"field": "meta.publisher", "operator": "==", "value": "nytimes"}, + ], + } + + +def test_merge_with_different_logical_operators(): + """ + Merging with a different logical operator + Result: warnings and runtime filters + """ + init_filters = {"operator": "AND", "conditions": [{"field": "meta.type", "operator": "==", "value": "article"}]} + runtime_filters = { + "operator": "OR", + "conditions": [{"field": "meta.genre", "operator": "IN", "value": ["economy", "politics"]}], + } + result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters) + assert result == runtime_filters + + +def test_merge_comparison_filters_with_same_field(): + """ + Merging comparison filters with the same field + Result: warnings and runtime filters + """ + init_filters = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"} + runtime_filters = {"field": "meta.date", "operator": "<=", "value": "2020-12-31"} + result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters) + assert result == runtime_filters + + +@pytest.mark.parametrize("logical_operator", ["AND", "OR", "NOT"]) +def test_merge_with_custom_logical_operator(logical_operator: str): + """ + Merging with a custom logical operator + Result: The given logical operator with both filters + """ + init_filters = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"} + runtime_filters = {"field": "meta.type", "operator": "==", "value": "article"} + result = apply_filter_policy( + FilterPolicy.MERGE, init_filters, runtime_filters, default_logical_operator=logical_operator + ) + assert result == { + "operator": logical_operator, + "conditions": [ + {"field": "meta.date", "operator": ">=", "value": "2015-01-01"}, + {"field": "meta.type", "operator": "==", "value": "article"}, + ], + }