From 43d2ac47bdd8a9ca31a2e2508e6755d0ea794075 Mon Sep 17 00:00:00 2001 From: nikhilleo10 Date: Wed, 22 Nov 2023 18:46:15 +0530 Subject: [PATCH] Add request id middleware --- app.py | 6 +++++- middlewares/request_id_injection.py | 22 ++++++++++++++++++++++ requirements.txt | 3 ++- routes/users.py | 6 ++++++ 4 files changed, 35 insertions(+), 2 deletions(-) create mode 100644 middlewares/request_id_injection.py diff --git a/app.py b/app.py index 7640d81..780f668 100644 --- a/app.py +++ b/app.py @@ -10,10 +10,12 @@ from routes.users import user from fastapi_pagination import add_pagination from middlewares.rate_limiter_middleware import RateLimitMiddleware +from middlewares.request_id_injection import RequestIdInjection from pybreaker import CircuitBreakerError from dependencies import circuit_breaker from utils.slack_notification_utils import send_slack_message import traceback +from middlewares.request_id_injection import request_id_contextvar # Initializing the swagger docs app = FastAPI( @@ -26,6 +28,8 @@ origins = ["*"] +app.add_middleware(RequestIdInjection) + # CORS middleware app.add_middleware( CORSMiddleware, @@ -41,7 +45,7 @@ # Default API route @app.get("/") async def read_main(): - 1/0 + print('Request ID:', request_id_contextvar.get()) return {"response": "service up and running..!"} diff --git a/middlewares/request_id_injection.py b/middlewares/request_id_injection.py new file mode 100644 index 0000000..1f647d2 --- /dev/null +++ b/middlewares/request_id_injection.py @@ -0,0 +1,22 @@ +from fastapi import FastAPI +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from fastapi.responses import JSONResponse +import contextvars +import uuid + +request_id_contextvar = contextvars.ContextVar("request_id", default=None) + +class RequestIdInjection(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + request_id = str(uuid.uuid4()) + request_id_contextvar.set(request_id) + try: + return await call_next(request) + + except Exception as ex: + print(ex) + return JSONResponse(content={"success": False}, status_code=500) + + finally: + assert request_id_contextvar.get() == request_id diff --git a/requirements.txt b/requirements.txt index 2d9bce7..4db7060 100644 --- a/requirements.txt +++ b/requirements.txt @@ -115,4 +115,5 @@ flake8 black pre-commit pybreaker -pytest-asyncio \ No newline at end of file +pytest-asyncio +contextvars \ No newline at end of file diff --git a/routes/users.py b/routes/users.py index 55240fe..2529d1c 100644 --- a/routes/users.py +++ b/routes/users.py @@ -15,6 +15,7 @@ from utils.user_utils import get_current_user from typing import Annotated from fastapi.security import HTTPBearer +from middlewares.request_id_injection import request_id_contextvar user = APIRouter() @@ -23,12 +24,14 @@ @user.post("/register", tags=["Users"]) def register(payload: CreateUser, db: Session = Depends(create_local_session)): + print('Request ID:', request_id_contextvar.get()) response = create_user_dao(data=payload, dbSession=db) return response @user.post("/signin", tags=["Users"]) def login(payload: Login, db: Session = Depends(create_local_session)): + print('Request ID:', request_id_contextvar.get()) response = signin(data=payload, dbSession=db) return response @@ -40,6 +43,7 @@ async def profile( db: Session = Depends(create_local_session), redis=Depends(get_redis), ): + print('Request ID:', request_id_contextvar.get()) # Here, you can use 'redis' to fetch or store data in Redis cache response = await get_user_dao(user_id, dbSession=db, redis=redis) return response @@ -47,10 +51,12 @@ async def profile( @user.get("/", tags=["Users"], response_model=Page[UserOutResponse]) def list_users(db: Session = Depends(create_local_session)): + print('Request ID:', request_id_contextvar.get()) response = list_users_dao(dbSession=db) return response @user.get("/{user_id}/secure-route/", tags=["Users"], dependencies=[Depends(get_current_user)]) def secure_route(token: Annotated[str, Depends(httpBearerScheme)], user_id: int): + print('Request ID:', request_id_contextvar.get()) return {"message": "If you see this, you're authenticated"}