Skip to content

Commit

Permalink
Export pytest_django.DjangoDbBlocker for typing purposes
Browse files Browse the repository at this point in the history
For users who want to type `django_db_blocker` in their tests.
  • Loading branch information
bluetech committed Nov 7, 2023
1 parent 53eead4 commit d93631f
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 17 deletions.
18 changes: 10 additions & 8 deletions docs/database.rst
Original file line number Diff line number Diff line change
Expand Up @@ -295,19 +295,21 @@ access for the specified block::

You can also manage the access manually via these methods:

.. py:method:: django_db_blocker.unblock()
.. py:class:: pytest_django.DjangoDbBlocker
Enable database access. Should be followed by a call to
:func:`~django_db_blocker.restore`.
.. py:method:: django_db_blocker.unblock()
.. py:method:: django_db_blocker.block()
Enable database access. Should be followed by a call to
:func:`~django_db_blocker.restore` or used as a context manager.

Disable database access. Should be followed by a call to
:func:`~django_db_blocker.restore`.
.. py:method:: django_db_blocker.block()
.. py:method:: django_db_blocker.restore()
Disable database access. Should be followed by a call to
:func:`~django_db_blocker.restore` or used as a context manager.

Restore the previous state of the database blocking.
.. py:method:: django_db_blocker.restore()
Restore the previous state of the database blocking.

Examples
########
Expand Down
4 changes: 4 additions & 0 deletions pytest_django/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
__version__ = "unknown"


from .plugin import DjangoDbBlocker


__all__ = [
"__version__",
"DjangoDbBlocker",
]
6 changes: 4 additions & 2 deletions pytest_django/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import django
import django.test

from . import DjangoDbBlocker


_DjangoDbDatabases = Optional[Union[Literal["__all__"], Iterable[str]]]
_DjangoDbAvailableApps = Optional[List[str]]
Expand Down Expand Up @@ -114,7 +116,7 @@ def django_db_createdb(request: pytest.FixtureRequest) -> bool:
def django_db_setup(
request: pytest.FixtureRequest,
django_test_environment: None,
django_db_blocker,
django_db_blocker: DjangoDbBlocker,
django_db_use_migrations: bool,
django_db_keepdb: bool,
django_db_createdb: bool,
Expand Down Expand Up @@ -154,7 +156,7 @@ def django_db_setup(
def _django_db_helper(
request: pytest.FixtureRequest,
django_db_setup: None,
django_db_blocker,
django_db_blocker: DjangoDbBlocker,
) -> Generator[None, None, None]:
from django import VERSION

Expand Down
26 changes: 19 additions & 7 deletions pytest_django/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import os
import pathlib
import sys
import types
from functools import reduce
from typing import TYPE_CHECKING, ContextManager, Generator, List, NoReturn

Expand Down Expand Up @@ -495,7 +496,7 @@ def django_test_environment(request: pytest.FixtureRequest) -> Generator[None, N


@pytest.fixture(scope="session")
def django_db_blocker() -> _DatabaseBlocker | None:
def django_db_blocker() -> DjangoDbBlocker | None:
"""Wrapper around Django's database access.
This object can be used to re-enable database access. This fixture is used
Expand Down Expand Up @@ -525,7 +526,7 @@ def _django_db_marker(request: pytest.FixtureRequest) -> None:
@pytest.fixture(autouse=True, scope="class")
def _django_setup_unittest(
request: pytest.FixtureRequest,
django_db_blocker: _DatabaseBlocker,
django_db_blocker: DjangoDbBlocker,
) -> Generator[None, None, None]:
"""Setup a django unittest, internal to pytest-django."""
if not django_settings_is_configured() or not is_django_unittest(request):
Expand Down Expand Up @@ -743,23 +744,34 @@ def _django_clear_site_cache() -> None:


class _DatabaseBlockerContextManager:
def __init__(self, db_blocker) -> None:
def __init__(self, db_blocker: DjangoDbBlocker) -> None:
self._db_blocker = db_blocker

def __enter__(self) -> None:
pass

def __exit__(self, exc_type, exc_value, traceback) -> None:
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
self._db_blocker.restore()


class _DatabaseBlocker:
class DjangoDbBlocker:
"""Manager for django.db.backends.base.base.BaseDatabaseWrapper.
This is the object returned by django_db_blocker.
"""

def __init__(self) -> None:
def __init__(self, *, _ispytest: bool = False) -> None:
if not _ispytest: # pragma: no cover
raise TypeError(
"The DjangoDbBlocker constructor is private. "
"use the django_db_blocker fixture instead."
)

self._history = [] # type: ignore[var-annotated]
self._real_ensure_connection = None

Expand Down Expand Up @@ -801,7 +813,7 @@ def restore(self) -> None:
self._dj_db_wrapper.ensure_connection = self._history.pop()


_blocking_manager = _DatabaseBlocker()
_blocking_manager = DjangoDbBlocker(_ispytest=True)


def validate_urls(marker) -> list[str]:
Expand Down

0 comments on commit d93631f

Please sign in to comment.