Skip to content

Commit

Permalink
Support await tortoise.contrib.fastapi.RegisterTortoise (#1662)
Browse files Browse the repository at this point in the history
* Support await RegisterTortoise

* Update changelog
  • Loading branch information
waketzheng authored Jun 24, 2024
1 parent bc18384 commit b87f485
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 24 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@ Added
^^^^^
- Add ObjectDoesNotExistError to show better 404 message. (#759)
- DoesNotExist and MultipleObjectsReturned support 'Type[Model]' argument. (#742)(#1650)
- Add argument use_tz and timezone to RegisterTortoise. (#1649)
- Support await `tortoise.contrib.fastapi.RegisterTortoise`. (#1662)

Fixed
^^^^^
- Fix `update_or_create` errors when field value changed. (#1584)
- Fix bandit check error (#1643)
- Fix potential race condition in ConnectionWrapper (#1656)
- Fix py312 warning for datetime.utcnow (#1661)

Changed
^^^^^^^
Expand Down
34 changes: 34 additions & 0 deletions tests/contrib/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from unittest.mock import AsyncMock, patch

from fastapi import FastAPI

from tortoise.contrib import test
from tortoise.contrib.fastapi import RegisterTortoise


class TestRegisterTortoise(test.TestCase):
@test.requireCapability(dialect="sqlite") # type:ignore[misc]
@patch("tortoise.Tortoise.init")
@patch("tortoise.connections.close_all")
async def test_await(
self,
mocked_close: AsyncMock,
mocked_init: AsyncMock,
) -> None:
app = FastAPI()
orm = await RegisterTortoise(
app,
db_url="sqlite://:memory:",
modules={"models": ["__main__"]},
)
mocked_init.assert_awaited_once()
mocked_init.assert_called_once_with(
config=None,
config_file=None,
db_url="sqlite://:memory:",
modules={"models": ["__main__"]},
use_tz=False,
timezone="UTC",
)
await orm.close_orm()
mocked_close.assert_awaited_once()
53 changes: 29 additions & 24 deletions tortoise/contrib/fastapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import sys
import warnings
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from types import ModuleType
from typing import TYPE_CHECKING, Dict, Iterable, Optional, Union
from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Union

from fastapi.responses import JSONResponse
from pydantic import BaseModel # pylint: disable=E0611
Expand All @@ -16,6 +17,11 @@
if TYPE_CHECKING:
from fastapi import FastAPI, Request

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self


class HTTPNotFoundError(BaseModel):
detail: str
Expand Down Expand Up @@ -80,6 +86,10 @@ class RegisterTortoise(AbstractAsyncContextManager):
add_exception_handlers:
True to add some automatic exception handlers for ``DoesNotExist`` & ``IntegrityError``.
This is not recommended for production systems as it may leak data.
use_tz:
A boolean that specifies if datetime will be timezone-aware by default or not.
timezone:
Timezone to use, default is UTC.
Raises
------
Expand Down Expand Up @@ -111,38 +121,26 @@ def __init__(
if add_exception_handlers:

@app.exception_handler(DoesNotExist)
async def doesnotexist_exception_handler(
request: "Request", exc: DoesNotExist
):
async def doesnotexist_exception_handler(request: "Request", exc: DoesNotExist):
return JSONResponse(status_code=404, content={"detail": str(exc)})

@app.exception_handler(IntegrityError)
async def integrityerror_exception_handler(
request: "Request", exc: IntegrityError
):
async def integrityerror_exception_handler(request: "Request", exc: IntegrityError):
return JSONResponse(
status_code=422,
content={
"detail": [
{"loc": [], "msg": str(exc), "type": "IntegrityError"}
]
},
content={"detail": [{"loc": [], "msg": str(exc), "type": "IntegrityError"}]},
)

async def init_orm(self) -> None: # pylint: disable=W0612
config, config_file = self.config, self.config_file
db_url, modules = self.db_url, self.modules
await Tortoise.init(
config=config,
config_file=config_file,
db_url=db_url,
modules=modules,
config=self.config,
config_file=self.config_file,
db_url=self.db_url,
modules=self.modules,
use_tz=self.use_tz,
timezone=self.timezone,
)
logger.info(
"Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps
)
logger.info("Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps)
if self.generate_schemas:
logger.info("Tortoise-ORM generating schema")
await Tortoise.generate_schemas()
Expand All @@ -152,16 +150,23 @@ async def close_orm() -> None: # pylint: disable=W0612
await connections.close_all()
logger.info("Tortoise-ORM shutdown")

def __call__(self, *args, **kwargs) -> "RegisterTortoise":
def __call__(self, *args, **kwargs) -> Self:
return self

async def __aenter__(self) -> "RegisterTortoise":
async def __aenter__(self) -> Self:
await self.init_orm()
return self

async def __aexit__(self, *args, **kw):
async def __aexit__(self, *args, **kw) -> None:
await self.close_orm()

def __await__(self) -> Generator[None, None, Self]:
async def _self() -> Self:
await self.init_orm()
return self

return _self().__await__()


def register_tortoise(
app: "FastAPI",
Expand Down

0 comments on commit b87f485

Please sign in to comment.