Skip to content

Commit

Permalink
Ensure scoping works with context manager
Browse files Browse the repository at this point in the history
  • Loading branch information
alisaifee committed Dec 27, 2022
1 parent 1870a8d commit f22dfbd
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 9 deletions.
5 changes: 2 additions & 3 deletions flask_limiter/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,6 @@ def __evaluate_limits(self, endpoint: str, limits: List[Limit]) -> None:
limit_scope = f"{endpoint}:{lim.scope}"
else:
limit_scope = lim.scope or endpoint

if lim.is_exempt or lim.method_exempt:
continue

Expand Down Expand Up @@ -1122,11 +1121,11 @@ def __enter__(self) -> None:
# on the limit manager's knowledge of decorated limits might be worth it.
if not self.is_static:
self.limiter.limit_manager.add_decorated_runtime_limit(
qualified_location, self.dynamic_limit
qualified_location, self.dynamic_limit, override=True
)
else:
self.limiter.limit_manager.add_decorated_static_limit(
qualified_location, *self.static_limits
qualified_location, *self.static_limits, override=True
)

self.limiter.limit_manager.add_endpoint_hint(
Expand Down
22 changes: 16 additions & 6 deletions flask_limiter/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,26 @@ def set_application_limits(self, limits: List[LimitGroup]) -> None:
def set_default_limits(self, limits: List[LimitGroup]) -> None:
self._default_limits = limits

def add_decorated_runtime_limit(self, route: str, limit: LimitGroup) -> None:
self._runtime_decorated_limits.setdefault(route, OrderedSet()).add(limit)
def add_decorated_runtime_limit(
self, route: str, limit: LimitGroup, override: bool = False
) -> None:
if not override:
self._runtime_decorated_limits.setdefault(route, OrderedSet()).add(limit)
else:
self._runtime_decorated_limits[route] = OrderedSet([limit])

def add_runtime_blueprint_limits(self, blueprint: str, limit: LimitGroup) -> None:
self._runtime_blueprint_limits.setdefault(blueprint, OrderedSet()).add(limit)

def add_decorated_static_limit(self, route: str, *limits: Limit) -> None:
self._static_decorated_limits.setdefault(route, OrderedSet()).update(
OrderedSet(limits)
)
def add_decorated_static_limit(
self, route: str, *limits: Limit, override: bool = False
) -> None:
if not override:
self._static_decorated_limits.setdefault(route, OrderedSet()).update(
OrderedSet(limits)
)
else:
self._static_decorated_limits[route] = OrderedSet(limits)

def add_static_blueprint_limits(self, blueprint: str, *limits: Limit) -> None:
self._static_blueprint_limits.setdefault(blueprint, OrderedSet()).update(
Expand Down
16 changes: 16 additions & 0 deletions tests/test_context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,19 @@ def t1():
with app.test_client() as cli:
assert 200 == cli.get("/t1").status_code
assert 429 == cli.get("/t1").status_code


def test_scoped_context_manager(extension_factory):
app, limiter = extension_factory()

@app.route("/t1/<int:param>")
def t1(param: int):
with limiter.limit("1/second", scope=param):
return "p1"

with hiro.Timeline().freeze() as timeline:
with app.test_client() as cli:
assert 200 == cli.get("/t1/1").status_code
assert 429 == cli.get("/t1/1").status_code
assert 200 == cli.get("/t1/2").status_code
assert 429 == cli.get("/t1/2").status_code

0 comments on commit f22dfbd

Please sign in to comment.