Skip to content

Commit

Permalink
Prevent dups in conjs and disjs by using proper data-structures
Browse files Browse the repository at this point in the history
Related to #480
  • Loading branch information
Suor committed Oct 9, 2024
1 parent 4b31c30 commit 8f0b432
Showing 1 changed file with 30 additions and 23 deletions.
53 changes: 30 additions & 23 deletions cacheops/tree.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from itertools import product
from funcy import group_by, join_with, lcat, lmap
from funcy import group_by, join_with, lcat, lmap, cat

from django.db.models import Subquery
from django.db.models.query import QuerySet
Expand All @@ -23,8 +23,8 @@ def dnfs(qs):
conditions on joined models and subrequests are ignored.
__in is converted into = or = or = ...
"""
SOME = object()
SOME_TREE = [[(None, None, SOME, True)]]
SOME = Some()
SOME_TREE = {frozenset({(None, None, SOME, True)})}

def negate(term):
return (term[0], term[1], term[2], not term[3])
Expand All @@ -51,37 +51,37 @@ def _dnf(where):

attname = where.lhs.target.attname
if isinstance(where, Exact):
return [[(where.lhs.alias, attname, where.rhs, True)]]
return {frozenset({(where.lhs.alias, attname, where.rhs, True)})}
elif isinstance(where, IsNull):
return [[(where.lhs.alias, attname, None, where.rhs)]]
return {frozenset({(where.lhs.alias, attname, None, where.rhs)})}
elif isinstance(where, In) and len(where.rhs) < settings.CACHEOPS_LONG_DISJUNCTION:
return [[(where.lhs.alias, attname, v, True)] for v in where.rhs]
return {frozenset({(where.lhs.alias, attname, v, True)}) for v in where.rhs}
else:
return SOME_TREE
elif isinstance(where, NothingNode):
return []
return set()
elif isinstance(where, (ExtraWhere, SubqueryConstraint, Exists)):
return SOME_TREE
elif len(where) == 0:
return [[]]
return {frozenset()}
else:
chilren_dnfs = lmap(_dnf, where.children)
children_dnfs = lmap(_dnf, where.children)

if len(chilren_dnfs) == 0:
return [[]]
elif len(chilren_dnfs) == 1:
result = chilren_dnfs[0]
if len(children_dnfs) == 0:
return {frozenset()}
elif len(children_dnfs) == 1:
result = children_dnfs[0]
else:
# Just unite children joined with OR
if where.connector == OR:
result = lcat(chilren_dnfs)
result = set(cat(children_dnfs))
# Use Cartesian product to AND children
else:
result = lmap(lcat, product(*chilren_dnfs))
result = {frozenset(cat(conjs)) for conjs in product(*children_dnfs)}

# Negating and expanding brackets
if where.negated:
result = [lmap(negate, p) for p in product(*result)]
result = {frozenset(map(negate, conjs)) for conjs in product(*result)}

return result

Expand Down Expand Up @@ -119,22 +119,23 @@ def add_join_conds(dnf, query):
join_exts[join.table_alias, target_col].append((join.parent_alias, parent_col))

if not join_exts:
return
return dnf

for conj in dnf:
# NOTE: using list comprehension over genexp here since we change the thing we iterate
conj.extend([
return {
conj | {
(join_alias, join_col, v, negation)
for alias, col, v, negation in conj
for (join_alias, join_col) in join_exts[alias, col]
])
for join_alias, join_col in join_exts[alias, col]
}
for conj in dnf
}

def query_dnf(query):
def table_for(alias):
return alias if alias == main_alias else query.alias_map[alias].table_name

dnf = _dnf(query.where)
add_join_conds(dnf, query)
dnf = add_join_conds(dnf, query)

# NOTE: we exclude content_type as it never changes and will hold dead invalidation info
main_alias = query.model._meta.db_table
Expand All @@ -156,3 +157,9 @@ def table_for(alias):
dnfs_.update(join_with(lcat, subqueries))

return dnfs_


class Some:
def __str__(self):
return 'SOME'
__repr__ = __str__

0 comments on commit 8f0b432

Please sign in to comment.