Skip to content

Commit

Permalink
rate_limit: Stop wrapping rate limited functions.
Browse files Browse the repository at this point in the history
This refactors `rate_limit` so that we no longer use it as a decorator.
This is a workaround to python/mypy#12909 as
`rate_limit` previous expects different parameters than its callers.

Our approach to test logging handlers also needs to be updated because
the view function is not decorated by `rate_limit`.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>
  • Loading branch information
PIG208 committed Aug 5, 2022
1 parent c788e33 commit fe96381
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 62 deletions.
64 changes: 23 additions & 41 deletions zerver/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,8 @@ def _wrapped_view_func(
request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs
) -> HttpResponse:
process_client(request, request.user, is_browser_view=True, query=view_func.__name__)
return rate_limit(view_func)(request, *args, **kwargs)
rate_limit(request)
return view_func(request, *args, **kwargs)

return _wrapped_view_func

Expand Down Expand Up @@ -723,10 +724,8 @@ def _wrapped_func_arguments(
) -> HttpResponse:
user_profile = validate_api_key(request, None, api_key, False)
if not skip_rate_limiting:
limited_func = rate_limit(view_func)
else:
limited_func = view_func
return limited_func(request, user_profile, *args, **kwargs)
rate_limit(request)
return view_func(request, user_profile, *args, **kwargs)

return _wrapped_func_arguments

Expand Down Expand Up @@ -787,10 +786,8 @@ def _wrapped_func_arguments(
try:
if not skip_rate_limiting:
# Apply rate limiting
target_view_func = rate_limit(view_func)
else:
target_view_func = view_func
return target_view_func(request, profile, *args, **kwargs)
rate_limit(request)
return view_func(request, profile, *args, **kwargs)
except Exception as err:
if not webhook_client_name:
raise err
Expand Down Expand Up @@ -864,9 +861,7 @@ def authenticate_log_and_execute_json(
**kwargs: object,
) -> HttpResponse:
if not skip_rate_limiting:
limited_view_func = rate_limit(view_func)
else:
limited_view_func = view_func
rate_limit(request)

if not request.user.is_authenticated:
if not allow_unauthenticated:
Expand All @@ -877,7 +872,7 @@ def authenticate_log_and_execute_json(
is_browser_view=True,
query=view_func.__name__,
)
return limited_view_func(request, request.user, *args, **kwargs)
return view_func(request, request.user, *args, **kwargs)

user_profile = request.user
validate_account_and_subdomain(request, user_profile)
Expand All @@ -886,7 +881,7 @@ def authenticate_log_and_execute_json(
raise JsonableError(_("Webhook bots can only access webhooks"))

process_client(request, user_profile, is_browser_view=True, query=view_func.__name__)
return limited_view_func(request, user_profile, *args, **kwargs)
return view_func(request, user_profile, *args, **kwargs)


# Checks if the user is logged in. If not, return an error (the
Expand Down Expand Up @@ -1069,36 +1064,23 @@ def rate_limit_remote_server(
raise e


def rate_limit(func: ViewFuncT) -> ViewFuncT:
"""Rate-limits a view."""

@wraps(func)
def wrapped_func(request: HttpRequest, *args: object, **kwargs: object) -> HttpResponse:

# It is really tempting to not even wrap our original function
# when settings.RATE_LIMITING is False, but it would make
# for awkward unit testing in some situations.
if not settings.RATE_LIMITING:
return func(request, *args, **kwargs)

if client_is_exempt_from_rate_limiting(request):
return func(request, *args, **kwargs)

user = request.user
remote_server = RequestNotes.get_notes(request).remote_server
def rate_limit(request: HttpRequest) -> None:
if not settings.RATE_LIMITING:
return

if settings.ZILENCER_ENABLED and remote_server is not None:
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
elif not user.is_authenticated:
rate_limit_request_by_ip(request, domain="api_by_ip")
return func(request, *args, **kwargs)
else:
assert isinstance(user, UserProfile)
rate_limit_user(request, user, domain="api_by_user")
if client_is_exempt_from_rate_limiting(request):
return

return func(request, *args, **kwargs)
user = request.user
remote_server = RequestNotes.get_notes(request).remote_server

return cast(ViewFuncT, wrapped_func) # https://github.com/python/mypy/issues/1927
if settings.ZILENCER_ENABLED and remote_server is not None:
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
elif not user.is_authenticated:
rate_limit_request_by_ip(request, domain="api_by_ip")
else:
assert isinstance(user, UserProfile)
rate_limit_user(request, user, domain="api_by_user")


def return_success_on_head_request(
Expand Down
3 changes: 1 addition & 2 deletions zerver/tests/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,10 +630,9 @@ def test_authenticated_rest_api_view_errors(self) -> None:
class RateLimitTestCase(ZulipTestCase):
def get_ratelimited_view(self) -> Callable[..., HttpResponse]:
def f(req: Any) -> HttpResponse:
rate_limit(req)
return json_response(msg="some value")

f = rate_limit(f)

return f

def errors_disallowed(self) -> Any:
Expand Down
36 changes: 17 additions & 19 deletions zerver/tests/test_logging_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
from functools import wraps
from types import TracebackType
from typing import Callable, Dict, Iterator, NoReturn, Optional, Tuple, Type, Union, cast
from typing import Dict, Iterator, NoReturn, Optional, Tuple, Type, Union, cast
from unittest import mock
from unittest.mock import MagicMock, patch

Expand All @@ -22,22 +22,19 @@
] = None


def capture_and_throw(domain: Optional[str] = None) -> Callable[[ViewFuncT], ViewFuncT]:
def wrapper(view_func: ViewFuncT) -> ViewFuncT:
@wraps(view_func)
def wrapped_view(request: HttpRequest, *args: object, **kwargs: object) -> NoReturn:
global captured_request
captured_request = request
try:
raise Exception("Request error")
except Exception as e:
global captured_exc_info
captured_exc_info = sys.exc_info()
raise e

return cast(ViewFuncT, wrapped_view) # https://github.com/python/mypy/issues/1927
def capture_and_throw(view_func: ViewFuncT) -> ViewFuncT:
@wraps(view_func)
def wrapped_view(request: HttpRequest, *args: object, **kwargs: object) -> NoReturn:
global captured_request
captured_request = request
try:
raise Exception("Request error")
except Exception as e:
global captured_exc_info
captured_exc_info = sys.exc_info()
raise e

return wrapper
return cast(ViewFuncT, wrapped_view) # https://github.com/python/mypy/issues/1927


class AdminNotifyHandlerTest(ZulipTestCase):
Expand Down Expand Up @@ -78,17 +75,18 @@ def test_basic(self, mock_function: MagicMock) -> None:

def simulate_error(self) -> logging.LogRecord:
self.login("hamlet")
with patch("zerver.decorator.rate_limit") as rate_limit_patch, self.assertLogs(
with patch(
"zerver.lib.rest.authenticated_json_view", side_effect=capture_and_throw
) as view_decorator_patch, self.assertLogs(
"django.request", level="ERROR"
) as request_error_log, self.assertLogs(
"zerver.middleware.json_error_handler", level="ERROR"
) as json_error_handler_log, self.settings(
TEST_SUITE=False
):
rate_limit_patch.side_effect = capture_and_throw
result = self.client_get("/json/users")
self.assert_json_error(result, "Internal server error", status_code=500)
rate_limit_patch.assert_called_once()
view_decorator_patch.assert_called_once()
self.assertEqual(
request_error_log.output, ["ERROR:django.request:Internal Server Error: /json/users"]
)
Expand Down

0 comments on commit fe96381

Please sign in to comment.