Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Rate Limiting Middleware and Update Redis Integration #7

Merged
merged 2 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .env
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ DB_PORT=3306
DB_NAME=fastapi
OPENAI_API_KEY_GPT4=
OPENAI_API_KEY_WEDNESDAY=
YOUR_SECRET_KEY=my_super_secret_key_here
YOUR_SECRET_KEY=my_super_secret_key_here
# REDIS_URL=http://localhost:6379
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,10 @@ alembic revision -m 'initialize all models'
## Upgrade migrations
```
alembic upgrade head
```
```

## Redis Dependency
```
docker run --name recorder-redis -p 6379:6379 -d redis:alpine
```
or add the REDIS_URL in .env file
4 changes: 2 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from routes.users import user
from middlewares.rate_limiter_middleware import rate_limit_middleware
from middlewares.rate_limiter_middleware import RateLimitMiddleware


# Initializing the swagger docs
Expand All @@ -28,7 +28,7 @@
allow_methods=["*"],
allow_headers=["*"],
)
app.middleware("http")(rate_limit_middleware)
app.add_middleware(RateLimitMiddleware)
app.include_router(user, prefix='/user')

# Default API route
Expand Down
11 changes: 6 additions & 5 deletions config/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
load_dotenv()

# Set the default values for connecting locally
HOST = os.environ.get("DB_HOSTNAME")
PORT = os.environ.get("DB_PORT")
DBNAME = os.environ.get("DB_NAME")
USERNAME = os.environ.get("DB_USERNAME")
PASSWORD = os.environ.get("DB_PASSWORD")
HOST = os.environ.get("DB_HOSTNAME", "localhost")
PORT = os.environ.get("DB_PORT", "3306")
DBNAME = os.environ.get("DB_NAME", "mydbname")
USERNAME = os.environ.get("DB_USERNAME", "user")
PASSWORD = os.environ.get("DB_PASSWORD", "password")


saurabh-wednesday marked this conversation as resolved.
Show resolved Hide resolved
if "pytest" in sys.modules:
SQLALCHEMY_DATABASE_URL = "sqlite://"
Expand Down
6 changes: 6 additions & 0 deletions config/redis_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import aioredis
import os

async def get_redis_pool():
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379")
return aioredis.from_url(redis_url)
11 changes: 10 additions & 1 deletion daos/users.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
from aioredis import Redis
from fastapi import HTTPException
from sqlalchemy.orm import Session
from constants import jwt_utils
Expand All @@ -6,11 +8,17 @@
from schemas.users import CreateUser
from schemas.users import Login
from werkzeug.security import check_password_hash
from utils import redis_utils
from utils.user_utils import check_existing_field, responseFormatter


def get_user(user_id: int, dbSession: Session):
try:
cache_key = f"user:{user_id}"
cached_user = redis_utils.get_redis().get(cache_key)

if cached_user:
return json.loads(cached_user)
# Check if the subject already exists in the database
user = (
dbSession.query(User)
Expand All @@ -26,7 +34,8 @@ def get_user(user_id: int, dbSession: Session):
)
.first()
)

if user:
Redis.set(cache_key, json.dumps(user))
if not user:
raise Exception(messages["NO_USER_FOUND_FOR_ID"])

Expand Down
77 changes: 38 additions & 39 deletions middlewares/rate_limiter_middleware.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,42 @@
from fastapi import Request, HTTPException
from datetime import datetime, timedelta
from collections import defaultdict
from typing import Callable
from fastapi import FastAPI
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from fastapi.responses import JSONResponse

_request_timestamps = defaultdict(list)
import aioredis
import datetime
import os

MAX_REQUESTS = 10
TIME_WINDOW = 60

async def rate_limit_middleware(request: Request, call_next: Callable):
global _request_timestamps

client_ip = request.client.host
now = datetime.now()

if client_ip not in _request_timestamps:
_request_timestamps[client_ip] = []

# Remove expired timestamps
_request_timestamps[client_ip] = [
timestamp for timestamp in _request_timestamps[client_ip]
if now - timestamp <= timedelta(seconds=TIME_WINDOW)
]

# Check the number of requests
if len(_request_timestamps[client_ip]) >= MAX_REQUESTS:
# Provide a more informative HTTP 429 response
detail = {
"error": "Too Many Requests",
"message": f"You have exceeded the maximum number of requests ({MAX_REQUESTS}) in the time window ({TIME_WINDOW}s)."
}
return JSONResponse(
status_code=429,
content=detail,
)

# Log the current request
_request_timestamps[client_ip].append(now)
TIME_WINDOW = 60

class RateLimitMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
client_ip = request.client.host
now = datetime.datetime.now()

# Updated for aioredis v2.x
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379")
redis = aioredis.from_url(redis_url, encoding="utf-8", decode_responses=True)

try:
request_count = await redis.get(client_ip)
request_count = int(request_count) if request_count else 0

if request_count >= MAX_REQUESTS:
ttl = await redis.ttl(client_ip)
detail = {
"error": "Too Many Requests",
"message": f"Rate limit exceeded. Try again in {ttl} seconds."
}
return JSONResponse(status_code=429, content=detail)

pipe = redis.pipeline()
pipe.incr(client_ip)
pipe.expire(client_ip, TIME_WINDOW)
await pipe.execute()
finally:
pass

response = await call_next(request)
return response

response = await call_next(request)
return response
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,5 @@ itsdangerous
pdf2image
alembic
email-validator
werkzeug
werkzeug
aioredis
7 changes: 4 additions & 3 deletions routes/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from daos.users import login as signin
from schemas.users import CreateUser
from schemas.users import Login
from utils.redis_utils import get_redis
from utils.user_utils import get_current_user

user = APIRouter()
Expand All @@ -22,11 +23,11 @@ def login(payload: Login, db: Session = Depends(create_local_session)):
return response

@user.get("/{user_id}", tags=["Users"])
def profile(user_id, db: Session = Depends(create_local_session)):
response = get_user_dao(user_id, dbSession=db)
async def profile(user_id, db: Session = Depends(create_local_session), redis=Depends(get_redis)):
# Here, you can use 'redis' to fetch or store data in Redis cache
response = await get_user_dao(user_id, dbSession=db, redis=redis)
return response


@user.get("/secure-route/", tags=["Users"], dependencies=[Depends(get_current_user)])
def secure_route():
return {"message": "If you see this, you're authenticated"}
4 changes: 4 additions & 0 deletions utils/redis_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from config.redis_config import get_redis_pool

async def get_redis():
return await get_redis_pool()