diff --git a/haystack/document_stores/types/filter_policy.py b/haystack/document_stores/types/filter_policy.py index 786c30767d..4825ac31ec 100644 --- a/haystack/document_stores/types/filter_policy.py +++ b/haystack/document_stores/types/filter_policy.py @@ -81,6 +81,19 @@ def convert_logical_operators(filter_dict: Dict[str, Any]) -> Dict[str, Any]: return filter_dict +def logical_operators_to_str(filter_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert any logical operators found in a filter dictionary to string. + + :param filter_dict: A dictionary representing a filter with potential LogicalOperator enums. + :return: A new dictionary with LogicalOperator strings in place of enums. + """ + # If the dictionary represents a logical filter, update the 'operator' + if is_logical_filter(filter_dict) and isinstance(filter_dict["operator"], LogicalOperator): + filter_dict["operator"] = filter_dict["operator"].value + return filter_dict + + def is_comparison_filter(filter_item: Dict[str, Any]) -> bool: """ Check if the given filter is a comparison filter. @@ -340,16 +353,23 @@ def apply_filter_policy( runtime_filters = convert_logical_operators(runtime_filters) # now we merge filters if is_comparison_filter(init_filters) and is_comparison_filter(runtime_filters): - return combine_two_comparison_filters(init_filters, runtime_filters, default_logical_operator) + combined_filters = combine_two_comparison_filters(init_filters, runtime_filters, default_logical_operator) + return logical_operators_to_str(combined_filters) elif is_comparison_filter(init_filters) and is_logical_filter(runtime_filters): - return combine_init_comparison_and_runtime_logical_filters( + combined_filters = combine_init_comparison_and_runtime_logical_filters( init_filters, runtime_filters, default_logical_operator ) + return logical_operators_to_str(combined_filters) elif is_logical_filter(init_filters) and is_comparison_filter(runtime_filters): - return combine_runtime_comparison_and_init_logical_filters( + combined_filters = combine_runtime_comparison_and_init_logical_filters( runtime_filters, init_filters, default_logical_operator ) + return logical_operators_to_str(combined_filters) elif is_logical_filter(init_filters) and is_logical_filter(runtime_filters): - return combine_two_logical_filters(init_filters, runtime_filters) + combined_filters = combine_two_logical_filters(init_filters, runtime_filters) + return logical_operators_to_str(combined_filters) - return runtime_filters or init_filters + resulting_filter = runtime_filters or init_filters + if resulting_filter: + resulting_filter = logical_operators_to_str(resulting_filter) + return resulting_filter diff --git a/test/document_stores/test_filter_policy.py b/test/document_stores/test_filter_policy.py index abe761ddac..26ad193759 100644 --- a/test/document_stores/test_filter_policy.py +++ b/test/document_stores/test_filter_policy.py @@ -45,7 +45,7 @@ def test_merge_two_comparison_filters(): runtime_filters = {"field": "meta.type", "operator": "==", "value": "article"} result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters) assert result == { - "operator": LogicalOperator.AND, + "operator": "AND", "conditions": [ {"field": "meta.date", "operator": ">=", "value": "2015-01-01"}, {"field": "meta.type", "operator": "==", "value": "article"}, @@ -60,7 +60,7 @@ def test_merge_init_comparison_and_runtime_logical_filters(): """ init_filters = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"} runtime_filters = { - "operator": LogicalOperator.AND, + "operator": "AND", "conditions": [ {"field": "meta.type", "operator": "==", "value": "article"}, {"field": "meta.rating", "operator": ">=", "value": 3}, @@ -68,7 +68,7 @@ def test_merge_init_comparison_and_runtime_logical_filters(): } result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters) assert result == { - "operator": LogicalOperator.AND, + "operator": "AND", "conditions": [ {"field": "meta.type", "operator": "==", "value": "article"}, {"field": "meta.rating", "operator": ">=", "value": 3}, @@ -99,7 +99,7 @@ def test_merge_runtime_comparison_and_init_logical_filters_with_string_operators } result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters) assert result == { - "operator": LogicalOperator.AND, + "operator": "AND", "conditions": [ {"field": "meta.type", "operator": "==", "value": "article"}, {"field": "meta.rating", "operator": ">=", "value": 3}, @@ -115,7 +115,7 @@ def test_merge_runtime_comparison_and_init_logical_filters(): Result: AND operator with both filters """ init_filters = { - "operator": LogicalOperator.AND, + "operator": "AND", "conditions": [ {"field": "meta.type", "operator": "==", "value": "article"}, {"field": "meta.rating", "operator": ">=", "value": 3}, @@ -124,7 +124,7 @@ def test_merge_runtime_comparison_and_init_logical_filters(): runtime_filters = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"} result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters) assert result == { - "operator": LogicalOperator.AND, + "operator": "AND", "conditions": [ {"field": "meta.type", "operator": "==", "value": "article"}, {"field": "meta.rating", "operator": ">=", "value": 3}, @@ -139,14 +139,14 @@ def test_merge_two_logical_filters(): Result: AND operator with both filters """ init_filters = { - "operator": LogicalOperator.AND, + "operator": "AND", "conditions": [ {"field": "meta.type", "operator": "==", "value": "article"}, {"field": "meta.rating", "operator": ">=", "value": 3}, ], } runtime_filters = { - "operator": LogicalOperator.AND, + "operator": "AND", "conditions": [ {"field": "meta.genre", "operator": "IN", "value": ["economy", "politics"]}, {"field": "meta.publisher", "operator": "==", "value": "nytimes"}, @@ -154,7 +154,7 @@ def test_merge_two_logical_filters(): } result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters) assert result == { - "operator": LogicalOperator.AND, + "operator": "AND", "conditions": [ {"field": "meta.type", "operator": "==", "value": "article"}, {"field": "meta.rating", "operator": ">=", "value": 3}, @@ -171,7 +171,7 @@ def test_merge_with_different_logical_operators(): """ init_filters = {"operator": "AND", "conditions": [{"field": "meta.type", "operator": "==", "value": "article"}]} runtime_filters = { - "operator": LogicalOperator.OR, + "operator": "OR", "conditions": [{"field": "meta.genre", "operator": "IN", "value": ["economy", "politics"]}], } result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters) @@ -204,7 +204,7 @@ def test_merge_with_custom_logical_operator(logical_operator): default_logical_operator=LogicalOperator.from_str(logical_operator), ) assert result == { - "operator": LogicalOperator.from_str(logical_operator), + "operator": logical_operator, "conditions": [ {"field": "meta.date", "operator": ">=", "value": "2015-01-01"}, {"field": "meta.type", "operator": "==", "value": "article"},