Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): build protocol layer to make pynest framework agnostic #98

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/BlankApp/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import uvicorn
from examples.BlankApp.src.app_module import app

if __name__ == "__main__":
uvicorn.run("src.app_module:http_server", host="0.0.0.0", port=8010, reload=True)
app.adapter.run()
1 change: 0 additions & 1 deletion examples/BlankApp/src/app_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,3 @@ class AppModule:
debug=True,
)

http_server: FastAPI = app.get_server()
19 changes: 14 additions & 5 deletions examples/BlankApp/src/user/user_controller.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
from nest.core import Controller, Depends, Get, Post

from nest.core import Controller, Get, Post
from nest.core.protocols import Param, Query, Header, Body
import uuid
from .user_model import User
from .user_service import UserService


@Controller("user")
@Controller("user", tag="user")
class UserController:
def __init__(self, service: UserService):
self.service = service

@Get("/")
@Get("/", response_model=list[User])
def get_user(self):
return self.service.get_user()

@Get("/{user_id}", response_model=User)
def get_user_by_id(self, user_id: Param[uuid.UUID]):
return self.service.get_user_by_id(user_id)

@Post("/")
def add_user(self, user: User):
def add_user(self, user: Body[User]):
return self.service.add_user(user)

@Get("/test/new-user/{user_id}")
def test_new_user(self, user_id: Param[uuid.UUID]):
return self.service.get_user_by_id(user_id)
8 changes: 5 additions & 3 deletions examples/BlankApp/src/user/user_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from pydantic import BaseModel
from dataclasses import dataclass
from uuid import UUID


class User(BaseModel):
@dataclass
class User:
id: UUID
name: str
3 changes: 3 additions & 0 deletions examples/BlankApp/src/user/user_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ def get_user(self):
def add_user(self, user: User):
self.database.append(user)
return user

def get_user_by_id(self, user_id: str):
return next((user for user in self.database if user.id == user_id), None)
2 changes: 1 addition & 1 deletion nest/common/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def has(self, token: str):
return True if self.get(token) is not None else False


class Module:
class NestModule:
def __init__(self, metatype: Type[object], container):
self._id = str(uuid.uuid4())
self._metatype = metatype
Expand Down
16 changes: 0 additions & 16 deletions nest/common/route_resolver.py

This file was deleted.

2 changes: 0 additions & 2 deletions nest/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from fastapi import Depends

from nest.core.decorators import (
Controller,
Delete,
Expand Down
Empty file added nest/core/adapters/__init__.py
Empty file.
Empty file.
Empty file.
Empty file.
142 changes: 142 additions & 0 deletions nest/core/adapters/fastapi/fastapi_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from typing import Any, Callable, List, Optional
import uvicorn
from fastapi import FastAPI, APIRouter
from fastapi.middleware import Middleware

from nest.core.protocols import (
WebFrameworkAdapterProtocol,
RouterProtocol, Container,
)

from nest.core.adapters.fastapi.utils import wrap_instance_method


class FastAPIRouterAdapter(RouterProtocol):
"""
An adapter for registering routes in FastAPI.
"""

def __init__(self, base_path: str = "") -> None:
"""
Initialize with an optional base path.
"""
print("Initializing FastAPIRouterAdapter")
self._base_path = base_path
self._router = APIRouter(prefix=self._base_path)

def add_route(
self,
path: str,
endpoint: Callable[..., Any],
methods: List[str],
*,
name: Optional[str] = None,
) -> None:
"""
Register an HTTP route with FastAPI's APIRouter.
"""
self._router.add_api_route(path, endpoint, methods=methods, name=name)


def get_router(self) -> APIRouter:
"""
Return the underlying FastAPI APIRouter.
"""
return self._router


###############################################################################
# FastAPI Adapter
###############################################################################

class FastAPIAdapter(WebFrameworkAdapterProtocol):
"""
A FastAPI-based implementation of WebFrameworkAdapterProtocol.
"""

def __init__(self) -> None:
self._app: Optional[FastAPI] = None
self._router_adapter = FastAPIRouterAdapter()
self._middlewares: List[Middleware] = []
self._initialized = False

def create_app(self, **kwargs: Any) -> FastAPI:
"""
Create and configure the FastAPI application.
"""
print("Creating FastAPI app")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove print

self._app = FastAPI(**kwargs)
self._app.include_router(self._router_adapter.get_router())
# Add any pre-collected middlewares
for mw in self._middlewares:
self._app.add_middleware(mw.cls, **mw.options)

self._initialized = True
return self._app

def get_router(self) -> RouterProtocol:
"""
Return the RouterProtocol implementation.
"""
return self._router_adapter

def add_middleware(
self,
middleware_cls: Any,
**options: Any,
) -> None:
"""
Add middleware to the FastAPI application.
"""
if not self._app:
# Collect middlewares before app creation
self._middlewares.append(Middleware(middleware_cls, **options))
else:
# Add middleware directly if app is already created
self._app.add_middleware(middleware_cls, **options)

def run(self, host: str = "127.0.0.1", port: int = 8000, **kwargs) -> None:
"""
Run the FastAPI application using Uvicorn.
"""
if not self._initialized or not self._app:
raise RuntimeError("FastAPI app not created yet. Call create_app() first.")

uvicorn.run(self._app, host=host, port=port, **kwargs)

async def startup(self) -> None:
"""
Handle any startup tasks if necessary.
"""
if self._app:
await self._app.router.startup()

async def shutdown(self) -> None:
"""
Handle any shutdown tasks if necessary.
"""
if self._app:
await self._app.router.shutdown()

def register_routes(self, container: Container) -> None:
"""
Register multiple routes at once.
"""
for module in container.modules.values():
for controller_cls in module.controllers.values():
instance = container.get_instance(controller_cls)

route_definitions = getattr(controller_cls, "__pynest_routes__", [])
for route_definition in route_definitions:
path = route_definition["path"]
method = route_definition["method"]
original_method = route_definition["endpoint"]

final_endpoint = wrap_instance_method(instance, controller_cls, original_method)

self._router_adapter.add_route(
path=path,
endpoint=final_endpoint,
methods=[method],
name=f"{controller_cls.__name__}.{original_method.__name__}",
)
87 changes: 87 additions & 0 deletions nest/core/adapters/fastapi/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Annotated, Callable

from fastapi import Path, Query, Header, Body, File, UploadFile, Response, Request, BackgroundTasks, Form
from nest.core.protocols import Param, Query as QueryParam, Header as HeaderParam, Body as BodyParam, \
Cookie as CookieParam, File as FileParam, Form as FormParam
import functools
import inspect
import typing


def wrap_instance_method(
instance,
cls,
method: Callable,
) -> Callable:
"""
1. Create a new plain function that calls `method(instance, ...)`.
2. Rewrite its signature so that 'self' is removed, and Param/Query/Body become Annotated[...] for FastAPI.
3. Return that new function, which you can pass to fastapi's router.

This avoids "invalid method signature" by not rewriting the bound method in place.
"""

# The unbound function object:
if hasattr(method, "__func__"):
# If 'method' is a bound method, get the actual function
unbound_func = method.__func__
else:
# If it's already an unbound function, use it
unbound_func = method

# Create a wrapper function that calls the unbound function with 'instance' as the first arg
@functools.wraps(unbound_func)
def wrapper(*args, **kwargs):
return unbound_func(instance, *args, **kwargs)

# Now rewrite the wrapper's signature:
# - removing 'self'
# - converting Param/Query/Body to Annotated
new_wrapper = rewrite_signature_for_fastapi(wrapper)
return new_wrapper


def rewrite_signature_for_fastapi(func: Callable) -> Callable:
"""
A function that modifies the signature to remove "self"
and convert Param/Query/Header/Body to FastAPI’s annotated params.
"""
sig = inspect.signature(func)
new_params = []

old_parameters = list(sig.parameters.values())

# 1) If the first param is named 'self', skip it entirely from the new signature
# (because we have a BOUND method).
if old_parameters and old_parameters[0].name == "self":
old_parameters = old_parameters[1:]

for param in old_parameters:
annotation = param.annotation

if typing.get_origin(annotation) == Param:
inner_type = typing.get_args(annotation)[0]
new_annotation = Annotated[inner_type, Path()]
new_params.append(param.replace(annotation=new_annotation))

elif typing.get_origin(annotation) == QueryParam:
inner_type = typing.get_args(annotation)[0]
new_annotation = Annotated[inner_type, Query()]
new_params.append(param.replace(annotation=new_annotation))

elif typing.get_origin(annotation) == HeaderParam:
inner_type = typing.get_args(annotation)[0]
new_annotation = Annotated[inner_type, Header()]
new_params.append(param.replace(annotation=new_annotation))

elif typing.get_origin(annotation) == BodyParam:
inner_type = typing.get_args(annotation)[0]
new_annotation = Annotated[inner_type, Body()]
new_params.append(param.replace(annotation=new_annotation))
else:
# unchanged param
new_params.append(param)

new_sig = sig.replace(parameters=new_params)
func.__signature__ = new_sig
return func
Loading
Loading