Skip to content

Commit

Permalink
feat: instead of hardcoding child tag ids into main query, include su…
Browse files Browse the repository at this point in the history
…bquery
  • Loading branch information
Computerdores committed Dec 30, 2024
1 parent b791159 commit 2615e7d
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions tagstudio/src/core/library/alchemy/visitors.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import TYPE_CHECKING

import structlog
from sqlalchemy import and_, distinct, func, or_, select, text
from sqlalchemy import and_, column, distinct, func, or_, select, text, union_all
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import BinaryExpression, ColumnExpressionArgument
from sqlalchemy.sql.expression import (
BinaryExpression,
ColumnExpressionArgument,
CompoundSelect,
)
from src.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories
from src.core.query_lang import BaseVisitor
from src.core.query_lang.ast import AST, ANDList, Constraint, ConstraintType, Not, ORList, Property
Expand All @@ -28,7 +32,7 @@
FROM tag_subtags ts
INNER JOIN Subtags s ON ts.child_id = s.child_id
)
SELECT * FROM Subtags;
SELECT child_id FROM Subtags
""") # noqa: E501


Expand Down Expand Up @@ -59,7 +63,10 @@ def visit_and_list(self, node: ANDList) -> ColumnExpressionArgument:
tag_ids.append(int(term.value))
continue
case ConstraintType.Tag:
if len(ids := self.__get_tag_ids(term.value)) == 1:
if (
isinstance((ids := self.__get_tag_ids(term.value)), list)
and len(ids) == 1
):
tag_ids.append(ids[0])
continue

Expand Down Expand Up @@ -113,7 +120,9 @@ def visit_property(self, node: Property) -> None:
def visit_not(self, node: Not) -> ColumnExpressionArgument:
return ~self.__entry_satisfies_ast(node.child)

def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[int]:
def __get_tag_ids(
self, tag_name: str, include_children: bool = True
) -> list[int] | CompoundSelect:
"""Given a tag name find the ids of all tags that this name could refer to."""
with Session(self.lib.engine) as session:
tag_ids = list(
Expand All @@ -131,10 +140,13 @@ def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[in
)
if not include_children:
return tag_ids
outp = []
for tag_id in tag_ids:
outp.extend(list(session.scalars(CHILDREN_QUERY, {"tag_id": tag_id})))
return outp
queries = [
CHILDREN_QUERY.bindparams(tag_id=id).columns(column("child_id")) for id in tag_ids
]
outp = union_all(*queries)
# if only one tag is found return that a list with that tag instead,
# in order to make use of the optimisations in __entry_has_all_tags
return t if len(t := list(session.scalars(outp))) == 1 else outp

def __entry_has_all_tags(self, tag_ids: list[int]) -> BinaryExpression[bool]:
"""Returns Binary Expression that is true if the Entry has all provided tag ids."""
Expand Down

0 comments on commit 2615e7d

Please sign in to comment.