This repository has been archived by the owner on May 5, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 9
/
rate_limit_quotas.py
220 lines (183 loc) · 7.58 KB
/
rate_limit_quotas.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""
Rate limiter and quota system.
Framework-independent rate limiting mechanism that provides:
* IP address and token-based accounting
* customizable quotas based on IP address and token
* late limiting based on resource usage (time spent on API calls)
* bucketing based on day, week, month
* statistics
* metrics
* fast in-memory storage
Also provides a connector for Flask
"""
import time
import ipaddress
from typing import Dict, List, Optional, Tuple, Union
IpAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
IpAddrBucket = Dict[IpAddress, float]
IpAddrBuckets = Tuple[IpAddrBucket, IpAddrBucket, IpAddrBucket]
TokenBucket = Dict[str, float]
TokenBuckets = Tuple[TokenBucket, TokenBucket, TokenBucket]
class Limiter:
def __init__(
self,
limits: dict,
token_check_callback=None,
ipaddr_methods=["X-Real-Ip", "socket"],
whitelisted_ipaddrs=Optional[List[str]],
):
# Bucket sequence: month, week, day
self._ipaddr_limits = [
limits.get(l, None)
for l in ("ipaddr_per_month", "ipaddr_per_week", "ipaddr_per_day")
]
self._token_limits = [
limits.get(l, None)
for l in ("token_per_month", "token_per_week", "token_per_day")
]
self._ipaddr_buckets = ({}, {}, {}) # type: IpAddrBuckets
self._token_buckets = ({}, {}, {}) # type: TokenBuckets
self._token_check_callback = token_check_callback
self._ipaddr_extraction_methods = ipaddr_methods
self._last_quota_update_time = time.monotonic()
self._whitelisted_ipaddrs = set()
for ipa in whitelisted_ipaddrs or []:
self._whitelisted_ipaddrs.add(ipaddress.ip_address(ipa))
self.increment_quota_counters(1)
self.refresh_quota_counters_if_needed()
def increment_quota_counters(self, tdelta: int):
"""Delta: time from previous run in seconds"""
if tdelta <= 0:
return
iterable = (
(30 * 24, self._ipaddr_limits[0], self._ipaddr_buckets[0]),
(7 * 24, self._ipaddr_limits[0], self._ipaddr_buckets[0]),
(1 * 24, self._ipaddr_limits[0], self._ipaddr_buckets[0]),
(30 * 24, self._ipaddr_limits[0], self._ipaddr_buckets[0]),
(7 * 24, self._ipaddr_limits[0], self._ipaddr_buckets[0]),
(1 * 24, self._ipaddr_limits[0], self._ipaddr_buckets[0]),
)
for hours, limit, bucket in iterable:
vdelta = limit / hours / 3600 * tdelta
to_delete = []
for k, v in bucket.items():
v += vdelta
if v >= limit:
to_delete.append(k)
else:
bucket[k] = v
for k in to_delete:
del bucket[k]
def refresh_quota_counters_if_needed(self):
t = time.monotonic()
delta = t - self._last_quota_update_time
if delta > 3600:
self.increment_quota_counters(delta)
self._last_quota_update_time = t
def consume_quota(self, elapsed: float, ipaddr: Optional[IpAddress]=None, token=None) -> None:
"""Consume quota in seconds
"""
assert ipaddr or token
if ipaddr:
assert isinstance(ipaddr, ipaddress.IPv4Address)
for n, limit in enumerate(self._ipaddr_limits):
b = self._ipaddr_buckets[n]
b[ipaddr] = b.get(ipaddr, limit) - elapsed
else:
raise NotImplementedError()
def get_minimum_across_quotas(self, ipaddr=None, token=None) -> float:
assert ipaddr or token
if ipaddr:
iterable = zip(self._ipaddr_limits, self._ipaddr_buckets)
return min(bucket.get(ipaddr, limit) for limit, bucket in iterable)
else:
raise NotImplementedError()
def is_quota_available(self, ipaddr=None, token=None) -> bool:
"""Check if all quota buckets for an ipaddr/token are > 0
"""
# return False if any bucket reached 0
for bucket in self._ipaddr_buckets:
if ipaddr in bucket:
if bucket[ipaddr] <= 0:
return False
return True
def is_ipaddr_whitelisted(self, ipaddr: IpAddress) -> bool:
return ipaddr in self._whitelisted_ipaddrs
def get_lowest_daily_quotas_summary(self, n=20) -> List[Tuple[int, float]]:
"""Returns a summary of daily quotas with the lowest values
"""
li = sorted((val, ipa) for ipa, val in self._ipaddr_buckets[2].items())
li = li[:n]
return [(int(ipa.packed[0]), val) for val, ipa in li]
# # Flask-specific code # #
from flask import request, current_app
import flask
class FlaskLimiter:
def _get_client_ipaddr(self) -> IpAddress:
# https://github.com/alisaifee/flask-limiter/issues/41
for m in self._limiter._ipaddr_extraction_methods:
if m == "X-Forwarded-For":
raise NotImplementedError("X-Forwarded-For ")
elif m == "X-Real-Ip":
ipaddr = request.headers.get("X-Real-Ip", None)
if ipaddr:
return ipaddress.ip_address(ipaddr)
elif m == "socket":
return ipaddress.ip_address(request.remote_addr)
else:
raise NotImplementedError(f"IP address method {m} is unknown")
methods = ",".join(self._limiter._ipaddr_extraction_methods)
raise Exception(f"Unable to detect IP address using {methods}")
def _check_limits_callback(self):
"""Check rate limits before processing a request
Refresh quota counters when needed
"""
self._limiter.refresh_quota_counters_if_needed()
ipaddr = self._get_client_ipaddr()
# token = request.headers.get("Token", None)
# if token:
# check token validity
if not self._limiter.is_quota_available(ipaddr=ipaddr):
flask.abort(429)
self._request_start_time = time.monotonic()
log = current_app.logger
log.error("_check_limits_callback called")
def _after_request_callback(self, response):
"""Consume quota and injects HTTP headers when responding to a request
"""
log = current_app.logger
try:
assert response
tdelta = time.monotonic() - self._request_start_time
ipaddr = self._get_client_ipaddr()
if not self._limiter.is_ipaddr_whitelisted(ipaddr):
self._limiter.consume_quota(tdelta, ipaddr=ipaddr)
q = self._limiter.get_minimum_across_quotas(ipaddr=ipaddr)
response.headers.add("X-RateLimit-Remaining", q)
except Exception as e:
log.error(str(e), exc_info=True)
finally:
return response
def __init__(
self,
app,
limits: dict,
token_check_callback=None,
ipaddr_methods=["X-Real-Ip", "socket"],
whitelisted_ipaddrs=None,
):
"""
"""
self._limiter = Limiter(
limits,
token_check_callback=token_check_callback,
ipaddr_methods=ipaddr_methods,
whitelisted_ipaddrs=whitelisted_ipaddrs,
)
if app.extensions.get("limiter"):
raise Exception("The Flask app already has an extension named 'limiter'")
app.before_request(self._check_limits_callback)
app.after_request(self._after_request_callback)
app.extensions["limiter"] = self
def get_lowest_daily_quotas_summary(self, n=20) -> List[Tuple[int, float]]:
return self._limiter.get_lowest_daily_quotas_summary(n)