diff --git a/backend/funix/decorator/__init__.py b/backend/funix/decorator/__init__.py index 578683f..597ad17 100644 --- a/backend/funix/decorator/__init__.py +++ b/backend/funix/decorator/__init__.py @@ -1,7 +1,10 @@ """ Funix decorator. The central logic of Funix. """ +import time +from collections import deque from copy import deepcopy +from enum import Enum, auto from functools import wraps from importlib import import_module from inspect import Parameter, Signature, getsource, signature @@ -12,8 +15,6 @@ from typing import Any, Optional from urllib.request import urlopen from uuid import uuid4 -from collections import deque -import time from flask import Response, request, session from requests import post @@ -242,10 +243,73 @@ Kumo callback token. """ -call_history: dict = {} -""" -Rate limit call history -""" + +class LimitSource(Enum): + """ + rate limit based on what value + """ + + # Based on browser session + SESSION = auto() + + # Based on IP + IP = auto() + + +class Limiter: + call_history: dict + # How many calls client can send between each interval set by `time_frame` + max_calls: int + # Max call interval time, in seconds + time_frame: int + source: LimitSource + + def __init__( + self, + max_calls: int = 10, + time_frame: int = 60, + source: LimitSource = LimitSource.SESSION, + ): + self.source = source + self.max_calls = max_calls + self.time_frame = time_frame + self.call_history = {} + + @staticmethod + def ip(max_calls: int): + return Limiter(max_calls=max_calls, source=LimitSource.IP) + + @staticmethod + def session(max_calls: int): + return Limiter(max_calls=max_calls, source=LimitSource.SESSION) + + def rate_limit(self) -> Optional[Response]: + call_history = self.call_history + match self.source: + case LimitSource.IP: + source = request.remote_addr + case LimitSource.SESSION: + source = session.get("__funix_id") + + if source not in call_history: + call_history[source] = deque() + + queue = call_history[source] + current_time = time.time() + + while len(queue) > 0 and current_time - queue[0] > self.time_frame: + queue.popleft() + + if len(queue) >= self.max_calls: + time_passed = current_time - queue[0] + time_to_wait = int(self.time_frame - time_passed) + error_message = ( + f"Rate limit exceeded. Please try again in {time_to_wait} seconds." + ) + return Response(error_message, status=429, mimetype="text/plain") + + queue.append(current_time) + return None def set_kumo_info(url: str, token: str) -> None: @@ -461,8 +525,7 @@ def funix( argument_config: ArgumentConfigType = None, pre_fill: PreFillType = None, menu: Optional[str] = None, - max_calls: Optional[int] = None, - time_frame: Optional[int] = None, + rate_limit: Limiter | list[Limiter] = list(), ): """ Decorator for functions to convert them to web apps @@ -496,8 +559,7 @@ def funix( menu(str): full module path of the function, for `path` only. You don't need to set it unless you are funixing a directory and package. - max_calls(int): How many calls client can send between each interval set by `time_frame` - time_frame(int): Max call interval time, in seconds + rate_limit(Limiter | list[Limiter]): rate limiters, an object or a list Returns: function: the decorated function @@ -1393,36 +1455,15 @@ def wrapper(): Any: The function's result """ - global call_history - - if max_calls is not None or time_frame is not None: - new_time_frame = 60 - new_max_calls = 5 - if time_frame: - new_time_frame = time_frame - if max_calls: - new_max_calls = max_calls - - ip = request.remote_addr - - if ip not in call_history: - call_history[ip] = deque() - - queue = call_history[ip] - current_time = time.time() - - while len(queue) > 0 and current_time - queue[0] > new_time_frame: - queue.popleft() - - if len(queue) >= new_max_calls: - time_passed = current_time - queue[0] - time_to_wait = int(new_time_frame - time_passed) - error_message = f"Rate limit exceeded. Please try again in {time_to_wait} seconds." - return Response( - error_message, status=429, mimetype="text/plain" - ) - - queue.append(current_time) + if isinstance(rate_limit, Limiter): + limit_result = rate_limit.rate_limit() + if limit_result is not None: + return limit_result + else: + for limiter in rate_limit: + limit_result = limiter.rate_limit() + if limit_result is not None: + return limit_result try: if not session.get("__funix_id"):