diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index fb33b260..124b759d 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -302,7 +302,7 @@ def _compound_searches_queries(self, search_replacements): search.as_mql(self, self.connection), { "$addFields": { - result_col.as_mql(self, self.connection, as_path=True): { + result_col.as_mql(self, self.connection).removeprefix("$"): { "$meta": score_function } } diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 010880ae..99c50bb1 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -5,6 +5,7 @@ from bson import Decimal128 from django.core.exceptions import EmptyResultSet, FullResultSet from django.db import NotSupportedError +from django.db.models import F from django.db.models.expressions import ( Case, Col, @@ -53,7 +54,7 @@ def case(self, compiler, connection): } -def col(self, compiler, connection, as_path=False): # noqa: ARG001 +def col(self, compiler, connection): # noqa: ARG001 # If the column is part of a subquery and belongs to one of the parent # queries, it will be stored for reference using $let in a $lookup stage. # If the query is built with `alias_cols=False`, treat the column as @@ -71,7 +72,7 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001 # Add the column's collection's alias for columns in joined collections. has_alias = self.alias and self.alias != compiler.collection_name prefix = f"{self.alias}." if has_alias else "" - if not as_path: + if not getattr(self, "_as_path", False): prefix = f"${prefix}" return f"{prefix}{self.target.column}" @@ -209,6 +210,13 @@ def value(self, compiler, connection): # noqa: ARG001 return value +class Path(F): + def resolve_expression(self, *args, **kwargs): + expr = super().resolve_expression(*args, **kwargs) + expr._as_path = True + return expr + + def register_expressions(): Case.as_mql = case Col.as_mql = col diff --git a/django_mongodb_backend/expressions/search.py b/django_mongodb_backend/expressions/search.py index 0cbfe788..3d5634d6 100644 --- a/django_mongodb_backend/expressions/search.py +++ b/django_mongodb_backend/expressions/search.py @@ -1,13 +1,14 @@ from django.db import NotSupportedError from django.db.models import CharField, Expression, FloatField, TextField -from django.db.models.expressions import F, Value +from django.db.models.expressions import Value from django.db.models.lookups import Lookup from ..query_utils import process_lhs, process_rhs +from .builtins import Path -def cast_as_field(path): - return F(path) if isinstance(path, str) else path +def cast_as_path(path): + return Path(path) if isinstance(path, str) else path class Operator: @@ -146,7 +147,7 @@ class SearchAutocomplete(SearchExpression): """ def __init__(self, path, query, *, fuzzy=None, token_order=None, score=None): - self.path = cast_as_field(path) + self.path = cast_as_path(path) self.query = query self.fuzzy = fuzzy self.token_order = token_order @@ -154,11 +155,11 @@ def __init__(self, path, query, *, fuzzy=None, token_order=None, score=None): super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), "query": self.query, } if self.score: @@ -186,17 +187,17 @@ class SearchEquals(SearchExpression): """ def __init__(self, path, value, *, score=None): - self.path = cast_as_field(path) + self.path = cast_as_path(path) self.value = value self.score = score super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), "value": self.value, } if self.score: @@ -223,16 +224,16 @@ class SearchExists(SearchExpression): """ def __init__(self, path, *, score=None): - self.path = cast_as_field(path) + self.path = cast_as_path(path) self.score = score super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), } if self.score: params["score"] = self.score.as_mql(compiler, connection) @@ -255,17 +256,17 @@ class SearchIn(SearchExpression): """ def __init__(self, path, value, *, score=None): - self.path = cast_as_field(path) + self.path = cast_as_path(path) self.value = value self.score = score super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), "value": self.value, } if self.score: @@ -294,7 +295,7 @@ class SearchPhrase(SearchExpression): """ def __init__(self, path, query, *, slop=None, synonyms=None, score=None): - self.path = cast_as_field(path) + self.path = cast_as_path(path) self.query = query self.slop = slop self.synonyms = synonyms @@ -302,11 +303,11 @@ def __init__(self, path, query, *, slop=None, synonyms=None, score=None): super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), "query": self.query, } if self.score: @@ -338,17 +339,17 @@ class SearchQueryString(SearchExpression): """ def __init__(self, path, query, *, score=None): - self.path = cast_as_field(path) + self.path = cast_as_path(path) self.query = query self.score = score super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "defaultPath": self.path.as_mql(compiler, connection, as_path=True), + "defaultPath": self.path.as_mql(compiler, connection), "query": self.query, } if self.score: @@ -378,7 +379,7 @@ class SearchRange(SearchExpression): """ def __init__(self, path, *, lt=None, lte=None, gt=None, gte=None, score=None): - self.path = cast_as_field(path) + self.path = cast_as_path(path) self.lt = lt self.lte = lte self.gt = gt @@ -387,11 +388,11 @@ def __init__(self, path, *, lt=None, lte=None, gt=None, gte=None, score=None): super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), } if self.score: params["score"] = self.score.as_mql(compiler, connection) @@ -424,18 +425,18 @@ class SearchRegex(SearchExpression): """ def __init__(self, path, query, *, allow_analyzed_field=None, score=None): - self.path = cast_as_field(path) + self.path = cast_as_path(path) self.query = query self.allow_analyzed_field = allow_analyzed_field self.score = score super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), "query": self.query, } if self.score: @@ -472,7 +473,7 @@ class SearchText(SearchExpression): """ def __init__(self, path, query, *, fuzzy=None, match_criteria=None, synonyms=None, score=None): - self.path = cast_as_field(path) + self.path = cast_as_path(path) self.query = query self.fuzzy = fuzzy self.match_criteria = match_criteria @@ -481,11 +482,11 @@ def __init__(self, path, query, *, fuzzy=None, match_criteria=None, synonyms=Non super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), "query": self.query, } if self.score: @@ -520,18 +521,18 @@ class SearchWildcard(SearchExpression): """ def __init__(self, path, query, allow_analyzed_field=None, score=None): - self.path = cast_as_field(path) + self.path = cast_as_path(path) self.query = query self.allow_analyzed_field = allow_analyzed_field self.score = score super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), "query": self.query, } if self.score: @@ -566,18 +567,18 @@ class SearchGeoShape(SearchExpression): """ def __init__(self, path, relation, geometry, *, score=None): - self.path = cast_as_field(path) + self.path = cast_as_path(path) self.relation = relation self.geometry = geometry self.score = score super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), "relation": self.relation, "geometry": self.geometry, } @@ -610,18 +611,18 @@ class SearchGeoWithin(SearchExpression): """ def __init__(self, path, kind, geometry, *, score=None): - self.path = cast_as_field(path) + self.path = cast_as_path(path) self.kind = kind self.geometry = geometry self.score = score super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), self.kind: self.geometry, } if self.score: @@ -855,7 +856,7 @@ def __init__( exact=None, filter=None, ): - self.path = cast_as_field(path) + self.path = cast_as_path(path) self.query_vector = query_vector self.limit = limit self.num_candidates = num_candidates @@ -879,7 +880,7 @@ def __ror__(self, other): raise NotSupportedError("SearchVector cannot be combined") def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def _get_query_index(self, fields, compiler): for search_indexes in compiler.collection.list_search_indexes(): @@ -894,7 +895,7 @@ def _get_query_index(self, fields, compiler): def as_mql(self, compiler, connection): params = { "index": self._get_query_index(self.get_search_fields(compiler, connection), compiler), - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), "queryVector": self.query_vector, "limit": self.limit, } @@ -924,6 +925,7 @@ class SearchTextLookup(Lookup): def __init__(self, lhs, rhs): super().__init__(lhs, rhs) + self.lhs._as_path = True self.lhs = SearchText(self.lhs, self.rhs) self.rhs = Value(0) diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index 95163236..61971e8c 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -184,14 +184,14 @@ def get_transform(self, name): f"{suggestion}" ) - def as_mql(self, compiler, connection, as_path=False): + def as_mql(self, compiler, connection): previous = self key_transforms = [] while isinstance(previous, KeyTransform): key_transforms.insert(0, previous.key_name) previous = previous.lhs - if as_path: - mql = previous.as_mql(compiler, connection, as_path=True) + if getattr(self, "_as_path", False): + mql = previous.as_mql(compiler, connection).removeprefix("$") mql_path = ".".join(key_transforms) return f"{mql}.{mql_path}" mql = previous.as_mql(compiler, connection)