From a1cc6f942270682104fa850980a9204fefacea98 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 8 Aug 2024 17:29:26 +0200 Subject: [PATCH] Revert back to Literal instead of LogicalOperator Enum --- haystack/document_stores/types/__init__.py | 4 +- .../document_stores/types/filter_policy.py | 81 +++++-------------- test/document_stores/test_filter_policy.py | 38 +-------- 3 files changed, 25 insertions(+), 98 deletions(-) diff --git a/haystack/document_stores/types/__init__.py b/haystack/document_stores/types/__init__.py index 1924b78a00..ed6becf8b4 100644 --- a/haystack/document_stores/types/__init__.py +++ b/haystack/document_stores/types/__init__.py @@ -2,8 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 -from .filter_policy import FilterPolicy, LogicalOperator, apply_filter_policy +from .filter_policy import FilterPolicy, apply_filter_policy from .policy import DuplicatePolicy from .protocol import DocumentStore -__all__ = ["apply_filter_policy", "DocumentStore", "DuplicatePolicy", "FilterPolicy", "LogicalOperator"] +__all__ = ["apply_filter_policy", "DocumentStore", "DuplicatePolicy", "FilterPolicy"] diff --git a/haystack/document_stores/types/filter_policy.py b/haystack/document_stores/types/filter_policy.py index 03e0114eb7..b0dc58d895 100644 --- a/haystack/document_stores/types/filter_policy.py +++ b/haystack/document_stores/types/filter_policy.py @@ -3,37 +3,13 @@ # SPDX-License-Identifier: Apache-2.0 from enum import Enum -from typing import Any, Dict, Optional +from typing import Any, Dict, Literal, Optional from haystack import logging logger = logging.getLogger(__name__) -class LogicalOperator(Enum): - AND = "AND" - OR = "OR" - NOT = "NOT" - - def __str__(self): - return self.value - - @staticmethod - def from_str(operator_label: str) -> "LogicalOperator": - """ - Convert a string to a LogicalOperator enum. - - :param operator_label: The string to convert. - :return: The corresponding LogicalOperator enum. - """ - enum_map = {e.value.lower(): e for e in LogicalOperator} - operator = enum_map.get(operator_label.lower() if operator_label else "") - if operator is None: - msg = f"Unknown LogicalOperator type '{operator}'. Supported types are: {list(enum_map.keys())}" - raise ValueError(msg) - return operator - - class FilterPolicy(Enum): """ Policy to determine how filters are applied in retrievers interacting with document stores. @@ -64,23 +40,6 @@ def from_str(filter_policy: str) -> "FilterPolicy": return policy -def convert_logical_operators(filter_dict: Dict[str, Any]) -> Dict[str, Any]: - """ - Convert string-based logical operators in a filter dictionary to LogicalOperator enums. - - :param filter_dict: A dictionary representing a filter with potential string logical operators. - :return: A new dictionary with LogicalOperator enums in place of string operators. - """ - # If the dictionary represents a logical filter, update the 'operator' - if is_logical_filter(filter_dict) and isinstance(filter_dict["operator"], str): - try: - filter_dict["operator"] = LogicalOperator.from_str(filter_dict["operator"]) - except ValueError as e: - raise ValueError(f"Error converting logical operator: {e}") - - return filter_dict - - def is_comparison_filter(filter_item: Dict[str, Any]) -> bool: """ Check if the given filter is a comparison filter. @@ -115,21 +74,21 @@ def combine_two_logical_filters( ```python init_logical_filter = { - "operator": LogicalOperator.AND, + "operator": "AND", "conditions": [ {"field": "meta.type", "operator": "==", "value": "article"}, {"field": "meta.rating", "operator": ">=", "value": 3}, ] } runtime_logical_filter = { - "operator": LogicalOperator.AND, + "operator": "AND", "conditions": [ {"field": "meta.genre", "operator": "IN", "value": ["economy", "politics"]}, {"field": "meta.publisher", "operator": "==", "value": "nytimes"}, ] } new_filters = combine_two_logical_filters( - init_logical_filter, runtime_logical_filter, LogicalOperator.AND + init_logical_filter, runtime_logical_filter, "AND" ) # Output: { @@ -163,7 +122,9 @@ def combine_two_logical_filters( def combine_init_comparison_and_runtime_logical_filters( - init_comparison_filter: Dict[str, Any], runtime_logical_filter: Dict[str, Any], logical_operator: LogicalOperator + init_comparison_filter: Dict[str, Any], + runtime_logical_filter: Dict[str, Any], + logical_operator: Literal["AND", "OR", "NOT"], ) -> Dict[str, Any]: """ Combine a runtime logical filter with the init comparison filter using the provided logical_operator. @@ -175,7 +136,7 @@ def combine_init_comparison_and_runtime_logical_filters( ```python runtime_logical_filter = { - "operator": LogicalOperator.AND, + "operator": "AND", "conditions": [ {"field": "meta.type", "operator": "==", "value": "article"}, {"field": "meta.rating", "operator": ">=", "value": 3}, @@ -183,7 +144,7 @@ def combine_init_comparison_and_runtime_logical_filters( } init_comparison_filter = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"} new_filters = combine_init_comparison_and_runtime_logical_filters( - init_comparison_filter, runtime_logical_filter, LogicalOperator.AND + init_comparison_filter, runtime_logical_filter, "AND" ) # Output: { @@ -221,7 +182,9 @@ def combine_init_comparison_and_runtime_logical_filters( def combine_runtime_comparison_and_init_logical_filters( - runtime_comparison_filter: Dict[str, Any], init_logical_filter: Dict[str, Any], logical_operator: LogicalOperator + runtime_comparison_filter: Dict[str, Any], + init_logical_filter: Dict[str, Any], + logical_operator: Literal["AND", "OR", "NOT"], ) -> Dict[str, Any]: """ Combine an init logical filter with the runtime comparison filter using the provided logical_operator. @@ -233,7 +196,7 @@ def combine_runtime_comparison_and_init_logical_filters( ```python init_logical_filter = { - "operator": LogicalOperator.AND, + "operator": "AND", "conditions": [ {"field": "meta.type", "operator": "==", "value": "article"}, {"field": "meta.rating", "operator": ">=", "value": 3}, @@ -241,7 +204,7 @@ def combine_runtime_comparison_and_init_logical_filters( } runtime_comparison_filter = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"} new_filters = combine_runtime_comparison_and_init_logical_filters( - runtime_comparison_filter, init_logical_filter, LogicalOperator.AND + runtime_comparison_filter, init_logical_filter, "AND" ) # Output: { @@ -277,7 +240,9 @@ def combine_runtime_comparison_and_init_logical_filters( def combine_two_comparison_filters( - init_comparison_filter: Dict[str, Any], runtime_comparison_filter: Dict[str, Any], logical_operator: LogicalOperator + init_comparison_filter: Dict[str, Any], + runtime_comparison_filter: Dict[str, Any], + logical_operator: Literal["AND", "OR", "NOT"], ) -> Dict[str, Any]: """ Combine a comparison filter with the `init_comparison_filter` using the provided `logical_operator`. @@ -291,7 +256,7 @@ def combine_two_comparison_filters( runtime_comparison_filter = {"field": "meta.type", "operator": "==", "value": "article"}, init_comparison_filter = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"}, new_filters = combine_two_comparison_filters( - init_comparison_filter, runtime_comparison_filter, LogicalOperator.AND + init_comparison_filter, runtime_comparison_filter, "AND" ) # Output: { @@ -319,7 +284,7 @@ def apply_filter_policy( filter_policy: FilterPolicy, init_filters: Optional[Dict[str, Any]] = None, runtime_filters: Optional[Dict[str, Any]] = None, - default_logical_operator: LogicalOperator = LogicalOperator.AND, + default_logical_operator: Literal["AND", "OR", "NOT"] = "AND", ) -> Optional[Dict[str, Any]]: """ Apply the filter policy to the given initial and runtime filters to determine the final set of filters used. @@ -337,9 +302,6 @@ def apply_filter_policy( :returns: A dictionary containing the resulting filters based on the provided policy. """ if filter_policy == FilterPolicy.MERGE and runtime_filters and init_filters: - # first convert string-based logical operators to LogicalOperator enums - init_filters = convert_logical_operators(init_filters) - runtime_filters = convert_logical_operators(runtime_filters) # now we merge filters if is_comparison_filter(init_filters) and is_comparison_filter(runtime_filters): return combine_two_comparison_filters(init_filters, runtime_filters, default_logical_operator) @@ -354,7 +316,4 @@ def apply_filter_policy( elif is_logical_filter(init_filters) and is_logical_filter(runtime_filters): return combine_two_logical_filters(init_filters, runtime_filters) - resulting_filter = runtime_filters or init_filters - if resulting_filter and is_logical_filter(resulting_filter): - resulting_filter["operator"] = str(resulting_filter["operator"]) - return resulting_filter + return runtime_filters or init_filters diff --git a/test/document_stores/test_filter_policy.py b/test/document_stores/test_filter_policy.py index 26ad193759..d775ee356e 100644 --- a/test/document_stores/test_filter_policy.py +++ b/test/document_stores/test_filter_policy.py @@ -4,36 +4,7 @@ import pytest -from haystack.document_stores.types import LogicalOperator, apply_filter_policy, FilterPolicy - - -def test_logical_operator_from_str(): - """ - Test the conversion of a string to a LogicalOperator enum. - """ - assert LogicalOperator.from_str("AND") == LogicalOperator.AND - assert LogicalOperator.from_str("OR") == LogicalOperator.OR - assert LogicalOperator.from_str("NOT") == LogicalOperator.NOT - - with pytest.raises(ValueError): - LogicalOperator.from_str(None) - - with pytest.raises(ValueError): - LogicalOperator.from_str("INVALID") - - -def test_filter_policy_from_str(): - """ - Test the conversion of a string to a FilterPolicy enum. - """ - assert FilterPolicy.from_str("REPLACE") == FilterPolicy.REPLACE - assert FilterPolicy.from_str("MERGE") == FilterPolicy.MERGE - - with pytest.raises(ValueError): - FilterPolicy.from_str(None) - - with pytest.raises(ValueError): - FilterPolicy.from_str("INVALID") +from haystack.document_stores.types import apply_filter_policy, FilterPolicy def test_merge_two_comparison_filters(): @@ -190,7 +161,7 @@ def test_merge_comparison_filters_with_same_field(): @pytest.mark.parametrize("logical_operator", ["AND", "OR", "NOT"]) -def test_merge_with_custom_logical_operator(logical_operator): +def test_merge_with_custom_logical_operator(logical_operator: str): """ Merging with a custom logical operator Result: The given logical operator with both filters @@ -198,10 +169,7 @@ def test_merge_with_custom_logical_operator(logical_operator): init_filters = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"} runtime_filters = {"field": "meta.type", "operator": "==", "value": "article"} result = apply_filter_policy( - FilterPolicy.MERGE, - init_filters, - runtime_filters, - default_logical_operator=LogicalOperator.from_str(logical_operator), + FilterPolicy.MERGE, init_filters, runtime_filters, default_logical_operator=logical_operator ) assert result == { "operator": logical_operator,