Skip to content

Commit

Permalink
feat: support pydantic2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
long2ice committed Jul 18, 2023
1 parent a819a79 commit f9bce0b
Show file tree
Hide file tree
Showing 9 changed files with 929 additions and 996 deletions.
17 changes: 7 additions & 10 deletions examples/fastapi/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# pylint: disable=E0611,E0401
from typing import List

from fastapi import FastAPI, HTTPException
from fastapi import FastAPI
from models import User_Pydantic, UserIn_Pydantic, Users
from pydantic import BaseModel
from starlette.exceptions import HTTPException

from tortoise.contrib.fastapi import HTTPNotFoundError, register_tortoise
from tortoise.contrib.fastapi import register_tortoise

app = FastAPI(title="Tortoise ORM FastAPI example")

Expand All @@ -19,28 +20,24 @@ async def get_users():
return await User_Pydantic.from_queryset(Users.all())


@app.post("/users", response_model=User_Pydantic)
@app.post("/users", response_model=User_Pydantic) # type: ignore
async def create_user(user: UserIn_Pydantic):
user_obj = await Users.create(**user.dict(exclude_unset=True))
return await User_Pydantic.from_tortoise_orm(user_obj)


@app.get(
"/user/{user_id}", response_model=User_Pydantic, responses={404: {"model": HTTPNotFoundError}}
)
@app.get("/user/{user_id}", response_model=User_Pydantic) # type: ignore
async def get_user(user_id: int):
return await User_Pydantic.from_queryset_single(Users.get(id=user_id))


@app.put(
"/user/{user_id}", response_model=User_Pydantic, responses={404: {"model": HTTPNotFoundError}}
)
@app.put("/user/{user_id}", response_model=User_Pydantic) # type: ignore
async def update_user(user_id: int, user: UserIn_Pydantic):
await Users.filter(id=user_id).update(**user.dict(exclude_unset=True))
return await User_Pydantic.from_queryset_single(Users.get(id=user_id))


@app.delete("/user/{user_id}", response_model=Status, responses={404: {"model": HTTPNotFoundError}})
@app.delete("/user/{user_id}", response_model=Status) # type: ignore
async def delete_user(user_id: int):
deleted_count = await Users.filter(id=user_id).delete()
if not deleted_count:
Expand Down
1,326 changes: 640 additions & 686 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ sanic = "*"
# Sample integration - Starlette
starlette = "*"
# Pydantic support
pydantic = "*"
pydantic = "^2.0"
# FastAPI support
fastapi = "*"
asgi_lifespan = "*"
Expand Down
399 changes: 202 additions & 197 deletions tests/contrib/test_pydantic.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions tests/test_early_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_early_init(self):
self.maxDiff = None
Event_TooEarly = pydantic_model_creator(Event)
self.assertEqual(
Event_TooEarly.schema(),
Event_TooEarly.model_json_schema(),
{
"title": "Event",
"type": "object",
Expand Down Expand Up @@ -159,7 +159,7 @@ def test_early_init(self):

Event_Pydantic = pydantic_model_creator(Event)
self.assertEqual(
Event_Pydantic.schema(),
Event_Pydantic.model_json_schema(),
{
"title": "Event",
"type": "object",
Expand Down
27 changes: 12 additions & 15 deletions tests/testmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import List, Union

import pytz
from pydantic import Extra as PydanticExtra
from pydantic import ConfigDict

from tortoise import fields
from tortoise.exceptions import NoValuesFetched, ValidationError
Expand Down Expand Up @@ -854,6 +854,11 @@ class Pair(Model):
)


def camelize_var(var_name: str):
var_parts: List[str] = var_name.split("_")
return var_parts[0] + "".join([part.title() for part in var_parts[1:]])


class CamelCaseAliasPerson(Model):
"""CamelCaseAliasPerson model.
Expand All @@ -869,17 +874,9 @@ class CamelCaseAliasPerson(Model):
class PydanticMeta:
"""Defines the default config for pydantic model generator."""

class PydanticConfig:
"""Defines the default pydantic config for the model."""

@staticmethod
def camelize_var(var_name: str):
var_parts: List[str] = var_name.split("_")
return var_parts[0] + "".join([part.title() for part in var_parts[1:]])

title = "My custom title"
extra = PydanticExtra.ignore
alias_generator = camelize_var
allow_population_by_field_name = True

config_class = PydanticConfig
model_config = ConfigDict(
title="My custom title",
extra="ignore",
alias_generator=camelize_var,
populate_by_name=True,
)
13 changes: 7 additions & 6 deletions tortoise/contrib/fastapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from types import ModuleType
from typing import Dict, Iterable, Optional, Union

from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi import FastAPI
from pydantic import BaseModel # pylint: disable=E0611
from starlette.requests import Request
from starlette.responses import JSONResponse

from tortoise import Tortoise, connections
from tortoise.exceptions import DoesNotExist, IntegrityError
Expand Down Expand Up @@ -88,26 +89,26 @@ def register_tortoise(
For any configuration error
"""

@app.on_event("startup")
@app.on_event("startup") # type: ignore
async def init_orm() -> None: # pylint: disable=W0612
await Tortoise.init(config=config, config_file=config_file, db_url=db_url, modules=modules)
logger.info("Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps)
if generate_schemas:
logger.info("Tortoise-ORM generating schema")
await Tortoise.generate_schemas()

@app.on_event("shutdown")
@app.on_event("shutdown") # type: ignore
async def close_orm() -> None: # pylint: disable=W0612
await connections.close_all()
logger.info("Tortoise-ORM shutdown")

if add_exception_handlers:

@app.exception_handler(DoesNotExist)
@app.exception_handler(DoesNotExist) # type: ignore
async def doesnotexist_exception_handler(request: Request, exc: DoesNotExist):
return JSONResponse(status_code=404, content={"detail": str(exc)})

@app.exception_handler(IntegrityError)
@app.exception_handler(IntegrityError) # type: ignore
async def integrityerror_exception_handler(request: Request, exc: IntegrityError):
return JSONResponse(
status_code=422,
Expand Down
31 changes: 15 additions & 16 deletions tortoise/contrib/pydantic/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import TYPE_CHECKING, List, Type, Union

import pydantic
from pydantic import BaseConfig, BaseModel # pylint: disable=E0611
from pydantic import BaseModel, ConfigDict, RootModel

from tortoise import fields

Expand All @@ -28,7 +28,7 @@ def _get_fetch_fields(
# noinspection PyProtectedMember
if field_name in model_class._meta.fetch_fields and issubclass(field_type, PydanticModel):
subclass_fetch_fields = _get_fetch_fields(
field_type, getattr(field_type.__config__, "orig_model")
field_type, field_type.model_config["orig_model"]
)
if subclass_fetch_fields:
fetch_fields.extend([field_name + "__" + f for f in subclass_fetch_fields])
Expand All @@ -45,11 +45,10 @@ class PydanticModel(BaseModel):
`model properties <https://pydantic-docs.helpmanual.io/usage/models/#model-properties>`__
"""

class Config(BaseConfig):
orm_mode = True # It should be in ORM mode to convert tortoise data to pydantic
model_config = ConfigDict(from_attributes=True)

# noinspection PyMethodParameters
@pydantic.validator("*", pre=True, each_item=False) # It is a classmethod!
@pydantic.field_validator("*") # It is a classmethod!
def _tortoise_convert(cls, value): # pylint: disable=E0213
# Computed fields
if callable(value):
Expand Down Expand Up @@ -81,10 +80,10 @@ async def from_tortoise_orm(cls, obj: "Model") -> "PydanticModel":
:param obj: The Model instance you want serialized.
"""
# Get fields needed to fetch
fetch_fields = _get_fetch_fields(cls, getattr(cls.__config__, "orig_model"))
fetch_fields = _get_fetch_fields(cls, cls.model_config["orig_model"]) # type: ignore
# Fetch fields
await obj.fetch_related(*fetch_fields)
return super().from_orm(obj)
return cls.model_validate(obj)

@classmethod
async def from_queryset_single(cls, queryset: "QuerySetSingle") -> "PydanticModel":
Expand All @@ -96,8 +95,8 @@ async def from_queryset_single(cls, queryset: "QuerySetSingle") -> "PydanticMode
:param queryset: a queryset on the model this PydanticModel is based on.
"""
fetch_fields = _get_fetch_fields(cls, getattr(cls.__config__, "orig_model"))
return cls.from_orm(await queryset.prefetch_related(*fetch_fields))
fetch_fields = _get_fetch_fields(cls, cls.model_config["orig_model"]) # type: ignore
return cls.model_validate(await queryset.prefetch_related(*fetch_fields))

@classmethod
async def from_queryset(cls, queryset: "QuerySet") -> "List[PydanticModel]":
Expand All @@ -109,11 +108,11 @@ async def from_queryset(cls, queryset: "QuerySet") -> "List[PydanticModel]":
:param queryset: a queryset on the model this PydanticModel is based on.
"""
fetch_fields = _get_fetch_fields(cls, getattr(cls.__config__, "orig_model"))
return [cls.from_orm(e) for e in await queryset.prefetch_related(*fetch_fields)]
fetch_fields = _get_fetch_fields(cls, cls.model_config["orig_model"]) # type: ignore
return [cls.model_validate(e) for e in await queryset.prefetch_related(*fetch_fields)]


class PydanticListModel(BaseModel):
class PydanticListModel(RootModel):
"""
Pydantic BaseModel for List of Tortoise Models
Expand All @@ -131,8 +130,8 @@ async def from_queryset(cls, queryset: "QuerySet") -> "PydanticListModel":
:param queryset: a queryset on the model this PydanticListModel is based on.
"""
submodel = getattr(cls.__config__, "submodel")
fetch_fields = _get_fetch_fields(submodel, getattr(submodel.__config__, "orig_model"))
return cls(
__root__=[submodel.from_orm(e) for e in await queryset.prefetch_related(*fetch_fields)]
submodel = cls.model_config["submodel"] # type: ignore
fetch_fields = _get_fetch_fields(submodel, submodel.model_config["orig_model"])
return cls.model_validate(
[submodel.model_validate(e) for e in await queryset.prefetch_related(*fetch_fields)]
)
Loading

0 comments on commit f9bce0b

Please sign in to comment.