From cf614edd9802c3ac659cb7ed1e4a9590356f0bd5 Mon Sep 17 00:00:00 2001 From: "Jeffrey A. Clark" Date: Wed, 4 Jun 2025 17:13:31 -0400 Subject: [PATCH 01/20] Check for replica set or sharded cluster --- django_mongodb_backend/features.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/django_mongodb_backend/features.py b/django_mongodb_backend/features.py index fa73461d9..8fa777c77 100644 --- a/django_mongodb_backend/features.py +++ b/django_mongodb_backend/features.py @@ -36,8 +36,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_temporal_subtraction = True # MongoDB stores datetimes in UTC. supports_timezones = False - # Not implemented: https://github.com/mongodb/django-mongodb-backend/issues/7 - supports_transactions = False supports_unspecified_pk = True uses_savepoints = False @@ -97,6 +95,22 @@ class DatabaseFeatures(BaseDatabaseFeatures): "expressions.tests.ExpressionOperatorTests.test_lefthand_transformed_field_bitwise_or", } + @cached_property + def supports_transactions(self): + """Confirm support for transactions.""" + is_replica_set = False + is_sharded_cluster = False + with self.connection.cursor(): + client = self.connection.connection + hello_response = client.admin.command("hello") + if "setName" in hello_response: + is_replica_set = True + if "msg" in client.admin.command("hello") and hello_response["msg"] == "isdbgrid": + is_sharded_cluster = True + if is_replica_set or is_sharded_cluster: + return True + return False + @cached_property def django_test_expected_failures(self): expected_failures = super().django_test_expected_failures From 392bc2fb1c1579b57fb34a94485d75bb97e7756b Mon Sep 17 00:00:00 2001 From: "Jeffrey A. Clark" Date: Fri, 6 Jun 2025 09:17:30 -0400 Subject: [PATCH 02/20] Check for wired tiger storage engine --- django_mongodb_backend/features.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/django_mongodb_backend/features.py b/django_mongodb_backend/features.py index 8fa777c77..d358d7321 100644 --- a/django_mongodb_backend/features.py +++ b/django_mongodb_backend/features.py @@ -103,11 +103,17 @@ def supports_transactions(self): with self.connection.cursor(): client = self.connection.connection hello_response = client.admin.command("hello") + server_status = client.admin.command("serverStatus") if "setName" in hello_response: is_replica_set = True - if "msg" in client.admin.command("hello") and hello_response["msg"] == "isdbgrid": + if "msg" in hello_response and hello_response["msg"] == "isdbgrid": is_sharded_cluster = True - if is_replica_set or is_sharded_cluster: + if ( + "storageEngine" in server_status + and server_status["storageEngine"].get("name") == "wiredTiger" + ): + is_wired_tiger = True + if (is_replica_set or is_sharded_cluster) and is_wired_tiger: return True return False From ca235e2e563c34f30f2e69424c26670812fb9d47 Mon Sep 17 00:00:00 2001 From: "Jeffrey A. Clark" Date: Fri, 6 Jun 2025 19:35:39 -0400 Subject: [PATCH 03/20] Add transaction support --- django_mongodb_backend/base.py | 22 +++++++++++++++++++--- django_mongodb_backend/compiler.py | 4 +++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index c6110dbb7..6325e184a 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -139,6 +139,7 @@ def _isnull_operator(a, b): introspection_class = DatabaseIntrospection ops_class = DatabaseOperations validation_class = DatabaseValidation + session = None def get_collection(self, name, **kwargs): collection = Collection(self.database, name, **kwargs) @@ -190,13 +191,28 @@ def _driver_info(self): return None def _commit(self): - pass + if self.session: + self.session.commit_transaction() + self.session.end_session() + self.session = None def _rollback(self): pass - def set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False): - self.autocommit = autocommit + def _start_session(self): + if self.session is None: + self.session = self.connection.start_session() + self.session.start_transaction() + + def _start_transaction_under_autocommit(self): + self._start_session() + + def _set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False): + if not autocommit: + self._start_session() + else: + if self.session: + self.commit() def _close(self): # Normally called by close(), this method is also called by some tests. diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 12da13a1c..0176e6b74 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -685,7 +685,9 @@ def execute_sql(self, returning_fields=None): @wrap_database_errors def insert(self, docs, returning_fields=None): """Store a list of documents using field columns as element names.""" - inserted_ids = self.collection.insert_many(docs).inserted_ids + inserted_ids = self.collection.insert_many( + docs, session=self.connection.session + ).inserted_ids return [(x,) for x in inserted_ids] if returning_fields else [] @cached_property From 001cc472ce2139fb2fc000335e8144253523475f Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Fri, 6 Jun 2025 21:31:46 -0400 Subject: [PATCH 04/20] pass session to queries --- django_mongodb_backend/base.py | 7 +++---- django_mongodb_backend/compiler.py | 4 +++- django_mongodb_backend/features.py | 12 ++---------- django_mongodb_backend/query.py | 8 ++++++-- django_mongodb_backend/queryset.py | 2 +- tests/raw_query_/test_raw_aggregate.py | 3 ++- 6 files changed, 17 insertions(+), 19 deletions(-) diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index 6325e184a..8f7b64932 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -197,7 +197,9 @@ def _commit(self): self.session = None def _rollback(self): - pass + if self.session: + self.session.abort_transaction() + self.session = None def _start_session(self): if self.session is None: @@ -210,9 +212,6 @@ def _start_transaction_under_autocommit(self): def _set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False): if not autocommit: self._start_session() - else: - if self.session: - self.commit() def _close(self): # Normally called by close(), this method is also called by some tests. diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 0176e6b74..74a5d60e0 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -770,7 +770,9 @@ def execute_sql(self, result_type): @wrap_database_errors def update(self, criteria, pipeline): - return self.collection.update_many(criteria, pipeline).matched_count + return self.collection.update_many( + criteria, pipeline, session=self.connection.session + ).matched_count def check_query(self): super().check_query() diff --git a/django_mongodb_backend/features.py b/django_mongodb_backend/features.py index d358d7321..14abf6c81 100644 --- a/django_mongodb_backend/features.py +++ b/django_mongodb_backend/features.py @@ -84,6 +84,8 @@ class DatabaseFeatures(BaseDatabaseFeatures): # Value.as_mql() doesn't call output_field.get_db_prep_save(): # https://github.com/mongodb/django-mongodb-backend/issues/282 "model_fields.test_jsonfield.TestSaveLoad.test_bulk_update_custom_get_prep_value", + # to debug + "transactions.tests.AtomicMiscTests.test_mark_for_rollback_on_error_in_transaction", } # $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3. _django_test_expected_failures_bitwise = { @@ -505,16 +507,6 @@ def django_test_expected_failures(self): "Connection health checks not implemented.": { "backends.base.test_base.ConnectionHealthChecksTests", }, - "transaction.atomic() is not supported.": { - "backends.base.test_base.DatabaseWrapperLoggingTests", - "migrations.test_executor.ExecutorTests.test_atomic_operation_in_non_atomic_migration", - "migrations.test_operations.OperationTests.test_run_python_atomic", - }, - "transaction.rollback() is not supported.": { - "transactions.tests.AtomicMiscTests.test_mark_for_rollback_on_error_in_autocommit", - "transactions.tests.AtomicMiscTests.test_mark_for_rollback_on_error_in_transaction", - "transactions.tests.NonAutocommitTests.test_orm_query_after_error_and_rollback", - }, "migrate --fake-initial is not supported.": { "migrations.test_commands.MigrateTests.test_migrate_fake_initial", "migrations.test_commands.MigrateTests.test_migrate_fake_split_initial", diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index 049775205..8c18fcb03 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -63,7 +63,9 @@ def delete(self): """Execute a delete query.""" if self.compiler.subqueries: raise NotSupportedError("Cannot use QuerySet.delete() when a subquery is required.") - return self.compiler.collection.delete_many(self.match_mql).deleted_count + return self.compiler.collection.delete_many( + self.match_mql, session=self.compiler.connection.session + ).deleted_count @wrap_database_errors def get_cursor(self): @@ -71,7 +73,9 @@ def get_cursor(self): Return a pymongo CommandCursor that can be iterated on to give the results of the query. """ - return self.compiler.collection.aggregate(self.get_pipeline()) + return self.compiler.collection.aggregate( + self.get_pipeline(), session=self.compiler.connection.session + ) def get_pipeline(self): pipeline = [] diff --git a/django_mongodb_backend/queryset.py b/django_mongodb_backend/queryset.py index 4a2d884da..b02496fee 100644 --- a/django_mongodb_backend/queryset.py +++ b/django_mongodb_backend/queryset.py @@ -35,7 +35,7 @@ def __init__(self, pipeline, using, model): def _execute_query(self): connection = connections[self.using] collection = connection.get_collection(self.model._meta.db_table) - self.cursor = collection.aggregate(self.pipeline) + self.cursor = collection.aggregate(self.pipeline, session=connection.session) def __str__(self): return str(self.pipeline) diff --git a/tests/raw_query_/test_raw_aggregate.py b/tests/raw_query_/test_raw_aggregate.py index 72cd74d02..ce87311a6 100644 --- a/tests/raw_query_/test_raw_aggregate.py +++ b/tests/raw_query_/test_raw_aggregate.py @@ -182,7 +182,8 @@ def test_different_db_key_order(self): { field.name: getattr(author, field.name) for field in reversed(Author._meta.concrete_fields) - } + }, + session=connection.session, ) query = [] authors = Author.objects.all() From 04d3f2fb1856f69ea85486c0b902f6f889026b33 Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Sat, 7 Jun 2025 08:30:04 -0400 Subject: [PATCH 05/20] disable transactions if not supported --- django_mongodb_backend/base.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index 8f7b64932..25b4a2d7e 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -191,12 +191,16 @@ def _driver_info(self): return None def _commit(self): + if not self.features.supports_transactions: + return if self.session: self.session.commit_transaction() self.session.end_session() self.session = None def _rollback(self): + if not self.features.supports_transactions: + return if self.session: self.session.abort_transaction() self.session = None @@ -207,9 +211,13 @@ def _start_session(self): self.session.start_transaction() def _start_transaction_under_autocommit(self): + if not self.features.supports_transactions: + return self._start_session() def _set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False): + if self.features.supports_transactions: + return if not autocommit: self._start_session() From 176205ba17e53fe9df91b18ffe5cc81243cd20a7 Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Sat, 7 Jun 2025 08:32:53 -0400 Subject: [PATCH 06/20] add replica set tests --- .github/workflows/test-python-replica.yml | 58 +++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 .github/workflows/test-python-replica.yml diff --git a/.github/workflows/test-python-replica.yml b/.github/workflows/test-python-replica.yml new file mode 100644 index 000000000..a51481bfd --- /dev/null +++ b/.github/workflows/test-python-replica.yml @@ -0,0 +1,58 @@ +name: Python Tests on a replica test + +on: + pull_request: + paths: + - '**.py' + - '!setup.py' + - '.github/workflows/test-python-replica.yml' + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +defaults: + run: + shell: bash -eux {0} + +jobs: + build: + name: Django Test Suite + runs-on: ubuntu-latest + steps: + - name: Checkout django-mongodb-backend + uses: actions/checkout@v4 + with: + persist-credentials: false + - name: install django-mongodb-backend + run: | + pip3 install --upgrade pip + pip3 install -e . + - name: Checkout Django + uses: actions/checkout@v4 + with: + repository: 'mongodb-forks/django' + ref: 'mongodb-5.2.x' + path: 'django_repo' + persist-credentials: false + - name: Install system packages for Django's Python test dependencies + run: | + sudo apt-get update + sudo apt-get install libmemcached-dev + - name: Install Django and its Python test dependencies + run: | + cd django_repo/tests/ + pip3 install -e .. + pip3 install -r requirements/py3.txt + - name: Copy the test settings file + run: cp .github/workflows/mongodb_settings.py django_repo/tests/ + - name: Copy the test runner file + run: cp .github/workflows/runtests.py django_repo/tests/runtests_.py + - name: Start MongoDB + uses: supercharge/mongodb-github-action@90004df786821b6308fb02299e5835d0dae05d0d # 1.12.0 + with: + mongodb-version: 6.0 + mongodb-replica-set: test-rs + - name: Run tests + run: python3 django_repo/tests/runtests.py --settings mongodb_settings -v 2 From b3b956f57e52a69e1a90351764864f57c894bd42 Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Sat, 7 Jun 2025 09:03:33 -0400 Subject: [PATCH 07/20] skip union tests on transactions --- django_mongodb_backend/features.py | 5 +++++ tests/queries_/test_objectid.py | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/django_mongodb_backend/features.py b/django_mongodb_backend/features.py index 14abf6c81..3dcecc2e7 100644 --- a/django_mongodb_backend/features.py +++ b/django_mongodb_backend/features.py @@ -589,3 +589,8 @@ def supports_atlas_search(self): return False else: return True + + @cached_property + def supports_select_union(self): + # Stage not supported inside of a multi-document transaction: $unionWith + return not self.supports_transactions diff --git a/tests/queries_/test_objectid.py b/tests/queries_/test_objectid.py index 490d1b334..36d9561a7 100644 --- a/tests/queries_/test_objectid.py +++ b/tests/queries_/test_objectid.py @@ -1,6 +1,6 @@ from bson import ObjectId from django.core.exceptions import ValidationError -from django.test import TestCase +from django.test import TestCase, skipUnlessDBFeature from .models import Order, OrderItem, Tag @@ -75,6 +75,7 @@ def test_filter_parent_by_children_values_obj(self): parent_qs = Tag.objects.filter(children__id__in=child_ids).distinct().order_by("name") self.assertSequenceEqual(parent_qs, [self.t1]) + @skipUnlessDBFeature("supports_select_union") def test_filter_group_id_union_with_str(self): """Combine queries using union with string values.""" qs_a = Tag.objects.filter(group_id=self.group_id_str_1) @@ -82,6 +83,7 @@ def test_filter_group_id_union_with_str(self): union_qs = qs_a.union(qs_b).order_by("name") self.assertSequenceEqual(union_qs, [self.t3, self.t4]) + @skipUnlessDBFeature("supports_select_union") def test_filter_group_id_union_with_obj(self): """Combine queries using union with ObjectId values.""" qs_a = Tag.objects.filter(group_id=self.group_id_obj_1) From 8087a7c6177a9d98bb47e77fbe6321584fa8c44d Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Sat, 7 Jun 2025 10:18:07 -0400 Subject: [PATCH 08/20] add transaction logging --- django_mongodb_backend/base.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index 25b4a2d7e..fff19330e 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -3,6 +3,7 @@ from django.core.exceptions import ImproperlyConfigured from django.db.backends.base.base import BaseDatabaseWrapper +from django.db.backends.utils import debug_transaction from django.utils.asyncio import async_unsafe from django.utils.functional import cached_property from pymongo.collection import Collection @@ -194,7 +195,8 @@ def _commit(self): if not self.features.supports_transactions: return if self.session: - self.session.commit_transaction() + with debug_transaction(self, "session.commit_transaction()"): + self.session.commit_transaction() self.session.end_session() self.session = None @@ -202,13 +204,15 @@ def _rollback(self): if not self.features.supports_transactions: return if self.session: - self.session.abort_transaction() + with debug_transaction(self, "session.abort_transaction()"): + self.session.abort_transaction() self.session = None def _start_session(self): if self.session is None: self.session = self.connection.start_session() - self.session.start_transaction() + with debug_transaction(self, "session.start_transaction()"): + self.session.start_transaction() def _start_transaction_under_autocommit(self): if not self.features.supports_transactions: From c9b639aa607dcfcf5f92da31d0438083c044de6f Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Sat, 7 Jun 2025 11:02:38 -0400 Subject: [PATCH 09/20] add skips for composite_pk tests --- django_mongodb_backend/features.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/django_mongodb_backend/features.py b/django_mongodb_backend/features.py index 3dcecc2e7..c60c32719 100644 --- a/django_mongodb_backend/features.py +++ b/django_mongodb_backend/features.py @@ -545,8 +545,18 @@ def django_test_expected_failures(self): "foreign_object.test_tuple_lookups.TupleLookupsTests", }, "ColPairs is not supported.": { - # 'ColPairs' object has no attribute 'as_mql' "auth_tests.test_views.CustomUserCompositePrimaryKeyPasswordResetTest", + "composite_pk.test_aggregate.CompositePKAggregateTests", + "composite_pk.test_create.CompositePKCreateTests", + "composite_pk.test_delete.CompositePKDeleteTests", + "composite_pk.test_filter.CompositePKFilterTests", + "composite_pk.test_get.CompositePKGetTests", + "composite_pk.test_models.CompositePKModelsTests", + "composite_pk.test_order_by.CompositePKOrderByTests", + "composite_pk.test_update.CompositePKUpdateTests", + "composite_pk.test_values.CompositePKValuesTests", + "composite_pk.tests.CompositePKTests", + "composite_pk.tests.CompositePKFixturesTests", }, "Custom lookups are not supported.": { "custom_lookups.tests.BilateralTransformTests", From 939159905d06a04538a1792c0884bc6c50751c8c Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Sat, 7 Jun 2025 20:36:44 -0400 Subject: [PATCH 10/20] fix backend_ tests --- tests/backend_/test_base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/backend_/test_base.py b/tests/backend_/test_base.py index 7695b6f4d..9d4e7006f 100644 --- a/tests/backend_/test_base.py +++ b/tests/backend_/test_base.py @@ -1,7 +1,7 @@ from django.core.exceptions import ImproperlyConfigured from django.db import connection from django.db.backends.signals import connection_created -from django.test import SimpleTestCase, TestCase +from django.test import SimpleTestCase, TransactionTestCase from django_mongodb_backend.base import DatabaseWrapper @@ -15,7 +15,9 @@ def test_database_name_empty(self): DatabaseWrapper(settings).get_connection_params() -class DatabaseWrapperConnectionTests(TestCase): +class DatabaseWrapperConnectionTests(TransactionTestCase): + available_apps = ["backend_"] + def test_set_autocommit(self): self.assertIs(connection.get_autocommit(), True) connection.set_autocommit(False) From 4eb66e99fc70645596a22517f4102579935c61db Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Sat, 7 Jun 2025 20:37:44 -0400 Subject: [PATCH 11/20] add expected transaction failures --- django_mongodb_backend/features.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/django_mongodb_backend/features.py b/django_mongodb_backend/features.py index c60c32719..0c9e5719d 100644 --- a/django_mongodb_backend/features.py +++ b/django_mongodb_backend/features.py @@ -96,6 +96,11 @@ class DatabaseFeatures(BaseDatabaseFeatures): "expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_right_null", "expressions.tests.ExpressionOperatorTests.test_lefthand_transformed_field_bitwise_or", } + _django_test_expected_failures_transactions = { + # When get_or_create() fails with IntegrityError, the transaction is no longer usable. + "get_or_create.tests.UpdateOrCreateTests.test_manual_primary_key_test", + "get_or_create.tests.UpdateOrCreateTestsWithManualPKs.test_create_with_duplicate_primary_key", + } @cached_property def supports_transactions(self): @@ -125,6 +130,8 @@ def django_test_expected_failures(self): expected_failures.update(self._django_test_expected_failures) if not self.is_mongodb_6_3: expected_failures.update(self._django_test_expected_failures_bitwise) + if self.supports_transactions: + expected_failures.update(self._django_test_expected_failures_transactions) return expected_failures django_test_skips = { From 749ab890e5e39e65d1a935cecee4469a63e9fd2c Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Mon, 9 Jun 2025 08:32:18 -0400 Subject: [PATCH 12/20] move session to instance variable; clear commit hooks in close_pool() --- django_mongodb_backend/base.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index fff19330e..7f66ab1d0 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -2,6 +2,7 @@ import os from django.core.exceptions import ImproperlyConfigured +from django.db import DEFAULT_DB_ALIAS from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.utils import debug_transaction from django.utils.asyncio import async_unsafe @@ -140,7 +141,10 @@ def _isnull_operator(a, b): introspection_class = DatabaseIntrospection ops_class = DatabaseOperations validation_class = DatabaseValidation - session = None + + def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS): + super().__init__(settings_dict, alias=alias) + self.session = None def get_collection(self, name, **kwargs): collection = Collection(self.database, name, **kwargs) @@ -237,6 +241,9 @@ def close(self): def close_pool(self): """Close the MongoClient.""" + # Clear commit hooks and session. + self.run_on_commit = [] + self.session = None connection = self.connection if connection is None: return From f26e114b8061571f1a6792a73b679ac43fbcb78c Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Mon, 9 Jun 2025 08:34:25 -0400 Subject: [PATCH 13/20] remove redundant test script Transactions are tested in Atlas tests. --- .github/workflows/test-python-atlas.yml | 2 +- .github/workflows/test-python-replica.yml | 58 ----------------------- 2 files changed, 1 insertion(+), 59 deletions(-) delete mode 100644 .github/workflows/test-python-replica.yml diff --git a/.github/workflows/test-python-atlas.yml b/.github/workflows/test-python-atlas.yml index 175dfe183..f84a99345 100644 --- a/.github/workflows/test-python-atlas.yml +++ b/.github/workflows/test-python-atlas.yml @@ -53,4 +53,4 @@ jobs: working-directory: . run: bash .github/workflows/start_local_atlas.sh mongodb/mongodb-atlas-local:7 - name: Run tests - run: python3 django_repo/tests/runtests_.py + run: python3 django_repo/tests/runtests.py --settings mongodb_settings -v 2 diff --git a/.github/workflows/test-python-replica.yml b/.github/workflows/test-python-replica.yml deleted file mode 100644 index a51481bfd..000000000 --- a/.github/workflows/test-python-replica.yml +++ /dev/null @@ -1,58 +0,0 @@ -name: Python Tests on a replica test - -on: - pull_request: - paths: - - '**.py' - - '!setup.py' - - '.github/workflows/test-python-replica.yml' - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -defaults: - run: - shell: bash -eux {0} - -jobs: - build: - name: Django Test Suite - runs-on: ubuntu-latest - steps: - - name: Checkout django-mongodb-backend - uses: actions/checkout@v4 - with: - persist-credentials: false - - name: install django-mongodb-backend - run: | - pip3 install --upgrade pip - pip3 install -e . - - name: Checkout Django - uses: actions/checkout@v4 - with: - repository: 'mongodb-forks/django' - ref: 'mongodb-5.2.x' - path: 'django_repo' - persist-credentials: false - - name: Install system packages for Django's Python test dependencies - run: | - sudo apt-get update - sudo apt-get install libmemcached-dev - - name: Install Django and its Python test dependencies - run: | - cd django_repo/tests/ - pip3 install -e .. - pip3 install -r requirements/py3.txt - - name: Copy the test settings file - run: cp .github/workflows/mongodb_settings.py django_repo/tests/ - - name: Copy the test runner file - run: cp .github/workflows/runtests.py django_repo/tests/runtests_.py - - name: Start MongoDB - uses: supercharge/mongodb-github-action@90004df786821b6308fb02299e5835d0dae05d0d # 1.12.0 - with: - mongodb-version: 6.0 - mongodb-replica-set: test-rs - - name: Run tests - run: python3 django_repo/tests/runtests.py --settings mongodb_settings -v 2 From 57c3683412cec4611679980265a2f331c42f9ed5 Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Mon, 9 Jun 2025 10:29:25 -0400 Subject: [PATCH 14/20] add tests for DatabaseFeatures.supports_transactions --- django_mongodb_backend/features.py | 39 +++++++-------- tests/backend_/test_features.py | 76 ++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 22 deletions(-) create mode 100644 tests/backend_/test_features.py diff --git a/django_mongodb_backend/features.py b/django_mongodb_backend/features.py index 0c9e5719d..8ccf0ac83 100644 --- a/django_mongodb_backend/features.py +++ b/django_mongodb_backend/features.py @@ -102,28 +102,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): "get_or_create.tests.UpdateOrCreateTestsWithManualPKs.test_create_with_duplicate_primary_key", } - @cached_property - def supports_transactions(self): - """Confirm support for transactions.""" - is_replica_set = False - is_sharded_cluster = False - with self.connection.cursor(): - client = self.connection.connection - hello_response = client.admin.command("hello") - server_status = client.admin.command("serverStatus") - if "setName" in hello_response: - is_replica_set = True - if "msg" in hello_response and hello_response["msg"] == "isdbgrid": - is_sharded_cluster = True - if ( - "storageEngine" in server_status - and server_status["storageEngine"].get("name") == "wiredTiger" - ): - is_wired_tiger = True - if (is_replica_set or is_sharded_cluster) and is_wired_tiger: - return True - return False - @cached_property def django_test_expected_failures(self): expected_failures = super().django_test_expected_failures @@ -611,3 +589,20 @@ def supports_atlas_search(self): def supports_select_union(self): # Stage not supported inside of a multi-document transaction: $unionWith return not self.supports_transactions + + @cached_property + def supports_transactions(self): + """ + Transactions are enabled if the MongoDB configuration supports it: + MongoDB must be configured as a replica set or sharded cluster, and + the store engine must be WiredTiger. + """ + self.connection.ensure_connection() + client = self.connection.connection.admin + hello_response = client.command("hello") + is_replica_set = "setName" in hello_response + is_sharded_cluster = hello_response.get("msg") == "isdbgrid" + if is_replica_set or is_sharded_cluster: + engine = client.command("serverStatus").get("storageEngine", {}) + return engine.get("name") == "wiredTiger" + return False diff --git a/tests/backend_/test_features.py b/tests/backend_/test_features.py new file mode 100644 index 000000000..3b18e64ae --- /dev/null +++ b/tests/backend_/test_features.py @@ -0,0 +1,76 @@ +from unittest.mock import patch + +from django.db import connection +from django.test import TestCase + + +class SupportsTransactionsTests(TestCase): + def setUp(self): + # Clear the cached property. + del connection.features.supports_transactions + + def tearDown(self): + del connection.features.supports_transactions + + def test_replica_set(self): + """A replica set supports transactions.""" + + def mocked_command(command): + if command == "hello": + return {"setName": "foo"} + if command == "serverStatus": + return {"storageEngine": {"name": "wiredTiger"}} + raise Exception("Unexpected command") + + with patch("pymongo.synchronous.database.Database.command", wraps=mocked_command): + self.assertIs(connection.features.supports_transactions, True) + + def test_replica_set_other_storage_engine(self): + """No support on a non-wiredTiger replica set.""" + + def mocked_command(command): + if command == "hello": + return {"setName": "foo"} + if command == "serverStatus": + return {"storageEngine": {"name": "other"}} + raise Exception("Unexpected command") + + with patch("pymongo.synchronous.database.Database.command", wraps=mocked_command): + self.assertIs(connection.features.supports_transactions, False) + + def test_sharded_cluster(self): + """A sharded cluster with wiredTiger storage engine supports them.""" + + def mocked_command(command): + if command == "hello": + return {"msg": "isdbgrid"} + if command == "serverStatus": + return {"storageEngine": {"name": "wiredTiger"}} + raise Exception("Unexpected command") + + with patch("pymongo.synchronous.database.Database.command", wraps=mocked_command): + self.assertIs(connection.features.supports_transactions, True) + + def test_sharded_cluster_other_storage_engine(self): + """No support on a non-wiredTiger shared cluster.""" + + def mocked_command(command): + if command == "hello": + return {"msg": "isdbgrid"} + if command == "serverStatus": + return {"storageEngine": {"name": "other"}} + raise Exception("Unexpected command") + + with patch("pymongo.synchronous.database.Database.command", wraps=mocked_command): + self.assertIs(connection.features.supports_transactions, False) + + def test_no_support(self): + """No support on a non-replica set, non-sharded cluster.""" + + def mocked_command(command): + if command == "hello": + return {} + raise Exception("Unexpected command") + + with patch("pymongo.synchronous.database.Database.command", wraps=mocked_command): + self.assertIs(connection.features.supports_transactions, False) From 65258b34146f10456d44a8e8805e3f9846dd3729 Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Mon, 9 Jun 2025 10:48:50 -0400 Subject: [PATCH 15/20] docs --- django_mongodb_backend/features.py | 3 ++- docs/source/conf.py | 1 + docs/source/ref/database.rst | 42 +++++++++++++++++++++++++++++ docs/source/releases/5.2.x.rst | 10 +++++++ docs/source/topics/known-issues.rst | 6 +---- 5 files changed, 56 insertions(+), 6 deletions(-) diff --git a/django_mongodb_backend/features.py b/django_mongodb_backend/features.py index 8ccf0ac83..eb75d4ae6 100644 --- a/django_mongodb_backend/features.py +++ b/django_mongodb_backend/features.py @@ -97,7 +97,8 @@ class DatabaseFeatures(BaseDatabaseFeatures): "expressions.tests.ExpressionOperatorTests.test_lefthand_transformed_field_bitwise_or", } _django_test_expected_failures_transactions = { - # When get_or_create() fails with IntegrityError, the transaction is no longer usable. + # When update_or_create() fails with IntegrityError, the transaction + # is no longer usable. "get_or_create.tests.UpdateOrCreateTests.test_manual_primary_key_test", "get_or_create.tests.UpdateOrCreateTestsWithManualPKs.test_create_with_duplicate_primary_key", } diff --git a/docs/source/conf.py b/docs/source/conf.py index a4e549381..2f1c8675a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -45,6 +45,7 @@ "pymongo": ("https://pymongo.readthedocs.io/en/stable/", None), "python": ("https://docs.python.org/3/", None), "atlas": ("https://www.mongodb.com/docs/atlas/", None), + "manual": ("https://www.mongodb.com/docs/manual/", None), } root_doc = "contents" diff --git a/docs/source/ref/database.rst b/docs/source/ref/database.rst index 365e1c73d..3b0d0930c 100644 --- a/docs/source/ref/database.rst +++ b/docs/source/ref/database.rst @@ -41,3 +41,45 @@ effect. Rather, if you need to close the connection pool, use .. versionadded:: 5.2.0b0 Support for connection pooling and ``connection.close_pool()`` were added. + +.. _transactions: + +Transactions +============ + +.. versionadded:: 5.2.0b2 + +Support for :doc:`Django's transactions APIs ` +is enabled if the MongoDB configuration supports them: MongoDB must be +configured as a :doc:`replica set ` or :doc:`sharded +cluster `, and the store engine must be :doc:`WiredTiger +`. + +If transactions aren't supported, query execution uses Django and MongoDB's +default behavior of autocommit mode. Each query is immediately committed to the +database. Django's transaction management APIs, such as +:func:`~django.db.transaction.atomic`, function as no-ops. + +.. _transactions-limitations: + +Limitations +----------- + +MongoDB's transaction limitations that are applicable to Django are: + +- :meth:`QuerySet.union() ` is not + supported inside a transaction. +- If a transaction raises an exception, the transaction is no longer usable. + For example, if the update stage of :meth:`QuerySet.update_or_create() + ` fails with + :class:`~django.db.IntegrityError` due to a unique constraint violation, the + create stage won't be able to proceed. + :class:`pymongo.errors.OperationFailure` is raised, wrapped by + :class:`django.db.DatabaseError`. +- Savepoints (i.e. nested :func:`~django.db.transaction.atomic` blocks) aren't + supported. The outermost :func:`~django.db.transaction.atomic` will start + a transaction while any subsequent :func:`~django.db.transaction.atomic` + blocks will have no effect. +- Migration operations aren't :ref:`wrapped in a transaction + ` because of MongoDB restrictions such as + adding indexes to existing collections while in a transaction. diff --git a/docs/source/releases/5.2.x.rst b/docs/source/releases/5.2.x.rst index de4b6efcb..e5e9b05ff 100644 --- a/docs/source/releases/5.2.x.rst +++ b/docs/source/releases/5.2.x.rst @@ -2,6 +2,16 @@ Django MongoDB Backend 5.2.x ============================ +5.2.0 beta 2 +============ + +*Unreleased* + +New features +------------ + +- Added support for :ref:`database transactions `. + 5.2.0 beta 1 ============ diff --git a/docs/source/topics/known-issues.rst b/docs/source/topics/known-issues.rst index 4b9edee70..60628f816 100644 --- a/docs/source/topics/known-issues.rst +++ b/docs/source/topics/known-issues.rst @@ -80,11 +80,7 @@ Database functions Transaction management ====================== -Query execution uses Django and MongoDB's default behavior of autocommit mode. -Each query is immediately committed to the database. - -Django's :doc:`transaction management APIs ` -are not supported. +See :ref:`transactions` for details. Database introspection ====================== From 6b86a8b29123b69bef7cf1b99663e46e1a52834e Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Tue, 10 Jun 2025 08:15:22 -0400 Subject: [PATCH 16/20] add validation of broken transactions --- django_mongodb_backend/base.py | 4 ++++ django_mongodb_backend/compiler.py | 2 ++ django_mongodb_backend/features.py | 30 ++++++++++++++++++++++++++---- django_mongodb_backend/query.py | 2 ++ 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index 7f66ab1d0..dcc492863 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -259,6 +259,10 @@ def close_pool(self): def cursor(self): return Cursor() + def validate_no_broken_transaction(self): + if self.features.supports_transactions: + super().validate_no_broken_transaction() + def get_database_version(self): """Return a tuple of the database's version.""" return tuple(self.connection.server_info()["versionArray"]) diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 74a5d60e0..b7e264f8b 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -685,6 +685,7 @@ def execute_sql(self, returning_fields=None): @wrap_database_errors def insert(self, docs, returning_fields=None): """Store a list of documents using field columns as element names.""" + self.connection.validate_no_broken_transaction() inserted_ids = self.collection.insert_many( docs, session=self.connection.session ).inserted_ids @@ -770,6 +771,7 @@ def execute_sql(self, result_type): @wrap_database_errors def update(self, criteria, pipeline): + self.connection.validate_no_broken_transaction() return self.collection.update_many( criteria, pipeline, session=self.connection.session ).matched_count diff --git a/django_mongodb_backend/features.py b/django_mongodb_backend/features.py index eb75d4ae6..92697bde1 100644 --- a/django_mongodb_backend/features.py +++ b/django_mongodb_backend/features.py @@ -48,8 +48,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): "aggregation.tests.AggregateTestCase.test_order_by_aggregate_transform", # 'NulledTransform' object has no attribute 'as_mql'. "lookup.tests.LookupTests.test_exact_none_transform", - # "Save with update_fields did not affect any rows." - "basic.tests.SelectOnSaveTests.test_select_on_save_lying_update", # BaseExpression.convert_value() crashes with Decimal128. "aggregation.tests.AggregateTestCase.test_combine_different_types", "annotations.tests.NonAggregateAnnotationTestCase.test_combined_f_expression_annotation_with_aggregation", @@ -84,8 +82,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): # Value.as_mql() doesn't call output_field.get_db_prep_save(): # https://github.com/mongodb/django-mongodb-backend/issues/282 "model_fields.test_jsonfield.TestSaveLoad.test_bulk_update_custom_get_prep_value", - # to debug - "transactions.tests.AtomicMiscTests.test_mark_for_rollback_on_error_in_transaction", } # $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3. _django_test_expected_failures_bitwise = { @@ -96,11 +92,35 @@ class DatabaseFeatures(BaseDatabaseFeatures): "expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_right_null", "expressions.tests.ExpressionOperatorTests.test_lefthand_transformed_field_bitwise_or", } + _django_test_expected_failures_no_transactions = { + # "Save with update_fields did not affect any rows." instead of + # "An error occurred in the current transaction. You can't execute + # queries until the end of the 'atomic' block." + "basic.tests.SelectOnSaveTests.test_select_on_save_lying_update", + } _django_test_expected_failures_transactions = { # When update_or_create() fails with IntegrityError, the transaction # is no longer usable. "get_or_create.tests.UpdateOrCreateTests.test_manual_primary_key_test", "get_or_create.tests.UpdateOrCreateTestsWithManualPKs.test_create_with_duplicate_primary_key", + # Tests that require savepoints + "admin_views.tests.AdminViewBasicTest.test_disallowed_to_field", + "admin_views.tests.AdminViewPermissionsTest.test_add_view", + "admin_views.tests.AdminViewPermissionsTest.test_change_view", + "admin_views.tests.AdminViewPermissionsTest.test_change_view_save_as_new", + "admin_views.tests.AdminViewPermissionsTest.test_delete_view", + "auth_tests.test_views.ChangelistTests.test_view_user_password_is_readonly", + "fixtures.tests.FixtureLoadingTests.test_loaddata_app_option", + "fixtures.tests.FixtureLoadingTests.test_unmatched_identifier_loading", + "fixtures_model_package.tests.FixtureTestCase.test_loaddata", + "get_or_create.tests.GetOrCreateTests.test_get_or_create_invalid_params", + "get_or_create.tests.UpdateOrCreateTests.test_integrity", + "many_to_many.tests.ManyToManyTests.test_add", + "many_to_one.tests.ManyToOneTests.test_fk_assignment_and_related_object_cache", + "model_fields.test_booleanfield.BooleanFieldTests.test_null_default", + "model_fields.test_floatfield.TestFloatField.test_float_validates_object", + "multiple_database.tests.QueryTestCase.test_generic_key_cross_database_protection", + "multiple_database.tests.QueryTestCase.test_m2m_cross_database_protection", } @cached_property @@ -111,6 +131,8 @@ def django_test_expected_failures(self): expected_failures.update(self._django_test_expected_failures_bitwise) if self.supports_transactions: expected_failures.update(self._django_test_expected_failures_transactions) + else: + expected_failures.update(self._django_test_expected_failures_no_transactions) return expected_failures django_test_skips = { diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index 8c18fcb03..d59bc1631 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -61,6 +61,7 @@ def __repr__(self): @wrap_database_errors def delete(self): """Execute a delete query.""" + self.compiler.connection.validate_no_broken_transaction() if self.compiler.subqueries: raise NotSupportedError("Cannot use QuerySet.delete() when a subquery is required.") return self.compiler.collection.delete_many( @@ -73,6 +74,7 @@ def get_cursor(self): Return a pymongo CommandCursor that can be iterated on to give the results of the query. """ + self.compiler.connection.validate_no_broken_transaction() return self.compiler.collection.aggregate( self.get_pipeline(), session=self.compiler.connection.session ) From 6eba5a446ca3fe33ee6ece6325bd2dbf2b0aee45 Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Mon, 16 Jun 2025 13:59:21 -0400 Subject: [PATCH 17/20] rename _start_session() --- django_mongodb_backend/base.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index dcc492863..65bfca9b9 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -212,22 +212,30 @@ def _rollback(self): self.session.abort_transaction() self.session = None - def _start_session(self): + def _start_transaction(self): if self.session is None: self.session = self.connection.start_session() with debug_transaction(self, "session.start_transaction()"): self.session.start_transaction() def _start_transaction_under_autocommit(self): + # Implementing this hook (intended only for SQLite), allows + # BaseDatabaseWrapper.set_autocommit() to use it to start a transaction + # rather than set_autocommit(), bypassing set_autocommit()'s call to + # debug_transaction(self, "BEGIN") which isn't semantic for a no-SQL + # backend. if not self.features.supports_transactions: return - self._start_session() + self._start_transaction() def _set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False): if self.features.supports_transactions: return + # Besides @transaction.atomic() (which uses + # _start_transaction_under_autocommit(), disabling autocommit is + # another way to start a transaction. if not autocommit: - self._start_session() + self._start_transaction() def _close(self): # Normally called by close(), this method is also called by some tests. From 5b3b757b45333965d883e7088ff587c3b5964621 Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Mon, 16 Jun 2025 16:58:30 -0400 Subject: [PATCH 18/20] add @requires_transaction_support --- django_mongodb_backend/base.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index 65bfca9b9..1c35fce27 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -34,6 +34,17 @@ def __exit__(self, exception_type, exception_value, exception_traceback): pass +def requires_transaction_support(func): + """Make a method a no-op if transactions aren't supported.""" + + def wrapper(self, *args, **kwargs): + if not self.features.supports_transactions: + return + func(self, *args, **kwargs) + + return wrapper + + class DatabaseWrapper(BaseDatabaseWrapper): data_types = { "AutoField": "int", @@ -195,18 +206,16 @@ def _driver_info(self): return DriverInfo("django-mongodb-backend", django_mongodb_backend_version) return None + @requires_transaction_support def _commit(self): - if not self.features.supports_transactions: - return if self.session: with debug_transaction(self, "session.commit_transaction()"): self.session.commit_transaction() self.session.end_session() self.session = None + @requires_transaction_support def _rollback(self): - if not self.features.supports_transactions: - return if self.session: with debug_transaction(self, "session.abort_transaction()"): self.session.abort_transaction() @@ -218,19 +227,17 @@ def _start_transaction(self): with debug_transaction(self, "session.start_transaction()"): self.session.start_transaction() + @requires_transaction_support def _start_transaction_under_autocommit(self): # Implementing this hook (intended only for SQLite), allows # BaseDatabaseWrapper.set_autocommit() to use it to start a transaction # rather than set_autocommit(), bypassing set_autocommit()'s call to # debug_transaction(self, "BEGIN") which isn't semantic for a no-SQL # backend. - if not self.features.supports_transactions: - return self._start_transaction() + @requires_transaction_support def _set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False): - if self.features.supports_transactions: - return # Besides @transaction.atomic() (which uses # _start_transaction_under_autocommit(), disabling autocommit is # another way to start a transaction. @@ -267,9 +274,9 @@ def close_pool(self): def cursor(self): return Cursor() + @requires_transaction_support def validate_no_broken_transaction(self): - if self.features.supports_transactions: - super().validate_no_broken_transaction() + super().validate_no_broken_transaction() def get_database_version(self): """Return a tuple of the database's version.""" From 2ffbaeb1b5010dca0c053356713ab0c2fc84e07f Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Mon, 16 Jun 2025 18:34:13 -0400 Subject: [PATCH 19/20] always end session --- django_mongodb_backend/base.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index 1c35fce27..fc21fa5b6 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -211,22 +211,27 @@ def _commit(self): if self.session: with debug_transaction(self, "session.commit_transaction()"): self.session.commit_transaction() - self.session.end_session() - self.session = None + self._end_session() @requires_transaction_support def _rollback(self): if self.session: with debug_transaction(self, "session.abort_transaction()"): self.session.abort_transaction() - self.session = None + self._end_session() def _start_transaction(self): + # Private API, specific to this backend. if self.session is None: self.session = self.connection.start_session() with debug_transaction(self, "session.start_transaction()"): self.session.start_transaction() + def _end_session(self): + # Private API, specific to this backend. + self.session.end_session() + self.session = None + @requires_transaction_support def _start_transaction_under_autocommit(self): # Implementing this hook (intended only for SQLite), allows @@ -258,7 +263,8 @@ def close_pool(self): """Close the MongoClient.""" # Clear commit hooks and session. self.run_on_commit = [] - self.session = None + if self.session: + self._end_session() connection = self.connection if connection is None: return From d0a88a190f9b6c998eb6cd5314916c8b1f59fd67 Mon Sep 17 00:00:00 2001 From: "Jeffrey A. Clark" Date: Tue, 17 Jun 2025 10:44:16 -0400 Subject: [PATCH 20/20] Add queryable encryption config --- django_mongodb_backend/__init__.py | 4 +- django_mongodb_backend/features.py | 10 +++++ django_mongodb_backend/utils.py | 60 ++++++++++++++++++++++++++ tests/backend_/utils/test_parse_uri.py | 10 ++++- 4 files changed, 81 insertions(+), 3 deletions(-) diff --git a/django_mongodb_backend/__init__.py b/django_mongodb_backend/__init__.py index 00700421a..1c9f88f39 100644 --- a/django_mongodb_backend/__init__.py +++ b/django_mongodb_backend/__init__.py @@ -2,7 +2,7 @@ # Check Django compatibility before other imports which may fail if the # wrong version of Django is installed. -from .utils import check_django_compatability, parse_uri +from .utils import check_django_compatability, get_auto_encryption_options, parse_uri check_django_compatability() @@ -15,7 +15,7 @@ from .lookups import register_lookups # noqa: E402 from .query import register_nodes # noqa: E402 -__all__ = ["parse_uri"] +__all__ = ["get_auto_encryption_options", "parse_uri"] register_aggregates() register_checks() diff --git a/django_mongodb_backend/features.py b/django_mongodb_backend/features.py index 92697bde1..487abffec 100644 --- a/django_mongodb_backend/features.py +++ b/django_mongodb_backend/features.py @@ -629,3 +629,13 @@ def supports_transactions(self): engine = client.command("serverStatus").get("storageEngine", {}) return engine.get("name") == "wiredTiger" return False + + @cached_property + def supports_queryable_encryption(self): + """ + Queryable Encryption is available if the server is Atlas or Enterprise. + """ + self.connection.ensure_connection() + client = self.connection.connection.admin + build_info = client.command("buildInfo") + return "enterprise" in build_info.get("modules") diff --git a/django_mongodb_backend/utils.py b/django_mongodb_backend/utils.py index ced60bc8e..15ebc1032 100644 --- a/django_mongodb_backend/utils.py +++ b/django_mongodb_backend/utils.py @@ -1,5 +1,8 @@ import copy +import os import time +from pathlib import Path +from urllib.parse import urlencode import django from django.conf import settings @@ -8,6 +11,7 @@ from django.utils.functional import SimpleLazyObject from django.utils.text import format_lazy from django.utils.version import get_version_tuple +from pymongo.encryption_options import AutoEncryptionOpts from pymongo.uri_parser import parse_uri as pymongo_parse_uri @@ -28,6 +32,62 @@ def check_django_compatability(): ) +# Queryable Encryption-related functions based on helpers from Python Queryable +# Encryption Tutorial +# https://github.com/mongodb/docs/tree/master/source/includes/qe-tutorials/python/ +def _get_kms_provider_credentials(kms_provider_name): + """ + "A KMS is a remote service that securely stores and manages your encryption keys." + + Via https://www.mongodb.com/docs/manual/core/queryable-encryption/quick-start/ + + Here we check the provider name and return the appropriate credentials. + """ + # TODO: Add support for other KMS providers. + if kms_provider_name == "local": + if not Path("./customer-master-key.txt").exists: + try: + path = "customer-master-key.txt" + file_bytes = os.urandom(96) + with Path.open(path, "wb") as f: + f.write(file_bytes) + except Exception as e: + raise Exception( + "Unable to write Customer Master Key to file due to the following error: " + ) from e + + try: + path = "./customer-master-key.txt" + with Path.open(path, "rb") as f: + local_master_key = f.read() + if len(local_master_key) != 96: + raise Exception("Expected the customer master key file to be 96 bytes.") + return { + "local": {"key": local_master_key}, + } + except Exception as e: + raise Exception( + "Unable to read Customer Master Key from file due to the following error: " + ) from e + else: + raise ValueError( + "Unrecognized value for kms_provider_name encountered while retrieving KMS credentials." + ) + + +def get_auto_encryption_options(kms_provider_name): + key_vault_database_name = "encryption" + key_vault_collection_name = "__keyVault" + key_vault_namespace = f"{key_vault_database_name}.{key_vault_collection_name}" + kms_provider_credentials = _get_kms_provider_credentials(kms_provider_name) + auto_encryption_opts = AutoEncryptionOpts( + kms_provider_credentials, + key_vault_namespace, + crypt_shared_lib_path=os.environ.get("SHARED_LIB_PATH"), + ) + return urlencode(auto_encryption_opts) + + def parse_uri(uri, *, db_name=None, test=None): """ Convert the given uri into a dictionary suitable for Django's DATABASES diff --git a/tests/backend_/utils/test_parse_uri.py b/tests/backend_/utils/test_parse_uri.py index 3198a4630..c5511baf8 100644 --- a/tests/backend_/utils/test_parse_uri.py +++ b/tests/backend_/utils/test_parse_uri.py @@ -1,10 +1,11 @@ from unittest.mock import patch +from urllib.parse import parse_qs import pymongo from django.core.exceptions import ImproperlyConfigured from django.test import SimpleTestCase -from django_mongodb_backend import parse_uri +from django_mongodb_backend import get_auto_encryption_options, parse_uri class ParseURITests(SimpleTestCase): @@ -94,3 +95,10 @@ def test_invalid_credentials(self): def test_no_scheme(self): with self.assertRaisesMessage(pymongo.errors.InvalidURI, "Invalid URI scheme"): parse_uri("cluster0.example.mongodb.net") + + def test_queryable_encryption_config(self): + auto_encryption_options = get_auto_encryption_options("local") + settings_dict = parse_uri( + f"mongodb://cluster0.example.mongodb.net/myDatabase{auto_encryption_options}" + ) + self.assertEqual(settings_dict["OPTIONS"], parse_qs(auto_encryption_options))