Skip to content

Atlas search lookups #325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ repos:
rev: "v2.2.6"
hooks:
- id: codespell
args: ["-L", "nin"]
args: ["-L", "nin", "-L", "searchin"]
2 changes: 1 addition & 1 deletion django_mongodb_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .aggregates import register_aggregates # noqa: E402
from .checks import register_checks # noqa: E402
from .expressions import register_expressions # noqa: E402
from .expressions.builtins import register_expressions # noqa: E402
from .fields import register_fields # noqa: E402
from .functions import register_functions # noqa: E402
from .indexes import register_indexes # noqa: E402
Expand Down
136 changes: 113 additions & 23 deletions django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from django.utils.functional import cached_property
from pymongo import ASCENDING, DESCENDING

from .expressions.search import SearchExpression, SearchVector
from .query import MongoQuery, wrap_database_errors


Expand All @@ -34,6 +35,8 @@ def __init__(self, *args, **kwargs):
# A list of OrderBy objects for this query.
self.order_by_objs = None
self.subqueries = []
# Atlas search calls
self.search_pipeline = []

def _get_group_alias_column(self, expr, annotation_group_idx):
"""Generate a dummy field for use in the ids fields in $group."""
Expand All @@ -57,6 +60,29 @@ def _get_column_from_expression(self, expr, alias):
column_target.set_attributes_from_name(alias)
return Col(self.collection_name, column_target)

def _get_replace_expr(self, sub_expr, group, alias):
column_target = sub_expr.output_field.clone()
column_target.db_column = alias
column_target.set_attributes_from_name(alias)
inner_column = Col(self.collection_name, column_target)
if getattr(sub_expr, "distinct", False):
# If the expression should return distinct values, use
# $addToSet to deduplicate.
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
group[alias] = {"$addToSet": rhs}
replacing_expr = sub_expr.copy()
replacing_expr.set_source_expressions([inner_column, None])
else:
group[alias] = sub_expr.as_mql(self, self.connection)
replacing_expr = inner_column
# Count must return 0 rather than null.
if isinstance(sub_expr, Count):
replacing_expr = Coalesce(replacing_expr, 0)
# Variance = StdDev^2
if isinstance(sub_expr, Variance):
replacing_expr = Power(replacing_expr, 2)
return replacing_expr

def _prepare_expressions_for_pipeline(self, expression, target, annotation_group_idx):
"""
Prepare expressions for the aggregation pipeline.
Expand All @@ -80,29 +106,33 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
alias = (
f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target
)
column_target = sub_expr.output_field.clone()
column_target.db_column = alias
column_target.set_attributes_from_name(alias)
inner_column = Col(self.collection_name, column_target)
if sub_expr.distinct:
# If the expression should return distinct values, use
# $addToSet to deduplicate.
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
group[alias] = {"$addToSet": rhs}
replacing_expr = sub_expr.copy()
replacing_expr.set_source_expressions([inner_column, None])
else:
group[alias] = sub_expr.as_mql(self, self.connection)
replacing_expr = inner_column
# Count must return 0 rather than null.
if isinstance(sub_expr, Count):
replacing_expr = Coalesce(replacing_expr, 0)
# Variance = StdDev^2
if isinstance(sub_expr, Variance):
replacing_expr = Power(replacing_expr, 2)
replacements[sub_expr] = replacing_expr
replacements[sub_expr] = self._get_replace_expr(sub_expr, group, alias)
return replacements, group

def _prepare_search_expressions_for_pipeline(self, expression, search_idx, replacements):
searches = {}
for sub_expr in self._get_search_expressions(expression):
if sub_expr not in replacements:
alias = f"__search_expr.search{next(search_idx)}"
replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias)

def _prepare_search_query_for_aggregation_pipeline(self, order_by):
replacements = {}
annotation_group_idx = itertools.count(start=1)
for expr in self.query.annotation_select.values():
self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements)

for expr, _ in order_by:
self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements)

self._prepare_search_expressions_for_pipeline(
self.having, annotation_group_idx, replacements
)
self._prepare_search_expressions_for_pipeline(
self.get_where(), annotation_group_idx, replacements
)
return replacements

def _prepare_annotations_for_aggregation_pipeline(self, order_by):
"""Prepare annotations for the aggregation pipeline."""
replacements = {}
Expand Down Expand Up @@ -207,9 +237,57 @@ def _build_aggregation_pipeline(self, ids, group):
pipeline.append({"$unset": "_id"})
return pipeline

def _compound_searches_queries(self, search_replacements):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to preserve this function for the future, probably want to make hybrid search and this part of the code could be useful. I know that it is weird, check the replacement len as 1 and then iterate over it. Also the exception could be raised before this point. Let me know if you want me to refactor this code.

if not search_replacements:
return []
if len(search_replacements) > 1:
has_search = any(not isinstance(search, SearchVector) for search in search_replacements)
has_vector_search = any(
isinstance(search, SearchVector) for search in search_replacements
)
if has_search and has_vector_search:
raise ValueError(
"Cannot combine a `$vectorSearch` with a `$search` operator. "
"If you need to combine them, consider restructuring your query logic or "
"running them as separate queries."
)
if not has_search:
raise ValueError(
"Cannot combine two `$vectorSearch` operator. "
"If you need to combine them, consider restructuring your query logic or "
"running them as separate queries."
)
raise ValueError(
"Only one $search operation is allowed per query. "
f"Received {len(search_replacements)} search expressions. "
"To combine multiple search expressions, use either a CompoundExpression for "
"fine-grained control or CombinedSearchExpression for simple logical combinations."
)
pipeline = []
for search, result_col in search_replacements.items():
score_function = (
"vectorSearchScore" if isinstance(search, SearchVector) else "searchScore"
)
pipeline.extend(
[
search.as_mql(self, self.connection),
{
"$addFields": {
result_col.as_mql(self, self.connection, as_path=True): {
"$meta": score_function
}
}
},
]
)
return pipeline

def pre_sql_setup(self, with_col_aliases=False):
extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases)
group, all_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
search_replacements = self._prepare_search_query_for_aggregation_pipeline(order_by)
group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
all_replacements = {**search_replacements, **group_replacements}
self.search_pipeline = self._compound_searches_queries(search_replacements)
# query.group_by is either:
# - None: no GROUP BY
# - True: group by select fields
Expand All @@ -234,6 +312,9 @@ def pre_sql_setup(self, with_col_aliases=False):
for target, expr in self.query.annotation_select.items()
}
self.order_by_objs = [expr.replace_expressions(all_replacements) for expr, _ in order_by]
if (where := self.get_where()) and search_replacements:
where = where.replace_expressions(search_replacements)
self.set_where(where)
return extra_select, order_by, group_by

def execute_sql(
Expand Down Expand Up @@ -557,10 +638,16 @@ def get_lookup_pipeline(self):
return result

def _get_aggregate_expressions(self, expr):
return self._get_all_expressions_of_type(expr, Aggregate)

def _get_search_expressions(self, expr):
return self._get_all_expressions_of_type(expr, SearchExpression)

def _get_all_expressions_of_type(self, expr, target_type):
stack = [expr]
while stack:
expr = stack.pop()
if isinstance(expr, Aggregate):
if isinstance(expr, target_type):
yield expr
elif hasattr(expr, "get_source_expressions"):
stack.extend(expr.get_source_expressions())
Expand Down Expand Up @@ -629,6 +716,9 @@ def _get_ordering(self):
def get_where(self):
return getattr(self, "where", self.query.where)

def set_where(self, value):
self.where = value

def explain_query(self):
# Validate format (none supported) and options.
options = self.connection.ops.explain_query_prefix(
Expand Down
4 changes: 4 additions & 0 deletions django_mongodb_backend/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def _destroy_test_db(self, test_database_name, verbosity):

for collection in self.connection.introspection.table_names():
if not collection.startswith("system."):
if self.connection.features.supports_atlas_search:
db_collection = self.connection.database.get_collection(collection)
for search_indexes in db_collection.list_search_indexes():
db_collection.drop_search_index(search_indexes["name"])
self.connection.database.drop_collection(collection)

def create_test_db(self, *args, **kwargs):
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from django.db.models.sql import Query

from .query_utils import process_lhs
from ..query_utils import process_lhs


def case(self, compiler, connection):
Expand Down Expand Up @@ -53,7 +53,7 @@ def case(self, compiler, connection):
}


def col(self, compiler, connection): # noqa: ARG001
def col(self, compiler, connection, as_path=False): # noqa: ARG001
# If the column is part of a subquery and belongs to one of the parent
# queries, it will be stored for reference using $let in a $lookup stage.
# If the query is built with `alias_cols=False`, treat the column as
Expand All @@ -71,7 +71,7 @@ def col(self, compiler, connection): # noqa: ARG001
# Add the column's collection's alias for columns in joined collections.
has_alias = self.alias and self.alias != compiler.collection_name
prefix = f"{self.alias}." if has_alias else ""
return f"${prefix}{self.target.column}"
return f"{prefix}{self.target.column}" if as_path else f"${prefix}{self.target.column}"


def col_pairs(self, compiler, connection):
Expand Down
Loading