Skip to content

Commit

Permalink
Revert back to Literal instead of LogicalOperator Enum
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Aug 8, 2024
1 parent aeac10d commit a1cc6f9
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 98 deletions.
4 changes: 2 additions & 2 deletions haystack/document_stores/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0

from .filter_policy import FilterPolicy, LogicalOperator, apply_filter_policy
from .filter_policy import FilterPolicy, apply_filter_policy
from .policy import DuplicatePolicy
from .protocol import DocumentStore

__all__ = ["apply_filter_policy", "DocumentStore", "DuplicatePolicy", "FilterPolicy", "LogicalOperator"]
__all__ = ["apply_filter_policy", "DocumentStore", "DuplicatePolicy", "FilterPolicy"]
81 changes: 20 additions & 61 deletions haystack/document_stores/types/filter_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,13 @@
# SPDX-License-Identifier: Apache-2.0

from enum import Enum
from typing import Any, Dict, Optional
from typing import Any, Dict, Literal, Optional

from haystack import logging

logger = logging.getLogger(__name__)


class LogicalOperator(Enum):
AND = "AND"
OR = "OR"
NOT = "NOT"

def __str__(self):
return self.value

@staticmethod
def from_str(operator_label: str) -> "LogicalOperator":
"""
Convert a string to a LogicalOperator enum.
:param operator_label: The string to convert.
:return: The corresponding LogicalOperator enum.
"""
enum_map = {e.value.lower(): e for e in LogicalOperator}
operator = enum_map.get(operator_label.lower() if operator_label else "")
if operator is None:
msg = f"Unknown LogicalOperator type '{operator}'. Supported types are: {list(enum_map.keys())}"
raise ValueError(msg)
return operator


class FilterPolicy(Enum):
"""
Policy to determine how filters are applied in retrievers interacting with document stores.
Expand Down Expand Up @@ -64,23 +40,6 @@ def from_str(filter_policy: str) -> "FilterPolicy":
return policy


def convert_logical_operators(filter_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert string-based logical operators in a filter dictionary to LogicalOperator enums.
:param filter_dict: A dictionary representing a filter with potential string logical operators.
:return: A new dictionary with LogicalOperator enums in place of string operators.
"""
# If the dictionary represents a logical filter, update the 'operator'
if is_logical_filter(filter_dict) and isinstance(filter_dict["operator"], str):
try:
filter_dict["operator"] = LogicalOperator.from_str(filter_dict["operator"])
except ValueError as e:
raise ValueError(f"Error converting logical operator: {e}")

return filter_dict


def is_comparison_filter(filter_item: Dict[str, Any]) -> bool:
"""
Check if the given filter is a comparison filter.
Expand Down Expand Up @@ -115,21 +74,21 @@ def combine_two_logical_filters(
```python
init_logical_filter = {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.rating", "operator": ">=", "value": 3},
]
}
runtime_logical_filter = {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.genre", "operator": "IN", "value": ["economy", "politics"]},
{"field": "meta.publisher", "operator": "==", "value": "nytimes"},
]
}
new_filters = combine_two_logical_filters(
init_logical_filter, runtime_logical_filter, LogicalOperator.AND
init_logical_filter, runtime_logical_filter, "AND"
)
# Output:
{
Expand Down Expand Up @@ -163,7 +122,9 @@ def combine_two_logical_filters(


def combine_init_comparison_and_runtime_logical_filters(
init_comparison_filter: Dict[str, Any], runtime_logical_filter: Dict[str, Any], logical_operator: LogicalOperator
init_comparison_filter: Dict[str, Any],
runtime_logical_filter: Dict[str, Any],
logical_operator: Literal["AND", "OR", "NOT"],
) -> Dict[str, Any]:
"""
Combine a runtime logical filter with the init comparison filter using the provided logical_operator.
Expand All @@ -175,15 +136,15 @@ def combine_init_comparison_and_runtime_logical_filters(
```python
runtime_logical_filter = {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.rating", "operator": ">=", "value": 3},
]
}
init_comparison_filter = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"}
new_filters = combine_init_comparison_and_runtime_logical_filters(
init_comparison_filter, runtime_logical_filter, LogicalOperator.AND
init_comparison_filter, runtime_logical_filter, "AND"
)
# Output:
{
Expand Down Expand Up @@ -221,7 +182,9 @@ def combine_init_comparison_and_runtime_logical_filters(


def combine_runtime_comparison_and_init_logical_filters(
runtime_comparison_filter: Dict[str, Any], init_logical_filter: Dict[str, Any], logical_operator: LogicalOperator
runtime_comparison_filter: Dict[str, Any],
init_logical_filter: Dict[str, Any],
logical_operator: Literal["AND", "OR", "NOT"],
) -> Dict[str, Any]:
"""
Combine an init logical filter with the runtime comparison filter using the provided logical_operator.
Expand All @@ -233,15 +196,15 @@ def combine_runtime_comparison_and_init_logical_filters(
```python
init_logical_filter = {
"operator": LogicalOperator.AND,
"operator": "AND",
"conditions": [
{"field": "meta.type", "operator": "==", "value": "article"},
{"field": "meta.rating", "operator": ">=", "value": 3},
]
}
runtime_comparison_filter = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"}
new_filters = combine_runtime_comparison_and_init_logical_filters(
runtime_comparison_filter, init_logical_filter, LogicalOperator.AND
runtime_comparison_filter, init_logical_filter, "AND"
)
# Output:
{
Expand Down Expand Up @@ -277,7 +240,9 @@ def combine_runtime_comparison_and_init_logical_filters(


def combine_two_comparison_filters(
init_comparison_filter: Dict[str, Any], runtime_comparison_filter: Dict[str, Any], logical_operator: LogicalOperator
init_comparison_filter: Dict[str, Any],
runtime_comparison_filter: Dict[str, Any],
logical_operator: Literal["AND", "OR", "NOT"],
) -> Dict[str, Any]:
"""
Combine a comparison filter with the `init_comparison_filter` using the provided `logical_operator`.
Expand All @@ -291,7 +256,7 @@ def combine_two_comparison_filters(
runtime_comparison_filter = {"field": "meta.type", "operator": "==", "value": "article"},
init_comparison_filter = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"},
new_filters = combine_two_comparison_filters(
init_comparison_filter, runtime_comparison_filter, LogicalOperator.AND
init_comparison_filter, runtime_comparison_filter, "AND"
)
# Output:
{
Expand Down Expand Up @@ -319,7 +284,7 @@ def apply_filter_policy(
filter_policy: FilterPolicy,
init_filters: Optional[Dict[str, Any]] = None,
runtime_filters: Optional[Dict[str, Any]] = None,
default_logical_operator: LogicalOperator = LogicalOperator.AND,
default_logical_operator: Literal["AND", "OR", "NOT"] = "AND",
) -> Optional[Dict[str, Any]]:
"""
Apply the filter policy to the given initial and runtime filters to determine the final set of filters used.
Expand All @@ -337,9 +302,6 @@ def apply_filter_policy(
:returns: A dictionary containing the resulting filters based on the provided policy.
"""
if filter_policy == FilterPolicy.MERGE and runtime_filters and init_filters:
# first convert string-based logical operators to LogicalOperator enums
init_filters = convert_logical_operators(init_filters)
runtime_filters = convert_logical_operators(runtime_filters)
# now we merge filters
if is_comparison_filter(init_filters) and is_comparison_filter(runtime_filters):
return combine_two_comparison_filters(init_filters, runtime_filters, default_logical_operator)
Expand All @@ -354,7 +316,4 @@ def apply_filter_policy(
elif is_logical_filter(init_filters) and is_logical_filter(runtime_filters):
return combine_two_logical_filters(init_filters, runtime_filters)

resulting_filter = runtime_filters or init_filters
if resulting_filter and is_logical_filter(resulting_filter):
resulting_filter["operator"] = str(resulting_filter["operator"])
return resulting_filter
return runtime_filters or init_filters
38 changes: 3 additions & 35 deletions test/document_stores/test_filter_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,7 @@

import pytest

from haystack.document_stores.types import LogicalOperator, apply_filter_policy, FilterPolicy


def test_logical_operator_from_str():
"""
Test the conversion of a string to a LogicalOperator enum.
"""
assert LogicalOperator.from_str("AND") == LogicalOperator.AND
assert LogicalOperator.from_str("OR") == LogicalOperator.OR
assert LogicalOperator.from_str("NOT") == LogicalOperator.NOT

with pytest.raises(ValueError):
LogicalOperator.from_str(None)

with pytest.raises(ValueError):
LogicalOperator.from_str("INVALID")


def test_filter_policy_from_str():
"""
Test the conversion of a string to a FilterPolicy enum.
"""
assert FilterPolicy.from_str("REPLACE") == FilterPolicy.REPLACE
assert FilterPolicy.from_str("MERGE") == FilterPolicy.MERGE

with pytest.raises(ValueError):
FilterPolicy.from_str(None)

with pytest.raises(ValueError):
FilterPolicy.from_str("INVALID")
from haystack.document_stores.types import apply_filter_policy, FilterPolicy


def test_merge_two_comparison_filters():
Expand Down Expand Up @@ -190,18 +161,15 @@ def test_merge_comparison_filters_with_same_field():


@pytest.mark.parametrize("logical_operator", ["AND", "OR", "NOT"])
def test_merge_with_custom_logical_operator(logical_operator):
def test_merge_with_custom_logical_operator(logical_operator: str):
"""
Merging with a custom logical operator
Result: The given logical operator with both filters
"""
init_filters = {"field": "meta.date", "operator": ">=", "value": "2015-01-01"}
runtime_filters = {"field": "meta.type", "operator": "==", "value": "article"}
result = apply_filter_policy(
FilterPolicy.MERGE,
init_filters,
runtime_filters,
default_logical_operator=LogicalOperator.from_str(logical_operator),
FilterPolicy.MERGE, init_filters, runtime_filters, default_logical_operator=logical_operator
)
assert result == {
"operator": logical_operator,
Expand Down

0 comments on commit a1cc6f9

Please sign in to comment.