From e635fd067b9c241e45509ff407ceb2076ab7f765 Mon Sep 17 00:00:00 2001 From: saurabh-wednesday Date: Mon, 20 Nov 2023 13:24:58 +0530 Subject: [PATCH] feat(redis): added config and support for cache --- .env | 3 +- README.md | 8 ++- app.py | 4 +- config/db.py | 11 ++-- config/redis_config.py | 6 ++ daos/users.py | 11 +++- middlewares/rate_limiter_middleware.py | 77 +++++++++++++------------- requirements.txt | 3 +- routes/users.py | 7 ++- utils/redis_utils.py | 4 ++ 10 files changed, 81 insertions(+), 53 deletions(-) create mode 100644 config/redis_config.py create mode 100644 utils/redis_utils.py diff --git a/.env b/.env index 5df3729..4db49ff 100644 --- a/.env +++ b/.env @@ -5,4 +5,5 @@ DB_PORT=3306 DB_NAME=fastapi OPENAI_API_KEY_GPT4= OPENAI_API_KEY_WEDNESDAY= -YOUR_SECRET_KEY=my_super_secret_key_here \ No newline at end of file +YOUR_SECRET_KEY=my_super_secret_key_here +# REDIS_URL=http://localhost:6379 \ No newline at end of file diff --git a/README.md b/README.md index 63f8276..8f6125c 100644 --- a/README.md +++ b/README.md @@ -16,4 +16,10 @@ alembic revision -m 'initialize all models' ## Upgrade migrations ``` alembic upgrade head -``` \ No newline at end of file +``` + +## Redis Dependency +``` +docker run --name recorder-redis -p 6379:6379 -d redis:alpine +``` +or add the REDIS_URL in .env file \ No newline at end of file diff --git a/app.py b/app.py index 105f404..be9aa5f 100644 --- a/app.py +++ b/app.py @@ -6,7 +6,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from routes.users import user -from middlewares.rate_limiter_middleware import rate_limit_middleware +from middlewares.rate_limiter_middleware import RateLimitMiddleware # Initializing the swagger docs @@ -28,7 +28,7 @@ allow_methods=["*"], allow_headers=["*"], ) -app.middleware("http")(rate_limit_middleware) +app.add_middleware(RateLimitMiddleware) app.include_router(user, prefix='/user') # Default API route diff --git a/config/db.py b/config/db.py index 3bc2f86..b42a9e6 100644 --- a/config/db.py +++ b/config/db.py @@ -12,11 +12,12 @@ load_dotenv() # Set the default values for connecting locally -HOST = os.environ.get("DB_HOSTNAME") -PORT = os.environ.get("DB_PORT") -DBNAME = os.environ.get("DB_NAME") -USERNAME = os.environ.get("DB_USERNAME") -PASSWORD = os.environ.get("DB_PASSWORD") +HOST = os.environ.get("DB_HOSTNAME", "localhost") +PORT = os.environ.get("DB_PORT", "3306") +DBNAME = os.environ.get("DB_NAME", "mydbname") +USERNAME = os.environ.get("DB_USERNAME", "user") +PASSWORD = os.environ.get("DB_PASSWORD", "password") + if "pytest" in sys.modules: SQLALCHEMY_DATABASE_URL = "sqlite://" diff --git a/config/redis_config.py b/config/redis_config.py new file mode 100644 index 0000000..84ca4bf --- /dev/null +++ b/config/redis_config.py @@ -0,0 +1,6 @@ +import aioredis +import os + +async def get_redis_pool(): + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379") + return aioredis.from_url(redis_url) diff --git a/daos/users.py b/daos/users.py index 8923ecc..fe8eab4 100644 --- a/daos/users.py +++ b/daos/users.py @@ -1,3 +1,5 @@ +import json +from aioredis import Redis from fastapi import HTTPException from sqlalchemy.orm import Session from constants import jwt_utils @@ -6,11 +8,17 @@ from schemas.users import CreateUser from schemas.users import Login from werkzeug.security import check_password_hash +from utils import redis_utils from utils.user_utils import check_existing_field, responseFormatter def get_user(user_id: int, dbSession: Session): try: + cache_key = f"user:{user_id}" + cached_user = redis_utils.get_redis().get(cache_key) + + if cached_user: + return json.loads(cached_user) # Check if the subject already exists in the database user = ( dbSession.query(User) @@ -26,7 +34,8 @@ def get_user(user_id: int, dbSession: Session): ) .first() ) - + if user: + Redis.set(cache_key, json.dumps(user)) if not user: raise Exception(messages["NO_USER_FOUND_FOR_ID"]) diff --git a/middlewares/rate_limiter_middleware.py b/middlewares/rate_limiter_middleware.py index 67a8241..8affa55 100644 --- a/middlewares/rate_limiter_middleware.py +++ b/middlewares/rate_limiter_middleware.py @@ -1,43 +1,42 @@ -from fastapi import Request, HTTPException -from datetime import datetime, timedelta -from collections import defaultdict -from typing import Callable +from fastapi import FastAPI +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request from fastapi.responses import JSONResponse - -_request_timestamps = defaultdict(list) +import aioredis +import datetime +import os MAX_REQUESTS = 10 -TIME_WINDOW = 60 - -async def rate_limit_middleware(request: Request, call_next: Callable): - global _request_timestamps - - client_ip = request.client.host - now = datetime.now() - - if client_ip not in _request_timestamps: - _request_timestamps[client_ip] = [] - - # Remove expired timestamps - _request_timestamps[client_ip] = [ - timestamp for timestamp in _request_timestamps[client_ip] - if now - timestamp <= timedelta(seconds=TIME_WINDOW) - ] - - # Check the number of requests - if len(_request_timestamps[client_ip]) >= MAX_REQUESTS: - # Provide a more informative HTTP 429 response - detail = { - "error": "Too Many Requests", - "message": f"You have exceeded the maximum number of requests ({MAX_REQUESTS}) in the time window ({TIME_WINDOW}s)." - } - return JSONResponse( - status_code=429, - content=detail, - ) - - # Log the current request - _request_timestamps[client_ip].append(now) +TIME_WINDOW = 60 + +class RateLimitMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + client_ip = request.client.host + now = datetime.datetime.now() + + # Updated for aioredis v2.x + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379") + redis = aioredis.from_url(redis_url, encoding="utf-8", decode_responses=True) + + try: + request_count = await redis.get(client_ip) + request_count = int(request_count) if request_count else 0 + + if request_count >= MAX_REQUESTS: + ttl = await redis.ttl(client_ip) + detail = { + "error": "Too Many Requests", + "message": f"Rate limit exceeded. Try again in {ttl} seconds." + } + return JSONResponse(status_code=429, content=detail) + + pipe = redis.pipeline() + pipe.incr(client_ip) + pipe.expire(client_ip, TIME_WINDOW) + await pipe.execute() + finally: + pass + + response = await call_next(request) + return response - response = await call_next(request) - return response diff --git a/requirements.txt b/requirements.txt index 6900906..8d31bda 100644 --- a/requirements.txt +++ b/requirements.txt @@ -108,4 +108,5 @@ itsdangerous pdf2image alembic email-validator -werkzeug \ No newline at end of file +werkzeug +aioredis \ No newline at end of file diff --git a/routes/users.py b/routes/users.py index bc42752..2120aee 100644 --- a/routes/users.py +++ b/routes/users.py @@ -7,6 +7,7 @@ from daos.users import login as signin from schemas.users import CreateUser from schemas.users import Login +from utils.redis_utils import get_redis from utils.user_utils import get_current_user user = APIRouter() @@ -22,11 +23,11 @@ def login(payload: Login, db: Session = Depends(create_local_session)): return response @user.get("/{user_id}", tags=["Users"]) -def profile(user_id, db: Session = Depends(create_local_session)): - response = get_user_dao(user_id, dbSession=db) +async def profile(user_id, db: Session = Depends(create_local_session), redis=Depends(get_redis)): + # Here, you can use 'redis' to fetch or store data in Redis cache + response = await get_user_dao(user_id, dbSession=db, redis=redis) return response - @user.get("/secure-route/", tags=["Users"], dependencies=[Depends(get_current_user)]) def secure_route(): return {"message": "If you see this, you're authenticated"} diff --git a/utils/redis_utils.py b/utils/redis_utils.py new file mode 100644 index 0000000..c79c9e4 --- /dev/null +++ b/utils/redis_utils.py @@ -0,0 +1,4 @@ +from config.redis_config import get_redis_pool + +async def get_redis(): + return await get_redis_pool() \ No newline at end of file