Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement apply_filter_policy and FilterPolicy.MERGE for the new filters #8042

Merged
merged 19 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions haystack/document_stores/types/filter_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,19 @@ def convert_logical_operators(filter_dict: Dict[str, Any]) -> Dict[str, Any]:
return filter_dict


def logical_operators_to_str(filter_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert any logical operators found in a filter dictionary to string.

:param filter_dict: A dictionary representing a filter with potential LogicalOperator enums.
:return: A new dictionary with LogicalOperator strings in place of enums.
"""
# If the dictionary represents a logical filter, update the 'operator'
if is_logical_filter(filter_dict) and isinstance(filter_dict["operator"], LogicalOperator):
filter_dict["operator"] = filter_dict["operator"].value
return filter_dict
shadeMe marked this conversation as resolved.
Show resolved Hide resolved


def is_comparison_filter(filter_item: Dict[str, Any]) -> bool:
"""
Check if the given filter is a comparison filter.
Expand Down Expand Up @@ -133,7 +146,7 @@ def combine_two_logical_filters(
)
# Output:
{
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.rating", "operator": ">=", "value": 3},
Expand Down Expand Up @@ -186,7 +199,7 @@ def combine_init_comparison_and_runtime_logical_filters(
)
# Output:
{
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.rating", "operator": ">=", "value": 3},
Expand Down Expand Up @@ -243,7 +256,7 @@ def combine_runtime_comparison_and_init_logical_filters(
)
# Output:
{
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.rating", "operator": ">=", "value": 3},
Expand Down Expand Up @@ -293,7 +306,7 @@ def combine_two_comparison_filters(
)
# Output:
{
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.date", "operator": ">=", "value": "2015-01-01"},
Expand Down Expand Up @@ -340,16 +353,23 @@ def apply_filter_policy(
runtime_filters = convert_logical_operators(runtime_filters)
# now we merge filters
if is_comparison_filter(init_filters) and is_comparison_filter(runtime_filters):
return combine_two_comparison_filters(init_filters, runtime_filters, default_logical_operator)
combined_filters = combine_two_comparison_filters(init_filters, runtime_filters, default_logical_operator)
return logical_operators_to_str(combined_filters)
elif is_comparison_filter(init_filters) and is_logical_filter(runtime_filters):
return combine_init_comparison_and_runtime_logical_filters(
combined_filters = combine_init_comparison_and_runtime_logical_filters(
init_filters, runtime_filters, default_logical_operator
)
return logical_operators_to_str(combined_filters)
elif is_logical_filter(init_filters) and is_comparison_filter(runtime_filters):
return combine_runtime_comparison_and_init_logical_filters(
combined_filters = combine_runtime_comparison_and_init_logical_filters(
runtime_filters, init_filters, default_logical_operator
)
return logical_operators_to_str(combined_filters)
elif is_logical_filter(init_filters) and is_logical_filter(runtime_filters):
return combine_two_logical_filters(init_filters, runtime_filters)
combined_filters = combine_two_logical_filters(init_filters, runtime_filters)
return logical_operators_to_str(combined_filters)

return runtime_filters or init_filters
resulting_filter = runtime_filters or init_filters
if resulting_filter:
resulting_filter = logical_operators_to_str(resulting_filter)
return resulting_filter
22 changes: 11 additions & 11 deletions test/document_stores/test_filter_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_merge_two_comparison_filters():
runtime_filters = {"field": "meta.type", "operator": "==", "value": "article"}
result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters)
assert result == {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.date", "operator": ">=", "value": "2015-01-01"},
{"field": "meta.type", "operator": "==", "value": "article"},
Expand All @@ -60,15 +60,15 @@ def test_merge_init_comparison_and_runtime_logical_filters():
"""
init_filters = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"}
runtime_filters = {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.rating", "operator": ">=", "value": 3},
],
}
result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters)
assert result == {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.rating", "operator": ">=", "value": 3},
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_merge_runtime_comparison_and_init_logical_filters_with_string_operators
}
result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters)
assert result == {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.rating", "operator": ">=", "value": 3},
Expand All @@ -115,7 +115,7 @@ def test_merge_runtime_comparison_and_init_logical_filters():
Result: AND operator with both filters
"""
init_filters = {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.rating", "operator": ">=", "value": 3},
Expand All @@ -124,7 +124,7 @@ def test_merge_runtime_comparison_and_init_logical_filters():
runtime_filters = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"}
result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters)
assert result == {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.rating", "operator": ">=", "value": 3},
Expand All @@ -139,22 +139,22 @@ def test_merge_two_logical_filters():
Result: AND operator with both filters
"""
init_filters = {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.rating", "operator": ">=", "value": 3},
],
}
runtime_filters = {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.genre", "operator": "IN", "value": ["economy", "politics"]},
{"field": "meta.publisher", "operator": "==", "value": "nytimes"},
],
}
result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters)
assert result == {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.rating", "operator": ">=", "value": 3},
Expand All @@ -171,7 +171,7 @@ def test_merge_with_different_logical_operators():
"""
init_filters = {"operator": "AND", "conditions": [{"field": "meta.type", "operator": "==", "value": "article"}]}
runtime_filters = {
"operator": LogicalOperator.OR,
"operator": "OR",
"conditions": [{"field": "meta.genre", "operator": "IN", "value": ["economy", "politics"]}],
}
result = apply_filter_policy(FilterPolicy.MERGE, init_filters, runtime_filters)
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_merge_with_custom_logical_operator(logical_operator):
default_logical_operator=LogicalOperator.from_str(logical_operator),
)
assert result == {
"operator": LogicalOperator.from_str(logical_operator),
"operator": logical_operator,
"conditions": [
{"field": "meta.date", "operator": ">=", "value": "2015-01-01"},
{"field": "meta.type", "operator": "==", "value": "article"},
Expand Down