diff --git a/.env b/.env index 5df3729..4db49ff 100644 --- a/.env +++ b/.env @@ -5,4 +5,5 @@ DB_PORT=3306 DB_NAME=fastapi OPENAI_API_KEY_GPT4= OPENAI_API_KEY_WEDNESDAY= -YOUR_SECRET_KEY=my_super_secret_key_here \ No newline at end of file +YOUR_SECRET_KEY=my_super_secret_key_here +# REDIS_URL=http://localhost:6379 \ No newline at end of file diff --git a/README.md b/README.md index d4e8f20..3593c7f 100644 --- a/README.md +++ b/README.md @@ -1,30 +1,8 @@ - -# FastAPI Template This repository provides a template for creating and deploying a FastAPI project. Follow the steps below to set up the local environment, run the project, manage database migrations, and deploy the service on AWS ECS. @@ -69,6 +47,7 @@ To deploy the FastAPI application on AWS ECS, use the following script: ``` ./scripts/setup-ecs.sh develop ``` + The setup-ecs.sh script leverages AWS Copilot to deploy the service. Provide the environment name as an argument (e.g., develop). The script creates and deploys an environment, then deploys the FastAPI service on that environment. Note: Ensure you have AWS credentials configured and AWS Copilot installed for successful deployment. @@ -76,3 +55,8 @@ Note: Ensure you have AWS credentials configured and AWS Copilot installed for s #### New to AWS Copilot? If you are new to AWS Copilot or you want to learn more about AWS Copilot, please refer to [this helpful article](https://www.wednesday.is/writing-tutorials/how-to-use-copilot-to-deploy-projects-on-ecs) that guides you through the process of setting up AWS Copilot locally as well as also helps you understand how you can publish and update an application using 4 simple steps. +### 5. Redis Dependency +``` +docker run --name recorder-redis -p 6379:6379 -d redis:alpine +``` +or add the REDIS_URL in .env file diff --git a/app.py b/app.py index 105f404..be9aa5f 100644 --- a/app.py +++ b/app.py @@ -6,7 +6,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from routes.users import user -from middlewares.rate_limiter_middleware import rate_limit_middleware +from middlewares.rate_limiter_middleware import RateLimitMiddleware # Initializing the swagger docs @@ -28,7 +28,7 @@ allow_methods=["*"], allow_headers=["*"], ) -app.middleware("http")(rate_limit_middleware) +app.add_middleware(RateLimitMiddleware) app.include_router(user, prefix='/user') # Default API route diff --git a/config/db.py b/config/db.py index c382986..57d5c6b 100644 --- a/config/db.py +++ b/config/db.py @@ -13,11 +13,12 @@ load_dotenv() # Set the default values for connecting locally -HOST = os.environ.get("DB_HOSTNAME") -PORT = os.environ.get("DB_PORT") -DBNAME = os.environ.get("DB_NAME") -USERNAME = os.environ.get("DB_USERNAME") -PASSWORD = os.environ.get("DB_PASSWORD") +HOST = os.environ.get("DB_HOSTNAME", "localhost") +PORT = os.environ.get("DB_PORT", "3306") +DBNAME = os.environ.get("DB_NAME", "mydbname") +USERNAME = os.environ.get("DB_USERNAME", "user") +PASSWORD = os.environ.get("DB_PASSWORD", "password") + if "pytest" in sys.modules: SQLALCHEMY_DATABASE_URL = "sqlite://" diff --git a/config/redis_config.py b/config/redis_config.py new file mode 100644 index 0000000..84ca4bf --- /dev/null +++ b/config/redis_config.py @@ -0,0 +1,6 @@ +import aioredis +import os + +async def get_redis_pool(): + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379") + return aioredis.from_url(redis_url) diff --git a/daos/users.py b/daos/users.py index 8923ecc..fe8eab4 100644 --- a/daos/users.py +++ b/daos/users.py @@ -1,3 +1,5 @@ +import json +from aioredis import Redis from fastapi import HTTPException from sqlalchemy.orm import Session from constants import jwt_utils @@ -6,11 +8,17 @@ from schemas.users import CreateUser from schemas.users import Login from werkzeug.security import check_password_hash +from utils import redis_utils from utils.user_utils import check_existing_field, responseFormatter def get_user(user_id: int, dbSession: Session): try: + cache_key = f"user:{user_id}" + cached_user = redis_utils.get_redis().get(cache_key) + + if cached_user: + return json.loads(cached_user) # Check if the subject already exists in the database user = ( dbSession.query(User) @@ -26,7 +34,8 @@ def get_user(user_id: int, dbSession: Session): ) .first() ) - + if user: + Redis.set(cache_key, json.dumps(user)) if not user: raise Exception(messages["NO_USER_FOUND_FOR_ID"]) diff --git a/middlewares/rate_limiter_middleware.py b/middlewares/rate_limiter_middleware.py index 67a8241..8affa55 100644 --- a/middlewares/rate_limiter_middleware.py +++ b/middlewares/rate_limiter_middleware.py @@ -1,43 +1,42 @@ -from fastapi import Request, HTTPException -from datetime import datetime, timedelta -from collections import defaultdict -from typing import Callable +from fastapi import FastAPI +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request from fastapi.responses import JSONResponse - -_request_timestamps = defaultdict(list) +import aioredis +import datetime +import os MAX_REQUESTS = 10 -TIME_WINDOW = 60 - -async def rate_limit_middleware(request: Request, call_next: Callable): - global _request_timestamps - - client_ip = request.client.host - now = datetime.now() - - if client_ip not in _request_timestamps: - _request_timestamps[client_ip] = [] - - # Remove expired timestamps - _request_timestamps[client_ip] = [ - timestamp for timestamp in _request_timestamps[client_ip] - if now - timestamp <= timedelta(seconds=TIME_WINDOW) - ] - - # Check the number of requests - if len(_request_timestamps[client_ip]) >= MAX_REQUESTS: - # Provide a more informative HTTP 429 response - detail = { - "error": "Too Many Requests", - "message": f"You have exceeded the maximum number of requests ({MAX_REQUESTS}) in the time window ({TIME_WINDOW}s)." - } - return JSONResponse( - status_code=429, - content=detail, - ) - - # Log the current request - _request_timestamps[client_ip].append(now) +TIME_WINDOW = 60 + +class RateLimitMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + client_ip = request.client.host + now = datetime.datetime.now() + + # Updated for aioredis v2.x + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379") + redis = aioredis.from_url(redis_url, encoding="utf-8", decode_responses=True) + + try: + request_count = await redis.get(client_ip) + request_count = int(request_count) if request_count else 0 + + if request_count >= MAX_REQUESTS: + ttl = await redis.ttl(client_ip) + detail = { + "error": "Too Many Requests", + "message": f"Rate limit exceeded. Try again in {ttl} seconds." + } + return JSONResponse(status_code=429, content=detail) + + pipe = redis.pipeline() + pipe.incr(client_ip) + pipe.expire(client_ip, TIME_WINDOW) + await pipe.execute() + finally: + pass + + response = await call_next(request) + return response - response = await call_next(request) - return response diff --git a/requirements.txt b/requirements.txt index 6900906..8d31bda 100644 --- a/requirements.txt +++ b/requirements.txt @@ -108,4 +108,5 @@ itsdangerous pdf2image alembic email-validator -werkzeug \ No newline at end of file +werkzeug +aioredis \ No newline at end of file diff --git a/routes/users.py b/routes/users.py index bc42752..2120aee 100644 --- a/routes/users.py +++ b/routes/users.py @@ -7,6 +7,7 @@ from daos.users import login as signin from schemas.users import CreateUser from schemas.users import Login +from utils.redis_utils import get_redis from utils.user_utils import get_current_user user = APIRouter() @@ -22,11 +23,11 @@ def login(payload: Login, db: Session = Depends(create_local_session)): return response @user.get("/{user_id}", tags=["Users"]) -def profile(user_id, db: Session = Depends(create_local_session)): - response = get_user_dao(user_id, dbSession=db) +async def profile(user_id, db: Session = Depends(create_local_session), redis=Depends(get_redis)): + # Here, you can use 'redis' to fetch or store data in Redis cache + response = await get_user_dao(user_id, dbSession=db, redis=redis) return response - @user.get("/secure-route/", tags=["Users"], dependencies=[Depends(get_current_user)]) def secure_route(): return {"message": "If you see this, you're authenticated"} diff --git a/utils/redis_utils.py b/utils/redis_utils.py new file mode 100644 index 0000000..c79c9e4 --- /dev/null +++ b/utils/redis_utils.py @@ -0,0 +1,4 @@ +from config.redis_config import get_redis_pool + +async def get_redis(): + return await get_redis_pool() \ No newline at end of file