diff --git a/jupyter_server/auth/authorizer.py b/jupyter_server/auth/authorizer.py index f22dbe5463..96b96e97c5 100644 --- a/jupyter_server/auth/authorizer.py +++ b/jupyter_server/auth/authorizer.py @@ -9,7 +9,7 @@ # Distributed under the terms of the Modified BSD License. from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Awaitable from traitlets import Instance from traitlets.config import LoggingConfigurable @@ -44,7 +44,7 @@ class Authorizer(LoggingConfigurable): def is_authorized( self, handler: JupyterHandler, user: User, action: str, resource: str - ) -> bool: + ) -> Awaitable | bool: """A method to determine if ``user`` is authorized to perform ``action`` (read, write, or execute) on the ``resource`` type. diff --git a/jupyter_server/auth/decorator.py b/jupyter_server/auth/decorator.py index fd38cda1e7..a92866b4e8 100644 --- a/jupyter_server/auth/decorator.py +++ b/jupyter_server/auth/decorator.py @@ -2,9 +2,11 @@ """ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +import asyncio from functools import wraps from typing import Any, Callable, Optional, TypeVar, Union, cast +from jupyter_core.utils import ensure_async from tornado.log import app_log from tornado.web import HTTPError @@ -42,7 +44,7 @@ def authorized( def wrapper(method): @wraps(method) - def inner(self, *args, **kwargs): + async def inner(self, *args, **kwargs): # default values for action, resource nonlocal action nonlocal resource @@ -61,8 +63,15 @@ def inner(self, *args, **kwargs): raise HTTPError(status_code=403, log_message=message) # If the user is allowed to do this action, # call the method. - if self.authorizer.is_authorized(self, user, action, resource): - return method(self, *args, **kwargs) + authorized = await ensure_async( + self.authorizer.is_authorized(self, user, action, resource) + ) + if authorized: + out = method(self, *args, **kwargs) + # If the method is a coroutine, await it + if asyncio.iscoroutine(out): + return await out + return out # else raise an exception. else: raise HTTPError(status_code=403, log_message=message) diff --git a/tests/auth/test_authorizer.py b/tests/auth/test_authorizer.py index 08c49eadf0..775c131da2 100644 --- a/tests/auth/test_authorizer.py +++ b/tests/auth/test_authorizer.py @@ -1,12 +1,18 @@ """Tests for authorization""" +import asyncio import json import os +from typing import Awaitable import pytest from jupyter_client.kernelspec import NATIVE_KERNEL_NAME from nbformat import writes from nbformat.v4 import new_notebook +from traitlets import Bool +from jupyter_server.auth.authorizer import Authorizer +from jupyter_server.auth.identity import User +from jupyter_server.base.handlers import JupyterHandler from jupyter_server.services.security import csp_report_uri @@ -217,3 +223,45 @@ async def test_authorized_requests( code = await send_request(url, body=body, method=method) assert code in expected_codes + + +class AsyncAuthorizerTest(Authorizer): + """Test that an asynchronous authorizer would still work.""" + + called = Bool(False) + + async def mock_async_fetch(self): + """Mock an async fetch""" + # Mock a hang for a half a second. + await asyncio.sleep(0.5) + return True + + async def is_authorized( + self, handler: JupyterHandler, user: User, action: str, resource: str + ) -> Awaitable | bool: + response = await self.mock_async_fetch() + self.called = True + return response + + +@pytest.mark.parametrize( + "jp_server_config,", + [ + { + "ServerApp": {"authorizer_class": AsyncAuthorizerTest}, + "jpserver_extensions": {"jupyter_server_terminals": True}, + } + ], +) +async def test_async_authorizer( + request, + io_loop, + send_request, + tmp_path, + jp_serverapp, +): + code = await send_request("/api/status", method="GET") + assert code == 200 + # Ensure that the authorizor method finished its request. + assert hasattr(jp_serverapp.authorizer, "called") + assert jp_serverapp.authorizer.called is True