Skip to content

Add support for more generic CBVs by way of a generic router #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 51 additions & 6 deletions fastapi_utils/cbv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import functools
import inspect
from typing import Any, Callable, List, Type, TypeVar, Union, get_type_hints
from copy import copy
from types import FunctionType
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, get_type_hints

from fastapi import APIRouter, Depends
from pydantic.typing import is_classvar
Expand All @@ -8,6 +11,7 @@
T = TypeVar("T")

CBV_CLASS_KEY = "__cbv_class__"
GENERIC_CBV_ROUTERS_KEY = "__generic_cbv_routers__"


def cbv(router: APIRouter) -> Callable[[Type[T]], Type[T]]:
Expand All @@ -17,24 +21,50 @@ def decorator(cls: Type[T]) -> Type[T]:
return decorator


def generic_cbv(router: APIRouter) -> Callable[[Type[T]], Type[T]]:
def decorator(cls: Type[T]) -> Type[T]:
generic_routers = getattr(cls, GENERIC_CBV_ROUTERS_KEY, None)
if generic_routers is None:
generic_routers = []
setattr(cls, GENERIC_CBV_ROUTERS_KEY, generic_routers)
generic_routers.append(router)
return cls

return decorator


def _cbv(router: APIRouter, cls: Type[T]) -> Type[T]:
_init_cbv(cls)
cbv_router = APIRouter()
functions = inspect.getmembers(cls, inspect.isfunction)
routes_by_endpoint = {
route.endpoint: route for route in router.routes if isinstance(route, (Route, WebSocketRoute))
}
routes_by_endpoint = _routes_by_endpoint(router)
generic_routes_by_endpoint = {}
for generic_router in getattr(cls, GENERIC_CBV_ROUTERS_KEY, []):
generic_routes_by_endpoint.update(_routes_by_endpoint(generic_router))
for _, func in functions:
route = routes_by_endpoint.get(func)
if route is None:
continue
router.routes.remove(route)
route = generic_routes_by_endpoint.get(func)
if route is None:
continue
else:
router.routes.remove(route)
route = copy(route)
route.endpoint = replace_method_with_copy(cls, func)
_update_cbv_route_endpoint_signature(cls, route)
cbv_router.routes.append(route)
router.include_router(cbv_router)
return cls


def _routes_by_endpoint(router: Optional[APIRouter]) -> Dict[Callable[..., Any], Union[Route, WebSocketRoute]]:
return (
{}
if router is None
else {route.endpoint: route for route in router.routes if isinstance(route, (Route, WebSocketRoute))}
)


def _update_cbv_route_endpoint_signature(cls: Type[Any], route: Union[Route, WebSocketRoute]) -> None:
old_endpoint = route.endpoint
old_signature = inspect.signature(old_endpoint)
Expand Down Expand Up @@ -78,3 +108,18 @@ def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
setattr(cls, "__signature__", new_signature)
setattr(cls, "__init__", new_init)
setattr(cls, CBV_CLASS_KEY, True)


def replace_method_with_copy(cls: Type[Any], function: FunctionType) -> FunctionType:
copied = FunctionType(
function.__code__,
function.__globals__,
name=function.__name__,
argdefs=function.__defaults__,
closure=function.__closure__,
)
functools.update_wrapper(copied, function)
copied.__qualname__ = f"{cls.__name__}.{function.__name__}"
copied.__kwdefaults__ = function.__kwdefaults__
setattr(cls, function.__name__, copied)
return copied
71 changes: 71 additions & 0 deletions tests/test_generic_cbv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from typing import Optional

from fastapi import APIRouter, Depends, FastAPI
from starlette.testclient import TestClient

from fastapi_utils.cbv import cbv, generic_cbv


def get_a(a: int) -> int:
return a


def get_double_b(b: int) -> int:
return 2 * b


def get_string(c: Optional[str] = None) -> Optional[str]:
return c


router = APIRouter()


@generic_cbv(router)
class BaseGenericCBV:
number: int = Depends(None)

@router.get("/")
async def echo_number(self) -> int:
return self.number


other_router = APIRouter()


@generic_cbv(other_router)
class GenericCBV(BaseGenericCBV):
string: Optional[str] = Depends(None)

@router.get("/string")
async def echo_string(self) -> Optional[str]:
return self.string


router_a = APIRouter()
router_b = APIRouter()


@cbv(router_a)
class CBVA(GenericCBV):
number = Depends(get_a)
string = Depends(get_string)


@cbv(router_b)
class CBVB(GenericCBV):
number = Depends(get_double_b)
string = Depends(get_string)


app = FastAPI()
app.include_router(router_a, prefix="/a")
app.include_router(router_b, prefix="/b")


def test_generic_cbv() -> None:
assert TestClient(app).get("/a/", params={"a": 1}).json() == 1
assert TestClient(app).get("/b/", params={"b": 1}).json() == 2

assert TestClient(app).get("/a/string", params={"a": 1, "c": "hello"}).json() == "hello"
assert TestClient(app).get("/b/string", params={"b": 1, "c": "world"}).json() == "world"