Skip to content

Commit

Permalink
Add pytest_django.DjangoCaptureOnCommitCallbacks for typing purposes
Browse files Browse the repository at this point in the history
This allows typing the `django_capture_on_commit_callbacks` fixture.
  • Loading branch information
bluetech committed Nov 8, 2023
1 parent 017bd77 commit 28484f4
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 6 deletions.
8 changes: 8 additions & 0 deletions docs/helpers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,14 @@ Example usage::
assert mailoutbox[0].subject == 'Contact Form'
assert mailoutbox[0].body == 'I like your site'

If you use type annotations, you can annotate the fixture like this::

from pytest_django import DjangoCaptureOnCommitCallbacks

def test_on_commit(
django_capture_on_commit_callbacks: DjangoCaptureOnCommitCallbacks,
):
...

.. fixture:: mailoutbox

Expand Down
2 changes: 2 additions & 0 deletions pytest_django/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
__version__ = "unknown"


from .fixtures import DjangoCaptureOnCommitCallbacks
from .plugin import DjangoDbBlocker


__all__ = [
"__version__",
"DjangoCaptureOnCommitCallbacks",
"DjangoDbBlocker",
]
19 changes: 17 additions & 2 deletions pytest_django/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
ContextManager,
Generator,
Iterable,
List,
Literal,
Optional,
Protocol,
Tuple,
Union,
)
Expand Down Expand Up @@ -647,8 +650,20 @@ def django_assert_max_num_queries(pytestconfig: pytest.Config):
return partial(_assert_num_queries, pytestconfig, exact=False)


class DjangoCaptureOnCommitCallbacks(Protocol):
"""The type of the `django_capture_on_commit_callbacks` fixture."""

def __call__(
self,
*,
using: str = ...,
execute: bool = ...,
) -> ContextManager[list[Callable[[], Any]]]:
pass # pragma: no cover


@pytest.fixture()
def django_capture_on_commit_callbacks():
def django_capture_on_commit_callbacks() -> DjangoCaptureOnCommitCallbacks:
from django.test import TestCase

return TestCase.captureOnCommitCallbacks
return TestCase.captureOnCommitCallbacks # type: ignore[no-any-return]
12 changes: 8 additions & 4 deletions tests/test_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from .helpers import DjangoPytester

from pytest_django import DjangoDbBlocker
from pytest_django import DjangoCaptureOnCommitCallbacks, DjangoDbBlocker
from pytest_django_test.app.models import Item


Expand Down Expand Up @@ -232,7 +232,9 @@ def test_queries(django_assert_num_queries):


@pytest.mark.django_db
def test_django_capture_on_commit_callbacks(django_capture_on_commit_callbacks) -> None:
def test_django_capture_on_commit_callbacks(
django_capture_on_commit_callbacks: DjangoCaptureOnCommitCallbacks,
) -> None:
if not connection.features.supports_transactions:
pytest.skip("transactions required for this test")

Expand All @@ -255,7 +257,9 @@ def test_django_capture_on_commit_callbacks(django_capture_on_commit_callbacks)


@pytest.mark.django_db(databases=["default", "second"])
def test_django_capture_on_commit_callbacks_multidb(django_capture_on_commit_callbacks) -> None:
def test_django_capture_on_commit_callbacks_multidb(
django_capture_on_commit_callbacks: DjangoCaptureOnCommitCallbacks,
) -> None:
if not connection.features.supports_transactions:
pytest.skip("transactions required for this test")

Expand All @@ -282,7 +286,7 @@ def test_django_capture_on_commit_callbacks_multidb(django_capture_on_commit_cal

@pytest.mark.django_db(transaction=True)
def test_django_capture_on_commit_callbacks_transactional(
django_capture_on_commit_callbacks,
django_capture_on_commit_callbacks: DjangoCaptureOnCommitCallbacks,
) -> None:
if not connection.features.supports_transactions:
pytest.skip("transactions required for this test")
Expand Down

0 comments on commit 28484f4

Please sign in to comment.