diff --git a/falcon/_typing.py b/falcon/_typing.py index d82a5bac5..93611fb9c 100644 --- a/falcon/_typing.py +++ b/falcon/_typing.py @@ -62,43 +62,50 @@ class _Unset(Enum): UNSET = auto() +# GENERICS _T = TypeVar('_T') +# there are used to type callables in a way that accept subclasses +_BE = TypeVar('_BE', bound=BaseException) +_REQ = TypeVar('_REQ', bound='Request') +_A_REQ = TypeVar('_A_REQ', bound='AsgiRequest') +_RESP = TypeVar('_RESP', bound='Response') +_A_RESP = TypeVar('_A_RESP', bound='AsgiResponse') + _UNSET = _Unset.UNSET UnsetOr = Union[Literal[_Unset.UNSET], _T] Link = Dict[str, str] CookieArg = Mapping[str, Union[str, Cookie]] # Error handlers -ErrorHandler = Callable[['Request', 'Response', BaseException, Dict[str, Any]], None] +ErrorHandler = Callable[[_REQ, _RESP, _BE, Dict[str, Any]], None] -class AsgiErrorHandler(Protocol): +class AsgiErrorHandler(Protocol[_A_REQ, _A_RESP, _BE]): async def __call__( self, - req: AsgiRequest, - resp: Optional[AsgiResponse], - error: BaseException, + req: _A_REQ, + resp: Optional[_A_RESP], + error: _BE, params: Dict[str, Any], + /, *, ws: Optional[WebSocket] = ..., ) -> None: ... # Error serializers -ErrorSerializer = Callable[['Request', 'Response', 'HTTPError'], None] +ErrorSerializer = Callable[[_REQ, _RESP, 'HTTPError'], None] # Sinks SinkPrefix = Union[str, Pattern[str]] -class SinkCallable(Protocol): - def __call__(self, req: Request, resp: Response, **kwargs: str) -> None: ... +class SinkCallable(Protocol[_REQ, _RESP]): + def __call__(self, req: _REQ, resp: _RESP, /, **kwargs: str) -> None: ... -class AsgiSinkCallable(Protocol): - async def __call__( - self, req: AsgiRequest, resp: AsgiResponse, **kwargs: str - ) -> None: ... +class AsgiSinkCallable(Protocol[_A_REQ, _A_RESP]): + async def __call__(self, req: _A_REQ, resp: _A_RESP, /, **kwargs: str) -> None: ... HeaderMapping = Mapping[str, str] @@ -111,62 +118,56 @@ async def __call__( # WSGI -class ResponderMethod(Protocol): +class ResponderMethod(Protocol[_REQ, _RESP]): def __call__( self, resource: Resource, - req: Request, - resp: Response, + req: _REQ, + resp: _RESP, + /, **kwargs: Any, ) -> None: ... -class ResponderCallable(Protocol): - def __call__(self, req: Request, resp: Response, **kwargs: Any) -> None: ... +class ResponderCallable(Protocol[_REQ, _RESP]): + def __call__(self, req: _REQ, resp: _RESP, /, **kwargs: Any) -> None: ... -ProcessRequestMethod = Callable[['Request', 'Response'], None] -ProcessResourceMethod = Callable[ - ['Request', 'Response', Resource, Dict[str, Any]], None -] -ProcessResponseMethod = Callable[['Request', 'Response', Resource, bool], None] +ProcessRequestMethod = Callable[[_REQ, _RESP], None] +ProcessResourceMethod = Callable[[_REQ, _RESP, Resource, Dict[str, Any]], None] +ProcessResponseMethod = Callable[[_REQ, _RESP, Resource, bool], None] # ASGI -class AsgiResponderMethod(Protocol): +class AsgiResponderMethod(Protocol[_A_REQ, _A_RESP]): async def __call__( self, resource: Resource, - req: AsgiRequest, - resp: AsgiResponse, + req: _A_REQ, + resp: _A_RESP, + /, **kwargs: Any, ) -> None: ... -class AsgiResponderCallable(Protocol): - async def __call__( - self, req: AsgiRequest, resp: AsgiResponse, **kwargs: Any - ) -> None: ... +class AsgiResponderCallable(Protocol[_A_REQ, _A_RESP]): + async def __call__(self, req: _A_REQ, resp: _A_RESP, /, **kwargs: Any) -> None: ... -class AsgiResponderWsCallable(Protocol): - async def __call__( - self, req: AsgiRequest, ws: WebSocket, **kwargs: Any - ) -> None: ... +class AsgiResponderWsCallable(Protocol[_A_REQ]): + async def __call__(self, req: _A_REQ, ws: WebSocket, /, **kwargs: Any) -> None: ... AsgiReceive = Callable[[], Awaitable['AsgiEvent']] AsgiSend = Callable[['AsgiSendMsg'], Awaitable[None]] -AsgiProcessRequestMethod = Callable[['AsgiRequest', 'AsgiResponse'], Awaitable[None]] +AsgiProcessRequestMethod = Callable[[_A_REQ, _A_RESP], Awaitable[None]] AsgiProcessResourceMethod = Callable[ - ['AsgiRequest', 'AsgiResponse', Resource, Dict[str, Any]], Awaitable[None] -] -AsgiProcessResponseMethod = Callable[ - ['AsgiRequest', 'AsgiResponse', Resource, bool], Awaitable[None] + [_A_REQ, _A_RESP, Resource, Dict[str, Any]], Awaitable[None] ] -AsgiProcessRequestWsMethod = Callable[['AsgiRequest', 'WebSocket'], Awaitable[None]] +AsgiProcessResponseMethod = Callable[[_A_REQ, _A_RESP, Resource, bool], Awaitable[None]] +AsgiProcessRequestWsMethod = Callable[['_A_REQ', 'WebSocket'], Awaitable[None]] AsgiProcessResourceWsMethod = Callable[ - ['AsgiRequest', 'WebSocket', Resource, Dict[str, Any]], Awaitable[None] + [_A_REQ, 'WebSocket', Resource, Dict[str, Any]], Awaitable[None] ] ResponseCallbacks = Union[ Tuple[Callable[[], None], Literal[False]], @@ -182,9 +183,9 @@ async def __call__( ] -class FindMethod(Protocol): +class FindMethod(Protocol[_REQ]): def __call__( - self, uri: str, req: Optional[Request] + self, uri: str, req: Optional[_REQ] ) -> Optional[Tuple[object, MethodDict, Dict[str, Any], Optional[str]]]: ... diff --git a/falcon/app.py b/falcon/app.py index f74083693..bc1c18062 100644 --- a/falcon/app.py +++ b/falcon/app.py @@ -36,7 +36,6 @@ Pattern, Tuple, Type, - TypeVar, Union, ) import warnings @@ -45,6 +44,9 @@ from falcon import constants from falcon import responders from falcon import routing +from falcon._typing import _BE +from falcon._typing import _REQ +from falcon._typing import _RESP from falcon._typing import AsgiResponderCallable from falcon._typing import AsgiResponderWsCallable from falcon._typing import AsgiSinkCallable @@ -90,7 +92,6 @@ status.HTTP_304, ] ) -_BE = TypeVar('_BE', bound=BaseException) class App: @@ -261,7 +262,9 @@ def process_response( ) _cors_enable: bool - _error_handlers: Dict[Type[BaseException], ErrorHandler] + _error_handlers: Dict[ + Type[BaseException], ErrorHandler[Request, Response, BaseException] + ] _independent_middleware: bool _middleware: helpers.PreparedMiddlewareResult _request_type: Type[Request] @@ -817,20 +820,20 @@ def add_sink(self, sink: SinkCallable, prefix: SinkPrefix = r'/') -> None: def add_error_handler( self, exception: Type[_BE], - handler: Callable[[Request, Response, _BE, Dict[str, Any]], None], + handler: ErrorHandler[_REQ, _RESP, _BE], ) -> None: ... @overload def add_error_handler( self, - exception: Union[Type[BaseException], Iterable[Type[BaseException]]], - handler: Optional[ErrorHandler] = None, + exception: Union[Type[_BE], Iterable[Type[_BE]]], + handler: Optional[ErrorHandler[_REQ, _RESP, _BE]] = None, ) -> None: ... def add_error_handler( # type: ignore[misc] self, exception: Union[Type[BaseException], Iterable[Type[BaseException]]], - handler: Optional[ErrorHandler] = None, + handler: Optional[ErrorHandler[_REQ, _RESP, _BE]] = None, ) -> None: """Register a handler for one or more exception types. diff --git a/falcon/asgi/app.py b/falcon/asgi/app.py index f3b637802..130335704 100644 --- a/falcon/asgi/app.py +++ b/falcon/asgi/app.py @@ -22,8 +22,6 @@ import traceback from typing import ( Any, - Awaitable, - Callable, ClassVar, Dict, Iterable, @@ -33,13 +31,15 @@ Tuple, Type, TYPE_CHECKING, - TypeVar, Union, ) from falcon import constants from falcon import responders from falcon import routing +from falcon._typing import _A_REQ +from falcon._typing import _A_RESP +from falcon._typing import _BE from falcon._typing import _UNSET from falcon._typing import AsgiErrorHandler from falcon._typing import AsgiReceive @@ -93,7 +93,6 @@ _TYPELESS_STATUS_CODES = frozenset([204, 304]) _FALLBACK_WS_ERROR_CODE = 3011 -_BE = TypeVar('_BE', bound=BaseException) class App(falcon.app.App): @@ -351,7 +350,9 @@ async def process_resource_ws( 'ws_options', ) - _error_handlers: Dict[Type[BaseException], AsgiErrorHandler] # type: ignore[assignment] + _error_handlers: Dict[ + Type[BaseException], AsgiErrorHandler[Request, Response, BaseException] + ] # type: ignore[assignment] _middleware: AsyncPreparedMiddlewareResult # type: ignore[assignment] _middleware_ws: AsyncPreparedMiddlewareWsResult _request_type: Type[Request] @@ -862,20 +863,20 @@ def add_sink(self, sink: AsgiSinkCallable, prefix: SinkPrefix = r'/') -> None: def add_error_handler( self, exception: Type[_BE], - handler: Callable[[Request, Response, _BE, Dict[str, Any]], Awaitable[None]], + handler: AsgiErrorHandler[_A_REQ, _A_RESP, _BE], ) -> None: ... @overload def add_error_handler( self, - exception: Union[Type[BaseException], Iterable[Type[BaseException]]], - handler: Optional[AsgiErrorHandler] = None, + exception: Union[Type[_BE], Iterable[Type[_BE]]], + handler: Optional[AsgiErrorHandler[_A_REQ, _A_RESP, _BE]] = None, ) -> None: ... def add_error_handler( # type: ignore[misc] self, exception: Union[Type[BaseException], Iterable[Type[BaseException]]], - handler: Optional[AsgiErrorHandler] = None, + handler: Optional[AsgiErrorHandler[_A_REQ, _A_RESP, _BE]] = None, ) -> None: """Register a handler for one or more exception types. diff --git a/tests/typing_only/__init__.py b/tests/typing_only/__init__.py new file mode 100644 index 000000000..0999bfb27 --- /dev/null +++ b/tests/typing_only/__init__.py @@ -0,0 +1,4 @@ +"""In this packages there are files that must pass mypy typing. + +Currently only positive test (meaning no errors) are supported. +""" diff --git a/tests/typing_only/error_handlers.py b/tests/typing_only/error_handlers.py new file mode 100644 index 000000000..43dea6b97 --- /dev/null +++ b/tests/typing_only/error_handlers.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from typing import Any + +import falcon +from falcon import asgi + + +class MyRequest(falcon.Request): + def some(self) -> bool: + return True + + +class MyResponse(falcon.Response): + def method(self) -> bool: + return True + + +def hook_1( + req: MyRequest, resp: falcon.Response, err: ValueError, params: dict[str, Any] +) -> None: ... +def hook_2( + req: falcon.Request, resp: MyResponse, err: NameError, params: dict[str, Any] +) -> None: ... +def hook_3( + req: MyRequest, resp: MyResponse, err: LookupError, params: dict[str, Any] +) -> None: ... +def hook_4( + req: falcon.Request, + resp: falcon.Response, + err: AttributeError, + params: dict[str, Any], +) -> None: ... + + +app1 = falcon.App() +app1.add_error_handler(ValueError, hook_1) +app1.add_error_handler(NameError, hook_2) +app1.add_error_handler(LookupError, hook_3) +app1.add_error_handler(AttributeError, hook_4) +app1.add_error_handler([ValueError], hook_1) +app1.add_error_handler([NameError], hook_2) +app1.add_error_handler([LookupError], hook_3) +app1.add_error_handler([AttributeError], hook_4) +# TODO: test these errors somehow +# app1.add_error_handler(BufferError, hook_1) +# app1.add_error_handler(BufferError, hook_2) +# app1.add_error_handler(BufferError, hook_3) +# app1.add_error_handler(BufferError, hook_4) + +app2 = falcon.App(request_type=MyRequest, response_type=MyResponse) +app2.add_error_handler(ValueError, hook_1) +app2.add_error_handler(NameError, hook_2) +app2.add_error_handler(LookupError, hook_3) +app2.add_error_handler(AttributeError, hook_4) +app2.add_error_handler([ValueError], hook_1) +app2.add_error_handler([NameError], hook_2) +app2.add_error_handler([LookupError], hook_3) +app2.add_error_handler([AttributeError], hook_4) + +# ---- +# asgi +# ---- + + +class AMyRequest(asgi.Request): + def some(self) -> bool: + return True + + +class AMyResponse(asgi.Response): + def method(self) -> bool: + return True + + +async def a_hook_1( + req: AMyRequest, + resp: asgi.Response | None, + err: ValueError, + params: dict[str, Any], + *, + ws: asgi.WebSocket | None = None, +) -> None: ... +async def a_hook_2( + req: asgi.Request, + resp: AMyResponse | None, + err: NameError, + params: dict[str, Any], + *, + ws: asgi.WebSocket | None = None, +) -> None: ... +async def a_hook_3( + req: AMyRequest, + resp: AMyResponse | None, + err: LookupError, + params: dict[str, Any], + *, + ws: asgi.WebSocket | None = None, +) -> None: ... +async def a_hook_4( + req: asgi.Request, + resp: asgi.Response | None, + err: AttributeError, + params: dict[str, Any], + *, + ws: asgi.WebSocket | None = None, +) -> None: ... + + +a_app1 = asgi.App() +a_app1.add_error_handler(ValueError, a_hook_1) +a_app1.add_error_handler(NameError, a_hook_2) +a_app1.add_error_handler(LookupError, a_hook_3) +a_app1.add_error_handler(AttributeError, a_hook_4) +a_app1.add_error_handler([ValueError], a_hook_1) +a_app1.add_error_handler([NameError], a_hook_2) +a_app1.add_error_handler([LookupError], a_hook_3) +a_app1.add_error_handler([AttributeError], a_hook_4) + + +a_app2 = asgi.App(request_type=AMyRequest, response_type=AMyResponse) +a_app2.add_error_handler(ValueError, a_hook_1) +a_app2.add_error_handler(NameError, a_hook_2) +a_app2.add_error_handler(LookupError, a_hook_3) +a_app2.add_error_handler(AttributeError, a_hook_4) +a_app2.add_error_handler([ValueError], a_hook_1) +a_app2.add_error_handler([NameError], a_hook_2) +a_app2.add_error_handler([LookupError], a_hook_3) +a_app2.add_error_handler([AttributeError], a_hook_4)