Skip to content

Commit

Permalink
feat: session ratelimit, flexible config
Browse files Browse the repository at this point in the history
  • Loading branch information
Colerar committed Oct 27, 2023
1 parent 9fab4a3 commit f72e546
Showing 1 changed file with 81 additions and 40 deletions.
121 changes: 81 additions & 40 deletions backend/funix/decorator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""
Funix decorator. The central logic of Funix.
"""
import time
from collections import deque
from copy import deepcopy
from enum import Enum, auto
from functools import wraps
from importlib import import_module
from inspect import Parameter, Signature, getsource, signature
Expand All @@ -12,8 +15,6 @@
from typing import Any, Optional
from urllib.request import urlopen
from uuid import uuid4
from collections import deque
import time

from flask import Response, request, session
from requests import post
Expand Down Expand Up @@ -242,10 +243,73 @@
Kumo callback token.
"""

call_history: dict = {}
"""
Rate limit call history
"""

class LimitSource(Enum):
"""
rate limit based on what value
"""

# Based on browser session
SESSION = auto()

# Based on IP
IP = auto()


class Limiter:
call_history: dict
# How many calls client can send between each interval set by `time_frame`
max_calls: int
# Max call interval time, in seconds
time_frame: int
source: LimitSource

def __init__(
self,
max_calls: int = 10,
time_frame: int = 60,
source: LimitSource = LimitSource.SESSION,
):
self.source = source
self.max_calls = max_calls
self.time_frame = time_frame
self.call_history = {}

@staticmethod
def ip(max_calls: int):
return Limiter(max_calls=max_calls, source=LimitSource.IP)

@staticmethod
def session(max_calls: int):
return Limiter(max_calls=max_calls, source=LimitSource.SESSION)

def rate_limit(self) -> Optional[Response]:
call_history = self.call_history
match self.source:
case LimitSource.IP:
source = request.remote_addr
case LimitSource.SESSION:
source = session.get("__funix_id")

if source not in call_history:
call_history[source] = deque()

queue = call_history[source]
current_time = time.time()

while len(queue) > 0 and current_time - queue[0] > self.time_frame:
queue.popleft()

if len(queue) >= self.max_calls:
time_passed = current_time - queue[0]
time_to_wait = int(self.time_frame - time_passed)
error_message = (
f"Rate limit exceeded. Please try again in {time_to_wait} seconds."
)
return Response(error_message, status=429, mimetype="text/plain")

queue.append(current_time)
return None


def set_kumo_info(url: str, token: str) -> None:
Expand Down Expand Up @@ -461,8 +525,7 @@ def funix(
argument_config: ArgumentConfigType = None,
pre_fill: PreFillType = None,
menu: Optional[str] = None,
max_calls: Optional[int] = None,
time_frame: Optional[int] = None,
rate_limit: Limiter | list[Limiter] = list(),
):
"""
Decorator for functions to convert them to web apps
Expand Down Expand Up @@ -496,8 +559,7 @@ def funix(
menu(str):
full module path of the function, for `path` only.
You don't need to set it unless you are funixing a directory and package.
max_calls(int): How many calls client can send between each interval set by `time_frame`
time_frame(int): Max call interval time, in seconds
rate_limit(Limiter | list[Limiter]): rate limiters, an object or a list
Returns:
function: the decorated function
Expand Down Expand Up @@ -1393,36 +1455,15 @@ def wrapper():
Any: The function's result
"""

global call_history

if max_calls is not None or time_frame is not None:
new_time_frame = 60
new_max_calls = 5
if time_frame:
new_time_frame = time_frame
if max_calls:
new_max_calls = max_calls

ip = request.remote_addr

if ip not in call_history:
call_history[ip] = deque()

queue = call_history[ip]
current_time = time.time()

while len(queue) > 0 and current_time - queue[0] > new_time_frame:
queue.popleft()

if len(queue) >= new_max_calls:
time_passed = current_time - queue[0]
time_to_wait = int(new_time_frame - time_passed)
error_message = f"Rate limit exceeded. Please try again in {time_to_wait} seconds."
return Response(
error_message, status=429, mimetype="text/plain"
)

queue.append(current_time)
if isinstance(rate_limit, Limiter):
limit_result = rate_limit.rate_limit()
if limit_result is not None:
return limit_result
else:
for limiter in rate_limit:
limit_result = limiter.rate_limit()
if limit_result is not None:
return limit_result

try:
if not session.get("__funix_id"):
Expand Down

0 comments on commit f72e546

Please sign in to comment.