diff --git a/.gitignore b/.gitignore index b3f862cd8..e13685883 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ *.py[cod] .eggs .installed.cfg +.mypy_cache build develop-eggs dist @@ -20,6 +21,7 @@ lib64 parts sdist var +pip-wheel-metadata # Installer logs pip-log.txt diff --git a/.travis.yml b/.travis.yml index 860d21c95..c27fcfdd9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,22 +9,27 @@ cache: matrix: include: - - python: pypy3.5-6.0 + - python: pypy3.6-7.1.1 env: TOXENV=pypy3 - python: 3.8 env: TOXENV=pep8 - python: 3.8 env: TOXENV=pep8-examples - - python: 3.5 + - python: 3.8 + env: TOXENV=mypy + # NOTE(kgriffs): 3.5.2 is the default Python 3 version on Ubuntu 16.04 + # so we pin to that for testing to make sure we are working around + # and quirks that were fixed in later micro versions. + - python: 3.5.2 env: TOXENV=py35 - python: 3.6 env: TOXENV=py36 + - python: 3.6 + env: TOXENV=py36_cython - python: 3.7 env: TOXENV=py37 - python: 3.8 env: TOXENV=py38 - - python: 3.8 - env: TOXENV=mypy - python: 3.8 env: TOXENV=py38_cython - python: 3.8 @@ -47,7 +52,7 @@ matrix: - python: 3.8 env: TOXENV=check_vendored -script: tox +script: tox -- -v notifications: webhooks: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 484390c14..ef18ff56f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -27,20 +27,26 @@ Please note that all contributors and maintainers of this project are subject to Before submitting a pull request, please ensure you have added or updated tests as appropriate, and that all existing tests still pass with your changes. Please also ensure that your coding style follows PEP 8. -You can check all this by running the following from within the Falcon project directory (requires Python 3.8 to be installed on your system): +You can check all this by running the following from within the Falcon project directory (requires Python 3.8 and 3.5 to be installed on your system): ```bash $ tools/mintest.sh - ``` -You may also use Python 3.5, 3.6 or 3.7 if you don't have 3.8 installed on your system. Substitute "py35", "py36" or "py37" as appropriate. For example: +You may also use Python 3.6 or 3.7 if you don't have 3.8 installed on your system. Substitute "py36" or "py37" as appropriate. For example: ```bash $ pip install -U tox coverage $ rm -f .coverage.* -$ tox -e pep8 && tox -e py37 && tools/testing/combine_coverage.sh +$ tox -e pep8 && tox -e py35,py37 && tools/testing/combine_coverage.sh +``` + +If you are using pyenv, you will need to make sure both 3.8 and 3.5 are available in the current shell, e.g.: + +```bash +$ pyenv shell 3.8.0 3.5.8 +``` #### Reviews diff --git a/README.rst b/README.rst index ec7f7532a..76186f074 100644 --- a/README.rst +++ b/README.rst @@ -231,7 +231,8 @@ Installing it is as simple as: Installing the Falcon wheel is a great way to get up and running quickly in a development environment, but for an extra speed boost when deploying your application in production, Falcon can compile itself with -Cython. +Cython. Note, however, that Cython is currently incompatible with +the falcon.asgi module. The following commands tell pip to install Cython, and then to invoke Falcon's ``setup.py``, which will in turn detect the presence of Cython diff --git a/docs/api/testing.rst b/docs/api/testing.rst index 607710b82..ce91d080a 100644 --- a/docs/api/testing.rst +++ b/docs/api/testing.rst @@ -11,5 +11,5 @@ Reference simulate_request, simulate_get, simulate_head, simulate_post, simulate_put, simulate_options, simulate_patch, simulate_delete, TestClient, TestCase, SimpleTestResource, StartResponseMock, - capture_responder_args, rand_string, create_environ, redirected, - closed_wsgi_iterable + capture_responder_args, rand_string, create_environ, create_req, + create_asgi_req, redirected, closed_wsgi_iterable diff --git a/docs/user/install.rst b/docs/user/install.rst index 2ce1fffd6..bdad07903 100644 --- a/docs/user/install.rst +++ b/docs/user/install.rst @@ -35,7 +35,8 @@ Installing it is as simple as: Installing the Falcon wheel is a great way to get up and running quickly in a development environment, but for an extra speed boost when deploying your application in production, Falcon can compile itself with -Cython. +Cython. Note, however, that Cython is currently incompatible with +the falcon.asgi module. The following commands tell pip to install Cython, and then to invoke Falcon's ``setup.py``, which will in turn detect the presence of Cython diff --git a/falcon/__init__.py b/falcon/__init__.py index 1ec459773..8034e5980 100644 --- a/falcon/__init__.py +++ b/falcon/__init__.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Primary package for Falcon, the minimalist WSGI library. +"""Primary package for Falcon, the minimalist web API framework. -Falcon is a minimalist WSGI library for building speedy web APIs and app +Falcon is a minimalist web API framework for building speedy web APIs and app backends. The `falcon` package can be used to directly access most of the framework's classes, functions, and variables:: @@ -24,6 +24,9 @@ """ +import logging as _logging +import sys as _sys + # Hoist classes and functions into the falcon namespace from falcon.version import __version__ # NOQA from falcon.constants import * # NOQA @@ -44,3 +47,18 @@ from falcon.hooks import before, after # NOQA from falcon.request import Request, RequestOptions, Forwarded # NOQA from falcon.response import Response, ResponseOptions # NOQA + + +ASGI_SUPPORTED = _sys.version_info.minor > 5 +"""Set to ``True`` when ASGI is supported for the current Python version.""" + + +# NOTE(kgriffs): Special singleton to be used internally whenever using +# None would be ambiguous. +_UNSET = object() + + +# NOTE(kgriffs): Only to be used internally on the rare occasion that we +# need to log something that we can't communicate any other way. +_logger = _logging.getLogger('falcon') +_logger.addHandler(_logging.NullHandler()) diff --git a/falcon/app.py b/falcon/app.py index 207665bf2..c6d1a6211 100644 --- a/falcon/app.py +++ b/falcon/app.py @@ -15,6 +15,7 @@ """Falcon App class.""" from functools import wraps +from inspect import iscoroutinefunction import re import traceback @@ -63,9 +64,11 @@ class App: number of constants for common media types, such as ``falcon.MEDIA_MSGPACK``, ``falcon.MEDIA_YAML``, ``falcon.MEDIA_XML``, etc. - middleware(object or list): Either a single object or a list - of objects (instantiated classes) that implement the - following middleware component interface:: + middleware: Either a single middleware component object or an iterable + of objects (instantiated classes) that implement the following + middleware component interface. Note that it is only necessary + to implement the methods for the events you would like to + handle; Falcon simply skips over any missing middleware methods:: class ExampleComponent: def process_request(self, req, resp): @@ -164,18 +167,32 @@ def process_response(self, req, resp, resource, req_succeeded) _STREAM_BLOCK_SIZE = 8 * 1024 # 8 KiB + _STATIC_ROUTE_TYPE = routing.StaticRoute + + # NOTE(kgriffs): This makes it easier to tell what we are dealing with + # without having to import falcon.asgi to get at the falcon.asgi.App + # type (which we may not be able to do under Python 3.5). + _ASGI = False + + # NOTE(kgriffs): We do it like this rather than just implementing the + # methods directly on the class, so that we keep all the default + # responders colocated in the same module. This will make it more + # likely that the implementations of the async and non-async versions + # of the methods are kept in sync (pun intended). + _default_responder_bad_request = falcon.responders.bad_request + _default_responder_path_not_found = falcon.responders.path_not_found + __slots__ = ('_request_type', '_response_type', - '_error_handlers', '_media_type', '_router', '_sinks', + '_error_handlers', '_router', '_sinks', '_serialize_error', 'req_options', 'resp_options', '_middleware', '_independent_middleware', '_router_search', - '_static_routes', '_cors_enable') + '_static_routes', '_cors_enable', '_unprepared_middleware') def __init__(self, media_type=DEFAULT_MEDIA_TYPE, request_type=Request, response_type=Response, middleware=None, router=None, independent_middleware=True, cors_enable=False): self._sinks = [] - self._media_type = media_type self._static_routes = [] if cors_enable: @@ -197,9 +214,9 @@ def __init__(self, media_type=DEFAULT_MEDIA_TYPE, middleware = [middleware, cm] # set middleware - self._middleware = helpers.prepare_middleware( - middleware, independent_middleware=independent_middleware) + self._unprepared_middleware = [] self._independent_middleware = independent_middleware + self.add_middleware(middleware) self._router = router or routing.DefaultRouter() self._router_search = self._router.find @@ -221,6 +238,32 @@ def __init__(self, media_type=DEFAULT_MEDIA_TYPE, self.add_error_handler(falcon.HTTPError, self._http_error_handler) self.add_error_handler(falcon.HTTPStatus, self._http_status_handler) + def add_middleware(self, middleware): + """Add one or more additional middleware components. + + Arguments: + middleware: Either a single middleware component or an iterable + of components to add. The component(s) will be invoked, in + order, as if they had been appended to the original middleware + list passed to the class initializer. + """ + + # NOTE(kgriffs): Since this is called by the initializer, there is + # the chance that middleware may be None. + if middleware: + try: + self._unprepared_middleware += middleware + except TypeError: # middleware is not iterable; assume it is just one bare component + self._unprepared_middleware.append(middleware) + + # NOTE(kgriffs): Even if middleware is None or an empty list, we still + # need to make sure self._middleware is initialized if this is the + # first call to add_middleware(). + self._middleware = self._prepare_middleware( + self._unprepared_middleware, + independent_middleware=self._independent_middleware + ) + def __call__(self, env, start_response): # noqa: C901 """WSGI `app` method. @@ -248,83 +291,82 @@ def __call__(self, env, start_response): # noqa: C901 req_succeeded = False try: - try: - # NOTE(ealogar): The execution of request middleware - # should be before routing. This will allow request mw - # to modify the path. - # NOTE: if flag set to use independent middleware, execute - # request middleware independently. Otherwise, only queue - # response middleware after request middleware succeeds. - if self._independent_middleware: - for process_request in mw_req_stack: + # NOTE(ealogar): The execution of request middleware + # should be before routing. This will allow request mw + # to modify the path. + # NOTE: if flag set to use independent middleware, execute + # request middleware independently. Otherwise, only queue + # response middleware after request middleware succeeds. + if self._independent_middleware: + for process_request in mw_req_stack: + process_request(req, resp) + if resp.complete: + break + else: + for process_request, process_response in mw_req_stack: + if process_request and not resp.complete: process_request(req, resp) + if process_response: + dependent_mw_resp_stack.insert(0, process_response) + + if not resp.complete: + # NOTE(warsaw): Moved this to inside the try except + # because it is possible when using object-based + # traversal for _get_responder() to fail. An example is + # a case where an object does not have the requested + # next-hop child resource. In that case, the object + # being asked to dispatch to its child will raise an + # HTTP exception signalling the problem, e.g. a 404. + responder, params, resource, req.uri_template = self._get_responder(req) + except Exception as ex: + if not self._handle_exception(req, resp, ex, params): + raise + else: + try: + # NOTE(kgriffs): If the request did not match any + # route, a default responder is returned and the + # resource is None. In that case, we skip the + # resource middleware methods. Resource will also be + # None when a middleware method already set + # resp.complete to True. + if resource: + # Call process_resource middleware methods. + for process_resource in mw_rsrc_stack: + process_resource(req, resp, resource, params) if resp.complete: break - else: - for process_request, process_response in mw_req_stack: - if process_request and not resp.complete: - process_request(req, resp) - if process_response: - dependent_mw_resp_stack.insert(0, process_response) if not resp.complete: - # NOTE(warsaw): Moved this to inside the try except - # because it is possible when using object-based - # traversal for _get_responder() to fail. An example is - # a case where an object does not have the requested - # next-hop child resource. In that case, the object - # being asked to dispatch to its child will raise an - # HTTP exception signalling the problem, e.g. a 404. - responder, params, resource, req.uri_template = self._get_responder(req) + responder(req, resp, **params) + + req_succeeded = True except Exception as ex: if not self._handle_exception(req, resp, ex, params): raise - else: - try: - # NOTE(kgriffs): If the request did not match any - # route, a default responder is returned and the - # resource is None. In that case, we skip the - # resource middleware methods. Resource will also be - # None when a middleware method already set - # resp.complete to True. - if resource: - # Call process_resource middleware methods. - for process_resource in mw_rsrc_stack: - process_resource(req, resp, resource, params) - if resp.complete: - break - - if not resp.complete: - responder(req, resp, **params) - - req_succeeded = True - except Exception as ex: - if not self._handle_exception(req, resp, ex, params): - raise - finally: - # NOTE(kgriffs): It may not be useful to still execute - # response middleware methods in the case of an unhandled - # exception, but this is done for the sake of backwards - # compatibility, since it was incidentally the behavior in - # the 1.0 release before this section of the code was - # reworked. - - # Call process_response middleware methods. - for process_response in mw_resp_stack or dependent_mw_resp_stack: - try: - process_response(req, resp, resource, req_succeeded) - except Exception as ex: - if not self._handle_exception(req, resp, ex, params): - raise - req_succeeded = False + # Call process_response middleware methods. + for process_response in mw_resp_stack or dependent_mw_resp_stack: + try: + process_response(req, resp, resource, req_succeeded) + except Exception as ex: + if not self._handle_exception(req, resp, ex, params): + raise + + req_succeeded = False - # - # Set status and headers - # + body = [] + length = 0 + + try: + body, length = self._get_body(resp, env.get('wsgi.file_wrapper')) + except Exception as ex: + if not self._handle_exception(req, resp, ex, params): + raise + + req_succeeded = False resp_status = resp.status - media_type = self._media_type + default_media_type = self.resp_options.default_media_type if req.method == 'HEAD' or resp_status in _BODILESS_STATUS_CODES: body = [] @@ -338,11 +380,20 @@ def __call__(self, env, start_response): # noqa: C901 # presence of the Content-Length header is not similarly # enforced. if resp_status in _TYPELESS_STATUS_CODES: - media_type = None + default_media_type = None + elif ( + length is not None and + req.method == 'HEAD'and + resp_status not in _BODILESS_STATUS_CODES and + 'content-length' not in resp._headers + ): + # NOTE(kgriffs): We really should be returning a Content-Length + # in this case according to my reading of the RFCs. By + # optionally using len(data) we let a resource simulate HEAD + # by turning around and calling it's own on_get(). + resp._headers['content-length'] = str(length) else: - body, length = self._get_body(resp, env.get('wsgi.file_wrapper')) - # PERF(kgriffs): Böse mußt sein. Operate directly on resp._headers # to reduce overhead since this is a hot/critical code path. # NOTE(kgriffs): We always set content-length to match the @@ -353,7 +404,7 @@ def __call__(self, env, start_response): # noqa: C901 if length is not None: resp._headers['content-length'] = str(length) - headers = resp._wsgi_headers(media_type) + headers = resp._wsgi_headers(default_media_type) # Return the response per the WSGI spec. start_response(resp_status, headers) @@ -445,6 +496,10 @@ def add_static_route(self, prefix, directory, downloadable=False, fallback_filen For security reasons, the directory and the fallback_filename (if provided) should be read only for the account running the application. + Note: + For ASGI apps, file reads are made non-blocking by scheduling + them on the default executor. + Static routes are matched in LIFO order. Therefore, if the same prefix is used for two routes, the second one will override the first. This also means that more specific routes should be added @@ -479,8 +534,8 @@ def add_static_route(self, prefix, directory, downloadable=False, fallback_filen self._static_routes.insert( 0, - routing.StaticRoute(prefix, directory, downloadable=downloadable, - fallback_filename=fallback_filename) + self._STATIC_ROUTE_TYPE(prefix, directory, downloadable=downloadable, + fallback_filename=fallback_filename) ) def add_sink(self, sink, prefix=r'/'): @@ -603,9 +658,21 @@ def handle(req, resp, ex, params): """ def wrap_old_handler(old_handler): + # NOTE(kgriffs): This branch *is* actually tested by + # test_error_handlers.test_handler_signature_shim_asgi() (as + # verified manually via pdb), but for some reason coverage + # tracking isn't picking it up. + if iscoroutinefunction(old_handler): # pragma: no cover + @wraps(old_handler) + async def handler_async(req, resp, ex, params): + await old_handler(ex, req, resp, params) + + return handler_async + @wraps(old_handler) def handler(req, resp, ex, params): old_handler(ex, req, resp, params) + return handler if handler is None: @@ -698,11 +765,17 @@ def my_serializer(req, resp, exception): # Helpers that require self # ------------------------------------------------------------------------ + def _prepare_middleware(self, middleware=None, independent_middleware=False): + return helpers.prepare_middleware( + middleware=middleware, + independent_middleware=independent_middleware + ) + def _get_responder(self, req): """Search routes for a matching responder. Args: - req: The request object. + req (Request): The request object. Returns: tuple: A 4-member tuple consisting of a responder callable, @@ -746,7 +819,13 @@ def _get_responder(self, req): try: responder = method_map[method] except KeyError: - responder = falcon.responders.bad_request + # NOTE(kgriffs): Dirty hack! We use __class__ here to avoid + # binding self to the default responder method. We could + # decorate the function itself with @staticmethod, but it + # would perhaps be less obvious to the reader why this is + # needed when just looking at the code in the reponder + # module, so we just grab it directly here. + responder = self.__class__._default_responder_bad_request else: params = {} @@ -764,7 +843,7 @@ def _get_responder(self, req): responder = sr break else: - responder = falcon.responders.path_not_found + responder = self.__class__._default_responder_path_not_found return (responder, params, resource, uri_template) diff --git a/falcon/app_helpers.py b/falcon/app_helpers.py index 78afed9dd..f6b4ce2a6 100644 --- a/falcon/app_helpers.py +++ b/falcon/app_helpers.py @@ -14,19 +14,28 @@ """Utilities for the App class.""" +from inspect import iscoroutinefunction + from falcon import util +from falcon.errors import CompatibilityError +from falcon.util.sync import _wrap_non_coroutine_unsafe -def prepare_middleware(middleware=None, independent_middleware=False): - """Check middleware interface and prepare it to iterate. +def prepare_middleware(middleware, independent_middleware=False, asgi=False): + """Check middleware interfaces and prepare the methods for request handling. - Args: - middleware: list (or object) of input middleware - independent_middleware: bool whether should prepare request and - response middleware independently + Arguments: + middleware (iterable): An iterable of middleware objects. + + Keyword Args: + independent_middleware (bool): ``True`` if the request and + response middleware methods should be treated independently + (default ``False``) + asgi (bool): ``True`` if an ASGI app, ``False`` otherwise + (default ``False``) Returns: - list: A tuple of prepared middleware tuples + tuple: A tuple of prepared middleware method tuples """ # PERF(kgriffs): do getattr calls once, in advance, so we don't @@ -35,22 +44,78 @@ def prepare_middleware(middleware=None, independent_middleware=False): resource_mw = [] response_mw = [] - if middleware is None: - middleware = [] - else: - if not isinstance(middleware, list): - middleware = [middleware] - for component in middleware: - process_request = util.get_bound_method(component, - 'process_request') - process_resource = util.get_bound_method(component, - 'process_resource') - process_response = util.get_bound_method(component, - 'process_response') + # NOTE(kgriffs): Middleware that uses parts of the Request and Response + # interfaces that are the same between ASGI and WSGI (most of it is, + # and we should probably define this via ABC) can just implement + # the method names without the *_async postfix. If a middleware + # component wants to provide an alternative implementation that + # does some work that requires async def, or something specific about + # the ASGI Request/Response classes, the component can implement the + # *_async method in that case. + # + # Middleware that is WSGI-only or ASGI-only can simply implement all + # methods without the *_async postfix. Regardless, components should + # clearly document their compatibility with WSGI vs. ASGI. + + if asgi: + process_request = ( + util.get_bound_method(component, 'process_request_async') or + _wrap_non_coroutine_unsafe( + util.get_bound_method(component, 'process_request') + ) + ) + + process_resource = ( + util.get_bound_method(component, 'process_resource_async') or + _wrap_non_coroutine_unsafe( + util.get_bound_method(component, 'process_resource') + ) + ) + + process_response = ( + util.get_bound_method(component, 'process_response_async') or + _wrap_non_coroutine_unsafe( + util.get_bound_method(component, 'process_response') + ) + ) + + for m in (process_request, process_resource, process_response): + if m and not iscoroutinefunction(m): + msg = ( + '{} must be implemented as an awaitable coroutine. If ' + 'you would like to retain compatibility ' + 'with WSGI apps, the coroutine versions of the ' + 'middleware methods may be implemented side-by-side ' + 'by applying an *_async postfix to the method names. ' + ) + raise CompatibilityError(msg.format(m)) + + else: + process_request = util.get_bound_method(component, 'process_request') + process_resource = util.get_bound_method(component, 'process_resource') + process_response = util.get_bound_method(component, 'process_response') + + for m in (process_request, process_resource, process_response): + if m and iscoroutinefunction(m): + msg = ( + '{} may not implement coroutine methods and ' + 'remain compatible with WSGI apps without ' + 'using the *_async postfix to explicitly identify ' + 'the coroutine version of a given middleware ' + 'method.' + ) + raise CompatibilityError(msg.format(component)) if not (process_request or process_resource or process_response): - msg = '{0} does not implement the middleware interface' + if asgi and ( + hasattr(component, 'process_startup') or hasattr(component, 'process_shutdown') + ): + # NOTE(kgriffs): This middleware only has ASGI lifespan + # event handlers + continue + + msg = '{0} must implement at least one middleware method' raise TypeError(msg.format(component)) # NOTE: depending on whether we want to execute middleware diff --git a/falcon/asgi/__init__.py b/falcon/asgi/__init__.py new file mode 100644 index 000000000..c2eb2b405 --- /dev/null +++ b/falcon/asgi/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2019 by Kurt Griffiths. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ASGI package for Falcon, the minimalist web API framework. + +The `falcon.asgi` package can be used to directly access most of +the framework's ASGI-related classes, functions, and variables:: + + import falcon.asgi + + app = falcon.asgi.API() + +Some ASGI-related methods and classes are found in other modules +(most notably falcon.testing) when (A) they are compatible with Python 3.5, +and (B) their purpose is particularly cohesive with that of the module in +question. +""" + +import sys as _sys + +if _sys.version_info.minor < 6: + raise ImportError('falcon.asgi requires Python 3.6+') + +from .app import App # NOQA +from .structures import SSEvent # NOQA +from .request import Request # NOQA +from .response import Response # NOQA +from .stream import BoundedStream # NOQA diff --git a/falcon/asgi/_request_helpers.py b/falcon/asgi/_request_helpers.py new file mode 100644 index 000000000..27a5e6138 --- /dev/null +++ b/falcon/asgi/_request_helpers.py @@ -0,0 +1,36 @@ +# Copyright 2019 by Kurt Griffiths +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def header_property(header_name): + """Create a read-only header property. + + Args: + wsgi_name (str): Case-sensitive name of the header as it would + appear in the WSGI environ ``dict`` (i.e., 'HTTP_*') + + Returns: + A property instance than can be assigned to a class variable. + + """ + + header_name = header_name.lower() + + def fget(self): + try: + return self._asgi_headers[header_name] or None + except KeyError: + return None + + return property(fget) diff --git a/falcon/asgi/app.py b/falcon/asgi/app.py new file mode 100644 index 000000000..e62529cdb --- /dev/null +++ b/falcon/asgi/app.py @@ -0,0 +1,645 @@ +# Copyright 2019 by Kurt Griffiths +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ASGI application class.""" + +from inspect import isasyncgenfunction, iscoroutinefunction +import traceback + +import falcon.app +from falcon.app_helpers import prepare_middleware +from falcon.errors import CompatibilityError, UnsupportedError, UnsupportedScopeError +from falcon.http_error import HTTPError +from falcon.http_status import HTTPStatus +import falcon.routing +from falcon.util.misc import http_status_to_code +from falcon.util.sync import _wrap_non_coroutine_unsafe, get_loop +from .request import Request +from .response import Response +from .structures import SSEvent + + +__all__ = ['App'] + + +_EVT_RESP_EOF = {'type': 'http.response.body'} + +_BODILESS_STATUS_CODES = frozenset([ + 100, + 101, + 204, + 304, +]) + +_TYPELESS_STATUS_CODES = frozenset([ + 204, + 304, +]) + + +# TODO(kgriffs): Rename the WSGI class to App with an API alias kept for +# backwards-compatibility. +class App(falcon.app.App): + """ + + Keyword Arguments: + middleware: Either a single middleware component object or an iterable + of objects (instantiated classes) that implement the following + middleware component interface. + + The interface provides support for handling both ASGI worker + lifespan events and per-request events. + + A lifespan handler can be used to perform startup and/or shutdown + activities for the main event loop. An example of this would be + creating a connection pool and subsequently closing the connection + pool to release the connections. + + Note: + In a multi-process environment, lifespan events will be + triggered independently for the individual event loop associated + with each process. + + It is only necessary to implement the methods for the events you + would like to handle; Falcon simply skips over any missing + middleware methods:: + + class ExampleComponent: + async def process_startup(self, scope, event): + \"\"\"Process the ASGI lifespan startup event. + + Invoked when the server is ready to startup and + receive connections, but before it has started to + do so. + + To halt startup processing and signal to the server that it + should terminate, simply raise an exception and the + framework will convert it to a "lifespan.startup.failed" + event for the server. + + Arguments: + scope (dict): The ASGI scope dictionary for the + lifespan protocol. The lifespan scope exists + for the duration of the event loop. + event (dict): The ASGI event dictionary for the + startup event. + \"\"\" + + async def process_shutdown(self, scope, event): + \"\"\"Process the ASGI lifespan shutdown event. + + Invoked when the server has stopped accepting + connections and closed all active connections. + + To halt shutdown processing and signal to the server + that it should immediately terminate, simply raise an + exception and the framework will convert it to a + "lifespan.shutdown.failed" event for the server. + + Arguments: + scope (dict): The ASGI scope dictionary for the + lifespan protocol. The lifespan scope exists + for the duration of the event loop. + event (dict): The ASGI event dictionary for the + shutdown event. + \"\"\" + + async def process_request(self, req, resp): + \"\"\"Process the request before routing it. + + Note: + Because Falcon routes each request based on + req.path, a request can be effectively re-routed + by setting that attribute to a new value from + within process_request(). + + Args: + req: Request object that will eventually be + routed to an on_* responder method. + resp: Response object that will be routed to + the on_* responder. + \"\"\" + + async def process_resource(self, req, resp, resource, params): + \"\"\"Process the request and resource *after* routing. + + Note: + This method is only called when the request matches + a route to a resource. + + Args: + req: Request object that will be passed to the + routed responder. + resp: Response object that will be passed to the + responder. + resource: Resource object to which the request was + routed. May be None if no route was found for + the request. + params: A dict-like object representing any + additional params derived from the route's URI + template fields, that will be passed to the + resource's responder method as keyword + arguments. + \"\"\" + + async def process_response(self, req, resp, resource, req_succeeded) + \"\"\"Post-processing of the response (after routing). + + Args: + req: Request object. + resp: Response object. + resource: Resource object to which the request was + routed. May be None if no route was found + for the request. + req_succeeded: True if no exceptions were raised + while the framework processed and routed the + request; otherwise False. + \"\"\" + + (See also: :ref:`Middleware `) + + """ + + _STATIC_ROUTE_TYPE = falcon.routing.StaticRouteAsync + + # NOTE(kgriffs): This makes it easier to tell what we are dealing with + # without having to import falcon.asgi to get at the falcon.asgi.App + # type (which we may not be able to do under Python 3.5). + _ASGI = True + + _default_responder_bad_request = falcon.responders.bad_request_async + _default_responder_path_not_found = falcon.responders.path_not_found_async + + def __init__(self, *args, request_type=Request, response_type=Response, **kwargs): + super().__init__(*args, request_type=request_type, response_type=response_type, **kwargs) + + async def __call__(self, scope, receive, send): # noqa: C901 + try: + asgi_info = scope['asgi'] + + # NOTE(kgriffs): We only check this here because + # uvicorn does not explicitly set the 'asgi' key, which + # would normally mean we should assume '2.0', but uvicorn + # actually *does* support 3.0. But in that case, we will + # end up in the except clause, below, and not raise an + # error. + # PERF(kgriffs): This should usually be present, so use a + # try..except + try: + version = asgi_info['version'] + except KeyError: + # NOTE(kgriffs): According to the ASGI spec, "2.0" is + # the default version. + version = '2.0' + + if not version.startswith('3.'): + raise UnsupportedScopeError( + f'Falcon requires ASGI version 3.x. Detected: {asgi_info}' + ) + + except KeyError: + asgi_info = scope['asgi'] = {'version': '2.0'} + + # NOTE(kgriffs): The ASGI spec requires the 'type' key to be present. + scope_type = scope['type'] + if scope_type != 'http': + if scope_type == 'lifespan': + try: + spec_version = asgi_info['spec_version'] + except KeyError: + spec_version = '1.0' + + if not spec_version.startswith('1.') and not spec_version.startswith('2.'): + raise UnsupportedScopeError( + f'Only versions 1.x and 2.x of the ASGI "lifespan" scope are supported.' + ) + + await self._call_lifespan_handlers(spec_version, scope, receive, send) + return + + # NOTE(kgriffs): According to the ASGI spec: "Applications should + # actively reject any protocol that they do not understand with + # an Exception (of any type)." + raise UnsupportedScopeError( + f'The ASGI "{scope_type}" scope type is not supported.' + ) + + # PERF(kgriffs): This is slighter faster than using dict.get() + # TODO(kgriffs): Use this to determine what features are supported by + # the server (e.g., the headers key in the WebSocket Accept + # response). + spec_version = asgi_info['spec_version'] if 'spec_version' in asgi_info else '2.0' + + if not spec_version.startswith('2.'): + raise UnsupportedScopeError( + f'The ASGI http scope version {spec_version} is not supported.' + ) + + resp = self._response_type(options=self.resp_options) + req = self._request_type(scope, receive, options=self.req_options) + if self.req_options.auto_parse_form_urlencoded: + raise UnsupportedError( + 'The deprecated WSGI RequestOptions.auto_parse_form_urlencoded option ' + 'is not supported for ASGI apps. Please use Request.get_media() instead. ' + ) + + resource = None + responder = None + params = {} + + dependent_mw_resp_stack = [] + mw_req_stack, mw_rsrc_stack, mw_resp_stack = self._middleware + + req_succeeded = False + + try: + # NOTE(ealogar): The execution of request middleware + # should be before routing. This will allow request mw + # to modify the path. + # NOTE: if flag set to use independent middleware, execute + # request middleware independently. Otherwise, only queue + # response middleware after request middleware succeeds. + if self._independent_middleware: + for process_request in mw_req_stack: + await process_request(req, resp) + + if resp.complete: + break + else: + for process_request, process_response in mw_req_stack: + if process_request and not resp.complete: + await process_request(req, resp) + + if process_response: + dependent_mw_resp_stack.insert(0, process_response) + + if not resp.complete: + # NOTE(warsaw): Moved this to inside the try except + # because it is possible when using object-based + # traversal for _get_responder() to fail. An example is + # a case where an object does not have the requested + # next-hop child resource. In that case, the object + # being asked to dispatch to its child will raise an + # HTTP exception signaling the problem, e.g. a 404. + responder, params, resource, req.uri_template = self._get_responder(req) + + except Exception as ex: + if not await self._handle_exception(req, resp, ex, params): + raise + + else: + try: + # NOTE(kgriffs): If the request did not match any + # route, a default responder is returned and the + # resource is None. In that case, we skip the + # resource middleware methods. Resource will also be + # None when a middleware method already set + # resp.complete to True. + if resource: + # Call process_resource middleware methods. + for process_resource in mw_rsrc_stack: + await process_resource(req, resp, resource, params) + + if resp.complete: + break + + if not resp.complete: + await responder(req, resp, **params) + + req_succeeded = True + + except Exception as ex: + if not await self._handle_exception(req, resp, ex, params): + raise + + # Call process_response middleware methods. + for process_response in mw_resp_stack or dependent_mw_resp_stack: + try: + await process_response(req, resp, resource, req_succeeded) + + except Exception as ex: + if not await self._handle_exception(req, resp, ex, params): + raise + + req_succeeded = False + + data = b'' + + try: + data = await resp.render_body() + except Exception as ex: + if not await self._handle_exception(req, resp, ex, params): + raise + + req_succeeded = False + + resp_status = http_status_to_code(resp.status) + default_media_type = self.resp_options.default_media_type + + if req.method == 'HEAD' or resp_status in _BODILESS_STATUS_CODES: + # + # PERF(vytas): move check for the less common and much faster path + # of resp_status being in {204, 304} here; NB: this builds on the + # assumption _TYPELESS_STATUS_CODES <= _BODILESS_STATUS_CODES. + # + # NOTE(kgriffs): Based on wsgiref.validate's interpretation of + # RFC 2616, as commented in that module's source code. The + # presence of the Content-Length header is not similarly + # enforced. + # + # NOTE(kgriffs): Assuming the same for ASGI until proven otherwise. + # + if resp_status in _TYPELESS_STATUS_CODES: + default_media_type = None + elif ( + # NOTE(kgriffs): If they are going to stream using an + # async generator, we can't know in advance what the + # content length will be. + (data is not None or not resp.stream) and + + req.method == 'HEAD' and + resp_status not in _BODILESS_STATUS_CODES and + 'content-length' not in resp._headers + ): + # NOTE(kgriffs): We really should be returning a Content-Length + # in this case according to my reading of the RFCs. By + # optionally using len(data) we let a resource simulate HEAD + # by turning around and calling it's own on_get(). + resp._headers['content-length'] = str(len(data)) if data else '0' + + await send({ + 'type': 'http.response.start', + 'status': resp_status, + 'headers': resp._asgi_headers(default_media_type) + }) + + await send(_EVT_RESP_EOF) + self._schedule_callbacks(resp) + return + + sse_emitter = resp.sse + if sse_emitter: + if isasyncgenfunction(sse_emitter): + raise TypeError( + 'Response.sse must be an async iterable. This can be obtained by ' + 'simply executing the async generator function and then setting ' + 'the result to Response.sse, e.g.: resp.sse = some_asyncgen_function()' + ) + + await send({ + 'type': 'http.response.start', + 'status': resp_status, + 'headers': resp._asgi_headers('text/event-stream') + }) + + self._schedule_callbacks(resp) + + # TODO(kgriffs): Do we need to do anything special to handle when + # a connection is closed? + async for event in sse_emitter: + if not event: + event = SSEvent() + + await send({ + 'type': 'http.response.body', + 'body': event.serialize(), + 'more_body': True + }) + + await send({'type': 'http.response.body'}) + return + + if data is not None: + # PERF(kgriffs): Böse mußt sein. Operate directly on resp._headers + # to reduce overhead since this is a hot/critical code path. + # NOTE(kgriffs): We always set content-length to match the + # body bytes length, even if content-length is already set. The + # reason being that web servers and LBs behave unpredictably + # when the header doesn't match the body (sometimes choosing to + # drop the HTTP connection prematurely, for example). + resp._headers['content-length'] = str(len(data)) + + await send({ + 'type': 'http.response.start', + 'status': resp_status, + 'headers': resp._asgi_headers(default_media_type) + }) + + await send({ + 'type': 'http.response.body', + 'body': data + }) + + self._schedule_callbacks(resp) + return + + stream = resp.stream + if not stream: + resp._headers['content-length'] = '0' + + await send({ + 'type': 'http.response.start', + 'status': resp_status, + 'headers': resp._asgi_headers(default_media_type) + }) + + if stream: + # Detect whether this is one of the following: + # + # (a) async file-like object (e.g., aiofiles) + # (b) async generator + # (c) async iterator + # + + if hasattr(stream, 'read'): + while True: + data = await stream.read(self._STREAM_BLOCK_SIZE) + if data == b'': + break + else: + await send({ + 'type': 'http.response.body', + + # NOTE(kgriffs): Handle the case in which data == None + 'body': data or b'', + + 'more_body': True + }) + else: + # NOTE(kgriffs): Works for both async generators and iterators + try: + async for data in stream: + # NOTE(kgriffs): We can not rely on StopIteration + # because of Pep 479 that is implemented starting + # with Python 3.7. AFAICT this is only an issue + # when using an async iterator instead of an async + # generator. + if data is None: + break + + await send({ + 'type': 'http.response.body', + 'body': data, + 'more_body': True + }) + except TypeError as ex: + if isasyncgenfunction(stream): + raise TypeError( + 'The object assigned to Response.stream appears to ' + 'be an async generator function. A generator ' + 'object is expected instead. This can be obtained ' + 'simply by calling the generator function, e.g.: ' + 'resp.stream = some_asyncgen_function()' + ) + + raise TypeError( + 'Response.stream must be a generator or implement an ' + '__aiter__ method. Error raised while iterating over ' + 'Response.stream: ' + str(ex) + ) + + if hasattr(stream, 'close'): + await stream.close() + + await send(_EVT_RESP_EOF) + self._schedule_callbacks(resp) + + def add_error_handler(self, exception, handler=None): + if not handler: + try: + handler = exception.handle + except AttributeError: + # NOTE(kgriffs): Delegate to the parent method for error handling. + pass + + handler = _wrap_non_coroutine_unsafe(handler) + + if handler and not iscoroutinefunction(handler): + raise CompatibilityError( + 'The handler must be an awaitable coroutine function in order ' + 'to be used safely with an ASGI app.' + ) + + super().add_error_handler(exception, handler=handler) + + def add_route(self, uri_template, resource, **kwargs): + # TODO: Check for an _auto_async_wrap kwarg or env var and if there and True, + # go through the resource and wrap any non-couroutine objects. Then + # set that flag in the test cases. + + # NOTE(kgriffs): Inject an extra kwarg so that the compiled router + # will know to validate the responder methods to make sure they + # are async coroutines. + kwargs['_asgi'] = True + super().add_route(uri_template, resource, **kwargs) + + # ------------------------------------------------------------------------ + # Helper methods + # ------------------------------------------------------------------------ + + def _schedule_callbacks(self, resp): + callbacks = resp._registered_callbacks + if not callbacks: + return + + loop = get_loop() + + for cb in callbacks: + if iscoroutinefunction(cb): + loop.create_task(cb()) + else: + loop.run_in_executor(None, cb) + + async def _call_lifespan_handlers(self, ver, scope, receive, send): + while True: + event = await receive() + if event['type'] == 'lifespan.startup': + for handler in self._unprepared_middleware: + if hasattr(handler, 'process_startup'): + try: + await handler.process_startup(scope, event) + except Exception: + await send({ + 'type': 'lifespan.startup.failed', + 'message': traceback.format_exc(), + }) + return + + await send({'type': 'lifespan.startup.complete'}) + + elif event['type'] == 'lifespan.shutdown': + for handler in reversed(self._unprepared_middleware): + if hasattr(handler, 'process_shutdown'): + try: + await handler.process_shutdown(scope, event) + except Exception: + await send({ + 'type': 'lifespan.shutdown.failed', + 'message': traceback.format_exc(), + }) + return + + await send({'type': 'lifespan.shutdown.complete'}) + return + + def _prepare_middleware(self, middleware=None, independent_middleware=False): + return prepare_middleware( + middleware=middleware, + independent_middleware=independent_middleware, + asgi=True + ) + + async def _http_status_handler(self, req, resp, status, params): + self._compose_status_response(req, resp, status) + + async def _http_error_handler(self, req, resp, error, params): + self._compose_error_response(req, resp, error) + + async def _python_error_handler(self, req, resp, error, params): + falcon._logger.error('Unhandled exception in ASGI app', exc_info=error) + self._compose_error_response(req, resp, falcon.HTTPInternalServerError()) + + async def _handle_exception(self, req, resp, ex, params): + """Handle an exception raised from mw or a responder. + + Args: + ex: Exception to handle + req: Current request object to pass to the handler + registered for the given exception type + resp: Current response object to pass to the handler + registered for the given exception type + params: Responder params to pass to the handler + registered for the given exception type + + Returns: + bool: ``True`` if a handler was found and called for the + exception, ``False`` otherwise. + """ + err_handler = self._find_error_handler(ex) + + if err_handler is not None: + try: + await err_handler(req, resp, ex, params) + except HTTPStatus as status: + self._compose_status_response(req, resp, status) + except HTTPError as error: + self._compose_error_response(req, resp, error) + + return True + + # NOTE(kgriffs): No error handlers are defined for ex + # and it is not one of (HTTPStatus, HTTPError), since it + # would have matched one of the corresponding default + # handlers. + return False diff --git a/falcon/asgi/request.py b/falcon/asgi/request.py new file mode 100644 index 000000000..55c46d60f --- /dev/null +++ b/falcon/asgi/request.py @@ -0,0 +1,573 @@ +# Copyright 2019 by Kurt Griffiths +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ASGI Request class.""" + +from falcon import errors +from falcon import request_helpers as helpers # NOQA: Required by fixed up WSGI Request attrs +from falcon.constants import SINGLETON_HEADERS +from falcon.forwarded import _parse_forwarded_header # NOQA: Req. by fixed up WSGI Request attrs +from falcon.forwarded import Forwarded # NOQA +import falcon.request +from falcon.util.uri import parse_host, parse_query_string +from . import _request_helpers as asgi_helpers +from .stream import BoundedStream + + +__all__ = ['Request'] + + +class Request(falcon.request.Request): + """ + + query_string (str): Query string portion of the request URI, without + the preceding '?' character. + + remote_addr(str): IP address of the closest known client or proxy to + the WSGI server, or '127.0.0.1' if unknown. + + This property's value is equivalent to the last element of the + :py:attr:`~.access_route` property. + + access_route(list): IP address of the original client (if known), as + well as any known addresses of proxies fronting the WSGI server. + + The following request headers are checked, in order of + preference, to determine the addresses: + + - ``Forwarded`` + - ``X-Forwarded-For`` + - ``X-Real-IP`` + + In addition, the value of the 'client' field from the ASGI + connection scope will be appended to the end of the list if + not already included in one of the above headers. If the + 'client' field is not available, it will default to + '127.0.0.1'. + + Note: + Per `RFC 7239`_, the access route may contain "unknown" + and obfuscated identifiers, in addition to IPv4 and + IPv6 addresses + + .. _RFC 7239: https://tools.ietf.org/html/rfc7239 + + Warning: + Headers can be forged by any client or proxy. Use this + property with caution and validate all values before + using them. Do not rely on the access route to authorize + requests! + + """ + + __slots__ = [ + '_asgi_headers', + '_asgi_server_cached' + '_receive', + '_stream', + 'scope', + ] + + def __init__(self, scope, receive, options=None): + + # ===================================================================== + # Prepare headers + # ===================================================================== + + req_headers = {} + for header_name, header_value in scope['headers']: + # NOTE(kgriffs): According to ASGI 3.0, header names are always + # lowercased, and both name and value are byte strings. Although + # technically header names and values are restricted to US-ASCII + # we decode using the default 'utf-8' because it is a little + # faster than passing an encoding option. + header_name = header_name.decode() + header_value = header_value.decode() + + # NOTE(kgriffs): There are no standard request headers that + # allow multiple instances to appear in the request while also + # disallowing list syntax. + if header_name not in req_headers or header_name in SINGLETON_HEADERS: + req_headers[header_name] = header_value + else: + req_headers[header_name] += ',' + header_value + + self._asgi_headers = req_headers + + # ===================================================================== + # Misc. + # ===================================================================== + + self._asgi_server_cached = None # Lazy + + self.scope = scope + self.options = options if options else falcon.request.RequestOptions() + + self._wsgierrors = None + self.method = scope['method'] + + self.uri_template = None + self._media = None + + # TODO(kgriffs): ASGI does not specify whether 'path' may be empty, + # as was allowed for WSGI. + path = scope['path'] or '/' + + if (self.options.strip_url_path_trailing_slash and + len(path) != 1 and path.endswith('/')): + self.path = path[:-1] + else: + self.path = path + + query_string = scope['query_string'].decode() + self.query_string = query_string + if query_string: + self._params = parse_query_string( + query_string, + keep_blank=self.options.keep_blank_qs_values, + csv=self.options.auto_parse_qs_csv, + ) + + else: + self._params = {} + + self._cached_access_route = None + self._cached_forwarded = None + self._cached_forwarded_prefix = None + self._cached_forwarded_uri = None + self._cached_headers = req_headers + self._cached_prefix = None + self._cached_relative_uri = None + self._cached_uri = None + + if self.method == 'GET': + # PERF(kgriffs): Normally we expect no Content-Type header, so + # use this pattern which is a little bit faster than dict.get() + if 'content-type' in req_headers: + self.content_type = req_headers['content-type'] + else: + self.content_type = None + else: + # PERF(kgriffs): This is the most performant pattern when we expect + # the key to be present most of the time. + try: + self.content_type = req_headers['content-type'] + except KeyError: + self.content_type = None + + # ===================================================================== + # The request body stream is created lazily + # ===================================================================== + + # NOTE(kgriffs): The ASGI spec states that "you should not trigger + # on a connection opening alone". I take this to mean that the app + # should have the opportunity to respond with a 401, for example, + # without having to first read any of the body. This is accomplished + # in Falcon by only reading the first data event when the app attempts + # to read from req.stream for the first time, and in uvicorn + # (for example) by not confirming a 100 Continue request unless + # the app calls receive() to read the request body. + + self._stream = None + self._receive = receive + + # ===================================================================== + # Create a context object + # ===================================================================== + + self.context = self.context_type() + + # ------------------------------------------------------------------------ + # Properties + # + # Much of the logic from the ASGI Request class is duplicted in these + # property implementations; however, to make the code more DRY we would + # have to factor out the common logic, which would add overhead to these + # properties and slow them down. They are simple enough that we should + # be able to keep them in sync with the WSGI side without too much + # trouble. + # ------------------------------------------------------------------------ + + auth = asgi_helpers.header_property('Authorization') + expect = asgi_helpers.header_property('Expect') + if_range = asgi_helpers.header_property('If-Range') + referer = asgi_helpers.header_property('Referer') + user_agent = asgi_helpers.header_property('User-Agent') + + @property + def accept(self): + # NOTE(kgriffs): Per RFC, a missing accept header is + # equivalent to '*/*' + try: + return self._asgi_headers['accept'] or '*/*' + except KeyError: + return '*/*' + + @property + def content_length(self): + try: + value = self._asgi_headers['content-length'] + except KeyError: + return None + + # NOTE(kgriffs): Normalize an empty value to behave as if + # the header were not included; wsgiref, at least, inserts + # an empty CONTENT_LENGTH value if the request does not + # set the header. Gunicorn and uWSGI do not do this, but + # others might if they are trying to match wsgiref's + # behavior too closely. + if not value: + return None + + try: + value_as_int = int(value) + except ValueError: + msg = 'The value of the header must be a number.' + raise errors.HTTPInvalidHeader(msg, 'Content-Length') + + if value_as_int < 0: + msg = 'The value of the header must be a positive number.' + raise errors.HTTPInvalidHeader(msg, 'Content-Length') + + return value_as_int + + @property + def stream(self): + if not self._stream: + self._stream = BoundedStream(self._receive, self.content_length) + + return self._stream + + bounded_stream = stream + + @property + def root_path(self): + # PERF(kgriffs): try...except is faster than get() assuming that + # we normally expect the key to exist. Even though ASGI 3.0 + # allows servers to omit the key when the value is an + # empty string, at least uvicorn still includes it explicitly in + # that case. + try: + return self.scope['root_path'] + except KeyError: + pass + + return '' + + app = root_path + + @property + def scheme(self): + # PERF(kgriffs): Use try...except because we normally expect the + # key to be present. + try: + return self.scope['scheme'] + except KeyError: + pass + + return 'http' + + @property + def forwarded_scheme(self): + # PERF(kgriffs): Since the Forwarded header is still relatively + # new, we expect X-Forwarded-Proto to be more common, so + # try to avoid calling self.forwarded if we can, since it uses a + # try...catch that will usually result in a relatively expensive + # raised exception. + if 'forwarded' in self._asgi_headers: + first_hop = self.forwarded[0] + scheme = first_hop.scheme or self.scheme + else: + # PERF(kgriffs): This call should normally succeed, so + # just go for it without wasting time checking it + # first. Note also that the indexing operator is + # slightly faster than using get(). + try: + scheme = self._asgi_headers['x-forwarded-proto'].lower() + except KeyError: + scheme = self.scheme + + return scheme + + @property + def prefix(self): + if self._cached_prefix is None: + self._cached_prefix = ( + self.scheme + '://' + + self.netloc + + self.app + ) + + return self._cached_prefix + + @property + def host(self): + try: + # NOTE(kgriffs): Prefer the host header; the web server + # isn't supposed to mess with it, so it should be what + # the client actually sent. + host_header = self._asgi_headers['host'] + host, __ = parse_host(host_header) + except KeyError: + host, __ = self._asgi_server + + return host + + @property + def forwarded_host(self): + # PERF(kgriffs): Since the Forwarded header is still relatively + # new, we expect X-Forwarded-Host to be more common, so + # try to avoid calling self.forwarded if we can, since it uses a + # try...catch that will usually result in a relatively expensive + # raised exception. + if 'forwarded' in self._asgi_headers: + first_hop = self.forwarded[0] + host = first_hop.host or self.host + else: + # PERF(kgriffs): This call should normally succeed, assuming + # that the caller is expecting a forwarded header, so + # just go for it without wasting time checking it + # first. + try: + host = self._asgi_headers['x-forwarded-host'] + except KeyError: + host = self.host + + return host + + @property + def access_route(self): + if self._cached_access_route is None: + # PERF(kgriffs): 'client' is optional according to the ASGI spec + # but it will probably be present, hence the try...except. + try: + # NOTE(kgriffs): The ASGI spec states that this can be + # any iterable. So we need to read and cache it in + # case the iterable is forward-only. But that is + # effectively what we are doing since we only ever + # access this field when setting self._cached_access_route + client, __ = self.scope['client'] + except KeyError: + # NOTE(kgriffs): Default to localhost so that app logic does + # note have to special-case the handling of a missing + # client field in the connection scope. This should be + # a reasonable default, but we can change it later if + # that turns out not to be the case. + client = '127.0.0.1' + + headers = self._asgi_headers + + if 'forwarded' in headers: + self._cached_access_route = [] + for hop in self.forwarded: + if hop.src is not None: + host, __ = parse_host(hop.src) + self._cached_access_route.append(host) + elif 'x-forwarded-for' in headers: + addresses = headers['x-forwarded-for'].split(',') + self._cached_access_route = [ip.strip() for ip in addresses] + elif 'x-real-ip' in headers: + self._cached_access_route = [headers['x-real-ip']] + + if self._cached_access_route: + if self._cached_access_route[-1] != client: + self._cached_access_route.append(client) + else: + self._cached_access_route = [client] if client else [] + + return self._cached_access_route + + @property + def remote_addr(self): + route = self.access_route + return route[-1] + + @property + def port(self): + try: + host_header = self._asgi_headers['host'] + default_port = 80 if self.scheme == 'http' else 443 + __, port = parse_host(host_header, default_port=default_port) + except KeyError: + __, port = self._asgi_server + + return port + + @property + def netloc(self): + # PERF(kgriffs): try..except is faster than get() when we + # expect the key to be present most of the time. + try: + netloc_value = self._asgi_headers['host'] + except KeyError: + netloc_value, port = self._asgi_server + + if self.scheme == 'https': + if port != 443: + netloc_value = f'{netloc_value}:{port}' + else: + if port != 80: + netloc_value = f'{netloc_value}:{port}' + + return netloc_value + + @property + def media(self): + raise errors.UnsupportedError( + 'The media property is not supported for ASGI requests. ' + 'Please use the Request.get_media() coroutine function instead.' + ) + + async def get_media(self): + """Returns a deserialized form of the request stream. + + When called the first time, the request stream will be deserialized + using the Content-Type header as well as the media-type handlers + configured via :class:`falcon.RequestOptions`. The result will + be cached and returned in subsequent calls. + + If the matched media handler raises an error while attempting to + deserialize the request body, the exception will propagate up + to the caller. + + See :ref:`media` for more information regarding media handling. + + Warning: + This operation will consume the request stream the first time + it's called and cache the results. Follow-up calls will just + retrieve a cached version of the object. + + Returns: + media (object): The deserialized media representation. + """ + + if self._media is not None or self.stream.eof: + return self._media + + handler = self.options.media_handlers.find_by_media_type( + self.content_type, + self.options.default_media_type + ) + + try: + self._media = await handler.deserialize_async( + self.stream, + self.content_type, + self.content_length + ) + finally: + await self.stream.exhaust() + + return self._media + + @property + def if_match(self): + # TODO(kgriffs): It may make sense at some point to create a + # header property generator that DRY's up the memoization + # pattern for us. + # PERF(kgriffs): It probably isn't worth it to set + # self._cached_if_match to a special type/object to distinguish + # between the variable being unset and the header not being + # present in the request. The reason is that if the app + # gets a None back on the first reference to property, it + # probably isn't going to access the property again (TBD). + if self._cached_if_match is None: + header_value = self._asgi_headers.get('if-match') + if header_value: + self._cached_if_match = helpers._parse_etags(header_value) + + return self._cached_if_match + + @property + def if_none_match(self): + if self._cached_if_none_match is None: + header_value = self._asgi_headers.get('if-none-match') + if header_value: + self._cached_if_none_match = helpers._parse_etags(header_value) + + return self._cached_if_none_match + + # ------------------------------------------------------------------------ + # Public Methods + # ------------------------------------------------------------------------ + + def get_header(self, name, required=False, default=None): + """Retrieve the raw string value for the given header. + + Args: + name (str): Header name, case-insensitive (e.g., 'Content-Type') + + Keyword Args: + required (bool): Set to ``True`` to raise + ``HTTPBadRequest`` instead of returning gracefully when the + header is not found (default ``False``). + default (any): Value to return if the header + is not found (default ``None``). + + Returns: + str: The value of the specified header if it exists, or + the default value if the header is not found and is not + required. + + Raises: + HTTPBadRequest: The header was not found in the request, but + it was required. + + """ + + asgi_name = name.lower() + + # Use try..except to optimize for the header existing in most cases + try: + # Don't take the time to cache beforehand, using HTTP naming. + # This will be faster, assuming that most headers are looked + # up only once, and not all headers will be requested. + return self._asgi_headers[asgi_name] + + except KeyError: + if not required: + return default + + raise errors.HTTPMissingHeader(name) + + def log_error(self, message): + # NOTE(kgriffs): Normally the pythonic thing to do would be to simply + # set this method to None so that it can't even be called, but we + # raise an error here to help people who are porting from WSGI. + raise NotImplementedError( + "ASGI does not support writing to the server's log. " + 'Please use the standard library logging framework ' + 'instead.' + ) + + # ------------------------------------------------------------------------ + # Private Helpers + # ------------------------------------------------------------------------ + + @property + def _asgi_server(self): + if not self._asgi_server_cached: + try: + # NOTE(kgriffs): Since the ASGI spec states that 'server' + # can be any old iterable, we have to be careful to only + # read it once and cache the result in case the + # iterator is forward-only (not likely, but better + # safe than sorry). + self._asgi_server_cached = tuple(self.scope['server']) + except (KeyError, TypeError): + # NOTE(kgriffs): Not found, or was None + default_port = 80 if self.scheme == 'http' else 443 + self._asgi_server_cached = ('localhost', default_port) + + return self._asgi_server_cached diff --git a/falcon/asgi/response.py b/falcon/asgi/response.py new file mode 100644 index 000000000..65fa96c5a --- /dev/null +++ b/falcon/asgi/response.py @@ -0,0 +1,283 @@ +# Copyright 2019 by Kurt Griffiths +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ASGI Response class.""" + +from asyncio.coroutines import CoroWrapper +from inspect import iscoroutine + +from falcon import _UNSET +import falcon.response + +__all__ = ['Response'] + + +class Response(falcon.response.Response): + """ + + Attributes: + sse (coroutine): A Server-Sent Event (SSE) emitter, implemented as + an async coroutine function that returns an iterable + of :py:class:`falcon.asgi.SSEvent` instances. Each event will be + serialized and sent to the client as HTML5 Server-Sent Events. + + data (bytes): Byte string representing response content. + + Use this attribute in lieu of `body` when your content is + already a byte string (of type ``bytes``). + + Warning: + Always use the `body` attribute for text, or encode it + first to ``bytes`` when using the `data` attribute, to + ensure Unicode characters are properly encoded in the + HTTP response. + + Note: + Unlike the WSGI Response class, the ASGI Response class + does not implement the side-effect of serializing + the media object (if one is set) when the `data` + attribute is read. Instead, + :py:meth:`falcon.asgi.Response.render_body` should + be used to get the correct content for the response. + + stream: An async iterator or generator that yields a series of + byte strings that will be streamed to the ASGI server as a + series of "http.response.body" events. Falcon will assume the + body is complete when the iterable is exhausted or as soon as it + yields ``None`` rather than an instance of ``bytes``. + + If the object assigned to :py:attr:`~.stream` holds any resources + (such as a file handle) that must be explicitly released, the + object must implement a close() method. The close() method will + be called after exhausting the iterable. + + Note: + In order to be compatible with Python 3.7+ and PEP 479, + async iterators must return ``None`` instead of raising + :py:class:`StopIteration`. + + Note: + If the stream length is known in advance, you may wish to + also set the Content-Length header on the response. + + """ + + # PERF(kgriffs): These will be shadowed when set on an instance; let's + # us avoid having to implement __init__ and incur the overhead of + # an additional function call. + _sse = None + _registered_callbacks = None + _media_rendered = _UNSET + + @property + def sse(self): + return self._sse + + @sse.setter + def sse(self, value): + self._sse = value + + @property + def media(self): + return self._media + + @media.setter + def media(self, value): + self._media = value + self._media_rendered = _UNSET + + @property + def data(self): + return self._data + + @data.setter + def data(self, value): + self._data = value + + async def render_body(self): + """Get the raw content for the response body. + + This coroutine can be awaited to get the raw body data that should + be returned in the HTTP response. + + Returns: + bytes: The UTF-8 encoded value of the `body` attribute, if + set. Otherwise, the value of the `data` attribute if set, or + finally the serialized value of the `media` attribute. If + none of these attributes are set, ``None`` is returned. + """ + + body = self.body + if body is None: + data = self._data + + if data is None and self._media is not None: + # NOTE(kgriffs): We use a special _UNSET singleton since + # None is ambiguous (the media handler might return None). + if self._media_rendered is _UNSET: + if not self.content_type: + self.content_type = self.options.default_media_type + + handler = self.options.media_handlers.find_by_media_type( + self.content_type, + self.options.default_media_type + ) + + self._media_rendered = await handler.serialize_async( + self._media, + self.content_type + ) + + data = self._media_rendered + else: + try: + # NOTE(kgriffs): Normally we expect body to be a string + data = body.encode() + except AttributeError: + # NOTE(kgriffs): Assume it was a bytes object already + data = body + + return data + + def schedule(self, callback): + """Schedules a callback to run soon after sending the HTTP response. + + This method can be used to execute a background job after the + response has been returned to the client. + + If the callback is an async coroutine function, it will be scheduled + to run on the event loop as soon as possible. Alternatively, if a + synchronous callable is passed, it will be run on the event loop's + default ``Executor`` (which can be overridden via + :py:meth:`asyncio.AbstractEventLoop.set_default_executor`). + + The callback will be invoked without arguments. Use + :py:meth`functools.partial` to pass arguments to the callback + as needed. + + Note: + If an unhandled exception is raised while processing the request, + the callback will not be scheduled to run. + + Note: + When an SSE emitter has been set on the response, the callback will + be scheduled before the first call to the emitter. + + Warning: + Because coroutines run on the main request thread, care should + be taken to ensure they are non-blocking. Long-running operations + must use async libraries or delegate to an Executor pool to avoid + blocking the processing of subsequent requests. + + Warning: + Synchronous callables run on the event loop's default ``Executor``, + which uses an instance of ``ThreadPoolExecutor`` unless + :py:meth:`asyncio.AbstractEventLoop.set_default_executor` is used + to change it to something else. Due to the GIL, CPU-bound jobs + will block request processing for the current process unless + the default ``Executor`` is changed to one that is process-based + instead of thread-based (e.g., an instance of + :py:class:`concurrent.futures.ProcessPoolExecutor`). + + Args: + callback(object): An async coroutine function or a synchronous + callable. The callback will be called without arguments. + """ + + # NOTE(kgriffs): We also have to do the CoroWrapper check because + # iscoroutine is less reliable under Python 3.6. + if iscoroutine(callback) or isinstance(callback, CoroWrapper): + raise TypeError( + 'The callback object appears to ' + 'be a coroutine, rather than a coroutine function. Please ' + 'pass the function itself, rather than the result obtained ' + 'by calling the function. ' + ) + + if not self._registered_callbacks: + self._registered_callbacks = [callback] + else: + self._registered_callbacks.append(callback) + + def set_stream(self, stream, content_length): + """Convenience method for setting both `stream` and `content_length`. + + Although the `stream` and `content_length` properties may be set + directly, using this method ensures `content_length` is not + accidentally neglected when the length of the stream is known in + advance. Using this method is also slightly more performant + as compared to setting the properties individually. + + Note: + If the stream length is unknown, you can set `stream` + directly, and ignore `content_length`. In this case, the + ASGI server may choose to use chunked encoding for HTTP/1.1 + + Args: + stream: A readable, awaitable file-like object or async iterable that + retuns byte strings. If the object implements a close() method, it + will be called after reading all of the data. + content_length (int): Length of the stream, used for the + Content-Length header in the response. + """ + + self.stream = stream + + # PERF(kgriffs): Set directly rather than incur the overhead of + # the self.content_length property. + self._headers['content-length'] = str(content_length) + + # ------------------------------------------------------------------------ + # Helper methods + # ------------------------------------------------------------------------ + + def _asgi_headers(self, media_type=None): + """Convert headers into the format expected by ASGI servers. + + Header names must be lowercased and both name and value must be + byte strings. + + See also: https://asgi.readthedocs.io/en/latest/specs/www.html#response-start + + Args: + media_type: Default media type to use for the Content-Type + header if the header was not set explicitly (default ``None``). + + """ + + headers = self._headers + # PERF(vytas): uglier inline version of Response._set_media_type + if media_type is not None and 'content-type' not in headers: + headers['content-type'] = media_type + + items = [(n.encode(), v.encode()) for n, v in headers.items()] + + if self._extra_headers: + items += [(n.encode(), v.encode()) for n, v in self._extra_headers] + + # NOTE(kgriffs): It is important to append these after self._extra_headers + # in case the latter contains Set-Cookie headers that should be + # overridden by a call to unset_cookie(). + if self._cookies is not None: + # PERF(tbug): + # The below implementation is ~23% faster than + # the alternative: + # + # self._cookies.output().split("\\r\\n") + # + # Even without the .split("\\r\\n"), the below + # is still ~17% faster, so don't use .output() + items += [(b'set-cookie', c.OutputString().encode()) + for c in self._cookies.values()] + return items diff --git a/falcon/asgi/stream.py b/falcon/asgi/stream.py new file mode 100644 index 000000000..8d36b4efb --- /dev/null +++ b/falcon/asgi/stream.py @@ -0,0 +1,310 @@ +# Copyright 2019 by Kurt Griffiths +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ASGI BoundedStream class.""" + + +__all__ = ['BoundedStream'] + + +class BoundedStream: + """File-like async object for reading ASGI streams. + + Does not support synhcronous reading/iterating but is otherwise similar to io.IOBase. + + If content length is unknown, will read until the ASGI server indicates + there is no more body available. + + """ + + __slots__ = [ + '_buffer', + '_bytes_remaining', + '_closed', + '_pos', + '_receive', + ] + + def __init__(self, receive, content_length=None): + self._closed = False + + self._receive = receive + self._buffer = b'' + + # NOTE(kgriffs): If length is unknown we just set remaining bytes + # to a ridiculously high number so that we will keep reading + # until we get an event with more_body == False. We do not + # use sys.maxsize because 2**31 on 32-bit systems is not + # a large enough number (someone may have an API that accepts + # multi-GB payloads). + self._bytes_remaining = 2**63 if content_length is None else content_length + + self._pos = 0 + + def __aiter__(self): + # NOTE(kgriffs): Technically we should be returning an async iterator + # here instead of an async generator, but in practice the caller + # should be happy as long as the returned object is iterable. + return self._iter_content() + + # ------------------------------------------------------------------------- + # These methods are included to improve compatibility with Python's + # standard "file-like" IO interface. + # ------------------------------------------------------------------------- + + # NOTE(kgriffs): According to the Python docs, NotImplementedError is not + # meant to be used to mean "not supported"; rather, the method should + # just be left undefined; hence we do not implement readline(), + # readlines(), __iter__(), __next__(), flush(), seek(), + # truncate(), __del__(). + + def fileno(self): + """Raises an instance of OSError since a file descriptor is not used.""" + raise OSError('This IO object does not use a file descriptor') + + def isatty(self): + """Always returns ``False``.""" + return False + + def readable(self): + """Always returns ``True``.""" + return True + + def seekable(self): + """Always returns ``False``.""" + return False + + def writable(self): + """Always returns ``False``.""" + return False + + def tell(self): + """Returns the number of bytes read from the stream.""" + return self._pos + + @property + def closed(self): + return self._closed + + # ------------------------------------------------------------------------- + + @property + def eof(self): + return not self._buffer and self._bytes_remaining == 0 + + def close(self): + """Clear any buffered data and close this stream. + + Once the stream is closed, any operation on it will + raise a ValueError. + + As a convenience, it is allowed to call this method more than + once; only the first call, however, will have an effect. + """ + + if not self._closed: + self._buffer = b'' + self._bytes_remaining = 0 + + self._closed = True + + async def exhaust(self): + if self._closed: + raise ValueError( + 'This stream is closed; no futher operations on it are permitted.' + ) + + self._buffer = b'' + + while self._bytes_remaining > 0: + event = await self._receive() + + if event['type'] == 'http.disconnect': + self._bytes_remaining = 0 + else: + try: + num_bytes = len(event['body']) + except KeyError: + # NOTE(kgriffs): The ASGI spec states that 'body' is optional. + num_bytes = 0 + + self._bytes_remaining -= num_bytes + self._pos += num_bytes + + if not ('more_body' in event and event['more_body']): + self._bytes_remaining = 0 + + # NOTE(kgriffs): Ensure that if we read more than expected, this + # this value is normalized to zero. + self._bytes_remaining = 0 + + async def readall(self): + if self._closed: + raise ValueError( + 'This stream is closed; no futher operations on it are permitted.' + ) + + if self.eof: + return b'' + + if self._buffer: + next_chunk = self._buffer + self._buffer = b'' + chunks = [next_chunk] + else: + chunks = [] + + while self._bytes_remaining > 0: + event = await self._receive() + + # PERF(kgriffs): Use try..except because we normally expect the + # 'body' key to be present. + try: + next_chunk = event['body'] + except KeyError: + pass + else: + next_chunk_len = len(next_chunk) + + if next_chunk_len <= self._bytes_remaining: + chunks.append(next_chunk) + self._bytes_remaining -= next_chunk_len + else: + # NOTE(kgriffs): Do not read more data than we are + # expecting. This *should* never happen if the + # server enforces the content-length header, but + # it is better to be safe than sorry. + chunks.append(next_chunk[:self._bytes_remaining]) + self._bytes_remaining = 0 + + # NOTE(kgriffs): This also handles the case of receiving + # the event: {'type': 'http.disconnect'} + if not ('more_body' in event and event['more_body']): + self._bytes_remaining = 0 + + data = chunks[0] if len(chunks) == 1 else b''.join(chunks) + self._pos += len(data) + + return data + + async def read(self, size=None): + if self._closed: + raise ValueError( + 'This stream is closed; no futher operations on it are permitted.' + ) + + if self.eof: + return b'' + + if size is None or size == -1: + return await self.readall() + + if size <= 0: + return b'' + + if self._buffer: + num_bytes_available = len(self._buffer) + chunks = [self._buffer] + else: + num_bytes_available = 0 + chunks = [] + + while self._bytes_remaining > 0 and num_bytes_available < size: + event = await self._receive() + + # PERF(kgriffs): Use try..except because we normally expect the + # 'body' key to be present. + try: + next_chunk = event['body'] + except KeyError: + pass + else: + next_chunk_len = len(next_chunk) + + if next_chunk_len <= self._bytes_remaining: + chunks.append(next_chunk) + self._bytes_remaining -= next_chunk_len + num_bytes_available += next_chunk_len + else: + # NOTE(kgriffs): Do not read more data than we are + # expecting. This *should* never happen, but better + # safe than sorry. + chunks.append(next_chunk[:self._bytes_remaining]) + self._bytes_remaining = 0 + num_bytes_available += self._bytes_remaining + + # NOTE(kgriffs): This also handles the case of receiving + # the event: {'type': 'http.disconnect'} + if not ('more_body' in event and event['more_body']): + self._bytes_remaining = 0 + + self._buffer = chunks[0] if len(chunks) == 1 else b''.join(chunks) + + if num_bytes_available <= size: + data = self._buffer + self._buffer = b'' + else: + data = self._buffer[:size] + self._buffer = self._buffer[size:] + + self._pos += len(data) + + return data + + # NOTE: In docs, tell people to not mix reading different modes - make + # sure you exhaust in the finally if you are reading something + # in middleware, or a chance something else might read it. Don't want someone + # to end up trying to read a half-read thing anyway! + async def _iter_content(self): + if self._closed: + raise ValueError( + 'This stream is closed; no futher operations on it are permitted.' + ) + + if self.eof: + yield b'' + return + + while self._bytes_remaining > 0: + event = await self._receive() + + # PERF(kgriffs): Use try..except because we normally expect the + # 'body' key to be present. + try: + next_chunk = event['body'] + except KeyError: + pass + else: + next_chunk_len = len(next_chunk) + + if next_chunk_len <= self._bytes_remaining: + self._bytes_remaining -= next_chunk_len + self._pos += next_chunk_len + else: + # NOTE(kgriffs): We received more data than expected, + # so truncate to the expected length. + next_chunk = next_chunk[:self._bytes_remaining] + self._pos += self._bytes_remaining + self._bytes_remaining = 0 + + yield next_chunk + + # NOTE(kgriffs): Per the ASGI spec, more_body is optional + # and should be considered False if not present. + # NOTE(kgriffs): This also handles the case of receiving + # the event: {'type': 'http.disconnect'} + # PERF(kgriffs): event.get() is more elegant, but uses a + # few more CPU cycles. + if not ('more_body' in event and event['more_body']): + self._bytes_remaining = 0 diff --git a/falcon/asgi/structures.py b/falcon/asgi/structures.py new file mode 100644 index 000000000..7a29ea24a --- /dev/null +++ b/falcon/asgi/structures.py @@ -0,0 +1,92 @@ +from json import dumps as json_dumps + + +__all__ = ['SSEvent'] + + +class SSEvent: + __slots__ = [ + 'data', + 'text', + 'json', + 'event', + 'event_id', + 'retry', + 'comment', + ] + + def __init__( + self, + data=None, + text=None, + json=None, + event=None, + event_id=None, + retry=None, + comment=None + ): + # NOTE(kgriffs): Check up front since this makes it a lot easier + # to debug the source of the problem in the app vs. waiting for + # an error to be raised from the framework when it calls serialize() + # after the fact. + + if data and not isinstance(data, bytes): + raise TypeError('data must be a byte string') + + if text and not isinstance(text, str): + raise TypeError('text must be a string') + + if event and not isinstance(event, str): + raise TypeError('event name must be a string') + + if event_id and not isinstance(event_id, str): + raise TypeError('event_id must be a string') + + if comment and not isinstance(comment, str): + raise TypeError('comment must be a string') + + if retry and not isinstance(retry, int): + raise TypeError('retry must be an int') + + self.data = data + self.text = text + self.json = json + self.event = event + self.event_id = event_id + self.retry = retry + + self.comment = comment + + def serialize(self): + if self.comment is not None: + block = ': ' + self.comment + '\n' + else: + block = '' + + if self.event is not None: + block += 'event: ' + self.event + '\n' + + if self.event_id is not None: + # NOTE(kgriffs): f-strings are a tiny bit faster than str(). + block += f'id: {self.event_id}\n' + + if self.retry is not None: + block += f'retry: {self.retry}\n' + + if self.data is not None: + # NOTE(kgriffs): While this decode() may seem unnecessary, it + # does provide a check to ensure it is valid UTF-8. I'm also + # assuming for the moment that most people will not use this + # attribute, but rather the text and json ones instead. If that + # is true, it makes sense to construct the entire string + # first, then encode it all in one go at the end. + block += 'data: ' + self.data.decode() + '\n' + elif self.text is not None: + block += 'data: ' + self.text + '\n' + elif self.json is not None: + block += 'data: ' + json_dumps(self.json, ensure_ascii=False) + '\n' + + if not block: + return b': ping\n\n' + + return (block + '\n').encode() diff --git a/falcon/cmd/print_routes.py b/falcon/cmd/print_routes.py index 1915212ee..ccea297f5 100644 --- a/falcon/cmd/print_routes.py +++ b/falcon/cmd/print_routes.py @@ -50,7 +50,9 @@ def traverse(roots, parent='', verbose=False): print('->', parent + '/' + root.raw_segment) if verbose: for method, func in root.method_map.items(): - if func.__name__ != 'method_not_allowed': + # NOTE(kgriffs): Skip the default responder that the + # framework creates. + if not func.__name__.startswith('method_not_allowed'): if isinstance(func, partial): real_func = func.func else: diff --git a/falcon/constants.py b/falcon/constants.py index 8de859769..420de21a5 100644 --- a/falcon/constants.py +++ b/falcon/constants.py @@ -91,3 +91,16 @@ MEDIA_GIF = 'image/gif' DEFAULT_MEDIA_TYPE = MEDIA_JSON + +# NOTE(kgriffs): We do not expect more than one of these in the request +SINGLETON_HEADERS = frozenset([ + 'content-length', + 'content-type', + 'cookie', + 'expect', + 'from', + 'host', + 'max-forwards', + 'referer', + 'user-agent', +]) diff --git a/falcon/errors.py b/falcon/errors.py index d16c21a9f..59609fad8 100644 --- a/falcon/errors.py +++ b/falcon/errors.py @@ -47,6 +47,18 @@ class HeaderNotSupported(ValueError): """The specified header is not supported by this method.""" +class CompatibilityError(ValueError): + """The given method or value is not compatibile.""" + + +class UnsupportedScopeError(RuntimeError): + """The ASGI scope type is not supported by Falcon.""" + + +class UnsupportedError(RuntimeError): + """The method or operation is not supported.""" + + class HTTPBadRequest(HTTPError): """400 Bad Request. diff --git a/falcon/hooks.py b/falcon/hooks.py index e90f30263..82c050274 100644 --- a/falcon/hooks.py +++ b/falcon/hooks.py @@ -15,11 +15,12 @@ """Hook decorators.""" from functools import wraps -from inspect import getmembers +from inspect import getmembers, iscoroutinefunction import re from falcon import COMBINED_METHODS from falcon.util.misc import get_argnames +from falcon.util.sync import _wrap_non_coroutine_unsafe _DECORABLE_METHOD_NAME = re.compile(r'^on_({})(_\w+)?$'.format( @@ -148,13 +149,24 @@ def _wrap_with_after(responder, action, action_args, action_kwargs): responder_argnames = get_argnames(responder) extra_argnames = responder_argnames[2:] # Skip req, resp - @wraps(responder) - def do_after(self, req, resp, *args, **kwargs): - if args: - _merge_responder_args(args, kwargs, extra_argnames) + if iscoroutinefunction(responder): + action = _wrap_non_coroutine_unsafe(action) - responder(self, req, resp, **kwargs) - action(req, resp, self, *action_args, **action_kwargs) + @wraps(responder) + async def do_after(self, req, resp, *args, **kwargs): + if args: + _merge_responder_args(args, kwargs, extra_argnames) + + await responder(self, req, resp, **kwargs) + await action(req, resp, self, *action_args, **action_kwargs) + else: + @wraps(responder) + def do_after(self, req, resp, *args, **kwargs): + if args: + _merge_responder_args(args, kwargs, extra_argnames) + + responder(self, req, resp, **kwargs) + action(req, resp, self, *action_args, **action_kwargs) return do_after @@ -173,13 +185,24 @@ def _wrap_with_before(responder, action, action_args, action_kwargs): responder_argnames = get_argnames(responder) extra_argnames = responder_argnames[2:] # Skip req, resp - @wraps(responder) - def do_before(self, req, resp, *args, **kwargs): - if args: - _merge_responder_args(args, kwargs, extra_argnames) + if iscoroutinefunction(responder): + action = _wrap_non_coroutine_unsafe(action) + + @wraps(responder) + async def do_before(self, req, resp, *args, **kwargs): + if args: + _merge_responder_args(args, kwargs, extra_argnames) + + await action(req, resp, self, kwargs, *action_args, **action_kwargs) + await responder(self, req, resp, **kwargs) + else: + @wraps(responder) + def do_before(self, req, resp, *args, **kwargs): + if args: + _merge_responder_args(args, kwargs, extra_argnames) - action(req, resp, self, kwargs, *action_args, **action_kwargs) - responder(self, req, resp, **kwargs) + action(req, resp, self, kwargs, *action_args, **action_kwargs) + responder(self, req, resp, **kwargs) return do_before diff --git a/falcon/media/base.py b/falcon/media/base.py index dcb5ee17d..55b34a5cd 100644 --- a/falcon/media/base.py +++ b/falcon/media/base.py @@ -1,13 +1,20 @@ import abc +import io class BaseHandler(metaclass=abc.ABCMeta): """Abstract Base Class for an internet media type handler""" - @abc.abstractmethod # pragma: no cover def serialize(self, media, content_type): """Serialize the media object on a :any:`falcon.Response` + By default, this method raises an instance of + :py:class:`NotImplementedError`. Therefore, it must be + overridden in order to work with WSGI apps. Child classes + can ignore this method if they are only to be used + with ASGI apps, as long as they override + :py:meth:`~.BaseHandler.serialize_async`. + Args: media (object): A serializable object. content_type (str): Type of response content. @@ -15,16 +22,86 @@ def serialize(self, media, content_type): Returns: bytes: The resulting serialized bytes from the input object. """ + raise NotImplementedError() + + async def serialize_async(self, media, content_type): + """Serialize the media object on a :any:`falcon.Response` + + This method is similar to :py:meth:`~.BaseHandler.serialize` + except that it is asynchronous. The default implementation simply calls + :py:meth:`~.BaseHandler.serialize`. If the media object may be + awaitable, or is otherwise something that should be read + asynchronously, subclasses must override the default implementation + in order to handle that case. + + Note: + By default, the :py:meth:`~.BaseHandler.serialize` + method raises an instance of :py:class:`NotImplementedError`. + Therefore, child classes must either override + :py:meth:`~.BaseHandler.serialize` or + :py:meth:`~.BaseHandler.serialize_async` in order to be + compatible with ASGI apps. + + Args: + media (object): A serializable object. + content_type (str): Type of response content. + + Returns: + bytes: The resulting serialized bytes from the input object. + """ + return self.serialize(media, content_type) - @abc.abstractmethod # pragma: no cover def deserialize(self, stream, content_type, content_length): """Deserialize the :any:`falcon.Request` body. + By default, this method raises an instance of + :py:class:`NotImplementedError`. Therefore, it must be + overridden in order to work with WSGI apps. Child classes + can ignore this method if they are only to be used + with ASGI apps, as long as they override + :py:meth:`~.BaseHandler.deserialize_async`. + + + Args: + stream (object): Readable file-like object to deserialize. + content_type (str): Type of request content. + content_length (int): Length of request content. + + Returns: + object: A deserialized object. + """ + raise NotImplementedError() + + async def deserialize_async(self, stream, content_type, content_length): + """Deserialize the :any:`falcon.Request` body. + + This method is similar to :py:meth:`~.BaseHandler.deserialize` except + that it is asynchronous. The default implementation adapts the + synchronous :py:meth:`~.BaseHandler.deserialize` method + via :py:class:`io.BytesIO`. For improved performance, media handlers should + override this method. + + Note: + By default, the :py:meth:`~.BaseHandler.deserialize` + method raises an instance of :py:class:`NotImplementedError`. + Therefore, child classes must either override + :py:meth:`~.BaseHandler.deserialize` or + :py:meth:`~.BaseHandler.deserialize_async` in order to be + compatible with ASGI apps. + Args: - stream (object): Input data to deserialize. + stream (object): Asynchronous file-like object to deserialize. content_type (str): Type of request content. content_length (int): Length of request content. Returns: object: A deserialized object. """ + + data = await stream.read() + + # NOTE(kgriffs): Override content length to make sure it is correct, + # since we know what it is in this case. + content_length = len(data) + + return self.deserialize(io.BytesIO(data), content_type, content_length) diff --git a/falcon/media/handlers.py b/falcon/media/handlers.py index fc1d92863..1a083da34 100644 --- a/falcon/media/handlers.py +++ b/falcon/media/handlers.py @@ -38,7 +38,7 @@ def _resolve_media_type(self, media_type, all_media_types): def find_by_media_type(self, media_type, default): # PERF(jmvrbanac): Check via a quick methods first for performance if media_type == '*/*' or not media_type: - return self.data[default] + media_type = default try: return self.data[media_type] diff --git a/falcon/media/json.py b/falcon/media/json.py index 73e3ac796..ef59e4648 100644 --- a/falcon/media/json.py +++ b/falcon/media/json.py @@ -78,6 +78,17 @@ def deserialize(self, stream, content_type, content_length): 'Could not parse JSON body - {0}'.format(err) ) + async def deserialize_async(self, stream, content_type, content_length): + data = await stream.read() + + try: + return self.loads(data.decode('utf-8')) + except ValueError as err: + raise errors.HTTPBadRequest( + 'Invalid JSON', + 'Could not parse JSON body - {0}'.format(err) + ) + def serialize(self, media, content_type): result = self.dumps(media) @@ -91,3 +102,11 @@ def serialize(self, media, content_type): pass return result + + async def serialize_async(self, media, content_type): + result = self.dumps(media) + + if not isinstance(result, bytes): + return result.encode('utf-8') + + return result diff --git a/falcon/media/msgpack.py b/falcon/media/msgpack.py index a2a5ceace..8c87b0e24 100644 --- a/falcon/media/msgpack.py +++ b/falcon/media/msgpack.py @@ -41,5 +41,21 @@ def deserialize(self, stream, content_type, content_length): 'Could not parse MessagePack body - {0}'.format(err) ) + async def deserialize_async(self, stream, content_type, content_length): + data = await stream.read() + + try: + # NOTE(jmvrbanac): Using unpackb since we would need to manage + # a buffer for Unpacker() which wouldn't gain us much. + return self.msgpack.unpackb(data, raw=False) + except ValueError as err: + raise errors.HTTPBadRequest( + 'Invalid MessagePack', + 'Could not parse MessagePack body - {0}'.format(err) + ) + def serialize(self, media, content_type): return self.packer.pack(media) + + async def serialize_async(self, media, content_type): + return self.packer.pack(media) diff --git a/falcon/media/urlencoded.py b/falcon/media/urlencoded.py index 5aa7dae2f..d5aa7e0b5 100644 --- a/falcon/media/urlencoded.py +++ b/falcon/media/urlencoded.py @@ -38,3 +38,15 @@ def deserialize(self, stream, content_type, content_length): return parse_query_string(body, keep_blank=self.keep_blank, csv=self.csv) + + async def deserialize_async(self, stream, content_type, content_length): + body = await stream.read() + + # NOTE(kgriffs): According to http://goo.gl/6rlcux the + # body should be US-ASCII. Enforcing this also helps + # catch malicious input. + body = body.decode('ascii') + + return parse_query_string(body, + keep_blank=self.keep_blank, + csv=self.csv) diff --git a/falcon/media/validators/jsonschema.py b/falcon/media/validators/jsonschema.py index eba5d03e0..9597ee4bd 100644 --- a/falcon/media/validators/jsonschema.py +++ b/falcon/media/validators/jsonschema.py @@ -1,4 +1,5 @@ from functools import wraps +from inspect import iscoroutinefunction import falcon @@ -46,36 +47,83 @@ def on_post(self, req, resp): """ def decorator(func): - @wraps(func) - def wrapper(self, req, resp, *args, **kwargs): - if req_schema is not None: - try: - jsonschema.validate( - req.media, req_schema, - format_checker=jsonschema.FormatChecker() - ) - except jsonschema.ValidationError as e: - raise falcon.HTTPBadRequest( - 'Request data failed validation', - description=e.message - ) - - result = func(self, req, resp, *args, **kwargs) - - if resp_schema is not None: - try: - jsonschema.validate( - resp.media, resp_schema, - format_checker=jsonschema.FormatChecker() - ) - except jsonschema.ValidationError: - raise falcon.HTTPInternalServerError( - 'Response data failed validation' - # Do not return 'e.message' in the response to - # prevent info about possible internal response - # formatting bugs from leaking out to users. - ) - - return result - return wrapper + if iscoroutinefunction(func): + return _validate_async(func, req_schema, resp_schema) + + return _validate(func, req_schema, resp_schema) + return decorator + + +def _validate(func, req_schema=None, resp_schema=None): + @wraps(func) + def wrapper(self, req, resp, *args, **kwargs): + if req_schema is not None: + try: + jsonschema.validate( + req.media, req_schema, + format_checker=jsonschema.FormatChecker() + ) + except jsonschema.ValidationError as e: + raise falcon.HTTPBadRequest( + 'Request data failed validation', + description=e.message + ) + + result = func(self, req, resp, *args, **kwargs) + + if resp_schema is not None: + try: + jsonschema.validate( + resp.media, resp_schema, + format_checker=jsonschema.FormatChecker() + ) + except jsonschema.ValidationError: + raise falcon.HTTPInternalServerError( + 'Response data failed validation' + # Do not return 'e.message' in the response to + # prevent info about possible internal response + # formatting bugs from leaking out to users. + ) + + return result + + return wrapper + + +def _validate_async(func, req_schema=None, resp_schema=None): + @wraps(func) + async def wrapper(self, req, resp, *args, **kwargs): + if req_schema is not None: + m = await req.get_media() + + try: + jsonschema.validate( + m, req_schema, + format_checker=jsonschema.FormatChecker() + ) + except jsonschema.ValidationError as e: + raise falcon.HTTPBadRequest( + 'Request data failed validation', + description=e.message + ) + + result = await func(self, req, resp, *args, **kwargs) + + if resp_schema is not None: + try: + jsonschema.validate( + resp.media, resp_schema, + format_checker=jsonschema.FormatChecker() + ) + except jsonschema.ValidationError: + raise falcon.HTTPInternalServerError( + 'Response data failed validation' + # Do not return 'e.message' in the response to + # prevent info about possible internal response + # formatting bugs from leaking out to users. + ) + + return result + + return wrapper diff --git a/falcon/middlewares.py b/falcon/middlewares.py index 97fb8dfad..37539d68e 100644 --- a/falcon/middlewares.py +++ b/falcon/middlewares.py @@ -27,3 +27,6 @@ def process_response(self, req, resp, resource, req_succeeded): resp.set_header('Access-Control-Allow-Methods', allow) resp.set_header('Access-Control-Allow-Headers', allow_headers) resp.set_header('Access-Control-Max-Age', '86400') # 24 hours + + async def process_response_async(self, *args): + self.process_response(*args) diff --git a/falcon/request.py b/falcon/request.py index 5efb7e17e..d2fa6c9a6 100644 --- a/falcon/request.py +++ b/falcon/request.py @@ -146,7 +146,7 @@ class Request: If the hostname in the request is an IP address, the value for `subdomain` is undefined. - app (str): The initial portion of the request URI's path that + root_path (str): The initial portion of the request URI's path that corresponds to the application object, so that the application knows its virtual "location". This may be an empty string, if the application corresponds to the "root" @@ -154,8 +154,9 @@ class Request: (Corresponds to the "SCRIPT_NAME" environ variable defined by PEP-3333.) + app (str): Deprecated alias for :attr:`root_path`. uri (str): The fully-qualified URI for the request. - url (str): Alias for `uri`. + url (str): Alias for :attr:`uri`. forwarded_uri (str): Original URI for proxied requests. Uses :attr:`forwarded_scheme` and :attr:`forwarded_host` in order to reconstruct the original URI requested by the user @@ -544,15 +545,8 @@ def forwarded(self): # At some point we might look into this but I don't think # it's worth it right now. if self._cached_forwarded is None: - # PERF(kgriffs): If someone is calling this, they are probably - # confident that the header exists, so most of the time we - # expect this call to succeed. Therefore, we won't need to - # pay the penalty of a raised exception in most cases, and - # there is no need to spend extra cycles calling get() or - # checking beforehand whether the key is in the dict. - try: - forwarded = self.env['HTTP_FORWARDED'] - except KeyError: + forwarded = self.get_header('Forwarded') + if forwarded is None: return None self._cached_forwarded = _parse_forwarded_header(forwarded) @@ -657,16 +651,16 @@ def if_unmodified_since(self): @property def range(self): - try: - value = self.env['HTTP_RANGE'] - if '=' in value: - unit, sep, req_range = value.partition('=') - else: - msg = "The value must be prefixed with a range unit, e.g. 'bytes='" - raise errors.HTTPInvalidHeader(msg, 'Range') - except KeyError: + value = self.get_header('Range') + if value is None: return None + if '=' in value: + unit, sep, req_range = value.partition('=') + else: + msg = "The value must be prefixed with a range unit, e.g. 'bytes='" + raise errors.HTTPInvalidHeader(msg, 'Range') + if ',' in req_range: msg = 'The value must be a continuous range.' raise errors.HTTPInvalidHeader(msg, 'Range') @@ -694,20 +688,19 @@ def range(self): @property def range_unit(self): - try: - value = self.env['HTTP_RANGE'] - - if '=' in value: - unit, sep, req_range = value.partition('=') - return unit - else: - msg = "The value must be prefixed with a range unit, e.g. 'bytes='" - raise errors.HTTPInvalidHeader(msg, 'Range') - except KeyError: + value = self.get_header('Range') + if value is None: return None + if value and '=' in value: + unit, sep, req_range = value.partition('=') + return unit + else: + msg = "The value must be prefixed with a range unit, e.g. 'bytes='" + raise errors.HTTPInvalidHeader(msg, 'Range') + @property - def app(self): + def root_path(self): # PERF(kgriffs): try..except is faster than get() assuming that # we normally expect the key to exist. Even though PEP-3333 # allows WSGI servers to omit the key when the value is an @@ -718,6 +711,8 @@ def app(self): except KeyError: return '' + app = root_path + @property def scheme(self): return self.env['wsgi.url_scheme'] @@ -913,16 +908,23 @@ def access_route(self): self._cached_access_route = [ip.strip() for ip in addresses] elif 'HTTP_X_REAL_IP' in self.env: self._cached_access_route = [self.env['HTTP_X_REAL_IP']] - elif 'REMOTE_ADDR' in self.env: - self._cached_access_route = [self.env['REMOTE_ADDR']] + + if self._cached_access_route: + if self._cached_access_route[-1] != self.remote_addr: + self._cached_access_route.append(self.remote_addr) else: - self._cached_access_route = [] + self._cached_access_route = [self.remote_addr] return self._cached_access_route @property def remote_addr(self): - return self.env.get('REMOTE_ADDR') + try: + value = self.env['REMOTE_ADDR'] + except KeyError: + value = '127.0.0.1' + + return value @property def port(self): @@ -966,7 +968,7 @@ def netloc(self): @property def media(self): - if self._media is not None or self.bounded_stream.is_exhausted: + if self._media is not None or self.bounded_stream.eof: return self._media handler = self.options.media_handlers.find_by_media_type( diff --git a/falcon/request_helpers.py b/falcon/request_helpers.py index 4b91b4bdd..bc86157ad 100644 --- a/falcon/request_helpers.py +++ b/falcon/request_helpers.py @@ -182,13 +182,19 @@ class BoundedStream(io.IOBase): This class normalizes *wsgi.input* behavior between WSGI servers by implementing non-blocking behavior for the cases mentioned - above. + above. The caller is not allowed to read more than the number of + bytes specified by the Content-Length header in the request. Args: stream: Instance of ``socket._fileobject`` from ``environ['wsgi.input']`` stream_len: Expected content length of the stream. + Attributes: + eof (bool): ``True`` if there is no more data to read from + the stream, otherwise ``False``. + is_exhausted (bool): Deprecated alias for `eof`. + """ def __init__(self, stream, stream_len): @@ -239,7 +245,7 @@ def seekable(self): """Always returns ``False``.""" return False - def writeable(self): + def writable(self): """Always returns ``False``.""" return False @@ -305,10 +311,11 @@ def exhaust(self, chunk_size=64 * 1024): break @property - def is_exhausted(self): - """If the stream is exhausted this attribute is ``True``.""" + def eof(self): return self._bytes_remaining <= 0 + is_exhausted = eof + # NOTE(kgriffs): Alias for backwards-compat Body = BoundedStream diff --git a/falcon/responders.py b/falcon/responders.py index 375405b67..52588373b 100644 --- a/falcon/responders.py +++ b/falcon/responders.py @@ -14,8 +14,6 @@ """Default responder implementations.""" -from functools import partial, update_wrapper - from falcon.errors import HTTPBadRequest from falcon.errors import HTTPMethodNotAllowed from falcon.errors import HTTPNotFound @@ -27,45 +25,65 @@ def path_not_found(req, resp, **kwargs): raise HTTPNotFound() +async def path_not_found_async(req, resp, **kwargs): + """Raise 404 HTTPNotFound error""" + raise HTTPNotFound() + + def bad_request(req, resp, **kwargs): """Raise 400 HTTPBadRequest error""" raise HTTPBadRequest('Bad request', 'Invalid HTTP method') -def method_not_allowed(allowed_methods, req, resp, **kwargs): - """Raise 405 HTTPMethodNotAllowed error""" - raise HTTPMethodNotAllowed(allowed_methods) +async def bad_request_async(req, resp, **kwargs): + """Raise 400 HTTPBadRequest error""" + raise HTTPBadRequest('Bad request', 'Invalid HTTP method') -def create_method_not_allowed(allowed_methods): +def create_method_not_allowed(allowed_methods, asgi=False): """Create a responder for "405 Method Not Allowed" Args: allowed_methods: A list of HTTP methods (uppercase) that should be returned in the Allow header. - + asgi (bool): ``True`` if using an ASGI app, ``False`` otherwise + (default ``False``). """ - partial_method_not_allowed = partial(method_not_allowed, allowed_methods) - update_wrapper(partial_method_not_allowed, method_not_allowed) - return partial_method_not_allowed + if asgi: + async def method_not_allowed_responder_async(req, resp, **kwargs): + raise HTTPMethodNotAllowed(allowed_methods) + + return method_not_allowed_responder_async -def on_options(allowed, req, resp, **kwargs): - """Default options responder.""" - resp.status = HTTP_200 - resp.set_header('Allow', allowed) - resp.set_header('Content-Length', '0') + def method_not_allowed(req, resp, **kwargs): + raise HTTPMethodNotAllowed(allowed_methods) + return method_not_allowed -def create_default_options(allowed_methods): + +def create_default_options(allowed_methods, asgi=False): """Create a default responder for the OPTIONS method Args: - allowed_methods: A list of HTTP methods (uppercase) that should be - returned in the Allow header. - + allowed_methods (iterable): An iterable of HTTP methods (uppercase) + that should be returned in the Allow header. + asgi (bool): ``True`` if using an ASGI app, ``False`` otherwise + (default ``False``). """ allowed = ', '.join(allowed_methods) - partial_on_options = partial(on_options, allowed) - update_wrapper(partial_on_options, on_options) - return partial_on_options + + if asgi: + async def options_responder_async(req, resp, **kwargs): + resp.status = HTTP_200 + resp.set_header('Allow', allowed) + resp.set_header('Content-Length', '0') + + return options_responder_async + + def options_responder(req, resp, **kwargs): + resp.status = HTTP_200 + resp.set_header('Allow', allowed) + resp.set_header('Content-Length', '0') + + return options_responder diff --git a/falcon/response.py b/falcon/response.py index 1cfb607ac..d1fb09999 100644 --- a/falcon/response.py +++ b/falcon/response.py @@ -191,6 +191,11 @@ def __init__(self, options=None): @property def data(self): + # TODO(kgriffs): Remove the side-effect that accessing this + # property causes (do something similar to what we did on the + # ASGI side). This will be a breaking change, so caution is + # advised. + # NOTE(kgriffs): Test explicitly against None since the # app may have set it to an empty binary string. if self._data is not None: @@ -238,6 +243,11 @@ def media(self, obj): # rather than serializing immediately. That way, if media() is called # multiple times we don't waste time serializing objects that will # just be thrown away. + # + # TODO(kgriffs): This makes precedence harder to reason about, since + # it is no longer about what attributes have and have not been set, + # but also what order they were set in. On the ASGI side this has + # already been addressed. self._data = None @property @@ -267,11 +277,12 @@ def set_stream(self, stream, content_length): Note: If the stream length is unknown, you can set `stream` directly, and ignore `content_length`. In this case, the - WSGI server may choose to use chunked encoding or one + server may choose to use chunked encoding or one of the other strategies suggested by PEP-3333. Args: - stream: A readable file-like object. + stream: A readable file-like object in the case of WSGI, or an + async iterable in the case of ASGI. content_length (int): Length of the stream, used for the Content-Length header in the response. """ diff --git a/falcon/routing/__init__.py b/falcon/routing/__init__.py index 90c5d0b9f..0ced173c4 100644 --- a/falcon/routing/__init__.py +++ b/falcon/routing/__init__.py @@ -20,7 +20,7 @@ """ from falcon.routing.compiled import CompiledRouter, CompiledRouterOptions # NOQA -from falcon.routing.static import StaticRoute # NOQA +from falcon.routing.static import StaticRoute, StaticRouteAsync # NOQA from falcon.routing.util import map_http_methods # NOQA from falcon.routing.util import set_default_responders # NOQA from falcon.routing.util import compile_uri_template # NOQA diff --git a/falcon/routing/compiled.py b/falcon/routing/compiled.py index 890788aa4..a5da83d8f 100644 --- a/falcon/routing/compiled.py +++ b/falcon/routing/compiled.py @@ -15,12 +15,14 @@ """Default routing engine.""" from collections import UserDict +from inspect import iscoroutinefunction import keyword import re import textwrap from falcon.routing import converters from falcon.routing.util import map_http_methods, set_default_responders +from falcon.util.sync import _should_wrap_non_coroutines, wrap_sync_to_async _TAB_STR = ' ' * 4 @@ -141,8 +143,18 @@ class can use suffixed responders to distinguish requests resource. """ + # NOTE(kgriffs): falcon.asgi.App injects this private kwarg; it is + # only intended to be used internally. + asgi = kwargs.get('_asgi', False) + method_map = self.map_http_methods(resource, **kwargs) - set_default_responders(method_map) + + set_default_responders(method_map, asgi=asgi) + + if asgi: + self._require_coroutine_responders(method_map) + else: + self._require_non_coroutine_responders(method_map) # NOTE(kgriffs): Fields may have whitespace in them, so sub # those before checking the rest of the URI template. @@ -227,6 +239,46 @@ def find(self, uri, req=None): # Private # ----------------------------------------------------------------- + def _require_coroutine_responders(self, method_map): + for method, responder in method_map.items(): + # NOTE(kgriffs): We don't simply wrap non-async functions + # since they likely peform relatively long blocking + # operations that need to be explicitly made non-blocking + # by the developer; raising an error helps highlight this + # issue. + + if not iscoroutinefunction(responder): + if _should_wrap_non_coroutines(): + def let(responder=responder): + method_map[method] = wrap_sync_to_async(responder) + + let() + + else: + msg = ( + 'The {} responder must be a non-blocking ' + 'async coroutine (i.e., defined using async def) to ' + 'avoid blocking the main request thread.' + ) + msg = msg.format(responder) + raise TypeError(msg) + + def _require_non_coroutine_responders(self, method_map): + for method, responder in method_map.items(): + # NOTE(kgriffs): We don't simply wrap non-async functions + # since they likely peform relatively long blocking + # operations that need to be explicitly made non-blocking + # by the developer; raising an error helps highlight this + # issue. + + if iscoroutinefunction(responder): + msg = ( + 'The {} responder must be a regular synchronous ' + 'method to be used with a WSGI app.' + ) + msg = msg.format(responder) + raise TypeError(msg) + def _validate_template_segment(self, segment, used_names): """Validates a single path segment of a URI template. diff --git a/falcon/routing/static.py b/falcon/routing/static.py index 0fcc5e309..c44e29fe6 100644 --- a/falcon/routing/static.py +++ b/falcon/routing/static.py @@ -1,8 +1,10 @@ +from functools import partial import io import os import re import falcon +from falcon.util.sync import get_loop class StaticRoute: @@ -122,3 +124,24 @@ def __call__(self, req, resp): if self._downloadable: resp.downloadable_as = os.path.basename(file_path) + + +class StaticRouteAsync(StaticRoute): + """Subclass of StaticRoute with modifications to support ASGI apps.""" + + async def __call__(self, req, resp): + super().__call__(req, resp) + + # NOTE(kgriffs): Fixup resp.stream so that it is non-blocking + resp.stream = _AsyncFileReader(resp.stream) + + +class _AsyncFileReader: + """Adapts a standard file I/O object so that reads are non-blocking.""" + + def __init__(self, file): + self._file = file + self._loop = get_loop() + + async def read(self, size=-1): + return await self._loop.run_in_executor(None, partial(self._file.read, size)) diff --git a/falcon/routing/util.py b/falcon/routing/util.py index 15a30e103..4455820af 100644 --- a/falcon/routing/util.py +++ b/falcon/routing/util.py @@ -130,12 +130,14 @@ def map_http_methods(resource, suffix=None): return method_map -def set_default_responders(method_map): +def set_default_responders(method_map, asgi=False): """Maps HTTP methods not explicitly defined on a resource to default responders. Args: method_map: A dict with HTTP methods mapped to responders explicitly defined in a resource. + asgi (bool): ``True`` if using an ASGI app, ``False`` otherwise + (default ``False``). """ # Attach a resource for unsupported HTTP methods @@ -143,11 +145,11 @@ def set_default_responders(method_map): if 'OPTIONS' not in method_map: # OPTIONS itself is intentionally excluded from the Allow header - opt_responder = responders.create_default_options(allowed_methods) + opt_responder = responders.create_default_options(allowed_methods, asgi=asgi) method_map['OPTIONS'] = opt_responder allowed_methods.append('OPTIONS') - na_responder = responders.create_method_not_allowed(allowed_methods) + na_responder = responders.create_method_not_allowed(allowed_methods, asgi=asgi) for method in constants.COMBINED_METHODS: if method not in allowed_methods: diff --git a/falcon/testing/__init__.py b/falcon/testing/__init__.py index 9540d8bf2..d0fe8ef7f 100644 --- a/falcon/testing/__init__.py +++ b/falcon/testing/__init__.py @@ -78,6 +78,6 @@ def test_get_message(client): from falcon.testing.client import * # NOQA from falcon.testing.helpers import * # NOQA from falcon.testing.resource import capture_responder_args, set_resp_defaults # NOQA -from falcon.testing.resource import SimpleTestResource # NOQA +from falcon.testing.resource import SimpleTestResource, SimpleTestResourceAsync # NOQA from falcon.testing.srmock import StartResponseMock # NOQA from falcon.testing.test_case import TestCase # NOQA diff --git a/falcon/testing/client.py b/falcon/testing/client.py index 58abd6c50..836c9d9d0 100644 --- a/falcon/testing/client.py +++ b/falcon/testing/client.py @@ -18,7 +18,10 @@ WSGI callable, without having to stand up a WSGI server. """ +import asyncio import datetime as dt +import inspect +import time from typing import Dict, Optional, Union import warnings import wsgiref.validate @@ -26,8 +29,15 @@ from falcon.constants import COMBINED_METHODS, MEDIA_JSON from falcon.testing import helpers from falcon.testing.srmock import StartResponseMock -from falcon.util import CaseInsensitiveDict, http_cookies, http_date_to_dt, to_query_str -from falcon.util import json as util_json +from falcon.util import ( + CaseInsensitiveDict, + code_to_http_status, + get_loop, + http_cookies, + http_date_to_dt, + json as util_json, + to_query_str, +) warnings.filterwarnings( 'error', @@ -124,7 +134,7 @@ def same_site(self) -> Optional[int]: class Result: - """Encapsulates the result of a simulated WSGI request. + """Encapsulates the result of a simulated request. Args: iterable (iterable): An iterable that yields zero or more @@ -163,7 +173,7 @@ class Result: if the response is not valid JSON. """ - def __init__(self, iterable, status, headers): + def __init__(self, iterable, status, headers, events=None): self._text = None self._content = b''.join(iterable) @@ -171,6 +181,7 @@ def __init__(self, iterable, status, headers): self._status = status self._status_code = int(status[:3]) self._headers = CaseInsensitiveDict(headers) + self._events = events or [] cookies = http_cookies.SimpleCookie() for name, value in headers: @@ -238,19 +249,25 @@ def json(self) -> Optional[Union[dict, list, str, int, float, bool]]: return util_json.loads(self.text) +# NOTE(kgriffs): The default of asgi_disconnect_ttl was chosen to be +# relatively long (5 minutes) to help testers notice when something +# appears to be "hanging", which might indicates that the app is +# not handling the reception of events correctly. def simulate_request(app, method='GET', path='/', query_string=None, headers=None, body=None, json=None, file_wrapper=None, wsgierrors=None, params=None, params_csv=True, protocol='http', host=helpers.DEFAULT_HOST, - remote_addr=None, extras=None) -> Result: - """Simulates a request to a WSGI application. + remote_addr=None, extras=None, http_version='1.1', + port=None, root_path=None, asgi_chunk_size=4096, + asgi_disconnect_ttl=300) -> Result: + + """Simulates a request to a WSGI or ASGI application. - Performs a request against a WSGI application. Uses - :any:`wsgiref.validate` to ensure the response is valid - WSGI. + Performs a request against a WSGI or ASGI application. In the case of + WSGI, uses :any:`wsgiref.validate` to ensure the response is valid. Keyword Args: - app (callable): The WSGI application to call + app (callable): The WSGI or ASGI application to call method (str): An HTTP method to use in the request (default: 'GET') path (str): The URL path to request (default: '/'). @@ -259,8 +276,17 @@ def simulate_request(app, method='GET', path='/', query_string=None, The path may contain a query string. However, neither `query_string` nor `params` may be specified in this case. + root_path (str): The initial portion of the request URL's "path" that + corresponds to the application object, so that the application + knows its virtual "location". This defaults to the empty string, + indicating that the application corresponds to the "root" of the + server. protocol: The protocol to use for the URL scheme (default: 'http') + port (int): The TCP port to simulate. Defaults to + the standard port used by the given scheme (i.e., 80 for 'http' + and 443 for 'https'). A string may also be passed, as long as + it can be parsed as an int. params (dict): A dictionary of query string parameters, where each key is a parameter name, and each value is either a ``str`` or something that can be converted @@ -275,27 +301,48 @@ def simulate_request(app, method='GET', path='/', query_string=None, query_string (str): A raw query string to include in the request (default: ``None``). If specified, overrides `params`. - headers (dict): Additional headers to include in the request - (default: ``None``) - body (str): A string to send as the body of the request. The value - will be encoded as UTF-8. + headers (dict): Extra headers as a dict-like (Mapping) object, or an + iterable yielding a series of two-member (*name*, *value*) + iterables. Each pair of strings provides the name and value + for an HTTP header. If desired, multiple header values may be + combined into a single (*name*, *value*) pair by joining the values + with a comma when the header in question supports the list + format (see also RFC 7230 and RFC 7231). Header names are not + case-sensitive. + body (str): The body of the request (default ''). The value will be + encoded as UTF-8 in the WSGI environ. Alternatively, a byte string + may be passed, in which case it will be used as-is. json(JSON serializable): A JSON document to serialize as the body of the request (default: ``None``). If specified, overrides `body` and the Content-Type header in `headers`. file_wrapper (callable): Callable that returns an iterable, to be used as the value for *wsgi.file_wrapper* in the - environ (default: ``None``). This can be used to test + WSGI environ (default: ``None``). This can be used to test high-performance file transmission when `resp.stream` is set to a file-like object. host(str): A string to use for the hostname part of the fully qualified request URL (default: 'falconframework.org') remote_addr (str): A string to use as the remote IP address for the - request (default: '127.0.0.1') - wsgierrors (io): The stream to use as *wsgierrors* - (default ``sys.stderr``) - extras (dict): Additional CGI variables to add to the WSGI - ``environ`` dictionary for the request (default: ``None``) + request (default: '127.0.0.1'). For WSGI, this corresponds to + the 'REMOTE_ADDR' environ variable. For ASGI, this corresponds + to the IP address used for the 'client' field in the connection + scope. + http_version (str): The HTTP version to simulate. Must be either + '2', '2.0', 1.1', '1.0', or '1' (default '1.1'). If set to '1.0', + the Host header will not be added to the scope. + wsgierrors (io): The stream to use as *wsgierrors* in the WSGI + environ (default ``sys.stderr``) + asgi_chunk_size (int): The maximum number of bytes that will be + sent to the ASGI app in a single 'http.request' event (default + 4096). + asgi_disconnect_ttl (int): The maximum number of seconds to wait + since the request was initiated, before emitting an + 'http.disconnect' event when the app calls the + receive() function (default 300). + extras (dict): Additional values to add to the WSGI + ``environ`` dictionary or the ASGI scope for the request + (default: ``None``) Returns: :py:class:`~.Result`: The result of the request @@ -316,15 +363,6 @@ def simulate_request(app, method='GET', path='/', query_string=None, raise ValueError("query_string should not start with '?'") extras = extras or {} - if 'REQUEST_METHOD' in extras and extras['REQUEST_METHOD'] != method: - # NOTE(vytas): Even given the duct tape nature of overriding - # arbitrary environ variables, changing the method can potentially - # be very confusing, particularly when using specialized - # simulate_get/post/patch etc methods. - raise ValueError( - 'environ extras may not override the request method. Please ' - 'use the method parameter.' - ) if query_string is None: query_string = to_query_str( @@ -338,48 +376,157 @@ def simulate_request(app, method='GET', path='/', query_string=None, headers = headers or {} headers['Content-Type'] = MEDIA_JSON - env = helpers.create_environ( - method=method, - scheme=protocol, + if not _is_asgi_app(app): + env = helpers.create_environ( + method=method, + scheme=protocol, + path=path, + query_string=(query_string or ''), + headers=headers, + body=body, + file_wrapper=file_wrapper, + host=host, + remote_addr=remote_addr, + wsgierrors=wsgierrors, + http_version=http_version, + port=port, + root_path=root_path, + ) + + if 'REQUEST_METHOD' in extras and extras['REQUEST_METHOD'] != method: + # NOTE(vytas): Even given the duct tape nature of overriding + # arbitrary environ variables, changing the method can potentially + # be very confusing, particularly when using specialized + # simulate_get/post/patch etc methods. + raise ValueError( + 'WSGI environ extras may not override the request method. ' + 'Please use the method parameter.' + ) + + env.update(extras) + + srmock = StartResponseMock() + validator = wsgiref.validate.validator(app) + + iterable = validator(env, srmock) + + return Result(helpers.closed_wsgi_iterable(iterable), + srmock.status, srmock.headers) + + # --------------------------------------------------------------------- + # NOTE(kgriffs): 'lifespan' scope + # --------------------------------------------------------------------- + + lifespan_scope = { + 'type': 'lifespan', + 'asgi': { + 'version': '3.0', + 'spec_version': '2.0', + }, + } + + shutting_down = asyncio.Condition() + lifespan_event_emitter = helpers.ASGILifespanEventEmitter(shutting_down) + lifespan_event_collector = helpers.ASGIResponseEventCollector() + + # --------------------------------------------------------------------- + # NOTE(kgriffs): 'http' scope + # --------------------------------------------------------------------- + + content_length = None + + if body is not None: + if isinstance(body, str): + body = body.encode() + + content_length = len(body) + + http_scope = helpers.create_scope( path=path, - query_string=(query_string or ''), + query_string=query_string, + method=method, headers=headers, - body=body, - file_wrapper=file_wrapper, host=host, + scheme=protocol, + port=port, + http_version=http_version, remote_addr=remote_addr, - wsgierrors=wsgierrors, + root_path=root_path, + content_length=content_length, ) - if extras: - env.update(extras) - srmock = StartResponseMock() - validator = wsgiref.validate.validator(app) + if 'method' in extras and extras['method'] != method.upper(): + raise ValueError( + 'ASGI scope extras may not override the request method. ' + 'Please use the method parameter.' + ) + + http_scope.update(extras) + + disconnect_at = time.time() + max(0, asgi_disconnect_ttl) + req_event_emitter = helpers.ASGIRequestEventEmitter( + (body or b''), + disconnect_at, + chunk_size=asgi_chunk_size + ) + + resp_event_collector = helpers.ASGIResponseEventCollector() + + async def conductor(): + # NOTE(kgriffs): We assume this is a Falcon ASGI app, which supports + # the lifespan protocol and thus we do not need to catch + # exceptions that would signify no lifespan protocol support. + t = get_loop().create_task( + app(lifespan_scope, lifespan_event_emitter, lifespan_event_collector) + ) + + await _wait_for_startup(lifespan_event_collector.events) + + await app(http_scope, req_event_emitter, resp_event_collector) + + # NOTE(kgriffs): Notify lifespan_event_emitter that it is OK + # to proceed. + async with shutting_down: + shutting_down.notify() - iterable = validator(env, srmock) + await _wait_for_shutdown(lifespan_event_collector.events) + await t - result = Result(helpers.closed_wsgi_iterable(iterable), - srmock.status, srmock.headers) + helpers.invoke_coroutine_sync(conductor) - return result + return Result(resp_event_collector.body_chunks, + code_to_http_status(resp_event_collector.status), + resp_event_collector.headers, + events=resp_event_collector.events) def simulate_get(app, path, **kwargs) -> Result: - """Simulates a GET request to a WSGI application. + """Simulates a GET request to a WSGI or ASGI application. Equivalent to:: simulate_request(app, 'GET', path, **kwargs) Args: - app (callable): The WSGI application to call - path (str): The URL path to request. + app (callable): The application to call + path (str): The URL path to request Note: The path may contain a query string. However, neither `query_string` nor `params` may be specified in this case. Keyword Args: + root_path (str): The initial portion of the request URL's "path" that + corresponds to the application object, so that the application + knows its virtual "location". This defaults to the empty string, + indicating that the application corresponds to the "root" of the + server. + protocol: The protocol to use for the URL scheme + (default: 'http') + port (int): The TCP port to simulate. Defaults to + the standard port used by the given scheme (i.e., 80 for 'http' + and 443 for 'https'). A string may also be passed, as long as + it can be parsed as an int. params (dict): A dictionary of query string parameters, where each key is a parameter name, and each value is either a ``str`` or something that can be converted @@ -394,44 +541,76 @@ def simulate_get(app, path, **kwargs) -> Result: query_string (str): A raw query string to include in the request (default: ``None``). If specified, overrides `params`. - headers (dict): Additional headers to include in the request - (default: ``None``) + headers (dict): Extra headers as a dict-like (Mapping) object, or an + iterable yielding a series of two-member (*name*, *value*) + iterables. Each pair of strings provides the name and value + for an HTTP header. If desired, multiple header values may be + combined into a single (*name*, *value*) pair by joining the values + with a comma when the header in question supports the list + format (see also RFC 7230 and RFC 7231). Header names are not + case-sensitive. file_wrapper (callable): Callable that returns an iterable, to be used as the value for *wsgi.file_wrapper* in the - environ (default: ``None``). This can be used to test + WSGI environ (default: ``None``). This can be used to test high-performance file transmission when `resp.stream` is set to a file-like object. - protocol: The protocol to use for the URL scheme - (default: 'http') - host(str): A string to use for the hostname part of the fully qualified - request URL (default: 'falconframework.org') + host(str): A string to use for the hostname part of the fully + qualified request URL (default: 'falconframework.org') remote_addr (str): A string to use as the remote IP address for the - request (default: '127.0.0.1') - extras (dict): Additional CGI variables to add to the WSGI ``environ`` - dictionary for the request (default: ``None``) + request (default: '127.0.0.1'). For WSGI, this corresponds to + the 'REMOTE_ADDR' environ variable. For ASGI, this corresponds + to the IP address used for the 'client' field in the connection + scope. + http_version (str): The HTTP version to simulate. Must be either + '2', '2.0', 1.1', '1.0', or '1' (default '1.1'). If set to '1.0', + the Host header will not be added to the scope. + wsgierrors (io): The stream to use as *wsgierrors* in the WSGI + environ (default ``sys.stderr``) + asgi_chunk_size (int): The maximum number of bytes that will be + sent to the ASGI app in a single 'http.request' event (default + 4096). + asgi_disconnect_ttl (int): The maximum number of seconds to wait + since the request was initiated, before emitting an + 'http.disconnect' event when the app calls the + receive() function (default 300). + extras (dict): Additional values to add to the WSGI + ``environ`` dictionary or the ASGI scope for the request + (default: ``None``) Returns: :py:class:`~.Result`: The result of the request """ + return simulate_request(app, 'GET', path, **kwargs) def simulate_head(app, path, **kwargs) -> Result: - """Simulates a HEAD request to a WSGI application. + """Simulates a HEAD request to a WSGI or ASGI application. Equivalent to:: simulate_request(app, 'HEAD', path, **kwargs) Args: - app (callable): The WSGI application to call - path (str): The URL path to request. + app (callable): The application to call + path (str): The URL path to request Note: The path may contain a query string. However, neither `query_string` nor `params` may be specified in this case. Keyword Args: + root_path (str): The initial portion of the request URL's "path" that + corresponds to the application object, so that the application + knows its virtual "location". This defaults to the empty string, + indicating that the application corresponds to the "root" of the + server. + protocol: The protocol to use for the URL scheme + (default: 'http') + port (int): The TCP port to simulate. Defaults to + the standard port used by the given scheme (i.e., 80 for 'http' + and 443 for 'https'). A string may also be passed, as long as + it can be parsed as an int. params (dict): A dictionary of query string parameters, where each key is a parameter name, and each value is either a ``str`` or something that can be converted @@ -446,16 +625,36 @@ def simulate_head(app, path, **kwargs) -> Result: query_string (str): A raw query string to include in the request (default: ``None``). If specified, overrides `params`. - headers (dict): Additional headers to include in the request - (default: ``None``) - protocol: The protocol to use for the URL scheme - (default: 'http') - host(str): A string to use for the hostname part of the fully qualified - request URL (default: 'falconframework.org') + headers (dict): Extra headers as a dict-like (Mapping) object, or an + iterable yielding a series of two-member (*name*, *value*) + iterables. Each pair of strings provides the name and value + for an HTTP header. If desired, multiple header values may be + combined into a single (*name*, *value*) pair by joining the values + with a comma when the header in question supports the list + format (see also RFC 7230 and RFC 7231). Header names are not + case-sensitive. + host(str): A string to use for the hostname part of the fully + qualified request URL (default: 'falconframework.org') remote_addr (str): A string to use as the remote IP address for the - request (default: '127.0.0.1') - extras (dict): Additional CGI variables to add to the WSGI ``environ`` - dictionary for the request (default: ``None``) + request (default: '127.0.0.1'). For WSGI, this corresponds to + the 'REMOTE_ADDR' environ variable. For ASGI, this corresponds + to the IP address used for the 'client' field in the connection + scope. + http_version (str): The HTTP version to simulate. Must be either + '2', '2.0', 1.1', '1.0', or '1' (default '1.1'). If set to '1.0', + the Host header will not be added to the scope. + wsgierrors (io): The stream to use as *wsgierrors* in the WSGI + environ (default ``sys.stderr``) + asgi_chunk_size (int): The maximum number of bytes that will be + sent to the ASGI app in a single 'http.request' event (default + 4096). + asgi_disconnect_ttl (int): The maximum number of seconds to wait + since the request was initiated, before emitting an + 'http.disconnect' event when the app calls the + receive() function (default 300). + extras (dict): Additional values to add to the WSGI + ``environ`` dictionary or the ASGI scope for the request + (default: ``None``) Returns: :py:class:`~.Result`: The result of the request @@ -464,17 +663,28 @@ def simulate_head(app, path, **kwargs) -> Result: def simulate_post(app, path, **kwargs) -> Result: - """Simulates a POST request to a WSGI application. + """Simulates a POST request to a WSGI or ASGI application. Equivalent to:: simulate_request(app, 'POST', path, **kwargs) Args: - app (callable): The WSGI application to call + app (callable): The application to call path (str): The URL path to request Keyword Args: + root_path (str): The initial portion of the request URL's "path" that + corresponds to the application object, so that the application + knows its virtual "location". This defaults to the empty string, + indicating that the application corresponds to the "root" of the + server. + protocol: The protocol to use for the URL scheme + (default: 'http') + port (int): The TCP port to simulate. Defaults to + the standard port used by the given scheme (i.e., 80 for 'http' + and 443 for 'https'). A string may also be passed, as long as + it can be parsed as an int. params (dict): A dictionary of query string parameters, where each key is a parameter name, and each value is either a ``str`` or something that can be converted @@ -486,22 +696,51 @@ def simulate_post(app, path, **kwargs) -> Result: of the parameter (e.g., 'thing=1&thing=2&thing=3'). Otherwise, parameters will be encoded as comma-separated values (e.g., 'thing=1,2,3'). Defaults to ``True``. - headers (dict): Additional headers to include in the request - (default: ``None``) - body (str): A string to send as the body of the request. The value - will be encoded as UTF-8. + query_string (str): A raw query string to include in the + request (default: ``None``). If specified, overrides + `params`. + headers (dict): Extra headers as a dict-like (Mapping) object, or an + iterable yielding a series of two-member (*name*, *value*) + iterables. Each pair of strings provides the name and value + for an HTTP header. If desired, multiple header values may be + combined into a single (*name*, *value*) pair by joining the values + with a comma when the header in question supports the list + format (see also RFC 7230 and RFC 7231). Header names are not + case-sensitive. + body (str): The body of the request (default ''). The value will be + encoded as UTF-8 in the WSGI environ. Alternatively, a byte string + may be passed, in which case it will be used as-is. json(JSON serializable): A JSON document to serialize as the body of the request (default: ``None``). If specified, overrides `body` and the Content-Type header in `headers`. - protocol: The protocol to use for the URL scheme - (default: 'http') - host(str): A string to use for the hostname part of the fully qualified - request URL (default: 'falconframework.org') + file_wrapper (callable): Callable that returns an iterable, + to be used as the value for *wsgi.file_wrapper* in the + WSGI environ (default: ``None``). This can be used to test + high-performance file transmission when `resp.stream` is + set to a file-like object. + host(str): A string to use for the hostname part of the fully + qualified request URL (default: 'falconframework.org') remote_addr (str): A string to use as the remote IP address for the - request (default: '127.0.0.1') - extras (dict): Additional CGI variables to add to the WSGI ``environ`` - dictionary for the request (default: ``None``) + request (default: '127.0.0.1'). For WSGI, this corresponds to + the 'REMOTE_ADDR' environ variable. For ASGI, this corresponds + to the IP address used for the 'client' field in the connection + scope. + http_version (str): The HTTP version to simulate. Must be either + '2', '2.0', 1.1', '1.0', or '1' (default '1.1'). If set to '1.0', + the Host header will not be added to the scope. + wsgierrors (io): The stream to use as *wsgierrors* in the WSGI + environ (default ``sys.stderr``) + asgi_chunk_size (int): The maximum number of bytes that will be + sent to the ASGI app in a single 'http.request' event (default + 4096). + asgi_disconnect_ttl (int): The maximum number of seconds to wait + since the request was initiated, before emitting an + 'http.disconnect' event when the app calls the + receive() function (default 300). + extras (dict): Additional values to add to the WSGI + ``environ`` dictionary or the ASGI scope for the request + (default: ``None``) Returns: :py:class:`~.Result`: The result of the request @@ -510,17 +749,28 @@ def simulate_post(app, path, **kwargs) -> Result: def simulate_put(app, path, **kwargs) -> Result: - """Simulates a PUT request to a WSGI application. + """Simulates a PUT request to a WSGI or ASGI application. Equivalent to:: simulate_request(app, 'PUT', path, **kwargs) Args: - app (callable): The WSGI application to call + app (callable): The application to call path (str): The URL path to request Keyword Args: + root_path (str): The initial portion of the request URL's "path" that + corresponds to the application object, so that the application + knows its virtual "location". This defaults to the empty string, + indicating that the application corresponds to the "root" of the + server. + protocol: The protocol to use for the URL scheme + (default: 'http') + port (int): The TCP port to simulate. Defaults to + the standard port used by the given scheme (i.e., 80 for 'http' + and 443 for 'https'). A string may also be passed, as long as + it can be parsed as an int. params (dict): A dictionary of query string parameters, where each key is a parameter name, and each value is either a ``str`` or something that can be converted @@ -532,22 +782,51 @@ def simulate_put(app, path, **kwargs) -> Result: of the parameter (e.g., 'thing=1&thing=2&thing=3'). Otherwise, parameters will be encoded as comma-separated values (e.g., 'thing=1,2,3'). Defaults to ``True``. - headers (dict): Additional headers to include in the request - (default: ``None``) - body (str): A string to send as the body of the request. The value - will be encoded as UTF-8. + query_string (str): A raw query string to include in the + request (default: ``None``). If specified, overrides + `params`. + headers (dict): Extra headers as a dict-like (Mapping) object, or an + iterable yielding a series of two-member (*name*, *value*) + iterables. Each pair of strings provides the name and value + for an HTTP header. If desired, multiple header values may be + combined into a single (*name*, *value*) pair by joining the values + with a comma when the header in question supports the list + format (see also RFC 7230 and RFC 7231). Header names are not + case-sensitive. + body (str): The body of the request (default ''). The value will be + encoded as UTF-8 in the WSGI environ. Alternatively, a byte string + may be passed, in which case it will be used as-is. json(JSON serializable): A JSON document to serialize as the body of the request (default: ``None``). If specified, overrides `body` and the Content-Type header in `headers`. - protocol: The protocol to use for the URL scheme - (default: 'http') - host(str): A string to use for the hostname part of the fully qualified - request URL (default: 'falconframework.org') + file_wrapper (callable): Callable that returns an iterable, + to be used as the value for *wsgi.file_wrapper* in the + WSGI environ (default: ``None``). This can be used to test + high-performance file transmission when `resp.stream` is + set to a file-like object. + host(str): A string to use for the hostname part of the fully + qualified request URL (default: 'falconframework.org') remote_addr (str): A string to use as the remote IP address for the - request (default: '127.0.0.1') - extras (dict): Additional CGI variables to add to the WSGI ``environ`` - dictionary for the request (default: ``None``) + request (default: '127.0.0.1'). For WSGI, this corresponds to + the 'REMOTE_ADDR' environ variable. For ASGI, this corresponds + to the IP address used for the 'client' field in the connection + scope. + http_version (str): The HTTP version to simulate. Must be either + '2', '2.0', 1.1', '1.0', or '1' (default '1.1'). If set to '1.0', + the Host header will not be added to the scope. + wsgierrors (io): The stream to use as *wsgierrors* in the WSGI + environ (default ``sys.stderr``) + asgi_chunk_size (int): The maximum number of bytes that will be + sent to the ASGI app in a single 'http.request' event (default + 4096). + asgi_disconnect_ttl (int): The maximum number of seconds to wait + since the request was initiated, before emitting an + 'http.disconnect' event when the app calls the + receive() function (default 300). + extras (dict): Additional values to add to the WSGI + ``environ`` dictionary or the ASGI scope for the request + (default: ``None``) Returns: :py:class:`~.Result`: The result of the request @@ -556,17 +835,28 @@ def simulate_put(app, path, **kwargs) -> Result: def simulate_options(app, path, **kwargs) -> Result: - """Simulates an OPTIONS request to a WSGI application. + """Simulates an OPTIONS request to a WSGI or ASGI application. Equivalent to:: simulate_request(app, 'OPTIONS', path, **kwargs) Args: - app (callable): The WSGI application to call + app (callable): The application to call path (str): The URL path to request Keyword Args: + root_path (str): The initial portion of the request URL's "path" that + corresponds to the application object, so that the application + knows its virtual "location". This defaults to the empty string, + indicating that the application corresponds to the "root" of the + server. + protocol: The protocol to use for the URL scheme + (default: 'http') + port (int): The TCP port to simulate. Defaults to + the standard port used by the given scheme (i.e., 80 for 'http' + and 443 for 'https'). A string may also be passed, as long as + it can be parsed as an int. params (dict): A dictionary of query string parameters, where each key is a parameter name, and each value is either a ``str`` or something that can be converted @@ -578,16 +868,39 @@ def simulate_options(app, path, **kwargs) -> Result: of the parameter (e.g., 'thing=1&thing=2&thing=3'). Otherwise, parameters will be encoded as comma-separated values (e.g., 'thing=1,2,3'). Defaults to ``True``. - headers (dict): Additional headers to include in the request - (default: ``None``) - protocol: The protocol to use for the URL scheme - (default: 'http') - host(str): A string to use for the hostname part of the fully qualified - request URL (default: 'falconframework.org') + query_string (str): A raw query string to include in the + request (default: ``None``). If specified, overrides + `params`. + headers (dict): Extra headers as a dict-like (Mapping) object, or an + iterable yielding a series of two-member (*name*, *value*) + iterables. Each pair of strings provides the name and value + for an HTTP header. If desired, multiple header values may be + combined into a single (*name*, *value*) pair by joining the values + with a comma when the header in question supports the list + format (see also RFC 7230 and RFC 7231). Header names are not + case-sensitive. + host(str): A string to use for the hostname part of the fully + qualified request URL (default: 'falconframework.org') remote_addr (str): A string to use as the remote IP address for the - request (default: '127.0.0.1') - extras (dict): Additional CGI variables to add to the WSGI ``environ`` - dictionary for the request (default: ``None``) + request (default: '127.0.0.1'). For WSGI, this corresponds to + the 'REMOTE_ADDR' environ variable. For ASGI, this corresponds + to the IP address used for the 'client' field in the connection + scope. + http_version (str): The HTTP version to simulate. Must be either + '2', '2.0', 1.1', '1.0', or '1' (default '1.1'). If set to '1.0', + the Host header will not be added to the scope. + wsgierrors (io): The stream to use as *wsgierrors* in the WSGI + environ (default ``sys.stderr``) + asgi_chunk_size (int): The maximum number of bytes that will be + sent to the ASGI app in a single 'http.request' event (default + 4096). + asgi_disconnect_ttl (int): The maximum number of seconds to wait + since the request was initiated, before emitting an + 'http.disconnect' event when the app calls the + receive() function (default 300). + extras (dict): Additional values to add to the WSGI + ``environ`` dictionary or the ASGI scope for the request + (default: ``None``) Returns: :py:class:`~.Result`: The result of the request @@ -596,17 +909,28 @@ def simulate_options(app, path, **kwargs) -> Result: def simulate_patch(app, path, **kwargs) -> Result: - """Simulates a PATCH request to a WSGI application. + """Simulates a PATCH request to a WSGI or ASGI application. Equivalent to:: simulate_request(app, 'PATCH', path, **kwargs) Args: - app (callable): The WSGI application to call + app (callable): The application to call path (str): The URL path to request Keyword Args: + root_path (str): The initial portion of the request URL's "path" that + corresponds to the application object, so that the application + knows its virtual "location". This defaults to the empty string, + indicating that the application corresponds to the "root" of the + server. + protocol: The protocol to use for the URL scheme + (default: 'http') + port (int): The TCP port to simulate. Defaults to + the standard port used by the given scheme (i.e., 80 for 'http' + and 443 for 'https'). A string may also be passed, as long as + it can be parsed as an int. params (dict): A dictionary of query string parameters, where each key is a parameter name, and each value is either a ``str`` or something that can be converted @@ -618,22 +942,46 @@ def simulate_patch(app, path, **kwargs) -> Result: of the parameter (e.g., 'thing=1&thing=2&thing=3'). Otherwise, parameters will be encoded as comma-separated values (e.g., 'thing=1,2,3'). Defaults to ``True``. - headers (dict): Additional headers to include in the request - (default: ``None``) - body (str): A string to send as the body of the request. The value - will be encoded as UTF-8. + query_string (str): A raw query string to include in the + request (default: ``None``). If specified, overrides + `params`. + headers (dict): Extra headers as a dict-like (Mapping) object, or an + iterable yielding a series of two-member (*name*, *value*) + iterables. Each pair of strings provides the name and value + for an HTTP header. If desired, multiple header values may be + combined into a single (*name*, *value*) pair by joining the values + with a comma when the header in question supports the list + format (see also RFC 7230 and RFC 7231). Header names are not + case-sensitive. + body (str): The body of the request (default ''). The value will be + encoded as UTF-8 in the WSGI environ. Alternatively, a byte string + may be passed, in which case it will be used as-is. json(JSON serializable): A JSON document to serialize as the body of the request (default: ``None``). If specified, overrides `body` and the Content-Type header in `headers`. - protocol: The protocol to use for the URL scheme - (default: 'http') - host(str): A string to use for the hostname part of the fully qualified - request URL (default: 'falconframework.org') + host(str): A string to use for the hostname part of the fully + qualified request URL (default: 'falconframework.org') remote_addr (str): A string to use as the remote IP address for the - request (default: '127.0.0.1') - extras (dict): Additional CGI variables to add to the WSGI ``environ`` - dictionary for the request (default: ``None``) + request (default: '127.0.0.1'). For WSGI, this corresponds to + the 'REMOTE_ADDR' environ variable. For ASGI, this corresponds + to the IP address used for the 'client' field in the connection + scope. + http_version (str): The HTTP version to simulate. Must be either + '2', '2.0', 1.1', '1.0', or '1' (default '1.1'). If set to '1.0', + the Host header will not be added to the scope. + wsgierrors (io): The stream to use as *wsgierrors* in the WSGI + environ (default ``sys.stderr``) + asgi_chunk_size (int): The maximum number of bytes that will be + sent to the ASGI app in a single 'http.request' event (default + 4096). + asgi_disconnect_ttl (int): The maximum number of seconds to wait + since the request was initiated, before emitting an + 'http.disconnect' event when the app calls the + receive() function (default 300). + extras (dict): Additional values to add to the WSGI + ``environ`` dictionary or the ASGI scope for the request + (default: ``None``) Returns: :py:class:`~.Result`: The result of the request @@ -642,17 +990,28 @@ def simulate_patch(app, path, **kwargs) -> Result: def simulate_delete(app, path, **kwargs) -> Result: - """Simulates a DELETE request to a WSGI application. + """Simulates a DELETE request to a WSGI or ASGI application. Equivalent to:: simulate_request(app, 'DELETE', path, **kwargs) Args: - app (callable): The WSGI application to call + app (callable): The application to call path (str): The URL path to request Keyword Args: + root_path (str): The initial portion of the request URL's "path" that + corresponds to the application object, so that the application + knows its virtual "location". This defaults to the empty string, + indicating that the application corresponds to the "root" of the + server. + protocol: The protocol to use for the URL scheme + (default: 'http') + port (int): The TCP port to simulate. Defaults to + the standard port used by the given scheme (i.e., 80 for 'http' + and 443 for 'https'). A string may also be passed, as long as + it can be parsed as an int. params (dict): A dictionary of query string parameters, where each key is a parameter name, and each value is either a ``str`` or something that can be converted @@ -664,16 +1023,46 @@ def simulate_delete(app, path, **kwargs) -> Result: of the parameter (e.g., 'thing=1&thing=2&thing=3'). Otherwise, parameters will be encoded as comma-separated values (e.g., 'thing=1,2,3'). Defaults to ``True``. - headers (dict): Additional headers to include in the request - (default: ``None``) - protocol: The protocol to use for the URL scheme - (default: 'http') - host(str): A string to use for the hostname part of the fully qualified - request URL (default: 'falconframework.org') + query_string (str): A raw query string to include in the + request (default: ``None``). If specified, overrides + `params`. + headers (dict): Extra headers as a dict-like (Mapping) object, or an + iterable yielding a series of two-member (*name*, *value*) + iterables. Each pair of strings provides the name and value + for an HTTP header. If desired, multiple header values may be + combined into a single (*name*, *value*) pair by joining the values + with a comma when the header in question supports the list + format (see also RFC 7230 and RFC 7231). Header names are not + case-sensitive. + body (str): The body of the request (default ''). The value will be + encoded as UTF-8 in the WSGI environ. Alternatively, a byte string + may be passed, in which case it will be used as-is. + json(JSON serializable): A JSON document to serialize as the + body of the request (default: ``None``). If specified, + overrides `body` and the Content-Type header in + `headers`. + host(str): A string to use for the hostname part of the fully + qualified request URL (default: 'falconframework.org') remote_addr (str): A string to use as the remote IP address for the - request (default: '127.0.0.1') - extras (dict): Additional CGI variables to add to the WSGI ``environ`` - dictionary for the request (default: ``None``) + request (default: '127.0.0.1'). For WSGI, this corresponds to + the 'REMOTE_ADDR' environ variable. For ASGI, this corresponds + to the IP address used for the 'client' field in the connection + scope. + http_version (str): The HTTP version to simulate. Must be either + '2', '2.0', 1.1', '1.0', or '1' (default '1.1'). If set to '1.0', + the Host header will not be added to the scope. + wsgierrors (io): The stream to use as *wsgierrors* in the WSGI + environ (default ``sys.stderr``) + asgi_chunk_size (int): The maximum number of bytes that will be + sent to the ASGI app in a single 'http.request' event (default + 4096). + asgi_disconnect_ttl (int): The maximum number of seconds to wait + since the request was initiated, before emitting an + 'http.disconnect' event when the app calls the + receive() function (default 300). + extras (dict): Additional values to add to the WSGI + ``environ`` dictionary or the ASGI scope for the request + (default: ``None``) Returns: :py:class:`~.Result`: The result of the request @@ -682,7 +1071,7 @@ def simulate_delete(app, path, **kwargs) -> Result: class TestClient: - """Simulates requests to a WSGI application. + """Simulates requests to a WSGI or ASGI application. This class provides a contextual wrapper for Falcon's `simulate_*` test functions. It lets you replace this:: @@ -701,14 +1090,17 @@ class TestClient: overriding of request preparation by child classes. Args: - app (callable): A WSGI application to target when simulating + app (callable): A WSGI or ASGI application to target when simulating requests - Keyword Arguments: headers (dict): Default headers to set on every request (default ``None``). These defaults may be overridden by passing values for the same headers to one of the `simulate_*()` methods. + + Attributes: + app: The app that this client instance was configured to use. + """ def __init__(self, app, headers=None): @@ -784,3 +1176,53 @@ def simulate_request(self, *args, **kwargs) -> Result: kwargs['headers'] = merged_headers return simulate_request(self.app, *args, **kwargs) + + +# ----------------------------------------------------------------------------- +# Private +# ----------------------------------------------------------------------------- + + +def _is_asgi_app(app): + app_args = inspect.getfullargspec(app).args + num_app_args = len(app_args) + + # NOTE(kgriffs): Technically someone could name the "self" or "cls" + # arg something else, but we will make the simplifying + # assumption that this is rare enough to not worry about. + if app_args[0] in {'cls', 'self'}: + num_app_args -= 1 + + is_asgi = (num_app_args == 3) + + return is_asgi + + +async def _wait_for_startup(events): + # NOTE(kgriffs): This is covered, but our gate for some reason doesn't + # understand `while True`. + while True: # pragma: nocover + for e in events: + if e['type'] == 'lifespan.startup.failed': + raise RuntimeError('ASGI app returned lifespan.startup.failed. ' + e['message']) + + if any(e['type'] == 'lifespan.startup.complete' for e in events): + break + + # NOTE(kgriffs): Yield to the concurrent lifespan task + await asyncio.sleep(0.001) + + +async def _wait_for_shutdown(events): + # NOTE(kgriffs): This is covered, but our gate for some reason doesn't + # understand `while True`. + while True: # pragma: nocover + for e in events: + if e['type'] == 'lifespan.shutdown.failed': + raise RuntimeError('ASGI app returned lifespan.shutdown.failed. ' + e['message']) + + if any(e['type'] == 'lifespan.shutdown.complete' for e in events): + break + + # NOTE(kgriffs): Yield to the concurrent lifespan task + await asyncio.sleep(0.001) diff --git a/falcon/testing/helpers.py b/falcon/testing/helpers.py index f8f542669..212270cc3 100644 --- a/falcon/testing/helpers.py +++ b/falcon/testing/helpers.py @@ -23,14 +23,19 @@ """ +import asyncio import cgi import contextlib +import functools import io import itertools import random import sys +import time from typing import Any, Dict +from falcon.constants import SINGLETON_HEADERS +import falcon.request from falcon.util import http_now, uri # Constants @@ -40,6 +45,219 @@ httpnow = http_now +class ASGILifespanEventEmitter: + def __init__(self, shutting_down): + self._state = 0 + self._shutting_down = shutting_down + + async def emit(self): + if self._state == 0: + self._state += 1 + return {'type': 'lifespan.startup'} + + if self._state == 1: + self._state += 1 + # NOTE(kgriffs): This ensures the app ignores events it does + # not recognize. + return {'type': 'lifespan._nonstandard_event'} + + async with self._shutting_down: + await self._shutting_down.wait() + + return {'type': 'lifespan.shutdown'} + + __call__ = emit + + +class ASGIRequestEventEmitter: + """Emits events on-demand to an ASGI app. + + Keyword Args: + body (str): The body content to use when emitting http.request + events. May be an empty string. If a byte string, it will + be used as-is; otherwise it will be encoded as UTF-8 + (default b''). + disconnect_at (int): The Unix timestamp after which to begin + returning http.disconnect events (default now + 30s). + chunk_size (int): The maximum number of bytes to include in + a single http.request event (default 4096). + """ + + def __init__(self, body=None, disconnect_at=None, chunk_size=4096): + if body is None: + body = b'' + elif not isinstance(body, bytes): + body = body.encode() + + if disconnect_at is None: + disconnect_at = time.time() + 30 + + self._body = body + self._chunk_size = chunk_size + self._disconnect_at = disconnect_at + + self._emitted_empty_chunk_a = False + self._emitted_empty_chunk_b = False + + async def emit(self): + if self._body is None: + # NOTE(kgriffs): When there are no more events, an ASGI + # server will hang until the client connection + # disconnects. + while time.time() < self._disconnect_at: + await asyncio.sleep(1) + + if self._disconnect_at <= time.time(): + return {'type': 'http.disconnect'} + + event = {'type': 'http.request'} + + # NOTE(kgriffs): Return a couple variations on empty chunks + # every time, to ensure test coverage. + if not self._emitted_empty_chunk_a: + self._emitted_empty_chunk_a = True + event['more_body'] = True + return event + + if not self._emitted_empty_chunk_b: + self._emitted_empty_chunk_b = True + event['more_body'] = True + event['body'] = b'' + return event + + # NOTE(kgriffs): Part of the time just return an + # empty chunk to make sure the app handles that + # correctly. + if flip_coin(): + event['more_body'] = True + + # NOTE(kgriffs): Since ASGI specifies that + # 'body' is optional, we randomaly choose whether + # or not to explicitly set it to b'' to ensure + # the app handles both correctly. + if flip_coin(): + event['body'] = b'' + + return event + + chunk = self._body[:self._chunk_size] + self._body = self._body[self._chunk_size:] or None + + if chunk: + event['body'] = chunk + elif flip_coin(): + # NOTE(kgriffs): Since ASGI specifies that + # 'body' is optional, we randomaly choose whether + # or not to explicitly set it to b'' to ensure + # the app handles both correctly. + event['body'] = b'' + + if self._body: + event['more_body'] = True + elif flip_coin(): + # NOTE(kgriffs): The ASGI spec allows leaving off + # the 'more_body' key when it would be set to + # False, so randomly choose one of the approaches + # to make sure the app handles both cases. + event['more_body'] = False + + return event + + __call__ = emit + + +class ASGIResponseEventCollector: + """Collects and validates ASGI events returned by an app.""" + + _LIFESPAN_EVENT_TYPES = frozenset([ + 'lifespan.startup.complete', + 'lifespan.startup.failed', + 'lifespan.shutdown.complete', + 'lifespan.shutdown.failed', + ]) + + def __init__(self): + self.events = [] + self.headers = [] + self.status = None + self.body_chunks = [] + self.more_body = None + + async def collect(self, event): + if self.more_body is False: + # NOTE(kgriffs): According to the ASGI spec, once we get a + # message setting more_body to False, any further messages + # on the channel are ignored. + return + + self.events.append(event) + + event_type = event['type'] + if not isinstance(event_type, str): + raise TypeError('ASGI event type must be a Unicode string') + + if event_type == 'http.response.start': + for name, value in event.get('headers', []): + if not isinstance(name, bytes): + raise TypeError('ASGI header names must be byte strings') + if not isinstance(value, bytes): + raise TypeError('ASGI header names must be byte strings') + + name_decoded = name.decode() + if not name_decoded.islower(): + raise ValueError('ASGI header names must be lowercase') + + self.headers.append((name_decoded, value.decode())) + + self.status = event['status'] + + if not isinstance(self.status, int): + raise TypeError('ASGI status must be an int') + + elif event_type == 'http.response.body': + chunk = event.get('body', b'') + if not isinstance(chunk, bytes): + raise TypeError('ASGI body content must be a byte string') + + self.body_chunks.append(chunk) + + self.more_body = event.get('more_body', False) + if not isinstance(self.more_body, bool): + raise TypeError('ASGI more_body flag must be a bool') + + elif event_type not in self._LIFESPAN_EVENT_TYPES: + raise ValueError('Invalid ASGI event type: ' + event_type) + + __call__ = collect + + +def invoke_coroutine_sync(coroutine, *args, **kwargs): + """Invokes a coroutine function from a synchronous caller and runs until complete. + + Warning: + This method is very inefficient and should only be used + for testing purposes. It will create an event loop for the current + thread if one is not already running. + + Additional arguments not mentioned below are bound to the given + coroutine function via ``functools.partial()``. + + Args: + coroutine: A coroutine function to invoke. + *args: Additional args are passed through to the coroutine function. + + Keyword Args: + **kwargs: Additional args are passed through to the coroutine function. + """ + + loop = asyncio.get_event_loop() + return loop.run_until_complete( + functools.partial( + coroutine, *args, **kwargs + )() + ) + + # get_encoding_from_headers() is Copyright 2016 Kenneth Reitz, and is # used here under the terms of the Apache License, Version 2.0. def get_encoding_from_headers(headers): @@ -61,12 +279,20 @@ def get_encoding_from_headers(headers): if 'charset' in params: return params['charset'].strip("'\"") + # NOTE(kgriffs): Added checks for text/event-stream and application/json + if content_type in ('text/event-stream', 'application/json'): + return 'UTF-8' + if 'text' in content_type: return 'ISO-8859-1' return None +def flip_coin() -> int: + return random.randint(0, 1) == 0 + + def rand_string(min, max) -> str: """Returns a randomly-generated string, of a random length. @@ -82,42 +308,167 @@ def rand_string(min, max) -> str: for __ in range(string_length)]) -def create_environ(path='/', query_string='', protocol='HTTP/1.1', +def create_scope(path='/', query_string='', method='GET', headers=None, + host=DEFAULT_HOST, scheme=None, port=None, http_version='1.1', + remote_addr=None, root_path=None, content_length=None, + include_server=True) -> Dict[str, Any]: + + """Create a mock ASGI scope ``dict`` for simulating ASGI requests. + + Keyword Args: + path (str): The path for the request (default '/') + query_string (str): The query string to simulate, without a + leading '?' (default ''). The query string is passed as-is + (it will not be percent-encoded). + method (str): The HTTP method to use (default 'GET') + headers (dict): Headers as a dict-like (Mapping) object, or an + iterable yielding a series of two-member (*name*, *value*) + iterables. Each pair of strings provides the name and value + for an HTTP header. If desired, multiple header values may be + combined into a single (*name*, *value*) pair by joining the values + with a comma when the header in question supports the list + format (see also RFC 7230 and RFC 7231). When the + request will include a body, the Content-Length header should be + included in this list. Header names are not case-sensitive. + host(str): Hostname for the request (default 'falconframework.org'). + This also determines the the value of the Host header in the + request. + scheme (str): URL scheme, either 'http' or 'https' (default 'http') + port (int): The TCP port to simulate. Defaults to + the standard port used by the given scheme (i.e., 80 for 'http' + and 443 for 'https'). A string may also be passed, as long as + it can be parsed as an int. + http_version (str): The HTTP version to simulate. Must be either + '2', '2.0', 1.1', '1.0', or '1' (default '1.1'). If set to '1.0', + the Host header will not be added to the scope. + remote_addr (str): Remote address for the request to use for + the 'client' field in the connection scope (default None) + root_path (str): The root path this application is mounted at; same as + SCRIPT_NAME in WSGI (default ''). + content_length (int): The expected content length of the request + body (default ``None``). If specified, this value will be + used to set the Content-Length header in the request. + include_server (bool): Set to ``False`` to not set the 'server' key + in the scope ``dict`` (default ``True``). + """ + + http_version = _fixup_http_version(http_version) + + path = uri.decode(path, unquote_plus=False) + + # NOTE(kgriffs): Handles both None and '' + query_string = query_string.encode() if query_string else b'' + + if query_string and query_string.startswith(b'?'): + raise ValueError("query_string should not start with '?'") + + scope = { + 'type': 'http', + 'asgi': { + 'version': '3.0', + 'spec_version': '2.1', + }, + 'http_version': http_version, + 'method': method.upper(), + 'path': path, + 'query_string': query_string, + } + + # NOTE(kgriffs): Explicitly test against None so that the caller + # is able to simulate setting app to an empty string if they + # need to cover that branch in their code. + if root_path is not None: + # NOTE(kgriffs): Judging by the algorithm given in PEP-3333 for + # reconstructing the URL, SCRIPT_NAME is expected to contain a + # preceding slash character. Since ASGI states that this value is + # the same as WSGI's SCRIPT_NAME, we will follow suit here. + if root_path and not root_path.startswith('/'): + scope['root_path'] = '/' + root_path + else: + scope['root_path'] = root_path + + if scheme: + if scheme not in ('http', 'https'): + raise ValueError("scheme must be either 'http' or 'https'") + + scope['scheme'] = scheme + + if port is None: + if (scheme or 'http') == 'http': + port = 80 + else: + port = 443 + else: + port = int(port) + + if remote_addr: + # NOTE(kgriffs): Choose from the standard IANA dynamic range + remote_port = random.randint(49152, 65535) + + # NOTE(kgriffs): Expose as an iterable to ensure the framework/app + # isn't hard-coded to only work with a list or tuple. + scope['client'] = iter([remote_addr, remote_port]) + + if include_server: + scope['server'] = iter([host, port]) + + _add_headers_to_scope(scope, headers, content_length, host, port, scheme, http_version) + + return scope + + +def create_environ(path='/', query_string='', http_version='1.1', scheme='http', host=DEFAULT_HOST, port=None, - headers=None, app='', body='', method='GET', - wsgierrors=None, file_wrapper=None, remote_addr=None) -> Dict[str, Any]: + headers=None, app=None, body='', method='GET', + wsgierrors=None, file_wrapper=None, remote_addr=None, + root_path=None) -> Dict[str, Any]: + """Creates a mock PEP-3333 environ ``dict`` for simulating WSGI requests. Keyword Args: path (str): The path for the request (default '/') query_string (str): The query string to simulate, without a - leading '?' (default '') - protocol (str): The HTTP protocol to simulate - (default 'HTTP/1.1'). If set to 'HTTP/1.0', the Host header - will not be added to the environment. + leading '?' (default ''). The query string is passed as-is + (it will not be percent-encoded). + http_version (str): The HTTP version to simulate. Must be either + '2', '2.0', 1.1', '1.0', or '1' (default '1.1'). If set to '1.0', + the Host header will not be added to the scope. scheme (str): URL scheme, either 'http' or 'https' (default 'http') host(str): Hostname for the request (default 'falconframework.org') - port (str): The TCP port to simulate. Defaults to + port (int): The TCP port to simulate. Defaults to the standard port used by the given scheme (i.e., 80 for 'http' - and 443 for 'https'). - headers (dict): Headers as a ``dict`` or an iterable yielding - (*key*, *value*) ``tuple``'s - app (str): Value for the ``SCRIPT_NAME`` environ variable, described in + and 443 for 'https'). A string may also be passed, as long as + it can be parsed as an int. + headers (dict): Headers as a dict-like (Mapping) object, or an + iterable yielding a series of two-member (*name*, *value*) + iterables. Each pair of strings provides the name and value + for an HTTP header. If desired, multiple header values may be + combined into a single (*name*, *value*) pair by joining the values + with a comma when the header in question supports the list + format (see also RFC 7230 and RFC 7231). Header names are not + case-sensitive. + root_path (str): Value for the ``SCRIPT_NAME`` environ variable, described in PEP-333: 'The initial portion of the request URL's "path" that corresponds to the application object, so that the application knows its virtual "location". This may be an empty string, if the application corresponds to the "root" of the server.' (default '') + app (str): Deprecated alias for `root_path`. If both kwargs are passed, + `root_path` takes precedence. body (str): The body of the request (default ''). The value will be - encoded as UTF-8 in the WSGI environ. + encoded as UTF-8 in the WSGI environ. Alternatively, a byte string + may be passed, in which case it will be used as-is. method (str): The HTTP method to use (default 'GET') wsgierrors (io): The stream to use as *wsgierrors* (default ``sys.stderr``) file_wrapper: Callable that returns an iterable, to be used as the value for *wsgi.file_wrapper* in the environ. - remote_addr (str): Remote address for the request (default '127.0.0.1') + remote_addr (str): Remote address for the request to use as the + 'REMOTE_ADDR' environ variable (default None) """ + http_version = _fixup_http_version(http_version) + if query_string and query_string.startswith('?'): raise ValueError("query_string should not start with '?'") @@ -148,25 +499,27 @@ def create_environ(path='/', query_string='', protocol='HTTP/1.1', if port is None: port = '80' if scheme == 'http' else '443' else: - port = str(port) + # NOTE(kgriffs): Running it through int() first ensures that if + # a string was passed, it is a valid integer. + port = str(int(port)) + + root_path = root_path or app or '' # NOTE(kgriffs): Judging by the algorithm given in PEP-3333 for # reconstructing the URL, SCRIPT_NAME is expected to contain a # preceding slash character. - if app and not app.startswith('/'): - app = '/' + app + if root_path and not root_path.startswith('/'): + root_path = '/' + root_path env = { - 'SERVER_PROTOCOL': protocol, + 'SERVER_PROTOCOL': 'HTTP/' + http_version, 'SERVER_SOFTWARE': 'gunicorn/0.17.0', - 'SCRIPT_NAME': app, + 'SCRIPT_NAME': (root_path or ''), 'REQUEST_METHOD': method, 'PATH_INFO': path, 'QUERY_STRING': query_string, - 'HTTP_USER_AGENT': 'curl/7.24.0 (x86_64-apple-darwin12.0)', 'REMOTE_PORT': '65133', 'RAW_URI': '/', - 'REMOTE_ADDR': remote_addr or '127.0.0.1', 'SERVER_NAME': host, 'SERVER_PORT': port, @@ -179,10 +532,16 @@ def create_environ(path='/', query_string='', protocol='HTTP/1.1', 'wsgi.run_once': False } + # NOTE(kgriffs): It has been observed that WSGI servers do not always + # set the REMOTE_ADDR variable, so we don't always set it either, to + # ensure the framework/app handles that case correctly. + if remote_addr: + env['REMOTE_ADDR'] = remote_addr + if file_wrapper is not None: env['wsgi.file_wrapper'] = file_wrapper - if protocol != 'HTTP/1.0': + if http_version != '1.0': host_header = host if scheme == 'https': @@ -206,6 +565,62 @@ def create_environ(path='/', query_string='', protocol='HTTP/1.1', return env +def create_req(options=None, **kwargs) -> falcon.Request: + """Create and return a new Request instance. + + This function can be used to conveniently create a WSGI environ + and use it to instantiate a :py:class:`~.Request` object in one go. + + The arguments for this function are identical to those + of :py:meth:`falcon.testing.create_environ`, except an additional + `options` keyword argument may be set to an instance of + :py:class:`falcon.RequestOptions` to configure certain + aspects of request parsing in lieu of the defaults. + """ + + env = create_environ(**kwargs) + return falcon.request.Request(env, options=options) + + +def create_asgi_req(body=None, req_type=None, options=None, **kwargs) -> falcon.Request: + """Create and return a new ASGI Request instance. + + This function can be used to conveniently create an ASGI scope + and use it to instantiate a :py:class:`falcon.asgi.Request` object + in one go. + + The arguments for this function are identical to those + of :py:meth:`falcon.testing.create_environ`, with the addition of + `body`, `req_type`, and `options` arguments as documented below. + + Keyword Arguments: + body (bytes): The body data to use for the request (default b''). If + the value is a :py:class:`str`, it will be UTF-8 encoded to + a byte string. + req_type (object): A subclass of :py:class:`falcon.asgi.Request` + to instantiate. If not specified, the standard + :py:class:`falcon.asgi.Request` class will simply be used. + options (falcon.RequestOptions): An instance of + :py:class:`falcon.RequestOptions` that should be used to determine + certain aspects of request parsing in lieu of the defaults. + """ + + scope = create_scope(**kwargs) + + body = body or b'' + disconnect_at = time.time() + 300 + + req_event_emitter = ASGIRequestEventEmitter(body, disconnect_at) + + # NOTE(kgriffs): Import here in case the app is running under + # Python 3.5 (in which case as long as it does not call the + # present function, it won't trigger an import error). + import falcon.asgi + + req_type = req_type or falcon.asgi.Request + return req_type(scope, req_event_emitter, options=options) + + @contextlib.contextmanager def redirected(stdout=sys.stdout, stderr=sys.stderr): """ @@ -233,7 +648,7 @@ def closed_wsgi_iterable(iterable): `the PEP-3333 server/gateway side example `_. Finally, if the iterable has a ``close()`` method, it is called upon - exception or exausting iteration. + exception or exhausting iteration. Furthermore, the first bytestring yielded from iteration, if any, is prefetched before returning the wrapped iterator in order to ensure the @@ -270,21 +685,78 @@ def wrapper(): def _add_headers_to_environ(env, headers): - if not isinstance(headers, dict): - # Try to convert - headers = dict(headers) + try: + items = headers.items() + except AttributeError: + items = headers - for name, value in headers.items(): - name = name.upper().replace('-', '_') + for name, value in items: + name_wsgi = name.upper().replace('-', '_') + if name_wsgi not in ('CONTENT_TYPE', 'CONTENT_LENGTH'): + name_wsgi = 'HTTP_' + name_wsgi if value is None: value = '' else: value = value.strip() - if name == 'CONTENT_TYPE': - env[name] = value - elif name == 'CONTENT_LENGTH': - env[name] = value + if name_wsgi not in env or name.lower() in SINGLETON_HEADERS: + env[name_wsgi] = value + else: + env[name_wsgi] += ',' + value + + +def _add_headers_to_scope(scope, headers, content_length, host, port, scheme, http_version): + if headers: + try: + items = headers.items() + except AttributeError: + items = headers + + prepared_headers = [ + # NOTE(kgriffs): Expose as an iterable to ensure the framework/app + # isn't hard-coded to only work with a list or tuple. + # NOTE(kgriffs): Value is stripped if not empty, otherwise defaults + # to b'' to be consistent with _add_headers_to_environ(). + iter([name.lower().encode(), value.strip().encode() if value else b'']) + + # NOTE(kgriffs): Use tuple unpacking to support iterables + # that yield arbitary two-item iterable objects. + for name, value in items + ] + else: + prepared_headers = [] + + if content_length is not None: + value = str(content_length).encode() + prepared_headers.append((b'content-length', value)) + + if http_version != '1.0': + host_header = host + + if scheme == 'https': + if port != 443: + host_header += ':' + str(port) else: - env['HTTP_' + name] = value + if port != 80: + host_header += ':' + str(port) + + prepared_headers.append([b'host', host_header.encode()]) + + # NOTE(kgriffs): Make it an iterator to ensure the app is not expecting + # a specific type (ASGI only specified that it is an iterable). + scope['headers'] = iter(prepared_headers) + + +def _fixup_http_version(http_version) -> str: + if http_version not in ('2', '2.0', '1.1', '1.0', '1'): + raise ValueError('Invalid http_version specified: ' + http_version) + + # NOTE(kgrifs): Normalize so that they conform to the standard + # protocol names with prefixed with "HTTP/" + if http_version == '2.0': + http_version = '2' + elif http_version == '1': + http_version = '1.0' + + return http_version diff --git a/falcon/testing/resource.py b/falcon/testing/resource.py index 020e8ba3c..42837904b 100644 --- a/falcon/testing/resource.py +++ b/falcon/testing/resource.py @@ -35,15 +35,51 @@ def capture_responder_args(req, resp, resource, params): Adds the following attributes to the hooked responder's resource class: - * captured_req - * captured_resp - * captured_kwargs + * `captured_req` + * `captured_resp` + * `captured_kwargs` + + In addition, if the capture-req-body-bytes header is present in the + request, the following attribute is added: + + * `captured_req_body` + + Including the capture-req-media header in the request (set to any + value) will add the following attribute: + + * `capture-req-media` """ resource.captured_req = req resource.captured_resp = resp resource.captured_kwargs = params + resource.captured_req_media = None + resource.captured_req_body = None + + num_bytes = req.get_header('capture-req-body-bytes') + if num_bytes: + resource.captured_req_body = req.stream.read(int(num_bytes)) + elif req.get_header('capture-req-media'): + resource.captured_req_media = req.media + + +async def capture_responder_args_async(req, resp, resource, params): + """An asynchronous version of ``capture_responder_args()``.""" + + resource.captured_req = req + resource.captured_resp = resp + resource.captured_kwargs = params + + resource.captured_req_media = None + resource.captured_req_body = None + + num_bytes = req.get_header('capture-req-body-bytes') + if num_bytes: + resource.captured_req_body = await req.stream.read(int(num_bytes)) + elif req.get_header('capture-req-media'): + resource.captured_req_media = await req.get_media() + def set_resp_defaults(req, resp, resource, params): """Before hook for setting default response properties.""" @@ -58,6 +94,11 @@ def set_resp_defaults(req, resp, resource, params): resp.set_headers(resource._default_headers) +async def set_resp_defaults_async(req, resp, resource, params): + """Wraps capture_responder_args in a coroutine.""" + set_resp_defaults(req, resp, resource, params) + + class SimpleTestResource: """Mock resource for functional testing of framework components. @@ -126,3 +167,52 @@ def on_get(self, req, resp, **kwargs): @falcon.before(set_resp_defaults) def on_post(self, req, resp, **kwargs): pass + + +class SimpleTestResourceAsync(SimpleTestResource): + """Mock resource for functional testing of ASGI framework components. + + This class implements a simple test resource that can be extended + as needed to test middleware, hooks, and the Falcon framework + itself. It is identical to SimpleTestResource, except that it implements + asynchronous responders for use with the ASGI interface. + + Only noop ``on_get()`` and ``on_post()`` responders are implemented; + when overriding these, or adding additional responders in child + classes, they can be decorated with the + :py:meth:`falcon.testing.capture_responder_args` hook in + order to capture the *req*, *resp*, and *params* arguments that + are passed to the responder. Responders may also be decorated with + the :py:meth:`falcon.testing.set_resp_defaults` hook in order to + set *resp* properties to default *status*, *body*, and *header* + values. + + Keyword Arguments: + status (str): Default status string to use in responses + body (str): Default body string to use in responses + json (JSON serializable): Default JSON document to use in responses. + Will be serialized to a string and encoded as UTF-8. Either + *json* or *body* may be specified, but not both. + headers (dict): Default set of additional headers to include in + responses + + Attributes: + called (bool): Whether or not a req/resp was captured. + captured_req (falcon.Request): The last Request object passed + into any one of the responder methods. + captured_resp (falcon.Response): The last Response object passed + into any one of the responder methods. + captured_kwargs (dict): The last dictionary of kwargs, beyond + ``req`` and ``resp``, that were passed into any one of the + responder methods. + """ + + @falcon.before(capture_responder_args_async) + @falcon.before(set_resp_defaults_async) + async def on_get(self, req, resp, **kwargs): + pass + + @falcon.before(capture_responder_args_async) + @falcon.before(set_resp_defaults_async) + async def on_post(self, req, resp, **kwargs): + pass diff --git a/falcon/testing/test_case.py b/falcon/testing/test_case.py index 6f89c32ac..99c0e7569 100644 --- a/falcon/testing/test_case.py +++ b/falcon/testing/test_case.py @@ -45,10 +45,10 @@ class TestCase(unittest.TestCase, TestClient): :py:class:`unittest.TestCase` or :py:class:`testtools.TestCase`. Attributes: - app (object): A WSGI application to target when simulating + app (object): A WSGI or ASGI application to target when simulating requests (default: ``falcon.App()``). When testing your application, you will need to set this to your own instance - of ``falcon.App``. For example:: + of ``falcon.App`` or ``falcon.asgi.App``. For example:: from falcon import testing import myapp @@ -80,6 +80,3 @@ def setUp(self): # NOTE(kgriffs): Don't use super() to avoid triggering # unittest.TestCase.__init__() TestClient.__init__(self, app) - - # Reset to simulate "restarting" the WSGI container - falcon.request._maybe_wrap_wsgi_stream = True diff --git a/falcon/util/__init__.py b/falcon/util/__init__.py index e8722d5ca..c4362b49f 100644 --- a/falcon/util/__init__.py +++ b/falcon/util/__init__.py @@ -26,6 +26,7 @@ # Hoist misc. utils from falcon.util.structures import * # NOQA from falcon.util.misc import * # NOQA +from falcon.util.sync import * # NOQA from falcon.util.time import * # NOQA diff --git a/falcon/util/misc.py b/falcon/util/misc.py index a0aa9dab3..4cd3b2886 100644 --- a/falcon/util/misc.py +++ b/falcon/util/misc.py @@ -26,7 +26,10 @@ import datetime import functools +import http import inspect +import os +import sys import warnings from falcon import status_codes @@ -39,7 +42,9 @@ 'to_query_str', 'get_bound_method', 'get_argnames', - 'get_http_status' + 'get_http_status', + 'http_status_to_code', + 'code_to_http_status', ) @@ -48,6 +53,30 @@ utcnow = datetime.datetime.utcnow +# NOTE(kgriffs): This is tested in the gate but we do not want devs to +# have to install a specific version of 3.5 to check coverage on their +# workstations, so we use the nocover pragma here. +def _lru_cache_nop(*args, **kwargs): # pragma: nocover + def decorator(func): + return func + + return decorator + + +if ( + # NOTE(kgriffs): https://bugs.python.org/issue28969 + (sys.version_info.minor == 5 and sys.version_info.micro < 4) or + (sys.version_info.minor == 6 and sys.version_info.micro < 1) or + + # PERF(kgriffs): Using lru_cache is slower on pypy when the wrapped + # function is just doing a few non-IO operations. + (sys.implementation.name == 'pypy') +): + _lru_cache_safe = _lru_cache_nop # pragma: nocover +else: + _lru_cache_safe = functools.lru_cache + + # NOTE(kgriffs): We don't want our deprecations to be ignored by default, # so create our own type. # @@ -71,16 +100,17 @@ def deprecated(instructions): def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - message = 'Call to deprecated function {0}(...). {1}'.format( - func.__name__, - instructions) + if 'FALCON_TESTING_SESSION' not in os.environ: + message = 'Call to deprecated function {0}(...). {1}'.format( + func.__name__, + instructions) - frame = inspect.currentframe().f_back + frame = inspect.currentframe().f_back - warnings.warn_explicit(message, - category=DeprecatedWarning, - filename=inspect.getfile(frame.f_code), - lineno=frame.f_lineno) + warnings.warn_explicit(message, + category=DeprecatedWarning, + filename=inspect.getfile(frame.f_code), + lineno=frame.f_lineno) return func(*args, **kwargs) @@ -274,6 +304,7 @@ def get_argnames(func): return args +@deprecated('Please use falcon.util.code_to_http_status() instead.') def get_http_status(status_code, default_reason='Unknown'): """Gets both the http status code and description from just a code @@ -305,3 +336,70 @@ def get_http_status(status_code, default_reason='Unknown'): except AttributeError: # not found return str(code) + ' ' + default_reason + + +@_lru_cache_safe(maxsize=64) +def http_status_to_code(status): + """Normalize an HTTP status to an integer code. + + This function takes a member of http.HTTPStatus, an HTTP status + line string or byte string (e.g., '200 OK'), or an ``int`` and + returns the corresponding integer code. + + An LRU is used to minimize lookup time. + + Args: + status: The status code or enum to normalize + + Returns: + int: Integer code for the HTTP status (e.g., 200) + """ + + if isinstance(status, http.HTTPStatus): + return status.value + + if isinstance(status, int): + return status + + if isinstance(status, bytes): + status = status.decode() + + if not isinstance(status, str): + raise ValueError('status must be an int, str, or a member of http.HTTPStatus') + + if len(status) < 3: + raise ValueError('status strings must be at least three characters long') + + try: + return int(status[:3]) + except ValueError: + raise ValueError('status strings must start with a three-digit integer') + + +@_lru_cache_safe(maxsize=64) +def code_to_http_status(code): + """Convert an HTTP status code integer to a status line string. + + An LRU is used to minimize lookup time. + + Args: + code (int): The integer status code to convert to a status line. + + Returns: + str: HTTP status line corresponding to the given code. A newline + is not included at the end of the string. + """ + + try: + code = int(code) + if code < 100: + raise ValueError() + except (ValueError, TypeError): + raise ValueError('"{}" is not a valid status code'.format(code)) + + try: + # NOTE(kgriffs): We do this instead of using http.HTTPStatus since + # the Falcon module defines a larger number of codes. + return getattr(status_codes, 'HTTP_' + str(code)) + except AttributeError: + return str(code) diff --git a/falcon/util/sync.py b/falcon/util/sync.py new file mode 100644 index 000000000..34e2bf4b2 --- /dev/null +++ b/falcon/util/sync.py @@ -0,0 +1,182 @@ +import asyncio +from collections.abc import Coroutine +from concurrent.futures import ThreadPoolExecutor +from functools import partial, wraps +import inspect +import os + + +__all__ = [ + 'get_loop', + 'sync_to_async', + 'wrap_sync_to_async', + 'wrap_sync_to_async_unsafe', +] + + +_one_thread_to_rule_them_all = ThreadPoolExecutor(max_workers=1) + + +try: + get_loop = asyncio.get_running_loop + """Gets the running asyncio event loop.""" +except AttributeError: # pragma: nocover + # NOTE(kgriffs): This branch is definitely covered under py35 and py36 + # but for some reason the codecov gate doesn't pick this up, hence + # the pragma above. + + get_loop = asyncio.get_event_loop + """Gets the running asyncio event loop.""" + + +def wrap_sync_to_async_unsafe(func) -> Coroutine: + """Wrap a callable in a coroutine that executes the callable directly. + + This helper makes it easier to use synchronous callables with ASGI + apps. However, it is considered "unsafe" because it calls the wrapped + function directly in the same thread as the asyncio loop. Generally, you + should use :meth:`~.wrap_sync_to_async` instead. + + Warning: + This helper is only to be used for functions that do not perform any + blocking I/O or lengthy CPU-bound operations, since the entire async + loop will be blocked while the wrapped function is executed. + For a safer, non-blocking alternative that runs the function in a + thread pool executor, use :func:~.sync_to_async instead. + + Arguments: + func (callable): Function, method, or other callable to wrap + + Returns: + function: An awaitable coroutine function that wraps the + synchronous callable. + """ + + @wraps(func) + async def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +def wrap_sync_to_async(func, threadsafe=None) -> Coroutine: + """Wrap a callable in a coroutine that executes the callable in the background. + + This helper makes it easier to call functions that can not be + ported to use async natively (e.g., functions exported by a database + library that does not yet support asyncio). + + To execute blocking operations safely, without stalling the async + loop, the wrapped callable is scheduled to run in the background, on a + separate thread, when the wrapper is called. + + Normally, the default executor for the running loop is used to schedule the + synchronous callable. If the callable is not thread-safe, it can be + scheduled serially in a global single-threaded executor. + + Warning: + Wrapping a synchronous function safely adds a fair amount of overhead + to the function call, and should only be used when a native async + library is not available for the operation you wish to perform. + + Arguments: + func (callable): Function, method, or other callable to wrap + + Keyword Arguments: + threadsafe (bool): Set to ``False`` when the callable is not + thread-safe (default ``True``). When this argument is ``False``, + the wrapped callable will be scheduled to run serially in a + global single-threaded executor. + + Returns: + function: An awaitable coroutine function that wraps the + synchronous callable. + """ + + if threadsafe is None or threadsafe: + executor = None # Use default + else: + executor = _one_thread_to_rule_them_all + + @wraps(func) + async def wrapper(*args, **kwargs): + return await get_loop().run_in_executor(executor, partial(func, *args, **kwargs)) + + return wrapper + + +async def sync_to_async(func, *args, **kwargs): + """Schedules a synchronous callable on the loop's default executor and awaits the result. + + This helper makes it easier to call functions that can not be + ported to use async natively (e.g., functions exported by a database + library that does not yet support asyncio). + + To execute blocking operations safely, without stalling the async + loop, the wrapped callable is scheduled to run in the background, on a + separate thread, when the wrapper is called. + + The default executor for the running loop is used to schedule the + synchronous callable. + + Warning: + This helper can only be used to execute thread-safe callables. If + the callable is not thread-safe, it can be executed serially + by first wrapping it with :meth:`~.wrap_sync_to_async`, and then + executing the wrapper directly. + + Warning: + Calling a synchronous function safely from an asyncio event loop + adds a fair amount of overhead to the function call, and should + only be used when a native async library is not available for the + operation you wish to perform. + + Arguments: + func (callable): Function, method, or other callable to wrap + *args: All additional arguments are passed through to the callable. + + Keyword Arguments: + **kwargs: All keyword arguments are passed through to the callable. + + Returns: + function: An awaitable coroutine function that wraps the + synchronous callable. + """ + + return await get_loop().run_in_executor(None, partial(func, *args, **kwargs)) + + +def _should_wrap_non_coroutines() -> bool: + """Returns True IFF FALCON_ASGI_WRAP_NON_COROUTINES is set in the environ. + + This should only be used for Falcon's own test suite. + """ + + return 'FALCON_ASGI_WRAP_NON_COROUTINES' in os.environ + + +def _wrap_non_coroutine_unsafe(func): + """Wraps a coroutine using ``wrap_sync_to_async_unsafe()`` for internal test cases. + + This method is intended for Falcon's own test suite and should not be + used by apps themselves. It provides a convenient way to reuse sync + methods for ASGI test cases when it is safe to do so. + + Arguments: + func (callable): Function, method, or other callable to wrap + Returns: + When not in test mode, this function simply returns the callable + unchanged. Otherwise, if the callable is not a coroutine function, + it will be wrapped using ``wrap_sync_to_async_unsafe()``. + """ + + if func is None: + return func + + if not _should_wrap_non_coroutines(): + return func + + if inspect.iscoroutinefunction(func): + return func + + return wrap_sync_to_async_unsafe(func) diff --git a/requirements/tests b/requirements/tests index 5d3b78b4c..cbe50b6a0 100644 --- a/requirements/tests +++ b/requirements/tests @@ -4,6 +4,11 @@ pyyaml requests testtools +# ASGI Specific +uvicorn; python_version >= '3.6' +aiofiles; python_version >= '3.6' +daphne; python_version >= '3.6' + # Handler Specific msgpack mujson @@ -12,4 +17,4 @@ ujson python-rapidjson # TODO(kgriffs): orjson is failing to build on Travis -# orjson; platform_python_implementation != 'PyPy' +orjson; platform_python_implementation != 'PyPy' and platform_machine != 's390x' diff --git a/setup.cfg b/setup.cfg index 41c485bd5..ee8ce80f6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,4 +10,7 @@ test=pytest [tool:pytest] filterwarnings = ignore:Unknown REQUEST_METHOD. '(CONNECT|DELETE|GET|HEAD|OPTIONS|PATCH|POST|PUT|TRACE|CHECKIN|CHECKOUT|COPY|LOCK|MKCOL|MOVE|PROPFIND|PROPPATCH|REPORT|UNCHECKIN|UNLOCK|UPDATE|VERSION-CONTROL)':wsgiref.validate.WSGIWarning + ignore:Unknown REQUEST_METHOD. '(FOO|BAR|BREW|SETECASTRONOMY)':wsgiref.validate.WSGIWarning + ignore:"@coroutine" decorator is deprecated:DeprecationWarning + ignore:Using or importing the ABCs:DeprecationWarning ignore:cannot collect test class 'TestClient':pytest.PytestCollectionWarning diff --git a/setup.py b/setup.py index 5de90076e..7db19507a 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,19 @@ def list_modules(dirname, pattern): 'falcon.vendor.mimeparse', ] + modules_to_exclude = [ + # NOTE(kgriffs): Cython does not handle dynamically-created async + # methods correctly, so we do not cythonize the following modules. + 'falcon.hooks', + # NOTE(vytas): Middleware classes cannot be cythonized until + # asyncio.iscoroutinefunction recognizes cythonized coroutines: + # * https://github.com/cython/cython/issues/2273 + # * https://bugs.python.org/issue38225 + 'falcon.middlewares', + 'falcon.responders', + 'falcon.util.sync', + ] + cython_package_names = frozenset([ 'falcon.cyutil', ]) @@ -72,6 +85,7 @@ def list_modules(dirname, pattern): for module, ext in list_modules( path.join(MYDIR, *package.split('.')), ('*.pyx' if package in cython_package_names else '*.py')) + if (package + '.' + module) not in modules_to_exclude ] cmdclass = {'build_ext': build_ext} @@ -151,7 +165,7 @@ def load_description(): packages=find_packages(exclude=['tests']), include_package_data=True, zip_safe=False, - python_requires='>=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*', + python_requires='>=3.5', install_requires=REQUIRES, cmdclass=cmdclass, ext_modules=ext_modules, diff --git a/tests/_util.py b/tests/_util.py new file mode 100644 index 000000000..721d1e138 --- /dev/null +++ b/tests/_util.py @@ -0,0 +1,77 @@ +from contextlib import contextmanager +import os + +import pytest + +import falcon +import falcon.testing + + +__all__ = [ + 'create_app', + 'create_req', + 'create_resp', + 'to_coroutine', +] + + +def create_app(asgi, **app_kwargs): + if asgi: + skipif_asgi_unsupported() + from falcon.asgi import App + else: + from falcon import App + + app = App(**app_kwargs) + return app + + +def create_req(asgi, options=None, **environ_or_scope_kwargs): + if asgi: + skipif_asgi_unsupported() + + req = falcon.testing.create_asgi_req( + options=options, + **environ_or_scope_kwargs + ) + + else: + req = falcon.testing.create_req( + options=options, + **environ_or_scope_kwargs + ) + + return req + + +def create_resp(asgi): + if asgi: + skipif_asgi_unsupported() + from falcon.asgi import Response + return Response() + + return falcon.Response() + + +def to_coroutine(callable): + async def wrapper(*args, **kwargs): + return callable(*args, **kwargs) + + return wrapper + + +def skipif_asgi_unsupported(): + if not falcon.ASGI_SUPPORTED: + pytest.skip('ASGI requires Python 3.6+') + + +@contextmanager +def disable_asgi_non_coroutine_wrapping(): + should_wrap = 'FALCON_ASGI_WRAP_NON_COROUTINES' in os.environ + if should_wrap: + del os.environ['FALCON_ASGI_WRAP_NON_COROUTINES'] + + yield + + if should_wrap: + os.environ['FALCON_ASGI_WRAP_NON_COROUTINES'] = 'Y' diff --git a/tests/asgi/__init__.py b/tests/asgi/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/asgi/_asgi_test_app.py b/tests/asgi/_asgi_test_app.py new file mode 100644 index 000000000..cf8a480b7 --- /dev/null +++ b/tests/asgi/_asgi_test_app.py @@ -0,0 +1,138 @@ +import asyncio +from collections import Counter +import time + +import falcon +import falcon.asgi +import falcon.util + + +class Things: + def __init__(self): + self._counter = Counter() + + async def on_get(self, req, resp): + await asyncio.sleep(0.01) + resp.body = req.remote_addr + + async def on_post(self, req, resp): + resp.data = await req.stream.read(req.content_length or 0) + resp.set_header('X-Counter', str(self._counter['backround:things:on_post'])) + + async def background_job_async(): + await asyncio.sleep(0.01) + self._counter['backround:things:on_post'] += 1 + + def background_job_sync(): + time.sleep(0.01) + self._counter['backround:things:on_post'] += 1000 + + resp.schedule(background_job_async) + resp.schedule(background_job_sync) + resp.schedule(background_job_async) + resp.schedule(background_job_sync) + + async def on_put(self, req, resp): + # NOTE(kgriffs): Test that reading past the end does + # not hang. + + chunks = [] + for i in range(req.content_length + 1): + # NOTE(kgriffs): In the ASGI interface, bounded_stream is an + # alias for req.stream. We'll use the alias here just as + # a sanity check. + chunk = await req.bounded_stream.read(1) + chunks.append(chunk) + + # NOTE(kgriffs): body should really be set to a string, but + # Falcon is lenient and will allow bytes as well (although + # it is slightly less performant). + # TODO(kgriffs): Perhaps in Falcon 4.0 be more strict? We would + # also have to change the WSGI behavior to match. + resp.body = b''.join(chunks) + + # ================================================================= + # NOTE(kgriffs): Test the sync_to_async helpers here to make sure + # they work as expected in the context of a real ASGI server. + # ================================================================= + safely_tasks = [] + safely_values = [] + + def callmesafely(a, b, c=None): + # NOTE(kgriffs): Sleep to prove that there isn't another instance + # running in parallel that is able to race ahead. + time.sleep(0.001) + safely_values.append((a, b, c)) + + cms = falcon.util.wrap_sync_to_async(callmesafely, threadsafe=False) + loop = falcon.util.get_loop() + + num_cms_tasks = 1000 + + for i in range(num_cms_tasks): + # NOTE(kgriffs): create_task() is used here, so that the coroutines + # are scheduled immediately in the order created; under Python + # 3.6, asyncio.gather() does not seem to always schedule + # them in order, so we do it this way to make it predictable. + safely_tasks.append(loop.create_task(cms(i, i + 1, c=i + 2))) + + await asyncio.gather(*safely_tasks) + + assert len(safely_values) == num_cms_tasks + for i, val in enumerate(safely_values): + assert safely_values[i] == (i, i + 1, i + 2) + + def callmeshirley(a=42, b=None): + return (a, b) + + assert (42, None) == await falcon.util.sync_to_async(callmeshirley) + assert (1, 2) == await falcon.util.sync_to_async(callmeshirley, 1, 2) + assert (5, None) == await falcon.util.sync_to_async(callmeshirley, 5) + assert (3, 4) == await falcon.util.sync_to_async(callmeshirley, 3, b=4) + + +class Bucket: + async def on_post(self, req, resp): + resp.body = await req.stream.read() + + +class Events: + async def on_get(self, req, resp): + async def emit(): + start = time.time() + while time.time() - start < 1: + yield falcon.asgi.SSEvent(text='hello world') + await asyncio.sleep(0.2) + + resp.sse = emit() + + +class LifespanHandler: + def __init__(self): + self.startup_succeeded = False + self.shutdown_succeeded = False + + async def process_startup(self, scope, event): + assert scope['type'] == 'lifespan' + assert event['type'] == 'lifespan.startup' + self.startup_succeeded = True + + async def process_shutdown(self, scope, event): + assert scope['type'] == 'lifespan' + assert event['type'] == 'lifespan.shutdown' + self.shutdown_succeeded = True + + +def create_app(): + app = falcon.asgi.App() + app.add_route('/', Things()) + app.add_route('/bucket', Bucket()) + app.add_route('/events', Events()) + + lifespan_handler = LifespanHandler() + app.add_middleware(lifespan_handler) + + return app + + +application = create_app() diff --git a/tests/asgi/test_asgi_servers.py b/tests/asgi/test_asgi_servers.py new file mode 100644 index 000000000..421c7a0b0 --- /dev/null +++ b/tests/asgi/test_asgi_servers.py @@ -0,0 +1,185 @@ +from contextlib import contextmanager +import os +import platform +import random +import subprocess +import time + +import pytest +import requests +import requests.exceptions + +from falcon import testing + + +_MODULE_DIR = os.path.abspath(os.path.dirname(__file__)) + +_PYPY = platform.python_implementation() == 'PyPy' + +_SERVER_HOST = '127.0.0.1' +_SIZE_1_KB = 1024 + +_random = random.Random() + + +_REQUEST_TIMEOUT = 10 + + +class TestASGIServer: + + def test_get(self, server_base_url): + resp = requests.get(server_base_url, timeout=_REQUEST_TIMEOUT) + assert resp.status_code == 200 + assert resp.text == '127.0.0.1' + + def test_put(self, server_base_url): + body = '{}' + resp = requests.put(server_base_url, data=body, timeout=_REQUEST_TIMEOUT) + assert resp.status_code == 200 + assert resp.text == '{}' + + def test_head_405(self, server_base_url): + body = '{}' + resp = requests.head(server_base_url, data=body, timeout=_REQUEST_TIMEOUT) + assert resp.status_code == 405 + + def test_post_multiple(self, server_base_url): + body = testing.rand_string(_SIZE_1_KB / 2, _SIZE_1_KB) + resp = requests.post(server_base_url, data=body, timeout=_REQUEST_TIMEOUT) + assert resp.status_code == 200 + assert resp.text == body + assert resp.headers['X-Counter'] == '0' + + time.sleep(1) + + resp = requests.post(server_base_url, data=body, timeout=_REQUEST_TIMEOUT) + assert resp.headers['X-Counter'] == '2002' + + def test_post_invalid_content_length(self, server_base_url): + headers = {'Content-Length': 'invalid'} + + try: + resp = requests.post(server_base_url, headers=headers, timeout=_REQUEST_TIMEOUT) + + # Daphne responds with a 400 + assert resp.status_code == 400 + + except requests.ConnectionError: + # NOTE(kgriffs): Uvicorn will kill the request so it does not + # even get to our app; the app logic is tested on the WSGI + # side. We leave this here in case something changes in + # the way uvicorn handles it or something and we want to + # get a heads-up if the request is no longer blocked. + pass + + def test_post_read_bounded_stream(self, server_base_url): + body = testing.rand_string(_SIZE_1_KB / 2, _SIZE_1_KB) + resp = requests.post(server_base_url + 'bucket', data=body, timeout=_REQUEST_TIMEOUT) + assert resp.status_code == 200 + assert resp.text == body + + def test_post_read_bounded_stream_no_body(self, server_base_url): + resp = requests.post(server_base_url + 'bucket', timeout=_REQUEST_TIMEOUT) + assert not resp.text + + def test_sse(self, server_base_url): + resp = requests.get(server_base_url + 'events', timeout=_REQUEST_TIMEOUT) + assert resp.status_code == 200 + + events = resp.text.split('\n\n') + assert len(events) > 2 + for e in events[:-1]: + assert e == 'data: hello world' + + assert not events[-1] + + +@contextmanager +def _run_server_isolated(process_factory, host, port): + # NOTE(kgriffs): We have to use subprocess because uvicorn has a tendency + # to corrupt our asyncio state and cause intermittent hangs in the test + # suite. + print('\n[Starting server process...]') + server = process_factory(host, port) + + time.sleep(0.2) + startup_succeeded = (server.poll() is None) + print('\n[Server process start {}]'.format('succeeded' if startup_succeeded else 'failed')) + + if startup_succeeded: + yield server + + print('\n[Sending SIGTERM to server process...]') + server.terminate() + + try: + server.communicate(timeout=10) + except subprocess.TimeoutExpired: + server.kill() + server.communicate() + + assert server.returncode == 0 + assert startup_succeeded + + +def _uvicorn_factory(host, port): + # NOTE(vytas): uvicorn+uvloop is not (well) supported on PyPy at the time + # of writing. + loop_options = ('--http', 'h11', '--loop', 'asyncio') if _PYPY else () + options = ( + '--host', host, + '--port', str(port), + + '_asgi_test_app:application' + ) + + return subprocess.Popen( + ('uvicorn',) + loop_options + options, + cwd=_MODULE_DIR, + ) + + +def _daphne_factory(host, port): + return subprocess.Popen( + ( + 'daphne', + + '--bind', host, + '--port', str(port), + + '--verbosity', '2', + '--access-log', '-', + + '_asgi_test_app:application' + ), + cwd=_MODULE_DIR, + ) + + +@pytest.fixture(params=[_uvicorn_factory, _daphne_factory]) +def server_base_url(request): + process_factory = request.param + + # NOTE(kgriffs): This facilitates parallel test execution as well as + # mitigating the problem of trying to reuse a port that the system + # hasn't cleaned up yet. + # NOTE(kgriffs): Use our own Random instance because we don't want + # pytest messing with the seed. + server_port = _random.randint(50000, 60000) + base_url = 'http://{}:{}/'.format(_SERVER_HOST, server_port) + + with _run_server_isolated(process_factory, _SERVER_HOST, server_port): + # NOTE(kgriffs): Let the server start up. Give up after 5 seconds. + start_ts = time.time() + while True: + wait_time = time.time() - start_ts + assert wait_time < 5 + + try: + requests.get(base_url, timeout=0.2) + except (requests.exceptions.Timeout, requests.exceptions.ConnectionError): + time.sleep(0.2) + else: + break + + yield base_url diff --git a/tests/asgi/test_boundedstream_asgi.py b/tests/asgi/test_boundedstream_asgi.py new file mode 100644 index 000000000..7209f6811 --- /dev/null +++ b/tests/asgi/test_boundedstream_asgi.py @@ -0,0 +1,224 @@ +import asyncio +import time + +import pytest + +from falcon import asgi, testing + + +@pytest.mark.parametrize('body', [ + b'', + b'\x00', + b'\x00\xFF', + b'catsup', + b'\xDE\xAD\xBE\xEF' * 512, + testing.rand_string(1, 2048), +]) +@pytest.mark.parametrize('extra_body', [True, False]) +@pytest.mark.parametrize('set_content_length', [True, False]) +def test_read_all(body, extra_body, set_content_length): + if extra_body and not set_content_length: + pytest.skip( + 'extra_body ignores set_content_length so we only need to test ' + 'one of the parameter permutations' + ) + + expected_body = body if isinstance(body, bytes) else body.encode() + + def stream(): + stream_body = body + content_length = None + + if extra_body: + # NOTE(kgriffs): Test emitting more data than expected to the app + content_length = len(expected_body) + stream_body += b'\x00' if isinstance(stream_body, bytes) else '~' + elif set_content_length: + content_length = len(expected_body) + + return _stream(stream_body, content_length=content_length) + + async def test_iteration(): + s = stream() + assert b''.join([chunk async for chunk in s]) == expected_body + assert await s.read() == b'' + assert await s.readall() == b'' + assert [chunk async for chunk in s][0] == b'' + assert s.tell() == len(expected_body) + assert s.eof + + async def test_readall_a(): + s = stream() + assert await s.readall() == expected_body + assert await s.read() == b'' + assert await s.readall() == b'' + assert [chunk async for chunk in s][0] == b'' + assert s.tell() == len(expected_body) + assert s.eof + + async def test_readall_b(): + s = stream() + assert await s.read() == expected_body + assert await s.readall() == b'' + assert await s.read() == b'' + assert [chunk async for chunk in s][0] == b'' + assert s.tell() == len(expected_body) + assert s.eof + + async def test_readall_c(): + s = stream() + body = await s.read(1) + body += await s.read(None) + assert body == expected_body + assert s.tell() == len(expected_body) + assert s.eof + + async def test_readall_d(): + s = stream() + assert not s.closed + + if expected_body: + assert not s.eof + elif set_content_length: + assert s.eof + else: + # NOTE(kgriffs): Stream doesn't know if there is more data + # coming or not until the first read. + assert not s.eof + + assert s.tell() == 0 + + assert await s.read(-2) == b'' + assert await s.read(-3) == b'' + assert await s.read(-100) == b'' + + assert await s.read(-1) == expected_body + assert await s.read(-1) == b'' + assert await s.readall() == b'' + assert await s.read() == b'' + assert [chunk async for chunk in s][0] == b'' + + assert await s.read(-2) == b'' + + assert s.tell() == len(expected_body) + assert s.eof + + assert not s.closed + s.close() + assert s.closed + + for t in (test_iteration, test_readall_a, test_readall_b, test_readall_c, test_readall_d): + testing.invoke_coroutine_sync(t) + + +def test_filelike(): + s = asgi.BoundedStream(testing.ASGIRequestEventEmitter()) + + for __ in range(2): + with pytest.raises(OSError): + s.fileno() + + assert not s.isatty() + assert s.readable() + assert not s.seekable() + assert not s.writable() + + s.close() + + assert s.closed + + # NOTE(kgriffs): Closing an already-closed stream is a noop. + s.close() + assert s.closed + + async def test_iteration(): + with pytest.raises(ValueError): + await s.read() + + with pytest.raises(ValueError): + await s.readall() + + with pytest.raises(ValueError): + await s.exhaust() + + with pytest.raises(ValueError): + async for chunk in s: + pass + + testing.invoke_coroutine_sync(test_iteration) + + +@pytest.mark.parametrize('body', [ + b'', + b'\x00', + b'\x00\xFF', + b'catsup', + b'\xDE\xAD\xBE\xEF' * 512, + testing.rand_string(1, 2048).encode(), +]) +@pytest.mark.parametrize('chunk_size', [1, 2, 10, 64, 100, 1000, 10000]) +def test_read_chunks(body, chunk_size): + def stream(): + return _stream(body) + + async def test_nonmixed(): + s = stream() + + assert await s.read(0) == b'' + + chunks = [] + + while not s.eof: + chunks.append(await s.read(chunk_size)) + + assert b''.join(chunks) == body + + async def test_mixed_a(): + s = stream() + + chunks = [] + + chunks.append(await s.read(chunk_size)) + chunks.append(await s.read(chunk_size)) + chunks.append(await s.readall()) + chunks.append(await s.read(chunk_size)) + + assert b''.join(chunks) == body + + async def test_mixed_b(): + s = stream() + + chunks = [] + + chunks.append(await s.read(chunk_size)) + chunks.append(await s.read(-1)) + + assert b''.join(chunks) == body + + for t in (test_nonmixed, test_mixed_a, test_mixed_b): + testing.invoke_coroutine_sync(t) + testing.invoke_coroutine_sync(t) + + +def test_exhaust_with_disconnect(): + async def t(): + emitter = testing.ASGIRequestEventEmitter( + b'123456798' * 1024, + disconnect_at=(time.time() + 0.5) + ) + s = asgi.BoundedStream(emitter) + + assert await s.read(1) == b'1' + assert await s.read(2) == b'23' + await asyncio.sleep(0.5) + await s.exhaust() + assert await s.read(1) == b'' + assert await s.read(100) == b'' + assert s.eof + + testing.invoke_coroutine_sync(t) + + +def _stream(body, content_length=None): + emitter = testing.ASGIRequestEventEmitter(body) + return asgi.BoundedStream(emitter, content_length=content_length) diff --git a/tests/asgi/test_hello_asgi.py b/tests/asgi/test_hello_asgi.py new file mode 100644 index 000000000..87392b847 --- /dev/null +++ b/tests/asgi/test_hello_asgi.py @@ -0,0 +1,315 @@ +import io +import os +import tempfile + +import aiofiles +import pytest + +import falcon +from falcon import testing +import falcon.asgi + +from _util import disable_asgi_non_coroutine_wrapping # NOQA + + +SIZE_1_KB = 1024 + + +@pytest.fixture +def client(): + return testing.TestClient(falcon.asgi.App()) + + +class DataReaderWithoutClose: + def __init__(self, data): + self._stream = io.BytesIO(data) + self.close_called = False + + async def read(self, num_bytes): + return self._stream.read(num_bytes) + + +class DataReader(DataReaderWithoutClose): + async def close(self): + self.close_called = True + + +class HelloResource: + sample_status = '200 OK' + sample_unicode = ('Hello World! \x80 - ' + testing.rand_string(0, 5)) + sample_utf8 = sample_unicode.encode('utf-8') + + def __init__(self, mode): + self.called = False + self.mode = mode + + async def on_get(self, req, resp): + self.called = True + self.req, self.resp = req, resp + + resp.status = falcon.HTTP_200 + + if 'stream' in self.mode: + if 'filelike' in self.mode: + stream = DataReader(self.sample_utf8) + else: + async def data_emitter(): + for b in self.sample_utf8: + yield bytes([b]) + + if 'stream_genfunc' in self.mode: + stream = data_emitter + elif 'stream_nongenfunc' in self.mode: + stream = 42 + else: + stream = data_emitter() + + if 'stream_len' in self.mode: + stream_len = len(self.sample_utf8) + else: + stream_len = None + + if 'use_helper' in self.mode: + resp.set_stream(stream, stream_len) + else: + resp.stream = stream + resp.content_length = stream_len + + if 'body' in self.mode: + if 'bytes' in self.mode: + resp.body = self.sample_utf8 + else: + resp.body = self.sample_unicode + + if 'data' in self.mode: + resp.data = self.sample_utf8 + + async def on_head(self, req, resp): + await self.on_get(req, resp) + + +class ClosingFilelikeHelloResource: + sample_status = '200 OK' + sample_unicode = ('Hello World! \x80' + testing.rand_string(0, 0)) + + sample_utf8 = sample_unicode.encode('utf-8') + + def __init__(self, stream_factory): + self.called = False + self.stream = stream_factory(self.sample_utf8) + self.stream_len = len(self.sample_utf8) + + async def on_get(self, req, resp): + self.called = True + self.req, self.resp = req, resp + resp.status = falcon.HTTP_200 + resp.set_stream(self.stream, self.stream_len) + + +class AIOFilesHelloResource: + def __init__(self): + self.sample_utf8 = testing.rand_string(8 * SIZE_1_KB, 16 * SIZE_1_KB).encode() + + fh, self.tempfile_name = tempfile.mkstemp() + with open(fh, 'wb') as f: + f.write(self.sample_utf8) + + self._aiofiles = None + + @property + def aiofiles_closed(self): + return not self._aiofiles or self._aiofiles.closed + + def cleanup(self): + os.remove(self.tempfile_name) + + async def on_get(self, req, resp): + self._aiofiles = await aiofiles.open(self.tempfile_name, 'rb') + resp.stream = self._aiofiles + + +class NoStatusResource: + async def on_get(self, req, resp): + pass + + +class PartialCoroutineResource: + def on_get(self, req, resp): + pass + + async def on_post(self, req, resp): + pass + + +class TestHelloWorld: + + def test_env_headers_list_of_tuples(self): + env = testing.create_environ(headers=[('User-Agent', 'Falcon-Test')]) + assert env['HTTP_USER_AGENT'] == 'Falcon-Test' + + def test_root_route(self, client): + doc = {'message': 'Hello world!'} + resource = testing.SimpleTestResourceAsync(json=doc) + client.app.add_route('/', resource) + + result = client.simulate_get() + assert result.json == doc + + def test_no_route(self, client): + result = client.simulate_get('/seenoevil') + assert result.status_code == 404 + + @pytest.mark.parametrize('path,resource,get_body', [ + ('/body', HelloResource('body'), lambda r: r.body.encode('utf-8')), + ('/bytes', HelloResource('body, bytes'), lambda r: r.body), + ('/data', HelloResource('data'), lambda r: r.data), + ]) + def test_body(self, client, path, resource, get_body): + client.app.add_route(path, resource) + + result = client.simulate_get(path) + resp = resource.resp + + content_length = int(result.headers['content-length']) + assert content_length == len(resource.sample_utf8) + + assert result.status == resource.sample_status + assert resp.status == resource.sample_status + assert get_body(resp) == resource.sample_utf8 + assert result.content == resource.sample_utf8 + + def test_no_body_on_head(self, client): + resource = HelloResource('body') + client.app.add_route('/body', resource) + result = client.simulate_head('/body') + + assert not result.content + assert result.status_code == 200 + assert resource.called + assert result.headers['content-length'] == str(len(HelloResource.sample_utf8)) + + def test_stream_chunked(self, client): + resource = HelloResource('stream') + client.app.add_route('/chunked-stream', resource) + + result = client.simulate_get('/chunked-stream') + + assert result.content == resource.sample_utf8 + assert 'content-length' not in result.headers + + def test_stream_known_len(self, client): + resource = HelloResource('stream, stream_len') + client.app.add_route('/stream', resource) + + result = client.simulate_get('/stream') + assert resource.called + + expected_len = int(resource.resp.content_length) + actual_len = int(result.headers['content-length']) + assert actual_len == expected_len + assert len(result.content) == expected_len + assert result.content == resource.sample_utf8 + + def test_filelike(self, client): + resource = HelloResource('stream, stream_len, filelike') + client.app.add_route('/filelike', resource) + + result = client.simulate_get('/filelike') + assert resource.called + + expected_len = int(resource.resp.content_length) + actual_len = int(result.headers['content-length']) + assert actual_len == expected_len + assert len(result.content) == expected_len + + result = client.simulate_get('/filelike') + assert resource.called + + expected_len = int(resource.resp.content_length) + actual_len = int(result.headers['content-length']) + assert actual_len == expected_len + assert len(result.content) == expected_len + + def test_genfunc_error(self, client): + resource = HelloResource('stream, stream_len, stream_genfunc') + client.app.add_route('/filelike', resource) + + with pytest.raises(TypeError): + client.simulate_get('/filelike') + + def test_nongenfunc_error(self, client): + resource = HelloResource('stream, stream_len, stream_nongenfunc') + client.app.add_route('/filelike', resource) + + with pytest.raises(TypeError): + client.simulate_get('/filelike') + + @pytest.mark.parametrize('stream_factory,assert_closed', [ + (DataReader, True), # Implements close() + (DataReaderWithoutClose, False), + ]) + def test_filelike_closing(self, client, stream_factory, assert_closed): + resource = ClosingFilelikeHelloResource(stream_factory) + client.app.add_route('/filelike-closing', resource) + + result = client.simulate_get('/filelike-closing') + assert resource.called + + expected_len = int(resource.resp.content_length) + actual_len = int(result.headers['content-length']) + assert actual_len == expected_len + assert len(result.content) == expected_len + + if assert_closed: + assert resource.stream.close_called + + def test_filelike_closing_aiofiles(self, client): + resource = AIOFilesHelloResource() + try: + client.app.add_route('/filelike-closing', resource) + + result = client.simulate_get('/filelike-closing') + + assert result.status_code == 200 + assert 'content-length' not in result.headers + assert result.content == resource.sample_utf8 + + assert resource.aiofiles_closed + + finally: + resource.cleanup() + + def test_filelike_using_helper(self, client): + resource = HelloResource('stream, stream_len, filelike, use_helper') + client.app.add_route('/filelike-helper', resource) + + result = client.simulate_get('/filelike-helper') + assert resource.called + + expected_len = int(resource.resp.content_length) + actual_len = int(result.headers['content-length']) + assert actual_len == expected_len + assert len(result.content) == expected_len + + def test_status_not_set(self, client): + client.app.add_route('/nostatus', NoStatusResource()) + + result = client.simulate_get('/nostatus') + + assert not result.content + assert result.status_code == 200 + + def test_coroutine_required(self, client): + with disable_asgi_non_coroutine_wrapping(): + with pytest.raises(TypeError) as exinfo: + client.app.add_route('/', PartialCoroutineResource()) + + assert 'responder must be a non-blocking async coroutine' in str(exinfo.value) + + def test_noncoroutine_required(self): + wsgi_app = falcon.App() + + with pytest.raises(TypeError) as exinfo: + wsgi_app.add_route('/', PartialCoroutineResource()) + + assert 'responder must be a regular synchronous method' in str(exinfo.value) diff --git a/tests/asgi/test_lifespan_handlers.py b/tests/asgi/test_lifespan_handlers.py new file mode 100644 index 000000000..2d7bcfdd5 --- /dev/null +++ b/tests/asgi/test_lifespan_handlers.py @@ -0,0 +1,205 @@ +import pytest + +from falcon import testing +from falcon.asgi import App + + +def test_at_least_one_event_method_required(): + class Foo: + pass + + app = App() + + with pytest.raises(TypeError): + app.add_middleware(Foo()) + + +def test_startup_only(): + class Foo: + async def process_startup(self, scope, event): + self._called = True + + foo = Foo() + + app = App() + app.add_middleware(foo) + client = testing.TestClient(app) + + client.simulate_get() + + assert foo._called + + +def test_startup_raises(): + class Foo: + def __init__(self): + self._shutdown_called = False + + async def process_startup(self, scope, event): + raise Exception('testing 123') + + async def process_shutdown(self, scope, event): + self._shutdown_called = True + + class Bar: + def __init__(self): + self._startup_called = False + self._shutdown_called = False + + async def process_startup(self, scope, event): + self._startup_called = True + + async def process_shutdown(self, scope, event): + self._shutdown_called = True + + foo = Foo() + bar = Bar() + + app = App() + app.add_middleware([foo, bar]) + client = testing.TestClient(app) + + with pytest.raises(RuntimeError) as excinfo: + client.simulate_get() + + message = str(excinfo.value) + + assert message.startswith('ASGI app returned lifespan.startup.failed.') + assert 'testing 123' in message + + assert not foo._shutdown_called + assert not bar._startup_called + assert not bar._shutdown_called + + +def test_shutdown_raises(): + class HandlerA: + def __init__(self): + self._startup_called = False + + async def process_startup(self, scope, event): + self._startup_called = True + + async def process_shutdown(self, scope, event): + raise Exception('testing 321') + + class HandlerB: + def __init__(self): + self._startup_called = False + self._shutdown_called = False + + async def process_startup(self, scope, event): + self._startup_called = True + + async def process_shutdown(self, scope, event): + self._shutdown_called = True + + a = HandlerA() + b1 = HandlerB() + b2 = HandlerB() + + app = App() + app.add_middleware(b1) + app.add_middleware([a, b2]) + client = testing.TestClient(app) + + with pytest.raises(RuntimeError) as excinfo: + client.simulate_get() + + message = str(excinfo.value) + + assert message.startswith('ASGI app returned lifespan.shutdown.failed.') + assert 'testing 321' in message + + assert a._startup_called + assert b1._startup_called + assert not b1._shutdown_called + assert b2._startup_called + assert b2._shutdown_called + + +def test_shutdown_only(): + class Foo: + async def process_shutdown(self, scope, event): + self._called = True + + foo = Foo() + + app = App() + app.add_middleware(foo) + client = testing.TestClient(app) + + client.simulate_get() + + assert foo._called + + +def test_multiple_handlers(): + counter = 0 + + class HandlerA: + async def process_startup(self, scope, event): + nonlocal counter + self._called_startup = counter + counter += 1 + + class HandlerB: + async def process_startup(self, scope, event): + nonlocal counter + self._called_startup = counter + counter += 1 + + async def process_shutdown(self, scope, event): + nonlocal counter + self._called_shutdown = counter + counter += 1 + + class HandlerC: + async def process_shutdown(self, scope, event): + nonlocal counter + self._called_shutdown = counter + counter += 1 + + class HandlerD: + async def process_startup(self, scope, event): + nonlocal counter + self._called_startup = counter + counter += 1 + + class HandlerE: + async def process_startup(self, scope, event): + nonlocal counter + self._called_startup = counter + counter += 1 + + async def process_shutdown(self, scope, event): + nonlocal counter + self._called_shutdown = counter + counter += 1 + + async def process_request(self, req, resp): + self._called_request = True + + app = App() + + a = HandlerA() + b = HandlerB() + c = HandlerC() + d = HandlerD() + e = HandlerE() + + app.add_middleware([a, b, c, d, e]) + + client = testing.TestClient(app) + client.simulate_get() + + assert a._called_startup == 0 + assert b._called_startup == 1 + assert d._called_startup == 2 + assert e._called_startup == 3 + + assert e._called_shutdown == 4 + assert c._called_shutdown == 5 + assert b._called_shutdown == 6 + + assert e._called_request diff --git a/tests/asgi/test_middleware_asgi.py b/tests/asgi/test_middleware_asgi.py new file mode 100644 index 000000000..0085ba358 --- /dev/null +++ b/tests/asgi/test_middleware_asgi.py @@ -0,0 +1,31 @@ +import pytest + +import falcon + + +class MiddlewareIncompatibleWithWSGI_A: + async def process_request(self, req, resp): + pass + + +class MiddlewareIncompatibleWithWSGI_B: + async def process_resource(self, req, resp, resource, params): + pass + + +class MiddlewareIncompatibleWithWSGI_C: + async def process_response(self, req, resp, resource, req_succeeded): + pass + + +@pytest.mark.parametrize('middleware', [ + MiddlewareIncompatibleWithWSGI_A(), + MiddlewareIncompatibleWithWSGI_B(), + MiddlewareIncompatibleWithWSGI_C(), + (MiddlewareIncompatibleWithWSGI_C(), MiddlewareIncompatibleWithWSGI_A()), +]) +def test_raise_on_incompatible(middleware): + api = falcon.API() + + with pytest.raises(falcon.CompatibilityError): + api.add_middleware(middleware) diff --git a/tests/asgi/test_request_asgi.py b/tests/asgi/test_request_asgi.py new file mode 100644 index 000000000..6a84bd8b7 --- /dev/null +++ b/tests/asgi/test_request_asgi.py @@ -0,0 +1,21 @@ +import pytest + +from falcon import testing, UnsupportedError + + +def test_missing_server_in_scope(): + req = testing.create_asgi_req(include_server=False, http_version='1.0') + assert req.host == 'localhost' + assert req.port == 80 + + +def test_log_error_not_supported(): + req = testing.create_asgi_req() + with pytest.raises(NotImplementedError): + req.log_error('Boink') + + +def test_media_prop_not_supported(): + req = testing.create_asgi_req() + with pytest.raises(UnsupportedError): + req.media diff --git a/tests/asgi/test_request_body_asgi.py b/tests/asgi/test_request_body_asgi.py new file mode 100644 index 000000000..3acbed3b3 --- /dev/null +++ b/tests/asgi/test_request_body_asgi.py @@ -0,0 +1,86 @@ +import pytest + +import falcon +import falcon.asgi +import falcon.request +import falcon.testing as testing + + +SIZE_1_KB = 1024 + + +@pytest.fixture +def resource(): + return testing.SimpleTestResourceAsync() + + +@pytest.fixture +def client(): + app = falcon.asgi.App() + return testing.TestClient(app) + + +class TestRequestBody: + def test_empty_body(self, client, resource): + client.app.add_route('/', resource) + client.simulate_request(path='/', body='') + stream = resource.captured_req.stream + assert stream.tell() == 0 + + def test_tiny_body(self, client, resource): + client.app.add_route('/', resource) + expected_body = '.' + + headers = {'capture-req-body-bytes': '1'} + client.simulate_request(path='/', body=expected_body, headers=headers) + stream = resource.captured_req.stream + + assert resource.captured_req_body == expected_body.encode('utf-8') + assert stream.tell() == 1 + + def test_tiny_body_overflow(self, client, resource): + client.app.add_route('/', resource) + expected_body = '.' + expected_len = len(expected_body) + + # Read too many bytes; shouldn't block + headers = {'capture-req-body-bytes': str(len(expected_body) + 1)} + client.simulate_request(path='/', body=expected_body, headers=headers) + stream = resource.captured_req.stream + + assert resource.captured_req_body == expected_body.encode('utf-8') + assert stream.tell() == expected_len + + def test_read_body(self, client, resource): + client.app.add_route('/', resource) + expected_body = testing.rand_string(SIZE_1_KB / 2, SIZE_1_KB) + expected_len = len(expected_body) + + headers = { + 'Content-Length': str(expected_len), + 'Capture-Req-Body-Bytes': '-1', + } + client.simulate_request(path='/', body=expected_body, headers=headers) + + content_len = resource.captured_req.get_header('content-length') + assert content_len == str(expected_len) + + stream = resource.captured_req.stream + + assert resource.captured_req_body == expected_body.encode('utf-8') + assert stream.tell() == expected_len + + def test_bounded_stream_alias(self): + scope = testing.create_scope() + req_event_emitter = testing.ASGIRequestEventEmitter(b'', 0) + req = falcon.asgi.Request(scope, req_event_emitter) + + assert req.bounded_stream is req.stream + + def test_request_repr(self): + scope = testing.create_scope() + req_event_emitter = testing.ASGIRequestEventEmitter(b'', 0) + req = falcon.asgi.Request(scope, req_event_emitter) + + _repr = '<%s: %s %r>' % (req.__class__.__name__, req.method, req.url) + assert req.__repr__() == _repr diff --git a/tests/asgi/test_request_context_asgi.py b/tests/asgi/test_request_context_asgi.py new file mode 100644 index 000000000..4e1a0a10e --- /dev/null +++ b/tests/asgi/test_request_context_asgi.py @@ -0,0 +1,53 @@ +import pytest + +from falcon.asgi import Request +import falcon.testing as testing + + +class TestRequestContext: + + def test_default_request_context(self,): + req = testing.create_asgi_req() + + req.context.hello = 'World' + assert req.context.hello == 'World' + assert req.context['hello'] == 'World' + + req.context['note'] = 'Default Request.context_type used to be dict.' + assert 'note' in req.context + assert hasattr(req.context, 'note') + assert req.context.get('note') == req.context['note'] + + def test_custom_request_context(self): + + # Define a Request-alike with a custom context type + class MyCustomContextType(): + pass + + class MyCustomRequest(Request): + context_type = MyCustomContextType + + req = testing.create_asgi_req(req_type=MyCustomRequest) + assert isinstance(req.context, MyCustomContextType) + + def test_custom_request_context_failure(self): + + # Define a Request-alike with a non-callable custom context type + class MyCustomRequest(Request): + context_type = False + + with pytest.raises(TypeError): + testing.create_asgi_req(req_type=MyCustomRequest) + + def test_custom_request_context_request_access(self): + + def create_context(req): + return {'uri': req.uri} + + # Define a Request-alike with a custom context type + class MyCustomRequest(Request): + context_type = create_context + + req = testing.create_asgi_req(req_type=MyCustomRequest) + assert isinstance(req.context, dict) + assert req.context['uri'] == req.uri diff --git a/tests/asgi/test_response_media_asgi.py b/tests/asgi/test_response_media_asgi.py new file mode 100644 index 000000000..fdaddc233 --- /dev/null +++ b/tests/asgi/test_response_media_asgi.py @@ -0,0 +1,164 @@ +import json + +import pytest + +import falcon +from falcon import errors, media, testing +import falcon.asgi + + +def create_client(resource, handlers=None): + app = falcon.asgi.App() + app.add_route('/', resource) + + if handlers: + app.resp_options.media_handlers.update(handlers) + + client = testing.TestClient(app, headers={'capture-resp-media': 'yes'}) + + return client + + +class SimpleMediaResource: + + def __init__(self, document, media_type=falcon.MEDIA_JSON): + self._document = document + self._media_type = media_type + + async def on_get(self, req, resp): + resp.content_type = self._media_type + resp.media = self._document + resp.status = falcon.HTTP_OK + + +@pytest.mark.parametrize('media_type', [ + ('*/*'), + (falcon.MEDIA_JSON), + ('application/json; charset=utf-8'), +]) +def test_json(media_type): + class TestResource: + async def on_get(self, req, resp): + resp.content_type = media_type + resp.media = {'something': True} + + body = await resp.render_body() + + assert json.loads(body.decode('utf-8')) == {'something': True} + + client = create_client(TestResource()) + client.simulate_get('/') + + +@pytest.mark.parametrize('document', [ + '', + 'I am a \u1d0a\ua731\u1d0f\u0274 string.', + ['\u2665', '\u2660', '\u2666', '\u2663'], + {'message': '\xa1Hello Unicode! \U0001F638'}, + { + 'description': 'A collection of primitive Python type examples.', + 'bool': False is not True and True is not False, + 'dict': {'example': 'mapping'}, + 'float': 1.0, + 'int': 1337, + 'list': ['a', 'sequence', 'of', 'items'], + 'none': None, + 'str': 'ASCII string', + 'unicode': 'Hello Unicode! \U0001F638', + }, +]) +def test_non_ascii_json_serialization(document): + client = create_client(SimpleMediaResource(document)) + resp = client.simulate_get('/') + assert resp.json == document + + +@pytest.mark.parametrize('media_type', [ + (falcon.MEDIA_MSGPACK), + ('application/msgpack; charset=utf-8'), + ('application/x-msgpack'), +]) +def test_msgpack(media_type): + + class TestResource: + async def on_get(self, req, resp): + resp.content_type = media_type + + # Bytes + resp.media = {b'something': True} + assert (await resp.render_body()) == b'\x81\xc4\tsomething\xc3' + + # Unicode + resp.media = {'something': True} + body = await resp.render_body() + assert body == b'\x81\xa9something\xc3' + + # Ensure that the result is being cached + assert (await resp.render_body()) is body + + client = create_client(TestResource(), handlers={ + 'application/msgpack': media.MessagePackHandler(), + 'application/x-msgpack': media.MessagePackHandler(), + }) + client.simulate_get('/') + + +def test_unknown_media_type(): + class TestResource: + async def on_get(self, req, resp): + resp.content_type = 'nope/json' + resp.media = {'something': True} + + try: + await resp.render_body() + except Exception as ex: + # NOTE(kgriffs): pytest.raises triggers a failed test even + # when the correct error is raises, so we check it like + # this instead. + assert isinstance(ex, errors.HTTPUnsupportedMediaType) + raise + + client = create_client(TestResource()) + result = client.simulate_get('/') + assert result.status_code == 415 + + +def test_default_media_type(): + doc = {'something': True} + + class TestResource: + async def on_get(self, req, resp): + resp.content_type = '' + resp.media = {'something': True} + + body = await resp.render_body() + assert json.loads(body.decode('utf-8')) == doc + assert resp.content_type == 'application/json' + + client = create_client(TestResource()) + result = client.simulate_get('/') + assert result.json == doc + + +def test_mimeparse_edgecases(): + doc = {'something': True} + + class TestResource: + async def on_get(self, req, resp): + resp.content_type = 'application/vnd.something' + with pytest.raises(errors.HTTPUnsupportedMediaType): + resp.media = {'something': False} + await resp.render_body() + + resp.content_type = 'invalid' + with pytest.raises(errors.HTTPUnsupportedMediaType): + resp.media = {'something': False} + await resp.render_body() + + # Clear the content type, shouldn't raise this time + resp.content_type = None + resp.media = doc + + client = create_client(TestResource()) + result = client.simulate_get('/') + assert result.json == doc diff --git a/tests/asgi/test_scheduled_callbacks.py b/tests/asgi/test_scheduled_callbacks.py new file mode 100644 index 000000000..19c94f6b5 --- /dev/null +++ b/tests/asgi/test_scheduled_callbacks.py @@ -0,0 +1,74 @@ +from collections import Counter +import time + +import pytest + +from falcon import testing +from falcon.asgi import App + + +def test_multiple(): + class SomeResource: + def __init__(self): + self.counter = Counter() + + async def on_get(self, req, resp): + async def background_job_async(): + self.counter['backround:on_get:async'] += 1 + + def background_job_sync(): + self.counter['backround:on_get:sync'] += 20 + + resp.schedule(background_job_async) + resp.schedule(background_job_sync) + resp.schedule(background_job_async) + resp.schedule(background_job_sync) + + async def on_post(self, req, resp): + async def background_job_async(): + self.counter['backround:on_get:async'] += 1000 + + def background_job_sync(): + self.counter['backround:on_get:sync'] += 2000 + + resp.schedule(background_job_async) + resp.schedule(background_job_async) + resp.schedule(background_job_sync) + resp.schedule(background_job_sync) + + async def on_put(self, req, resp): + async def background_job_async(): + self.counter['backround:on_get:async'] += 1000 + + c = background_job_async() + + try: + resp.schedule(c) + finally: + await c + + resource = SomeResource() + + app = App() + app.add_route('/', resource) + + client = testing.TestClient(app) + + client.simulate_get() + client.simulate_post() + + time.sleep(0.5) + + assert resource.counter['backround:on_get:async'] == 2002 + assert resource.counter['backround:on_get:sync'] == 4040 + + result = client.simulate_put() + assert result.status_code == 500 + + # NOTE(kgriffs): Remove default handlers so that we can check the raised + # exception is what we expecte. + app._error_handlers.clear() + with pytest.raises(TypeError) as exinfo: + client.simulate_put() + + assert 'coroutine' in str(exinfo.value) diff --git a/tests/asgi/test_scope.py b/tests/asgi/test_scope.py new file mode 100644 index 000000000..f3d603594 --- /dev/null +++ b/tests/asgi/test_scope.py @@ -0,0 +1,218 @@ +import asyncio + +import pytest + +from falcon import testing +from falcon.asgi import App +from falcon.errors import UnsupportedScopeError + + +def test_missing_asgi_version(): + scope = testing.create_scope() + del scope['asgi'] + + resource = _call_with_scope(scope) + + # NOTE(kgriffs): According to the ASGI spec, the version should + # default to "2.0". + assert resource.captured_req.scope['asgi']['version'] == '2.0' + + +@pytest.mark.parametrize('version, supported', [ + ('3.0', True), + ('3.1', True), + ('3.10', True), + ('30.0', False), + ('31.0', False), + ('4.0', False), + ('4.1', False), + ('4.10', False), + ('40.0', False), + ('41.0', False), + ('2.0', False), + ('2.1', False), + ('2.10', False), + (None, False), +]) +def test_supported_asgi_version(version, supported): + scope = testing.create_scope() + + if version: + scope['asgi']['version'] = version + else: + del scope['asgi']['version'] + + if supported: + _call_with_scope(scope) + else: + with pytest.raises(UnsupportedScopeError): + _call_with_scope(scope) + + +@pytest.mark.parametrize('scope_type', ['websocket', 'tubes', 'http3', 'htt']) +def test_unsupported_scope_type(scope_type): + scope = testing.create_scope() + scope['type'] = scope_type + with pytest.raises(UnsupportedScopeError): + _call_with_scope(scope) + + +@pytest.mark.parametrize('spec_version, supported', [ + ('0.0', False), + ('1.0', False), + ('11.0', False), + ('2.0', True), + ('2.1', True), + ('2.10', True), + ('20.0', False), + ('22.0', False), + ('3.0', False), + ('3.1', False), + ('30.0', False), +]) +def test_supported_http_spec(spec_version, supported): + scope = testing.create_scope() + scope['asgi']['spec_version'] = spec_version + + if supported: + _call_with_scope(scope) + else: + with pytest.raises(UnsupportedScopeError): + _call_with_scope(scope) + + +def test_lifespan_scope_default_version(): + app = App() + + resource = testing.SimpleTestResourceAsync() + + app.add_route('/', resource) + + shutting_down = asyncio.Condition() + req_event_emitter = testing.ASGILifespanEventEmitter(shutting_down) + resp_event_collector = testing.ASGIResponseEventCollector() + + scope = {'type': 'lifespan'} + + async def t(): + t = asyncio.get_event_loop().create_task( + app(scope, req_event_emitter, resp_event_collector) + ) + + # NOTE(kgriffs): Yield to the lifespan task above + await asyncio.sleep(0.001) + + async with shutting_down: + shutting_down.notify() + + await t + + testing.invoke_coroutine_sync(t) + + assert not resource.called + + +@pytest.mark.parametrize('spec_version, supported', [ + ('0.0', False), + ('1.0', True), + ('1.1', True), + ('1.10', True), + ('2.0', True), + ('2.1', True), + ('2.10', True), + ('3.0', False), + ('4.0', False), + ('11.0', False), + ('22.0', False), +]) +def test_lifespan_scope_version(spec_version, supported): + app = App() + + shutting_down = asyncio.Condition() + req_event_emitter = testing.ASGILifespanEventEmitter(shutting_down) + resp_event_collector = testing.ASGIResponseEventCollector() + + scope = { + 'type': 'lifespan', + 'asgi': {'spec_version': spec_version, 'version': '3.0'} + } + + if not supported: + with pytest.raises(UnsupportedScopeError): + testing.invoke_coroutine_sync( + app.__call__, scope, req_event_emitter, resp_event_collector + ) + + return + + async def t(): + t = asyncio.get_event_loop().create_task( + app(scope, req_event_emitter, resp_event_collector) + ) + + # NOTE(kgriffs): Yield to the lifespan task above + await asyncio.sleep(0.001) + + async with shutting_down: + shutting_down.notify() + + await t + + testing.invoke_coroutine_sync(t) + + +def test_query_string_values(): + with pytest.raises(ValueError): + testing.create_scope(query_string='?catsup=y') + + with pytest.raises(ValueError): + testing.create_scope(query_string='?') + + for qs in ('', None): + scope = testing.create_scope(query_string=qs) + assert scope['query_string'] == b'' + + resource = _call_with_scope(scope) + assert resource.captured_req.query_string == '' + + qs = 'a=1&b=2&c=%3E%20%3C' + scope = testing.create_scope(query_string=qs) + assert scope['query_string'] == qs.encode() + + resource = _call_with_scope(scope) + assert resource.captured_req.query_string == qs + + +@pytest.mark.parametrize('scheme, valid', [ + ('http', True), + ('https', True), + ('htt', False), + ('http:', False), + ('https:', False), + ('ftp', False), + ('gopher', False), +]) +def test_scheme(scheme, valid): + if valid: + testing.create_scope(scheme=scheme) + else: + with pytest.raises(ValueError): + testing.create_scope(scheme=scheme) + + +def _call_with_scope(scope): + app = App() + + resource = testing.SimpleTestResourceAsync() + + app.add_route('/', resource) + + req_event_emitter = testing.ASGIRequestEventEmitter() + resp_event_collector = testing.ASGIResponseEventCollector() + + testing.invoke_coroutine_sync( + app.__call__, scope, req_event_emitter, resp_event_collector + ) + + assert resource.called + return resource diff --git a/tests/asgi/test_sse.py b/tests/asgi/test_sse.py new file mode 100644 index 000000000..116a970e0 --- /dev/null +++ b/tests/asgi/test_sse.py @@ -0,0 +1,167 @@ +import pytest + +from falcon import testing +from falcon.asgi import App, SSEvent + + +def test_no_events(): + + class Emitter: + def __aiter__(self): + return self + + async def __anext__(self): + raise StopAsyncIteration + + class SomeResource: + async def on_get(self, req, resp): + self._called = True + resp.sse = Emitter() + + resource = SomeResource() + + app = App() + app.add_route('/', resource) + + client = testing.TestClient(app) + client.simulate_get() + + assert resource._called + + +def test_single_event(): + class SomeResource: + async def on_get(self, req, resp): + async def emitter(): + yield + + resp.sse = emitter() + + async def on_post(self, req, resp): + async def emitter(): + yield SSEvent() + + resp.sse = emitter() + + resource = SomeResource() + + app = App() + app.add_route('/', resource) + + client = testing.TestClient(app) + + result = client.simulate_get() + assert result.text == ': ping\n\n' + + result = client.simulate_post() + assert result.text == ': ping\n\n' + + +def test_multiple_events(): + class SomeResource: + async def on_get(self, req, resp): + async def emitter(): + yield SSEvent(data=b'ketchup') + yield SSEvent(data=b'mustard', event='condiment') + yield SSEvent(data=b'mayo', event='condiment', event_id='1234') + yield SSEvent(data=b'onions', event='topping', event_id='5678', retry=100) + yield SSEvent(text='guacamole \u1F951', retry=100, comment='Serve with chips.') + yield SSEvent(json={'condiment': 'salsa'}, retry=100) + + resp.sse = emitter() + + resource = SomeResource() + + app = App() + app.add_route('/', resource) + + client = testing.TestClient(app) + + result = client.simulate_get() + assert result.text == ( + 'data: ketchup\n' + '\n' + 'event: condiment\n' + 'data: mustard\n' + '\n' + 'event: condiment\n' + 'id: 1234\n' + 'data: mayo\n' + '\n' + 'event: topping\n' + 'id: 5678\n' + 'retry: 100\n' + 'data: onions\n' + '\n' + ': Serve with chips.\n' + 'retry: 100\n' + 'data: guacamole \u1F951\n' + '\n' + 'retry: 100\n' + 'data: {"condiment": "salsa"}\n' + '\n' + ) + + +def test_invalid_event_values(): + with pytest.raises(TypeError): + SSEvent(data='notbytes') + + with pytest.raises(TypeError): + SSEvent(data=12345) + + with pytest.raises(TypeError): + SSEvent(text=b'notbytes') + + with pytest.raises(TypeError): + SSEvent(text=23455) + + with pytest.raises(TypeError): + SSEvent(json=set()).serialize() + + with pytest.raises(TypeError): + SSEvent(event=b'name') + + with pytest.raises(TypeError): + SSEvent(event=1234) + + with pytest.raises(TypeError): + SSEvent(event_id=b'idbytes') + + with pytest.raises(TypeError): + SSEvent(event_id=52085) + + with pytest.raises(TypeError): + SSEvent(retry='5808.25') + + with pytest.raises(TypeError): + SSEvent(retry=5808.25) + + with pytest.raises(TypeError): + SSEvent(comment=b'somebytes') + + with pytest.raises(TypeError): + SSEvent(comment=1234) + + +def test_non_iterable(): + class SomeResource: + async def on_get(self, req, resp): + async def emitter(): + yield + + resp.sse = emitter + + resource = SomeResource() + + app = App() + app.add_route('/', resource) + + client = testing.TestClient(app) + + with pytest.raises(TypeError): + client.simulate_get() + + +# TODO: Test with uvicorn +# TODO: Test in browser with JavaScript diff --git a/tests/asgi/test_sync.py b/tests/asgi/test_sync.py new file mode 100644 index 000000000..e4b58b091 --- /dev/null +++ b/tests/asgi/test_sync.py @@ -0,0 +1,104 @@ +import asyncio +import time + +import pytest + +from falcon import testing +from falcon.asgi import App +import falcon.util + + +def test_sync_helpers(): + safely_values = [] + unsafely_values = [] + shirley_values = [] + + class SomeResource: + async def on_get(self, req, resp): + safely_coroutine_objects = [] + unsafely_coroutine_objects = [] + shirley_coroutine_objects = [] + + def callme_safely(a, b, c=None): + # NOTE(kgriffs): Sleep to prove that there isn't another + # instance running in parallel that is able to race ahead. + time.sleep(0.001) + safely_values.append((a, b, c)) + pass + + def callme_unsafely(a, b, c=None): + time.sleep(0.01) + unsafely_values.append((a, b, c)) + pass + + def callme_shirley(a=42, b=None): + time.sleep(0.01) + v = (a, b) + shirley_values.append(v) + + # NOTE(kgriffs): Test that returning values works as expected + return v + + # NOTE(kgriffs): Test setting threadsafe=True explicitly + cmus = falcon.util.wrap_sync_to_async(callme_unsafely, threadsafe=True) + cms = falcon.util.wrap_sync_to_async(callme_safely, threadsafe=False) + + loop = falcon.util.get_loop() + + # NOTE(kgriffs): create_task() is used here, so that the coroutines + # are scheduled immediately in the order created; under Python + # 3.6, asyncio.gather() does not seem to always schedule + # them in order, so we do it this way to make it predictable. + for i in range(1000): + safely_coroutine_objects.append( + loop.create_task(cms(i, i + 1, c=i + 2)) + ) + unsafely_coroutine_objects.append( + loop.create_task(cmus(i, i + 1, c=i + 2)) + ) + shirley_coroutine_objects.append( + loop.create_task(falcon.util.sync_to_async(callme_shirley, 24, b=i)) + ) + + await asyncio.gather( + *( + safely_coroutine_objects + + unsafely_coroutine_objects + + shirley_coroutine_objects + ) + ) + + assert (42, None) == await falcon.util.sync_to_async(callme_shirley) + assert (1, 2) == await falcon.util.sync_to_async(callme_shirley, 1, 2) + assert (3, 4) == await falcon.util.sync_to_async(callme_shirley, 3, b=4) + + assert (5, None) == await falcon.util.wrap_sync_to_async(callme_shirley)(5) + assert (42, 6) == await falcon.util.wrap_sync_to_async( + callme_shirley, threadsafe=True)(b=6) + + with pytest.raises(TypeError): + await falcon.util.sync_to_async(callme_shirley, -1, bogus=-1) + + resource = SomeResource() + + app = App() + app.add_route('/', resource) + + client = testing.TestClient(app) + + result = client.simulate_get() + assert result.status_code == 200 + + assert len(safely_values) == 1000 + for i, val in enumerate(safely_values): + assert val == (i, i + 1, i + 2) + + assert len(unsafely_values) == 1000 + assert any( + val != (i, i + 1, i + 2) + for i, val in enumerate(unsafely_values) + ) + + for i, val in enumerate(shirley_values): + assert val[0] in {24, 42, 1, 5, 3} + assert val[1] is None or (0 <= val[1] < 1000) diff --git a/tests/conftest.py b/tests/conftest.py index 361b74cae..268155380 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,16 @@ import falcon +@pytest.fixture(params=[True, False]) +def asgi(request): + is_asgi = request.param + + if is_asgi and not falcon.ASGI_SUPPORTED: + pytest.skip('ASGI requires Python 3.6+') + + return is_asgi + + # NOTE(kgriffs): Some modules actually run a wsgiref server, so # to ensure we reset the detection for the other modules, we just # run this fixture before each one is tested. @@ -16,3 +26,11 @@ def pytest_configure(config): plugin = config.pluginmanager.getplugin('mypy') if plugin: plugin.mypy_argv.append('--ignore-missing-imports') + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_protocol(item, nextitem): + if hasattr(item, 'cls') and item.cls: + item.cls._item = item + + yield diff --git a/tests/dump_asgi.py b/tests/dump_asgi.py new file mode 100644 index 000000000..d96562d57 --- /dev/null +++ b/tests/dump_asgi.py @@ -0,0 +1,23 @@ +import asyncio +import time + + +async def _say_hi(): + print(f'[{time.time()}] Hi!') + + +async def app(scope, receive, send): + await send({ + 'type': 'http.response.start', + 'status': 200, + 'headers': [ + [b'content-type', b'application/json'], + ] + }) + await send({ + 'type': 'http.response.body', + 'body': f'[{time.time()}] Hello world!'.encode(), + }) + + loop = asyncio.get_event_loop() + loop.create_task(_say_hi()) diff --git a/tests/test_after_hooks.py b/tests/test_after_hooks.py index e986ae67e..bef116e35 100644 --- a/tests/test_after_hooks.py +++ b/tests/test_after_hooks.py @@ -6,6 +6,8 @@ import falcon from falcon import testing +from _util import create_app, create_resp # NOQA + # -------------------------------------------------------------------- # Fixtures @@ -18,10 +20,10 @@ def wrapped_resource_aware(): @pytest.fixture -def client(): - app = falcon.App() +def client(asgi): + app = create_app(asgi) - resource = WrappedRespondersResource() + resource = WrappedRespondersResourceAsync() if asgi else WrappedRespondersResource() app.add_route('/', resource) return testing.TestClient(app) @@ -47,6 +49,10 @@ def serialize_body(req, resp, resource): resp.body = 'Nothing to see here. Move along.' +async def serialize_body_async(*args): + return serialize_body(*args) + + def fluffiness(req, resp, resource, animal=''): assert resource @@ -123,6 +129,25 @@ def on_post(self, req, resp): pass +class WrappedRespondersResourceAsync: + + @falcon.after(serialize_body_async) + @falcon.after(validate_output) + async def on_get(self, req, resp): + self.req = req + self.resp = resp + + @falcon.after(serialize_body_async) + async def on_put(self, req, resp): + self.req = req + self.resp = resp + resp.body = {'animal': 'falcon'} + + @falcon.after(Smartness()) + async def on_post(self, req, resp): + pass + + @falcon.after(cuteness, 'fluffy', postfix=' and innocent') @falcon.after(fluffiness, 'kitten') class WrappedClassResource: @@ -158,6 +183,13 @@ def on_get(self, req, resp, field1, field2): self.fields = (field1, field2) +class ClassResourceWithURIFieldsAsync: + + @falcon.after(fluffiness_in_the_head, 'fluffy') + async def on_get(self, req, resp, field1, field2): + self.fields = (field1, field2) + + class ClassResourceWithURIFieldsChild(ClassResourceWithURIFields): def on_get(self, req, resp, field1, field2): @@ -244,6 +276,30 @@ def test_resource_with_uri_fields(client, resource): assert resource.fields == ('82074', '58927') +def test_resource_with_uri_fields_async(): + app = create_app(asgi=True) + + resource = ClassResourceWithURIFieldsAsync() + app.add_route('/{field1}/{field2}', resource) + + result = testing.simulate_get(app, '/a/b') + + assert result.status_code == 200 + assert result.headers['X-Fluffiness'] == 'fluffy' + assert resource.fields == ('a', 'b') + + async def test_direct(): + resource = ClassResourceWithURIFieldsAsync() + + req = testing.create_asgi_req() + resp = create_resp(True) + + await resource.on_get(req, resp, '1', '2') + assert resource.fields == ('1', '2') + + testing.invoke_coroutine_sync(test_direct) + + @pytest.mark.parametrize( 'resource', [ diff --git a/tests/test_before_hooks.py b/tests/test_before_hooks.py index ff9b8d3b1..02a7c9417 100644 --- a/tests/test_before_hooks.py +++ b/tests/test_before_hooks.py @@ -7,6 +7,8 @@ import falcon import falcon.testing as testing +from _util import create_app, create_resp, disable_asgi_non_coroutine_wrapping # NOQA + def validate(req, resp, resource, params): assert resource @@ -23,6 +25,10 @@ def validate_param(req, resp, resource, params, param_name, maxval=100): raise falcon.HTTPBadRequest('Out of Range', msg) +async def validate_param_async(*args, **kwargs): + validate_param(*args, **kwargs) + + class ResourceAwareValidateParam: def __call__(self, req, resp, resource, params): assert resource @@ -41,11 +47,20 @@ def validate_field(req, resp, resource, params, field_name='test'): def parse_body(req, resp, resource, params): assert resource - length = req.content_length or 0 - if length != 0: + length = req.content_length + if length: params['doc'] = json.load(io.TextIOWrapper(req.bounded_stream, 'utf-8')) +async def parse_body_async(req, resp, resource, params): + assert resource + + length = req.content_length + if length: + data = await req.bounded_stream.read() + params['doc'] = json.loads(data.decode('utf-8')) + + def bunnies(req, resp, resource, params): assert resource params['bunnies'] = 'fuzzy' @@ -92,11 +107,9 @@ def things_in_the_head(header, value, req, resp, resource, params): class WrappedRespondersResource: @falcon.before(validate_param, 'limit', 100) - @falcon.before(parse_body) - def on_get(self, req, resp, doc): + def on_get(self, req, resp): self.req = req self.resp = resp - self.doc = doc @falcon.before(validate) def on_put(self, req, resp): @@ -115,6 +128,16 @@ def on_put(self, req, resp): super(WrappedRespondersResourceChild, self).on_put(req, resp) +class WrappedRespondersBodyParserResource: + + @falcon.before(validate_param, 'limit', 100) + @falcon.before(parse_body) + def on_get(self, req, resp, doc=None): + self.req = req + self.resp = resp + self.doc = doc + + @falcon.before(bunnies) class WrappedClassResource: @@ -239,8 +262,8 @@ def resource(): @pytest.fixture -def client(resource): - app = falcon.App() +def client(asgi, request, resource): + app = create_app(asgi) app.add_route('/', resource) return testing.TestClient(app) @@ -302,9 +325,62 @@ def test_field_validator(client, resource): assert result.status_code == 400 -def test_parser(client, resource): - client.simulate_get('/', body=json.dumps({'animal': 'falcon'})) - assert resource.doc == {'animal': 'falcon'} +@pytest.mark.parametrize( + 'body,doc', + [ + (json.dumps({'animal': 'falcon'}), {'animal': 'falcon'}), + ('{}', {}), + ('', None), + (None, None), + ] +) +def test_parser_sync(body, doc): + app = falcon.API() + + resource = WrappedRespondersBodyParserResource() + app.add_route('/', resource) + + testing.simulate_get(app, '/', body=body) + assert resource.doc == doc + + +@pytest.mark.parametrize( + 'body,doc', + [ + (json.dumps({'animal': 'falcon'}), {'animal': 'falcon'}), + ('{}', {}), + ('', None), + (None, None), + ] +) +def test_parser_async(body, doc): + with disable_asgi_non_coroutine_wrapping(): + class WrappedRespondersBodyParserAsyncResource: + @falcon.before(validate_param_async, 'limit', 100) + @falcon.before(parse_body_async) + async def on_get(self, req, resp, doc=None): + self.req = req + self.resp = resp + self.doc = doc + + app = create_app(asgi=True) + + resource = WrappedRespondersBodyParserAsyncResource() + app.add_route('/', resource) + + testing.simulate_get(app, '/', body=body) + assert resource.doc == doc + + async def test_direct(): + resource = WrappedRespondersBodyParserAsyncResource() + + req = testing.create_asgi_req() + resp = create_resp(True) + + await resource.on_get(req, resp, doc) + assert resource.doc == doc + + testing.invoke_coroutine_sync(test_direct) def test_wrapped_resource(client, wrapped_resource): @@ -402,11 +478,25 @@ def on_post_collection(self, req, resp): resp.status = falcon.HTTP_CREATED -@pytest.fixture -def app_client(): - items = PiggybackingCollection() +class PiggybackingCollectionAsync(PiggybackingCollection): + + @falcon.before(header_hook) + async def on_post_collection(self, req, resp): + self._sequence += 1 + itemid = self._sequence + + doc = await req.get_media() + + self._items[itemid] = dict(doc, itemid=itemid) + resp.location = '/items/{}'.format(itemid) + resp.status = falcon.HTTP_CREATED + + +@pytest.fixture(params=[True, False]) +def app_client(request): + items = PiggybackingCollectionAsync() if request.param else PiggybackingCollection() - app = falcon.App() + app = create_app(asgi=request.param) app.add_route('/items', items, suffix='collection') app.add_route('/items/{itemid:int}', items) diff --git a/tests/test_boundedstream.py b/tests/test_boundedstream.py index d28ce61a4..f2a8cb7b9 100644 --- a/tests/test_boundedstream.py +++ b/tests/test_boundedstream.py @@ -10,8 +10,8 @@ def bounded_stream(): return BoundedStream(io.BytesIO(), 1024) -def test_not_writeable(bounded_stream): - assert not bounded_stream.writeable() +def test_not_writable(bounded_stream): + assert not bounded_stream.writable() with pytest.raises(IOError): bounded_stream.write(b'something something') diff --git a/tests/test_cmd_print_api.py b/tests/test_cmd_print_api.py index 085fec03d..d4a11a30f 100644 --- a/tests/test_cmd_print_api.py +++ b/tests/test_cmd_print_api.py @@ -1,13 +1,11 @@ import io -from falcon import App +import pytest + from falcon.cmd import print_routes from falcon.testing import redirected -try: - import cython -except ImportError: - cython = None +from _util import create_app # NOQA class DummyResource: @@ -17,16 +15,27 @@ def on_get(self, req, resp): resp.status = '200 OK' -_api = App() -_api.add_route('/test', DummyResource()) +class DummyResourceAsync: + + async def on_get(self, req, resp): + resp.body = 'Test\n' + resp.status = '200 OK' + +@pytest.fixture +def app(asgi): + app = create_app(asgi) + app.add_route('/test', DummyResourceAsync() if asgi else DummyResource()) -def test_traverse_with_verbose(): + return app + + +def test_traverse_with_verbose(app): """Ensure traverse() finds the proper routes and outputs verbose info.""" output = io.StringIO() with redirected(stdout=output): - print_routes.traverse(_api._router._roots, verbose=True) + print_routes.traverse(app._router._roots, verbose=True) route, get_info, options_info = output.getvalue().strip().split('\n') assert '-> /test' == route @@ -37,23 +46,26 @@ def test_traverse_with_verbose(): get_info, options_info = options_info, get_info assert options_info.startswith('-->OPTIONS') - if cython: - assert options_info.endswith('[unknown file]') - else: - assert 'falcon/responders.py:' in options_info + assert 'falcon/responders.py:' in options_info assert get_info.startswith('-->GET') - # NOTE(vytas): This builds upon the fact that on_get is defined on line 14 - # in this file. Adjust the test if the said responder is relocated, or just - # check for any number if this becomes too painful to maintain. - assert get_info.endswith('tests/test_cmd_print_api.py:15') + + # NOTE(vytas): This builds upon the fact that on_get is defined on line + # 18 or 25 (in the case of DummyResourceAsync) in the present file. + # Adjust the test if the said responder is relocated, or just check for + # any number if this becomes too painful to maintain. + + assert ( + get_info.endswith('tests/test_cmd_print_api.py:13') or + get_info.endswith('tests/test_cmd_print_api.py:20') + ) -def test_traverse(): +def test_traverse(app): """Ensure traverse() finds the proper routes.""" output = io.StringIO() with redirected(stdout=output): - print_routes.traverse(_api._router._roots, verbose=False) + print_routes.traverse(app._router._roots, verbose=False) route = output.getvalue().strip() assert '-> /test' == route diff --git a/tests/test_cookies.py b/tests/test_cookies.py index 30d60fe58..7877d66a2 100644 --- a/tests/test_cookies.py +++ b/tests/test_cookies.py @@ -8,6 +8,8 @@ import falcon.testing as testing from falcon.util import http_date_to_dt, TimezoneGMT +from _util import create_app # NOQA + UNICODE_TEST_STRING = 'Unicode_\xc3\xa6\xc3\xb8' @@ -75,9 +77,9 @@ def on_delete(self, req, resp): resp.set_cookie('baz', 'foo', same_site='') -@pytest.fixture() -def client(): - app = falcon.App() +@pytest.fixture +def client(asgi): + app = create_app(asgi) app.add_route('/', CookieResource()) app.add_route('/test-convert', CookieResourceMaxAgeFloatString()) app.add_route('/same-site', CookieResourceSameSite()) diff --git a/tests/test_custom_router.py b/tests/test_custom_router.py index 164d43a33..710cd809b 100644 --- a/tests/test_custom_router.py +++ b/tests/test_custom_router.py @@ -3,8 +3,11 @@ import falcon from falcon import testing +from _util import create_app # NOQA -def test_custom_router_add_route_should_be_used(): + +@pytest.mark.parametrize('asgi', [True, False]) +def test_custom_router_add_route_should_be_used(asgi): check = [] class CustomRouter: @@ -14,17 +17,21 @@ def add_route(self, uri_template, *args, **kwargs): def find(self, uri): pass - app = falcon.App(router=CustomRouter()) + app = create_app(asgi=asgi, router=CustomRouter()) app.add_route('/test', 'resource') assert len(check) == 1 assert '/test' in check -def test_custom_router_find_should_be_used(): - - def resource(req, resp, **kwargs): - resp.body = '{{"uri_template": "{0}"}}'.format(req.uri_template) +@pytest.mark.parametrize('asgi', [True, False]) +def test_custom_router_find_should_be_used(asgi): + if asgi: + async def resource(req, resp, **kwargs): + resp.body = '{{"uri_template": "{0}"}}'.format(req.uri_template) + else: + def resource(req, resp, **kwargs): + resp.body = '{{"uri_template": "{0}"}}'.format(req.uri_template) class CustomRouter: def __init__(self): @@ -47,7 +54,7 @@ def find(self, uri, req=None): return None router = CustomRouter() - app = falcon.App(router=router) + app = create_app(asgi=asgi, router=router) client = testing.TestClient(app) response = client.simulate_request(path='/test/42') @@ -67,7 +74,8 @@ def find(self, uri, req=None): assert router.reached_backwards_compat -def test_can_pass_additional_params_to_add_route(): +@pytest.mark.parametrize('asgi', [True, False]) +def test_can_pass_additional_params_to_add_route(asgi): check = [] @@ -80,7 +88,7 @@ def add_route(self, uri_template, resource, **kwargs): def find(self, uri): pass - app = falcon.App(router=CustomRouter()) + app = create_app(asgi=asgi, router=CustomRouter()) app.add_route('/test', 'resource', name='my-url-name') assert len(check) == 1 @@ -93,9 +101,14 @@ def find(self, uri): app.add_route('/test', 'resource', 'xarg1', 'xarg2') -def test_custom_router_takes_req_positional_argument(): - def responder(req, resp): - resp.body = 'OK' +@pytest.mark.parametrize('asgi', [True, False]) +def test_custom_router_takes_req_positional_argument(asgi): + if asgi: + async def responder(req, resp): + resp.body = 'OK' + else: + def responder(req, resp): + resp.body = 'OK' class CustomRouter: def find(self, uri, req): @@ -103,15 +116,20 @@ def find(self, uri, req): return responder, {'GET': responder}, {}, None router = CustomRouter() - app = falcon.App(router=router) + app = create_app(asgi=asgi, router=router) client = testing.TestClient(app) response = client.simulate_request(path='/test') assert response.content == b'OK' -def test_custom_router_takes_req_keyword_argument(): - def responder(req, resp): - resp.body = 'OK' +@pytest.mark.parametrize('asgi', [True, False]) +def test_custom_router_takes_req_keyword_argument(asgi): + if asgi: + async def responder(req, resp): + resp.body = 'OK' + else: + def responder(req, resp): + resp.body = 'OK' class CustomRouter: def find(self, uri, req=None): @@ -119,7 +137,7 @@ def find(self, uri, req=None): return responder, {'GET': responder}, {}, None router = CustomRouter() - app = falcon.App(router=router) + app = create_app(asgi=asgi, router=router) client = testing.TestClient(app) response = client.simulate_request(path='/test') assert response.content == b'OK' diff --git a/tests/test_default_router.py b/tests/test_default_router.py index 8881fc98b..653ee9163 100644 --- a/tests/test_default_router.py +++ b/tests/test_default_router.py @@ -2,14 +2,14 @@ import pytest -import falcon from falcon import testing from falcon.routing import DefaultRouter +from _util import create_app # NOQA -@pytest.fixture -def client(): - return testing.TestClient(falcon.App()) + +def client(asgi): + return testing.TestClient(create_app(asgi)) @pytest.fixture @@ -244,13 +244,14 @@ def test_user_regression_special_chars(uri_template, path, expected_params): # ===================================================================== +@pytest.mark.parametrize('asgi', [True, False]) @pytest.mark.parametrize('uri_template', [ {}, set(), object() ]) -def test_not_str(uri_template): - app = falcon.App() +def test_not_str(asgi, uri_template): + app = create_app(asgi) with pytest.raises(TypeError): app.add_route(uri_template, ResourceWithId(-1)) diff --git a/tests/test_error_handlers.py b/tests/test_error_handlers.py index 35aac39fe..ab8399565 100644 --- a/tests/test_error_handlers.py +++ b/tests/test_error_handlers.py @@ -1,7 +1,9 @@ import pytest import falcon -from falcon import constants, testing +from falcon import ASGI_SUPPORTED, constants, testing + +from _util import create_app, disable_asgi_non_coroutine_wrapping # NOQA def capture_error(req, resp, ex, params): @@ -9,6 +11,10 @@ def capture_error(req, resp, ex, params): resp.body = 'error: %s' % str(ex) +async def capture_error_async(*args): + capture_error(*args) + + def handle_error_first(req, resp, ex, params): resp.status = falcon.HTTP_200 resp.body = 'first error handler' @@ -43,8 +49,8 @@ def on_delete(self, req, resp): @pytest.fixture -def client(): - app = falcon.App() +def client(asgi): + app = create_app(asgi) app.add_route('/', ErroredClassResource()) return testing.TestClient(app) @@ -73,6 +79,38 @@ def test_uncaught_python_error(self, client, assert result.headers['content-type'] == resp_content_type assert result.text.startswith(resp_start) + def test_caught_error_async(self, asgi): + if not asgi: + pytest.skip('Test only applies to ASGI') + + if not ASGI_SUPPORTED: + pytest.skip('ASGI requires Python 3.6+') + + import falcon.asgi + app = falcon.asgi.App() + app.add_route('/', ErroredClassResource()) + app.add_error_handler(Exception, capture_error_async) + + client = testing.TestClient(app) + + result = client.simulate_get() + assert result.text == 'error: Plain Exception' + + result = client.simulate_head() + assert result.status_code == 723 + assert not result.content + + def test_uncaught_error(self, client): + client.app._error_handlers.clear() + client.app.add_error_handler(CustomException, capture_error) + with pytest.raises(Exception): + client.simulate_get() + + def test_uncaught_error_else(self, client): + client.app._error_handlers.clear() + with pytest.raises(Exception): + client.simulate_get() + def test_converted_error(self, client): client.app.add_error_handler(CustomException) @@ -144,7 +182,7 @@ def exception_list_generator(): NotImplemented, 'Hello, world!', frozenset([ZeroDivisionError, int, NotImplementedError]), - iter([float, float]), + [float, float], ]) def test_invalid_add_exception_handler_input(self, client, exceptions): with pytest.raises(TypeError): @@ -172,3 +210,29 @@ def legacy_handler3(err, rq, rs, prms): client.simulate_delete() client.simulate_get() client.simulate_head() + + def test_handler_signature_shim_asgi(self): + def check_args(ex, req, resp): + assert isinstance(ex, BaseException) + assert isinstance(req, falcon.Request) + assert isinstance(resp, falcon.Response) + + async def legacy_handler(err, rq, rs, prms): + check_args(err, rq, rs) + + app = create_app(True) + app.add_route('/', ErroredClassResource()) + app.add_error_handler(Exception, legacy_handler) + client = testing.TestClient(app) + + client.simulate_get() + + def test_handler_must_be_coroutine_for_asgi(self): + async def legacy_handler(err, rq, rs, prms): + pass + + app = create_app(True) + + with disable_asgi_non_coroutine_wrapping(): + with pytest.raises(ValueError): + app.add_error_handler(Exception, capture_error) diff --git a/tests/test_headers.py b/tests/test_headers.py index ffda9d0d8..9490b4103 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -6,19 +6,24 @@ import falcon from falcon import testing +from _util import create_app, disable_asgi_non_coroutine_wrapping # NOQA + SAMPLE_BODY = testing.rand_string(0, 128 * 1024) @pytest.fixture -def client(): - app = falcon.App() +def client(asgi): + app = create_app(asgi) return testing.TestClient(app) @pytest.fixture(scope='function') -def cors_client(): - app = falcon.App(cors_enable=True) +def cors_client(asgi): + # NOTE(kgriffs): Disable wrapping to test that built-in middleware does + # not require it (since this will be the case for non-test apps). + with disable_asgi_non_coroutine_wrapping(): + app = create_app(asgi, cors_enable=True) return testing.TestClient(app) @@ -76,6 +81,7 @@ def on_get(self, req, resp): assert resp.get_header('x-client-should-never-see-this') == 'abc' resp.delete_header('x-client-should-never-see-this') + self.req = req self.resp = resp def on_head(self, req, resp): @@ -385,24 +391,26 @@ def test_default_media_type(self, client): resource = testing.SimpleTestResource(body='Hello world!') self._check_header(client, resource, 'Content-Type', falcon.DEFAULT_MEDIA_TYPE) + @pytest.mark.parametrize('asgi', [True, False]) @pytest.mark.parametrize('content_type,body', [ ('text/plain; charset=UTF-8', 'Hello Unicode! \U0001F638'), # NOTE(kgriffs): This only works because the client defaults to # ISO-8859-1 IFF the media type is 'text'. ('text/plain', 'Hello ISO-8859-1!'), ]) - def test_override_default_media_type(self, client, content_type, body): - client.app = falcon.App(media_type=content_type) + def test_override_default_media_type(self, asgi, client, content_type, body): + client.app = create_app(asgi=asgi, media_type=content_type) client.app.add_route('/', testing.SimpleTestResource(body=body)) result = client.simulate_get() assert result.text == body assert result.headers['Content-Type'] == content_type - def test_override_default_media_type_missing_encoding(self, client): + @pytest.mark.parametrize('asgi', [True, False]) + def test_override_default_media_type_missing_encoding(self, asgi, client): body = '{"msg": "Hello Unicode! \U0001F638"}' - client.app = falcon.App(media_type='application/json') + client.app = create_app(asgi=asgi, media_type='application/json') client.app.add_route('/', testing.SimpleTestResource(body=body)) result = client.simulate_get() @@ -782,6 +790,23 @@ def test_enabled_cors_handles_preflighting_no_headers_in_req(self, cors_client): assert result.headers['Access-Control-Allow-Headers'] == '*' assert result.headers['Access-Control-Max-Age'] == '86400' # 24 hours in seconds + def test_request_multiple_header(self, client): + resource = HeaderHelpersResource() + client.app.add_route('/', resource) + + client.simulate_request(headers=[ + # Singletone header; last one wins + ('Content-Type', 'text/plain'), + ('Content-Type', 'image/jpeg'), + + # Should be concatenated + ('X-Thing', '1'), + ('X-Thing', '2'), + ]) + + assert resource.req.content_type == 'image/jpeg' + assert resource.req.get_header('X-Thing') == '1,2' + # ---------------------------------------------------------------------- # Helpers # ---------------------------------------------------------------------- diff --git a/tests/test_hello.py b/tests/test_hello.py index dcf90d7ec..2939df105 100644 --- a/tests/test_hello.py +++ b/tests/test_hello.py @@ -147,11 +147,14 @@ def test_body(self, client, path, resource, get_body): assert result.content == resource.sample_utf8 def test_no_body_on_head(self, client): - client.app.add_route('/body', HelloResource('body')) + resource = HelloResource('body') + client.app.add_route('/body', resource) result = client.simulate_head('/body') assert not result.content assert result.status_code == 200 + assert resource.called + assert result.headers['content-length'] == str(len(HelloResource.sample_utf8)) def test_stream_chunked(self, client): resource = HelloResource('stream') diff --git a/tests/test_http_custom_method_routing.py b/tests/test_http_custom_method_routing.py index a3e403cae..06483d71f 100644 --- a/tests/test_http_custom_method_routing.py +++ b/tests/test_http_custom_method_routing.py @@ -2,6 +2,10 @@ import os import wsgiref.validate +try: + import cython +except ImportError: + cython = None import pytest import falcon @@ -9,10 +13,8 @@ import falcon.constants from falcon.routing import util -try: - import cython -except ImportError: - cython = None +from _util import create_app # NOQA + FALCON_CUSTOM_HTTP_METHODS = ['FOO', 'BAR'] @@ -37,10 +39,10 @@ def cleanup_constants(): @pytest.fixture -def custom_http_client(cleanup_constants, resource_things): +def custom_http_client(asgi, request, cleanup_constants, resource_things): falcon.constants.COMBINED_METHODS += FALCON_CUSTOM_HTTP_METHODS - app = falcon.App() + app = create_app(asgi) app.add_route('/things', resource_things) return testing.TestClient(app) @@ -94,8 +96,14 @@ def test_foo(custom_http_client, resource_things): """FOO is a supported method, so returns HTTP_204""" custom_http_client.app.add_route('/things', resource_things) - with pytest.warns(wsgiref.validate.WSGIWarning): - response = custom_http_client.simulate_request(path='/things', method='FOO') + def s(): + return custom_http_client.simulate_request(path='/things', method='FOO') + + if not custom_http_client.app._ASGI: + with pytest.warns(wsgiref.validate.WSGIWarning): + response = s() + else: + response = s() assert 'FOO' in falcon.constants.COMBINED_METHODS assert response.status == falcon.HTTP_204 @@ -106,8 +114,14 @@ def test_bar(custom_http_client, resource_things): """BAR is not supported by ResourceThing""" custom_http_client.app.add_route('/things', resource_things) - with pytest.warns(wsgiref.validate.WSGIWarning): - response = custom_http_client.simulate_request(path='/things', method='BAR') + def s(): + return custom_http_client.simulate_request(path='/things', method='BAR') + + if not custom_http_client.app._ASGI: + with pytest.warns(wsgiref.validate.WSGIWarning): + response = s() + else: + response = s() assert 'BAR' in falcon.constants.COMBINED_METHODS assert response.status == falcon.HTTP_405 diff --git a/tests/test_http_method_routing.py b/tests/test_http_method_routing.py index f8236a12a..6f2e272b0 100644 --- a/tests/test_http_method_routing.py +++ b/tests/test_http_method_routing.py @@ -1,11 +1,12 @@ from functools import wraps -import wsgiref.validate import pytest import falcon import falcon.testing as testing +from _util import create_app # NOQA + # RFC 7231, 5789 methods HTTP_METHODS = [ 'CONNECT', @@ -58,8 +59,8 @@ def resource_get_with_faulty_put(): @pytest.fixture -def client(): - app = falcon.App() +def client(asgi): + app = create_app(asgi) app.add_route('/stonewall', Stonewall()) @@ -278,8 +279,7 @@ def test_bogus_method(self, client, resource_things): client.app.add_route('/things', resource_things) client.app.add_route('/things/{id}/stuff/{sid}', resource_things) - with pytest.warns(wsgiref.validate.WSGIWarning): - response = client.simulate_request(path='/things', method='SETECASTRONOMY') + response = client.simulate_request(path='/things', method='SETECASTRONOMY') assert not resource_things.called assert response.status == falcon.HTTP_400 diff --git a/tests/test_httperror.py b/tests/test_httperror.py index ecdba2759..198edc320 100644 --- a/tests/test_httperror.py +++ b/tests/test_httperror.py @@ -11,10 +11,12 @@ import falcon.testing as testing from falcon.util import json +from _util import create_app # NOQA + @pytest.fixture -def client(): - app = falcon.App() +def client(asgi): + app = create_app(asgi) resource = FaultyResource() app.add_route('/fail', resource) @@ -380,11 +382,17 @@ def _simple_serializer(req, resp, exception): client.app.add_route('/notfound', NotFoundResourceWithBody()) client.app.set_error_serializer(_simple_serializer) + def s(): + return client.simulate_request(path=path, method=method) + if method not in falcon.COMBINED_METHODS: - with pytest.warns(wsgiref.validate.WSGIWarning): - resp = client.simulate_request(path=path, method=method) + if not client.app._ASGI: + with pytest.warns(wsgiref.validate.WSGIWarning): + resp = s() + else: + resp = s() else: - resp = client.simulate_request(path=path, method=method) + resp = s() assert resp.json['title'] assert resp.json['status'] == status @@ -702,8 +710,8 @@ def test_414_with_custom_kwargs(self, client): parsed_body = json.loads(response.content.decode()) assert parsed_body['code'] == code - def test_416(self, client): - client.app = falcon.App() + def test_416(self, client, asgi): + client.app = create_app(asgi) client.app.add_route('/416', RangeNotSatisfiableResource()) response = client.simulate_request(path='/416', headers={'accept': 'text/xml'}) diff --git a/tests/test_httpstatus.py b/tests/test_httpstatus.py index c6061210a..4277724c0 100644 --- a/tests/test_httpstatus.py +++ b/tests/test_httpstatus.py @@ -1,9 +1,27 @@ # -*- coding: utf-8 +import pytest + import falcon from falcon.http_status import HTTPStatus import falcon.testing as testing +from _util import create_app # NOQA + + +@pytest.fixture(params=[True, False]) +def client(request): + app = create_app(asgi=request.param) + app.add_route('/status', TestStatusResource()) + return testing.TestClient(app) + + +@pytest.fixture(params=[True, False]) +def hook_test_client(request): + app = create_app(asgi=request.param) + app.add_route('/status', TestHookResource()) + return testing.TestClient(app) + def before_hook(req, resp, resource, params): raise HTTPStatus(falcon.HTTP_200, @@ -74,106 +92,101 @@ def on_delete(self, req, resp): class TestHTTPStatus: - def test_raise_status_in_before_hook(self): + def test_raise_status_in_before_hook(self, client): """ Make sure we get the 200 raised by before hook """ - app = falcon.App() - app.add_route('/status', TestStatusResource()) - client = testing.TestClient(app) - response = client.simulate_request(path='/status', method='GET') assert response.status == falcon.HTTP_200 assert response.headers['x-failed'] == 'False' assert response.text == 'Pass' - def test_raise_status_in_responder(self): + def test_raise_status_in_responder(self, client): """ Make sure we get the 200 raised by responder """ - app = falcon.App() - app.add_route('/status', TestStatusResource()) - client = testing.TestClient(app) - response = client.simulate_request(path='/status', method='POST') assert response.status == falcon.HTTP_200 assert response.headers['x-failed'] == 'False' assert response.text == 'Pass' - def test_raise_status_runs_after_hooks(self): + def test_raise_status_runs_after_hooks(self, client): """ Make sure after hooks still run """ - app = falcon.App() - app.add_route('/status', TestStatusResource()) - client = testing.TestClient(app) - response = client.simulate_request(path='/status', method='PUT') assert response.status == falcon.HTTP_200 assert response.headers['x-failed'] == 'False' assert response.text == 'Pass' - def test_raise_status_survives_after_hooks(self): + def test_raise_status_survives_after_hooks(self, client): """ Make sure after hook doesn't overwrite our status """ - app = falcon.App() - app.add_route('/status', TestStatusResource()) - client = testing.TestClient(app) - response = client.simulate_request(path='/status', method='DELETE') assert response.status == falcon.HTTP_200 assert response.headers['x-failed'] == 'False' assert response.text == 'Pass' - def test_raise_status_empty_body(self): + def test_raise_status_empty_body(self, client): """ Make sure passing None to body results in empty body """ - app = falcon.App() - app.add_route('/status', TestStatusResource()) - client = testing.TestClient(app) - response = client.simulate_request(path='/status', method='PATCH') assert response.text == '' class TestHTTPStatusWithMiddleware: - def test_raise_status_in_process_request(self): + + def test_raise_status_in_process_request(self, hook_test_client): """ Make sure we can raise status from middleware process request """ + client = hook_test_client + class TestMiddleware: def process_request(self, req, resp): raise HTTPStatus(falcon.HTTP_200, headers={'X-Failed': 'False'}, body='Pass') - app = falcon.App(middleware=TestMiddleware()) - app.add_route('/status', TestHookResource()) - client = testing.TestClient(app) + # NOTE(kgriffs): Test the side-by-side support for dual WSGI and + # ASGI compatibility. + async def process_request_async(self, req, resp): + self.process_request(req, resp) + + client.app.add_middleware(TestMiddleware()) response = client.simulate_request(path='/status', method='GET') assert response.status == falcon.HTTP_200 assert response.headers['x-failed'] == 'False' assert response.text == 'Pass' - def test_raise_status_in_process_resource(self): + def test_raise_status_in_process_resource(self, hook_test_client): """ Make sure we can raise status from middleware process resource """ + client = hook_test_client + class TestMiddleware: def process_resource(self, req, resp, resource, params): raise HTTPStatus(falcon.HTTP_200, headers={'X-Failed': 'False'}, body='Pass') - app = falcon.App(middleware=TestMiddleware()) - app.add_route('/status', TestHookResource()) - client = testing.TestClient(app) + async def process_resource_async(self, *args): + self.process_resource(*args) + + # NOTE(kgriffs): Pass a list to test that add_middleware can handle it + client.app.add_middleware([TestMiddleware()]) response = client.simulate_request(path='/status', method='GET') assert response.status == falcon.HTTP_200 assert response.headers['x-failed'] == 'False' assert response.text == 'Pass' - def test_raise_status_runs_process_response(self): + def test_raise_status_runs_process_response(self, hook_test_client): """ Make sure process_response still runs """ + client = hook_test_client + class TestMiddleware: def process_response(self, req, resp, resource, req_succeeded): resp.status = falcon.HTTP_200 resp.set_header('X-Failed', 'False') resp.body = 'Pass' - app = falcon.App(middleware=TestMiddleware()) - app.add_route('/status', TestHookResource()) - client = testing.TestClient(app) + async def process_response_async(self, *args): + self.process_response(*args) + + # NOTE(kgriffs): Pass a generic iterable to test that add_middleware + # can handle it. + client.app.add_middleware(iter([TestMiddleware()])) response = client.simulate_request(path='/status', method='GET') assert response.status == falcon.HTTP_200 diff --git a/tests/test_media_handlers.py b/tests/test_media_handlers.py index b8ce617ce..6cbb2b355 100644 --- a/tests/test_media_handlers.py +++ b/tests/test_media_handlers.py @@ -8,7 +8,10 @@ import pytest import ujson -from falcon import media +from falcon import ASGI_SUPPORTED, media, testing + +from _util import create_app # NOQA + orjson = None rapidjson = None @@ -81,20 +84,177 @@ @pytest.mark.parametrize('func, body, expected', SERIALIZATION_PARAM_LIST) -def test_serialization(func, body, expected): - JH = media.JSONHandler(dumps=func) +def test_serialization(asgi, func, body, expected): + handler = media.JSONHandler(dumps=func) + + args = (body, b'application/javacript') + + if asgi: + result = testing.invoke_coroutine_sync(handler.serialize_async, *args) + else: + result = handler.serialize(*args) # NOTE(nZac) PyPy and CPython render the final string differently. One # includes spaces and the other doesn't. This replace will normalize that. - assert JH.serialize(body, b'application/javacript').replace(b' ', b'') == expected # noqa + assert result.replace(b' ', b'') == expected # noqa @pytest.mark.parametrize('func, body, expected', DESERIALIZATION_PARAM_LIST) -def test_deserialization(func, body, expected): - JH = media.JSONHandler(loads=func) - - assert JH.deserialize( - io.BytesIO(body), - 'application/javacript', - len(body) - ) == expected +def test_deserialization(asgi, func, body, expected): + handler = media.JSONHandler(loads=func) + + args = ['application/javacript', len(body)] + + if asgi: + if not ASGI_SUPPORTED: + pytest.skip('ASGI requires Python 3.6+') + + from falcon.asgi.stream import BoundedStream + + s = BoundedStream(testing.ASGIRequestEventEmitter(body)) + args.insert(0, s) + + result = testing.invoke_coroutine_sync(handler.deserialize_async, *args) + else: + args.insert(0, io.BytesIO(body)) + result = handler.deserialize(*args) + + assert result == expected + + +def test_deserialization_raises(asgi): + app = create_app(asgi) + + class SuchException(Exception): + pass + + class FaultyHandler(media.BaseHandler): + def deserialize(self, stream, content_type, content_length): + raise SuchException('Wow such error.') + + def deserialize_async(self, stream, content_type, content_length): + raise SuchException('Wow such error.') + + def serialize(self, media, content_type): + raise SuchException('Wow such error.') + + handlers = media.Handlers({'application/json': FaultyHandler()}) + app.req_options.media_handlers = handlers + app.resp_options.media_handlers = handlers + + class Resource: + def on_get(self, req, resp): + resp.media = {} + + def on_post(self, req, resp): + req.media + + class ResourceAsync: + async def on_get(self, req, resp): + resp.media = {} + + async def on_post(self, req, resp): + await req.get_media() + + app.add_route('/', ResourceAsync() if asgi else Resource()) + + # NOTE(kgriffs): Now that we install a default handler for + # Exception, we have to clear them to test the path we want + # to trigger. + # TODO(kgriffs): Since we always add a default error handler + # for Exception, should we take out the checks in the WSGI/ASGI + # callable and just always assume it will be handled? If so, + # it makes testing that the right exception is propagated harder; + # I suppose we'll have to look at what is logged. + app._error_handlers.clear() + + with pytest.raises(SuchException): + testing.simulate_get(app, '/') + + with pytest.raises(SuchException): + testing.simulate_post(app, '/', json={}) + + +def test_sync_methods_not_overridden(asgi): + app = create_app(asgi) + + class FaultyHandler(media.BaseHandler): + pass + + handlers = media.Handlers({'application/json': FaultyHandler()}) + app.req_options.media_handlers = handlers + app.resp_options.media_handlers = handlers + + class Resource: + def on_get(self, req, resp): + resp.media = {} + + def on_post(self, req, resp): + req.media + + class ResourceAsync: + async def on_get(self, req, resp): + resp.media = {} + + async def on_post(self, req, resp): + await req.get_media() + + app.add_route('/', ResourceAsync() if asgi else Resource()) + + result = testing.simulate_get(app, '/') + assert result.status_code == 500 + + result = testing.simulate_post(app, '/', json={}) + assert result.status_code == 500 + + +def test_async_methods_not_overridden(): + app = create_app(asgi=True) + + class SimpleHandler(media.BaseHandler): + def serialize(self, media, content_type): + return json.dumps(media).encode() + + def deserialize(self, stream, content_type, content_length): + return json.load(stream) + + handlers = media.Handlers({'application/json': SimpleHandler()}) + app.req_options.media_handlers = handlers + app.resp_options.media_handlers = handlers + + class ResourceAsync: + async def on_post(self, req, resp): + resp.media = await req.get_media() + + app.add_route('/', ResourceAsync()) + + doc = {'event': 'serialized'} + result = testing.simulate_post(app, '/', json=doc) + assert result.status_code == 200 + assert result.json == doc + + +def test_async_handler_returning_none(): + app = create_app(asgi=True) + + class SimpleHandler(media.BaseHandler): + def serialize(self, media, content_type): + return json.dumps(media).encode() + + def deserialize(self, stream, content_type, content_length): + return None + + handlers = media.Handlers({'application/json': SimpleHandler()}) + app.req_options.media_handlers = handlers + app.resp_options.media_handlers = handlers + + class ResourceAsync: + async def on_post(self, req, resp): + resp.media = [await req.get_media()] + + app.add_route('/', ResourceAsync()) + + doc = {'event': 'serialized'} + result = testing.simulate_post(app, '/', json=doc) + assert result.status_code == 200 + assert result.json == [None] diff --git a/tests/test_media_urlencoded.py b/tests/test_media_urlencoded.py index 615ce366f..5ca61c27e 100644 --- a/tests/test_media_urlencoded.py +++ b/tests/test_media_urlencoded.py @@ -6,6 +6,8 @@ from falcon import media from falcon import testing +from _util import create_app # NOQA + def test_deserialize_empty_form(): handler = media.URLEncodedFormHandler() @@ -28,6 +30,9 @@ def test_urlencoded_form_handler_serialize(data, expected): handler = media.URLEncodedFormHandler() assert handler.serialize(data, falcon.MEDIA_URLENCODED) == expected + value = testing.invoke_coroutine_sync(handler.serialize_async, data, falcon.MEDIA_URLENCODED) + assert value == expected + class MediaMirror: @@ -35,10 +40,16 @@ def on_post(self, req, resp): resp.media = req.media +class MediaMirrorAsync: + + async def on_post(self, req, resp): + resp.media = await req.get_media() + + @pytest.fixture -def client(): - app = falcon.App() - app.add_route('/media', MediaMirror()) +def client(asgi): + app = create_app(asgi) + app.add_route('/media', MediaMirrorAsync() if asgi else MediaMirror()) return testing.TestClient(app) @@ -46,7 +57,11 @@ def test_empty_form(client): resp = client.simulate_post( '/media', headers={'Content-Type': 'application/x-www-form-urlencoded'}) - assert resp.content == b'' + + # TODO(kgriffs): The ASGI side implements the recommended fixes from + # https://github.com/falconry/falcon/issues/1589 so we will need to + # update this assert once the WSGI side has been updated to suit. + assert resp.content == (b'{}' if client.app._ASGI else b'') @pytest.mark.parametrize('body,expected', [ diff --git a/tests/test_middleware.py b/tests/test_middleware.py index e87e634dc..1ca6ad84e 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -4,8 +4,12 @@ import pytest import falcon +import falcon.errors import falcon.testing as testing +from _util import create_app # NOQA + + _EXPECTED_BODY = {'status': 'ok'} context = {'executed_methods': []} # type: ignore @@ -42,6 +46,15 @@ def process_response(self, req, resp, resource, req_succeeded): context['end_time'] = datetime.utcnow() context['req_succeeded'] = req_succeeded + async def process_request_async(self, req, resp): + self.process_request(req, resp) + + async def process_resource_async(self, req, resp, resource, params): + self.process_resource(req, resp, resource, params) + + async def process_response_async(self, req, resp, resource, req_succeeded): + self.process_response(req, resp, resource, req_succeeded) + class TransactionIdMiddleware: @@ -57,6 +70,21 @@ def process_response(self, req, resp, resource, req_succeeded): pass +class TransactionIdMiddlewareAsync: + + def __init__(self): + self._mw = TransactionIdMiddleware() + + async def process_request_async(self, req, resp): + self._mw.process_request(req, resp) + + async def process_resource_async(self, req, resp, resource, params): + self._mw.process_resource(req, resp, resource, params) + + async def process_response_async(self, req, resp, resource, req_succeeded): + self._mw.process_response(req, resp, resource, req_succeeded) + + class ExecutedFirstMiddleware: def process_request(self, req, resp): @@ -152,9 +180,9 @@ def setup_method(self, method): class TestRequestTimeMiddleware(TestMiddleware): - def test_skip_process_resource(self): + def test_skip_process_resource(self, asgi): global context - app = falcon.App(middleware=[RequestTimeMiddleware()]) + app = create_app(asgi, middleware=[RequestTimeMiddleware()]) app.add_route('/', MiddlewareClassResource()) client = testing.TestClient(app) @@ -166,7 +194,7 @@ def test_skip_process_resource(self): assert 'end_time' in context assert not context['req_succeeded'] - def test_add_invalid_middleware(self): + def test_add_invalid_middleware(self, asgi): """Test than an invalid class can not be added as middleware""" class InvalidMiddleware(): def process_request(self, *args): @@ -174,22 +202,24 @@ def process_request(self, *args): mw_list = [RequestTimeMiddleware(), InvalidMiddleware] with pytest.raises(AttributeError): - falcon.App(middleware=mw_list) + create_app(asgi, middleware=mw_list) + mw_list = [RequestTimeMiddleware(), 'InvalidMiddleware'] with pytest.raises(TypeError): - falcon.App(middleware=mw_list) + create_app(asgi, middleware=mw_list) + mw_list = [{'process_request': 90}] with pytest.raises(TypeError): - falcon.App(middleware=mw_list) + create_app(asgi, middleware=mw_list) - def test_response_middleware_raises_exception(self): + def test_response_middleware_raises_exception(self, asgi): """Test that error in response middleware is propagated up""" class RaiseErrorMiddleware: def process_response(self, req, resp, resource): raise Exception('Always fail') - app = falcon.App(middleware=[RaiseErrorMiddleware()]) + app = create_app(asgi, middleware=[RaiseErrorMiddleware()]) app.add_route(TEST_ROUTE, MiddlewareClassResource()) client = testing.TestClient(app) @@ -198,10 +228,10 @@ def process_response(self, req, resp, resource): assert result.status_code == 500 @pytest.mark.parametrize('independent_middleware', [True, False]) - def test_log_get_request(self, independent_middleware): + def test_log_get_request(self, independent_middleware, asgi): """Test that Log middleware is executed""" global context - app = falcon.App(middleware=[RequestTimeMiddleware()], + app = create_app(asgi, middleware=[RequestTimeMiddleware()], independent_middleware=independent_middleware) app.add_route(TEST_ROUTE, MiddlewareClassResource()) @@ -223,10 +253,12 @@ def test_log_get_request(self, independent_middleware): class TestTransactionIdMiddleware(TestMiddleware): - def test_generate_trans_id_with_request(self): + def test_generate_trans_id_with_request(self, asgi): """Test that TransactionIdmiddleware is executed""" global context - app = falcon.App(middleware=TransactionIdMiddleware()) + + middleware = TransactionIdMiddlewareAsync() if asgi else TransactionIdMiddleware() + app = create_app(asgi, middleware=middleware) app.add_route(TEST_ROUTE, MiddlewareClassResource()) client = testing.TestClient(app) @@ -239,8 +271,9 @@ def test_generate_trans_id_with_request(self): class TestSeveralMiddlewares(TestMiddleware): + @pytest.mark.parametrize('independent_middleware', [True, False]) - def test_generate_trans_id_and_time_with_request(self, independent_middleware): + def test_generate_trans_id_and_time_with_request(self, independent_middleware, asgi): # NOTE(kgriffs): We test both so that we can cover the code paths # where only a single middleware method is implemented by a # component. @@ -248,11 +281,21 @@ def test_generate_trans_id_and_time_with_request(self, independent_middleware): cresp = CaptureResponseMiddleware() global context - app = falcon.App(independent_middleware=independent_middleware, - middleware=[TransactionIdMiddleware(), - RequestTimeMiddleware(), - creq, - cresp]) + app = create_app( + asgi, + independent_middleware=independent_middleware, + + # NOTE(kgriffs): Pass as a generic iterable to verify that works. + middleware=iter([ + TransactionIdMiddleware(), + RequestTimeMiddleware(), + ]) + ) + + # NOTE(kgriffs): Add a couple more after the fact to test + # add_middleware(). + app.add_middleware(creq) + app.add_middleware(cresp) app.add_route(TEST_ROUTE, MiddlewareClassResource()) client = testing.TestClient(app) @@ -270,9 +313,9 @@ def test_generate_trans_id_and_time_with_request(self, independent_middleware): assert context['end_time'] >= context['start_time'], \ 'process_response not executed after request' - def test_legacy_middleware_called_with_correct_args(self): + def test_legacy_middleware_called_with_correct_args(self, asgi): global context - app = falcon.App(middleware=[ExecutedFirstMiddleware()]) + app = create_app(asgi, middleware=[ExecutedFirstMiddleware()]) app.add_route(TEST_ROUTE, MiddlewareClassResource()) client = testing.TestClient(app) @@ -281,9 +324,9 @@ def test_legacy_middleware_called_with_correct_args(self): assert isinstance(context['resp'], falcon.Response) assert isinstance(context['resource'], MiddlewareClassResource) - def test_middleware_execution_order(self): + def test_middleware_execution_order(self, asgi): global context - app = falcon.App(independent_middleware=False, + app = create_app(asgi, independent_middleware=False, middleware=[ExecutedFirstMiddleware(), ExecutedLastMiddleware()]) @@ -305,9 +348,9 @@ def test_middleware_execution_order(self): ] assert expectedExecutedMethods == context['executed_methods'] - def test_independent_middleware_execution_order(self): + def test_independent_middleware_execution_order(self, asgi): global context - app = falcon.App(independent_middleware=True, + app = create_app(asgi, independent_middleware=True, middleware=[ExecutedFirstMiddleware(), ExecutedLastMiddleware()]) @@ -329,7 +372,7 @@ def test_independent_middleware_execution_order(self): ] assert expectedExecutedMethods == context['executed_methods'] - def test_multiple_reponse_mw_throw_exception(self): + def test_multiple_reponse_mw_throw_exception(self, asgi): """Test that error in inner middleware leaves""" global context @@ -348,11 +391,11 @@ def process_response(self, req, resp, resource, req_succeeded): context['executed_methods'].append('process_response') context['req_succeeded'].append(req_succeeded) - app = falcon.App(middleware=[ProcessResponseMiddleware(), - RaiseErrorMiddleware(), - ProcessResponseMiddleware(), - RaiseStatusMiddleware(), - ProcessResponseMiddleware()]) + app = create_app(asgi, middleware=[ProcessResponseMiddleware(), + RaiseErrorMiddleware(), + ProcessResponseMiddleware(), + RaiseStatusMiddleware(), + ProcessResponseMiddleware()]) app.add_route(TEST_ROUTE, MiddlewareClassResource()) client = testing.TestClient(app) @@ -365,32 +408,89 @@ def process_response(self, req, resp, resource, req_succeeded): assert context['executed_methods'] == expected_methods assert context['req_succeeded'] == [True, False, False] - def test_inner_mw_throw_exception(self): + def test_inner_mw_throw_exception(self, asgi): """Test that error in inner middleware leaves""" global context + class MyException(Exception): + pass + class RaiseErrorMiddleware: def process_request(self, req, resp): - raise Exception('Always fail') + raise MyException('Always fail') - app = falcon.App(middleware=[TransactionIdMiddleware(), - RequestTimeMiddleware(), - RaiseErrorMiddleware()]) + app = create_app(asgi, middleware=[TransactionIdMiddleware(), + RequestTimeMiddleware(), + RaiseErrorMiddleware()]) + + # NOTE(kgriffs): Now that we install a default handler for + # Exception, we have to clear them to test the path we want + # to trigger with RaiseErrorMiddleware + # TODO(kgriffs): Since we always add a default error handler + # for Exception, should we take out the checks in the WSGI/ASGI + # callable and just always assume it will be handled? If so, + # then we would remove the test here... + app._error_handlers.clear() app.add_route(TEST_ROUTE, MiddlewareClassResource()) client = testing.TestClient(app) - result = client.simulate_request(path=TEST_ROUTE) - assert result.status_code == 500 + with pytest.raises(MyException): + client.simulate_request(path=TEST_ROUTE) # RequestTimeMiddleware process_response should be executed assert 'transaction_id' in context assert 'start_time' in context assert 'mid_time' not in context - assert 'end_time' in context - def test_inner_mw_with_ex_handler_throw_exception(self): + # NOTE(kgriffs): Should not have been added since raising an + # unhandled error skips further processing, including response + # middleware methods. + assert 'end_time' not in context + + def test_inner_mw_throw_exception_while_processing_resp(self, asgi): + """Test that error in inner middleware leaves""" + global context + + class MyException(Exception): + pass + + class RaiseErrorMiddleware: + + def process_response(self, req, resp, resource, req_succeeded): + raise MyException('Always fail') + + app = create_app(asgi, middleware=[TransactionIdMiddleware(), + RequestTimeMiddleware(), + RaiseErrorMiddleware()]) + + # NOTE(kgriffs): Now that we install a default handler for + # Exception, we have to clear them to test the path we want + # to trigger with RaiseErrorMiddleware + # TODO(kgriffs): Since we always add a default error handler + # for Exception, should we take out the checks in the WSGI/ASGI + # callable and just always assume it will be handled? If so, + # then we would remove the test here... + app._error_handlers.clear() + + app.add_route(TEST_ROUTE, MiddlewareClassResource()) + client = testing.TestClient(app) + + with pytest.raises(MyException): + client.simulate_request(path=TEST_ROUTE) + + # RequestTimeMiddleware process_response should be executed + assert 'transaction_id' in context + assert 'start_time' in context + assert 'mid_time' in context + + # NOTE(kgriffs): Should not have been added since raising an + # unhandled error skips further processing, including response + # middleware methods. + assert 'end_time' not in context + + def test_inner_mw_with_ex_handler_throw_exception(self, asgi): """Test that error in inner middleware leaves""" global context @@ -399,9 +499,9 @@ class RaiseErrorMiddleware: def process_request(self, req, resp, resource): raise Exception('Always fail') - app = falcon.App(middleware=[TransactionIdMiddleware(), - RequestTimeMiddleware(), - RaiseErrorMiddleware()]) + app = create_app(asgi, middleware=[TransactionIdMiddleware(), + RequestTimeMiddleware(), + RaiseErrorMiddleware()]) def handler(req, resp, ex, params): context['error_handler'] = True @@ -420,7 +520,7 @@ def handler(req, resp, ex, params): assert 'end_time' in context assert 'error_handler' in context - def test_outer_mw_with_ex_handler_throw_exception(self): + def test_outer_mw_with_ex_handler_throw_exception(self, asgi): """Test that error in inner middleware leaves""" global context @@ -429,9 +529,9 @@ class RaiseErrorMiddleware: def process_request(self, req, resp): raise Exception('Always fail') - app = falcon.App(middleware=[TransactionIdMiddleware(), - RaiseErrorMiddleware(), - RequestTimeMiddleware()]) + app = create_app(asgi, middleware=[TransactionIdMiddleware(), + RaiseErrorMiddleware(), + RequestTimeMiddleware()]) def handler(req, resp, ex, params): context['error_handler'] = True @@ -450,7 +550,7 @@ def handler(req, resp, ex, params): assert 'end_time' in context assert 'error_handler' in context - def test_order_mw_executed_when_exception_in_resp(self): + def test_order_mw_executed_when_exception_in_resp(self, asgi): """Test that error in inner middleware leaves""" global context @@ -459,9 +559,9 @@ class RaiseErrorMiddleware: def process_response(self, req, resp, resource): raise Exception('Always fail') - app = falcon.App(middleware=[ExecutedFirstMiddleware(), - RaiseErrorMiddleware(), - ExecutedLastMiddleware()]) + app = create_app(asgi, middleware=[ExecutedFirstMiddleware(), + RaiseErrorMiddleware(), + ExecutedLastMiddleware()]) def handler(req, resp, ex, params): pass @@ -484,7 +584,7 @@ def handler(req, resp, ex, params): ] assert expectedExecutedMethods == context['executed_methods'] - def test_order_independent_mw_executed_when_exception_in_resp(self): + def test_order_independent_mw_executed_when_exception_in_resp(self, asgi): """Test that error in inner middleware leaves""" global context @@ -493,7 +593,7 @@ class RaiseErrorMiddleware: def process_response(self, req, resp, resource): raise Exception('Always fail') - app = falcon.App(independent_middleware=True, + app = create_app(asgi, independent_middleware=True, middleware=[ExecutedFirstMiddleware(), RaiseErrorMiddleware(), ExecutedLastMiddleware()]) @@ -519,18 +619,23 @@ def handler(req, resp, ex, params): ] assert expectedExecutedMethods == context['executed_methods'] - def test_order_mw_executed_when_exception_in_req(self): + def test_order_mw_executed_when_exception_in_req(self, asgi): """Test that error in inner middleware leaves""" global context class RaiseErrorMiddleware: - def process_request(self, req, resp): raise Exception('Always fail') - app = falcon.App(middleware=[ExecutedFirstMiddleware(), - RaiseErrorMiddleware(), - ExecutedLastMiddleware()]) + class RaiseErrorMiddlewareAsync: + async def process_request(self, req, resp): + raise Exception('Always fail') + + rem = RaiseErrorMiddlewareAsync() if asgi else RaiseErrorMiddleware() + + app = create_app(asgi, middleware=[ExecutedFirstMiddleware(), + rem, + ExecutedLastMiddleware()]) def handler(req, resp, ex, params): pass @@ -550,18 +655,23 @@ def handler(req, resp, ex, params): ] assert expectedExecutedMethods == context['executed_methods'] - def test_order_independent_mw_executed_when_exception_in_req(self): + def test_order_independent_mw_executed_when_exception_in_req(self, asgi): """Test that error in inner middleware leaves""" global context class RaiseErrorMiddleware: - def process_request(self, req, resp): raise Exception('Always fail') - app = falcon.App(independent_middleware=True, + class RaiseErrorMiddlewareAsync: + async def process_request(self, req, resp): + raise Exception('Always fail') + + rem = RaiseErrorMiddlewareAsync() if asgi else RaiseErrorMiddleware() + + app = create_app(asgi, independent_middleware=True, middleware=[ExecutedFirstMiddleware(), - RaiseErrorMiddleware(), + rem, ExecutedLastMiddleware()]) def handler(req, resp, ex, params): @@ -582,18 +692,25 @@ def handler(req, resp, ex, params): ] assert expectedExecutedMethods == context['executed_methods'] - def test_order_mw_executed_when_exception_in_rsrc(self): + def test_order_mw_executed_when_exception_in_rsrc(self, asgi): """Test that error in inner middleware leaves""" global context class RaiseErrorMiddleware: - def process_resource(self, req, resp, resource): raise Exception('Always fail') - app = falcon.App(middleware=[ExecutedFirstMiddleware(), - RaiseErrorMiddleware(), - ExecutedLastMiddleware()]) + class RaiseErrorMiddlewareAsync: + # NOTE(kgriffs): The *_async postfix is not required in this + # case, but we include it to make sure it works as expected. + async def process_resource_async(self, req, resp, resource): + raise Exception('Always fail') + + rem = RaiseErrorMiddlewareAsync() if asgi else RaiseErrorMiddleware() + + app = create_app(asgi, middleware=[ExecutedFirstMiddleware(), + rem, + ExecutedLastMiddleware()]) def handler(req, resp, ex, params): pass @@ -615,18 +732,23 @@ def handler(req, resp, ex, params): ] assert expectedExecutedMethods == context['executed_methods'] - def test_order_independent_mw_executed_when_exception_in_rsrc(self): + def test_order_independent_mw_executed_when_exception_in_rsrc(self, asgi): """Test that error in inner middleware leaves""" global context class RaiseErrorMiddleware: - def process_resource(self, req, resp, resource): raise Exception('Always fail') - app = falcon.App(independent_middleware=True, + class RaiseErrorMiddlewareAsync: + async def process_resource(self, req, resp, resource): + raise Exception('Always fail') + + rem = RaiseErrorMiddlewareAsync() if asgi else RaiseErrorMiddleware() + + app = create_app(asgi, independent_middleware=True, middleware=[ExecutedFirstMiddleware(), - RaiseErrorMiddleware(), + rem, ExecutedLastMiddleware()]) def handler(req, resp, ex, params): @@ -651,9 +773,9 @@ def handler(req, resp, ex, params): class TestRemoveBasePathMiddleware(TestMiddleware): - def test_base_path_is_removed_before_routing(self): + def test_base_path_is_removed_before_routing(self, asgi): """Test that RemoveBasePathMiddleware is executed before routing""" - app = falcon.App(middleware=RemoveBasePathMiddleware()) + app = create_app(asgi, middleware=RemoveBasePathMiddleware()) # We dont include /base_path as it will be removed in middleware app.add_route('/sub_path', MiddlewareClassResource()) @@ -669,7 +791,7 @@ def test_base_path_is_removed_before_routing(self): class TestResourceMiddleware(TestMiddleware): @pytest.mark.parametrize('independent_middleware', [True, False]) - def test_can_access_resource_params(self, independent_middleware): + def test_can_access_resource_params(self, asgi, independent_middleware): """Test that params can be accessed from within process_resource""" global context @@ -677,7 +799,7 @@ class Resource: def on_get(self, req, resp, **params): resp.body = json.dumps(params) - app = falcon.App(middleware=AccessParamsMiddleware(), + app = create_app(asgi, middleware=AccessParamsMiddleware(), independent_middleware=independent_middleware) app.add_route('/path/{id}', Resource()) client = testing.TestClient(app) @@ -690,7 +812,7 @@ def on_get(self, req, resp, **params): class TestEmptySignatureMiddleware(TestMiddleware): - def test_dont_need_params_in_signature(self): + def test_dont_need_params_in_signature(self, asgi): """ Verify that we don't need parameters in the process_* signatures (for side-effect-only middlewares, mostly). Makes no difference on py27 @@ -698,13 +820,13 @@ def test_dont_need_params_in_signature(self): https://github.com/falconry/falcon/issues/1254 """ - falcon.App(middleware=EmptySignatureMiddleware()) + create_app(asgi, middleware=EmptySignatureMiddleware()) class TestErrorHandling(TestMiddleware): - def test_error_composed_before_resp_middleware_called(self): + def test_error_composed_before_resp_middleware_called(self, asgi): mw = CaptureResponseMiddleware() - app = falcon.App(middleware=mw) + app = create_app(asgi, middleware=mw) app.add_route('/', MiddlewareClassResource()) client = testing.TestClient(app) @@ -722,18 +844,25 @@ def test_error_composed_before_resp_middleware_called(self): assert isinstance(mw.req, falcon.Request) assert isinstance(mw.resource, MiddlewareClassResource) - def test_http_status_raised_from_error_handler(self): + def test_http_status_raised_from_error_handler(self, asgi): mw = CaptureResponseMiddleware() - app = falcon.App(middleware=mw) + app = create_app(asgi, middleware=mw) app.add_route('/', MiddlewareClassResource()) client = testing.TestClient(app) + # NOTE(kgriffs): Use the old-style error handler signature to + # ensure our shim for that works as expected. def _http_error_handler(error, req, resp, params): raise falcon.HTTPStatus(falcon.HTTP_201) + async def _http_error_handler_async(error, req, resp, params): + raise falcon.HTTPStatus(falcon.HTTP_201) + + h = _http_error_handler_async if asgi else _http_error_handler + # NOTE(kgriffs): This will take precedence over the default # handler for facon.HTTPError. - app.add_error_handler(falcon.HTTPError, _http_error_handler) + app.add_error_handler(falcon.HTTPError, h) response = client.simulate_request(path='/', method='POST') assert response.status == falcon.HTTP_201 @@ -744,21 +873,21 @@ class TestShortCircuiting(TestMiddleware): def setup_method(self, method): super(TestShortCircuiting, self).setup_method(method) - def _make_client(self, independent_middleware=True): + def _make_client(self, asgi, independent_middleware=True): mw = [ RequestTimeMiddleware(), ResponseCacheMiddlware(), TransactionIdMiddleware(), ] - app = falcon.App(middleware=mw, independent_middleware=independent_middleware) + app = create_app(asgi, middleware=mw, independent_middleware=independent_middleware) app.add_route('/', MiddlewareClassResource()) app.add_route('/cached', MiddlewareClassResource()) app.add_route('/cached/resource', MiddlewareClassResource()) return testing.TestClient(app) - def test_process_request_not_cached(self): - response = self._make_client().simulate_get('/') + def test_process_request_not_cached(self, asgi): + response = self._make_client(asgi).simulate_get('/') assert response.status == falcon.HTTP_200 assert response.json == _EXPECTED_BODY assert 'transaction_id' in context @@ -767,8 +896,8 @@ def test_process_request_not_cached(self): assert 'end_time' in context @pytest.mark.parametrize('independent_middleware', [True, False]) - def test_process_request_cached(self, independent_middleware): - response = self._make_client(independent_middleware).simulate_get('/cached') + def test_process_request_cached(self, asgi, independent_middleware): + response = self._make_client(asgi, independent_middleware).simulate_get('/cached') assert response.status == falcon.HTTP_200 assert response.json == ResponseCacheMiddlware.PROCESS_REQUEST_CACHED_BODY @@ -788,8 +917,8 @@ def test_process_request_cached(self, independent_middleware): assert 'end_time' in context @pytest.mark.parametrize('independent_middleware', [True, False]) - def test_process_resource_cached(self, independent_middleware): - response = self._make_client(independent_middleware).simulate_get('/cached/resource') + def test_process_resource_cached(self, asgi, independent_middleware): + response = self._make_client(asgi, independent_middleware).simulate_get('/cached/resource') assert response.status == falcon.HTTP_200 assert response.json == ResponseCacheMiddlware.PROCESS_RESOURCE_CACHED_BODY @@ -820,9 +949,27 @@ class TestCORSMiddlewareWithAnotherMiddleware(TestMiddleware): (CaptureResponseMiddleware(),), iter([CaptureResponseMiddleware()]), ]) - def test_api_initialization_with_cors_enabled_and_middleware_param(self, mw): - app = falcon.App(middleware=mw, cors_enable=True) + def test_api_initialization_with_cors_enabled_and_middleware_param(self, mw, asgi): + app = create_app(asgi, middleware=mw, cors_enable=True) app.add_route('/', TestCorsResource()) client = testing.TestClient(app) result = client.simulate_get() assert result.headers['Access-Control-Allow-Origin'] == '*' + + +def test_async_postfix_method_must_be_coroutine(): + class FaultyComponentA: + def process_request_async(self, req, resp): + pass + + class FaultyComponentB: + def process_resource_async(self, req, resp, resource, params): + pass + + class FaultyComponentC: + def process_response_async(self, req, resp, resource, req_succeeded): + pass + + for mw in (FaultyComponentA, FaultyComponentB, FaultyComponentC): + with pytest.raises(falcon.errors.CompatibilityError): + create_app(True, middleware=[mw()]) diff --git a/tests/test_python_version_requirements.py b/tests/test_python_version_requirements.py new file mode 100644 index 000000000..d8217ac9e --- /dev/null +++ b/tests/test_python_version_requirements.py @@ -0,0 +1,12 @@ +import pytest + +from falcon import ASGI_SUPPORTED + + +def test_asgi(): + if ASGI_SUPPORTED: + # Should not raise + import falcon.asgi # NOQA + else: + with pytest.raises(ImportError): + import falcon.asgi # NOQA diff --git a/tests/test_query_params.py b/tests/test_query_params.py index a5360208a..2d4996a10 100644 --- a/tests/test_query_params.py +++ b/tests/test_query_params.py @@ -5,9 +5,11 @@ import pytest import falcon -from falcon.errors import HTTPInvalidParam +from falcon.errors import HTTPInvalidParam, UnsupportedError import falcon.testing as testing +from _util import create_app # NOQA + class Resource(testing.SimpleTestResource): @@ -43,9 +45,11 @@ def resource(): @pytest.fixture -def client(): - app = falcon.App() - app.req_options.auto_parse_form_urlencoded = True +def client(asgi): + app = create_app(asgi) + if not asgi: + app.req_options.auto_parse_form_urlencoded = True + return testing.TestClient(app) @@ -54,6 +58,12 @@ def simulate_request_get_query_params(client, path, query_string, **kwargs): def simulate_request_post_query_params(client, path, query_string, **kwargs): + if client.app._ASGI: + pytest.skip( + 'The ASGI implementation does not support ' + 'RequestOptions.auto_parse_form_urlencoded' + ) + headers = kwargs.setdefault('headers', {}) headers['Content-Type'] = 'application/x-www-form-urlencoded' if 'method' not in kwargs: @@ -899,10 +909,19 @@ def test_explicitly_disable_auto_parse(self, client, resource): req = resource.captured_req assert req.get_param('q') is None + def test_asgi_raises_error(self, resource): + app = create_app(asgi=True) + app.add_route('/', resource) + app.req_options.auto_parse_form_urlencoded = True + + with pytest.raises(UnsupportedError): + testing.simulate_get(app, '/') + +@pytest.mark.parametrize('asgi', [True, False]) class TestPostQueryParamsDefaultBehavior: - def test_dont_auto_parse_by_default(self): - app = falcon.App() + def test_dont_auto_parse_by_default(self, asgi): + app = create_app(asgi) resource = testing.SimpleTestResource() app.add_route('/', resource) diff --git a/tests/test_redirects.py b/tests/test_redirects.py index 54ca73565..2138af1b8 100644 --- a/tests/test_redirects.py +++ b/tests/test_redirects.py @@ -3,10 +3,12 @@ import falcon import falcon.testing as testing +from _util import create_app # NOQA + @pytest.fixture -def client(): - app = falcon.App() +def client(asgi): + app = create_app(asgi) resource = RedirectingResource() app.add_route('/', resource) @@ -15,8 +17,8 @@ def client(): @pytest.fixture -def client_exercising_headers(): - app = falcon.App() +def client_exercising_headers(asgi): + app = create_app(asgi) resource = RedirectingResourceWithHeaders() app.add_route('/', resource) diff --git a/tests/test_request_access_route.py b/tests/test_request_access_route.py index 006fb2b61..e25880b27 100644 --- a/tests/test_request_access_route.py +++ b/tests/test_request_access_route.py @@ -1,9 +1,25 @@ +import pytest + from falcon.request import Request import falcon.testing as testing +from _util import create_req # NOQA + + +def test_remote_addr_default(asgi): + req = create_req(asgi) + assert req.remote_addr == '127.0.0.1' + -def test_remote_addr_only(): - req = Request(testing.create_environ( +def test_remote_addr_non_default(asgi): + client_ip = '10.132.0.5' + req = create_req(asgi, remote_addr=client_ip) + assert req.remote_addr == client_ip + + +def test_remote_addr_only(asgi): + req = create_req( + asgi, host='example.com', path='/access_route', headers={ @@ -11,13 +27,14 @@ def test_remote_addr_only(): 'for="unknown", by=_hidden,for="\\"\\\\",' 'for="198\\.51\\.100\\.17\\:1236";' 'proto=https;host=example.com') - })) + }) assert req.remote_addr == '127.0.0.1' -def test_rfc_forwarded(): - req = Request(testing.create_environ( +def test_rfc_forwarded(asgi): + req = create_req( + asgi, host='example.com', path='/access_route', headers={ @@ -28,72 +45,92 @@ def test_rfc_forwarded(): 'for="_don\\\"t_\\try_this\\\\at_home_\\42",' 'for="198\\.51\\.100\\.17\\:1236";' 'proto=https;host=example.com') - })) + }) compares = ['192.0.2.43', '2001:db8:cafe::17', 'x', 'unknown', '"\\', '_don"t_try_this\\at_home_42', - '198.51.100.17'] + '198.51.100.17', '127.0.0.1'] - req.access_route == compares + assert req.access_route == compares # test cached - req.access_route == compares + assert req.access_route == compares -def test_malformed_rfc_forwarded(): - req = Request(testing.create_environ( +def test_malformed_rfc_forwarded(asgi): + req = create_req( + asgi, host='example.com', path='/access_route', headers={ 'Forwarded': 'for' - })) + }) - req.access_route == [] + assert req.access_route == ['127.0.0.1'] # test cached - req.access_route == [] + assert req.access_route == ['127.0.0.1'] -def test_x_forwarded_for(): - req = Request(testing.create_environ( +@pytest.mark.parametrize('include_localhost', [True, False]) +def test_x_forwarded_for(asgi, include_localhost): + + forwarded_for = ( + '192.0.2.43, 2001:db8:cafe::17,' + 'unknown, _hidden, 203.0.113.60' + ) + + if include_localhost: + forwarded_for += ', 127.0.0.1' + + req = create_req( + asgi, host='example.com', path='/access_route', - headers={ - 'X-Forwarded-For': ('192.0.2.43, 2001:db8:cafe::17,' - 'unknown, _hidden, 203.0.113.60') - })) + headers={'X-Forwarded-For': forwarded_for} + ) assert req.access_route == [ '192.0.2.43', '2001:db8:cafe::17', 'unknown', '_hidden', - '203.0.113.60' + '203.0.113.60', + '127.0.0.1', ] -def test_x_real_ip(): - req = Request(testing.create_environ( +def test_x_real_ip(asgi): + req = create_req( + asgi, host='example.com', path='/access_route', headers={ 'X-Real-IP': '2001:db8:cafe::17' - })) + }) - assert req.access_route == ['2001:db8:cafe::17'] + assert req.access_route == ['2001:db8:cafe::17', '127.0.0.1'] -def test_remote_addr(): - req = Request(testing.create_environ( +@pytest.mark.parametrize('remote_addr', ['10.0.0.1', '98.245.211.177']) +def test_remote_addr(asgi, remote_addr): + req = create_req( + asgi, host='example.com', - path='/access_route')) + path='/access_route', + remote_addr=remote_addr, + ) - assert req.access_route == ['127.0.0.1'] + assert req.access_route == [remote_addr] def test_remote_addr_missing(): env = testing.create_environ(host='example.com', path='/access_route') - del env['REMOTE_ADDR'] + + # NOTE(kgriffs): It should not be present, but include this check so + # that in the future if things change, we still cover this case. + if 'REMOTE_ADDR' in env: + del env['REMOTE_ADDR'] req = Request(env) - assert req.access_route == [] + assert req.access_route == ['127.0.0.1'] diff --git a/tests/test_request_attrs.py b/tests/test_request_attrs.py index 2e5f4a760..c0acd2783 100644 --- a/tests/test_request_attrs.py +++ b/tests/test_request_attrs.py @@ -10,7 +10,10 @@ import falcon.uri from falcon.util.structures import ETag -_PROTOCOLS = ['HTTP/1.0', 'HTTP/1.1'] +from _util import create_req # NOQA + + +_HTTP_VERSIONS = ['1.0', '1.1', '2'] def _make_etag(value, is_weak=False): @@ -30,9 +33,29 @@ def _make_etag(value, is_weak=False): return etag +def test_missing_qs(): + env = testing.create_environ() + if 'QUERY_STRING' in env: + del env['QUERY_STRING'] + + # Should not cause an exception when Request is instantiated + Request(env) + + +def test_app_missing(): + env = testing.create_environ() + del env['SCRIPT_NAME'] + req = Request(env) + + assert req.app == '' + + +@pytest.mark.parametrize('asgi', [True, False]) class TestRequestAttributes: def setup_method(self, method): + asgi = self._item.callspec.getparam('asgi') + self.qs = 'marker=deadbeef&limit=10' self.headers = { @@ -41,81 +64,81 @@ def setup_method(self, method): 'Authorization': '' } - self.app = '/test' + self.root_path = '/test' self.path = '/hello' self.relative_uri = self.path + '?' + self.qs - self.req = Request(testing.create_environ( - app=self.app, + self.req = create_req( + asgi, + root_path=self.root_path, port=8080, path='/hello', query_string=self.qs, - headers=self.headers)) + headers=self.headers) - self.req_noqs = Request(testing.create_environ( - app=self.app, + self.req_noqs = create_req( + asgi, + root_path=self.root_path, path='/hello', - headers=self.headers)) - - def test_missing_qs(self): - env = testing.create_environ() - if 'QUERY_STRING' in env: - del env['QUERY_STRING'] - - # Should not cause an exception when Request is instantiated - Request(env) + headers=self.headers) - def test_empty(self): + def test_empty(self, asgi): assert self.req.auth is None - def test_host(self): + def test_host(self, asgi): assert self.req.host == testing.DEFAULT_HOST - def test_subdomain(self): - req = Request(testing.create_environ( + def test_subdomain(self, asgi): + req = create_req( + asgi, host='com', path='/hello', - headers=self.headers)) + headers=self.headers) assert req.subdomain is None - req = Request(testing.create_environ( + req = create_req( + asgi, host='example.com', path='/hello', - headers=self.headers)) + headers=self.headers) assert req.subdomain == 'example' - req = Request(testing.create_environ( + req = create_req( + asgi, host='highwire.example.com', path='/hello', - headers=self.headers)) + headers=self.headers) assert req.subdomain == 'highwire' - req = Request(testing.create_environ( + req = create_req( + asgi, host='lb01.dfw01.example.com', port=8080, path='/hello', - headers=self.headers)) + headers=self.headers) assert req.subdomain == 'lb01' # NOTE(kgriffs): Behavior for IP addresses is undefined, # so just make sure it doesn't blow up. - req = Request(testing.create_environ( + req = create_req( + asgi, host='127.0.0.1', path='/hello', - headers=self.headers)) + headers=self.headers) assert type(req.subdomain) == str # NOTE(kgriffs): Test fallback to SERVER_NAME by using # HTTP 1.0, which will cause .create_environ to not set # HTTP_HOST. - req = Request(testing.create_environ( - protocol='HTTP/1.0', + req = create_req( + asgi, + http_version='1.0', host='example.com', path='/hello', - headers=self.headers)) + headers=self.headers) assert req.subdomain == 'example' - def test_reconstruct_url(self): + def test_reconstruct_url(self, asgi): req = self.req scheme = req.scheme @@ -136,7 +159,7 @@ def test_reconstruct_url(self): '/test/%E5%BB%B6%E5%AE%89', '/test/%C3%A4%C3%B6%C3%BC%C3%9F%E2%82%AC', ]) - def test_nonlatin_path(self, test_path): + def test_nonlatin_path(self, asgi, test_path): # NOTE(kgriffs): When a request comes in, web servers decode # the path. The decoded path may contain UTF-8 characters, # but according to the WSGI spec, no strings can contain chars @@ -153,15 +176,16 @@ def test_nonlatin_path(self, test_path): # path = tunnelled_path.encode('iso-8859-1').decode('utf-8', 'replace') # - req = Request(testing.create_environ( + req = create_req( + asgi, host='com', path=test_path, - headers=self.headers)) + headers=self.headers) assert req.path == falcon.uri.decode(test_path) - def test_uri(self): - prefix = 'http://' + testing.DEFAULT_HOST + ':8080' + self.app + def test_uri(self, asgi): + prefix = 'http://' + testing.DEFAULT_HOST + ':8080' + self.root_path uri = prefix + self.relative_uri assert self.req.url == uri @@ -171,15 +195,15 @@ def test_uri(self): assert self.req.uri == uri assert self.req.uri == uri - uri_noqs = ('http://' + testing.DEFAULT_HOST + self.app + self.path) + uri_noqs = ('http://' + testing.DEFAULT_HOST + self.root_path + self.path) assert self.req_noqs.uri == uri_noqs - def test_uri_https(self): + def test_uri_https(self, asgi): # ======================================================= # Default port, implicit # ======================================================= - req = Request(testing.create_environ( - path='/hello', scheme='https')) + req = create_req( + asgi, path='/hello', scheme='https') uri = ('https://' + testing.DEFAULT_HOST + '/hello') assert req.uri == uri @@ -187,8 +211,8 @@ def test_uri_https(self): # ======================================================= # Default port, explicit # ======================================================= - req = Request(testing.create_environ( - path='/hello', scheme='https', port=443)) + req = create_req( + asgi, path='/hello', scheme='https', port=443) uri = ('https://' + testing.DEFAULT_HOST + '/hello') assert req.uri == uri @@ -196,94 +220,100 @@ def test_uri_https(self): # ======================================================= # Non-default port # ======================================================= - req = Request(testing.create_environ( - path='/hello', scheme='https', port=22)) + req = create_req( + asgi, path='/hello', scheme='https', port=22) uri = ('https://' + testing.DEFAULT_HOST + ':22/hello') assert req.uri == uri - def test_uri_http_1_0(self): + def test_uri_http_1_0(self, asgi): # ======================================================= # HTTP, 80 # ======================================================= - req = Request(testing.create_environ( - protocol='HTTP/1.0', - app=self.app, + req = create_req( + asgi, + http_version='1.0', + root_path=self.root_path, port=80, path='/hello', query_string=self.qs, - headers=self.headers)) + headers=self.headers) uri = ('http://' + testing.DEFAULT_HOST + - self.app + self.relative_uri) + self.root_path + self.relative_uri) assert req.uri == uri # ======================================================= # HTTP, 80 # ======================================================= - req = Request(testing.create_environ( - protocol='HTTP/1.0', - app=self.app, + req = create_req( + asgi, + http_version='1.0', + root_path=self.root_path, port=8080, path='/hello', query_string=self.qs, - headers=self.headers)) + headers=self.headers) uri = ('http://' + testing.DEFAULT_HOST + ':8080' + - self.app + self.relative_uri) + self.root_path + self.relative_uri) assert req.uri == uri # ======================================================= # HTTP, 80 # ======================================================= - req = Request(testing.create_environ( - protocol='HTTP/1.0', + req = create_req( + asgi, + http_version='1.0', scheme='https', - app=self.app, + root_path=self.root_path, port=443, path='/hello', query_string=self.qs, - headers=self.headers)) + headers=self.headers) uri = ('https://' + testing.DEFAULT_HOST + - self.app + self.relative_uri) + self.root_path + self.relative_uri) assert req.uri == uri # ======================================================= # HTTP, 80 # ======================================================= - req = Request(testing.create_environ( - protocol='HTTP/1.0', + req = create_req( + asgi, + http_version='1.0', scheme='https', - app=self.app, + root_path=self.root_path, port=22, path='/hello', query_string=self.qs, - headers=self.headers)) + headers=self.headers) uri = ('https://' + testing.DEFAULT_HOST + ':22' + - self.app + self.relative_uri) + self.root_path + self.relative_uri) assert req.uri == uri - def test_relative_uri(self): - assert self.req.relative_uri == self.app + self.relative_uri - assert self.req_noqs.relative_uri == self.app + self.path + def test_relative_uri(self, asgi): + assert self.req.relative_uri == self.root_path + self.relative_uri + assert self.req_noqs.relative_uri == self.root_path + self.path - req_noapp = Request(testing.create_environ( + req_noapp = create_req( + asgi, path='/hello', query_string=self.qs, - headers=self.headers)) + headers=self.headers) assert req_noapp.relative_uri == self.relative_uri - req_noapp = Request(testing.create_environ( + req_noapp = create_req( + asgi, path='/hello/', query_string=self.qs, - headers=self.headers)) + headers=self.headers) relative_trailing_uri = self.path + '/?' + self.qs # NOTE(kgriffs): Call twice to check caching works @@ -292,117 +322,118 @@ def test_relative_uri(self): options = RequestOptions() options.strip_url_path_trailing_slash = False - req_noapp = Request(testing.create_environ( + req_noapp = create_req( + asgi, + options=options, path='/hello/', query_string=self.qs, - headers=self.headers), - options=options) + headers=self.headers) assert req_noapp.relative_uri == '/hello/' + '?' + self.qs - def test_client_accepts(self): + def test_client_accepts(self, asgi): headers = {'Accept': 'application/xml'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.client_accepts('application/xml') headers = {'Accept': '*/*'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.client_accepts('application/xml') assert req.client_accepts('application/json') assert req.client_accepts('application/x-msgpack') headers = {'Accept': 'application/x-msgpack'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert not req.client_accepts('application/xml') assert not req.client_accepts('application/json') assert req.client_accepts('application/x-msgpack') headers = {} # NOTE(kgriffs): Equivalent to '*/*' per RFC - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.client_accepts('application/xml') headers = {'Accept': 'application/json'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert not req.client_accepts('application/xml') headers = {'Accept': 'application/x-msgpack'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.client_accepts('application/x-msgpack') headers = {'Accept': 'application/xm'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert not req.client_accepts('application/xml') headers = {'Accept': 'application/*'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.client_accepts('application/json') assert req.client_accepts('application/xml') assert req.client_accepts('application/x-msgpack') headers = {'Accept': 'text/*'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.client_accepts('text/plain') assert req.client_accepts('text/csv') assert not req.client_accepts('application/xhtml+xml') headers = {'Accept': 'text/*, application/xhtml+xml; q=0.0'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.client_accepts('text/plain') assert req.client_accepts('text/csv') assert not req.client_accepts('application/xhtml+xml') headers = {'Accept': 'text/*; q=0.1, application/xhtml+xml; q=0.5'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.client_accepts('text/plain') assert req.client_accepts('application/xhtml+xml') headers = {'Accept': 'text/*, application/*'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.client_accepts('text/plain') assert req.client_accepts('application/xml') assert req.client_accepts('application/json') assert req.client_accepts('application/x-msgpack') headers = {'Accept': 'text/*,application/*'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.client_accepts('text/plain') assert req.client_accepts('application/xml') assert req.client_accepts('application/json') assert req.client_accepts('application/x-msgpack') - def test_client_accepts_bogus(self): + def test_client_accepts_bogus(self, asgi): headers = {'Accept': '~'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert not req.client_accepts('text/plain') assert not req.client_accepts('application/json') - def test_client_accepts_props(self): + def test_client_accepts_props(self, asgi): headers = {'Accept': 'application/xml'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.client_accepts_xml assert not req.client_accepts_json assert not req.client_accepts_msgpack headers = {'Accept': 'application/*'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.client_accepts_xml assert req.client_accepts_json assert req.client_accepts_msgpack headers = {'Accept': 'application/json'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert not req.client_accepts_xml assert req.client_accepts_json assert not req.client_accepts_msgpack headers = {'Accept': 'application/x-msgpack'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert not req.client_accepts_xml assert not req.client_accepts_json assert req.client_accepts_msgpack headers = {'Accept': 'application/msgpack'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert not req.client_accepts_xml assert not req.client_accepts_json assert req.client_accepts_msgpack @@ -410,14 +441,14 @@ def test_client_accepts_props(self): headers = { 'Accept': 'application/json,application/xml,application/x-msgpack' } - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.client_accepts_xml assert req.client_accepts_json assert req.client_accepts_msgpack - def test_client_prefers(self): + def test_client_prefers(self, asgi): headers = {'Accept': 'application/xml'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) preferred_type = req.client_prefers(['application/xml']) assert preferred_type == 'application/xml' @@ -429,62 +460,62 @@ def test_client_prefers(self): assert preferred_type == 'application/xml' headers = {'Accept': 'text/*; q=0.1, application/xhtml+xml; q=0.5'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) preferred_type = req.client_prefers(['application/xhtml+xml']) assert preferred_type == 'application/xhtml+xml' headers = {'Accept': '3p12845j;;;asfd;'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) preferred_type = req.client_prefers(['application/xhtml+xml']) assert preferred_type is None - def test_range(self): + def test_range(self, asgi): headers = {'Range': 'bytes=10-'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.range == (10, -1) headers = {'Range': 'bytes=10-20'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.range == (10, 20) headers = {'Range': 'bytes=-10240'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.range == (-10240, -1) headers = {'Range': 'bytes=0-2'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.range == (0, 2) headers = {'Range': ''} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) with pytest.raises(falcon.HTTPInvalidHeader): req.range - req = Request(testing.create_environ()) + req = create_req(asgi) assert req.range is None - def test_range_unit(self): + def test_range_unit(self, asgi): headers = {'Range': 'bytes=10-'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.range == (10, -1) assert req.range_unit == 'bytes' headers = {'Range': 'items=10-'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.range == (10, -1) assert req.range_unit == 'items' headers = {'Range': ''} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) with pytest.raises(falcon.HTTPInvalidHeader): req.range_unit - req = Request(testing.create_environ()) + req = create_req(asgi) assert req.range_unit is None - def test_range_invalid(self): + def test_range_invalid(self, asgi): headers = {'Range': 'bytes=10240'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) with pytest.raises(falcon.HTTPBadRequest): req.range @@ -493,60 +524,61 @@ def test_range_invalid(self): 'invalid. The range offsets are missing.') self._test_error_details(headers, 'range', falcon.HTTPInvalidHeader, - 'Invalid header value', expected_desc) + 'Invalid header value', expected_desc, + asgi) headers = {'Range': 'bytes=--'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) with pytest.raises(falcon.HTTPBadRequest): req.range headers = {'Range': 'bytes=-3-'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) with pytest.raises(falcon.HTTPBadRequest): req.range headers = {'Range': 'bytes=-3-4'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) with pytest.raises(falcon.HTTPBadRequest): req.range headers = {'Range': 'bytes=3-3-4'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) with pytest.raises(falcon.HTTPBadRequest): req.range headers = {'Range': 'bytes=3-3-'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) with pytest.raises(falcon.HTTPBadRequest): req.range headers = {'Range': 'bytes=3-3- '} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) with pytest.raises(falcon.HTTPBadRequest): req.range headers = {'Range': 'bytes=fizbit'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) with pytest.raises(falcon.HTTPBadRequest): req.range headers = {'Range': 'bytes=a-'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) with pytest.raises(falcon.HTTPBadRequest): req.range headers = {'Range': 'bytes=a-3'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) with pytest.raises(falcon.HTTPBadRequest): req.range headers = {'Range': 'bytes=-b'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) with pytest.raises(falcon.HTTPBadRequest): req.range headers = {'Range': 'bytes=3-b'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) with pytest.raises(falcon.HTTPBadRequest): req.range @@ -556,7 +588,8 @@ def test_range_invalid(self): 'according to RFC 7233.') self._test_error_details(headers, 'range', falcon.HTTPInvalidHeader, - 'Invalid header value', expected_desc) + 'Invalid header value', expected_desc, + asgi) headers = {'Range': 'bytes=0-0,-1'} expected_desc = ('The value provided for the Range ' @@ -564,7 +597,8 @@ def test_range_invalid(self): 'continuous range.') self._test_error_details(headers, 'range', falcon.HTTPInvalidHeader, - 'Invalid header value', expected_desc) + 'Invalid header value', expected_desc, + asgi) headers = {'Range': '10-'} expected_desc = ('The value provided for the Range ' @@ -572,59 +606,64 @@ def test_range_invalid(self): "prefixed with a range unit, e.g. 'bytes='") self._test_error_details(headers, 'range', falcon.HTTPInvalidHeader, - 'Invalid header value', expected_desc) + 'Invalid header value', expected_desc, + asgi) - def test_missing_attribute_header(self): - req = Request(testing.create_environ()) + def test_missing_attribute_header(self, asgi): + req = create_req(asgi) assert req.range is None - req = Request(testing.create_environ()) + req = create_req(asgi) assert req.content_length is None - def test_content_length(self): + def test_content_length(self, asgi): headers = {'content-length': '5656'} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.content_length == 5656 headers = {'content-length': ''} - req = Request(testing.create_environ(headers=headers)) + req = create_req(asgi, headers=headers) assert req.content_length is None - def test_bogus_content_length_nan(self): + def test_bogus_content_length_nan(self, asgi): headers = {'content-length': 'fuzzy-bunnies'} expected_desc = ('The value provided for the ' 'Content-Length header is invalid. The value ' 'of the header must be a number.') self._test_error_details(headers, 'content_length', falcon.HTTPInvalidHeader, - 'Invalid header value', expected_desc) + 'Invalid header value', expected_desc, + asgi) - def test_bogus_content_length_neg(self): + def test_bogus_content_length_neg(self, asgi): headers = {'content-length': '-1'} expected_desc = ('The value provided for the Content-Length ' 'header is invalid. The value of the header ' 'must be a positive number.') self._test_error_details(headers, 'content_length', falcon.HTTPInvalidHeader, - 'Invalid header value', expected_desc) + 'Invalid header value', expected_desc, + asgi) @pytest.mark.parametrize('header,attr', [ ('Date', 'date'), ('If-Modified-Since', 'if_modified_since'), ('If-Unmodified-Since', 'if_unmodified_since'), ]) - def test_date(self, header, attr): + def test_date(self, asgi, header, attr): date = datetime.datetime(2013, 4, 4, 5, 19, 18) date_str = 'Thu, 04 Apr 2013 05:19:18 GMT' - self._test_header_expected_value(header, date_str, attr, date) + headers = {header: date_str} + req = create_req(asgi, headers=headers) + assert getattr(req, attr) == date @pytest.mark.parametrize('header,attr', [ ('Date', 'date'), ('If-Modified-Since', 'if_modified_since'), ('If-Unmodified-Since', 'if_unmodified_since'), ]) - def test_date_invalid(self, header, attr): + def test_date_invalid(self, asgi, header, attr): # Date formats don't conform to RFC 1123 headers = {header: 'Thu, 04 Apr 2013'} @@ -635,91 +674,90 @@ def test_date_invalid(self, header, attr): self._test_error_details(headers, attr, falcon.HTTPInvalidHeader, 'Invalid header value', - expected_desc.format(header)) + expected_desc.format(header), + asgi) headers = {header: ''} self._test_error_details(headers, attr, falcon.HTTPInvalidHeader, 'Invalid header value', - expected_desc.format(header)) + expected_desc.format(header), + asgi) @pytest.mark.parametrize('attr', ('date', 'if_modified_since', 'if_unmodified_since')) - def test_date_missing(self, attr): - req = Request(testing.create_environ()) + def test_date_missing(self, asgi, attr): + req = create_req(asgi) assert getattr(req, attr) is None - def test_attribute_headers(self): - date = 'Wed, 21 Oct 2015 07:28:00 GMT' - auth = 'HMAC_SHA1 c590afa9bb59191ffab30f223791e82d3fd3e3af' - agent = 'testing/1.0.1' - default_agent = 'curl/7.24.0 (x86_64-apple-darwin12.0)' - referer = 'https://www.google.com/' - - self._test_attribute_header('Accept', 'x-falcon', 'accept', - default='*/*') - - self._test_attribute_header('Authorization', auth, 'auth') - - self._test_attribute_header('Content-Type', 'text/plain', - 'content_type') - self._test_attribute_header('Expect', '100-continue', 'expect') - - self._test_attribute_header('If-Range', date, 'if_range') + @pytest.mark.parametrize('name,value,attr,default', [ + ('Accept', 'x-falcon', 'accept', '*/*'), + ('Authorization', 'HMAC_SHA1 c590afa9bb59191ffab30f223791e82d3fd3e3af', 'auth', None), + ('Content-Type', 'text/plain', 'content_type', None), + ('Expect', '100-continue', 'expect', None), + ('If-Range', 'Wed, 21 Oct 2015 07:28:00 GMT', 'if_range', None), + ('User-Agent', 'testing/3.0', 'user_agent', None), + ('Referer', 'https://www.google.com/', 'referer', None), + ]) + def test_attribute_headers(self, asgi, name, value, attr, default): + headers = {name: value} + req = create_req(asgi, headers=headers) + assert getattr(req, attr) == value - self._test_attribute_header('User-Agent', agent, 'user_agent', - default=default_agent) - self._test_attribute_header('Referer', referer, 'referer') + req = create_req(asgi) + assert getattr(req, attr) == default - def test_method(self): + def test_method(self, asgi): assert self.req.method == 'GET' - self.req = Request(testing.create_environ(path='', method='HEAD')) + self.req = create_req(asgi, path='', method='HEAD') assert self.req.method == 'HEAD' - def test_empty_path(self): - self.req = Request(testing.create_environ(path='')) + def test_empty_path(self, asgi): + self.req = create_req(asgi, path='') assert self.req.path == '/' - def test_content_type_method(self): + def test_content_type_method(self, asgi): assert self.req.get_header('content-type') == 'text/plain' - def test_content_length_method(self): + def test_content_length_method(self, asgi): assert self.req.get_header('content-length') == '4829' # TODO(kgriffs): Migrate to pytest and parametrized fixtures # to DRY things up a bit. - @pytest.mark.parametrize('protocol', _PROTOCOLS) - def test_port_explicit(self, protocol): + @pytest.mark.parametrize('http_version', _HTTP_VERSIONS) + def test_port_explicit(self, asgi, http_version): port = 9000 - req = Request(testing.create_environ( - protocol=protocol, + req = create_req( + asgi, + http_version=http_version, port=port, - app=self.app, + root_path=self.root_path, path='/hello', query_string=self.qs, - headers=self.headers)) + headers=self.headers) assert req.port == port - @pytest.mark.parametrize('protocol', _PROTOCOLS) - def test_scheme_https(self, protocol): + @pytest.mark.parametrize('http_version', _HTTP_VERSIONS) + def test_scheme_https(self, asgi, http_version): scheme = 'https' - req = Request(testing.create_environ( - protocol=protocol, + req = create_req( + asgi, + http_version=http_version, scheme=scheme, - app=self.app, + root_path=self.root_path, path='/hello', query_string=self.qs, - headers=self.headers)) + headers=self.headers) assert req.scheme == scheme assert req.port == 443 @pytest.mark.parametrize( - 'protocol, set_forwarded_proto', - list(itertools.product(_PROTOCOLS, [True, False])) + 'http_version, set_forwarded_proto', + list(itertools.product(_HTTP_VERSIONS, [True, False])) ) - def test_scheme_http(self, protocol, set_forwarded_proto): + def test_scheme_http(self, asgi, http_version, set_forwarded_proto): scheme = 'http' forwarded_scheme = 'HttPs' @@ -728,13 +766,14 @@ def test_scheme_http(self, protocol, set_forwarded_proto): if set_forwarded_proto: headers['X-Forwarded-Proto'] = forwarded_scheme - req = Request(testing.create_environ( - protocol=protocol, + req = create_req( + asgi, + http_version=http_version, scheme=scheme, - app=self.app, + root_path=self.root_path, path='/hello', query_string=self.qs, - headers=headers)) + headers=headers) assert req.scheme == scheme assert req.port == 80 @@ -744,63 +783,58 @@ def test_scheme_http(self, protocol, set_forwarded_proto): else: assert req.forwarded_scheme == scheme - @pytest.mark.parametrize('protocol', _PROTOCOLS) - def test_netloc_default_port(self, protocol): - req = Request(testing.create_environ( - protocol=protocol, - app=self.app, + @pytest.mark.parametrize('http_version', _HTTP_VERSIONS) + def test_netloc_default_port(self, asgi, http_version): + req = create_req( + asgi, + http_version=http_version, + root_path=self.root_path, path='/hello', query_string=self.qs, - headers=self.headers)) + headers=self.headers) assert req.netloc == 'falconframework.org' - @pytest.mark.parametrize('protocol', _PROTOCOLS) - def test_netloc_nondefault_port(self, protocol): - req = Request(testing.create_environ( - protocol=protocol, + @pytest.mark.parametrize('http_version', _HTTP_VERSIONS) + def test_netloc_nondefault_port(self, asgi, http_version): + req = create_req( + asgi, + http_version=http_version, port='8080', - app=self.app, + root_path=self.root_path, path='/hello', query_string=self.qs, - headers=self.headers)) + headers=self.headers) assert req.netloc == 'falconframework.org:8080' - @pytest.mark.parametrize('protocol', _PROTOCOLS) - def test_netloc_from_env(self, protocol): + @pytest.mark.parametrize('http_version', _HTTP_VERSIONS) + def test_netloc_from_env(self, asgi, http_version): port = 9000 host = 'example.org' - env = testing.create_environ( - protocol=protocol, + + req = create_req( + asgi, + http_version=http_version, host=host, port=port, - app=self.app, + root_path=self.root_path, path='/hello', query_string=self.qs, headers=self.headers) - req = Request(env) - assert req.port == port assert req.netloc == '{}:{}'.format(host, port) - def test_app_present(self): - req = Request(testing.create_environ(app='/moving-pictures')) + def test_app_present(self, asgi): + req = create_req(asgi, root_path='/moving-pictures') assert req.app == '/moving-pictures' - def test_app_blank(self): - req = Request(testing.create_environ(app='')) - assert req.app == '' - - def test_app_missing(self): - env = testing.create_environ() - del env['SCRIPT_NAME'] - req = Request(env) - + def test_app_blank(self, asgi): + req = create_req(asgi, root_path='') assert req.app == '' - @pytest.mark.parametrize('etag,expected', [ + @pytest.mark.parametrize('etag,expected_value', [ ('', None), (' ', None), (' ', None), @@ -898,18 +932,37 @@ def test_app_missing(self): ] ), ]) - def test_etag(self, etag, expected): - self._test_header_etag('If-Match', etag, 'if_match', expected) - self._test_header_etag('If-None-Match', etag, 'if_none_match', expected) + @pytest.mark.parametrize('name,attr', [ + ('If-Match', 'if_match'), + ('If-None-Match', 'if_none_match'), + ]) + def test_etag(self, asgi, name, attr, etag, expected_value): + headers = {name: etag} + req = create_req(asgi, headers=headers) - def test_etag_is_missing(self): + # NOTE(kgriffs): Loop in order to test caching + for __ in range(3): + value = getattr(req, attr) + + if expected_value is None: + assert value is None + return + + assert value is not None + + for element, expected_element in zip(value, expected_value): + assert element == expected_element + if isinstance(expected_element, ETag): + assert element.is_weak == expected_element.is_weak + + def test_etag_is_missing(self, asgi): # NOTE(kgriffs): Loop in order to test caching for __ in range(3): assert self.req.if_match is None assert self.req.if_none_match is None @pytest.mark.parametrize('header_value', ['', ' ', ' ']) - def test_etag_parsing_helper(self, header_value): + def test_etag_parsing_helper(self, asgi, header_value): # NOTE(kgriffs): Test a couple of cases that are not directly covered # elsewhere (but that we want the helper to still support # for the sake of avoiding suprises if they are ever called without @@ -921,41 +974,9 @@ def test_etag_parsing_helper(self, header_value): # Helpers # ------------------------------------------------------------------------- - def _test_attribute_header(self, name, value, attr, default=None): - headers = {name: value} - req = Request(testing.create_environ(headers=headers)) - assert getattr(req, attr) == value - - req = Request(testing.create_environ()) - assert getattr(req, attr) == default - - def _test_header_expected_value(self, name, value, attr, expected_value): - headers = {name: value} - req = Request(testing.create_environ(headers=headers)) - assert getattr(req, attr) == expected_value - - def _test_header_etag(self, name, value, attr, expected_value): - headers = {name: value} - req = Request(testing.create_environ(headers=headers)) - - # NOTE(kgriffs): Loop in order to test caching - for __ in range(3): - value = getattr(req, attr) - - if expected_value is None: - assert value is None - return - - assert value is not None - - for element, expected_element in zip(value, expected_value): - assert element == expected_element - if isinstance(expected_element, ETag): - assert element.is_weak == expected_element.is_weak - def _test_error_details(self, headers, attr_name, - error_type, title, description): - req = Request(testing.create_environ(headers=headers)) + error_type, title, description, asgi): + req = create_req(asgi, headers=headers) try: getattr(req, attr_name) diff --git a/tests/test_request_body.py b/tests/test_request_body.py index a72834667..eb33f1be4 100644 --- a/tests/test_request_body.py +++ b/tests/test_request_body.py @@ -30,8 +30,8 @@ def _get_wrapped_stream(self, req): stream = stream.stream if isinstance(stream, InputWrapper): stream = stream.input - if isinstance(stream, io.BytesIO): - return stream + + return stream def test_empty_body(self, client, resource): client.app.add_route('/', resource) diff --git a/tests/test_request_context.py b/tests/test_request_context.py index 301e8bbc1..5f707fd4e 100644 --- a/tests/test_request_context.py +++ b/tests/test_request_context.py @@ -6,9 +6,8 @@ class TestRequestContext: - def test_default_request_context(self): - env = testing.create_environ() - req = Request(env) + def test_default_request_context(self,): + req = testing.create_req() req.context.hello = 'World' assert req.context.hello == 'World' diff --git a/tests/test_request_forwarded.py b/tests/test_request_forwarded.py index 2103629cf..49e82b763 100644 --- a/tests/test_request_forwarded.py +++ b/tests/test_request_forwarded.py @@ -1,15 +1,15 @@ import pytest -from falcon.request import Request -import falcon.testing as testing +from _util import create_req # NOQA -def test_no_forwarded_headers(): - req = Request(testing.create_environ( +def test_no_forwarded_headers(asgi): + req = create_req( + asgi, host='example.com', path='/languages', - app='backoffice' - )) + root_path='backoffice' + ) assert req.forwarded is None assert req.forwarded_uri == req.uri @@ -17,12 +17,13 @@ def test_no_forwarded_headers(): assert req.forwarded_prefix == 'http://example.com/backoffice' -def test_x_forwarded_host(): - req = Request(testing.create_environ( +def test_x_forwarded_host(asgi): + req = create_req( + asgi, host='suchproxy.suchtesting.com', path='/languages', headers={'X-Forwarded-Host': 'something.org'} - )) + ) assert req.forwarded is None assert req.forwarded_host == 'something.org' @@ -32,12 +33,13 @@ def test_x_forwarded_host(): assert req.forwarded_prefix == 'http://something.org' # Check cached value -def test_x_forwarded_proto(): - req = Request(testing.create_environ( +def test_x_forwarded_proto(asgi): + req = create_req( + asgi, host='example.org', path='/languages', headers={'X-Forwarded-Proto': 'HTTPS'} - )) + ) assert req.forwarded is None assert req.forwarded_scheme == 'https' @@ -46,14 +48,15 @@ def test_x_forwarded_proto(): assert req.forwarded_prefix == 'https://example.org' -def test_forwarded_host(): - req = Request(testing.create_environ( +def test_forwarded_host(asgi): + req = create_req( + asgi, host='suchproxy02.suchtesting.com', path='/languages', headers={ 'Forwarded': 'host=something.org , host=suchproxy01.suchtesting.com' } - )) + ) assert req.forwarded is not None for f in req.forwarded: @@ -70,8 +73,9 @@ def test_forwarded_host(): assert req.forwarded_prefix == 'http://something.org' -def test_forwarded_multiple_params(): - req = Request(testing.create_environ( +def test_forwarded_multiple_params(asgi): + req = create_req( + asgi, host='suchproxy02.suchtesting.com', path='/languages', headers={ @@ -80,7 +84,7 @@ def test_forwarded_multiple_params(): 'by=203.0.113.43;host=suchproxy01.suchtesting.com;proto=httP' ) } - )) + ) assert req.forwarded is not None @@ -101,15 +105,16 @@ def test_forwarded_multiple_params(): assert req.forwarded_prefix == 'https://something.org' -def test_forwarded_missing_first_hop_host(): - req = Request(testing.create_environ( +def test_forwarded_missing_first_hop_host(asgi): + req = create_req( + asgi, host='suchproxy02.suchtesting.com', path='/languages', - app='doge', + root_path='doge', headers={ 'Forwarded': 'for=108.166.30.185,host=suchproxy01.suchtesting.com' } - )) + ) assert req.forwarded[0].host is None assert req.forwarded[0].src == '108.166.30.185' @@ -124,15 +129,16 @@ def test_forwarded_missing_first_hop_host(): assert req.forwarded_prefix == 'http://suchproxy02.suchtesting.com/doge' -def test_forwarded_quote_escaping(): - req = Request(testing.create_environ( +def test_forwarded_quote_escaping(asgi): + req = create_req( + asgi, host='suchproxy02.suchtesting.com', path='/languages', - app='doge', + root_path='doge', headers={ 'Forwarded': 'for="1\\.2\\.3\\.4";some="extra,\\"info\\""' } - )) + ) assert req.forwarded[0].host is None assert req.forwarded[0].src == '1.2.3.4' @@ -145,16 +151,17 @@ def test_forwarded_quote_escaping(): ('for=1.2.3.4;by="4.3.2.\\1"thing="blah"', '4.3.2.1'), ('for=1.2.3.4;by="4.3.\\2\\.1" thing="blah"', '4.3.2.1'), ]) -def test_escape_malformed_requests(forwarded, expected_dest): +def test_escape_malformed_requests(forwarded, expected_dest, asgi): - req = Request(testing.create_environ( + req = create_req( + asgi, host='suchproxy02.suchtesting.com', path='/languages', - app='doge', + root_path='doge', headers={ 'Forwarded': forwarded } - )) + ) assert len(req.forwarded) == 1 assert req.forwarded[0].src == '1.2.3.4' diff --git a/tests/test_request_media.py b/tests/test_request_media.py index 97c7d4399..e3c5e5365 100644 --- a/tests/test_request_media.py +++ b/tests/test_request_media.py @@ -1,37 +1,81 @@ import pytest -import falcon from falcon import errors, media, testing +from _util import create_app # NOQA -def create_client(handlers=None): - res = testing.SimpleTestResource() - app = falcon.App() - app.add_route('/', res) +def create_client(asgi, handlers=None, resource=None): + if not resource: + resource = testing.SimpleTestResourceAsync() if asgi else testing.SimpleTestResource() + + app = create_app(asgi) + app.add_route('/', resource) if handlers: app.req_options.media_handlers.update(handlers) - client = testing.TestClient(app) - client.resource = res + client = testing.TestClient(app, headers={'capture-req-media': 'yes'}) + client.resource = resource return client +@pytest.fixture(params=[True, False]) +def client(request): + return create_client(request.param) + + +class ResourceCachedMedia: + def on_post(self, req, resp, **kwargs): + self.captured_req_media = req.media + + # NOTE(kgriffs): Ensure that the media object is cached + assert self.captured_req_media is req.media + + +class ResourceCachedMediaAsync: + async def on_post(self, req, resp, **kwargs): + self.captured_req_media = await req.get_media() + + # NOTE(kgriffs): Ensure that the media object is cached + assert self.captured_req_media is await req.get_media() + + +class ResourceInvalidMedia: + def __init__(self, expected_error): + self._expected_error = expected_error + + def on_post(self, req, resp, **kwargs): + with pytest.raises(self._expected_error) as error: + req.media + + self.captured_error = error + + +class ResourceInvalidMediaAsync: + def __init__(self, expected_error): + self._expected_error = expected_error + + async def on_post(self, req, resp, **kwargs): + with pytest.raises(self._expected_error) as error: + await req.get_media() + + self.captured_error = error + + @pytest.mark.parametrize('media_type', [ (None), ('*/*'), ('application/json'), ('application/json; charset=utf-8'), ]) -def test_json(media_type): - client = create_client() +def test_json(client, media_type): expected_body = b'{"something": true}' headers = {'Content-Type': media_type} client.simulate_post('/', body=expected_body, headers=headers) - media = client.resource.captured_req.media + media = client.resource.captured_req_media assert media is not None assert media.get('something') is True @@ -41,8 +85,8 @@ def test_json(media_type): ('application/msgpack; charset=utf-8'), ('application/x-msgpack'), ]) -def test_msgpack(media_type): - client = create_client({ +def test_msgpack(asgi, media_type): + client = create_client(asgi, { 'application/msgpack': media.MessagePackHandler(), 'application/x-msgpack': media.MessagePackHandler(), }) @@ -52,91 +96,77 @@ def test_msgpack(media_type): expected_body = b'\x81\xc4\tsomething\xc3' client.simulate_post('/', body=expected_body, headers=headers) - req_media = client.resource.captured_req.media + req_media = client.resource.captured_req_media assert req_media.get(b'something') is True # Unicode expected_body = b'\x81\xa9something\xc3' client.simulate_post('/', body=expected_body, headers=headers) - req_media = client.resource.captured_req.media + req_media = client.resource.captured_req_media assert req_media.get('something') is True @pytest.mark.parametrize('media_type', [ ('nope/json'), ]) -def test_unknown_media_type(media_type): - client = create_client() +def test_unknown_media_type(asgi, media_type): + client = _create_client_invalid_media(asgi, errors.HTTPUnsupportedMediaType) + headers = {'Content-Type': media_type} client.simulate_post('/', body=b'something', headers=headers) - with pytest.raises(errors.HTTPUnsupportedMediaType) as err: - client.resource.captured_req.media - title_msg = '415 Unsupported Media Type' description_msg = '{} is an unsupported media type.'.format(media_type) - assert err.value.title == title_msg - assert err.value.description == description_msg + assert client.resource.captured_error.value.title == title_msg + assert client.resource.captured_error.value.description == description_msg @pytest.mark.parametrize('media_type', [ ('application/json'), ]) -def test_exhausted_stream(media_type): - client = create_client({ +def test_exhausted_stream(asgi, media_type): + client = create_client(asgi, { 'application/json': media.JSONHandler(), }) headers = {'Content-Type': media_type} client.simulate_post('/', body='', headers=headers) - assert client.resource.captured_req.media is None + assert client.resource.captured_req_media is None + +def test_invalid_json(asgi): + client = _create_client_invalid_media(asgi, errors.HTTPBadRequest) -def test_invalid_json(): - client = create_client() expected_body = b'{' headers = {'Content-Type': 'application/json'} client.simulate_post('/', body=expected_body, headers=headers) - with pytest.raises(errors.HTTPBadRequest) as err: - client.resource.captured_req.media + assert 'Could not parse JSON body' in client.resource.captured_error.value.description - assert 'Could not parse JSON body' in err.value.description +def test_invalid_msgpack(asgi): + handlers = { + 'application/msgpack': media.MessagePackHandler() + } + client = _create_client_invalid_media(asgi, errors.HTTPBadRequest, handlers=handlers) -def test_invalid_msgpack(): - client = create_client({'application/msgpack': media.MessagePackHandler()}) expected_body = '/////////////////////' headers = {'Content-Type': 'application/msgpack'} client.simulate_post('/', body=expected_body, headers=headers) - with pytest.raises(errors.HTTPBadRequest) as err: - client.resource.captured_req.media - desc = 'Could not parse MessagePack body - unpack(b) received extra data.' - assert err.value.description == desc + assert client.resource.captured_error.value.description == desc -def test_invalid_stream_fails_gracefully(): - client = create_client() +def test_invalid_stream_fails_gracefully(client): client.simulate_post('/') req = client.resource.captured_req req.headers['Content-Type'] = 'application/json' req._bounded_stream = None - assert req.media is None - - -def test_use_cached_media(): - client = create_client() - client.simulate_post('/') - - req = client.resource.captured_req - req._media = {'something': True} - - assert req.media == {'something': True} + assert client.resource.captured_req_media is None class NopeHandler(media.BaseHandler): @@ -148,8 +178,8 @@ def deserialize(self, *args, **kwargs): pass -def test_complete_consumption(): - client = create_client({ +def test_complete_consumption(asgi): + client = create_client(asgi, { 'nope/nope': NopeHandler() }) body = b'{"something": "abracadabra"}' @@ -157,27 +187,27 @@ def test_complete_consumption(): client.simulate_post('/', body=body, headers=headers) - req_media = client.resource.captured_req.media + req_media = client.resource.captured_req_media assert req_media is None req_bounded_stream = client.resource.captured_req.bounded_stream - assert not req_bounded_stream.read() + assert req_bounded_stream.eof @pytest.mark.parametrize('payload', [False, 0, 0.0, '', [], {}]) -def test_empty_json_media(payload): - client = create_client() +def test_empty_json_media(asgi, payload): + resource = ResourceCachedMediaAsync() if asgi else ResourceCachedMedia() + client = create_client(asgi, resource=resource) client.simulate_post('/', json=payload) + assert resource.captured_req_media == payload - req = client.resource.captured_req - for access in range(3): - assert req.media == payload - -def test_null_json_media(): - client = create_client() +def test_null_json_media(client): client.simulate_post('/', body='null', headers={'Content-Type': 'application/json'}) + assert client.resource.captured_req_media is None - req = client.resource.captured_req - for access in range(3): - assert req.media is None + +def _create_client_invalid_media(asgi, error_type, handlers=None): + resource_type = ResourceInvalidMediaAsync if asgi else ResourceInvalidMedia + resource = resource_type(error_type) + return create_client(asgi, handlers=handlers, resource=resource) diff --git a/tests/test_response.py b/tests/test_response.py index 80b59b502..c3ea3d06a 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -1,25 +1,28 @@ import pytest -import falcon from falcon import MEDIA_TEXT +from _util import create_resp # NOQA -def test_response_set_content_type_set(): - resp = falcon.Response() + +@pytest.fixture(params=[True, False]) +def resp(request): + return create_resp(asgi=request.param) + + +def test_response_set_content_type_set(resp): resp._set_media_type(MEDIA_TEXT) assert resp._headers['content-type'] == MEDIA_TEXT -def test_response_set_content_type_not_set(): - resp = falcon.Response() +def test_response_set_content_type_not_set(resp): assert 'content-type' not in resp._headers resp._set_media_type() assert 'content-type' not in resp._headers -def test_response_get_headers(): - resp = falcon.Response() +def test_response_get_headers(resp): resp.append_header('x-things1', 'thing-1') resp.append_header('x-things2', 'thing-2') resp.append_header('X-Things3', 'Thing-3') @@ -34,9 +37,7 @@ def test_response_get_headers(): assert 'set-cookie' not in headers -def test_response_attempt_to_set_read_only_headers(): - resp = falcon.Response() - +def test_response_attempt_to_set_read_only_headers(resp): resp.append_header('x-things1', 'thing-1') resp.append_header('x-things2', 'thing-2') resp.append_header('x-things3', 'thing-3a') @@ -51,9 +52,7 @@ def test_response_attempt_to_set_read_only_headers(): assert headers['x-things3'] == 'thing-3a, thing-3b' -def test_response_removed_stream_len(): - resp = falcon.Response() - +def test_response_removed_stream_len(resp): with pytest.raises(AttributeError): resp.stream_len = 128 diff --git a/tests/test_response_body.py b/tests/test_response_body.py index d735b9d77..b8b8cc6e3 100644 --- a/tests/test_response_body.py +++ b/tests/test_response_body.py @@ -1,21 +1,92 @@ +import pytest import falcon +from falcon import testing +from _util import create_app, create_resp # NOQA -class TestResponseBody: - def test_append_body(self): - text = 'Hello beautiful world! ' - resp = falcon.Response() - resp.body = '' +@pytest.fixture +def resp(asgi): + return create_resp(asgi) - for token in text.split(): - resp.body += token - resp.body += ' ' - assert resp.body == text +def test_append_body(resp): + text = 'Hello beautiful world! ' + resp.body = '' - def test_response_repr(self): - resp = falcon.Response() - _repr = '<%s: %s>' % (resp.__class__.__name__, resp.status) - assert resp.__repr__() == _repr + for token in text.split(): + resp.body += token + resp.body += ' ' + + assert resp.body == text + + +def test_response_repr(resp): + _repr = '<%s: %s>' % (resp.__class__.__name__, resp.status) + assert resp.__repr__() == _repr + + +def test_content_length_set_on_head_with_no_body(asgi): + class NoBody: + def on_get(self, req, resp): + pass + + on_head = on_get + + app = create_app(asgi) + app.add_route('/', NoBody()) + + result = testing.simulate_head(app, '/') + + assert result.status_code == 200 + assert result.headers['content-length'] == '0' + + +@pytest.mark.parametrize('method', ['GET', 'HEAD']) +def test_content_length_not_set_when_streaming_response(asgi, method): + class SynthesizedHead: + def on_get(self, req, resp): + def words(): + for word in ('Hello', ',', ' ', 'World!'): + yield word.encode() + + resp.content_type = falcon.MEDIA_TEXT + resp.stream = words() + + on_head = on_get + + class SynthesizedHeadAsync: + async def on_get(self, req, resp): + # NOTE(kgriffs): Using an iterator in lieu of a generator + # makes this code parsable by 3.5 and also tests our support + # for iterators vs. generators. + class Words: + def __init__(self): + self._stream = iter(('Hello', ',', ' ', 'World!')) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._stream).encode() + except StopIteration: + pass # Test Falcon's PEP 479 support + + resp.content_type = falcon.MEDIA_TEXT + resp.stream = Words() + + on_head = on_get + + app = create_app(asgi) + app.add_route('/', SynthesizedHeadAsync() if asgi else SynthesizedHead()) + + result = testing.simulate_request(app, method) + + assert result.status_code == 200 + assert result.headers['content-type'] == falcon.MEDIA_TEXT + assert 'content-length' not in result.headers + + if method == 'GET': + assert result.text == 'Hello, World!' diff --git a/tests/test_response_context.py b/tests/test_response_context.py index 3396bd5a8..e7c2ee543 100644 --- a/tests/test_response_context.py +++ b/tests/test_response_context.py @@ -2,11 +2,23 @@ from falcon import Response +from _util import skipif_asgi_unsupported # NOQA + + +@pytest.fixture(params=[True, False]) +def resp_type(request): + if request.param: + skipif_asgi_unsupported() + import falcon.asgi + return falcon.asgi.Response + + return Response + class TestResponseContext: - def test_default_response_context(self): - resp = Response() + def test_default_response_context(self, resp_type): + resp = resp_type() resp.context.hello = 'World!' assert resp.context.hello == 'World!' @@ -17,31 +29,31 @@ def test_default_response_context(self): assert hasattr(resp.context, 'note') assert resp.context.get('note') == resp.context['note'] - def test_custom_response_context(self): + def test_custom_response_context(self, resp_type): class MyCustomContextType: pass - class MyCustomResponse(Response): + class MyCustomResponse(resp_type): context_type = MyCustomContextType resp = MyCustomResponse() assert isinstance(resp.context, MyCustomContextType) - def test_custom_response_context_failure(self): + def test_custom_response_context_failure(self, resp_type): - class MyCustomResponse(Response): + class MyCustomResponse(resp_type): context_type = False with pytest.raises(TypeError): MyCustomResponse() - def test_custom_response_context_factory(self): + def test_custom_response_context_factory(self, resp_type): def create_context(resp): return {'resp': resp} - class MyCustomResponse(Response): + class MyCustomResponse(resp_type): context_type = create_context resp = MyCustomResponse() diff --git a/tests/test_response_media.py b/tests/test_response_media.py index 359e0b107..700e8a607 100644 --- a/tests/test_response_media.py +++ b/tests/test_response_media.py @@ -6,6 +6,11 @@ from falcon import errors, media, testing +@pytest.fixture +def client(): + return create_client() + + def create_client(handlers=None): res = testing.SimpleTestResource() @@ -38,8 +43,7 @@ def on_get(self, req, resp): (falcon.MEDIA_JSON), ('application/json; charset=utf-8'), ]) -def test_json(media_type): - client = create_client() +def test_json(client, media_type): client.simulate_get('/') resp = client.resource.captured_resp @@ -99,8 +103,7 @@ def test_msgpack(media_type): assert resp.data == b'\x81\xa9something\xc3' -def test_unknown_media_type(): - client = create_client() +def test_unknown_media_type(client): client.simulate_get('/') resp = client.resource.captured_resp @@ -112,20 +115,17 @@ def test_unknown_media_type(): assert err.value.description == 'nope/json is an unsupported media type.' -def test_use_cached_media(): - expected = {'something': True} - - client = create_client() +def test_use_cached_media(client): client.simulate_get('/') resp = client.resource.captured_resp - resp._media = expected + expected = {'something': True} + resp._media = expected assert resp.media == expected -def test_default_media_type(): - client = create_client() +def test_default_media_type(client): client.simulate_get('/') resp = client.resource.captured_resp @@ -136,8 +136,7 @@ def test_default_media_type(): assert resp.content_type == 'application/json' -def test_mimeparse_edgecases(): - client = create_client() +def test_mimeparse_edgecases(client): client.simulate_get('/') resp = client.resource.captured_resp diff --git a/tests/test_sinks.py b/tests/test_sinks.py index 31ff4dab6..f1515337c 100644 --- a/tests/test_sinks.py +++ b/tests/test_sinks.py @@ -5,6 +5,8 @@ import falcon import falcon.testing as testing +from _util import create_app # NOQA + class Proxy: def forward(self, req): @@ -12,7 +14,6 @@ def forward(self, req): class Sink: - def __init__(self): self._proxy = Proxy() @@ -21,8 +22,9 @@ def __call__(self, req, resp, **kwargs): self.kwargs = kwargs -def sink_too(req, resp): - resp.status = falcon.HTTP_781 +class SinkAsync(Sink): + async def __call__(self, req, resp, **kwargs): + super().__call__(req, resp, **kwargs) class BookCollection(testing.SimpleTestResource): @@ -35,13 +37,13 @@ def resource(): @pytest.fixture -def sink(): - return Sink() +def sink(asgi): + return SinkAsync() if asgi else Sink() @pytest.fixture -def client(): - app = falcon.App() +def client(asgi): + app = create_app(asgi) return testing.TestClient(app) @@ -78,7 +80,14 @@ def test_named_groups(self, client, sink, resource): response = client.simulate_request(path='/user/sally') assert response.status == falcon.HTTP_404 - def test_multiple_patterns(self, client, sink, resource): + def test_multiple_patterns(self, asgi, client, sink, resource): + if asgi: + async def sink_too(req, resp): + resp.status = falcon.HTTP_781 + else: + def sink_too(req, resp): + resp.status = falcon.HTTP_781 + client.app.add_sink(sink, r'/foo') client.app.add_sink(sink_too, r'/foo') # Last duplicate wins diff --git a/tests/test_slots.py b/tests/test_slots.py index 3ff225258..64f48b87d 100644 --- a/tests/test_slots.py +++ b/tests/test_slots.py @@ -1,22 +1,25 @@ import pytest -from falcon import Request, Response import falcon.testing as testing class TestSlots: - def test_slots_request(self): - env = testing.create_environ() - req = Request(env) + def test_slots_request(self, asgi): + req = testing.create_asgi_req() if asgi else testing.create_req() try: req.doesnt = 'exist' except AttributeError: pytest.fail('Unable to add additional variables dynamically') - def test_slots_response(self): - resp = Response() + def test_slots_response(self, asgi): + if asgi: + import falcon.asgi + resp = falcon.asgi.Response() + else: + import falcon + resp = falcon.Response() try: resp.doesnt = 'exist' diff --git a/tests/test_static.py b/tests/test_static.py index 8141553d0..d8e60b639 100644 --- a/tests/test_static.py +++ b/tests/test_static.py @@ -6,16 +6,23 @@ import pytest import falcon -from falcon.request import Request -from falcon.response import Response -from falcon.routing import StaticRoute +from falcon.routing import StaticRoute, StaticRouteAsync import falcon.testing as testing +import _util # NOQA -@pytest.fixture -def client(): - app = falcon.App() - return testing.TestClient(app) + +@pytest.fixture(params=[True, False]) +def client(request): + app = _util.create_app(asgi=request.param) + client = testing.TestClient(app) + client.asgi = request.param + return client + + +def create_sr(asgi, *args, **kwargs): + sr_type = StaticRouteAsync if asgi else StaticRoute + return sr_type(*args, **kwargs) @pytest.mark.parametrize('uri', [ @@ -78,21 +85,26 @@ def client(): # Invalid unicode character '/static/\ufffdsomething', ]) -def test_bad_path(uri, monkeypatch): - monkeypatch.setattr(io, 'open', lambda path, mode: path) +def test_bad_path(asgi, uri, monkeypatch): + monkeypatch.setattr(io, 'open', lambda path, mode: io.BytesIO()) - sr = StaticRoute('/static', '/var/www/statics') + sr_type = StaticRouteAsync if asgi else StaticRoute + sr = sr_type('/static', '/var/www/statics') - req = Request(testing.create_environ( + req = _util.create_req( + asgi, host='test.com', path=uri, - app='statics' - )) + root_path='statics' + ) - resp = Response() + resp = _util.create_resp(asgi) with pytest.raises(falcon.HTTPNotFound): - sr(req, resp) + if asgi: + testing.invoke_coroutine_sync(sr, req, resp) + else: + sr(req, resp) @pytest.mark.parametrize('prefix, directory', [ @@ -101,9 +113,9 @@ def test_bad_path(uri, monkeypatch): ('/static', 'statics'), ('/static', '../statics'), ]) -def test_invalid_args(prefix, directory, client): +def test_invalid_args(client, prefix, directory): with pytest.raises(ValueError): - StaticRoute(prefix, directory) + create_sr(client.asgi, prefix, directory) with pytest.raises(ValueError): client.app.add_static_route(prefix, directory) @@ -118,7 +130,7 @@ def test_invalid_args(prefix, directory, client): def test_invalid_args_fallback_filename(client, default): prefix, directory = '/static', '/var/www/statics' with pytest.raises(ValueError, match='fallback_filename'): - StaticRoute(prefix, directory, fallback_filename=default) + create_sr(client.asgi, prefix, directory, fallback_filename=default) with pytest.raises(ValueError, match='fallback_filename'): client.app.add_static_route(prefix, directory, fallback_filename=default) @@ -141,30 +153,39 @@ def test_invalid_args_fallback_filename(client, default): ('/some/download', '/foo/../bar/../report.zip', '/report.zip', 'application/zip'), ('/some/download', '/foo/bar/../../report.zip', '/report.zip', 'application/zip'), ]) -def test_good_path(uri_prefix, uri_path, expected_path, mtype, monkeypatch): - monkeypatch.setattr(io, 'open', lambda path, mode: path) +def test_good_path(asgi, uri_prefix, uri_path, expected_path, mtype, monkeypatch): + monkeypatch.setattr(io, 'open', lambda path, mode: io.BytesIO(path.encode())) - sr = StaticRoute(uri_prefix, '/var/www/statics') + sr = create_sr(asgi, uri_prefix, '/var/www/statics') req_path = uri_prefix[:-1] if uri_prefix.endswith('/') else uri_prefix req_path += uri_path - req = Request(testing.create_environ( + req = _util.create_req( + asgi, host='test.com', path=req_path, - app='statics' - )) + root_path='statics' + ) + + resp = _util.create_resp(asgi) - resp = Response() + if asgi: + async def run(): + await sr(req, resp) + return await resp.stream.read() - sr(req, resp) + body = testing.invoke_coroutine_sync(run) + else: + sr(req, resp) + body = resp.stream.read() assert resp.content_type == mtype - assert resp.stream == '/var/www/statics' + expected_path + assert body.decode() == '/var/www/statics' + expected_path def test_lifo(client, monkeypatch): - monkeypatch.setattr(io, 'open', lambda path, mode: [path.encode('utf-8')]) + monkeypatch.setattr(io, 'open', lambda path, mode: io.BytesIO(path.encode())) client.app.add_static_route('/downloads', '/opt/somesite/downloads') client.app.add_static_route('/downloads/archive', '/opt/somesite/x') @@ -179,7 +200,7 @@ def test_lifo(client, monkeypatch): def test_lifo_negative(client, monkeypatch): - monkeypatch.setattr(io, 'open', lambda path, mode: [path.encode('utf-8')]) + monkeypatch.setattr(io, 'open', lambda path, mode: io.BytesIO(path.encode())) client.app.add_static_route('/downloads/archive', '/opt/somesite/x') client.app.add_static_route('/downloads', '/opt/somesite/downloads') @@ -194,7 +215,7 @@ def test_lifo_negative(client, monkeypatch): def test_downloadable(client, monkeypatch): - monkeypatch.setattr(io, 'open', lambda path, mode: [path.encode('utf-8')]) + monkeypatch.setattr(io, 'open', lambda path, mode: io.BytesIO(path.encode())) client.app.add_static_route('/downloads', '/opt/somesite/downloads', downloadable=True) client.app.add_static_route('/assets/', '/opt/somesite/assets') @@ -228,32 +249,48 @@ def test_downloadable_not_found(client): ('index.html_files/test.txt', 'index.html', 'index.html_files/test.txt', 'text/plain'), ]) @pytest.mark.parametrize('downloadable', [True, False]) -def test_fallback_filename(uri, default, expected, content_type, downloadable, +def test_fallback_filename(asgi, uri, default, expected, content_type, downloadable, monkeypatch): - def mockOpen(path, mode): + def mock_open(path, mode): if default in path: - return path + return io.BytesIO(path.encode()) + raise IOError() - monkeypatch.setattr(io, 'open', mockOpen) + monkeypatch.setattr(io, 'open', mock_open) monkeypatch.setattr('os.path.isfile', lambda file: default in file) - sr = StaticRoute('/static', '/var/www/statics', downloadable=downloadable, - fallback_filename=default) + sr = create_sr( + asgi, + '/static', + '/var/www/statics', + downloadable=downloadable, + fallback_filename=default + ) req_path = '/static/' + uri - req = Request(testing.create_environ( + req = _util.create_req( + asgi, host='test.com', path=req_path, - app='statics' - )) - resp = Response() - sr(req, resp) + root_path='statics' + ) + resp = _util.create_resp(asgi) + + if asgi: + async def run(): + await sr(req, resp) + return await resp.stream.read() + + body = testing.invoke_coroutine_sync(run) + else: + sr(req, resp) + body = resp.stream.read() assert sr.match(req.path) - assert resp.stream == os.path.join('/var/www/statics', expected) + assert body.decode() == os.path.join('/var/www/statics', expected) assert resp.content_type == content_type if downloadable: @@ -275,7 +312,7 @@ def test_e2e_fallback_filename(client, monkeypatch, strip_slash, path, fallback, def mockOpen(path, mode): if 'index' in path and 'raise' not in path: - return [path.encode('utf-8')] + return io.BytesIO(path.encode()) raise IOError() monkeypatch.setattr(io, 'open', mockOpen) @@ -308,9 +345,9 @@ def test(prefix, directory, expected): ('index2', '/staticfoo', False), ('index2', '/static/foo', True), ]) -def test_match(default, path, expected, monkeypatch): +def test_match(asgi, default, path, expected, monkeypatch): monkeypatch.setattr('os.path.isfile', lambda file: True) - sr = StaticRoute('/static', '/var/www/statics', fallback_filename=default) + sr = create_sr(asgi, '/static', '/var/www/statics', fallback_filename=default) assert sr.match(path) == expected diff --git a/tests/test_testing.py b/tests/test_testing.py index e1647cfd1..96b7970f4 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -1,7 +1,8 @@ +import time + import pytest -from falcon import status_codes -from falcon.testing import closed_wsgi_iterable, TestClient +from falcon import API, status_codes, testing def another_dummy_wsgi_app(environ, start_response): @@ -11,7 +12,7 @@ def another_dummy_wsgi_app(environ, start_response): def test_testing_client_handles_wsgi_generator_app(): - client = TestClient(another_dummy_wsgi_app) + client = testing.TestClient(another_dummy_wsgi_app) response = client.simulate_get('/nevermind') @@ -26,4 +27,134 @@ def test_testing_client_handles_wsgi_generator_app(): (b'Hello, ', b'World', b'!\n'), ]) def test_closed_wsgi_iterable(items): - assert tuple(closed_wsgi_iterable(items)) == items + assert tuple(testing.closed_wsgi_iterable(items)) == items + + +@pytest.mark.parametrize('version, valid', [ + ('1', True), + ('1.0', True), + ('1.1', True), + ('2', True), + ('2.0', True), + ('', False), + ('0', False), + ('1.2', False), + ('2.1', False), + ('3', False), + ('3.1', False), + ('11', False), + ('22', False), +]) +def test_simulate_request_http_version(version, valid): + app = API() + + if valid: + testing.simulate_request(app, http_version=version) + else: + with pytest.raises(ValueError): + testing.simulate_request(app, http_version=version) + + +def test_asgi_request_event_emitter_hang(): + # NOTE(kgriffs): This tests the ASGI server behavior that + # ASGIRequestEventEmitter simulates when emit() is called + # again after there are not more events available. + + expected_elasped_min = 1 + disconnect_at = time.time() + expected_elasped_min + + emit = testing.ASGIRequestEventEmitter(disconnect_at=disconnect_at) + + async def t(): + start = time.time() + while True: + event = await emit() + if not event.get('more_body', False): + break + elapsed = time.time() - start + + assert elapsed < 0.1 + + start = time.time() + await emit() + elapsed = time.time() - start + + assert (elapsed + 0.1) > expected_elasped_min + + testing.invoke_coroutine_sync(t) + + +def test_ignore_extra_asgi_events(): + collect = testing.ASGIResponseEventCollector() + + async def t(): + await collect({'type': 'http.response.start', 'status': 200}) + await collect({'type': 'http.response.body', 'more_body': False}) + + # NOTE(kgriffs): Events after more_body is False are ignored to conform + # to the ASGI spec. + await collect({'type': 'http.response.body'}) + assert len(collect.events) == 2 + + testing.invoke_coroutine_sync(t) + + +def test_invalid_asgi_events(): + collect = testing.ASGIResponseEventCollector() + + def make_event(headers=None, status=200): + return { + 'type': 'http.response.start', + 'headers': headers or [], + 'status': status + } + + async def t(): + with pytest.raises(TypeError): + await collect({'type': 123}) + + with pytest.raises(TypeError): + headers = [ + ('notbytes', b'bytes') + ] + await collect(make_event(headers)) + + with pytest.raises(TypeError): + headers = [ + (b'bytes', 'notbytes') + ] + await collect(make_event(headers)) + + with pytest.raises(ValueError): + headers = [ + # NOTE(kgriffs): Name must be lowercase + (b'Content-Type', b'application/json') + ] + await collect(make_event(headers)) + + with pytest.raises(TypeError): + await collect(make_event(status='200')) + + with pytest.raises(TypeError): + await collect(make_event(status=200.1)) + + with pytest.raises(TypeError): + await collect({'type': 'http.response.body', 'body': 'notbytes'}) + + with pytest.raises(TypeError): + await collect({'type': 'http.response.body', 'more_body': ''}) + + with pytest.raises(ValueError): + # NOTE(kgriffs): Invalid type + await collect({'type': 'http.response.bod'}) + + testing.invoke_coroutine_sync(t) + + +def test_is_asgi_app_cls(): + class Foo: + @classmethod + def class_meth(cls, scope, receive, send): + pass + + assert testing.client._is_asgi_app(Foo.class_meth) diff --git a/tests/test_uri_templates.py b/tests/test_uri_templates.py index 78ec17f38..661edb478 100644 --- a/tests/test_uri_templates.py +++ b/tests/test_uri_templates.py @@ -14,6 +14,8 @@ from falcon import testing from falcon.routing.util import SuffixedMethodNotFoundError +from _util import create_app # NOQA + _TEST_UUID = uuid.uuid4() _TEST_UUID_2 = uuid.uuid4() @@ -120,8 +122,8 @@ def resource(): @pytest.fixture -def client(): - return testing.TestClient(falcon.App()) +def client(asgi): + return testing.TestClient(create_app(asgi)) def test_root_path(client, resource): diff --git a/tests/test_utils.py b/tests/test_utils.py index 09385469e..a9d63be32 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,9 @@ from datetime import datetime import functools +import http import itertools +import os import random from urllib.parse import quote, unquote_plus @@ -13,6 +15,13 @@ from falcon import util from falcon.util import json, misc, structures, uri +from _util import create_app, to_coroutine # NOQA + + +@pytest.fixture +def app(asgi): + return create_app(asgi) + def _arbitrary_uris(count, length): return ( @@ -37,8 +46,10 @@ def test_deprecated_decorator(self): def old_thing(): pass + del os.environ['FALCON_TESTING_SESSION'] with pytest.warns(UserWarning) as rec: old_thing() + os.environ['FALCON_TESTING_SESSION'] = '1' warn = rec.pop() assert msg in str(warn.message) @@ -379,6 +390,56 @@ def test_get_http_status(self): falcon.get_http_status('-404.3') assert falcon.get_http_status(123, 'Go Away') == '123 Go Away' + @pytest.mark.parametrize( + 'v_in,v_out', + [ + (703, falcon.HTTP_703), + (404, falcon.HTTP_404), + (404.9, falcon.HTTP_404), + ('404', falcon.HTTP_404), + (123, '123'), + ] + ) + def test_code_to_http_status(self, v_in, v_out): + assert falcon.code_to_http_status(v_in) == v_out + + @pytest.mark.parametrize( + 'v', + ['not_a_number', 0, '0', 99, '99', '404.3', -404.3, '-404', '-404.3'] + ) + def test_code_to_http_status_neg(self, v): + with pytest.raises(ValueError): + falcon.code_to_http_status(v) + + @pytest.mark.parametrize( + 'v_in,v_out', + [ + # NOTE(kgriffs): Include some codes not used elsewhere so that + # we get past the LRU. + (http.HTTPStatus(505), 505), + (712, 712), + ('712', 712), + (b'404 Not Found', 404), + (b'712 NoSQL', 712), + ('404 Not Found', 404), + ('123 Wow Such Status', 123), + + # NOTE(kgriffs): Test LRU + (http.HTTPStatus(505), 505), + ('123 Wow Such Status', 123), + ] + ) + def test_http_status_to_code(self, v_in, v_out): + assert falcon.http_status_to_code(v_in) == v_out + + @pytest.mark.parametrize( + 'v', + ['', ' ', '1', '12', '99', 'catsup', b'', 5.2] + ) + def test_http_status_to_code_neg(self, v): + with pytest.raises(ValueError): + falcon.http_status_to_code(v) + def test_etag_dumps_to_header_format(self): etag = structures.ETag('67ab43') @@ -429,14 +490,17 @@ def test_etag_strong_vs_weak_comparison(self): falcon.HTTP_METHODS * 2 ) ) -def test_simulate_request_protocol(protocol, method): +def test_simulate_request_protocol(asgi, protocol, method): sink_called = [False] def sink(req, resp): sink_called[0] = True assert req.protocol == protocol - app = falcon.App() + if asgi: + sink = to_coroutine(sink) + + app = create_app(asgi) app.add_sink(sink, '/test') client = testing.TestClient(app) @@ -459,13 +523,16 @@ def sink(req, resp): testing.simulate_patch, testing.simulate_delete, ]) -def test_simulate_free_functions(simulate): +def test_simulate_free_functions(asgi, simulate): sink_called = [False] def sink(req, resp): sink_called[0] = True - app = falcon.App() + if asgi: + sink = to_coroutine(sink) + + app = create_app(asgi) app.add_sink(sink, '/test') simulate(app, '/test') @@ -491,8 +558,7 @@ def test_none_header_value_in_create_environ(self): env = testing.create_environ('/', headers={'X-Foo': None}) assert env['HTTP_X_FOO'] == '' - def test_decode_empty_result(self): - app = falcon.App() + def test_decode_empty_result(self, app): client = testing.TestClient(app) response = client.simulate_request(path='/') assert response.text == '' @@ -500,8 +566,7 @@ def test_decode_empty_result(self): def test_httpnow_alias_for_backwards_compat(self): assert testing.httpnow is util.http_now - def test_default_headers(self): - app = falcon.App() + def test_default_headers(self, app): resource = testing.SimpleTestResource() app.add_route('/', resource) @@ -517,8 +582,7 @@ def test_default_headers(self): client.simulate_get(headers=None) assert resource.captured_req.auth == headers['Authorization'] - def test_default_headers_with_override(self): - app = falcon.App() + def test_default_headers_with_override(self, app): resource = testing.SimpleTestResource() app.add_route('/', resource) @@ -538,8 +602,7 @@ def test_default_headers_with_override(self): assert resource.captured_req.accept == headers['Accept'] assert resource.captured_req.get_header('X-Override-Me') == override_after - def test_status(self): - app = falcon.App() + def test_status(self, app): resource = testing.SimpleTestResource(status=falcon.HTTP_702) app.add_route('/', resource) client = testing.TestClient(app) @@ -552,26 +615,28 @@ def test_wsgi_iterable_not_closeable(self): assert not result.content assert result.json is None - def test_path_must_start_with_slash(self): - app = falcon.App() + def test_path_must_start_with_slash(self, app): app.add_route('/', testing.SimpleTestResource()) client = testing.TestClient(app) with pytest.raises(ValueError): client.simulate_get('foo') - def test_cached_text_in_result(self): - app = falcon.App() + def test_cached_text_in_result(self, app): app.add_route('/', testing.SimpleTestResource(body='test')) client = testing.TestClient(app) result = client.simulate_get() assert result.text == result.text - def test_simple_resource_body_json_xor(self): + @pytest.mark.parametrize('resource_type', [ + testing.SimpleTestResource, + testing.SimpleTestResourceAsync, + ]) + def test_simple_resource_body_json_xor(self, resource_type): with pytest.raises(ValueError): - testing.SimpleTestResource(body='', json={}) + resource_type(body='', json={}) - def test_query_string(self): + def test_query_string(self, app): class SomeResource: def on_get(self, req, resp): doc = {} @@ -583,7 +648,6 @@ def on_get(self, req, resp): resp.body = json.dumps(doc) - app = falcon.App() app.req_options.auto_parse_qs_csv = True app.add_route('/', SomeResource()) client = testing.TestClient(app) @@ -614,15 +678,13 @@ def on_get(self, req, resp): params_csv=False) assert result.json['query_string'] == expected_qs - def test_query_string_no_question(self): - app = falcon.App() + def test_query_string_no_question(self, app): app.add_route('/', testing.SimpleTestResource()) client = testing.TestClient(app) with pytest.raises(ValueError): client.simulate_get(query_string='?x=1') - def test_query_string_in_path(self): - app = falcon.App() + def test_query_string_in_path(self, app): resource = testing.SimpleTestResource() app.add_route('/thing', resource) client = testing.TestClient(app) @@ -660,25 +722,25 @@ def test_query_string_in_path(self): 'next': None, }, ]) - def test_simulate_json_body(self, document): - app = falcon.App() - resource = testing.SimpleTestResource() + def test_simulate_json_body(self, asgi, document): + resource = testing.SimpleTestResourceAsync() if asgi else testing.SimpleTestResource() + app = create_app(asgi) app.add_route('/', resource) json_types = ('application/json', 'application/json; charset=UTF-8') client = testing.TestClient(app) - client.simulate_post('/', json=document) - captured_body = resource.captured_req.bounded_stream.read().decode('utf-8') - assert json.loads(captured_body) == document + client.simulate_post('/', json=document, headers={'capture-req-body-bytes': '-1'}) + assert json.loads(resource.captured_req_body.decode()) == document assert resource.captured_req.content_type in json_types headers = { 'Content-Type': 'x-falcon/peregrine', 'X-Falcon-Type': 'peregrine', + 'capture-req-media': 'y' } body = 'If provided, `json` parameter overrides `body`.' client.simulate_post('/', headers=headers, body=body, json=document) - assert resource.captured_req.media == document + assert resource.captured_req_media == document assert resource.captured_req.content_type in json_types assert resource.captured_req.get_header('X-Falcon-Type') == 'peregrine' @@ -689,13 +751,12 @@ def test_simulate_json_body(self, document): '104.24.101.85', '2606:4700:30::6818:6455', ]) - def test_simulate_remote_addr(self, remote_addr): + def test_simulate_remote_addr(self, app, remote_addr): class ShowMyIPResource: def on_get(self, req, resp): resp.body = req.remote_addr resp.content_type = falcon.MEDIA_TEXT - app = falcon.App() app.add_route('/', ShowMyIPResource()) client = testing.TestClient(app) @@ -707,8 +768,7 @@ def on_get(self, req, resp): else: assert resp.text == remote_addr - def test_simulate_hostname(self): - app = falcon.App() + def test_simulate_hostname(self, app): resource = testing.SimpleTestResource() app.add_route('/', resource) @@ -720,7 +780,7 @@ def test_simulate_hostname(self): @pytest.mark.parametrize('extras,expected_headers', [ ( {}, - (('user-agent', 'curl/7.24.0 (x86_64-apple-darwin12.0)'),), + (('user-agent', None),), ), ( {'HTTP_USER_AGENT': 'URL/Emacs', 'HTTP_X_FALCON': 'peregrine'}, @@ -738,17 +798,20 @@ def test_simulate_with_environ_extras(self, extras, expected_headers): for header, value in expected_headers: assert resource.captured_req.get_header(header) == value - def test_override_method_with_extras(self): - app = falcon.App() + def test_override_method_with_extras(self, asgi): + app = create_app(asgi) app.add_route('/', testing.SimpleTestResource(body='test')) client = testing.TestClient(app) with pytest.raises(ValueError): - client.simulate_get('/', extras={'REQUEST_METHOD': 'PATCH'}) + if asgi: + client.simulate_get('/', extras={'method': 'PATCH'}) + else: + client.simulate_get('/', extras={'REQUEST_METHOD': 'PATCH'}) - resp = client.simulate_get('/', extras={'REQUEST_METHOD': 'GET'}) - assert resp.status_code == 200 - assert resp.text == 'test' + result = client.simulate_get('/', extras={'REQUEST_METHOD': 'GET'}) + assert result.status_code == 200 + assert result.text == 'test' class TestNoApiClass(testing.TestCase): @@ -759,10 +822,16 @@ def test_something(self): class TestSetupApi(testing.TestCase): def setUp(self): super(TestSetupApi, self).setUp() - self.api = falcon.App() + self.app = falcon.API() + self.app.add_route('/', testing.SimpleTestResource(body='test')) def test_something(self): - self.assertTrue(isinstance(self.api, falcon.App)) + self.assertTrue(isinstance(self.app, falcon.API)) + self.assertTrue(isinstance(self.app, falcon.App)) + + result = self.simulate_get() + assert result.status_code == 200 + assert result.text == 'test' def test_get_argnames(): diff --git a/tests/test_validators.py b/tests/test_validators.py index 0d3c13012..7bf6e0b16 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -1,15 +1,29 @@ +import typing + try: - import jsonschema + import jsonschema as _jsonschema # NOQA except ImportError: - jsonschema = None - + pass import pytest import falcon from falcon import testing from falcon.media import validators -basic_schema = { +from _util import create_app, disable_asgi_non_coroutine_wrapping # NOQA + + +# NOTE(kgriffs): Default to None if missing. We do it like this, here, instead +# of in the body of the except statement, above, to avoid flake8 import +# ordering errors. +jsonschema = globals().get('_jsonschema') + + +_VALID_MEDIA = {'message': 'something'} +_INVALID_MEDIA = {} # type: typing.Dict[str, str] + + +_TEST_SCHEMA = { 'type': 'object', 'properies': { 'message': { @@ -27,93 +41,156 @@ class Resource: - @validators.jsonschema.validate(req_schema=basic_schema) + @validators.jsonschema.validate(req_schema=_TEST_SCHEMA) def request_validated(self, req, resp): assert req.media is not None return resp - @validators.jsonschema.validate(resp_schema=basic_schema) + @validators.jsonschema.validate(resp_schema=_TEST_SCHEMA) def response_validated(self, req, resp): assert resp.media is not None return resp - @validators.jsonschema.validate(req_schema=basic_schema, resp_schema=basic_schema) + @validators.jsonschema.validate(req_schema=_TEST_SCHEMA, resp_schema=_TEST_SCHEMA) def both_validated(self, req, resp): assert req.media is not None assert resp.media is not None return req, resp - @validators.jsonschema.validate(req_schema=basic_schema, resp_schema=basic_schema) + @validators.jsonschema.validate(req_schema=_TEST_SCHEMA, resp_schema=_TEST_SCHEMA) def on_put(self, req, resp): assert req.media is not None - resp.media = GoodData.media + resp.media = _VALID_MEDIA + + +class ResourceAsync: + @validators.jsonschema.validate(req_schema=_TEST_SCHEMA) + async def request_validated(self, req, resp): + # NOTE(kgriffs): Verify that we can await req.get_media() multiple times + for i in range(3): + m = await req.get_media() + assert m == _VALID_MEDIA + + assert m is not None + return resp + + @validators.jsonschema.validate(resp_schema=_TEST_SCHEMA) + async def response_validated(self, req, resp): + assert resp.media is not None + return resp + + @validators.jsonschema.validate(req_schema=_TEST_SCHEMA, resp_schema=_TEST_SCHEMA) + async def both_validated(self, req, resp): + m = await req.get_media() + assert m is not None + + assert resp.media is not None + + return req, resp + + @validators.jsonschema.validate(req_schema=_TEST_SCHEMA, resp_schema=_TEST_SCHEMA) + async def on_put(self, req, resp): + m = await req.get_media() + assert m is not None + resp.media = _VALID_MEDIA + + +class _MockReq: + def __init__(self, valid=True): + self.media = _VALID_MEDIA if valid else {} + + +class _MockReqAsync: + def __init__(self, valid=True): + self._media = _VALID_MEDIA if valid else {} + + async def get_media(self): + return self._media -class GoodData: - media = {'message': 'something'} # type: ignore +def MockReq(asgi, valid=True): + return _MockReqAsync(valid) if asgi else _MockReq(valid) -class BadData: - media = {} # type: ignore +class MockResp: + def __init__(self, valid=True): + self.media = _VALID_MEDIA if valid else {} + + +def call_method(asgi, method_name, *args): + resource = ResourceAsync() if asgi else Resource() + + if asgi: + return testing.invoke_coroutine_sync(getattr(resource, method_name), *args) + + return getattr(resource, method_name)(*args) @skip_missing_dep -def test_req_schema_validation_success(): - data = GoodData() - assert Resource().request_validated(GoodData(), data) is data +def test_req_schema_validation_success(asgi): + data = MockResp() + assert call_method(asgi, 'request_validated', MockReq(asgi), data) is data @skip_missing_dep -def test_req_schema_validation_failure(): +def test_req_schema_validation_failure(asgi): with pytest.raises(falcon.HTTPBadRequest) as excinfo: - Resource().request_validated(BadData(), None) + call_method(asgi, 'request_validated', MockReq(asgi, False), None) assert excinfo.value.description == "'message' is a required property" @skip_missing_dep -def test_resp_schema_validation_success(): - data = GoodData() - assert Resource().response_validated(GoodData(), data) is data +def test_resp_schema_validation_success(asgi): + data = MockResp() + assert call_method(asgi, 'response_validated', MockReq(asgi), data) is data @skip_missing_dep -def test_resp_schema_validation_failure(): +def test_resp_schema_validation_failure(asgi): with pytest.raises(falcon.HTTPInternalServerError) as excinfo: - Resource().response_validated(GoodData(), BadData()) + call_method(asgi, 'response_validated', MockReq(asgi), MockResp(False)) assert excinfo.value.title == 'Response data failed validation' @skip_missing_dep -def test_both_schemas_validation_success(): - req_data = GoodData() - resp_data = GoodData() +def test_both_schemas_validation_success(asgi): + req = MockReq(asgi) + resp = MockResp() - result = Resource().both_validated(req_data, resp_data) + result = call_method(asgi, 'both_validated', req, resp) - assert result[0] is req_data - assert result[1] is resp_data + assert result[0] is req + assert result[1] is resp - client = testing.TestClient(falcon.App()) - client.app.add_route('/test', Resource()) - result = client.simulate_put('/test', json=GoodData.media) - assert result.json == resp_data.media + client = testing.TestClient(create_app(asgi)) + resource = ResourceAsync() if asgi else Resource() + client.app.add_route('/test', resource) + + result = client.simulate_put('/test', json=_VALID_MEDIA) + assert result.json == resp.media @skip_missing_dep -def test_both_schemas_validation_failure(): +def test_both_schemas_validation_failure(asgi): + bad_resp = MockResp(False) + with pytest.raises(falcon.HTTPInternalServerError) as excinfo: - Resource().both_validated(GoodData(), BadData()) + call_method(asgi, 'both_validated', MockReq(asgi), bad_resp) assert excinfo.value.title == 'Response data failed validation' with pytest.raises(falcon.HTTPBadRequest) as excinfo: - Resource().both_validated(BadData(), GoodData()) + call_method(asgi, 'both_validated', MockReq(asgi, False), MockResp()) assert excinfo.value.title == 'Request data failed validation' - client = testing.TestClient(falcon.App()) - client.app.add_route('/test', Resource()) - result = client.simulate_put('/test', json=BadData.media) + client = testing.TestClient(create_app(asgi)) + resource = ResourceAsync() if asgi else Resource() + + with disable_asgi_non_coroutine_wrapping(): + client.app.add_route('/test', resource) + + result = client.simulate_put('/test', json=_INVALID_MEDIA) assert result.status_code == 400 diff --git a/tests/test_wsgi.py b/tests/test_wsgi.py index 2360020e6..74ecbc84d 100644 --- a/tests/test_wsgi.py +++ b/tests/test_wsgi.py @@ -104,6 +104,7 @@ def _setup_wsgi_server(): stop_event = multiprocessing.Event() process = multiprocessing.Process( target=_run_server, + daemon=True, # NOTE(kgriffs): Pass these explicitly since if multiprocessing is # using the 'spawn' start method, we can't depend on closures. diff --git a/tools/mintest.sh b/tools/mintest.sh index bbc0bb3cd..771974dc4 100755 --- a/tools/mintest.sh +++ b/tools/mintest.sh @@ -3,4 +3,4 @@ pip install -U tox coverage rm -f .coverage.* -tox -e pep8 && tox -e py38 && tools/testing/combine_coverage.sh +tox -e pep8 && tox -e py35,py38 && tools/testing/combine_coverage.sh diff --git a/tox.ini b/tox.ini index 0f4df6442..e797748cf 100644 --- a/tox.ini +++ b/tox.ini @@ -25,6 +25,9 @@ envlist = py38, [testenv] setenv = PIP_CONFIG_FILE={toxinidir}/pip.conf + FALCON_ASGI_WRAP_NON_COROUTINES=Y + FALCON_TESTING_SESSION=Y + PYTHONASYNCIODEBUG=1 deps = -r{toxinidir}/requirements/tests commands = {toxinidir}/tools/clean.sh {toxinidir}/falcon pytest tests [] @@ -39,6 +42,14 @@ whitelist_externals = mkdir commands = "{toxinidir}/tools/clean.sh" "{toxinidir}/falcon" coverage run -m pytest tests [] +[testenv:py35] +deps = {[testenv]deps} + pytest-randomly + jsonschema +whitelist_externals = {[with-coverage]whitelist_externals} +commands = "{toxinidir}/tools/clean.sh" "{toxinidir}/falcon" + coverage run -m pytest tests --ignore=tests/asgi [] + [testenv:py38] deps = {[testenv]deps} pytest-randomly @@ -50,9 +61,6 @@ commands = {[with-coverage]commands} # Additional test suite environments # -------------------------------------------------------------------- -[testenv:pypy] -basepython = pypy - [testenv:pypy3] basepython = pypy3 @@ -67,6 +75,8 @@ deps = -r{toxinidir}/requirements/tests [testenv:py3_debug] basepython = python3.8 deps = {[with-debug-tools]deps} + uvicorn + jsonschema # -------------------------------------------------------------------- # mypy @@ -84,18 +94,31 @@ commands = {toxinidir}/tools/clean.sh {toxinidir}/falcon [with-cython] deps = -r{toxinidir}/requirements/tests cython +setenv = + PIP_CONFIG_FILE={toxinidir}/pip.conf + FALCON_CYTHON=Y + FALCON_ASGI_WRAP_NON_COROUTINES=Y + FALCON_TESTING_SESSION=Y + PYTHONASYNCIODEBUG=1 +commands = pytest tests [] [testenv:py35_cython] basepython = python3.5 deps = {[with-cython]deps} +setenv = {[with-cython]setenv} +commands = {[with-cython]commands} [testenv:py36_cython] basepython = python3.6 deps = {[with-cython]deps} +setenv = {[with-cython]setenv} +commands = {[with-cython]commands} [testenv:py37_cython] basepython = python3.7 deps = {[with-cython]deps} +setenv = {[with-cython]setenv} +commands = {[with-cython]commands} [testenv:py38_cython] basepython = python3.8