Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def _compound_searches_queries(self, search_replacements):
search.as_mql(self, self.connection),
{
"$addFields": {
result_col.as_mql(self, self.connection, as_path=True): {
result_col.as_mql(self, self.connection).removeprefix("$"): {
"$meta": score_function
}
}
Expand Down
12 changes: 10 additions & 2 deletions django_mongodb_backend/expressions/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from bson import Decimal128
from django.core.exceptions import EmptyResultSet, FullResultSet
from django.db import NotSupportedError
from django.db.models import F
from django.db.models.expressions import (
Case,
Col,
Expand Down Expand Up @@ -53,7 +54,7 @@ def case(self, compiler, connection):
}


def col(self, compiler, connection, as_path=False): # noqa: ARG001
def col(self, compiler, connection): # noqa: ARG001
# If the column is part of a subquery and belongs to one of the parent
# queries, it will be stored for reference using $let in a $lookup stage.
# If the query is built with `alias_cols=False`, treat the column as
Expand All @@ -71,7 +72,7 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001
# Add the column's collection's alias for columns in joined collections.
has_alias = self.alias and self.alias != compiler.collection_name
prefix = f"{self.alias}." if has_alias else ""
if not as_path:
if not getattr(self, "_as_path", False):
prefix = f"${prefix}"
return f"{prefix}{self.target.column}"

Expand Down Expand Up @@ -209,6 +210,13 @@ def value(self, compiler, connection): # noqa: ARG001
return value


class Path(F):
def resolve_expression(self, *args, **kwargs):
expr = super().resolve_expression(*args, **kwargs)
expr._as_path = True
return expr


def register_expressions():
Case.as_mql = case
Col.as_mql = col
Expand Down
86 changes: 44 additions & 42 deletions django_mongodb_backend/expressions/search.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from django.db import NotSupportedError
from django.db.models import CharField, Expression, FloatField, TextField
from django.db.models.expressions import F, Value
from django.db.models.expressions import Value
from django.db.models.lookups import Lookup

from ..query_utils import process_lhs, process_rhs
from .builtins import Path


def cast_as_field(path):
return F(path) if isinstance(path, str) else path
def cast_as_path(path):
return Path(path) if isinstance(path, str) else path


class Operator:
Expand Down Expand Up @@ -146,19 +147,19 @@ class SearchAutocomplete(SearchExpression):
"""

def __init__(self, path, query, *, fuzzy=None, token_order=None, score=None):
self.path = cast_as_field(path)
self.path = cast_as_path(path)
self.query = query
self.fuzzy = fuzzy
self.token_order = token_order
self.score = score
super().__init__()

def get_search_fields(self, compiler, connection):
return {self.path.as_mql(compiler, connection, as_path=True)}
return {self.path.as_mql(compiler, connection)}

def search_operator(self, compiler, connection):
params = {
"path": self.path.as_mql(compiler, connection, as_path=True),
"path": self.path.as_mql(compiler, connection),
"query": self.query,
}
if self.score:
Expand Down Expand Up @@ -186,17 +187,17 @@ class SearchEquals(SearchExpression):
"""

def __init__(self, path, value, *, score=None):
self.path = cast_as_field(path)
self.path = cast_as_path(path)
self.value = value
self.score = score
super().__init__()

def get_search_fields(self, compiler, connection):
return {self.path.as_mql(compiler, connection, as_path=True)}
return {self.path.as_mql(compiler, connection)}

def search_operator(self, compiler, connection):
params = {
"path": self.path.as_mql(compiler, connection, as_path=True),
"path": self.path.as_mql(compiler, connection),
"value": self.value,
}
if self.score:
Expand All @@ -223,16 +224,16 @@ class SearchExists(SearchExpression):
"""

def __init__(self, path, *, score=None):
self.path = cast_as_field(path)
self.path = cast_as_path(path)
self.score = score
super().__init__()

def get_search_fields(self, compiler, connection):
return {self.path.as_mql(compiler, connection, as_path=True)}
return {self.path.as_mql(compiler, connection)}

def search_operator(self, compiler, connection):
params = {
"path": self.path.as_mql(compiler, connection, as_path=True),
"path": self.path.as_mql(compiler, connection),
}
if self.score:
params["score"] = self.score.as_mql(compiler, connection)
Expand All @@ -255,17 +256,17 @@ class SearchIn(SearchExpression):
"""

def __init__(self, path, value, *, score=None):
self.path = cast_as_field(path)
self.path = cast_as_path(path)
self.value = value
self.score = score
super().__init__()

def get_search_fields(self, compiler, connection):
return {self.path.as_mql(compiler, connection, as_path=True)}
return {self.path.as_mql(compiler, connection)}

def search_operator(self, compiler, connection):
params = {
"path": self.path.as_mql(compiler, connection, as_path=True),
"path": self.path.as_mql(compiler, connection),
"value": self.value,
}
if self.score:
Expand Down Expand Up @@ -294,19 +295,19 @@ class SearchPhrase(SearchExpression):
"""

def __init__(self, path, query, *, slop=None, synonyms=None, score=None):
self.path = cast_as_field(path)
self.path = cast_as_path(path)
self.query = query
self.slop = slop
self.synonyms = synonyms
self.score = score
super().__init__()

def get_search_fields(self, compiler, connection):
return {self.path.as_mql(compiler, connection, as_path=True)}
return {self.path.as_mql(compiler, connection)}

def search_operator(self, compiler, connection):
params = {
"path": self.path.as_mql(compiler, connection, as_path=True),
"path": self.path.as_mql(compiler, connection),
"query": self.query,
}
if self.score:
Expand Down Expand Up @@ -338,17 +339,17 @@ class SearchQueryString(SearchExpression):
"""

def __init__(self, path, query, *, score=None):
self.path = cast_as_field(path)
self.path = cast_as_path(path)
self.query = query
self.score = score
super().__init__()

def get_search_fields(self, compiler, connection):
return {self.path.as_mql(compiler, connection, as_path=True)}
return {self.path.as_mql(compiler, connection)}

def search_operator(self, compiler, connection):
params = {
"defaultPath": self.path.as_mql(compiler, connection, as_path=True),
"defaultPath": self.path.as_mql(compiler, connection),
"query": self.query,
}
if self.score:
Expand Down Expand Up @@ -378,7 +379,7 @@ class SearchRange(SearchExpression):
"""

def __init__(self, path, *, lt=None, lte=None, gt=None, gte=None, score=None):
self.path = cast_as_field(path)
self.path = cast_as_path(path)
self.lt = lt
self.lte = lte
self.gt = gt
Expand All @@ -387,11 +388,11 @@ def __init__(self, path, *, lt=None, lte=None, gt=None, gte=None, score=None):
super().__init__()

def get_search_fields(self, compiler, connection):
return {self.path.as_mql(compiler, connection, as_path=True)}
return {self.path.as_mql(compiler, connection)}

def search_operator(self, compiler, connection):
params = {
"path": self.path.as_mql(compiler, connection, as_path=True),
"path": self.path.as_mql(compiler, connection),
}
if self.score:
params["score"] = self.score.as_mql(compiler, connection)
Expand Down Expand Up @@ -424,18 +425,18 @@ class SearchRegex(SearchExpression):
"""

def __init__(self, path, query, *, allow_analyzed_field=None, score=None):
self.path = cast_as_field(path)
self.path = cast_as_path(path)
self.query = query
self.allow_analyzed_field = allow_analyzed_field
self.score = score
super().__init__()

def get_search_fields(self, compiler, connection):
return {self.path.as_mql(compiler, connection, as_path=True)}
return {self.path.as_mql(compiler, connection)}

def search_operator(self, compiler, connection):
params = {
"path": self.path.as_mql(compiler, connection, as_path=True),
"path": self.path.as_mql(compiler, connection),
"query": self.query,
}
if self.score:
Expand Down Expand Up @@ -472,7 +473,7 @@ class SearchText(SearchExpression):
"""

def __init__(self, path, query, *, fuzzy=None, match_criteria=None, synonyms=None, score=None):
self.path = cast_as_field(path)
self.path = cast_as_path(path)
self.query = query
self.fuzzy = fuzzy
self.match_criteria = match_criteria
Expand All @@ -481,11 +482,11 @@ def __init__(self, path, query, *, fuzzy=None, match_criteria=None, synonyms=Non
super().__init__()

def get_search_fields(self, compiler, connection):
return {self.path.as_mql(compiler, connection, as_path=True)}
return {self.path.as_mql(compiler, connection)}

def search_operator(self, compiler, connection):
params = {
"path": self.path.as_mql(compiler, connection, as_path=True),
"path": self.path.as_mql(compiler, connection),
"query": self.query,
}
if self.score:
Expand Down Expand Up @@ -520,18 +521,18 @@ class SearchWildcard(SearchExpression):
"""

def __init__(self, path, query, allow_analyzed_field=None, score=None):
self.path = cast_as_field(path)
self.path = cast_as_path(path)
self.query = query
self.allow_analyzed_field = allow_analyzed_field
self.score = score
super().__init__()

def get_search_fields(self, compiler, connection):
return {self.path.as_mql(compiler, connection, as_path=True)}
return {self.path.as_mql(compiler, connection)}

def search_operator(self, compiler, connection):
params = {
"path": self.path.as_mql(compiler, connection, as_path=True),
"path": self.path.as_mql(compiler, connection),
"query": self.query,
}
if self.score:
Expand Down Expand Up @@ -566,18 +567,18 @@ class SearchGeoShape(SearchExpression):
"""

def __init__(self, path, relation, geometry, *, score=None):
self.path = cast_as_field(path)
self.path = cast_as_path(path)
self.relation = relation
self.geometry = geometry
self.score = score
super().__init__()

def get_search_fields(self, compiler, connection):
return {self.path.as_mql(compiler, connection, as_path=True)}
return {self.path.as_mql(compiler, connection)}

def search_operator(self, compiler, connection):
params = {
"path": self.path.as_mql(compiler, connection, as_path=True),
"path": self.path.as_mql(compiler, connection),
"relation": self.relation,
"geometry": self.geometry,
}
Expand Down Expand Up @@ -610,18 +611,18 @@ class SearchGeoWithin(SearchExpression):
"""

def __init__(self, path, kind, geometry, *, score=None):
self.path = cast_as_field(path)
self.path = cast_as_path(path)
self.kind = kind
self.geometry = geometry
self.score = score
super().__init__()

def get_search_fields(self, compiler, connection):
return {self.path.as_mql(compiler, connection, as_path=True)}
return {self.path.as_mql(compiler, connection)}

def search_operator(self, compiler, connection):
params = {
"path": self.path.as_mql(compiler, connection, as_path=True),
"path": self.path.as_mql(compiler, connection),
self.kind: self.geometry,
}
if self.score:
Expand Down Expand Up @@ -855,7 +856,7 @@ def __init__(
exact=None,
filter=None,
):
self.path = cast_as_field(path)
self.path = cast_as_path(path)
self.query_vector = query_vector
self.limit = limit
self.num_candidates = num_candidates
Expand All @@ -879,7 +880,7 @@ def __ror__(self, other):
raise NotSupportedError("SearchVector cannot be combined")

def get_search_fields(self, compiler, connection):
return {self.path.as_mql(compiler, connection, as_path=True)}
return {self.path.as_mql(compiler, connection)}

def _get_query_index(self, fields, compiler):
for search_indexes in compiler.collection.list_search_indexes():
Expand All @@ -894,7 +895,7 @@ def _get_query_index(self, fields, compiler):
def as_mql(self, compiler, connection):
params = {
"index": self._get_query_index(self.get_search_fields(compiler, connection), compiler),
"path": self.path.as_mql(compiler, connection, as_path=True),
"path": self.path.as_mql(compiler, connection),
"queryVector": self.query_vector,
"limit": self.limit,
}
Expand Down Expand Up @@ -924,6 +925,7 @@ class SearchTextLookup(Lookup):

def __init__(self, lhs, rhs):
super().__init__(lhs, rhs)
self.lhs._as_path = True
self.lhs = SearchText(self.lhs, self.rhs)
self.rhs = Value(0)

Expand Down
6 changes: 3 additions & 3 deletions django_mongodb_backend/fields/embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,14 @@ def get_transform(self, name):
f"{suggestion}"
)

def as_mql(self, compiler, connection, as_path=False):
def as_mql(self, compiler, connection):
previous = self
key_transforms = []
while isinstance(previous, KeyTransform):
key_transforms.insert(0, previous.key_name)
previous = previous.lhs
if as_path:
mql = previous.as_mql(compiler, connection, as_path=True)
if getattr(self, "_as_path", False):
mql = previous.as_mql(compiler, connection).removeprefix("$")
mql_path = ".".join(key_transforms)
return f"{mql}.{mql_path}"
mql = previous.as_mql(compiler, connection)
Expand Down
Loading