Skip to content

Commit

Permalink
Convert all apply_filter_policy logical operators to str representations
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Aug 8, 2024
1 parent 519de1c commit a2621ad
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 16 deletions.
30 changes: 25 additions & 5 deletions haystack/document_stores/types/filter_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,19 @@ def convert_logical_operators(filter_dict: Dict[str, Any]) -> Dict[str, Any]:
return filter_dict


def logical_operators_to_str(filter_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert any logical operators found in a filter dictionary to string.
:param filter_dict: A dictionary representing a filter with potential LogicalOperator enums.
:return: A new dictionary with LogicalOperator strings in place of enums.
"""
# If the dictionary represents a logical filter, update the 'operator'
if is_logical_filter(filter_dict) and isinstance(filter_dict["operator"], LogicalOperator):
filter_dict["operator"] = filter_dict["operator"].value
return filter_dict


def is_comparison_filter(filter_item: Dict[str, Any]) -> bool:
"""
Check if the given filter is a comparison filter.
Expand Down Expand Up @@ -340,16 +353,23 @@ def apply_filter_policy(
runtime_filters = convert_logical_operators(runtime_filters)
# now we merge filters
if is_comparison_filter(init_filters) and is_comparison_filter(runtime_filters):
return combine_two_comparison_filters(init_filters, runtime_filters, default_logical_operator)
combined_filters = combine_two_comparison_filters(init_filters, runtime_filters, default_logical_operator)
return logical_operators_to_str(combined_filters)
elif is_comparison_filter(init_filters) and is_logical_filter(runtime_filters):
return combine_init_comparison_and_runtime_logical_filters(
combined_filters = combine_init_comparison_and_runtime_logical_filters(
init_filters, runtime_filters, default_logical_operator
)
return logical_operators_to_str(combined_filters)
elif is_logical_filter(init_filters) and is_comparison_filter(runtime_filters):
return combine_runtime_comparison_and_init_logical_filters(
combined_filters = combine_runtime_comparison_and_init_logical_filters(
runtime_filters, init_filters, default_logical_operator
)
return logical_operators_to_str(combined_filters)
elif is_logical_filter(init_filters) and is_logical_filter(runtime_filters):
return combine_two_logical_filters(init_filters, runtime_filters)
combined_filters = combine_two_logical_filters(init_filters, runtime_filters)
return logical_operators_to_str(combined_filters)

return runtime_filters or init_filters
resulting_filter = runtime_filters or init_filters
if resulting_filter:
resulting_filter = logical_operators_to_str(resulting_filter)
return resulting_filter
22 changes: 11 additions & 11 deletions test/document_stores/test_filter_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_merge_two_comparison_filters():
runtime_filters = {"field": "meta.type", "operator": "==", "value": "article"}
result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters)
assert result == {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.date", "operator": ">=", "value": "2015-01-01"},
{"field": "meta.type", "operator": "==", "value": "article"},
Expand All @@ -60,15 +60,15 @@ def test_merge_init_comparison_and_runtime_logical_filters():
"""
init_filters = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"}
runtime_filters = {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.rating", "operator": ">=", "value": 3},
],
}
result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters)
assert result == {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.rating", "operator": ">=", "value": 3},
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_merge_runtime_comparison_and_init_logical_filters_with_string_operators
}
result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters)
assert result == {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.rating", "operator": ">=", "value": 3},
Expand All @@ -115,7 +115,7 @@ def test_merge_runtime_comparison_and_init_logical_filters():
Result: AND operator with both filters
"""
init_filters = {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.rating", "operator": ">=", "value": 3},
Expand All @@ -124,7 +124,7 @@ def test_merge_runtime_comparison_and_init_logical_filters():
runtime_filters = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"}
result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters)
assert result == {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.rating", "operator": ">=", "value": 3},
Expand All @@ -139,22 +139,22 @@ def test_merge_two_logical_filters():
Result: AND operator with both filters
"""
init_filters = {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.rating", "operator": ">=", "value": 3},
],
}
runtime_filters = {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.genre", "operator": "IN", "value": ["economy", "politics"]},
{"field": "meta.publisher", "operator": "==", "value": "nytimes"},
],
}
result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters)
assert result == {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.rating", "operator": ">=", "value": 3},
Expand All @@ -171,7 +171,7 @@ def test_merge_with_different_logical_operators():
"""
init_filters = {"operator": "AND", "conditions": [{"field": "meta.type", "operator": "==", "value": "article"}]}
runtime_filters = {
"operator": LogicalOperator.OR,
"operator": "OR",
"conditions": [{"field": "meta.genre", "operator": "IN", "value": ["economy", "politics"]}],
}
result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters)
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_merge_with_custom_logical_operator(logical_operator):
default_logical_operator=LogicalOperator.from_str(logical_operator),
)
assert result == {
"operator": LogicalOperator.from_str(logical_operator),
"operator": logical_operator,
"conditions": [
{"field": "meta.date", "operator": ">=", "value": "2015-01-01"},
{"field": "meta.type", "operator": "==", "value": "article"},
Expand Down

0 comments on commit a2621ad

Please sign in to comment.