Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support await tortoise.contrib.fastapi.RegisterTortoise #1662

Merged
merged 2 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@ Added
^^^^^
- Add ObjectDoesNotExistError to show better 404 message. (#759)
- DoesNotExist and MultipleObjectsReturned support 'Type[Model]' argument. (#742)(#1650)
- Add argument use_tz and timezone to RegisterTortoise. (#1649)
- Support await `tortoise.contrib.fastapi.RegisterTortoise`. (#1662)

Fixed
^^^^^
- Fix `update_or_create` errors when field value changed. (#1584)
- Fix bandit check error (#1643)
- Fix potential race condition in ConnectionWrapper (#1656)
- Fix py312 warning for datetime.utcnow (#1661)

Changed
^^^^^^^
Expand Down
34 changes: 34 additions & 0 deletions tests/contrib/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from unittest.mock import AsyncMock, patch

from fastapi import FastAPI

from tortoise.contrib import test
from tortoise.contrib.fastapi import RegisterTortoise


class TestRegisterTortoise(test.TestCase):
@test.requireCapability(dialect="sqlite") # type:ignore[misc]
@patch("tortoise.Tortoise.init")
@patch("tortoise.connections.close_all")
async def test_await(
self,
mocked_close: AsyncMock,
mocked_init: AsyncMock,
) -> None:
app = FastAPI()
orm = await RegisterTortoise(
app,
db_url="sqlite://:memory:",
modules={"models": ["__main__"]},
)
mocked_init.assert_awaited_once()
mocked_init.assert_called_once_with(
config=None,
config_file=None,
db_url="sqlite://:memory:",
modules={"models": ["__main__"]},
use_tz=False,
timezone="UTC",
)
await orm.close_orm()
mocked_close.assert_awaited_once()
53 changes: 29 additions & 24 deletions tortoise/contrib/fastapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import sys
import warnings
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from types import ModuleType
from typing import TYPE_CHECKING, Dict, Iterable, Optional, Union
from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Union

from fastapi.responses import JSONResponse
from pydantic import BaseModel # pylint: disable=E0611
Expand All @@ -16,6 +17,11 @@
if TYPE_CHECKING:
from fastapi import FastAPI, Request

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self


class HTTPNotFoundError(BaseModel):
detail: str
Expand Down Expand Up @@ -80,6 +86,10 @@ class RegisterTortoise(AbstractAsyncContextManager):
add_exception_handlers:
True to add some automatic exception handlers for ``DoesNotExist`` & ``IntegrityError``.
This is not recommended for production systems as it may leak data.
use_tz:
A boolean that specifies if datetime will be timezone-aware by default or not.
timezone:
Timezone to use, default is UTC.

Raises
------
Expand Down Expand Up @@ -111,38 +121,26 @@ def __init__(
if add_exception_handlers:

@app.exception_handler(DoesNotExist)
async def doesnotexist_exception_handler(
request: "Request", exc: DoesNotExist
):
async def doesnotexist_exception_handler(request: "Request", exc: DoesNotExist):
return JSONResponse(status_code=404, content={"detail": str(exc)})

@app.exception_handler(IntegrityError)
async def integrityerror_exception_handler(
request: "Request", exc: IntegrityError
):
async def integrityerror_exception_handler(request: "Request", exc: IntegrityError):
return JSONResponse(
status_code=422,
content={
"detail": [
{"loc": [], "msg": str(exc), "type": "IntegrityError"}
]
},
content={"detail": [{"loc": [], "msg": str(exc), "type": "IntegrityError"}]},
)

async def init_orm(self) -> None: # pylint: disable=W0612
config, config_file = self.config, self.config_file
db_url, modules = self.db_url, self.modules
await Tortoise.init(
config=config,
config_file=config_file,
db_url=db_url,
modules=modules,
config=self.config,
config_file=self.config_file,
db_url=self.db_url,
modules=self.modules,
use_tz=self.use_tz,
timezone=self.timezone,
)
logger.info(
"Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps
)
logger.info("Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps)
if self.generate_schemas:
logger.info("Tortoise-ORM generating schema")
await Tortoise.generate_schemas()
Expand All @@ -152,16 +150,23 @@ async def close_orm() -> None: # pylint: disable=W0612
await connections.close_all()
logger.info("Tortoise-ORM shutdown")

def __call__(self, *args, **kwargs) -> "RegisterTortoise":
def __call__(self, *args, **kwargs) -> Self:
return self

async def __aenter__(self) -> "RegisterTortoise":
async def __aenter__(self) -> Self:
await self.init_orm()
return self

async def __aexit__(self, *args, **kw):
async def __aexit__(self, *args, **kw) -> None:
await self.close_orm()

def __await__(self) -> Generator[None, None, Self]:
async def _self() -> Self:
await self.init_orm()
return self

return _self().__await__()


def register_tortoise(
app: "FastAPI",
Expand Down