diff --git a/README.md b/README.md index a3c366f..d2a9fc8 100644 --- a/README.md +++ b/README.md @@ -357,7 +357,15 @@ async def my_handler(var: str = Depends(Path())): ``` -## Overridiing dependencies +## ExtraOpenAPI + +This dependency is used to add additional swagger fields to the endpoint's swagger +that is using this dependency. It might be even indirect dependency. + +You can check how this thing can be used in our [examples/swagger_auth.py](https://github.com/taskiq-python/aiohttp-deps/tree/master/examples/swagger_auth.py). + + +## Overriding dependencies Sometimes for tests you don't want to calculate actual functions and you want to pass another functions instead. diff --git a/aiohttp_deps/__init__.py b/aiohttp_deps/__init__.py index ae8625c..fa1b831 100644 --- a/aiohttp_deps/__init__.py +++ b/aiohttp_deps/__init__.py @@ -6,7 +6,7 @@ from aiohttp_deps.keys import DEPENDENCY_OVERRIDES_KEY, VALUES_OVERRIDES_KEY from aiohttp_deps.router import Router from aiohttp_deps.swagger import extra_openapi, openapi_response, setup_swagger -from aiohttp_deps.utils import Form, Header, Json, Path, Query +from aiohttp_deps.utils import ExtraOpenAPI, Form, Header, Json, Path, Query from aiohttp_deps.view import View __all__ = [ @@ -21,6 +21,7 @@ "Query", "Form", "Path", + "ExtraOpenAPI", "openapi_response", "DEPENDENCY_OVERRIDES_KEY", "VALUES_OVERRIDES_KEY", diff --git a/aiohttp_deps/swagger.py b/aiohttp_deps/swagger.py index 23f3eed..351fdf6 100644 --- a/aiohttp_deps/swagger.py +++ b/aiohttp_deps/swagger.py @@ -6,6 +6,7 @@ Awaitable, Callable, Dict, + List, Optional, Tuple, TypeVar, @@ -19,7 +20,7 @@ from aiohttp_deps.initializer import InjectableFuncHandler, InjectableViewHandler from aiohttp_deps.keys import SWAGGER_SCHEMA_KEY -from aiohttp_deps.utils import Form, Header, Json, Path, Query +from aiohttp_deps.utils import ExtraOpenAPI, Form, Header, Json, Path, Query _T = TypeVar("_T") @@ -119,6 +120,7 @@ def _add_route_def( # noqa: C901 openapi_schema["components"]["schemas"].update(extra_openapi_schemas) params: Dict[Tuple[str, str], Any] = {} + updaters: List[Callable[[Dict[str, Any]], None]] = [] def _insert_in_params(data: Dict[str, Any]) -> None: element = params.get((data["name"], data["in"])) @@ -191,8 +193,18 @@ def _insert_in_params(data: Dict[str, Any]) -> None: "schema": schema, }, ) + elif isinstance(dependency.dependency, ExtraOpenAPI): + if dependency.dependency.updater is not None: + updaters.append(dependency.dependency.updater) + if dependency.dependency.extra_openapi is not None: + extra_openapi = always_merger.merge( + extra_openapi, + dependency.dependency.extra_openapi, + ) route_info["parameters"] = list(params.values()) + for updater in updaters: + updater(route_info) openapi_schema["paths"][route.resource.canonical].update( {method.lower(): always_merger.merge(route_info, extra_openapi)}, ) @@ -207,6 +219,7 @@ def setup_swagger( # noqa: C901 title: str = "AioHTTP", description: Optional[str] = None, version: str = "1.0.0", + extra_openapi: Optional[Dict[str, Any]] = None, ) -> Callable[[web.Application], Awaitable[None]]: """ Add swagger documentation. @@ -230,8 +243,11 @@ def setup_swagger( # noqa: C901 :param title: Title of an application. :param description: description of an application. :param version: version of an application. + :param extra_openapi: extra openAPI dict that will be merged with generated schema. :return: startup event handler. """ + if extra_openapi is None: + extra_openapi = {} async def event_handler(app: web.Application) -> None: # noqa: C901 openapi_schema = { @@ -252,12 +268,12 @@ async def event_handler(app: web.Application) -> None: # noqa: C901 if hide_options and route.method.upper() == "OPTIONS": continue if isinstance(route._handler, InjectableFuncHandler): - extra_openapi = getattr( + route_extra_openapi = getattr( route._handler.original_handler, "__extra_openapi__", {}, ) - extra_schemas = getattr( + route_extra_schemas = getattr( route._handler.original_handler, "__extra_openapi_schemas__", {}, @@ -268,8 +284,8 @@ async def event_handler(app: web.Application) -> None: # noqa: C901 route, # type: ignore route.method, route._handler.graph, - extra_openapi=extra_openapi, - extra_openapi_schemas=extra_schemas, + extra_openapi=route_extra_openapi, + extra_openapi_schemas=route_extra_schemas, ) except Exception as exc: # pragma: no cover logger.warn( @@ -280,12 +296,12 @@ async def event_handler(app: web.Application) -> None: # noqa: C901 elif isinstance(route._handler, InjectableViewHandler): for key, graph in route._handler.graph_map.items(): - extra_openapi = getattr( + route_extra_openapi = getattr( getattr(route._handler.original_handler, key), "__extra_openapi__", {}, ) - extra_schemas = getattr( + route_extra_schemas = getattr( getattr(route._handler.original_handler, key), "__extra_openapi_schemas__", {}, @@ -296,8 +312,8 @@ async def event_handler(app: web.Application) -> None: # noqa: C901 route, # type: ignore key, graph, - extra_openapi=extra_openapi, - extra_openapi_schemas=extra_schemas, + extra_openapi=route_extra_openapi, + extra_openapi_schemas=route_extra_schemas, ) except Exception as exc: # pragma: no cover logger.warn( @@ -306,7 +322,7 @@ async def event_handler(app: web.Application) -> None: # noqa: C901 exc_info=True, ) - app[SWAGGER_SCHEMA_KEY] = openapi_schema + app[SWAGGER_SCHEMA_KEY] = always_merger.merge(openapi_schema, extra_openapi) app.router.add_get( schema_url, diff --git a/aiohttp_deps/utils.py b/aiohttp_deps/utils.py index f3a6f85..e983d61 100644 --- a/aiohttp_deps/utils.py +++ b/aiohttp_deps/utils.py @@ -1,6 +1,6 @@ import inspect import json -from typing import Any, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import pydantic from aiohttp import web @@ -344,3 +344,37 @@ def __call__( headers={"Content-Type": "application/json"}, text=json.dumps(errors), ) from err + + +class ExtraOpenAPI: + """ + Update swagger for the endpoint. + + You can use this dependency to add swagger to an endpoint from + a dependency. It's useful when you want to add some extra swagger + to the route when some specific dependency is used by it. + """ + + def __init__( + self, + extra_openapi: Optional[Dict[str, Any]] = None, + swagger_updater: Optional[Callable[[Dict[str, Any]], None]] = None, + ) -> None: + """ + Initialize the dependency. + + :param swagger_updater: function that takes final swagger endpoint and + updates it. + :param extra_swagger: extra swagger to add to the endpoint. This one might + override other extra_swagger on the endpoint. + """ + self.updater = swagger_updater + self.extra_openapi = extra_openapi + + def __call__(self) -> None: + """ + This method is called when dependency is resolved. + + It's empty, becuase it's used by the swagger function and + there is no actual dependency. + """ diff --git a/examples/swagger_auth.py b/examples/swagger_auth.py new file mode 100644 index 0000000..469e69a --- /dev/null +++ b/examples/swagger_auth.py @@ -0,0 +1,87 @@ +import base64 + +from aiohttp import web +from pydantic import BaseModel + +from aiohttp_deps import Depends, ExtraOpenAPI, Header, Router, init, setup_swagger + + +class UserInfo(BaseModel): + """Abstract user model.""" + + id: int + name: str + password: str + + +router = Router() + +# Here we create a simple user storage. +# In real-world applications, you would use a database. +users = { + "john": UserInfo(id=1, name="John Doe", password="123"), # noqa: S106 + "caren": UserInfo(id=2, name="Caren Doe", password="321"), # noqa: S106 +} + + +def get_current_user( + # Current auth header. + authorization: str = Depends(Header()), + # We don't need a name to this variable, + # because it will only affect the API schema, + # but won't be used in runtime. + _: None = Depends( + ExtraOpenAPI( + extra_openapi={ + "security": [{"basicAuth": []}], + }, + ), + ), +) -> UserInfo: + """This function checks if the user authorized.""" + # Here we check if the authorization header is present. + if not authorization.startswith("Basic"): + raise web.HTTPUnauthorized(reason="Unsupported authorization type") + # We decode credentials from the header. + # And check if the user exists. + creds = base64.b64decode(authorization.split(" ")[1]).decode() + username, password = creds.split(":") + found_user = users.get(username) + if found_user is None: + raise web.HTTPUnauthorized(reason="User not found") + if found_user.password != password: + raise web.HTTPUnauthorized(reason="Invalid password") + return found_user + + +@router.get("/") +async def index(current_user: UserInfo = Depends(get_current_user)) -> web.Response: + """Index handler returns current user.""" + return web.json_response(current_user.model_dump(mode="json")) + + +app = web.Application() +app.router.add_routes(router) +app.on_startup.extend( + [ + init, + setup_swagger( + # Here we add security schemes used + # to authorize users. + extra_openapi={ + "components": { + "securitySchemes": { + # We only support basic auth. + "basicAuth": { + "type": "http", + "scheme": "basic", + }, + }, + }, + }, + ), + ], +) + +if __name__ == "__main__": + web.run_app(app) diff --git a/tests/test_swagger.py b/tests/test_swagger.py index 175d46b..6caef8a 100644 --- a/tests/test_swagger.py +++ b/tests/test_swagger.py @@ -8,6 +8,7 @@ from aiohttp_deps import ( Depends, + ExtraOpenAPI, Form, Header, Json, @@ -780,3 +781,60 @@ async def my_handler() -> None: schema = await response.json() assert "get" in schema["paths"]["/"] assert method.lower() not in schema["paths"]["/"] + + +@pytest.mark.anyio +async def test_extra_openapi_dep_func( + my_app: web.Application, + aiohttp_client: ClientGenerator, +) -> None: + openapi_url = "/my_api_def.json" + my_app.on_startup.append(setup_swagger(schema_url=openapi_url)) + + async def dep( + _: None = Depends(ExtraOpenAPI(extra_openapi={"responses": {"200": {}}})), + ) -> None: + """Test dep that adds swagger through a dependency.""" + + async def my_handler(_: None = Depends(dep)) -> None: + """Nothing.""" + + my_app.router.add_get("/a", my_handler) + + client = await aiohttp_client(my_app) + resp = await client.get(openapi_url) + assert resp.status == 200 + resp_json = await resp.json() + + handler_info = resp_json["paths"]["/a"]["get"] + assert handler_info["responses"] == {"200": {}} + + +@pytest.mark.anyio +async def test_extra_openapi_dep_updater_func( + my_app: web.Application, + aiohttp_client: ClientGenerator, +) -> None: + openapi_url = "/my_api_def.json" + my_app.on_startup.append(setup_swagger(schema_url=openapi_url)) + + def schema_updater(schema: Dict[str, Any]) -> None: + schema["responses"] = {"200": {}} + + async def dep( + _: None = Depends(ExtraOpenAPI(swagger_updater=schema_updater)), + ) -> None: + """Test dep that adds swagger through a dependency.""" + + async def my_handler(_: None = Depends(dep)) -> None: + """Nothing.""" + + my_app.router.add_get("/a", my_handler) + + client = await aiohttp_client(my_app) + resp = await client.get(openapi_url) + assert resp.status == 200 + resp_json = await resp.json() + + handler_info = resp_json["paths"]["/a"]["get"] + assert handler_info["responses"] == {"200": {}}