From 4d40913387aa75fbc4ec15e3d4b15f0ae2f86861 Mon Sep 17 00:00:00 2001
From: Christophe Haen <christophe.haen@cern.ch>
Date: Thu, 16 Jan 2025 13:02:40 +0100
Subject: [PATCH 1/2] tests: Set up opensearch DBs in ClientFactory

---
 diracx-testing/src/diracx/testing/utils.py | 18 ++++++++++++++----
 1 file changed, 14 insertions(+), 4 deletions(-)

diff --git a/diracx-testing/src/diracx/testing/utils.py b/diracx-testing/src/diracx/testing/utils.py
index c895f4d4..73bfef24 100644
--- a/diracx-testing/src/diracx/testing/utils.py
+++ b/diracx-testing/src/diracx/testing/utils.py
@@ -252,6 +252,7 @@ def configure(self, enabled_dependencies):
         assert (
             self.app.dependency_overrides == {} and self.app.lifetime_functions == []
         ), "configure cannot be nested"
+
         for k, v in self.all_dependency_overrides.items():
 
             class_name = k.__self__.__name__
@@ -284,17 +285,26 @@ async def create_db_schemas(self):
         import sqlalchemy
         from sqlalchemy.util.concurrency import greenlet_spawn
 
+        from diracx.db.os.utils import BaseOSDB
         from diracx.db.sql.utils import BaseSQLDB
+        from diracx.testing.mock_osdb import MockOSDBMixin
 
         for k, v in self.app.dependency_overrides.items():
-            # Ignore dependency overrides which aren't BaseSQLDB.transaction
-            if (
-                isinstance(v, UnavailableDependency)
-                or k.__func__ != BaseSQLDB.transaction.__func__
+            # Ignore dependency overrides which aren't BaseSQLDB.transaction or BaseOSDB.session
+            if isinstance(v, UnavailableDependency) or k.__func__ not in (
+                BaseSQLDB.transaction.__func__,
+                BaseOSDB.session.__func__,
             ):
+
                 continue
+
             # The first argument of the overridden BaseSQLDB.transaction is the DB object
             db = v.args[0]
+            # We expect the OS DB to be mocked with sqlite, so use the
+            # internal DB
+            if isinstance(db, MockOSDBMixin):
+                db = db._sql_db
+
             assert isinstance(db, BaseSQLDB), (k, db)
 
             # set PRAGMA foreign_keys=ON if sqlite

From 2a46eea73309a66fb3cbf83a7917cc54e1afe846 Mon Sep 17 00:00:00 2001
From: Chris Burr <christopher.burr@cern.ch>
Date: Thu, 16 Jan 2025 14:17:19 +0100
Subject: [PATCH 2/2] Don't enter a transaction in MockOSDBMixin when starting
 a request

In OpenSearch we're currently not using transactions. To accurately mock this with
the MySQL backend we instead enter a transaction in each method call.
---
 diracx-testing/src/diracx/testing/mock_osdb.py | 15 ++++++++++-----
 1 file changed, 10 insertions(+), 5 deletions(-)

diff --git a/diracx-testing/src/diracx/testing/mock_osdb.py b/diracx-testing/src/diracx/testing/mock_osdb.py
index 6e181a79..5b482102 100644
--- a/diracx-testing/src/diracx/testing/mock_osdb.py
+++ b/diracx-testing/src/diracx/testing/mock_osdb.py
@@ -72,18 +72,22 @@ async def client_context(self) -> AsyncIterator[None]:
             yield
 
     async def __aenter__(self):
-        await self._sql_db.__aenter__()
+        """Enter the request context.
+
+        This is a no-op as the real OpenSearch class doesn't use transactions.
+        Instead we enter a transaction in each method that needs it.
+        """
         return self
 
     async def __aexit__(self, exc_type, exc_value, traceback):
-        await self._sql_db.__aexit__(exc_type, exc_value, traceback)
+        pass
 
     async def create_index_template(self) -> None:
         async with self._sql_db.engine.begin() as conn:
             await conn.run_sync(self._sql_db.metadata.create_all)
 
     async def upsert(self, doc_id, document) -> None:
-        async with self:
+        async with self._sql_db:
             values = {}
             for key, value in document.items():
                 if key in self.fields:
@@ -106,7 +110,7 @@ async def search(
         per_page: int = 100,
         page: int | None = None,
     ) -> tuple[int, list[dict[Any, Any]]]:
-        async with self:
+        async with self._sql_db:
             # Apply selection
             if parameters:
                 columns = []
@@ -150,7 +154,8 @@ async def search(
         return results
 
     async def ping(self):
-        return await self._sql_db.ping()
+        async with self._sql_db:
+            return await self._sql_db.ping()
 
 
 def fake_available_osdb_implementations(name, *, real_available_implementations):