diff --git a/frontend/components/Domain/Recipe/RecipeCardSection.vue b/frontend/components/Domain/Recipe/RecipeCardSection.vue index d3e3c49e4c..30211b7f1d 100644 --- a/frontend/components/Domain/Recipe/RecipeCardSection.vue +++ b/frontend/components/Domain/Recipe/RecipeCardSection.vue @@ -217,7 +217,7 @@ export default defineComponent({ const queryFilter = computed(() => { const orderBy = props.query?.orderBy || preferences.value.orderBy; - return preferences.value.filterNull && orderBy ? `${orderBy} <> null` : null; + return preferences.value.filterNull && orderBy ? `${orderBy} IS NOT NULL` : null; }); async function fetchRecipes(pageCount = 1) { diff --git a/mealie/schema/response/query_filter.py b/mealie/schema/response/query_filter.py index e557862921..5ff5324c04 100644 --- a/mealie/schema/response/query_filter.py +++ b/mealie/schema/response/query_filter.py @@ -1,7 +1,7 @@ from __future__ import annotations -import datetime import re +from collections import deque from enum import Enum from typing import Any, TypeVar, cast from uuid import UUID @@ -9,16 +9,66 @@ from dateutil import parser as date_parser from dateutil.parser import ParserError from humps import decamelize -from sqlalchemy import Select, bindparam, inspect, text -from sqlalchemy.orm import Mapper +from sqlalchemy import ColumnElement, Select, and_, inspect, or_ +from sqlalchemy.orm import InstrumentedAttribute, Mapper from sqlalchemy.sql import sqltypes -from sqlalchemy.sql.expression import BindParameter from mealie.db.models._model_utils.guid import GUID Model = TypeVar("Model") +class RelationalKeyword(Enum): + IS = "IS" + IS_NOT = "IS NOT" + IN = "IN" + NOT_IN = "NOT IN" + CONTAINS_ALL = "CONTAINS ALL" + LIKE = "LIKE" + NOT_LIKE = "NOT LIKE" + + @classmethod + def parse_component(cls, component: str) -> list[str] | None: + """ + Try to parse a component using a relational keyword + + If no matching keyword is found, returns None + """ + + # extract the attribute name from the component + parsed_component = component.split(maxsplit=1) + if len(parsed_component) < 2: + return None + + # assume the component has already filtered out the value and try to match a keyword + # if we try to filter out the value without checking first, keywords with spaces won't parse correctly + possible_keyword = parsed_component[1].strip().lower() + for rel_kw in sorted([keyword.value for keyword in cls], key=len, reverse=True): + if rel_kw.lower() != possible_keyword: + continue + + parsed_component[1] = rel_kw + return parsed_component + + # there was no match, so the component may still have the value in it + try: + _possible_keyword, _value = parsed_component[-1].rsplit(maxsplit=1) + parsed_component = [parsed_component[0], _possible_keyword, _value] + except ValueError: + # the component has no value to filter out + return None + + possible_keyword = parsed_component[1].strip().lower() + for rel_kw in sorted([keyword.value for keyword in cls], key=len, reverse=True): + if rel_kw.lower() != possible_keyword: + continue + + parsed_component[1] = rel_kw + return parsed_component + + return None + + class RelationalOperator(Enum): EQ = "=" NOTEQ = "<>" @@ -27,6 +77,24 @@ class RelationalOperator(Enum): GTE = ">=" LTE = "<=" + @classmethod + def parse_component(cls, component: str) -> list[str] | None: + """ + Try to parse a component using a relational operator + + If no matching operator is found, returns None + """ + + for rel_op in sorted([operator.value for operator in cls], key=len, reverse=True): + if rel_op not in component: + continue + + parsed_component = [base_component.strip() for base_component in component.split(rel_op) if base_component] + parsed_component.insert(1, rel_op) + return parsed_component + + return None + class LogicalOperator(Enum): AND = "AND" @@ -36,31 +104,107 @@ class LogicalOperator(Enum): class QueryFilterComponent: """A single relational statement""" - def __init__(self, attribute_name: str, relational_operator: RelationalOperator, value: str) -> None: + @staticmethod + def strip_quotes_from_string(val: str) -> str: + if len(val) > 2 and val[0] == '"' and val[-1] == '"': + return val[1:-1] + else: + return val + + def __init__( + self, attribute_name: str, relationship: RelationalKeyword | RelationalOperator, value: str | list[str] + ) -> None: self.attribute_name = decamelize(attribute_name) - self.relational_operator = relational_operator - self.value = value + self.relationship = relationship # remove encasing quotes - if len(value) > 2 and value[0] == '"' and value[-1] == '"': - self.value = value[1:-1] + if isinstance(value, str): + value = self.strip_quotes_from_string(value) + + elif isinstance(value, list): + value = [self.strip_quotes_from_string(v) for v in value] + + # validate relationship/value pairs + if relationship in [ + RelationalKeyword.IN, + RelationalKeyword.NOT_IN, + RelationalKeyword.CONTAINS_ALL, + ] and not isinstance(value, list): + raise ValueError( + f"invalid query string: {relationship.value} must be given a list of values" + f"enclosed by {QueryFilter.l_list_sep} and {QueryFilter.r_list_sep}" + ) + + if relationship is RelationalKeyword.IS or relationship is RelationalKeyword.IS_NOT: + if not isinstance(value, str) or value.lower() not in ["null", "none"]: + raise ValueError( + f'invalid query string: "{relationship.value}" can only be used with "NULL", not "{value}"' + ) + + self.value = None + else: + self.value = value def __repr__(self) -> str: - return f"[{self.attribute_name} {self.relational_operator.value} {self.value}]" + return f"[{self.attribute_name} {self.relationship.value} {self.value}]" + + def validate(self, model_attr_type: Any) -> Any: + """Validate value against an model attribute's type and return a validated value, or raise a ValueError""" + + sanitized_values: list[Any] + if not isinstance(self.value, list): + sanitized_values = [self.value] + else: + sanitized_values = self.value + + for i, v in enumerate(sanitized_values): + # always allow querying for null values + if v is None: + continue + + if self.relationship is RelationalKeyword.LIKE or self.relationship is RelationalKeyword.NOT_LIKE: + if not isinstance(model_attr_type, sqltypes.String): + raise ValueError( + f'invalid query string: "{self.relationship.value}" can only be used with string columns' + ) + + if isinstance(model_attr_type, (GUID)): + try: + # we don't set value since a UUID is functionally identical to a string here + UUID(v) + except ValueError as e: + raise ValueError(f"invalid query string: invalid UUID '{v}'") from e + + if isinstance(model_attr_type, sqltypes.Date | sqltypes.DateTime): + try: + sanitized_values[i] = date_parser.parse(v) + except ParserError as e: + raise ValueError(f"invalid query string: unknown date or datetime format '{v}'") from e + + if isinstance(model_attr_type, sqltypes.Boolean): + try: + sanitized_values[i] = v.lower()[0] in ["t", "y"] or v == "1" + except IndexError as e: + raise ValueError("invalid query string") from e + + return sanitized_values if isinstance(self.value, list) else sanitized_values[0] class QueryFilter: - lsep: str = "(" - rsep: str = ")" + l_group_sep: str = "(" + r_group_sep: str = ")" + group_seps: set[str] = {l_group_sep, r_group_sep} - seps: set[str] = {lsep, rsep} + l_list_sep: str = "[" + r_list_sep: str = "]" + list_item_sep: str = "," def __init__(self, filter_string: str) -> None: # parse filter string components = QueryFilter._break_filter_string_into_components(filter_string) base_components = QueryFilter._break_components_into_base_components(components) - if base_components.count(QueryFilter.lsep) != base_components.count(QueryFilter.rsep): - raise ValueError("invalid filter string: parenthesis are unbalanced") + if base_components.count(QueryFilter.l_group_sep) != base_components.count(QueryFilter.r_group_sep): + raise ValueError("invalid query string: parenthesis are unbalanced") # parse base components into a filter group self.filter_components = QueryFilter._parse_base_components_into_filter_components(base_components) @@ -75,97 +219,125 @@ def __repr__(self) -> str: return f"<<{joined}>>" + @classmethod + def _consolidate_group(cls, group: list[ColumnElement], logical_operators: deque[LogicalOperator]) -> ColumnElement: + consolidated_group_builder: ColumnElement | None = None + for i, element in enumerate(reversed(group)): + if not i: + consolidated_group_builder = element + else: + operator = logical_operators.pop() + if operator is LogicalOperator.AND: + consolidated_group_builder = and_(consolidated_group_builder, element) + elif operator is LogicalOperator.OR: + consolidated_group_builder = or_(consolidated_group_builder, element) + else: + raise ValueError(f"invalid logical operator {operator}") + + if i == len(group) - 1: + return consolidated_group_builder.self_group() + def filter_query(self, query: Select, model: type[Model]) -> Select: - segments: list[str] = [] - params: list[BindParameter] = [] + # join tables and build model chain + attr_model_map: dict[int, Any] = {} + model_attr: InstrumentedAttribute for i, component in enumerate(self.filter_components): - if component in QueryFilter.seps: - segments.append(component) # type: ignore + if not isinstance(component, QueryFilterComponent): continue - if isinstance(component, LogicalOperator): - segments.append(component.value) - continue - - # for some reason typing doesn't like the lsep and rsep literals, so - # we explicitly mark this as a filter component instead cast doesn't - # actually do anything at runtime - component = cast(QueryFilterComponent, component) attribute_chain = component.attribute_name.split(".") if not attribute_chain: raise ValueError("invalid query string: attribute name cannot be empty") - attr_model: Any = model + current_model = model for j, attribute_link in enumerate(attribute_chain): - # last element - if j == len(attribute_chain) - 1: - if not hasattr(attr_model, attribute_link): - raise ValueError( - f"invalid query string: '{component.attribute_name}' does not exist on this schema" - ) - - attr_value = attribute_link - if j: - # use the nested table name, rather than the dot notation - component.attribute_name = f"{attr_model.__table__.name}.{attr_value}" - - continue - - # join on nested model try: - query = query.join(getattr(attr_model, attribute_link)) + model_attr = getattr(current_model, attribute_link) - mapper: Mapper = inspect(attr_model) + # at the end of the chain there are no more relationships to inspect + if j == len(attribute_chain) - 1: + break + + query = query.join(model_attr) + mapper: Mapper = inspect(current_model) relationship = mapper.relationships[attribute_link] - attr_model = relationship.mapper.class_ + current_model = relationship.mapper.class_ except (AttributeError, KeyError) as e: raise ValueError( f"invalid query string: '{component.attribute_name}' does not exist on this schema" ) from e + attr_model_map[i] = current_model - # convert values to their proper types - attr = getattr(attr_model, attr_value) - value: Any = component.value - - if isinstance(attr.type, (GUID)): - try: - # we don't set value since a UUID is functionally identical to a string here - UUID(value) - - except ValueError as e: - raise ValueError(f"invalid query string: invalid UUID '{component.value}'") from e - - if isinstance(attr.type, sqltypes.Date | sqltypes.DateTime): - # TODO: add support for IS NULL and IS NOT NULL - # in the meantime, this will work for the specific usecase of non-null dates/datetimes - if value in ["none", "null"] and component.relational_operator == RelationalOperator.NOTEQ: - component.relational_operator = RelationalOperator.GTE - value = datetime.datetime(datetime.MINYEAR, 1, 1) - + # build query filter + partial_group: list[ColumnElement] = [] + partial_group_stack: deque[list[ColumnElement]] = deque() + logical_operator_stack: deque[LogicalOperator] = deque() + for i, component in enumerate(self.filter_components): + if component == self.l_group_sep: + partial_group_stack.append(partial_group) + partial_group = [] + + elif component == self.r_group_sep: + if partial_group: + complete_group = self._consolidate_group(partial_group, logical_operator_stack) + partial_group = partial_group_stack.pop() + partial_group.append(complete_group) else: - try: - value = date_parser.parse(component.value) - - except ParserError as e: - raise ValueError( - f"invalid query string: unknown date or datetime format '{component.value}'" - ) from e - - if isinstance(attr.type, sqltypes.Boolean): - try: - value = component.value.lower()[0] in ["t", "y"] or component.value == "1" - - except IndexError as e: - raise ValueError("invalid query string") from e + partial_group = partial_group_stack.pop() + + elif isinstance(component, LogicalOperator): + logical_operator_stack.append(component) + + else: + component = cast(QueryFilterComponent, component) + model_attr = getattr(attr_model_map[i], component.attribute_name.split(".")[-1]) + + # Keywords + if component.relationship is RelationalKeyword.IS: + element = model_attr.is_(component.validate(model_attr.type)) + elif component.relationship is RelationalKeyword.IS_NOT: + element = model_attr.is_not(component.validate(model_attr.type)) + elif component.relationship is RelationalKeyword.IN: + element = model_attr.in_(component.validate(model_attr.type)) + elif component.relationship is RelationalKeyword.NOT_IN: + element = model_attr.not_in(component.validate(model_attr.type)) + elif component.relationship is RelationalKeyword.CONTAINS_ALL: + primary_model_attr: InstrumentedAttribute = getattr(model, component.attribute_name.split(".")[0]) + element = and_() + for v in component.validate(model_attr.type): + element = and_(element, primary_model_attr.any(model_attr == v)) + elif component.relationship is RelationalKeyword.LIKE: + element = model_attr.like(component.validate(model_attr.type)) + elif component.relationship is RelationalKeyword.NOT_LIKE: + element = model_attr.not_like(component.validate(model_attr.type)) + + # Operators + elif component.relationship is RelationalOperator.EQ: + element = model_attr == component.validate(model_attr.type) + elif component.relationship is RelationalOperator.NOTEQ: + element = model_attr != component.validate(model_attr.type) + elif component.relationship is RelationalOperator.GT: + element = model_attr > component.validate(model_attr.type) + elif component.relationship is RelationalOperator.LT: + element = model_attr < component.validate(model_attr.type) + elif component.relationship is RelationalOperator.GTE: + element = model_attr >= component.validate(model_attr.type) + elif component.relationship is RelationalOperator.LTE: + element = model_attr <= component.validate(model_attr.type) + else: + raise ValueError(f"invalid relationship {component.relationship}") - paramkey = f"P{i+1}" - segments.append(" ".join([component.attribute_name, component.relational_operator.value, f":{paramkey}"])) - params.append(bindparam(paramkey, value, attr.type)) + partial_group.append(element) - qs = text(" ".join(segments)).bindparams(*params) - query = query.filter(qs) - return query + # combine the completed groups into one filter + while True: + consolidated_group = self._consolidate_group(partial_group, logical_operator_stack) + if not partial_group_stack: + return query.filter(consolidated_group) + else: + partial_group = partial_group_stack.pop() + partial_group.append(consolidated_group) @staticmethod def _break_filter_string_into_components(filter_string: str) -> list[str]: @@ -176,7 +348,7 @@ def _break_filter_string_into_components(filter_string: str) -> list[str]: subcomponents = [] for component in components: # don't parse components comprised of only a separator - if component in QueryFilter.seps: + if component in QueryFilter.group_seps: subcomponents.append(component) continue @@ -187,7 +359,7 @@ def _break_filter_string_into_components(filter_string: str) -> list[str]: if c == '"': in_quotes = not in_quotes - if c in QueryFilter.seps and not in_quotes: + if c in QueryFilter.group_seps and not in_quotes: if new_component: subcomponents.append(new_component) @@ -208,25 +380,50 @@ def _break_filter_string_into_components(filter_string: str) -> list[str]: return components @staticmethod - def _break_components_into_base_components(components: list[str]) -> list[str]: + def _break_components_into_base_components(components: list[str]) -> list[str | list[str]]: """Further break down components by splitting at relational and logical operators""" - logical_operators = re.compile( - f'({"|".join(operator.value for operator in LogicalOperator)})', flags=re.IGNORECASE - ) + pattern = "|".join([f"\\b{operator.value}\\b" for operator in LogicalOperator]) + logical_operators = re.compile(f"({pattern})", flags=re.IGNORECASE) - base_components = [] + in_list = False + base_components: list[str | list] = [] + list_value_components = [] for component in components: - offset = 0 + # parse out lists as their own singular sub component + subcomponents = component.split(QueryFilter.l_list_sep) + for i, subcomponent in enumerate(subcomponents): + if not i: + continue + + for j, list_value_string in enumerate(subcomponent.split(QueryFilter.r_list_sep)): + if j % 2: + continue + + list_value_components.append( + [val.strip() for val in list_value_string.split(QueryFilter.list_item_sep)] + ) + + quote_offset = 0 subcomponents = component.split('"') for i, subcomponent in enumerate(subcomponents): + # we are in a list subcomponent, which is already handled + if in_list: + if QueryFilter.r_list_sep in subcomponent: + # filter out the remainder of the list subcomponent and continue parsing + base_components.append(list_value_components.pop(0)) + subcomponent = subcomponent.split(QueryFilter.r_list_sep, maxsplit=1)[-1].strip() + in_list = False + else: + continue + # don't parse components comprised of only a separator - if subcomponent in QueryFilter.seps: - offset += 1 + if subcomponent in QueryFilter.group_seps: + quote_offset += 1 base_components.append(subcomponent) continue - # this subscomponent was surrounded in quotes, so we keep it as-is - if (i + offset) % 2: + # this subcomponent was surrounded in quotes, so we keep it as-is + if (i + quote_offset) % 2: base_components.append(f'"{subcomponent.strip()}"') continue @@ -234,53 +431,70 @@ def _break_components_into_base_components(components: list[str]) -> list[str]: if not subcomponent: continue + # continue parsing this subcomponent up to the list, then skip over subsequent subcomponents + if not in_list and QueryFilter.l_list_sep in subcomponent: + subcomponent, _new_sub_component = subcomponent.split(QueryFilter.l_list_sep, maxsplit=1) + subcomponent = subcomponent.strip() + subcomponents.insert(i + 1, _new_sub_component) + quote_offset += 1 + in_list = True + # parse out logical operators new_components = [ base_component.strip() for base_component in logical_operators.split(subcomponent) if base_component ] - # parse out relational operators; each base_subcomponent has exactly zero or one relational operator - # we do them one at a time in descending length since some operators overlap (e.g. :> and >) + # parse out relational keywords and operators + # each base_subcomponent has exactly zero or one keyword or operator for component in new_components: if not component: continue - added_to_base_components = False - for rel_op in sorted([operator.value for operator in RelationalOperator], key=len, reverse=True): - if rel_op in component: - new_base_components = [ - base_component.strip() for base_component in component.split(rel_op) if base_component - ] - new_base_components.insert(1, rel_op) - base_components.extend(new_base_components) + # we try relational operators first since they aren't required to be surrounded by spaces + parsed_component = RelationalOperator.parse_component(component) + if parsed_component is not None: + base_components.extend(parsed_component) + continue - added_to_base_components = True - break + parsed_component = RelationalKeyword.parse_component(component) + if parsed_component is not None: + base_components.extend(parsed_component) + continue - if not added_to_base_components: - base_components.append(component) + # this component does not have any keywords or operators, so we just add it as-is + base_components.append(component) return base_components @staticmethod def _parse_base_components_into_filter_components( - base_components: list[str], + base_components: list[str | list[str]], ) -> list[str | QueryFilterComponent | LogicalOperator]: """Walk through base components and construct filter collections""" + relational_keywords = [kw.value for kw in RelationalKeyword] relational_operators = [op.value for op in RelationalOperator] logical_operators = [op.value for op in LogicalOperator] # parse QueryFilterComponents and logical operators components: list[str | QueryFilterComponent | LogicalOperator] = [] for i, base_component in enumerate(base_components): - if base_component in QueryFilter.seps: + if isinstance(base_component, list): + continue + + if base_component in QueryFilter.group_seps: components.append(base_component) - elif base_component in relational_operators: + elif base_component in relational_keywords or base_component in relational_operators: + relationship: RelationalKeyword | RelationalOperator + if base_component in relational_keywords: + relationship = RelationalKeyword(base_components[i]) + else: + relationship = RelationalOperator(base_components[i]) + components.append( QueryFilterComponent( - attribute_name=base_components[i - 1], - relational_operator=RelationalOperator(base_components[i]), + attribute_name=base_components[i - 1], # type: ignore + relationship=relationship, value=base_components[i + 1], ) ) diff --git a/tests/unit_tests/repository_tests/test_pagination.py b/tests/unit_tests/repository_tests/test_pagination.py index 322b1bf883..eaee104bcb 100644 --- a/tests/unit_tests/repository_tests/test_pagination.py +++ b/tests/unit_tests/repository_tests/test_pagination.py @@ -1,5 +1,6 @@ import time from collections import defaultdict +from datetime import datetime from random import randint from urllib.parse import parse_qsl, urlsplit @@ -9,7 +10,10 @@ from mealie.repos.repository_factory import AllRepositories from mealie.repos.repository_units import RepositoryUnit +from mealie.schema.recipe import Recipe +from mealie.schema.recipe.recipe_category import CategorySave, TagSave from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit +from mealie.schema.recipe.recipe_tool import RecipeToolSave from mealie.schema.response.pagination import PaginationQuery from mealie.services.seeder.seeder_service import SeederService from tests.utils import api_routes @@ -172,6 +176,256 @@ def test_pagination_filter_basic(query_units: tuple[RepositoryUnit, IngredientUn assert unit_results[0].id == unit_2.id +def test_pagination_filter_null(database: AllRepositories, unique_user: TestUser): + recipe_not_made_1 = database.recipes.create( + Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string()) + ) + recipe_not_made_2 = database.recipes.create( + Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string()) + ) + + # give one recipe a last made date + recipe_made = database.recipes.create( + Recipe( + user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string(), last_made=datetime.now() + ) + ) + + recipe_repo = database.recipes.by_group(unique_user.group_id) # type: ignore + + query = PaginationQuery(page=1, per_page=-1, query_filter="lastMade IS NONE") + recipe_results = recipe_repo.page_all(query).items + assert len(recipe_results) == 2 + result_ids = {result.id for result in recipe_results} + assert recipe_not_made_1.id in result_ids + assert recipe_not_made_2.id in result_ids + assert recipe_made.id not in result_ids + + query = PaginationQuery(page=1, per_page=-1, query_filter="lastMade IS NULL") + recipe_results = recipe_repo.page_all(query).items + assert len(recipe_results) == 2 + result_ids = {result.id for result in recipe_results} + assert recipe_not_made_1.id in result_ids + assert recipe_not_made_2.id in result_ids + assert recipe_made.id not in result_ids + + query = PaginationQuery(page=1, per_page=-1, query_filter="lastMade IS NOT NONE") + recipe_results = recipe_repo.page_all(query).items + assert len(recipe_results) == 1 + result_ids = {result.id for result in recipe_results} + assert recipe_not_made_1.id not in result_ids + assert recipe_not_made_2.id not in result_ids + assert recipe_made.id in result_ids + + query = PaginationQuery(page=1, per_page=-1, query_filter="lastMade IS NOT NULL") + recipe_results = recipe_repo.page_all(query).items + assert len(recipe_results) == 1 + result_ids = {result.id for result in recipe_results} + assert recipe_not_made_1.id not in result_ids + assert recipe_not_made_2.id not in result_ids + assert recipe_made.id in result_ids + + +def test_pagination_filter_in(query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]): + units_repo, unit_1, unit_2, unit_3 = query_units + + query = PaginationQuery(page=1, per_page=-1, query_filter=f"name IN [{unit_1.name}, {unit_2.name}]") + unit_results = units_repo.page_all(query).items + + assert len(unit_results) == 2 + result_ids = {unit.id for unit in unit_results} + assert unit_1.id in result_ids + assert unit_2.id in result_ids + assert unit_3.id not in result_ids + + query = PaginationQuery(page=1, per_page=-1, query_filter=f"name NOT IN [{unit_1.name}, {unit_2.name}]") + unit_results = units_repo.page_all(query).items + + assert len(unit_results) == 1 + result_ids = {unit.id for unit in unit_results} + assert unit_1.id not in result_ids + assert unit_2.id not in result_ids + assert unit_3.id in result_ids + + query = PaginationQuery(page=1, per_page=-1, query_filter=f'name IN ["{unit_3.name}"]') + unit_results = units_repo.page_all(query).items + + assert len(unit_results) == 1 + result_ids = {unit.id for unit in unit_results} + assert unit_1.id not in result_ids + assert unit_2.id not in result_ids + assert unit_3.id in result_ids + + +def test_pagination_filter_in_advanced(database: AllRepositories, unique_user: TestUser): + slug1, slug2 = (random_string(10) for _ in range(2)) + + tags = [ + TagSave(group_id=unique_user.group_id, name=slug1, slug=slug1), + TagSave(group_id=unique_user.group_id, name=slug2, slug=slug2), + ] + + tag_1, tag_2 = [database.tags.create(tag) for tag in tags] + + # Bootstrap the database with recipes + slug = random_string() + recipe_0 = database.recipes.create( + Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug, tags=[]) + ) + + slug = random_string() + recipe_1 = database.recipes.create( + Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug, tags=[tag_1]) + ) + + slug = random_string() + recipe_2 = database.recipes.create( + Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug, tags=[tag_2]) + ) + + slug = random_string() + recipe_1_2 = database.recipes.create( + Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug, tags=[tag_1, tag_2]) + ) + + query = PaginationQuery(page=1, per_page=-1, query_filter=f"tags.name IN [{tag_1.name}]") + recipe_results = database.recipes.page_all(query).items + assert len(recipe_results) == 2 + recipe_ids = {recipe.id for recipe in recipe_results} + assert recipe_0.id not in recipe_ids + assert recipe_1.id in recipe_ids + assert recipe_2.id not in recipe_ids + assert recipe_1_2.id in recipe_ids + + query = PaginationQuery(page=1, per_page=-1, query_filter=f"tags.name IN [{tag_1.name}, {tag_2.name}]") + recipe_results = database.recipes.page_all(query).items + assert len(recipe_results) == 3 + recipe_ids = {recipe.id for recipe in recipe_results} + assert recipe_0.id not in recipe_ids + assert recipe_1.id in recipe_ids + assert recipe_2.id in recipe_ids + assert recipe_1_2.id in recipe_ids + + query = PaginationQuery(page=1, per_page=-1, query_filter=f"tags.name CONTAINS ALL [{tag_1.name}, {tag_2.name}]") + recipe_results = database.recipes.page_all(query).items + assert len(recipe_results) == 1 + recipe_ids = {recipe.id for recipe in recipe_results} + assert recipe_0.id not in recipe_ids + assert recipe_1.id not in recipe_ids + assert recipe_2.id not in recipe_ids + assert recipe_1_2.id in recipe_ids + + +def test_pagination_filter_like(query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]): + units_repo, unit_1, unit_2, unit_3 = query_units + + query = PaginationQuery(page=1, per_page=-1, query_filter=r'name LIKE "test u_it%"') + unit_results = units_repo.page_all(query).items + + assert len(unit_results) == 3 + result_ids = {unit.id for unit in unit_results} + assert unit_1.id in result_ids + assert unit_2.id in result_ids + assert unit_3.id in result_ids + + query = PaginationQuery(page=1, per_page=-1, query_filter=r'name LIKE "%unit 1"') + unit_results = units_repo.page_all(query).items + + assert len(unit_results) == 1 + result_ids = {unit.id for unit in unit_results} + assert unit_1.id in result_ids + assert unit_2.id not in result_ids + assert unit_3.id not in result_ids + + query = PaginationQuery(page=1, per_page=-1, query_filter=r'name NOT LIKE %t_1"') + unit_results = units_repo.page_all(query).items + + assert len(unit_results) == 2 + result_ids = {unit.id for unit in unit_results} + assert unit_1.id not in result_ids + assert unit_2.id in result_ids + assert unit_3.id in result_ids + + +def test_pagination_filter_keyword_namespace_conflict(database: AllRepositories, unique_user: TestUser): + recipe_rating_1 = database.recipes.create( + Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string(), rating=1) + ) + recipe_rating_2 = database.recipes.create( + Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string(), rating=2) + ) + + recipe_rating_3 = database.recipes.create( + Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string(), rating=3) + ) + + recipe_repo = database.recipes.by_group(unique_user.group_id) # type: ignore + + # "rating" contains the word "in", but we should not parse this as the keyword "IN" + query = PaginationQuery(page=1, per_page=-1, query_filter="rating > 2") + recipe_results = recipe_repo.page_all(query).items + + assert len(recipe_results) == 1 + result_ids = {recipe.id for recipe in recipe_results} + assert recipe_rating_1.id not in result_ids + assert recipe_rating_2.id not in result_ids + assert recipe_rating_3.id in result_ids + + query = PaginationQuery(page=1, per_page=-1, query_filter="rating in [1, 3]") + recipe_results = recipe_repo.page_all(query).items + + assert len(recipe_results) == 2 + result_ids = {recipe.id for recipe in recipe_results} + assert recipe_rating_1.id in result_ids + assert recipe_rating_2.id not in result_ids + assert recipe_rating_3.id in result_ids + + +def test_pagination_filter_logical_namespace_conflict(database: AllRepositories, unique_user: TestUser): + categories = [ + CategorySave(group_id=unique_user.group_id, name=random_string(10)), + CategorySave(group_id=unique_user.group_id, name=random_string(10)), + ] + category_1, category_2 = [database.categories.create(category) for category in categories] + + # Bootstrap the database with recipes + slug = random_string() + recipe_category_0 = database.recipes.create( + Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug) + ) + + slug = random_string() + recipe_category_1 = database.recipes.create( + Recipe( + user_id=unique_user.user_id, + group_id=unique_user.group_id, + name=slug, + slug=slug, + recipe_category=[category_1], + ) + ) + + slug = random_string() + recipe_category_2 = database.recipes.create( + Recipe( + user_id=unique_user.user_id, + group_id=unique_user.group_id, + name=slug, + slug=slug, + recipe_category=[category_2], + ) + ) + + # "recipeCategory" has the substring "or" in it, which shouldn't break queries + query = PaginationQuery(page=1, per_page=-1, query_filter=f'recipeCategory.id = "{category_1.id}"') + recipe_results = database.recipes.by_group(unique_user.group_id).page_all(query).items # type: ignore + assert len(recipe_results) == 1 + recipe_ids = {recipe.id for recipe in recipe_results} + assert recipe_category_0.id not in recipe_ids + assert recipe_category_1.id in recipe_ids + assert recipe_category_2.id not in recipe_ids + + def test_pagination_filter_datetimes( query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit] ): @@ -197,15 +451,183 @@ def test_pagination_filter_booleans(query_units: tuple[RepositoryUnit, Ingredien def test_pagination_filter_advanced(query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]): - units_repo = query_units[0] - unit_3 = query_units[3] + units_repo, unit_1, unit_2, unit_3 = query_units dt = str(unit_3.created_at.isoformat()) # type: ignore - qf = f'name="test unit 1" OR (useAbbreviation=f AND (name="test unit 2" OR createdAt > "{dt}"))' + qf = f'name="test unit 1" OR (useAbbreviation=f AND (name="{unit_2.name}" OR createdAt > "{dt}"))' + query = PaginationQuery(page=1, per_page=-1, query_filter=qf) + unit_results = units_repo.page_all(query).items + + assert len(unit_results) == 2 + result_ids = {unit.id for unit in unit_results} + assert unit_1.id in result_ids + assert unit_2.id in result_ids + assert unit_3.id not in result_ids + + qf = f'(name LIKE %_1 OR name IN ["{unit_2.name}"]) AND createdAt IS NOT NONE' query = PaginationQuery(page=1, per_page=-1, query_filter=qf) unit_results = units_repo.page_all(query).items + assert len(unit_results) == 2 - assert unit_3.id not in [unit.id for unit in unit_results] + result_ids = {unit.id for unit in unit_results} + assert unit_1.id in result_ids + assert unit_2.id in result_ids + assert unit_3.id not in result_ids + + +def test_pagination_filter_advanced_frontend_sort(database: AllRepositories, unique_user: TestUser): + categories = [ + CategorySave(group_id=unique_user.group_id, name=random_string(10)), + CategorySave(group_id=unique_user.group_id, name=random_string(10)), + ] + category_1, category_2 = [database.categories.create(category) for category in categories] + + slug1, slug2 = (random_string(10) for _ in range(2)) + tags = [ + TagSave(group_id=unique_user.group_id, name=slug1, slug=slug1), + TagSave(group_id=unique_user.group_id, name=slug2, slug=slug2), + ] + tag_1, tag_2 = [database.tags.create(tag) for tag in tags] + + tools = [ + RecipeToolSave(group_id=unique_user.group_id, name=random_string(10)), + RecipeToolSave(group_id=unique_user.group_id, name=random_string(10)), + ] + tool_1, tool_2 = [database.tools.create(tool) for tool in tools] + + # Bootstrap the database with recipes + slug = random_string() + recipe_ct0_tg0_tl0 = database.recipes.create( + Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug) + ) + + slug = random_string() + recipe_ct1_tg0_tl0 = database.recipes.create( + Recipe( + user_id=unique_user.user_id, + group_id=unique_user.group_id, + name=slug, + slug=slug, + recipe_category=[category_1], + ) + ) + + slug = random_string() + recipe_ct12_tg0_tl0 = database.recipes.create( + Recipe( + user_id=unique_user.user_id, + group_id=unique_user.group_id, + name=slug, + slug=slug, + recipe_category=[category_1, category_2], + ) + ) + + slug = random_string() + recipe_ct1_tg1_tl0 = database.recipes.create( + Recipe( + user_id=unique_user.user_id, + group_id=unique_user.group_id, + name=slug, + slug=slug, + recipe_category=[category_1], + tags=[tag_1], + ) + ) + + slug = random_string() + recipe_ct1_tg0_tl1 = database.recipes.create( + Recipe( + user_id=unique_user.user_id, + group_id=unique_user.group_id, + name=slug, + slug=slug, + recipe_category=[category_1], + tools=[tool_1], + ) + ) + + slug = random_string() + recipe_ct0_tg2_tl2 = database.recipes.create( + Recipe( + user_id=unique_user.user_id, + group_id=unique_user.group_id, + name=slug, + slug=slug, + tags=[tag_2], + tools=[tool_2], + ) + ) + + slug = random_string() + recipe_ct12_tg12_tl2 = database.recipes.create( + Recipe( + user_id=unique_user.user_id, + group_id=unique_user.group_id, + name=slug, + slug=slug, + recipe_category=[category_1, category_2], + tags=[tag_1, tag_2], + tools=[tool_2], + ) + ) + + repo = database.recipes.by_group(unique_user.group_id) # type: ignore + + qf = f'recipeCategory.id IN ["{category_1.id}"] AND tools.id IN ["{tool_1.id}"]' + query = PaginationQuery(page=1, per_page=-1, query_filter=qf) + recipe_results = repo.page_all(query).items + assert len(recipe_results) == 1 + recipe_ids = {recipe.id for recipe in recipe_results} + assert recipe_ct0_tg0_tl0.id not in recipe_ids + assert recipe_ct1_tg0_tl0.id not in recipe_ids + assert recipe_ct12_tg0_tl0.id not in recipe_ids + assert recipe_ct1_tg1_tl0.id not in recipe_ids + assert recipe_ct1_tg0_tl1.id in recipe_ids + assert recipe_ct0_tg2_tl2.id not in recipe_ids + assert recipe_ct12_tg12_tl2.id not in recipe_ids + + qf = f'recipeCategory.id CONTAINS ALL ["{category_1.id}", "{category_2.id}"] AND tags.id IN ["{tag_1.id}"]' + query = PaginationQuery(page=1, per_page=-1, query_filter=qf) + recipe_results = repo.page_all(query).items + assert len(recipe_results) == 1 + recipe_ids = {recipe.id for recipe in recipe_results} + assert recipe_ct0_tg0_tl0.id not in recipe_ids + assert recipe_ct1_tg0_tl0.id not in recipe_ids + assert recipe_ct12_tg0_tl0.id not in recipe_ids + assert recipe_ct1_tg1_tl0.id not in recipe_ids + assert recipe_ct1_tg0_tl1.id not in recipe_ids + assert recipe_ct0_tg2_tl2.id not in recipe_ids + assert recipe_ct12_tg12_tl2.id in recipe_ids + + qf = f'tags.id IN ["{tag_1.id}", "{tag_2.id}"] AND tools.id IN ["{tool_2.id}"]' + query = PaginationQuery(page=1, per_page=-1, query_filter=qf) + recipe_results = repo.page_all(query).items + assert len(recipe_results) == 2 + recipe_ids = {recipe.id for recipe in recipe_results} + assert recipe_ct0_tg0_tl0.id not in recipe_ids + assert recipe_ct1_tg0_tl0.id not in recipe_ids + assert recipe_ct12_tg0_tl0.id not in recipe_ids + assert recipe_ct1_tg1_tl0.id not in recipe_ids + assert recipe_ct1_tg0_tl1.id not in recipe_ids + assert recipe_ct0_tg2_tl2.id in recipe_ids + assert recipe_ct12_tg12_tl2.id in recipe_ids + + qf = ( + f'recipeCategory.id CONTAINS ALL ["{category_1.id}", "{category_2.id}"]' + f'AND tags.id IN ["{tag_1.id}", "{tag_2.id}"] AND tools.id IN ["{tool_1.id}", "{tool_2.id}"]' + ) + query = PaginationQuery(page=1, per_page=-1, query_filter=qf) + recipe_results = repo.page_all(query).items + assert len(recipe_results) == 1 + recipe_ids = {recipe.id for recipe in recipe_results} + assert recipe_ct0_tg0_tl0.id not in recipe_ids + assert recipe_ct1_tg0_tl0.id not in recipe_ids + assert recipe_ct12_tg0_tl0.id not in recipe_ids + assert recipe_ct1_tg1_tl0.id not in recipe_ids + assert recipe_ct1_tg0_tl1.id not in recipe_ids + assert recipe_ct0_tg2_tl2.id not in recipe_ids + assert recipe_ct12_tg12_tl2.id in recipe_ids @pytest.mark.parametrize( @@ -214,6 +636,13 @@ def test_pagination_filter_advanced(query_units: tuple[RepositoryUnit, Ingredien pytest.param('(name="test name" AND useAbbreviation=f))', id="unbalanced parenthesis"), pytest.param('id="this is not a valid UUID"', id="invalid UUID"), pytest.param('createdAt="this is not a valid datetime format"', id="invalid datetime format"), + pytest.param('name IS "test name"', id="IS can only be used with NULL or NONE"), + pytest.param('name IS NOT "test name"', id="IS NOT can only be used with NULL or NONE"), + pytest.param('name IN "test name"', id="IN must use a list of values"), + pytest.param('name NOT IN "test name"', id="NOT IN must use a list of values"), + pytest.param('name CONTAINS ALL "test name"', id="CONTAINS ALL must use a list of values"), + pytest.param('createdAt LIKE "2023-02-25"', id="LIKE is only valid for string columns"), + pytest.param('createdAt NOT LIKE "2023-02-25"', id="NOT LIKE is only valid for string columns"), pytest.param('badAttribute="test value"', id="invalid attribute"), pytest.param('group.badAttribute="test value"', id="bad nested attribute"), pytest.param('group.preferences.badAttribute="test value"', id="bad double nested attribute"),