-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from wednesday-solutions/feat/redis
Implement Rate Limiting Middleware and Update Redis Integration
- Loading branch information
Showing
10 changed files
with
81 additions
and
75 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import aioredis | ||
import os | ||
|
||
async def get_redis_pool(): | ||
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379") | ||
return aioredis.from_url(redis_url) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,43 +1,42 @@ | ||
from fastapi import Request, HTTPException | ||
from datetime import datetime, timedelta | ||
from collections import defaultdict | ||
from typing import Callable | ||
from fastapi import FastAPI | ||
from starlette.middleware.base import BaseHTTPMiddleware | ||
from starlette.requests import Request | ||
from fastapi.responses import JSONResponse | ||
|
||
_request_timestamps = defaultdict(list) | ||
import aioredis | ||
import datetime | ||
import os | ||
|
||
MAX_REQUESTS = 10 | ||
TIME_WINDOW = 60 | ||
|
||
async def rate_limit_middleware(request: Request, call_next: Callable): | ||
global _request_timestamps | ||
|
||
client_ip = request.client.host | ||
now = datetime.now() | ||
|
||
if client_ip not in _request_timestamps: | ||
_request_timestamps[client_ip] = [] | ||
|
||
# Remove expired timestamps | ||
_request_timestamps[client_ip] = [ | ||
timestamp for timestamp in _request_timestamps[client_ip] | ||
if now - timestamp <= timedelta(seconds=TIME_WINDOW) | ||
] | ||
|
||
# Check the number of requests | ||
if len(_request_timestamps[client_ip]) >= MAX_REQUESTS: | ||
# Provide a more informative HTTP 429 response | ||
detail = { | ||
"error": "Too Many Requests", | ||
"message": f"You have exceeded the maximum number of requests ({MAX_REQUESTS}) in the time window ({TIME_WINDOW}s)." | ||
} | ||
return JSONResponse( | ||
status_code=429, | ||
content=detail, | ||
) | ||
|
||
# Log the current request | ||
_request_timestamps[client_ip].append(now) | ||
TIME_WINDOW = 60 | ||
|
||
class RateLimitMiddleware(BaseHTTPMiddleware): | ||
async def dispatch(self, request: Request, call_next): | ||
client_ip = request.client.host | ||
now = datetime.datetime.now() | ||
|
||
# Updated for aioredis v2.x | ||
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379") | ||
redis = aioredis.from_url(redis_url, encoding="utf-8", decode_responses=True) | ||
|
||
try: | ||
request_count = await redis.get(client_ip) | ||
request_count = int(request_count) if request_count else 0 | ||
|
||
if request_count >= MAX_REQUESTS: | ||
ttl = await redis.ttl(client_ip) | ||
detail = { | ||
"error": "Too Many Requests", | ||
"message": f"Rate limit exceeded. Try again in {ttl} seconds." | ||
} | ||
return JSONResponse(status_code=429, content=detail) | ||
|
||
pipe = redis.pipeline() | ||
pipe.incr(client_ip) | ||
pipe.expire(client_ip, TIME_WINDOW) | ||
await pipe.execute() | ||
finally: | ||
pass | ||
|
||
response = await call_next(request) | ||
return response | ||
|
||
response = await call_next(request) | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -108,4 +108,5 @@ itsdangerous | |
pdf2image | ||
alembic | ||
email-validator | ||
werkzeug | ||
werkzeug | ||
aioredis |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from config.redis_config import get_redis_pool | ||
|
||
async def get_redis(): | ||
return await get_redis_pool() |