Skip to content

Commit

Permalink
Merge pull request #375 from chaen/fix_mock_os
Browse files Browse the repository at this point in the history
tests: fix the mock_osdb
  • Loading branch information
chrisburr authored Jan 16, 2025
2 parents 2bf4e28 + 2a46eea commit 4ddba9d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
15 changes: 10 additions & 5 deletions diracx-testing/src/diracx/testing/mock_osdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,22 @@ async def client_context(self) -> AsyncIterator[None]:
yield

async def __aenter__(self):
await self._sql_db.__aenter__()
"""Enter the request context.
This is a no-op as the real OpenSearch class doesn't use transactions.
Instead we enter a transaction in each method that needs it.
"""
return self

async def __aexit__(self, exc_type, exc_value, traceback):
await self._sql_db.__aexit__(exc_type, exc_value, traceback)
pass

async def create_index_template(self) -> None:
async with self._sql_db.engine.begin() as conn:
await conn.run_sync(self._sql_db.metadata.create_all)

async def upsert(self, doc_id, document) -> None:
async with self:
async with self._sql_db:
values = {}
for key, value in document.items():
if key in self.fields:
Expand All @@ -106,7 +110,7 @@ async def search(
per_page: int = 100,
page: int | None = None,
) -> tuple[int, list[dict[Any, Any]]]:
async with self:
async with self._sql_db:
# Apply selection
if parameters:
columns = []
Expand Down Expand Up @@ -150,7 +154,8 @@ async def search(
return results

async def ping(self):
return await self._sql_db.ping()
async with self._sql_db:
return await self._sql_db.ping()


def fake_available_osdb_implementations(name, *, real_available_implementations):
Expand Down
18 changes: 14 additions & 4 deletions diracx-testing/src/diracx/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def configure(self, enabled_dependencies):
assert (
self.app.dependency_overrides == {} and self.app.lifetime_functions == []
), "configure cannot be nested"

for k, v in self.all_dependency_overrides.items():

class_name = k.__self__.__name__
Expand Down Expand Up @@ -284,17 +285,26 @@ async def create_db_schemas(self):
import sqlalchemy
from sqlalchemy.util.concurrency import greenlet_spawn

from diracx.db.os.utils import BaseOSDB
from diracx.db.sql.utils import BaseSQLDB
from diracx.testing.mock_osdb import MockOSDBMixin

for k, v in self.app.dependency_overrides.items():
# Ignore dependency overrides which aren't BaseSQLDB.transaction
if (
isinstance(v, UnavailableDependency)
or k.__func__ != BaseSQLDB.transaction.__func__
# Ignore dependency overrides which aren't BaseSQLDB.transaction or BaseOSDB.session
if isinstance(v, UnavailableDependency) or k.__func__ not in (
BaseSQLDB.transaction.__func__,
BaseOSDB.session.__func__,
):

continue

# The first argument of the overridden BaseSQLDB.transaction is the DB object
db = v.args[0]
# We expect the OS DB to be mocked with sqlite, so use the
# internal DB
if isinstance(db, MockOSDBMixin):
db = db._sql_db

assert isinstance(db, BaseSQLDB), (k, db)

# set PRAGMA foreign_keys=ON if sqlite
Expand Down

0 comments on commit 4ddba9d

Please sign in to comment.