diff --git a/zerver/decorator.py b/zerver/decorator.py index a9a1787a3b545..8056e6b796c4d 100644 --- a/zerver/decorator.py +++ b/zerver/decorator.py @@ -532,7 +532,8 @@ def _wrapped_view_func( request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs ) -> HttpResponse: process_client(request, request.user, is_browser_view=True, query=view_func.__name__) - return rate_limit(view_func)(request, *args, **kwargs) + rate_limit(request) + return view_func(request, *args, **kwargs) return _wrapped_view_func @@ -723,10 +724,8 @@ def _wrapped_func_arguments( ) -> HttpResponse: user_profile = validate_api_key(request, None, api_key, False) if not skip_rate_limiting: - limited_func = rate_limit(view_func) - else: - limited_func = view_func - return limited_func(request, user_profile, *args, **kwargs) + rate_limit(request) + return view_func(request, user_profile, *args, **kwargs) return _wrapped_func_arguments @@ -787,10 +786,8 @@ def _wrapped_func_arguments( try: if not skip_rate_limiting: # Apply rate limiting - target_view_func = rate_limit(view_func) - else: - target_view_func = view_func - return target_view_func(request, profile, *args, **kwargs) + rate_limit(request) + return view_func(request, profile, *args, **kwargs) except Exception as err: if not webhook_client_name: raise err @@ -864,9 +861,7 @@ def authenticate_log_and_execute_json( **kwargs: object, ) -> HttpResponse: if not skip_rate_limiting: - limited_view_func = rate_limit(view_func) - else: - limited_view_func = view_func + rate_limit(request) if not request.user.is_authenticated: if not allow_unauthenticated: @@ -877,7 +872,7 @@ def authenticate_log_and_execute_json( is_browser_view=True, query=view_func.__name__, ) - return limited_view_func(request, request.user, *args, **kwargs) + return view_func(request, request.user, *args, **kwargs) user_profile = request.user validate_account_and_subdomain(request, user_profile) @@ -886,7 +881,7 @@ def authenticate_log_and_execute_json( raise JsonableError(_("Webhook bots can only access webhooks")) process_client(request, user_profile, is_browser_view=True, query=view_func.__name__) - return limited_view_func(request, user_profile, *args, **kwargs) + return view_func(request, user_profile, *args, **kwargs) # Checks if the user is logged in. If not, return an error (the @@ -1069,36 +1064,23 @@ def rate_limit_remote_server( raise e -def rate_limit(func: ViewFuncT) -> ViewFuncT: - """Rate-limits a view.""" - - @wraps(func) - def wrapped_func(request: HttpRequest, *args: object, **kwargs: object) -> HttpResponse: - - # It is really tempting to not even wrap our original function - # when settings.RATE_LIMITING is False, but it would make - # for awkward unit testing in some situations. - if not settings.RATE_LIMITING: - return func(request, *args, **kwargs) - - if client_is_exempt_from_rate_limiting(request): - return func(request, *args, **kwargs) - - user = request.user - remote_server = RequestNotes.get_notes(request).remote_server +def rate_limit(request: HttpRequest) -> None: + if not settings.RATE_LIMITING: + return - if settings.ZILENCER_ENABLED and remote_server is not None: - rate_limit_remote_server(request, remote_server, domain="api_by_remote_server") - elif not user.is_authenticated: - rate_limit_request_by_ip(request, domain="api_by_ip") - return func(request, *args, **kwargs) - else: - assert isinstance(user, UserProfile) - rate_limit_user(request, user, domain="api_by_user") + if client_is_exempt_from_rate_limiting(request): + return - return func(request, *args, **kwargs) + user = request.user + remote_server = RequestNotes.get_notes(request).remote_server - return cast(ViewFuncT, wrapped_func) # https://github.com/python/mypy/issues/1927 + if settings.ZILENCER_ENABLED and remote_server is not None: + rate_limit_remote_server(request, remote_server, domain="api_by_remote_server") + elif not user.is_authenticated: + rate_limit_request_by_ip(request, domain="api_by_ip") + else: + assert isinstance(user, UserProfile) + rate_limit_user(request, user, domain="api_by_user") def return_success_on_head_request( diff --git a/zerver/tests/test_decorators.py b/zerver/tests/test_decorators.py index 36e88cf082410..70768681d9c5f 100644 --- a/zerver/tests/test_decorators.py +++ b/zerver/tests/test_decorators.py @@ -630,10 +630,9 @@ def test_authenticated_rest_api_view_errors(self) -> None: class RateLimitTestCase(ZulipTestCase): def get_ratelimited_view(self) -> Callable[..., HttpResponse]: def f(req: Any) -> HttpResponse: + rate_limit(req) return json_response(msg="some value") - f = rate_limit(f) - return f def errors_disallowed(self) -> Any: diff --git a/zerver/tests/test_logging_handlers.py b/zerver/tests/test_logging_handlers.py index 346d57a0435f0..ab1c20084f5b1 100644 --- a/zerver/tests/test_logging_handlers.py +++ b/zerver/tests/test_logging_handlers.py @@ -2,7 +2,7 @@ import sys from functools import wraps from types import TracebackType -from typing import Callable, Dict, Iterator, NoReturn, Optional, Tuple, Type, Union, cast +from typing import Dict, Iterator, NoReturn, Optional, Tuple, Type, Union, cast from unittest import mock from unittest.mock import MagicMock, patch @@ -22,22 +22,19 @@ ] = None -def capture_and_throw(domain: Optional[str] = None) -> Callable[[ViewFuncT], ViewFuncT]: - def wrapper(view_func: ViewFuncT) -> ViewFuncT: - @wraps(view_func) - def wrapped_view(request: HttpRequest, *args: object, **kwargs: object) -> NoReturn: - global captured_request - captured_request = request - try: - raise Exception("Request error") - except Exception as e: - global captured_exc_info - captured_exc_info = sys.exc_info() - raise e - - return cast(ViewFuncT, wrapped_view) # https://github.com/python/mypy/issues/1927 +def capture_and_throw(view_func: ViewFuncT) -> ViewFuncT: + @wraps(view_func) + def wrapped_view(request: HttpRequest, *args: object, **kwargs: object) -> NoReturn: + global captured_request + captured_request = request + try: + raise Exception("Request error") + except Exception as e: + global captured_exc_info + captured_exc_info = sys.exc_info() + raise e - return wrapper + return cast(ViewFuncT, wrapped_view) # https://github.com/python/mypy/issues/1927 class AdminNotifyHandlerTest(ZulipTestCase): @@ -78,17 +75,18 @@ def test_basic(self, mock_function: MagicMock) -> None: def simulate_error(self) -> logging.LogRecord: self.login("hamlet") - with patch("zerver.decorator.rate_limit") as rate_limit_patch, self.assertLogs( + with patch( + "zerver.lib.rest.authenticated_json_view", side_effect=capture_and_throw + ) as view_decorator_patch, self.assertLogs( "django.request", level="ERROR" ) as request_error_log, self.assertLogs( "zerver.middleware.json_error_handler", level="ERROR" ) as json_error_handler_log, self.settings( TEST_SUITE=False ): - rate_limit_patch.side_effect = capture_and_throw result = self.client_get("/json/users") self.assert_json_error(result, "Internal server error", status_code=500) - rate_limit_patch.assert_called_once() + view_decorator_patch.assert_called_once() self.assertEqual( request_error_log.output, ["ERROR:django.request:Internal Server Error: /json/users"] )