diff --git a/tagstudio/src/core/library/alchemy/visitors.py b/tagstudio/src/core/library/alchemy/visitors.py index 5eed4580f..0d23d3302 100644 --- a/tagstudio/src/core/library/alchemy/visitors.py +++ b/tagstudio/src/core/library/alchemy/visitors.py @@ -1,9 +1,13 @@ from typing import TYPE_CHECKING import structlog -from sqlalchemy import and_, distinct, func, or_, select, text +from sqlalchemy import and_, column, distinct, func, or_, select, text, union_all from sqlalchemy.orm import Session -from sqlalchemy.sql.expression import BinaryExpression, ColumnExpressionArgument +from sqlalchemy.sql.expression import ( + BinaryExpression, + ColumnExpressionArgument, + CompoundSelect, +) from src.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories from src.core.query_lang import BaseVisitor from src.core.query_lang.ast import AST, ANDList, Constraint, ConstraintType, Not, ORList, Property @@ -28,7 +32,7 @@ FROM tag_subtags ts INNER JOIN Subtags s ON ts.child_id = s.child_id ) -SELECT * FROM Subtags; +SELECT child_id FROM Subtags """) # noqa: E501 @@ -59,7 +63,10 @@ def visit_and_list(self, node: ANDList) -> ColumnExpressionArgument: tag_ids.append(int(term.value)) continue case ConstraintType.Tag: - if len(ids := self.__get_tag_ids(term.value)) == 1: + if ( + isinstance((ids := self.__get_tag_ids(term.value)), list) + and len(ids) == 1 + ): tag_ids.append(ids[0]) continue @@ -113,7 +120,9 @@ def visit_property(self, node: Property) -> None: def visit_not(self, node: Not) -> ColumnExpressionArgument: return ~self.__entry_satisfies_ast(node.child) - def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[int]: + def __get_tag_ids( + self, tag_name: str, include_children: bool = True + ) -> list[int] | CompoundSelect: """Given a tag name find the ids of all tags that this name could refer to.""" with Session(self.lib.engine) as session: tag_ids = list( @@ -131,10 +140,13 @@ def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[in ) if not include_children: return tag_ids - outp = [] - for tag_id in tag_ids: - outp.extend(list(session.scalars(CHILDREN_QUERY, {"tag_id": tag_id}))) - return outp + queries = [ + CHILDREN_QUERY.bindparams(tag_id=id).columns(column("child_id")) for id in tag_ids + ] + outp = union_all(*queries) + # if only one tag is found return that a list with that tag instead, + # in order to make use of the optimisations in __entry_has_all_tags + return t if len(t := list(session.scalars(outp))) == 1 else outp def __entry_has_all_tags(self, tag_ids: list[int]) -> BinaryExpression[bool]: """Returns Binary Expression that is true if the Entry has all provided tag ids."""