From 6b03411a5c356316596946261dfa568ae2eb0b3b Mon Sep 17 00:00:00 2001 From: Eugene Xu <105043487@qq.com> Date: Sat, 10 Apr 2021 16:26:30 +0800 Subject: [PATCH] PR (#1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * iostream: check that stream is open before trying to read (#2670) * curl_httpclient: fix disabled decompress_response by setting None (NULL) instead of "none" for ENCODING reported by Andrey Oparin * tests: run test_gzip for curl_httpclient also move simple_httpclient test_gzip to the shared httpclient tests, to test the decompress_response option for curl_httpclient as well * mypy: Enable no_implicit_optional "Implicit-optional" mode is on by default, but that default is intended to change in the indefinite future (python/peps#689, python/typing#275). Go ahead and change to the future explicit use of Optional. * gen.with_timeout: Don't log CancelledError after timeout See also: commit a237a995a1d54ad6e07c1ecdf5103ff8f45073b5 * Fix ReST syntax * locks: Remove redundant CancelledError handling CancelledError is now always considered "quiet" (and concurrent.futures.CancelledError is no longer the same as asyncio.CancelledError). * ci: Re-enable nightly python Fixes #2677 * gen: Clean up docs for with_timeout Mark CancelledError change as 6.0.3 * *: Modernize IO error handling Where possible, replace use of errno with the exception hierarchy available since python 3.3. Remove explicit handling of EINTR which has been automatic since python 3.5 * netutil: Ignore EADDRNOTAVAIL when binding to localhost ipv6 This happens in docker with default configurations and is generally harmless. Fixes #2274 * test: Skip test_source_port_fail when running as root Root is always allowed to bind to low port numbers, so we can't simulate failure in this case. This is the last remaining failure when running tests in docker. * docs: Add notice about WindowsSelectorEventLoop on py38 Fixes #2608 * Bump version of twisted to pick up security fix * Release notes for 6.0.3 * SSLIOStream: Handle CertificateErrors like other errors Fixes: tornadoweb/tornado#2689 * Update database backend reference * Strip documentation about removed argument * Mark Template autoescape kwarg as Optional * httputil: cache header normalization with @lru_cache instead of hand-rolling Tornado is now py3-only so @lru_cache is always available. Performance is about the same. Benchmark below. Python 3.7 on Linux. before, cached: 0.9121252089971676 before, uncached: 13.358482279989403 after, cached: 0.9175888689933345 after, uncached: 11.085199063003529 ```py from time import perf_counter names = [f'sOMe-RanDOM-hEAdeR-{i}' for i in range(1000)] from tornado.httputil import _normalize_header start = perf_counter() for i in range(10000): # _normalize_header.cache_clear() for name in names: _normalize_header(name) print(perf_counter() - start) from tornado.httputil import _NormalizedHeaderCache start = perf_counter() _normalized_headers = _NormalizedHeaderCache(1000) for i in range(10000): # _normalized_headers = _NormalizedHeaderCache(1000) for name in names: _normalized_headers[name] print(perf_counter() - start) ``` * httputil: use compiled re patterns This is slightly faster than using the builtin cache, e.g.: With benchmark below (Python 3.7, Linux): before: 0.7284867879934609 after: 0.2657967659761198 ```py import re from time import perf_counter line = 'HTTP/1.1' _http_version_re = re.compile(r"^HTTP/1\.[0-9]$") start = perf_counter() for i in range(1000000): _http_version_re.match(line) print(perf_counter() - start) start = perf_counter() for i in range(1000000): re.match(r"^HTTP/1\.[0-9]$", line) print(perf_counter() - start) ``` * test: Disable TLS 1.3 in one test This test started failing on windows CI with an upgrade to python 3.7.4 (which bundles a newer version of openssl). Disable tls 1.3 for now. Possibly related to #2536 * spelling corrections * maintainance -> maintenance * recieving -> receiving * tests: replace remaining assertEquals() with assertEqual() assertEquals() is deprecated, and python3.7/pytest can warn about it * httputil.parse_body_arguments: allow incomplete url-escaping support x-www-form-urlencoded body with values consisting of encoded bytes which are not url-encoded into ascii (it seems other web frameworks often support this) add bytes qs support to escape.parse_qs_bytes, leave str qs support for backwards compatibility * Clear fewer headers on 1xx/204/304 responses This function is called on more than just 304 responses; it’s important to permit the Allow header on 204 responses. Also, the relevant RFCs have changed significantly. Fixes #2726. Signed-off-by: Anders Kaseorg * Fix extra data sending at HEAD response with Transfer-Encoding: Chunked * Omit Transfer-Encoding header for HEAD response * Add test for unescaping with groups * Fix unescaping of regex routes Previously, only the part before the first '(' would be correctly unescaped. * Use HTTPS link for tornado website. * Simplify chained comparison. * build(deps): bump twisted from 19.2.1 to 19.7.0 in /maint Bumps [twisted](https://github.com/twisted/twisted) from 19.2.1 to 19.7.0. - [Release notes](https://github.com/twisted/twisted/releases) - [Changelog](https://github.com/twisted/twisted/blob/trunk/NEWS.rst) - [Commits](https://github.com/twisted/twisted/compare/twisted-19.2.1...twisted-19.7.0) Signed-off-by: dependabot[bot] * Dead link handling Added an extra set for handling dead links, and reporting. One consequence of this is that using this script will "work" offline, but will report that some all the links were not fetched. * ci: Pin version of black A new release of black changed the way some of our files are formatted, so use a fixed version in CI. * demos: Fix lint in webspider demo Updates #2765 * build: Revamp test/CI configuration Reduce tox matrix to one env per python version, with two extra builds for lint and docs. Delegate to tox from travis-ci. Add 3.8 to testing. Simplify by dropping coverage reporting and "no-deps" test runs. * process: correct docs of fork_processes exit behavior fixes #2771 * Remove legacy Python support in speedups.c * ci: Don't run full test suite on python 3.5.2 * web: Update hashing algorithm in StaticFileHandler (#2778) Addresses #2776. * build: Run docs and lint on py38 This requires moving some noqa comments due to 3.8's changes to the ast module. * lint: Upgrade to new version of black * lint: Use newer mypy This required some minor code changes, mainly some adjustments in tests (which are now analyzed more thoroughly in spite of being mostly unannotated), and some changes to placement of type:ignore comments. * use bcrypt's checkpw instead of == * Fix case of JavaScript, GitHub and CSS. * Fix syntax error in nested routing example * test: Add gitattributes for test data files This ensures that the tests pass on Windows regardless of the user's git CRLF settings. * test: Use selector event loop on windows. This gets most of the tests working again on windows with py38. * test: Add some more skips on windows Alternate resolvers behave differently on this platform for unknown reasons. * test: Add hasattr check for SIGCHLD This name is not present on all platforms * testing: Add level argument to ExpectLog This makes it possible for tests to be a little more precise, and also makes them less dependent on exactly how the test is run (runtests.py sets the logging level to info, but when running tests directly from an editor it may use the default of warnings-only). CI only runs the tests with runtests.py, so this might regress, but I'm not building anything to prevent that yet (options include running the tests differently in CI or making ExpectLog always use a fixed log configuration instead of picking up the current one) * ci: Add python 3.8 to windows CI * asyncio: AnyThreadEventLoopPolicy should always use selectors on windows * iostream: resolve reads that may be completed while closing fixes issue that a read may fail with StreamClosedError if stream is closed mid-read * avoid premature _check_closed in _start_read _start_read can resolve with _try_inline_read, which can succeed even if the stream has been closed if the buffer has been populated by a prior read preserve the fix for asserts being hit when dealing with closed sockets * catch UnsatisfiableReadError in close * iostream: Add tests for behavior around close with read_until Updates #2719 * iostream: Expand comments around recent subtle changes * Fix Google OAuth example (from 6.0 OAuth2Mixin->authorize_redirect is an ordinary synchronous function) * Add Python 3.8 clasifier to setup.py * Standardize type documentation for HTTPRequest init * travis-ci.com doesn't like it when you have matrix and jobs .org still allows this for some reason * Master branch release notes for version 6.0.4 * maint: Bump bleach version for a security fix * iostream: Update comment Update comment from #2690 about ssl module exceptions. * Added default User-Agent to the simple http client if not provided. The User-Agent format is "Tornado\{Tornado_Version}". If self.request.user_agent isn't set and self.request.headers has no User-Agent in it's keys the default User-Agent is added. Fixes: #2702 * Revert "docs: Use python 3.7 via conda for readthedocs builds" This reverts commit e7e31e5642ae56da3f768d9829036eab99f0c988. We were using conda to get access to python 3.7 before rtd supported it in their regular builds, but this led to problems pinning a specific version of sphinx. See https://github.com/readthedocs/readthedocs.org/issues/6870 * fix new E741 detected cases * fix typos * revert genericize change * stop ping_callback * fix types for max_age_days and expires_days parameters * test: Add a sleep to deflake a test Not sure why this has recently started happening in some environments, but killing a process too soon causes the wrong exit status in some python builds on macOS. * ci: Drop tox-venv Its README says it is mostly obsolete due to improvements in virtualenv. Using it appears to cause problems related to https://github.com/pypa/setuptools/issues/1934 because virtualenv installs the wheel package by default but venv doesn't. * ci: Allow failures on nightly python due to cffi incompatibility * template: Clarify docs on escaping Originally from #2831, which went to the wrong branch. * test: Use default timeouts in sigchild test The 1s timeout used here has become flaky with the introduction of a sleep (before the timeout even starts). * auth: Fix example code Continuation of #2811 The oauth2 version of authorize_redirect is no longer a coroutine, so don't use await in example code. The oauth1 version is still a coroutine, but one twitter example was incorrectly calling it with yield instead of await. * platform: Remove tornado.platform.auto.set_close_exec This function is obsolete: Since python 3.4, file descriptors created by python are non-inheritable by default (and in the event you create a file descriptor another way, a standard function os.set_inheritable is available). The windows implementation of this function was also apparently broken, but this went unnoticed because the default behavior on windows is for file descriptors to be non-inheritable. Fixes #2867 * iostream,platform: Remove _set_nonblocking function This functionality is now provided directly in the `os` module. * test: Use larger time values in testing_test This test was flaky on appveyor. Also expand comments about what exactly the test is doing. * Remove text about callback (removed) in run_on_executor * curl_httpclient: set CURLOPT_PROXY to NULL if pycurl supports it This restores curl's default behaviour: use environment variables. This option was set to "" to disable proxy in 905a215a286041c986005859c378c0445c127cbb but curl uses environment variables by default. * httpclient_test: Improve error reporting Without this try/finally, if this test ever fails, errors can be reported in a confusing way. * iostream_test: Improve cleanup Closing the file descriptor without removing the corresponding handler is technically incorrect, although the default IOLoops don't have a problem with it. * test: Add missing level to ExpectLog call * asyncio: Improve support Python 3.8 on Windows This commit removes the need for applications to work around the backwards-incompatible change to the default event loop. Instead, Tornado will detect the use of the windows proactor event loop and start a selector event loop in a separate thread. Closes #2804 * asyncio: Rework AddThreadSelectorEventLoop Running a whole event loop on the other thread leads to tricky synchronization problems. Instead, keep as much as possible on the main thread, and call out to a second thread only for the blocking select system call itself. * test: Add an option to disable assertion that logs are empty Use this on windows due to a log spam issue in asyncio. * asyncio: Refactor selector to callbacks instead of coroutine Restarting the event loop to "cleanly" shut down a coroutine introduces other problems (mainly manifesting as errors logged while running tornado.test.gen_test). Replace the coroutine with a pair of callbacks so we don't need to do anything special to shut down without logging warnings. * docs: Pin version of sphinxcontrib-asyncio The just-released version 0.3.0 is incompatible with our older pinned version of sphinx. * docs: Pin version of sphinxcontrib-asyncio The just-released version 0.3.0 is incompatible with our older pinned version of sphinx. * Added arm64 jobs for Travis-CI * CLN : Remove utf-8 coding cookies in source files On Python 3, utf-8 is the default python source code encoding. so, the coding cookies on files that specify utf-8 are not needed anymore. modified: tornado/_locale_data.py modified: tornado/locale.py modified: tornado/test/curl_httpclient_test.py modified: tornado/test/httpclient_test.py modified: tornado/test/httputil_test.py modified: tornado/test/options_test.py modified: tornado/test/util_test.py * Allow non-yielding functions in `tornado.gen.coroutine`'s type hint (#2909) `@gen.coroutine` deco allows non-yielding functions, so I reflected that in the type hint. Requires usage of `@typing.overload` due to python/mypy#9435 * Update super usage (#2912) On Python 3, super does not need to be called with arguments where as on Python 2, super needs to be called with a class object and an instance. This commit updates the super usage using automated regex-based search and replace. After the automated changes were made, each change was individually checked before committing. * Update links on home page * Updated http links to the https versions when possible. * Updated links to Google Groups to match their new URL format. * Updated links to other projects to match their new locations. * And finally, updated link to FriendFeed to go to the Wikipedia page, because friendfeed.com is just a redirect to facebook.com now :-( :-( * Modified ".travis.yml" to test it's own built wheel Signed-off-by: odidev * tests: httpclient may turn all methods into GET for 303 redirect * websocket_test: test websocket_connect redirect raises exception instead of "uncaught exception" and then test timeout * websocket: set follow_redirects to False to prevent silent failure when the websocket client gets a 3xx redirect response, because it does not currently support redirects Partial fix for issue #2405 * simple_httpclient: after 303 redirect, turn all methods into GET not just POST (but still not HEAD) following the behavior of libcurl > 7.70 * httpclient_test: add test for connect_timeout=0 request_timeout=0 * simple_httpclient: handle connect_timeout or request_timeout of 0 Using a connect_timeout or request_timeout of 0 was effectively invalid for simple_httpclient: it would skip the actual request entirely (because the bulk of the logic was inside "if timeout:"). This was not checked for or raised as an error, it just behaved unexpectedly. Change simple_httpclient to always assert these timeouts are not None and to support the 0 value similar to curl (where request_timeout=0 means no timeout, and connect_timeout=0 means curl default 300 seconds which is very very long for a tcp connection). * httpclient: document connect_timeout/request_timeout 0 value not exactly true for curl_httpclient (libcurl uses a connect_timeout of 300 seconds if no connect timeout is set) but close enough * test: update Travis-CI matrix pypy version to 3.6-7.3.1 * httpclient_test: new test for invalid gzip Content-Encoding this caused an infinite loop in simple_httpclient * http: fix infinite loop hang with invalid gzip data * test: Refactor CI configuration - Add osx and windows builds on travis - Stop running -full test suites on every python version on arm64 - Use cibuildwheel to build for all python versions in one job per platform - Bring a single test configuration and linters up to a first "quick" stage before starting the whole matrix - Push the resulting wheels (and sdist) to pypi on tag builds * Add release notes for 6.1, bump version to 6.1b1 * ci: Switch from testpypi to real pypi * Add deprecation notice for Python 3.5 * Update how to register application with Google * Fix await vs yield in the example * gen: Expliclty track contextvars, fixing contextvars.reset The asyncio event loop provides enough contextvars support out of the box for basic contextvars functionality to work in tornado coroutines, but not `contextvars.reset`. Prior to this change, each yield created a new "level" of context, when an entire coroutine should be on the same level. This is necessary for the reset method to work. Fixes #2731 * test: Add a timeout to SyncHTTPClient test * asyncio: Manage our own thread instead of an executor Python 3.9 changed the behavior of ThreadPoolExecutor at interpreter shutdown (after the already-tricky import-order issues around atexit hooks). Avoid these issues by managing the thread by hand. * ci,setup: Add python 3.9 to tox, cibuildwheel and setup.py * Bump version to 6.1b2 * Set version to 6.1 final * ci: Work around outdated windows root certificates * Bump main branch to 6.2.dev1 * Remove appveyor configs * Drop support for python 3.5 * iostream: Add platform assertion for mypy Without this mypy would fail when run on windows. * maint: Prune requirements lists Remove dependencies that are rarely used outside of tox. The main motivation is to give dependabot less to worry about when an indirect dependency has a security vulnerability. * *: Update black to newest version * Update mypy to latest version * docs: Upgrade to latest version of sphinx This version attempts to resolve types found in type annotations, but in many cases it can't find them so silence a bunch of warnings. (Looks like deferred annotation processing will make this better but we won't be able to use that until we drop Python 3.6) * docs: Pin specific versions of requirements * docs: Stop using autodoc for t.p.twisted This way we don't have to install twisted into the docs build environment. Add some more detail while I'm here. * platform: Deprecate twisted and cares resolvers These were most interesting when the default resolver blocked the main thread. Now that the default is to use a thread pool, there is little if any demand for alternative resolvers just to avoid threads. * Issue #2954: prevent logging error messages for not existing translation files Every not existing translation file for the existing locales logged an error message: Cannot load translation for 'ps': [Errno 2] No such file or directory: '/usr/share/locale/ps/LC_MESSAGES/foo.mo' * WaitIterator: don't re-use _running_future When used with asyncio.Future, WaitIterator may skip indices in some cases. This is caused by multiple _return_result calls after another, without having the chain_future call finish in between. This is fixed here by not hanging on to the _running_future anymore, which forces subsequent _return_result calls to add to _finished, instead of causing the previous result to be silently dropped. Fixes #2034 * Fix return type of _return_result * docs: fix simple typo, authentiate -> authenticate There is a small typo in tornado/netutil.py. Should read `authenticate` rather than `authentiate`. * Avoid 2GB read limitation on SSLIOStream * Remove trailing whitespace * locale: Format with black * wsgi: Update docstring example for python 3 Fixes #2960 * Remove WebSocketHandler.stream. It was no longer used and always set to None. * Add 'address' keyword control binded address #2969 * format code according to result of flake8 check * Add comment explaining workaround * change comment * should use python3 unicode in 'blog' demo #2977 * leave previous versionchanged * leave previous versionchanged * write_message method of WebSocketClientConnection now accepts dict as input * write_message method of WebSocketClientConnection now accepts dict as input * Uppercase A in Any * BaseIOStream.write(): support typed memoryview Making sure that ``len(data) == data.nbytes`` by casting memoryviews to bytes. * Allowed set max_body_size to 0 * fix line too long * fix E127 * what * But this is not beautiful * Is this okay * build(deps): bump jinja2 from 2.11.2 to 2.11.3 in /docs Bumps [jinja2](https://github.com/pallets/jinja) from 2.11.2 to 2.11.3. - [Release notes](https://github.com/pallets/jinja/releases) - [Changelog](https://github.com/pallets/jinja/blob/master/CHANGES.rst) - [Commits](https://github.com/pallets/jinja/compare/2.11.2...2.11.3) Signed-off-by: dependabot[bot] * build(deps): bump pygments from 2.7.2 to 2.7.4 in /docs Bumps [pygments](https://github.com/pygments/pygments) from 2.7.2 to 2.7.4. - [Release notes](https://github.com/pygments/pygments/releases) - [Changelog](https://github.com/pygments/pygments/blob/master/CHANGES) - [Commits](https://github.com/pygments/pygments/compare/2.7.2...2.7.4) Signed-off-by: dependabot[bot] Co-authored-by: Ben Darnell Co-authored-by: Zachary Sailer Co-authored-by: Pierce Lopez Co-authored-by: Robin Roth Co-authored-by: Petr Viktorin Co-authored-by: Martijn van Oosterhout Co-authored-by: Michael V. DePalatis Co-authored-by: Remi Rampin Co-authored-by: Ran Benita Co-authored-by: Semen Zhydenko Co-authored-by: Anders Kaseorg Co-authored-by: Bulat Khasanov Co-authored-by: supakeen Co-authored-by: John Bampton Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Jeff van Santen Co-authored-by: Bruno P. Kinoshita Co-authored-by: Gareth T Co-authored-by: Min RK Co-authored-by: bn0ir Co-authored-by: James Bourbeau Co-authored-by: Recursing Co-authored-by: Flavio Garcia Co-authored-by: Ben Darnell Co-authored-by: marc Co-authored-by: agnewee Co-authored-by: Jeff Hunter Co-authored-by: 依云 Co-authored-by: odidev Co-authored-by: Sai Rahul Poruri Co-authored-by: jack1142 <6032823+jack1142@users.noreply.github.com> Co-authored-by: Amit Patel Co-authored-by: Debby Co-authored-by: = <=> Co-authored-by: Eugene Toder Co-authored-by: Florian Best Co-authored-by: Alexander Clausen Co-authored-by: Tim Gates Co-authored-by: bfis Co-authored-by: Eugene Toder Co-authored-by: youguanxinqing Co-authored-by: kriskros341 Co-authored-by: Mads R. B. Kristensen Co-authored-by: Sakuya --- .flake8 | 4 + .gitattributes | 4 + .travis.yml | 170 +- MANIFEST.in | 1 + README.rst | 2 +- appveyor.yml | 76 - demos/appengine/README | 48 - demos/appengine/app.yaml | 12 - demos/appengine/blog.py | 166 -- demos/appengine/static/blog.css | 153 -- demos/appengine/templates/archive.html | 31 - demos/appengine/templates/base.html | 29 - demos/appengine/templates/compose.html | 40 - demos/appengine/templates/entry.html | 5 - demos/appengine/templates/feed.xml | 26 - demos/appengine/templates/home.html | 8 - demos/appengine/templates/modules/entry.html | 8 - demos/benchmark/stack_context_benchmark.py | 80 - demos/blog/Dockerfile | 12 +- demos/blog/README | 66 +- demos/blog/blog.py | 262 +- demos/blog/docker-compose.yml | 14 +- demos/blog/requirements.txt | 5 +- demos/blog/schema.sql | 34 +- demos/chat/chatdemo.py | 101 +- demos/chat/static/chat.js | 1 - demos/facebook/facebook.py | 45 +- demos/file_upload/file_receiver.py | 20 +- demos/file_upload/file_uploader.py | 49 +- demos/helloworld/helloworld.py | 4 +- demos/s3server/s3server.py | 149 +- demos/tcpecho/client.py | 1 - demos/twitter/twitterdemo.py | 58 +- demos/websocket/chatdemo.py | 17 +- demos/websocket/templates/index.html | 2 +- demos/webspider/webspider.py | 103 +- docs/caresresolver.rst | 4 + docs/concurrent.rst | 4 +- docs/conf.py | 87 +- docs/escape.rst | 12 +- docs/faq.rst | 28 +- docs/gen.rst | 49 +- docs/guide/async.rst | 80 +- docs/guide/coroutines.rst | 210 +- docs/guide/intro.rst | 13 +- docs/guide/running.rst | 64 +- docs/guide/security.rst | 11 +- docs/guide/structure.rst | 74 +- docs/guide/templates.rst | 15 +- docs/httpclient.rst | 4 +- docs/httpserver.rst | 5 +- docs/index.rst | 72 +- docs/ioloop.rst | 15 - docs/locks.rst | 7 +- docs/releases.rst | 8 + docs/releases/v2.2.0.rst | 2 +- docs/releases/v2.3.0.rst | 4 +- docs/releases/v2.4.0.rst | 2 +- docs/releases/v2.4.1.rst | 2 +- docs/releases/v3.0.0.rst | 16 +- docs/releases/v3.0.2.rst | 2 +- docs/releases/v3.1.0.rst | 18 +- docs/releases/v3.2.0.rst | 2 +- docs/releases/v4.0.0.rst | 22 +- docs/releases/v4.1.0.rst | 4 +- docs/releases/v4.2.0.rst | 2 +- docs/releases/v4.3.0.rst | 2 +- docs/releases/v4.4.0.rst | 2 +- docs/releases/v4.5.0.rst | 2 +- docs/releases/v5.0.0.rst | 8 +- docs/releases/v5.1.0.rst | 195 ++ docs/releases/v5.1.1.rst | 14 + docs/releases/v6.0.0.rst | 162 ++ docs/releases/v6.0.1.rst | 11 + docs/releases/v6.0.2.rst | 13 + docs/releases/v6.0.3.rst | 14 + docs/releases/v6.0.4.rst | 21 + docs/releases/v6.1.0.rst | 106 + docs/requirements.in | 3 + docs/requirements.txt | 27 +- docs/stack_context.rst | 5 - docs/twisted.rst | 66 +- docs/utilities.rst | 1 - docs/web.rst | 47 +- docs/websocket.rst | 1 + docs/wsgi.rst | 12 - maint/README | 2 +- {demos => maint}/benchmark/benchmark.py | 0 {demos => maint}/benchmark/chunk_benchmark.py | 0 {demos => maint}/benchmark/gen_benchmark.py | 0 maint/benchmark/parsing_benchmark.py | 112 + .../benchmark/template_benchmark.py | 0 maint/circlerefs/circlerefs.py | 1 - maint/requirements.in | 28 +- maint/requirements.txt | 76 +- maint/scripts/test_resolvers.py | 2 - maint/test/appengine/README | 8 - maint/test/appengine/common/cgi_runtests.py | 59 - maint/test/appengine/common/runtests.py | 58 - maint/test/appengine/py27/app.yaml | 9 - maint/test/appengine/py27/cgi_runtests.py | 1 - maint/test/appengine/py27/runtests.py | 1 - maint/test/appengine/py27/tornado | 1 - maint/test/appengine/setup.py | 4 - maint/test/appengine/tox.ini | 15 - maint/test/cython/tox.ini | 3 +- maint/test/mypy/.gitignore | 1 + maint/test/mypy/bad.py | 6 + maint/test/mypy/good.py | 11 + maint/test/mypy/setup.py | 3 + maint/test/mypy/tox.ini | 14 + maint/test/pyuv/tox.ini | 13 - maint/test/redbot/red_test.py | 2 +- maint/test/websocket/fuzzingclient.json | 57 +- maint/test/websocket/run-client.sh | 2 +- maint/test/websocket/run-server.sh | 6 +- maint/test/websocket/tox.ini | 2 +- maint/vm/windows/bootstrap.py | 1 - setup.cfg | 15 + setup.py | 110 +- tornado/__init__.py | 6 +- tornado/_locale_data.py | 4 - tornado/auth.py | 909 ++++--- tornado/autoreload.py | 89 +- tornado/concurrent.py | 592 +---- tornado/curl_httpclient.py | 376 +-- tornado/escape.py | 311 +-- tornado/gen.py | 993 ++------ tornado/http1connection.py | 526 +++-- tornado/httpclient.py | 409 ++-- tornado/httpserver.py | 196 +- tornado/httputil.py | 594 +++-- tornado/ioloop.py | 901 +++---- tornado/iostream.py | 1202 +++++----- tornado/locale.py | 304 ++- tornado/locks.py | 333 +-- tornado/log.py | 195 +- tornado/netutil.py | 304 ++- tornado/options.py | 350 ++- tornado/platform/asyncio.py | 445 +++- tornado/platform/auto.py | 58 - tornado/platform/auto.pyi | 4 - tornado/platform/caresresolver.py | 52 +- tornado/platform/common.py | 113 - tornado/platform/epoll.py | 25 - tornado/platform/interface.py | 66 - tornado/platform/kqueue.py | 90 - tornado/platform/posix.py | 69 - tornado/platform/select.py | 75 - tornado/platform/twisted.py | 559 +---- tornado/platform/windows.py | 20 - tornado/process.py | 186 +- tornado/{test/__init__.py => py.typed} | 0 tornado/queues.py | 172 +- tornado/routing.py | 240 +- tornado/simple_httpclient.py | 657 +++-- tornado/speedups.c | 7 - tornado/stack_context.py | 389 --- tornado/tcpclient.py | 158 +- tornado/tcpserver.py | 132 +- tornado/template.py | 371 +-- tornado/test/__main__.py | 2 - tornado/test/asyncio_test.py | 86 +- tornado/test/auth_test.py | 709 +++--- tornado/test/autoreload_test.py | 113 +- tornado/test/concurrent_test.py | 358 +-- tornado/test/curl_httpclient_test.py | 143 +- tornado/test/escape_test.py | 390 +-- tornado/test/gen_test.py | 1270 +++------- .../test/gettext_translations/extract_me.py | 1 - tornado/test/http1connection_test.py | 12 +- tornado/test/httpclient_test.py | 680 ++++-- tornado/test/httpserver_test.py | 864 ++++--- tornado/test/httputil_test.py | 359 +-- tornado/test/import_test.py | 95 +- tornado/test/ioloop_test.py | 512 ++-- tornado/test/iostream_test.py | 708 +++--- tornado/test/locale_test.py | 120 +- tornado/test/locks_test.py | 148 +- tornado/test/log_test.py | 108 +- tornado/test/netutil_test.py | 145 +- tornado/test/options_test.py | 229 +- tornado/test/options_test_types.cfg | 11 + tornado/test/options_test_types_str.cfg | 8 + tornado/test/process_test.py | 124 +- tornado/test/queues_test.py | 134 +- tornado/test/resolve_test_helper.py | 3 +- tornado/test/routing_test.py | 107 +- tornado/test/runtests.py | 266 ++- tornado/test/simple_httpclient_test.py | 641 ++--- tornado/test/stack_context_test.py | 296 --- tornado/test/tcpclient_test.py | 307 +-- tornado/test/tcpserver_test.py | 79 +- tornado/test/template_test.py | 386 +-- tornado/test/test.crt | 31 +- tornado/test/test.key | 40 +- tornado/test/testing_test.py | 198 +- tornado/test/twisted_test.py | 630 +---- tornado/test/util.py | 50 +- tornado/test/util_test.py | 144 +- tornado/test/web_test.py | 2104 +++++++++-------- tornado/test/websocket_test.py | 583 +++-- tornado/test/windows_test.py | 25 - tornado/test/wsgi_test.py | 87 +- tornado/testing.py | 493 ++-- tornado/util.py | 291 +-- tornado/web.py | 1728 ++++++++------ tornado/websocket.py | 1001 +++++--- tornado/wsgi.py | 271 +-- tox.ini | 170 +- 210 files changed, 15891 insertions(+), 17830 deletions(-) create mode 100644 .gitattributes delete mode 100644 appveyor.yml delete mode 100644 demos/appengine/README delete mode 100644 demos/appengine/app.yaml delete mode 100644 demos/appengine/blog.py delete mode 100644 demos/appengine/static/blog.css delete mode 100644 demos/appengine/templates/archive.html delete mode 100644 demos/appengine/templates/base.html delete mode 100644 demos/appengine/templates/compose.html delete mode 100644 demos/appengine/templates/entry.html delete mode 100644 demos/appengine/templates/feed.xml delete mode 100644 demos/appengine/templates/home.html delete mode 100644 demos/appengine/templates/modules/entry.html delete mode 100755 demos/benchmark/stack_context_benchmark.py create mode 100644 docs/releases/v5.1.0.rst create mode 100644 docs/releases/v5.1.1.rst create mode 100644 docs/releases/v6.0.0.rst create mode 100644 docs/releases/v6.0.1.rst create mode 100644 docs/releases/v6.0.2.rst create mode 100644 docs/releases/v6.0.3.rst create mode 100644 docs/releases/v6.0.4.rst create mode 100644 docs/releases/v6.1.0.rst create mode 100644 docs/requirements.in delete mode 100644 docs/stack_context.rst rename {demos => maint}/benchmark/benchmark.py (100%) rename {demos => maint}/benchmark/chunk_benchmark.py (100%) rename {demos => maint}/benchmark/gen_benchmark.py (100%) create mode 100644 maint/benchmark/parsing_benchmark.py rename {demos => maint}/benchmark/template_benchmark.py (100%) delete mode 100644 maint/test/appengine/README delete mode 100755 maint/test/appengine/common/cgi_runtests.py delete mode 100755 maint/test/appengine/common/runtests.py delete mode 100644 maint/test/appengine/py27/app.yaml delete mode 120000 maint/test/appengine/py27/cgi_runtests.py delete mode 120000 maint/test/appengine/py27/runtests.py delete mode 120000 maint/test/appengine/py27/tornado delete mode 100644 maint/test/appengine/setup.py delete mode 100644 maint/test/appengine/tox.ini create mode 100644 maint/test/mypy/.gitignore create mode 100644 maint/test/mypy/bad.py create mode 100644 maint/test/mypy/good.py create mode 100644 maint/test/mypy/setup.py create mode 100644 maint/test/mypy/tox.ini delete mode 100644 maint/test/pyuv/tox.ini create mode 100644 setup.cfg delete mode 100644 tornado/platform/auto.py delete mode 100644 tornado/platform/auto.pyi delete mode 100644 tornado/platform/common.py delete mode 100644 tornado/platform/epoll.py delete mode 100644 tornado/platform/interface.py delete mode 100644 tornado/platform/kqueue.py delete mode 100644 tornado/platform/posix.py delete mode 100644 tornado/platform/select.py delete mode 100644 tornado/platform/windows.py rename tornado/{test/__init__.py => py.typed} (100%) delete mode 100644 tornado/stack_context.py create mode 100644 tornado/test/options_test_types.cfg create mode 100644 tornado/test/options_test_types_str.cfg delete mode 100644 tornado/test/stack_context_test.py delete mode 100644 tornado/test/windows_test.py diff --git a/.flake8 b/.flake8 index 1c2c768d1f..18c72168cb 100644 --- a/.flake8 +++ b/.flake8 @@ -10,4 +10,8 @@ ignore = E402, # E722 do not use bare except E722, + # flake8 and black disagree about + # W503 line break before binary operator + # E203 whitespace before ':' + W503,E203 doctests = true diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000..facf16e431 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,4 @@ +# Tests of static file handling assume unix-style line endings. +tornado/test/static/*.txt text eol=lf +tornado/test/static/dir/*.html text eol=lf +tornado/test/templates/*.html text eol=lf diff --git a/.travis.yml b/.travis.yml index 5b5faa660f..f7b1e825ef 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,89 +1,101 @@ # https://travis-ci.org/tornadoweb/tornado -dist: trusty -# Use containers instead of full VMs for faster startup. -sudo: false +os: linux +dist: xenial +language: python +addons: + apt: + packages: + - libgnutls-dev -matrix: - fast_finish: true +env: + global: + - CIBW_BUILD="cp3[6789]*" + - CIBW_TEST_COMMAND="python3 -m tornado.test" + - CIBW_TEST_COMMAND_WINDOWS="python -m tornado.test --fail-if-logs=false" -language: python -# For a list of available versions, run -# aws s3 ls s3://travis-python-archives/binaries/ubuntu/14.04/x86_64/ -python: - - 2.7 - - pypy2.7-5.8.0 - - 3.4 - - 3.5 - - 3.6 - - nightly - - pypy3.5-5.8.0 +# Before starting the full build matrix, run one test configuration +# and the linter (the `black` linter is especially likely to catch +# first-time contributors). +stages: + - quick + - test + +jobs: + fast_finish: true + include: + # We have two and a half types of builds: Wheel builds run on all supported + # platforms and run the basic test suite for all supported python versions. + # Sdist builds (the "and a half") just build an sdist and run some basic + # validation. Both of these upload their artifacts to pypi if there is a + # tag on the build and the key is available. + # + # Tox builds run a more comprehensive test suite with more configurations + # and dependencies (getting all these dependencies installed for wheel + # builds is a pain, and slows things down because we don't use as much + # parallelism there. We could parallelize wheel builds more but we're also + # amortizing download costs across the different builds). The wheel builds + # take longer, so we run them before the tox builds for better bin packing + # in our allotted concurrency. + - python: '3.8' + arch: amd64 + services: docker + env: BUILD_WHEEL=1 + - python: '3.8' + arch: arm64 + services: docker + env: BUILD_WHEEL=1 ASYNC_TEST_TIMEOUT=15 + - os: windows + env: PATH=/c/Python38:/c/Python38/Scripts:$PATH BUILD_WHEEL=1 + language: shell + before_install: + - choco install python --version 3.8.0 + # Windows build images have outdated root certificates; until that's + # fixed use certifi instead. + # https://github.com/joerick/cibuildwheel/issues/452 + - python -m pip install certifi + - export SSL_CERT_FILE=`python -c "import certifi; print(certifi.where())"` + - os: osx + env: BUILD_WHEEL=1 + language: shell + + - python: '3.8' + arch: amd64 + env: BUILD_SDIST=1 + + - python: '3.6' + env: TOX_ENV=py36-full + - python: '3.7' + env: TOX_ENV=py37-full + - python: '3.8' + env: TOX_ENV=py38-full + - python: '3.9-dev' + env: TOX_ENV=py39-full + - python: nightly + env: TOX_ENV=py3 + - python: pypy3.6-7.3.1 + # Pypy is a lot slower due to jit warmup costs, so don't run the "full" + # test config there. + env: TOX_ENV=pypy3 + # Docs and lint python versions must be synced with those in tox.ini + - python: '3.8' + env: TOX_ENV=docs + + # the quick stage runs first, but putting this at the end lets us take + # advantage of travis-ci's defaults and not repeat stage:test in the others. + - python: '3.8' + env: TOX_ENV=py38,lint + stage: quick install: - # On nightly, upgrade setuptools first to work around - # https://github.com/pypa/setuptools/issues/1257 - - if [[ $TRAVIS_PYTHON_VERSION == 'nightly' ]]; then travis_retry pip install -U setuptools; fi - - if [[ $TRAVIS_PYTHON_VERSION == 2* ]]; then travis_retry pip install mock monotonic; fi - - if [[ $TRAVIS_PYTHON_VERSION == 'pypy' ]]; then travis_retry pip install mock; fi - # TODO(bdarnell): pycares tests are currently disabled on travis due to ipv6 issues. - #- if [[ $TRAVIS_PYTHON_VERSION != 'pypy'* ]]; then travis_retry pip install pycares; fi - - if [[ $TRAVIS_PYTHON_VERSION != 'pypy'* ]]; then travis_retry pip install pycurl; fi - # Twisted runs on 2.x and 3.3+, but is flaky on pypy. - - if [[ $TRAVIS_PYTHON_VERSION != 'pypy'* ]]; then travis_retry pip install Twisted; fi - - if [[ $TRAVIS_PYTHON_VERSION == '2.7' || $TRAVIS_PYTHON_VERSION == '3.6' ]]; then travis_retry pip install sphinx sphinx_rtd_theme; fi - - if [[ $TRAVIS_PYTHON_VERSION == '3.6' ]]; then travis_retry pip install flake8; fi - # On travis the extension should always be built - - if [[ $TRAVIS_PYTHON_VERSION != 'pypy'* ]]; then export TORNADO_EXTENSION=1; fi - - travis_retry python setup.py install - - travis_retry pip install codecov virtualenv - # Create a separate no-dependencies virtualenv to make sure all imports - # of optional-dependencies are guarded. - - virtualenv ./nodeps - - ./nodeps/bin/python -VV - - ./nodeps/bin/python setup.py install - - curl-config --version; pip freeze + - if [[ -n "$TOX_ENV" ]]; then pip3 install tox; fi + - if [[ -n "$BUILD_WHEEL" ]]; then pip3 install cibuildwheel; fi + - if [[ -n "$BUILD_WHEEL" || -n "$BUILD_SDIST" ]]; then pip3 install twine; fi script: - # Run the tests once from the source directory to detect issues - # involving relative __file__ paths; see - # https://github.com/tornadoweb/tornado/issues/1780 - - unset TORNADO_EXTENSION && python -m tornado.test - # For all other test variants, get out of the source directory before - # running tests to ensure that we get the installed speedups module - # instead of the source directory which doesn't have it. - - cd maint - # Copy the coveragerc down so coverage.py can find it. - - cp ../.coveragerc . - - if [[ $TRAVIS_PYTHON_VERSION != 'pypy'* ]]; then export TORNADO_EXTENSION=1; fi - - export TARGET="-m tornado.test.runtests" - # Travis workers are often overloaded and cause our tests to exceed - # the default timeout of 5s. - - export ASYNC_TEST_TIMEOUT=15 - # We use "python -m coverage" instead of the "bin/coverage" script - # so we can pass additional arguments to python. - # coverage needs a function that was removed in python 3.6 so we can't - # run it with nightly cpython. Coverage is very slow on pypy. - - if [[ $TRAVIS_PYTHON_VERSION != nightly && $TRAVIS_PYTHON_VERSION != 'pypy'* ]]; then export RUN_COVERAGE=1; fi - - if [[ "$RUN_COVERAGE" == 1 ]]; then export TARGET="-m coverage run $TARGET"; fi - - python $TARGET - - if [[ $TRAVIS_PYTHON_VERSION == 2* ]]; then python $TARGET --ioloop=tornado.platform.select.SelectIOLoop; fi - - python -O $TARGET - - LANG=C python $TARGET - - LANG=en_US.utf-8 python $TARGET - - if [[ $TRAVIS_PYTHON_VERSION == 3* ]]; then python -bb $TARGET; fi - - if [[ $TRAVIS_PYTHON_VERSION == 2* ]]; then python $TARGET --httpclient=tornado.curl_httpclient.CurlAsyncHTTPClient; fi - - if [[ $TRAVIS_PYTHON_VERSION == 2* ]]; then python $TARGET --ioloop=tornado.platform.twisted.TwistedIOLoop; fi - - if [[ $TRAVIS_PYTHON_VERSION == 2* ]]; then python $TARGET --resolver=tornado.platform.twisted.TwistedResolver; fi - - if [[ $TRAVIS_PYTHON_VERSION == 2* ]]; then python $TARGET --ioloop=tornado.ioloop.PollIOLoop --ioloop_time_monotonic; fi - #- if [[ $TRAVIS_PYTHON_VERSION != pypy* ]]; then python $TARGET --resolver=tornado.platform.caresresolver.CaresResolver; fi - - if [[ $TRAVIS_PYTHON_VERSION != 'pypy3' ]]; then ../nodeps/bin/python -m tornado.test.runtests; fi - # make coverage reports for Codecov to find - - if [[ "$RUN_COVERAGE" == 1 ]]; then coverage xml; fi - - export TORNADO_EXTENSION=0 - - if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then cd ../docs && mkdir sphinx-out && sphinx-build -E -n -W -b html . sphinx-out; fi - - if [[ $TRAVIS_PYTHON_VERSION == '2.7' || $TRAVIS_PYTHON_VERSION == 3.6 ]]; then cd ../docs && mkdir sphinx-doctest-out && sphinx-build -E -n -b doctest . sphinx-out; fi - - if [[ $TRAVIS_PYTHON_VERSION == '3.6' ]]; then flake8; fi + - if [[ -n "$TOX_ENV" ]]; then tox -e $TOX_ENV -- $TOX_ARGS; fi + - if [[ -n "$BUILD_WHEEL" ]]; then cibuildwheel --output-dir dist && ls -l dist; fi + - if [[ -n "$BUILD_SDIST" ]]; then python setup.py check sdist && ls -l dist; fi after_success: - # call codecov from project root - - if [[ "$RUN_COVERAGE" == 1 ]]; then cd ../ && codecov; fi + - if [[ ( -n "$BUILD_WHEEL" || -n "$BUILD_SDIST" ) && -n "$TRAVIS_TAG" && -n "$TWINE_PASSWORD" ]]; then twine upload -u __token__ dist/*; fi diff --git a/MANIFEST.in b/MANIFEST.in index 2ef76aefd9..d99e4bb930 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,7 @@ recursive-include demos *.py *.yaml *.html *.css *.js *.xml *.sql README recursive-include docs * prune docs/build +include tornado/py.typed include tornado/speedups.c include tornado/test/README include tornado/test/csv_translations/fr_FR.csv diff --git a/README.rst b/README.rst index c177ef1291..2c9561d527 100644 --- a/README.rst +++ b/README.rst @@ -45,4 +45,4 @@ Documentation ------------- Documentation and links to additional resources are available at -http://www.tornadoweb.org +https://www.tornadoweb.org diff --git a/appveyor.yml b/appveyor.yml deleted file mode 100644 index 72f9d881d3..0000000000 --- a/appveyor.yml +++ /dev/null @@ -1,76 +0,0 @@ -# Appveyor is Windows CI: https://ci.appveyor.com/project/bdarnell/tornado -environment: - global: - TORNADO_EXTENSION: "1" - - # We only build with 3.5+ because it works out of the box, while other - # versions require lots of machinery. - # - # We produce binary wheels for 32- and 64-bit builds, but because - # the tests are so slow on Windows (6 minutes vs 15 seconds on Linux - # or MacOS), we don't want to test the full matrix. We do full - # tests on a couple of configurations and on the others we limit - # the tests to the websocket module (which, because it exercises the - # C extension module, is most likely to exhibit differences between - # 32- and 64-bits) - matrix: - - PYTHON: "C:\\Python35" - PYTHON_VERSION: "3.5.x" - PYTHON_ARCH: "32" - TOX_ENV: "py35" - TOX_ARGS: "" - - - PYTHON: "C:\\Python35-x64" - PYTHON_VERSION: "3.5.x" - PYTHON_ARCH: "64" - TOX_ENV: "py35" - TOX_ARGS: "tornado.test.websocket_test" - - - PYTHON: "C:\\Python36" - PYTHON_VERSION: "3.6.x" - PYTHON_ARCH: "32" - TOX_ENV: "py36" - TOX_ARGS: "tornado.test.websocket_test" - - - PYTHON: "C:\\Python36-x64" - PYTHON_VERSION: "3.6.x" - PYTHON_ARCH: "64" - TOX_ENV: "py36" - TOX_ARGS: "" - -install: - # Make sure the right python version is first on the PATH. - - "SET PATH=%PYTHON%;%PYTHON%\\Scripts;%PATH%" - - # Check that we have the expected version and architecture for Python - - "python --version" - - "python -c \"import struct; print(struct.calcsize('P') * 8)\"" - - # Upgrade to the latest version of pip to avoid it displaying warnings - # about it being out of date. - - "python -m pip install --disable-pip-version-check --user --upgrade pip" - - - "python -m pip install tox wheel" - -build: false # Not a C# project, build stuff at the test step instead. - -test_script: - # Build the compiled extension and run the project tests. - # This is a bit of a hack that doesn't scale with new python versions, - # but for now it lets us avoid duplication with .travis.yml and tox.ini. - # Running "py3x-full" would be nice but it's failing on installing - # dependencies with no useful logs. - - "tox -e %TOX_ENV% -- %TOX_ARGS%" - -after_test: - # If tests are successful, create binary packages for the project. - - "python setup.py bdist_wheel" - - ps: "ls dist" - -artifacts: - # Archive the generated packages in the ci.appveyor.com build report. - - path: dist\* - -#on_success: -# - TODO: upload the content of dist/*.whl to a public wheelhouse -# diff --git a/demos/appengine/README b/demos/appengine/README deleted file mode 100644 index e4aead6701..0000000000 --- a/demos/appengine/README +++ /dev/null @@ -1,48 +0,0 @@ -Running the Tornado AppEngine example -===================================== -This example is designed to run in Google AppEngine, so there are a couple -of steps to get it running. You can download the Google AppEngine Python -development environment at http://code.google.com/appengine/downloads.html. - -1. Link or copy the tornado code directory into this directory: - - ln -s ../../tornado tornado - - AppEngine doesn't use the Python modules installed on this machine. - You need to have the 'tornado' module copied or linked for AppEngine - to find it. - -3. Install and run dev_appserver - - If you don't already have the App Engine SDK, download it from - http://code.google.com/appengine/downloads.html - - To start the tornado demo, run the dev server on this directory: - - dev_appserver.py . - -4. Visit http://localhost:8080/ in your browser - - If you sign in as an administrator, you will be able to create and - edit blog posts. If you sign in as anybody else, you will only see - the existing blog posts. - - -If you want to deploy the blog in production: - -1. Register a new appengine application and put its id in app.yaml - - First register a new application at http://appengine.google.com/. - Then edit app.yaml in this directory and change the "application" - setting from "tornado-appenginge" to your new application id. - -2. Deploy to App Engine - - If you registered an application id, you can now upload your new - Tornado blog by running this command: - - appcfg update . - - After that, visit application_id.appspot.com, where application_id - is the application you registered. - diff --git a/demos/appengine/app.yaml b/demos/appengine/app.yaml deleted file mode 100644 index c90cecdba1..0000000000 --- a/demos/appengine/app.yaml +++ /dev/null @@ -1,12 +0,0 @@ -application: tornado-appengine -version: 2 -runtime: python27 -api_version: 1 -threadsafe: yes - -handlers: -- url: /static/ - static_dir: static - -- url: /.* - script: blog.application diff --git a/demos/appengine/blog.py b/demos/appengine/blog.py deleted file mode 100644 index e2b2ef5042..0000000000 --- a/demos/appengine/blog.py +++ /dev/null @@ -1,166 +0,0 @@ -# -# Copyright 2009 Facebook -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import functools -import os.path -import re -import tornado.escape -import tornado.web -import tornado.wsgi -import unicodedata - -from google.appengine.api import users -from google.appengine.ext import db - - -class Entry(db.Model): - """A single blog entry.""" - author = db.UserProperty() - title = db.StringProperty(required=True) - slug = db.StringProperty(required=True) - body_source = db.TextProperty(required=True) - html = db.TextProperty(required=True) - published = db.DateTimeProperty(auto_now_add=True) - updated = db.DateTimeProperty(auto_now=True) - - -def administrator(method): - """Decorate with this method to restrict to site admins.""" - @functools.wraps(method) - def wrapper(self, *args, **kwargs): - if not self.current_user: - if self.request.method == "GET": - self.redirect(self.get_login_url()) - return - raise tornado.web.HTTPError(403) - elif not self.current_user.administrator: - if self.request.method == "GET": - self.redirect("/") - return - raise tornado.web.HTTPError(403) - else: - return method(self, *args, **kwargs) - return wrapper - - -class BaseHandler(tornado.web.RequestHandler): - """Implements Google Accounts authentication methods.""" - def get_current_user(self): - user = users.get_current_user() - if user: - user.administrator = users.is_current_user_admin() - return user - - def get_login_url(self): - return users.create_login_url(self.request.uri) - - def get_template_namespace(self): - # Let the templates access the users module to generate login URLs - ns = super(BaseHandler, self).get_template_namespace() - ns['users'] = users - return ns - - -class HomeHandler(BaseHandler): - def get(self): - entries = db.Query(Entry).order('-published').fetch(limit=5) - if not entries: - if not self.current_user or self.current_user.administrator: - self.redirect("/compose") - return - self.render("home.html", entries=entries) - - -class EntryHandler(BaseHandler): - def get(self, slug): - entry = db.Query(Entry).filter("slug =", slug).get() - if not entry: - raise tornado.web.HTTPError(404) - self.render("entry.html", entry=entry) - - -class ArchiveHandler(BaseHandler): - def get(self): - entries = db.Query(Entry).order('-published') - self.render("archive.html", entries=entries) - - -class FeedHandler(BaseHandler): - def get(self): - entries = db.Query(Entry).order('-published').fetch(limit=10) - self.set_header("Content-Type", "application/atom+xml") - self.render("feed.xml", entries=entries) - - -class ComposeHandler(BaseHandler): - @administrator - def get(self): - key = self.get_argument("key", None) - entry = Entry.get(key) if key else None - self.render("compose.html", entry=entry) - - @administrator - def post(self): - key = self.get_argument("key", None) - if key: - entry = Entry.get(key) - entry.title = self.get_argument("title") - entry.body_source = self.get_argument("body_source") - entry.html = tornado.escape.linkify( - self.get_argument("body_source")) - else: - title = self.get_argument("title") - slug = unicodedata.normalize("NFKD", title).encode( - "ascii", "ignore") - slug = re.sub(r"[^\w]+", " ", slug) - slug = "-".join(slug.lower().strip().split()) - if not slug: - slug = "entry" - while True: - existing = db.Query(Entry).filter("slug =", slug).get() - if not existing or str(existing.key()) == key: - break - slug += "-2" - entry = Entry( - author=self.current_user, - title=title, - slug=slug, - body_source=self.get_argument("body_source"), - html=tornado.escape.linkify(self.get_argument("body_source")), - ) - entry.put() - self.redirect("/entry/" + entry.slug) - - -class EntryModule(tornado.web.UIModule): - def render(self, entry): - return self.render_string("modules/entry.html", entry=entry) - - -settings = { - "blog_title": u"Tornado Blog", - "template_path": os.path.join(os.path.dirname(__file__), "templates"), - "ui_modules": {"Entry": EntryModule}, - "xsrf_cookies": True, -} -application = tornado.web.Application([ - (r"/", HomeHandler), - (r"/archive", ArchiveHandler), - (r"/feed", FeedHandler), - (r"/entry/([^/]+)", EntryHandler), - (r"/compose", ComposeHandler), -], **settings) - -application = tornado.wsgi.WSGIAdapter(application) diff --git a/demos/appengine/static/blog.css b/demos/appengine/static/blog.css deleted file mode 100644 index 3ebef875e8..0000000000 --- a/demos/appengine/static/blog.css +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Copyright 2009 Facebook - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. You may obtain - * a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - */ - -body { - background: white; - color: black; - margin: 15px; - margin-top: 0; -} - -body, -input, -textarea { - font-family: Georgia, serif; - font-size: 12pt; -} - -table { - border-collapse: collapse; - border: 0; -} - -td { - border: 0; - padding: 0; -} - -h1, -h2, -h3, -h4 { - font-family: "Helvetica Nue", Helvetica, Arial, sans-serif; - margin: 0; -} - -h1 { - font-size: 20pt; -} - -pre, -code { - font-family: monospace; - color: #060; -} - -pre { - margin-left: 1em; - padding-left: 1em; - border-left: 1px solid silver; - line-height: 14pt; -} - -a, -a code { - color: #00c; -} - -#body { - max-width: 800px; - margin: auto; -} - -#header { - background-color: #3b5998; - padding: 5px; - padding-left: 10px; - padding-right: 10px; - margin-bottom: 1em; -} - -#header, -#header a { - color: white; -} - -#header h1 a { - text-decoration: none; -} - -#footer, -#content { - margin-left: 10px; - margin-right: 10px; -} - -#footer { - margin-top: 3em; -} - -.entry h1 a { - color: black; - text-decoration: none; -} - -.entry { - margin-bottom: 2em; -} - -.entry .date { - margin-top: 3px; -} - -.entry p { - margin: 0; - margin-bottom: 1em; -} - -.entry .body { - margin-top: 1em; - line-height: 16pt; -} - -.compose td { - vertical-align: middle; - padding-bottom: 5px; -} - -.compose td.field { - padding-right: 10px; -} - -.compose .title, -.compose .submit { - font-family: "Helvetica Nue", Helvetica, Arial, sans-serif; - font-weight: bold; -} - -.compose .title { - font-size: 20pt; -} - -.compose .title, -.compose .body_source { - width: 100%; -} - -.compose .body_source { - height: 500px; - line-height: 16pt; -} diff --git a/demos/appengine/templates/archive.html b/demos/appengine/templates/archive.html deleted file mode 100644 index d501464976..0000000000 --- a/demos/appengine/templates/archive.html +++ /dev/null @@ -1,31 +0,0 @@ -{% extends "base.html" %} - -{% block head %} - -{% end %} - -{% block body %} -
    - {% for entry in entries %} -
  • - -
    {{ locale.format_date(entry.published, full_format=True, shorter=True) }}
    -
  • - {% end %} -
-{% end %} diff --git a/demos/appengine/templates/base.html b/demos/appengine/templates/base.html deleted file mode 100644 index 7ea0efa9f3..0000000000 --- a/demos/appengine/templates/base.html +++ /dev/null @@ -1,29 +0,0 @@ - - - - - {{ handler.settings["blog_title"] }} - - - {% block head %}{% end %} - - -
- -
{% block body %}{% end %}
-
- {% block bottom %}{% end %} - - diff --git a/demos/appengine/templates/compose.html b/demos/appengine/templates/compose.html deleted file mode 100644 index 39045e0394..0000000000 --- a/demos/appengine/templates/compose.html +++ /dev/null @@ -1,40 +0,0 @@ -{% extends "base.html" %} - -{% block body %} -
-
-
- - {% if entry %} - - {% end %} - {% module xsrf_form_html() %} -
-{% end %} - -{% block bottom %} - - -{% end %} diff --git a/demos/appengine/templates/entry.html b/demos/appengine/templates/entry.html deleted file mode 100644 index f3f495b496..0000000000 --- a/demos/appengine/templates/entry.html +++ /dev/null @@ -1,5 +0,0 @@ -{% extends "base.html" %} - -{% block body %} - {% module Entry(entry) %} -{% end %} diff --git a/demos/appengine/templates/feed.xml b/demos/appengine/templates/feed.xml deleted file mode 100644 index a98826c8d3..0000000000 --- a/demos/appengine/templates/feed.xml +++ /dev/null @@ -1,26 +0,0 @@ - - - {% set date_format = "%Y-%m-%dT%H:%M:%SZ" %} - {{ handler.settings["blog_title"] }} - {% if len(entries) > 0 %} - {{ max(e.updated for e in entries).strftime(date_format) }} - {% else %} - {{ datetime.datetime.utcnow().strftime(date_format) }} - {% end %} - http://{{ request.host }}/ - - - {{ handler.settings["blog_title"] }} - {% for entry in entries %} - - http://{{ request.host }}/entry/{{ entry.slug }} - {{ entry.title }} - - {{ entry.updated.strftime(date_format) }} - {{ entry.published.strftime(date_format) }} - -
{% raw entry.html %}
-
-
- {% end %} -
diff --git a/demos/appengine/templates/home.html b/demos/appengine/templates/home.html deleted file mode 100644 index 8e990ca56c..0000000000 --- a/demos/appengine/templates/home.html +++ /dev/null @@ -1,8 +0,0 @@ -{% extends "base.html" %} - -{% block body %} - {% for entry in entries %} - {% module Entry(entry) %} - {% end %} - -{% end %} diff --git a/demos/appengine/templates/modules/entry.html b/demos/appengine/templates/modules/entry.html deleted file mode 100644 index 201c04118c..0000000000 --- a/demos/appengine/templates/modules/entry.html +++ /dev/null @@ -1,8 +0,0 @@ -
-

{{ entry.title }}

-
{{ locale.format_date(entry.published, full_format=True, shorter=True) }}
-
{% raw entry.html %}
- {% if current_user and current_user.administrator %} - - {% end %} -
diff --git a/demos/benchmark/stack_context_benchmark.py b/demos/benchmark/stack_context_benchmark.py deleted file mode 100755 index 2b4a388fea..0000000000 --- a/demos/benchmark/stack_context_benchmark.py +++ /dev/null @@ -1,80 +0,0 @@ -#!/usr/bin/env python -"""Benchmark for stack_context functionality.""" -import collections -import contextlib -import functools -import subprocess -import sys - -from tornado import stack_context - - -class Benchmark(object): - def enter_exit(self, count): - """Measures the overhead of the nested "with" statements - when using many contexts. - """ - if count < 0: - return - with self.make_context(): - self.enter_exit(count - 1) - - def call_wrapped(self, count): - """Wraps and calls a function at each level of stack depth - to measure the overhead of the wrapped function. - """ - # This queue is analogous to IOLoop.add_callback, but lets us - # benchmark the stack_context in isolation without system call - # overhead. - queue = collections.deque() - self.call_wrapped_inner(queue, count) - while queue: - queue.popleft()() - - def call_wrapped_inner(self, queue, count): - if count < 0: - return - with self.make_context(): - queue.append(stack_context.wrap( - functools.partial(self.call_wrapped_inner, queue, count - 1))) - - -class StackBenchmark(Benchmark): - def make_context(self): - return stack_context.StackContext(self.__context) - - @contextlib.contextmanager - def __context(self): - yield - - -class ExceptionBenchmark(Benchmark): - def make_context(self): - return stack_context.ExceptionStackContext(self.__handle_exception) - - def __handle_exception(self, typ, value, tb): - pass - - -def main(): - base_cmd = [ - sys.executable, '-m', 'timeit', '-s', - 'from stack_context_benchmark import StackBenchmark, ExceptionBenchmark'] - cmds = [ - 'StackBenchmark().enter_exit(50)', - 'StackBenchmark().call_wrapped(50)', - 'StackBenchmark().enter_exit(500)', - 'StackBenchmark().call_wrapped(500)', - - 'ExceptionBenchmark().enter_exit(50)', - 'ExceptionBenchmark().call_wrapped(50)', - 'ExceptionBenchmark().enter_exit(500)', - 'ExceptionBenchmark().call_wrapped(500)', - ] - for cmd in cmds: - print(cmd) - subprocess.check_call(base_cmd + [cmd]) - - -if __name__ == '__main__': - main() diff --git a/demos/blog/Dockerfile b/demos/blog/Dockerfile index 9ba708f382..4e3c7250be 100644 --- a/demos/blog/Dockerfile +++ b/demos/blog/Dockerfile @@ -1,17 +1,13 @@ -FROM python:2.7 +FROM python:3.7 EXPOSE 8888 -RUN apt-get update && apt-get install -y mysql-client - -# based on python:2.7-onbuild, but if we use that image directly -# the above apt-get line runs too late. RUN mkdir -p /usr/src/app WORKDIR /usr/src/app COPY requirements.txt /usr/src/app/ -RUN pip install -r requirements.txt +RUN pip install --no-cache-dir -r requirements.txt -COPY . /usr/src/app +COPY . . -CMD python blog.py --mysql_host=mysql +ENTRYPOINT ["python3", "blog.py"] diff --git a/demos/blog/README b/demos/blog/README index 72f0774f39..f54ad0abc8 100644 --- a/demos/blog/README +++ b/demos/blog/README @@ -1,63 +1,65 @@ Running the Tornado Blog example app ==================================== -This demo is a simple blogging engine that uses MySQL to store posts and -Google Accounts for author authentication. Since it depends on MySQL, you -need to set up MySQL and the database schema for the demo to run. + +This demo is a simple blogging engine that uses a database to store posts. +You must have PostgreSQL or CockroachDB installed to run this demo. If you have `docker` and `docker-compose` installed, the demo and all its prerequisites can be installed with `docker-compose up`. -1. Install prerequisites and build tornado +1. Install a database if needed + + Consult the documentation at either https://www.postgresql.org or + https://www.cockroachlabs.com to install one of these databases for + your platform. + +2. Install Python prerequisites - See http://www.tornadoweb.org/ for installation instructions. If you can - run the "helloworld" example application, your environment is set up - correctly. + This demo requires Python 3.6 or newer, and the packages listed in + requirements.txt. Install them with `pip -r requirements.txt` -2. Install MySQL if needed +3. Create a database and user for the blog. - Consult the documentation for your platform. Under Ubuntu Linux you - can run "apt-get install mysql". Under OS X you can download the - MySQL PKG file from http://dev.mysql.com/downloads/mysql/ + Connect to the database with `psql -U postgres` (for PostgreSQL) or + `cockroach sql` (for CockroachDB). -3. Install Python prerequisites + Create a database and user, and grant permissions: - Install the packages MySQL-python, torndb, and markdown (e.g. using pip or - easy_install). Note that these packages currently only work on - Python 2. Tornado supports Python 3, but this blog demo does not. + CREATE DATABASE blog; + CREATE USER blog WITH PASSWORD 'blog'; + GRANT ALL ON DATABASE blog TO blog; -3. Connect to MySQL and create a database and user for the blog. + (If using CockroachDB in insecure mode, omit the `WITH PASSWORD 'blog'`) - Connect to MySQL as a user that can create databases and users: - mysql -u root +4. Create the tables in your new database (optional): - Create a database named "blog": - mysql> CREATE DATABASE blog; + The blog application will create its tables automatically when starting up. + It's also possible to create them separately. - Allow the "blog" user to connect with the password "blog": - mysql> GRANT ALL PRIVILEGES ON blog.* TO 'blog'@'localhost' IDENTIFIED BY 'blog'; + You can use the provided schema.sql file by running this command for PostgreSQL: -4. Create the tables in your new database. + psql -U blog -d blog < schema.sql - You can use the provided schema.sql file by running this command: - mysql --user=blog --password=blog --database=blog < schema.sql + Or this one for CockcroachDB: + + cockroach sql -u blog -d blog < schema.sql You can run the above command again later if you want to delete the contents of the blog and start over after testing. 5. Run the blog example - With the default user, password, and database you can just run: + For PostgreSQL, you can just run ./blog.py - If you've changed anything, you can alter the default MySQL settings - with arguments on the command line, e.g.: - ./blog.py --mysql_user=casey --mysql_password=happiness --mysql_database=foodblog + For CockroachDB, run + ./blog.py --db_port=26257 + + If you've changed anything from the defaults, use the other `--db_*` flags. 6. Visit your new blog - Open http://localhost:8888/ in your web browser. You will be redirected to - a Google account sign-in page because the blog uses Google accounts for - authentication. + Open http://localhost:8888/ in your web browser. Currently the first user to connect will automatically be given the ability to create and edit posts. diff --git a/demos/blog/blog.py b/demos/blog/blog.py index d629957270..a16ddf3e7f 100755 --- a/demos/blog/blog.py +++ b/demos/blog/blog.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # # Copyright 2009 Facebook # @@ -14,18 +14,16 @@ # License for the specific language governing permissions and limitations # under the License. +import aiopg import bcrypt -import concurrent.futures -import MySQLdb import markdown import os.path +import psycopg2 import re -import subprocess -import torndb import tornado.escape -from tornado import gen import tornado.httpserver import tornado.ioloop +import tornado.locks import tornado.options import tornado.web import unicodedata @@ -33,18 +31,32 @@ from tornado.options import define, options define("port", default=8888, help="run on the given port", type=int) -define("mysql_host", default="127.0.0.1:3306", help="blog database host") -define("mysql_database", default="blog", help="blog database name") -define("mysql_user", default="blog", help="blog database user") -define("mysql_password", default="blog", help="blog database password") +define("db_host", default="127.0.0.1", help="blog database host") +define("db_port", default=5432, help="blog database port") +define("db_database", default="blog", help="blog database name") +define("db_user", default="blog", help="blog database user") +define("db_password", default="blog", help="blog database password") -# A thread pool to be used for password hashing with bcrypt. -executor = concurrent.futures.ThreadPoolExecutor(2) +class NoResultError(Exception): + pass + + +async def maybe_create_tables(db): + try: + with (await db.cursor()) as cur: + await cur.execute("SELECT COUNT(*) FROM entries LIMIT 1") + await cur.fetchone() + except psycopg2.ProgrammingError: + with open("schema.sql") as f: + schema = f.read() + with (await db.cursor()) as cur: + await cur.execute(schema) class Application(tornado.web.Application): - def __init__(self): + def __init__(self, db): + self.db = db handlers = [ (r"/", HomeHandler), (r"/archive", ArchiveHandler), @@ -56,7 +68,7 @@ def __init__(self): (r"/auth/logout", AuthLogoutHandler), ] settings = dict( - blog_title=u"Tornado Blog", + blog_title="Tornado Blog", template_path=os.path.join(os.path.dirname(__file__), "templates"), static_path=os.path.join(os.path.dirname(__file__), "static"), ui_modules={"Entry": EntryModule}, @@ -65,45 +77,71 @@ def __init__(self): login_url="/auth/login", debug=True, ) - super(Application, self).__init__(handlers, **settings) - # Have one global connection to the blog DB across all handlers - self.db = torndb.Connection( - host=options.mysql_host, database=options.mysql_database, - user=options.mysql_user, password=options.mysql_password) - - self.maybe_create_tables() - - def maybe_create_tables(self): - try: - self.db.get("SELECT COUNT(*) from entries;") - except MySQLdb.ProgrammingError: - subprocess.check_call(['mysql', - '--host=' + options.mysql_host, - '--database=' + options.mysql_database, - '--user=' + options.mysql_user, - '--password=' + options.mysql_password], - stdin=open('schema.sql')) + super().__init__(handlers, **settings) class BaseHandler(tornado.web.RequestHandler): - @property - def db(self): - return self.application.db - - def get_current_user(self): + def row_to_obj(self, row, cur): + """Convert a SQL row to an object supporting dict and attribute access.""" + obj = tornado.util.ObjectDict() + for val, desc in zip(row, cur.description): + obj[desc.name] = val + return obj + + async def execute(self, stmt, *args): + """Execute a SQL statement. + + Must be called with ``await self.execute(...)`` + """ + with (await self.application.db.cursor()) as cur: + await cur.execute(stmt, args) + + async def query(self, stmt, *args): + """Query for a list of results. + + Typical usage:: + + results = await self.query(...) + + Or:: + + for row in await self.query(...) + """ + with (await self.application.db.cursor()) as cur: + await cur.execute(stmt, args) + return [self.row_to_obj(row, cur) for row in await cur.fetchall()] + + async def queryone(self, stmt, *args): + """Query for exactly one result. + + Raises NoResultError if there are no results, or ValueError if + there are more than one. + """ + results = await self.query(stmt, *args) + if len(results) == 0: + raise NoResultError() + elif len(results) > 1: + raise ValueError("Expected 1 result, got %d" % len(results)) + return results[0] + + async def prepare(self): + # get_current_user cannot be a coroutine, so set + # self.current_user in prepare instead. user_id = self.get_secure_cookie("blogdemo_user") - if not user_id: - return None - return self.db.get("SELECT * FROM authors WHERE id = %s", int(user_id)) + if user_id: + self.current_user = await self.queryone( + "SELECT * FROM authors WHERE id = %s", int(user_id) + ) - def any_author_exists(self): - return bool(self.db.get("SELECT * FROM authors LIMIT 1")) + async def any_author_exists(self): + return bool(await self.query("SELECT * FROM authors LIMIT 1")) class HomeHandler(BaseHandler): - def get(self): - entries = self.db.query("SELECT * FROM entries ORDER BY published " - "DESC LIMIT 5") + async def get(self): + entries = await self.query( + "SELECT * FROM entries ORDER BY published DESC LIMIT 5" + ) if not entries: self.redirect("/compose") return @@ -111,67 +149,80 @@ def get(self): class EntryHandler(BaseHandler): - def get(self, slug): - entry = self.db.get("SELECT * FROM entries WHERE slug = %s", slug) + async def get(self, slug): + entry = await self.queryone("SELECT * FROM entries WHERE slug = %s", slug) if not entry: raise tornado.web.HTTPError(404) self.render("entry.html", entry=entry) class ArchiveHandler(BaseHandler): - def get(self): - entries = self.db.query("SELECT * FROM entries ORDER BY published " - "DESC") + async def get(self): + entries = await self.query("SELECT * FROM entries ORDER BY published DESC") self.render("archive.html", entries=entries) class FeedHandler(BaseHandler): - def get(self): - entries = self.db.query("SELECT * FROM entries ORDER BY published " - "DESC LIMIT 10") + async def get(self): + entries = await self.query( + "SELECT * FROM entries ORDER BY published DESC LIMIT 10" + ) self.set_header("Content-Type", "application/atom+xml") self.render("feed.xml", entries=entries) class ComposeHandler(BaseHandler): @tornado.web.authenticated - def get(self): + async def get(self): id = self.get_argument("id", None) entry = None if id: - entry = self.db.get("SELECT * FROM entries WHERE id = %s", int(id)) + entry = await self.queryone("SELECT * FROM entries WHERE id = %s", int(id)) self.render("compose.html", entry=entry) @tornado.web.authenticated - def post(self): + async def post(self): id = self.get_argument("id", None) title = self.get_argument("title") text = self.get_argument("markdown") html = markdown.markdown(text) if id: - entry = self.db.get("SELECT * FROM entries WHERE id = %s", int(id)) - if not entry: + try: + entry = await self.queryone( + "SELECT * FROM entries WHERE id = %s", int(id) + ) + except NoResultError: raise tornado.web.HTTPError(404) slug = entry.slug - self.db.execute( + await self.execute( "UPDATE entries SET title = %s, markdown = %s, html = %s " - "WHERE id = %s", title, text, html, int(id)) + "WHERE id = %s", + title, + text, + html, + int(id), + ) else: - slug = unicodedata.normalize("NFKD", title).encode( - "ascii", "ignore") + slug = unicodedata.normalize("NFKD", title) slug = re.sub(r"[^\w]+", " ", slug) slug = "-".join(slug.lower().strip().split()) + slug = slug.encode("ascii", "ignore").decode("ascii") if not slug: slug = "entry" while True: - e = self.db.get("SELECT * FROM entries WHERE slug = %s", slug) + e = await self.query("SELECT * FROM entries WHERE slug = %s", slug) if not e: break slug += "-2" - self.db.execute( - "INSERT INTO entries (author_id,title,slug,markdown,html," - "published) VALUES (%s,%s,%s,%s,%s,UTC_TIMESTAMP())", - self.current_user.id, title, slug, text, html) + await self.execute( + "INSERT INTO entries (author_id,title,slug,markdown,html,published,updated)" + "VALUES (%s,%s,%s,%s,%s,CURRENT_TIMESTAMP,CURRENT_TIMESTAMP)", + self.current_user.id, + title, + slug, + text, + html, + ) self.redirect("/entry/" + slug) @@ -179,41 +230,49 @@ class AuthCreateHandler(BaseHandler): def get(self): self.render("create_author.html") - @gen.coroutine - def post(self): - if self.any_author_exists(): + async def post(self): + if await self.any_author_exists(): raise tornado.web.HTTPError(400, "author already created") - hashed_password = yield executor.submit( - bcrypt.hashpw, tornado.escape.utf8(self.get_argument("password")), - bcrypt.gensalt()) - author_id = self.db.execute( + hashed_password = await tornado.ioloop.IOLoop.current().run_in_executor( + None, + bcrypt.hashpw, + tornado.escape.utf8(self.get_argument("password")), + bcrypt.gensalt(), + ) + author = await self.queryone( "INSERT INTO authors (email, name, hashed_password) " - "VALUES (%s, %s, %s)", - self.get_argument("email"), self.get_argument("name"), - hashed_password) - self.set_secure_cookie("blogdemo_user", str(author_id)) + "VALUES (%s, %s, %s) RETURNING id", + self.get_argument("email"), + self.get_argument("name"), + tornado.escape.to_unicode(hashed_password), + ) + self.set_secure_cookie("blogdemo_user", str(author.id)) self.redirect(self.get_argument("next", "/")) class AuthLoginHandler(BaseHandler): - def get(self): + async def get(self): # If there are no authors, redirect to the account creation page. - if not self.any_author_exists(): + if not await self.any_author_exists(): self.redirect("/auth/create") else: self.render("login.html", error=None) - @gen.coroutine - def post(self): - author = self.db.get("SELECT * FROM authors WHERE email = %s", - self.get_argument("email")) - if not author: + async def post(self): + try: + author = await self.queryone( + "SELECT * FROM authors WHERE email = %s", self.get_argument("email") + ) + except NoResultError: self.render("login.html", error="email not found") return - hashed_password = yield executor.submit( - bcrypt.hashpw, tornado.escape.utf8(self.get_argument("password")), - tornado.escape.utf8(author.hashed_password)) - if hashed_password == author.hashed_password: + password_equal = await tornado.ioloop.IOLoop.current().run_in_executor( + None, + bcrypt.checkpw, + tornado.escape.utf8(self.get_argument("password")), + tornado.escape.utf8(author.hashed_password), + ) + if password_equal: self.set_secure_cookie("blogdemo_user", str(author.id)) self.redirect(self.get_argument("next", "/")) else: @@ -231,12 +290,27 @@ def render(self, entry): return self.render_string("modules/entry.html", entry=entry) -def main(): +async def main(): tornado.options.parse_command_line() - http_server = tornado.httpserver.HTTPServer(Application()) - http_server.listen(options.port) - tornado.ioloop.IOLoop.current().start() + + # Create the global connection pool. + async with aiopg.create_pool( + host=options.db_host, + port=options.db_port, + user=options.db_user, + password=options.db_password, + dbname=options.db_database, + ) as db: + await maybe_create_tables(db) + app = Application(db) + app.listen(options.port) + + # In this demo the server will simply run until interrupted + # with Ctrl-C, but if you want to shut down more gracefully, + # call shutdown_event.set(). + shutdown_event = tornado.locks.Event() + await shutdown_event.wait() if __name__ == "__main__": - main() + tornado.ioloop.IOLoop.current().run_sync(main) diff --git a/demos/blog/docker-compose.yml b/demos/blog/docker-compose.yml index 247c94beaf..95f8e84f4b 100644 --- a/demos/blog/docker-compose.yml +++ b/demos/blog/docker-compose.yml @@ -1,15 +1,15 @@ -mysql: - image: mysql:5.6 +postgres: + image: postgres:10.3 environment: - MYSQL_ROOT_PASSWORD: its_a_secret_to_everybody - MYSQL_USER: blog - MYSQL_PASSWORD: blog - MYSQL_DATABASE: blog + POSTGRES_USER: blog + POSTGRES_PASSWORD: blog + POSTGRES_DB: blog ports: - "3306" blog: build: . links: - - mysql + - postgres ports: - "8888:8888" + command: --db_host=postgres diff --git a/demos/blog/requirements.txt b/demos/blog/requirements.txt index 8669e33bd5..f4c727a021 100644 --- a/demos/blog/requirements.txt +++ b/demos/blog/requirements.txt @@ -1,6 +1,5 @@ +aiopg bcrypt -futures -MySQL-python markdown +psycopg2 tornado -torndb diff --git a/demos/blog/schema.sql b/demos/blog/schema.sql index a63e91fdef..1820f17720 100644 --- a/demos/blog/schema.sql +++ b/demos/blog/schema.sql @@ -14,32 +14,30 @@ -- To create the database: -- CREATE DATABASE blog; --- GRANT ALL PRIVILEGES ON blog.* TO 'blog'@'localhost' IDENTIFIED BY 'blog'; +-- CREATE USER blog WITH PASSWORD 'blog'; +-- GRANT ALL ON DATABASE blog TO blog; -- -- To reload the tables: --- mysql --user=blog --password=blog --database=blog < schema.sql +-- psql -U blog -d blog < schema.sql -SET SESSION storage_engine = "InnoDB"; -SET SESSION time_zone = "+0:00"; -ALTER DATABASE CHARACTER SET "utf8"; +DROP TABLE IF EXISTS authors; +CREATE TABLE authors ( + id SERIAL PRIMARY KEY, + email VARCHAR(100) NOT NULL UNIQUE, + name VARCHAR(100) NOT NULL, + hashed_password VARCHAR(100) NOT NULL +); DROP TABLE IF EXISTS entries; CREATE TABLE entries ( - id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + id SERIAL PRIMARY KEY, author_id INT NOT NULL REFERENCES authors(id), slug VARCHAR(100) NOT NULL UNIQUE, title VARCHAR(512) NOT NULL, - markdown MEDIUMTEXT NOT NULL, - html MEDIUMTEXT NOT NULL, - published DATETIME NOT NULL, - updated TIMESTAMP NOT NULL, - KEY (published) + markdown TEXT NOT NULL, + html TEXT NOT NULL, + published TIMESTAMP NOT NULL, + updated TIMESTAMP NOT NULL ); -DROP TABLE IF EXISTS authors; -CREATE TABLE authors ( - id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, - email VARCHAR(100) NOT NULL UNIQUE, - name VARCHAR(100) NOT NULL, - hashed_password VARCHAR(100) NOT NULL -); +CREATE INDEX ON entries (published); diff --git a/demos/chat/chatdemo.py b/demos/chat/chatdemo.py index 89149c4209..c109b222a2 100755 --- a/demos/chat/chatdemo.py +++ b/demos/chat/chatdemo.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # # Copyright 2009 Facebook # @@ -14,58 +14,45 @@ # License for the specific language governing permissions and limitations # under the License. -import logging +import asyncio import tornado.escape import tornado.ioloop +import tornado.locks import tornado.web import os.path import uuid -from tornado.concurrent import Future -from tornado import gen from tornado.options import define, options, parse_command_line define("port", default=8888, help="run on the given port", type=int) -define("debug", default=False, help="run in debug mode") +define("debug", default=True, help="run in debug mode") class MessageBuffer(object): def __init__(self): - self.waiters = set() + # cond is notified whenever the message cache is updated + self.cond = tornado.locks.Condition() self.cache = [] self.cache_size = 200 - def wait_for_messages(self, cursor=None): - # Construct a Future to return to our caller. This allows - # wait_for_messages to be yielded from a coroutine even though - # it is not a coroutine itself. We will set the result of the - # Future when results are available. - result_future = Future() - if cursor: - new_count = 0 - for msg in reversed(self.cache): - if msg["id"] == cursor: - break - new_count += 1 - if new_count: - result_future.set_result(self.cache[-new_count:]) - return result_future - self.waiters.add(result_future) - return result_future - - def cancel_wait(self, future): - self.waiters.remove(future) - # Set an empty result to unblock any coroutines waiting. - future.set_result([]) - - def new_messages(self, messages): - logging.info("Sending new message to %r listeners", len(self.waiters)) - for future in self.waiters: - future.set_result(messages) - self.waiters = set() - self.cache.extend(messages) + def get_messages_since(self, cursor): + """Returns a list of messages newer than the given cursor. + + ``cursor`` should be the ``id`` of the last message received. + """ + results = [] + for msg in reversed(self.cache): + if msg["id"] == cursor: + break + results.append(msg) + results.reverse() + return results + + def add_message(self, message): + self.cache.append(message) if len(self.cache) > self.cache_size: - self.cache = self.cache[-self.cache_size:] + self.cache = self.cache[-self.cache_size :] + self.cond.notify_all() # Making this a non-singleton is left as an exercise for the reader. @@ -78,36 +65,46 @@ def get(self): class MessageNewHandler(tornado.web.RequestHandler): + """Post a new message to the chat room.""" + def post(self): - message = { - "id": str(uuid.uuid4()), - "body": self.get_argument("body"), - } - # to_basestring is necessary for Python 3's json encoder, - # which doesn't accept byte strings. - message["html"] = tornado.escape.to_basestring( - self.render_string("message.html", message=message)) + message = {"id": str(uuid.uuid4()), "body": self.get_argument("body")} + # render_string() returns a byte string, which is not supported + # in json, so we must convert it to a character string. + message["html"] = tornado.escape.to_unicode( + self.render_string("message.html", message=message) + ) if self.get_argument("next", None): self.redirect(self.get_argument("next")) else: self.write(message) - global_message_buffer.new_messages([message]) + global_message_buffer.add_message(message) class MessageUpdatesHandler(tornado.web.RequestHandler): - @gen.coroutine - def post(self): + """Long-polling request for new messages. + + Waits until new messages are available before returning anything. + """ + + async def post(self): cursor = self.get_argument("cursor", None) - # Save the future returned by wait_for_messages so we can cancel - # it in wait_for_messages - self.future = global_message_buffer.wait_for_messages(cursor=cursor) - messages = yield self.future + messages = global_message_buffer.get_messages_since(cursor) + while not messages: + # Save the Future returned here so we can cancel it in + # on_connection_close. + self.wait_future = global_message_buffer.cond.wait() + try: + await self.wait_future + except asyncio.CancelledError: + return + messages = global_message_buffer.get_messages_since(cursor) if self.request.connection.stream.closed(): return self.write(dict(messages=messages)) def on_connection_close(self): - global_message_buffer.cancel_wait(self.future) + self.wait_future.cancel() def main(): diff --git a/demos/chat/static/chat.js b/demos/chat/static/chat.js index 151a5880bc..48a63c4137 100644 --- a/demos/chat/static/chat.js +++ b/demos/chat/static/chat.js @@ -116,7 +116,6 @@ var updater = { newMessages: function(response) { if (!response.messages) return; - updater.cursor = response.cursor; var messages = response.messages; updater.cursor = messages[messages.length - 1].id; console.log(messages.length, "new messages, cursor:", updater.cursor); diff --git a/demos/facebook/facebook.py b/demos/facebook/facebook.py index 2f3355928c..dc054b9a9d 100755 --- a/demos/facebook/facebook.py +++ b/demos/facebook/facebook.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # # Copyright 2009 Facebook # @@ -61,12 +61,10 @@ def get_current_user(self): class MainHandler(BaseHandler, tornado.auth.FacebookGraphMixin): @tornado.web.authenticated - @tornado.web.asynchronous - def get(self): - self.facebook_request("/me/home", self._on_stream, - access_token=self.current_user["access_token"]) - - def _on_stream(self, stream): + async def get(self): + stream = await self.facebook_request( + "/me/home", self._on_stream, access_token=self.current_user["access_token"] + ) if stream is None: # Session may have expired self.redirect("/auth/login") @@ -75,28 +73,29 @@ def _on_stream(self, stream): class AuthLoginHandler(BaseHandler, tornado.auth.FacebookGraphMixin): - @tornado.web.asynchronous - def get(self): - my_url = (self.request.protocol + "://" + self.request.host + - "/auth/login?next=" + - tornado.escape.url_escape(self.get_argument("next", "/"))) + async def get(self): + my_url = ( + self.request.protocol + + "://" + + self.request.host + + "/auth/login?next=" + + tornado.escape.url_escape(self.get_argument("next", "/")) + ) if self.get_argument("code", False): - self.get_authenticated_user( + user = await self.get_authenticated_user( redirect_uri=my_url, client_id=self.settings["facebook_api_key"], client_secret=self.settings["facebook_secret"], code=self.get_argument("code"), - callback=self._on_auth) + ) + self.set_secure_cookie("fbdemo_user", tornado.escape.json_encode(user)) + self.redirect(self.get_argument("next", "/")) return - self.authorize_redirect(redirect_uri=my_url, - client_id=self.settings["facebook_api_key"], - extra_params={"scope": "user_posts"}) - - def _on_auth(self, user): - if not user: - raise tornado.web.HTTPError(500, "Facebook auth failed") - self.set_secure_cookie("fbdemo_user", tornado.escape.json_encode(user)) - self.redirect(self.get_argument("next", "/")) + self.authorize_redirect( + redirect_uri=my_url, + client_id=self.settings["facebook_api_key"], + extra_params={"scope": "user_posts"}, + ) class AuthLogoutHandler(BaseHandler, tornado.auth.FacebookGraphMixin): diff --git a/demos/file_upload/file_receiver.py b/demos/file_upload/file_receiver.py index 3b3e98673a..53489704c4 100755 --- a/demos/file_upload/file_receiver.py +++ b/demos/file_upload/file_receiver.py @@ -25,12 +25,13 @@ class POSTHandler(tornado.web.RequestHandler): def post(self): for field_name, files in self.request.files.items(): for info in files: - filename, content_type = info['filename'], info['content_type'] - body = info['body'] - logging.info('POST "%s" "%s" %d bytes', - filename, content_type, len(body)) + filename, content_type = info["filename"], info["content_type"] + body = info["body"] + logging.info( + 'POST "%s" "%s" %d bytes', filename, content_type, len(body) + ) - self.write('OK') + self.write("OK") @tornado.web.stream_request_body @@ -43,16 +44,13 @@ def data_received(self, chunk): def put(self, filename): filename = unquote(filename) - mtype = self.request.headers.get('Content-Type') + mtype = self.request.headers.get("Content-Type") logging.info('PUT "%s" "%s" %d bytes', filename, mtype, self.bytes_read) - self.write('OK') + self.write("OK") def make_app(): - return tornado.web.Application([ - (r"/post", POSTHandler), - (r"/(.*)", PUTHandler), - ]) + return tornado.web.Application([(r"/post", POSTHandler), (r"/(.*)", PUTHandler)]) if __name__ == "__main__": diff --git a/demos/file_upload/file_uploader.py b/demos/file_upload/file_uploader.py index 9f1f84d516..f48991407f 100755 --- a/demos/file_upload/file_uploader.py +++ b/demos/file_upload/file_uploader.py @@ -9,7 +9,6 @@ See also file_receiver.py in this directory, a server that receives uploads. """ -from __future__ import print_function import mimetypes import os import sys @@ -34,16 +33,18 @@ def multipart_producer(boundary, filenames, write): for filename in filenames: filename_bytes = filename.encode() - mtype = mimetypes.guess_type(filename)[0] or 'application/octet-stream' + mtype = mimetypes.guess_type(filename)[0] or "application/octet-stream" buf = ( - (b'--%s\r\n' % boundary_bytes) + - (b'Content-Disposition: form-data; name="%s"; filename="%s"\r\n' % - (filename_bytes, filename_bytes)) + - (b'Content-Type: %s\r\n' % mtype.encode()) + - b'\r\n' + (b"--%s\r\n" % boundary_bytes) + + ( + b'Content-Disposition: form-data; name="%s"; filename="%s"\r\n' + % (filename_bytes, filename_bytes) + ) + + (b"Content-Type: %s\r\n" % mtype.encode()) + + b"\r\n" ) yield write(buf) - with open(filename, 'rb') as f: + with open(filename, "rb") as f: while True: # 16k at a time. chunk = f.read(16 * 1024) @@ -51,9 +52,9 @@ def multipart_producer(boundary, filenames, write): break yield write(chunk) - yield write(b'\r\n') + yield write(b"\r\n") - yield write(b'--%s--\r\n' % (boundary_bytes,)) + yield write(b"--%s--\r\n" % (boundary_bytes,)) # Using HTTP PUT, upload one raw file. This is preferred for large files since @@ -62,19 +63,21 @@ def multipart_producer(boundary, filenames, write): def post(filenames): client = httpclient.AsyncHTTPClient() boundary = uuid4().hex - headers = {'Content-Type': 'multipart/form-data; boundary=%s' % boundary} + headers = {"Content-Type": "multipart/form-data; boundary=%s" % boundary} producer = partial(multipart_producer, boundary, filenames) - response = yield client.fetch('http://localhost:8888/post', - method='POST', - headers=headers, - body_producer=producer) + response = yield client.fetch( + "http://localhost:8888/post", + method="POST", + headers=headers, + body_producer=producer, + ) print(response) @gen.coroutine def raw_producer(filename, write): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: while True: # 16K at a time. chunk = f.read(16 * 1024) @@ -89,14 +92,16 @@ def raw_producer(filename, write): def put(filenames): client = httpclient.AsyncHTTPClient() for filename in filenames: - mtype = mimetypes.guess_type(filename)[0] or 'application/octet-stream' - headers = {'Content-Type': mtype} + mtype = mimetypes.guess_type(filename)[0] or "application/octet-stream" + headers = {"Content-Type": mtype} producer = partial(raw_producer, filename) url_path = quote(os.path.basename(filename)) - response = yield client.fetch('http://localhost:8888/%s' % url_path, - method='PUT', - headers=headers, - body_producer=producer) + response = yield client.fetch( + "http://localhost:8888/%s" % url_path, + method="PUT", + headers=headers, + body_producer=producer, + ) print(response) diff --git a/demos/helloworld/helloworld.py b/demos/helloworld/helloworld.py index 06eac9290f..7f64e82405 100755 --- a/demos/helloworld/helloworld.py +++ b/demos/helloworld/helloworld.py @@ -31,9 +31,7 @@ def get(self): def main(): tornado.options.parse_command_line() - application = tornado.web.Application([ - (r"/", MainHandler), - ]) + application = tornado.web.Application([(r"/", MainHandler)]) http_server = tornado.httpserver.HTTPServer(application) http_server.listen(options.port) tornado.ioloop.IOLoop.current().start() diff --git a/demos/s3server/s3server.py b/demos/s3server/s3server.py index c01f06c469..11be1c2c24 100644 --- a/demos/s3server/s3server.py +++ b/demos/s3server/s3server.py @@ -42,14 +42,19 @@ from tornado import ioloop from tornado import web from tornado.util import unicode_type +from tornado.options import options, define try: long except NameError: long = int +define("port", default=9888, help="TCP port to listen on") +define("root_directory", default="/tmp/s3", help="Root storage directory") +define("bucket_depth", default=0, help="Bucket file system depth limit") -def start(port, root_directory="/tmp/s3", bucket_depth=0): + +def start(port, root_directory, bucket_depth): """Starts the mock S3 server on the given port at the given path.""" application = S3Application(root_directory, bucket_depth) http_server = httpserver.HTTPServer(application) @@ -64,12 +69,16 @@ class S3Application(web.Application): to prevent hitting file system limits for number of files in each directories. 1 means one level of directories, 2 means 2, etc. """ + def __init__(self, root_directory, bucket_depth=0): - web.Application.__init__(self, [ - (r"/", RootHandler), - (r"/([^/]+)/(.+)", ObjectHandler), - (r"/([^/]+)/", BucketHandler), - ]) + web.Application.__init__( + self, + [ + (r"/", RootHandler), + (r"/([^/]+)/(.+)", ObjectHandler), + (r"/([^/]+)/", BucketHandler), + ], + ) self.directory = os.path.abspath(root_directory) if not os.path.exists(self.directory): os.makedirs(self.directory) @@ -82,14 +91,12 @@ class BaseRequestHandler(web.RequestHandler): def render_xml(self, value): assert isinstance(value, dict) and len(value) == 1 self.set_header("Content-Type", "application/xml; charset=UTF-8") - name = value.keys()[0] + name = list(value.keys())[0] parts = [] - parts.append('<' + escape.utf8(name) + - ' xmlns="http://doc.s3.amazonaws.com/2006-03-01">') - self._render_parts(value.values()[0], parts) - parts.append('') - self.finish('\n' + - ''.join(parts)) + parts.append("<" + name + ' xmlns="http://doc.s3.amazonaws.com/2006-03-01">') + self._render_parts(value[name], parts) + parts.append("") + self.finish('\n' + "".join(parts)) def _render_parts(self, value, parts=[]): if isinstance(value, (unicode_type, bytes)): @@ -99,25 +106,25 @@ def _render_parts(self, value, parts=[]): elif isinstance(value, datetime.datetime): parts.append(value.strftime("%Y-%m-%dT%H:%M:%S.000Z")) elif isinstance(value, dict): - for name, subvalue in value.iteritems(): + for name, subvalue in value.items(): if not isinstance(subvalue, list): subvalue = [subvalue] for subsubvalue in subvalue: - parts.append('<' + escape.utf8(name) + '>') + parts.append("<" + name + ">") self._render_parts(subsubvalue, parts) - parts.append('') + parts.append("") else: raise Exception("Unknown S3 value type %r", value) def _object_path(self, bucket, object_name): if self.application.bucket_depth < 1: - return os.path.abspath(os.path.join( - self.application.directory, bucket, object_name)) + return os.path.abspath( + os.path.join(self.application.directory, bucket, object_name) + ) hash = hashlib.md5(object_name).hexdigest() - path = os.path.abspath(os.path.join( - self.application.directory, bucket)) + path = os.path.abspath(os.path.join(self.application.directory, bucket)) for i in range(self.application.bucket_depth): - path = os.path.join(path, hash[:2 * (i + 1)]) + path = os.path.join(path, hash[: 2 * (i + 1)]) return os.path.join(path, object_name) @@ -128,14 +135,13 @@ def get(self): for name in names: path = os.path.join(self.application.directory, name) info = os.stat(path) - buckets.append({ - "Name": name, - "CreationDate": datetime.datetime.utcfromtimestamp( - info.st_ctime), - }) - self.render_xml({"ListAllMyBucketsResult": { - "Buckets": {"Bucket": buckets}, - }}) + buckets.append( + { + "Name": name, + "CreationDate": datetime.datetime.utcfromtimestamp(info.st_ctime), + } + ) + self.render_xml({"ListAllMyBucketsResult": {"Buckets": {"Bucket": buckets}}}) class BucketHandler(BaseRequestHandler): @@ -143,11 +149,9 @@ def get(self, bucket_name): prefix = self.get_argument("prefix", u"") marker = self.get_argument("marker", u"") max_keys = int(self.get_argument("max-keys", 50000)) - path = os.path.abspath(os.path.join(self.application.directory, - bucket_name)) + path = os.path.abspath(os.path.join(self.application.directory, bucket_name)) terse = int(self.get_argument("terse", 0)) - if not path.startswith(self.application.directory) or \ - not os.path.isdir(path): + if not path.startswith(self.application.directory) or not os.path.isdir(path): raise web.HTTPError(404) object_names = [] for root, dirs, files in os.walk(path): @@ -177,36 +181,39 @@ def get(self, bucket_name): c = {"Key": object_name} if not terse: info = os.stat(object_path) - c.update({ - "LastModified": datetime.datetime.utcfromtimestamp( - info.st_mtime), - "Size": info.st_size, - }) + c.update( + { + "LastModified": datetime.datetime.utcfromtimestamp( + info.st_mtime + ), + "Size": info.st_size, + } + ) contents.append(c) marker = object_name - self.render_xml({"ListBucketResult": { - "Name": bucket_name, - "Prefix": prefix, - "Marker": marker, - "MaxKeys": max_keys, - "IsTruncated": truncated, - "Contents": contents, - }}) + self.render_xml( + { + "ListBucketResult": { + "Name": bucket_name, + "Prefix": prefix, + "Marker": marker, + "MaxKeys": max_keys, + "IsTruncated": truncated, + "Contents": contents, + } + } + ) def put(self, bucket_name): - path = os.path.abspath(os.path.join( - self.application.directory, bucket_name)) - if not path.startswith(self.application.directory) or \ - os.path.exists(path): + path = os.path.abspath(os.path.join(self.application.directory, bucket_name)) + if not path.startswith(self.application.directory) or os.path.exists(path): raise web.HTTPError(403) os.makedirs(path) self.finish() def delete(self, bucket_name): - path = os.path.abspath(os.path.join( - self.application.directory, bucket_name)) - if not path.startswith(self.application.directory) or \ - not os.path.isdir(path): + path = os.path.abspath(os.path.join(self.application.directory, bucket_name)) + if not path.startswith(self.application.directory) or not os.path.isdir(path): raise web.HTTPError(404) if len(os.listdir(path)) > 0: raise web.HTTPError(403) @@ -219,25 +226,22 @@ class ObjectHandler(BaseRequestHandler): def get(self, bucket, object_name): object_name = urllib.unquote(object_name) path = self._object_path(bucket, object_name) - if not path.startswith(self.application.directory) or \ - not os.path.isfile(path): + if not path.startswith(self.application.directory) or not os.path.isfile(path): raise web.HTTPError(404) info = os.stat(path) self.set_header("Content-Type", "application/unknown") - self.set_header("Last-Modified", datetime.datetime.utcfromtimestamp( - info.st_mtime)) - object_file = open(path, "rb") - try: + self.set_header( + "Last-Modified", datetime.datetime.utcfromtimestamp(info.st_mtime) + ) + with open(path, "rb") as object_file: self.finish(object_file.read()) - finally: - object_file.close() def put(self, bucket, object_name): object_name = urllib.unquote(object_name) - bucket_dir = os.path.abspath(os.path.join( - self.application.directory, bucket)) - if not bucket_dir.startswith(self.application.directory) or \ - not os.path.isdir(bucket_dir): + bucket_dir = os.path.abspath(os.path.join(self.application.directory, bucket)) + if not bucket_dir.startswith(self.application.directory) or not os.path.isdir( + bucket_dir + ): raise web.HTTPError(404) path = self._object_path(bucket, object_name) if not path.startswith(bucket_dir) or os.path.isdir(path): @@ -245,17 +249,20 @@ def put(self, bucket, object_name): directory = os.path.dirname(path) if not os.path.exists(directory): os.makedirs(directory) - object_file = open(path, "w") - object_file.write(self.request.body) - object_file.close() + with open(path, "w") as object_file: + object_file.write(self.request.body) self.finish() def delete(self, bucket, object_name): object_name = urllib.unquote(object_name) path = self._object_path(bucket, object_name) - if not path.startswith(self.application.directory) or \ - not os.path.isfile(path): + if not path.startswith(self.application.directory) or not os.path.isfile(path): raise web.HTTPError(404) os.unlink(path) self.set_status(204) self.finish() + + +if __name__ == "__main__": + options.parse_command_line() + start(options.port, options.root_directory, options.bucket_depth) diff --git a/demos/tcpecho/client.py b/demos/tcpecho/client.py index 51d3a8d5f2..a2ead08bc8 100755 --- a/demos/tcpecho/client.py +++ b/demos/tcpecho/client.py @@ -1,6 +1,5 @@ #!/usr/bin/env python -from __future__ import print_function from tornado.ioloop import IOLoop from tornado import gen from tornado.tcpclient import TCPClient diff --git a/demos/twitter/twitterdemo.py b/demos/twitter/twitterdemo.py index c674c65c2d..4bd3022531 100755 --- a/demos/twitter/twitterdemo.py +++ b/demos/twitter/twitterdemo.py @@ -26,22 +26,31 @@ from tornado.options import define, options, parse_command_line, parse_config_file from tornado.web import Application, RequestHandler, authenticated -define('port', default=8888, help="port to listen on") -define('config_file', default='secrets.cfg', - help='filename for additional configuration') - -define('debug', default=False, group='application', - help="run in debug mode (with automatic reloading)") +define("port", default=8888, help="port to listen on") +define( + "config_file", default="secrets.cfg", help="filename for additional configuration" +) + +define( + "debug", + default=False, + group="application", + help="run in debug mode (with automatic reloading)", +) # The following settings should probably be defined in secrets.cfg -define('twitter_consumer_key', type=str, group='application') -define('twitter_consumer_secret', type=str, group='application') -define('cookie_secret', type=str, group='application', - default='__TODO:_GENERATE_YOUR_OWN_RANDOM_VALUE__', - help="signing key for secure cookies") +define("twitter_consumer_key", type=str, group="application") +define("twitter_consumer_secret", type=str, group="application") +define( + "cookie_secret", + type=str, + group="application", + default="__TODO:_GENERATE_YOUR_OWN_RANDOM_VALUE__", + help="signing key for secure cookies", +) class BaseHandler(RequestHandler): - COOKIE_NAME = 'twitterdemo_user' + COOKIE_NAME = "twitterdemo_user" def get_current_user(self): user_json = self.get_secure_cookie(self.COOKIE_NAME) @@ -55,19 +64,19 @@ class MainHandler(BaseHandler, TwitterMixin): @gen.coroutine def get(self): timeline = yield self.twitter_request( - '/statuses/home_timeline', - access_token=self.current_user['access_token']) - self.render('home.html', timeline=timeline) + "/statuses/home_timeline", access_token=self.current_user["access_token"] + ) + self.render("home.html", timeline=timeline) class LoginHandler(BaseHandler, TwitterMixin): @gen.coroutine def get(self): - if self.get_argument('oauth_token', None): + if self.get_argument("oauth_token", None): user = yield self.get_authenticated_user() del user["description"] self.set_secure_cookie(self.COOKIE_NAME, json_encode(user)) - self.redirect(self.get_argument('next', '/')) + self.redirect(self.get_argument("next", "/")) else: yield self.authorize_redirect(callback_uri=self.request.full_url()) @@ -82,18 +91,15 @@ def main(): parse_config_file(options.config_file) app = Application( - [ - ('/', MainHandler), - ('/login', LoginHandler), - ('/logout', LogoutHandler), - ], - login_url='/login', - **options.group_dict('application')) + [("/", MainHandler), ("/login", LoginHandler), ("/logout", LogoutHandler)], + login_url="/login", + **options.group_dict("application") + ) app.listen(options.port) - logging.info('Listening on http://localhost:%d' % options.port) + logging.info("Listening on http://localhost:%d" % options.port) IOLoop.current().start() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/demos/websocket/chatdemo.py b/demos/websocket/chatdemo.py index a3fa2a69f0..1a7a3042c3 100755 --- a/demos/websocket/chatdemo.py +++ b/demos/websocket/chatdemo.py @@ -34,17 +34,14 @@ class Application(tornado.web.Application): def __init__(self): - handlers = [ - (r"/", MainHandler), - (r"/chatsocket", ChatSocketHandler), - ] + handlers = [(r"/", MainHandler), (r"/chatsocket", ChatSocketHandler)] settings = dict( cookie_secret="__TODO:_GENERATE_YOUR_OWN_RANDOM_VALUE_HERE__", template_path=os.path.join(os.path.dirname(__file__), "templates"), static_path=os.path.join(os.path.dirname(__file__), "static"), xsrf_cookies=True, ) - super(Application, self).__init__(handlers, **settings) + super().__init__(handlers, **settings) class MainHandler(tornado.web.RequestHandler): @@ -71,7 +68,7 @@ def on_close(self): def update_cache(cls, chat): cls.cache.append(chat) if len(cls.cache) > cls.cache_size: - cls.cache = cls.cache[-cls.cache_size:] + cls.cache = cls.cache[-cls.cache_size :] @classmethod def send_updates(cls, chat): @@ -85,12 +82,10 @@ def send_updates(cls, chat): def on_message(self, message): logging.info("got message %r", message) parsed = tornado.escape.json_decode(message) - chat = { - "id": str(uuid.uuid4()), - "body": parsed["body"], - } + chat = {"id": str(uuid.uuid4()), "body": parsed["body"]} chat["html"] = tornado.escape.to_basestring( - self.render_string("message.html", message=chat)) + self.render_string("message.html", message=chat) + ) ChatSocketHandler.update_cache(chat) ChatSocketHandler.send_updates(chat) diff --git a/demos/websocket/templates/index.html b/demos/websocket/templates/index.html index 91a4536394..d022ee750d 100644 --- a/demos/websocket/templates/index.html +++ b/demos/websocket/templates/index.html @@ -16,7 +16,7 @@
- +
diff --git a/demos/webspider/webspider.py b/demos/webspider/webspider.py index dd8e6b385b..16c3840fa7 100755 --- a/demos/webspider/webspider.py +++ b/demos/webspider/webspider.py @@ -1,42 +1,29 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 import time from datetime import timedelta -try: - from HTMLParser import HTMLParser - from urlparse import urljoin, urldefrag -except ImportError: - from html.parser import HTMLParser - from urllib.parse import urljoin, urldefrag +from html.parser import HTMLParser +from urllib.parse import urljoin, urldefrag -from tornado import httpclient, gen, ioloop, queues +from tornado import gen, httpclient, ioloop, queues -base_url = 'http://www.tornadoweb.org/en/stable/' +base_url = "http://www.tornadoweb.org/en/stable/" concurrency = 10 -@gen.coroutine -def get_links_from_url(url): +async def get_links_from_url(url): """Download the page at `url` and parse it for links. Returned links have had the fragment after `#` removed, and have been made absolute so, e.g. the URL 'gen.html#tornado.gen.coroutine' becomes 'http://www.tornadoweb.org/en/stable/gen.html'. """ - try: - response = yield httpclient.AsyncHTTPClient().fetch(url) - print('fetched %s' % url) + response = await httpclient.AsyncHTTPClient().fetch(url) + print("fetched %s" % url) - html = response.body if isinstance(response.body, str) \ - else response.body.decode(errors='ignore') - urls = [urljoin(url, remove_fragment(new_url)) - for new_url in get_links(html)] - except Exception as e: - print('Exception: %s %s' % (e, url)) - raise gen.Return([]) - - raise gen.Return(urls) + html = response.body.decode(errors="ignore") + return [urljoin(url, remove_fragment(new_url)) for new_url in get_links(html)] def remove_fragment(url): @@ -51,8 +38,8 @@ def __init__(self): self.urls = [] def handle_starttag(self, tag, attrs): - href = dict(attrs).get('href') - if href and tag == 'a': + href = dict(attrs).get("href") + if href and tag == "a": self.urls.append(href) url_seeker = URLSeeker() @@ -60,48 +47,52 @@ def handle_starttag(self, tag, attrs): return url_seeker.urls -@gen.coroutine -def main(): +async def main(): q = queues.Queue() start = time.time() - fetching, fetched = set(), set() - - @gen.coroutine - def fetch_url(): - current_url = yield q.get() - try: - if current_url in fetching: - return + fetching, fetched, dead = set(), set(), set() - print('fetching %s' % current_url) - fetching.add(current_url) - urls = yield get_links_from_url(current_url) - fetched.add(current_url) + async def fetch_url(current_url): + if current_url in fetching: + return - for new_url in urls: - # Only follow links beneath the base URL - if new_url.startswith(base_url): - yield q.put(new_url) + print("fetching %s" % current_url) + fetching.add(current_url) + urls = await get_links_from_url(current_url) + fetched.add(current_url) - finally: - q.task_done() + for new_url in urls: + # Only follow links beneath the base URL + if new_url.startswith(base_url): + await q.put(new_url) - @gen.coroutine - def worker(): - while True: - yield fetch_url() + async def worker(): + async for url in q: + if url is None: + return + try: + await fetch_url(url) + except Exception as e: + print("Exception: %s %s" % (e, url)) + dead.add(url) + finally: + q.task_done() - q.put(base_url) + await q.put(base_url) # Start workers, then wait for the work queue to be empty. + workers = gen.multi([worker() for _ in range(concurrency)]) + await q.join(timeout=timedelta(seconds=300)) + assert fetching == (fetched | dead) + print("Done in %d seconds, fetched %s URLs." % (time.time() - start, len(fetched))) + print("Unable to fetch %s URLS." % len(dead)) + + # Signal all the workers to exit. for _ in range(concurrency): - worker() - yield q.join(timeout=timedelta(seconds=300)) - assert fetching == fetched - print('Done in %d seconds, fetched %s URLs.' % ( - time.time() - start, len(fetched))) + await q.put(None) + await workers -if __name__ == '__main__': +if __name__ == "__main__": io_loop = ioloop.IOLoop.current() io_loop.run_sync(main) diff --git a/docs/caresresolver.rst b/docs/caresresolver.rst index b5d6ddd101..4e0058eac0 100644 --- a/docs/caresresolver.rst +++ b/docs/caresresolver.rst @@ -18,3 +18,7 @@ wrapper ``pycares``). so it is only recommended for use in ``AF_INET`` (i.e. IPv4). This is the default for ``tornado.simple_httpclient``, but other libraries may default to ``AF_UNSPEC``. + + .. deprecated:: 6.2 + This class is deprecated and will be removed in Tornado 7.0. Use the default + thread-based resolver instead. diff --git a/docs/concurrent.rst b/docs/concurrent.rst index 5904a60b89..f7a855a38f 100644 --- a/docs/concurrent.rst +++ b/docs/concurrent.rst @@ -11,9 +11,7 @@ .. class:: Future - ``tornado.concurrent.Future`` is an alias for `asyncio.Future` - on Python 3. On Python 2, it provides an equivalent - implementation. + ``tornado.concurrent.Future`` is an alias for `asyncio.Future`. In Tornado, the main way in which applications interact with ``Future`` objects is by ``awaiting`` or ``yielding`` them in diff --git a/docs/conf.py b/docs/conf.py index 39345d282f..efa1c01d03 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,14 +1,15 @@ -# Ensure we get the local copy of tornado instead of what's on the standard path import os +import sphinx.errors import sys -import time + +# Ensure we get the local copy of tornado instead of what's on the standard path sys.path.insert(0, os.path.abspath("..")) import tornado master_doc = "index" project = "Tornado" -copyright = "2009-%s, The Tornado Authors" % time.strftime("%Y") +copyright = "The Tornado Authors" version = release = tornado.version @@ -18,10 +19,11 @@ "sphinx.ext.doctest", "sphinx.ext.intersphinx", "sphinx.ext.viewcode", + "sphinxcontrib.asyncio", ] -primary_domain = 'py' -default_role = 'py:obj' +primary_domain = "py" +default_role = "py:obj" autodoc_member_order = "bysource" autoclass_content = "both" @@ -37,22 +39,18 @@ "tornado.platform.asyncio", "tornado.platform.caresresolver", "tornado.platform.twisted", + "tornado.simple_httpclient", ] # I wish this could go in a per-module file... coverage_ignore_classes = [ # tornado.gen "Runner", - - # tornado.ioloop - "PollIOLoop", - # tornado.web "ChunkedTransferEncoding", "GZipContentEncoding", "OutputTransform", "TemplateModule", "url", - # tornado.websocket "WebSocketProtocol", "WebSocketProtocol13", @@ -63,32 +61,83 @@ # various modules "doctests", "main", - # tornado.escape # parse_qs_bytes should probably be documented but it's complicated by # having different implementations between py2 and py3. "parse_qs_bytes", - # tornado.gen "Multi", ] -html_favicon = 'favicon.ico' +html_favicon = "favicon.ico" latex_documents = [ - ('index', 'tornado.tex', 'Tornado Documentation', 'The Tornado Authors', 'manual', False), + ( + "index", + "tornado.tex", + "Tornado Documentation", + "The Tornado Authors", + "manual", + False, + ) ] -intersphinx_mapping = { - 'python': ('https://docs.python.org/3.6/', None), -} +intersphinx_mapping = {"python": ("https://docs.python.org/3/", None)} -on_rtd = os.environ.get('READTHEDOCS', None) == 'True' +on_rtd = os.environ.get("READTHEDOCS", None) == "True" # On RTD we can't import sphinx_rtd_theme, but it will be applied by # default anyway. This block will use the same theme when building locally # as on RTD. if not on_rtd: import sphinx_rtd_theme - html_theme = 'sphinx_rtd_theme' + + html_theme = "sphinx_rtd_theme" html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] + +# Suppress warnings about "class reference target not found" for these types. +# In most cases these types come from type annotations and are for mypy's use. +missing_references = { + # Generic type variables; nothing to link to. + "_IOStreamType", + "_S", + "_T", + # Standard library types which are defined in one module and documented + # in another. We could probably remap them to their proper location if + # there's not an upstream fix in python and/or sphinx. + "_asyncio.Future", + "_io.BytesIO", + "asyncio.AbstractEventLoop.run_forever", + "asyncio.events.AbstractEventLoop", + "concurrent.futures._base.Executor", + "concurrent.futures._base.Future", + "futures.Future", + "socket.socket", + "TextIO", + # Other stuff. I'm not sure why some of these are showing up, but + # I'm just listing everything here to avoid blocking the upgrade of sphinx. + "Future", + "httputil.HTTPServerConnectionDelegate", + "httputil.HTTPServerRequest", + "OutputTransform", + "Pattern", + "RAISE", + "Rule", + "tornado.ioloop._Selectable", + "tornado.locks._ReleasingContextManager", + "tornado.options._Mockable", + "tornado.web._ArgDefaultMarker", + "tornado.web._HandlerDelegate", + "traceback", + "WSGIAppType", + "Yieldable", +} + + +def missing_reference_handler(app, env, node, contnode): + if node["reftarget"] in missing_references: + raise sphinx.errors.NoUri + + +def setup(app): + app.connect("missing-reference", missing_reference_handler) diff --git a/docs/escape.rst b/docs/escape.rst index 54f1ca9d2d..2a03eddb38 100644 --- a/docs/escape.rst +++ b/docs/escape.rst @@ -17,19 +17,15 @@ Byte/unicode conversions ------------------------ - These functions are used extensively within Tornado itself, - but should not be directly needed by most applications. Note that - much of the complexity of these functions comes from the fact that - Tornado supports both Python 2 and Python 3. .. autofunction:: utf8 .. autofunction:: to_unicode .. function:: native_str + .. function:: to_basestring - Converts a byte or unicode string into type `str`. Equivalent to - `utf8` on Python 2 and `to_unicode` on Python 3. - - .. autofunction:: to_basestring + Converts a byte or unicode string into type `str`. These functions + were used to help transition from Python 2 to Python 3 but are now + deprecated aliases for `to_unicode`. .. autofunction:: recursive_unicode diff --git a/docs/faq.rst b/docs/faq.rst index d45173c86b..1628073a32 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -28,14 +28,13 @@ No matter what the real code is doing, to achieve concurrency blocking code must be replaced with non-blocking equivalents. This means one of three things: 1. *Find a coroutine-friendly equivalent.* For `time.sleep`, use - `tornado.gen.sleep` instead:: + `tornado.gen.sleep` (or `asyncio.sleep`) instead:: class CoroutineSleepHandler(RequestHandler): - @gen.coroutine - def get(self): + async def get(self): for i in range(5): print(i) - yield gen.sleep(1) + await gen.sleep(1) When this option is available, it is usually the best approach. See the `Tornado wiki `_ @@ -44,16 +43,16 @@ code must be replaced with non-blocking equivalents. This means one of three thi 2. *Find a callback-based equivalent.* Similar to the first option, callback-based libraries are available for many tasks, although they are slightly more complicated to use than a library designed for - coroutines. These are typically used with `tornado.gen.Task` as an - adapter:: + coroutines. Adapt the callback-based function into a future:: class CoroutineTimeoutHandler(RequestHandler): - @gen.coroutine - def get(self): + async def get(self): io_loop = IOLoop.current() for i in range(5): print(i) - yield gen.Task(io_loop.add_timeout, io_loop.time() + 1) + f = tornado.concurrent.Future() + do_something_with_callback(f.set_result) + result = await f Again, the `Tornado wiki `_ @@ -65,21 +64,18 @@ code must be replaced with non-blocking equivalents. This means one of three thi that can be used for any blocking function whether an asynchronous counterpart exists or not:: - executor = concurrent.futures.ThreadPoolExecutor(8) - class ThreadPoolHandler(RequestHandler): - @gen.coroutine - def get(self): + async def get(self): for i in range(5): print(i) - yield executor.submit(time.sleep, 1) + await IOLoop.current().run_in_executor(None, time.sleep, 1) See the :doc:`Asynchronous I/O ` chapter of the Tornado user's guide for more on blocking and asynchronous functions. -My code is asynchronous, but it's not running in parallel in two browser tabs. ------------------------------------------------------------------------------- +My code is asynchronous. Why is it not running in parallel in two browser tabs? +------------------------------------------------------------------------------- Even when a handler is asynchronous and non-blocking, it can be surprisingly tricky to verify this. Browsers will recognize that you are trying to diff --git a/docs/gen.rst b/docs/gen.rst index 0f0bfdeefb..4cb5a4f434 100644 --- a/docs/gen.rst +++ b/docs/gen.rst @@ -13,26 +13,21 @@ .. autofunction:: coroutine - .. autofunction:: engine + .. autoexception:: Return Utility functions ----------------- - .. autoexception:: Return - - .. autofunction:: with_timeout + .. autofunction:: with_timeout(timeout: Union[float, datetime.timedelta], future: Yieldable, quiet_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = ()) .. autofunction:: sleep - .. autodata:: moment - :annotation: - .. autoclass:: WaitIterator :members: - .. autofunction:: multi + .. autofunction:: multi(Union[List[Yieldable], Dict[Any, Yieldable]], quiet_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = ()) - .. autofunction:: multi_future + .. autofunction:: multi_future(Union[List[Yieldable], Dict[Any, Yieldable]], quiet_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = ()) .. autofunction:: convert_yielded @@ -40,37 +35,5 @@ .. autofunction:: is_coroutine_function - Legacy interface - ---------------- - - Before support for `Futures <.Future>` was introduced in Tornado 3.0, - coroutines used subclasses of `YieldPoint` in their ``yield`` expressions. - These classes are still supported but should generally not be used - except for compatibility with older interfaces. None of these classes - are compatible with native (``await``-based) coroutines. - - .. autoclass:: YieldPoint - :members: - - .. autoclass:: Callback - - .. autoclass:: Wait - - .. autoclass:: WaitAll - - .. autoclass:: MultiYieldPoint - - .. autofunction:: Task - - .. class:: Arguments - - The result of a `Task` or `Wait` whose callback had more than one - argument (or keyword arguments). - - The `Arguments` object is a `collections.namedtuple` and can be - used either as a tuple ``(args, kwargs)`` or an object with attributes - ``args`` and ``kwargs``. - - .. deprecated:: 5.1 - - This class will be removed in 6.0. + .. autodata:: moment + :annotation: diff --git a/docs/guide/async.rst b/docs/guide/async.rst index 60f8a23b34..1b5526ffc5 100644 --- a/docs/guide/async.rst +++ b/docs/guide/async.rst @@ -53,6 +53,12 @@ transparent to its callers (systems like `gevent comparable to asynchronous systems, but they do not actually make things asynchronous). +Asynchronous operations in Tornado generally return placeholder +objects (``Futures``), with the exception of some low-level components +like the `.IOLoop` that use callbacks. ``Futures`` are usually +transformed into their result with the ``await`` or ``yield`` +keywords. + Examples ~~~~~~~~ @@ -70,65 +76,59 @@ Here is a sample synchronous function: .. testoutput:: :hide: -And here is the same function rewritten to be asynchronous with a -callback argument: +And here is the same function rewritten asynchronously as a native coroutine: .. testcode:: - from tornado.httpclient import AsyncHTTPClient + from tornado.httpclient import AsyncHTTPClient - def asynchronous_fetch(url, callback): - http_client = AsyncHTTPClient() - def handle_response(response): - callback(response.body) - http_client.fetch(url, callback=handle_response) + async def asynchronous_fetch(url): + http_client = AsyncHTTPClient() + response = await http_client.fetch(url) + return response.body .. testoutput:: :hide: -And again with a `.Future` instead of a callback: +Or for compatibility with older versions of Python, using the `tornado.gen` module: -.. testcode:: +.. testcode:: - from tornado.concurrent import Future + from tornado.httpclient import AsyncHTTPClient + from tornado import gen - def async_fetch_future(url): + @gen.coroutine + def async_fetch_gen(url): http_client = AsyncHTTPClient() - my_future = Future() - fetch_future = http_client.fetch(url) - fetch_future.add_done_callback( - lambda f: my_future.set_result(f.result())) - return my_future - -.. testoutput:: - :hide: + response = yield http_client.fetch(url) + raise gen.Return(response.body) -The raw `.Future` version is more complex, but ``Futures`` are -nonetheless recommended practice in Tornado because they have two -major advantages. Error handling is more consistent since the -``Future.result`` method can simply raise an exception (as opposed to -the ad-hoc error handling common in callback-oriented interfaces), and -``Futures`` lend themselves well to use with coroutines. Coroutines -will be discussed in depth in the next section of this guide. Here is -the coroutine version of our sample function, which is very similar to -the original synchronous version: +Coroutines are a little magical, but what they do internally is something like this: .. testcode:: - from tornado import gen + from tornado.concurrent import Future - @gen.coroutine - def fetch_coroutine(url): + def async_fetch_manual(url): http_client = AsyncHTTPClient() - response = yield http_client.fetch(url) - raise gen.Return(response.body) + my_future = Future() + fetch_future = http_client.fetch(url) + def on_fetch(f): + my_future.set_result(f.result().body) + fetch_future.add_done_callback(on_fetch) + return my_future .. testoutput:: :hide: -The statement ``raise gen.Return(response.body)`` is an artifact of -Python 2, in which generators aren't allowed to return -values. To overcome this, Tornado coroutines raise a special kind of -exception called a `.Return`. The coroutine catches this exception and -treats it like a returned value. In Python 3.3 and later, a ``return -response.body`` achieves the same result. +Notice that the coroutine returns its `.Future` before the fetch is +done. This is what makes coroutines *asynchronous*. + +Anything you can do with coroutines you can also do by passing +callback objects around, but coroutines provide an important +simplification by letting you organize your code in the same way you +would if it were synchronous. This is especially important for error +handling, since ``try``/``except`` blocks work as you would expect in +coroutines while this is difficult to achieve with callbacks. +Coroutines will be discussed in depth in the next section of this +guide. diff --git a/docs/guide/coroutines.rst b/docs/guide/coroutines.rst index 8089c4f886..795631bf74 100644 --- a/docs/guide/coroutines.rst +++ b/docs/guide/coroutines.rst @@ -6,9 +6,9 @@ Coroutines from tornado import gen **Coroutines** are the recommended way to write asynchronous code in -Tornado. Coroutines use the Python ``yield`` keyword to suspend and -resume execution instead of a chain of callbacks (cooperative -lightweight threads as seen in frameworks like `gevent +Tornado. Coroutines use the Python ``await`` or ``yield`` keyword to +suspend and resume execution instead of a chain of callbacks +(cooperative lightweight threads as seen in frameworks like `gevent `_ are sometimes called coroutines as well, but in Tornado all coroutines use explicit context switches and are called as asynchronous functions). @@ -21,53 +21,80 @@ happen. Example:: - from tornado import gen - - @gen.coroutine - def fetch_coroutine(url): + async def fetch_coroutine(url): http_client = AsyncHTTPClient() - response = yield http_client.fetch(url) - # In Python versions prior to 3.3, returning a value from - # a generator is not allowed and you must use - # raise gen.Return(response.body) - # instead. + response = await http_client.fetch(url) return response.body .. _native_coroutines: -Python 3.5: ``async`` and ``await`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Native vs decorated coroutines +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Python 3.5 introduces the ``async`` and ``await`` keywords (functions -using these keywords are also called "native coroutines"). Starting in -Tornado 4.3, you can use them in place of most ``yield``-based -coroutines (see the following paragraphs for limitations). Simply use -``async def foo()`` in place of a function definition with the -``@gen.coroutine`` decorator, and ``await`` in place of yield. The -rest of this document still uses the ``yield`` style for compatibility -with older versions of Python, but ``async`` and ``await`` will run -faster when they are available:: +Python 3.5 introduced the ``async`` and ``await`` keywords (functions +using these keywords are also called "native coroutines"). For +compatibility with older versions of Python, you can use "decorated" +or "yield-based" coroutines using the `tornado.gen.coroutine` +decorator. - async def fetch_coroutine(url): - http_client = AsyncHTTPClient() - response = await http_client.fetch(url) - return response.body +Native coroutines are the recommended form whenever possible. Only use +decorated coroutines when compatibility with older versions of Python +is required. Examples in the Tornado documentation will generally use +the native form. -The ``await`` keyword is less versatile than the ``yield`` keyword. -For example, in a ``yield``-based coroutine you can yield a list of -``Futures``, while in a native coroutine you must wrap the list in -`tornado.gen.multi`. This also eliminates the integration with -`concurrent.futures`. You can use `tornado.gen.convert_yielded` -to convert anything that would work with ``yield`` into a form that -will work with ``await``:: +Translation between the two forms is generally straightforward:: - async def f(): - executor = concurrent.futures.ThreadPoolExecutor() - await tornado.gen.convert_yielded(executor.submit(g)) + # Decorated: # Native: + + # Normal function declaration + # with decorator # "async def" keywords + @gen.coroutine + def a(): async def a(): + # "yield" all async funcs # "await" all async funcs + b = yield c() b = await c() + # "return" and "yield" + # cannot be mixed in + # Python 2, so raise a + # special exception. # Return normally + raise gen.Return(b) return b + +Other differences between the two forms of coroutine are outlined below. + +- Native coroutines: + + - are generally faster. + - can use ``async for`` and ``async with`` + statements which make some patterns much simpler. + - do not run at all unless you ``await`` or + ``yield`` them. Decorated coroutines can start running "in the + background" as soon as they are called. Note that for both kinds of + coroutines it is important to use ``await`` or ``yield`` so that + any exceptions have somewhere to go. + +- Decorated coroutines: + + - have additional integration with the + `concurrent.futures` package, allowing the result of + ``executor.submit`` to be yielded directly. For native coroutines, + use `.IOLoop.run_in_executor` instead. + - support some shorthand for waiting on multiple + objects by yielding a list or dict. Use `tornado.gen.multi` to do + this in native coroutines. + - can support integration with other packages + including Twisted via a registry of conversion functions. + To access this functionality in native coroutines, use + `tornado.gen.convert_yielded`. + - always return a `.Future` object. Native + coroutines return an *awaitable* object that is not a `.Future`. In + Tornado the two are mostly interchangeable. How it works ~~~~~~~~~~~~ +This section explains the operation of decorated coroutines. Native +coroutines are conceptually similar, but a little more complicated +because of the extra integration with the Python runtime. + A function containing ``yield`` is a **generator**. All generators are asynchronous; when called they return a generator object instead of running to completion. The ``@gen.coroutine`` decorator @@ -97,12 +124,11 @@ How to call a coroutine ~~~~~~~~~~~~~~~~~~~~~~~ Coroutines do not raise exceptions in the normal way: any exception -they raise will be trapped in the `.Future` until it is yielded. This -means it is important to call coroutines in the right way, or you may -have errors that go unnoticed:: +they raise will be trapped in the awaitable object until it is +yielded. This means it is important to call coroutines in the right +way, or you may have errors that go unnoticed:: - @gen.coroutine - def divide(x, y): + async def divide(x, y): return x / y def bad_call(): @@ -111,17 +137,16 @@ have errors that go unnoticed:: divide(1, 0) In nearly all cases, any function that calls a coroutine must be a -coroutine itself, and use the ``yield`` keyword in the call. When you -are overriding a method defined in a superclass, consult the -documentation to see if coroutines are allowed (the documentation -should say that the method "may be a coroutine" or "may return a -`.Future`"):: - - @gen.coroutine - def good_call(): - # yield will unwrap the Future returned by divide() and raise +coroutine itself, and use the ``await`` or ``yield`` keyword in the +call. When you are overriding a method defined in a superclass, +consult the documentation to see if coroutines are allowed (the +documentation should say that the method "may be a coroutine" or "may +return a `.Future`"):: + + async def good_call(): + # await will unwrap the object returned by divide() and raise # the exception. - yield divide(1, 0) + await divide(1, 0) Sometimes you may want to "fire and forget" a coroutine without waiting for its result. In this case it is recommended to use `.IOLoop.spawn_callback`, @@ -156,49 +181,68 @@ The simplest way to call a blocking function from a coroutine is to use `.IOLoop.run_in_executor`, which returns ``Futures`` that are compatible with coroutines:: - @gen.coroutine - def call_blocking(): - yield IOLoop.current().run_in_executor(blocking_func, args) + async def call_blocking(): + await IOLoop.current().run_in_executor(None, blocking_func, args) Parallelism ^^^^^^^^^^^ -The `.coroutine` decorator recognizes lists and dicts whose values are +The `.multi` function accepts lists and dicts whose values are ``Futures``, and waits for all of those ``Futures`` in parallel: .. testcode:: - @gen.coroutine - def parallel_fetch(url1, url2): - resp1, resp2 = yield [http_client.fetch(url1), - http_client.fetch(url2)] + from tornado.gen import multi - @gen.coroutine - def parallel_fetch_many(urls): - responses = yield [http_client.fetch(url) for url in urls] + async def parallel_fetch(url1, url2): + resp1, resp2 = await multi([http_client.fetch(url1), + http_client.fetch(url2)]) + + async def parallel_fetch_many(urls): + responses = await multi ([http_client.fetch(url) for url in urls]) # responses is a list of HTTPResponses in the same order - @gen.coroutine - def parallel_fetch_dict(urls): - responses = yield {url: http_client.fetch(url) - for url in urls} + async def parallel_fetch_dict(urls): + responses = await multi({url: http_client.fetch(url) + for url in urls}) # responses is a dict {url: HTTPResponse} .. testoutput:: :hide: -Lists and dicts must be wrapped in `tornado.gen.multi` for use with -``await``:: +In decorated coroutines, it is possible to ``yield`` the list or dict directly:: - async def parallel_fetch(url1, url2): - resp1, resp2 = await gen.multi([http_client.fetch(url1), - http_client.fetch(url2)]) + @gen.coroutine + def parallel_fetch_decorated(url1, url2): + resp1, resp2 = yield [http_client.fetch(url1), + http_client.fetch(url2)] Interleaving ^^^^^^^^^^^^ Sometimes it is useful to save a `.Future` instead of yielding it -immediately, so you can start another operation before waiting: +immediately, so you can start another operation before waiting. + +.. testcode:: + + from tornado.gen import convert_yielded + + async def get(self): + # convert_yielded() starts the native coroutine in the background. + # This is equivalent to asyncio.ensure_future() (both work in Tornado). + fetch_future = convert_yielded(self.fetch_next_chunk()) + while True: + chunk = await fetch_future + if chunk is None: break + self.write(chunk) + fetch_future = convert_yielded(self.fetch_next_chunk()) + await self.flush() + +.. testoutput:: + :hide: + +This is a little easier to do with decorated coroutines, because they +start immediately when called: .. testcode:: @@ -215,12 +259,6 @@ immediately, so you can start another operation before waiting: .. testoutput:: :hide: -This pattern is most usable with ``@gen.coroutine``. If -``fetch_next_chunk()`` uses ``async def``, then it must be called as -``fetch_future = -tornado.gen.convert_yielded(self.fetch_next_chunk())`` to start the -background processing. - Looping ^^^^^^^ @@ -247,11 +285,10 @@ Running in the background coroutine can contain a ``while True:`` loop and use `tornado.gen.sleep`:: - @gen.coroutine - def minute_loop(): + async def minute_loop(): while True: - yield do_something() - yield gen.sleep(60) + await do_something() + await gen.sleep(60) # Coroutines that loop forever are generally started with # spawn_callback(). @@ -262,9 +299,8 @@ previous loop runs every ``60+N`` seconds, where ``N`` is the running time of ``do_something()``. To run exactly every 60 seconds, use the interleaving pattern from above:: - @gen.coroutine - def minute_loop2(): + async def minute_loop2(): while True: nxt = gen.sleep(60) # Start the clock. - yield do_something() # Run while the clock is ticking. - yield nxt # Wait for the timer to run out. + await do_something() # Run while the clock is ticking. + await nxt # Wait for the timer to run out. diff --git a/docs/guide/intro.rst b/docs/guide/intro.rst index d17e74420f..8d87ba62b2 100644 --- a/docs/guide/intro.rst +++ b/docs/guide/intro.rst @@ -20,12 +20,13 @@ Tornado can be roughly divided into four major components: components and can also be used to implement other protocols. * A coroutine library (`tornado.gen`) which allows asynchronous code to be written in a more straightforward way than chaining - callbacks. + callbacks. This is similar to the native coroutine feature introduced + in Python 3.5 (``async def``). Native coroutines are recommended + in place of the `tornado.gen` module when available. The Tornado web framework and HTTP server together offer a full-stack alternative to `WSGI `_. -While it is possible to use the Tornado web framework in a WSGI -container (`.WSGIAdapter`), or use the Tornado HTTP server as a -container for other WSGI frameworks (`.WSGIContainer`), each of these -combinations has limitations and to take full advantage of Tornado you -will need to use the Tornado's web framework and HTTP server together. +While it is possible to use the Tornado HTTP server as a container for +other WSGI frameworks (`.WSGIContainer`), this combination has +limitations and to take full advantage of Tornado you will need to use +Tornado's web framework and HTTP server together. diff --git a/docs/guide/running.rst b/docs/guide/running.rst index e7c5b3494a..8cf34f0502 100644 --- a/docs/guide/running.rst +++ b/docs/guide/running.rst @@ -22,8 +22,9 @@ configuring a WSGI container to find your application, you write a Configure your operating system or process manager to run this program to start the server. Please note that it may be necessary to increase the number of open files per process (to avoid "Too many open files"-Error). -To raise this limit (setting it to 50000 for example) you can use the ulimit command, -modify /etc/security/limits.conf or setting ``minfds`` in your supervisord config. +To raise this limit (setting it to 50000 for example) you can use the +``ulimit`` command, modify ``/etc/security/limits.conf`` or set +``minfds`` in your `supervisord `_ config. Processes and ports ~~~~~~~~~~~~~~~~~~~ @@ -33,8 +34,9 @@ multiple Python processes to take full advantage of multi-CPU machines. Typically it is best to run one process per CPU. Tornado includes a built-in multi-process mode to start several -processes at once. This requires a slight alteration to the standard -main function: +processes at once (note that multi-process mode does not work on +Windows). This requires a slight alteration to the standard main +function: .. testcode:: @@ -50,8 +52,8 @@ main function: This is the easiest way to start multiple processes and have them all share the same port, although it has some limitations. First, each -child process will have its own IOLoop, so it is important that -nothing touch the global IOLoop instance (even indirectly) before the +child process will have its own ``IOLoop``, so it is important that +nothing touches the global ``IOLoop`` instance (even indirectly) before the fork. Second, it is difficult to do zero-downtime updates in this model. Finally, since all the processes share the same port it is more difficult to monitor them individually. @@ -67,10 +69,10 @@ to present a single address to outside visitors. Running behind a load balancer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -When running behind a load balancer like nginx, it is recommended to -pass ``xheaders=True`` to the `.HTTPServer` constructor. This will tell -Tornado to use headers like ``X-Real-IP`` to get the user's IP address -instead of attributing all traffic to the balancer's IP address. +When running behind a load balancer like `nginx `_, +it is recommended to pass ``xheaders=True`` to the `.HTTPServer` constructor. +This will tell Tornado to use headers like ``X-Real-IP`` to get the user's +IP address instead of attributing all traffic to the balancer's IP address. This is a barebones nginx config file that is structurally similar to the one we use at FriendFeed. It assumes nginx and the Tornado servers @@ -169,7 +171,7 @@ You can serve static files from Tornado by specifying the ], **settings) This setting will automatically make all requests that start with -``/static/`` serve from that static directory, e.g., +``/static/`` serve from that static directory, e.g. ``http://localhost:8888/static/foo.png`` will serve the file ``foo.png`` from the specified static directory. We also automatically serve ``/robots.txt`` and ``/favicon.ico`` from the static directory @@ -248,7 +250,7 @@ individual flag takes precedence): server down in a way that debug mode cannot currently recover from. * ``compiled_template_cache=False``: Templates will not be cached. * ``static_hash_cache=False``: Static file hashes (used by the - ``static_url`` function) will not be cached + ``static_url`` function) will not be cached. * ``serve_traceback=True``: When an exception in a `.RequestHandler` is not caught, an error page including a stack trace will be generated. @@ -273,41 +275,3 @@ On some platforms (including Windows and Mac OSX prior to 10.6), the process cannot be updated "in-place", so when a code change is detected the old server exits and a new one starts. This has been known to confuse some IDEs. - - -WSGI and Google App Engine -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Tornado is normally intended to be run on its own, without a WSGI -container. However, in some environments (such as Google App Engine), -only WSGI is allowed and applications cannot run their own servers. -In this case Tornado supports a limited mode of operation that does -not support asynchronous operation but allows a subset of Tornado's -functionality in a WSGI-only environment. The features that are -not allowed in WSGI mode include coroutines, the ``@asynchronous`` -decorator, `.AsyncHTTPClient`, the ``auth`` module, and WebSockets. - -You can convert a Tornado `.Application` to a WSGI application -with `tornado.wsgi.WSGIAdapter`. In this example, configure -your WSGI container to find the ``application`` object: - -.. testcode:: - - import tornado.web - import tornado.wsgi - - class MainHandler(tornado.web.RequestHandler): - def get(self): - self.write("Hello, world") - - tornado_app = tornado.web.Application([ - (r"/", MainHandler), - ]) - application = tornado.wsgi.WSGIAdapter(tornado_app) - -.. testoutput:: - :hide: - -See the `appengine example application -`_ for a -full-featured AppEngine app built on Tornado. diff --git a/docs/guide/security.rst b/docs/guide/security.rst index f71f323aed..b65cd3f370 100644 --- a/docs/guide/security.rst +++ b/docs/guide/security.rst @@ -169,7 +169,7 @@ not be appropriate for non-browser-based login schemes. Check out the `Tornado Blog example application `_ for a complete example that uses authentication (and stores user data in a -MySQL database). +PostgreSQL database). Third party authentication ~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -189,15 +189,14 @@ the Google credentials in a cookie for later access: class GoogleOAuth2LoginHandler(tornado.web.RequestHandler, tornado.auth.GoogleOAuth2Mixin): - @tornado.gen.coroutine - def get(self): + async def get(self): if self.get_argument('code', False): - user = yield self.get_authenticated_user( + user = await self.get_authenticated_user( redirect_uri='http://your.site.com/auth/google', code=self.get_argument('code')) # Save the user with e.g. set_secure_cookie else: - yield self.authorize_redirect( + await self.authorize_redirect( redirect_uri='http://your.site.com/auth/google', client_id=self.settings['google_oauth']['key'], scope=['profile', 'email'], @@ -280,7 +279,7 @@ all requests:: For ``PUT`` and ``DELETE`` requests (as well as ``POST`` requests that do not use form-encoded arguments), the XSRF token may also be passed via an HTTP header named ``X-XSRFToken``. The XSRF cookie is normally -set when ``xsrf_form_html`` is used, but in a pure-Javascript application +set when ``xsrf_form_html`` is used, but in a pure-JavaScript application that does not use any regular forms you may need to access ``self.xsrf_token`` manually (just reading the property is enough to set the cookie as a side effect). diff --git a/docs/guide/structure.rst b/docs/guide/structure.rst index 0f20d33864..407edf414b 100644 --- a/docs/guide/structure.rst +++ b/docs/guide/structure.rst @@ -178,7 +178,7 @@ In addition to ``get()``/``post()``/etc, certain other methods in necessary. On every request, the following sequence of calls takes place: -1. A new `.RequestHandler` object is created on each request +1. A new `.RequestHandler` object is created on each request. 2. `~.RequestHandler.initialize()` is called with the initialization arguments from the `.Application` configuration. ``initialize`` should typically just save the arguments passed into member @@ -193,9 +193,8 @@ place: etc. If the URL regular expression contains capturing groups, they are passed as arguments to this method. 5. When the request is finished, `~.RequestHandler.on_finish()` is - called. For synchronous handlers this is immediately after - ``get()`` (etc) return; for asynchronous handlers it is after the - call to `~.RequestHandler.finish()`. + called. This is generally after ``get()`` or another HTTP method + returns. All methods designed to be overridden are noted as such in the `.RequestHandler` documentation. Some of the most commonly @@ -207,12 +206,12 @@ overridden methods include: disconnects; applications may choose to detect this case and halt further processing. Note that there is no guarantee that a closed connection can be detected promptly. -- `~.RequestHandler.get_current_user` - see :ref:`user-authentication` +- `~.RequestHandler.get_current_user` - see :ref:`user-authentication`. - `~.RequestHandler.get_user_locale` - returns `.Locale` object to use - for the current user + for the current user. - `~.RequestHandler.set_default_headers` - may be used to set additional headers on the response (such as a custom ``Server`` - header) + header). Error Handling ~~~~~~~~~~~~~~ @@ -261,7 +260,7 @@ redirect users elsewhere. There is also an optional parameter considered permanent. The default value of ``permanent`` is ``False``, which generates a ``302 Found`` HTTP response code and is appropriate for things like redirecting users after successful -``POST`` requests. If ``permanent`` is true, the ``301 Moved +``POST`` requests. If ``permanent`` is ``True``, the ``301 Moved Permanently`` HTTP response code is used, which is useful for e.g. redirecting to a canonical URL for a page in an SEO-friendly manner. @@ -295,65 +294,18 @@ To send a temporary redirect with a `.RedirectHandler`, add Asynchronous handlers ~~~~~~~~~~~~~~~~~~~~~ -Tornado handlers are synchronous by default: when the -``get()``/``post()`` method returns, the request is considered -finished and the response is sent. Since all other requests are -blocked while one handler is running, any long-running handler should -be made asynchronous so it can call its slow operations in a -non-blocking way. This topic is covered in more detail in -:doc:`async`; this section is about the particulars of -asynchronous techniques in `.RequestHandler` subclasses. - -The simplest way to make a handler asynchronous is to use the -`.coroutine` decorator or ``async def``. This allows you to perform -non-blocking I/O with the ``yield`` or ``await`` keywords, and no -response will be sent until the coroutine has returned. See -:doc:`coroutines` for more details. - -In some cases, coroutines may be less convenient than a -callback-oriented style, in which case the `.tornado.web.asynchronous` -decorator can be used instead. When this decorator is used the response -is not automatically sent; instead the request will be kept open until -some callback calls `.RequestHandler.finish`. It is up to the application -to ensure that this method is called, or else the user's browser will -simply hang. - -Here is an example that makes a call to the FriendFeed API using -Tornado's built-in `.AsyncHTTPClient`: +Certain handler methods (including ``prepare()`` and the HTTP verb +methods ``get()``/``post()``/etc) may be overridden as coroutines to +make the handler asynchronous. -.. testcode:: - - class MainHandler(tornado.web.RequestHandler): - @tornado.web.asynchronous - def get(self): - http = tornado.httpclient.AsyncHTTPClient() - http.fetch("http://friendfeed-api.com/v2/feed/bret", - callback=self.on_response) - - def on_response(self, response): - if response.error: raise tornado.web.HTTPError(500) - json = tornado.escape.json_decode(response.body) - self.write("Fetched " + str(len(json["entries"])) + " entries " - "from the FriendFeed API") - self.finish() - -.. testoutput:: - :hide: - -When ``get()`` returns, the request has not finished. When the HTTP -client eventually calls ``on_response()``, the request is still open, -and the response is finally flushed to the client with the call to -``self.finish()``. - -For comparison, here is the same example using a coroutine: +For example, here is a simple handler using a coroutine: .. testcode:: class MainHandler(tornado.web.RequestHandler): - @tornado.gen.coroutine - def get(self): + async def get(self): http = tornado.httpclient.AsyncHTTPClient() - response = yield http.fetch("http://friendfeed-api.com/v2/feed/bret") + response = await http.fetch("http://friendfeed-api.com/v2/feed/bret") json = tornado.escape.json_decode(response.body) self.write("Fetched " + str(len(json["entries"])) + " entries " "from the FriendFeed API") diff --git a/docs/guide/templates.rst b/docs/guide/templates.rst index 8755d25e13..61ce753e6a 100644 --- a/docs/guide/templates.rst +++ b/docs/guide/templates.rst @@ -66,9 +66,9 @@ directory as your Python file, you could render this template with: :hide: Tornado templates support *control statements* and *expressions*. -Control statements are surrounded by ``{%`` and ``%}``, e.g., +Control statements are surrounded by ``{%`` and ``%}``, e.g. ``{% if len(items) > 2 %}``. Expressions are surrounded by ``{{`` and -``}}``, e.g., ``{{ items[0] }}``. +``}}``, e.g. ``{{ items[0] }}``. Control statements more or less map exactly to Python statements. We support ``if``, ``for``, ``while``, and ``try``, all of which are @@ -78,7 +78,7 @@ detail in the documentation for the `tornado.template`. Expressions can be any Python expression, including function calls. Template code is executed in a namespace that includes the following -objects and functions (Note that this list applies to templates +objects and functions. (Note that this list applies to templates rendered using `.RequestHandler.render` and `~.RequestHandler.render_string`. If you're using the `tornado.template` module directly outside of a `.RequestHandler` many @@ -132,11 +132,12 @@ instead of ``None``. Note that while Tornado's automatic escaping is helpful in avoiding XSS vulnerabilities, it is not sufficient in all cases. Expressions -that appear in certain locations, such as in Javascript or CSS, may need +that appear in certain locations, such as in JavaScript or CSS, may need additional escaping. Additionally, either care must be taken to always use double quotes and `.xhtml_escape` in HTML attributes that may contain untrusted content, or a separate escaping function must be used for -attributes (see e.g. http://wonko.com/post/html-escaping) +attributes (see e.g. +`this blog post `_). Internationalization ~~~~~~~~~~~~~~~~~~~~ @@ -212,7 +213,7 @@ formats: the ``.mo`` format used by `gettext` and related tools, and a simple ``.csv`` format. An application will generally call either `tornado.locale.load_translations` or `tornado.locale.load_gettext_translations` once at startup; see those -methods for more details on the supported formats.. +methods for more details on the supported formats. You can get the list of supported locales in your application with `tornado.locale.get_supported_locales()`. The user's locale is chosen @@ -234,7 +235,7 @@ packaged with their own CSS and JavaScript. For example, if you are implementing a blog, and you want to have blog entries appear on both the blog home page and on each blog entry page, you can make an ``Entry`` module to render them on both pages. First, -create a Python module for your UI modules, e.g., ``uimodules.py``:: +create a Python module for your UI modules, e.g. ``uimodules.py``:: class Entry(tornado.web.UIModule): def render(self, entry, show_comments=False): diff --git a/docs/httpclient.rst b/docs/httpclient.rst index cf3bc8e715..178dc1480a 100644 --- a/docs/httpclient.rst +++ b/docs/httpclient.rst @@ -47,7 +47,9 @@ Implementations ~~~~~~~~~~~~~~~ .. automodule:: tornado.simple_httpclient - :members: + + .. autoclass:: SimpleAsyncHTTPClient + :members: .. module:: tornado.curl_httpclient diff --git a/docs/httpserver.rst b/docs/httpserver.rst index 88c74376bd..74d411ddd0 100644 --- a/docs/httpserver.rst +++ b/docs/httpserver.rst @@ -5,5 +5,8 @@ HTTP Server ----------- - .. autoclass:: HTTPServer + .. autoclass:: HTTPServer(request_callback: Union[httputil.HTTPServerConnectionDelegate, Callable[[httputil.HTTPServerRequest], None]], no_keep_alive: bool = False, xheaders: bool = False, ssl_options: Union[Dict[str, Any], ssl.SSLContext] = None, protocol: Optional[str] = None, decompress_request: bool = False, chunk_size: Optional[int] = None, max_header_size: Optional[int] = None, idle_connection_timeout: Optional[float] = None, body_timeout: Optional[float] = None, max_body_size: Optional[int] = None, max_buffer_size: Optional[int] = None, trusted_downstream: Optional[List[str]] = None) :members: + + The public interface of this class is mostly inherited from + `.TCPServer` and is documented under that class. diff --git a/docs/index.rst b/docs/index.rst index 0892e92c22..6f59bd7085 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -9,21 +9,21 @@ .. |Tornado Web Server| image:: tornado.png :alt: Tornado Web Server -`Tornado `_ is a Python web framework and +`Tornado `_ is a Python web framework and asynchronous networking library, originally developed at `FriendFeed -`_. By using non-blocking network I/O, Tornado +`_. By using non-blocking network I/O, Tornado can scale to tens of thousands of open connections, making it ideal for -`long polling `_, -`WebSockets `_, and other +`long polling `_, +`WebSockets `_, and other applications that require a long-lived connection to each user. Quick links ----------- * Current version: |version| (`download from PyPI `_, :doc:`release notes `) -* `Source (github) `_ -* Mailing lists: `discussion `_ and `announcements `_ -* `Stack Overflow `_ +* `Source (GitHub) `_ +* Mailing lists: `discussion `_ and `announcements `_ +* `Stack Overflow `_ * `Wiki `_ Hello, world @@ -73,6 +73,14 @@ that the function passed to ``run_in_executor`` should avoid referencing any Tornado objects. ``run_in_executor`` is the recommended way to interact with blocking code. +``asyncio`` Integration +----------------------- + +Tornado is integrated with the standard library `asyncio` module and +shares the same event loop (by default since Tornado 5.0). In general, +libraries designed for use with `asyncio` can be mixed freely with +Tornado. + Installation ------------ @@ -81,42 +89,36 @@ Installation pip install tornado -Tornado is listed in `PyPI `_ and +Tornado is listed in `PyPI `_ and can be installed with ``pip``. Note that the source distribution includes demo applications that are not present when Tornado is installed in this way, so you may wish to download a copy of the source tarball or clone the `git repository `_ as well. -**Prerequisites**: Tornado runs on Python 2.7, and 3.4+. -The updates to the `ssl` module in Python 2.7.9 are required -(in some distributions, these updates may be available in -older python versions). In addition to the requirements -which will be installed automatically by ``pip`` or ``setup.py install``, -the following optional packages may be useful: +**Prerequisites**: Tornado 6.0 requires Python 3.6 or newer (See +`Tornado 5.1 `_ if +compatibility with Python 2.7 is required). The following optional +packages may be useful: -* `pycurl `_ is used by the optional +* `pycurl `_ is used by the optional ``tornado.curl_httpclient``. Libcurl version 7.22 or higher is required. -* `Twisted `_ may be used with the classes in +* `Twisted `_ may be used with the classes in `tornado.platform.twisted`. -* `pycares `_ is an alternative +* `pycares `_ is an alternative non-blocking DNS resolver that can be used when threads are not appropriate. -* `monotonic `_ or `Monotime - `_ add support for a - monotonic clock, which improves reliability in environments where - clock adjustments are frequent. No longer needed in Python 3. - -**Platforms**: Tornado should run on any Unix-like platform, although -for the best performance and scalability only Linux (with ``epoll``) -and BSD (with ``kqueue``) are recommended for production deployment -(even though Mac OS X is derived from BSD and supports kqueue, its -networking performance is generally poor so it is recommended only for -development use). Tornado will also run on Windows, although this -configuration is not officially supported and is recommended only for -development use. Without reworking Tornado IOLoop interface, it's not -possible to add a native Tornado Windows IOLoop implementation or -leverage Windows' IOCP support from frameworks like AsyncIO or Twisted. + +**Platforms**: Tornado is designed for Unix-like platforms, with best +performance and scalability on systems supporting ``epoll`` (Linux), +``kqueue`` (BSD/macOS), or ``/dev/poll`` (Solaris). + +Tornado will also run on Windows, although this configuration is not +officially supported or recommended for production use. Some features +are missing on Windows (including multi-process mode) and scalability +is limited (Even though Tornado is built on ``asyncio``, which +supports Windows, Tornado does not use the APIs that are necessary for +scalable networking on Windows). Documentation ------------- @@ -145,17 +147,17 @@ Discussion and support ---------------------- You can discuss Tornado on `the Tornado developer mailing list -`_, and report bugs on +`_, and report bugs on the `GitHub issue tracker `_. Links to additional resources can be found on the `Tornado wiki `_. New releases are announced on the `announcements mailing list -`_. +`_. Tornado is available under the `Apache License, Version 2.0 `_. This web site and all documentation is licensed under `Creative -Commons 3.0 `_. +Commons 3.0 `_. diff --git a/docs/ioloop.rst b/docs/ioloop.rst index 2c5fdc2cd2..5b748d3695 100644 --- a/docs/ioloop.rst +++ b/docs/ioloop.rst @@ -45,18 +45,3 @@ .. automethod:: IOLoop.time .. autoclass:: PeriodicCallback :members: - - Debugging and error handling - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - - .. automethod:: IOLoop.handle_callback_exception - .. automethod:: IOLoop.set_blocking_signal_threshold - .. automethod:: IOLoop.set_blocking_log_threshold - .. automethod:: IOLoop.log_stack - - Methods for subclasses - ^^^^^^^^^^^^^^^^^^^^^^ - - .. automethod:: IOLoop.initialize - .. automethod:: IOLoop.close_fd - .. automethod:: IOLoop.split_fd diff --git a/docs/locks.rst b/docs/locks.rst index 9f991880c3..df30351cea 100644 --- a/docs/locks.rst +++ b/docs/locks.rst @@ -10,9 +10,10 @@ similar to those provided in the standard library's `asyncio package .. warning:: - Note that these primitives are not actually thread-safe and cannot be used in - place of those from the standard library--they are meant to coordinate Tornado - coroutines in a single-threaded app, not to protect shared objects in a + Note that these primitives are not actually thread-safe and cannot + be used in place of those from the standard library's `threading` + module--they are meant to coordinate Tornado coroutines in a + single-threaded app, not to protect shared objects in a multithreaded app. .. automodule:: tornado.locks diff --git a/docs/releases.rst b/docs/releases.rst index 6f87edc30b..a478821d78 100644 --- a/docs/releases.rst +++ b/docs/releases.rst @@ -4,6 +4,14 @@ Release notes .. toctree:: :maxdepth: 2 + releases/v6.1.0 + releases/v6.0.4 + releases/v6.0.3 + releases/v6.0.2 + releases/v6.0.1 + releases/v6.0.0 + releases/v5.1.1 + releases/v5.1.0 releases/v5.0.2 releases/v5.0.1 releases/v5.0.0 diff --git a/docs/releases/v2.2.0.rst b/docs/releases/v2.2.0.rst index a3298c557c..922c6d2ff5 100644 --- a/docs/releases/v2.2.0.rst +++ b/docs/releases/v2.2.0.rst @@ -56,7 +56,7 @@ Backwards-incompatible changes * ``IOStream.write`` now works correctly when given an empty string. * ``IOStream.read_until`` (and ``read_until_regex``) now perform better - when there is a lot of buffered data, which improves peformance of + when there is a lot of buffered data, which improves performance of ``SimpleAsyncHTTPClient`` when downloading files with lots of chunks. * `.SSLIOStream` now works correctly when ``ssl_version`` is set to diff --git a/docs/releases/v2.3.0.rst b/docs/releases/v2.3.0.rst index d24f46c547..5be231d032 100644 --- a/docs/releases/v2.3.0.rst +++ b/docs/releases/v2.3.0.rst @@ -102,9 +102,9 @@ Other modules function is called repeatedly. * `tornado.locale.get_supported_locales` no longer takes a meaningless ``cls`` argument. -* `.StackContext` instances now have a deactivation callback that can be +* ``StackContext`` instances now have a deactivation callback that can be used to prevent further propagation. * `tornado.testing.AsyncTestCase.wait` now resets its timeout on each call. -* `tornado.wsgi.WSGIApplication` now parses arguments correctly on Python 3. +* ``tornado.wsgi.WSGIApplication`` now parses arguments correctly on Python 3. * Exception handling on Python 3 has been improved; previously some exceptions such as `UnicodeDecodeError` would generate ``TypeErrors`` diff --git a/docs/releases/v2.4.0.rst b/docs/releases/v2.4.0.rst index bbf07bf82e..5bbff30bcb 100644 --- a/docs/releases/v2.4.0.rst +++ b/docs/releases/v2.4.0.rst @@ -57,7 +57,7 @@ HTTP clients * New method `.RequestHandler.get_template_namespace` can be overridden to add additional variables without modifying keyword arguments to ``render_string``. -* `.RequestHandler.add_header` now works with `.WSGIApplication`. +* `.RequestHandler.add_header` now works with ``WSGIApplication``. * `.RequestHandler.get_secure_cookie` now handles a potential error case. * ``RequestHandler.__init__`` now calls ``super().__init__`` to ensure that all constructors are called when multiple inheritance is used. diff --git a/docs/releases/v2.4.1.rst b/docs/releases/v2.4.1.rst index 22b09ca428..82eabc94e0 100644 --- a/docs/releases/v2.4.1.rst +++ b/docs/releases/v2.4.1.rst @@ -7,7 +7,7 @@ Nov 24, 2012 Bug fixes ~~~~~~~~~ -* Fixed a memory leak in `tornado.stack_context` that was especially likely +* Fixed a memory leak in ``tornado.stack_context`` that was especially likely with long-running ``@gen.engine`` functions. * `tornado.auth.TwitterMixin` now works on Python 3. * Fixed a bug in which ``IOStream.read_until_close`` with a streaming callback diff --git a/docs/releases/v3.0.0.rst b/docs/releases/v3.0.0.rst index a1d34e12f2..53c9771f30 100644 --- a/docs/releases/v3.0.0.rst +++ b/docs/releases/v3.0.0.rst @@ -10,7 +10,7 @@ Highlights * The ``callback`` argument to many asynchronous methods is now optional, and these methods return a `.Future`. The `tornado.gen` module now understands ``Futures``, and these methods can be used - directly without a `.gen.Task` wrapper. + directly without a ``gen.Task`` wrapper. * New function `.IOLoop.current` returns the `.IOLoop` that is running on the current thread (as opposed to `.IOLoop.instance`, which returns a specific thread's (usually the main thread's) IOLoop. @@ -136,7 +136,7 @@ Multiple modules calling a callback you return a value with ``raise gen.Return(value)`` (or simply ``return value`` in Python 3.3). * Generators may now yield `.Future` objects. -* Callbacks produced by `.gen.Callback` and `.gen.Task` are now automatically +* Callbacks produced by ``gen.Callback`` and ``gen.Task`` are now automatically stack-context-wrapped, to minimize the risk of context leaks when used with asynchronous functions that don't do their own wrapping. * Fixed a memory leak involving generators, `.RequestHandler.flush`, @@ -167,7 +167,7 @@ Multiple modules when instantiating an implementation subclass directly. * Secondary `.AsyncHTTPClient` callbacks (``streaming_callback``, ``header_callback``, and ``prepare_curl_callback``) now respect - `.StackContext`. + ``StackContext``. `tornado.httpserver` ~~~~~~~~~~~~~~~~~~~~ @@ -311,9 +311,9 @@ Multiple modules `tornado.platform.twisted` ~~~~~~~~~~~~~~~~~~~~~~~~~~ -* New class `tornado.platform.twisted.TwistedIOLoop` allows Tornado +* New class ``tornado.platform.twisted.TwistedIOLoop`` allows Tornado code to be run on the Twisted reactor (as opposed to the existing - `.TornadoReactor`, which bridges the gap in the other direction). + ``TornadoReactor``, which bridges the gap in the other direction). * New class `tornado.platform.twisted.TwistedResolver` is an asynchronous implementation of the `.Resolver` interface. @@ -343,10 +343,10 @@ Multiple modules * Fixed a bug in which ``SimpleAsyncHTTPClient`` callbacks were being run in the client's ``stack_context``. -`tornado.stack_context` -~~~~~~~~~~~~~~~~~~~~~~~ +``tornado.stack_context`` +~~~~~~~~~~~~~~~~~~~~~~~~~ -* `.stack_context.wrap` now runs the wrapped callback in a more consistent +* ``stack_context.wrap`` now runs the wrapped callback in a more consistent environment by recreating contexts even if they already exist on the stack. * Fixed a bug in which stack contexts could leak from one callback diff --git a/docs/releases/v3.0.2.rst b/docs/releases/v3.0.2.rst index 7eac09dba0..70e7d52b3e 100644 --- a/docs/releases/v3.0.2.rst +++ b/docs/releases/v3.0.2.rst @@ -9,4 +9,4 @@ Jun 2, 2013 June 11 `_. It also now uses HTTPS when talking to Twitter. * Fixed a potential memory leak with a long chain of `.gen.coroutine` - or `.gen.engine` functions. + or ``gen.engine`` functions. diff --git a/docs/releases/v3.1.0.rst b/docs/releases/v3.1.0.rst index edbf39dba6..b4ae0e12ee 100644 --- a/docs/releases/v3.1.0.rst +++ b/docs/releases/v3.1.0.rst @@ -28,7 +28,7 @@ Multiple modules are asynchronous in `.OAuthMixin` and derived classes, although they do not take a callback. The `.Future` these methods return must be yielded if they are called from a function decorated with `.gen.coroutine` - (but not `.gen.engine`). + (but not ``gen.engine``). * `.TwitterMixin` now uses ``/account/verify_credentials`` to get information about the logged-in user, which is more robust against changing screen names. @@ -147,11 +147,11 @@ Multiple modules * `.Subprocess.set_exit_callback` now works for subprocesses created without an explicit ``io_loop`` parameter. -`tornado.stack_context` -~~~~~~~~~~~~~~~~~~~~~~~ +``tornado.stack_context`` +~~~~~~~~~~~~~~~~~~~~~~~~~ -* `tornado.stack_context` has been rewritten and is now much faster. -* New function `.run_with_stack_context` facilitates the use of stack +* ``tornado.stack_context`` has been rewritten and is now much faster. +* New function ``run_with_stack_context`` facilitates the use of stack contexts with coroutines. `tornado.tcpserver` @@ -170,7 +170,7 @@ Multiple modules ~~~~~~~~~~~~~~~~~ * `tornado.testing.AsyncTestCase.wait` now raises the correct exception - when it has been modified by `tornado.stack_context`. + when it has been modified by ``tornado.stack_context``. * `tornado.testing.gen_test` can now be called as ``@gen_test(timeout=60)`` to give some tests a longer timeout than others. * The environment variable ``ASYNC_TEST_TIMEOUT`` can now be set to @@ -210,11 +210,11 @@ Multiple modules instead of being turned into spaces. * `.RequestHandler.send_error` will now only be called once per request, even if multiple exceptions are caught by the stack context. -* The `tornado.web.asynchronous` decorator is no longer necessary for +* The ``tornado.web.asynchronous`` decorator is no longer necessary for methods that return a `.Future` (i.e. those that use the `.gen.coroutine` - or `.return_future` decorators) + or ``return_future`` decorators) * `.RequestHandler.prepare` may now be asynchronous if it returns a - `.Future`. The `~tornado.web.asynchronous` decorator is not used with + `.Future`. The ``tornado.web.asynchronous`` decorator is not used with ``prepare``; one of the `.Future`-related decorators should be used instead. * ``RequestHandler.current_user`` may now be assigned to normally. * `.RequestHandler.redirect` no longer silently strips control characters diff --git a/docs/releases/v3.2.0.rst b/docs/releases/v3.2.0.rst index 09057030ac..f0be99961a 100644 --- a/docs/releases/v3.2.0.rst +++ b/docs/releases/v3.2.0.rst @@ -79,7 +79,7 @@ New modules `tornado.ioloop` ~~~~~~~~~~~~~~~~ -* `.IOLoop` now uses `~.IOLoop.handle_callback_exception` consistently for +* `.IOLoop` now uses ``IOLoop.handle_callback_exception`` consistently for error logging. * `.IOLoop` now frees callback objects earlier, reducing memory usage while idle. diff --git a/docs/releases/v4.0.0.rst b/docs/releases/v4.0.0.rst index dd60ea8c3a..6b470d3bc1 100644 --- a/docs/releases/v4.0.0.rst +++ b/docs/releases/v4.0.0.rst @@ -96,14 +96,14 @@ Other notes will be created on demand when needed. * The internals of the `tornado.gen` module have been rewritten to improve performance when using ``Futures``, at the expense of some - performance degradation for the older `.YieldPoint` interfaces. + performance degradation for the older ``YieldPoint`` interfaces. * New function `.with_timeout` wraps a `.Future` and raises an exception if it doesn't complete in a given amount of time. * New object `.moment` can be yielded to allow the IOLoop to run for one iteration before resuming. -* `.Task` is now a function returning a `.Future` instead of a `.YieldPoint` +* ``Task`` is now a function returning a `.Future` instead of a ``YieldPoint`` subclass. This change should be transparent to application code, but - allows `.Task` to take advantage of the newly-optimized `.Future` + allows ``Task`` to take advantage of the newly-optimized `.Future` handling. `tornado.http1connection` @@ -135,7 +135,7 @@ Other notes * The ``connection`` attribute of `.HTTPServerRequest` is now documented for public use; applications are expected to write their responses via the `.HTTPConnection` interface. -* The `.HTTPServerRequest.write` and `.HTTPServerRequest.finish` methods +* The ``HTTPServerRequest.write`` and ``HTTPServerRequest.finish`` methods are now deprecated. (`.RequestHandler.write` and `.RequestHandler.finish` are *not* deprecated; this only applies to the methods on `.HTTPServerRequest`) @@ -233,7 +233,7 @@ Other notes `tornado.platform.twisted` ~~~~~~~~~~~~~~~~~~~~~~~~~~ -* `.TwistedIOLoop` now works on Python 3.3+ (with Twisted 14.0.0+). +* ``TwistedIOLoop`` now works on Python 3.3+ (with Twisted 14.0.0+). ``tornado.simple_httpclient`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -251,8 +251,8 @@ Other notes * ``simple_httpclient`` now raises the original exception (e.g. an `IOError`) in more cases, instead of converting everything to ``HTTPError``. -`tornado.stack_context` -~~~~~~~~~~~~~~~~~~~~~~~ +``tornado.stack_context`` +~~~~~~~~~~~~~~~~~~~~~~~~~ * The stack context system now has less performance overhead when no stack contexts are active. @@ -324,8 +324,8 @@ Other notes `tornado.wsgi` ~~~~~~~~~~~~~~ -* New class `.WSGIAdapter` supports running a Tornado `.Application` on +* New class ``WSGIAdapter`` supports running a Tornado `.Application` on a WSGI server in a way that is more compatible with Tornado's non-WSGI - `.HTTPServer`. `.WSGIApplication` is deprecated in favor of using - `.WSGIAdapter` with a regular `.Application`. -* `.WSGIAdapter` now supports gzipped output. + `.HTTPServer`. ``WSGIApplication`` is deprecated in favor of using + ``WSGIAdapter`` with a regular `.Application`. +* ``WSGIAdapter`` now supports gzipped output. diff --git a/docs/releases/v4.1.0.rst b/docs/releases/v4.1.0.rst index 29ad19146a..74cd30a49f 100644 --- a/docs/releases/v4.1.0.rst +++ b/docs/releases/v4.1.0.rst @@ -60,7 +60,7 @@ Backwards-compatibility notes yieldable in coroutines. * New function `tornado.gen.sleep` is a coroutine-friendly analogue to `time.sleep`. -* `.gen.engine` now correctly captures the stack context for its callbacks. +* ``gen.engine`` now correctly captures the stack context for its callbacks. `tornado.httpclient` ~~~~~~~~~~~~~~~~~~~~ @@ -167,7 +167,7 @@ Backwards-compatibility notes `tornado.web` ~~~~~~~~~~~~~ -* The `.asynchronous` decorator now understands `concurrent.futures.Future` +* The ``asynchronous`` decorator now understands `concurrent.futures.Future` in addition to `tornado.concurrent.Future`. * `.StaticFileHandler` no longer logs a stack trace if the connection is closed while sending the file. diff --git a/docs/releases/v4.2.0.rst b/docs/releases/v4.2.0.rst index 93493ee113..bacfb13a05 100644 --- a/docs/releases/v4.2.0.rst +++ b/docs/releases/v4.2.0.rst @@ -162,7 +162,7 @@ Then the Tornado equivalent is:: * The `.IOLoop` constructor now has a ``make_current`` keyword argument to control whether the new `.IOLoop` becomes `.IOLoop.current()`. * Third-party implementations of `.IOLoop` should accept ``**kwargs`` - in their `~.IOLoop.initialize` methods and pass them to the superclass + in their ``IOLoop.initialize`` methods and pass them to the superclass implementation. * `.PeriodicCallback` is now more efficient when the clock jumps forward by a large amount. diff --git a/docs/releases/v4.3.0.rst b/docs/releases/v4.3.0.rst index e6ea1589c7..b19b297c1b 100644 --- a/docs/releases/v4.3.0.rst +++ b/docs/releases/v4.3.0.rst @@ -12,7 +12,7 @@ Highlights Inside a function defined with ``async def``, use ``await`` instead of ``yield`` to wait on an asynchronous operation. Coroutines defined with async/await will be faster than those defined with ``@gen.coroutine`` and - ``yield``, but do not support some features including `.Callback`/`.Wait` or + ``yield``, but do not support some features including ``Callback``/``Wait`` or the ability to yield a Twisted ``Deferred``. See :ref:`the users' guide ` for more. * The async/await keywords are also available when compiling with Cython in diff --git a/docs/releases/v4.4.0.rst b/docs/releases/v4.4.0.rst index fa7c81702a..5ac3018b9f 100644 --- a/docs/releases/v4.4.0.rst +++ b/docs/releases/v4.4.0.rst @@ -26,7 +26,7 @@ General ~~~~~~~~~~~~~ * `.with_timeout` now accepts any yieldable object (except - `.YieldPoint`), not just `tornado.concurrent.Future`. + ``YieldPoint``), not just `tornado.concurrent.Future`. `tornado.httpclient` ~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/releases/v4.5.0.rst b/docs/releases/v4.5.0.rst index 9cdf0ad5c2..5a4ce9e258 100644 --- a/docs/releases/v4.5.0.rst +++ b/docs/releases/v4.5.0.rst @@ -58,7 +58,7 @@ General changes - Fixed an issue in which a generator object could be garbage collected prematurely (most often when weak references are used. - New function `.is_coroutine_function` identifies functions wrapped - by `.coroutine` or `.engine`. + by `.coroutine` or ``engine``. ``tornado.http1connection`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/releases/v5.0.0.rst b/docs/releases/v5.0.0.rst index 7f1325b8de..dd0bd02439 100644 --- a/docs/releases/v5.0.0.rst +++ b/docs/releases/v5.0.0.rst @@ -193,8 +193,8 @@ Other notes - The ``io_loop`` argument to `.PeriodicCallback` has been removed. - It is now possible to create a `.PeriodicCallback` in one thread and start it in another without passing an explicit event loop. -- The `.IOLoop.set_blocking_signal_threshold` and - `.IOLoop.set_blocking_log_threshold` methods are deprecated because +- The ``IOLoop.set_blocking_signal_threshold`` and + ``IOLoop.set_blocking_log_threshold`` methods are deprecated because they are not implemented for the `asyncio` event loop`. Use the ``PYTHONASYNCIODEBUG=1`` environment variable instead. - `.IOLoop.clear_current` now works if it is called before any @@ -257,8 +257,8 @@ Other notes `tornado.platform.twisted` ~~~~~~~~~~~~~~~~~~~~~~~~~~ -- The ``io_loop`` arguments to `.TornadoReactor`, `.TwistedResolver`, - and `tornado.platform.twisted.install` have been removed. +- The ``io_loop`` arguments to ``TornadoReactor``, `.TwistedResolver`, + and ``tornado.platform.twisted.install`` have been removed. `tornado.process` ~~~~~~~~~~~~~~~~~ diff --git a/docs/releases/v5.1.0.rst b/docs/releases/v5.1.0.rst new file mode 100644 index 0000000000..00def8f38c --- /dev/null +++ b/docs/releases/v5.1.0.rst @@ -0,0 +1,195 @@ +What's new in Tornado 5.1 +========================= + +July 12, 2018 +------------- + +Deprecation notice +~~~~~~~~~~~~~~~~~~ + +- Tornado 6.0 will drop support for Python 2.7 and 3.4. The minimum + supported Python version will be 3.5.2. +- The ``tornado.stack_context`` module is deprecated and will be removed + in Tornado 6.0. The reason for this is that it is not feasible to + provide this module's semantics in the presence of ``async def`` + native coroutines. ``ExceptionStackContext`` is mainly obsolete + thanks to coroutines. ``StackContext`` lacks a direct replacement + although the new ``contextvars`` package (in the Python standard + library beginning in Python 3.7) may be an alternative. +- Callback-oriented code often relies on ``ExceptionStackContext`` to + handle errors and prevent leaked connections. In order to avoid the + risk of silently introducing subtle leaks (and to consolidate all of + Tornado's interfaces behind the coroutine pattern), ``callback`` + arguments throughout the package are deprecated and will be removed + in version 6.0. All functions that had a ``callback`` argument + removed now return a `.Future` which should be used instead. +- Where possible, deprecation warnings are emitted when any of these + deprecated interfaces is used. However, Python does not display + deprecation warnings by default. To prepare your application for + Tornado 6.0, run Python with the ``-Wd`` argument or set the + environment variable ``PYTHONWARNINGS`` to ``d``. If your + application runs on Python 3 without deprecation warnings, it should + be able to move to Tornado 6.0 without disruption. + +`tornado.auth` +~~~~~~~~~~~~~~ + +- `.OAuthMixin._oauth_get_user_future` may now be a native coroutine. +- All ``callback`` arguments in this package are deprecated and will + be removed in 6.0. Use the coroutine interfaces instead. +- The ``OAuthMixin._oauth_get_user`` method is deprecated and will be removed in + 6.0. Override `~.OAuthMixin._oauth_get_user_future` instead. + +`tornado.autoreload` +~~~~~~~~~~~~~~~~~~~~ + +- The command-line autoreload wrapper is now preserved if an internal + autoreload fires. +- The command-line wrapper no longer starts duplicated processes on windows + when combined with internal autoreload. + +`tornado.concurrent` +~~~~~~~~~~~~~~~~~~~~ + +- `.run_on_executor` now returns `.Future` objects that are compatible + with ``await``. +- The ``callback`` argument to `.run_on_executor` is deprecated and will + be removed in 6.0. +- ``return_future`` is deprecated and will be removed in 6.0. + +`tornado.gen` +~~~~~~~~~~~~~ + +- Some older portions of this module are deprecated and will be removed + in 6.0. This includes ``engine``, ``YieldPoint``, ``Callback``, + ``Wait``, ``WaitAll``, ``MultiYieldPoint``, and ``Task``. +- Functions decorated with ``@gen.coroutine`` will no longer accept + ``callback`` arguments in 6.0. + +`tornado.httpclient` +~~~~~~~~~~~~~~~~~~~~ + +- The behavior of ``raise_error=False`` is changing in 6.0. Currently + it suppresses all errors; in 6.0 it will only suppress the errors + raised due to completed responses with non-200 status codes. +- The ``callback`` argument to `.AsyncHTTPClient.fetch` is deprecated + and will be removed in 6.0. +- `tornado.httpclient.HTTPError` has been renamed to + `.HTTPClientError` to avoid ambiguity in code that also has to deal + with `tornado.web.HTTPError`. The old name remains as an alias. +- ``tornado.curl_httpclient`` now supports non-ASCII characters in + username and password arguments. +- ``.HTTPResponse.request_time`` now behaves consistently across + ``simple_httpclient`` and ``curl_httpclient``, excluding time spent + in the ``max_clients`` queue in both cases (previously this time was + included in ``simple_httpclient`` but excluded in + ``curl_httpclient``). In both cases the time is now computed using + a monotonic clock where available. +- `.HTTPResponse` now has a ``start_time`` attribute recording a + wall-clock (`time.time`) timestamp at which the request started + (after leaving the ``max_clients`` queue if applicable). + +`tornado.httputil` +~~~~~~~~~~~~~~~~~~ + +- `.parse_multipart_form_data` now recognizes non-ASCII filenames in + RFC 2231/5987 (``filename*=``) format. +- ``HTTPServerRequest.write`` is deprecated and will be removed in 6.0. Use + the methods of ``request.connection`` instead. +- Malformed HTTP headers are now logged less noisily. + +`tornado.ioloop` +~~~~~~~~~~~~~~~~ + +- `.PeriodicCallback` now supports a ``jitter`` argument to randomly + vary the timeout. +- ``IOLoop.set_blocking_signal_threshold``, + ``IOLoop.set_blocking_log_threshold``, ``IOLoop.log_stack``, + and ``IOLoop.handle_callback_exception`` are deprecated and will + be removed in 6.0. +- Fixed a `KeyError` in `.IOLoop.close` when `.IOLoop` objects are + being opened and closed in multiple threads. + +`tornado.iostream` +~~~~~~~~~~~~~~~~~~ + +- All ``callback`` arguments in this module are deprecated except for + `.BaseIOStream.set_close_callback`. They will be removed in 6.0. +- ``streaming_callback`` arguments to `.BaseIOStream.read_bytes` and + `.BaseIOStream.read_until_close` are deprecated and will be removed + in 6.0. + +`tornado.netutil` +~~~~~~~~~~~~~~~~~ + +- Improved compatibility with GNU Hurd. + +`tornado.options` +~~~~~~~~~~~~~~~~~ + +- `tornado.options.parse_config_file` now allows setting options to + strings (which will be parsed the same way as + `tornado.options.parse_command_line`) in addition to the specified + type for the option. + +`tornado.platform.twisted` +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- ``TornadoReactor`` and ``TwistedIOLoop`` are deprecated and will be + removed in 6.0. Instead, Tornado will always use the asyncio event loop + and twisted can be configured to do so as well. + +``tornado.stack_context`` +~~~~~~~~~~~~~~~~~~~~~~~~~ + +- The ``tornado.stack_context`` module is deprecated and will be removed + in 6.0. + +`tornado.testing` +~~~~~~~~~~~~~~~~~ + +- `.AsyncHTTPTestCase.fetch` now takes a ``raise_error`` argument. + This argument has the same semantics as `.AsyncHTTPClient.fetch`, + but defaults to false because tests often need to deal with non-200 + responses (and for backwards-compatibility). +- The `.AsyncTestCase.stop` and `.AsyncTestCase.wait` methods are + deprecated. + +`tornado.web` +~~~~~~~~~~~~~ + +- New method `.RequestHandler.detach` can be used from methods + that are not decorated with ``@asynchronous`` (the decorator + was required to use ``self.request.connection.detach()``. +- `.RequestHandler.finish` and `.RequestHandler.render` now return + ``Futures`` that can be used to wait for the last part of the + response to be sent to the client. +- `.FallbackHandler` now calls ``on_finish`` for the benefit of + subclasses that may have overridden it. +- The ``asynchronous`` decorator is deprecated and will be removed in 6.0. +- The ``callback`` argument to `.RequestHandler.flush` is deprecated + and will be removed in 6.0. + + +`tornado.websocket` +~~~~~~~~~~~~~~~~~~~ + +- When compression is enabled, memory limits now apply to the + post-decompression size of the data, protecting against DoS attacks. +- `.websocket_connect` now supports subprotocols. +- `.WebSocketHandler` and `.WebSocketClientConnection` now have + ``selected_subprotocol`` attributes to see the subprotocol in use. +- The `.WebSocketHandler.select_subprotocol` method is now called with + an empty list instead of a list containing an empty string if no + subprotocols were requested by the client. +- `.WebSocketHandler.open` may now be a coroutine. +- The ``data`` argument to `.WebSocketHandler.ping` is now optional. +- Client-side websocket connections no longer buffer more than one + message in memory at a time. +- Exception logging now uses `.RequestHandler.log_exception`. + +`tornado.wsgi` +~~~~~~~~~~~~~~ + +- ``WSGIApplication`` and ``WSGIAdapter`` are deprecated and will be removed + in Tornado 6.0. diff --git a/docs/releases/v5.1.1.rst b/docs/releases/v5.1.1.rst new file mode 100644 index 0000000000..7fc4fb881a --- /dev/null +++ b/docs/releases/v5.1.1.rst @@ -0,0 +1,14 @@ +What's new in Tornado 5.1.1 +=========================== + +Sep 16, 2018 +------------ + +Bug fixes +~~~~~~~~~ + +- Fixed an case in which the `.Future` returned by + `.RequestHandler.finish` could fail to resolve. +- The `.TwitterMixin.authenticate_redirect` method works again. +- Improved error handling in the `tornado.auth` module, fixing hanging + requests when a network or other error occurs. diff --git a/docs/releases/v6.0.0.rst b/docs/releases/v6.0.0.rst new file mode 100644 index 0000000000..d3d2dfbc0b --- /dev/null +++ b/docs/releases/v6.0.0.rst @@ -0,0 +1,162 @@ +What's new in Tornado 6.0 +========================= + +Mar 1, 2019 +----------- + +Backwards-incompatible changes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Python 2.7 and 3.4 are no longer supported; the minimum supported + Python version is 3.5.2. +- APIs deprecated in Tornado 5.1 have been removed. This includes the + ``tornado.stack_context`` module and most ``callback`` arguments + throughout the package. All removed APIs emitted + `DeprecationWarning` when used in Tornado 5.1, so running your + application with the ``-Wd`` Python command-line flag or the + environment variable ``PYTHONWARNINGS=d`` should tell you whether + your application is ready to move to Tornado 6.0. +- ``.WebSocketHandler.get`` is now a coroutine and must be called + accordingly in any subclasses that override this method (but note + that overriding ``get`` is not recommended; either ``prepare`` or + ``open`` should be used instead). + +General changes +~~~~~~~~~~~~~~~ + +- Tornado now includes type annotations compatible with ``mypy``. + These annotations will be used when type-checking your application + with ``mypy``, and may be usable in editors and other tools. +- Tornado now uses native coroutines internally, improving performance. + +`tornado.auth` +~~~~~~~~~~~~~~ + +- All ``callback`` arguments in this package have been removed. Use + the coroutine interfaces instead. +- The ``OAuthMixin._oauth_get_user`` method has been removed. + Override `~.OAuthMixin._oauth_get_user_future` instead. + +`tornado.concurrent` +~~~~~~~~~~~~~~~~~~~~ + +- The ``callback`` argument to `.run_on_executor` has been removed. +- ``return_future`` has been removed. + +`tornado.gen` +~~~~~~~~~~~~~ + +- Some older portions of this module have been removed. This includes + ``engine``, ``YieldPoint``, ``Callback``, ``Wait``, ``WaitAll``, + ``MultiYieldPoint``, and ``Task``. +- Functions decorated with ``@gen.coroutine`` no longer accept + ``callback`` arguments. + +`tornado.httpclient` +~~~~~~~~~~~~~~~~~~~~ + +- The behavior of ``raise_error=False`` has changed. Now only + suppresses the errors raised due to completed responses with non-200 + status codes (previously it suppressed all errors). +- The ``callback`` argument to `.AsyncHTTPClient.fetch` has been removed. + +`tornado.httputil` +~~~~~~~~~~~~~~~~~~ + +- ``HTTPServerRequest.write`` has been removed. Use the methods of + ``request.connection`` instead. +- Unrecognized ``Content-Encoding`` values now log warnings only for + content types that we would otherwise attempt to parse. + +`tornado.ioloop` +~~~~~~~~~~~~~~~~ + +- ``IOLoop.set_blocking_signal_threshold``, + ``IOLoop.set_blocking_log_threshold``, ``IOLoop.log_stack``, + and ``IOLoop.handle_callback_exception`` have been removed. +- Improved performance of `.IOLoop.add_callback`. + +`tornado.iostream` +~~~~~~~~~~~~~~~~~~ + +- All ``callback`` arguments in this module have been removed except + for `.BaseIOStream.set_close_callback`. +- ``streaming_callback`` arguments to `.BaseIOStream.read_bytes` and + `.BaseIOStream.read_until_close` have been removed. +- Eliminated unnecessary logging of "Errno 0". + +`tornado.log` +~~~~~~~~~~~~~ + +- Log files opened by this module are now explicitly set to UTF-8 encoding. + +`tornado.netutil` +~~~~~~~~~~~~~~~~~ + +- The results of ``getaddrinfo`` are now sorted by address family to + avoid partial failures and deadlocks. + +`tornado.platform.twisted` +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- ``TornadoReactor`` and ``TwistedIOLoop`` have been removed. + +``tornado.simple_httpclient`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- The default HTTP client now supports the ``network_interface`` + request argument to specify the source IP for the connection. +- If a server returns a 3xx response code without a ``Location`` + header, the response is raised or returned directly instead of + trying and failing to follow the redirect. +- When following redirects, methods other than ``POST`` will no longer + be transformed into ``GET`` requests. 301 (permanent) redirects are + now treated the same way as 302 (temporary) and 303 (see other) + redirects in this respect. +- Following redirects now works with ``body_producer``. + +``tornado.stack_context`` +~~~~~~~~~~~~~~~~~~~~~~~~~ + +- The ``tornado.stack_context`` module has been removed. + +`tornado.tcpserver` +~~~~~~~~~~~~~~~~~~~ + +- `.TCPServer.start` now supports a ``max_restarts`` argument (same as + `.fork_processes`). + +`tornado.testing` +~~~~~~~~~~~~~~~~~ + +- `.AsyncHTTPTestCase` now drops all references to the `.Application` + during ``tearDown``, allowing its memory to be reclaimed sooner. +- `.AsyncTestCase` now cancels all pending coroutines in ``tearDown``, + in an effort to reduce warnings from the python runtime about + coroutines that were not awaited. Note that this may cause + ``asyncio.CancelledError`` to be logged in other places. Coroutines + that expect to be running at test shutdown may need to catch this + exception. + +`tornado.web` +~~~~~~~~~~~~~ + +- The ``asynchronous`` decorator has been removed. +- The ``callback`` argument to `.RequestHandler.flush` has been removed. +- `.StaticFileHandler` now supports large negative values for the + ``Range`` header and returns an appropriate error for ``end > + start``. +- It is now possible to set ``expires_days`` in ``xsrf_cookie_kwargs``. + +`tornado.websocket` +~~~~~~~~~~~~~~~~~~~ + +- Pings and other messages sent while the connection is closing are + now silently dropped instead of logging exceptions. +- Errors raised by ``open()`` are now caught correctly when this method + is a coroutine. + +`tornado.wsgi` +~~~~~~~~~~~~~~ + +- ``WSGIApplication`` and ``WSGIAdapter`` have been removed. diff --git a/docs/releases/v6.0.1.rst b/docs/releases/v6.0.1.rst new file mode 100644 index 0000000000..c9da7507e6 --- /dev/null +++ b/docs/releases/v6.0.1.rst @@ -0,0 +1,11 @@ +What's new in Tornado 6.0.1 +=========================== + +Mar 3, 2019 +----------- + +Bug fixes +~~~~~~~~~ + +- Fixed issues with type annotations that caused errors while + importing Tornado on Python 3.5.2. diff --git a/docs/releases/v6.0.2.rst b/docs/releases/v6.0.2.rst new file mode 100644 index 0000000000..3d394a3edc --- /dev/null +++ b/docs/releases/v6.0.2.rst @@ -0,0 +1,13 @@ +What's new in Tornado 6.0.2 +=========================== + +Mar 23, 2019 +------------ + +Bug fixes +~~~~~~~~~ + +- `.WebSocketHandler.set_nodelay` works again. +- Accessing ``HTTPResponse.body`` now returns an empty byte string + instead of raising ``ValueError`` for error responses that don't + have a body (it returned None in this case in Tornado 5). diff --git a/docs/releases/v6.0.3.rst b/docs/releases/v6.0.3.rst new file mode 100644 index 0000000000..c112a0286d --- /dev/null +++ b/docs/releases/v6.0.3.rst @@ -0,0 +1,14 @@ +What's new in Tornado 6.0.3 +=========================== + +Jun 22, 2019 +------------ + +Bug fixes +~~~~~~~~~ + +- `.gen.with_timeout` always treats ``asyncio.CancelledError`` as a + ``quiet_exception`` (this improves compatibility with Python 3.8, + which changed ``CancelledError`` to a ``BaseException``). +- ``IOStream`` now checks for closed streams earlier, avoiding + spurious logged errors in some situations (mainly with websockets). diff --git a/docs/releases/v6.0.4.rst b/docs/releases/v6.0.4.rst new file mode 100644 index 0000000000..f9864bff4a --- /dev/null +++ b/docs/releases/v6.0.4.rst @@ -0,0 +1,21 @@ +What's new in Tornado 6.0.4 +=========================== + +Mar 3, 2020 +----------- + +General changes +~~~~~~~~~~~~~~~ + +- Binary wheels are now available for Python 3.8 on Windows. Note that it is + still necessary to use + ``asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())`` for + this platform/version. + +Bug fixes +~~~~~~~~~ + +- Fixed an issue in `.IOStream` (introduced in 6.0.0) that resulted in + ``StreamClosedError`` being incorrectly raised if a stream is closed mid-read + but there is enough buffered data to satisfy the read. +- `.AnyThreadEventLoopPolicy` now always uses the selector event loop on Windows. \ No newline at end of file diff --git a/docs/releases/v6.1.0.rst b/docs/releases/v6.1.0.rst new file mode 100644 index 0000000000..7de6350ab5 --- /dev/null +++ b/docs/releases/v6.1.0.rst @@ -0,0 +1,106 @@ +What's new in Tornado 6.1.0 +=========================== + +Oct 30, 2020 +------------ + +Deprecation notice +~~~~~~~~~~~~~~~~~~ + +- This is the last release of Tornado to support Python 3.5. Future versions + will require Python 3.6 or newer. + +General changes +~~~~~~~~~~~~~~~ + +- Windows support has been improved. Tornado is now compatible with the proactor + event loop (which became the default in Python 3.8) by automatically falling + back to running a selector in a second thread. This means that it is no longer + necessary to explicitly configure a selector event loop, although doing so may + improve performance. This does not change the fact that Tornado is significantly + less scalable on Windows than on other platforms. +- Binary wheels are now provided for Windows, MacOS, and Linux (amd64 and arm64). + +`tornado.gen` +~~~~~~~~~~~~~ + +- `.coroutine` now has better support for the Python 3.7+ ``contextvars`` module. + In particular, the ``ContextVar.reset`` method is now supported. + +`tornado.http1connection` +~~~~~~~~~~~~~~~~~~~~~~~~~ + +- ``HEAD`` requests to handlers that used chunked encoding no longer produce malformed output. +- Certain kinds of malformed ``gzip`` data no longer cause an infinite loop. + +`tornado.httpclient` +~~~~~~~~~~~~~~~~~~~~ + +- Setting ``decompress_response=False`` now works correctly with + ``curl_httpclient``. +- Mixing requests with and without proxies works correctly in ``curl_httpclient`` + (assuming the version of pycurl is recent enough). +- A default ``User-Agent`` of ``Tornado/$VERSION`` is now used if the + ``user_agent`` parameter is not specified. +- After a 303 redirect, ``tornado.simple_httpclient`` always uses ``GET``. + Previously this would use ``GET`` if the original request was a ``POST`` and + would otherwise reuse the original request method. For ``curl_httpclient``, the + behavior depends on the version of ``libcurl`` (with the most recent versions + using ``GET`` after 303 regardless of the original method). +- Setting ``request_timeout`` and/or ``connect_timeout`` to zero is now supported + to disable the timeout. + +`tornado.httputil` +~~~~~~~~~~~~~~~~~~ + +- Header parsing is now faster. +- `.parse_body_arguments` now accepts incompletely-escaped non-ASCII inputs. + +`tornado.iostream` +~~~~~~~~~~~~~~~~~~ + +- `ssl.CertificateError` during the SSL handshake is now handled correctly. +- Reads that are resolved while the stream is closing are now handled correctly. + +`tornado.log` +~~~~~~~~~~~~~ + +- When colored logging is enabled, ``logging.CRITICAL`` messages are now + recognized and colored magenta. + +`tornado.netutil` +~~~~~~~~~~~~~~~~~ + +- ``EADDRNOTAVAIL`` is now ignored when binding to ``localhost`` with IPv6. This + error is common in docker. + +`tornado.platform.asyncio` +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- `.AnyThreadEventLoopPolicy` now also configures a selector event loop for + these threads (the proactor event loop only works on the main thread) + +``tornado.platform.auto`` +~~~~~~~~~~~~~~~~~~~~~~~~~ + +- The ``set_close_exec`` function has been removed. + +`tornado.testing` +~~~~~~~~~~~~~~~~~ + +- `.ExpectLog` now has a ``level`` argument to ensure that the given log level + is enabled. + +`tornado.web` +~~~~~~~~~~~~~ + +- ``RedirectHandler.get`` now accepts keyword arguments. +- When sending 304 responses, more headers (including ``Allow``) are now preserved. +- ``reverse_url`` correctly handles escaped characters in the regex route. +- Default ``Etag`` headers are now generated with SHA-512 instead of MD5. + +`tornado.websocket` +~~~~~~~~~~~~~~~~~~~ + +- The ``ping_interval`` timer is now stopped when the connection is closed. +- `.websocket_connect` now raises an error when it encounters a redirect instead of hanging. diff --git a/docs/requirements.in b/docs/requirements.in new file mode 100644 index 0000000000..334534d739 --- /dev/null +++ b/docs/requirements.in @@ -0,0 +1,3 @@ +sphinx +sphinxcontrib-asyncio +sphinx_rtd_theme diff --git a/docs/requirements.txt b/docs/requirements.txt index 082785c979..ab250c658e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1 +1,26 @@ -Twisted +alabaster==0.7.12 +babel==2.8.0 +certifi==2020.6.20 +chardet==3.0.4 +docutils==0.16 +idna==2.10 +imagesize==1.2.0 +jinja2==2.11.3 +markupsafe==1.1.1 +packaging==20.4 +pygments==2.7.4 +pyparsing==2.4.7 +pytz==2020.1 +requests==2.24.0 +six==1.15.0 +snowballstemmer==2.0.0 +sphinx-rtd-theme==0.5.0 +sphinx==3.2.1 +sphinxcontrib-applehelp==1.0.2 +sphinxcontrib-asyncio==0.3.0 +sphinxcontrib-devhelp==1.0.2 +sphinxcontrib-htmlhelp==1.0.3 +sphinxcontrib-jsmath==1.0.1 +sphinxcontrib-qthelp==1.0.3 +sphinxcontrib-serializinghtml==1.1.4 +urllib3==1.25.11 diff --git a/docs/stack_context.rst b/docs/stack_context.rst deleted file mode 100644 index 489a37fdc5..0000000000 --- a/docs/stack_context.rst +++ /dev/null @@ -1,5 +0,0 @@ -``tornado.stack_context`` --- Exception handling across asynchronous callbacks -============================================================================== - -.. automodule:: tornado.stack_context - :members: diff --git a/docs/twisted.rst b/docs/twisted.rst index a6e1946f11..5d8fe8fbc8 100644 --- a/docs/twisted.rst +++ b/docs/twisted.rst @@ -1,24 +1,60 @@ ``tornado.platform.twisted`` --- Bridges between Twisted and Tornado -======================================================================== +==================================================================== -.. automodule:: tornado.platform.twisted +.. module:: tornado.platform.twisted - Twisted on Tornado - ------------------ +.. deprecated:: 6.0 - .. autoclass:: TornadoReactor - :members: + This module is no longer recommended for new code. Instead of using + direct integration between Tornado and Twisted, new applications should + rely on the integration with ``asyncio`` provided by both packages. - .. autofunction:: install +Importing this module has the side effect of registering Twisted's ``Deferred`` +class with Tornado's ``@gen.coroutine`` so that ``Deferred`` objects can be +used with ``yield`` in coroutines using this decorator (importing this module has +no effect on native coroutines using ``async def``). - Tornado on Twisted - ------------------ +.. function:: install() - .. autoclass:: TwistedIOLoop - :members: + Install ``AsyncioSelectorReactor`` as the default Twisted reactor. - Twisted DNS resolver - -------------------- + .. deprecated:: 5.1 + + This function is provided for backwards compatibility; code + that does not require compatibility with older versions of + Tornado should use + ``twisted.internet.asyncioreactor.install()`` directly. + + .. versionchanged:: 6.0.3 + + In Tornado 5.x and before, this function installed a reactor + based on the Tornado ``IOLoop``. When that reactor + implementation was removed in Tornado 6.0.0, this function was + removed as well. It was restored in Tornado 6.0.3 using the + ``asyncio`` reactor instead. + +Twisted DNS resolver +-------------------- + +.. class:: TwistedResolver + + Twisted-based asynchronous resolver. + + This is a non-blocking and non-threaded resolver. It is + recommended only when threads cannot be used, since it has + limitations compared to the standard ``getaddrinfo``-based + `~tornado.netutil.Resolver` and + `~tornado.netutil.DefaultExecutorResolver`. Specifically, it returns at + most one result, and arguments other than ``host`` and ``family`` + are ignored. It may fail to resolve when ``family`` is not + ``socket.AF_UNSPEC``. + + Requires Twisted 12.1 or newer. + + .. versionchanged:: 5.0 + The ``io_loop`` argument (deprecated since version 4.1) has been removed. + + .. deprecated:: 6.2 + This class is deprecated and will be removed in Tornado 7.0. Use the default + thread-based resolver instead. - .. autoclass:: TwistedResolver - :members: diff --git a/docs/utilities.rst b/docs/utilities.rst index 55536626ed..4c6edf586a 100644 --- a/docs/utilities.rst +++ b/docs/utilities.rst @@ -7,6 +7,5 @@ Utilities concurrent log options - stack_context testing util diff --git a/docs/web.rst b/docs/web.rst index 77f2fff4c6..720d75678e 100644 --- a/docs/web.rst +++ b/docs/web.rst @@ -9,7 +9,7 @@ Request handlers ---------------- - .. autoclass:: RequestHandler + .. autoclass:: RequestHandler(...) Entry points ^^^^^^^^^^^^ @@ -21,14 +21,14 @@ .. _verbs: Implement any of the following methods (collectively known as the - HTTP verb methods) to handle the corresponding HTTP method. - These methods can be made asynchronous with one of the following - decorators: `.gen.coroutine`, `.return_future`, or `asynchronous`. + HTTP verb methods) to handle the corresponding HTTP method. These + methods can be made asynchronous with the ``async def`` keyword or + `.gen.coroutine` decorator. The arguments to these methods come from the `.URLSpec`: Any capturing groups in the regular expression become arguments to the HTTP verb methods (keyword arguments if the group is named, - positional arguments if its unnamed). + positional arguments if it's unnamed). To support a method not on this list, override the class variable ``SUPPORTED_METHODS``:: @@ -50,11 +50,24 @@ Input ^^^^^ - .. automethod:: RequestHandler.get_argument + The ``argument`` methods provide support for HTML form-style + arguments. These methods are available in both singular and plural + forms because HTML forms are ambiguous and do not distinguish + between a singular argument and a list containing one entry. If you + wish to use other formats for arguments (for example, JSON), parse + ``self.request.body`` yourself:: + + def prepare(self): + if self.request.headers['Content-Type'] == 'application/x-json': + self.args = json_decode(self.request.body) + # Access self.args directly instead of using self.get_argument. + + + .. automethod:: RequestHandler.get_argument(name: str, default: Union[None, str, RAISE] = RAISE, strip: bool = True) -> Optional[str] .. automethod:: RequestHandler.get_arguments - .. automethod:: RequestHandler.get_query_argument + .. automethod:: RequestHandler.get_query_argument(name: str, default: Union[None, str, RAISE] = RAISE, strip: bool = True) -> Optional[str] .. automethod:: RequestHandler.get_query_arguments - .. automethod:: RequestHandler.get_body_argument + .. automethod:: RequestHandler.get_body_argument(name: str, default: Union[None, str, RAISE] = RAISE, strip: bool = True) -> Optional[str] .. automethod:: RequestHandler.get_body_arguments .. automethod:: RequestHandler.decode_argument .. attribute:: RequestHandler.request @@ -125,6 +138,7 @@ .. automethod:: RequestHandler.compute_etag .. automethod:: RequestHandler.create_template_loader .. autoattribute:: RequestHandler.current_user + .. automethod:: RequestHandler.detach .. automethod:: RequestHandler.get_browser_locale .. automethod:: RequestHandler.get_current_user .. automethod:: RequestHandler.get_login_url @@ -145,9 +159,9 @@ Application configuration - ----------------------------- - .. autoclass:: Application - :members: + ------------------------- + + .. autoclass:: Application(handlers: Optional[List[Union[Rule, Tuple]]] = None, default_host: Optional[str] = None, transforms: Optional[List[Type[OutputTransform]]] = None, **settings) .. attribute:: settings @@ -184,7 +198,7 @@ `RequestHandler` object). The default implementation writes to the `logging` module's root logger. May also be customized by overriding `Application.log_request`. - * ``serve_traceback``: If true, the default error page + * ``serve_traceback``: If ``True``, the default error page will include the traceback of the error. This option is new in Tornado 3.2; previously this functionality was controlled by the ``debug`` setting. @@ -211,7 +225,7 @@ * ``login_url``: The `authenticated` decorator will redirect to this url if the user is not logged in. Can be further customized by overriding `RequestHandler.get_login_url` - * ``xsrf_cookies``: If true, :ref:`xsrf` will be enabled. + * ``xsrf_cookies``: If ``True``, :ref:`xsrf` will be enabled. * ``xsrf_cookie_version``: Controls the version of new XSRF cookies produced by this server. Should generally be left at the default (which will always be the highest supported @@ -265,13 +279,18 @@ should be a dictionary of keyword arguments to be passed to the handler's ``initialize`` method. + .. automethod:: Application.listen + .. automethod:: Application.add_handlers(handlers: List[Union[Rule, Tuple]]) + .. automethod:: Application.get_handler_delegate + .. automethod:: Application.reverse_url + .. automethod:: Application.log_request + .. autoclass:: URLSpec The ``URLSpec`` class is also available under the name ``tornado.web.url``. Decorators ---------- - .. autofunction:: asynchronous .. autofunction:: authenticated .. autofunction:: addslash .. autofunction:: removeslash diff --git a/docs/websocket.rst b/docs/websocket.rst index 96255589be..76bc05227e 100644 --- a/docs/websocket.rst +++ b/docs/websocket.rst @@ -16,6 +16,7 @@ .. automethod:: WebSocketHandler.on_message .. automethod:: WebSocketHandler.on_close .. automethod:: WebSocketHandler.select_subprotocol + .. autoattribute:: WebSocketHandler.selected_subprotocol .. automethod:: WebSocketHandler.on_ping Output diff --git a/docs/wsgi.rst b/docs/wsgi.rst index a54b7aaae7..75d544aad0 100644 --- a/docs/wsgi.rst +++ b/docs/wsgi.rst @@ -3,17 +3,5 @@ .. automodule:: tornado.wsgi - Running Tornado apps on WSGI servers - ------------------------------------ - - .. autoclass:: WSGIAdapter - :members: - - .. autoclass:: WSGIApplication - :members: - - Running WSGI apps on Tornado servers - ------------------------------------ - .. autoclass:: WSGIContainer :members: diff --git a/maint/README b/maint/README index 9a9122b3b0..2ea722be39 100644 --- a/maint/README +++ b/maint/README @@ -1,3 +1,3 @@ This directory contains tools and scripts that are used in the development -and maintainance of Tornado itself, but are probably not of interest to +and maintenance of Tornado itself, but are probably not of interest to Tornado users. diff --git a/demos/benchmark/benchmark.py b/maint/benchmark/benchmark.py similarity index 100% rename from demos/benchmark/benchmark.py rename to maint/benchmark/benchmark.py diff --git a/demos/benchmark/chunk_benchmark.py b/maint/benchmark/chunk_benchmark.py similarity index 100% rename from demos/benchmark/chunk_benchmark.py rename to maint/benchmark/chunk_benchmark.py diff --git a/demos/benchmark/gen_benchmark.py b/maint/benchmark/gen_benchmark.py similarity index 100% rename from demos/benchmark/gen_benchmark.py rename to maint/benchmark/gen_benchmark.py diff --git a/maint/benchmark/parsing_benchmark.py b/maint/benchmark/parsing_benchmark.py new file mode 100644 index 0000000000..d0bfcc8950 --- /dev/null +++ b/maint/benchmark/parsing_benchmark.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python +import re +import timeit +from enum import Enum +from typing import Callable + +from tornado.httputil import HTTPHeaders +from tornado.options import define, options, parse_command_line + + +define("benchmark", type=str) +define("num_runs", type=int, default=1) + + +_CRLF_RE = re.compile(r"\r?\n") +_TEST_HEADERS = ( + "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp," + "image/apng,*/*;q=0.8,application/signed-exchange;v=b3\r\n" + "Accept-Encoding: gzip, deflate, br\r\n" + "Accept-Language: ru-RU,ru;q=0.9,en-US;q=0.8,en;q=0.7\r\n" + "Cache-Control: max-age=0\r\n" + "Connection: keep-alive\r\n" + "Host: example.com\r\n" + "Upgrade-Insecure-Requests: 1\r\n" + "User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 " + "(KHTML, like Gecko) Chrome/73.0.3683.103 Safari/537.36\r\n" +) + + +def headers_split_re(headers: str) -> None: + for line in _CRLF_RE.split(headers): + pass + + +def headers_split_simple(headers: str) -> None: + for line in headers.split("\n"): + if line.endswith("\r"): + line = line[:-1] + + +def headers_parse_re(headers: str) -> HTTPHeaders: + h = HTTPHeaders() + for line in _CRLF_RE.split(headers): + if line: + h.parse_line(line) + return h + + +def headers_parse_simple(headers: str) -> HTTPHeaders: + h = HTTPHeaders() + for line in headers.split("\n"): + if line.endswith("\r"): + line = line[:-1] + if line: + h.parse_line(line) + return h + + +def run_headers_split(): + regex_time = timeit.timeit(lambda: headers_split_re(_TEST_HEADERS), number=100000) + print("regex", regex_time) + + simple_time = timeit.timeit( + lambda: headers_split_simple(_TEST_HEADERS), number=100000 + ) + print("str.split", simple_time) + + print("speedup", regex_time / simple_time) + + +def run_headers_full(): + regex_time = timeit.timeit(lambda: headers_parse_re(_TEST_HEADERS), number=10000) + print("regex", regex_time) + + simple_time = timeit.timeit( + lambda: headers_parse_simple(_TEST_HEADERS), number=10000 + ) + print("str.split", simple_time) + + print("speedup", regex_time / simple_time) + + +class Benchmark(Enum): + def __new__(cls, arg_value: str, func: Callable[[], None]): + member = object.__new__(cls) + member._value_ = arg_value + member.func = func + return member + + HEADERS_SPLIT = ("headers-split", run_headers_split) + HEADERS_FULL = ("headers-full", run_headers_full) + + +def main(): + parse_command_line() + + try: + func = Benchmark(options.benchmark).func + except ValueError: + known_benchmarks = [benchmark.value for benchmark in Benchmark] + print( + "Unknown benchmark: '{}', supported values are: {}" + .format(options.benchmark, ", ".join(known_benchmarks)) + ) + return + + for _ in range(options.num_runs): + func() + + +if __name__ == '__main__': + main() diff --git a/demos/benchmark/template_benchmark.py b/maint/benchmark/template_benchmark.py similarity index 100% rename from demos/benchmark/template_benchmark.py rename to maint/benchmark/template_benchmark.py diff --git a/maint/circlerefs/circlerefs.py b/maint/circlerefs/circlerefs.py index 5cc4e1f6de..bd8214aa82 100755 --- a/maint/circlerefs/circlerefs.py +++ b/maint/circlerefs/circlerefs.py @@ -7,7 +7,6 @@ increases memory footprint and CPU overhead, so we try to eliminate circular references created by normal operation. """ -from __future__ import print_function import gc import traceback diff --git a/maint/requirements.in b/maint/requirements.in index eeb2f4d6a3..3f704d00cc 100644 --- a/maint/requirements.in +++ b/maint/requirements.in @@ -1,27 +1,13 @@ # Requirements for tools used in the development of tornado. -# This list is for python 3.5; for 2.7 add: -# - futures -# - mock # -# Use virtualenv instead of venv; tox seems to get confused otherwise. +# This mainly contains tools that should be installed for editor integration. +# Other tools we use are installed only via tox or CI scripts. # -# maint/requirements.txt contains the pinned versions of all direct and -# indirect dependencies; this file only contains direct dependencies -# and is useful for upgrading. +# This is a manual recreation of the lockfile pattern: maint/requirements.txt +# is the lockfile, and maint/requirements.in is the input file containing only +# direct dependencies. -# Tornado's optional dependencies -Twisted -pycares -pycurl - -# Other useful tools -Sphinx -autopep8 -coverage +black flake8 -pep8 -pyflakes -sphinx-rtd-theme +mypy tox -twine -virtualenv diff --git a/maint/requirements.txt b/maint/requirements.txt index eaa02ea57b..0ef29bc3bc 100644 --- a/maint/requirements.txt +++ b/maint/requirements.txt @@ -1,44 +1,34 @@ -alabaster==0.7.10 -attrs==17.4.0 -Automat==0.6.0 -autopep8==1.3.4 -Babel==2.5.3 -certifi==2018.1.18 -chardet==3.0.4 -constantly==15.1.0 -coverage==4.5.1 -docutils==0.14 -flake8==3.5.0 -hyperlink==18.0.0 -idna==2.6 -imagesize==1.0.0 -incremental==17.5.0 -Jinja2==2.10 -MarkupSafe==1.0 +# Requirements for tools used in the development of tornado. +# +# This mainly contains tools that should be installed for editor integration. +# Other tools we use are installed only via tox or CI scripts. +# This is a manual recreation of the lockfile pattern: maint/requirements.txt +# is the lockfile, and maint/requirements.in is the input file containing only +# direct dependencies. + +black==20.8b1 +flake8==3.8.4 +mypy==0.790 +tox==3.20.1 +## The following requirements were added by pip freeze: +appdirs==1.4.4 +attrs==20.2.0 +click==7.1.2 +colorama==0.4.4 +distlib==0.3.1 +filelock==3.0.12 mccabe==0.6.1 -packaging==17.1 -pep8==1.7.1 -pkginfo==1.4.2 -pluggy==0.6.0 -py==1.5.2 -pycares==2.3.0 -pycodestyle==2.3.1 -pycurl==7.43.0.1 -pyflakes==1.6.0 -Pygments==2.2.0 -pyparsing==2.2.0 -pytz==2018.3 -requests==2.18.4 -requests-toolbelt==0.8.0 -six==1.11.0 -snowballstemmer==1.2.1 -Sphinx==1.7.1 -sphinx-rtd-theme==0.2.4 -sphinxcontrib-websupport==1.0.1 -tox==2.9.1 -tqdm==4.19.8 -twine==1.10.0 -Twisted==17.9.0 -urllib3==1.22 -virtualenv==15.1.0 -zope.interface==4.4.3 +mypy-extensions==0.4.3 +packaging==20.4 +pathspec==0.8.0 +pluggy==0.13.1 +py==1.9.0 +pycodestyle==2.6.0 +pyflakes==2.2.0 +pyparsing==2.4.7 +regex==2020.10.28 +six==1.15.0 +toml==0.10.1 +typed-ast==1.4.1 +typing-extensions==3.7.4.3 +virtualenv==20.1.0 diff --git a/maint/scripts/test_resolvers.py b/maint/scripts/test_resolvers.py index 2a466c1ac9..82dec30e66 100755 --- a/maint/scripts/test_resolvers.py +++ b/maint/scripts/test_resolvers.py @@ -1,6 +1,4 @@ #!/usr/bin/env python -from __future__ import print_function - import pprint import socket diff --git a/maint/test/appengine/README b/maint/test/appengine/README deleted file mode 100644 index 8d534f28d1..0000000000 --- a/maint/test/appengine/README +++ /dev/null @@ -1,8 +0,0 @@ -Unit test support for app engine. Currently very limited as most of -our tests depend on direct network access, but these tests ensure that the -modules that are supposed to work on app engine don't depend on any -forbidden modules. - -The code lives in maint/appengine/common, but should be run from the py25 -or py27 subdirectories (which contain an app.yaml and a bunch of symlinks). -runtests.py is the entry point; cgi_runtests.py is used internally. diff --git a/maint/test/appengine/common/cgi_runtests.py b/maint/test/appengine/common/cgi_runtests.py deleted file mode 100755 index f28aa6ca01..0000000000 --- a/maint/test/appengine/common/cgi_runtests.py +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env python -from __future__ import absolute_import, division, print_function - -import sys -import unittest - -# Most of our tests depend on IOLoop, which is not usable on app engine. -# Run the tests that work, and check that everything else is at least -# importable (via tornado.test.import_test) -TEST_MODULES = [ - 'tornado.httputil.doctests', - 'tornado.iostream.doctests', - 'tornado.util.doctests', - #'tornado.test.auth_test', - #'tornado.test.concurrent_test', - #'tornado.test.curl_httpclient_test', - 'tornado.test.escape_test', - #'tornado.test.gen_test', - #'tornado.test.httpclient_test', - #'tornado.test.httpserver_test', - 'tornado.test.httputil_test', - 'tornado.test.import_test', - #'tornado.test.ioloop_test', - #'tornado.test.iostream_test', - 'tornado.test.locale_test', - #'tornado.test.netutil_test', - #'tornado.test.log_test', - 'tornado.test.options_test', - #'tornado.test.process_test', - #'tornado.test.simple_httpclient_test', - #'tornado.test.stack_context_test', - 'tornado.test.template_test', - #'tornado.test.testing_test', - #'tornado.test.twisted_test', - 'tornado.test.util_test', - #'tornado.test.web_test', - #'tornado.test.websocket_test', - #'tornado.test.wsgi_test', -] - - -def all(): - return unittest.defaultTestLoader.loadTestsFromNames(TEST_MODULES) - - -def main(): - print("Content-Type: text/plain\r\n\r\n", end="") - - try: - unittest.main(defaultTest='all', argv=sys.argv[:1]) - except SystemExit as e: - if e.code == 0: - print("PASS") - else: - raise - - -if __name__ == '__main__': - main() diff --git a/maint/test/appengine/common/runtests.py b/maint/test/appengine/common/runtests.py deleted file mode 100755 index ca7abe119f..0000000000 --- a/maint/test/appengine/common/runtests.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python -from __future__ import absolute_import, division, print_function - -import contextlib -import errno -import os -import random -import signal -import socket -import subprocess -import sys -import time -import urllib2 - -try: - xrange -except NameError: - xrange = range - -if __name__ == "__main__": - tornado_root = os.path.abspath(os.path.join(os.path.dirname(__file__), - '../../..')) - # dev_appserver doesn't seem to set SO_REUSEADDR - port = random.randrange(10000, 11000) - # does dev_appserver.py ever live anywhere but /usr/local/bin? - proc = subprocess.Popen([sys.executable, - "/usr/local/bin/dev_appserver.py", - os.path.dirname(os.path.abspath(__file__)), - "--port=%d" % port, - "--skip_sdk_update_check", - ], - cwd=tornado_root) - - try: - for i in xrange(50): - with contextlib.closing(socket.socket()) as sock: - err = sock.connect_ex(('localhost', port)) - if err == 0: - break - elif err != errno.ECONNREFUSED: - raise Exception("Got unexpected socket error %d" % err) - time.sleep(0.1) - else: - raise Exception("Server didn't start listening") - - resp = urllib2.urlopen("http://localhost:%d/" % port) - print(resp.read()) - finally: - # dev_appserver sometimes ignores SIGTERM (especially on 2.5), - # so try a few times to kill it. - for sig in [signal.SIGTERM, signal.SIGTERM, signal.SIGKILL]: - os.kill(proc.pid, sig) - res = os.waitpid(proc.pid, os.WNOHANG) - if res != (0, 0): - break - time.sleep(0.1) - else: - os.waitpid(proc.pid, 0) diff --git a/maint/test/appengine/py27/app.yaml b/maint/test/appengine/py27/app.yaml deleted file mode 100644 index e5dea072da..0000000000 --- a/maint/test/appengine/py27/app.yaml +++ /dev/null @@ -1,9 +0,0 @@ -application: tornado-tests-appengine27 -version: 1 -runtime: python27 -threadsafe: false -api_version: 1 - -handlers: -- url: / - script: cgi_runtests.py \ No newline at end of file diff --git a/maint/test/appengine/py27/cgi_runtests.py b/maint/test/appengine/py27/cgi_runtests.py deleted file mode 120000 index a9fc90e99c..0000000000 --- a/maint/test/appengine/py27/cgi_runtests.py +++ /dev/null @@ -1 +0,0 @@ -../common/cgi_runtests.py \ No newline at end of file diff --git a/maint/test/appengine/py27/runtests.py b/maint/test/appengine/py27/runtests.py deleted file mode 120000 index 2cce26b0fb..0000000000 --- a/maint/test/appengine/py27/runtests.py +++ /dev/null @@ -1 +0,0 @@ -../common/runtests.py \ No newline at end of file diff --git a/maint/test/appengine/py27/tornado b/maint/test/appengine/py27/tornado deleted file mode 120000 index d4f6cc317d..0000000000 --- a/maint/test/appengine/py27/tornado +++ /dev/null @@ -1 +0,0 @@ -../../../../tornado \ No newline at end of file diff --git a/maint/test/appengine/setup.py b/maint/test/appengine/setup.py deleted file mode 100644 index 5d2d3141d2..0000000000 --- a/maint/test/appengine/setup.py +++ /dev/null @@ -1,4 +0,0 @@ -# Dummy setup file to make tox happy. In the appengine world things aren't -# installed through setup.py -import distutils.core -distutils.core.setup() diff --git a/maint/test/appengine/tox.ini b/maint/test/appengine/tox.ini deleted file mode 100644 index ca7a861aee..0000000000 --- a/maint/test/appengine/tox.ini +++ /dev/null @@ -1,15 +0,0 @@ -# App Engine tests require the SDK to be installed separately. -# Version 1.6.1 or newer is required (older versions don't work when -# python is run from a virtualenv) -# -# These are currently excluded from the main tox.ini because their -# logs are spammy and they're a little flaky. -[tox] -envlist = py27-appengine - -[testenv] -changedir = {toxworkdir} - -[testenv:py27-appengine] -basepython = python2.7 -commands = python {toxinidir}/py27/runtests.py {posargs:} diff --git a/maint/test/cython/tox.ini b/maint/test/cython/tox.ini index bbf8f15748..c79ab7db5e 100644 --- a/maint/test/cython/tox.ini +++ b/maint/test/cython/tox.ini @@ -1,6 +1,6 @@ [tox] # This currently segfaults on pypy. -envlist = py27,py35,py36 +envlist = py27,py36 [testenv] deps = @@ -13,5 +13,4 @@ commands = python -m unittest cythonapp_test # defaults for the others. basepython = py27: python2.7 - py35: python3.5 py36: python3.6 diff --git a/maint/test/mypy/.gitignore b/maint/test/mypy/.gitignore new file mode 100644 index 0000000000..dc3112749e --- /dev/null +++ b/maint/test/mypy/.gitignore @@ -0,0 +1 @@ +UNKNOWN.egg-info diff --git a/maint/test/mypy/bad.py b/maint/test/mypy/bad.py new file mode 100644 index 0000000000..3e6b6342e2 --- /dev/null +++ b/maint/test/mypy/bad.py @@ -0,0 +1,6 @@ +from tornado.web import RequestHandler + + +class MyHandler(RequestHandler): + def get(self) -> str: # Deliberate type error + return "foo" diff --git a/maint/test/mypy/good.py b/maint/test/mypy/good.py new file mode 100644 index 0000000000..5ee2d3ddcb --- /dev/null +++ b/maint/test/mypy/good.py @@ -0,0 +1,11 @@ +from tornado import gen +from tornado.web import RequestHandler + + +class MyHandler(RequestHandler): + def get(self) -> None: + self.write("foo") + + async def post(self) -> None: + await gen.sleep(1) + self.write("foo") diff --git a/maint/test/mypy/setup.py b/maint/test/mypy/setup.py new file mode 100644 index 0000000000..606849326a --- /dev/null +++ b/maint/test/mypy/setup.py @@ -0,0 +1,3 @@ +from setuptools import setup + +setup() diff --git a/maint/test/mypy/tox.ini b/maint/test/mypy/tox.ini new file mode 100644 index 0000000000..42235252d6 --- /dev/null +++ b/maint/test/mypy/tox.ini @@ -0,0 +1,14 @@ +# Test that the py.typed marker file is respected and client +# application code can be typechecked using tornado's published +# annotations. +[tox] +envlist = py37 + +[testenv] +deps = + ../../.. + mypy +whitelist_externals = /bin/sh +commands = + mypy good.py + /bin/sh -c '! mypy bad.py' diff --git a/maint/test/pyuv/tox.ini b/maint/test/pyuv/tox.ini deleted file mode 100644 index 8b6d569a7a..0000000000 --- a/maint/test/pyuv/tox.ini +++ /dev/null @@ -1,13 +0,0 @@ -[tox] -envlist = py27 -setupdir = ../../.. - -[testenv] -commands = - python -m tornado.test.runtests --ioloop=tornaduv.UVLoop {posargs:} -# twisted tests don't work on pyuv IOLoop currently. -deps = - pyuv - tornaduv - futures - mock diff --git a/maint/test/redbot/red_test.py b/maint/test/redbot/red_test.py index 055b479565..ac4b5ad25b 100755 --- a/maint/test/redbot/red_test.py +++ b/maint/test/redbot/red_test.py @@ -248,7 +248,7 @@ def get_app(self): return Application(self.get_handlers(), gzip=True, **self.get_app_kwargs()) def get_allowed_errors(self): - return super(GzipHTTPTest, self).get_allowed_errors() + [ + return super().get_allowed_errors() + [ # TODO: The Etag is supposed to change when Content-Encoding is # used. This should be fixed, but it's difficult to do with the # way GZipContentEncoding fits into the pipeline, and in practice diff --git a/maint/test/websocket/fuzzingclient.json b/maint/test/websocket/fuzzingclient.json index 7b4cb318cf..2ac091f37a 100644 --- a/maint/test/websocket/fuzzingclient.json +++ b/maint/test/websocket/fuzzingclient.json @@ -1,17 +1,42 @@ { - "options": {"failByDrop": false}, - "outdir": "./reports/servers", - - "servers": [ - {"agent": "Tornado/py27", "url": "ws://localhost:9001", - "options": {"version": 18}}, - {"agent": "Tornado/py35", "url": "ws://localhost:9002", - "options": {"version": 18}}, - {"agent": "Tornado/pypy", "url": "ws://localhost:9003", - "options": {"version": 18}} - ], - - "cases": ["*"], - "exclude-cases": ["9.*", "12.*.1","12.2.*", "12.3.*", "12.4.*", "12.5.*", "13.*.1"], - "exclude-agent-cases": {} -} + "options": { + "failByDrop": false + }, + "outdir": "./reports/servers", + "servers": [ + { + "agent": "Tornado/py27", + "url": "ws://localhost:9001", + "options": { + "version": 18 + } + }, + { + "agent": "Tornado/py39", + "url": "ws://localhost:9002", + "options": { + "version": 18 + } + }, + { + "agent": "Tornado/pypy", + "url": "ws://localhost:9003", + "options": { + "version": 18 + } + } + ], + "cases": [ + "*" + ], + "exclude-cases": [ + "9.*", + "12.*.1", + "12.2.*", + "12.3.*", + "12.4.*", + "12.5.*", + "13.*.1" + ], + "exclude-agent-cases": {} +} \ No newline at end of file diff --git a/maint/test/websocket/run-client.sh b/maint/test/websocket/run-client.sh index bd35f4dca8..f32e72aff9 100755 --- a/maint/test/websocket/run-client.sh +++ b/maint/test/websocket/run-client.sh @@ -10,7 +10,7 @@ FUZZING_SERVER_PID=$! sleep 1 .tox/py27/bin/python client.py --name='Tornado/py27' -.tox/py35/bin/python client.py --name='Tornado/py35' +.tox/py39/bin/python client.py --name='Tornado/py39' .tox/pypy/bin/python client.py --name='Tornado/pypy' kill $FUZZING_SERVER_PID diff --git a/maint/test/websocket/run-server.sh b/maint/test/websocket/run-server.sh index 2a83871366..401795a005 100755 --- a/maint/test/websocket/run-server.sh +++ b/maint/test/websocket/run-server.sh @@ -15,8 +15,8 @@ tox .tox/py27/bin/python server.py --port=9001 & PY27_SERVER_PID=$! -.tox/py35/bin/python server.py --port=9002 & -PY35_SERVER_PID=$! +.tox/py39/bin/python server.py --port=9002 & +PY39_SERVER_PID=$! .tox/pypy/bin/python server.py --port=9003 & PYPY_SERVER_PID=$! @@ -26,7 +26,7 @@ sleep 1 .tox/py27/bin/wstest -m fuzzingclient kill $PY27_SERVER_PID -kill $PY35_SERVER_PID +kill $PY39_SERVER_PID kill $PYPY_SERVER_PID wait diff --git a/maint/test/websocket/tox.ini b/maint/test/websocket/tox.ini index 289d127b10..7c4b72ebc6 100644 --- a/maint/test/websocket/tox.ini +++ b/maint/test/websocket/tox.ini @@ -2,7 +2,7 @@ # to install autobahn and build the speedups module. # See run.sh for the real test runner. [tox] -envlist = py27, py35, pypy +envlist = py27, py39, pypy setupdir=../../.. [testenv] diff --git a/maint/vm/windows/bootstrap.py b/maint/vm/windows/bootstrap.py index 281bf573cf..9bfb5c7230 100755 --- a/maint/vm/windows/bootstrap.py +++ b/maint/vm/windows/bootstrap.py @@ -21,7 +21,6 @@ To run under cygwin (which must be installed separately), run cd /cygdrive/e; python -m tornado.test.runtests """ -from __future__ import absolute_import, division, print_function import os import subprocess diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000..a365d2c33c --- /dev/null +++ b/setup.cfg @@ -0,0 +1,15 @@ +[metadata] +license_file = LICENSE + +[mypy] +python_version = 3.6 +no_implicit_optional = True + +[mypy-tornado.*,tornado.platform.*] +disallow_untyped_defs = True + +# It's generally too tedious to require type annotations in tests, but +# we do want to type check them as much as type inference allows. +[mypy-tornado.test.*] +disallow_untyped_defs = False +check_untyped_defs = True diff --git a/setup.py b/setup.py index be3e3d1b08..7c9b35c5af 100644 --- a/setup.py +++ b/setup.py @@ -13,9 +13,10 @@ # License for the specific language governing permissions and limitations # under the License. +# type: ignore + import os import platform -import ssl import sys import warnings @@ -80,11 +81,16 @@ def run(self): build_ext.run(self) except Exception: e = sys.exc_info()[1] - sys.stdout.write('%s\n' % str(e)) - warnings.warn(self.warning_message % ("Extension modules", - "There was an issue with " - "your platform configuration" - " - see above.")) + sys.stdout.write("%s\n" % str(e)) + warnings.warn( + self.warning_message + % ( + "Extension modules", + "There was an issue with " + "your platform configuration" + " - see above.", + ) + ) def build_extension(self, ext): name = ext.name @@ -92,60 +98,48 @@ def build_extension(self, ext): build_ext.build_extension(self, ext) except Exception: e = sys.exc_info()[1] - sys.stdout.write('%s\n' % str(e)) - warnings.warn(self.warning_message % ("The %s extension " - "module" % (name,), - "The output above " - "this warning shows how " - "the compilation " - "failed.")) + sys.stdout.write("%s\n" % str(e)) + warnings.warn( + self.warning_message + % ( + "The %s extension " "module" % (name,), + "The output above " + "this warning shows how " + "the compilation " + "failed.", + ) + ) kwargs = {} -version = "5.1.dev1" +with open("tornado/__init__.py") as f: + ns = {} + exec(f.read(), ns) + version = ns["version"] -with open('README.rst') as f: - kwargs['long_description'] = f.read() +with open("README.rst") as f: + kwargs["long_description"] = f.read() -if (platform.python_implementation() == 'CPython' and - os.environ.get('TORNADO_EXTENSION') != '0'): +if ( + platform.python_implementation() == "CPython" + and os.environ.get("TORNADO_EXTENSION") != "0" +): # This extension builds and works on pypy as well, although pypy's jit # produces equivalent performance. - kwargs['ext_modules'] = [ - Extension('tornado.speedups', - sources=['tornado/speedups.c']), + kwargs["ext_modules"] = [ + Extension("tornado.speedups", sources=["tornado/speedups.c"]) ] - if os.environ.get('TORNADO_EXTENSION') != '1': + if os.environ.get("TORNADO_EXTENSION") != "1": # Unless the user has specified that the extension is mandatory, # fall back to the pure-python implementation on any build failure. - kwargs['cmdclass'] = {'build_ext': custom_build_ext} + kwargs["cmdclass"] = {"build_ext": custom_build_ext} if setuptools is not None: - # If setuptools is not available, you're on your own for dependencies. - install_requires = [] - if sys.version_info < (3, 2): - install_requires.append('futures') - if sys.version_info < (3, 4): - install_requires.append('singledispatch') - if sys.version_info < (3, 5): - install_requires.append('backports_abc>=0.4') - kwargs['install_requires'] = install_requires - - python_requires = '>= 2.7, !=3.0.*, !=3.1.*, !=3.2.*, != 3.3.*' - kwargs['python_requires'] = python_requires - -# Verify that the SSL module has all the modern upgrades. Check for several -# names individually since they were introduced at different versions, -# although they should all be present by Python 3.4 or 2.7.9. -if (not hasattr(ssl, 'SSLContext') or - not hasattr(ssl, 'create_default_context') or - not hasattr(ssl, 'match_hostname')): - raise ImportError("Tornado requires an up-to-date SSL module. This means " - "Python 2.7.9+ or 3.4+ (although some distributions have " - "backported the necessary changes to older versions).") + python_requires = ">= 3.6" + kwargs["python_requires"] = python_requires setup( name="tornado", @@ -155,12 +149,15 @@ def build_extension(self, ext): # data files need to be listed both here (which determines what gets # installed) and in MANIFEST.in (which determines what gets included # in the sdist tarball) + "tornado": ["py.typed"], "tornado.test": [ "README", "csv_translations/fr_FR.csv", "gettext_translations/fr_FR/LC_MESSAGES/tornado_test.mo", "gettext_translations/fr_FR/LC_MESSAGES/tornado_test.po", "options_test.cfg", + "options_test_types.cfg", + "options_test_types_str.cfg", "static/robots.txt", "static/sample.xml", "static/sample.xml.gz", @@ -176,18 +173,19 @@ def build_extension(self, ext): author_email="python-tornado@googlegroups.com", url="http://www.tornadoweb.org/", license="http://www.apache.org/licenses/LICENSE-2.0", - description=("Tornado is a Python web framework and asynchronous networking library," - " originally developed at FriendFeed."), + description=( + "Tornado is a Python web framework and asynchronous networking library," + " originally developed at FriendFeed." + ), classifiers=[ - 'License :: OSI Approved :: Apache Software License', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: Implementation :: CPython', - 'Programming Language :: Python :: Implementation :: PyPy', + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", ], **kwargs ) diff --git a/tornado/__init__.py b/tornado/__init__.py index 9a9d2202ef..7c889e2d5a 100644 --- a/tornado/__init__.py +++ b/tornado/__init__.py @@ -15,8 +15,6 @@ """The Tornado web server and tools.""" -from __future__ import absolute_import, division, print_function - # version is a human-readable version number. # version_info is a four-tuple for programmatic comparison. The first @@ -24,5 +22,5 @@ # is zero for an official release, positive for a development branch, # or negative for a release candidate or beta (after the base version # number has been incremented) -version = "5.1.dev1" -version_info = (5, 1, 0, -100) +version = "6.2.dev1" +version_info = (6, 2, 0, -100) diff --git a/tornado/_locale_data.py b/tornado/_locale_data.py index a2c503907d..c706230ee5 100644 --- a/tornado/_locale_data.py +++ b/tornado/_locale_data.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- -# # Copyright 2012 Facebook # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -16,8 +14,6 @@ """Data used by the tornado.locale module.""" -from __future__ import absolute_import, division, print_function - LOCALE_NAMES = { "af_ZA": {"name_en": u"Afrikaans", "name": u"Afrikaans"}, "am_ET": {"name_en": u"Amharic", "name": u"አማርኛ"}, diff --git a/tornado/auth.py b/tornado/auth.py index 0069efcb48..d1cf29b39d 100644 --- a/tornado/auth.py +++ b/tornado/auth.py @@ -37,15 +37,14 @@ class GoogleOAuth2LoginHandler(tornado.web.RequestHandler, tornado.auth.GoogleOAuth2Mixin): - @tornado.gen.coroutine - def get(self): + async def get(self): if self.get_argument('code', False): - user = yield self.get_authenticated_user( + user = await self.get_authenticated_user( redirect_uri='http://your.site.com/auth/google', code=self.get_argument('code')) # Save the user with e.g. set_secure_cookie else: - yield self.authorize_redirect( + self.authorize_redirect( redirect_uri='http://your.site.com/auth/google', client_id=self.settings['google_oauth']['key'], scope=['profile', 'email'], @@ -55,93 +54,29 @@ def get(self): .. testoutput:: :hide: - -.. versionchanged:: 4.0 - All of the callback interfaces in this module are now guaranteed - to run their callback with an argument of ``None`` on error. - Previously some functions would do this while others would simply - terminate the request on their own. This change also ensures that - errors are more consistently reported through the ``Future`` interfaces. """ -from __future__ import absolute_import, division, print_function - import base64 import binascii -import functools import hashlib import hmac import time +import urllib.parse import uuid -import warnings -from tornado.concurrent import (Future, return_future, chain_future, - future_set_exc_info, - future_set_result_unless_cancelled) -from tornado import gen from tornado import httpclient from tornado import escape from tornado.httputil import url_concat -from tornado.log import gen_log -from tornado.stack_context import ExceptionStackContext -from tornado.util import unicode_type, ArgReplacer, PY3 +from tornado.util import unicode_type +from tornado.web import RequestHandler -if PY3: - import urllib.parse as urlparse - import urllib.parse as urllib_parse - long = int -else: - import urlparse - import urllib as urllib_parse +from typing import List, Any, Dict, cast, Iterable, Union, Optional class AuthError(Exception): pass -def _auth_future_to_callback(callback, future): - try: - result = future.result() - except AuthError as e: - gen_log.warning(str(e)) - result = None - callback(result) - - -def _auth_return_future(f): - """Similar to tornado.concurrent.return_future, but uses the auth - module's legacy callback interface. - - Note that when using this decorator the ``callback`` parameter - inside the function will actually be a future. - - .. deprecated:: 5.1 - Will be removed in 6.0. - """ - replacer = ArgReplacer(f, 'callback') - - @functools.wraps(f) - def wrapper(*args, **kwargs): - future = Future() - callback, args, kwargs = replacer.replace(future, args, kwargs) - if callback is not None: - warnings.warn("callback arguments are deprecated, use the returned Future instead", - DeprecationWarning) - future.add_done_callback( - functools.partial(_auth_future_to_callback, callback)) - - def handle_exception(typ, value, tb): - if future.done(): - return False - else: - future_set_exc_info(future, (typ, value, tb)) - return True - with ExceptionStackContext(handle_exception): - f(*args, **kwargs) - return future - return wrapper - - class OpenIdMixin(object): """Abstract implementation of OpenID and Attribute Exchange. @@ -149,10 +84,12 @@ class OpenIdMixin(object): * ``_OPENID_ENDPOINT``: the identity provider's URI. """ - @return_future - def authenticate_redirect(self, callback_uri=None, - ax_attrs=["name", "email", "language", "username"], - callback=None): + + def authenticate_redirect( + self, + callback_uri: Optional[str] = None, + ax_attrs: List[str] = ["name", "email", "language", "username"], + ) -> None: """Redirects to the authentication URL for this service. After authentication, the service will redirect back to the given @@ -163,24 +100,22 @@ def authenticate_redirect(self, callback_uri=None, all those attributes for your app, you can request fewer with the ax_attrs keyword argument. - .. versionchanged:: 3.1 - Returns a `.Future` and takes an optional callback. These are - not strictly necessary as this method is synchronous, - but they are supplied for consistency with - `OAuthMixin.authorize_redirect`. - - .. deprecated:: 5.1 + .. versionchanged:: 6.0 - The ``callback`` argument and returned awaitable will be removed - in Tornado 6.0; this will be an ordinary synchronous function. + The ``callback`` argument was removed and this method no + longer returns an awaitable object. It is now an ordinary + synchronous function. """ - callback_uri = callback_uri or self.request.uri + handler = cast(RequestHandler, self) + callback_uri = callback_uri or handler.request.uri + assert callback_uri is not None args = self._openid_args(callback_uri, ax_attrs=ax_attrs) - self.redirect(self._OPENID_ENDPOINT + "?" + urllib_parse.urlencode(args)) - callback() + endpoint = self._OPENID_ENDPOINT # type: ignore + handler.redirect(endpoint + "?" + urllib.parse.urlencode(args)) - @_auth_return_future - def get_authenticated_user(self, callback, http_client=None): + async def get_authenticated_user( + self, http_client: Optional[httpclient.AsyncHTTPClient] = None + ) -> Dict[str, Any]: """Fetches the authenticated user data upon redirect. This method should be called by the handler that receives the @@ -191,51 +126,60 @@ def get_authenticated_user(self, callback, http_client=None): The result of this method will generally be used to set a cookie. - .. deprecated:: 5.1 + .. versionchanged:: 6.0 - The ``callback`` argument is deprecated and will be removed in 6.0. - Use the returned awaitable object instead. + The ``callback`` argument was removed. Use the returned + awaitable object instead. """ + handler = cast(RequestHandler, self) # Verify the OpenID response via direct request to the OP - args = dict((k, v[-1]) for k, v in self.request.arguments.items()) + args = dict( + (k, v[-1]) for k, v in handler.request.arguments.items() + ) # type: Dict[str, Union[str, bytes]] args["openid.mode"] = u"check_authentication" - url = self._OPENID_ENDPOINT + url = self._OPENID_ENDPOINT # type: ignore if http_client is None: http_client = self.get_auth_http_client() - fut = http_client.fetch(url, method="POST", body=urllib_parse.urlencode(args)) - fut.add_done_callback(functools.partial( - self._on_authentication_verified, callback)) - - def _openid_args(self, callback_uri, ax_attrs=[], oauth_scope=None): - url = urlparse.urljoin(self.request.full_url(), callback_uri) + resp = await http_client.fetch( + url, method="POST", body=urllib.parse.urlencode(args) + ) + return self._on_authentication_verified(resp) + + def _openid_args( + self, + callback_uri: str, + ax_attrs: Iterable[str] = [], + oauth_scope: Optional[str] = None, + ) -> Dict[str, str]: + handler = cast(RequestHandler, self) + url = urllib.parse.urljoin(handler.request.full_url(), callback_uri) args = { "openid.ns": "http://specs.openid.net/auth/2.0", - "openid.claimed_id": - "http://specs.openid.net/auth/2.0/identifier_select", - "openid.identity": - "http://specs.openid.net/auth/2.0/identifier_select", + "openid.claimed_id": "http://specs.openid.net/auth/2.0/identifier_select", + "openid.identity": "http://specs.openid.net/auth/2.0/identifier_select", "openid.return_to": url, - "openid.realm": urlparse.urljoin(url, '/'), + "openid.realm": urllib.parse.urljoin(url, "/"), "openid.mode": "checkid_setup", } if ax_attrs: - args.update({ - "openid.ns.ax": "http://openid.net/srv/ax/1.0", - "openid.ax.mode": "fetch_request", - }) + args.update( + { + "openid.ns.ax": "http://openid.net/srv/ax/1.0", + "openid.ax.mode": "fetch_request", + } + ) ax_attrs = set(ax_attrs) - required = [] + required = [] # type: List[str] if "name" in ax_attrs: ax_attrs -= set(["name", "firstname", "fullname", "lastname"]) required += ["firstname", "fullname", "lastname"] - args.update({ - "openid.ax.type.firstname": - "http://axschema.org/namePerson/first", - "openid.ax.type.fullname": - "http://axschema.org/namePerson", - "openid.ax.type.lastname": - "http://axschema.org/namePerson/last", - }) + args.update( + { + "openid.ax.type.firstname": "http://axschema.org/namePerson/first", + "openid.ax.type.fullname": "http://axschema.org/namePerson", + "openid.ax.type.lastname": "http://axschema.org/namePerson/last", + } + ) known_attrs = { "email": "http://axschema.org/contact/email", "language": "http://axschema.org/pref/language", @@ -246,47 +190,45 @@ def _openid_args(self, callback_uri, ax_attrs=[], oauth_scope=None): required.append(name) args["openid.ax.required"] = ",".join(required) if oauth_scope: - args.update({ - "openid.ns.oauth": - "http://specs.openid.net/extensions/oauth/1.0", - "openid.oauth.consumer": self.request.host.split(":")[0], - "openid.oauth.scope": oauth_scope, - }) + args.update( + { + "openid.ns.oauth": "http://specs.openid.net/extensions/oauth/1.0", + "openid.oauth.consumer": handler.request.host.split(":")[0], + "openid.oauth.scope": oauth_scope, + } + ) return args - def _on_authentication_verified(self, future, response_fut): - try: - response = response_fut.result() - except Exception as e: - future.set_exception(AuthError( - "Error response %s" % e)) - return + def _on_authentication_verified( + self, response: httpclient.HTTPResponse + ) -> Dict[str, Any]: + handler = cast(RequestHandler, self) if b"is_valid:true" not in response.body: - future.set_exception(AuthError( - "Invalid OpenID response: %s" % response.body)) - return + raise AuthError("Invalid OpenID response: %r" % response.body) # Make sure we got back at least an email from attribute exchange ax_ns = None - for name in self.request.arguments: - if name.startswith("openid.ns.") and \ - self.get_argument(name) == u"http://openid.net/srv/ax/1.0": - ax_ns = name[10:] + for key in handler.request.arguments: + if ( + key.startswith("openid.ns.") + and handler.get_argument(key) == u"http://openid.net/srv/ax/1.0" + ): + ax_ns = key[10:] break - def get_ax_arg(uri): + def get_ax_arg(uri: str) -> str: if not ax_ns: return u"" prefix = "openid." + ax_ns + ".type." ax_name = None - for name in self.request.arguments.keys(): - if self.get_argument(name) == uri and name.startswith(prefix): - part = name[len(prefix):] + for name in handler.request.arguments.keys(): + if handler.get_argument(name) == uri and name.startswith(prefix): + part = name[len(prefix) :] ax_name = "openid." + ax_ns + ".value." + part break if not ax_name: return u"" - return self.get_argument(ax_name, u"") + return handler.get_argument(ax_name, u"") email = get_ax_arg("http://axschema.org/contact/email") name = get_ax_arg("http://axschema.org/namePerson") @@ -314,12 +256,12 @@ def get_ax_arg(uri): user["locale"] = locale if username: user["username"] = username - claimed_id = self.get_argument("openid.claimed_id", None) + claimed_id = handler.get_argument("openid.claimed_id", None) if claimed_id: user["claimed_id"] = claimed_id - future_set_result_unless_cancelled(future, user) + return user - def get_auth_http_client(self): + def get_auth_http_client(self) -> httpclient.AsyncHTTPClient: """Returns the `.AsyncHTTPClient` instance to be used for auth requests. May be overridden by subclasses to use an HTTP client other than @@ -344,9 +286,13 @@ class OAuthMixin(object): Subclasses must also override the `_oauth_get_user_future` and `_oauth_consumer_token` methods. """ - @return_future - def authorize_redirect(self, callback_uri=None, extra_params=None, - http_client=None, callback=None): + + async def authorize_redirect( + self, + callback_uri: Optional[str] = None, + extra_params: Optional[Dict[str, Any]] = None, + http_client: Optional[httpclient.AsyncHTTPClient] = None, + ) -> None: """Redirects the user to obtain OAuth authorization for this service. The ``callback_uri`` may be omitted if you have previously @@ -368,35 +314,31 @@ def authorize_redirect(self, callback_uri=None, extra_params=None, Now returns a `.Future` and takes an optional callback, for compatibility with `.gen.coroutine`. - .. deprecated:: 5.1 + .. versionchanged:: 6.0 - The ``callback`` argument is deprecated and will be removed in 6.0. - Use the returned awaitable object instead. + The ``callback`` argument was removed. Use the returned + awaitable object instead. """ if callback_uri and getattr(self, "_OAUTH_NO_CALLBACKS", False): raise Exception("This service does not support oauth_callback") if http_client is None: http_client = self.get_auth_http_client() + assert http_client is not None if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a": - fut = http_client.fetch( - self._oauth_request_token_url(callback_uri=callback_uri, - extra_params=extra_params)) - fut.add_done_callback(functools.partial( - self._on_request_token, - self._OAUTH_AUTHORIZE_URL, - callback_uri, - callback)) + response = await http_client.fetch( + self._oauth_request_token_url( + callback_uri=callback_uri, extra_params=extra_params + ) + ) else: - fut = http_client.fetch(self._oauth_request_token_url()) - fut.add_done_callback( - functools.partial( - self._on_request_token, self._OAUTH_AUTHORIZE_URL, - callback_uri, - callback)) - - @_auth_return_future - def get_authenticated_user(self, callback, http_client=None): + response = await http_client.fetch(self._oauth_request_token_url()) + url = self._OAUTH_AUTHORIZE_URL # type: ignore + self._on_request_token(url, callback_uri, response) + + async def get_authenticated_user( + self, http_client: Optional[httpclient.AsyncHTTPClient] = None + ) -> Dict[str, Any]: """Gets the OAuth authorized user and access token. This method should be called from the handler for your @@ -407,37 +349,47 @@ def get_authenticated_user(self, callback, http_client=None): also contain other fields such as ``name``, depending on the service used. - .. deprecated:: 5.1 + .. versionchanged:: 6.0 - The ``callback`` argument is deprecated and will be removed in 6.0. - Use the returned awaitable object instead. + The ``callback`` argument was removed. Use the returned + awaitable object instead. """ - future = callback - request_key = escape.utf8(self.get_argument("oauth_token")) - oauth_verifier = self.get_argument("oauth_verifier", None) - request_cookie = self.get_cookie("_oauth_request_token") + handler = cast(RequestHandler, self) + request_key = escape.utf8(handler.get_argument("oauth_token")) + oauth_verifier = handler.get_argument("oauth_verifier", None) + request_cookie = handler.get_cookie("_oauth_request_token") if not request_cookie: - future.set_exception(AuthError( - "Missing OAuth request token cookie")) - return - self.clear_cookie("_oauth_request_token") + raise AuthError("Missing OAuth request token cookie") + handler.clear_cookie("_oauth_request_token") cookie_key, cookie_secret = [ - base64.b64decode(escape.utf8(i)) for i in request_cookie.split("|")] + base64.b64decode(escape.utf8(i)) for i in request_cookie.split("|") + ] if cookie_key != request_key: - future.set_exception(AuthError( - "Request token does not match cookie")) - return - token = dict(key=cookie_key, secret=cookie_secret) + raise AuthError("Request token does not match cookie") + token = dict( + key=cookie_key, secret=cookie_secret + ) # type: Dict[str, Union[str, bytes]] if oauth_verifier: token["verifier"] = oauth_verifier if http_client is None: http_client = self.get_auth_http_client() - fut = http_client.fetch(self._oauth_access_token_url(token)) - fut.add_done_callback(functools.partial(self._on_access_token, callback)) - - def _oauth_request_token_url(self, callback_uri=None, extra_params=None): + assert http_client is not None + response = await http_client.fetch(self._oauth_access_token_url(token)) + access_token = _oauth_parse_response(response.body) + user = await self._oauth_get_user_future(access_token) + if not user: + raise AuthError("Error getting user") + user["access_token"] = access_token + return user + + def _oauth_request_token_url( + self, + callback_uri: Optional[str] = None, + extra_params: Optional[Dict[str, Any]] = None, + ) -> str: + handler = cast(RequestHandler, self) consumer_token = self._oauth_consumer_token() - url = self._OAUTH_REQUEST_TOKEN_URL + url = self._OAUTH_REQUEST_TOKEN_URL # type: ignore args = dict( oauth_consumer_key=escape.to_basestring(consumer_token["key"]), oauth_signature_method="HMAC-SHA1", @@ -449,8 +401,9 @@ def _oauth_request_token_url(self, callback_uri=None, extra_params=None): if callback_uri == "oob": args["oauth_callback"] = "oob" elif callback_uri: - args["oauth_callback"] = urlparse.urljoin( - self.request.full_url(), callback_uri) + args["oauth_callback"] = urllib.parse.urljoin( + handler.request.full_url(), callback_uri + ) if extra_params: args.update(extra_params) signature = _oauth10a_signature(consumer_token, "GET", url, args) @@ -458,32 +411,35 @@ def _oauth_request_token_url(self, callback_uri=None, extra_params=None): signature = _oauth_signature(consumer_token, "GET", url, args) args["oauth_signature"] = signature - return url + "?" + urllib_parse.urlencode(args) - - def _on_request_token(self, authorize_url, callback_uri, callback, - response_fut): - try: - response = response_fut.result() - except Exception as e: - raise Exception("Could not get request token: %s" % e) + return url + "?" + urllib.parse.urlencode(args) + + def _on_request_token( + self, + authorize_url: str, + callback_uri: Optional[str], + response: httpclient.HTTPResponse, + ) -> None: + handler = cast(RequestHandler, self) request_token = _oauth_parse_response(response.body) - data = (base64.b64encode(escape.utf8(request_token["key"])) + b"|" + - base64.b64encode(escape.utf8(request_token["secret"]))) - self.set_cookie("_oauth_request_token", data) + data = ( + base64.b64encode(escape.utf8(request_token["key"])) + + b"|" + + base64.b64encode(escape.utf8(request_token["secret"])) + ) + handler.set_cookie("_oauth_request_token", data) args = dict(oauth_token=request_token["key"]) if callback_uri == "oob": - self.finish(authorize_url + "?" + urllib_parse.urlencode(args)) - callback() + handler.finish(authorize_url + "?" + urllib.parse.urlencode(args)) return elif callback_uri: - args["oauth_callback"] = urlparse.urljoin( - self.request.full_url(), callback_uri) - self.redirect(authorize_url + "?" + urllib_parse.urlencode(args)) - callback() + args["oauth_callback"] = urllib.parse.urljoin( + handler.request.full_url(), callback_uri + ) + handler.redirect(authorize_url + "?" + urllib.parse.urlencode(args)) - def _oauth_access_token_url(self, request_token): + def _oauth_access_token_url(self, request_token: Dict[str, Any]) -> str: consumer_token = self._oauth_consumer_token() - url = self._OAUTH_ACCESS_TOKEN_URL + url = self._OAUTH_ACCESS_TOKEN_URL # type: ignore args = dict( oauth_consumer_key=escape.to_basestring(consumer_token["key"]), oauth_token=escape.to_basestring(request_token["key"]), @@ -496,41 +452,31 @@ def _oauth_access_token_url(self, request_token): args["oauth_verifier"] = request_token["verifier"] if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a": - signature = _oauth10a_signature(consumer_token, "GET", url, args, - request_token) + signature = _oauth10a_signature( + consumer_token, "GET", url, args, request_token + ) else: - signature = _oauth_signature(consumer_token, "GET", url, args, - request_token) + signature = _oauth_signature( + consumer_token, "GET", url, args, request_token + ) args["oauth_signature"] = signature - return url + "?" + urllib_parse.urlencode(args) - - def _on_access_token(self, future, response_fut): - try: - response = response_fut.result() - except Exception: - future.set_exception(AuthError("Could not fetch access token")) - return + return url + "?" + urllib.parse.urlencode(args) - access_token = _oauth_parse_response(response.body) - fut = self._oauth_get_user_future(access_token) - fut = gen.convert_yielded(fut) - fut.add_done_callback( - functools.partial(self._on_oauth_get_user, access_token, future)) - - def _oauth_consumer_token(self): + def _oauth_consumer_token(self) -> Dict[str, Any]: """Subclasses must override this to return their OAuth consumer keys. The return value should be a `dict` with keys ``key`` and ``secret``. """ raise NotImplementedError() - @return_future - def _oauth_get_user_future(self, access_token, callback): + async def _oauth_get_user_future( + self, access_token: Dict[str, Any] + ) -> Dict[str, Any]: """Subclasses must override this to get basic information about the user. - Should return a `.Future` whose result is a dictionary + Should be a coroutine whose result is a dictionary containing information about the user, which may have been retrieved by using ``access_token`` to make a request to the service. @@ -538,40 +484,23 @@ def _oauth_get_user_future(self, access_token, callback): The access token will be added to the returned dictionary to make the result of `get_authenticated_user`. - For backwards compatibility, the callback-based ``_oauth_get_user`` - method is also supported. - .. versionchanged:: 5.1 Subclasses may also define this method with ``async def``. - .. deprecated:: 5.1 + .. versionchanged:: 6.0 - The ``_oauth_get_user`` fallback is deprecated and support for it - will be removed in 6.0. + A synchronous fallback to ``_oauth_get_user`` was removed. """ - warnings.warn("_oauth_get_user is deprecated, override _oauth_get_user_future instead", - DeprecationWarning) - # By default, call the old-style _oauth_get_user, but new code - # should override this method instead. - self._oauth_get_user(access_token, callback) - - def _oauth_get_user(self, access_token, callback): raise NotImplementedError() - def _on_oauth_get_user(self, access_token, future, user_future): - if user_future.exception() is not None: - future.set_exception(user_future.exception()) - return - user = user_future.result() - if not user: - future.set_exception(AuthError("Error getting user")) - return - user["access_token"] = access_token - future_set_result_unless_cancelled(future, user) - - def _oauth_request_parameters(self, url, access_token, parameters={}, - method="GET"): + def _oauth_request_parameters( + self, + url: str, + access_token: Dict[str, Any], + parameters: Dict[str, Any] = {}, + method: str = "GET", + ) -> Dict[str, Any]: """Returns the OAuth parameters as a dict for the given request. parameters should include all POST arguments and query string arguments @@ -590,15 +519,17 @@ def _oauth_request_parameters(self, url, access_token, parameters={}, args.update(base_args) args.update(parameters) if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a": - signature = _oauth10a_signature(consumer_token, method, url, args, - access_token) + signature = _oauth10a_signature( + consumer_token, method, url, args, access_token + ) else: - signature = _oauth_signature(consumer_token, method, url, args, - access_token) + signature = _oauth_signature( + consumer_token, method, url, args, access_token + ) base_args["oauth_signature"] = escape.to_basestring(signature) return base_args - def get_auth_http_client(self): + def get_auth_http_client(self) -> httpclient.AsyncHTTPClient: """Returns the `.AsyncHTTPClient` instance to be used for auth requests. May be overridden by subclasses to use an HTTP client other than @@ -618,10 +549,16 @@ class OAuth2Mixin(object): * ``_OAUTH_AUTHORIZE_URL``: The service's authorization url. * ``_OAUTH_ACCESS_TOKEN_URL``: The service's access token url. """ - @return_future - def authorize_redirect(self, redirect_uri=None, client_id=None, - client_secret=None, extra_params=None, - callback=None, scope=None, response_type="code"): + + def authorize_redirect( + self, + redirect_uri: Optional[str] = None, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, + extra_params: Optional[Dict[str, Any]] = None, + scope: Optional[List[str]] = None, + response_type: str = "code", + ) -> None: """Redirects the user to obtain OAuth authorization for this service. Some providers require that you register a redirect URL with @@ -630,47 +567,53 @@ def authorize_redirect(self, redirect_uri=None, client_id=None, ``get_authenticated_user`` in the handler for your redirect URL to complete the authorization process. - .. versionchanged:: 3.1 - Returns a `.Future` and takes an optional callback. These are - not strictly necessary as this method is synchronous, - but they are supplied for consistency with - `OAuthMixin.authorize_redirect`. - - .. deprecated:: 5.1 + .. versionchanged:: 6.0 - The ``callback`` argument and returned awaitable will be removed - in Tornado 6.0; this will be an ordinary synchronous function. + The ``callback`` argument and returned awaitable were removed; + this is now an ordinary synchronous function. """ - args = { - "redirect_uri": redirect_uri, - "client_id": client_id, - "response_type": response_type - } + handler = cast(RequestHandler, self) + args = {"response_type": response_type} + if redirect_uri is not None: + args["redirect_uri"] = redirect_uri + if client_id is not None: + args["client_id"] = client_id if extra_params: args.update(extra_params) if scope: - args['scope'] = ' '.join(scope) - self.redirect( - url_concat(self._OAUTH_AUTHORIZE_URL, args)) - callback() - - def _oauth_request_token_url(self, redirect_uri=None, client_id=None, - client_secret=None, code=None, - extra_params=None): - url = self._OAUTH_ACCESS_TOKEN_URL - args = dict( - redirect_uri=redirect_uri, - code=code, - client_id=client_id, - client_secret=client_secret, - ) + args["scope"] = " ".join(scope) + url = self._OAUTH_AUTHORIZE_URL # type: ignore + handler.redirect(url_concat(url, args)) + + def _oauth_request_token_url( + self, + redirect_uri: Optional[str] = None, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, + code: Optional[str] = None, + extra_params: Optional[Dict[str, Any]] = None, + ) -> str: + url = self._OAUTH_ACCESS_TOKEN_URL # type: ignore + args = {} # type: Dict[str, str] + if redirect_uri is not None: + args["redirect_uri"] = redirect_uri + if code is not None: + args["code"] = code + if client_id is not None: + args["client_id"] = client_id + if client_secret is not None: + args["client_secret"] = client_secret if extra_params: args.update(extra_params) return url_concat(url, args) - @_auth_return_future - def oauth2_request(self, url, callback, access_token=None, - post_args=None, **args): + async def oauth2_request( + self, + url: str, + access_token: Optional[str] = None, + post_args: Optional[Dict[str, Any]] = None, + **args: Any + ) -> Any: """Fetches the given URL auth an OAuth2 access token. If the request is a POST, ``post_args`` should be provided. Query @@ -683,16 +626,15 @@ def oauth2_request(self, url, callback, access_token=None, class MainHandler(tornado.web.RequestHandler, tornado.auth.FacebookGraphMixin): @tornado.web.authenticated - @tornado.gen.coroutine - def get(self): - new_entry = yield self.oauth2_request( + async def get(self): + new_entry = await self.oauth2_request( "https://graph.facebook.com/me/feed", post_args={"message": "I am posting from my Tornado application!"}, access_token=self.current_user["access_token"]) if not new_entry: # Call failed; perhaps missing permission? - yield self.authorize_redirect() + self.authorize_redirect() return self.finish("Posted a message!") @@ -701,10 +643,9 @@ def get(self): .. versionadded:: 4.3 - .. deprecated:: 5.1 + .. versionchanged::: 6.0 - The ``callback`` argument is deprecated and will be removed in 6.0. - Use the returned awaitable object instead. + The ``callback`` argument was removed. Use the returned awaitable object instead. """ all_args = {} if access_token: @@ -712,25 +653,17 @@ def get(self): all_args.update(args) if all_args: - url += "?" + urllib_parse.urlencode(all_args) - callback = functools.partial(self._on_oauth2_request, callback) + url += "?" + urllib.parse.urlencode(all_args) http = self.get_auth_http_client() if post_args is not None: - fut = http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args)) + response = await http.fetch( + url, method="POST", body=urllib.parse.urlencode(post_args) + ) else: - fut = http.fetch(url) - fut.add_done_callback(callback) - - def _on_oauth2_request(self, future, response_fut): - try: - response = response_fut.result() - except Exception as e: - future.set_exception(AuthError("Error response %s" % e)) - return + response = await http.fetch(url) + return escape.json_decode(response.body) - future_set_result_unless_cancelled(future, escape.json_decode(response.body)) - - def get_auth_http_client(self): + def get_auth_http_client(self) -> httpclient.AsyncHTTPClient: """Returns the `.AsyncHTTPClient` instance to be used for auth requests. May be overridden by subclasses to use an HTTP client other than @@ -758,13 +691,12 @@ class TwitterMixin(OAuthMixin): class TwitterLoginHandler(tornado.web.RequestHandler, tornado.auth.TwitterMixin): - @tornado.gen.coroutine - def get(self): + async def get(self): if self.get_argument("oauth_token", None): - user = yield self.get_authenticated_user() + user = await self.get_authenticated_user() # Save the user using e.g. set_secure_cookie() else: - yield self.authorize_redirect() + await self.authorize_redirect() .. testoutput:: :hide: @@ -774,6 +706,7 @@ def get(self): and all of the custom Twitter user attributes described at https://dev.twitter.com/docs/api/1.1/get/users/show """ + _OAUTH_REQUEST_TOKEN_URL = "https://api.twitter.com/oauth/request_token" _OAUTH_ACCESS_TOKEN_URL = "https://api.twitter.com/oauth/access_token" _OAUTH_AUTHORIZE_URL = "https://api.twitter.com/oauth/authorize" @@ -781,8 +714,7 @@ def get(self): _OAUTH_NO_CALLBACKS = False _TWITTER_BASE_URL = "https://api.twitter.com/1.1" - @return_future - def authenticate_redirect(self, callback_uri=None, callback=None): + async def authenticate_redirect(self, callback_uri: Optional[str] = None) -> None: """Just like `~OAuthMixin.authorize_redirect`, but auto-redirects if authorized. @@ -793,20 +725,24 @@ def authenticate_redirect(self, callback_uri=None, callback=None): Now returns a `.Future` and takes an optional callback, for compatibility with `.gen.coroutine`. - .. deprecated:: 5.1 + .. versionchanged:: 6.0 - The ``callback`` argument is deprecated and will be removed in 6.0. - Use the returned awaitable object instead. + The ``callback`` argument was removed. Use the returned + awaitable object instead. """ http = self.get_auth_http_client() - http.fetch(self._oauth_request_token_url(callback_uri=callback_uri), - functools.partial( - self._on_request_token, self._OAUTH_AUTHENTICATE_URL, - None, callback)) - - @_auth_return_future - def twitter_request(self, path, callback=None, access_token=None, - post_args=None, **args): + response = await http.fetch( + self._oauth_request_token_url(callback_uri=callback_uri) + ) + self._on_request_token(self._OAUTH_AUTHENTICATE_URL, None, response) + + async def twitter_request( + self, + path: str, + access_token: Dict[str, Any], + post_args: Optional[Dict[str, Any]] = None, + **args: Any + ) -> Any: """Fetches the given API path, e.g., ``statuses/user_timeline/btaylor`` The path should not include the format or API version number. @@ -829,27 +765,26 @@ def twitter_request(self, path, callback=None, access_token=None, class MainHandler(tornado.web.RequestHandler, tornado.auth.TwitterMixin): @tornado.web.authenticated - @tornado.gen.coroutine - def get(self): - new_entry = yield self.twitter_request( + async def get(self): + new_entry = await self.twitter_request( "/statuses/update", post_args={"status": "Testing Tornado Web Server"}, access_token=self.current_user["access_token"]) if not new_entry: # Call failed; perhaps missing permission? - yield self.authorize_redirect() + await self.authorize_redirect() return self.finish("Posted a message!") .. testoutput:: :hide: - .. deprecated:: 5.1 + .. versionchanged:: 6.0 - The ``callback`` argument is deprecated and will be removed in 6.0. - Use the returned awaitable object instead. + The ``callback`` argument was removed. Use the returned + awaitable object instead. """ - if path.startswith('http:') or path.startswith('https:'): + if path.startswith("http:") or path.startswith("https:"): # Raw urls are useful for e.g. search which doesn't follow the # usual pattern: http://search.twitter.com/search.json url = path @@ -862,42 +797,38 @@ def get(self): all_args.update(post_args or {}) method = "POST" if post_args is not None else "GET" oauth = self._oauth_request_parameters( - url, access_token, all_args, method=method) + url, access_token, all_args, method=method + ) args.update(oauth) if args: - url += "?" + urllib_parse.urlencode(args) + url += "?" + urllib.parse.urlencode(args) http = self.get_auth_http_client() - http_callback = functools.partial(self._on_twitter_request, callback, url) if post_args is not None: - fut = http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args)) + response = await http.fetch( + url, method="POST", body=urllib.parse.urlencode(post_args) + ) else: - fut = http.fetch(url) - fut.add_done_callback(http_callback) - - def _on_twitter_request(self, future, url, response_fut): - try: - response = response_fut.result() - except Exception as e: - future.set_exception(AuthError( - "Error response %s fetching %s" % (e, url))) - return - future_set_result_unless_cancelled(future, escape.json_decode(response.body)) + response = await http.fetch(url) + return escape.json_decode(response.body) - def _oauth_consumer_token(self): - self.require_setting("twitter_consumer_key", "Twitter OAuth") - self.require_setting("twitter_consumer_secret", "Twitter OAuth") + def _oauth_consumer_token(self) -> Dict[str, Any]: + handler = cast(RequestHandler, self) + handler.require_setting("twitter_consumer_key", "Twitter OAuth") + handler.require_setting("twitter_consumer_secret", "Twitter OAuth") return dict( - key=self.settings["twitter_consumer_key"], - secret=self.settings["twitter_consumer_secret"]) - - @gen.coroutine - def _oauth_get_user_future(self, access_token): - user = yield self.twitter_request( - "/account/verify_credentials", - access_token=access_token) + key=handler.settings["twitter_consumer_key"], + secret=handler.settings["twitter_consumer_secret"], + ) + + async def _oauth_get_user_future( + self, access_token: Dict[str, Any] + ) -> Dict[str, Any]: + user = await self.twitter_request( + "/account/verify_credentials", access_token=access_token + ) if user: user["username"] = user["screen_name"] - raise gen.Return(user) + return user class GoogleOAuth2Mixin(OAuth2Mixin): @@ -908,24 +839,25 @@ class GoogleOAuth2Mixin(OAuth2Mixin): * Go to the Google Dev Console at http://console.developers.google.com * Select a project, or create a new one. - * In the sidebar on the left, select APIs & Auth. - * In the list of APIs, find the Google+ API service and set it to ON. * In the sidebar on the left, select Credentials. - * In the OAuth section of the page, select Create New Client ID. - * Set the Redirect URI to point to your auth handler + * Click CREATE CREDENTIALS and click OAuth client ID. + * Under Application type, select Web application. + * Name OAuth 2.0 client and click Create. * Copy the "Client secret" and "Client ID" to the application settings as - {"google_oauth": {"key": CLIENT_ID, "secret": CLIENT_SECRET}} + ``{"google_oauth": {"key": CLIENT_ID, "secret": CLIENT_SECRET}}`` .. versionadded:: 3.2 """ + _OAUTH_AUTHORIZE_URL = "https://accounts.google.com/o/oauth2/v2/auth" _OAUTH_ACCESS_TOKEN_URL = "https://www.googleapis.com/oauth2/v4/token" _OAUTH_USERINFO_URL = "https://www.googleapis.com/oauth2/v1/userinfo" _OAUTH_NO_CALLBACKS = False - _OAUTH_SETTINGS_KEY = 'google_oauth' + _OAUTH_SETTINGS_KEY = "google_oauth" - @_auth_return_future - def get_authenticated_user(self, redirect_uri, code, callback): + async def get_authenticated_user( + self, redirect_uri: str, code: str + ) -> Dict[str, Any]: """Handles the login for the Google user, returning an access token. The result is a dictionary containing an ``access_token`` field @@ -942,19 +874,18 @@ def get_authenticated_user(self, redirect_uri, code, callback): class GoogleOAuth2LoginHandler(tornado.web.RequestHandler, tornado.auth.GoogleOAuth2Mixin): - @tornado.gen.coroutine - def get(self): + async def get(self): if self.get_argument('code', False): - access = yield self.get_authenticated_user( + access = await self.get_authenticated_user( redirect_uri='http://your.site.com/auth/google', code=self.get_argument('code')) - user = yield self.oauth2_request( + user = await self.oauth2_request( "https://www.googleapis.com/oauth2/v1/userinfo", access_token=access["access_token"]) # Save the user and access token with # e.g. set_secure_cookie. else: - yield self.authorize_redirect( + self.authorize_redirect( redirect_uri='http://your.site.com/auth/google', client_id=self.settings['google_oauth']['key'], scope=['profile', 'email'], @@ -964,48 +895,47 @@ def get(self): .. testoutput:: :hide: - .. deprecated:: 5.1 + .. versionchanged:: 6.0 - The ``callback`` argument is deprecated and will be removed in 6.0. - Use the returned awaitable object instead. + The ``callback`` argument was removed. Use the returned awaitable object instead. """ # noqa: E501 + handler = cast(RequestHandler, self) http = self.get_auth_http_client() - body = urllib_parse.urlencode({ - "redirect_uri": redirect_uri, - "code": code, - "client_id": self.settings[self._OAUTH_SETTINGS_KEY]['key'], - "client_secret": self.settings[self._OAUTH_SETTINGS_KEY]['secret'], - "grant_type": "authorization_code", - }) - - fut = http.fetch(self._OAUTH_ACCESS_TOKEN_URL, - method="POST", - headers={'Content-Type': 'application/x-www-form-urlencoded'}, - body=body) - fut.add_done_callback(functools.partial(self._on_access_token, callback)) - - def _on_access_token(self, future, response_fut): - """Callback function for the exchange to the access token.""" - try: - response = response_fut.result() - except Exception as e: - future.set_exception(AuthError('Google auth error: %s' % str(e))) - return + body = urllib.parse.urlencode( + { + "redirect_uri": redirect_uri, + "code": code, + "client_id": handler.settings[self._OAUTH_SETTINGS_KEY]["key"], + "client_secret": handler.settings[self._OAUTH_SETTINGS_KEY]["secret"], + "grant_type": "authorization_code", + } + ) - args = escape.json_decode(response.body) - future_set_result_unless_cancelled(future, args) + response = await http.fetch( + self._OAUTH_ACCESS_TOKEN_URL, + method="POST", + headers={"Content-Type": "application/x-www-form-urlencoded"}, + body=body, + ) + return escape.json_decode(response.body) class FacebookGraphMixin(OAuth2Mixin): """Facebook authentication using the new Graph API and OAuth2.""" + _OAUTH_ACCESS_TOKEN_URL = "https://graph.facebook.com/oauth/access_token?" _OAUTH_AUTHORIZE_URL = "https://www.facebook.com/dialog/oauth?" _OAUTH_NO_CALLBACKS = False _FACEBOOK_BASE_URL = "https://graph.facebook.com" - @_auth_return_future - def get_authenticated_user(self, redirect_uri, client_id, client_secret, - code, callback, extra_fields=None): + async def get_authenticated_user( + self, + redirect_uri: str, + client_id: str, + client_secret: str, + code: str, + extra_fields: Optional[Dict[str, Any]] = None, + ) -> Optional[Dict[str, Any]]: """Handles the login for the Facebook user, returning a user object. Example usage: @@ -1014,17 +944,16 @@ def get_authenticated_user(self, redirect_uri, client_id, client_secret, class FacebookGraphLoginHandler(tornado.web.RequestHandler, tornado.auth.FacebookGraphMixin): - @tornado.gen.coroutine - def get(self): + async def get(self): if self.get_argument("code", False): - user = yield self.get_authenticated_user( + user = await self.get_authenticated_user( redirect_uri='/auth/facebookgraph/', client_id=self.settings["facebook_api_key"], client_secret=self.settings["facebook_secret"], code=self.get_argument("code")) # Save the user with e.g. set_secure_cookie else: - yield self.authorize_redirect( + self.authorize_redirect( redirect_uri='/auth/facebookgraph/', client_id=self.settings["facebook_api_key"], extra_params={"scope": "read_stream,offline_access"}) @@ -1048,10 +977,9 @@ def get(self): The ``session_expires`` field was updated to support changes made to the Facebook API in March 2017. - .. deprecated:: 5.1 + .. versionchanged:: 6.0 - The ``callback`` argument is deprecated and will be removed in 6.0. - Use the returned awaitable object instead. + The ``callback`` argument was removed. Use the returned awaitable object instead. """ http = self.get_auth_http_client() args = { @@ -1061,42 +989,35 @@ def get(self): "client_secret": client_secret, } - fields = set(['id', 'name', 'first_name', 'last_name', - 'locale', 'picture', 'link']) + fields = set( + ["id", "name", "first_name", "last_name", "locale", "picture", "link"] + ) if extra_fields: fields.update(extra_fields) - fut = http.fetch(self._oauth_request_token_url(**args)) - fut.add_done_callback(functools.partial(self._on_access_token, redirect_uri, client_id, - client_secret, callback, fields)) - - @gen.coroutine - def _on_access_token(self, redirect_uri, client_id, client_secret, - future, fields, response_fut): - try: - response = response_fut.result() - except Exception as e: - future.set_exception(AuthError('Facebook auth error: %s' % str(e))) - return - + response = await http.fetch( + self._oauth_request_token_url(**args) # type: ignore + ) args = escape.json_decode(response.body) session = { "access_token": args.get("access_token"), - "expires_in": args.get("expires_in") + "expires_in": args.get("expires_in"), } + assert session["access_token"] is not None - user = yield self.facebook_request( + user = await self.facebook_request( path="/me", access_token=session["access_token"], - appsecret_proof=hmac.new(key=client_secret.encode('utf8'), - msg=session["access_token"].encode('utf8'), - digestmod=hashlib.sha256).hexdigest(), - fields=",".join(fields) + appsecret_proof=hmac.new( + key=client_secret.encode("utf8"), + msg=session["access_token"].encode("utf8"), + digestmod=hashlib.sha256, + ).hexdigest(), + fields=",".join(fields), ) if user is None: - future_set_result_unless_cancelled(future, None) - return + return None fieldmap = {} for field in fields: @@ -1106,13 +1027,21 @@ def _on_access_token(self, redirect_uri, client_id, client_secret, # older versions in which the server used url-encoding and # this code simply returned the string verbatim. # This should change in Tornado 5.0. - fieldmap.update({"access_token": session["access_token"], - "session_expires": str(session.get("expires_in"))}) - future_set_result_unless_cancelled(future, fieldmap) - - @_auth_return_future - def facebook_request(self, path, callback, access_token=None, - post_args=None, **args): + fieldmap.update( + { + "access_token": session["access_token"], + "session_expires": str(session.get("expires_in")), + } + ) + return fieldmap + + async def facebook_request( + self, + path: str, + access_token: Optional[str] = None, + post_args: Optional[Dict[str, Any]] = None, + **args: Any + ) -> Any: """Fetches the given relative API path, e.g., "/btaylor/picture" If the request is a POST, ``post_args`` should be provided. Query @@ -1134,16 +1063,15 @@ def facebook_request(self, path, callback, access_token=None, class MainHandler(tornado.web.RequestHandler, tornado.auth.FacebookGraphMixin): @tornado.web.authenticated - @tornado.gen.coroutine - def get(self): - new_entry = yield self.facebook_request( + async def get(self): + new_entry = await self.facebook_request( "/me/feed", post_args={"message": "I am posting from my Tornado application!"}, access_token=self.current_user["access_token"]) if not new_entry: # Call failed; perhaps missing permission? - yield self.authorize_redirect() + self.authorize_redirect() return self.finish("Posted a message!") @@ -1160,35 +1088,39 @@ def get(self): .. versionchanged:: 3.1 Added the ability to override ``self._FACEBOOK_BASE_URL``. - .. deprecated:: 5.1 + .. versionchanged:: 6.0 - The ``callback`` argument is deprecated and will be removed in 6.0. - Use the returned awaitable object instead. + The ``callback`` argument was removed. Use the returned awaitable object instead. """ url = self._FACEBOOK_BASE_URL + path - # Thanks to the _auth_return_future decorator, our "callback" - # argument is a Future, which we cannot pass as a callback to - # oauth2_request. Instead, have oauth2_request return a - # future and chain them together. - oauth_future = self.oauth2_request(url, access_token=access_token, - post_args=post_args, **args) - chain_future(oauth_future, callback) + return await self.oauth2_request( + url, access_token=access_token, post_args=post_args, **args + ) -def _oauth_signature(consumer_token, method, url, parameters={}, token=None): +def _oauth_signature( + consumer_token: Dict[str, Any], + method: str, + url: str, + parameters: Dict[str, Any] = {}, + token: Optional[Dict[str, Any]] = None, +) -> bytes: """Calculates the HMAC-SHA1 OAuth signature for the given request. See http://oauth.net/core/1.0/#signing_process """ - parts = urlparse.urlparse(url) + parts = urllib.parse.urlparse(url) scheme, netloc, path = parts[:3] normalized_url = scheme.lower() + "://" + netloc.lower() + path base_elems = [] base_elems.append(method.upper()) base_elems.append(normalized_url) - base_elems.append("&".join("%s=%s" % (k, _oauth_escape(str(v))) - for k, v in sorted(parameters.items()))) + base_elems.append( + "&".join( + "%s=%s" % (k, _oauth_escape(str(v))) for k, v in sorted(parameters.items()) + ) + ) base_string = "&".join(_oauth_escape(e) for e in base_elems) key_elems = [escape.utf8(consumer_token["secret"])] @@ -1199,42 +1131,53 @@ def _oauth_signature(consumer_token, method, url, parameters={}, token=None): return binascii.b2a_base64(hash.digest())[:-1] -def _oauth10a_signature(consumer_token, method, url, parameters={}, token=None): +def _oauth10a_signature( + consumer_token: Dict[str, Any], + method: str, + url: str, + parameters: Dict[str, Any] = {}, + token: Optional[Dict[str, Any]] = None, +) -> bytes: """Calculates the HMAC-SHA1 OAuth 1.0a signature for the given request. See http://oauth.net/core/1.0a/#signing_process """ - parts = urlparse.urlparse(url) + parts = urllib.parse.urlparse(url) scheme, netloc, path = parts[:3] normalized_url = scheme.lower() + "://" + netloc.lower() + path base_elems = [] base_elems.append(method.upper()) base_elems.append(normalized_url) - base_elems.append("&".join("%s=%s" % (k, _oauth_escape(str(v))) - for k, v in sorted(parameters.items()))) + base_elems.append( + "&".join( + "%s=%s" % (k, _oauth_escape(str(v))) for k, v in sorted(parameters.items()) + ) + ) base_string = "&".join(_oauth_escape(e) for e in base_elems) - key_elems = [escape.utf8(urllib_parse.quote(consumer_token["secret"], safe='~'))] - key_elems.append(escape.utf8(urllib_parse.quote(token["secret"], safe='~') if token else "")) + key_elems = [escape.utf8(urllib.parse.quote(consumer_token["secret"], safe="~"))] + key_elems.append( + escape.utf8(urllib.parse.quote(token["secret"], safe="~") if token else "") + ) key = b"&".join(key_elems) hash = hmac.new(key, escape.utf8(base_string), hashlib.sha1) return binascii.b2a_base64(hash.digest())[:-1] -def _oauth_escape(val): +def _oauth_escape(val: Union[str, bytes]) -> str: if isinstance(val, unicode_type): val = val.encode("utf-8") - return urllib_parse.quote(val, safe="~") + return urllib.parse.quote(val, safe="~") -def _oauth_parse_response(body): +def _oauth_parse_response(body: bytes) -> Dict[str, Any]: # I can't find an officially-defined encoding for oauth responses and # have never seen anyone use non-ascii. Leave the response in a byte # string for python 2, and use utf8 on python 3. - body = escape.native_str(body) - p = urlparse.parse_qs(body, keep_blank_values=False) + body_str = escape.native_str(body) + p = urllib.parse.parse_qs(body_str, keep_blank_values=False) token = dict(key=p["oauth_token"][0], secret=p["oauth_token_secret"][0]) # Add the extra parameters the Provider included to the token diff --git a/tornado/autoreload.py b/tornado/autoreload.py index 2f91127031..db47262be2 100644 --- a/tornado/autoreload.py +++ b/tornado/autoreload.py @@ -33,9 +33,8 @@ other import-time failures, while debug mode catches changes once the server has started. -This module depends on `.IOLoop`, so it will not work in WSGI applications -and Google App Engine. It also will not work correctly when `.HTTPServer`'s -multi-process mode is used. +This module will not work correctly when `.HTTPServer`'s multi-process +mode is used. Reloading loses any Python interpreter command-line arguments (e.g. ``-u``) because it re-executes Python using ``sys.executable`` and ``sys.argv``. @@ -44,8 +43,6 @@ """ -from __future__ import absolute_import, division, print_function - import os import sys @@ -96,20 +93,29 @@ try: import signal except ImportError: - signal = None + signal = None # type: ignore + +import typing +from typing import Callable, Dict + +if typing.TYPE_CHECKING: + from typing import List, Optional, Union # noqa: F401 # os.execv is broken on Windows and can't properly parse command line # arguments and executable name if they contain whitespaces. subprocess # fixes that behavior. -_has_execv = sys.platform != 'win32' +_has_execv = sys.platform != "win32" _watched_files = set() _reload_hooks = [] _reload_attempted = False _io_loops = weakref.WeakKeyDictionary() # type: ignore +_autoreload_is_main = False +_original_argv = None # type: Optional[List[str]] +_original_spec = None -def start(check_time=500): +def start(check_time: int = 500) -> None: """Begins watching source files for changes. .. versionchanged:: 5.0 @@ -121,13 +127,13 @@ def start(check_time=500): _io_loops[io_loop] = True if len(_io_loops) > 1: gen_log.warning("tornado.autoreload started more than once in the same process") - modify_times = {} + modify_times = {} # type: Dict[str, float] callback = functools.partial(_reload_on_update, modify_times) scheduler = ioloop.PeriodicCallback(callback, check_time) scheduler.start() -def wait(): +def wait() -> None: """Wait for a watched file to change, then restart the process. Intended to be used at the end of scripts like unit test runners, @@ -139,7 +145,7 @@ def wait(): io_loop.start() -def watch(filename): +def watch(filename: str) -> None: """Add a file to the watch list. All imported modules are watched by default. @@ -147,18 +153,17 @@ def watch(filename): _watched_files.add(filename) -def add_reload_hook(fn): +def add_reload_hook(fn: Callable[[], None]) -> None: """Add a function to be called before reloading the process. Note that for open file and socket handles it is generally preferable to set the ``FD_CLOEXEC`` flag (using `fcntl` or - ``tornado.platform.auto.set_close_exec``) instead - of using a reload hook to close them. + `os.set_inheritable`) instead of using a reload hook to close them. """ _reload_hooks.append(fn) -def _reload_on_update(modify_times): +def _reload_on_update(modify_times: Dict[str, float]) -> None: if _reload_attempted: # We already tried to reload and it didn't work, so don't try again. return @@ -184,7 +189,7 @@ def _reload_on_update(modify_times): _check_file(modify_times, path) -def _check_file(modify_times, path): +def _check_file(modify_times: Dict[str, float], path: str) -> None: try: modified = os.stat(path).st_mtime except Exception: @@ -197,12 +202,12 @@ def _check_file(modify_times, path): _reload() -def _reload(): +def _reload() -> None: global _reload_attempted _reload_attempted = True for fn in _reload_hooks: fn() - if hasattr(signal, "setitimer"): + if sys.platform != "win32": # Clear the alarm signal set by # ioloop.set_blocking_log_threshold so it doesn't fire # after the exec. @@ -214,19 +219,24 @@ def _reload(): # __spec__ is not available (Python < 3.4), check instead if # sys.path[0] is an empty string and add the current directory to # $PYTHONPATH. - spec = getattr(sys.modules['__main__'], '__spec__', None) - if spec: - argv = ['-m', spec.name] + sys.argv[1:] + if _autoreload_is_main: + assert _original_argv is not None + spec = _original_spec + argv = _original_argv else: + spec = getattr(sys.modules["__main__"], "__spec__", None) argv = sys.argv - path_prefix = '.' + os.pathsep - if (sys.path[0] == '' and - not os.environ.get("PYTHONPATH", "").startswith(path_prefix)): - os.environ["PYTHONPATH"] = (path_prefix + - os.environ.get("PYTHONPATH", "")) + if spec: + argv = ["-m", spec.name] + argv[1:] + else: + path_prefix = "." + os.pathsep + if sys.path[0] == "" and not os.environ.get("PYTHONPATH", "").startswith( + path_prefix + ): + os.environ["PYTHONPATH"] = path_prefix + os.environ.get("PYTHONPATH", "") if not _has_execv: subprocess.Popen([sys.executable] + argv) - sys.exit(0) + os._exit(0) else: try: os.execv(sys.executable, [sys.executable] + argv) @@ -242,7 +252,9 @@ def _reload(): # Unfortunately the errno returned in this case does not # appear to be consistent, so we can't easily check for # this error specifically. - os.spawnv(os.P_NOWAIT, sys.executable, [sys.executable] + argv) + os.spawnv( + os.P_NOWAIT, sys.executable, [sys.executable] + argv # type: ignore + ) # At this point the IOLoop has been closed and finally # blocks will experience errors if we allow the stack to # unwind, so just exit uncleanly. @@ -256,7 +268,7 @@ def _reload(): """ -def main(): +def main() -> None: """Command-line wrapper to re-run a script whenever its source changes. Scripts may be specified by filename or module name:: @@ -269,7 +281,18 @@ def main(): can catch import-time problems like syntax errors that would otherwise prevent the script from reaching its call to `wait`. """ + # Remember that we were launched with autoreload as main. + # The main module can be tricky; set the variables both in our globals + # (which may be __main__) and the real importable version. + import tornado.autoreload + + global _autoreload_is_main + global _original_argv, _original_spec + tornado.autoreload._autoreload_is_main = _autoreload_is_main = True original_argv = sys.argv + tornado.autoreload._original_argv = _original_argv = original_argv + original_spec = getattr(sys.modules["__main__"], "__spec__", None) + tornado.autoreload._original_spec = _original_spec = original_spec sys.argv = sys.argv[:] if len(sys.argv) >= 3 and sys.argv[1] == "-m": mode = "module" @@ -286,6 +309,7 @@ def main(): try: if mode == "module": import runpy + runpy.run_module(module, run_name="__main__", alter_sys=True) elif mode == "script": with open(script) as f: @@ -316,19 +340,20 @@ def main(): # SyntaxErrors are special: their innermost stack frame is fake # so extract_tb won't see it and we have to get the filename # from the exception object. - watch(e.filename) + if e.filename is not None: + watch(e.filename) else: logging.basicConfig() gen_log.info("Script exited normally") # restore sys.argv so subsequent executions will include autoreload sys.argv = original_argv - if mode == 'module': + if mode == "module": # runpy did a fake import of the module as __main__, but now it's # no longer in sys.modules. Figure out where it is and watch it. loader = pkgutil.get_loader(module) if loader is not None: - watch(loader.get_filename()) + watch(loader.get_filename()) # type: ignore wait() diff --git a/tornado/concurrent.py b/tornado/concurrent.py index 850766818d..6e05346b56 100644 --- a/tornado/concurrent.py +++ b/tornado/concurrent.py @@ -14,393 +14,67 @@ # under the License. """Utilities for working with ``Future`` objects. -``Futures`` are a pattern for concurrent programming introduced in -Python 3.2 in the `concurrent.futures` package, and also adopted (in a -slightly different form) in Python 3.4's `asyncio` package. This -package defines a ``Future`` class that is an alias for `asyncio.Future` -when available, and a compatible implementation for older versions of -Python. It also includes some utility functions for interacting with -``Future`` objects. - -While this package is an important part of Tornado's internal +Tornado previously provided its own ``Future`` class, but now uses +`asyncio.Future`. This module contains utility functions for working +with `asyncio.Future` in a way that is backwards-compatible with +Tornado's old ``Future`` implementation. + +While this module is an important part of Tornado's internal implementation, applications rarely need to interact with it directly. + """ -from __future__ import absolute_import, division, print_function +import asyncio +from concurrent import futures import functools -import platform -import textwrap -import traceback import sys -import warnings +import types from tornado.log import app_log -from tornado.stack_context import ExceptionStackContext, wrap -from tornado.util import raise_exc_info, ArgReplacer, is_finalizing - -try: - from concurrent import futures -except ImportError: - futures = None - -try: - import asyncio -except ImportError: - asyncio = None - -try: - import typing -except ImportError: - typing = None +import typing +from typing import Any, Callable, Optional, Tuple, Union -# Can the garbage collector handle cycles that include __del__ methods? -# This is true in cpython beginning with version 3.4 (PEP 442). -_GC_CYCLE_FINALIZERS = (platform.python_implementation() == 'CPython' and - sys.version_info >= (3, 4)) +_T = typing.TypeVar("_T") class ReturnValueIgnoredError(Exception): + # No longer used; was previously used by @return_future pass -# This class and associated code in the future object is derived -# from the Trollius project, a backport of asyncio to Python 2.x - 3.x - - -class _TracebackLogger(object): - """Helper to log a traceback upon destruction if not cleared. - - This solves a nasty problem with Futures and Tasks that have an - exception set: if nobody asks for the exception, the exception is - never logged. This violates the Zen of Python: 'Errors should - never pass silently. Unless explicitly silenced.' - - However, we don't want to log the exception as soon as - set_exception() is called: if the calling code is written - properly, it will get the exception and handle it properly. But - we *do* want to log it if result() or exception() was never called - -- otherwise developers waste a lot of time wondering why their - buggy code fails silently. - - An earlier attempt added a __del__() method to the Future class - itself, but this backfired because the presence of __del__() - prevents garbage collection from breaking cycles. A way out of - this catch-22 is to avoid having a __del__() method on the Future - class itself, but instead to have a reference to a helper object - with a __del__() method that logs the traceback, where we ensure - that the helper object doesn't participate in cycles, and only the - Future has a reference to it. - - The helper object is added when set_exception() is called. When - the Future is collected, and the helper is present, the helper - object is also collected, and its __del__() method will log the - traceback. When the Future's result() or exception() method is - called (and a helper object is present), it removes the the helper - object, after calling its clear() method to prevent it from - logging. - - One downside is that we do a fair amount of work to extract the - traceback from the exception, even when it is never logged. It - would seem cheaper to just store the exception object, but that - references the traceback, which references stack frames, which may - reference the Future, which references the _TracebackLogger, and - then the _TracebackLogger would be included in a cycle, which is - what we're trying to avoid! As an optimization, we don't - immediately format the exception; we only do the work when - activate() is called, which call is delayed until after all the - Future's callbacks have run. Since usually a Future has at least - one callback (typically set by 'yield From') and usually that - callback extracts the callback, thereby removing the need to - format the exception. - - PS. I don't claim credit for this solution. I first heard of it - in a discussion about closing files when they are collected. - """ - - __slots__ = ('exc_info', 'formatted_tb') - - def __init__(self, exc_info): - self.exc_info = exc_info - self.formatted_tb = None - - def activate(self): - exc_info = self.exc_info - if exc_info is not None: - self.exc_info = None - self.formatted_tb = traceback.format_exception(*exc_info) - - def clear(self): - self.exc_info = None - self.formatted_tb = None - - def __del__(self, is_finalizing=is_finalizing): - if not is_finalizing() and self.formatted_tb: - app_log.error('Future exception was never retrieved: %s', - ''.join(self.formatted_tb).rstrip()) - - -class Future(object): - """Placeholder for an asynchronous result. - - A ``Future`` encapsulates the result of an asynchronous - operation. In synchronous applications ``Futures`` are used - to wait for the result from a thread or process pool; in - Tornado they are normally used with `.IOLoop.add_future` or by - yielding them in a `.gen.coroutine`. - - `tornado.concurrent.Future` is an alias for `asyncio.Future` when - that package is available (Python 3.4+). Unlike - `concurrent.futures.Future`, the ``Futures`` used by Tornado and - `asyncio` are not thread-safe (and therefore faster for use with - single-threaded event loops). - - In addition to ``exception`` and ``set_exception``, Tornado's - ``Future`` implementation supports storing an ``exc_info`` triple - to support better tracebacks on Python 2. To set an ``exc_info`` - triple, use `future_set_exc_info`, and to retrieve one, call - `result()` (which will raise it). - - .. versionchanged:: 4.0 - `tornado.concurrent.Future` is always a thread-unsafe ``Future`` - with support for the ``exc_info`` methods. Previously it would - be an alias for the thread-safe `concurrent.futures.Future` - if that package was available and fall back to the thread-unsafe - implementation if it was not. - - .. versionchanged:: 4.1 - If a `.Future` contains an error but that error is never observed - (by calling ``result()``, ``exception()``, or ``exc_info()``), - a stack trace will be logged when the `.Future` is garbage collected. - This normally indicates an error in the application, but in cases - where it results in undesired logging it may be necessary to - suppress the logging by ensuring that the exception is observed: - ``f.add_done_callback(lambda f: f.exception())``. - - .. versionchanged:: 5.0 - - This class was previoiusly available under the name - ``TracebackFuture``. This name, which was deprecated since - version 4.0, has been removed. When `asyncio` is available - ``tornado.concurrent.Future`` is now an alias for - `asyncio.Future`. Like `asyncio.Future`, callbacks are now - always scheduled on the `.IOLoop` and are never run - synchronously. - - """ - def __init__(self): - self._done = False - self._result = None - self._exc_info = None - - self._log_traceback = False # Used for Python >= 3.4 - self._tb_logger = None # Used for Python <= 3.3 - - self._callbacks = [] - - # Implement the Python 3.5 Awaitable protocol if possible - # (we can't use return and yield together until py33). - if sys.version_info >= (3, 3): - exec(textwrap.dedent(""" - def __await__(self): - return (yield self) - """)) - else: - # Py2-compatible version for use with cython. - def __await__(self): - result = yield self - # StopIteration doesn't take args before py33, - # but Cython recognizes the args tuple. - e = StopIteration() - e.args = (result,) - raise e - - def cancel(self): - """Cancel the operation, if possible. - - Tornado ``Futures`` do not support cancellation, so this method always - returns False. - """ - return False - - def cancelled(self): - """Returns True if the operation has been cancelled. - - Tornado ``Futures`` do not support cancellation, so this method - always returns False. - """ - return False - - def running(self): - """Returns True if this operation is currently running.""" - return not self._done - - def done(self): - """Returns True if the future has finished running.""" - return self._done - - def _clear_tb_log(self): - self._log_traceback = False - if self._tb_logger is not None: - self._tb_logger.clear() - self._tb_logger = None - - def result(self, timeout=None): - """If the operation succeeded, return its result. If it failed, - re-raise its exception. - - This method takes a ``timeout`` argument for compatibility with - `concurrent.futures.Future` but it is an error to call it - before the `Future` is done, so the ``timeout`` is never used. - """ - self._clear_tb_log() - if self._result is not None: - return self._result - if self._exc_info is not None: - try: - raise_exc_info(self._exc_info) - finally: - self = None - self._check_done() - return self._result - - def exception(self, timeout=None): - """If the operation raised an exception, return the `Exception` - object. Otherwise returns None. - - This method takes a ``timeout`` argument for compatibility with - `concurrent.futures.Future` but it is an error to call it - before the `Future` is done, so the ``timeout`` is never used. - """ - self._clear_tb_log() - if self._exc_info is not None: - return self._exc_info[1] - else: - self._check_done() - return None - - def add_done_callback(self, fn): - """Attaches the given callback to the `Future`. - - It will be invoked with the `Future` as its argument when the Future - has finished running and its result is available. In Tornado - consider using `.IOLoop.add_future` instead of calling - `add_done_callback` directly. - """ - if self._done: - from tornado.ioloop import IOLoop - IOLoop.current().add_callback(fn, self) - else: - self._callbacks.append(fn) - - def set_result(self, result): - """Sets the result of a ``Future``. - - It is undefined to call any of the ``set`` methods more than once - on the same object. - """ - self._result = result - self._set_done() - def set_exception(self, exception): - """Sets the exception of a ``Future.``""" - self.set_exc_info( - (exception.__class__, - exception, - getattr(exception, '__traceback__', None))) +Future = asyncio.Future - def exc_info(self): - """Returns a tuple in the same format as `sys.exc_info` or None. +FUTURES = (futures.Future, Future) - .. versionadded:: 4.0 - """ - self._clear_tb_log() - return self._exc_info - def set_exc_info(self, exc_info): - """Sets the exception information of a ``Future.`` - - Preserves tracebacks on Python 2. - - .. versionadded:: 4.0 - """ - self._exc_info = exc_info - self._log_traceback = True - if not _GC_CYCLE_FINALIZERS: - self._tb_logger = _TracebackLogger(exc_info) - - try: - self._set_done() - finally: - # Activate the logger after all callbacks have had a - # chance to call result() or exception(). - if self._log_traceback and self._tb_logger is not None: - self._tb_logger.activate() - self._exc_info = exc_info - - def _check_done(self): - if not self._done: - raise Exception("DummyFuture does not support blocking for results") - - def _set_done(self): - self._done = True - if self._callbacks: - from tornado.ioloop import IOLoop - loop = IOLoop.current() - for cb in self._callbacks: - loop.add_callback(cb, self) - self._callbacks = None - - # On Python 3.3 or older, objects with a destructor part of a reference - # cycle are never destroyed. It's no longer the case on Python 3.4 thanks to - # the PEP 442. - if _GC_CYCLE_FINALIZERS: - def __del__(self, is_finalizing=is_finalizing): - if is_finalizing() or not self._log_traceback: - # set_exception() was not called, or result() or exception() - # has consumed the exception - return - - tb = traceback.format_exception(*self._exc_info) - - app_log.error('Future %r exception was never retrieved: %s', - self, ''.join(tb).rstrip()) - - -if asyncio is not None: - Future = asyncio.Future # noqa - -if futures is None: - FUTURES = Future # type: typing.Union[type, typing.Tuple[type, ...]] -else: - FUTURES = (futures.Future, Future) - - -def is_future(x): +def is_future(x: Any) -> bool: return isinstance(x, FUTURES) -class DummyExecutor(object): - def submit(self, fn, *args, **kwargs): - future = Future() +class DummyExecutor(futures.Executor): + def submit( + self, fn: Callable[..., _T], *args: Any, **kwargs: Any + ) -> "futures.Future[_T]": + future = futures.Future() # type: futures.Future[_T] try: future_set_result_unless_cancelled(future, fn(*args, **kwargs)) except Exception: future_set_exc_info(future, sys.exc_info()) return future - def shutdown(self, wait=True): + def shutdown(self, wait: bool = True) -> None: pass dummy_executor = DummyExecutor() -def run_on_executor(*args, **kwargs): +def run_on_executor(*args: Any, **kwargs: Any) -> Callable: """Decorator to run a synchronous method asynchronously on an executor. - The decorated method may be called with a ``callback`` keyword - argument and returns a future. + Returns a future. The executor to be used is determined by the ``executor`` attributes of ``self``. To use a different attribute name, pass a @@ -432,24 +106,25 @@ def foo(self): The ``callback`` argument is deprecated and will be removed in 6.0. The decorator itself is discouraged in new code but will not be removed in 6.0. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. """ - def run_on_executor_decorator(fn): + # Fully type-checking decorators is tricky, and this one is + # discouraged anyway so it doesn't have all the generic magic. + def run_on_executor_decorator(fn: Callable) -> Callable[..., Future]: executor = kwargs.get("executor", "executor") @functools.wraps(fn) - def wrapper(self, *args, **kwargs): - callback = kwargs.pop("callback", None) - async_future = Future() + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Future: + async_future = Future() # type: Future conc_future = getattr(self, executor).submit(fn, self, *args, **kwargs) chain_future(conc_future, async_future) - if callback: - warnings.warn("callback arguments are deprecated, use the returned Future instead", - DeprecationWarning) - from tornado.ioloop import IOLoop - IOLoop.current().add_future( - async_future, lambda future: callback(future.result())) return async_future + return wrapper + if args and kwargs: raise ValueError("cannot combine positional and keyword args") if len(args) == 1: @@ -462,110 +137,7 @@ def wrapper(self, *args, **kwargs): _NO_RESULT = object() -def return_future(f): - """Decorator to make a function that returns via callback return a - `Future`. - - This decorator was provided to ease the transition from - callback-oriented code to coroutines. It is not recommended for - new code. - - The wrapped function should take a ``callback`` keyword argument - and invoke it with one argument when it has finished. To signal failure, - the function can simply raise an exception (which will be - captured by the `.StackContext` and passed along to the ``Future``). - - From the caller's perspective, the callback argument is optional. - If one is given, it will be invoked when the function is complete - with ``Future.result()`` as an argument. If the function fails, the - callback will not be run and an exception will be raised into the - surrounding `.StackContext`. - - If no callback is given, the caller should use the ``Future`` to - wait for the function to complete (perhaps by yielding it in a - `.gen.engine` function, or passing it to `.IOLoop.add_future`). - - Usage: - - .. testcode:: - - @return_future - def future_func(arg1, arg2, callback): - # Do stuff (possibly asynchronous) - callback(result) - - @gen.engine - def caller(callback): - yield future_func(arg1, arg2) - callback() - - .. - - Note that ``@return_future`` and ``@gen.engine`` can be applied to the - same function, provided ``@return_future`` appears first. However, - consider using ``@gen.coroutine`` instead of this combination. - - .. versionchanged:: 5.1 - - Now raises a `.DeprecationWarning` if a callback argument is passed to - the decorated function and deprecation warnings are enabled. - - .. deprecated:: 5.1 - - New code should use coroutines directly instead of wrapping - callback-based code with this decorator. - """ - replacer = ArgReplacer(f, 'callback') - - @functools.wraps(f) - def wrapper(*args, **kwargs): - future = Future() - callback, args, kwargs = replacer.replace( - lambda value=_NO_RESULT: future_set_result_unless_cancelled(future, value), - args, kwargs) - - def handle_error(typ, value, tb): - future_set_exc_info(future, (typ, value, tb)) - return True - exc_info = None - with ExceptionStackContext(handle_error): - try: - result = f(*args, **kwargs) - if result is not None: - raise ReturnValueIgnoredError( - "@return_future should not be used with functions " - "that return values") - except: - exc_info = sys.exc_info() - raise - if exc_info is not None: - # If the initial synchronous part of f() raised an exception, - # go ahead and raise it to the caller directly without waiting - # for them to inspect the Future. - future.result() - - # If the caller passed in a callback, schedule it to be called - # when the future resolves. It is important that this happens - # just before we return the future, or else we risk confusing - # stack contexts with multiple exceptions (one here with the - # immediate exception, and again when the future resolves and - # the callback triggers its exception by calling future.result()). - if callback is not None: - warnings.warn("callback arguments are deprecated, use the returned Future instead", - DeprecationWarning) - - def run_callback(future): - result = future.result() - if result is _NO_RESULT: - callback() - else: - callback(future.result()) - future_add_done_callback(future, wrap(run_callback)) - return future - return wrapper - - -def chain_future(a, b): +def chain_future(a: "Future[_T]", b: "Future[_T]") -> None: """Chain two futures together so that when one completes, so does the other. The result (success or failure) of ``a`` will be copied to ``b``, unless @@ -577,29 +149,35 @@ def chain_future(a, b): `concurrent.futures.Future`. """ - def copy(future): + + def copy(future: "Future[_T]") -> None: assert future is a if b.done(): return - if (hasattr(a, 'exc_info') and - a.exc_info() is not None): - future_set_exc_info(b, a.exc_info()) - elif a.exception() is not None: - b.set_exception(a.exception()) + if hasattr(a, "exc_info") and a.exc_info() is not None: # type: ignore + future_set_exc_info(b, a.exc_info()) # type: ignore else: - b.set_result(a.result()) + a_exc = a.exception() + if a_exc is not None: + b.set_exception(a_exc) + else: + b.set_result(a.result()) + if isinstance(a, Future): future_add_done_callback(a, copy) else: # concurrent.futures.Future from tornado.ioloop import IOLoop + IOLoop.current().add_future(a, copy) -def future_set_result_unless_cancelled(future, value): +def future_set_result_unless_cancelled( + future: "Union[futures.Future[_T], Future[_T]]", value: _T +) -> None: """Set the given ``value`` as the `Future`'s result, if not cancelled. - Avoids asyncio.InvalidStateError when calling set_result() on + Avoids ``asyncio.InvalidStateError`` when calling ``set_result()`` on a cancelled `asyncio.Future`. .. versionadded:: 5.0 @@ -608,23 +186,69 @@ def future_set_result_unless_cancelled(future, value): future.set_result(value) -def future_set_exc_info(future, exc_info): +def future_set_exception_unless_cancelled( + future: "Union[futures.Future[_T], Future[_T]]", exc: BaseException +) -> None: + """Set the given ``exc`` as the `Future`'s exception. + + If the Future is already canceled, logs the exception instead. If + this logging is not desired, the caller should explicitly check + the state of the Future and call ``Future.set_exception`` instead of + this wrapper. + + Avoids ``asyncio.InvalidStateError`` when calling ``set_exception()`` on + a cancelled `asyncio.Future`. + + .. versionadded:: 6.0 + + """ + if not future.cancelled(): + future.set_exception(exc) + else: + app_log.error("Exception after Future was cancelled", exc_info=exc) + + +def future_set_exc_info( + future: "Union[futures.Future[_T], Future[_T]]", + exc_info: Tuple[ + Optional[type], Optional[BaseException], Optional[types.TracebackType] + ], +) -> None: """Set the given ``exc_info`` as the `Future`'s exception. - Understands both `asyncio.Future` and Tornado's extensions to - enable better tracebacks on Python 2. + Understands both `asyncio.Future` and the extensions in older + versions of Tornado to enable better tracebacks on Python 2. .. versionadded:: 5.0 + + .. versionchanged:: 6.0 + + If the future is already cancelled, this function is a no-op. + (previously ``asyncio.InvalidStateError`` would be raised) + """ - if hasattr(future, 'set_exc_info'): - # Tornado's Future - future.set_exc_info(exc_info) - else: - # asyncio.Future - future.set_exception(exc_info[1]) + if exc_info[1] is None: + raise Exception("future_set_exc_info called with no exception") + future_set_exception_unless_cancelled(future, exc_info[1]) + + +@typing.overload +def future_add_done_callback( + future: "futures.Future[_T]", callback: Callable[["futures.Future[_T]"], None] +) -> None: + pass + + +@typing.overload # noqa: F811 +def future_add_done_callback( + future: "Future[_T]", callback: Callable[["Future[_T]"], None] +) -> None: + pass -def future_add_done_callback(future, callback): +def future_add_done_callback( # noqa: F811 + future: "Union[futures.Future[_T], Future[_T]]", callback: Callable[..., None] +) -> None: """Arrange to call ``callback`` when ``future`` is complete. ``callback`` is invoked with one argument, the ``future``. diff --git a/tornado/curl_httpclient.py b/tornado/curl_httpclient.py index 54fc5b36da..3860944956 100644 --- a/tornado/curl_httpclient.py +++ b/tornado/curl_httpclient.py @@ -15,44 +15,60 @@ """Non-blocking HTTP client implementation using pycurl.""" -from __future__ import absolute_import, division, print_function - import collections import functools import logging -import pycurl # type: ignore +import pycurl import threading import time from io import BytesIO from tornado import httputil from tornado import ioloop -from tornado import stack_context from tornado.escape import utf8, native_str -from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main +from tornado.httpclient import ( + HTTPRequest, + HTTPResponse, + HTTPError, + AsyncHTTPClient, + main, +) +from tornado.log import app_log + +from typing import Dict, Any, Callable, Union, Tuple, Optional +import typing + +if typing.TYPE_CHECKING: + from typing import Deque # noqa: F401 -curl_log = logging.getLogger('tornado.curl_httpclient') +curl_log = logging.getLogger("tornado.curl_httpclient") class CurlAsyncHTTPClient(AsyncHTTPClient): - def initialize(self, max_clients=10, defaults=None): - super(CurlAsyncHTTPClient, self).initialize(defaults=defaults) - self._multi = pycurl.CurlMulti() + def initialize( # type: ignore + self, max_clients: int = 10, defaults: Optional[Dict[str, Any]] = None + ) -> None: + super().initialize(defaults=defaults) + # Typeshed is incomplete for CurlMulti, so just use Any for now. + self._multi = pycurl.CurlMulti() # type: Any self._multi.setopt(pycurl.M_TIMERFUNCTION, self._set_timeout) self._multi.setopt(pycurl.M_SOCKETFUNCTION, self._handle_socket) self._curls = [self._curl_create() for i in range(max_clients)] self._free_list = self._curls[:] - self._requests = collections.deque() - self._fds = {} - self._timeout = None + self._requests = ( + collections.deque() + ) # type: Deque[Tuple[HTTPRequest, Callable[[HTTPResponse], None], float]] + self._fds = {} # type: Dict[int, int] + self._timeout = None # type: Optional[object] # libcurl has bugs that sometimes cause it to not report all # relevant file descriptors and timeouts to TIMERFUNCTION/ # SOCKETFUNCTION. Mitigate the effects of such bugs by # forcing a periodic scan of all active requests. self._force_timeout_callback = ioloop.PeriodicCallback( - self._handle_force_timeout, 1000) + self._handle_force_timeout, 1000 + ) self._force_timeout_callback.start() # Work around a bug in libcurl 7.29.0: Some fields in the curl @@ -64,27 +80,29 @@ def initialize(self, max_clients=10, defaults=None): self._multi.add_handle(dummy_curl_handle) self._multi.remove_handle(dummy_curl_handle) - def close(self): + def close(self) -> None: self._force_timeout_callback.stop() if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) for curl in self._curls: curl.close() self._multi.close() - super(CurlAsyncHTTPClient, self).close() + super().close() # Set below properties to None to reduce the reference count of current # instance, because those properties hold some methods of current # instance that will case circular reference. - self._force_timeout_callback = None + self._force_timeout_callback = None # type: ignore self._multi = None - def fetch_impl(self, request, callback): - self._requests.append((request, callback)) + def fetch_impl( + self, request: HTTPRequest, callback: Callable[[HTTPResponse], None] + ) -> None: + self._requests.append((request, callback, self.io_loop.time())) self._process_queue() self._set_timeout(0) - def _handle_socket(self, event, fd, multi, data): + def _handle_socket(self, event: int, fd: int, multi: Any, data: bytes) -> None: """Called by libcurl when it wants to change the file descriptors it cares about. """ @@ -92,7 +110,7 @@ def _handle_socket(self, event, fd, multi, data): pycurl.POLL_NONE: ioloop.IOLoop.NONE, pycurl.POLL_IN: ioloop.IOLoop.READ, pycurl.POLL_OUT: ioloop.IOLoop.WRITE, - pycurl.POLL_INOUT: ioloop.IOLoop.READ | ioloop.IOLoop.WRITE + pycurl.POLL_INOUT: ioloop.IOLoop.READ | ioloop.IOLoop.WRITE, } if event == pycurl.POLL_REMOVE: if fd in self._fds: @@ -110,18 +128,18 @@ def _handle_socket(self, event, fd, multi, data): # instead of update. if fd in self._fds: self.io_loop.remove_handler(fd) - self.io_loop.add_handler(fd, self._handle_events, - ioloop_event) + self.io_loop.add_handler(fd, self._handle_events, ioloop_event) self._fds[fd] = ioloop_event - def _set_timeout(self, msecs): + def _set_timeout(self, msecs: int) -> None: """Called by libcurl to schedule a timeout.""" if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = self.io_loop.add_timeout( - self.io_loop.time() + msecs / 1000.0, self._handle_timeout) + self.io_loop.time() + msecs / 1000.0, self._handle_timeout + ) - def _handle_events(self, fd, events): + def _handle_events(self, fd: int, events: int) -> None: """Called by IOLoop when there is activity on one of our file descriptors. """ @@ -139,19 +157,17 @@ def _handle_events(self, fd, events): break self._finish_pending_requests() - def _handle_timeout(self): + def _handle_timeout(self) -> None: """Called by IOLoop when the requested timeout has passed.""" - with stack_context.NullContext(): - self._timeout = None - while True: - try: - ret, num_handles = self._multi.socket_action( - pycurl.SOCKET_TIMEOUT, 0) - except pycurl.error as e: - ret = e.args[0] - if ret != pycurl.E_CALL_MULTI_PERFORM: - break - self._finish_pending_requests() + self._timeout = None + while True: + try: + ret, num_handles = self._multi.socket_action(pycurl.SOCKET_TIMEOUT, 0) + except pycurl.error as e: + ret = e.args[0] + if ret != pycurl.E_CALL_MULTI_PERFORM: + break + self._finish_pending_requests() # In theory, we shouldn't have to do this because curl will # call _set_timeout whenever the timeout changes. However, @@ -170,21 +186,20 @@ def _handle_timeout(self): if new_timeout >= 0: self._set_timeout(new_timeout) - def _handle_force_timeout(self): + def _handle_force_timeout(self) -> None: """Called by IOLoop periodically to ask libcurl to process any events it may have forgotten about. """ - with stack_context.NullContext(): - while True: - try: - ret, num_handles = self._multi.socket_all() - except pycurl.error as e: - ret = e.args[0] - if ret != pycurl.E_CALL_MULTI_PERFORM: - break - self._finish_pending_requests() - - def _finish_pending_requests(self): + while True: + try: + ret, num_handles = self._multi.socket_all() + except pycurl.error as e: + ret = e.args[0] + if ret != pycurl.E_CALL_MULTI_PERFORM: + break + self._finish_pending_requests() + + def _finish_pending_requests(self) -> None: """Process any requests that were completed by the last call to multi.socket_action. """ @@ -198,53 +213,62 @@ def _finish_pending_requests(self): break self._process_queue() - def _process_queue(self): - with stack_context.NullContext(): - while True: - started = 0 - while self._free_list and self._requests: - started += 1 - curl = self._free_list.pop() - (request, callback) = self._requests.popleft() - curl.info = { - "headers": httputil.HTTPHeaders(), - "buffer": BytesIO(), - "request": request, - "callback": callback, - "curl_start_time": time.time(), - } - try: - self._curl_setup_request( - curl, request, curl.info["buffer"], - curl.info["headers"]) - except Exception as e: - # If there was an error in setup, pass it on - # to the callback. Note that allowing the - # error to escape here will appear to work - # most of the time since we are still in the - # caller's original stack frame, but when - # _process_queue() is called from - # _finish_pending_requests the exceptions have - # nowhere to go. - self._free_list.append(curl) - callback(HTTPResponse( - request=request, - code=599, - error=e)) - else: - self._multi.add_handle(curl) - - if not started: - break - - def _finish(self, curl, curl_error=None, curl_message=None): - info = curl.info - curl.info = None + def _process_queue(self) -> None: + while True: + started = 0 + while self._free_list and self._requests: + started += 1 + curl = self._free_list.pop() + (request, callback, queue_start_time) = self._requests.popleft() + # TODO: Don't smuggle extra data on an attribute of the Curl object. + curl.info = { # type: ignore + "headers": httputil.HTTPHeaders(), + "buffer": BytesIO(), + "request": request, + "callback": callback, + "queue_start_time": queue_start_time, + "curl_start_time": time.time(), + "curl_start_ioloop_time": self.io_loop.current().time(), # type: ignore + } + try: + self._curl_setup_request( + curl, + request, + curl.info["buffer"], # type: ignore + curl.info["headers"], # type: ignore + ) + except Exception as e: + # If there was an error in setup, pass it on + # to the callback. Note that allowing the + # error to escape here will appear to work + # most of the time since we are still in the + # caller's original stack frame, but when + # _process_queue() is called from + # _finish_pending_requests the exceptions have + # nowhere to go. + self._free_list.append(curl) + callback(HTTPResponse(request=request, code=599, error=e)) + else: + self._multi.add_handle(curl) + + if not started: + break + + def _finish( + self, + curl: pycurl.Curl, + curl_error: Optional[int] = None, + curl_message: Optional[str] = None, + ) -> None: + info = curl.info # type: ignore + curl.info = None # type: ignore self._multi.remove_handle(curl) self._free_list.append(curl) buffer = info["buffer"] if curl_error: - error = CurlError(curl_error, curl_message) + assert curl_message is not None + error = CurlError(curl_error, curl_message) # type: Optional[CurlError] + assert error is not None code = error.code effective_url = None buffer.close() @@ -257,7 +281,7 @@ def _finish(self, curl, curl_error=None, curl_message=None): # the various curl timings are documented at # http://curl.haxx.se/libcurl/c/curl_easy_getinfo.html time_info = dict( - queue=info["curl_start_time"] - info["request"].start_time, + queue=info["curl_start_ioloop_time"] - info["queue_start_time"], namelookup=curl.getinfo(pycurl.NAMELOOKUP_TIME), connect=curl.getinfo(pycurl.CONNECT_TIME), appconnect=curl.getinfo(pycurl.APPCONNECT_TIME), @@ -267,29 +291,45 @@ def _finish(self, curl, curl_error=None, curl_message=None): redirect=curl.getinfo(pycurl.REDIRECT_TIME), ) try: - info["callback"](HTTPResponse( - request=info["request"], code=code, headers=info["headers"], - buffer=buffer, effective_url=effective_url, error=error, - reason=info['headers'].get("X-Http-Reason", None), - request_time=time.time() - info["curl_start_time"], - time_info=time_info)) + info["callback"]( + HTTPResponse( + request=info["request"], + code=code, + headers=info["headers"], + buffer=buffer, + effective_url=effective_url, + error=error, + reason=info["headers"].get("X-Http-Reason", None), + request_time=self.io_loop.time() - info["curl_start_ioloop_time"], + start_time=info["curl_start_time"], + time_info=time_info, + ) + ) except Exception: self.handle_callback_exception(info["callback"]) - def handle_callback_exception(self, callback): - self.io_loop.handle_callback_exception(callback) + def handle_callback_exception(self, callback: Any) -> None: + app_log.error("Exception in callback %r", callback, exc_info=True) - def _curl_create(self): + def _curl_create(self) -> pycurl.Curl: curl = pycurl.Curl() if curl_log.isEnabledFor(logging.DEBUG): curl.setopt(pycurl.VERBOSE, 1) curl.setopt(pycurl.DEBUGFUNCTION, self._curl_debug) - if hasattr(pycurl, 'PROTOCOLS'): # PROTOCOLS first appeared in pycurl 7.19.5 (2014-07-12) + if hasattr( + pycurl, "PROTOCOLS" + ): # PROTOCOLS first appeared in pycurl 7.19.5 (2014-07-12) curl.setopt(pycurl.PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS) curl.setopt(pycurl.REDIR_PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS) return curl - def _curl_setup_request(self, curl, request, buffer, headers): + def _curl_setup_request( + self, + curl: pycurl.Curl, + request: HTTPRequest, + buffer: BytesIO, + headers: httputil.HTTPHeaders, + ) -> None: curl.setopt(pycurl.URL, native_str(request.url)) # libcurl's magic "Expect: 100-continue" behavior causes delays @@ -307,32 +347,35 @@ def _curl_setup_request(self, curl, request, buffer, headers): if "Pragma" not in request.headers: request.headers["Pragma"] = "" - curl.setopt(pycurl.HTTPHEADER, - ["%s: %s" % (native_str(k), native_str(v)) - for k, v in request.headers.get_all()]) + curl.setopt( + pycurl.HTTPHEADER, + [ + "%s: %s" % (native_str(k), native_str(v)) + for k, v in request.headers.get_all() + ], + ) - curl.setopt(pycurl.HEADERFUNCTION, - functools.partial(self._curl_header_callback, - headers, request.header_callback)) + curl.setopt( + pycurl.HEADERFUNCTION, + functools.partial( + self._curl_header_callback, headers, request.header_callback + ), + ) if request.streaming_callback: - def write_function(chunk): - self.io_loop.add_callback(request.streaming_callback, chunk) + + def write_function(b: Union[bytes, bytearray]) -> int: + assert request.streaming_callback is not None + self.io_loop.add_callback(request.streaming_callback, b) + return len(b) + else: - write_function = buffer.write - if bytes is str: # py2 - curl.setopt(pycurl.WRITEFUNCTION, write_function) - else: # py3 - # Upstream pycurl doesn't support py3, but ubuntu 12.10 includes - # a fork/port. That version has a bug in which it passes unicode - # strings instead of bytes to the WRITEFUNCTION. This means that - # if you use a WRITEFUNCTION (which tornado always does), you cannot - # download arbitrary binary data. This needs to be fixed in the - # ported pycurl package, but in the meantime this lambda will - # make it work for downloading (utf8) text. - curl.setopt(pycurl.WRITEFUNCTION, lambda s: write_function(utf8(s))) + write_function = buffer.write # type: ignore + curl.setopt(pycurl.WRITEFUNCTION, write_function) curl.setopt(pycurl.FOLLOWLOCATION, request.follow_redirects) curl.setopt(pycurl.MAXREDIRS, request.max_redirects) + assert request.connect_timeout is not None curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(1000 * request.connect_timeout)) + assert request.request_timeout is not None curl.setopt(pycurl.TIMEOUT_MS, int(1000 * request.request_timeout)) if request.user_agent: curl.setopt(pycurl.USERAGENT, native_str(request.user_agent)) @@ -343,25 +386,30 @@ def write_function(chunk): if request.decompress_response: curl.setopt(pycurl.ENCODING, "gzip,deflate") else: - curl.setopt(pycurl.ENCODING, "none") + curl.setopt(pycurl.ENCODING, None) if request.proxy_host and request.proxy_port: curl.setopt(pycurl.PROXY, request.proxy_host) curl.setopt(pycurl.PROXYPORT, request.proxy_port) if request.proxy_username: - credentials = '%s:%s' % (request.proxy_username, - request.proxy_password) + assert request.proxy_password is not None + credentials = httputil.encode_username_password( + request.proxy_username, request.proxy_password + ) curl.setopt(pycurl.PROXYUSERPWD, credentials) - if (request.proxy_auth_mode is None or - request.proxy_auth_mode == "basic"): + if request.proxy_auth_mode is None or request.proxy_auth_mode == "basic": curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_BASIC) elif request.proxy_auth_mode == "digest": curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_DIGEST) else: raise ValueError( - "Unsupported proxy_auth_mode %s" % request.proxy_auth_mode) + "Unsupported proxy_auth_mode %s" % request.proxy_auth_mode + ) else: - curl.setopt(pycurl.PROXY, '') + try: + curl.unsetopt(pycurl.PROXY) + except TypeError: # not supported, disable proxy + curl.setopt(pycurl.PROXY, "") curl.unsetopt(pycurl.PROXYUSERPWD) if request.validate_cert: curl.setopt(pycurl.SSL_VERIFYPEER, 1) @@ -404,7 +452,7 @@ def write_function(chunk): elif request.allow_nonstandard_methods or request.method in custom_methods: curl.setopt(pycurl.CUSTOMREQUEST, request.method) else: - raise KeyError('unknown method ' + request.method) + raise KeyError("unknown method " + request.method) body_expected = request.method in ("POST", "PATCH", "PUT") body_present = request.body is not None @@ -412,12 +460,14 @@ def write_function(chunk): # Some HTTP methods nearly always have bodies while others # almost never do. Fail in this case unless the user has # opted out of sanity checks with allow_nonstandard_methods. - if ((body_expected and not body_present) or - (body_present and not body_expected)): + if (body_expected and not body_present) or ( + body_present and not body_expected + ): raise ValueError( - 'Body must %sbe None for method %s (unless ' - 'allow_nonstandard_methods is true)' % - ('not ' if body_expected else '', request.method)) + "Body must %sbe None for method %s (unless " + "allow_nonstandard_methods is true)" + % ("not " if body_expected else "", request.method) + ) if body_expected or body_present: if request.method == "GET": @@ -426,23 +476,23 @@ def write_function(chunk): # unless we use CUSTOMREQUEST). While the spec doesn't # forbid clients from sending a body, it arguably # disallows the server from doing anything with them. - raise ValueError('Body must be None for GET request') - request_buffer = BytesIO(utf8(request.body or '')) + raise ValueError("Body must be None for GET request") + request_buffer = BytesIO(utf8(request.body or "")) - def ioctl(cmd): - if cmd == curl.IOCMD_RESTARTREAD: + def ioctl(cmd: int) -> None: + if cmd == curl.IOCMD_RESTARTREAD: # type: ignore request_buffer.seek(0) + curl.setopt(pycurl.READFUNCTION, request_buffer.read) curl.setopt(pycurl.IOCTLFUNCTION, ioctl) if request.method == "POST": - curl.setopt(pycurl.POSTFIELDSIZE, len(request.body or '')) + curl.setopt(pycurl.POSTFIELDSIZE, len(request.body or "")) else: curl.setopt(pycurl.UPLOAD, True) - curl.setopt(pycurl.INFILESIZE, len(request.body or '')) + curl.setopt(pycurl.INFILESIZE, len(request.body or "")) if request.auth_username is not None: - userpwd = "%s:%s" % (request.auth_username, request.auth_password or '') - + assert request.auth_password is not None if request.auth_mode is None or request.auth_mode == "basic": curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_BASIC) elif request.auth_mode == "digest": @@ -450,9 +500,16 @@ def ioctl(cmd): else: raise ValueError("Unsupported auth_mode %s" % request.auth_mode) - curl.setopt(pycurl.USERPWD, native_str(userpwd)) - curl_log.debug("%s %s (username: %r)", request.method, request.url, - request.auth_username) + userpwd = httputil.encode_username_password( + request.auth_username, request.auth_password + ) + curl.setopt(pycurl.USERPWD, userpwd) + curl_log.debug( + "%s %s (username: %r)", + request.method, + request.url, + request.auth_username, + ) else: curl.unsetopt(pycurl.USERPWD) curl_log.debug("%s %s", request.method, request.url) @@ -466,7 +523,7 @@ def ioctl(cmd): if request.ssl_options is not None: raise ValueError("ssl_options not supported in curl_httpclient") - if threading.activeCount() > 1: + if threading.active_count() > 1: # libcurl/pycurl is not thread-safe by default. When multiple threads # are used, signals should be disabled. This has the side effect # of disabling DNS timeouts in some environments (when libcurl is @@ -479,8 +536,13 @@ def ioctl(cmd): if request.prepare_curl_callback is not None: request.prepare_curl_callback(curl) - def _curl_header_callback(self, headers, header_callback, header_line): - header_line = native_str(header_line.decode('latin1')) + def _curl_header_callback( + self, + headers: httputil.HTTPHeaders, + header_callback: Callable[[str], None], + header_line_bytes: bytes, + ) -> None: + header_line = native_str(header_line_bytes.decode("latin1")) if header_callback is not None: self.io_loop.add_callback(header_callback, header_line) # header_line as returned by curl includes the end-of-line characters. @@ -497,21 +559,21 @@ def _curl_header_callback(self, headers, header_callback, header_line): return headers.parse_line(header_line) - def _curl_debug(self, debug_type, debug_msg): - debug_types = ('I', '<', '>', '<', '>') + def _curl_debug(self, debug_type: int, debug_msg: str) -> None: + debug_types = ("I", "<", ">", "<", ">") if debug_type == 0: debug_msg = native_str(debug_msg) - curl_log.debug('%s', debug_msg.strip()) + curl_log.debug("%s", debug_msg.strip()) elif debug_type in (1, 2): debug_msg = native_str(debug_msg) for line in debug_msg.splitlines(): - curl_log.debug('%s %s', debug_types[debug_type], line) + curl_log.debug("%s %s", debug_types[debug_type], line) elif debug_type == 4: - curl_log.debug('%s %r', debug_types[debug_type], debug_msg) + curl_log.debug("%s %r", debug_types[debug_type], debug_msg) class CurlError(HTTPError): - def __init__(self, errno, message): + def __init__(self, errno: int, message: str) -> None: HTTPError.__init__(self, 599, message) self.errno = errno diff --git a/tornado/escape.py b/tornado/escape.py index a79ece66ce..3cf7ff2e4a 100644 --- a/tornado/escape.py +++ b/tornado/escape.py @@ -19,35 +19,28 @@ have crept in over time. """ -from __future__ import absolute_import, division, print_function - +import html.entities import json import re +import urllib.parse -from tornado.util import PY3, unicode_type, basestring_type - -if PY3: - from urllib.parse import parse_qs as _parse_qs - import html.entities as htmlentitydefs - import urllib.parse as urllib_parse - unichr = chr -else: - from urlparse import parse_qs as _parse_qs - import htmlentitydefs - import urllib as urllib_parse - -try: - import typing # noqa -except ImportError: - pass +from tornado.util import unicode_type +import typing +from typing import Union, Any, Optional, Dict, List, Callable -_XHTML_ESCAPE_RE = re.compile('[&<>"\']') -_XHTML_ESCAPE_DICT = {'&': '&', '<': '<', '>': '>', '"': '"', - '\'': '''} +_XHTML_ESCAPE_RE = re.compile("[&<>\"']") +_XHTML_ESCAPE_DICT = { + "&": "&", + "<": "<", + ">": ">", + '"': """, + "'": "'", +} -def xhtml_escape(value): + +def xhtml_escape(value: Union[str, bytes]) -> str: """Escapes a string so it is valid within HTML or XML. Escapes the characters ``<``, ``>``, ``"``, ``'``, and ``&``. @@ -58,11 +51,12 @@ def xhtml_escape(value): Added the single quote to the list of escaped characters. """ - return _XHTML_ESCAPE_RE.sub(lambda match: _XHTML_ESCAPE_DICT[match.group(0)], - to_basestring(value)) + return _XHTML_ESCAPE_RE.sub( + lambda match: _XHTML_ESCAPE_DICT[match.group(0)], to_basestring(value) + ) -def xhtml_unescape(value): +def xhtml_unescape(value: Union[str, bytes]) -> str: """Un-escapes an XML-escaped string.""" return re.sub(r"&(#?)(\w+?);", _convert_entity, _unicode(value)) @@ -70,28 +64,31 @@ def xhtml_unescape(value): # The fact that json_encode wraps json.dumps is an implementation detail. # Please see https://github.com/tornadoweb/tornado/pull/706 # before sending a pull request that adds **kwargs to this function. -def json_encode(value): +def json_encode(value: Any) -> str: """JSON-encodes the given Python object.""" # JSON permits but does not require forward slashes to be escaped. # This is useful when json data is emitted in a tags from prematurely terminating - # the javascript. Some json libraries do this escaping by default, + # the JavaScript. Some json libraries do this escaping by default, # although python's standard library does not, so we do it here. # http://stackoverflow.com/questions/1580647/json-why-are-forward-slashes-escaped return json.dumps(value).replace(" Any: + """Returns Python objects for the given JSON string. + + Supports both `str` and `bytes` inputs. + """ return json.loads(to_basestring(value)) -def squeeze(value): +def squeeze(value: str) -> str: """Replace all sequences of whitespace chars with a single space.""" return re.sub(r"[\x00-\x20]+", " ", value).strip() -def url_escape(value, plus=True): +def url_escape(value: Union[str, bytes], plus: bool = True) -> str: """Returns a URL-encoded version of the given value. If ``plus`` is true (the default), spaces will be represented @@ -102,89 +99,93 @@ def url_escape(value, plus=True): .. versionadded:: 3.1 The ``plus`` argument """ - quote = urllib_parse.quote_plus if plus else urllib_parse.quote + quote = urllib.parse.quote_plus if plus else urllib.parse.quote return quote(utf8(value)) -# python 3 changed things around enough that we need two separate -# implementations of url_unescape. We also need our own implementation -# of parse_qs since python 3's version insists on decoding everything. -if not PY3: - def url_unescape(value, encoding='utf-8', plus=True): - """Decodes the given value from a URL. +@typing.overload +def url_unescape(value: Union[str, bytes], encoding: None, plus: bool = True) -> bytes: + pass - The argument may be either a byte or unicode string. - If encoding is None, the result will be a byte string. Otherwise, - the result is a unicode string in the specified encoding. +@typing.overload # noqa: F811 +def url_unescape( + value: Union[str, bytes], encoding: str = "utf-8", plus: bool = True +) -> str: + pass - If ``plus`` is true (the default), plus signs will be interpreted - as spaces (literal plus signs must be represented as "%2B"). This - is appropriate for query strings and form-encoded values but not - for the path component of a URL. Note that this default is the - reverse of Python's urllib module. - .. versionadded:: 3.1 - The ``plus`` argument - """ - unquote = (urllib_parse.unquote_plus if plus else urllib_parse.unquote) - if encoding is None: - return unquote(utf8(value)) - else: - return unicode_type(unquote(utf8(value)), encoding) - - parse_qs_bytes = _parse_qs -else: - def url_unescape(value, encoding='utf-8', plus=True): - """Decodes the given value from a URL. - - The argument may be either a byte or unicode string. - - If encoding is None, the result will be a byte string. Otherwise, - the result is a unicode string in the specified encoding. - - If ``plus`` is true (the default), plus signs will be interpreted - as spaces (literal plus signs must be represented as "%2B"). This - is appropriate for query strings and form-encoded values but not - for the path component of a URL. Note that this default is the - reverse of Python's urllib module. - - .. versionadded:: 3.1 - The ``plus`` argument - """ - if encoding is None: - if plus: - # unquote_to_bytes doesn't have a _plus variant - value = to_basestring(value).replace('+', ' ') - return urllib_parse.unquote_to_bytes(value) - else: - unquote = (urllib_parse.unquote_plus if plus - else urllib_parse.unquote) - return unquote(to_basestring(value), encoding=encoding) - - def parse_qs_bytes(qs, keep_blank_values=False, strict_parsing=False): - """Parses a query string like urlparse.parse_qs, but returns the - values as byte strings. - - Keys still become type str (interpreted as latin1 in python3!) - because it's too painful to keep them as byte strings in - python3 and in practice they're nearly always ascii anyway. - """ - # This is gross, but python3 doesn't give us another way. - # Latin1 is the universal donor of character encodings. - result = _parse_qs(qs, keep_blank_values, strict_parsing, - encoding='latin1', errors='strict') - encoded = {} - for k, v in result.items(): - encoded[k] = [i.encode('latin1') for i in v] - return encoded +def url_unescape( # noqa: F811 + value: Union[str, bytes], encoding: Optional[str] = "utf-8", plus: bool = True +) -> Union[str, bytes]: + """Decodes the given value from a URL. + + The argument may be either a byte or unicode string. + + If encoding is None, the result will be a byte string. Otherwise, + the result is a unicode string in the specified encoding. + + If ``plus`` is true (the default), plus signs will be interpreted + as spaces (literal plus signs must be represented as "%2B"). This + is appropriate for query strings and form-encoded values but not + for the path component of a URL. Note that this default is the + reverse of Python's urllib module. + + .. versionadded:: 3.1 + The ``plus`` argument + """ + if encoding is None: + if plus: + # unquote_to_bytes doesn't have a _plus variant + value = to_basestring(value).replace("+", " ") + return urllib.parse.unquote_to_bytes(value) + else: + unquote = urllib.parse.unquote_plus if plus else urllib.parse.unquote + return unquote(to_basestring(value), encoding=encoding) + + +def parse_qs_bytes( + qs: Union[str, bytes], keep_blank_values: bool = False, strict_parsing: bool = False +) -> Dict[str, List[bytes]]: + """Parses a query string like urlparse.parse_qs, + but takes bytes and returns the values as byte strings. + + Keys still become type str (interpreted as latin1 in python3!) + because it's too painful to keep them as byte strings in + python3 and in practice they're nearly always ascii anyway. + """ + # This is gross, but python3 doesn't give us another way. + # Latin1 is the universal donor of character encodings. + if isinstance(qs, bytes): + qs = qs.decode("latin1") + result = urllib.parse.parse_qs( + qs, keep_blank_values, strict_parsing, encoding="latin1", errors="strict" + ) + encoded = {} + for k, v in result.items(): + encoded[k] = [i.encode("latin1") for i in v] + return encoded _UTF8_TYPES = (bytes, type(None)) -def utf8(value): - # type: (typing.Union[bytes,unicode_type,None])->typing.Union[bytes,None] +@typing.overload +def utf8(value: bytes) -> bytes: + pass + + +@typing.overload # noqa: F811 +def utf8(value: str) -> bytes: + pass + + +@typing.overload # noqa: F811 +def utf8(value: None) -> None: + pass + + +def utf8(value: Union[None, str, bytes]) -> Optional[bytes]: # noqa: F811 """Converts a string argument to a byte string. If the argument is already a byte string or None, it is returned unchanged. @@ -193,16 +194,29 @@ def utf8(value): if isinstance(value, _UTF8_TYPES): return value if not isinstance(value, unicode_type): - raise TypeError( - "Expected bytes, unicode, or None; got %r" % type(value) - ) + raise TypeError("Expected bytes, unicode, or None; got %r" % type(value)) return value.encode("utf-8") _TO_UNICODE_TYPES = (unicode_type, type(None)) -def to_unicode(value): +@typing.overload +def to_unicode(value: str) -> str: + pass + + +@typing.overload # noqa: F811 +def to_unicode(value: bytes) -> str: + pass + + +@typing.overload # noqa: F811 +def to_unicode(value: None) -> None: + pass + + +def to_unicode(value: Union[None, str, bytes]) -> Optional[str]: # noqa: F811 """Converts a string argument to a unicode string. If the argument is already a unicode string or None, it is returned @@ -211,9 +225,7 @@ def to_unicode(value): if isinstance(value, _TO_UNICODE_TYPES): return value if not isinstance(value, bytes): - raise TypeError( - "Expected bytes, unicode, or None; got %r" % type(value) - ) + raise TypeError("Expected bytes, unicode, or None; got %r" % type(value)) return value.decode("utf-8") @@ -223,39 +235,19 @@ def to_unicode(value): # When dealing with the standard library across python 2 and 3 it is # sometimes useful to have a direct conversion to the native string type -if str is unicode_type: - native_str = to_unicode -else: - native_str = utf8 - -_BASESTRING_TYPES = (basestring_type, type(None)) +native_str = to_unicode +to_basestring = to_unicode -def to_basestring(value): - """Converts a string argument to a subclass of basestring. - - In python2, byte and unicode strings are mostly interchangeable, - so functions that deal with a user-supplied argument in combination - with ascii string constants can use either and should return the type - the user supplied. In python3, the two types are not interchangeable, - so this method is needed to convert byte strings to unicode. - """ - if isinstance(value, _BASESTRING_TYPES): - return value - if not isinstance(value, bytes): - raise TypeError( - "Expected bytes, unicode, or None; got %r" % type(value) - ) - return value.decode("utf-8") - - -def recursive_unicode(obj): +def recursive_unicode(obj: Any) -> Any: """Walks a simple data structure, converting byte strings to unicode. Supports lists, tuples, and dictionaries. """ if isinstance(obj, dict): - return dict((recursive_unicode(k), recursive_unicode(v)) for (k, v) in obj.items()) + return dict( + (recursive_unicode(k), recursive_unicode(v)) for (k, v) in obj.items() + ) elif isinstance(obj, list): return list(recursive_unicode(i) for i in obj) elif isinstance(obj, tuple): @@ -273,13 +265,20 @@ def recursive_unicode(obj): # This regex should avoid those problems. # Use to_unicode instead of tornado.util.u - we don't want backslashes getting # processed as escapes. -_URL_RE = re.compile(to_unicode( - r"""\b((?:([\w-]+):(/{1,3})|www[.])(?:(?:(?:[^\s&()]|&|")*(?:[^!"#$%&'()*+,.:;<=>?@\[\]^`{|}~\s]))|(?:\((?:[^\s&()]|&|")*\)))+)""" # noqa: E501 -)) - - -def linkify(text, shorten=False, extra_params="", - require_protocol=False, permitted_protocols=["http", "https"]): +_URL_RE = re.compile( + to_unicode( + r"""\b((?:([\w-]+):(/{1,3})|www[.])(?:(?:(?:[^\s&()]|&|")*(?:[^!"#$%&'()*+,.:;<=>?@\[\]^`{|}~\s]))|(?:\((?:[^\s&()]|&|")*\)))+)""" # noqa: E501 + ) +) + + +def linkify( + text: Union[str, bytes], + shorten: bool = False, + extra_params: Union[str, Callable[[str], str]] = "", + require_protocol: bool = False, + permitted_protocols: List[str] = ["http", "https"], +) -> str: """Converts plain text into HTML with links. For example: ``linkify("Hello http://tornadoweb.org!")`` would return @@ -312,7 +311,7 @@ def extra_params_cb(url): if extra_params and not callable(extra_params): extra_params = " " + extra_params.strip() - def make_link(m): + def make_link(m: typing.Match) -> str: url = m.group(1) proto = m.group(2) if require_protocol and not proto: @@ -323,7 +322,7 @@ def make_link(m): href = m.group(1) if not proto: - href = "http://" + href # no proto specified, use http + href = "http://" + href # no proto specified, use http if callable(extra_params): params = " " + extra_params(href).strip() @@ -345,14 +344,18 @@ def make_link(m): # The path is usually not that interesting once shortened # (no more slug, etc), so it really just provides a little # extra indication of shortening. - url = url[:proto_len] + parts[0] + "/" + \ - parts[1][:8].split('?')[0].split('.')[0] + url = ( + url[:proto_len] + + parts[0] + + "/" + + parts[1][:8].split("?")[0].split(".")[0] + ) if len(url) > max_len * 1.5: # still too long url = url[:max_len] if url != before_clip: - amp = url.rfind('&') + amp = url.rfind("&") # avoid splitting html char entities if amp > max_len - 5: url = url[:amp] @@ -374,13 +377,13 @@ def make_link(m): return _URL_RE.sub(make_link, text) -def _convert_entity(m): +def _convert_entity(m: typing.Match) -> str: if m.group(1) == "#": try: - if m.group(2)[:1].lower() == 'x': - return unichr(int(m.group(2)[1:], 16)) + if m.group(2)[:1].lower() == "x": + return chr(int(m.group(2)[1:], 16)) else: - return unichr(int(m.group(2))) + return chr(int(m.group(2))) except ValueError: return "&#%s;" % m.group(2) try: @@ -389,10 +392,10 @@ def _convert_entity(m): return "&%s;" % m.group(2) -def _build_unicode_map(): +def _build_unicode_map() -> Dict[str, str]: unicode_map = {} - for name, value in htmlentitydefs.name2codepoint.items(): - unicode_map[name] = unichr(value) + for name, value in html.entities.name2codepoint.items(): + unicode_map[name] = chr(value) return unicode_map diff --git a/tornado/gen.py b/tornado/gen.py index cc59402e5e..a6370259b3 100644 --- a/tornado/gen.py +++ b/tornado/gen.py @@ -17,25 +17,7 @@ technically asynchronous, but it is written as a single generator instead of a collection of separate functions. -For example, the following asynchronous handler: - -.. testcode:: - - class AsyncHandler(RequestHandler): - @asynchronous - def get(self): - http_client = AsyncHTTPClient() - http_client.fetch("http://example.com", - callback=self.on_fetch) - - def on_fetch(self, response): - do_something_with_response(response) - self.render("template.html") - -.. testoutput:: - :hide: - -could be written with ``gen`` as: +For example, here's a coroutine-based handler: .. testcode:: @@ -50,12 +32,12 @@ def get(self): .. testoutput:: :hide: -Most asynchronous functions in Tornado return a `.Future`; -yielding this object returns its ``Future.result``. +Asynchronous functions in Tornado return an ``Awaitable`` or `.Future`; +yielding this object returns its result. -You can also yield a list or dict of ``Futures``, which will be -started at the same time and run in parallel; a list or dict of results will -be returned when they are all finished: +You can also yield a list or dict of other yieldable objects, which +will be started at the same time and run in parallel; a list or dict +of results will be returned when they are all finished: .. testcode:: @@ -72,13 +54,9 @@ def get(self): .. testoutput:: :hide: -If the `~functools.singledispatch` library is available (standard in -Python 3.4, available via the `singledispatch -`_ package on older -versions), additional types of objects may be yielded. Tornado includes -support for ``asyncio.Future`` and Twisted's ``Deferred`` class when -``tornado.platform.asyncio`` and ``tornado.platform.twisted`` are imported. -See the `convert_yielded` function to extend this mechanism. +If ``tornado.platform.twisted`` is imported, it is also possible to +yield Twisted's ``Deferred`` objects. See the `convert_yielded` +function to extend this mechanism. .. versionchanged:: 3.2 Dict support added. @@ -88,64 +66,46 @@ def get(self): via ``singledispatch``. """ -from __future__ import absolute_import, division, print_function - +import asyncio +import builtins import collections +from collections.abc import Generator +import concurrent.futures +import datetime import functools -import itertools -import os +from functools import singledispatch +from inspect import isawaitable import sys import types -import warnings -from tornado.concurrent import (Future, is_future, chain_future, future_set_exc_info, - future_add_done_callback, future_set_result_unless_cancelled) +from tornado.concurrent import ( + Future, + is_future, + chain_future, + future_set_exc_info, + future_add_done_callback, + future_set_result_unless_cancelled, +) from tornado.ioloop import IOLoop from tornado.log import app_log -from tornado import stack_context -from tornado.util import PY3, raise_exc_info, TimeoutError +from tornado.util import TimeoutError try: - try: - # py34+ - from functools import singledispatch # type: ignore - except ImportError: - from singledispatch import singledispatch # backport + import contextvars except ImportError: - # In most cases, singledispatch is required (to avoid - # difficult-to-diagnose problems in which the functionality - # available differs depending on which invisble packages are - # installed). However, in Google App Engine third-party - # dependencies are more trouble so we allow this module to be - # imported without it. - if 'APPENGINE_RUNTIME' not in os.environ: - raise - singledispatch = None + contextvars = None # type: ignore -try: - try: - # py35+ - from collections.abc import Generator as GeneratorType # type: ignore - except ImportError: - from backports_abc import Generator as GeneratorType # type: ignore +import typing +from typing import Union, Any, Callable, List, Type, Tuple, Awaitable, Dict, overload - try: - # py35+ - from inspect import isawaitable # type: ignore - except ImportError: - from backports_abc import isawaitable -except ImportError: - if 'APPENGINE_RUNTIME' not in os.environ: - raise - from types import GeneratorType +if typing.TYPE_CHECKING: + from typing import Sequence, Deque, Optional, Set, Iterable # noqa: F401 - def isawaitable(x): # type: ignore - return False +_T = typing.TypeVar("_T") -if PY3: - import builtins -else: - import __builtin__ as builtins +_Yieldable = Union[ + None, Awaitable, List[Awaitable], Dict[Any, Awaitable], concurrent.futures.Future +] class KeyReuseError(Exception): @@ -168,7 +128,7 @@ class ReturnValueIgnoredError(Exception): pass -def _value_from_stopiteration(e): +def _value_from_stopiteration(e: Union[StopIteration, "Return"]) -> Any: try: # StopIteration has a value attribute beginning in py33. # So does our Return class. @@ -183,8 +143,8 @@ def _value_from_stopiteration(e): return None -def _create_future(): - future = Future() +def _create_future() -> Future: + future = Future() # type: Future # Fixup asyncio debug info by removing extraneous stack entries source_traceback = getattr(future, "_source_traceback", ()) while source_traceback: @@ -198,68 +158,32 @@ def _create_future(): return future -def engine(func): - """Callback-oriented decorator for asynchronous generators. - - This is an older interface; for new code that does not need to be - compatible with versions of Tornado older than 3.0 the - `coroutine` decorator is recommended instead. +def _fake_ctx_run(f: Callable[..., _T], *args: Any, **kw: Any) -> _T: + return f(*args, **kw) - This decorator is similar to `coroutine`, except it does not - return a `.Future` and the ``callback`` argument is not treated - specially. - In most cases, functions decorated with `engine` should take - a ``callback`` argument and invoke it with their result when - they are finished. One notable exception is the - `~tornado.web.RequestHandler` :ref:`HTTP verb methods `, - which use ``self.finish()`` in place of a callback argument. +@overload +def coroutine( + func: Callable[..., "Generator[Any, Any, _T]"] +) -> Callable[..., "Future[_T]"]: + ... - .. deprecated:: 5.1 - This decorator will be removed in 6.0. Use `coroutine` or - ``async def`` instead. - """ - warnings.warn("gen.engine is deprecated, use gen.coroutine or async def instead", - DeprecationWarning) - func = _make_coroutine_wrapper(func, replace_callback=False) - - @functools.wraps(func) - def wrapper(*args, **kwargs): - future = func(*args, **kwargs) - - def final_callback(future): - if future.result() is not None: - raise ReturnValueIgnoredError( - "@gen.engine functions cannot return values: %r" % - (future.result(),)) - # The engine interface doesn't give us any way to return - # errors but to raise them into the stack context. - # Save the stack context here to use when the Future has resolved. - future_add_done_callback(future, stack_context.wrap(final_callback)) - return wrapper +@overload +def coroutine(func: Callable[..., _T]) -> Callable[..., "Future[_T]"]: + ... -def coroutine(func): +def coroutine( + func: Union[Callable[..., "Generator[Any, Any, _T]"], Callable[..., _T]] +) -> Callable[..., "Future[_T]"]: """Decorator for asynchronous generators. - Any generator that yields objects from this module must be wrapped - in either this decorator or `engine`. - - Coroutines may "return" by raising the special exception - `Return(value) `. In Python 3.3+, it is also possible for - the function to simply use the ``return value`` statement (prior to - Python 3.3 generators were not allowed to also return values). - In all versions of Python a coroutine that simply wishes to exit - early may use the ``return`` statement without a value. + For compatibility with older versions of Python, coroutines may + also "return" by raising the special exception `Return(value) + `. - Functions with this decorator return a `.Future`. Additionally, - they may be called with a ``callback`` keyword argument, which - will be invoked with the future's result when it resolves. If the - coroutine fails, the callback will not be run and an exception - will be raised into the surrounding `.StackContext`. The - ``callback`` argument is not visible inside the decorated - function; it is handled by the decorator itself. + Functions with this decorator return a `.Future`. .. warning:: @@ -271,40 +195,25 @@ def coroutine(func): `.IOLoop.run_sync` for top-level calls, or passing the `.Future` to `.IOLoop.add_future`. - .. deprecated:: 5.1 - - The ``callback`` argument is deprecated and will be removed in 6.0. - Use the returned awaitable object instead. - """ - return _make_coroutine_wrapper(func, replace_callback=True) - + .. versionchanged:: 6.0 -def _make_coroutine_wrapper(func, replace_callback): - """The inner workings of ``@gen.coroutine`` and ``@gen.engine``. + The ``callback`` argument was removed. Use the returned + awaitable object instead. - The two decorators differ in their treatment of the ``callback`` - argument, so we cannot simply implement ``@engine`` in terms of - ``@coroutine``. """ - # On Python 3.5, set the coroutine flag on our generator, to allow it - # to be used with 'await'. - wrapped = func - if hasattr(types, 'coroutine'): - func = types.coroutine(func) - @functools.wraps(wrapped) + @functools.wraps(func) def wrapper(*args, **kwargs): + # type: (*Any, **Any) -> Future[_T] + # This function is type-annotated with a comment to work around + # https://bitbucket.org/pypy/pypy/issues/2868/segfault-with-args-type-annotation-in future = _create_future() - - if replace_callback and 'callback' in kwargs: - warnings.warn("callback arguments are deprecated, use the returned Future instead", - DeprecationWarning, stacklevel=2) - callback = kwargs.pop('callback') - IOLoop.current().add_future( - future, lambda future: callback(future.result())) - + if contextvars is not None: + ctx_run = contextvars.copy_context().run # type: Callable + else: + ctx_run = _fake_ctx_run try: - result = func(*args, **kwargs) + result = ctx_run(func, *args, **kwargs) except (Return, StopIteration) as e: result = _value_from_stopiteration(e) except Exception: @@ -313,25 +222,20 @@ def wrapper(*args, **kwargs): return future finally: # Avoid circular references - future = None + future = None # type: ignore else: - if isinstance(result, GeneratorType): + if isinstance(result, Generator): # Inline the first iteration of Runner.run. This lets us # avoid the cost of creating a Runner when the coroutine # never actually yields, which in turn allows us to # use "optional" coroutines in critical path code without # performance penalty for the synchronous case. try: - orig_stack_contexts = stack_context._state.contexts - yielded = next(result) - if stack_context._state.contexts is not orig_stack_contexts: - yielded = _create_future() - yielded.set_exception( - stack_context.StackContextInconsistentError( - 'stack_context inconsistency (probably caused ' - 'by yield within a "with StackContext" block)')) + yielded = ctx_run(next, result) except (StopIteration, Return) as e: - future_set_result_unless_cancelled(future, _value_from_stopiteration(e)) + future_set_result_unless_cancelled( + future, _value_from_stopiteration(e) + ) except Exception: future_set_exc_info(future, sys.exc_info()) else: @@ -342,8 +246,8 @@ def wrapper(*args, **kwargs): # We do this by exploiting the public API # add_done_callback() instead of putting a private # attribute on the Future. - # (Github issues #1769, #2229). - runner = Runner(result, future, yielded) + # (GitHub issues #1769, #2229). + runner = Runner(ctx_run, result, future, yielded) future.add_done_callback(lambda _: runner) yielded = None try: @@ -357,22 +261,22 @@ def wrapper(*args, **kwargs): # benchmarks (relative to the refcount-based scheme # used in the absence of cycles). We can avoid the # cycle by clearing the local variable after we return it. - future = None + future = None # type: ignore future_set_result_unless_cancelled(future, result) return future - wrapper.__wrapped__ = wrapped - wrapper.__tornado_coroutine__ = True + wrapper.__wrapped__ = func # type: ignore + wrapper.__tornado_coroutine__ = True # type: ignore return wrapper -def is_coroutine_function(func): +def is_coroutine_function(func: Any) -> bool: """Return whether *func* is a coroutine function, i.e. a function wrapped with `~.gen.coroutine`. .. versionadded:: 4.5 """ - return getattr(func, '__tornado_coroutine__', False) + return getattr(func, "__tornado_coroutine__", False) class Return(Exception): @@ -395,30 +299,32 @@ def fetch_json(url): but it is never necessary to ``raise gen.Return()``. The ``return`` statement can be used with no arguments instead. """ - def __init__(self, value=None): - super(Return, self).__init__() + + def __init__(self, value: Any = None) -> None: + super().__init__() self.value = value # Cython recognizes subclasses of StopIteration with a .args tuple. self.args = (value,) class WaitIterator(object): - """Provides an iterator to yield the results of futures as they finish. + """Provides an iterator to yield the results of awaitables as they finish. - Yielding a set of futures like this: + Yielding a set of awaitables like this: - ``results = yield [future1, future2]`` + ``results = yield [awaitable1, awaitable2]`` - pauses the coroutine until both ``future1`` and ``future2`` + pauses the coroutine until both ``awaitable1`` and ``awaitable2`` return, and then restarts the coroutine with the results of both - futures. If either future is an exception, the expression will - raise that exception and all the results will be lost. + awaitables. If either awaitable raises an exception, the + expression will raise that exception and all the results will be + lost. - If you need to get the result of each future as soon as possible, - or if you need the result of some futures even if others produce + If you need to get the result of each awaitable as soon as possible, + or if you need the result of some awaitables even if others produce errors, you can use ``WaitIterator``:: - wait_iterator = gen.WaitIterator(future1, future2) + wait_iterator = gen.WaitIterator(awaitable1, awaitable2) while not wait_iterator.done(): try: result = yield wait_iterator.next() @@ -434,7 +340,7 @@ class WaitIterator(object): input arguments*. If you need to know which future produced the current result, you can use the attributes ``WaitIterator.current_future``, or ``WaitIterator.current_index`` - to get the index of the future from the input list. (if keyword + to get the index of the awaitable from the input list. (if keyword arguments were used in the construction of the `WaitIterator`, ``current_index`` will use the corresponding keyword). @@ -455,26 +361,29 @@ class WaitIterator(object): Added ``async for`` support in Python 3.5. """ - def __init__(self, *args, **kwargs): + + _unfinished = {} # type: Dict[Future, Union[int, str]] + + def __init__(self, *args: Future, **kwargs: Future) -> None: if args and kwargs: - raise ValueError( - "You must provide args or kwargs, not both") + raise ValueError("You must provide args or kwargs, not both") if kwargs: self._unfinished = dict((f, k) for (k, f) in kwargs.items()) - futures = list(kwargs.values()) + futures = list(kwargs.values()) # type: Sequence[Future] else: self._unfinished = dict((f, i) for (i, f) in enumerate(args)) futures = args - self._finished = collections.deque() - self.current_index = self.current_future = None - self._running_future = None + self._finished = collections.deque() # type: Deque[Future] + self.current_index = None # type: Optional[Union[str, int]] + self.current_future = None # type: Optional[Future] + self._running_future = None # type: Optional[Future] for future in futures: future_add_done_callback(future, self._done_callback) - def done(self): + def done(self) -> bool: """Returns True if this iterator has no more results.""" if self._finished or self._unfinished: return False @@ -482,7 +391,7 @@ def done(self): self.current_index = self.current_future = None return True - def next(self): + def next(self) -> Future: """Returns a `.Future` that will yield the next available result. Note that this `.Future` will not be the same object as any of @@ -491,233 +400,45 @@ def next(self): self._running_future = Future() if self._finished: - self._return_result(self._finished.popleft()) + return self._return_result(self._finished.popleft()) return self._running_future - def _done_callback(self, done): + def _done_callback(self, done: Future) -> None: if self._running_future and not self._running_future.done(): self._return_result(done) else: self._finished.append(done) - def _return_result(self, done): + def _return_result(self, done: Future) -> Future: """Called set the returned future's state that of the future we yielded, and set the current future for the iterator. """ + if self._running_future is None: + raise Exception("no future is running") chain_future(done, self._running_future) + res = self._running_future + self._running_future = None self.current_future = done self.current_index = self._unfinished.pop(done) - def __aiter__(self): + return res + + def __aiter__(self) -> typing.AsyncIterator: return self - def __anext__(self): + def __anext__(self) -> Future: if self.done(): # Lookup by name to silence pyflakes on older versions. - raise getattr(builtins, 'StopAsyncIteration')() + raise getattr(builtins, "StopAsyncIteration")() return self.next() -class YieldPoint(object): - """Base class for objects that may be yielded from the generator. - - .. deprecated:: 4.0 - Use `Futures <.Future>` instead. This class and all its subclasses - will be removed in 6.0 - """ - def __init__(self): - warnings.warn("YieldPoint is deprecated, use Futures instead", - DeprecationWarning) - - def start(self, runner): - """Called by the runner after the generator has yielded. - - No other methods will be called on this object before ``start``. - """ - raise NotImplementedError() - - def is_ready(self): - """Called by the runner to determine whether to resume the generator. - - Returns a boolean; may be called more than once. - """ - raise NotImplementedError() - - def get_result(self): - """Returns the value to use as the result of the yield expression. - - This method will only be called once, and only after `is_ready` - has returned true. - """ - raise NotImplementedError() - - -class Callback(YieldPoint): - """Returns a callable object that will allow a matching `Wait` to proceed. - - The key may be any value suitable for use as a dictionary key, and is - used to match ``Callbacks`` to their corresponding ``Waits``. The key - must be unique among outstanding callbacks within a single run of the - generator function, but may be reused across different runs of the same - function (so constants generally work fine). - - The callback may be called with zero or one arguments; if an argument - is given it will be returned by `Wait`. - - .. deprecated:: 4.0 - Use `Futures <.Future>` instead. This class will be removed in 6.0. - """ - def __init__(self, key): - warnings.warn("gen.Callback is deprecated, use Futures instead", - DeprecationWarning) - self.key = key - - def start(self, runner): - self.runner = runner - runner.register_callback(self.key) - - def is_ready(self): - return True - - def get_result(self): - return self.runner.result_callback(self.key) - - -class Wait(YieldPoint): - """Returns the argument passed to the result of a previous `Callback`. - - .. deprecated:: 4.0 - Use `Futures <.Future>` instead. This class will be removed in 6.0. - """ - def __init__(self, key): - warnings.warn("gen.Wait is deprecated, use Futures instead", - DeprecationWarning) - self.key = key - - def start(self, runner): - self.runner = runner - - def is_ready(self): - return self.runner.is_ready(self.key) - - def get_result(self): - return self.runner.pop_result(self.key) - - -class WaitAll(YieldPoint): - """Returns the results of multiple previous `Callbacks `. - - The argument is a sequence of `Callback` keys, and the result is - a list of results in the same order. - - `WaitAll` is equivalent to yielding a list of `Wait` objects. - - .. deprecated:: 4.0 - Use `Futures <.Future>` instead. This class will be removed in 6.0. - """ - def __init__(self, keys): - warnings.warn("gen.WaitAll is deprecated, use gen.multi instead", - DeprecationWarning) - self.keys = keys - - def start(self, runner): - self.runner = runner - - def is_ready(self): - return all(self.runner.is_ready(key) for key in self.keys) - - def get_result(self): - return [self.runner.pop_result(key) for key in self.keys] - - -def Task(func, *args, **kwargs): - """Adapts a callback-based asynchronous function for use in coroutines. - - Takes a function (and optional additional arguments) and runs it with - those arguments plus a ``callback`` keyword argument. The argument passed - to the callback is returned as the result of the yield expression. - - .. versionchanged:: 4.0 - ``gen.Task`` is now a function that returns a `.Future`, instead of - a subclass of `YieldPoint`. It still behaves the same way when - yielded. - - .. deprecated:: 5.1 - This function is deprecated and will be removed in 6.0. - """ - warnings.warn("gen.Task is deprecated, use Futures instead", - DeprecationWarning) - future = _create_future() - - def handle_exception(typ, value, tb): - if future.done(): - return False - future_set_exc_info(future, (typ, value, tb)) - return True - - def set_result(result): - if future.done(): - return - future_set_result_unless_cancelled(future, result) - with stack_context.ExceptionStackContext(handle_exception): - func(*args, callback=_argument_adapter(set_result), **kwargs) - return future - - -class YieldFuture(YieldPoint): - def __init__(self, future): - """Adapts a `.Future` to the `YieldPoint` interface. - - .. versionchanged:: 5.0 - The ``io_loop`` argument (deprecated since version 4.1) has been removed. - - .. deprecated:: 5.1 - This class will be removed in 6.0. - """ - warnings.warn("YieldFuture is deprecated, use Futures instead", - DeprecationWarning) - self.future = future - self.io_loop = IOLoop.current() - - def start(self, runner): - if not self.future.done(): - self.runner = runner - self.key = object() - runner.register_callback(self.key) - self.io_loop.add_future(self.future, runner.result_callback(self.key)) - else: - self.runner = None - self.result_fn = self.future.result - - def is_ready(self): - if self.runner is not None: - return self.runner.is_ready(self.key) - else: - return True - - def get_result(self): - if self.runner is not None: - return self.runner.pop_result(self.key).result() - else: - return self.result_fn() - - -def _contains_yieldpoint(children): - """Returns True if ``children`` contains any YieldPoints. - - ``children`` may be a dict or a list, as used by `MultiYieldPoint` - and `multi_future`. - """ - if isinstance(children, dict): - return any(isinstance(i, YieldPoint) for i in children.values()) - if isinstance(children, list): - return any(isinstance(i, YieldPoint) for i in children) - return False - - -def multi(children, quiet_exceptions=()): +def multi( + children: Union[List[_Yieldable], Dict[Any, _Yieldable]], + quiet_exceptions: "Union[Type[Exception], Tuple[Type[Exception], ...]]" = (), +) -> "Union[Future[List], Future[Dict]]": """Runs multiple asynchronous operations in parallel. ``children`` may either be a list or a dict whose values are @@ -738,11 +459,6 @@ def multi(children, quiet_exceptions=()): one. All others will be logged, unless they are of types contained in the ``quiet_exceptions`` argument. - If any of the inputs are `YieldPoints `, the returned - yieldable object is a `YieldPoint`. Otherwise, returns a `.Future`. - This means that the result of `multi` can be used in a native - coroutine if and only if all of its children can be. - In a ``yield``-based coroutine, it is not normally necessary to call this function directly, since the coroutine runner will do it automatically when a list or dict is yielded. However, @@ -764,91 +480,22 @@ def multi(children, quiet_exceptions=()): .. versionchanged:: 4.3 Replaced the class ``Multi`` and the function ``multi_future`` with a unified function ``multi``. Added support for yieldables - other than `YieldPoint` and `.Future`. + other than ``YieldPoint`` and `.Future`. """ - if _contains_yieldpoint(children): - return MultiYieldPoint(children, quiet_exceptions=quiet_exceptions) - else: - return multi_future(children, quiet_exceptions=quiet_exceptions) + return multi_future(children, quiet_exceptions=quiet_exceptions) Multi = multi -class MultiYieldPoint(YieldPoint): - """Runs multiple asynchronous operations in parallel. - - This class is similar to `multi`, but it always creates a stack - context even when no children require it. It is not compatible with - native coroutines. - - .. versionchanged:: 4.2 - If multiple ``YieldPoints`` fail, any exceptions after the first - (which is raised) will be logged. Added the ``quiet_exceptions`` - argument to suppress this logging for selected exception types. - - .. versionchanged:: 4.3 - Renamed from ``Multi`` to ``MultiYieldPoint``. The name ``Multi`` - remains as an alias for the equivalent `multi` function. - - .. deprecated:: 4.3 - Use `multi` instead. This class will be removed in 6.0. - """ - def __init__(self, children, quiet_exceptions=()): - warnings.warn("MultiYieldPoint is deprecated, use Futures instead", - DeprecationWarning) - self.keys = None - if isinstance(children, dict): - self.keys = list(children.keys()) - children = children.values() - self.children = [] - for i in children: - if not isinstance(i, YieldPoint): - i = convert_yielded(i) - if is_future(i): - i = YieldFuture(i) - self.children.append(i) - assert all(isinstance(i, YieldPoint) for i in self.children) - self.unfinished_children = set(self.children) - self.quiet_exceptions = quiet_exceptions - - def start(self, runner): - for i in self.children: - i.start(runner) - - def is_ready(self): - finished = list(itertools.takewhile( - lambda i: i.is_ready(), self.unfinished_children)) - self.unfinished_children.difference_update(finished) - return not self.unfinished_children - - def get_result(self): - result_list = [] - exc_info = None - for f in self.children: - try: - result_list.append(f.get_result()) - except Exception as e: - if exc_info is None: - exc_info = sys.exc_info() - else: - if not isinstance(e, self.quiet_exceptions): - app_log.error("Multiple exceptions in yield list", - exc_info=True) - if exc_info is not None: - raise_exc_info(exc_info) - if self.keys is not None: - return dict(zip(self.keys, result_list)) - else: - return list(result_list) - - -def multi_future(children, quiet_exceptions=()): +def multi_future( + children: Union[List[_Yieldable], Dict[Any, _Yieldable]], + quiet_exceptions: "Union[Type[Exception], Tuple[Type[Exception], ...]]" = (), +) -> "Union[Future[List], Future[Dict]]": """Wait for multiple asynchronous futures in parallel. - This function is similar to `multi`, but does not support - `YieldPoints `. + Since Tornado 6.0, this function is exactly the same as `multi`. .. versionadded:: 4.0 @@ -861,49 +508,51 @@ def multi_future(children, quiet_exceptions=()): Use `multi` instead. """ if isinstance(children, dict): - keys = list(children.keys()) - children = children.values() + keys = list(children.keys()) # type: Optional[List] + children_seq = children.values() # type: Iterable else: keys = None - children = list(map(convert_yielded, children)) - assert all(is_future(i) or isinstance(i, _NullFuture) for i in children) - unfinished_children = set(children) + children_seq = children + children_futs = list(map(convert_yielded, children_seq)) + assert all(is_future(i) or isinstance(i, _NullFuture) for i in children_futs) + unfinished_children = set(children_futs) future = _create_future() - if not children: - future_set_result_unless_cancelled(future, - {} if keys is not None else []) + if not children_futs: + future_set_result_unless_cancelled(future, {} if keys is not None else []) - def callback(f): - unfinished_children.remove(f) + def callback(fut: Future) -> None: + unfinished_children.remove(fut) if not unfinished_children: result_list = [] - for f in children: + for f in children_futs: try: result_list.append(f.result()) except Exception as e: if future.done(): if not isinstance(e, quiet_exceptions): - app_log.error("Multiple exceptions in yield list", - exc_info=True) + app_log.error( + "Multiple exceptions in yield list", exc_info=True + ) else: future_set_exc_info(future, sys.exc_info()) if not future.done(): if keys is not None: - future_set_result_unless_cancelled(future, - dict(zip(keys, result_list))) + future_set_result_unless_cancelled( + future, dict(zip(keys, result_list)) + ) else: future_set_result_unless_cancelled(future, result_list) - listening = set() - for f in children: + listening = set() # type: Set[Future] + for f in children_futs: if f not in listening: listening.add(f) future_add_done_callback(f, callback) return future -def maybe_future(x): +def maybe_future(x: Any) -> Future: """Converts ``x`` into a `.Future`. If ``x`` is already a `.Future`, it is simply returned; otherwise @@ -924,7 +573,11 @@ def maybe_future(x): return fut -def with_timeout(timeout, future, quiet_exceptions=()): +def with_timeout( + timeout: Union[float, datetime.timedelta], + future: _Yieldable, + quiet_exceptions: "Union[Type[Exception], Tuple[Type[Exception], ...]]" = (), +) -> Future: """Wraps a `.Future` (or other yieldable object) in a timeout. Raises `tornado.util.TimeoutError` if the input future does not @@ -933,10 +586,9 @@ def with_timeout(timeout, future, quiet_exceptions=()): an absolute time relative to `.IOLoop.time`) If the wrapped `.Future` fails after it has timed out, the exception - will be logged unless it is of a type contained in ``quiet_exceptions`` - (which may be an exception type or a sequence of types). - - Does not support `YieldPoint` subclasses. + will be logged unless it is either of a type contained in + ``quiet_exceptions`` (which may be an exception type or a sequence of + types), or an ``asyncio.CancelledError``. The wrapped `.Future` is not canceled when the timeout expires, permitting it to be reused. `asyncio.wait_for` is similar to this @@ -951,50 +603,55 @@ def with_timeout(timeout, future, quiet_exceptions=()): .. versionchanged:: 4.4 Added support for yieldable objects other than `.Future`. + .. versionchanged:: 6.0.3 + ``asyncio.CancelledError`` is now always considered "quiet". + """ - # TODO: allow YieldPoints in addition to other yieldables? - # Tricky to do with stack_context semantics. - # # It's tempting to optimize this by cancelling the input future on timeout # instead of creating a new one, but A) we can't know if we are the only # one waiting on the input future, so cancelling it might disrupt other # callers and B) concurrent futures can only be cancelled while they are # in the queue, so cancellation cannot reliably bound our waiting time. - future = convert_yielded(future) + future_converted = convert_yielded(future) result = _create_future() - chain_future(future, result) + chain_future(future_converted, result) io_loop = IOLoop.current() - def error_callback(future): + def error_callback(future: Future) -> None: try: future.result() + except asyncio.CancelledError: + pass except Exception as e: if not isinstance(e, quiet_exceptions): - app_log.error("Exception in Future %r after timeout", - future, exc_info=True) + app_log.error( + "Exception in Future %r after timeout", future, exc_info=True + ) - def timeout_callback(): + def timeout_callback() -> None: if not result.done(): result.set_exception(TimeoutError("Timeout")) # In case the wrapped future goes on to fail, log it. - future_add_done_callback(future, error_callback) - timeout_handle = io_loop.add_timeout( - timeout, timeout_callback) - if isinstance(future, Future): + future_add_done_callback(future_converted, error_callback) + + timeout_handle = io_loop.add_timeout(timeout, timeout_callback) + if isinstance(future_converted, Future): # We know this future will resolve on the IOLoop, so we don't # need the extra thread-safety of IOLoop.add_future (and we also # don't care about StackContext here. future_add_done_callback( - future, lambda future: io_loop.remove_timeout(timeout_handle)) + future_converted, lambda future: io_loop.remove_timeout(timeout_handle) + ) else: # concurrent.futures.Futures may resolve on any thread, so we # need to route them back to the IOLoop. io_loop.add_future( - future, lambda future: io_loop.remove_timeout(timeout_handle)) + future_converted, lambda future: io_loop.remove_timeout(timeout_handle) + ) return result -def sleep(duration): +def sleep(duration: float) -> "Future[None]": """Return a `.Future` that resolves after the given number of seconds. When used with ``yield`` in a coroutine, this is a non-blocking @@ -1009,8 +666,9 @@ def sleep(duration): .. versionadded:: 4.1 """ f = _create_future() - IOLoop.current().call_later(duration, - lambda: future_set_result_unless_cancelled(f, None)) + IOLoop.current().call_later( + duration, lambda: future_set_result_unless_cancelled(f, None) + ) return f @@ -1019,22 +677,28 @@ class _NullFuture(object): It's not actually a `Future` to avoid depending on a particular event loop. Handled as a special case in the coroutine runner. + + We lie and tell the type checker that a _NullFuture is a Future so + we don't have to leak _NullFuture into lots of public APIs. But + this means that the type checker can't warn us when we're passing + a _NullFuture into a code path that doesn't understand what to do + with it. """ - def result(self): + + def result(self) -> None: return None - def done(self): + def done(self) -> bool: return True # _null_future is used as a dummy value in the coroutine runner. It differs # from moment in that moment always adds a delay of one IOLoop iteration # while _null_future is processed as soon as possible. -_null_future = _NullFuture() +_null_future = typing.cast(Future, _NullFuture()) -moment = _NullFuture() -moment.__doc__ = \ - """A special object which may be yielded to allow the IOLoop to run for +moment = typing.cast(Future, _NullFuture()) +moment.__doc__ = """A special object which may be yielded to allow the IOLoop to run for one iteration. This is not needed in normal use but it can be helpful in long-running @@ -1042,6 +706,9 @@ def done(self): Usage: ``yield gen.moment`` +In native coroutines, the equivalent of ``yield gen.moment`` is +``await asyncio.sleep(0)``. + .. versionadded:: 4.0 .. deprecated:: 4.5 @@ -1051,68 +718,33 @@ def done(self): class Runner(object): - """Internal implementation of `tornado.gen.engine`. + """Internal implementation of `tornado.gen.coroutine`. Maintains information about pending callbacks and their results. The results of the generator are stored in ``result_future`` (a `.Future`) """ - def __init__(self, gen, result_future, first_yielded): + + def __init__( + self, + ctx_run: Callable, + gen: "Generator[_Yieldable, Any, _T]", + result_future: "Future[_T]", + first_yielded: _Yieldable, + ) -> None: + self.ctx_run = ctx_run self.gen = gen self.result_future = result_future - self.future = _null_future - self.yield_point = None - self.pending_callbacks = None - self.results = None + self.future = _null_future # type: Union[None, Future] self.running = False self.finished = False - self.had_exception = False self.io_loop = IOLoop.current() - # For efficiency, we do not create a stack context until we - # reach a YieldPoint (stack contexts are required for the historical - # semantics of YieldPoints, but not for Futures). When we have - # done so, this field will be set and must be called at the end - # of the coroutine. - self.stack_context_deactivate = None if self.handle_yield(first_yielded): - gen = result_future = first_yielded = None - self.run() - - def register_callback(self, key): - """Adds ``key`` to the list of callbacks.""" - if self.pending_callbacks is None: - # Lazily initialize the old-style YieldPoint data structures. - self.pending_callbacks = set() - self.results = {} - if key in self.pending_callbacks: - raise KeyReuseError("key %r is already pending" % (key,)) - self.pending_callbacks.add(key) - - def is_ready(self, key): - """Returns true if a result is available for ``key``.""" - if self.pending_callbacks is None or key not in self.pending_callbacks: - raise UnknownKeyError("key %r is not pending" % (key,)) - return key in self.results - - def set_result(self, key, result): - """Sets the result for ``key`` and attempts to resume the generator.""" - self.results[key] = result - if self.yield_point is not None and self.yield_point.is_ready(): - try: - future_set_result_unless_cancelled(self.future, - self.yield_point.get_result()) - except: - future_set_exc_info(self.future, sys.exc_info()) - self.yield_point = None - self.run() - - def pop_result(self, key): - """Returns the result for ``key`` and unregisters it.""" - self.pending_callbacks.remove(key) - return self.results.pop(key) - - def run(self): + gen = result_future = first_yielded = None # type: ignore + self.ctx_run(self.run) + + def run(self) -> None: """Starts or resumes the generator, running until it reaches a yield point that is not ready. """ @@ -1122,23 +754,23 @@ def run(self): self.running = True while True: future = self.future + if future is None: + raise Exception("No pending future") if not future.done(): return self.future = None try: - orig_stack_contexts = stack_context._state.contexts exc_info = None try: value = future.result() except Exception: - self.had_exception = True exc_info = sys.exc_info() future = None if exc_info is not None: try: - yielded = self.gen.throw(*exc_info) + yielded = self.gen.throw(*exc_info) # type: ignore finally: # Break up a reference to itself # for faster GC on CPython. @@ -1146,33 +778,19 @@ def run(self): else: yielded = self.gen.send(value) - if stack_context._state.contexts is not orig_stack_contexts: - self.gen.throw( - stack_context.StackContextInconsistentError( - 'stack_context inconsistency (probably caused ' - 'by yield within a "with StackContext" block)')) except (StopIteration, Return) as e: self.finished = True self.future = _null_future - if self.pending_callbacks and not self.had_exception: - # If we ran cleanly without waiting on all callbacks - # raise an error (really more of a warning). If we - # had an exception then some callbacks may have been - # orphaned, so skip the check in that case. - raise LeakedCallbackError( - "finished without waiting for callbacks %r" % - self.pending_callbacks) - future_set_result_unless_cancelled(self.result_future, - _value_from_stopiteration(e)) - self.result_future = None - self._deactivate_stack_context() + future_set_result_unless_cancelled( + self.result_future, _value_from_stopiteration(e) + ) + self.result_future = None # type: ignore return except Exception: self.finished = True self.future = _null_future future_set_exc_info(self.result_future, sys.exc_info()) - self.result_future = None - self._deactivate_stack_context() + self.result_future = None # type: ignore return if not self.handle_yield(yielded): return @@ -1180,164 +798,56 @@ def run(self): finally: self.running = False - def handle_yield(self, yielded): - # Lists containing YieldPoints require stack contexts; - # other lists are handled in convert_yielded. - if _contains_yieldpoint(yielded): - yielded = multi(yielded) - - if isinstance(yielded, YieldPoint): - # YieldPoints are too closely coupled to the Runner to go - # through the generic convert_yielded mechanism. + def handle_yield(self, yielded: _Yieldable) -> bool: + try: + self.future = convert_yielded(yielded) + except BadYieldError: self.future = Future() - - def start_yield_point(): - try: - yielded.start(self) - if yielded.is_ready(): - future_set_result_unless_cancelled(self.future, yielded.get_result()) - else: - self.yield_point = yielded - except Exception: - self.future = Future() - future_set_exc_info(self.future, sys.exc_info()) - - if self.stack_context_deactivate is None: - # Start a stack context if this is the first - # YieldPoint we've seen. - with stack_context.ExceptionStackContext( - self.handle_exception) as deactivate: - self.stack_context_deactivate = deactivate - - def cb(): - start_yield_point() - self.run() - self.io_loop.add_callback(cb) - return False - else: - start_yield_point() - else: - try: - self.future = convert_yielded(yielded) - except BadYieldError: - self.future = Future() - future_set_exc_info(self.future, sys.exc_info()) + future_set_exc_info(self.future, sys.exc_info()) if self.future is moment: - self.io_loop.add_callback(self.run) + self.io_loop.add_callback(self.ctx_run, self.run) return False + elif self.future is None: + raise Exception("no pending future") elif not self.future.done(): - def inner(f): + + def inner(f: Any) -> None: # Break a reference cycle to speed GC. - f = None # noqa - self.run() - self.io_loop.add_future( - self.future, inner) + f = None # noqa: F841 + self.ctx_run(self.run) + + self.io_loop.add_future(self.future, inner) return False return True - def result_callback(self, key): - return stack_context.wrap(_argument_adapter( - functools.partial(self.set_result, key))) - - def handle_exception(self, typ, value, tb): + def handle_exception( + self, typ: Type[Exception], value: Exception, tb: types.TracebackType + ) -> bool: if not self.running and not self.finished: self.future = Future() future_set_exc_info(self.future, (typ, value, tb)) - self.run() + self.ctx_run(self.run) return True else: return False - def _deactivate_stack_context(self): - if self.stack_context_deactivate is not None: - self.stack_context_deactivate() - self.stack_context_deactivate = None - - -Arguments = collections.namedtuple('Arguments', ['args', 'kwargs']) - - -def _argument_adapter(callback): - """Returns a function that when invoked runs ``callback`` with one arg. - - If the function returned by this function is called with exactly - one argument, that argument is passed to ``callback``. Otherwise - the args tuple and kwargs dict are wrapped in an `Arguments` object. - """ - def wrapper(*args, **kwargs): - if kwargs or len(args) > 1: - callback(Arguments(args, kwargs)) - elif args: - callback(args[0]) - else: - callback(None) - return wrapper - # Convert Awaitables into Futures. try: - import asyncio -except ImportError: - # Py2-compatible version for use with Cython. - # Copied from PEP 380. - @coroutine - def _wrap_awaitable(x): - if hasattr(x, '__await__'): - _i = x.__await__() - else: - _i = iter(x) - try: - _y = next(_i) - except StopIteration as _e: - _r = _value_from_stopiteration(_e) - else: - while 1: - try: - _s = yield _y - except GeneratorExit as _e: - try: - _m = _i.close - except AttributeError: - pass - else: - _m() - raise _e - except BaseException as _e: - _x = sys.exc_info() - try: - _m = _i.throw - except AttributeError: - raise _e - else: - try: - _y = _m(*_x) - except StopIteration as _e: - _r = _value_from_stopiteration(_e) - break - else: - try: - if _s is None: - _y = next(_i) - else: - _y = _i.send(_s) - except StopIteration as _e: - _r = _value_from_stopiteration(_e) - break - raise Return(_r) -else: - try: - _wrap_awaitable = asyncio.ensure_future - except AttributeError: - # asyncio.ensure_future was introduced in Python 3.4.4, but - # Debian jessie still ships with 3.4.2 so try the old name. - _wrap_awaitable = getattr(asyncio, 'async') + _wrap_awaitable = asyncio.ensure_future +except AttributeError: + # asyncio.ensure_future was introduced in Python 3.4.4, but + # Debian jessie still ships with 3.4.2 so try the old name. + _wrap_awaitable = getattr(asyncio, "async") -def convert_yielded(yielded): +def convert_yielded(yielded: _Yieldable) -> Future: """Convert a yielded object into a `.Future`. - The default implementation accepts lists, dictionaries, and Futures. + The default implementation accepts lists, dictionaries, and + Futures. This has the side effect of starting any coroutines that + did not start themselves, similar to `asyncio.ensure_future`. If the `~functools.singledispatch` library is available, this function may be extended to support additional types. For example:: @@ -1347,21 +857,20 @@ def _(asyncio_future): return tornado.platform.asyncio.to_tornado_future(asyncio_future) .. versionadded:: 4.1 + """ - # Lists and dicts containing YieldPoints were handled earlier. if yielded is None or yielded is moment: return moment elif yielded is _null_future: return _null_future elif isinstance(yielded, (list, dict)): - return multi(yielded) + return multi(yielded) # type: ignore elif is_future(yielded): - return yielded + return typing.cast(Future, yielded) elif isawaitable(yielded): - return _wrap_awaitable(yielded) + return _wrap_awaitable(yielded) # type: ignore else: raise BadYieldError("yielded unknown object %r" % (yielded,)) -if singledispatch is not None: - convert_yielded = singledispatch(convert_yielded) +convert_yielded = singledispatch(convert_yielded) diff --git a/tornado/http1connection.py b/tornado/http1connection.py index 1c5eadf84c..72088d6023 100644 --- a/tornado/http1connection.py +++ b/tornado/http1connection.py @@ -18,23 +18,29 @@ .. versionadded:: 4.0 """ -from __future__ import absolute_import, division, print_function - +import asyncio +import logging import re +import types -from tornado.concurrent import (Future, future_add_done_callback, - future_set_result_unless_cancelled) +from tornado.concurrent import ( + Future, + future_add_done_callback, + future_set_result_unless_cancelled, +) from tornado.escape import native_str, utf8 from tornado import gen from tornado import httputil from tornado import iostream from tornado.log import gen_log, app_log -from tornado import stack_context -from tornado.util import GzipDecompressor, PY3 +from tornado.util import GzipDecompressor + + +from typing import cast, Optional, Type, Awaitable, Callable, Union, Tuple class _QuietException(Exception): - def __init__(self): + def __init__(self) -> None: pass @@ -43,24 +49,38 @@ class _ExceptionLoggingContext(object): log any exceptions with the given logger. Any exceptions caught are converted to _QuietException """ - def __init__(self, logger): + + def __init__(self, logger: logging.Logger) -> None: self.logger = logger - def __enter__(self): + def __enter__(self) -> None: pass - def __exit__(self, typ, value, tb): + def __exit__( + self, + typ: "Optional[Type[BaseException]]", + value: Optional[BaseException], + tb: types.TracebackType, + ) -> None: if value is not None: + assert typ is not None self.logger.error("Uncaught exception", exc_info=(typ, value, tb)) raise _QuietException class HTTP1ConnectionParameters(object): - """Parameters for `.HTTP1Connection` and `.HTTP1ServerConnection`. - """ - def __init__(self, no_keep_alive=False, chunk_size=None, - max_header_size=None, header_timeout=None, max_body_size=None, - body_timeout=None, decompress=False): + """Parameters for `.HTTP1Connection` and `.HTTP1ServerConnection`.""" + + def __init__( + self, + no_keep_alive: bool = False, + chunk_size: Optional[int] = None, + max_header_size: Optional[int] = None, + header_timeout: Optional[float] = None, + max_body_size: Optional[int] = None, + body_timeout: Optional[float] = None, + decompress: bool = False, + ) -> None: """ :arg bool no_keep_alive: If true, always close the connection after one request. @@ -87,7 +107,14 @@ class HTTP1Connection(httputil.HTTPConnection): This class can be on its own for clients, or via `HTTP1ServerConnection` for servers. """ - def __init__(self, stream, is_client, params=None, context=None): + + def __init__( + self, + stream: iostream.IOStream, + is_client: bool, + params: Optional[HTTP1ConnectionParameters] = None, + context: Optional[object] = None, + ) -> None: """ :arg stream: an `.IOStream` :arg bool is_client: client or server @@ -104,8 +131,11 @@ def __init__(self, stream, is_client, params=None, context=None): self.no_keep_alive = params.no_keep_alive # The body limits can be altered by the delegate, so save them # here instead of just referencing self.params later. - self._max_body_size = (self.params.max_body_size or - self.stream.max_buffer_size) + self._max_body_size = ( + self.params.max_body_size + if self.params.max_body_size is not None + else self.stream.max_buffer_size + ) self._body_timeout = self.params.body_timeout # _write_finished is set to True when finish() has been called, # i.e. there will be no more data sent. Data may still be in the @@ -115,7 +145,7 @@ def __init__(self, stream, is_client, params=None, context=None): self._read_finished = False # _finish_future resolves when all data has been written and flushed # to the IOStream. - self._finish_future = Future() + self._finish_future = Future() # type: Future[None] # If true, the connection should be closed after this request # (after the response has been written in the server side, # and after it has been read in the client) @@ -124,18 +154,18 @@ def __init__(self, stream, is_client, params=None, context=None): # Save the start lines after we read or write them; they # affect later processing (e.g. 304 responses and HEAD methods # have content-length but no bodies) - self._request_start_line = None - self._response_start_line = None - self._request_headers = None + self._request_start_line = None # type: Optional[httputil.RequestStartLine] + self._response_start_line = None # type: Optional[httputil.ResponseStartLine] + self._request_headers = None # type: Optional[httputil.HTTPHeaders] # True if we are writing output with chunked encoding. - self._chunking_output = None + self._chunking_output = False # While reading a body with a content-length, this is the # amount left to read. - self._expected_content_remaining = None + self._expected_content_remaining = None # type: Optional[int] # A Future for our outgoing writes, returned by IOStream.write. - self._pending_write = None + self._pending_write = None # type: Optional[Future[None]] - def read_response(self, delegate): + def read_response(self, delegate: httputil.HTTPMessageDelegate) -> Awaitable[bool]: """Read a single HTTP response. Typical client-mode usage is to write a request using `write_headers`, @@ -143,55 +173,64 @@ def read_response(self, delegate): :arg delegate: a `.HTTPMessageDelegate` - Returns a `.Future` that resolves to None after the full response has - been read. + Returns a `.Future` that resolves to a bool after the full response has + been read. The result is true if the stream is still open. """ if self.params.decompress: delegate = _GzipMessageDelegate(delegate, self.params.chunk_size) return self._read_message(delegate) - @gen.coroutine - def _read_message(self, delegate): + async def _read_message(self, delegate: httputil.HTTPMessageDelegate) -> bool: need_delegate_close = False try: header_future = self.stream.read_until_regex( - b"\r?\n\r?\n", - max_bytes=self.params.max_header_size) + b"\r?\n\r?\n", max_bytes=self.params.max_header_size + ) if self.params.header_timeout is None: - header_data = yield header_future + header_data = await header_future else: try: - header_data = yield gen.with_timeout( + header_data = await gen.with_timeout( self.stream.io_loop.time() + self.params.header_timeout, header_future, - quiet_exceptions=iostream.StreamClosedError) + quiet_exceptions=iostream.StreamClosedError, + ) except gen.TimeoutError: self.close() - raise gen.Return(False) - start_line, headers = self._parse_headers(header_data) + return False + start_line_str, headers = self._parse_headers(header_data) if self.is_client: - start_line = httputil.parse_response_start_line(start_line) - self._response_start_line = start_line + resp_start_line = httputil.parse_response_start_line(start_line_str) + self._response_start_line = resp_start_line + start_line = ( + resp_start_line + ) # type: Union[httputil.RequestStartLine, httputil.ResponseStartLine] + # TODO: this will need to change to support client-side keepalive + self._disconnect_on_finish = False else: - start_line = httputil.parse_request_start_line(start_line) - self._request_start_line = start_line + req_start_line = httputil.parse_request_start_line(start_line_str) + self._request_start_line = req_start_line self._request_headers = headers - - self._disconnect_on_finish = not self._can_keep_alive( - start_line, headers) + start_line = req_start_line + self._disconnect_on_finish = not self._can_keep_alive( + req_start_line, headers + ) need_delegate_close = True with _ExceptionLoggingContext(app_log): - header_future = delegate.headers_received(start_line, headers) - if header_future is not None: - yield header_future + header_recv_future = delegate.headers_received(start_line, headers) + if header_recv_future is not None: + await header_recv_future if self.stream is None: # We've been detached. need_delegate_close = False - raise gen.Return(False) + return False skip_body = False if self.is_client: - if (self._request_start_line is not None and - self._request_start_line.method == 'HEAD'): + assert isinstance(start_line, httputil.ResponseStartLine) + if ( + self._request_start_line is not None + and self._request_start_line.method == "HEAD" + ): skip_body = True code = start_line.code if code == 304: @@ -199,37 +238,37 @@ def _read_message(self, delegate): # but do not actually have a body. # http://tools.ietf.org/html/rfc7230#section-3.3 skip_body = True - if code >= 100 and code < 200: + if 100 <= code < 200: # 1xx responses should never indicate the presence of # a body. - if ('Content-Length' in headers or - 'Transfer-Encoding' in headers): + if "Content-Length" in headers or "Transfer-Encoding" in headers: raise httputil.HTTPInputError( - "Response code %d cannot have body" % code) + "Response code %d cannot have body" % code + ) # TODO: client delegates will get headers_received twice # in the case of a 100-continue. Document or change? - yield self._read_message(delegate) + await self._read_message(delegate) else: - if (headers.get("Expect") == "100-continue" and - not self._write_finished): + if headers.get("Expect") == "100-continue" and not self._write_finished: self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n") if not skip_body: body_future = self._read_body( - start_line.code if self.is_client else 0, headers, delegate) + resp_start_line.code if self.is_client else 0, headers, delegate + ) if body_future is not None: if self._body_timeout is None: - yield body_future + await body_future else: try: - yield gen.with_timeout( + await gen.with_timeout( self.stream.io_loop.time() + self._body_timeout, body_future, - quiet_exceptions=iostream.StreamClosedError) + quiet_exceptions=iostream.StreamClosedError, + ) except gen.TimeoutError: - gen_log.info("Timeout reading body from %s", - self.context) + gen_log.info("Timeout reading body from %s", self.context) self.stream.close() - raise gen.Return(False) + return False self._read_finished = True if not self._write_finished or self.is_client: need_delegate_close = False @@ -238,57 +277,58 @@ def _read_message(self, delegate): # If we're waiting for the application to produce an asynchronous # response, and we're not detached, register a close callback # on the stream (we didn't need one while we were reading) - if (not self._finish_future.done() and - self.stream is not None and - not self.stream.closed()): + if ( + not self._finish_future.done() + and self.stream is not None + and not self.stream.closed() + ): self.stream.set_close_callback(self._on_connection_close) - yield self._finish_future + await self._finish_future if self.is_client and self._disconnect_on_finish: self.close() if self.stream is None: - raise gen.Return(False) + return False except httputil.HTTPInputError as e: - gen_log.info("Malformed HTTP message from %s: %s", - self.context, e) + gen_log.info("Malformed HTTP message from %s: %s", self.context, e) if not self.is_client: - yield self.stream.write(b'HTTP/1.1 400 Bad Request\r\n\r\n') + await self.stream.write(b"HTTP/1.1 400 Bad Request\r\n\r\n") self.close() - raise gen.Return(False) + return False finally: if need_delegate_close: with _ExceptionLoggingContext(app_log): delegate.on_connection_close() - header_future = None + header_future = None # type: ignore self._clear_callbacks() - raise gen.Return(True) + return True - def _clear_callbacks(self): + def _clear_callbacks(self) -> None: """Clears the callback attributes. This allows the request handler to be garbage collected more quickly in CPython by breaking up reference cycles. """ self._write_callback = None - self._write_future = None - self._close_callback = None + self._write_future = None # type: Optional[Future[None]] + self._close_callback = None # type: Optional[Callable[[], None]] if self.stream is not None: self.stream.set_close_callback(None) - def set_close_callback(self, callback): + def set_close_callback(self, callback: Optional[Callable[[], None]]) -> None: """Sets a callback that will be run when the connection is closed. Note that this callback is slightly different from `.HTTPMessageDelegate.on_connection_close`: The `.HTTPMessageDelegate` method is called when the connection is - closed while recieving a message. This callback is used when + closed while receiving a message. This callback is used when there is not an active delegate (for example, on the server side this callback is used if the client closes the connection after sending its request but before receiving all the response. """ - self._close_callback = stack_context.wrap(callback) + self._close_callback = callback - def _on_connection_close(self): + def _on_connection_close(self) -> None: # Note that this callback is only registered on the IOStream # when we have finished reading the request and are waiting for # the application to produce its response. @@ -300,14 +340,14 @@ def _on_connection_close(self): future_set_result_unless_cancelled(self._finish_future, None) self._clear_callbacks() - def close(self): + def close(self) -> None: if self.stream is not None: self.stream.close() self._clear_callbacks() if not self._finish_future.done(): future_set_result_unless_cancelled(self._finish_future, None) - def detach(self): + def detach(self) -> iostream.IOStream: """Take control of the underlying stream. Returns the underlying `.IOStream` object and stops all further @@ -317,108 +357,127 @@ def detach(self): """ self._clear_callbacks() stream = self.stream - self.stream = None + self.stream = None # type: ignore if not self._finish_future.done(): future_set_result_unless_cancelled(self._finish_future, None) return stream - def set_body_timeout(self, timeout): + def set_body_timeout(self, timeout: float) -> None: """Sets the body timeout for a single request. Overrides the value from `.HTTP1ConnectionParameters`. """ self._body_timeout = timeout - def set_max_body_size(self, max_body_size): + def set_max_body_size(self, max_body_size: int) -> None: """Sets the body size limit for a single request. Overrides the value from `.HTTP1ConnectionParameters`. """ self._max_body_size = max_body_size - def write_headers(self, start_line, headers, chunk=None, callback=None): + def write_headers( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + chunk: Optional[bytes] = None, + ) -> "Future[None]": """Implements `.HTTPConnection.write_headers`.""" lines = [] if self.is_client: + assert isinstance(start_line, httputil.RequestStartLine) self._request_start_line = start_line - lines.append(utf8('%s %s HTTP/1.1' % (start_line[0], start_line[1]))) + lines.append(utf8("%s %s HTTP/1.1" % (start_line[0], start_line[1]))) # Client requests with a non-empty body must have either a # Content-Length or a Transfer-Encoding. self._chunking_output = ( - start_line.method in ('POST', 'PUT', 'PATCH') and - 'Content-Length' not in headers and - 'Transfer-Encoding' not in headers) + start_line.method in ("POST", "PUT", "PATCH") + and "Content-Length" not in headers + and ( + "Transfer-Encoding" not in headers + or headers["Transfer-Encoding"] == "chunked" + ) + ) else: + assert isinstance(start_line, httputil.ResponseStartLine) + assert self._request_start_line is not None + assert self._request_headers is not None self._response_start_line = start_line - lines.append(utf8('HTTP/1.1 %d %s' % (start_line[1], start_line[2]))) + lines.append(utf8("HTTP/1.1 %d %s" % (start_line[1], start_line[2]))) self._chunking_output = ( # TODO: should this use # self._request_start_line.version or # start_line.version? - self._request_start_line.version == 'HTTP/1.1' and + self._request_start_line.version == "HTTP/1.1" + # Omit payload header field for HEAD request. + and self._request_start_line.method != "HEAD" # 1xx, 204 and 304 responses have no body (not even a zero-length # body), and so should not have either Content-Length or # Transfer-Encoding headers. - start_line.code not in (204, 304) and - (start_line.code < 100 or start_line.code >= 200) and + and start_line.code not in (204, 304) + and (start_line.code < 100 or start_line.code >= 200) # No need to chunk the output if a Content-Length is specified. - 'Content-Length' not in headers and + and "Content-Length" not in headers # Applications are discouraged from touching Transfer-Encoding, # but if they do, leave it alone. - 'Transfer-Encoding' not in headers) + and "Transfer-Encoding" not in headers + ) # If connection to a 1.1 client will be closed, inform client - if (self._request_start_line.version == 'HTTP/1.1' and self._disconnect_on_finish): - headers['Connection'] = 'close' + if ( + self._request_start_line.version == "HTTP/1.1" + and self._disconnect_on_finish + ): + headers["Connection"] = "close" # If a 1.0 client asked for keep-alive, add the header. - if (self._request_start_line.version == 'HTTP/1.0' and - self._request_headers.get('Connection', '').lower() == 'keep-alive'): - headers['Connection'] = 'Keep-Alive' + if ( + self._request_start_line.version == "HTTP/1.0" + and self._request_headers.get("Connection", "").lower() == "keep-alive" + ): + headers["Connection"] = "Keep-Alive" if self._chunking_output: - headers['Transfer-Encoding'] = 'chunked' - if (not self.is_client and - (self._request_start_line.method == 'HEAD' or - start_line.code == 304)): + headers["Transfer-Encoding"] = "chunked" + if not self.is_client and ( + self._request_start_line.method == "HEAD" + or cast(httputil.ResponseStartLine, start_line).code == 304 + ): self._expected_content_remaining = 0 - elif 'Content-Length' in headers: - self._expected_content_remaining = int(headers['Content-Length']) + elif "Content-Length" in headers: + self._expected_content_remaining = int(headers["Content-Length"]) else: self._expected_content_remaining = None # TODO: headers are supposed to be of type str, but we still have some # cases that let bytes slip through. Remove these native_str calls when those # are fixed. - header_lines = (native_str(n) + ": " + native_str(v) for n, v in headers.get_all()) - if PY3: - lines.extend(l.encode('latin1') for l in header_lines) - else: - lines.extend(header_lines) + header_lines = ( + native_str(n) + ": " + native_str(v) for n, v in headers.get_all() + ) + lines.extend(line.encode("latin1") for line in header_lines) for line in lines: - if b'\n' in line: - raise ValueError('Newline in header: ' + repr(line)) + if b"\n" in line: + raise ValueError("Newline in header: " + repr(line)) future = None if self.stream.closed(): future = self._write_future = Future() future.set_exception(iostream.StreamClosedError()) future.exception() else: - if callback is not None: - self._write_callback = stack_context.wrap(callback) - else: - future = self._write_future = Future() + future = self._write_future = Future() data = b"\r\n".join(lines) + b"\r\n\r\n" if chunk: data += self._format_chunk(chunk) self._pending_write = self.stream.write(data) - self._pending_write.add_done_callback(self._on_write_complete) + future_add_done_callback(self._pending_write, self._on_write_complete) return future - def _format_chunk(self, chunk): + def _format_chunk(self, chunk: bytes) -> bytes: if self._expected_content_remaining is not None: self._expected_content_remaining -= len(chunk) if self._expected_content_remaining < 0: # Close the stream now to stop further framing errors. self.stream.close() raise httputil.HTTPOutputError( - "Tried to write more data than Content-Length") + "Tried to write more data than Content-Length" + ) if self._chunking_output and chunk: # Don't write out empty chunks because that means END-OF-STREAM # with chunked encoding @@ -426,7 +485,7 @@ def _format_chunk(self, chunk): else: return chunk - def write(self, chunk, callback=None): + def write(self, chunk: bytes) -> "Future[None]": """Implements `.HTTPConnection.write`. For backwards compatibility it is allowed but deprecated to @@ -439,23 +498,23 @@ def write(self, chunk, callback=None): self._write_future.set_exception(iostream.StreamClosedError()) self._write_future.exception() else: - if callback is not None: - self._write_callback = stack_context.wrap(callback) - else: - future = self._write_future = Future() + future = self._write_future = Future() self._pending_write = self.stream.write(self._format_chunk(chunk)) - self._pending_write.add_done_callback(self._on_write_complete) + future_add_done_callback(self._pending_write, self._on_write_complete) return future - def finish(self): + def finish(self) -> None: """Implements `.HTTPConnection.finish`.""" - if (self._expected_content_remaining is not None and - self._expected_content_remaining != 0 and - not self.stream.closed()): + if ( + self._expected_content_remaining is not None + and self._expected_content_remaining != 0 + and not self.stream.closed() + ): self.stream.close() raise httputil.HTTPOutputError( - "Tried to write %d bytes less than Content-Length" % - self._expected_content_remaining) + "Tried to write %d bytes less than Content-Length" + % self._expected_content_remaining + ) if self._chunking_output: if not self.stream.closed(): self._pending_write = self.stream.write(b"0\r\n\r\n") @@ -476,7 +535,7 @@ def finish(self): else: future_add_done_callback(self._pending_write, self._finish_request) - def _on_write_complete(self, future): + def _on_write_complete(self, future: "Future[None]") -> None: exc = future.exception() if exc is not None and not isinstance(exc, iostream.StreamClosedError): future.result() @@ -489,7 +548,9 @@ def _on_write_complete(self, future): self._write_future = None future_set_result_unless_cancelled(future, None) - def _can_keep_alive(self, start_line, headers): + def _can_keep_alive( + self, start_line: httputil.RequestStartLine, headers: httputil.HTTPHeaders + ) -> bool: if self.params.no_keep_alive: return False connection_header = headers.get("Connection") @@ -497,15 +558,17 @@ def _can_keep_alive(self, start_line, headers): connection_header = connection_header.lower() if start_line.version == "HTTP/1.1": return connection_header != "close" - elif ("Content-Length" in headers or - headers.get("Transfer-Encoding", "").lower() == "chunked" or - getattr(start_line, 'method', None) in ("HEAD", "GET")): + elif ( + "Content-Length" in headers + or headers.get("Transfer-Encoding", "").lower() == "chunked" + or getattr(start_line, "method", None) in ("HEAD", "GET") + ): # start_line may be a request or response start line; only # the former has a method attribute. return connection_header == "keep-alive" return False - def _finish_request(self, future): + def _finish_request(self, future: "Optional[Future[None]]") -> None: self._clear_callbacks() if not self.is_client and self._disconnect_on_finish: self.close() @@ -516,45 +579,54 @@ def _finish_request(self, future): if not self._finish_future.done(): future_set_result_unless_cancelled(self._finish_future, None) - def _parse_headers(self, data): + def _parse_headers(self, data: bytes) -> Tuple[str, httputil.HTTPHeaders]: # The lstrip removes newlines that some implementations sometimes # insert between messages of a reused connection. Per RFC 7230, # we SHOULD ignore at least one empty line before the request. # http://tools.ietf.org/html/rfc7230#section-3.5 - data = native_str(data.decode('latin1')).lstrip("\r\n") + data_str = native_str(data.decode("latin1")).lstrip("\r\n") # RFC 7230 section allows for both CRLF and bare LF. - eol = data.find("\n") - start_line = data[:eol].rstrip("\r") - headers = httputil.HTTPHeaders.parse(data[eol:]) + eol = data_str.find("\n") + start_line = data_str[:eol].rstrip("\r") + headers = httputil.HTTPHeaders.parse(data_str[eol:]) return start_line, headers - def _read_body(self, code, headers, delegate): + def _read_body( + self, + code: int, + headers: httputil.HTTPHeaders, + delegate: httputil.HTTPMessageDelegate, + ) -> Optional[Awaitable[None]]: if "Content-Length" in headers: if "Transfer-Encoding" in headers: # Response cannot contain both Content-Length and # Transfer-Encoding headers. # http://tools.ietf.org/html/rfc7230#section-3.3.3 raise httputil.HTTPInputError( - "Response with both Transfer-Encoding and Content-Length") + "Response with both Transfer-Encoding and Content-Length" + ) if "," in headers["Content-Length"]: # Proxies sometimes cause Content-Length headers to get # duplicated. If all the values are identical then we can # use them but if they differ it's an error. - pieces = re.split(r',\s*', headers["Content-Length"]) + pieces = re.split(r",\s*", headers["Content-Length"]) if any(i != pieces[0] for i in pieces): raise httputil.HTTPInputError( - "Multiple unequal Content-Lengths: %r" % - headers["Content-Length"]) + "Multiple unequal Content-Lengths: %r" + % headers["Content-Length"] + ) headers["Content-Length"] = pieces[0] try: - content_length = int(headers["Content-Length"]) + content_length = int(headers["Content-Length"]) # type: Optional[int] except ValueError: # Handles non-integer Content-Length value. raise httputil.HTTPInputError( - "Only integer Content-Length is allowed: %s" % headers["Content-Length"]) + "Only integer Content-Length is allowed: %s" + % headers["Content-Length"] + ) - if content_length > self._max_body_size: + if cast(int, content_length) > self._max_body_size: raise httputil.HTTPInputError("Content-Length too long") else: content_length = None @@ -563,10 +635,10 @@ def _read_body(self, code, headers, delegate): # This response code is not allowed to have a non-empty body, # and has an implicit length of zero instead of read-until-close. # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3 - if ("Transfer-Encoding" in headers or - content_length not in (None, 0)): + if "Transfer-Encoding" in headers or content_length not in (None, 0): raise httputil.HTTPInputError( - "Response with code %d should not have body" % code) + "Response with code %d should not have body" % code + ) content_length = 0 if content_length is not None: @@ -577,110 +649,133 @@ def _read_body(self, code, headers, delegate): return self._read_body_until_close(delegate) return None - @gen.coroutine - def _read_fixed_body(self, content_length, delegate): + async def _read_fixed_body( + self, content_length: int, delegate: httputil.HTTPMessageDelegate + ) -> None: while content_length > 0: - body = yield self.stream.read_bytes( - min(self.params.chunk_size, content_length), partial=True) + body = await self.stream.read_bytes( + min(self.params.chunk_size, content_length), partial=True + ) content_length -= len(body) if not self._write_finished or self.is_client: with _ExceptionLoggingContext(app_log): ret = delegate.data_received(body) if ret is not None: - yield ret + await ret - @gen.coroutine - def _read_chunked_body(self, delegate): + async def _read_chunked_body(self, delegate: httputil.HTTPMessageDelegate) -> None: # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1 total_size = 0 while True: - chunk_len = yield self.stream.read_until(b"\r\n", max_bytes=64) - chunk_len = int(chunk_len.strip(), 16) + chunk_len_str = await self.stream.read_until(b"\r\n", max_bytes=64) + chunk_len = int(chunk_len_str.strip(), 16) if chunk_len == 0: - crlf = yield self.stream.read_bytes(2) - if crlf != b'\r\n': - raise httputil.HTTPInputError("improperly terminated chunked request") + crlf = await self.stream.read_bytes(2) + if crlf != b"\r\n": + raise httputil.HTTPInputError( + "improperly terminated chunked request" + ) return total_size += chunk_len if total_size > self._max_body_size: raise httputil.HTTPInputError("chunked body too large") bytes_to_read = chunk_len while bytes_to_read: - chunk = yield self.stream.read_bytes( - min(bytes_to_read, self.params.chunk_size), partial=True) + chunk = await self.stream.read_bytes( + min(bytes_to_read, self.params.chunk_size), partial=True + ) bytes_to_read -= len(chunk) if not self._write_finished or self.is_client: with _ExceptionLoggingContext(app_log): ret = delegate.data_received(chunk) if ret is not None: - yield ret + await ret # chunk ends with \r\n - crlf = yield self.stream.read_bytes(2) + crlf = await self.stream.read_bytes(2) assert crlf == b"\r\n" - @gen.coroutine - def _read_body_until_close(self, delegate): - body = yield self.stream.read_until_close() + async def _read_body_until_close( + self, delegate: httputil.HTTPMessageDelegate + ) -> None: + body = await self.stream.read_until_close() if not self._write_finished or self.is_client: with _ExceptionLoggingContext(app_log): - delegate.data_received(body) + ret = delegate.data_received(body) + if ret is not None: + await ret class _GzipMessageDelegate(httputil.HTTPMessageDelegate): - """Wraps an `HTTPMessageDelegate` to decode ``Content-Encoding: gzip``. - """ - def __init__(self, delegate, chunk_size): + """Wraps an `HTTPMessageDelegate` to decode ``Content-Encoding: gzip``.""" + + def __init__(self, delegate: httputil.HTTPMessageDelegate, chunk_size: int) -> None: self._delegate = delegate self._chunk_size = chunk_size - self._decompressor = None + self._decompressor = None # type: Optional[GzipDecompressor] - def headers_received(self, start_line, headers): + def headers_received( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + ) -> Optional[Awaitable[None]]: if headers.get("Content-Encoding") == "gzip": self._decompressor = GzipDecompressor() # Downstream delegates will only see uncompressed data, # so rename the content-encoding header. # (but note that curl_httpclient doesn't do this). - headers.add("X-Consumed-Content-Encoding", - headers["Content-Encoding"]) + headers.add("X-Consumed-Content-Encoding", headers["Content-Encoding"]) del headers["Content-Encoding"] return self._delegate.headers_received(start_line, headers) - @gen.coroutine - def data_received(self, chunk): + async def data_received(self, chunk: bytes) -> None: if self._decompressor: compressed_data = chunk while compressed_data: decompressed = self._decompressor.decompress( - compressed_data, self._chunk_size) + compressed_data, self._chunk_size + ) if decompressed: ret = self._delegate.data_received(decompressed) if ret is not None: - yield ret + await ret compressed_data = self._decompressor.unconsumed_tail + if compressed_data and not decompressed: + raise httputil.HTTPInputError( + "encountered unconsumed gzip data without making progress" + ) else: ret = self._delegate.data_received(chunk) if ret is not None: - yield ret + await ret - def finish(self): + def finish(self) -> None: if self._decompressor is not None: tail = self._decompressor.flush() if tail: - # I believe the tail will always be empty (i.e. - # decompress will return all it can). The purpose - # of the flush call is to detect errors such - # as truncated input. But in case it ever returns - # anything, treat it as an extra chunk - self._delegate.data_received(tail) + # The tail should always be empty: decompress returned + # all that it can in data_received and the only + # purpose of the flush call is to detect errors such + # as truncated input. If we did legitimately get a new + # chunk at this point we'd need to change the + # interface to make finish() a coroutine. + raise ValueError( + "decompressor.flush returned data; possible truncated input" + ) return self._delegate.finish() - def on_connection_close(self): + def on_connection_close(self) -> None: return self._delegate.on_connection_close() class HTTP1ServerConnection(object): """An HTTP/1.x server.""" - def __init__(self, stream, params=None, context=None): + + def __init__( + self, + stream: iostream.IOStream, + params: Optional[HTTP1ConnectionParameters] = None, + context: Optional[object] = None, + ) -> None: """ :arg stream: an `.IOStream` :arg params: a `.HTTP1ConnectionParameters` or None @@ -692,10 +787,9 @@ def __init__(self, stream, params=None, context=None): params = HTTP1ConnectionParameters() self.params = params self.context = context - self._serving_future = None + self._serving_future = None # type: Optional[Future[None]] - @gen.coroutine - def close(self): + async def close(self) -> None: """Closes the connection. Returns a `.Future` that resolves after the serving loop has exited. @@ -703,33 +797,37 @@ def close(self): self.stream.close() # Block until the serving loop is done, but ignore any exceptions # (start_serving is already responsible for logging them). + assert self._serving_future is not None try: - yield self._serving_future + await self._serving_future except Exception: pass - def start_serving(self, delegate): + def start_serving(self, delegate: httputil.HTTPServerConnectionDelegate) -> None: """Starts serving requests on this connection. :arg delegate: a `.HTTPServerConnectionDelegate` """ assert isinstance(delegate, httputil.HTTPServerConnectionDelegate) - self._serving_future = self._server_request_loop(delegate) + fut = gen.convert_yielded(self._server_request_loop(delegate)) + self._serving_future = fut # Register the future on the IOLoop so its errors get logged. - self.stream.io_loop.add_future(self._serving_future, - lambda f: f.result()) + self.stream.io_loop.add_future(fut, lambda f: f.result()) - @gen.coroutine - def _server_request_loop(self, delegate): + async def _server_request_loop( + self, delegate: httputil.HTTPServerConnectionDelegate + ) -> None: try: while True: - conn = HTTP1Connection(self.stream, False, - self.params, self.context) + conn = HTTP1Connection(self.stream, False, self.params, self.context) request_delegate = delegate.start_request(self, conn) try: - ret = yield conn.read_response(request_delegate) - except (iostream.StreamClosedError, - iostream.UnsatisfiableReadError): + ret = await conn.read_response(request_delegate) + except ( + iostream.StreamClosedError, + iostream.UnsatisfiableReadError, + asyncio.CancelledError, + ): return except _QuietException: # This exception was already logged. @@ -741,6 +839,6 @@ def _server_request_loop(self, delegate): return if not ret: return - yield gen.moment + await asyncio.sleep(0) finally: delegate.on_close(self) diff --git a/tornado/httpclient.py b/tornado/httpclient.py index f0a2df8871..3011c371b8 100644 --- a/tornado/httpclient.py +++ b/tornado/httpclient.py @@ -20,8 +20,6 @@ * ``curl_httpclient`` is faster. -* ``curl_httpclient`` was the default prior to Tornado 2.0. - Note that if you are using ``curl_httpclient``, it is highly recommended that you use a recent version of ``libcurl`` and ``pycurl``. Currently the minimum supported version of libcurl is @@ -38,19 +36,25 @@ AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient") """ -from __future__ import absolute_import, division, print_function - +import datetime import functools +from io import BytesIO +import ssl import time -import warnings import weakref -from tornado.concurrent import Future, future_set_result_unless_cancelled +from tornado.concurrent import ( + Future, + future_set_result_unless_cancelled, + future_set_exception_unless_cancelled, +) from tornado.escape import utf8, native_str -from tornado import gen, httputil, stack_context +from tornado import gen, httputil from tornado.ioloop import IOLoop from tornado.util import Configurable +from typing import Type, Any, Union, Dict, Callable, Optional, cast + class HTTPClient(object): """A blocking HTTP client. @@ -81,7 +85,12 @@ class HTTPClient(object): Use `AsyncHTTPClient` instead. """ - def __init__(self, async_client_class=None, **kwargs): + + def __init__( + self, + async_client_class: "Optional[Type[AsyncHTTPClient]]" = None, + **kwargs: Any + ) -> None: # Initialize self._closed at the beginning of the constructor # so that an exception raised here doesn't lead to confusing # failures in __del__. @@ -89,23 +98,30 @@ def __init__(self, async_client_class=None, **kwargs): self._io_loop = IOLoop(make_current=False) if async_client_class is None: async_client_class = AsyncHTTPClient + # Create the client while our IOLoop is "current", without # clobbering the thread's real current IOLoop (if any). - self._async_client = self._io_loop.run_sync( - gen.coroutine(lambda: async_client_class(**kwargs))) + async def make_client() -> "AsyncHTTPClient": + await gen.sleep(0) + assert async_client_class is not None + return async_client_class(**kwargs) + + self._async_client = self._io_loop.run_sync(make_client) self._closed = False - def __del__(self): + def __del__(self) -> None: self.close() - def close(self): + def close(self) -> None: """Closes the HTTPClient, freeing any resources used.""" if not self._closed: self._async_client.close() self._io_loop.close() self._closed = True - def fetch(self, request, **kwargs): + def fetch( + self, request: Union["HTTPRequest", str], **kwargs: Any + ) -> "HTTPResponse": """Executes a request, returning an `HTTPResponse`. The request may be either a string URL or an `HTTPRequest` object. @@ -115,8 +131,9 @@ def fetch(self, request, **kwargs): If an error occurs during the fetch, we raise an `HTTPError` unless the ``raise_error`` keyword argument is set to False. """ - response = self._io_loop.run_sync(functools.partial( - self._async_client.fetch, request, **kwargs)) + response = self._io_loop.run_sync( + functools.partial(self._async_client.fetch, request, **kwargs) + ) return response @@ -125,15 +142,15 @@ class AsyncHTTPClient(Configurable): Example usage:: - def handle_response(response): - if response.error: - print("Error: %s" % response.error) + async def f(): + http_client = AsyncHTTPClient() + try: + response = await http_client.fetch("http://www.google.com") + except Exception as e: + print("Error: %s" % e) else: print(response.body) - http_client = AsyncHTTPClient() - http_client.fetch("http://www.google.com/", handle_response) - The constructor for this class is magic in several respects: It actually creates an instance of an implementation-specific subclass, and instances are reused as a kind of pseudo-singleton @@ -158,23 +175,27 @@ def handle_response(response): The ``io_loop`` argument (deprecated since version 4.1) has been removed. """ + + _instance_cache = None # type: Dict[IOLoop, AsyncHTTPClient] + @classmethod - def configurable_base(cls): + def configurable_base(cls) -> Type[Configurable]: return AsyncHTTPClient @classmethod - def configurable_default(cls): + def configurable_default(cls) -> Type[Configurable]: from tornado.simple_httpclient import SimpleAsyncHTTPClient + return SimpleAsyncHTTPClient @classmethod - def _async_clients(cls): - attr_name = '_async_client_dict_' + cls.__name__ + def _async_clients(cls) -> Dict[IOLoop, "AsyncHTTPClient"]: + attr_name = "_async_client_dict_" + cls.__name__ if not hasattr(cls, attr_name): setattr(cls, attr_name, weakref.WeakKeyDictionary()) return getattr(cls, attr_name) - def __new__(cls, force_instance=False, **kwargs): + def __new__(cls, force_instance: bool = False, **kwargs: Any) -> "AsyncHTTPClient": io_loop = IOLoop.current() if force_instance: instance_cache = None @@ -182,7 +203,7 @@ def __new__(cls, force_instance=False, **kwargs): instance_cache = cls._async_clients() if instance_cache is not None and io_loop in instance_cache: return instance_cache[io_loop] - instance = super(AsyncHTTPClient, cls).__new__(cls, **kwargs) + instance = super(AsyncHTTPClient, cls).__new__(cls, **kwargs) # type: ignore # Make sure the instance knows which cache to remove itself from. # It can't simply call _async_clients() because we may be in # __new__(AsyncHTTPClient) but instance.__class__ may be @@ -192,14 +213,14 @@ def __new__(cls, force_instance=False, **kwargs): instance_cache[instance.io_loop] = instance return instance - def initialize(self, defaults=None): + def initialize(self, defaults: Optional[Dict[str, Any]] = None) -> None: self.io_loop = IOLoop.current() self.defaults = dict(HTTPRequest._DEFAULTS) if defaults is not None: self.defaults.update(defaults) self._closed = False - def close(self): + def close(self) -> None: """Destroys this HTTP client, freeing any file descriptors used. This method is **not needed in normal use** due to the way @@ -216,11 +237,21 @@ def close(self): return self._closed = True if self._instance_cache is not None: - if self._instance_cache.get(self.io_loop) is not self: + cached_val = self._instance_cache.pop(self.io_loop, None) + # If there's an object other than self in the instance + # cache for our IOLoop, something has gotten mixed up. A + # value of None appears to be possible when this is called + # from a destructor (HTTPClient.__del__) as the weakref + # gets cleared before the destructor runs. + if cached_val is not None and cached_val is not self: raise RuntimeError("inconsistent AsyncHTTPClient cache") - del self._instance_cache[self.io_loop] - def fetch(self, request, callback=None, raise_error=True, **kwargs): + def fetch( + self, + request: Union[str, "HTTPRequest"], + raise_error: bool = True, + **kwargs: Any + ) -> "Future[HTTPResponse]": """Executes a request, asynchronously returning an `HTTPResponse`. The request may be either a string URL or an `HTTPRequest` object. @@ -240,17 +271,14 @@ def fetch(self, request, callback=None, raise_error=True, **kwargs): Instead, you must check the response's ``error`` attribute or call its `~HTTPResponse.rethrow` method. - .. deprecated:: 5.1 - - The ``callback`` argument is deprecated and will be removed - in 6.0. Use the returned `.Future` instead. + .. versionchanged:: 6.0 - The ``raise_error=False`` argument currently suppresses - *all* errors, encapsulating them in `HTTPResponse` objects - with a 599 response code. This will change in Tornado 6.0: - ``raise_error=False`` will only affect the `HTTPError` - raised when a non-200 response code is used. + The ``callback`` argument was removed. Use the returned + `.Future` instead. + The ``raise_error=False`` argument only affects the + `HTTPError` raised when a non-200 response code is used, + instead of suppressing all errors. """ if self._closed: raise RuntimeError("fetch() called on closed AsyncHTTPClient") @@ -258,49 +286,35 @@ def fetch(self, request, callback=None, raise_error=True, **kwargs): request = HTTPRequest(url=request, **kwargs) else: if kwargs: - raise ValueError("kwargs can't be used if request is an HTTPRequest object") + raise ValueError( + "kwargs can't be used if request is an HTTPRequest object" + ) # We may modify this (to add Host, Accept-Encoding, etc), # so make sure we don't modify the caller's object. This is also # where normal dicts get converted to HTTPHeaders objects. request.headers = httputil.HTTPHeaders(request.headers) - request = _RequestProxy(request, self.defaults) - future = Future() - if callback is not None: - warnings.warn("callback arguments are deprecated, use the returned Future instead", - DeprecationWarning) - callback = stack_context.wrap(callback) - - def handle_future(future): - exc = future.exception() - if isinstance(exc, HTTPError) and exc.response is not None: - response = exc.response - elif exc is not None: - response = HTTPResponse( - request, 599, error=exc, - request_time=time.time() - request.start_time) - else: - response = future.result() - self.io_loop.add_callback(callback, response) - future.add_done_callback(handle_future) - - def handle_response(response): - if raise_error and response.error: - if isinstance(response.error, HTTPError): - response.error.response = response - future.set_exception(response.error) - else: - if response.error and not response._error_is_response_code: - warnings.warn("raise_error=False will allow '%s' to be raised in the future" % - response.error, DeprecationWarning) - future_set_result_unless_cancelled(future, response) - self.fetch_impl(request, handle_response) + request_proxy = _RequestProxy(request, self.defaults) + future = Future() # type: Future[HTTPResponse] + + def handle_response(response: "HTTPResponse") -> None: + if response.error: + if raise_error or not response._error_is_response_code: + future_set_exception_unless_cancelled(future, response.error) + return + future_set_result_unless_cancelled(future, response) + + self.fetch_impl(cast(HTTPRequest, request_proxy), handle_response) return future - def fetch_impl(self, request, callback): + def fetch_impl( + self, request: "HTTPRequest", callback: Callable[["HTTPResponse"], None] + ) -> None: raise NotImplementedError() @classmethod - def configure(cls, impl, **kwargs): + def configure( + cls, impl: "Union[None, str, Type[Configurable]]", **kwargs: Any + ) -> None: """Configures the `AsyncHTTPClient` subclass to use. ``AsyncHTTPClient()`` actually creates an instance of a subclass. @@ -325,6 +339,8 @@ def configure(cls, impl, **kwargs): class HTTPRequest(object): """HTTP client request object.""" + _headers = None # type: Union[Dict[str, str], httputil.HTTPHeaders] + # Default values for HTTPRequest parameters. # Merged with the values on the request object by AsyncHTTPClient # implementations. @@ -334,24 +350,49 @@ class HTTPRequest(object): follow_redirects=True, max_redirects=5, decompress_response=True, - proxy_password='', + proxy_password="", allow_nonstandard_methods=False, - validate_cert=True) - - def __init__(self, url, method="GET", headers=None, body=None, - auth_username=None, auth_password=None, auth_mode=None, - connect_timeout=None, request_timeout=None, - if_modified_since=None, follow_redirects=None, - max_redirects=None, user_agent=None, use_gzip=None, - network_interface=None, streaming_callback=None, - header_callback=None, prepare_curl_callback=None, - proxy_host=None, proxy_port=None, proxy_username=None, - proxy_password=None, proxy_auth_mode=None, - allow_nonstandard_methods=None, validate_cert=None, - ca_certs=None, allow_ipv6=None, client_key=None, - client_cert=None, body_producer=None, - expect_100_continue=False, decompress_response=None, - ssl_options=None): + validate_cert=True, + ) + + def __init__( + self, + url: str, + method: str = "GET", + headers: Optional[Union[Dict[str, str], httputil.HTTPHeaders]] = None, + body: Optional[Union[bytes, str]] = None, + auth_username: Optional[str] = None, + auth_password: Optional[str] = None, + auth_mode: Optional[str] = None, + connect_timeout: Optional[float] = None, + request_timeout: Optional[float] = None, + if_modified_since: Optional[Union[float, datetime.datetime]] = None, + follow_redirects: Optional[bool] = None, + max_redirects: Optional[int] = None, + user_agent: Optional[str] = None, + use_gzip: Optional[bool] = None, + network_interface: Optional[str] = None, + streaming_callback: Optional[Callable[[bytes], None]] = None, + header_callback: Optional[Callable[[str], None]] = None, + prepare_curl_callback: Optional[Callable[[Any], None]] = None, + proxy_host: Optional[str] = None, + proxy_port: Optional[int] = None, + proxy_username: Optional[str] = None, + proxy_password: Optional[str] = None, + proxy_auth_mode: Optional[str] = None, + allow_nonstandard_methods: Optional[bool] = None, + validate_cert: Optional[bool] = None, + ca_certs: Optional[str] = None, + allow_ipv6: Optional[bool] = None, + client_key: Optional[str] = None, + client_cert: Optional[str] = None, + body_producer: Optional[ + Callable[[Callable[[bytes], None]], "Future[None]"] + ] = None, + expect_100_continue: bool = False, + decompress_response: Optional[bool] = None, + ssl_options: Optional[Union[Dict[str, Any], ssl.SSLContext]] = None, + ) -> None: r"""All parameters except ``url`` are optional. :arg str url: URL to fetch @@ -360,7 +401,9 @@ def __init__(self, url, method="GET", headers=None, body=None, :type headers: `~tornado.httputil.HTTPHeaders` or `dict` :arg body: HTTP request body as a string (byte or unicode; if unicode the utf-8 encoding will be used) - :arg body_producer: Callable used for lazy/asynchronous request bodies. + :type body: `str` or `bytes` + :arg collections.abc.Callable body_producer: Callable used for + lazy/asynchronous request bodies. It is called with one argument, a ``write`` function, and should return a `.Future`. It should call the write function with new data as it becomes available. The write function returns a @@ -378,9 +421,9 @@ def __init__(self, url, method="GET", headers=None, body=None, supports "basic" and "digest"; ``simple_httpclient`` only supports "basic" :arg float connect_timeout: Timeout for initial connection in seconds, - default 20 seconds + default 20 seconds (0 means no timeout) :arg float request_timeout: Timeout for entire request in seconds, - default 20 seconds + default 20 seconds (0 means no timeout) :arg if_modified_since: Timestamp for ``If-Modified-Since`` header :type if_modified_since: `datetime` or `float` :arg bool follow_redirects: Should redirects be followed automatically @@ -392,8 +435,8 @@ def __init__(self, url, method="GET", headers=None, body=None, New in Tornado 4.0. :arg bool use_gzip: Deprecated alias for ``decompress_response`` since Tornado 4.0. - :arg str network_interface: Network interface to use for request. - ``curl_httpclient`` only; see note below. + :arg str network_interface: Network interface or source IP to use for request. + See ``curl_httpclient`` note below. :arg collections.abc.Callable streaming_callback: If set, ``streaming_callback`` will be run with each chunk of data as it is received, and ``HTTPResponse.body`` and ``HTTPResponse.buffer`` will be empty in @@ -433,11 +476,11 @@ def __init__(self, url, method="GET", headers=None, body=None, ``simple_httpclient`` (unsupported by ``curl_httpclient``). Overrides ``validate_cert``, ``ca_certs``, ``client_key``, and ``client_cert``. - :arg bool allow_ipv6: Use IPv6 when available? Default is true. + :arg bool allow_ipv6: Use IPv6 when available? Default is True. :arg bool expect_100_continue: If true, send the ``Expect: 100-continue`` header and wait for a continue response before sending the request body. Only supported with - simple_httpclient. + ``simple_httpclient``. .. note:: @@ -465,10 +508,11 @@ def __init__(self, url, method="GET", headers=None, body=None, """ # Note that some of these attributes go through property setters # defined below. - self.headers = headers + self.headers = headers # type: ignore if if_modified_since: self.headers["If-Modified-Since"] = httputil.format_timestamp( - if_modified_since) + if_modified_since + ) self.proxy_host = proxy_host self.proxy_port = proxy_port self.proxy_username = proxy_username @@ -476,7 +520,7 @@ def __init__(self, url, method="GET", headers=None, body=None, self.proxy_auth_mode = proxy_auth_mode self.url = url self.method = method - self.body = body + self.body = body # type: ignore self.body_producer = body_producer self.auth_username = auth_username self.auth_password = auth_password @@ -487,7 +531,7 @@ def __init__(self, url, method="GET", headers=None, body=None, self.max_redirects = max_redirects self.user_agent = user_agent if decompress_response is not None: - self.decompress_response = decompress_response + self.decompress_response = decompress_response # type: Optional[bool] else: self.decompress_response = use_gzip self.network_interface = network_interface @@ -505,90 +549,96 @@ def __init__(self, url, method="GET", headers=None, body=None, self.start_time = time.time() @property - def headers(self): - return self._headers + def headers(self) -> httputil.HTTPHeaders: + # TODO: headers may actually be a plain dict until fairly late in + # the process (AsyncHTTPClient.fetch), but practically speaking, + # whenever the property is used they're already HTTPHeaders. + return self._headers # type: ignore @headers.setter - def headers(self, value): + def headers(self, value: Union[Dict[str, str], httputil.HTTPHeaders]) -> None: if value is None: self._headers = httputil.HTTPHeaders() else: - self._headers = value + self._headers = value # type: ignore @property - def body(self): + def body(self) -> bytes: return self._body @body.setter - def body(self, value): + def body(self, value: Union[bytes, str]) -> None: self._body = utf8(value) - @property - def body_producer(self): - return self._body_producer - - @body_producer.setter - def body_producer(self, value): - self._body_producer = stack_context.wrap(value) - - @property - def streaming_callback(self): - return self._streaming_callback - - @streaming_callback.setter - def streaming_callback(self, value): - self._streaming_callback = stack_context.wrap(value) - - @property - def header_callback(self): - return self._header_callback - - @header_callback.setter - def header_callback(self, value): - self._header_callback = stack_context.wrap(value) - - @property - def prepare_curl_callback(self): - return self._prepare_curl_callback - - @prepare_curl_callback.setter - def prepare_curl_callback(self, value): - self._prepare_curl_callback = stack_context.wrap(value) - class HTTPResponse(object): """HTTP Response object. Attributes: - * request: HTTPRequest object + * ``request``: HTTPRequest object - * code: numeric HTTP status code, e.g. 200 or 404 + * ``code``: numeric HTTP status code, e.g. 200 or 404 - * reason: human-readable reason phrase describing the status code + * ``reason``: human-readable reason phrase describing the status code - * headers: `tornado.httputil.HTTPHeaders` object + * ``headers``: `tornado.httputil.HTTPHeaders` object - * effective_url: final location of the resource after following any + * ``effective_url``: final location of the resource after following any redirects - * buffer: ``cStringIO`` object for response body + * ``buffer``: ``cStringIO`` object for response body + + * ``body``: response body as bytes (created on demand from ``self.buffer``) - * body: response body as bytes (created on demand from ``self.buffer``) + * ``error``: Exception object, if any - * error: Exception object, if any + * ``request_time``: seconds from request start to finish. Includes all + network operations from DNS resolution to receiving the last byte of + data. Does not include time spent in the queue (due to the + ``max_clients`` option). If redirects were followed, only includes + the final request. - * request_time: seconds from request start to finish + * ``start_time``: Time at which the HTTP operation started, based on + `time.time` (not the monotonic clock used by `.IOLoop.time`). May + be ``None`` if the request timed out while in the queue. - * time_info: dictionary of diagnostic timing information from the request. - Available data are subject to change, but currently uses timings + * ``time_info``: dictionary of diagnostic timing information from the + request. Available data are subject to change, but currently uses timings available from http://curl.haxx.se/libcurl/c/curl_easy_getinfo.html, plus ``queue``, which is the delay (if any) introduced by waiting for a slot under `AsyncHTTPClient`'s ``max_clients`` setting. + + .. versionadded:: 5.1 + + Added the ``start_time`` attribute. + + .. versionchanged:: 5.1 + + The ``request_time`` attribute previously included time spent in the queue + for ``simple_httpclient``, but not in ``curl_httpclient``. Now queueing time + is excluded in both implementations. ``request_time`` is now more accurate for + ``curl_httpclient`` because it uses a monotonic clock when available. """ - def __init__(self, request, code, headers=None, buffer=None, - effective_url=None, error=None, request_time=None, - time_info=None, reason=None): + + # I'm not sure why these don't get type-inferred from the references in __init__. + error = None # type: Optional[BaseException] + _error_is_response_code = False + request = None # type: HTTPRequest + + def __init__( + self, + request: HTTPRequest, + code: int, + headers: Optional[httputil.HTTPHeaders] = None, + buffer: Optional[BytesIO] = None, + effective_url: Optional[str] = None, + error: Optional[BaseException] = None, + request_time: Optional[float] = None, + time_info: Optional[Dict[str, float]] = None, + reason: Optional[str] = None, + start_time: Optional[float] = None, + ) -> None: if isinstance(request, _RequestProxy): self.request = request.request else: @@ -600,7 +650,7 @@ def __init__(self, request, code, headers=None, buffer=None, else: self.headers = httputil.HTTPHeaders() self.buffer = buffer - self._body = None + self._body = None # type: Optional[bytes] if effective_url is None: self.effective_url = request.url else: @@ -609,30 +659,30 @@ def __init__(self, request, code, headers=None, buffer=None, if error is None: if self.code < 200 or self.code >= 300: self._error_is_response_code = True - self.error = HTTPError(self.code, message=self.reason, - response=self) + self.error = HTTPError(self.code, message=self.reason, response=self) else: self.error = None else: self.error = error + self.start_time = start_time self.request_time = request_time self.time_info = time_info or {} @property - def body(self): + def body(self) -> bytes: if self.buffer is None: - return None + return b"" elif self._body is None: self._body = self.buffer.getvalue() return self._body - def rethrow(self): + def rethrow(self) -> None: """If there was an error on the request, raise an `HTTPError`.""" if self.error: raise self.error - def __repr__(self): + def __repr__(self) -> str: args = ",".join("%s=%r" % i for i in sorted(self.__dict__.items())) return "%s(%s)" % (self.__class__.__name__, args) @@ -657,13 +707,19 @@ class HTTPClientError(Exception): `tornado.web.HTTPError`. The name ``tornado.httpclient.HTTPError`` remains as an alias. """ - def __init__(self, code, message=None, response=None): + + def __init__( + self, + code: int, + message: Optional[str] = None, + response: Optional[HTTPResponse] = None, + ) -> None: self.code = code self.message = message or httputil.responses.get(code, "Unknown") self.response = response - super(HTTPClientError, self).__init__(code, message, response) + super().__init__(code, message, response) - def __str__(self): + def __str__(self) -> str: return "HTTP %d: %s" % (self.code, self.message) # There is a cyclic reference between self and self.response, @@ -681,11 +737,14 @@ class _RequestProxy(object): Used internally by AsyncHTTPClient implementations. """ - def __init__(self, request, defaults): + + def __init__( + self, request: HTTPRequest, defaults: Optional[Dict[str, Any]] + ) -> None: self.request = request self.defaults = defaults - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: request_attr = getattr(self.request, name) if request_attr is not None: return request_attr @@ -695,8 +754,9 @@ def __getattr__(self, name): return None -def main(): +def main() -> None: from tornado.options import define, options, parse_command_line + define("print_headers", type=bool, default=False) define("print_body", type=bool, default=True) define("follow_redirects", type=bool, default=True) @@ -707,12 +767,13 @@ def main(): client = HTTPClient() for arg in args: try: - response = client.fetch(arg, - follow_redirects=options.follow_redirects, - validate_cert=options.validate_cert, - proxy_host=options.proxy_host, - proxy_port=options.proxy_port, - ) + response = client.fetch( + arg, + follow_redirects=options.follow_redirects, + validate_cert=options.validate_cert, + proxy_host=options.proxy_host, + proxy_port=options.proxy_port, + ) except HTTPError as e: if e.response is not None: response = e.response diff --git a/tornado/httpserver.py b/tornado/httpserver.py index 3498d71fb6..cd4a468120 100644 --- a/tornado/httpserver.py +++ b/tornado/httpserver.py @@ -25,22 +25,25 @@ class except to start a server at the beginning of the process to `tornado.httputil.HTTPServerRequest`. The old name remains as an alias. """ -from __future__ import absolute_import, division, print_function - import socket +import ssl from tornado.escape import native_str from tornado.http1connection import HTTP1ServerConnection, HTTP1ConnectionParameters -from tornado import gen from tornado import httputil from tornado import iostream from tornado import netutil from tornado.tcpserver import TCPServer from tornado.util import Configurable +import typing +from typing import Union, Any, Dict, Callable, List, Type, Tuple, Optional, Awaitable + +if typing.TYPE_CHECKING: + from typing import Set # noqa: F401 + -class HTTPServer(TCPServer, Configurable, - httputil.HTTPServerConnectionDelegate): +class HTTPServer(TCPServer, Configurable, httputil.HTTPServerConnectionDelegate): r"""A non-blocking, single-threaded HTTP server. A server is defined by a subclass of `.HTTPServerConnectionDelegate`, @@ -137,7 +140,8 @@ class HTTPServer(TCPServer, Configurable, .. versionchanged:: 5.0 The ``io_loop`` argument has been removed. """ - def __init__(self, *args, **kwargs): + + def __init__(self, *args: Any, **kwargs: Any) -> None: # Ignore args to __init__; real initialization belongs in # initialize since we're Configurable. (there's something # weird in initialization order between this class, @@ -145,13 +149,29 @@ def __init__(self, *args, **kwargs): # completely) pass - def initialize(self, request_callback, no_keep_alive=False, - xheaders=False, ssl_options=None, protocol=None, - decompress_request=False, - chunk_size=None, max_header_size=None, - idle_connection_timeout=None, body_timeout=None, - max_body_size=None, max_buffer_size=None, - trusted_downstream=None): + def initialize( + self, + request_callback: Union[ + httputil.HTTPServerConnectionDelegate, + Callable[[httputil.HTTPServerRequest], None], + ], + no_keep_alive: bool = False, + xheaders: bool = False, + ssl_options: Optional[Union[Dict[str, Any], ssl.SSLContext]] = None, + protocol: Optional[str] = None, + decompress_request: bool = False, + chunk_size: Optional[int] = None, + max_header_size: Optional[int] = None, + idle_connection_timeout: Optional[float] = None, + body_timeout: Optional[float] = None, + max_body_size: Optional[int] = None, + max_buffer_size: Optional[int] = None, + trusted_downstream: Optional[List[str]] = None, + ) -> None: + # This method's signature is not extracted with autodoc + # because we want its arguments to appear on the class + # constructor. When changing this signature, also update the + # copy in httpserver.rst. self.request_callback = request_callback self.xheaders = xheaders self.protocol = protocol @@ -162,38 +182,55 @@ def initialize(self, request_callback, no_keep_alive=False, header_timeout=idle_connection_timeout or 3600, max_body_size=max_body_size, body_timeout=body_timeout, - no_keep_alive=no_keep_alive) - TCPServer.__init__(self, ssl_options=ssl_options, - max_buffer_size=max_buffer_size, - read_chunk_size=chunk_size) - self._connections = set() + no_keep_alive=no_keep_alive, + ) + TCPServer.__init__( + self, + ssl_options=ssl_options, + max_buffer_size=max_buffer_size, + read_chunk_size=chunk_size, + ) + self._connections = set() # type: Set[HTTP1ServerConnection] self.trusted_downstream = trusted_downstream @classmethod - def configurable_base(cls): + def configurable_base(cls) -> Type[Configurable]: return HTTPServer @classmethod - def configurable_default(cls): + def configurable_default(cls) -> Type[Configurable]: return HTTPServer - @gen.coroutine - def close_all_connections(self): + async def close_all_connections(self) -> None: + """Close all open connections and asynchronously wait for them to finish. + + This method is used in combination with `~.TCPServer.stop` to + support clean shutdowns (especially for unittests). Typical + usage would call ``stop()`` first to stop accepting new + connections, then ``await close_all_connections()`` to wait for + existing connections to finish. + + This method does not currently close open websocket connections. + + Note that this method is a coroutine and must be called with ``await``. + + """ while self._connections: # Peek at an arbitrary element of the set conn = next(iter(self._connections)) - yield conn.close() - - def handle_stream(self, stream, address): - context = _HTTPRequestContext(stream, address, - self.protocol, - self.trusted_downstream) - conn = HTTP1ServerConnection( - stream, self.conn_params, context) + await conn.close() + + def handle_stream(self, stream: iostream.IOStream, address: Tuple) -> None: + context = _HTTPRequestContext( + stream, address, self.protocol, self.trusted_downstream + ) + conn = HTTP1ServerConnection(stream, self.conn_params, context) self._connections.add(conn) conn.start_serving(self) - def start_request(self, server_conn, request_conn): + def start_request( + self, server_conn: object, request_conn: httputil.HTTPConnection + ) -> httputil.HTTPMessageDelegate: if isinstance(self.request_callback, httputil.HTTPServerConnectionDelegate): delegate = self.request_callback.start_request(server_conn, request_conn) else: @@ -204,37 +241,56 @@ def start_request(self, server_conn, request_conn): return delegate - def on_close(self, server_conn): - self._connections.remove(server_conn) + def on_close(self, server_conn: object) -> None: + self._connections.remove(typing.cast(HTTP1ServerConnection, server_conn)) class _CallableAdapter(httputil.HTTPMessageDelegate): - def __init__(self, request_callback, request_conn): + def __init__( + self, + request_callback: Callable[[httputil.HTTPServerRequest], None], + request_conn: httputil.HTTPConnection, + ) -> None: self.connection = request_conn self.request_callback = request_callback - self.request = None + self.request = None # type: Optional[httputil.HTTPServerRequest] self.delegate = None - self._chunks = [] + self._chunks = [] # type: List[bytes] - def headers_received(self, start_line, headers): + def headers_received( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + ) -> Optional[Awaitable[None]]: self.request = httputil.HTTPServerRequest( - connection=self.connection, start_line=start_line, - headers=headers) + connection=self.connection, + start_line=typing.cast(httputil.RequestStartLine, start_line), + headers=headers, + ) + return None - def data_received(self, chunk): + def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: self._chunks.append(chunk) + return None - def finish(self): - self.request.body = b''.join(self._chunks) + def finish(self) -> None: + assert self.request is not None + self.request.body = b"".join(self._chunks) self.request._parse_body() self.request_callback(self.request) - def on_connection_close(self): - self._chunks = None + def on_connection_close(self) -> None: + del self._chunks class _HTTPRequestContext(object): - def __init__(self, stream, address, protocol, trusted_downstream=None): + def __init__( + self, + stream: iostream.IOStream, + address: Tuple, + protocol: Optional[str], + trusted_downstream: Optional[List[str]] = None, + ) -> None: self.address = address # Save the socket's address family now so we know how to # interpret self.address even after the stream is closed @@ -244,12 +300,14 @@ def __init__(self, stream, address, protocol, trusted_downstream=None): else: self.address_family = None # In HTTPServerRequest we want an IP, not a full socket address. - if (self.address_family in (socket.AF_INET, socket.AF_INET6) and - address is not None): + if ( + self.address_family in (socket.AF_INET, socket.AF_INET6) + and address is not None + ): self.remote_ip = address[0] else: # Unix (or other) socket; fake the remote address. - self.remote_ip = '0.0.0.0' + self.remote_ip = "0.0.0.0" if protocol: self.protocol = protocol elif isinstance(stream, iostream.SSLIOStream): @@ -260,7 +318,7 @@ def __init__(self, stream, address, protocol, trusted_downstream=None): self._orig_protocol = self.protocol self.trusted_downstream = set(trusted_downstream or []) - def __str__(self): + def __str__(self) -> str: if self.address_family in (socket.AF_INET, socket.AF_INET6): return self.remote_ip elif isinstance(self.address, bytes): @@ -271,12 +329,12 @@ def __str__(self): else: return str(self.address) - def _apply_xheaders(self, headers): + def _apply_xheaders(self, headers: httputil.HTTPHeaders) -> None: """Rewrite the ``remote_ip`` and ``protocol`` fields.""" # Squid uses X-Forwarded-For, others use X-Real-Ip ip = headers.get("X-Forwarded-For", self.remote_ip) # Skip trusted downstream hosts in X-Forwarded-For list - for ip in (cand.strip() for cand in reversed(ip.split(','))): + for ip in (cand.strip() for cand in reversed(ip.split(","))): if ip not in self.trusted_downstream: break ip = headers.get("X-Real-Ip", ip) @@ -284,16 +342,16 @@ def _apply_xheaders(self, headers): self.remote_ip = ip # AWS uses X-Forwarded-Proto proto_header = headers.get( - "X-Scheme", headers.get("X-Forwarded-Proto", - self.protocol)) + "X-Scheme", headers.get("X-Forwarded-Proto", self.protocol) + ) if proto_header: # use only the last proto entry if there is more than one - # TODO: support trusting mutiple layers of proxied protocol - proto_header = proto_header.split(',')[-1].strip() + # TODO: support trusting multiple layers of proxied protocol + proto_header = proto_header.split(",")[-1].strip() if proto_header in ("http", "https"): self.protocol = proto_header - def _unapply_xheaders(self): + def _unapply_xheaders(self) -> None: """Undo changes from `_apply_xheaders`. Xheaders are per-request so they should not leak to the next @@ -304,27 +362,37 @@ def _unapply_xheaders(self): class _ProxyAdapter(httputil.HTTPMessageDelegate): - def __init__(self, delegate, request_conn): + def __init__( + self, + delegate: httputil.HTTPMessageDelegate, + request_conn: httputil.HTTPConnection, + ) -> None: self.connection = request_conn self.delegate = delegate - def headers_received(self, start_line, headers): - self.connection.context._apply_xheaders(headers) + def headers_received( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + ) -> Optional[Awaitable[None]]: + # TODO: either make context an official part of the + # HTTPConnection interface or figure out some other way to do this. + self.connection.context._apply_xheaders(headers) # type: ignore return self.delegate.headers_received(start_line, headers) - def data_received(self, chunk): + def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: return self.delegate.data_received(chunk) - def finish(self): + def finish(self) -> None: self.delegate.finish() self._cleanup() - def on_connection_close(self): + def on_connection_close(self) -> None: self.delegate.on_connection_close() self._cleanup() - def _cleanup(self): - self.connection.context._unapply_xheaders() + def _cleanup(self) -> None: + self.connection.context._unapply_xheaders() # type: ignore HTTPRequest = httputil.HTTPServerRequest diff --git a/tornado/httputil.py b/tornado/httputil.py index 9c607b8c85..c0c57e6e95 100644 --- a/tornado/httputil.py +++ b/tornado/httputil.py @@ -19,91 +19,61 @@ via `tornado.web.RequestHandler.request`. """ -from __future__ import absolute_import, division, print_function - import calendar -import collections +import collections.abc import copy import datetime import email.utils -import numbers +from functools import lru_cache +from http.client import responses +import http.cookies import re +from ssl import SSLError import time -import warnings +import unicodedata +from urllib.parse import urlencode, urlparse, urlunparse, parse_qsl from tornado.escape import native_str, parse_qs_bytes, utf8 from tornado.log import gen_log -from tornado.util import ObjectDict, PY3 - -if PY3: - import http.cookies as Cookie - from http.client import responses - from urllib.parse import urlencode, urlparse, urlunparse, parse_qsl -else: - import Cookie - from httplib import responses - from urllib import urlencode - from urlparse import urlparse, urlunparse, parse_qsl +from tornado.util import ObjectDict, unicode_type # responses is unused in this file, but we re-export it to other files. # Reference it so pyflakes doesn't complain. responses -try: - from ssl import SSLError -except ImportError: - # ssl is unavailable on app engine. - class _SSLError(Exception): - pass - # Hack around a mypy limitation. We can't simply put "type: ignore" - # on the class definition itself; must go through an assignment. - SSLError = _SSLError # type: ignore - -try: - import typing # noqa: F401 -except ImportError: - pass - - -# RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line -# terminator and ignore any preceding CR. -_CRLF_RE = re.compile(r'\r?\n') - - -class _NormalizedHeaderCache(dict): - """Dynamic cached mapping of header names to Http-Header-Case. - - Implemented as a dict subclass so that cache hits are as fast as a - normal dict lookup, without the overhead of a python function - call. - - >>> normalized_headers = _NormalizedHeaderCache(10) - >>> normalized_headers["coNtent-TYPE"] +import typing +from typing import ( + Tuple, + Iterable, + List, + Mapping, + Iterator, + Dict, + Union, + Optional, + Awaitable, + Generator, + AnyStr, +) + +if typing.TYPE_CHECKING: + from typing import Deque # noqa: F401 + from asyncio import Future # noqa: F401 + import unittest # noqa: F401 + + +@lru_cache(1000) +def _normalize_header(name: str) -> str: + """Map a header name to Http-Header-Case. + + >>> _normalize_header("coNtent-TYPE") 'Content-Type' """ - def __init__(self, size): - super(_NormalizedHeaderCache, self).__init__() - self.size = size - self.queue = collections.deque() - - def __missing__(self, key): - normalized = "-".join([w.capitalize() for w in key.split("-")]) - self[key] = normalized - self.queue.append(key) - if len(self.queue) > self.size: - # Limit the size of the cache. LRU would be better, but this - # simpler approach should be fine. In Python 2.7+ we could - # use OrderedDict (or in 3.2+, @functools.lru_cache). - old_key = self.queue.popleft() - del self[old_key] - return normalized + return "-".join([w.capitalize() for w in name.split("-")]) -_normalized_headers = _NormalizedHeaderCache(1000) - - -class HTTPHeaders(collections.MutableMapping): +class HTTPHeaders(collections.abc.MutableMapping): """A dictionary that maintains ``Http-Header-Case`` for all keys. Supports multiple values per key via a pair of new methods, @@ -131,12 +101,28 @@ class HTTPHeaders(collections.MutableMapping): Set-Cookie: A=B Set-Cookie: C=D """ - def __init__(self, *args, **kwargs): + + @typing.overload + def __init__(self, __arg: Mapping[str, List[str]]) -> None: + pass + + @typing.overload # noqa: F811 + def __init__(self, __arg: Mapping[str, str]) -> None: + pass + + @typing.overload # noqa: F811 + def __init__(self, *args: Tuple[str, str]) -> None: + pass + + @typing.overload # noqa: F811 + def __init__(self, **kwargs: str) -> None: + pass + + def __init__(self, *args: typing.Any, **kwargs: str) -> None: # noqa: F811 self._dict = {} # type: typing.Dict[str, str] self._as_list = {} # type: typing.Dict[str, typing.List[str]] - self._last_key = None - if (len(args) == 1 and len(kwargs) == 0 and - isinstance(args[0], HTTPHeaders)): + self._last_key = None # type: Optional[str] + if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], HTTPHeaders): # Copy constructor for k, v in args[0].get_all(): self.add(k, v) @@ -146,25 +132,24 @@ def __init__(self, *args, **kwargs): # new public methods - def add(self, name, value): - # type: (str, str) -> None + def add(self, name: str, value: str) -> None: """Adds a new value for the given key.""" - norm_name = _normalized_headers[name] + norm_name = _normalize_header(name) self._last_key = norm_name if norm_name in self: - self._dict[norm_name] = (native_str(self[norm_name]) + ',' + - native_str(value)) + self._dict[norm_name] = ( + native_str(self[norm_name]) + "," + native_str(value) + ) self._as_list[norm_name].append(value) else: self[norm_name] = value - def get_list(self, name): + def get_list(self, name: str) -> List[str]: """Returns all values for the given header as a list.""" - norm_name = _normalized_headers[name] + norm_name = _normalize_header(name) return self._as_list.get(norm_name, []) - def get_all(self): - # type: () -> typing.Iterable[typing.Tuple[str, str]] + def get_all(self) -> Iterable[Tuple[str, str]]: """Returns an iterable of all (name, value) pairs. If a header has multiple values, multiple pairs will be @@ -174,7 +159,7 @@ def get_all(self): for value in values: yield (name, value) - def parse_line(self, line): + def parse_line(self, line: str) -> None: """Updates the dictionary with a single header line. >>> h = HTTPHeaders() @@ -186,7 +171,7 @@ def parse_line(self, line): # continuation of a multi-line header if self._last_key is None: raise HTTPInputError("first header line cannot start with whitespace") - new_part = ' ' + line.lstrip() + new_part = " " + line.lstrip() self._as_list[self._last_key][-1] += new_part self._dict[self._last_key] += new_part else: @@ -197,7 +182,7 @@ def parse_line(self, line): self.add(name, value.strip()) @classmethod - def parse(cls, headers): + def parse(cls, headers: str) -> "HTTPHeaders": """Returns a dictionary from HTTP header text. >>> h = HTTPHeaders.parse("Content-Type: text/html\\r\\nContent-Length: 42\\r\\n") @@ -211,34 +196,37 @@ def parse(cls, headers): """ h = cls() - for line in _CRLF_RE.split(headers): + # RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line + # terminator and ignore any preceding CR. + for line in headers.split("\n"): + if line.endswith("\r"): + line = line[:-1] if line: h.parse_line(line) return h # MutableMapping abstract method implementations. - def __setitem__(self, name, value): - norm_name = _normalized_headers[name] + def __setitem__(self, name: str, value: str) -> None: + norm_name = _normalize_header(name) self._dict[norm_name] = value self._as_list[norm_name] = [value] - def __getitem__(self, name): - # type: (str) -> str - return self._dict[_normalized_headers[name]] + def __getitem__(self, name: str) -> str: + return self._dict[_normalize_header(name)] - def __delitem__(self, name): - norm_name = _normalized_headers[name] + def __delitem__(self, name: str) -> None: + norm_name = _normalize_header(name) del self._dict[norm_name] del self._as_list[norm_name] - def __len__(self): + def __len__(self) -> int: return len(self._dict) - def __iter__(self): + def __iter__(self) -> Iterator[typing.Any]: return iter(self._dict) - def copy(self): + def copy(self) -> "HTTPHeaders": # defined in dict but not in MutableMapping. return HTTPHeaders(self) @@ -247,7 +235,7 @@ def copy(self): # the appearance that HTTPHeaders is a single container. __copy__ = copy - def __str__(self): + def __str__(self) -> str: lines = [] for name, value in self.get_all(): lines.append("%s: %s\n" % (name, value)) @@ -348,9 +336,26 @@ class HTTPServerRequest(object): .. versionchanged:: 4.0 Moved from ``tornado.httpserver.HTTPRequest``. """ - def __init__(self, method=None, uri=None, version="HTTP/1.0", headers=None, - body=None, host=None, files=None, connection=None, - start_line=None, server_connection=None): + + path = None # type: str + query = None # type: str + + # HACK: Used for stream_request_body + _body_future = None # type: Future[None] + + def __init__( + self, + method: Optional[str] = None, + uri: Optional[str] = None, + version: str = "HTTP/1.0", + headers: Optional[HTTPHeaders] = None, + body: Optional[bytes] = None, + host: Optional[str] = None, + files: Optional[Dict[str, List["HTTPFile"]]] = None, + connection: Optional["HTTPConnection"] = None, + start_line: Optional["RequestStartLine"] = None, + server_connection: Optional[object] = None, + ) -> None: if start_line is not None: method, uri, version = start_line self.method = method @@ -360,9 +365,9 @@ def __init__(self, method=None, uri=None, version="HTTP/1.0", headers=None, self.body = body or b"" # set remote IP and protocol - context = getattr(connection, 'context', None) - self.remote_ip = getattr(context, 'remote_ip', None) - self.protocol = getattr(context, 'protocol', "http") + context = getattr(connection, "context", None) + self.remote_ip = getattr(context, "remote_ip", None) + self.protocol = getattr(context, "protocol", "http") self.host = host or self.headers.get("Host") or "127.0.0.1" self.host_name = split_host_and_port(self.host.lower())[0] @@ -372,31 +377,19 @@ def __init__(self, method=None, uri=None, version="HTTP/1.0", headers=None, self._start_time = time.time() self._finish_time = None - self.path, sep, self.query = uri.partition('?') + if uri is not None: + self.path, sep, self.query = uri.partition("?") self.arguments = parse_qs_bytes(self.query, keep_blank_values=True) self.query_arguments = copy.deepcopy(self.arguments) - self.body_arguments = {} - - def supports_http_1_1(self): - """Returns True if this request supports HTTP/1.1 semantics. - - .. deprecated:: 4.0 - - Applications are less likely to need this information with - the introduction of `.HTTPConnection`. If you still need - it, access the ``version`` attribute directly. This method - will be removed in Tornado 6.0. - - """ - warnings.warn("supports_http_1_1() is deprecated, use request.version instead", - DeprecationWarning) - return self.version == "HTTP/1.1" + self.body_arguments = {} # type: Dict[str, List[bytes]] @property - def cookies(self): - """A dictionary of Cookie.Morsel objects.""" + def cookies(self) -> Dict[str, http.cookies.Morsel]: + """A dictionary of ``http.cookies.Morsel`` objects.""" if not hasattr(self, "_cookies"): - self._cookies = Cookie.SimpleCookie() + self._cookies = ( + http.cookies.SimpleCookie() + ) # type: http.cookies.SimpleCookie if "Cookie" in self.headers: try: parsed = parse_cookie(self.headers["Cookie"]) @@ -413,44 +406,20 @@ def cookies(self): pass return self._cookies - def write(self, chunk, callback=None): - """Writes the given chunk to the response stream. - - .. deprecated:: 4.0 - Use ``request.connection`` and the `.HTTPConnection` methods - to write the response. This method will be removed in Tornado 6.0. - """ - warnings.warn("req.write deprecated, use req.connection.write and write_headers instead", - DeprecationWarning) - assert isinstance(chunk, bytes) - assert self.version.startswith("HTTP/1."), \ - "deprecated interface only supported in HTTP/1.x" - self.connection.write(chunk, callback=callback) - - def finish(self): - """Finishes this HTTP request on the open connection. - - .. deprecated:: 4.0 - Use ``request.connection`` and the `.HTTPConnection` methods - to write the response. This method will be removed in Tornado 6.0. - """ - warnings.warn("req.finish deprecated, use req.connection.finish instead", - DeprecationWarning) - self.connection.finish() - self._finish_time = time.time() - - def full_url(self): + def full_url(self) -> str: """Reconstructs the full URL for this request.""" return self.protocol + "://" + self.host + self.uri - def request_time(self): + def request_time(self) -> float: """Returns the amount of time it took for this request to execute.""" if self._finish_time is None: return time.time() - self._start_time else: return self._finish_time - self._start_time - def get_ssl_certificate(self, binary_form=False): + def get_ssl_certificate( + self, binary_form: bool = False + ) -> Union[None, Dict, bytes]: """Returns the client's SSL certificate, if any. To use client certificates, the HTTPServer's @@ -470,21 +439,28 @@ def get_ssl_certificate(self, binary_form=False): http://docs.python.org/library/ssl.html#sslsocket-objects """ try: - return self.connection.stream.socket.getpeercert( - binary_form=binary_form) + if self.connection is None: + return None + # TODO: add a method to HTTPConnection for this so it can work with HTTP/2 + return self.connection.stream.socket.getpeercert( # type: ignore + binary_form=binary_form + ) except SSLError: return None - def _parse_body(self): + def _parse_body(self) -> None: parse_body_arguments( - self.headers.get("Content-Type", ""), self.body, - self.body_arguments, self.files, - self.headers) + self.headers.get("Content-Type", ""), + self.body, + self.body_arguments, + self.files, + self.headers, + ) for k, v in self.body_arguments.items(): self.arguments.setdefault(k, []).extend(v) - def __repr__(self): + def __repr__(self) -> str: attrs = ("protocol", "host", "method", "uri", "version", "remote_ip") args = ", ".join(["%s=%r" % (n, getattr(self, n)) for n in attrs]) return "%s(%s)" % (self.__class__.__name__, args) @@ -496,6 +472,7 @@ class HTTPInputError(Exception): .. versionadded:: 4.0 """ + pass @@ -504,6 +481,7 @@ class HTTPOutputError(Exception): .. versionadded:: 4.0 """ + pass @@ -512,7 +490,10 @@ class HTTPServerConnectionDelegate(object): .. versionadded:: 4.0 """ - def start_request(self, server_conn, request_conn): + + def start_request( + self, server_conn: object, request_conn: "HTTPConnection" + ) -> "HTTPMessageDelegate": """This method is called by the server when a new request has started. :arg server_conn: is an opaque object representing the long-lived @@ -524,7 +505,7 @@ def start_request(self, server_conn, request_conn): """ raise NotImplementedError() - def on_close(self, server_conn): + def on_close(self, server_conn: object) -> None: """This method is called when a connection has been closed. :arg server_conn: is a server connection that has previously been @@ -538,7 +519,13 @@ class HTTPMessageDelegate(object): .. versionadded:: 4.0 """ - def headers_received(self, start_line, headers): + + # TODO: genericize this class to avoid exposing the Union. + def headers_received( + self, + start_line: Union["RequestStartLine", "ResponseStartLine"], + headers: HTTPHeaders, + ) -> Optional[Awaitable[None]]: """Called when the HTTP headers have been received and parsed. :arg start_line: a `.RequestStartLine` or `.ResponseStartLine` @@ -553,18 +540,18 @@ def headers_received(self, start_line, headers): """ pass - def data_received(self, chunk): + def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: """Called when a chunk of data has been received. May return a `.Future` for flow control. """ pass - def finish(self): + def finish(self) -> None: """Called after the last chunk of data has been received.""" pass - def on_connection_close(self): + def on_connection_close(self) -> None: """Called if the connection is closed without finishing the request. If ``headers_received`` is called, either ``finish`` or @@ -578,7 +565,13 @@ class HTTPConnection(object): .. versionadded:: 4.0 """ - def write_headers(self, start_line, headers, chunk=None, callback=None): + + def write_headers( + self, + start_line: Union["RequestStartLine", "ResponseStartLine"], + headers: HTTPHeaders, + chunk: Optional[bytes] = None, + ) -> "Future[None]": """Write an HTTP header block. :arg start_line: a `.RequestStartLine` or `.ResponseStartLine`. @@ -586,29 +579,39 @@ def write_headers(self, start_line, headers, chunk=None, callback=None): :arg chunk: the first (optional) chunk of data. This is an optimization so that small responses can be written in the same call as their headers. - :arg callback: a callback to be run when the write is complete. The ``version`` field of ``start_line`` is ignored. - Returns a `.Future` if no callback is given. + Returns a future for flow control. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. """ raise NotImplementedError() - def write(self, chunk, callback=None): + def write(self, chunk: bytes) -> "Future[None]": """Writes a chunk of body data. - The callback will be run when the write is complete. If no callback - is given, returns a Future. + Returns a future for flow control. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. """ raise NotImplementedError() - def finish(self): - """Indicates that the last body data has been written. - """ + def finish(self) -> None: + """Indicates that the last body data has been written.""" raise NotImplementedError() -def url_concat(url, args): +def url_concat( + url: str, + args: Union[ + None, Dict[str, str], List[Tuple[str, str]], Tuple[Tuple[str, str], ...] + ], +) -> str: """Concatenate url and arguments regardless of whether url has existing query parameters. @@ -633,16 +636,20 @@ def url_concat(url, args): parsed_query.extend(args) else: err = "'args' parameter should be dict, list or tuple. Not {0}".format( - type(args)) + type(args) + ) raise TypeError(err) final_query = urlencode(parsed_query) - url = urlunparse(( - parsed_url[0], - parsed_url[1], - parsed_url[2], - parsed_url[3], - final_query, - parsed_url[5])) + url = urlunparse( + ( + parsed_url[0], + parsed_url[1], + parsed_url[2], + parsed_url[3], + final_query, + parsed_url[5], + ) + ) return url @@ -656,10 +663,13 @@ class HTTPFile(ObjectDict): * ``body`` * ``content_type`` """ + pass -def _parse_request_range(range_header): +def _parse_request_range( + range_header: str, +) -> Optional[Tuple[Optional[int], Optional[int]]]: """Parses a Range header. Returns either ``None`` or tuple ``(start, end)``. @@ -708,7 +718,7 @@ def _parse_request_range(range_header): return (start, end) -def _get_content_range(start, end, total): +def _get_content_range(start: Optional[int], end: Optional[int], total: int) -> str: """Returns a suitable Content-Range header: >>> print(_get_content_range(None, 1, 4)) @@ -723,14 +733,20 @@ def _get_content_range(start, end, total): return "bytes %s-%s/%s" % (start, end, total) -def _int_or_none(val): +def _int_or_none(val: str) -> Optional[int]: val = val.strip() if val == "": return None return int(val) -def parse_body_arguments(content_type, body, arguments, files, headers=None): +def parse_body_arguments( + content_type: str, + body: bytes, + arguments: Dict[str, List[bytes]], + files: Dict[str, List[HTTPFile]], + headers: Optional[HTTPHeaders] = None, +) -> None: """Parses a form request body. Supports ``application/x-www-form-urlencoded`` and @@ -739,20 +755,27 @@ def parse_body_arguments(content_type, body, arguments, files, headers=None): and ``files`` parameters are dictionaries that will be updated with the parsed contents. """ - if headers and 'Content-Encoding' in headers: - gen_log.warning("Unsupported Content-Encoding: %s", - headers['Content-Encoding']) - return if content_type.startswith("application/x-www-form-urlencoded"): + if headers and "Content-Encoding" in headers: + gen_log.warning( + "Unsupported Content-Encoding: %s", headers["Content-Encoding"] + ) + return try: - uri_arguments = parse_qs_bytes(native_str(body), keep_blank_values=True) + # real charset decoding will happen in RequestHandler.decode_argument() + uri_arguments = parse_qs_bytes(body, keep_blank_values=True) except Exception as e: - gen_log.warning('Invalid x-www-form-urlencoded body: %s', e) + gen_log.warning("Invalid x-www-form-urlencoded body: %s", e) uri_arguments = {} for name, values in uri_arguments.items(): if values: arguments.setdefault(name, []).extend(values) elif content_type.startswith("multipart/form-data"): + if headers and "Content-Encoding" in headers: + gen_log.warning( + "Unsupported Content-Encoding: %s", headers["Content-Encoding"] + ) + return try: fields = content_type.split(";") for field in fields: @@ -766,12 +789,22 @@ def parse_body_arguments(content_type, body, arguments, files, headers=None): gen_log.warning("Invalid multipart/form-data: %s", e) -def parse_multipart_form_data(boundary, data, arguments, files): +def parse_multipart_form_data( + boundary: bytes, + data: bytes, + arguments: Dict[str, List[bytes]], + files: Dict[str, List[HTTPFile]], +) -> None: """Parses a ``multipart/form-data`` body. The ``boundary`` and ``data`` parameters are both byte strings. The dictionaries given in the arguments and files parameters will be updated with the contents of the body. + + .. versionchanged:: 5.1 + + Now recognizes non-ASCII filenames in RFC 2231/5987 + (``filename*=``) format. """ # The standard allows for the boundary to be quoted in the header, # although it's rare (it happens at least for google app engine @@ -798,21 +831,25 @@ def parse_multipart_form_data(boundary, data, arguments, files): if disposition != "form-data" or not part.endswith(b"\r\n"): gen_log.warning("Invalid multipart/form-data") continue - value = part[eoh + 4:-2] + value = part[eoh + 4 : -2] if not disp_params.get("name"): gen_log.warning("multipart/form-data value missing name") continue name = disp_params["name"] if disp_params.get("filename"): ctype = headers.get("Content-Type", "application/unknown") - files.setdefault(name, []).append(HTTPFile( # type: ignore - filename=disp_params["filename"], body=value, - content_type=ctype)) + files.setdefault(name, []).append( + HTTPFile( + filename=disp_params["filename"], body=value, content_type=ctype + ) + ) else: arguments.setdefault(name, []).append(value) -def format_timestamp(ts): +def format_timestamp( + ts: Union[int, float, tuple, time.struct_time, datetime.datetime] +) -> str: """Formats a timestamp in the format used by HTTP. The argument may be a numeric timestamp as returned by `time.time`, @@ -822,22 +859,26 @@ def format_timestamp(ts): >>> format_timestamp(1359312200) 'Sun, 27 Jan 2013 18:43:20 GMT' """ - if isinstance(ts, numbers.Real): - pass + if isinstance(ts, (int, float)): + time_num = ts elif isinstance(ts, (tuple, time.struct_time)): - ts = calendar.timegm(ts) + time_num = calendar.timegm(ts) elif isinstance(ts, datetime.datetime): - ts = calendar.timegm(ts.utctimetuple()) + time_num = calendar.timegm(ts.utctimetuple()) else: raise TypeError("unknown timestamp type: %r" % ts) - return email.utils.formatdate(ts, usegmt=True) + return email.utils.formatdate(time_num, usegmt=True) RequestStartLine = collections.namedtuple( - 'RequestStartLine', ['method', 'path', 'version']) + "RequestStartLine", ["method", "path", "version"] +) -def parse_request_start_line(line): +_http_version_re = re.compile(r"^HTTP/1\.[0-9]$") + + +def parse_request_start_line(line: str) -> RequestStartLine: """Returns a (method, path, version) tuple for an HTTP 1.x request line. The response is a `collections.namedtuple`. @@ -851,17 +892,22 @@ def parse_request_start_line(line): # https://tools.ietf.org/html/rfc7230#section-3.1.1 # invalid request-line SHOULD respond with a 400 (Bad Request) raise HTTPInputError("Malformed HTTP request line") - if not re.match(r"^HTTP/1\.[0-9]$", version): + if not _http_version_re.match(version): raise HTTPInputError( - "Malformed HTTP version in HTTP Request-Line: %r" % version) + "Malformed HTTP version in HTTP Request-Line: %r" % version + ) return RequestStartLine(method, path, version) ResponseStartLine = collections.namedtuple( - 'ResponseStartLine', ['version', 'code', 'reason']) + "ResponseStartLine", ["version", "code", "reason"] +) + + +_http_response_line_re = re.compile(r"(HTTP/1.[0-9]) ([0-9]+) ([^\r]*)") -def parse_response_start_line(line): +def parse_response_start_line(line: str) -> ResponseStartLine: """Returns a (version, code, reason) tuple for an HTTP 1.x response line. The response is a `collections.namedtuple`. @@ -870,25 +916,26 @@ def parse_response_start_line(line): ResponseStartLine(version='HTTP/1.1', code=200, reason='OK') """ line = native_str(line) - match = re.match("(HTTP/1.[0-9]) ([0-9]+) ([^\r]*)", line) + match = _http_response_line_re.match(line) if not match: raise HTTPInputError("Error parsing response start line") - return ResponseStartLine(match.group(1), int(match.group(2)), - match.group(3)) + return ResponseStartLine(match.group(1), int(match.group(2)), match.group(3)) + # _parseparam and _parse_header are copied and modified from python2.7's cgi.py # The original 2.7 version of this code did not correctly support some # combinations of semicolons and double quotes. # It has also been modified to support valueless parameters as seen in -# websocket extension negotiations. +# websocket extension negotiations, and to support non-ascii values in +# RFC 2231/5987 format. -def _parseparam(s): - while s[:1] == ';': +def _parseparam(s: str) -> Generator[str, None, None]: + while s[:1] == ";": s = s[1:] - end = s.find(';') + end = s.find(";") while end > 0 and (s.count('"', 0, end) - s.count('\\"', 0, end)) % 2: - end = s.find(';', end + 1) + end = s.find(";", end + 1) if end < 0: end = len(s) f = s[:end] @@ -896,30 +943,42 @@ def _parseparam(s): s = s[end:] -def _parse_header(line): - """Parse a Content-type like header. +def _parse_header(line: str) -> Tuple[str, Dict[str, str]]: + r"""Parse a Content-type like header. Return the main content-type and a dictionary of options. + >>> d = "form-data; foo=\"b\\\\a\\\"r\"; file*=utf-8''T%C3%A4st" + >>> ct, d = _parse_header(d) + >>> ct + 'form-data' + >>> d['file'] == r'T\u00e4st'.encode('ascii').decode('unicode_escape') + True + >>> d['foo'] + 'b\\a"r' """ - parts = _parseparam(';' + line) + parts = _parseparam(";" + line) key = next(parts) - pdict = {} + # decode_params treats first argument special, but we already stripped key + params = [("Dummy", "value")] for p in parts: - i = p.find('=') + i = p.find("=") if i >= 0: name = p[:i].strip().lower() - value = p[i + 1:].strip() - if len(value) >= 2 and value[0] == value[-1] == '"': - value = value[1:-1] - value = value.replace('\\\\', '\\').replace('\\"', '"') - pdict[name] = value - else: - pdict[p] = None + value = p[i + 1 :].strip() + params.append((name, native_str(value))) + decoded_params = email.utils.decode_params(params) + decoded_params.pop(0) # get rid of the dummy again + pdict = {} + for name, decoded_value in decoded_params: + value = email.utils.collapse_rfc2231_value(decoded_value) + if len(value) >= 2 and value[0] == '"' and value[-1] == '"': + value = value[1:-1] + pdict[name] = value return key, pdict -def _encode_header(key, pdict): +def _encode_header(key: str, pdict: Dict[str, str]) -> str: """Inverse of _parse_header. >>> _encode_header('permessage-deflate', @@ -935,33 +994,54 @@ def _encode_header(key, pdict): out.append(k) else: # TODO: quote if necessary. - out.append('%s=%s' % (k, v)) - return '; '.join(out) + out.append("%s=%s" % (k, v)) + return "; ".join(out) + + +def encode_username_password( + username: Union[str, bytes], password: Union[str, bytes] +) -> bytes: + """Encodes a username/password pair in the format used by HTTP auth. + + The return value is a byte string in the form ``username:password``. + + .. versionadded:: 5.1 + """ + if isinstance(username, unicode_type): + username = unicodedata.normalize("NFC", username) + if isinstance(password, unicode_type): + password = unicodedata.normalize("NFC", password) + return utf8(username) + b":" + utf8(password) def doctests(): + # type: () -> unittest.TestSuite import doctest + return doctest.DocTestSuite() -def split_host_and_port(netloc): +_netloc_re = re.compile(r"^(.+):(\d+)$") + + +def split_host_and_port(netloc: str) -> Tuple[str, Optional[int]]: """Returns ``(host, port)`` tuple from ``netloc``. Returned ``port`` will be ``None`` if not present. .. versionadded:: 4.1 """ - match = re.match(r'^(.+):(\d+)$', netloc) + match = _netloc_re.match(netloc) if match: host = match.group(1) - port = int(match.group(2)) + port = int(match.group(2)) # type: Optional[int] else: host = netloc port = None return (host, port) -def qs_to_qsl(qs): +def qs_to_qsl(qs: Dict[str, List[AnyStr]]) -> Iterable[Tuple[str, AnyStr]]: """Generator converting a result of ``parse_qs`` back to name-value pairs. .. versionadded:: 5.0 @@ -973,10 +1053,10 @@ def qs_to_qsl(qs): _OctalPatt = re.compile(r"\\[0-3][0-7][0-7]") _QuotePatt = re.compile(r"[\\].") -_nulljoin = ''.join +_nulljoin = "".join -def _unquote_cookie(str): +def _unquote_cookie(s: str) -> str: """Handle double quotes and escaping in cookie values. This method is copied verbatim from the Python 3.5 standard @@ -985,29 +1065,29 @@ def _unquote_cookie(str): """ # If there aren't any doublequotes, # then there can't be any special characters. See RFC 2109. - if str is None or len(str) < 2: - return str - if str[0] != '"' or str[-1] != '"': - return str + if s is None or len(s) < 2: + return s + if s[0] != '"' or s[-1] != '"': + return s # We have to assume that we must decode this string. # Down to work. # Remove the "s - str = str[1:-1] + s = s[1:-1] # Check for special sequences. Examples: # \012 --> \n # \" --> " # i = 0 - n = len(str) + n = len(s) res = [] while 0 <= i < n: - o_match = _OctalPatt.search(str, i) - q_match = _QuotePatt.search(str, i) - if not o_match and not q_match: # Neither matched - res.append(str[i:]) + o_match = _OctalPatt.search(s, i) + q_match = _QuotePatt.search(s, i) + if not o_match and not q_match: # Neither matched + res.append(s[i:]) break # else: j = k = -1 @@ -1015,18 +1095,18 @@ def _unquote_cookie(str): j = o_match.start(0) if q_match: k = q_match.start(0) - if q_match and (not o_match or k < j): # QuotePatt matched - res.append(str[i:k]) - res.append(str[k + 1]) + if q_match and (not o_match or k < j): # QuotePatt matched + res.append(s[i:k]) + res.append(s[k + 1]) i = k + 2 - else: # OctalPatt matched - res.append(str[i:j]) - res.append(chr(int(str[j + 1:j + 4], 8))) + else: # OctalPatt matched + res.append(s[i:j]) + res.append(chr(int(s[j + 1 : j + 4], 8))) i = j + 4 return _nulljoin(res) -def parse_cookie(cookie): +def parse_cookie(cookie: str) -> Dict[str, str]: """Parse a ``Cookie`` HTTP header into a dict of name/value pairs. This function attempts to mimic browser cookie parsing behavior; @@ -1038,13 +1118,13 @@ def parse_cookie(cookie): .. versionadded:: 4.4.2 """ cookiedict = {} - for chunk in cookie.split(str(';')): - if str('=') in chunk: - key, val = chunk.split(str('='), 1) + for chunk in cookie.split(str(";")): + if str("=") in chunk: + key, val = chunk.split(str("="), 1) else: # Assume an empty name per # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 - key, val = str(''), chunk + key, val = str(""), chunk key, val = key.strip(), val.strip() if key or val: # unquote using Python's algorithm. diff --git a/tornado/ioloop.py b/tornado/ioloop.py index 55623f34e4..2cf884450c 100644 --- a/tornado/ioloop.py +++ b/tornado/ioloop.py @@ -15,7 +15,11 @@ """An I/O event loop for non-blocking sockets. -On Python 3, `.IOLoop` is a wrapper around the `asyncio` event loop. +In Tornado 6.0, `.IOLoop` is a wrapper around the `asyncio` event +loop, with a slightly different interface for historical reasons. +Applications can use either the `.IOLoop` interface or the underlying +`asyncio` event loop directly (unless compatibility with older +versions of Tornado is desired, in which case `.IOLoop` must be used). Typical applications will use a single `IOLoop` object, accessed via `IOLoop.current` class method. The `IOLoop.start` method (or @@ -24,73 +28,58 @@ may use more than one `IOLoop`, such as one `IOLoop` per thread, or per `unittest` case. -In addition to I/O events, the `IOLoop` can also schedule time-based -events. `IOLoop.add_timeout` is a non-blocking alternative to -`time.sleep`. - """ -from __future__ import absolute_import, division, print_function - -import collections +import asyncio +import concurrent.futures import datetime -import errno import functools -import heapq -import itertools import logging import numbers import os -import select import sys -import threading import time -import traceback import math import random -from tornado.concurrent import Future, is_future, chain_future, future_set_exc_info, future_add_done_callback # noqa: E501 -from tornado.log import app_log, gen_log -from tornado.platform.auto import set_close_exec, Waker -from tornado import stack_context -from tornado.util import ( - PY3, Configurable, errno_from_exception, timedelta_to_seconds, - TimeoutError, unicode_type, import_object, +from tornado.concurrent import ( + Future, + is_future, + chain_future, + future_set_exc_info, + future_add_done_callback, ) +from tornado.log import app_log +from tornado.util import Configurable, TimeoutError, import_object -try: - import signal -except ImportError: - signal = None +import typing +from typing import Union, Any, Type, Optional, Callable, TypeVar, Tuple, Awaitable -try: - from concurrent.futures import ThreadPoolExecutor -except ImportError: - ThreadPoolExecutor = None +if typing.TYPE_CHECKING: + from typing import Dict, List # noqa: F401 -if PY3: - import _thread as thread + from typing_extensions import Protocol else: - import thread + Protocol = object -try: - import asyncio -except ImportError: - asyncio = None +class _Selectable(Protocol): + def fileno(self) -> int: + pass + + def close(self) -> None: + pass -_POLL_TIMEOUT = 3600.0 + +_T = TypeVar("_T") +_S = TypeVar("_S", bound=_Selectable) class IOLoop(Configurable): - """A level-triggered I/O loop. + """An I/O event loop. - On Python 3, `IOLoop` is a wrapper around the `asyncio` event - loop. On Python 2, it uses ``epoll`` (Linux) or ``kqueue`` (BSD - and Mac OS X) if they are available, or else we fall back on - select(). If you are implementing a system that needs to handle - thousands of simultaneous connections, you should use a system - that supports either ``epoll`` or ``kqueue``. + As of Tornado 6.0, `IOLoop` is a wrapper around the `asyncio` event + loop. Example usage for a simple TCP server: @@ -101,25 +90,22 @@ class IOLoop(Configurable): import socket import tornado.ioloop - from tornado import gen from tornado.iostream import IOStream - @gen.coroutine - def handle_connection(connection, address): + async def handle_connection(connection, address): stream = IOStream(connection) - message = yield stream.read_until_close() + message = await stream.read_until_close() print("message from client:", message.decode().strip()) def connection_ready(sock, fd, events): while True: try: connection, address = sock.accept() - except socket.error as e: - if e.args[0] not in (errno.EWOULDBLOCK, errno.EAGAIN): - raise + except BlockingIOError: return connection.setblocking(0) - handle_connection(connection, address) + io_loop = tornado.ioloop.IOLoop.current() + io_loop.spawn_callback(handle_connection, connection, address) if __name__ == '__main__': sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) @@ -165,42 +151,33 @@ def connection_ready(sock, fd, events): to redundantly specify the `asyncio` event loop. """ - # Constants from the epoll module - _EPOLLIN = 0x001 - _EPOLLPRI = 0x002 - _EPOLLOUT = 0x004 - _EPOLLERR = 0x008 - _EPOLLHUP = 0x010 - _EPOLLRDHUP = 0x2000 - _EPOLLONESHOT = (1 << 30) - _EPOLLET = (1 << 31) - - # Our events map exactly to the epoll events - NONE = 0 - READ = _EPOLLIN - WRITE = _EPOLLOUT - ERROR = _EPOLLERR | _EPOLLHUP - # In Python 2, _current.instance points to the current IOLoop. - _current = threading.local() + # These constants were originally based on constants from the epoll module. + NONE = 0 + READ = 0x001 + WRITE = 0x004 + ERROR = 0x018 # In Python 3, _ioloop_for_asyncio maps from asyncio loops to IOLoops. - _ioloop_for_asyncio = dict() + _ioloop_for_asyncio = dict() # type: Dict[asyncio.AbstractEventLoop, IOLoop] @classmethod - def configure(cls, impl, **kwargs): + def configure( + cls, impl: "Union[None, str, Type[Configurable]]", **kwargs: Any + ) -> None: if asyncio is not None: from tornado.platform.asyncio import BaseAsyncIOLoop - if isinstance(impl, (str, unicode_type)): + if isinstance(impl, str): impl = import_object(impl) - if not issubclass(impl, BaseAsyncIOLoop): + if isinstance(impl, type) and not issubclass(impl, BaseAsyncIOLoop): raise RuntimeError( - "only AsyncIOLoop is allowed when asyncio is available") + "only AsyncIOLoop is allowed when asyncio is available" + ) super(IOLoop, cls).configure(impl, **kwargs) @staticmethod - def instance(): + def instance() -> "IOLoop": """Deprecated alias for `IOLoop.current()`. .. versionchanged:: 5.0 @@ -221,7 +198,7 @@ def instance(): """ return IOLoop.current() - def install(self): + def install(self) -> None: """Deprecated alias for `make_current()`. .. versionchanged:: 5.0 @@ -236,7 +213,7 @@ def install(self): self.make_current() @staticmethod - def clear_instance(): + def clear_instance() -> None: """Deprecated alias for `clear_current()`. .. versionchanged:: 5.0 @@ -251,8 +228,18 @@ def clear_instance(): """ IOLoop.clear_current() + @typing.overload @staticmethod - def current(instance=True): + def current() -> "IOLoop": + pass + + @typing.overload + @staticmethod + def current(instance: bool = True) -> Optional["IOLoop"]: # noqa: F811 + pass + + @staticmethod + def current(instance: bool = True) -> Optional["IOLoop"]: # noqa: F811 """Returns the current thread's `IOLoop`. If an `IOLoop` is currently running or has been marked as @@ -272,30 +259,24 @@ def current(instance=True): since even if we do not create an `IOLoop`, this method may initialize the asyncio loop. """ - if asyncio is None: - current = getattr(IOLoop._current, "instance", None) - if current is None and instance: - current = IOLoop() - if IOLoop._current.instance is not current: - raise RuntimeError("new IOLoop did not become current") - else: - try: - loop = asyncio.get_event_loop() - except (RuntimeError, AssertionError): - if not instance: - return None - raise - try: - return IOLoop._ioloop_for_asyncio[loop] - except KeyError: - if instance: - from tornado.platform.asyncio import AsyncIOMainLoop - current = AsyncIOMainLoop(make_current=True) - else: - current = None + try: + loop = asyncio.get_event_loop() + except (RuntimeError, AssertionError): + if not instance: + return None + raise + try: + return IOLoop._ioloop_for_asyncio[loop] + except KeyError: + if instance: + from tornado.platform.asyncio import AsyncIOMainLoop + + current = AsyncIOMainLoop(make_current=True) # type: Optional[IOLoop] + else: + current = None return current - def make_current(self): + def make_current(self) -> None: """Makes this the `IOLoop` for the current thread. An `IOLoop` automatically becomes current for its thread @@ -312,14 +293,10 @@ def make_current(self): This method also sets the current `asyncio` event loop. """ # The asyncio event loops override this method. - assert asyncio is None - old = getattr(IOLoop._current, "instance", None) - if old is not None: - old.clear_current() - IOLoop._current.instance = self + raise NotImplementedError() @staticmethod - def clear_current(): + def clear_current() -> None: """Clears the `IOLoop` for the current thread. Intended primarily for use by test frameworks in between tests. @@ -333,7 +310,7 @@ def clear_current(): if asyncio is None: IOLoop._current.instance = None - def _clear_current_hook(self): + def _clear_current_hook(self) -> None: """Instance method called when an IOLoop ceases to be current. May be overridden by subclasses as a counterpart to make_current. @@ -341,17 +318,16 @@ def _clear_current_hook(self): pass @classmethod - def configurable_base(cls): + def configurable_base(cls) -> Type[Configurable]: return IOLoop @classmethod - def configurable_default(cls): - if asyncio is not None: - from tornado.platform.asyncio import AsyncIOLoop - return AsyncIOLoop - return PollIOLoop + def configurable_default(cls) -> Type[Configurable]: + from tornado.platform.asyncio import AsyncIOLoop + + return AsyncIOLoop - def initialize(self, make_current=None): + def initialize(self, make_current: Optional[bool] = None) -> None: if make_current is None: if IOLoop.current(instance=False) is None: self.make_current() @@ -362,7 +338,7 @@ def initialize(self, make_current=None): raise RuntimeError("current IOLoop already exists") self.make_current() - def close(self, all_fds=False): + def close(self, all_fds: bool = False) -> None: """Closes the `IOLoop`, freeing any resources used. If ``all_fds`` is true, all file descriptors registered on the @@ -389,13 +365,25 @@ def close(self, all_fds=False): """ raise NotImplementedError() - def add_handler(self, fd, handler, events): + @typing.overload + def add_handler( + self, fd: int, handler: Callable[[int, int], None], events: int + ) -> None: + pass + + @typing.overload # noqa: F811 + def add_handler( + self, fd: _S, handler: Callable[[_S, int], None], events: int + ) -> None: + pass + + def add_handler( # noqa: F811 + self, fd: Union[int, _Selectable], handler: Callable[..., None], events: int + ) -> None: """Registers the given handler to receive the given events for ``fd``. The ``fd`` argument may either be an integer file descriptor or - a file-like object with a ``fileno()`` method (and optionally a - ``close()`` method, which may be called when the `IOLoop` is shut - down). + a file-like object with a ``fileno()`` and ``close()`` method. The ``events`` argument is a bitwise or of the constants ``IOLoop.READ``, ``IOLoop.WRITE``, and ``IOLoop.ERROR``. @@ -408,7 +396,7 @@ def add_handler(self, fd, handler, events): """ raise NotImplementedError() - def update_handler(self, fd, events): + def update_handler(self, fd: Union[int, _Selectable], events: int) -> None: """Changes the events we listen for ``fd``. .. versionchanged:: 4.0 @@ -417,7 +405,7 @@ def update_handler(self, fd, events): """ raise NotImplementedError() - def remove_handler(self, fd): + def remove_handler(self, fd: Union[int, _Selectable]) -> None: """Stop listening for events on ``fd``. .. versionchanged:: 4.0 @@ -426,55 +414,7 @@ def remove_handler(self, fd): """ raise NotImplementedError() - def set_blocking_signal_threshold(self, seconds, action): - """Sends a signal if the `IOLoop` is blocked for more than - ``s`` seconds. - - Pass ``seconds=None`` to disable. Requires Python 2.6 on a unixy - platform. - - The action parameter is a Python signal handler. Read the - documentation for the `signal` module for more information. - If ``action`` is None, the process will be killed if it is - blocked for too long. - - .. deprecated:: 5.0 - - Not implemented on the `asyncio` event loop. Use the environment - variable ``PYTHONASYNCIODEBUG=1`` instead. This method will be - removed in Tornado 6.0. - """ - raise NotImplementedError() - - def set_blocking_log_threshold(self, seconds): - """Logs a stack trace if the `IOLoop` is blocked for more than - ``s`` seconds. - - Equivalent to ``set_blocking_signal_threshold(seconds, - self.log_stack)`` - - .. deprecated:: 5.0 - - Not implemented on the `asyncio` event loop. Use the environment - variable ``PYTHONASYNCIODEBUG=1`` instead. This method will be - removed in Tornado 6.0. - """ - self.set_blocking_signal_threshold(seconds, self.log_stack) - - def log_stack(self, signal, frame): - """Signal handler to log the stack trace of the current thread. - - For use with `set_blocking_signal_threshold`. - - .. deprecated:: 5.1 - - This method will be removed in Tornado 6.0. - """ - gen_log.warning('IOLoop blocked for %f seconds in\n%s', - self._blocking_signal_threshold, - ''.join(traceback.format_stack(frame))) - - def start(self): + def start(self) -> None: """Starts the I/O loop. The loop will run until one of the callbacks calls `stop()`, which @@ -482,7 +422,7 @@ def start(self): """ raise NotImplementedError() - def _setup_logging(self): + def _setup_logging(self) -> None: """The IOLoop catches and logs exceptions, so it's important that log output be visible. However, python's default behavior for non-root loggers (prior to python @@ -493,28 +433,21 @@ def _setup_logging(self): This method should be called from start() in subclasses. """ - if not any([logging.getLogger().handlers, - logging.getLogger('tornado').handlers, - logging.getLogger('tornado.application').handlers]): + if not any( + [ + logging.getLogger().handlers, + logging.getLogger("tornado").handlers, + logging.getLogger("tornado.application").handlers, + ] + ): logging.basicConfig() - def stop(self): + def stop(self) -> None: """Stop the I/O loop. If the event loop is not currently running, the next call to `start()` will return immediately. - To use asynchronous methods from otherwise-synchronous code (such as - unit tests), you can start and stop the event loop like this:: - - ioloop = IOLoop() - async_method(ioloop=ioloop, callback=ioloop.stop) - ioloop.start() - - ``ioloop.start()`` will return after ``async_method`` has run - its callback, whether that callback was invoked before or - after ``ioloop.start``. - Note that even after `stop` has been called, the `IOLoop` is not completely stopped until `IOLoop.start` has also returned. Some work that was scheduled before the call to `stop` may still @@ -522,13 +455,13 @@ def stop(self): """ raise NotImplementedError() - def run_sync(self, func, timeout=None): + def run_sync(self, func: Callable, timeout: Optional[float] = None) -> Any: """Starts the `IOLoop`, runs the given function, and stops the loop. - The function must return either a yieldable object or - ``None``. If the function returns a yieldable object, the - `IOLoop` will run until the yieldable is resolved (and - `run_sync()` will return the yieldable's result). If it raises + The function must return either an awaitable object or + ``None``. If the function returns an awaitable object, the + `IOLoop` will run until the awaitable is resolved (and + `run_sync()` will return the awaitable's result). If it raises an exception, the `IOLoop` will stop and the exception will be re-raised to the caller. @@ -536,73 +469,87 @@ def run_sync(self, func, timeout=None): a maximum duration for the function. If the timeout expires, a `tornado.util.TimeoutError` is raised. - This method is useful in conjunction with `tornado.gen.coroutine` - to allow asynchronous calls in a ``main()`` function:: + This method is useful to allow asynchronous calls in a + ``main()`` function:: - @gen.coroutine - def main(): + async def main(): # do stuff... if __name__ == '__main__': IOLoop.current().run_sync(main) .. versionchanged:: 4.3 - Returning a non-``None``, non-yieldable value is now an error. + Returning a non-``None``, non-awaitable value is now an error. .. versionchanged:: 5.0 If a timeout occurs, the ``func`` coroutine will be cancelled. + """ - future_cell = [None] + future_cell = [None] # type: List[Optional[Future]] - def run(): + def run() -> None: try: result = func() if result is not None: from tornado.gen import convert_yielded + result = convert_yielded(result) except Exception: - future_cell[0] = Future() - future_set_exc_info(future_cell[0], sys.exc_info()) + fut = Future() # type: Future[Any] + future_cell[0] = fut + future_set_exc_info(fut, sys.exc_info()) else: if is_future(result): future_cell[0] = result else: - future_cell[0] = Future() - future_cell[0].set_result(result) + fut = Future() + future_cell[0] = fut + fut.set_result(result) + assert future_cell[0] is not None self.add_future(future_cell[0], lambda future: self.stop()) + self.add_callback(run) if timeout is not None: - def timeout_callback(): + + def timeout_callback() -> None: # If we can cancel the future, do so and wait on it. If not, # Just stop the loop and return with the task still pending. # (If we neither cancel nor wait for the task, a warning # will be logged). + assert future_cell[0] is not None if not future_cell[0].cancel(): self.stop() + timeout_handle = self.add_timeout(self.time() + timeout, timeout_callback) self.start() if timeout is not None: self.remove_timeout(timeout_handle) + assert future_cell[0] is not None if future_cell[0].cancelled() or not future_cell[0].done(): - raise TimeoutError('Operation timed out after %s seconds' % timeout) + raise TimeoutError("Operation timed out after %s seconds" % timeout) return future_cell[0].result() - def time(self): + def time(self) -> float: """Returns the current time according to the `IOLoop`'s clock. The return value is a floating-point number relative to an unspecified time in the past. - By default, the `IOLoop`'s time function is `time.time`. However, - it may be configured to use e.g. `time.monotonic` instead. - Calls to `add_timeout` that pass a number instead of a - `datetime.timedelta` should use this function to compute the - appropriate time, so they can work no matter what time function - is chosen. + Historically, the IOLoop could be customized to use e.g. + `time.monotonic` instead of `time.time`, but this is not + currently supported and so this method is equivalent to + `time.time`. + """ return time.time() - def add_timeout(self, deadline, callback, *args, **kwargs): + def add_timeout( + self, + deadline: Union[float, datetime.timedelta], + callback: Callable[..., None], + *args: Any, + **kwargs: Any + ) -> object: """Runs the ``callback`` at the time ``deadline`` from the I/O loop. Returns an opaque handle that may be passed to @@ -631,12 +578,15 @@ def add_timeout(self, deadline, callback, *args, **kwargs): if isinstance(deadline, numbers.Real): return self.call_at(deadline, callback, *args, **kwargs) elif isinstance(deadline, datetime.timedelta): - return self.call_at(self.time() + timedelta_to_seconds(deadline), - callback, *args, **kwargs) + return self.call_at( + self.time() + deadline.total_seconds(), callback, *args, **kwargs + ) else: raise TypeError("Unsupported deadline %r" % deadline) - def call_later(self, delay, callback, *args, **kwargs): + def call_later( + self, delay: float, callback: Callable[..., None], *args: Any, **kwargs: Any + ) -> object: """Runs the ``callback`` after ``delay`` seconds have passed. Returns an opaque handle that may be passed to `remove_timeout` @@ -649,7 +599,9 @@ def call_later(self, delay, callback, *args, **kwargs): """ return self.call_at(self.time() + delay, callback, *args, **kwargs) - def call_at(self, when, callback, *args, **kwargs): + def call_at( + self, when: float, callback: Callable[..., None], *args: Any, **kwargs: Any + ) -> object: """Runs the ``callback`` at the absolute time designated by ``when``. ``when`` must be a number using the same reference point as @@ -665,7 +617,7 @@ def call_at(self, when, callback, *args, **kwargs): """ return self.add_timeout(when, callback, *args, **kwargs) - def remove_timeout(self, timeout): + def remove_timeout(self, timeout: object) -> None: """Cancels a pending timeout. The argument is a handle as returned by `add_timeout`. It is @@ -674,7 +626,7 @@ def remove_timeout(self, timeout): """ raise NotImplementedError() - def add_callback(self, callback, *args, **kwargs): + def add_callback(self, callback: Callable, *args: Any, **kwargs: Any) -> None: """Calls the given callback on the next I/O loop iteration. It is safe to call this method from any thread at any time, @@ -689,44 +641,66 @@ def add_callback(self, callback, *args, **kwargs): """ raise NotImplementedError() - def add_callback_from_signal(self, callback, *args, **kwargs): + def add_callback_from_signal( + self, callback: Callable, *args: Any, **kwargs: Any + ) -> None: """Calls the given callback on the next I/O loop iteration. Safe for use from a Python signal handler; should not be used otherwise. - - Callbacks added with this method will be run without any - `.stack_context`, to avoid picking up the context of the function - that was interrupted by the signal. """ raise NotImplementedError() - def spawn_callback(self, callback, *args, **kwargs): + def spawn_callback(self, callback: Callable, *args: Any, **kwargs: Any) -> None: """Calls the given callback on the next IOLoop iteration. - Unlike all other callback-related methods on IOLoop, - ``spawn_callback`` does not associate the callback with its caller's - ``stack_context``, so it is suitable for fire-and-forget callbacks - that should not interfere with the caller. + As of Tornado 6.0, this method is equivalent to `add_callback`. .. versionadded:: 4.0 """ - with stack_context.NullContext(): - self.add_callback(callback, *args, **kwargs) + self.add_callback(callback, *args, **kwargs) - def add_future(self, future, callback): + def add_future( + self, + future: "Union[Future[_T], concurrent.futures.Future[_T]]", + callback: Callable[["Future[_T]"], None], + ) -> None: """Schedules a callback on the ``IOLoop`` when the given `.Future` is finished. The callback is invoked with one argument, the `.Future`. - """ - assert is_future(future) - callback = stack_context.wrap(callback) - future_add_done_callback( - future, lambda future: self.add_callback(callback, future)) - def run_in_executor(self, executor, func, *args): + This method only accepts `.Future` objects and not other + awaitables (unlike most of Tornado where the two are + interchangeable). + """ + if isinstance(future, Future): + # Note that we specifically do not want the inline behavior of + # tornado.concurrent.future_add_done_callback. We always want + # this callback scheduled on the next IOLoop iteration (which + # asyncio.Future always does). + # + # Wrap the callback in self._run_callback so we control + # the error logging (i.e. it goes to tornado.log.app_log + # instead of asyncio's log). + future.add_done_callback( + lambda f: self._run_callback(functools.partial(callback, future)) + ) + else: + assert is_future(future) + # For concurrent futures, we use self.add_callback, so + # it's fine if future_add_done_callback inlines that call. + future_add_done_callback( + future, lambda f: self.add_callback(callback, future) + ) + + def run_in_executor( + self, + executor: Optional[concurrent.futures.Executor], + func: Callable[..., _T], + *args: Any + ) -> Awaitable[_T]: """Runs a function in a ``concurrent.futures.Executor``. If ``executor`` is ``None``, the IO loop's default executor will be used. @@ -734,38 +708,40 @@ def run_in_executor(self, executor, func, *args): .. versionadded:: 5.0 """ - if ThreadPoolExecutor is None: - raise RuntimeError( - "concurrent.futures is required to use IOLoop.run_in_executor") - if executor is None: - if not hasattr(self, '_executor'): + if not hasattr(self, "_executor"): from tornado.process import cpu_count - self._executor = ThreadPoolExecutor(max_workers=(cpu_count() * 5)) + + self._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=(cpu_count() * 5) + ) # type: concurrent.futures.Executor executor = self._executor c_future = executor.submit(func, *args) # Concurrent Futures are not usable with await. Wrap this in a # Tornado Future instead, using self.add_future for thread-safety. - t_future = Future() + t_future = Future() # type: Future[_T] self.add_future(c_future, lambda f: chain_future(f, t_future)) return t_future - def set_default_executor(self, executor): + def set_default_executor(self, executor: concurrent.futures.Executor) -> None: """Sets the default executor to use with :meth:`run_in_executor`. .. versionadded:: 5.0 """ self._executor = executor - def _run_callback(self, callback): + def _run_callback(self, callback: Callable[[], Any]) -> None: """Runs a callback with error handling. - For use in subclasses. + .. versionchanged:: 6.0 + + CancelledErrors are no longer logged. """ try: ret = callback() if ret is not None: from tornado import gen + # Functions that return Futures typically swallow all # exceptions and store them in the Future. If a Future # makes it out to the IOLoop, ensure its exception (if any) @@ -779,395 +755,84 @@ def _run_callback(self, callback): pass else: self.add_future(ret, self._discard_future_result) + except asyncio.CancelledError: + pass except Exception: - self.handle_callback_exception(callback) + app_log.error("Exception in callback %r", callback, exc_info=True) - def _discard_future_result(self, future): + def _discard_future_result(self, future: Future) -> None: """Avoid unhandled-exception warnings from spawned coroutines.""" future.result() - def handle_callback_exception(self, callback): - """This method is called whenever a callback run by the `IOLoop` - throws an exception. - - By default simply logs the exception as an error. Subclasses - may override this method to customize reporting of exceptions. - - The exception itself is not passed explicitly, but is available - in `sys.exc_info`. - - .. versionchanged:: 5.0 - - When the `asyncio` event loop is used (which is now the - default on Python 3), some callback errors will be handled by - `asyncio` instead of this method. - - .. deprecated: 5.1 - - Support for this method will be removed in Tornado 6.0. - """ - app_log.error("Exception in callback %r", callback, exc_info=True) - - def split_fd(self, fd): - """Returns an (fd, obj) pair from an ``fd`` parameter. - - We accept both raw file descriptors and file-like objects as - input to `add_handler` and related methods. When a file-like - object is passed, we must retain the object itself so we can - close it correctly when the `IOLoop` shuts down, but the - poller interfaces favor file descriptors (they will accept - file-like objects and call ``fileno()`` for you, but they - always return the descriptor itself). - - This method is provided for use by `IOLoop` subclasses and should - not generally be used by application code. - - .. versionadded:: 4.0 - """ - try: - return fd.fileno(), fd - except AttributeError: + def split_fd( + self, fd: Union[int, _Selectable] + ) -> Tuple[int, Union[int, _Selectable]]: + # """Returns an (fd, obj) pair from an ``fd`` parameter. + + # We accept both raw file descriptors and file-like objects as + # input to `add_handler` and related methods. When a file-like + # object is passed, we must retain the object itself so we can + # close it correctly when the `IOLoop` shuts down, but the + # poller interfaces favor file descriptors (they will accept + # file-like objects and call ``fileno()`` for you, but they + # always return the descriptor itself). + + # This method is provided for use by `IOLoop` subclasses and should + # not generally be used by application code. + + # .. versionadded:: 4.0 + # """ + if isinstance(fd, int): return fd, fd + return fd.fileno(), fd - def close_fd(self, fd): - """Utility method to close an ``fd``. + def close_fd(self, fd: Union[int, _Selectable]) -> None: + # """Utility method to close an ``fd``. - If ``fd`` is a file-like object, we close it directly; otherwise - we use `os.close`. + # If ``fd`` is a file-like object, we close it directly; otherwise + # we use `os.close`. - This method is provided for use by `IOLoop` subclasses (in - implementations of ``IOLoop.close(all_fds=True)`` and should - not generally be used by application code. + # This method is provided for use by `IOLoop` subclasses (in + # implementations of ``IOLoop.close(all_fds=True)`` and should + # not generally be used by application code. - .. versionadded:: 4.0 - """ + # .. versionadded:: 4.0 + # """ try: - try: - fd.close() - except AttributeError: + if isinstance(fd, int): os.close(fd) + else: + fd.close() except OSError: pass -class PollIOLoop(IOLoop): - """Base class for IOLoops built around a select-like function. - - For concrete implementations, see `tornado.platform.epoll.EPollIOLoop` - (Linux), `tornado.platform.kqueue.KQueueIOLoop` (BSD and Mac), or - `tornado.platform.select.SelectIOLoop` (all platforms). - """ - def initialize(self, impl, time_func=None, **kwargs): - super(PollIOLoop, self).initialize(**kwargs) - self._impl = impl - if hasattr(self._impl, 'fileno'): - set_close_exec(self._impl.fileno()) - self.time_func = time_func or time.time - self._handlers = {} - self._events = {} - self._callbacks = collections.deque() - self._timeouts = [] - self._cancellations = 0 - self._running = False - self._stopped = False - self._closing = False - self._thread_ident = None - self._pid = os.getpid() - self._blocking_signal_threshold = None - self._timeout_counter = itertools.count() - - # Create a pipe that we send bogus data to when we want to wake - # the I/O loop when it is idle - self._waker = Waker() - self.add_handler(self._waker.fileno(), - lambda fd, events: self._waker.consume(), - self.READ) - - @classmethod - def configurable_base(cls): - return PollIOLoop - - @classmethod - def configurable_default(cls): - if hasattr(select, "epoll"): - from tornado.platform.epoll import EPollIOLoop - return EPollIOLoop - if hasattr(select, "kqueue"): - # Python 2.6+ on BSD or Mac - from tornado.platform.kqueue import KQueueIOLoop - return KQueueIOLoop - from tornado.platform.select import SelectIOLoop - return SelectIOLoop - - def close(self, all_fds=False): - self._closing = True - self.remove_handler(self._waker.fileno()) - if all_fds: - for fd, handler in list(self._handlers.values()): - self.close_fd(fd) - self._waker.close() - self._impl.close() - self._callbacks = None - self._timeouts = None - if hasattr(self, '_executor'): - self._executor.shutdown() - - def add_handler(self, fd, handler, events): - fd, obj = self.split_fd(fd) - self._handlers[fd] = (obj, stack_context.wrap(handler)) - self._impl.register(fd, events | self.ERROR) - - def update_handler(self, fd, events): - fd, obj = self.split_fd(fd) - self._impl.modify(fd, events | self.ERROR) - - def remove_handler(self, fd): - fd, obj = self.split_fd(fd) - self._handlers.pop(fd, None) - self._events.pop(fd, None) - try: - self._impl.unregister(fd) - except Exception: - gen_log.debug("Error deleting fd from IOLoop", exc_info=True) - - def set_blocking_signal_threshold(self, seconds, action): - if not hasattr(signal, "setitimer"): - gen_log.error("set_blocking_signal_threshold requires a signal module " - "with the setitimer method") - return - self._blocking_signal_threshold = seconds - if seconds is not None: - signal.signal(signal.SIGALRM, - action if action is not None else signal.SIG_DFL) - - def start(self): - if self._running: - raise RuntimeError("IOLoop is already running") - if os.getpid() != self._pid: - raise RuntimeError("Cannot share PollIOLoops across processes") - self._setup_logging() - if self._stopped: - self._stopped = False - return - old_current = IOLoop.current(instance=False) - if old_current is not self: - self.make_current() - self._thread_ident = thread.get_ident() - self._running = True - - # signal.set_wakeup_fd closes a race condition in event loops: - # a signal may arrive at the beginning of select/poll/etc - # before it goes into its interruptible sleep, so the signal - # will be consumed without waking the select. The solution is - # for the (C, synchronous) signal handler to write to a pipe, - # which will then be seen by select. - # - # In python's signal handling semantics, this only matters on the - # main thread (fortunately, set_wakeup_fd only works on the main - # thread and will raise a ValueError otherwise). - # - # If someone has already set a wakeup fd, we don't want to - # disturb it. This is an issue for twisted, which does its - # SIGCHLD processing in response to its own wakeup fd being - # written to. As long as the wakeup fd is registered on the IOLoop, - # the loop will still wake up and everything should work. - old_wakeup_fd = None - if hasattr(signal, 'set_wakeup_fd') and os.name == 'posix': - # requires python 2.6+, unix. set_wakeup_fd exists but crashes - # the python process on windows. - try: - old_wakeup_fd = signal.set_wakeup_fd(self._waker.write_fileno()) - if old_wakeup_fd != -1: - # Already set, restore previous value. This is a little racy, - # but there's no clean get_wakeup_fd and in real use the - # IOLoop is just started once at the beginning. - signal.set_wakeup_fd(old_wakeup_fd) - old_wakeup_fd = None - except ValueError: - # Non-main thread, or the previous value of wakeup_fd - # is no longer valid. - old_wakeup_fd = None - - try: - while True: - # Prevent IO event starvation by delaying new callbacks - # to the next iteration of the event loop. - ncallbacks = len(self._callbacks) - - # Add any timeouts that have come due to the callback list. - # Do not run anything until we have determined which ones - # are ready, so timeouts that call add_timeout cannot - # schedule anything in this iteration. - due_timeouts = [] - if self._timeouts: - now = self.time() - while self._timeouts: - if self._timeouts[0].callback is None: - # The timeout was cancelled. Note that the - # cancellation check is repeated below for timeouts - # that are cancelled by another timeout or callback. - heapq.heappop(self._timeouts) - self._cancellations -= 1 - elif self._timeouts[0].deadline <= now: - due_timeouts.append(heapq.heappop(self._timeouts)) - else: - break - if (self._cancellations > 512 and - self._cancellations > (len(self._timeouts) >> 1)): - # Clean up the timeout queue when it gets large and it's - # more than half cancellations. - self._cancellations = 0 - self._timeouts = [x for x in self._timeouts - if x.callback is not None] - heapq.heapify(self._timeouts) - - for i in range(ncallbacks): - self._run_callback(self._callbacks.popleft()) - for timeout in due_timeouts: - if timeout.callback is not None: - self._run_callback(timeout.callback) - # Closures may be holding on to a lot of memory, so allow - # them to be freed before we go into our poll wait. - due_timeouts = timeout = None - - if self._callbacks: - # If any callbacks or timeouts called add_callback, - # we don't want to wait in poll() before we run them. - poll_timeout = 0.0 - elif self._timeouts: - # If there are any timeouts, schedule the first one. - # Use self.time() instead of 'now' to account for time - # spent running callbacks. - poll_timeout = self._timeouts[0].deadline - self.time() - poll_timeout = max(0, min(poll_timeout, _POLL_TIMEOUT)) - else: - # No timeouts and no callbacks, so use the default. - poll_timeout = _POLL_TIMEOUT - - if not self._running: - break - - if self._blocking_signal_threshold is not None: - # clear alarm so it doesn't fire while poll is waiting for - # events. - signal.setitimer(signal.ITIMER_REAL, 0, 0) - - try: - event_pairs = self._impl.poll(poll_timeout) - except Exception as e: - # Depending on python version and IOLoop implementation, - # different exception types may be thrown and there are - # two ways EINTR might be signaled: - # * e.errno == errno.EINTR - # * e.args is like (errno.EINTR, 'Interrupted system call') - if errno_from_exception(e) == errno.EINTR: - continue - else: - raise - - if self._blocking_signal_threshold is not None: - signal.setitimer(signal.ITIMER_REAL, - self._blocking_signal_threshold, 0) - - # Pop one fd at a time from the set of pending fds and run - # its handler. Since that handler may perform actions on - # other file descriptors, there may be reentrant calls to - # this IOLoop that modify self._events - self._events.update(event_pairs) - while self._events: - fd, events = self._events.popitem() - try: - fd_obj, handler_func = self._handlers[fd] - handler_func(fd_obj, events) - except (OSError, IOError) as e: - if errno_from_exception(e) == errno.EPIPE: - # Happens when the client closes the connection - pass - else: - self.handle_callback_exception(self._handlers.get(fd)) - except Exception: - self.handle_callback_exception(self._handlers.get(fd)) - fd_obj = handler_func = None - - finally: - # reset the stopped flag so another start/stop pair can be issued - self._stopped = False - if self._blocking_signal_threshold is not None: - signal.setitimer(signal.ITIMER_REAL, 0, 0) - if old_current is None: - IOLoop.clear_current() - elif old_current is not self: - old_current.make_current() - if old_wakeup_fd is not None: - signal.set_wakeup_fd(old_wakeup_fd) - - def stop(self): - self._running = False - self._stopped = True - self._waker.wake() - - def time(self): - return self.time_func() - - def call_at(self, deadline, callback, *args, **kwargs): - timeout = _Timeout( - deadline, - functools.partial(stack_context.wrap(callback), *args, **kwargs), - self) - heapq.heappush(self._timeouts, timeout) - return timeout - - def remove_timeout(self, timeout): - # Removing from a heap is complicated, so just leave the defunct - # timeout object in the queue (see discussion in - # http://docs.python.org/library/heapq.html). - # If this turns out to be a problem, we could add a garbage - # collection pass whenever there are too many dead timeouts. - timeout.callback = None - self._cancellations += 1 - - def add_callback(self, callback, *args, **kwargs): - if self._closing: - return - # Blindly insert into self._callbacks. This is safe even - # from signal handlers because deque.append is atomic. - self._callbacks.append(functools.partial( - stack_context.wrap(callback), *args, **kwargs)) - if thread.get_ident() != self._thread_ident: - # This will write one byte but Waker.consume() reads many - # at once, so it's ok to write even when not strictly - # necessary. - self._waker.wake() - else: - # If we're on the IOLoop's thread, we don't need to wake anyone. - pass - - def add_callback_from_signal(self, callback, *args, **kwargs): - with stack_context.NullContext(): - self.add_callback(callback, *args, **kwargs) - - class _Timeout(object): """An IOLoop timeout, a UNIX timestamp and a callback""" # Reduce memory overhead when there are lots of pending callbacks - __slots__ = ['deadline', 'callback', 'tdeadline'] + __slots__ = ["deadline", "callback", "tdeadline"] - def __init__(self, deadline, callback, io_loop): + def __init__( + self, deadline: float, callback: Callable[[], None], io_loop: IOLoop + ) -> None: if not isinstance(deadline, numbers.Real): raise TypeError("Unsupported deadline %r" % deadline) self.deadline = deadline self.callback = callback - self.tdeadline = (deadline, next(io_loop._timeout_counter)) + self.tdeadline = ( + deadline, + next(io_loop._timeout_counter), + ) # type: Tuple[float, int] # Comparison methods to sort by deadline, with object id as a tiebreaker # to guarantee a consistent ordering. The heapq module uses __le__ # in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons # use __lt__). - def __lt__(self, other): + def __lt__(self, other: "_Timeout") -> bool: return self.tdeadline < other.tdeadline - def __le__(self, other): + def __le__(self, other: "_Timeout") -> bool: return self.tdeadline <= other.tdeadline @@ -1197,16 +862,19 @@ class PeriodicCallback(object): .. versionchanged:: 5.1 The ``jitter`` argument is added. """ - def __init__(self, callback, callback_time, jitter=0): + + def __init__( + self, callback: Callable[[], None], callback_time: float, jitter: float = 0 + ) -> None: self.callback = callback if callback_time <= 0: raise ValueError("Periodic callback must have a positive callback_time") self.callback_time = callback_time self.jitter = jitter self._running = False - self._timeout = None + self._timeout = None # type: object - def start(self): + def start(self) -> None: """Starts the timer.""" # Looking up the IOLoop here allows to first instantiate the # PeriodicCallback in another thread, then start it using @@ -1216,36 +884,36 @@ def start(self): self._next_timeout = self.io_loop.time() self._schedule_next() - def stop(self): + def stop(self) -> None: """Stops the timer.""" self._running = False if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None - def is_running(self): - """Return True if this `.PeriodicCallback` has been started. + def is_running(self) -> bool: + """Returns ``True`` if this `.PeriodicCallback` has been started. .. versionadded:: 4.1 """ return self._running - def _run(self): + def _run(self) -> None: if not self._running: return try: return self.callback() except Exception: - self.io_loop.handle_callback_exception(self.callback) + app_log.error("Exception in callback %r", self.callback, exc_info=True) finally: self._schedule_next() - def _schedule_next(self): + def _schedule_next(self) -> None: if self._running: self._update_next(self.io_loop.time()) self._timeout = self.io_loop.add_timeout(self._next_timeout, self._run) - def _update_next(self, current_time): + def _update_next(self, current_time: float) -> None: callback_time_sec = self.callback_time / 1000.0 if self.jitter: # apply jitter fraction @@ -1255,8 +923,9 @@ def _update_next(self, current_time): # to the start of the next. If one call takes too long, # skip cycles to get back to a multiple of the original # schedule. - self._next_timeout += (math.floor((current_time - self._next_timeout) / - callback_time_sec) + 1) * callback_time_sec + self._next_timeout += ( + math.floor((current_time - self._next_timeout) / callback_time_sec) + 1 + ) * callback_time_sec else: # If the clock moved backwards, ensure we advance the next # timeout instead of recomputing the same value again. diff --git a/tornado/iostream.py b/tornado/iostream.py index 20f481d2ac..86235f4dc3 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -23,53 +23,54 @@ * `PipeIOStream`: Pipe-based IOStream implementation. """ -from __future__ import absolute_import, division, print_function - +import asyncio import collections import errno import io import numbers import os import socket +import ssl import sys import re -import warnings -from tornado.concurrent import Future +from tornado.concurrent import Future, future_set_result_unless_cancelled from tornado import ioloop -from tornado.log import gen_log, app_log +from tornado.log import gen_log from tornado.netutil import ssl_wrap_socket, _client_ssl_defaults, _server_ssl_defaults -from tornado import stack_context from tornado.util import errno_from_exception -try: - from tornado.platform.posix import _set_nonblocking -except ImportError: - _set_nonblocking = None - -try: - import ssl -except ImportError: - # ssl is not available on Google App Engine - ssl = None - -# These errnos indicate that a non-blocking operation must be retried -# at a later time. On most platforms they're the same value, but on -# some they differ. -_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN) - -if hasattr(errno, "WSAEWOULDBLOCK"): - _ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,) # type: ignore +import typing +from typing import ( + Union, + Optional, + Awaitable, + Callable, + Pattern, + Any, + Dict, + TypeVar, + Tuple, +) +from types import TracebackType + +if typing.TYPE_CHECKING: + from typing import Deque, List, Type # noqa: F401 + +_IOStreamType = TypeVar("_IOStreamType", bound="IOStream") # These errnos indicate that a connection has been abruptly terminated. # They should be caught and handled less noisily than other errors. -_ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE, - errno.ETIMEDOUT) +_ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE, errno.ETIMEDOUT) if hasattr(errno, "WSAECONNRESET"): - _ERRNO_CONNRESET += (errno.WSAECONNRESET, errno.WSAECONNABORTED, errno.WSAETIMEDOUT) # type: ignore # noqa: E501 + _ERRNO_CONNRESET += ( # type: ignore + errno.WSAECONNRESET, # type: ignore + errno.WSAECONNABORTED, # type: ignore + errno.WSAETIMEDOUT, # type: ignore + ) -if sys.platform == 'darwin': +if sys.platform == "darwin": # OSX appears to have a race condition that causes send(2) to return # EPROTOTYPE if called while a socket is being torn down: # http://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/ @@ -77,13 +78,7 @@ # instead of an unexpected error. _ERRNO_CONNRESET += (errno.EPROTOTYPE,) # type: ignore -# More non-portable errnos: -_ERRNO_INPROGRESS = (errno.EINPROGRESS,) - -if hasattr(errno, "WSAEINPROGRESS"): - _ERRNO_INPROGRESS += (errno.WSAEINPROGRESS,) # type: ignore - -_WINDOWS = sys.platform.startswith('win') +_WINDOWS = sys.platform.startswith("win") class StreamClosedError(IOError): @@ -99,8 +94,9 @@ class StreamClosedError(IOError): .. versionchanged:: 4.3 Added the ``real_error`` attribute. """ - def __init__(self, real_error=None): - super(StreamClosedError, self).__init__('Stream is closed') + + def __init__(self, real_error: Optional[BaseException] = None) -> None: + super().__init__("Stream is closed") self.real_error = real_error @@ -110,12 +106,12 @@ class UnsatisfiableReadError(Exception): Raised by ``read_until`` and ``read_until_regex`` with a ``max_bytes`` argument. """ + pass class StreamBufferFullError(Exception): - """Exception raised by `IOStream` methods when the buffer is full. - """ + """Exception raised by `IOStream` methods when the buffer is full.""" class _StreamBuffer(object): @@ -124,21 +120,23 @@ class _StreamBuffer(object): of data are encountered. """ - def __init__(self): + def __init__(self) -> None: # A sequence of (False, bytearray) and (True, memoryview) objects - self._buffers = collections.deque() + self._buffers = ( + collections.deque() + ) # type: Deque[Tuple[bool, Union[bytearray, memoryview]]] # Position in the first buffer self._first_pos = 0 self._size = 0 - def __len__(self): + def __len__(self) -> int: return self._size # Data above this size will be appended separately instead # of extending an existing bytearray _large_buf_threshold = 2048 - def append(self, data): + def append(self, data: Union[bytes, bytearray, memoryview]) -> None: """ Append the given piece of data (should be a buffer-compatible object). """ @@ -156,11 +154,11 @@ def append(self, data): if new_buf: self._buffers.append((False, bytearray(data))) else: - b += data + b += data # type: ignore self._size += size - def peek(self, size): + def peek(self, size: int) -> memoryview: """ Get a view over at most ``size`` bytes (possibly fewer) at the current buffer position. @@ -169,15 +167,15 @@ def peek(self, size): try: is_memview, b = self._buffers[0] except IndexError: - return memoryview(b'') + return memoryview(b"") pos = self._first_pos if is_memview: - return b[pos:pos + size] + return typing.cast(memoryview, b[pos : pos + size]) else: - return memoryview(b)[pos:pos + size] + return memoryview(b)[pos : pos + size] - def advance(self, size): + def advance(self, size: int) -> None: """ Advance the current buffer position by ``size`` bytes. """ @@ -200,7 +198,7 @@ def advance(self, size): # Amortized O(1) shrink for Python 2 pos += size if len(b) <= 2 * pos: - del b[:pos] + del typing.cast(bytearray, b)[:pos] pos = 0 size = 0 @@ -211,23 +209,27 @@ def advance(self, size): class BaseIOStream(object): """A utility class to write to and read from a non-blocking file or socket. - We support a non-blocking ``write()`` and a family of ``read_*()`` methods. - All of the methods take an optional ``callback`` argument and return a - `.Future` only if no callback is given. When the operation completes, - the callback will be run or the `.Future` will resolve with the data - read (or ``None`` for ``write()``). All outstanding ``Futures`` will - resolve with a `StreamClosedError` when the stream is closed; users - of the callback interface will be notified via - `.BaseIOStream.set_close_callback` instead. + We support a non-blocking ``write()`` and a family of ``read_*()`` + methods. When the operation completes, the ``Awaitable`` will resolve + with the data read (or ``None`` for ``write()``). All outstanding + ``Awaitables`` will resolve with a `StreamClosedError` when the + stream is closed; `.BaseIOStream.set_close_callback` can also be used + to be notified of a closed stream. When a stream is closed due to an error, the IOStream's ``error`` attribute contains the exception object. Subclasses must implement `fileno`, `close_fd`, `write_to_fd`, `read_from_fd`, and optionally `get_fd_error`. + """ - def __init__(self, max_buffer_size=None, - read_chunk_size=None, max_write_buffer_size=None): + + def __init__( + self, + max_buffer_size: Optional[int] = None, + read_chunk_size: Optional[int] = None, + max_write_buffer_size: Optional[int] = None, + ) -> None: """`BaseIOStream` constructor. :arg max_buffer_size: Maximum amount of incoming data to buffer; @@ -248,47 +250,43 @@ def __init__(self, max_buffer_size=None, self.max_buffer_size = max_buffer_size or 104857600 # A chunk size that is too close to max_buffer_size can cause # spurious failures. - self.read_chunk_size = min(read_chunk_size or 65536, - self.max_buffer_size // 2) + self.read_chunk_size = min(read_chunk_size or 65536, self.max_buffer_size // 2) self.max_write_buffer_size = max_write_buffer_size - self.error = None + self.error = None # type: Optional[BaseException] self._read_buffer = bytearray() self._read_buffer_pos = 0 self._read_buffer_size = 0 self._user_read_buffer = False - self._after_user_read_buffer = None + self._after_user_read_buffer = None # type: Optional[bytearray] self._write_buffer = _StreamBuffer() self._total_write_index = 0 self._total_write_done_index = 0 - self._read_delimiter = None - self._read_regex = None - self._read_max_bytes = None - self._read_bytes = None + self._read_delimiter = None # type: Optional[bytes] + self._read_regex = None # type: Optional[Pattern] + self._read_max_bytes = None # type: Optional[int] + self._read_bytes = None # type: Optional[int] self._read_partial = False self._read_until_close = False - self._read_callback = None - self._read_future = None - self._streaming_callback = None - self._write_callback = None - self._write_futures = collections.deque() - self._close_callback = None - self._connect_callback = None - self._connect_future = None + self._read_future = None # type: Optional[Future] + self._write_futures = ( + collections.deque() + ) # type: Deque[Tuple[int, Future[None]]] + self._close_callback = None # type: Optional[Callable[[], None]] + self._connect_future = None # type: Optional[Future[IOStream]] # _ssl_connect_future should be defined in SSLIOStream - # but it's here so we can clean it up in maybe_run_close_callback. + # but it's here so we can clean it up in _signal_closed # TODO: refactor that so subclasses can add additional futures # to be cancelled. - self._ssl_connect_future = None + self._ssl_connect_future = None # type: Optional[Future[SSLIOStream]] self._connecting = False - self._state = None - self._pending_callbacks = 0 + self._state = None # type: Optional[int] self._closed = False - def fileno(self): + def fileno(self) -> Union[int, ioloop._Selectable]: """Returns the file descriptor for this stream.""" raise NotImplementedError() - def close_fd(self): + def close_fd(self) -> None: """Closes the file underlying this stream. ``close_fd`` is called by `BaseIOStream` and should not be called @@ -296,14 +294,14 @@ def close_fd(self): """ raise NotImplementedError() - def write_to_fd(self, data): + def write_to_fd(self, data: memoryview) -> int: """Attempts to write ``data`` to the underlying file. Returns the number of bytes written. """ raise NotImplementedError() - def read_from_fd(self, buf): + def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]: """Attempts to read from the underlying file. Reads up to ``len(buf)`` bytes, storing them in the buffer. @@ -318,7 +316,7 @@ def read_from_fd(self, buf): """ raise NotImplementedError() - def get_fd_error(self): + def get_fd_error(self) -> Optional[Exception]: """Returns information about any error on the underlying file. This method is called after the `.IOLoop` has signaled an error on the @@ -328,13 +326,13 @@ def get_fd_error(self): """ return None - def read_until_regex(self, regex, callback=None, max_bytes=None): + def read_until_regex( + self, regex: bytes, max_bytes: Optional[int] = None + ) -> Awaitable[bytes]: """Asynchronously read until we have matched the given regex. The result includes the data that matches the regex and anything - that came before it. If a callback is given, it will be run - with the data as an argument; if not, this method returns a - `.Future`. + that came before it. If ``max_bytes`` is not None, the connection will be closed if more than ``max_bytes`` bytes have been read and the regex is @@ -344,13 +342,13 @@ def read_until_regex(self, regex, callback=None, max_bytes=None): Added the ``max_bytes`` argument. The ``callback`` argument is now optional and a `.Future` will be returned if it is omitted. - .. deprecated:: 5.1 + .. versionchanged:: 6.0 - The ``callback`` argument is deprecated and will be removed - in Tornado 6.0. Use the returned `.Future` instead. + The ``callback`` argument was removed. Use the returned + `.Future` instead. """ - future = self._set_read_callback(callback) + future = self._start_read() self._read_regex = re.compile(regex) self._read_max_bytes = max_bytes try: @@ -361,19 +359,18 @@ def read_until_regex(self, regex, callback=None, max_bytes=None): self.close(exc_info=e) return future except: - if future is not None: - # Ensure that the future doesn't log an error because its - # failure was never examined. - future.add_done_callback(lambda f: f.exception()) + # Ensure that the future doesn't log an error because its + # failure was never examined. + future.add_done_callback(lambda f: f.exception()) raise return future - def read_until(self, delimiter, callback=None, max_bytes=None): + def read_until( + self, delimiter: bytes, max_bytes: Optional[int] = None + ) -> Awaitable[bytes]: """Asynchronously read until we have found the given delimiter. The result includes all the data read including the delimiter. - If a callback is given, it will be run with the data as an argument; - if not, this method returns a `.Future`. If ``max_bytes`` is not None, the connection will be closed if more than ``max_bytes`` bytes have been read and the delimiter @@ -383,12 +380,12 @@ def read_until(self, delimiter, callback=None, max_bytes=None): Added the ``max_bytes`` argument. The ``callback`` argument is now optional and a `.Future` will be returned if it is omitted. - .. deprecated:: 5.1 + .. versionchanged:: 6.0 - The ``callback`` argument is deprecated and will be removed - in Tornado 6.0. Use the returned `.Future` instead. + The ``callback`` argument was removed. Use the returned + `.Future` instead. """ - future = self._set_read_callback(callback) + future = self._start_read() self._read_delimiter = delimiter self._read_max_bytes = max_bytes try: @@ -399,58 +396,42 @@ def read_until(self, delimiter, callback=None, max_bytes=None): self.close(exc_info=e) return future except: - if future is not None: - future.add_done_callback(lambda f: f.exception()) + future.add_done_callback(lambda f: f.exception()) raise return future - def read_bytes(self, num_bytes, callback=None, streaming_callback=None, - partial=False): + def read_bytes(self, num_bytes: int, partial: bool = False) -> Awaitable[bytes]: """Asynchronously read a number of bytes. - If a ``streaming_callback`` is given, it will be called with chunks - of data as they become available, and the final result will be empty. - Otherwise, the result is all the data that was read. - If a callback is given, it will be run with the data as an argument; - if not, this method returns a `.Future`. - - If ``partial`` is true, the callback is run as soon as we have + If ``partial`` is true, data is returned as soon as we have any bytes to return (but never more than ``num_bytes``) .. versionchanged:: 4.0 Added the ``partial`` argument. The callback argument is now optional and a `.Future` will be returned if it is omitted. - .. deprecated:: 5.1 + .. versionchanged:: 6.0 - The ``callback`` and ``streaming_callback`` arguments are - deprecated and will be removed in Tornado 6.0. Use the - returned `.Future` (and ``partial=True`` for - ``streaming_callback``) instead. + The ``callback`` and ``streaming_callback`` arguments have + been removed. Use the returned `.Future` (and + ``partial=True`` for ``streaming_callback``) instead. """ - future = self._set_read_callback(callback) + future = self._start_read() assert isinstance(num_bytes, numbers.Integral) self._read_bytes = num_bytes self._read_partial = partial - if streaming_callback is not None: - warnings.warn("streaming_callback is deprecated, use partial instead", - DeprecationWarning) - self._streaming_callback = stack_context.wrap(streaming_callback) try: self._try_inline_read() except: - if future is not None: - future.add_done_callback(lambda f: f.exception()) + future.add_done_callback(lambda f: f.exception()) raise return future - def read_into(self, buf, callback=None, partial=False): + def read_into(self, buf: bytearray, partial: bool = False) -> Awaitable[int]: """Asynchronously read a number of bytes. ``buf`` must be a writable buffer into which data will be read. - If a callback is given, it will be run with the number of read - bytes as an argument; if not, this method returns a `.Future`. If ``partial`` is true, the callback is run as soon as any bytes have been read. Otherwise, it is run when the ``buf`` has been @@ -458,24 +439,26 @@ def read_into(self, buf, callback=None, partial=False): .. versionadded:: 5.0 - .. deprecated:: 5.1 + .. versionchanged:: 6.0 - The ``callback`` argument is deprecated and will be removed - in Tornado 6.0. Use the returned `.Future` instead. + The ``callback`` argument was removed. Use the returned + `.Future` instead. """ - future = self._set_read_callback(callback) + future = self._start_read() # First copy data already in read buffer available_bytes = self._read_buffer_size n = len(buf) if available_bytes >= n: end = self._read_buffer_pos + n - buf[:] = memoryview(self._read_buffer)[self._read_buffer_pos:end] + buf[:] = memoryview(self._read_buffer)[self._read_buffer_pos : end] del self._read_buffer[:end] self._after_user_read_buffer = self._read_buffer elif available_bytes > 0: - buf[:available_bytes] = memoryview(self._read_buffer)[self._read_buffer_pos:] + buf[:available_bytes] = memoryview(self._read_buffer)[ + self._read_buffer_pos : + ] # Set up the supplied buffer as our temporary read buffer. # The original (if it had any data remaining) has been @@ -490,68 +473,45 @@ def read_into(self, buf, callback=None, partial=False): try: self._try_inline_read() except: - if future is not None: - future.add_done_callback(lambda f: f.exception()) + future.add_done_callback(lambda f: f.exception()) raise return future - def read_until_close(self, callback=None, streaming_callback=None): + def read_until_close(self) -> Awaitable[bytes]: """Asynchronously reads all data from the socket until it is closed. - If a ``streaming_callback`` is given, it will be called with chunks - of data as they become available, and the final result will be empty. - Otherwise, the result is all the data that was read. - If a callback is given, it will be run with the data as an argument; - if not, this method returns a `.Future`. - - Note that if a ``streaming_callback`` is used, data will be - read from the socket as quickly as it becomes available; there - is no way to apply backpressure or cancel the reads. If flow - control or cancellation are desired, use a loop with - `read_bytes(partial=True) <.read_bytes>` instead. + This will buffer all available data until ``max_buffer_size`` + is reached. If flow control or cancellation are desired, use a + loop with `read_bytes(partial=True) <.read_bytes>` instead. .. versionchanged:: 4.0 The callback argument is now optional and a `.Future` will be returned if it is omitted. - .. deprecated:: 5.1 + .. versionchanged:: 6.0 - The ``callback`` and ``streaming_callback`` arguments are - deprecated and will be removed in Tornado 6.0. Use the - returned `.Future` (and `read_bytes` with ``partial=True`` - for ``streaming_callback``) instead. + The ``callback`` and ``streaming_callback`` arguments have + been removed. Use the returned `.Future` (and `read_bytes` + with ``partial=True`` for ``streaming_callback``) instead. """ - future = self._set_read_callback(callback) - if streaming_callback is not None: - warnings.warn("streaming_callback is deprecated, use read_bytes(partial=True) instead", - DeprecationWarning) - self._streaming_callback = stack_context.wrap(streaming_callback) + future = self._start_read() if self.closed(): - if self._streaming_callback is not None: - self._run_read_callback(self._read_buffer_size, True) - self._run_read_callback(self._read_buffer_size, False) + self._finish_read(self._read_buffer_size, False) return future self._read_until_close = True try: self._try_inline_read() except: - if future is not None: - future.add_done_callback(lambda f: f.exception()) + future.add_done_callback(lambda f: f.exception()) raise return future - def write(self, data, callback=None): + def write(self, data: Union[bytes, memoryview]) -> "Future[None]": """Asynchronously write the given data to this stream. - If ``callback`` is given, we call it when all of the buffered write - data has been successfully written to the stream. If there was - previously buffered write data and an old write callback, that - callback is simply overwritten with this new callback. - - If no ``callback`` is given, this method returns a `.Future` that - resolves (with a result of ``None``) when the write has been - completed. + This method returns a `.Future` that resolves (with a result + of ``None``) when the write has been completed. The ``data`` argument may be of type `bytes` or `memoryview`. @@ -561,28 +521,27 @@ def write(self, data, callback=None): .. versionchanged:: 4.5 Added support for `memoryview` arguments. - .. deprecated:: 5.1 + .. versionchanged:: 6.0 - The ``callback`` argument is deprecated and will be removed - in Tornado 6.0. Use the returned `.Future` instead. + The ``callback`` argument was removed. Use the returned + `.Future` instead. """ self._check_closed() if data: - if (self.max_write_buffer_size is not None and - len(self._write_buffer) + len(data) > self.max_write_buffer_size): + if isinstance(data, memoryview): + # Make sure that ``len(data) == data.nbytes`` + data = memoryview(data).cast("B") + if ( + self.max_write_buffer_size is not None + and len(self._write_buffer) + len(data) > self.max_write_buffer_size + ): raise StreamBufferFullError("Reached maximum write buffer size") self._write_buffer.append(data) self._total_write_index += len(data) - if callback is not None: - warnings.warn("callback argument is deprecated, use returned Future instead", - DeprecationWarning) - self._write_callback = stack_context.wrap(callback) - future = None - else: - future = Future() - future.add_done_callback(lambda f: f.exception()) - self._write_futures.append((self._total_write_index, future)) + future = Future() # type: Future[None] + future.add_done_callback(lambda f: f.exception()) + self._write_futures.append((self._total_write_index, future)) if not self._connecting: self._handle_write() if self._write_buffer: @@ -590,7 +549,7 @@ def write(self, data, callback=None): self._maybe_add_error_listener() return future - def set_close_callback(self, callback): + def set_close_callback(self, callback: Optional[Callable[[], None]]) -> None: """Call the given callback when the stream is closed. This mostly is not necessary for applications that use the @@ -600,12 +559,24 @@ def set_close_callback(self, callback): closed while no other read or write is in progress. Unlike other callback-based interfaces, ``set_close_callback`` - will not be removed in Tornado 6.0. + was not removed in Tornado 6.0. """ - self._close_callback = stack_context.wrap(callback) + self._close_callback = callback self._maybe_add_error_listener() - def close(self, exc_info=False): + def close( + self, + exc_info: Union[ + None, + bool, + BaseException, + Tuple[ + "Optional[Type[BaseException]]", + Optional[BaseException], + Optional[TracebackType], + ], + ] = False, + ) -> None: """Close this stream. If ``exc_info`` is true, set the ``error`` attribute to the current @@ -623,61 +594,76 @@ def close(self, exc_info=False): if any(exc_info): self.error = exc_info[1] if self._read_until_close: - if (self._streaming_callback is not None and - self._read_buffer_size): - self._run_read_callback(self._read_buffer_size, True) self._read_until_close = False - self._run_read_callback(self._read_buffer_size, False) + self._finish_read(self._read_buffer_size, False) + elif self._read_future is not None: + # resolve reads that are pending and ready to complete + try: + pos = self._find_read_pos() + except UnsatisfiableReadError: + pass + else: + if pos is not None: + self._read_from_buffer(pos) if self._state is not None: self.io_loop.remove_handler(self.fileno()) self._state = None self.close_fd() self._closed = True - self._maybe_run_close_callback() - - def _maybe_run_close_callback(self): - # If there are pending callbacks, don't run the close callback - # until they're done (see _maybe_add_error_handler) - if self.closed() and self._pending_callbacks == 0: - futures = [] - if self._read_future is not None: - futures.append(self._read_future) - self._read_future = None - futures += [future for _, future in self._write_futures] - self._write_futures.clear() - if self._connect_future is not None: - futures.append(self._connect_future) - self._connect_future = None - if self._ssl_connect_future is not None: - futures.append(self._ssl_connect_future) - self._ssl_connect_future = None - for future in futures: + self._signal_closed() + + def _signal_closed(self) -> None: + futures = [] # type: List[Future] + if self._read_future is not None: + futures.append(self._read_future) + self._read_future = None + futures += [future for _, future in self._write_futures] + self._write_futures.clear() + if self._connect_future is not None: + futures.append(self._connect_future) + self._connect_future = None + for future in futures: + if not future.done(): future.set_exception(StreamClosedError(real_error=self.error)) + # Reference the exception to silence warnings. Annoyingly, + # this raises if the future was cancelled, but just + # returns any other error. + try: future.exception() - if self._close_callback is not None: - cb = self._close_callback - self._close_callback = None - self._run_callback(cb) - # Delete any unfinished callbacks to break up reference cycles. - self._read_callback = self._write_callback = None - # Clear the buffers so they can be cleared immediately even - # if the IOStream object is kept alive by a reference cycle. - # TODO: Clear the read buffer too; it currently breaks some tests. - self._write_buffer = None - - def reading(self): - """Returns true if we are currently reading from the stream.""" - return self._read_callback is not None or self._read_future is not None - - def writing(self): - """Returns true if we are currently writing to the stream.""" + except asyncio.CancelledError: + pass + if self._ssl_connect_future is not None: + # _ssl_connect_future expects to see the real exception (typically + # an ssl.SSLError), not just StreamClosedError. + if not self._ssl_connect_future.done(): + if self.error is not None: + self._ssl_connect_future.set_exception(self.error) + else: + self._ssl_connect_future.set_exception(StreamClosedError()) + self._ssl_connect_future.exception() + self._ssl_connect_future = None + if self._close_callback is not None: + cb = self._close_callback + self._close_callback = None + self.io_loop.add_callback(cb) + # Clear the buffers so they can be cleared immediately even + # if the IOStream object is kept alive by a reference cycle. + # TODO: Clear the read buffer too; it currently breaks some tests. + self._write_buffer = None # type: ignore + + def reading(self) -> bool: + """Returns ``True`` if we are currently reading from the stream.""" + return self._read_future is not None + + def writing(self) -> bool: + """Returns ``True`` if we are currently writing to the stream.""" return bool(self._write_buffer) - def closed(self): - """Returns true if the stream has been closed.""" + def closed(self) -> bool: + """Returns ``True`` if the stream has been closed.""" return self._closed - def set_nodelay(self, value): + def set_nodelay(self, value: bool) -> None: """Sets the no-delay flag for this stream. By default, data written to TCP streams may be held for a time @@ -692,7 +678,10 @@ def set_nodelay(self, value): """ pass - def _handle_events(self, fd, events): + def _handle_connect(self) -> None: + raise NotImplementedError() + + def _handle_events(self, fd: Union[int, ioloop._Selectable], events: int) -> None: if self.closed(): gen_log.warning("Got events for closed stream %s", fd) return @@ -732,170 +721,114 @@ def _handle_events(self, fd, events): # yet anyway, so we don't need to listen in this case. state |= self.io_loop.READ if state != self._state: - assert self._state is not None, \ - "shouldn't happen: _handle_events without self._state" + assert ( + self._state is not None + ), "shouldn't happen: _handle_events without self._state" self._state = state self.io_loop.update_handler(self.fileno(), self._state) except UnsatisfiableReadError as e: gen_log.info("Unsatisfiable read, closing connection: %s" % e) self.close(exc_info=e) except Exception as e: - gen_log.error("Uncaught exception, closing connection.", - exc_info=True) + gen_log.error("Uncaught exception, closing connection.", exc_info=True) self.close(exc_info=e) raise - def _run_callback(self, callback, *args): - def wrapper(): - self._pending_callbacks -= 1 - try: - return callback(*args) - except Exception as e: - app_log.error("Uncaught exception, closing connection.", - exc_info=True) - # Close the socket on an uncaught exception from a user callback - # (It would eventually get closed when the socket object is - # gc'd, but we don't want to rely on gc happening before we - # run out of file descriptors) - self.close(exc_info=e) - # Re-raise the exception so that IOLoop.handle_callback_exception - # can see it and log the error - raise - finally: - self._maybe_add_error_listener() - # We schedule callbacks to be run on the next IOLoop iteration - # rather than running them directly for several reasons: - # * Prevents unbounded stack growth when a callback calls an - # IOLoop operation that immediately runs another callback - # * Provides a predictable execution context for e.g. - # non-reentrant mutexes - # * Ensures that the try/except in wrapper() is run outside - # of the application's StackContexts - with stack_context.NullContext(): - # stack_context was already captured in callback, we don't need to - # capture it again for IOStream's wrapper. This is especially - # important if the callback was pre-wrapped before entry to - # IOStream (as in HTTPConnection._header_callback), as we could - # capture and leak the wrong context here. - self._pending_callbacks += 1 - self.io_loop.add_callback(wrapper) - - def _read_to_buffer_loop(self): + def _read_to_buffer_loop(self) -> Optional[int]: # This method is called from _handle_read and _try_inline_read. - try: - if self._read_bytes is not None: - target_bytes = self._read_bytes - elif self._read_max_bytes is not None: - target_bytes = self._read_max_bytes - elif self.reading(): - # For read_until without max_bytes, or - # read_until_close, read as much as we can before - # scanning for the delimiter. - target_bytes = None - else: - target_bytes = 0 - next_find_pos = 0 - # Pretend to have a pending callback so that an EOF in - # _read_to_buffer doesn't trigger an immediate close - # callback. At the end of this method we'll either - # establish a real pending callback via - # _read_from_buffer or run the close callback. - # - # We need two try statements here so that - # pending_callbacks is decremented before the `except` - # clause below (which calls `close` and does need to - # trigger the callback) - self._pending_callbacks += 1 - while not self.closed(): - # Read from the socket until we get EWOULDBLOCK or equivalent. - # SSL sockets do some internal buffering, and if the data is - # sitting in the SSL object's buffer select() and friends - # can't see it; the only way to find out if it's there is to - # try to read it. - if self._read_to_buffer() == 0: - break - - self._run_streaming_callback() + if self._read_bytes is not None: + target_bytes = self._read_bytes # type: Optional[int] + elif self._read_max_bytes is not None: + target_bytes = self._read_max_bytes + elif self.reading(): + # For read_until without max_bytes, or + # read_until_close, read as much as we can before + # scanning for the delimiter. + target_bytes = None + else: + target_bytes = 0 + next_find_pos = 0 + while not self.closed(): + # Read from the socket until we get EWOULDBLOCK or equivalent. + # SSL sockets do some internal buffering, and if the data is + # sitting in the SSL object's buffer select() and friends + # can't see it; the only way to find out if it's there is to + # try to read it. + if self._read_to_buffer() == 0: + break - # If we've read all the bytes we can use, break out of - # this loop. We can't just call read_from_buffer here - # because of subtle interactions with the - # pending_callback and error_listener mechanisms. - # - # If we've reached target_bytes, we know we're done. - if (target_bytes is not None and - self._read_buffer_size >= target_bytes): - break + # If we've read all the bytes we can use, break out of + # this loop. - # Otherwise, we need to call the more expensive find_read_pos. - # It's inefficient to do this on every read, so instead - # do it on the first read and whenever the read buffer - # size has doubled. - if self._read_buffer_size >= next_find_pos: - pos = self._find_read_pos() - if pos is not None: - return pos - next_find_pos = self._read_buffer_size * 2 - return self._find_read_pos() - finally: - self._pending_callbacks -= 1 + # If we've reached target_bytes, we know we're done. + if target_bytes is not None and self._read_buffer_size >= target_bytes: + break - def _handle_read(self): + # Otherwise, we need to call the more expensive find_read_pos. + # It's inefficient to do this on every read, so instead + # do it on the first read and whenever the read buffer + # size has doubled. + if self._read_buffer_size >= next_find_pos: + pos = self._find_read_pos() + if pos is not None: + return pos + next_find_pos = self._read_buffer_size * 2 + return self._find_read_pos() + + def _handle_read(self) -> None: try: pos = self._read_to_buffer_loop() except UnsatisfiableReadError: raise + except asyncio.CancelledError: + raise except Exception as e: gen_log.warning("error on read: %s" % e) self.close(exc_info=e) return if pos is not None: self._read_from_buffer(pos) - return - else: - self._maybe_run_close_callback() - - def _set_read_callback(self, callback): - assert self._read_callback is None, "Already reading" - assert self._read_future is None, "Already reading" - if callback is not None: - warnings.warn("callbacks are deprecated, use returned Future instead", - DeprecationWarning) - self._read_callback = stack_context.wrap(callback) - else: - self._read_future = Future() + + def _start_read(self) -> Future: + if self._read_future is not None: + # It is an error to start a read while a prior read is unresolved. + # However, if the prior read is unresolved because the stream was + # closed without satisfying it, it's better to raise + # StreamClosedError instead of AssertionError. In particular, this + # situation occurs in harmless situations in http1connection.py and + # an AssertionError would be logged noisily. + # + # On the other hand, it is legal to start a new read while the + # stream is closed, in case the read can be satisfied from the + # read buffer. So we only want to check the closed status of the + # stream if we need to decide what kind of error to raise for + # "already reading". + # + # These conditions have proven difficult to test; we have no + # unittests that reliably verify this behavior so be careful + # when making changes here. See #2651 and #2719. + self._check_closed() + assert self._read_future is None, "Already reading" + self._read_future = Future() return self._read_future - def _run_read_callback(self, size, streaming): + def _finish_read(self, size: int, streaming: bool) -> None: if self._user_read_buffer: self._read_buffer = self._after_user_read_buffer or bytearray() self._after_user_read_buffer = None self._read_buffer_pos = 0 self._read_buffer_size = len(self._read_buffer) self._user_read_buffer = False - result = size + result = size # type: Union[int, bytes] else: result = self._consume(size) - if streaming: - callback = self._streaming_callback - else: - callback = self._read_callback - self._read_callback = self._streaming_callback = None - if self._read_future is not None: - assert callback is None - future = self._read_future - self._read_future = None - - future.set_result(result) - if callback is not None: - assert (self._read_future is None) or streaming - self._run_callback(callback, result) - else: - # If we scheduled a callback, we will add the error listener - # afterwards. If we didn't, we have to do it now. - self._maybe_add_error_listener() + if self._read_future is not None: + future = self._read_future + self._read_future = None + future_set_result_unless_cancelled(future, result) + self._maybe_add_error_listener() - def _try_inline_read(self): + def _try_inline_read(self) -> None: """Attempt to complete the current read operation from buffered data. If the read can be completed without blocking, schedules the @@ -903,32 +836,21 @@ def _try_inline_read(self): listening for reads on the socket. """ # See if we've already got the data from a previous read - self._run_streaming_callback() pos = self._find_read_pos() if pos is not None: self._read_from_buffer(pos) return self._check_closed() - try: - pos = self._read_to_buffer_loop() - except Exception: - # If there was an in _read_to_buffer, we called close() already, - # but couldn't run the close callback because of _pending_callbacks. - # Before we escape from this function, run the close callback if - # applicable. - self._maybe_run_close_callback() - raise + pos = self._read_to_buffer_loop() if pos is not None: self._read_from_buffer(pos) return - # We couldn't satisfy the read inline, so either close the stream - # or listen for new data. - if self.closed(): - self._maybe_run_close_callback() - else: + # We couldn't satisfy the read inline, so make sure we're + # listening for new data unless the stream is closed. + if not self.closed(): self._add_io_state(ioloop.IOLoop.READ) - def _read_to_buffer(self): + def _read_to_buffer(self) -> Optional[int]: """Reads from the socket and appends the result to the read buffer. Returns the number of bytes read. Returns 0 if there is nothing @@ -939,20 +861,20 @@ def _read_to_buffer(self): while True: try: if self._user_read_buffer: - buf = memoryview(self._read_buffer)[self._read_buffer_size:] + buf = memoryview(self._read_buffer)[ + self._read_buffer_size : + ] # type: Union[memoryview, bytearray] else: buf = bytearray(self.read_chunk_size) bytes_read = self.read_from_fd(buf) except (socket.error, IOError, OSError) as e: - if errno_from_exception(e) == errno.EINTR: - continue # ssl.SSLError is a subclass of socket.error if self._is_connreset(e): # Treat ECONNRESET as a connection close rather than # an error to minimize log spam (the exception will # be available on self.error for apps that care). self.close(exc_info=e) - return + return None self.close(exc_info=e) raise break @@ -967,22 +889,14 @@ def _read_to_buffer(self): finally: # Break the reference to buf so we don't waste a chunk's worth of # memory in case an exception hangs on to our stack frame. - buf = None + del buf if self._read_buffer_size > self.max_buffer_size: gen_log.error("Reached maximum read buffer size") self.close() raise StreamBufferFullError("Reached maximum read buffer size") return bytes_read - def _run_streaming_callback(self): - if self._streaming_callback is not None and self._read_buffer_size: - bytes_to_consume = self._read_buffer_size - if self._read_bytes is not None: - bytes_to_consume = min(self._read_bytes, bytes_to_consume) - self._read_bytes -= bytes_to_consume - self._run_read_callback(bytes_to_consume, True) - - def _read_from_buffer(self, pos): + def _read_from_buffer(self, pos: int) -> None: """Attempts to complete the currently-pending read from the buffer. The argument is either a position in the read buffer or None, @@ -990,18 +904,19 @@ def _read_from_buffer(self, pos): """ self._read_bytes = self._read_delimiter = self._read_regex = None self._read_partial = False - self._run_read_callback(pos, False) + self._finish_read(pos, False) - def _find_read_pos(self): + def _find_read_pos(self) -> Optional[int]: """Attempts to find a position in the read buffer that satisfies the currently-pending read. Returns a position in the buffer if the current read can be satisfied, or None if it cannot. """ - if (self._read_bytes is not None and - (self._read_buffer_size >= self._read_bytes or - (self._read_partial and self._read_buffer_size > 0))): + if self._read_bytes is not None and ( + self._read_buffer_size >= self._read_bytes + or (self._read_partial and self._read_buffer_size > 0) + ): num_bytes = min(self._read_bytes, self._read_buffer_size) return num_bytes elif self._read_delimiter is not None: @@ -1014,20 +929,18 @@ def _find_read_pos(self): # since large merges are relatively expensive and get undone in # _consume(). if self._read_buffer: - loc = self._read_buffer.find(self._read_delimiter, - self._read_buffer_pos) + loc = self._read_buffer.find( + self._read_delimiter, self._read_buffer_pos + ) if loc != -1: loc -= self._read_buffer_pos delimiter_len = len(self._read_delimiter) - self._check_max_bytes(self._read_delimiter, - loc + delimiter_len) + self._check_max_bytes(self._read_delimiter, loc + delimiter_len) return loc + delimiter_len - self._check_max_bytes(self._read_delimiter, - self._read_buffer_size) + self._check_max_bytes(self._read_delimiter, self._read_buffer_size) elif self._read_regex is not None: if self._read_buffer: - m = self._read_regex.search(self._read_buffer, - self._read_buffer_pos) + m = self._read_regex.search(self._read_buffer, self._read_buffer_pos) if m is not None: loc = m.end() - self._read_buffer_pos self._check_max_bytes(self._read_regex, loc) @@ -1035,14 +948,14 @@ def _find_read_pos(self): self._check_max_bytes(self._read_regex, self._read_buffer_size) return None - def _check_max_bytes(self, delimiter, size): - if (self._read_max_bytes is not None and - size > self._read_max_bytes): + def _check_max_bytes(self, delimiter: Union[bytes, Pattern], size: int) -> None: + if self._read_max_bytes is not None and size > self._read_max_bytes: raise UnsatisfiableReadError( - "delimiter %r not found within %d bytes" % ( - delimiter, self._read_max_bytes)) + "delimiter %r not found within %d bytes" + % (delimiter, self._read_max_bytes) + ) - def _handle_write(self): + def _handle_write(self) -> None: while True: size = len(self._write_buffer) if not size: @@ -1062,111 +975,102 @@ def _handle_write(self): break self._write_buffer.advance(num_bytes) self._total_write_done_index += num_bytes + except BlockingIOError: + break except (socket.error, IOError, OSError) as e: - if e.args[0] in _ERRNO_WOULDBLOCK: - break - else: - if not self._is_connreset(e): - # Broken pipe errors are usually caused by connection - # reset, and its better to not log EPIPE errors to - # minimize log spam - gen_log.warning("Write error on %s: %s", - self.fileno(), e) - self.close(exc_info=e) - return + if not self._is_connreset(e): + # Broken pipe errors are usually caused by connection + # reset, and its better to not log EPIPE errors to + # minimize log spam + gen_log.warning("Write error on %s: %s", self.fileno(), e) + self.close(exc_info=e) + return while self._write_futures: index, future = self._write_futures[0] if index > self._total_write_done_index: break self._write_futures.popleft() - future.set_result(None) - - if not len(self._write_buffer): - if self._write_callback: - callback = self._write_callback - self._write_callback = None - self._run_callback(callback) + future_set_result_unless_cancelled(future, None) - def _consume(self, loc): + def _consume(self, loc: int) -> bytes: # Consume loc bytes from the read buffer and return them if loc == 0: return b"" assert loc <= self._read_buffer_size # Slice the bytearray buffer into bytes, without intermediate copying - b = (memoryview(self._read_buffer) - [self._read_buffer_pos:self._read_buffer_pos + loc] - ).tobytes() + b = ( + memoryview(self._read_buffer)[ + self._read_buffer_pos : self._read_buffer_pos + loc + ] + ).tobytes() self._read_buffer_pos += loc self._read_buffer_size -= loc # Amortized O(1) shrink # (this heuristic is implemented natively in Python 3.4+ # but is replicated here for Python 2) if self._read_buffer_pos > self._read_buffer_size: - del self._read_buffer[:self._read_buffer_pos] + del self._read_buffer[: self._read_buffer_pos] self._read_buffer_pos = 0 return b - def _check_closed(self): + def _check_closed(self) -> None: if self.closed(): raise StreamClosedError(real_error=self.error) - def _maybe_add_error_listener(self): + def _maybe_add_error_listener(self) -> None: # This method is part of an optimization: to detect a connection that # is closed when we're not actively reading or writing, we must listen # for read events. However, it is inefficient to do this when the # connection is first established because we are going to read or write # immediately anyway. Instead, we insert checks at various times to # see if the connection is idle and add the read listener then. - if self._pending_callbacks != 0: - return if self._state is None or self._state == ioloop.IOLoop.ERROR: - if self.closed(): - self._maybe_run_close_callback() - elif (self._read_buffer_size == 0 and - self._close_callback is not None): + if ( + not self.closed() + and self._read_buffer_size == 0 + and self._close_callback is not None + ): self._add_io_state(ioloop.IOLoop.READ) - def _add_io_state(self, state): + def _add_io_state(self, state: int) -> None: """Adds `state` (IOLoop.{READ,WRITE} flags) to our event handler. Implementation notes: Reads and writes have a fast path and a slow path. The fast path reads synchronously from socket buffers, while the slow path uses `_add_io_state` to schedule - an IOLoop callback. Note that in both cases, the callback is - run asynchronously with `_run_callback`. + an IOLoop callback. To detect closed connections, we must have called `_add_io_state` at some point, but we want to delay this as much as possible so we don't have to set an `IOLoop.ERROR` listener that will be overwritten by the next slow-path - operation. As long as there are callbacks scheduled for - fast-path ops, those callbacks may do more reads. - If a sequence of fast-path ops do not end in a slow-path op, - (e.g. for an @asynchronous long-poll request), we must add - the error handler. This is done in `_run_callback` and `write` - (since the write callback is optional so we can have a - fast-path write with no `_run_callback`) + operation. If a sequence of fast-path ops do not end in a + slow-path op, (e.g. for an @asynchronous long-poll request), + we must add the error handler. + + TODO: reevaluate this now that callbacks are gone. + """ if self.closed(): # connection has been closed, so there can be no future events return if self._state is None: self._state = ioloop.IOLoop.ERROR | state - with stack_context.NullContext(): - self.io_loop.add_handler( - self.fileno(), self._handle_events, self._state) + self.io_loop.add_handler(self.fileno(), self._handle_events, self._state) elif not self._state & state: self._state = self._state | state self.io_loop.update_handler(self.fileno(), self._state) - def _is_connreset(self, exc): - """Return true if exc is ECONNRESET or equivalent. + def _is_connreset(self, exc: BaseException) -> bool: + """Return ``True`` if exc is ECONNRESET or equivalent. May be overridden in subclasses. """ - return (isinstance(exc, (socket.error, IOError)) and - errno_from_exception(exc) in _ERRNO_CONNRESET) + return ( + isinstance(exc, (socket.error, IOError)) + and errno_from_exception(exc) in _ERRNO_CONNRESET + ) class IOStream(BaseIOStream): @@ -1190,24 +1094,23 @@ class IOStream(BaseIOStream): import tornado.iostream import socket - def send_request(): - stream.write(b"GET / HTTP/1.0\r\nHost: friendfeed.com\r\n\r\n") - stream.read_until(b"\r\n\r\n", on_headers) - - def on_headers(data): + async def main(): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + stream = tornado.iostream.IOStream(s) + await stream.connect(("friendfeed.com", 80)) + await stream.write(b"GET / HTTP/1.0\r\nHost: friendfeed.com\r\n\r\n") + header_data = await stream.read_until(b"\r\n\r\n") headers = {} - for line in data.split(b"\r\n"): - parts = line.split(b":") - if len(parts) == 2: - headers[parts[0].strip()] = parts[1].strip() - stream.read_bytes(int(headers[b"Content-Length"]), on_body) - - def on_body(data): - print(data) + for line in header_data.split(b"\r\n"): + parts = line.split(b":") + if len(parts) == 2: + headers[parts[0].strip()] = parts[1].strip() + body_data = await stream.read_bytes(int(headers[b"Content-Length"])) + print(body_data) stream.close() - tornado.ioloop.IOLoop.current().stop() if __name__ == '__main__': + tornado.ioloop.IOLoop.current().run_sync(main) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) stream = tornado.iostream.IOStream(s) stream.connect(("friendfeed.com", 80), send_request) @@ -1217,43 +1120,42 @@ def on_body(data): :hide: """ - def __init__(self, socket, *args, **kwargs): + + def __init__(self, socket: socket.socket, *args: Any, **kwargs: Any) -> None: self.socket = socket self.socket.setblocking(False) - super(IOStream, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) - def fileno(self): + def fileno(self) -> Union[int, ioloop._Selectable]: return self.socket - def close_fd(self): + def close_fd(self) -> None: self.socket.close() - self.socket = None + self.socket = None # type: ignore - def get_fd_error(self): - errno = self.socket.getsockopt(socket.SOL_SOCKET, - socket.SO_ERROR) + def get_fd_error(self) -> Optional[Exception]: + errno = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) return socket.error(errno, os.strerror(errno)) - def read_from_fd(self, buf): + def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]: try: - return self.socket.recv_into(buf) - except socket.error as e: - if e.args[0] in _ERRNO_WOULDBLOCK: - return None - else: - raise + return self.socket.recv_into(buf, len(buf)) + except BlockingIOError: + return None finally: - buf = None + del buf - def write_to_fd(self, data): + def write_to_fd(self, data: memoryview) -> int: try: - return self.socket.send(data) + return self.socket.send(data) # type: ignore finally: # Avoid keeping to data, which can be a memoryview. # See https://github.com/tornadoweb/tornado/pull/2008 del data - def connect(self, address, callback=None, server_hostname=None): + def connect( + self: _IOStreamType, address: Any, server_hostname: Optional[str] = None + ) -> "Future[_IOStreamType]": """Connects the socket to a remote address without blocking. May only be called if the socket passed to the constructor was @@ -1292,41 +1194,39 @@ class is recommended instead of calling this method directly. suitably-configured `ssl.SSLContext` to the `SSLIOStream` constructor to disable. - .. deprecated:: 5.1 + .. versionchanged:: 6.0 - The ``callback`` argument is deprecated and will be removed - in Tornado 6.0. Use the returned `.Future` instead. + The ``callback`` argument was removed. Use the returned + `.Future` instead. """ self._connecting = True - if callback is not None: - warnings.warn("callback argument is deprecated, use returned Future instead", - DeprecationWarning) - self._connect_callback = stack_context.wrap(callback) - future = None - else: - future = self._connect_future = Future() + future = Future() # type: Future[_IOStreamType] + self._connect_future = typing.cast("Future[IOStream]", future) try: self.socket.connect(address) - except socket.error as e: + except BlockingIOError: # In non-blocking mode we expect connect() to raise an # exception with EINPROGRESS or EWOULDBLOCK. - # + pass + except socket.error as e: # On freebsd, other errors such as ECONNREFUSED may be # returned immediately when attempting to connect to # localhost, so handle them the same way as an error # reported later in _handle_connect. - if (errno_from_exception(e) not in _ERRNO_INPROGRESS and - errno_from_exception(e) not in _ERRNO_WOULDBLOCK): - if future is None: - gen_log.warning("Connect error on fd %s: %s", - self.socket.fileno(), e) - self.close(exc_info=e) - return future + if future is None: + gen_log.warning("Connect error on fd %s: %s", self.socket.fileno(), e) + self.close(exc_info=e) + return future self._add_io_state(self.io_loop.WRITE) return future - def start_tls(self, server_side, ssl_options=None, server_hostname=None): + def start_tls( + self, + server_side: bool, + ssl_options: Optional[Union[Dict[str, Any], ssl.SSLContext]] = None, + server_hostname: Optional[str] = None, + ) -> Awaitable["SSLIOStream"]: """Convert this `IOStream` to an `SSLIOStream`. This enables protocols that begin in clear-text mode and @@ -1361,11 +1261,14 @@ def start_tls(self, server_side, ssl_options=None, server_hostname=None): ``ssl_options=dict(cert_reqs=ssl.CERT_NONE)`` or a suitably-configured `ssl.SSLContext` to disable. """ - if (self._read_callback or self._read_future or - self._write_callback or self._write_futures or - self._connect_callback or self._connect_future or - self._pending_callbacks or self._closed or - self._read_buffer or self._write_buffer): + if ( + self._read_future + or self._write_futures + or self._connect_future + or self._closed + or self._read_buffer + or self._write_buffer + ): raise ValueError("IOStream is not idle; cannot convert to SSL") if ssl_options is None: if server_side: @@ -1375,43 +1278,33 @@ def start_tls(self, server_side, ssl_options=None, server_hostname=None): socket = self.socket self.io_loop.remove_handler(socket) - self.socket = None - socket = ssl_wrap_socket(socket, ssl_options, - server_hostname=server_hostname, - server_side=server_side, - do_handshake_on_connect=False) + self.socket = None # type: ignore + socket = ssl_wrap_socket( + socket, + ssl_options, + server_hostname=server_hostname, + server_side=server_side, + do_handshake_on_connect=False, + ) orig_close_callback = self._close_callback self._close_callback = None - future = Future() + future = Future() # type: Future[SSLIOStream] ssl_stream = SSLIOStream(socket, ssl_options=ssl_options) - # Wrap the original close callback so we can fail our Future as well. - # If we had an "unwrap" counterpart to this method we would need - # to restore the original callback after our Future resolves - # so that repeated wrap/unwrap calls don't build up layers. - - def close_callback(): - if not future.done(): - # Note that unlike most Futures returned by IOStream, - # this one passes the underlying error through directly - # instead of wrapping everything in a StreamClosedError - # with a real_error attribute. This is because once the - # connection is established it's more helpful to raise - # the SSLError directly than to hide it behind a - # StreamClosedError (and the client is expecting SSL - # issues rather than network issues since this method is - # named start_tls). - future.set_exception(ssl_stream.error or StreamClosedError()) - if orig_close_callback is not None: - orig_close_callback() - ssl_stream.set_close_callback(close_callback) - ssl_stream._ssl_connect_callback = lambda: future.set_result(ssl_stream) + ssl_stream.set_close_callback(orig_close_callback) + ssl_stream._ssl_connect_future = future ssl_stream.max_buffer_size = self.max_buffer_size ssl_stream.read_chunk_size = self.read_chunk_size return future - def _handle_connect(self): - err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + def _handle_connect(self) -> None: + try: + err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + except socket.error as e: + # Hurd doesn't allow SO_ERROR for loopback sockets because all + # errors for such sockets are reported synchronously. + if errno_from_exception(e) == errno.ENOPROTOOPT: + err = 0 if err != 0: self.error = socket.error(err, os.strerror(err)) # IOLoop implementations may vary: some of them return @@ -1419,30 +1312,32 @@ def _handle_connect(self): # in that case a connection failure would be handled by the # error path in _handle_events instead of here. if self._connect_future is None: - gen_log.warning("Connect error on fd %s: %s", - self.socket.fileno(), errno.errorcode[err]) + gen_log.warning( + "Connect error on fd %s: %s", + self.socket.fileno(), + errno.errorcode[err], + ) self.close() return - if self._connect_callback is not None: - callback = self._connect_callback - self._connect_callback = None - self._run_callback(callback) if self._connect_future is not None: future = self._connect_future self._connect_future = None - future.set_result(self) + future_set_result_unless_cancelled(future, self) self._connecting = False - def set_nodelay(self, value): - if (self.socket is not None and - self.socket.family in (socket.AF_INET, socket.AF_INET6)): + def set_nodelay(self, value: bool) -> None: + if self.socket is not None and self.socket.family in ( + socket.AF_INET, + socket.AF_INET6, + ): try: - self.socket.setsockopt(socket.IPPROTO_TCP, - socket.TCP_NODELAY, 1 if value else 0) + self.socket.setsockopt( + socket.IPPROTO_TCP, socket.TCP_NODELAY, 1 if value else 0 + ) except socket.error as e: # Sometimes setsockopt will fail if the socket is closed # at the wrong time. This can happen with HTTPServer - # resetting the value to false between requests. + # resetting the value to ``False`` between requests. if e.errno != errno.EINVAL and not self._is_connreset(e): raise @@ -1458,18 +1353,20 @@ class SSLIOStream(IOStream): before constructing the `SSLIOStream`. Unconnected sockets will be wrapped when `IOStream.connect` is finished. """ - def __init__(self, *args, **kwargs): + + socket = None # type: ssl.SSLSocket + + def __init__(self, *args: Any, **kwargs: Any) -> None: """The ``ssl_options`` keyword argument may either be an `ssl.SSLContext` object or a dictionary of keywords arguments for `ssl.wrap_socket` """ - self._ssl_options = kwargs.pop('ssl_options', _client_ssl_defaults) - super(SSLIOStream, self).__init__(*args, **kwargs) + self._ssl_options = kwargs.pop("ssl_options", _client_ssl_defaults) + super().__init__(*args, **kwargs) self._ssl_accepting = True self._handshake_reading = False self._handshake_writing = False - self._ssl_connect_callback = None - self._server_hostname = None + self._server_hostname = None # type: Optional[str] # If the socket is already connected, attempt to start the handshake. try: @@ -1482,13 +1379,13 @@ def __init__(self, *args, **kwargs): # _handle_events. self._add_io_state(self.io_loop.WRITE) - def reading(self): - return self._handshake_reading or super(SSLIOStream, self).reading() + def reading(self) -> bool: + return self._handshake_reading or super().reading() - def writing(self): - return self._handshake_writing or super(SSLIOStream, self).writing() + def writing(self) -> bool: + return self._handshake_writing or super().writing() - def _do_ssl_handshake(self): + def _do_ssl_handshake(self) -> None: # Based on code from test_ssl.py in the python stdlib try: self._handshake_reading = False @@ -1501,25 +1398,36 @@ def _do_ssl_handshake(self): elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE: self._handshake_writing = True return - elif err.args[0] in (ssl.SSL_ERROR_EOF, - ssl.SSL_ERROR_ZERO_RETURN): + elif err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN): return self.close(exc_info=err) elif err.args[0] == ssl.SSL_ERROR_SSL: try: peer = self.socket.getpeername() except Exception: - peer = '(not connected)' - gen_log.warning("SSL Error on %s %s: %s", - self.socket.fileno(), peer, err) + peer = "(not connected)" + gen_log.warning( + "SSL Error on %s %s: %s", self.socket.fileno(), peer, err + ) return self.close(exc_info=err) raise + except ssl.CertificateError as err: + # CertificateError can happen during handshake (hostname + # verification) and should be passed to user. Starting + # in Python 3.7, this error is a subclass of SSLError + # and will be handled by the previous block instead. + return self.close(exc_info=err) except socket.error as err: # Some port scans (e.g. nmap in -sT mode) have been known # to cause do_handshake to raise EBADF and ENOTCONN, so make # those errors quiet as well. # https://groups.google.com/forum/?fromgroups#!topic/python-tornado/ApucKJat1_0 - if (self._is_connreset(err) or - err.args[0] in (errno.EBADF, errno.ENOTCONN)): + # Errno 0 is also possible in some cases (nc -z). + # https://github.com/tornadoweb/tornado/issues/2504 + if self._is_connreset(err) or err.args[0] in ( + 0, + errno.EBADF, + errno.ENOTCONN, + ): return self.close(exc_info=err) raise except AttributeError as err: @@ -1532,20 +1440,16 @@ def _do_ssl_handshake(self): if not self._verify_cert(self.socket.getpeercert()): self.close() return - self._run_ssl_connect_callback() + self._finish_ssl_connect() - def _run_ssl_connect_callback(self): - if self._ssl_connect_callback is not None: - callback = self._ssl_connect_callback - self._ssl_connect_callback = None - self._run_callback(callback) + def _finish_ssl_connect(self) -> None: if self._ssl_connect_future is not None: future = self._ssl_connect_future self._ssl_connect_future = None - future.set_result(self) + future_set_result_unless_cancelled(future, self) - def _verify_cert(self, peercert): - """Returns True if peercert is valid according to the configured + def _verify_cert(self, peercert: Any) -> bool: + """Returns ``True`` if peercert is valid according to the configured validation mode and hostname. The ssl handshake already tested the certificate for a valid @@ -1553,7 +1457,7 @@ def _verify_cert(self, peercert): the hostname. """ if isinstance(self._ssl_options, dict): - verify_mode = self._ssl_options.get('cert_reqs', ssl.CERT_NONE) + verify_mode = self._ssl_options.get("cert_reqs", ssl.CERT_NONE) elif isinstance(self._ssl_options, ssl.SSLContext): verify_mode = self._ssl_options.verify_mode assert verify_mode in (ssl.CERT_NONE, ssl.CERT_REQUIRED, ssl.CERT_OPTIONAL) @@ -1571,32 +1475,40 @@ def _verify_cert(self, peercert): else: return True - def _handle_read(self): + def _handle_read(self) -> None: if self._ssl_accepting: self._do_ssl_handshake() return - super(SSLIOStream, self)._handle_read() + super()._handle_read() - def _handle_write(self): + def _handle_write(self) -> None: if self._ssl_accepting: self._do_ssl_handshake() return - super(SSLIOStream, self)._handle_write() + super()._handle_write() - def connect(self, address, callback=None, server_hostname=None): + def connect( + self, address: Tuple, server_hostname: Optional[str] = None + ) -> "Future[SSLIOStream]": self._server_hostname = server_hostname # Ignore the result of connect(). If it fails, # wait_for_handshake will raise an error too. This is # necessary for the old semantics of the connect callback # (which takes no arguments). In 6.0 this can be refactored to # be a regular coroutine. - fut = super(SSLIOStream, self).connect(address) + # TODO: This is trickier than it looks, since if write() + # is called with a connect() pending, we want the connect + # to resolve before the write. Or do we care about this? + # (There's a test for it, but I think in practice users + # either wait for the connect before performing a write or + # they don't care about the connect Future at all) + fut = super().connect(address) fut.add_done_callback(lambda f: f.exception()) - return self.wait_for_handshake(callback) + return self.wait_for_handshake() - def _handle_connect(self): + def _handle_connect(self) -> None: # Call the superclass method to check for errors. - super(SSLIOStream, self)._handle_connect() + super()._handle_connect() if self.closed(): return # When the connection is complete, wrap the socket for SSL @@ -1611,13 +1523,17 @@ def _handle_connect(self): # wrap_socket(). self.io_loop.remove_handler(self.socket) old_state = self._state + assert old_state is not None self._state = None - self.socket = ssl_wrap_socket(self.socket, self._ssl_options, - server_hostname=self._server_hostname, - do_handshake_on_connect=False) + self.socket = ssl_wrap_socket( + self.socket, + self._ssl_options, + server_hostname=self._server_hostname, + do_handshake_on_connect=False, + ) self._add_io_state(old_state) - def wait_for_handshake(self, callback=None): + def wait_for_handshake(self) -> "Future[SSLIOStream]": """Wait for the initial SSL handshake to complete. If a ``callback`` is given, it will be called with no @@ -1636,29 +1552,22 @@ def wait_for_handshake(self, callback=None): .. versionadded:: 4.2 - .. deprecated:: 5.1 + .. versionchanged:: 6.0 - The ``callback`` argument is deprecated and will be removed - in Tornado 6.0. Use the returned `.Future` instead. + The ``callback`` argument was removed. Use the returned + `.Future` instead. """ - if (self._ssl_connect_callback is not None or - self._ssl_connect_future is not None): + if self._ssl_connect_future is not None: raise RuntimeError("Already waiting") - if callback is not None: - warnings.warn("callback argument is deprecated, use returned Future instead", - DeprecationWarning) - self._ssl_connect_callback = stack_context.wrap(callback) - future = None - else: - future = self._ssl_connect_future = Future() + future = self._ssl_connect_future = Future() if not self._ssl_accepting: - self._run_ssl_connect_callback() + self._finish_ssl_connect() return future - def write_to_fd(self, data): + def write_to_fd(self, data: memoryview) -> int: try: - return self.socket.send(data) + return self.socket.send(data) # type: ignore except ssl.SSLError as e: if e.args[0] == ssl.SSL_ERROR_WANT_WRITE: # In Python 3.5+, SSLSocket.send raises a WANT_WRITE error if @@ -1674,15 +1583,20 @@ def write_to_fd(self, data): # See https://github.com/tornadoweb/tornado/pull/2008 del data - def read_from_fd(self, buf): + def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]: try: if self._ssl_accepting: # If the handshake hasn't finished yet, there can't be anything # to read (attempting to read may or may not raise an exception # depending on the SSL version) return None + # clip buffer size at 1GB since SSL sockets only support upto 2GB + # this change in behaviour is transparent, since the function is + # already expected to (possibly) read less than the provided buffer + if len(buf) >> 30: + buf = memoryview(buf)[: 1 << 30] try: - return self.socket.recv_into(buf) + return self.socket.recv_into(buf, len(buf)) except ssl.SSLError as e: # SSLError is a subclass of socket.error, so this except # block must come first. @@ -1690,18 +1604,15 @@ def read_from_fd(self, buf): return None else: raise - except socket.error as e: - if e.args[0] in _ERRNO_WOULDBLOCK: - return None - else: - raise + except BlockingIOError: + return None finally: - buf = None + del buf - def _is_connreset(self, e): + def _is_connreset(self, e: BaseException) -> bool: if isinstance(e, ssl.SSLError) and e.args[0] == ssl.SSL_ERROR_EOF: return True - return super(SSLIOStream, self)._is_connreset(e) + return super()._is_connreset(e) class PipeIOStream(BaseIOStream): @@ -1711,30 +1622,40 @@ class PipeIOStream(BaseIOStream): by `os.pipe`) rather than an open file object. Pipes are generally one-way, so a `PipeIOStream` can be used for reading or writing but not both. + + ``PipeIOStream`` is only available on Unix-based platforms. """ - def __init__(self, fd, *args, **kwargs): + + def __init__(self, fd: int, *args: Any, **kwargs: Any) -> None: self.fd = fd self._fio = io.FileIO(self.fd, "r+") - _set_nonblocking(fd) - super(PipeIOStream, self).__init__(*args, **kwargs) - - def fileno(self): + if sys.platform == "win32": + # The form and placement of this assertion is important to mypy. + # A plain assert statement isn't recognized here. If the assertion + # were earlier it would worry that the attributes of self aren't + # set on windows. If it were missing it would complain about + # the absence of the set_blocking function. + raise AssertionError("PipeIOStream is not supported on Windows") + os.set_blocking(fd, False) + super().__init__(*args, **kwargs) + + def fileno(self) -> int: return self.fd - def close_fd(self): + def close_fd(self) -> None: self._fio.close() - def write_to_fd(self, data): + def write_to_fd(self, data: memoryview) -> int: try: - return os.write(self.fd, data) + return os.write(self.fd, data) # type: ignore finally: # Avoid keeping to data, which can be a memoryview. # See https://github.com/tornadoweb/tornado/pull/2008 del data - def read_from_fd(self, buf): + def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]: try: - return self._fio.readinto(buf) + return self._fio.readinto(buf) # type: ignore except (IOError, OSError) as e: if errno_from_exception(e) == errno.EBADF: # If the writing half of a pipe is closed, select will @@ -1744,9 +1665,10 @@ def read_from_fd(self, buf): else: raise finally: - buf = None + del buf -def doctests(): +def doctests() -> Any: import doctest + return doctest.DocTestSuite() diff --git a/tornado/locale.py b/tornado/locale.py index d45172f3b8..533ce4d41c 100644 --- a/tornado/locale.py +++ b/tornado/locale.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- -# # Copyright 2009 Facebook # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -39,30 +37,29 @@ the `Locale.translate` method will simply return the original string. """ -from __future__ import absolute_import, division, print_function - import codecs import csv import datetime -from io import BytesIO -import numbers +import gettext +import glob import os import re from tornado import escape from tornado.log import gen_log -from tornado.util import PY3 from tornado._locale_data import LOCALE_NAMES +from typing import Iterable, Any, Union, Dict, Optional + _default_locale = "en_US" -_translations = {} # type: dict +_translations = {} # type: Dict[str, Any] _supported_locales = frozenset([_default_locale]) _use_gettext = False CONTEXT_SEPARATOR = "\x04" -def get(*locale_codes): +def get(*locale_codes: str) -> "Locale": """Returns the closest match for the given locale codes. We iterate over all given locale codes in order. If we have a tight @@ -76,7 +73,7 @@ def get(*locale_codes): return Locale.get_closest(*locale_codes) -def set_default_locale(code): +def set_default_locale(code: str) -> None: """Sets the default locale. The default locale is assumed to be the language used for all strings @@ -90,7 +87,7 @@ def set_default_locale(code): _supported_locales = frozenset(list(_translations.keys()) + [_default_locale]) -def load_translations(directory, encoding=None): +def load_translations(directory: str, encoding: Optional[str] = None) -> None: """Loads translations from CSV files in a directory. Translations are strings with optional Python-style named placeholders @@ -133,54 +130,51 @@ def load_translations(directory, encoding=None): continue locale, extension = path.split(".") if not re.match("[a-z]+(_[A-Z]+)?$", locale): - gen_log.error("Unrecognized locale %r (path: %s)", locale, - os.path.join(directory, path)) + gen_log.error( + "Unrecognized locale %r (path: %s)", + locale, + os.path.join(directory, path), + ) continue full_path = os.path.join(directory, path) if encoding is None: # Try to autodetect encoding based on the BOM. - with open(full_path, 'rb') as f: - data = f.read(len(codecs.BOM_UTF16_LE)) + with open(full_path, "rb") as bf: + data = bf.read(len(codecs.BOM_UTF16_LE)) if data in (codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE): - encoding = 'utf-16' + encoding = "utf-16" else: # utf-8-sig is "utf-8 with optional BOM". It's discouraged # in most cases but is common with CSV files because Excel # cannot read utf-8 files without a BOM. - encoding = 'utf-8-sig' - if PY3: - # python 3: csv.reader requires a file open in text mode. - # Force utf8 to avoid dependence on $LANG environment variable. - f = open(full_path, "r", encoding=encoding) - else: - # python 2: csv can only handle byte strings (in ascii-compatible - # encodings), which we decode below. Transcode everything into - # utf8 before passing it to csv.reader. - f = BytesIO() - with codecs.open(full_path, "r", encoding=encoding) as infile: - f.write(escape.utf8(infile.read())) - f.seek(0) - _translations[locale] = {} - for i, row in enumerate(csv.reader(f)): - if not row or len(row) < 2: - continue - row = [escape.to_unicode(c).strip() for c in row] - english, translation = row[:2] - if len(row) > 2: - plural = row[2] or "unknown" - else: - plural = "unknown" - if plural not in ("plural", "singular", "unknown"): - gen_log.error("Unrecognized plural indicator %r in %s line %d", - plural, path, i + 1) - continue - _translations[locale].setdefault(plural, {})[english] = translation - f.close() + encoding = "utf-8-sig" + # python 3: csv.reader requires a file open in text mode. + # Specify an encoding to avoid dependence on $LANG environment variable. + with open(full_path, encoding=encoding) as f: + _translations[locale] = {} + for i, row in enumerate(csv.reader(f)): + if not row or len(row) < 2: + continue + row = [escape.to_unicode(c).strip() for c in row] + english, translation = row[:2] + if len(row) > 2: + plural = row[2] or "unknown" + else: + plural = "unknown" + if plural not in ("plural", "singular", "unknown"): + gen_log.error( + "Unrecognized plural indicator %r in %s line %d", + plural, + path, + i + 1, + ) + continue + _translations[locale].setdefault(plural, {})[english] = translation _supported_locales = frozenset(list(_translations.keys()) + [_default_locale]) gen_log.debug("Supported locales: %s", sorted(_supported_locales)) -def load_gettext_translations(directory, domain): +def load_gettext_translations(directory: str, domain: str) -> None: """Loads translations from `gettext`'s locale tree Locale tree is similar to system's ``/usr/share/locale``, like:: @@ -201,20 +195,19 @@ def load_gettext_translations(directory, domain): msgfmt mydomain.po -o {directory}/pt_BR/LC_MESSAGES/mydomain.mo """ - import gettext global _translations global _supported_locales global _use_gettext _translations = {} - for lang in os.listdir(directory): - if lang.startswith('.'): - continue # skip .svn, etc - if os.path.isfile(os.path.join(directory, lang)): - continue + + for filename in glob.glob( + os.path.join(directory, "*", "LC_MESSAGES", domain + ".mo") + ): + lang = os.path.basename(os.path.dirname(os.path.dirname(filename))) try: - os.stat(os.path.join(directory, lang, "LC_MESSAGES", domain + ".mo")) - _translations[lang] = gettext.translation(domain, directory, - languages=[lang]) + _translations[lang] = gettext.translation( + domain, directory, languages=[lang] + ) except Exception as e: gen_log.error("Cannot load translation for '%s': %s", lang, str(e)) continue @@ -223,7 +216,7 @@ def load_gettext_translations(directory, domain): gen_log.debug("Supported locales: %s", sorted(_supported_locales)) -def get_supported_locales(): +def get_supported_locales() -> Iterable[str]: """Returns a list of all the supported locale codes.""" return _supported_locales @@ -234,8 +227,11 @@ class Locale(object): After calling one of `load_translations` or `load_gettext_translations`, call `get` or `get_closest` to get a Locale object. """ + + _cache = {} # type: Dict[str, Locale] + @classmethod - def get_closest(cls, *locale_codes): + def get_closest(cls, *locale_codes: str) -> "Locale": """Returns the closest match for the given locale code.""" for code in locale_codes: if not code: @@ -253,18 +249,16 @@ def get_closest(cls, *locale_codes): return cls.get(_default_locale) @classmethod - def get(cls, code): + def get(cls, code: str) -> "Locale": """Returns the Locale for the given locale code. If it is not supported, we raise an exception. """ - if not hasattr(cls, "_cache"): - cls._cache = {} if code not in cls._cache: assert code in _supported_locales translations = _translations.get(code, None) if translations is None: - locale = CSVLocale(code, {}) + locale = CSVLocale(code, {}) # type: Locale elif _use_gettext: locale = GettextLocale(code, translations) else: @@ -272,7 +266,7 @@ def get(cls, code): cls._cache[code] = locale return cls._cache[code] - def __init__(self, code, translations): + def __init__(self, code: str) -> None: self.code = code self.name = LOCALE_NAMES.get(code, {}).get("name", u"Unknown") self.rtl = False @@ -280,19 +274,39 @@ def __init__(self, code, translations): if self.code.startswith(prefix): self.rtl = True break - self.translations = translations # Initialize strings for date formatting _ = self.translate self._months = [ - _("January"), _("February"), _("March"), _("April"), - _("May"), _("June"), _("July"), _("August"), - _("September"), _("October"), _("November"), _("December")] + _("January"), + _("February"), + _("March"), + _("April"), + _("May"), + _("June"), + _("July"), + _("August"), + _("September"), + _("October"), + _("November"), + _("December"), + ] self._weekdays = [ - _("Monday"), _("Tuesday"), _("Wednesday"), _("Thursday"), - _("Friday"), _("Saturday"), _("Sunday")] - - def translate(self, message, plural_message=None, count=None): + _("Monday"), + _("Tuesday"), + _("Wednesday"), + _("Thursday"), + _("Friday"), + _("Saturday"), + _("Sunday"), + ] + + def translate( + self, + message: str, + plural_message: Optional[str] = None, + count: Optional[int] = None, + ) -> str: """Returns the translation for the given message for this locale. If ``plural_message`` is given, you must also provide @@ -302,11 +316,23 @@ def translate(self, message, plural_message=None, count=None): """ raise NotImplementedError() - def pgettext(self, context, message, plural_message=None, count=None): + def pgettext( + self, + context: str, + message: str, + plural_message: Optional[str] = None, + count: Optional[int] = None, + ) -> str: raise NotImplementedError() - def format_date(self, date, gmt_offset=0, relative=True, shorter=False, - full_format=False): + def format_date( + self, + date: Union[int, float, datetime.datetime], + gmt_offset: int = 0, + relative: bool = True, + shorter: bool = False, + full_format: bool = False, + ) -> str: """Formats the given date (which should be GMT). By default, we return a relative time (e.g., "2 minutes ago"). You @@ -318,7 +344,7 @@ def format_date(self, date, gmt_offset=0, relative=True, shorter=False, This method is primarily intended for dates in the past. For dates in the future, we fall back to full format. """ - if isinstance(date, numbers.Real): + if isinstance(date, (int, float)): date = datetime.datetime.utcfromtimestamp(date) now = datetime.datetime.utcnow() if date > now: @@ -342,56 +368,66 @@ def format_date(self, date, gmt_offset=0, relative=True, shorter=False, if not full_format: if relative and days == 0: if seconds < 50: - return _("1 second ago", "%(seconds)d seconds ago", - seconds) % {"seconds": seconds} + return _("1 second ago", "%(seconds)d seconds ago", seconds) % { + "seconds": seconds + } if seconds < 50 * 60: minutes = round(seconds / 60.0) - return _("1 minute ago", "%(minutes)d minutes ago", - minutes) % {"minutes": minutes} + return _("1 minute ago", "%(minutes)d minutes ago", minutes) % { + "minutes": minutes + } hours = round(seconds / (60.0 * 60)) - return _("1 hour ago", "%(hours)d hours ago", - hours) % {"hours": hours} + return _("1 hour ago", "%(hours)d hours ago", hours) % {"hours": hours} if days == 0: format = _("%(time)s") - elif days == 1 and local_date.day == local_yesterday.day and \ - relative: - format = _("yesterday") if shorter else \ - _("yesterday at %(time)s") + elif days == 1 and local_date.day == local_yesterday.day and relative: + format = _("yesterday") if shorter else _("yesterday at %(time)s") elif days < 5: - format = _("%(weekday)s") if shorter else \ - _("%(weekday)s at %(time)s") + format = _("%(weekday)s") if shorter else _("%(weekday)s at %(time)s") elif days < 334: # 11mo, since confusing for same month last year - format = _("%(month_name)s %(day)s") if shorter else \ - _("%(month_name)s %(day)s at %(time)s") + format = ( + _("%(month_name)s %(day)s") + if shorter + else _("%(month_name)s %(day)s at %(time)s") + ) if format is None: - format = _("%(month_name)s %(day)s, %(year)s") if shorter else \ - _("%(month_name)s %(day)s, %(year)s at %(time)s") + format = ( + _("%(month_name)s %(day)s, %(year)s") + if shorter + else _("%(month_name)s %(day)s, %(year)s at %(time)s") + ) tfhour_clock = self.code not in ("en", "en_US", "zh_CN") if tfhour_clock: str_time = "%d:%02d" % (local_date.hour, local_date.minute) elif self.code == "zh_CN": str_time = "%s%d:%02d" % ( - (u'\u4e0a\u5348', u'\u4e0b\u5348')[local_date.hour >= 12], - local_date.hour % 12 or 12, local_date.minute) + (u"\u4e0a\u5348", u"\u4e0b\u5348")[local_date.hour >= 12], + local_date.hour % 12 or 12, + local_date.minute, + ) else: str_time = "%d:%02d %s" % ( - local_date.hour % 12 or 12, local_date.minute, - ("am", "pm")[local_date.hour >= 12]) + local_date.hour % 12 or 12, + local_date.minute, + ("am", "pm")[local_date.hour >= 12], + ) return format % { "month_name": self._months[local_date.month - 1], "weekday": self._weekdays[local_date.weekday()], "day": str(local_date.day), "year": str(local_date.year), - "time": str_time + "time": str_time, } - def format_day(self, date, gmt_offset=0, dow=True): + def format_day( + self, date: datetime.datetime, gmt_offset: int = 0, dow: bool = True + ) -> bool: """Formats the given date as a day of week. Example: "Monday, January 22". You can remove the day of week with @@ -411,7 +447,7 @@ def format_day(self, date, gmt_offset=0, dow=True): "day": str(local_date.day), } - def list(self, parts): + def list(self, parts: Any) -> str: """Returns a comma-separated list for the given list of parts. The format is, e.g., "A, B and C", "A and B" or just "A" for lists @@ -422,27 +458,37 @@ def list(self, parts): return "" if len(parts) == 1: return parts[0] - comma = u' \u0648 ' if self.code.startswith("fa") else u", " + comma = u" \u0648 " if self.code.startswith("fa") else u", " return _("%(commas)s and %(last)s") % { "commas": comma.join(parts[:-1]), "last": parts[len(parts) - 1], } - def friendly_number(self, value): + def friendly_number(self, value: int) -> str: """Returns a comma-separated number for the given integer.""" if self.code not in ("en", "en_US"): return str(value) - value = str(value) + s = str(value) parts = [] - while value: - parts.append(value[-3:]) - value = value[:-3] + while s: + parts.append(s[-3:]) + s = s[:-3] return ",".join(reversed(parts)) class CSVLocale(Locale): """Locale implementation using tornado's CSV translation format.""" - def translate(self, message, plural_message=None, count=None): + + def __init__(self, code: str, translations: Dict[str, Dict[str, str]]) -> None: + self.translations = translations + super().__init__(code) + + def translate( + self, + message: str, + plural_message: Optional[str] = None, + count: Optional[int] = None, + ) -> str: if plural_message is not None: assert count is not None if count != 1: @@ -454,35 +500,47 @@ def translate(self, message, plural_message=None, count=None): message_dict = self.translations.get("unknown", {}) return message_dict.get(message, message) - def pgettext(self, context, message, plural_message=None, count=None): + def pgettext( + self, + context: str, + message: str, + plural_message: Optional[str] = None, + count: Optional[int] = None, + ) -> str: if self.translations: - gen_log.warning('pgettext is not supported by CSVLocale') + gen_log.warning("pgettext is not supported by CSVLocale") return self.translate(message, plural_message, count) class GettextLocale(Locale): """Locale implementation using the `gettext` module.""" - def __init__(self, code, translations): - try: - # python 2 - self.ngettext = translations.ungettext - self.gettext = translations.ugettext - except AttributeError: - # python 3 - self.ngettext = translations.ngettext - self.gettext = translations.gettext + + def __init__(self, code: str, translations: gettext.NullTranslations) -> None: + self.ngettext = translations.ngettext + self.gettext = translations.gettext # self.gettext must exist before __init__ is called, since it # calls into self.translate - super(GettextLocale, self).__init__(code, translations) - - def translate(self, message, plural_message=None, count=None): + super().__init__(code) + + def translate( + self, + message: str, + plural_message: Optional[str] = None, + count: Optional[int] = None, + ) -> str: if plural_message is not None: assert count is not None return self.ngettext(message, plural_message, count) else: return self.gettext(message) - def pgettext(self, context, message, plural_message=None, count=None): + def pgettext( + self, + context: str, + message: str, + plural_message: Optional[str] = None, + count: Optional[int] = None, + ) -> str: """Allows to set context for translation, accepts plural forms. Usage example:: @@ -504,9 +562,11 @@ def pgettext(self, context, message, plural_message=None, count=None): """ if plural_message is not None: assert count is not None - msgs_with_ctxt = ("%s%s%s" % (context, CONTEXT_SEPARATOR, message), - "%s%s%s" % (context, CONTEXT_SEPARATOR, plural_message), - count) + msgs_with_ctxt = ( + "%s%s%s" % (context, CONTEXT_SEPARATOR, message), + "%s%s%s" % (context, CONTEXT_SEPARATOR, plural_message), + count, + ) result = self.ngettext(*msgs_with_ctxt) if CONTEXT_SEPARATOR in result: # Translation not found diff --git a/tornado/locks.py b/tornado/locks.py index 94adb322e1..29b6b41299 100644 --- a/tornado/locks.py +++ b/tornado/locks.py @@ -12,15 +12,20 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function - import collections -from concurrent.futures import CancelledError +import datetime +import types from tornado import gen, ioloop from tornado.concurrent import Future, future_set_result_unless_cancelled -__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock'] +from typing import Union, Optional, Type, Any, Awaitable +import typing + +if typing.TYPE_CHECKING: + from typing import Deque, Set # noqa: F401 + +__all__ = ["Condition", "Event", "Semaphore", "BoundedSemaphore", "Lock"] class _TimeoutGarbageCollector(object): @@ -32,17 +37,17 @@ class _TimeoutGarbageCollector(object): yield condition.wait(short_timeout) print('looping....') """ - def __init__(self): - self._waiters = collections.deque() # Futures. + + def __init__(self) -> None: + self._waiters = collections.deque() # type: Deque[Future] self._timeouts = 0 - def _garbage_collect(self): + def _garbage_collect(self) -> None: # Occasionally clear timed-out waiters. self._timeouts += 1 if self._timeouts > 100: self._timeouts = 0 - self._waiters = collections.deque( - w for w in self._waiters if not w.done()) + self._waiters = collections.deque(w for w in self._waiters if not w.done()) class Condition(_TimeoutGarbageCollector): @@ -61,22 +66,19 @@ class Condition(_TimeoutGarbageCollector): condition = Condition() - @gen.coroutine - def waiter(): + async def waiter(): print("I'll wait right here") - yield condition.wait() # Yield a Future. + await condition.wait() print("I'm done waiting") - @gen.coroutine - def notifier(): + async def notifier(): print("About to notify") condition.notify() print("Done notifying") - @gen.coroutine - def runner(): - # Yield two Futures; wait for waiter() and notifier() to finish. - yield [waiter(), notifier()] + async def runner(): + # Wait for waiter() and notifier() in parallel + await gen.multi([waiter(), notifier()]) IOLoop.current().run_sync(runner) @@ -93,12 +95,12 @@ def runner(): io_loop = IOLoop.current() # Wait up to 1 second for a notification. - yield condition.wait(timeout=io_loop.time() + 1) + await condition.wait(timeout=io_loop.time() + 1) ...or a `datetime.timedelta` for a timeout relative to the current time:: # Wait up to 1 second. - yield condition.wait(timeout=datetime.timedelta(seconds=1)) + await condition.wait(timeout=datetime.timedelta(seconds=1)) The method returns False if there's no notification before the deadline. @@ -108,36 +110,39 @@ def runner(): next iteration of the `.IOLoop`. """ - def __init__(self): - super(Condition, self).__init__() + def __init__(self) -> None: + super().__init__() self.io_loop = ioloop.IOLoop.current() - def __repr__(self): - result = '<%s' % (self.__class__.__name__, ) + def __repr__(self) -> str: + result = "<%s" % (self.__class__.__name__,) if self._waiters: - result += ' waiters[%s]' % len(self._waiters) - return result + '>' + result += " waiters[%s]" % len(self._waiters) + return result + ">" - def wait(self, timeout=None): + def wait( + self, timeout: Optional[Union[float, datetime.timedelta]] = None + ) -> Awaitable[bool]: """Wait for `.notify`. Returns a `.Future` that resolves ``True`` if the condition is notified, or ``False`` after a timeout. """ - waiter = Future() + waiter = Future() # type: Future[bool] self._waiters.append(waiter) if timeout: - def on_timeout(): + + def on_timeout() -> None: if not waiter.done(): future_set_result_unless_cancelled(waiter, False) self._garbage_collect() + io_loop = ioloop.IOLoop.current() timeout_handle = io_loop.add_timeout(timeout, on_timeout) - waiter.add_done_callback( - lambda _: io_loop.remove_timeout(timeout_handle)) + waiter.add_done_callback(lambda _: io_loop.remove_timeout(timeout_handle)) return waiter - def notify(self, n=1): + def notify(self, n: int = 1) -> None: """Wake ``n`` waiters.""" waiters = [] # Waiters we plan to run right now. while n and self._waiters: @@ -149,7 +154,7 @@ def notify(self, n=1): for waiter in waiters: future_set_result_unless_cancelled(waiter, True) - def notify_all(self): + def notify_all(self) -> None: """Wake all waiters.""" self.notify(len(self._waiters)) @@ -170,22 +175,19 @@ class Event(object): event = Event() - @gen.coroutine - def waiter(): + async def waiter(): print("Waiting for event") - yield event.wait() + await event.wait() print("Not waiting this time") - yield event.wait() + await event.wait() print("Done") - @gen.coroutine - def setter(): + async def setter(): print("About to set the event") event.set() - @gen.coroutine - def runner(): - yield [waiter(), setter()] + async def runner(): + await gen.multi([waiter(), setter()]) IOLoop.current().run_sync(runner) @@ -196,19 +198,22 @@ def runner(): Not waiting this time Done """ - def __init__(self): + + def __init__(self) -> None: self._value = False - self._waiters = set() + self._waiters = set() # type: Set[Future[None]] - def __repr__(self): - return '<%s %s>' % ( - self.__class__.__name__, 'set' if self.is_set() else 'clear') + def __repr__(self) -> str: + return "<%s %s>" % ( + self.__class__.__name__, + "set" if self.is_set() else "clear", + ) - def is_set(self): + def is_set(self) -> bool: """Return ``True`` if the internal flag is true.""" return self._value - def set(self): + def set(self) -> None: """Set the internal flag to ``True``. All waiters are awakened. Calling `.wait` once the flag is set will not block. @@ -220,20 +225,22 @@ def set(self): if not fut.done(): fut.set_result(None) - def clear(self): + def clear(self) -> None: """Reset the internal flag to ``False``. Calls to `.wait` will block until `.set` is called. """ self._value = False - def wait(self, timeout=None): + def wait( + self, timeout: Optional[Union[float, datetime.timedelta]] = None + ) -> Awaitable[None]: """Block until the internal flag is true. - Returns a Future, which raises `tornado.util.TimeoutError` after a + Returns an awaitable, which raises `tornado.util.TimeoutError` after a timeout. """ - fut = Future() + fut = Future() # type: Future[None] if self._value: fut.set_result(None) return fut @@ -242,29 +249,37 @@ def wait(self, timeout=None): if timeout is None: return fut else: - timeout_fut = gen.with_timeout(timeout, fut, quiet_exceptions=(CancelledError,)) + timeout_fut = gen.with_timeout(timeout, fut) # This is a slightly clumsy workaround for the fact that # gen.with_timeout doesn't cancel its futures. Cancelling # fut will remove it from the waiters list. - timeout_fut.add_done_callback(lambda tf: fut.cancel() if not fut.done() else None) + timeout_fut.add_done_callback( + lambda tf: fut.cancel() if not fut.done() else None + ) return timeout_fut class _ReleasingContextManager(object): """Releases a Lock or Semaphore at the end of a "with" statement. - with (yield semaphore.acquire()): - pass + with (yield semaphore.acquire()): + pass - # Now semaphore.release() has been called. + # Now semaphore.release() has been called. """ - def __init__(self, obj): + + def __init__(self, obj: Any) -> None: self._obj = obj - def __enter__(self): + def __enter__(self) -> None: pass - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: "Optional[Type[BaseException]]", + exc_val: Optional[BaseException], + exc_tb: Optional[types.TracebackType], + ) -> None: self._obj.release() @@ -290,12 +305,11 @@ class Semaphore(_TimeoutGarbageCollector): # Ensure reliable doctest output: resolve Futures one at a time. futures_q = deque([Future() for _ in range(3)]) - @gen.coroutine - def simulator(futures): + async def simulator(futures): for f in futures: # simulate the asynchronous passage of time - yield gen.moment - yield gen.moment + await gen.sleep(0) + await gen.sleep(0) f.set_result(None) IOLoop.current().add_callback(simulator, list(futures_q)) @@ -311,20 +325,18 @@ def use_some_resource(): sem = Semaphore(2) - @gen.coroutine - def worker(worker_id): - yield sem.acquire() + async def worker(worker_id): + await sem.acquire() try: print("Worker %d is working" % worker_id) - yield use_some_resource() + await use_some_resource() finally: print("Worker %d is done" % worker_id) sem.release() - @gen.coroutine - def runner(): + async def runner(): # Join all workers. - yield [worker(i) for i in range(3)] + await gen.multi([worker(i) for i in range(3)]) IOLoop.current().run_sync(runner) @@ -340,47 +352,50 @@ def runner(): Workers 0 and 1 are allowed to run concurrently, but worker 2 waits until the semaphore has been released once, by worker 0. - `.acquire` is a context manager, so ``worker`` could be written as:: + The semaphore can be used as an async context manager:: - @gen.coroutine - def worker(worker_id): - with (yield sem.acquire()): + async def worker(worker_id): + async with sem: print("Worker %d is working" % worker_id) - yield use_some_resource() + await use_some_resource() # Now the semaphore has been released. print("Worker %d is done" % worker_id) - In Python 3.5, the semaphore itself can be used as an async context - manager:: + For compatibility with older versions of Python, `.acquire` is a + context manager, so ``worker`` could also be written as:: - async def worker(worker_id): - async with sem: + @gen.coroutine + def worker(worker_id): + with (yield sem.acquire()): print("Worker %d is working" % worker_id) - await use_some_resource() + yield use_some_resource() # Now the semaphore has been released. print("Worker %d is done" % worker_id) .. versionchanged:: 4.3 Added ``async with`` support in Python 3.5. + """ - def __init__(self, value=1): - super(Semaphore, self).__init__() + + def __init__(self, value: int = 1) -> None: + super().__init__() if value < 0: - raise ValueError('semaphore initial value must be >= 0') + raise ValueError("semaphore initial value must be >= 0") self._value = value - def __repr__(self): - res = super(Semaphore, self).__repr__() - extra = 'locked' if self._value == 0 else 'unlocked,value:{0}'.format( - self._value) + def __repr__(self) -> str: + res = super().__repr__() + extra = ( + "locked" if self._value == 0 else "unlocked,value:{0}".format(self._value) + ) if self._waiters: - extra = '{0},waiters:{1}'.format(extra, len(self._waiters)) - return '<{0} [{1}]>'.format(res[1:-1], extra) + extra = "{0},waiters:{1}".format(extra, len(self._waiters)) + return "<{0} [{1}]>".format(res[1:-1], extra) - def release(self): + def release(self) -> None: """Increment the counter and wake one waiter.""" self._value += 1 while self._waiters: @@ -397,42 +412,54 @@ def release(self): waiter.set_result(_ReleasingContextManager(self)) break - def acquire(self, timeout=None): - """Decrement the counter. Returns a Future. + def acquire( + self, timeout: Optional[Union[float, datetime.timedelta]] = None + ) -> Awaitable[_ReleasingContextManager]: + """Decrement the counter. Returns an awaitable. - Block if the counter is zero and wait for a `.release`. The Future + Block if the counter is zero and wait for a `.release`. The awaitable raises `.TimeoutError` after the deadline. """ - waiter = Future() + waiter = Future() # type: Future[_ReleasingContextManager] if self._value > 0: self._value -= 1 waiter.set_result(_ReleasingContextManager(self)) else: self._waiters.append(waiter) if timeout: - def on_timeout(): + + def on_timeout() -> None: if not waiter.done(): waiter.set_exception(gen.TimeoutError()) self._garbage_collect() + io_loop = ioloop.IOLoop.current() timeout_handle = io_loop.add_timeout(timeout, on_timeout) waiter.add_done_callback( - lambda _: io_loop.remove_timeout(timeout_handle)) + lambda _: io_loop.remove_timeout(timeout_handle) + ) return waiter - def __enter__(self): - raise RuntimeError( - "Use Semaphore like 'with (yield semaphore.acquire())', not like" - " 'with semaphore'") - - __exit__ = __enter__ - - @gen.coroutine - def __aenter__(self): - yield self.acquire() - - @gen.coroutine - def __aexit__(self, typ, value, tb): + def __enter__(self) -> None: + raise RuntimeError("Use 'async with' instead of 'with' for Semaphore") + + def __exit__( + self, + typ: "Optional[Type[BaseException]]", + value: Optional[BaseException], + traceback: Optional[types.TracebackType], + ) -> None: + self.__enter__() + + async def __aenter__(self) -> None: + await self.acquire() + + async def __aexit__( + self, + typ: "Optional[Type[BaseException]]", + value: Optional[BaseException], + tb: Optional[types.TracebackType], + ) -> None: self.release() @@ -444,15 +471,16 @@ class BoundedSemaphore(Semaphore): resources with limited capacity, so a semaphore released too many times is a sign of a bug. """ - def __init__(self, value=1): - super(BoundedSemaphore, self).__init__(value=value) + + def __init__(self, value: int = 1) -> None: + super().__init__(value=value) self._initial_value = value - def release(self): + def release(self) -> None: """Increment the counter and wake one waiter.""" if self._value >= self._initial_value: raise ValueError("Semaphore released too many times") - super(BoundedSemaphore, self).release() + super().release() class Lock(object): @@ -464,26 +492,24 @@ class Lock(object): Releasing an unlocked lock raises `RuntimeError`. - `acquire` supports the context manager protocol in all Python versions: + A Lock can be used as an async context manager with the ``async + with`` statement: - >>> from tornado import gen, locks + >>> from tornado import locks >>> lock = locks.Lock() >>> - >>> @gen.coroutine - ... def f(): - ... with (yield lock.acquire()): + >>> async def f(): + ... async with lock: ... # Do something holding the lock. ... pass ... ... # Now the lock is released. - In Python 3.5, `Lock` also supports the async context manager - protocol. Note that in this case there is no `acquire`, because - ``async with`` includes both the ``yield`` and the ``acquire`` - (just as it does with `threading.Lock`): + For compatibility with older versions of Python, the `.acquire` + method asynchronously returns a regular context manager: - >>> async def f2(): # doctest: +SKIP - ... async with lock: + >>> async def f2(): + ... with (yield lock.acquire()): ... # Do something holding the lock. ... pass ... @@ -493,23 +519,24 @@ class Lock(object): Added ``async with`` support in Python 3.5. """ - def __init__(self): + + def __init__(self) -> None: self._block = BoundedSemaphore(value=1) - def __repr__(self): - return "<%s _block=%s>" % ( - self.__class__.__name__, - self._block) + def __repr__(self) -> str: + return "<%s _block=%s>" % (self.__class__.__name__, self._block) - def acquire(self, timeout=None): - """Attempt to lock. Returns a Future. + def acquire( + self, timeout: Optional[Union[float, datetime.timedelta]] = None + ) -> Awaitable[_ReleasingContextManager]: + """Attempt to lock. Returns an awaitable. - Returns a Future, which raises `tornado.util.TimeoutError` after a + Returns an awaitable, which raises `tornado.util.TimeoutError` after a timeout. """ return self._block.acquire(timeout) - def release(self): + def release(self) -> None: """Unlock. The first coroutine in line waiting for `acquire` gets the lock. @@ -519,18 +546,26 @@ def release(self): try: self._block.release() except ValueError: - raise RuntimeError('release unlocked lock') - - def __enter__(self): - raise RuntimeError( - "Use Lock like 'with (yield lock)', not like 'with lock'") - - __exit__ = __enter__ - - @gen.coroutine - def __aenter__(self): - yield self.acquire() - - @gen.coroutine - def __aexit__(self, typ, value, tb): + raise RuntimeError("release unlocked lock") + + def __enter__(self) -> None: + raise RuntimeError("Use `async with` instead of `with` for Lock") + + def __exit__( + self, + typ: "Optional[Type[BaseException]]", + value: Optional[BaseException], + tb: Optional[types.TracebackType], + ) -> None: + self.__enter__() + + async def __aenter__(self) -> None: + await self.acquire() + + async def __aexit__( + self, + typ: "Optional[Type[BaseException]]", + value: Optional[BaseException], + tb: Optional[types.TracebackType], + ) -> None: self.release() diff --git a/tornado/log.py b/tornado/log.py index cda905c9ba..810a0373b5 100644 --- a/tornado/log.py +++ b/tornado/log.py @@ -27,8 +27,6 @@ `logging` module. For example, you may wish to send ``tornado.access`` logs to a separate file for analysis. """ -from __future__ import absolute_import, division, print_function - import logging import logging.handlers import sys @@ -37,14 +35,16 @@ from tornado.util import unicode_type, basestring_type try: - import colorama + import colorama # type: ignore except ImportError: colorama = None try: - import curses # type: ignore + import curses except ImportError: - curses = None + curses = None # type: ignore + +from typing import Dict, Any, cast, Optional # Logger objects for internal tornado use access_log = logging.getLogger("tornado.access") @@ -52,16 +52,17 @@ gen_log = logging.getLogger("tornado.general") -def _stderr_supports_color(): +def _stderr_supports_color() -> bool: try: - if hasattr(sys.stderr, 'isatty') and sys.stderr.isatty(): + if hasattr(sys.stderr, "isatty") and sys.stderr.isatty(): if curses: curses.setupterm() if curses.tigetnum("colors") > 0: return True elif colorama: - if sys.stderr is getattr(colorama.initialise, 'wrapped_stderr', - object()): + if sys.stderr is getattr( + colorama.initialise, "wrapped_stderr", object() + ): return True except Exception: # Very broad exception handling because it's always better to @@ -70,7 +71,7 @@ def _stderr_supports_color(): return False -def _safe_unicode(s): +def _safe_unicode(s: Any) -> str: try: return _unicode(s) except UnicodeDecodeError: @@ -101,18 +102,25 @@ class LogFormatter(logging.Formatter): Added support for ``colorama``. Changed the constructor signature to be compatible with `logging.config.dictConfig`. """ - DEFAULT_FORMAT = \ - '%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s' - DEFAULT_DATE_FORMAT = '%y%m%d %H:%M:%S' + + DEFAULT_FORMAT = "%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s" # noqa: E501 + DEFAULT_DATE_FORMAT = "%y%m%d %H:%M:%S" DEFAULT_COLORS = { logging.DEBUG: 4, # Blue logging.INFO: 2, # Green logging.WARNING: 3, # Yellow logging.ERROR: 1, # Red + logging.CRITICAL: 5, # Magenta } - def __init__(self, fmt=DEFAULT_FORMAT, datefmt=DEFAULT_DATE_FORMAT, - style='%', color=True, colors=DEFAULT_COLORS): + def __init__( + self, + fmt: str = DEFAULT_FORMAT, + datefmt: str = DEFAULT_DATE_FORMAT, + style: str = "%", + color: bool = True, + colors: Dict[int, int] = DEFAULT_COLORS, + ) -> None: r""" :arg bool color: Enables color support. :arg str fmt: Log message format. @@ -131,34 +139,29 @@ def __init__(self, fmt=DEFAULT_FORMAT, datefmt=DEFAULT_DATE_FORMAT, logging.Formatter.__init__(self, datefmt=datefmt) self._fmt = fmt - self._colors = {} + self._colors = {} # type: Dict[int, str] if color and _stderr_supports_color(): if curses is not None: - # The curses module has some str/bytes confusion in - # python3. Until version 3.2.3, most methods return - # bytes, but only accept strings. In addition, we want to - # output these strings with the logging module, which - # works with unicode strings. The explicit calls to - # unicode() below are harmless in python2 but will do the - # right conversion in python 3. - fg_color = (curses.tigetstr("setaf") or - curses.tigetstr("setf") or "") - if (3, 0) < sys.version_info < (3, 2, 3): - fg_color = unicode_type(fg_color, "ascii") + fg_color = curses.tigetstr("setaf") or curses.tigetstr("setf") or b"" for levelno, code in colors.items(): - self._colors[levelno] = unicode_type(curses.tparm(fg_color, code), "ascii") + # Convert the terminal control characters from + # bytes to unicode strings for easier use with the + # logging module. + self._colors[levelno] = unicode_type( + curses.tparm(fg_color, code), "ascii" + ) self._normal = unicode_type(curses.tigetstr("sgr0"), "ascii") else: # If curses is not present (currently we'll only get here for # colorama on windows), assume hard-coded ANSI color codes. for levelno, code in colors.items(): - self._colors[levelno] = '\033[2;3%dm' % code - self._normal = '\033[0m' + self._colors[levelno] = "\033[2;3%dm" % code + self._normal = "\033[0m" else: - self._normal = '' + self._normal = "" - def format(self, record): + def format(self, record: Any) -> str: try: message = record.getMessage() assert isinstance(message, basestring_type) # guaranteed by logging @@ -182,13 +185,13 @@ def format(self, record): except Exception as e: record.message = "Bad message (%r): %r" % (e, record.__dict__) - record.asctime = self.formatTime(record, self.datefmt) + record.asctime = self.formatTime(record, cast(str, self.datefmt)) if record.levelno in self._colors: record.color = self._colors[record.levelno] record.end_color = self._normal else: - record.color = record.end_color = '' + record.color = record.end_color = "" formatted = self._fmt % record.__dict__ @@ -200,12 +203,14 @@ def format(self, record): # each line separately so that non-utf8 bytes don't cause # all the newlines to turn into '\n'. lines = [formatted.rstrip()] - lines.extend(_safe_unicode(ln) for ln in record.exc_text.split('\n')) - formatted = '\n'.join(lines) + lines.extend(_safe_unicode(ln) for ln in record.exc_text.split("\n")) + formatted = "\n".join(lines) return formatted.replace("\n", "\n ") -def enable_pretty_logging(options=None, logger=None): +def enable_pretty_logging( + options: Any = None, logger: Optional[logging.Logger] = None +) -> None: """Turns on formatted logging output as configured. This is called automatically by `tornado.options.parse_command_line` @@ -213,41 +218,47 @@ def enable_pretty_logging(options=None, logger=None): """ if options is None: import tornado.options + options = tornado.options.options - if options.logging is None or options.logging.lower() == 'none': + if options.logging is None or options.logging.lower() == "none": return if logger is None: logger = logging.getLogger() logger.setLevel(getattr(logging, options.logging.upper())) if options.log_file_prefix: rotate_mode = options.log_rotate_mode - if rotate_mode == 'size': + if rotate_mode == "size": channel = logging.handlers.RotatingFileHandler( filename=options.log_file_prefix, maxBytes=options.log_file_max_size, - backupCount=options.log_file_num_backups) - elif rotate_mode == 'time': + backupCount=options.log_file_num_backups, + encoding="utf-8", + ) # type: logging.Handler + elif rotate_mode == "time": channel = logging.handlers.TimedRotatingFileHandler( filename=options.log_file_prefix, when=options.log_rotate_when, interval=options.log_rotate_interval, - backupCount=options.log_file_num_backups) + backupCount=options.log_file_num_backups, + encoding="utf-8", + ) else: - error_message = 'The value of log_rotate_mode option should be ' +\ - '"size" or "time", not "%s".' % rotate_mode + error_message = ( + "The value of log_rotate_mode option should be " + + '"size" or "time", not "%s".' % rotate_mode + ) raise ValueError(error_message) channel.setFormatter(LogFormatter(color=False)) logger.addHandler(channel) - if (options.log_to_stderr or - (options.log_to_stderr is None and not logger.handlers)): + if options.log_to_stderr or (options.log_to_stderr is None and not logger.handlers): # Set up color if we are in a tty and curses is installed channel = logging.StreamHandler() channel.setFormatter(LogFormatter()) logger.addHandler(channel) -def define_logging_options(options=None): +def define_logging_options(options: Any = None) -> None: """Add logging-related flags to ``options``. These options are present automatically on the default options instance; @@ -259,32 +270,70 @@ def define_logging_options(options=None): if options is None: # late import to prevent cycle import tornado.options + options = tornado.options.options - options.define("logging", default="info", - help=("Set the Python log level. If 'none', tornado won't touch the " - "logging configuration."), - metavar="debug|info|warning|error|none") - options.define("log_to_stderr", type=bool, default=None, - help=("Send log output to stderr (colorized if possible). " - "By default use stderr if --log_file_prefix is not set and " - "no other logging is configured.")) - options.define("log_file_prefix", type=str, default=None, metavar="PATH", - help=("Path prefix for log files. " - "Note that if you are running multiple tornado processes, " - "log_file_prefix must be different for each of them (e.g. " - "include the port number)")) - options.define("log_file_max_size", type=int, default=100 * 1000 * 1000, - help="max size of log files before rollover") - options.define("log_file_num_backups", type=int, default=10, - help="number of log files to keep") - - options.define("log_rotate_when", type=str, default='midnight', - help=("specify the type of TimedRotatingFileHandler interval " - "other options:('S', 'M', 'H', 'D', 'W0'-'W6')")) - options.define("log_rotate_interval", type=int, default=1, - help="The interval value of timed rotating") - - options.define("log_rotate_mode", type=str, default='size', - help="The mode of rotating files(time or size)") + options.define( + "logging", + default="info", + help=( + "Set the Python log level. If 'none', tornado won't touch the " + "logging configuration." + ), + metavar="debug|info|warning|error|none", + ) + options.define( + "log_to_stderr", + type=bool, + default=None, + help=( + "Send log output to stderr (colorized if possible). " + "By default use stderr if --log_file_prefix is not set and " + "no other logging is configured." + ), + ) + options.define( + "log_file_prefix", + type=str, + default=None, + metavar="PATH", + help=( + "Path prefix for log files. " + "Note that if you are running multiple tornado processes, " + "log_file_prefix must be different for each of them (e.g. " + "include the port number)" + ), + ) + options.define( + "log_file_max_size", + type=int, + default=100 * 1000 * 1000, + help="max size of log files before rollover", + ) + options.define( + "log_file_num_backups", type=int, default=10, help="number of log files to keep" + ) + + options.define( + "log_rotate_when", + type=str, + default="midnight", + help=( + "specify the type of TimedRotatingFileHandler interval " + "other options:('S', 'M', 'H', 'D', 'W0'-'W6')" + ), + ) + options.define( + "log_rotate_interval", + type=int, + default=1, + help="The interval value of timed rotating", + ) + + options.define( + "log_rotate_mode", + type=str, + default="size", + help="The mode of rotating files(time or size)", + ) options.add_parse_callback(lambda: enable_pretty_logging(options)) diff --git a/tornado/netutil.py b/tornado/netutil.py index 08c9d88627..f8a3038051 100644 --- a/tornado/netutil.py +++ b/tornado/netutil.py @@ -15,70 +15,51 @@ """Miscellaneous network utility code.""" -from __future__ import absolute_import, division, print_function - +import concurrent.futures import errno import os import sys import socket +import ssl import stat from tornado.concurrent import dummy_executor, run_on_executor -from tornado import gen from tornado.ioloop import IOLoop -from tornado.platform.auto import set_close_exec -from tornado.util import PY3, Configurable, errno_from_exception - -try: - import ssl -except ImportError: - # ssl is not available on Google App Engine - ssl = None - -if PY3: - xrange = range - -if ssl is not None: - # Note that the naming of ssl.Purpose is confusing; the purpose - # of a context is to authentiate the opposite side of the connection. - _client_ssl_defaults = ssl.create_default_context( - ssl.Purpose.SERVER_AUTH) - _server_ssl_defaults = ssl.create_default_context( - ssl.Purpose.CLIENT_AUTH) - if hasattr(ssl, 'OP_NO_COMPRESSION'): - # See netutil.ssl_options_to_context - _client_ssl_defaults.options |= ssl.OP_NO_COMPRESSION - _server_ssl_defaults.options |= ssl.OP_NO_COMPRESSION -else: - # Google App Engine - _client_ssl_defaults = dict(cert_reqs=None, - ca_certs=None) - _server_ssl_defaults = {} +from tornado.util import Configurable, errno_from_exception + +from typing import List, Callable, Any, Type, Dict, Union, Tuple, Awaitable, Optional + +# Note that the naming of ssl.Purpose is confusing; the purpose +# of a context is to authenticate the opposite side of the connection. +_client_ssl_defaults = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) +_server_ssl_defaults = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) +if hasattr(ssl, "OP_NO_COMPRESSION"): + # See netutil.ssl_options_to_context + _client_ssl_defaults.options |= ssl.OP_NO_COMPRESSION + _server_ssl_defaults.options |= ssl.OP_NO_COMPRESSION # ThreadedResolver runs getaddrinfo on a thread. If the hostname is unicode, # getaddrinfo attempts to import encodings.idna. If this is done at # module-import time, the import lock is already held by the main thread, # leading to deadlock. Avoid it by caching the idna encoder on the main # thread now. -u'foo'.encode('idna') +u"foo".encode("idna") # For undiagnosed reasons, 'latin1' codec may also need to be preloaded. -u'foo'.encode('latin1') - -# These errnos indicate that a non-blocking operation must be retried -# at a later time. On most platforms they're the same value, but on -# some they differ. -_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN) - -if hasattr(errno, "WSAEWOULDBLOCK"): - _ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,) # type: ignore +u"foo".encode("latin1") # Default backlog used when calling sock.listen() _DEFAULT_BACKLOG = 128 -def bind_sockets(port, address=None, family=socket.AF_UNSPEC, - backlog=_DEFAULT_BACKLOG, flags=None, reuse_port=False): +def bind_sockets( + port: int, + address: Optional[str] = None, + family: socket.AddressFamily = socket.AF_UNSPEC, + backlog: int = _DEFAULT_BACKLOG, + flags: Optional[int] = None, + reuse_port: bool = False, +) -> List[socket.socket]: """Creates listening sockets bound to the given port and address. Returns a list of socket objects (multiple sockets are returned if @@ -118,11 +99,23 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC, if flags is None: flags = socket.AI_PASSIVE bound_port = None - for res in set(socket.getaddrinfo(address, port, family, socket.SOCK_STREAM, - 0, flags)): + unique_addresses = set() # type: set + for res in sorted( + socket.getaddrinfo(address, port, family, socket.SOCK_STREAM, 0, flags), + key=lambda x: x[0], + ): + if res in unique_addresses: + continue + + unique_addresses.add(res) + af, socktype, proto, canonname, sockaddr = res - if (sys.platform == 'darwin' and address == 'localhost' and - af == socket.AF_INET6 and sockaddr[3] != 0): + if ( + sys.platform == "darwin" + and address == "localhost" + and af == socket.AF_INET6 + and sockaddr[3] != 0 + ): # Mac OS X includes a link-local address fe80::1%lo0 in the # getaddrinfo results for 'localhost'. However, the firewall # doesn't understand that this is a local address and will @@ -136,9 +129,13 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC, if errno_from_exception(e) == errno.EAFNOSUPPORT: continue raise - set_close_exec(sock.fileno()) - if os.name != 'nt': - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if os.name != "nt": + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + except socket.error as e: + if errno_from_exception(e) != errno.ENOPROTOOPT: + # Hurd doesn't support SO_REUSEADDR. + raise if reuse_port: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) if af == socket.AF_INET6: @@ -159,16 +156,41 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC, if requested_port == 0 and bound_port is not None: sockaddr = tuple([host, bound_port] + list(sockaddr[2:])) - sock.setblocking(0) - sock.bind(sockaddr) + sock.setblocking(False) + try: + sock.bind(sockaddr) + except OSError as e: + if ( + errno_from_exception(e) == errno.EADDRNOTAVAIL + and address == "localhost" + and sockaddr[0] == "::1" + ): + # On some systems (most notably docker with default + # configurations), ipv6 is partially disabled: + # socket.has_ipv6 is true, we can create AF_INET6 + # sockets, and getaddrinfo("localhost", ..., + # AF_PASSIVE) resolves to ::1, but we get an error + # when binding. + # + # Swallow the error, but only for this specific case. + # If EADDRNOTAVAIL occurs in other situations, it + # might be a real problem like a typo in a + # configuration. + sock.close() + continue + else: + raise bound_port = sock.getsockname()[1] sock.listen(backlog) sockets.append(sock) return sockets -if hasattr(socket, 'AF_UNIX'): - def bind_unix_socket(file, mode=0o600, backlog=_DEFAULT_BACKLOG): +if hasattr(socket, "AF_UNIX"): + + def bind_unix_socket( + file: str, mode: int = 0o600, backlog: int = _DEFAULT_BACKLOG + ) -> socket.socket: """Creates a listening unix socket. If a socket with the given name already exists, it will be deleted. @@ -179,14 +201,17 @@ def bind_unix_socket(file, mode=0o600, backlog=_DEFAULT_BACKLOG): `bind_sockets`) """ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - set_close_exec(sock.fileno()) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.setblocking(0) try: - st = os.stat(file) - except OSError as err: - if errno_from_exception(err) != errno.ENOENT: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + except socket.error as e: + if errno_from_exception(e) != errno.ENOPROTOOPT: + # Hurd doesn't support SO_REUSEADDR raise + sock.setblocking(False) + try: + st = os.stat(file) + except FileNotFoundError: + pass else: if stat.S_ISSOCK(st.st_mode): os.remove(file) @@ -198,7 +223,9 @@ def bind_unix_socket(file, mode=0o600, backlog=_DEFAULT_BACKLOG): return sock -def add_accept_handler(sock, callback): +def add_accept_handler( + sock: socket.socket, callback: Callable[[socket.socket, Any], None] +) -> Callable[[], None]: """Adds an `.IOLoop` event handler to accept new connections on ``sock``. When a connection is accepted, ``callback(connection, address)`` will @@ -219,7 +246,7 @@ def add_accept_handler(sock, callback): io_loop = IOLoop.current() removed = [False] - def accept_handler(fd, events): + def accept_handler(fd: socket.socket, events: int) -> None: # More connections may come in while we're handling callbacks; # to prevent starvation of other tasks we must limit the number # of connections we accept at a time. Ideally we would accept @@ -231,27 +258,24 @@ def accept_handler(fd, events): # Instead, we use the (default) listen backlog as a rough # heuristic for the number of connections we can reasonably # accept at once. - for i in xrange(_DEFAULT_BACKLOG): + for i in range(_DEFAULT_BACKLOG): if removed[0]: # The socket was probably closed return try: connection, address = sock.accept() - except socket.error as e: - # _ERRNO_WOULDBLOCK indicate we have accepted every + except BlockingIOError: + # EWOULDBLOCK indicates we have accepted every # connection that is available. - if errno_from_exception(e) in _ERRNO_WOULDBLOCK: - return + return + except ConnectionAbortedError: # ECONNABORTED indicates that there was a connection # but it was closed while still in the accept queue. # (observed on FreeBSD). - if errno_from_exception(e) == errno.ECONNABORTED: - continue - raise - set_close_exec(connection.fileno()) + continue callback(connection, address) - def remove_handler(): + def remove_handler() -> None: io_loop.remove_handler(sock) removed[0] = True @@ -259,19 +283,19 @@ def remove_handler(): return remove_handler -def is_valid_ip(ip): - """Returns true if the given string is a well-formed IP address. +def is_valid_ip(ip: str) -> bool: + """Returns ``True`` if the given string is a well-formed IP address. Supports IPv4 and IPv6. """ - if not ip or '\x00' in ip: + if not ip or "\x00" in ip: # getaddrinfo resolves empty strings to localhost, and truncates # on zero bytes. return False try: - res = socket.getaddrinfo(ip, 0, socket.AF_UNSPEC, - socket.SOCK_STREAM, - 0, socket.AI_NUMERICHOST) + res = socket.getaddrinfo( + ip, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_NUMERICHOST + ) return bool(res) except socket.gaierror as e: if e.args[0] == socket.EAI_NONAME: @@ -303,15 +327,18 @@ class method:: The default implementation has changed from `BlockingResolver` to `DefaultExecutorResolver`. """ + @classmethod - def configurable_base(cls): + def configurable_base(cls) -> Type["Resolver"]: return Resolver @classmethod - def configurable_default(cls): + def configurable_default(cls) -> Type["Resolver"]: return DefaultExecutorResolver - def resolve(self, host, port, family=socket.AF_UNSPEC, callback=None): + def resolve( + self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC + ) -> Awaitable[List[Tuple[int, Any]]]: """Resolves an address. The ``host`` argument is a string which may be a hostname or a @@ -329,13 +356,13 @@ def resolve(self, host, port, family=socket.AF_UNSPEC, callback=None): .. versionchanged:: 4.4 Standardized all implementations to raise `IOError`. - .. deprecated:: 5.1 - The ``callback`` argument is deprecated and will be removed in 6.0. + .. versionchanged:: 6.0 The ``callback`` argument was removed. Use the returned awaitable object instead. + """ raise NotImplementedError() - def close(self): + def close(self) -> None: """Closes the `Resolver`, freeing any resources used. .. versionadded:: 3.1 @@ -344,7 +371,9 @@ def close(self): pass -def _resolve_addr(host, port, family=socket.AF_UNSPEC): +def _resolve_addr( + host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC +) -> List[Tuple[int, Any]]: # On Solaris, getaddrinfo fails if the given port is not found # in /etc/services and no socket type is given, so we must pass # one here. The socket type used here doesn't seem to actually @@ -352,9 +381,9 @@ def _resolve_addr(host, port, family=socket.AF_UNSPEC): # so the addresses we return should still be usable with SOCK_DGRAM. addrinfo = socket.getaddrinfo(host, port, family, socket.SOCK_STREAM) results = [] - for family, socktype, proto, canonname, address in addrinfo: - results.append((family, address)) - return results + for fam, socktype, proto, canonname, address in addrinfo: + results.append((fam, address)) + return results # type: ignore class DefaultExecutorResolver(Resolver): @@ -362,11 +391,14 @@ class DefaultExecutorResolver(Resolver): .. versionadded:: 5.0 """ - @gen.coroutine - def resolve(self, host, port, family=socket.AF_UNSPEC): - result = yield IOLoop.current().run_in_executor( - None, _resolve_addr, host, port, family) - raise gen.Return(result) + + async def resolve( + self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC + ) -> List[Tuple[int, Any]]: + result = await IOLoop.current().run_in_executor( + None, _resolve_addr, host, port, family + ) + return result class ExecutorResolver(Resolver): @@ -386,7 +418,12 @@ class ExecutorResolver(Resolver): The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead of this class. """ - def initialize(self, executor=None, close_executor=True): + + def initialize( + self, + executor: Optional[concurrent.futures.Executor] = None, + close_executor: bool = True, + ) -> None: self.io_loop = IOLoop.current() if executor is not None: self.executor = executor @@ -395,13 +432,15 @@ def initialize(self, executor=None, close_executor=True): self.executor = dummy_executor self.close_executor = False - def close(self): + def close(self) -> None: if self.close_executor: self.executor.shutdown() - self.executor = None + self.executor = None # type: ignore @run_on_executor - def resolve(self, host, port, family=socket.AF_UNSPEC): + def resolve( + self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC + ) -> List[Tuple[int, Any]]: return _resolve_addr(host, port, family) @@ -415,8 +454,9 @@ class BlockingResolver(ExecutorResolver): The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead of this class. """ - def initialize(self): - super(BlockingResolver, self).initialize() + + def initialize(self) -> None: # type: ignore + super().initialize() class ThreadedResolver(ExecutorResolver): @@ -439,24 +479,25 @@ class ThreadedResolver(ExecutorResolver): The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead of this class. """ + _threadpool = None # type: ignore _threadpool_pid = None # type: int - def initialize(self, num_threads=10): + def initialize(self, num_threads: int = 10) -> None: # type: ignore threadpool = ThreadedResolver._create_threadpool(num_threads) - super(ThreadedResolver, self).initialize( - executor=threadpool, close_executor=False) + super().initialize(executor=threadpool, close_executor=False) @classmethod - def _create_threadpool(cls, num_threads): + def _create_threadpool( + cls, num_threads: int + ) -> concurrent.futures.ThreadPoolExecutor: pid = os.getpid() if cls._threadpool_pid != pid: # Threads cannot survive after a fork, so if our pid isn't what it # was when we created the pool then delete it. cls._threadpool = None if cls._threadpool is None: - from concurrent.futures import ThreadPoolExecutor - cls._threadpool = ThreadPoolExecutor(num_threads) + cls._threadpool = concurrent.futures.ThreadPoolExecutor(num_threads) cls._threadpool_pid = pid return cls._threadpool @@ -483,31 +524,37 @@ class OverrideResolver(Resolver): .. versionchanged:: 5.0 Added support for host-port-family triplets. """ - def initialize(self, resolver, mapping): + + def initialize(self, resolver: Resolver, mapping: dict) -> None: self.resolver = resolver self.mapping = mapping - def close(self): + def close(self) -> None: self.resolver.close() - def resolve(self, host, port, family=socket.AF_UNSPEC, *args, **kwargs): + def resolve( + self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC + ) -> Awaitable[List[Tuple[int, Any]]]: if (host, port, family) in self.mapping: host, port = self.mapping[(host, port, family)] elif (host, port) in self.mapping: host, port = self.mapping[(host, port)] elif host in self.mapping: host = self.mapping[host] - return self.resolver.resolve(host, port, family, *args, **kwargs) + return self.resolver.resolve(host, port, family) # These are the keyword arguments to ssl.wrap_socket that must be translated # to their SSLContext equivalents (the other arguments are still passed # to SSLContext.wrap_socket). -_SSL_CONTEXT_KEYWORDS = frozenset(['ssl_version', 'certfile', 'keyfile', - 'cert_reqs', 'ca_certs', 'ciphers']) +_SSL_CONTEXT_KEYWORDS = frozenset( + ["ssl_version", "certfile", "keyfile", "cert_reqs", "ca_certs", "ciphers"] +) -def ssl_options_to_context(ssl_options): +def ssl_options_to_context( + ssl_options: Union[Dict[str, Any], ssl.SSLContext] +) -> ssl.SSLContext: """Try to convert an ``ssl_options`` dictionary to an `~ssl.SSLContext` object. @@ -524,17 +571,18 @@ def ssl_options_to_context(ssl_options): assert all(k in _SSL_CONTEXT_KEYWORDS for k in ssl_options), ssl_options # Can't use create_default_context since this interface doesn't # tell us client vs server. - context = ssl.SSLContext( - ssl_options.get('ssl_version', ssl.PROTOCOL_SSLv23)) - if 'certfile' in ssl_options: - context.load_cert_chain(ssl_options['certfile'], ssl_options.get('keyfile', None)) - if 'cert_reqs' in ssl_options: - context.verify_mode = ssl_options['cert_reqs'] - if 'ca_certs' in ssl_options: - context.load_verify_locations(ssl_options['ca_certs']) - if 'ciphers' in ssl_options: - context.set_ciphers(ssl_options['ciphers']) - if hasattr(ssl, 'OP_NO_COMPRESSION'): + context = ssl.SSLContext(ssl_options.get("ssl_version", ssl.PROTOCOL_SSLv23)) + if "certfile" in ssl_options: + context.load_cert_chain( + ssl_options["certfile"], ssl_options.get("keyfile", None) + ) + if "cert_reqs" in ssl_options: + context.verify_mode = ssl_options["cert_reqs"] + if "ca_certs" in ssl_options: + context.load_verify_locations(ssl_options["ca_certs"]) + if "ciphers" in ssl_options: + context.set_ciphers(ssl_options["ciphers"]) + if hasattr(ssl, "OP_NO_COMPRESSION"): # Disable TLS compression to avoid CRIME and related attacks. # This constant depends on openssl version 1.0. # TODO: Do we need to do this ourselves or can we trust @@ -543,7 +591,12 @@ def ssl_options_to_context(ssl_options): return context -def ssl_wrap_socket(socket, ssl_options, server_hostname=None, **kwargs): +def ssl_wrap_socket( + socket: socket.socket, + ssl_options: Union[Dict[str, Any], ssl.SSLContext], + server_hostname: Optional[str] = None, + **kwargs: Any +) -> ssl.SSLSocket: """Returns an ``ssl.SSLSocket`` wrapping the given socket. ``ssl_options`` may be either an `ssl.SSLContext` object or a @@ -559,7 +612,6 @@ def ssl_wrap_socket(socket, ssl_options, server_hostname=None, **kwargs): # TODO: add a unittest (python added server-side SNI support in 3.4) # In the meantime it can be manually tested with # python3 -m tornado.httpclient https://sni.velox.ch - return context.wrap_socket(socket, server_hostname=server_hostname, - **kwargs) + return context.wrap_socket(socket, server_hostname=server_hostname, **kwargs) else: return context.wrap_socket(socket, **kwargs) diff --git a/tornado/options.py b/tornado/options.py index a6f77029f7..058f88d16d 100644 --- a/tornado/options.py +++ b/tornado/options.py @@ -91,11 +91,8 @@ def start_server(): options can be defined, set, and read with any mix of the two. Dashes are typical for command-line usage while config files require underscores. - """ -from __future__ import absolute_import, division, print_function - import datetime import numbers import re @@ -105,12 +102,25 @@ def start_server(): from tornado.escape import _unicode, native_str from tornado.log import define_logging_options -from tornado import stack_context from tornado.util import basestring_type, exec_in +from typing import ( + Any, + Iterator, + Iterable, + Tuple, + Set, + Dict, + Callable, + List, + TextIO, + Optional, +) + class Error(Exception): """Exception raised by errors in the options module.""" + pass @@ -120,56 +130,61 @@ class OptionParser(object): Normally accessed via static functions in the `tornado.options` module, which reference a global instance. """ - def __init__(self): - # we have to use self.__dict__ because we override setattr. - self.__dict__['_options'] = {} - self.__dict__['_parse_callbacks'] = [] - self.define("help", type=bool, help="show this help information", - callback=self._help_callback) - def _normalize_name(self, name): - return name.replace('_', '-') - - def __getattr__(self, name): + def __init__(self) -> None: + # we have to use self.__dict__ because we override setattr. + self.__dict__["_options"] = {} + self.__dict__["_parse_callbacks"] = [] + self.define( + "help", + type=bool, + help="show this help information", + callback=self._help_callback, + ) + + def _normalize_name(self, name: str) -> str: + return name.replace("_", "-") + + def __getattr__(self, name: str) -> Any: name = self._normalize_name(name) if isinstance(self._options.get(name), _Option): return self._options[name].value() raise AttributeError("Unrecognized option %r" % name) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any) -> None: name = self._normalize_name(name) if isinstance(self._options.get(name), _Option): return self._options[name].set(value) raise AttributeError("Unrecognized option %r" % name) - def __iter__(self): + def __iter__(self) -> Iterator: return (opt.name for opt in self._options.values()) - def __contains__(self, name): + def __contains__(self, name: str) -> bool: name = self._normalize_name(name) return name in self._options - def __getitem__(self, name): + def __getitem__(self, name: str) -> Any: return self.__getattr__(name) - def __setitem__(self, name, value): + def __setitem__(self, name: str, value: Any) -> None: return self.__setattr__(name, value) - def items(self): - """A sequence of (name, value) pairs. + def items(self) -> Iterable[Tuple[str, Any]]: + """An iterable of (name, value) pairs. .. versionadded:: 3.1 """ return [(opt.name, opt.value()) for name, opt in self._options.items()] - def groups(self): + def groups(self) -> Set[str]: """The set of option-groups created by ``define``. .. versionadded:: 3.1 """ return set(opt.group_name for opt in self._options.values()) - def group_dict(self, group): + def group_dict(self, group: str) -> Dict[str, Any]: """The names and values of options in a group. Useful for copying options into Application settings:: @@ -187,19 +202,29 @@ def group_dict(self, group): .. versionadded:: 3.1 """ return dict( - (opt.name, opt.value()) for name, opt in self._options.items() - if not group or group == opt.group_name) + (opt.name, opt.value()) + for name, opt in self._options.items() + if not group or group == opt.group_name + ) - def as_dict(self): + def as_dict(self) -> Dict[str, Any]: """The names and values of all options. .. versionadded:: 3.1 """ - return dict( - (opt.name, opt.value()) for name, opt in self._options.items()) - - def define(self, name, default=None, type=None, help=None, metavar=None, - multiple=False, group=None, callback=None): + return dict((opt.name, opt.value()) for name, opt in self._options.items()) + + def define( + self, + name: str, + default: Any = None, + type: Optional[type] = None, + help: Optional[str] = None, + metavar: Optional[str] = None, + multiple: bool = False, + group: Optional[str] = None, + callback: Optional[Callable[[Any], None]] = None, + ) -> None: """Defines a new command line option. ``type`` can be any of `str`, `int`, `float`, `bool`, @@ -236,18 +261,27 @@ def define(self, name, default=None, type=None, help=None, metavar=None, """ normalized = self._normalize_name(name) if normalized in self._options: - raise Error("Option %r already defined in %s" % - (normalized, self._options[normalized].file_name)) + raise Error( + "Option %r already defined in %s" + % (normalized, self._options[normalized].file_name) + ) frame = sys._getframe(0) - options_file = frame.f_code.co_filename - - # Can be called directly, or through top level define() fn, in which - # case, step up above that frame to look for real caller. - if (frame.f_back.f_code.co_filename == options_file and - frame.f_back.f_code.co_name == 'define'): - frame = frame.f_back - - file_name = frame.f_back.f_code.co_filename + if frame is not None: + options_file = frame.f_code.co_filename + + # Can be called directly, or through top level define() fn, in which + # case, step up above that frame to look for real caller. + if ( + frame.f_back is not None + and frame.f_back.f_code.co_filename == options_file + and frame.f_back.f_code.co_name == "define" + ): + frame = frame.f_back + + assert frame.f_back is not None + file_name = frame.f_back.f_code.co_filename + else: + file_name = "" if file_name == options_file: file_name = "" if type is None: @@ -256,17 +290,25 @@ def define(self, name, default=None, type=None, help=None, metavar=None, else: type = str if group: - group_name = group + group_name = group # type: Optional[str] else: group_name = file_name - option = _Option(name, file_name=file_name, - default=default, type=type, help=help, - metavar=metavar, multiple=multiple, - group_name=group_name, - callback=callback) + option = _Option( + name, + file_name=file_name, + default=default, + type=type, + help=help, + metavar=metavar, + multiple=multiple, + group_name=group_name, + callback=callback, + ) self._options[normalized] = option - def parse_command_line(self, args=None, final=True): + def parse_command_line( + self, args: Optional[List[str]] = None, final: bool = True + ) -> List[str]: """Parses all options given on the command line (defaults to `sys.argv`). @@ -290,27 +332,27 @@ def parse_command_line(self, args=None, final=True): """ if args is None: args = sys.argv - remaining = [] + remaining = [] # type: List[str] for i in range(1, len(args)): # All things after the last option are command line arguments if not args[i].startswith("-"): remaining = args[i:] break if args[i] == "--": - remaining = args[i + 1:] + remaining = args[i + 1 :] break arg = args[i].lstrip("-") name, equals, value = arg.partition("=") name = self._normalize_name(name) if name not in self._options: self.print_help() - raise Error('Unrecognized command line option: %r' % name) + raise Error("Unrecognized command line option: %r" % name) option = self._options[name] if not equals: if option.type == bool: value = "true" else: - raise Error('Option %r requires a value' % name) + raise Error("Option %r requires a value" % name) option.parse(value) if final: @@ -318,7 +360,7 @@ def parse_command_line(self, args=None, final=True): return remaining - def parse_config_file(self, path, final=True): + def parse_config_file(self, path: str, final: bool = True) -> None: """Parses and loads the config file at the given path. The config file contains Python code that will be executed (so @@ -326,18 +368,20 @@ def parse_config_file(self, path, final=True): the global namespace that matches a defined option will be used to set that option's value. - Options are not parsed from strings as they would be on the - command line; they should be set to the correct type (this - means if you have ``datetime`` or ``timedelta`` options you - will need to import those modules in the config file. + Options may either be the specified type for the option or + strings (in which case they will be parsed the same way as in + `.parse_command_line`) Example (using the options defined in the top-level docs of this module):: port = 80 mysql_host = 'mydb.example.com:3306' + # Both lists and comma-separated strings are allowed for + # multiple=True. memcache_hosts = ['cache1.example.com:11011', 'cache2.example.com:11011'] + memcache_hosts = 'cache1.example.com:11011,cache2.example.com:11011' If ``final`` is ``False``, parse callbacks will not be run. This is useful for applications that wish to combine configurations @@ -358,25 +402,40 @@ def parse_config_file(self, path, final=True): The special variable ``__file__`` is available inside config files, specifying the absolute path to the config file itself. + .. versionchanged:: 5.1 + Added the ability to set options via strings in config files. + """ - config = {'__file__': os.path.abspath(path)} - with open(path, 'rb') as f: + config = {"__file__": os.path.abspath(path)} + with open(path, "rb") as f: exec_in(native_str(f.read()), config, config) for name in config: normalized = self._normalize_name(name) if normalized in self._options: - self._options[normalized].set(config[name]) + option = self._options[normalized] + if option.multiple: + if not isinstance(config[name], (list, str)): + raise Error( + "Option %r is required to be a list of %s " + "or a comma-separated string" + % (option.name, option.type.__name__) + ) + + if type(config[name]) == str and option.type != str: + option.parse(config[name]) + else: + option.set(config[name]) if final: self.run_parse_callbacks() - def print_help(self, file=None): + def print_help(self, file: Optional[TextIO] = None) -> None: """Prints all the command line options to stderr (or another file).""" if file is None: file = sys.stderr print("Usage: %s [OPTIONS]" % sys.argv[0], file=file) print("\nOptions:\n", file=file) - by_group = {} + by_group = {} # type: Dict[str, List[_Option]] for option in self._options.values(): by_group.setdefault(option.group_name, []).append(option) @@ -390,30 +449,30 @@ def print_help(self, file=None): if option.metavar: prefix += "=" + option.metavar description = option.help or "" - if option.default is not None and option.default != '': + if option.default is not None and option.default != "": description += " (default %s)" % option.default lines = textwrap.wrap(description, 79 - 35) if len(prefix) > 30 or len(lines) == 0: - lines.insert(0, '') + lines.insert(0, "") print(" --%-30s %s" % (prefix, lines[0]), file=file) for line in lines[1:]: - print("%-34s %s" % (' ', line), file=file) + print("%-34s %s" % (" ", line), file=file) print(file=file) - def _help_callback(self, value): + def _help_callback(self, value: bool) -> None: if value: self.print_help() sys.exit(0) - def add_parse_callback(self, callback): + def add_parse_callback(self, callback: Callable[[], None]) -> None: """Adds a parse callback, to be invoked when option parsing is done.""" - self._parse_callbacks.append(stack_context.wrap(callback)) + self._parse_callbacks.append(callback) - def run_parse_callbacks(self): + def run_parse_callbacks(self) -> None: for callback in self._parse_callbacks: callback() - def mockable(self): + def mockable(self) -> "_Mockable": """Returns a wrapper around self that is compatible with `mock.patch `. @@ -437,38 +496,53 @@ class _Mockable(object): As of ``mock`` version 1.0.1, when an object uses ``__getattr__`` hooks instead of ``__dict__``, ``patch.__exit__`` tries to delete the attribute it set instead of setting a new one (assuming that - the object does not catpure ``__setattr__``, so the patch + the object does not capture ``__setattr__``, so the patch created a new attribute in ``__dict__``). _Mockable's getattr and setattr pass through to the underlying OptionParser, and delattr undoes the effect of a previous setattr. """ - def __init__(self, options): + + def __init__(self, options: OptionParser) -> None: # Modify __dict__ directly to bypass __setattr__ - self.__dict__['_options'] = options - self.__dict__['_originals'] = {} + self.__dict__["_options"] = options + self.__dict__["_originals"] = {} - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: return getattr(self._options, name) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any) -> None: assert name not in self._originals, "don't reuse mockable objects" self._originals[name] = getattr(self._options, name) setattr(self._options, name, value) - def __delattr__(self, name): + def __delattr__(self, name: str) -> None: setattr(self._options, name, self._originals.pop(name)) class _Option(object): + # This class could almost be made generic, but the way the types + # interact with the multiple argument makes this tricky. (default + # and the callback use List[T], but type is still Type[T]). UNSET = object() - def __init__(self, name, default=None, type=basestring_type, help=None, - metavar=None, multiple=False, file_name=None, group_name=None, - callback=None): + def __init__( + self, + name: str, + default: Any = None, + type: Optional[type] = None, + help: Optional[str] = None, + metavar: Optional[str] = None, + multiple: bool = False, + file_name: Optional[str] = None, + group_name: Optional[str] = None, + callback: Optional[Callable[[Any], None]] = None, + ) -> None: if default is None and multiple: default = [] self.name = name + if type is None: + raise ValueError("type must not be None") self.type = type self.help = help self.metavar = metavar @@ -477,26 +551,28 @@ def __init__(self, name, default=None, type=basestring_type, help=None, self.group_name = group_name self.callback = callback self.default = default - self._value = _Option.UNSET + self._value = _Option.UNSET # type: Any - def value(self): + def value(self) -> Any: return self.default if self._value is _Option.UNSET else self._value - def parse(self, value): + def parse(self, value: str) -> Any: _parse = { datetime.datetime: self._parse_datetime, datetime.timedelta: self._parse_timedelta, bool: self._parse_bool, basestring_type: self._parse_string, - }.get(self.type, self.type) + }.get( + self.type, self.type + ) # type: Callable[[str], Any] if self.multiple: self._value = [] for part in value.split(","): if issubclass(self.type, numbers.Integral): # allow ranges of the form X:Y (inclusive at both ends) - lo, _, hi = part.partition(":") - lo = _parse(lo) - hi = _parse(hi) if hi else lo + lo_str, _, hi_str = part.partition(":") + lo = _parse(lo_str) + hi = _parse(hi_str) if hi_str else lo self._value.extend(range(lo, hi + 1)) else: self._value.append(_parse(part)) @@ -506,19 +582,25 @@ def parse(self, value): self.callback(self._value) return self.value() - def set(self, value): + def set(self, value: Any) -> None: if self.multiple: if not isinstance(value, list): - raise Error("Option %r is required to be a list of %s" % - (self.name, self.type.__name__)) + raise Error( + "Option %r is required to be a list of %s" + % (self.name, self.type.__name__) + ) for item in value: if item is not None and not isinstance(item, self.type): - raise Error("Option %r is required to be a list of %s" % - (self.name, self.type.__name__)) + raise Error( + "Option %r is required to be a list of %s" + % (self.name, self.type.__name__) + ) else: if value is not None and not isinstance(value, self.type): - raise Error("Option %r is required to be a %s (%s given)" % - (self.name, self.type.__name__, type(value))) + raise Error( + "Option %r is required to be a %s (%s given)" + % (self.name, self.type.__name__, type(value)) + ) self._value = value if self.callback is not None: self.callback(self._value) @@ -537,32 +619,33 @@ def set(self, value): "%H:%M", ] - def _parse_datetime(self, value): + def _parse_datetime(self, value: str) -> datetime.datetime: for format in self._DATETIME_FORMATS: try: return datetime.datetime.strptime(value, format) except ValueError: pass - raise Error('Unrecognized date/time format: %r' % value) + raise Error("Unrecognized date/time format: %r" % value) _TIMEDELTA_ABBREV_DICT = { - 'h': 'hours', - 'm': 'minutes', - 'min': 'minutes', - 's': 'seconds', - 'sec': 'seconds', - 'ms': 'milliseconds', - 'us': 'microseconds', - 'd': 'days', - 'w': 'weeks', + "h": "hours", + "m": "minutes", + "min": "minutes", + "s": "seconds", + "sec": "seconds", + "ms": "milliseconds", + "us": "microseconds", + "d": "days", + "w": "weeks", } - _FLOAT_PATTERN = r'[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?' + _FLOAT_PATTERN = r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?" _TIMEDELTA_PATTERN = re.compile( - r'\s*(%s)\s*(\w*)\s*' % _FLOAT_PATTERN, re.IGNORECASE) + r"\s*(%s)\s*(\w*)\s*" % _FLOAT_PATTERN, re.IGNORECASE + ) - def _parse_timedelta(self, value): + def _parse_timedelta(self, value: str) -> datetime.timedelta: try: sum = datetime.timedelta() start = 0 @@ -571,18 +654,20 @@ def _parse_timedelta(self, value): if not m: raise Exception() num = float(m.group(1)) - units = m.group(2) or 'seconds' + units = m.group(2) or "seconds" units = self._TIMEDELTA_ABBREV_DICT.get(units, units) - sum += datetime.timedelta(**{units: num}) + # This line confuses mypy when setup.py sets python_version=3.6 + # https://github.com/python/mypy/issues/9676 + sum += datetime.timedelta(**{units: num}) # type: ignore start = m.end() return sum except Exception: raise - def _parse_bool(self, value): + def _parse_bool(self, value: str) -> bool: return value.lower() not in ("false", "0", "f") - def _parse_string(self, value): + def _parse_string(self, value: str) -> str: return _unicode(value) @@ -593,18 +678,35 @@ def _parse_string(self, value): """ -def define(name, default=None, type=None, help=None, metavar=None, - multiple=False, group=None, callback=None): +def define( + name: str, + default: Any = None, + type: Optional[type] = None, + help: Optional[str] = None, + metavar: Optional[str] = None, + multiple: bool = False, + group: Optional[str] = None, + callback: Optional[Callable[[Any], None]] = None, +) -> None: """Defines an option in the global namespace. See `OptionParser.define`. """ - return options.define(name, default=default, type=type, help=help, - metavar=metavar, multiple=multiple, group=group, - callback=callback) - - -def parse_command_line(args=None, final=True): + return options.define( + name, + default=default, + type=type, + help=help, + metavar=metavar, + multiple=multiple, + group=group, + callback=callback, + ) + + +def parse_command_line( + args: Optional[List[str]] = None, final: bool = True +) -> List[str]: """Parses global options from the command line. See `OptionParser.parse_command_line`. @@ -612,7 +714,7 @@ def parse_command_line(args=None, final=True): return options.parse_command_line(args, final=final) -def parse_config_file(path, final=True): +def parse_config_file(path: str, final: bool = True) -> None: """Parses global options from a config file. See `OptionParser.parse_config_file`. @@ -620,7 +722,7 @@ def parse_config_file(path, final=True): return options.parse_config_file(path, final=final) -def print_help(file=None): +def print_help(file: Optional[TextIO] = None) -> None: """Prints all the command line options to stderr (or another file). See `OptionParser.print_help`. @@ -628,7 +730,7 @@ def print_help(file=None): return options.print_help(file) -def add_parse_callback(callback): +def add_parse_callback(callback: Callable[[], None]) -> None: """Adds a parse callback, to be invoked when option parsing is done. See `OptionParser.add_parse_callback` diff --git a/tornado/platform/asyncio.py b/tornado/platform/asyncio.py index b6a490afac..292d9b66a4 100644 --- a/tornado/platform/asyncio.py +++ b/tornado/platform/asyncio.py @@ -14,29 +14,88 @@ .. note:: - Tornado requires the `~asyncio.AbstractEventLoop.add_reader` family of - methods, so it is not compatible with the `~asyncio.ProactorEventLoop` on - Windows. Use the `~asyncio.SelectorEventLoop` instead. + Tornado is designed to use a selector-based event loop. On Windows, + where a proactor-based event loop has been the default since Python 3.8, + a selector event loop is emulated by running ``select`` on a separate thread. + Configuring ``asyncio`` to use a selector event loop may improve performance + of Tornado (but may reduce performance of other ``asyncio``-based libraries + in the same process). """ -from __future__ import absolute_import, division, print_function +import asyncio +import atexit +import concurrent.futures +import errno import functools - +import select +import socket +import sys +import threading +import typing from tornado.gen import convert_yielded -from tornado.ioloop import IOLoop -from tornado import stack_context +from tornado.ioloop import IOLoop, _Selectable + +from typing import Any, TypeVar, Awaitable, Callable, Union, Optional, List, Tuple, Dict + +if typing.TYPE_CHECKING: + from typing import Set # noqa: F401 + from typing_extensions import Protocol + + class _HasFileno(Protocol): + def fileno(self) -> int: + pass + + _FileDescriptorLike = Union[int, _HasFileno] + +_T = TypeVar("_T") + + +# Collection of selector thread event loops to shut down on exit. +_selector_loops = set() # type: Set[AddThreadSelectorEventLoop] + + +def _atexit_callback() -> None: + for loop in _selector_loops: + with loop._select_cond: + loop._closing_selector = True + loop._select_cond.notify() + try: + loop._waker_w.send(b"a") + except BlockingIOError: + pass + # If we don't join our (daemon) thread here, we may get a deadlock + # during interpreter shutdown. I don't really understand why. This + # deadlock happens every time in CI (both travis and appveyor) but + # I've never been able to reproduce locally. + loop._thread.join() + _selector_loops.clear() -import asyncio + +atexit.register(_atexit_callback) class BaseAsyncIOLoop(IOLoop): - def initialize(self, asyncio_loop, **kwargs): + def initialize( # type: ignore + self, asyncio_loop: asyncio.AbstractEventLoop, **kwargs: Any + ) -> None: + # asyncio_loop is always the real underlying IOLoop. This is used in + # ioloop.py to maintain the asyncio-to-ioloop mappings. self.asyncio_loop = asyncio_loop + # selector_loop is an event loop that implements the add_reader family of + # methods. Usually the same as asyncio_loop but differs on platforms such + # as windows where the default event loop does not implement these methods. + self.selector_loop = asyncio_loop + if hasattr(asyncio, "ProactorEventLoop") and isinstance( + asyncio_loop, asyncio.ProactorEventLoop # type: ignore + ): + # Ignore this line for mypy because the abstract method checker + # doesn't understand dynamic proxies. + self.selector_loop = AddThreadSelectorEventLoop(asyncio_loop) # type: ignore # Maps fd to (fileobj, handler function) pair (as in IOLoop.add_handler) - self.handlers = {} + self.handlers = {} # type: Dict[int, Tuple[Union[int, _Selectable], Callable]] # Set of fds listening for reads/writes - self.readers = set() - self.writers = set() + self.readers = set() # type: Set[int] + self.writers = set() # type: Set[int] self.closing = False # If an asyncio loop was closed through an asyncio interface # instead of IOLoop.close(), we'd never hear about it and may @@ -53,74 +112,87 @@ def initialize(self, asyncio_loop, **kwargs): if loop.is_closed(): del IOLoop._ioloop_for_asyncio[loop] IOLoop._ioloop_for_asyncio[asyncio_loop] = self - super(BaseAsyncIOLoop, self).initialize(**kwargs) - def close(self, all_fds=False): + self._thread_identity = 0 + + super().initialize(**kwargs) + + def assign_thread_identity() -> None: + self._thread_identity = threading.get_ident() + + self.add_callback(assign_thread_identity) + + def close(self, all_fds: bool = False) -> None: self.closing = True for fd in list(self.handlers): fileobj, handler_func = self.handlers[fd] self.remove_handler(fd) if all_fds: self.close_fd(fileobj) - self.asyncio_loop.close() + # Remove the mapping before closing the asyncio loop. If this + # happened in the other order, we could race against another + # initialize() call which would see the closed asyncio loop, + # assume it was closed from the asyncio side, and do this + # cleanup for us, leading to a KeyError. del IOLoop._ioloop_for_asyncio[self.asyncio_loop] + if self.selector_loop is not self.asyncio_loop: + self.selector_loop.close() + self.asyncio_loop.close() - def add_handler(self, fd, handler, events): + def add_handler( + self, fd: Union[int, _Selectable], handler: Callable[..., None], events: int + ) -> None: fd, fileobj = self.split_fd(fd) if fd in self.handlers: raise ValueError("fd %s added twice" % fd) - self.handlers[fd] = (fileobj, stack_context.wrap(handler)) + self.handlers[fd] = (fileobj, handler) if events & IOLoop.READ: - self.asyncio_loop.add_reader( - fd, self._handle_events, fd, IOLoop.READ) + self.selector_loop.add_reader(fd, self._handle_events, fd, IOLoop.READ) self.readers.add(fd) if events & IOLoop.WRITE: - self.asyncio_loop.add_writer( - fd, self._handle_events, fd, IOLoop.WRITE) + self.selector_loop.add_writer(fd, self._handle_events, fd, IOLoop.WRITE) self.writers.add(fd) - def update_handler(self, fd, events): + def update_handler(self, fd: Union[int, _Selectable], events: int) -> None: fd, fileobj = self.split_fd(fd) if events & IOLoop.READ: if fd not in self.readers: - self.asyncio_loop.add_reader( - fd, self._handle_events, fd, IOLoop.READ) + self.selector_loop.add_reader(fd, self._handle_events, fd, IOLoop.READ) self.readers.add(fd) else: if fd in self.readers: - self.asyncio_loop.remove_reader(fd) + self.selector_loop.remove_reader(fd) self.readers.remove(fd) if events & IOLoop.WRITE: if fd not in self.writers: - self.asyncio_loop.add_writer( - fd, self._handle_events, fd, IOLoop.WRITE) + self.selector_loop.add_writer(fd, self._handle_events, fd, IOLoop.WRITE) self.writers.add(fd) else: if fd in self.writers: - self.asyncio_loop.remove_writer(fd) + self.selector_loop.remove_writer(fd) self.writers.remove(fd) - def remove_handler(self, fd): + def remove_handler(self, fd: Union[int, _Selectable]) -> None: fd, fileobj = self.split_fd(fd) if fd not in self.handlers: return if fd in self.readers: - self.asyncio_loop.remove_reader(fd) + self.selector_loop.remove_reader(fd) self.readers.remove(fd) if fd in self.writers: - self.asyncio_loop.remove_writer(fd) + self.selector_loop.remove_writer(fd) self.writers.remove(fd) del self.handlers[fd] - def _handle_events(self, fd, events): + def _handle_events(self, fd: int, events: int) -> None: fileobj, handler_func = self.handlers[fd] handler_func(fileobj, events) - def start(self): + def start(self) -> None: try: old_loop = asyncio.get_event_loop() except (RuntimeError, AssertionError): - old_loop = None + old_loop = None # type: ignore try: self._setup_logging() asyncio.set_event_loop(self.asyncio_loop) @@ -128,25 +200,31 @@ def start(self): finally: asyncio.set_event_loop(old_loop) - def stop(self): + def stop(self) -> None: self.asyncio_loop.stop() - def call_at(self, when, callback, *args, **kwargs): + def call_at( + self, when: float, callback: Callable[..., None], *args: Any, **kwargs: Any + ) -> object: # asyncio.call_at supports *args but not **kwargs, so bind them here. # We do not synchronize self.time and asyncio_loop.time, so # convert from absolute to relative. return self.asyncio_loop.call_later( - max(0, when - self.time()), self._run_callback, - functools.partial(stack_context.wrap(callback), *args, **kwargs)) + max(0, when - self.time()), + self._run_callback, + functools.partial(callback, *args, **kwargs), + ) - def remove_timeout(self, timeout): - timeout.cancel() + def remove_timeout(self, timeout: object) -> None: + timeout.cancel() # type: ignore - def add_callback(self, callback, *args, **kwargs): + def add_callback(self, callback: Callable, *args: Any, **kwargs: Any) -> None: + if threading.get_ident() == self._thread_identity: + call_soon = self.asyncio_loop.call_soon + else: + call_soon = self.asyncio_loop.call_soon_threadsafe try: - self.asyncio_loop.call_soon_threadsafe( - self._run_callback, - functools.partial(stack_context.wrap(callback), *args, **kwargs)) + call_soon(self._run_callback, functools.partial(callback, *args, **kwargs)) except RuntimeError: # "Event loop is closed". Swallow the exception for # consistency with PollIOLoop (and logical consistency @@ -154,13 +232,31 @@ def add_callback(self, callback, *args, **kwargs): # add_callback that completes without error will # eventually execute). pass + except AttributeError: + # ProactorEventLoop may raise this instead of RuntimeError + # if call_soon_threadsafe races with a call to close(). + # Swallow it too for consistency. + pass - add_callback_from_signal = add_callback + def add_callback_from_signal( + self, callback: Callable, *args: Any, **kwargs: Any + ) -> None: + try: + self.asyncio_loop.call_soon_threadsafe( + self._run_callback, functools.partial(callback, *args, **kwargs) + ) + except RuntimeError: + pass - def run_in_executor(self, executor, func, *args): + def run_in_executor( + self, + executor: Optional[concurrent.futures.Executor], + func: Callable[..., _T], + *args: Any + ) -> Awaitable[_T]: return self.asyncio_loop.run_in_executor(executor, func, *args) - def set_default_executor(self, executor): + def set_default_executor(self, executor: concurrent.futures.Executor) -> None: return self.asyncio_loop.set_default_executor(executor) @@ -178,10 +274,11 @@ class AsyncIOMainLoop(BaseAsyncIOLoop): Closing an `AsyncIOMainLoop` now closes the underlying asyncio loop. """ - def initialize(self, **kwargs): - super(AsyncIOMainLoop, self).initialize(asyncio.get_event_loop(), **kwargs) - def make_current(self): + def initialize(self, **kwargs: Any) -> None: # type: ignore + super().initialize(asyncio.get_event_loop(), **kwargs) + + def make_current(self) -> None: # AsyncIOMainLoop already refers to the current asyncio loop so # nothing to do here. pass @@ -206,38 +303,39 @@ class AsyncIOLoop(BaseAsyncIOLoop): Now used automatically when appropriate; it is no longer necessary to refer to this class directly. """ - def initialize(self, **kwargs): + + def initialize(self, **kwargs: Any) -> None: # type: ignore self.is_current = False loop = asyncio.new_event_loop() try: - super(AsyncIOLoop, self).initialize(loop, **kwargs) + super().initialize(loop, **kwargs) except Exception: # If initialize() does not succeed (taking ownership of the loop), # we have to close it. loop.close() raise - def close(self, all_fds=False): + def close(self, all_fds: bool = False) -> None: if self.is_current: self.clear_current() - super(AsyncIOLoop, self).close(all_fds=all_fds) + super().close(all_fds=all_fds) - def make_current(self): + def make_current(self) -> None: if not self.is_current: try: self.old_asyncio = asyncio.get_event_loop() except (RuntimeError, AssertionError): - self.old_asyncio = None + self.old_asyncio = None # type: ignore self.is_current = True asyncio.set_event_loop(self.asyncio_loop) - def _clear_current_hook(self): + def _clear_current_hook(self) -> None: if self.is_current: asyncio.set_event_loop(self.old_asyncio) self.is_current = False -def to_tornado_future(asyncio_future): +def to_tornado_future(asyncio_future: asyncio.Future) -> asyncio.Future: """Convert an `asyncio.Future` to a `tornado.concurrent.Future`. .. versionadded:: 4.1 @@ -249,7 +347,7 @@ def to_tornado_future(asyncio_future): return asyncio_future -def to_asyncio_future(tornado_future): +def to_asyncio_future(tornado_future: asyncio.Future) -> asyncio.Future: """Convert a Tornado yieldable object to an `asyncio.Future`. .. versionadded:: 4.1 @@ -265,7 +363,15 @@ def to_asyncio_future(tornado_future): return convert_yielded(tornado_future) -class AnyThreadEventLoopPolicy(asyncio.DefaultEventLoopPolicy): +if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): + # "Any thread" and "selector" should be orthogonal, but there's not a clean + # interface for composing policies so pick the right base. + _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore +else: + _BasePolicy = asyncio.DefaultEventLoopPolicy + + +class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore """Event loop policy that allows loop creation on any thread. The default `asyncio` event loop policy only automatically creates @@ -282,13 +388,228 @@ class AnyThreadEventLoopPolicy(asyncio.DefaultEventLoopPolicy): .. versionadded:: 5.0 """ - def get_event_loop(self): + + def get_event_loop(self) -> asyncio.AbstractEventLoop: try: return super().get_event_loop() except (RuntimeError, AssertionError): - # This was an AssertionError in python 3.4.2 (which ships with debian jessie) + # This was an AssertionError in Python 3.4.2 (which ships with Debian Jessie) # and changed to a RuntimeError in 3.4.3. # "There is no current event loop in thread %r" loop = self.new_event_loop() self.set_event_loop(loop) return loop + + +class AddThreadSelectorEventLoop(asyncio.AbstractEventLoop): + """Wrap an event loop to add implementations of the ``add_reader`` method family. + + Instances of this class start a second thread to run a selector. + This thread is completely hidden from the user; all callbacks are + run on the wrapped event loop's thread. + + This class is used automatically by Tornado; applications should not need + to refer to it directly. + + It is safe to wrap any event loop with this class, although it only makes sense + for event loops that do not implement the ``add_reader`` family of methods + themselves (i.e. ``WindowsProactorEventLoop``) + + Closing the ``AddThreadSelectorEventLoop`` also closes the wrapped event loop. + + """ + + # This class is a __getattribute__-based proxy. All attributes other than those + # in this set are proxied through to the underlying loop. + MY_ATTRIBUTES = { + "_consume_waker", + "_select_cond", + "_select_args", + "_closing_selector", + "_thread", + "_handle_event", + "_readers", + "_real_loop", + "_start_select", + "_run_select", + "_handle_select", + "_wake_selector", + "_waker_r", + "_waker_w", + "_writers", + "add_reader", + "add_writer", + "close", + "remove_reader", + "remove_writer", + } + + def __getattribute__(self, name: str) -> Any: + if name in AddThreadSelectorEventLoop.MY_ATTRIBUTES: + return super().__getattribute__(name) + return getattr(self._real_loop, name) + + def __init__(self, real_loop: asyncio.AbstractEventLoop) -> None: + self._real_loop = real_loop + + # Create a thread to run the select system call. We manage this thread + # manually so we can trigger a clean shutdown from an atexit hook. Note + # that due to the order of operations at shutdown, only daemon threads + # can be shut down in this way (non-daemon threads would require the + # introduction of a new hook: https://bugs.python.org/issue41962) + self._select_cond = threading.Condition() + self._select_args = ( + None + ) # type: Optional[Tuple[List[_FileDescriptorLike], List[_FileDescriptorLike]]] + self._closing_selector = False + self._thread = threading.Thread( + name="Tornado selector", + daemon=True, + target=self._run_select, + ) + self._thread.start() + # Start the select loop once the loop is started. + self._real_loop.call_soon(self._start_select) + + self._readers = {} # type: Dict[_FileDescriptorLike, Callable] + self._writers = {} # type: Dict[_FileDescriptorLike, Callable] + + # Writing to _waker_w will wake up the selector thread, which + # watches for _waker_r to be readable. + self._waker_r, self._waker_w = socket.socketpair() + self._waker_r.setblocking(False) + self._waker_w.setblocking(False) + _selector_loops.add(self) + self.add_reader(self._waker_r, self._consume_waker) + + def __del__(self) -> None: + # If the top-level application code uses asyncio interfaces to + # start and stop the event loop, no objects created in Tornado + # can get a clean shutdown notification. If we're just left to + # be GC'd, we must explicitly close our sockets to avoid + # logging warnings. + _selector_loops.discard(self) + self._waker_r.close() + self._waker_w.close() + + def close(self) -> None: + with self._select_cond: + self._closing_selector = True + self._select_cond.notify() + self._wake_selector() + self._thread.join() + _selector_loops.discard(self) + self._waker_r.close() + self._waker_w.close() + self._real_loop.close() + + def _wake_selector(self) -> None: + try: + self._waker_w.send(b"a") + except BlockingIOError: + pass + + def _consume_waker(self) -> None: + try: + self._waker_r.recv(1024) + except BlockingIOError: + pass + + def _start_select(self) -> None: + # Capture reader and writer sets here in the event loop + # thread to avoid any problems with concurrent + # modification while the select loop uses them. + with self._select_cond: + assert self._select_args is None + self._select_args = (list(self._readers.keys()), list(self._writers.keys())) + self._select_cond.notify() + + def _run_select(self) -> None: + while True: + with self._select_cond: + while self._select_args is None and not self._closing_selector: + self._select_cond.wait() + if self._closing_selector: + return + assert self._select_args is not None + to_read, to_write = self._select_args + self._select_args = None + + # We use the simpler interface of the select module instead of + # the more stateful interface in the selectors module because + # this class is only intended for use on windows, where + # select.select is the only option. The selector interface + # does not have well-documented thread-safety semantics that + # we can rely on so ensuring proper synchronization would be + # tricky. + try: + # On windows, selecting on a socket for write will not + # return the socket when there is an error (but selecting + # for reads works). Also select for errors when selecting + # for writes, and merge the results. + # + # This pattern is also used in + # https://github.com/python/cpython/blob/v3.8.0/Lib/selectors.py#L312-L317 + rs, ws, xs = select.select(to_read, to_write, to_write) + ws = ws + xs + except OSError as e: + # After remove_reader or remove_writer is called, the file + # descriptor may subsequently be closed on the event loop + # thread. It's possible that this select thread hasn't + # gotten into the select system call by the time that + # happens in which case (at least on macOS), select may + # raise a "bad file descriptor" error. If we get that + # error, check and see if we're also being woken up by + # polling the waker alone. If we are, just return to the + # event loop and we'll get the updated set of file + # descriptors on the next iteration. Otherwise, raise the + # original error. + if e.errno == getattr(errno, "WSAENOTSOCK", errno.EBADF): + rs, _, _ = select.select([self._waker_r.fileno()], [], [], 0) + if rs: + ws = [] + else: + raise + else: + raise + self._real_loop.call_soon_threadsafe(self._handle_select, rs, ws) + + def _handle_select( + self, rs: List["_FileDescriptorLike"], ws: List["_FileDescriptorLike"] + ) -> None: + for r in rs: + self._handle_event(r, self._readers) + for w in ws: + self._handle_event(w, self._writers) + self._start_select() + + def _handle_event( + self, + fd: "_FileDescriptorLike", + cb_map: Dict["_FileDescriptorLike", Callable], + ) -> None: + try: + callback = cb_map[fd] + except KeyError: + return + callback() + + def add_reader( + self, fd: "_FileDescriptorLike", callback: Callable[..., None], *args: Any + ) -> None: + self._readers[fd] = functools.partial(callback, *args) + self._wake_selector() + + def add_writer( + self, fd: "_FileDescriptorLike", callback: Callable[..., None], *args: Any + ) -> None: + self._writers[fd] = functools.partial(callback, *args) + self._wake_selector() + + def remove_reader(self, fd: "_FileDescriptorLike") -> None: + del self._readers[fd] + self._wake_selector() + + def remove_writer(self, fd: "_FileDescriptorLike") -> None: + del self._writers[fd] + self._wake_selector() diff --git a/tornado/platform/auto.py b/tornado/platform/auto.py deleted file mode 100644 index 1a9133faf3..0000000000 --- a/tornado/platform/auto.py +++ /dev/null @@ -1,58 +0,0 @@ -# -# Copyright 2011 Facebook -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -"""Implementation of platform-specific functionality. - -For each function or class described in `tornado.platform.interface`, -the appropriate platform-specific implementation exists in this module. -Most code that needs access to this functionality should do e.g.:: - - from tornado.platform.auto import set_close_exec -""" - -from __future__ import absolute_import, division, print_function - -import os - -if 'APPENGINE_RUNTIME' in os.environ: - from tornado.platform.common import Waker - - def set_close_exec(fd): - pass -elif os.name == 'nt': - from tornado.platform.common import Waker - from tornado.platform.windows import set_close_exec -else: - from tornado.platform.posix import set_close_exec, Waker - -try: - # monotime monkey-patches the time module to have a monotonic function - # in versions of python before 3.3. - import monotime - # Silence pyflakes warning about this unused import - monotime -except ImportError: - pass -try: - # monotonic can provide a monotonic function in versions of python before - # 3.3, too. - from monotonic import monotonic as monotonic_time -except ImportError: - try: - from time import monotonic as monotonic_time - except ImportError: - monotonic_time = None - -__all__ = ['Waker', 'set_close_exec', 'monotonic_time'] diff --git a/tornado/platform/auto.pyi b/tornado/platform/auto.pyi deleted file mode 100644 index a1c97228a4..0000000000 --- a/tornado/platform/auto.pyi +++ /dev/null @@ -1,4 +0,0 @@ -# auto.py is full of patterns mypy doesn't like, so for type checking -# purposes we replace it with interface.py. - -from .interface import * diff --git a/tornado/platform/caresresolver.py b/tornado/platform/caresresolver.py index 768cb62499..962f84f48f 100644 --- a/tornado/platform/caresresolver.py +++ b/tornado/platform/caresresolver.py @@ -1,4 +1,3 @@ -from __future__ import absolute_import, division, print_function import pycares # type: ignore import socket @@ -7,6 +6,11 @@ from tornado.ioloop import IOLoop from tornado.netutil import Resolver, is_valid_ip +import typing + +if typing.TYPE_CHECKING: + from typing import Generator, Any, List, Tuple, Dict # noqa: F401 + class CaresResolver(Resolver): """Name resolver based on the c-ares library. @@ -22,15 +26,19 @@ class CaresResolver(Resolver): .. versionchanged:: 5.0 The ``io_loop`` argument (deprecated since version 4.1) has been removed. + + .. deprecated:: 6.2 + This class is deprecated and will be removed in Tornado 7.0. Use the default + thread-based resolver instead. """ - def initialize(self): + + def initialize(self) -> None: self.io_loop = IOLoop.current() self.channel = pycares.Channel(sock_state_cb=self._sock_state_cb) - self.fds = {} + self.fds = {} # type: Dict[int, int] - def _sock_state_cb(self, fd, readable, writable): - state = ((IOLoop.READ if readable else 0) | - (IOLoop.WRITE if writable else 0)) + def _sock_state_cb(self, fd: int, readable: bool, writable: bool) -> None: + state = (IOLoop.READ if readable else 0) | (IOLoop.WRITE if writable else 0) if not state: self.io_loop.remove_handler(fd) del self.fds[fd] @@ -41,7 +49,7 @@ def _sock_state_cb(self, fd, readable, writable): self.io_loop.add_handler(fd, self._handle_events, state) self.fds[fd] = state - def _handle_events(self, fd, events): + def _handle_events(self, fd: int, events: int) -> None: read_fd = pycares.ARES_SOCKET_BAD write_fd = pycares.ARES_SOCKET_BAD if events & IOLoop.READ: @@ -51,29 +59,35 @@ def _handle_events(self, fd, events): self.channel.process_fd(read_fd, write_fd) @gen.coroutine - def resolve(self, host, port, family=0): + def resolve( + self, host: str, port: int, family: int = 0 + ) -> "Generator[Any, Any, List[Tuple[int, Any]]]": if is_valid_ip(host): addresses = [host] else: # gethostbyname doesn't take callback as a kwarg - fut = Future() - self.channel.gethostbyname(host, family, - lambda result, error: fut.set_result((result, error))) + fut = Future() # type: Future[Tuple[Any, Any]] + self.channel.gethostbyname( + host, family, lambda result, error: fut.set_result((result, error)) + ) result, error = yield fut if error: - raise IOError('C-Ares returned error %s: %s while resolving %s' % - (error, pycares.errno.strerror(error), host)) + raise IOError( + "C-Ares returned error %s: %s while resolving %s" + % (error, pycares.errno.strerror(error), host) + ) addresses = result.addresses addrinfo = [] for address in addresses: - if '.' in address: + if "." in address: address_family = socket.AF_INET - elif ':' in address: + elif ":" in address: address_family = socket.AF_INET6 else: address_family = socket.AF_UNSPEC if family != socket.AF_UNSPEC and family != address_family: - raise IOError('Requested socket family %d but got %d' % - (family, address_family)) - addrinfo.append((address_family, (address, port))) - raise gen.Return(addrinfo) + raise IOError( + "Requested socket family %d but got %d" % (family, address_family) + ) + addrinfo.append((typing.cast(int, address_family), (address, port))) + return addrinfo diff --git a/tornado/platform/common.py b/tornado/platform/common.py deleted file mode 100644 index b597748d1f..0000000000 --- a/tornado/platform/common.py +++ /dev/null @@ -1,113 +0,0 @@ -"""Lowest-common-denominator implementations of platform functionality.""" -from __future__ import absolute_import, division, print_function - -import errno -import socket -import time - -from tornado.platform import interface -from tornado.util import errno_from_exception - - -def try_close(f): - # Avoid issue #875 (race condition when using the file in another - # thread). - for i in range(10): - try: - f.close() - except IOError: - # Yield to another thread - time.sleep(1e-3) - else: - break - # Try a last time and let raise - f.close() - - -class Waker(interface.Waker): - """Create an OS independent asynchronous pipe. - - For use on platforms that don't have os.pipe() (or where pipes cannot - be passed to select()), but do have sockets. This includes Windows - and Jython. - """ - def __init__(self): - from .auto import set_close_exec - # Based on Zope select_trigger.py: - # https://github.com/zopefoundation/Zope/blob/master/src/ZServer/medusa/thread/select_trigger.py - - self.writer = socket.socket() - set_close_exec(self.writer.fileno()) - # Disable buffering -- pulling the trigger sends 1 byte, - # and we want that sent immediately, to wake up ASAP. - self.writer.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - - count = 0 - while 1: - count += 1 - # Bind to a local port; for efficiency, let the OS pick - # a free port for us. - # Unfortunately, stress tests showed that we may not - # be able to connect to that port ("Address already in - # use") despite that the OS picked it. This appears - # to be a race bug in the Windows socket implementation. - # So we loop until a connect() succeeds (almost always - # on the first try). See the long thread at - # http://mail.zope.org/pipermail/zope/2005-July/160433.html - # for hideous details. - a = socket.socket() - set_close_exec(a.fileno()) - a.bind(("127.0.0.1", 0)) - a.listen(1) - connect_address = a.getsockname() # assigned (host, port) pair - try: - self.writer.connect(connect_address) - break # success - except socket.error as detail: - if (not hasattr(errno, 'WSAEADDRINUSE') or - errno_from_exception(detail) != errno.WSAEADDRINUSE): - # "Address already in use" is the only error - # I've seen on two WinXP Pro SP2 boxes, under - # Pythons 2.3.5 and 2.4.1. - raise - # (10048, 'Address already in use') - # assert count <= 2 # never triggered in Tim's tests - if count >= 10: # I've never seen it go above 2 - a.close() - self.writer.close() - raise socket.error("Cannot bind trigger!") - # Close `a` and try again. Note: I originally put a short - # sleep() here, but it didn't appear to help or hurt. - a.close() - - self.reader, addr = a.accept() - set_close_exec(self.reader.fileno()) - self.reader.setblocking(0) - self.writer.setblocking(0) - a.close() - self.reader_fd = self.reader.fileno() - - def fileno(self): - return self.reader.fileno() - - def write_fileno(self): - return self.writer.fileno() - - def wake(self): - try: - self.writer.send(b"x") - except (IOError, socket.error, ValueError): - pass - - def consume(self): - try: - while True: - result = self.reader.recv(1024) - if not result: - break - except (IOError, socket.error): - pass - - def close(self): - self.reader.close() - try_close(self.writer) diff --git a/tornado/platform/epoll.py b/tornado/platform/epoll.py deleted file mode 100644 index 4e34617406..0000000000 --- a/tornado/platform/epoll.py +++ /dev/null @@ -1,25 +0,0 @@ -# -# Copyright 2012 Facebook -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -"""EPoll-based IOLoop implementation for Linux systems.""" -from __future__ import absolute_import, division, print_function - -import select - -from tornado.ioloop import PollIOLoop - - -class EPollIOLoop(PollIOLoop): - def initialize(self, **kwargs): - super(EPollIOLoop, self).initialize(impl=select.epoll(), **kwargs) diff --git a/tornado/platform/interface.py b/tornado/platform/interface.py deleted file mode 100644 index cac5326465..0000000000 --- a/tornado/platform/interface.py +++ /dev/null @@ -1,66 +0,0 @@ -# -# Copyright 2011 Facebook -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -"""Interfaces for platform-specific functionality. - -This module exists primarily for documentation purposes and as base classes -for other tornado.platform modules. Most code should import the appropriate -implementation from `tornado.platform.auto`. -""" - -from __future__ import absolute_import, division, print_function - - -def set_close_exec(fd): - """Sets the close-on-exec bit (``FD_CLOEXEC``)for a file descriptor.""" - raise NotImplementedError() - - -class Waker(object): - """A socket-like object that can wake another thread from ``select()``. - - The `~tornado.ioloop.IOLoop` will add the Waker's `fileno()` to - its ``select`` (or ``epoll`` or ``kqueue``) calls. When another - thread wants to wake up the loop, it calls `wake`. Once it has woken - up, it will call `consume` to do any necessary per-wake cleanup. When - the ``IOLoop`` is closed, it closes its waker too. - """ - def fileno(self): - """Returns the read file descriptor for this waker. - - Must be suitable for use with ``select()`` or equivalent on the - local platform. - """ - raise NotImplementedError() - - def write_fileno(self): - """Returns the write file descriptor for this waker.""" - raise NotImplementedError() - - def wake(self): - """Triggers activity on the waker's file descriptor.""" - raise NotImplementedError() - - def consume(self): - """Called after the listen has woken up to do any necessary cleanup.""" - raise NotImplementedError() - - def close(self): - """Closes the waker's file descriptor(s).""" - raise NotImplementedError() - - -def monotonic_time(): - raise NotImplementedError() diff --git a/tornado/platform/kqueue.py b/tornado/platform/kqueue.py deleted file mode 100644 index 4e0aee02ee..0000000000 --- a/tornado/platform/kqueue.py +++ /dev/null @@ -1,90 +0,0 @@ -# -# Copyright 2012 Facebook -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -"""KQueue-based IOLoop implementation for BSD/Mac systems.""" -from __future__ import absolute_import, division, print_function - -import select - -from tornado.ioloop import IOLoop, PollIOLoop - -assert hasattr(select, 'kqueue'), 'kqueue not supported' - - -class _KQueue(object): - """A kqueue-based event loop for BSD/Mac systems.""" - def __init__(self): - self._kqueue = select.kqueue() - self._active = {} - - def fileno(self): - return self._kqueue.fileno() - - def close(self): - self._kqueue.close() - - def register(self, fd, events): - if fd in self._active: - raise IOError("fd %s already registered" % fd) - self._control(fd, events, select.KQ_EV_ADD) - self._active[fd] = events - - def modify(self, fd, events): - self.unregister(fd) - self.register(fd, events) - - def unregister(self, fd): - events = self._active.pop(fd) - self._control(fd, events, select.KQ_EV_DELETE) - - def _control(self, fd, events, flags): - kevents = [] - if events & IOLoop.WRITE: - kevents.append(select.kevent( - fd, filter=select.KQ_FILTER_WRITE, flags=flags)) - if events & IOLoop.READ: - kevents.append(select.kevent( - fd, filter=select.KQ_FILTER_READ, flags=flags)) - # Even though control() takes a list, it seems to return EINVAL - # on Mac OS X (10.6) when there is more than one event in the list. - for kevent in kevents: - self._kqueue.control([kevent], 0) - - def poll(self, timeout): - kevents = self._kqueue.control(None, 1000, timeout) - events = {} - for kevent in kevents: - fd = kevent.ident - if kevent.filter == select.KQ_FILTER_READ: - events[fd] = events.get(fd, 0) | IOLoop.READ - if kevent.filter == select.KQ_FILTER_WRITE: - if kevent.flags & select.KQ_EV_EOF: - # If an asynchronous connection is refused, kqueue - # returns a write event with the EOF flag set. - # Turn this into an error for consistency with the - # other IOLoop implementations. - # Note that for read events, EOF may be returned before - # all data has been consumed from the socket buffer, - # so we only check for EOF on write events. - events[fd] = IOLoop.ERROR - else: - events[fd] = events.get(fd, 0) | IOLoop.WRITE - if kevent.flags & select.KQ_EV_ERROR: - events[fd] = events.get(fd, 0) | IOLoop.ERROR - return events.items() - - -class KQueueIOLoop(PollIOLoop): - def initialize(self, **kwargs): - super(KQueueIOLoop, self).initialize(impl=_KQueue(), **kwargs) diff --git a/tornado/platform/posix.py b/tornado/platform/posix.py deleted file mode 100644 index 6fe1fa8372..0000000000 --- a/tornado/platform/posix.py +++ /dev/null @@ -1,69 +0,0 @@ -# -# Copyright 2011 Facebook -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -"""Posix implementations of platform-specific functionality.""" - -from __future__ import absolute_import, division, print_function - -import fcntl -import os - -from tornado.platform import common, interface - - -def set_close_exec(fd): - flags = fcntl.fcntl(fd, fcntl.F_GETFD) - fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) - - -def _set_nonblocking(fd): - flags = fcntl.fcntl(fd, fcntl.F_GETFL) - fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) - - -class Waker(interface.Waker): - def __init__(self): - r, w = os.pipe() - _set_nonblocking(r) - _set_nonblocking(w) - set_close_exec(r) - set_close_exec(w) - self.reader = os.fdopen(r, "rb", 0) - self.writer = os.fdopen(w, "wb", 0) - - def fileno(self): - return self.reader.fileno() - - def write_fileno(self): - return self.writer.fileno() - - def wake(self): - try: - self.writer.write(b"x") - except (IOError, ValueError): - pass - - def consume(self): - try: - while True: - result = self.reader.read() - if not result: - break - except IOError: - pass - - def close(self): - self.reader.close() - common.try_close(self.writer) diff --git a/tornado/platform/select.py b/tornado/platform/select.py deleted file mode 100644 index 14e8a4745c..0000000000 --- a/tornado/platform/select.py +++ /dev/null @@ -1,75 +0,0 @@ -# -# Copyright 2012 Facebook -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -"""Select-based IOLoop implementation. - -Used as a fallback for systems that don't support epoll or kqueue. -""" -from __future__ import absolute_import, division, print_function - -import select - -from tornado.ioloop import IOLoop, PollIOLoop - - -class _Select(object): - """A simple, select()-based IOLoop implementation for non-Linux systems""" - def __init__(self): - self.read_fds = set() - self.write_fds = set() - self.error_fds = set() - self.fd_sets = (self.read_fds, self.write_fds, self.error_fds) - - def close(self): - pass - - def register(self, fd, events): - if fd in self.read_fds or fd in self.write_fds or fd in self.error_fds: - raise IOError("fd %s already registered" % fd) - if events & IOLoop.READ: - self.read_fds.add(fd) - if events & IOLoop.WRITE: - self.write_fds.add(fd) - if events & IOLoop.ERROR: - self.error_fds.add(fd) - # Closed connections are reported as errors by epoll and kqueue, - # but as zero-byte reads by select, so when errors are requested - # we need to listen for both read and error. - # self.read_fds.add(fd) - - def modify(self, fd, events): - self.unregister(fd) - self.register(fd, events) - - def unregister(self, fd): - self.read_fds.discard(fd) - self.write_fds.discard(fd) - self.error_fds.discard(fd) - - def poll(self, timeout): - readable, writeable, errors = select.select( - self.read_fds, self.write_fds, self.error_fds, timeout) - events = {} - for fd in readable: - events[fd] = events.get(fd, 0) | IOLoop.READ - for fd in writeable: - events[fd] = events.get(fd, 0) | IOLoop.WRITE - for fd in errors: - events[fd] = events.get(fd, 0) | IOLoop.ERROR - return events.items() - - -class SelectIOLoop(PollIOLoop): - def initialize(self, **kwargs): - super(SelectIOLoop, self).initialize(impl=_Select(), **kwargs) diff --git a/tornado/platform/twisted.py b/tornado/platform/twisted.py index 4ae98be976..153fe436eb 100644 --- a/tornado/platform/twisted.py +++ b/tornado/platform/twisted.py @@ -1,6 +1,3 @@ -# Author: Ovidiu Predescu -# Date: July 2011 -# # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at @@ -12,505 +9,31 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -"""Bridges between the Twisted reactor and Tornado IOLoop. - -This module lets you run applications and libraries written for -Twisted in a Tornado application. It can be used in two modes, -depending on which library's underlying event loop you want to use. - -This module has been tested with Twisted versions 11.0.0 and newer. +"""Bridges between the Twisted package and Tornado. """ -from __future__ import absolute_import, division, print_function - -import datetime -import functools -import numbers import socket import sys import twisted.internet.abstract # type: ignore +import twisted.internet.asyncioreactor # type: ignore from twisted.internet.defer import Deferred # type: ignore -from twisted.internet.posixbase import PosixReactorBase # type: ignore -from twisted.internet.interfaces import IReactorFDSet, IDelayedCall, IReactorTime, IReadDescriptor, IWriteDescriptor # type: ignore # noqa: E501 -from twisted.python import failure, log # type: ignore -from twisted.internet import error # type: ignore +from twisted.python import failure # type: ignore import twisted.names.cache # type: ignore import twisted.names.client # type: ignore import twisted.names.hosts # type: ignore import twisted.names.resolve # type: ignore -from zope.interface import implementer # type: ignore from tornado.concurrent import Future, future_set_exc_info from tornado.escape import utf8 from tornado import gen -import tornado.ioloop -from tornado.log import app_log from tornado.netutil import Resolver -from tornado.stack_context import NullContext, wrap -from tornado.ioloop import IOLoop -from tornado.util import timedelta_to_seconds - - -@implementer(IDelayedCall) -class TornadoDelayedCall(object): - """DelayedCall object for Tornado.""" - def __init__(self, reactor, seconds, f, *args, **kw): - self._reactor = reactor - self._func = functools.partial(f, *args, **kw) - self._time = self._reactor.seconds() + seconds - self._timeout = self._reactor._io_loop.add_timeout(self._time, - self._called) - self._active = True - - def _called(self): - self._active = False - self._reactor._removeDelayedCall(self) - try: - self._func() - except: - app_log.error("_called caught exception", exc_info=True) - - def getTime(self): - return self._time - - def cancel(self): - self._active = False - self._reactor._io_loop.remove_timeout(self._timeout) - self._reactor._removeDelayedCall(self) - - def delay(self, seconds): - self._reactor._io_loop.remove_timeout(self._timeout) - self._time += seconds - self._timeout = self._reactor._io_loop.add_timeout(self._time, - self._called) - - def reset(self, seconds): - self._reactor._io_loop.remove_timeout(self._timeout) - self._time = self._reactor.seconds() + seconds - self._timeout = self._reactor._io_loop.add_timeout(self._time, - self._called) - - def active(self): - return self._active - - -@implementer(IReactorTime, IReactorFDSet) -class TornadoReactor(PosixReactorBase): - """Twisted reactor built on the Tornado IOLoop. - - `TornadoReactor` implements the Twisted reactor interface on top of - the Tornado IOLoop. To use it, simply call `install` at the beginning - of the application:: - - import tornado.platform.twisted - tornado.platform.twisted.install() - from twisted.internet import reactor - - When the app is ready to start, call ``IOLoop.current().start()`` - instead of ``reactor.run()``. - - It is also possible to create a non-global reactor by calling - ``tornado.platform.twisted.TornadoReactor()``. However, if - the `.IOLoop` and reactor are to be short-lived (such as those used in - unit tests), additional cleanup may be required. Specifically, it is - recommended to call:: - - reactor.fireSystemEvent('shutdown') - reactor.disconnectAll() - - before closing the `.IOLoop`. - - .. versionchanged:: 5.0 - The ``io_loop`` argument (deprecated since version 4.1) has been removed. - """ - def __init__(self): - self._io_loop = tornado.ioloop.IOLoop.current() - self._readers = {} # map of reader objects to fd - self._writers = {} # map of writer objects to fd - self._fds = {} # a map of fd to a (reader, writer) tuple - self._delayedCalls = {} - PosixReactorBase.__init__(self) - self.addSystemEventTrigger('during', 'shutdown', self.crash) - - # IOLoop.start() bypasses some of the reactor initialization. - # Fire off the necessary events if they weren't already triggered - # by reactor.run(). - def start_if_necessary(): - if not self._started: - self.fireSystemEvent('startup') - self._io_loop.add_callback(start_if_necessary) - - # IReactorTime - def seconds(self): - return self._io_loop.time() - - def callLater(self, seconds, f, *args, **kw): - dc = TornadoDelayedCall(self, seconds, f, *args, **kw) - self._delayedCalls[dc] = True - return dc - - def getDelayedCalls(self): - return [x for x in self._delayedCalls if x._active] - - def _removeDelayedCall(self, dc): - if dc in self._delayedCalls: - del self._delayedCalls[dc] - - # IReactorThreads - def callFromThread(self, f, *args, **kw): - assert callable(f), "%s is not callable" % f - with NullContext(): - # This NullContext is mainly for an edge case when running - # TwistedIOLoop on top of a TornadoReactor. - # TwistedIOLoop.add_callback uses reactor.callFromThread and - # should not pick up additional StackContexts along the way. - self._io_loop.add_callback(f, *args, **kw) - - # We don't need the waker code from the super class, Tornado uses - # its own waker. - def installWaker(self): - pass - - def wakeUp(self): - pass - - # IReactorFDSet - def _invoke_callback(self, fd, events): - if fd not in self._fds: - return - (reader, writer) = self._fds[fd] - if reader: - err = None - if reader.fileno() == -1: - err = error.ConnectionLost() - elif events & IOLoop.READ: - err = log.callWithLogger(reader, reader.doRead) - if err is None and events & IOLoop.ERROR: - err = error.ConnectionLost() - if err is not None: - self.removeReader(reader) - reader.readConnectionLost(failure.Failure(err)) - if writer: - err = None - if writer.fileno() == -1: - err = error.ConnectionLost() - elif events & IOLoop.WRITE: - err = log.callWithLogger(writer, writer.doWrite) - if err is None and events & IOLoop.ERROR: - err = error.ConnectionLost() - if err is not None: - self.removeWriter(writer) - writer.writeConnectionLost(failure.Failure(err)) - - def addReader(self, reader): - if reader in self._readers: - # Don't add the reader if it's already there - return - fd = reader.fileno() - self._readers[reader] = fd - if fd in self._fds: - (_, writer) = self._fds[fd] - self._fds[fd] = (reader, writer) - if writer: - # We already registered this fd for write events, - # update it for read events as well. - self._io_loop.update_handler(fd, IOLoop.READ | IOLoop.WRITE) - else: - with NullContext(): - self._fds[fd] = (reader, None) - self._io_loop.add_handler(fd, self._invoke_callback, - IOLoop.READ) - - def addWriter(self, writer): - if writer in self._writers: - return - fd = writer.fileno() - self._writers[writer] = fd - if fd in self._fds: - (reader, _) = self._fds[fd] - self._fds[fd] = (reader, writer) - if reader: - # We already registered this fd for read events, - # update it for write events as well. - self._io_loop.update_handler(fd, IOLoop.READ | IOLoop.WRITE) - else: - with NullContext(): - self._fds[fd] = (None, writer) - self._io_loop.add_handler(fd, self._invoke_callback, - IOLoop.WRITE) - - def removeReader(self, reader): - if reader in self._readers: - fd = self._readers.pop(reader) - (_, writer) = self._fds[fd] - if writer: - # We have a writer so we need to update the IOLoop for - # write events only. - self._fds[fd] = (None, writer) - self._io_loop.update_handler(fd, IOLoop.WRITE) - else: - # Since we have no writer registered, we remove the - # entry from _fds and unregister the handler from the - # IOLoop - del self._fds[fd] - self._io_loop.remove_handler(fd) - - def removeWriter(self, writer): - if writer in self._writers: - fd = self._writers.pop(writer) - (reader, _) = self._fds[fd] - if reader: - # We have a reader so we need to update the IOLoop for - # read events only. - self._fds[fd] = (reader, None) - self._io_loop.update_handler(fd, IOLoop.READ) - else: - # Since we have no reader registered, we remove the - # entry from the _fds and unregister the handler from - # the IOLoop. - del self._fds[fd] - self._io_loop.remove_handler(fd) - - def removeAll(self): - return self._removeAll(self._readers, self._writers) - - def getReaders(self): - return self._readers.keys() - - def getWriters(self): - return self._writers.keys() - - # The following functions are mainly used in twisted-style test cases; - # it is expected that most users of the TornadoReactor will call - # IOLoop.start() instead of Reactor.run(). - def stop(self): - PosixReactorBase.stop(self) - fire_shutdown = functools.partial(self.fireSystemEvent, "shutdown") - self._io_loop.add_callback(fire_shutdown) - - def crash(self): - PosixReactorBase.crash(self) - self._io_loop.stop() - - def doIteration(self, delay): - raise NotImplementedError("doIteration") - - def mainLoop(self): - # Since this class is intended to be used in applications - # where the top-level event loop is ``io_loop.start()`` rather - # than ``reactor.run()``, it is implemented a little - # differently than other Twisted reactors. We override - # ``mainLoop`` instead of ``doIteration`` and must implement - # timed call functionality on top of `.IOLoop.add_timeout` - # rather than using the implementation in - # ``PosixReactorBase``. - self._io_loop.start() - - -class _TestReactor(TornadoReactor): - """Subclass of TornadoReactor for use in unittests. - - This can't go in the test.py file because of import-order dependencies - with the Twisted reactor test builder. - """ - def __init__(self): - # always use a new ioloop - IOLoop.clear_current() - IOLoop(make_current=True) - super(_TestReactor, self).__init__() - IOLoop.clear_current() - def listenTCP(self, port, factory, backlog=50, interface=''): - # default to localhost to avoid firewall prompts on the mac - if not interface: - interface = '127.0.0.1' - return super(_TestReactor, self).listenTCP( - port, factory, backlog=backlog, interface=interface) +import typing - def listenUDP(self, port, protocol, interface='', maxPacketSize=8192): - if not interface: - interface = '127.0.0.1' - return super(_TestReactor, self).listenUDP( - port, protocol, interface=interface, maxPacketSize=maxPacketSize) - - -def install(): - """Install this package as the default Twisted reactor. - - ``install()`` must be called very early in the startup process, - before most other twisted-related imports. Conversely, because it - initializes the `.IOLoop`, it cannot be called before - `.fork_processes` or multi-process `~.TCPServer.start`. These - conflicting requirements make it difficult to use `.TornadoReactor` - in multi-process mode, and an external process manager such as - ``supervisord`` is recommended instead. - - .. versionchanged:: 5.0 - The ``io_loop`` argument (deprecated since version 4.1) has been removed. - - """ - reactor = TornadoReactor() - from twisted.internet.main import installReactor # type: ignore - installReactor(reactor) - return reactor - - -@implementer(IReadDescriptor, IWriteDescriptor) -class _FD(object): - def __init__(self, fd, fileobj, handler): - self.fd = fd - self.fileobj = fileobj - self.handler = handler - self.reading = False - self.writing = False - self.lost = False - - def fileno(self): - return self.fd - - def doRead(self): - if not self.lost: - self.handler(self.fileobj, tornado.ioloop.IOLoop.READ) - - def doWrite(self): - if not self.lost: - self.handler(self.fileobj, tornado.ioloop.IOLoop.WRITE) - - def connectionLost(self, reason): - if not self.lost: - self.handler(self.fileobj, tornado.ioloop.IOLoop.ERROR) - self.lost = True - - writeConnectionLost = readConnectionLost = connectionLost - - def logPrefix(self): - return '' - - -class TwistedIOLoop(tornado.ioloop.IOLoop): - """IOLoop implementation that runs on Twisted. - - `TwistedIOLoop` implements the Tornado IOLoop interface on top of - the Twisted reactor. Recommended usage:: - - from tornado.platform.twisted import TwistedIOLoop - from twisted.internet import reactor - TwistedIOLoop().install() - # Set up your tornado application as usual using `IOLoop.instance` - reactor.run() - - Uses the global Twisted reactor by default. To create multiple - ``TwistedIOLoops`` in the same process, you must pass a unique reactor - when constructing each one. - - Not compatible with `tornado.process.Subprocess.set_exit_callback` - because the ``SIGCHLD`` handlers used by Tornado and Twisted conflict - with each other. - - See also :meth:`tornado.ioloop.IOLoop.install` for general notes on - installing alternative IOLoops. - """ - def initialize(self, reactor=None, **kwargs): - super(TwistedIOLoop, self).initialize(**kwargs) - if reactor is None: - import twisted.internet.reactor # type: ignore - reactor = twisted.internet.reactor - self.reactor = reactor - self.fds = {} - - def close(self, all_fds=False): - fds = self.fds - self.reactor.removeAll() - for c in self.reactor.getDelayedCalls(): - c.cancel() - if all_fds: - for fd in fds.values(): - self.close_fd(fd.fileobj) - - def add_handler(self, fd, handler, events): - if fd in self.fds: - raise ValueError('fd %s added twice' % fd) - fd, fileobj = self.split_fd(fd) - self.fds[fd] = _FD(fd, fileobj, wrap(handler)) - if events & tornado.ioloop.IOLoop.READ: - self.fds[fd].reading = True - self.reactor.addReader(self.fds[fd]) - if events & tornado.ioloop.IOLoop.WRITE: - self.fds[fd].writing = True - self.reactor.addWriter(self.fds[fd]) - - def update_handler(self, fd, events): - fd, fileobj = self.split_fd(fd) - if events & tornado.ioloop.IOLoop.READ: - if not self.fds[fd].reading: - self.fds[fd].reading = True - self.reactor.addReader(self.fds[fd]) - else: - if self.fds[fd].reading: - self.fds[fd].reading = False - self.reactor.removeReader(self.fds[fd]) - if events & tornado.ioloop.IOLoop.WRITE: - if not self.fds[fd].writing: - self.fds[fd].writing = True - self.reactor.addWriter(self.fds[fd]) - else: - if self.fds[fd].writing: - self.fds[fd].writing = False - self.reactor.removeWriter(self.fds[fd]) - - def remove_handler(self, fd): - fd, fileobj = self.split_fd(fd) - if fd not in self.fds: - return - self.fds[fd].lost = True - if self.fds[fd].reading: - self.reactor.removeReader(self.fds[fd]) - if self.fds[fd].writing: - self.reactor.removeWriter(self.fds[fd]) - del self.fds[fd] - - def start(self): - old_current = IOLoop.current(instance=False) - try: - self._setup_logging() - self.make_current() - self.reactor.run() - finally: - if old_current is None: - IOLoop.clear_current() - else: - old_current.make_current() - - def stop(self): - self.reactor.crash() - - def add_timeout(self, deadline, callback, *args, **kwargs): - # This method could be simplified (since tornado 4.0) by - # overriding call_at instead of add_timeout, but we leave it - # for now as a test of backwards-compatibility. - if isinstance(deadline, numbers.Real): - delay = max(deadline - self.time(), 0) - elif isinstance(deadline, datetime.timedelta): - delay = timedelta_to_seconds(deadline) - else: - raise TypeError("Unsupported deadline %r") - return self.reactor.callLater( - delay, self._run_callback, - functools.partial(wrap(callback), *args, **kwargs)) - - def remove_timeout(self, timeout): - if timeout.active(): - timeout.cancel() - - def add_callback(self, callback, *args, **kwargs): - self.reactor.callFromThread( - self._run_callback, - functools.partial(wrap(callback), *args, **kwargs)) - - def add_callback_from_signal(self, callback, *args, **kwargs): - self.add_callback(callback, *args, **kwargs) +if typing.TYPE_CHECKING: + from typing import Generator, Any, List, Tuple # noqa: F401 class TwistedResolver(Resolver): @@ -529,21 +52,30 @@ class TwistedResolver(Resolver): .. versionchanged:: 5.0 The ``io_loop`` argument (deprecated since version 4.1) has been removed. + + .. deprecated:: 6.2 + This class is deprecated and will be removed in Tornado 7.0. Use the default + thread-based resolver instead. """ - def initialize(self): + + def initialize(self) -> None: # partial copy of twisted.names.client.createResolver, which doesn't # allow for a reactor to be passed in. - self.reactor = tornado.platform.twisted.TornadoReactor() + self.reactor = twisted.internet.asyncioreactor.AsyncioSelectorReactor() - host_resolver = twisted.names.hosts.Resolver('/etc/hosts') + host_resolver = twisted.names.hosts.Resolver("/etc/hosts") cache_resolver = twisted.names.cache.CacheResolver(reactor=self.reactor) - real_resolver = twisted.names.client.Resolver('/etc/resolv.conf', - reactor=self.reactor) + real_resolver = twisted.names.client.Resolver( + "/etc/resolv.conf", reactor=self.reactor + ) self.resolver = twisted.names.resolve.ResolverChain( - [host_resolver, cache_resolver, real_resolver]) + [host_resolver, cache_resolver, real_resolver] + ) @gen.coroutine - def resolve(self, host, port, family=0): + def resolve( + self, host: str, port: int, family: int = 0 + ) -> "Generator[Any, Any, List[Tuple[int, Any]]]": # getHostByName doesn't accept IP addresses, so if the input # looks like an IP address just return it immediately. if twisted.internet.abstract.isIPAddress(host): @@ -554,7 +86,7 @@ def resolve(self, host, port, family=0): resolved_family = socket.AF_INET6 else: deferred = self.resolver.getHostByName(utf8(host)) - fut = Future() + fut = Future() # type: Future[Any] deferred.addBoth(fut.set_result) resolved = yield fut if isinstance(resolved, failure.Failure): @@ -569,25 +101,50 @@ def resolve(self, host, port, family=0): else: resolved_family = socket.AF_UNSPEC if family != socket.AF_UNSPEC and family != resolved_family: - raise Exception('Requested socket family %d but got %d' % - (family, resolved_family)) - result = [ - (resolved_family, (resolved, port)), - ] - raise gen.Return(result) + raise Exception( + "Requested socket family %d but got %d" % (family, resolved_family) + ) + result = [(typing.cast(int, resolved_family), (resolved, port))] + return result + + +def install() -> None: + """Install ``AsyncioSelectorReactor`` as the default Twisted reactor. + + .. deprecated:: 5.1 + + This function is provided for backwards compatibility; code + that does not require compatibility with older versions of + Tornado should use + ``twisted.internet.asyncioreactor.install()`` directly. + + .. versionchanged:: 6.0.3 + In Tornado 5.x and before, this function installed a reactor + based on the Tornado ``IOLoop``. When that reactor + implementation was removed in Tornado 6.0.0, this function was + removed as well. It was restored in Tornado 6.0.3 using the + ``asyncio`` reactor instead. + + """ + from twisted.internet.asyncioreactor import install + + install() + + +if hasattr(gen.convert_yielded, "register"): -if hasattr(gen.convert_yielded, 'register'): @gen.convert_yielded.register(Deferred) # type: ignore - def _(d): - f = Future() + def _(d: Deferred) -> Future: + f = Future() # type: Future[Any] - def errback(failure): + def errback(failure: failure.Failure) -> None: try: failure.raiseException() # Should never happen, but just in case raise Exception("errback called without error") except: future_set_exc_info(f, sys.exc_info()) + d.addCallbacks(f.set_result, errback) return f diff --git a/tornado/platform/windows.py b/tornado/platform/windows.py deleted file mode 100644 index 4127700659..0000000000 --- a/tornado/platform/windows.py +++ /dev/null @@ -1,20 +0,0 @@ -# NOTE: win32 support is currently experimental, and not recommended -# for production use. - - -from __future__ import absolute_import, division, print_function -import ctypes # type: ignore -import ctypes.wintypes # type: ignore - -# See: http://msdn.microsoft.com/en-us/library/ms724935(VS.85).aspx -SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation -SetHandleInformation.argtypes = (ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD, ctypes.wintypes.DWORD) # noqa: E501 -SetHandleInformation.restype = ctypes.wintypes.BOOL - -HANDLE_FLAG_INHERIT = 0x00000001 - - -def set_close_exec(fd): - success = SetHandleInformation(fd, HANDLE_FLAG_INHERIT, 0) - if not success: - raise ctypes.WinError() diff --git a/tornado/process.py b/tornado/process.py index 122fd7e14b..26428feb77 100644 --- a/tornado/process.py +++ b/tornado/process.py @@ -17,10 +17,8 @@ the server into multiple processes and managing subprocesses. """ -from __future__ import absolute_import, division, print_function - -import errno import os +import multiprocessing import signal import subprocess import sys @@ -28,35 +26,26 @@ from binascii import hexlify -from tornado.concurrent import Future, future_set_result_unless_cancelled +from tornado.concurrent import ( + Future, + future_set_result_unless_cancelled, + future_set_exception_unless_cancelled, +) from tornado import ioloop from tornado.iostream import PipeIOStream from tornado.log import gen_log -from tornado.platform.auto import set_close_exec -from tornado import stack_context -from tornado.util import errno_from_exception, PY3 -try: - import multiprocessing -except ImportError: - # Multiprocessing is not available on Google App Engine. - multiprocessing = None +import typing +from typing import Optional, Any, Callable -if PY3: - long = int +if typing.TYPE_CHECKING: + from typing import List # noqa: F401 # Re-export this exception for convenience. -try: - CalledProcessError = subprocess.CalledProcessError -except AttributeError: - # The subprocess module exists in Google App Engine, but is empty. - # This module isn't very useful in that case, but it should - # at least be importable. - if 'APPENGINE_RUNTIME' not in os.environ: - raise +CalledProcessError = subprocess.CalledProcessError -def cpu_count(): +def cpu_count() -> int: """Returns the number of processors on this machine.""" if multiprocessing is None: return 1 @@ -65,38 +54,34 @@ def cpu_count(): except NotImplementedError: pass try: - return os.sysconf("SC_NPROCESSORS_CONF") + return os.sysconf("SC_NPROCESSORS_CONF") # type: ignore except (AttributeError, ValueError): pass gen_log.error("Could not detect number of processors; assuming 1") return 1 -def _reseed_random(): - if 'random' not in sys.modules: +def _reseed_random() -> None: + if "random" not in sys.modules: return import random + # If os.urandom is available, this method does the same thing as # random.seed (at least as of python 2.6). If os.urandom is not # available, we mix in the pid in addition to a timestamp. try: - seed = long(hexlify(os.urandom(16)), 16) + seed = int(hexlify(os.urandom(16)), 16) except NotImplementedError: seed = int(time.time() * 1000) ^ os.getpid() random.seed(seed) -def _pipe_cloexec(): - r, w = os.pipe() - set_close_exec(r) - set_close_exec(w) - return r, w - - _task_id = None -def fork_processes(num_processes, max_restarts=100): +def fork_processes( + num_processes: Optional[int], max_restarts: Optional[int] = None +) -> int: """Starts multiple worker processes. If ``num_processes`` is None or <= 0, we detect the number of cores @@ -117,10 +102,20 @@ def fork_processes(num_processes, max_restarts=100): number between 0 and ``num_processes``. Processes that exit abnormally (due to a signal or non-zero exit status) are restarted with the same id (up to ``max_restarts`` times). In the parent - process, ``fork_processes`` returns None if all child processes - have exited normally, but will otherwise only exit by throwing an - exception. + process, ``fork_processes`` calls ``sys.exit(0)`` after all child + processes have exited normally. + + max_restarts defaults to 100. + + Availability: Unix """ + if sys.platform == "win32": + # The exact form of this condition matters to mypy; it understands + # if but not assert in this context. + raise Exception("fork not available on windows") + if max_restarts is None: + max_restarts = 100 + global _task_id assert _task_id is None if num_processes is None or num_processes <= 0: @@ -128,7 +123,7 @@ def fork_processes(num_processes, max_restarts=100): gen_log.info("Starting %d processes", num_processes) children = {} - def start_child(i): + def start_child(i: int) -> Optional[int]: pid = os.fork() if pid == 0: # child process @@ -146,21 +141,24 @@ def start_child(i): return id num_restarts = 0 while children: - try: - pid, status = os.wait() - except OSError as e: - if errno_from_exception(e) == errno.EINTR: - continue - raise + pid, status = os.wait() if pid not in children: continue id = children.pop(pid) if os.WIFSIGNALED(status): - gen_log.warning("child %d (pid %d) killed by signal %d, restarting", - id, pid, os.WTERMSIG(status)) + gen_log.warning( + "child %d (pid %d) killed by signal %d, restarting", + id, + pid, + os.WTERMSIG(status), + ) elif os.WEXITSTATUS(status) != 0: - gen_log.warning("child %d (pid %d) exited with status %d, restarting", - id, pid, os.WEXITSTATUS(status)) + gen_log.warning( + "child %d (pid %d) exited with status %d, restarting", + id, + pid, + os.WEXITSTATUS(status), + ) else: gen_log.info("child %d (pid %d) exited normally", id, pid) continue @@ -177,7 +175,7 @@ def start_child(i): sys.exit(0) -def task_id(): +def task_id() -> Optional[int]: """Returns the current task id, if any. Returns None if this process was not created by `fork_processes`. @@ -207,32 +205,34 @@ class Subprocess(object): The ``io_loop`` argument (deprecated since version 4.1) has been removed. """ + STREAM = object() _initialized = False _waiting = {} # type: ignore + _old_sigchld = None - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: self.io_loop = ioloop.IOLoop.current() # All FDs we create should be closed on error; those in to_close # should be closed in the parent process on success. - pipe_fds = [] - to_close = [] - if kwargs.get('stdin') is Subprocess.STREAM: - in_r, in_w = _pipe_cloexec() - kwargs['stdin'] = in_r + pipe_fds = [] # type: List[int] + to_close = [] # type: List[int] + if kwargs.get("stdin") is Subprocess.STREAM: + in_r, in_w = os.pipe() + kwargs["stdin"] = in_r pipe_fds.extend((in_r, in_w)) to_close.append(in_r) self.stdin = PipeIOStream(in_w) - if kwargs.get('stdout') is Subprocess.STREAM: - out_r, out_w = _pipe_cloexec() - kwargs['stdout'] = out_w + if kwargs.get("stdout") is Subprocess.STREAM: + out_r, out_w = os.pipe() + kwargs["stdout"] = out_w pipe_fds.extend((out_r, out_w)) to_close.append(out_w) self.stdout = PipeIOStream(out_r) - if kwargs.get('stderr') is Subprocess.STREAM: - err_r, err_w = _pipe_cloexec() - kwargs['stderr'] = err_w + if kwargs.get("stderr") is Subprocess.STREAM: + err_r, err_w = os.pipe() + kwargs["stderr"] = err_w pipe_fds.extend((err_r, err_w)) to_close.append(err_w) self.stderr = PipeIOStream(err_r) @@ -244,13 +244,14 @@ def __init__(self, *args, **kwargs): raise for fd in to_close: os.close(fd) - for attr in ['stdin', 'stdout', 'stderr', 'pid']: + self.pid = self.proc.pid + for attr in ["stdin", "stdout", "stderr"]: if not hasattr(self, attr): # don't clobber streams set above setattr(self, attr, getattr(self.proc, attr)) - self._exit_callback = None - self.returncode = None + self._exit_callback = None # type: Optional[Callable[[int], None]] + self.returncode = None # type: Optional[int] - def set_exit_callback(self, callback): + def set_exit_callback(self, callback: Callable[[int], None]) -> None: """Runs ``callback`` when this process exits. The callback takes one argument, the return code of the process. @@ -264,13 +265,15 @@ def set_exit_callback(self, callback): In many cases a close callback on the stdout or stderr streams can be used as an alternative to an exit callback if the signal handler is causing a problem. + + Availability: Unix """ - self._exit_callback = stack_context.wrap(callback) + self._exit_callback = callback Subprocess.initialize() Subprocess._waiting[self.pid] = self Subprocess._try_cleanup_process(self.pid) - def wait_for_exit(self, raise_error=True): + def wait_for_exit(self, raise_error: bool = True) -> "Future[int]": """Returns a `.Future` which resolves when the process exits. Usage:: @@ -285,20 +288,25 @@ def wait_for_exit(self, raise_error=True): to suppress this behavior and return the exit status without raising. .. versionadded:: 4.2 + + Availability: Unix """ - future = Future() + future = Future() # type: Future[int] - def callback(ret): + def callback(ret: int) -> None: if ret != 0 and raise_error: # Unfortunately we don't have the original args any more. - future.set_exception(CalledProcessError(ret, None)) + future_set_exception_unless_cancelled( + future, CalledProcessError(ret, "unknown") + ) else: future_set_result_unless_cancelled(future, ret) + self.set_exit_callback(callback) return future @classmethod - def initialize(cls): + def initialize(cls) -> None: """Initializes the ``SIGCHLD`` handler. The signal handler is run on an `.IOLoop` to avoid locking issues. @@ -309,17 +317,20 @@ def initialize(cls): .. versionchanged:: 5.0 The ``io_loop`` argument (deprecated since version 4.1) has been removed. + + Availability: Unix """ if cls._initialized: return io_loop = ioloop.IOLoop.current() cls._old_sigchld = signal.signal( signal.SIGCHLD, - lambda sig, frame: io_loop.add_callback_from_signal(cls._cleanup)) + lambda sig, frame: io_loop.add_callback_from_signal(cls._cleanup), + ) cls._initialized = True @classmethod - def uninitialize(cls): + def uninitialize(cls) -> None: """Removes the ``SIGCHLD`` handler.""" if not cls._initialized: return @@ -327,30 +338,31 @@ def uninitialize(cls): cls._initialized = False @classmethod - def _cleanup(cls): + def _cleanup(cls) -> None: for pid in list(cls._waiting.keys()): # make a copy cls._try_cleanup_process(pid) @classmethod - def _try_cleanup_process(cls, pid): + def _try_cleanup_process(cls, pid: int) -> None: try: - ret_pid, status = os.waitpid(pid, os.WNOHANG) - except OSError as e: - if errno_from_exception(e) == errno.ECHILD: - return + ret_pid, status = os.waitpid(pid, os.WNOHANG) # type: ignore + except ChildProcessError: + return if ret_pid == 0: return assert ret_pid == pid subproc = cls._waiting.pop(pid) - subproc.io_loop.add_callback_from_signal( - subproc._set_returncode, status) + subproc.io_loop.add_callback_from_signal(subproc._set_returncode, status) - def _set_returncode(self, status): - if os.WIFSIGNALED(status): - self.returncode = -os.WTERMSIG(status) + def _set_returncode(self, status: int) -> None: + if sys.platform == "win32": + self.returncode = -1 else: - assert os.WIFEXITED(status) - self.returncode = os.WEXITSTATUS(status) + if os.WIFSIGNALED(status): + self.returncode = -os.WTERMSIG(status) + else: + assert os.WIFEXITED(status) + self.returncode = os.WEXITSTATUS(status) # We've taken over wait() duty from the subprocess.Popen # object. If we don't inform it of the process's return code, # it will log a warning at destruction in python 3.6+. diff --git a/tornado/test/__init__.py b/tornado/py.typed similarity index 100% rename from tornado/test/__init__.py rename to tornado/py.typed diff --git a/tornado/queues.py b/tornado/queues.py index 23b8bb9caa..1e87f62e09 100644 --- a/tornado/queues.py +++ b/tornado/queues.py @@ -25,48 +25,60 @@ """ -from __future__ import absolute_import, division, print_function - import collections +import datetime import heapq from tornado import gen, ioloop from tornado.concurrent import Future, future_set_result_unless_cancelled from tornado.locks import Event -__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty'] +from typing import Union, TypeVar, Generic, Awaitable, Optional +import typing + +if typing.TYPE_CHECKING: + from typing import Deque, Tuple, Any # noqa: F401 + +_T = TypeVar("_T") + +__all__ = ["Queue", "PriorityQueue", "LifoQueue", "QueueFull", "QueueEmpty"] class QueueEmpty(Exception): """Raised by `.Queue.get_nowait` when the queue has no items.""" + pass class QueueFull(Exception): """Raised by `.Queue.put_nowait` when a queue is at its maximum size.""" + pass -def _set_timeout(future, timeout): +def _set_timeout( + future: Future, timeout: Union[None, float, datetime.timedelta] +) -> None: if timeout: - def on_timeout(): + + def on_timeout() -> None: if not future.done(): future.set_exception(gen.TimeoutError()) + io_loop = ioloop.IOLoop.current() timeout_handle = io_loop.add_timeout(timeout, on_timeout) - future.add_done_callback( - lambda _: io_loop.remove_timeout(timeout_handle)) + future.add_done_callback(lambda _: io_loop.remove_timeout(timeout_handle)) -class _QueueIterator(object): - def __init__(self, q): +class _QueueIterator(Generic[_T]): + def __init__(self, q: "Queue[_T]") -> None: self.q = q - def __anext__(self): + def __anext__(self) -> Awaitable[_T]: return self.q.get() -class Queue(object): +class Queue(Generic[_T]): """Coordinate producer and consumer coroutines. If maxsize is 0 (the default) the queue size is unbounded. @@ -79,28 +91,24 @@ class Queue(object): q = Queue(maxsize=2) - @gen.coroutine - def consumer(): - while True: - item = yield q.get() + async def consumer(): + async for item in q: try: print('Doing work on %s' % item) - yield gen.sleep(0.01) + await gen.sleep(0.01) finally: q.task_done() - @gen.coroutine - def producer(): + async def producer(): for item in range(5): - yield q.put(item) + await q.put(item) print('Put %s' % item) - @gen.coroutine - def main(): + async def main(): # Start consumer without waiting (since it never finishes). IOLoop.current().spawn_callback(consumer) - yield producer() # Wait for producer to put all tasks. - yield q.join() # Wait for consumer to finish all tasks. + await producer() # Wait for producer to put all tasks. + await q.join() # Wait for consumer to finish all tasks. print('Done') IOLoop.current().run_sync(main) @@ -119,11 +127,14 @@ def main(): Doing work on 4 Done - In Python 3.5, `Queue` implements the async iterator protocol, so - ``consumer()`` could be rewritten as:: - async def consumer(): - async for item in q: + In versions of Python without native coroutines (before 3.5), + ``consumer()`` could be written as:: + + @gen.coroutine + def consumer(): + while True: + item = yield q.get() try: print('Doing work on %s' % item) yield gen.sleep(0.01) @@ -134,7 +145,12 @@ async def consumer(): Added ``async for`` support in Python 3.5. """ - def __init__(self, maxsize=0): + + # Exact type depends on subclass. Could be another generic + # parameter and use protocols to be more precise here. + _queue = None # type: Any + + def __init__(self, maxsize: int = 0) -> None: if maxsize is None: raise TypeError("maxsize can't be None") @@ -143,31 +159,33 @@ def __init__(self, maxsize=0): self._maxsize = maxsize self._init() - self._getters = collections.deque([]) # Futures. - self._putters = collections.deque([]) # Pairs of (item, Future). + self._getters = collections.deque([]) # type: Deque[Future[_T]] + self._putters = collections.deque([]) # type: Deque[Tuple[_T, Future[None]]] self._unfinished_tasks = 0 self._finished = Event() self._finished.set() @property - def maxsize(self): + def maxsize(self) -> int: """Number of items allowed in the queue.""" return self._maxsize - def qsize(self): + def qsize(self) -> int: """Number of items in the queue.""" return len(self._queue) - def empty(self): + def empty(self) -> bool: return not self._queue - def full(self): + def full(self) -> bool: if self.maxsize == 0: return False else: return self.qsize() >= self.maxsize - def put(self, item, timeout=None): + def put( + self, item: _T, timeout: Optional[Union[float, datetime.timedelta]] = None + ) -> "Future[None]": """Put an item into the queue, perhaps waiting until there is room. Returns a Future, which raises `tornado.util.TimeoutError` after a @@ -178,7 +196,7 @@ def put(self, item, timeout=None): `datetime.timedelta` object for a deadline relative to the current time. """ - future = Future() + future = Future() # type: Future[None] try: self.put_nowait(item) except QueueFull: @@ -188,7 +206,7 @@ def put(self, item, timeout=None): future.set_result(None) return future - def put_nowait(self, item): + def put_nowait(self, item: _T) -> None: """Put an item into the queue without blocking. If no free slot is immediately available, raise `QueueFull`. @@ -204,18 +222,30 @@ def put_nowait(self, item): else: self.__put_internal(item) - def get(self, timeout=None): + def get( + self, timeout: Optional[Union[float, datetime.timedelta]] = None + ) -> Awaitable[_T]: """Remove and return an item from the queue. - Returns a Future which resolves once an item is available, or raises + Returns an awaitable which resolves once an item is available, or raises `tornado.util.TimeoutError` after a timeout. ``timeout`` may be a number denoting a time (on the same scale as `tornado.ioloop.IOLoop.time`, normally `time.time`), or a `datetime.timedelta` object for a deadline relative to the current time. + + .. note:: + + The ``timeout`` argument of this method differs from that + of the standard library's `queue.Queue.get`. That method + interprets numeric values as relative timeouts; this one + interprets them as absolute deadlines and requires + ``timedelta`` objects for relative timeouts (consistent + with other timeouts in Tornado). + """ - future = Future() + future = Future() # type: Future[_T] try: future.set_result(self.get_nowait()) except QueueEmpty: @@ -223,7 +253,7 @@ def get(self, timeout=None): _set_timeout(future, timeout) return future - def get_nowait(self): + def get_nowait(self) -> _T: """Remove and return an item from the queue without blocking. Return an item if one is immediately available, else raise @@ -241,7 +271,7 @@ def get_nowait(self): else: raise QueueEmpty - def task_done(self): + def task_done(self) -> None: """Indicate that a formerly enqueued task is complete. Used by queue consumers. For each `.get` used to fetch a task, a @@ -254,39 +284,42 @@ def task_done(self): Raises `ValueError` if called more times than `.put`. """ if self._unfinished_tasks <= 0: - raise ValueError('task_done() called too many times') + raise ValueError("task_done() called too many times") self._unfinished_tasks -= 1 if self._unfinished_tasks == 0: self._finished.set() - def join(self, timeout=None): + def join( + self, timeout: Optional[Union[float, datetime.timedelta]] = None + ) -> Awaitable[None]: """Block until all items in the queue are processed. - Returns a Future, which raises `tornado.util.TimeoutError` after a + Returns an awaitable, which raises `tornado.util.TimeoutError` after a timeout. """ return self._finished.wait(timeout) - def __aiter__(self): + def __aiter__(self) -> _QueueIterator[_T]: return _QueueIterator(self) # These three are overridable in subclasses. - def _init(self): + def _init(self) -> None: self._queue = collections.deque() - def _get(self): + def _get(self) -> _T: return self._queue.popleft() - def _put(self, item): + def _put(self, item: _T) -> None: self._queue.append(item) + # End of the overridable methods. - def __put_internal(self, item): + def __put_internal(self, item: _T) -> None: self._unfinished_tasks += 1 self._finished.clear() self._put(item) - def _consume_expired(self): + def _consume_expired(self) -> None: # Remove timed-out waiters. while self._putters and self._putters[0][1].done(): self._putters.popleft() @@ -294,23 +327,22 @@ def _consume_expired(self): while self._getters and self._getters[0].done(): self._getters.popleft() - def __repr__(self): - return '<%s at %s %s>' % ( - type(self).__name__, hex(id(self)), self._format()) + def __repr__(self) -> str: + return "<%s at %s %s>" % (type(self).__name__, hex(id(self)), self._format()) - def __str__(self): - return '<%s %s>' % (type(self).__name__, self._format()) + def __str__(self) -> str: + return "<%s %s>" % (type(self).__name__, self._format()) - def _format(self): - result = 'maxsize=%r' % (self.maxsize, ) - if getattr(self, '_queue', None): - result += ' queue=%r' % self._queue + def _format(self) -> str: + result = "maxsize=%r" % (self.maxsize,) + if getattr(self, "_queue", None): + result += " queue=%r" % self._queue if self._getters: - result += ' getters[%s]' % len(self._getters) + result += " getters[%s]" % len(self._getters) if self._putters: - result += ' putters[%s]' % len(self._putters) + result += " putters[%s]" % len(self._putters) if self._unfinished_tasks: - result += ' tasks=%s' % self._unfinished_tasks + result += " tasks=%s" % self._unfinished_tasks return result @@ -338,13 +370,14 @@ class PriorityQueue(Queue): (1, 'medium-priority item') (10, 'low-priority item') """ - def _init(self): + + def _init(self) -> None: self._queue = [] - def _put(self, item): + def _put(self, item: _T) -> None: heapq.heappush(self._queue, item) - def _get(self): + def _get(self) -> _T: return heapq.heappop(self._queue) @@ -370,11 +403,12 @@ class LifoQueue(Queue): 2 3 """ - def _init(self): + + def _init(self) -> None: self._queue = [] - def _put(self, item): + def _put(self, item: _T) -> None: self._queue.append(item) - def _get(self): + def _get(self) -> _T: return self._queue.pop() diff --git a/tornado/routing.py b/tornado/routing.py index e56d1a75f9..a145d71916 100644 --- a/tornado/routing.py +++ b/tornado/routing.py @@ -142,7 +142,7 @@ def request_callable(request): router = RuleRouter([ Rule(HostMatches("example.com"), RuleRouter([ - Rule(PathMatches("/app1/.*"), Application([(r"/app1/handler", Handler)]))), + Rule(PathMatches("/app1/.*"), Application([(r"/app1/handler", Handler)])), ])) ]) @@ -175,8 +175,6 @@ def request_callable(request): """ -from __future__ import absolute_import, division, print_function - import re from functools import partial @@ -186,17 +184,15 @@ def request_callable(request): from tornado.log import app_log from tornado.util import basestring_type, import_object, re_unescape, unicode_type -try: - import typing # noqa -except ImportError: - pass +from typing import Any, Union, Optional, Awaitable, List, Dict, Pattern, Tuple, overload class Router(httputil.HTTPServerConnectionDelegate): """Abstract router interface.""" - def find_handler(self, request, **kwargs): - # type: (httputil.HTTPServerRequest, typing.Any)->httputil.HTTPMessageDelegate + def find_handler( + self, request: httputil.HTTPServerRequest, **kwargs: Any + ) -> Optional[httputil.HTTPMessageDelegate]: """Must be implemented to return an appropriate instance of `~.httputil.HTTPMessageDelegate` that can serve the request. Routing implementations may pass additional kwargs to extend the routing logic. @@ -208,7 +204,9 @@ def find_handler(self, request, **kwargs): """ raise NotImplementedError() - def start_request(self, server_conn, request_conn): + def start_request( + self, server_conn: object, request_conn: httputil.HTTPConnection + ) -> httputil.HTTPMessageDelegate: return _RoutingDelegate(self, server_conn, request_conn) @@ -217,7 +215,7 @@ class ReversibleRouter(Router): and support reversing them to original urls. """ - def reverse_url(self, name, *args): + def reverse_url(self, name: str, *args: Any) -> Optional[str]: """Returns url string for a given route name and arguments or ``None`` if no match is found. @@ -229,50 +227,80 @@ def reverse_url(self, name, *args): class _RoutingDelegate(httputil.HTTPMessageDelegate): - def __init__(self, router, server_conn, request_conn): + def __init__( + self, router: Router, server_conn: object, request_conn: httputil.HTTPConnection + ) -> None: self.server_conn = server_conn self.request_conn = request_conn - self.delegate = None + self.delegate = None # type: Optional[httputil.HTTPMessageDelegate] self.router = router # type: Router - def headers_received(self, start_line, headers): + def headers_received( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + ) -> Optional[Awaitable[None]]: + assert isinstance(start_line, httputil.RequestStartLine) request = httputil.HTTPServerRequest( connection=self.request_conn, server_connection=self.server_conn, - start_line=start_line, headers=headers) + start_line=start_line, + headers=headers, + ) self.delegate = self.router.find_handler(request) if self.delegate is None: - app_log.debug("Delegate for %s %s request not found", - start_line.method, start_line.path) + app_log.debug( + "Delegate for %s %s request not found", + start_line.method, + start_line.path, + ) self.delegate = _DefaultMessageDelegate(self.request_conn) return self.delegate.headers_received(start_line, headers) - def data_received(self, chunk): + def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: + assert self.delegate is not None return self.delegate.data_received(chunk) - def finish(self): + def finish(self) -> None: + assert self.delegate is not None self.delegate.finish() - def on_connection_close(self): + def on_connection_close(self) -> None: + assert self.delegate is not None self.delegate.on_connection_close() class _DefaultMessageDelegate(httputil.HTTPMessageDelegate): - def __init__(self, connection): + def __init__(self, connection: httputil.HTTPConnection) -> None: self.connection = connection - def finish(self): + def finish(self) -> None: self.connection.write_headers( - httputil.ResponseStartLine("HTTP/1.1", 404, "Not Found"), httputil.HTTPHeaders()) + httputil.ResponseStartLine("HTTP/1.1", 404, "Not Found"), + httputil.HTTPHeaders(), + ) self.connection.finish() +# _RuleList can either contain pre-constructed Rules or a sequence of +# arguments to be passed to the Rule constructor. +_RuleList = List[ + Union[ + "Rule", + List[Any], # Can't do detailed typechecking of lists. + Tuple[Union[str, "Matcher"], Any], + Tuple[Union[str, "Matcher"], Any, Dict[str, Any]], + Tuple[Union[str, "Matcher"], Any, Dict[str, Any], str], + ] +] + + class RuleRouter(Router): """Rule-based router implementation.""" - def __init__(self, rules=None): + def __init__(self, rules: Optional[_RuleList] = None) -> None: """Constructs a router from an ordered list of rules:: RuleRouter([ @@ -299,11 +327,11 @@ def __init__(self, rules=None): :arg rules: a list of `Rule` instances or tuples of `Rule` constructor arguments. """ - self.rules = [] # type: typing.List[Rule] + self.rules = [] # type: List[Rule] if rules: self.add_rules(rules) - def add_rules(self, rules): + def add_rules(self, rules: _RuleList) -> None: """Appends new rules to the router. :arg rules: a list of Rule instances (or tuples of arguments, which are @@ -319,7 +347,7 @@ def add_rules(self, rules): self.rules.append(self.process_rule(rule)) - def process_rule(self, rule): + def process_rule(self, rule: "Rule") -> "Rule": """Override this method for additional preprocessing of each rule. :arg Rule rule: a rule to be processed. @@ -327,22 +355,27 @@ def process_rule(self, rule): """ return rule - def find_handler(self, request, **kwargs): + def find_handler( + self, request: httputil.HTTPServerRequest, **kwargs: Any + ) -> Optional[httputil.HTTPMessageDelegate]: for rule in self.rules: target_params = rule.matcher.match(request) if target_params is not None: if rule.target_kwargs: - target_params['target_kwargs'] = rule.target_kwargs + target_params["target_kwargs"] = rule.target_kwargs delegate = self.get_target_delegate( - rule.target, request, **target_params) + rule.target, request, **target_params + ) if delegate is not None: return delegate return None - def get_target_delegate(self, target, request, **target_params): + def get_target_delegate( + self, target: Any, request: httputil.HTTPServerRequest, **target_params: Any + ) -> Optional[httputil.HTTPMessageDelegate]: """Returns an instance of `~.httputil.HTTPMessageDelegate` for a Rule's target. This method is called by `~.find_handler` and can be extended to provide additional target types. @@ -356,9 +389,11 @@ def get_target_delegate(self, target, request, **target_params): return target.find_handler(request, **target_params) elif isinstance(target, httputil.HTTPServerConnectionDelegate): + assert request.connection is not None return target.start_request(request.server_connection, request.connection) elif callable(target): + assert request.connection is not None return _CallableAdapter( partial(target, **target_params), request.connection ) @@ -374,23 +409,23 @@ class ReversibleRuleRouter(ReversibleRouter, RuleRouter): in a rule's matcher (see `Matcher.reverse`). """ - def __init__(self, rules=None): - self.named_rules = {} # type: typing.Dict[str] - super(ReversibleRuleRouter, self).__init__(rules) + def __init__(self, rules: Optional[_RuleList] = None) -> None: + self.named_rules = {} # type: Dict[str, Any] + super().__init__(rules) - def process_rule(self, rule): - rule = super(ReversibleRuleRouter, self).process_rule(rule) + def process_rule(self, rule: "Rule") -> "Rule": + rule = super().process_rule(rule) if rule.name: if rule.name in self.named_rules: app_log.warning( - "Multiple handlers named %s; replacing previous value", - rule.name) + "Multiple handlers named %s; replacing previous value", rule.name + ) self.named_rules[rule.name] = rule return rule - def reverse_url(self, name, *args): + def reverse_url(self, name: str, *args: Any) -> Optional[str]: if name in self.named_rules: return self.named_rules[name].matcher.reverse(*args) @@ -406,7 +441,13 @@ def reverse_url(self, name, *args): class Rule(object): """A routing rule.""" - def __init__(self, matcher, target, target_kwargs=None, name=None): + def __init__( + self, + matcher: "Matcher", + target: Any, + target_kwargs: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + ) -> None: """Constructs a Rule instance. :arg Matcher matcher: a `Matcher` instance used for determining @@ -433,19 +474,23 @@ def __init__(self, matcher, target, target_kwargs=None, name=None): self.target_kwargs = target_kwargs if target_kwargs else {} self.name = name - def reverse(self, *args): + def reverse(self, *args: Any) -> Optional[str]: return self.matcher.reverse(*args) - def __repr__(self): - return '%s(%r, %s, kwargs=%r, name=%r)' % \ - (self.__class__.__name__, self.matcher, - self.target, self.target_kwargs, self.name) + def __repr__(self) -> str: + return "%s(%r, %s, kwargs=%r, name=%r)" % ( + self.__class__.__name__, + self.matcher, + self.target, + self.target_kwargs, + self.name, + ) class Matcher(object): """Represents a matcher for request features.""" - def match(self, request): + def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]: """Matches current instance against the request. :arg httputil.HTTPServerRequest request: current HTTP request @@ -457,7 +502,7 @@ def match(self, request): ``None`` must be returned to indicate that there is no match.""" raise NotImplementedError() - def reverse(self, *args): + def reverse(self, *args: Any) -> Optional[str]: """Reconstructs full url from matcher instance and additional arguments.""" return None @@ -465,14 +510,14 @@ def reverse(self, *args): class AnyMatches(Matcher): """Matches any request.""" - def match(self, request): + def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]: return {} class HostMatches(Matcher): """Matches requests from hosts specified by ``host_pattern`` regex.""" - def __init__(self, host_pattern): + def __init__(self, host_pattern: Union[str, Pattern]) -> None: if isinstance(host_pattern, basestring_type): if not host_pattern.endswith("$"): host_pattern += "$" @@ -480,7 +525,7 @@ def __init__(self, host_pattern): else: self.host_pattern = host_pattern - def match(self, request): + def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]: if self.host_pattern.match(request.host_name): return {} @@ -492,11 +537,11 @@ class DefaultHostMatches(Matcher): Always returns no match if ``X-Real-Ip`` header is present. """ - def __init__(self, application, host_pattern): + def __init__(self, application: Any, host_pattern: Pattern) -> None: self.application = application self.host_pattern = host_pattern - def match(self, request): + def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]: # Look for default host if not behind load balancer (for debugging) if "X-Real-Ip" not in request.headers: if self.host_pattern.match(self.application.default_host): @@ -507,28 +552,30 @@ def match(self, request): class PathMatches(Matcher): """Matches requests with paths specified by ``path_pattern`` regex.""" - def __init__(self, path_pattern): + def __init__(self, path_pattern: Union[str, Pattern]) -> None: if isinstance(path_pattern, basestring_type): - if not path_pattern.endswith('$'): - path_pattern += '$' + if not path_pattern.endswith("$"): + path_pattern += "$" self.regex = re.compile(path_pattern) else: self.regex = path_pattern - assert len(self.regex.groupindex) in (0, self.regex.groups), \ - ("groups in url regexes must either be all named or all " - "positional: %r" % self.regex.pattern) + assert len(self.regex.groupindex) in (0, self.regex.groups), ( + "groups in url regexes must either be all named or all " + "positional: %r" % self.regex.pattern + ) self._path, self._group_count = self._find_groups() - def match(self, request): + def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]: match = self.regex.match(request.path) if match is None: return None if not self.regex.groups: return {} - path_args, path_kwargs = [], {} + path_args = [] # type: List[bytes] + path_kwargs = {} # type: Dict[str, bytes] # Pass matched groups to the handler. Since # match.groups() includes both named and @@ -536,18 +583,19 @@ def match(self, request): # or groupdict but not both. if self.regex.groupindex: path_kwargs = dict( - (str(k), _unquote_or_none(v)) - for (k, v) in match.groupdict().items()) + (str(k), _unquote_or_none(v)) for (k, v) in match.groupdict().items() + ) else: path_args = [_unquote_or_none(s) for s in match.groups()] return dict(path_args=path_args, path_kwargs=path_kwargs) - def reverse(self, *args): + def reverse(self, *args: Any) -> Optional[str]: if self._path is None: raise ValueError("Cannot reverse url regex " + self.regex.pattern) - assert len(args) == self._group_count, "required number of arguments " \ - "not found" + assert len(args) == self._group_count, ( + "required number of arguments " "not found" + ) if not len(args): return self._path converted_args = [] @@ -557,29 +605,35 @@ def reverse(self, *args): converted_args.append(url_escape(utf8(a), plus=False)) return self._path % tuple(converted_args) - def _find_groups(self): + def _find_groups(self) -> Tuple[Optional[str], Optional[int]]: """Returns a tuple (reverse string, group count) for a url. For example: Given the url pattern /([0-9]{4})/([a-z-]+)/, this method would return ('/%s/%s/', 2). """ pattern = self.regex.pattern - if pattern.startswith('^'): + if pattern.startswith("^"): pattern = pattern[1:] - if pattern.endswith('$'): + if pattern.endswith("$"): pattern = pattern[:-1] - if self.regex.groups != pattern.count('('): + if self.regex.groups != pattern.count("("): # The pattern is too complicated for our simplistic matching, # so we can't support reversing it. return None, None pieces = [] - for fragment in pattern.split('('): - if ')' in fragment: - paren_loc = fragment.index(')') + for fragment in pattern.split("("): + if ")" in fragment: + paren_loc = fragment.index(")") if paren_loc >= 0: - pieces.append('%s' + fragment[paren_loc + 1:]) + try: + unescaped_fragment = re_unescape(fragment[paren_loc + 1 :]) + except ValueError: + # If we can't unescape part of it, we can't + # reverse this url. + return (None, None) + pieces.append("%s" + unescaped_fragment) else: try: unescaped_fragment = re_unescape(fragment) @@ -589,7 +643,7 @@ def _find_groups(self): return (None, None) pieces.append(unescaped_fragment) - return ''.join(pieces), self.regex.groups + return "".join(pieces), self.regex.groups class URLSpec(Rule): @@ -599,7 +653,14 @@ class URLSpec(Rule): `URLSpec` is now a subclass of a `Rule` with `PathMatches` matcher and is preserved for backwards compatibility. """ - def __init__(self, pattern, handler, kwargs=None, name=None): + + def __init__( + self, + pattern: Union[str, Pattern], + handler: Any, + kwargs: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + ) -> None: """Parameters: * ``pattern``: Regular expression to be matched. Any capturing @@ -617,19 +678,34 @@ def __init__(self, pattern, handler, kwargs=None, name=None): `~.web.Application.reverse_url`. """ - super(URLSpec, self).__init__(PathMatches(pattern), handler, kwargs, name) + matcher = PathMatches(pattern) + super().__init__(matcher, handler, kwargs, name) - self.regex = self.matcher.regex + self.regex = matcher.regex self.handler_class = self.target self.kwargs = kwargs - def __repr__(self): - return '%s(%r, %s, kwargs=%r, name=%r)' % \ - (self.__class__.__name__, self.regex.pattern, - self.handler_class, self.kwargs, self.name) + def __repr__(self) -> str: + return "%s(%r, %s, kwargs=%r, name=%r)" % ( + self.__class__.__name__, + self.regex.pattern, + self.handler_class, + self.kwargs, + self.name, + ) + + +@overload +def _unquote_or_none(s: str) -> bytes: + pass + + +@overload # noqa: F811 +def _unquote_or_none(s: None) -> None: + pass -def _unquote_or_none(s): +def _unquote_or_none(s: Optional[str]) -> Optional[bytes]: # noqa: F811 """None-safe wrapper around url_unescape to handle unmatched optional groups correctly. diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py index 7696dd1849..1d9d14ba17 100644 --- a/tornado/simple_httpclient.py +++ b/tornado/simple_httpclient.py @@ -1,17 +1,25 @@ -from __future__ import absolute_import, division, print_function - -from tornado.escape import utf8, _unicode -from tornado import gen -from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _RequestProxy +from tornado.escape import _unicode +from tornado import gen, version +from tornado.httpclient import ( + HTTPResponse, + HTTPError, + AsyncHTTPClient, + main, + _RequestProxy, + HTTPRequest, +) from tornado import httputil from tornado.http1connection import HTTP1Connection, HTTP1ConnectionParameters from tornado.ioloop import IOLoop -from tornado.iostream import StreamClosedError -from tornado.netutil import Resolver, OverrideResolver, _client_ssl_defaults +from tornado.iostream import StreamClosedError, IOStream +from tornado.netutil import ( + Resolver, + OverrideResolver, + _client_ssl_defaults, + is_valid_ip, +) from tornado.log import gen_log -from tornado import stack_context from tornado.tcpclient import TCPClient -from tornado.util import PY3 import base64 import collections @@ -19,20 +27,18 @@ import functools import re import socket +import ssl import sys +import time from io import BytesIO +import urllib.parse +from typing import Dict, Any, Callable, Optional, Type, Union +from types import TracebackType +import typing -if PY3: - import urllib.parse as urlparse -else: - import urlparse - -try: - import ssl -except ImportError: - # ssl is not available on Google App Engine. - ssl = None +if typing.TYPE_CHECKING: + from typing import Deque, Tuple, List # noqa: F401 class HTTPTimeoutError(HTTPError): @@ -43,11 +49,12 @@ class HTTPTimeoutError(HTTPError): .. versionadded:: 5.1 """ - def __init__(self, message): - super(HTTPTimeoutError, self).__init__(599, message=message) - def __str__(self): - return self.message + def __init__(self, message: str) -> None: + super().__init__(599, message=message) + + def __str__(self) -> str: + return self.message or "Timeout" class HTTPStreamClosedError(HTTPError): @@ -61,11 +68,12 @@ class HTTPStreamClosedError(HTTPError): .. versionadded:: 5.1 """ - def __init__(self, message): - super(HTTPStreamClosedError, self).__init__(599, message=message) - def __str__(self): - return self.message + def __init__(self, message: str) -> None: + super().__init__(599, message=message) + + def __str__(self) -> str: + return self.message or "Stream closed" class SimpleAsyncHTTPClient(AsyncHTTPClient): @@ -77,10 +85,17 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): are not reused, and callers cannot select the network interface to be used. """ - def initialize(self, max_clients=10, - hostname_mapping=None, max_buffer_size=104857600, - resolver=None, defaults=None, max_header_size=None, - max_body_size=None): + + def initialize( # type: ignore + self, + max_clients: int = 10, + hostname_mapping: Optional[Dict[str, str]] = None, + max_buffer_size: int = 104857600, + resolver: Optional[Resolver] = None, + defaults: Optional[Dict[str, Any]] = None, + max_header_size: Optional[int] = None, + max_body_size: Optional[int] = None, + ) -> None: """Creates a AsyncHTTPClient. Only a single AsyncHTTPClient instance exists per IOLoop @@ -113,11 +128,17 @@ def initialize(self, max_clients=10, .. versionchanged:: 4.2 Added the ``max_body_size`` argument. """ - super(SimpleAsyncHTTPClient, self).initialize(defaults=defaults) + super().initialize(defaults=defaults) self.max_clients = max_clients - self.queue = collections.deque() - self.active = {} - self.waiting = {} + self.queue = ( + collections.deque() + ) # type: Deque[Tuple[object, HTTPRequest, Callable[[HTTPResponse], None]]] + self.active = ( + {} + ) # type: Dict[object, Tuple[HTTPRequest, Callable[[HTTPResponse], None]]] + self.waiting = ( + {} + ) # type: Dict[object, Tuple[HTTPRequest, Callable[[HTTPResponse], None], object]] self.max_buffer_size = max_buffer_size self.max_header_size = max_header_size self.max_body_size = max_body_size @@ -130,65 +151,86 @@ def initialize(self, max_clients=10, self.resolver = Resolver() self.own_resolver = True if hostname_mapping is not None: - self.resolver = OverrideResolver(resolver=self.resolver, - mapping=hostname_mapping) + self.resolver = OverrideResolver( + resolver=self.resolver, mapping=hostname_mapping + ) self.tcp_client = TCPClient(resolver=self.resolver) - def close(self): - super(SimpleAsyncHTTPClient, self).close() + def close(self) -> None: + super().close() if self.own_resolver: self.resolver.close() self.tcp_client.close() - def fetch_impl(self, request, callback): + def fetch_impl( + self, request: HTTPRequest, callback: Callable[[HTTPResponse], None] + ) -> None: key = object() self.queue.append((key, request, callback)) - if not len(self.active) < self.max_clients: - timeout_handle = self.io_loop.add_timeout( - self.io_loop.time() + min(request.connect_timeout, - request.request_timeout), - functools.partial(self._on_timeout, key, "in request queue")) - else: - timeout_handle = None + assert request.connect_timeout is not None + assert request.request_timeout is not None + timeout_handle = None + if len(self.active) >= self.max_clients: + timeout = ( + min(request.connect_timeout, request.request_timeout) + or request.connect_timeout + or request.request_timeout + ) # min but skip zero + if timeout: + timeout_handle = self.io_loop.add_timeout( + self.io_loop.time() + timeout, + functools.partial(self._on_timeout, key, "in request queue"), + ) self.waiting[key] = (request, callback, timeout_handle) self._process_queue() if self.queue: - gen_log.debug("max_clients limit reached, request queued. " - "%d active, %d queued requests." % ( - len(self.active), len(self.queue))) - - def _process_queue(self): - with stack_context.NullContext(): - while self.queue and len(self.active) < self.max_clients: - key, request, callback = self.queue.popleft() - if key not in self.waiting: - continue - self._remove_timeout(key) - self.active[key] = (request, callback) - release_callback = functools.partial(self._release_fetch, key) - self._handle_request(request, release_callback, callback) - - def _connection_class(self): + gen_log.debug( + "max_clients limit reached, request queued. " + "%d active, %d queued requests." % (len(self.active), len(self.queue)) + ) + + def _process_queue(self) -> None: + while self.queue and len(self.active) < self.max_clients: + key, request, callback = self.queue.popleft() + if key not in self.waiting: + continue + self._remove_timeout(key) + self.active[key] = (request, callback) + release_callback = functools.partial(self._release_fetch, key) + self._handle_request(request, release_callback, callback) + + def _connection_class(self) -> type: return _HTTPConnection - def _handle_request(self, request, release_callback, final_callback): + def _handle_request( + self, + request: HTTPRequest, + release_callback: Callable[[], None], + final_callback: Callable[[HTTPResponse], None], + ) -> None: self._connection_class()( - self, request, release_callback, - final_callback, self.max_buffer_size, self.tcp_client, - self.max_header_size, self.max_body_size) - - def _release_fetch(self, key): + self, + request, + release_callback, + final_callback, + self.max_buffer_size, + self.tcp_client, + self.max_header_size, + self.max_body_size, + ) + + def _release_fetch(self, key: object) -> None: del self.active[key] self._process_queue() - def _remove_timeout(self, key): + def _remove_timeout(self, key: object) -> None: if key in self.waiting: request, callback, timeout_handle = self.waiting[key] if timeout_handle is not None: self.io_loop.remove_timeout(timeout_handle) del self.waiting[key] - def _on_timeout(self, key, info=None): + def _on_timeout(self, key: object, info: Optional[str] = None) -> None: """Timeout callback of request. Construct a timeout HTTPResponse when a timeout occurs. @@ -201,20 +243,34 @@ def _on_timeout(self, key, info=None): error_message = "Timeout {0}".format(info) if info else "Timeout" timeout_response = HTTPResponse( - request, 599, error=HTTPTimeoutError(error_message), - request_time=self.io_loop.time() - request.start_time) + request, + 599, + error=HTTPTimeoutError(error_message), + request_time=self.io_loop.time() - request.start_time, + ) self.io_loop.add_callback(callback, timeout_response) del self.waiting[key] class _HTTPConnection(httputil.HTTPMessageDelegate): - _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]) - - def __init__(self, client, request, release_callback, - final_callback, max_buffer_size, tcp_client, - max_header_size, max_body_size): + _SUPPORTED_METHODS = set( + ["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"] + ) + + def __init__( + self, + client: Optional[SimpleAsyncHTTPClient], + request: HTTPRequest, + release_callback: Callable[[], None], + final_callback: Callable[[HTTPResponse], None], + max_buffer_size: int, + tcp_client: TCPClient, + max_header_size: int, + max_body_size: int, + ) -> None: self.io_loop = IOLoop.current() self.start_time = self.io_loop.time() + self.start_wall_time = time.time() self.client = client self.request = request self.release_callback = release_callback @@ -223,18 +279,22 @@ def __init__(self, client, request, release_callback, self.tcp_client = tcp_client self.max_header_size = max_header_size self.max_body_size = max_body_size - self.code = None - self.headers = None - self.chunks = [] + self.code = None # type: Optional[int] + self.headers = None # type: Optional[httputil.HTTPHeaders] + self.chunks = [] # type: List[bytes] self._decompressor = None # Timeout handle returned by IOLoop.add_timeout - self._timeout = None + self._timeout = None # type: object self._sockaddr = None - with stack_context.ExceptionStackContext(self._handle_exception): - self.parsed = urlparse.urlsplit(_unicode(self.request.url)) + IOLoop.current().add_future( + gen.convert_yielded(self.run()), lambda f: f.result() + ) + + async def run(self) -> None: + try: + self.parsed = urllib.parse.urlsplit(_unicode(self.request.url)) if self.parsed.scheme not in ("http", "https"): - raise ValueError("Unsupported url scheme: %s" % - self.request.url) + raise ValueError("Unsupported url scheme: %s" % self.request.url) # urlsplit results have hostname and port results, but they # didn't support ipv6 literals until python 2.7. netloc = self.parsed.netloc @@ -243,55 +303,186 @@ def __init__(self, client, request, release_callback, host, port = httputil.split_host_and_port(netloc) if port is None: port = 443 if self.parsed.scheme == "https" else 80 - if re.match(r'^\[.*\]$', host): + if re.match(r"^\[.*\]$", host): # raw ipv6 addresses in urls are enclosed in brackets host = host[1:-1] self.parsed_hostname = host # save final host for _on_connect - if request.allow_ipv6 is False: + if self.request.allow_ipv6 is False: af = socket.AF_INET else: af = socket.AF_UNSPEC ssl_options = self._get_ssl_options(self.parsed.scheme) - timeout = min(self.request.connect_timeout, self.request.request_timeout) + source_ip = None + if self.request.network_interface: + if is_valid_ip(self.request.network_interface): + source_ip = self.request.network_interface + else: + raise ValueError( + "Unrecognized IPv4 or IPv6 address for network_interface, got %r" + % (self.request.network_interface,) + ) + + if self.request.connect_timeout and self.request.request_timeout: + timeout = min( + self.request.connect_timeout, self.request.request_timeout + ) + elif self.request.connect_timeout: + timeout = self.request.connect_timeout + elif self.request.request_timeout: + timeout = self.request.request_timeout + else: + timeout = 0 if timeout: self._timeout = self.io_loop.add_timeout( self.start_time + timeout, - stack_context.wrap(functools.partial(self._on_timeout, "while connecting"))) - fut = self.tcp_client.connect(host, port, af=af, - ssl_options=ssl_options, - max_buffer_size=self.max_buffer_size) - fut.add_done_callback(stack_context.wrap(self._on_connect)) - - def _get_ssl_options(self, scheme): + functools.partial(self._on_timeout, "while connecting"), + ) + stream = await self.tcp_client.connect( + host, + port, + af=af, + ssl_options=ssl_options, + max_buffer_size=self.max_buffer_size, + source_ip=source_ip, + ) + + if self.final_callback is None: + # final_callback is cleared if we've hit our timeout. + stream.close() + return + self.stream = stream + self.stream.set_close_callback(self.on_connection_close) + self._remove_timeout() + if self.final_callback is None: + return + if self.request.request_timeout: + self._timeout = self.io_loop.add_timeout( + self.start_time + self.request.request_timeout, + functools.partial(self._on_timeout, "during request"), + ) + if ( + self.request.method not in self._SUPPORTED_METHODS + and not self.request.allow_nonstandard_methods + ): + raise KeyError("unknown method %s" % self.request.method) + for key in ( + "proxy_host", + "proxy_port", + "proxy_username", + "proxy_password", + "proxy_auth_mode", + ): + if getattr(self.request, key, None): + raise NotImplementedError("%s not supported" % key) + if "Connection" not in self.request.headers: + self.request.headers["Connection"] = "close" + if "Host" not in self.request.headers: + if "@" in self.parsed.netloc: + self.request.headers["Host"] = self.parsed.netloc.rpartition("@")[ + -1 + ] + else: + self.request.headers["Host"] = self.parsed.netloc + username, password = None, None + if self.parsed.username is not None: + username, password = self.parsed.username, self.parsed.password + elif self.request.auth_username is not None: + username = self.request.auth_username + password = self.request.auth_password or "" + if username is not None: + assert password is not None + if self.request.auth_mode not in (None, "basic"): + raise ValueError("unsupported auth_mode %s", self.request.auth_mode) + self.request.headers["Authorization"] = "Basic " + _unicode( + base64.b64encode( + httputil.encode_username_password(username, password) + ) + ) + if self.request.user_agent: + self.request.headers["User-Agent"] = self.request.user_agent + elif self.request.headers.get("User-Agent") is None: + self.request.headers["User-Agent"] = "Tornado/{}".format(version) + if not self.request.allow_nonstandard_methods: + # Some HTTP methods nearly always have bodies while others + # almost never do. Fail in this case unless the user has + # opted out of sanity checks with allow_nonstandard_methods. + body_expected = self.request.method in ("POST", "PATCH", "PUT") + body_present = ( + self.request.body is not None + or self.request.body_producer is not None + ) + if (body_expected and not body_present) or ( + body_present and not body_expected + ): + raise ValueError( + "Body must %sbe None for method %s (unless " + "allow_nonstandard_methods is true)" + % ("not " if body_expected else "", self.request.method) + ) + if self.request.expect_100_continue: + self.request.headers["Expect"] = "100-continue" + if self.request.body is not None: + # When body_producer is used the caller is responsible for + # setting Content-Length (or else chunked encoding will be used). + self.request.headers["Content-Length"] = str(len(self.request.body)) + if ( + self.request.method == "POST" + and "Content-Type" not in self.request.headers + ): + self.request.headers[ + "Content-Type" + ] = "application/x-www-form-urlencoded" + if self.request.decompress_response: + self.request.headers["Accept-Encoding"] = "gzip" + req_path = (self.parsed.path or "/") + ( + ("?" + self.parsed.query) if self.parsed.query else "" + ) + self.connection = self._create_connection(stream) + start_line = httputil.RequestStartLine(self.request.method, req_path, "") + self.connection.write_headers(start_line, self.request.headers) + if self.request.expect_100_continue: + await self.connection.read_response(self) + else: + await self._write_body(True) + except Exception: + if not self._handle_exception(*sys.exc_info()): + raise + + def _get_ssl_options( + self, scheme: str + ) -> Union[None, Dict[str, Any], ssl.SSLContext]: if scheme == "https": if self.request.ssl_options is not None: return self.request.ssl_options # If we are using the defaults, don't construct a # new SSLContext. - if (self.request.validate_cert and - self.request.ca_certs is None and - self.request.client_cert is None and - self.request.client_key is None): + if ( + self.request.validate_cert + and self.request.ca_certs is None + and self.request.client_cert is None + and self.request.client_key is None + ): return _client_ssl_defaults ssl_ctx = ssl.create_default_context( - ssl.Purpose.SERVER_AUTH, - cafile=self.request.ca_certs) + ssl.Purpose.SERVER_AUTH, cafile=self.request.ca_certs + ) if not self.request.validate_cert: ssl_ctx.check_hostname = False ssl_ctx.verify_mode = ssl.CERT_NONE if self.request.client_cert is not None: - ssl_ctx.load_cert_chain(self.request.client_cert, - self.request.client_key) - if hasattr(ssl, 'OP_NO_COMPRESSION'): + ssl_ctx.load_cert_chain( + self.request.client_cert, self.request.client_key + ) + if hasattr(ssl, "OP_NO_COMPRESSION"): # See netutil.ssl_options_to_context ssl_ctx.options |= ssl.OP_NO_COMPRESSION return ssl_ctx return None - def _on_timeout(self, info=None): + def _on_timeout(self, info: Optional[str] = None) -> None: """Timeout callback of _HTTPConnection instance. Raise a `HTTPTimeoutError` when a timeout occurs. @@ -301,147 +492,64 @@ def _on_timeout(self, info=None): self._timeout = None error_message = "Timeout {0}".format(info) if info else "Timeout" if self.final_callback is not None: - raise HTTPTimeoutError(error_message) + self._handle_exception( + HTTPTimeoutError, HTTPTimeoutError(error_message), None + ) - def _remove_timeout(self): + def _remove_timeout(self) -> None: if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None - def _on_connect(self, stream_fut): - stream = stream_fut.result() - if self.final_callback is None: - # final_callback is cleared if we've hit our timeout. - stream.close() - return - self.stream = stream - self.stream.set_close_callback(self.on_connection_close) - self._remove_timeout() - if self.final_callback is None: - return - if self.request.request_timeout: - self._timeout = self.io_loop.add_timeout( - self.start_time + self.request.request_timeout, - stack_context.wrap(functools.partial(self._on_timeout, "during request"))) - if (self.request.method not in self._SUPPORTED_METHODS and - not self.request.allow_nonstandard_methods): - raise KeyError("unknown method %s" % self.request.method) - for key in ('network_interface', - 'proxy_host', 'proxy_port', - 'proxy_username', 'proxy_password', - 'proxy_auth_mode'): - if getattr(self.request, key, None): - raise NotImplementedError('%s not supported' % key) - if "Connection" not in self.request.headers: - self.request.headers["Connection"] = "close" - if "Host" not in self.request.headers: - if '@' in self.parsed.netloc: - self.request.headers["Host"] = self.parsed.netloc.rpartition('@')[-1] - else: - self.request.headers["Host"] = self.parsed.netloc - username, password = None, None - if self.parsed.username is not None: - username, password = self.parsed.username, self.parsed.password - elif self.request.auth_username is not None: - username = self.request.auth_username - password = self.request.auth_password or '' - if username is not None: - if self.request.auth_mode not in (None, "basic"): - raise ValueError("unsupported auth_mode %s", - self.request.auth_mode) - auth = utf8(username) + b":" + utf8(password) - self.request.headers["Authorization"] = (b"Basic " + - base64.b64encode(auth)) - if self.request.user_agent: - self.request.headers["User-Agent"] = self.request.user_agent - if not self.request.allow_nonstandard_methods: - # Some HTTP methods nearly always have bodies while others - # almost never do. Fail in this case unless the user has - # opted out of sanity checks with allow_nonstandard_methods. - body_expected = self.request.method in ("POST", "PATCH", "PUT") - body_present = (self.request.body is not None or - self.request.body_producer is not None) - if ((body_expected and not body_present) or - (body_present and not body_expected)): - raise ValueError( - 'Body must %sbe None for method %s (unless ' - 'allow_nonstandard_methods is true)' % - ('not ' if body_expected else '', self.request.method)) - if self.request.expect_100_continue: - self.request.headers["Expect"] = "100-continue" - if self.request.body is not None: - # When body_producer is used the caller is responsible for - # setting Content-Length (or else chunked encoding will be used). - self.request.headers["Content-Length"] = str(len( - self.request.body)) - if (self.request.method == "POST" and - "Content-Type" not in self.request.headers): - self.request.headers["Content-Type"] = "application/x-www-form-urlencoded" - if self.request.decompress_response: - self.request.headers["Accept-Encoding"] = "gzip" - req_path = ((self.parsed.path or '/') + - (('?' + self.parsed.query) if self.parsed.query else '')) - self.connection = self._create_connection(stream) - start_line = httputil.RequestStartLine(self.request.method, - req_path, '') - self.connection.write_headers(start_line, self.request.headers) - if self.request.expect_100_continue: - self._read_response() - else: - self._write_body(True) - - def _create_connection(self, stream): + def _create_connection(self, stream: IOStream) -> HTTP1Connection: stream.set_nodelay(True) connection = HTTP1Connection( - stream, True, + stream, + True, HTTP1ConnectionParameters( no_keep_alive=True, max_header_size=self.max_header_size, max_body_size=self.max_body_size, - decompress=self.request.decompress_response), - self._sockaddr) + decompress=bool(self.request.decompress_response), + ), + self._sockaddr, + ) return connection - def _write_body(self, start_read): + async def _write_body(self, start_read: bool) -> None: if self.request.body is not None: self.connection.write(self.request.body) elif self.request.body_producer is not None: fut = self.request.body_producer(self.connection.write) if fut is not None: - fut = gen.convert_yielded(fut) - - def on_body_written(fut): - fut.result() - self.connection.finish() - if start_read: - self._read_response() - self.io_loop.add_future(fut, on_body_written) - return + await fut self.connection.finish() if start_read: - self._read_response() - - def _read_response(self): - # Ensure that any exception raised in read_response ends up in our - # stack context. - self.io_loop.add_future( - self.connection.read_response(self), - lambda f: f.result()) + try: + await self.connection.read_response(self) + except StreamClosedError: + if not self._handle_exception(*sys.exc_info()): + raise - def _release(self): + def _release(self) -> None: if self.release_callback is not None: release_callback = self.release_callback - self.release_callback = None + self.release_callback = None # type: ignore release_callback() - def _run_callback(self, response): + def _run_callback(self, response: HTTPResponse) -> None: self._release() if self.final_callback is not None: final_callback = self.final_callback - self.final_callback = None + self.final_callback = None # type: ignore self.io_loop.add_callback(final_callback, response) - def _handle_exception(self, typ, value, tb): + def _handle_exception( + self, + typ: "Optional[Type[BaseException]]", + value: Optional[BaseException], + tb: Optional[TracebackType], + ) -> bool: if self.final_callback: self._remove_timeout() if isinstance(value, StreamClosedError): @@ -449,9 +557,15 @@ def _handle_exception(self, typ, value, tb): value = HTTPStreamClosedError("Stream closed") else: value = value.real_error - self._run_callback(HTTPResponse(self.request, 599, error=value, - request_time=self.io_loop.time() - self.start_time, - )) + self._run_callback( + HTTPResponse( + self.request, + 599, + error=value, + request_time=self.io_loop.time() - self.start_time, + start_time=self.start_wall_time, + ) + ) if hasattr(self, "stream"): # TODO: this may cause a StreamClosedError to be raised @@ -466,7 +580,7 @@ def _handle_exception(self, typ, value, tb): # pass it along, unless it's just the stream being closed. return isinstance(value, StreamClosedError) - def on_connection_close(self): + def on_connection_close(self) -> None: if self.final_callback is not None: message = "Connection closed" if self.stream.error: @@ -476,9 +590,14 @@ def on_connection_close(self): except HTTPStreamClosedError: self._handle_exception(*sys.exc_info()) - def headers_received(self, first_line, headers): + async def headers_received( + self, + first_line: Union[httputil.ResponseStartLine, httputil.RequestStartLine], + headers: httputil.HTTPHeaders, + ) -> None: + assert isinstance(first_line, httputil.ResponseStartLine) if self.request.expect_100_continue and first_line.code == 100: - self._write_body(False) + await self._write_body(False) return self.code = first_line.code self.reason = first_line.reason @@ -489,48 +608,66 @@ def headers_received(self, first_line, headers): if self.request.header_callback is not None: # Reassemble the start line. - self.request.header_callback('%s %s %s\r\n' % first_line) + self.request.header_callback("%s %s %s\r\n" % first_line) for k, v in self.headers.get_all(): self.request.header_callback("%s: %s\r\n" % (k, v)) - self.request.header_callback('\r\n') - - def _should_follow_redirect(self): - return (self.request.follow_redirects and - self.request.max_redirects > 0 and - self.code in (301, 302, 303, 307, 308)) - - def finish(self): - data = b''.join(self.chunks) + self.request.header_callback("\r\n") + + def _should_follow_redirect(self) -> bool: + if self.request.follow_redirects: + assert self.request.max_redirects is not None + return ( + self.code in (301, 302, 303, 307, 308) + and self.request.max_redirects > 0 + and self.headers is not None + and self.headers.get("Location") is not None + ) + return False + + def finish(self) -> None: + assert self.code is not None + data = b"".join(self.chunks) self._remove_timeout() - original_request = getattr(self.request, "original_request", - self.request) + original_request = getattr(self.request, "original_request", self.request) if self._should_follow_redirect(): assert isinstance(self.request, _RequestProxy) + assert self.headers is not None new_request = copy.copy(self.request.request) - new_request.url = urlparse.urljoin(self.request.url, - self.headers["Location"]) + new_request.url = urllib.parse.urljoin( + self.request.url, self.headers["Location"] + ) + assert self.request.max_redirects is not None new_request.max_redirects = self.request.max_redirects - 1 del new_request.headers["Host"] - # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4 - # Client SHOULD make a GET request after a 303. - # According to the spec, 302 should be followed by the same - # method as the original request, but in practice browsers - # treat 302 the same as 303, and many servers use 302 for - # compatibility with pre-HTTP/1.1 user agents which don't - # understand the 303 status. - if self.code in (302, 303): + # https://tools.ietf.org/html/rfc7231#section-6.4 + # + # The original HTTP spec said that after a 301 or 302 + # redirect, the request method should be preserved. + # However, browsers implemented this by changing the + # method to GET, and the behavior stuck. 303 redirects + # always specified this POST-to-GET behavior, arguably + # for *all* methods, but libcurl < 7.70 only does this + # for POST, while libcurl >= 7.70 does it for other methods. + if (self.code == 303 and self.request.method != "HEAD") or ( + self.code in (301, 302) and self.request.method == "POST" + ): new_request.method = "GET" - new_request.body = None - for h in ["Content-Length", "Content-Type", - "Content-Encoding", "Transfer-Encoding"]: + new_request.body = None # type: ignore + for h in [ + "Content-Length", + "Content-Type", + "Content-Encoding", + "Transfer-Encoding", + ]: try: del self.request.headers[h] except KeyError: pass - new_request.original_request = original_request + new_request.original_request = original_request # type: ignore final_callback = self.final_callback - self.final_callback = None + self.final_callback = None # type: ignore self._release() + assert self.client is not None fut = self.client.fetch(new_request, raise_error=False) fut.add_done_callback(lambda f: final_callback(f.result())) self._on_end_request() @@ -539,19 +676,23 @@ def finish(self): buffer = BytesIO() else: buffer = BytesIO(data) # TODO: don't require one big string? - response = HTTPResponse(original_request, - self.code, reason=getattr(self, 'reason', None), - headers=self.headers, - request_time=self.io_loop.time() - self.start_time, - buffer=buffer, - effective_url=self.request.url) + response = HTTPResponse( + original_request, + self.code, + reason=getattr(self, "reason", None), + headers=self.headers, + request_time=self.io_loop.time() - self.start_time, + start_time=self.start_wall_time, + buffer=buffer, + effective_url=self.request.url, + ) self._run_callback(response) self._on_end_request() - def _on_end_request(self): + def _on_end_request(self) -> None: self.stream.close() - def data_received(self, chunk): + def data_received(self, chunk: bytes) -> None: if self._should_follow_redirect(): # We're going to follow a redirect so just discard the body. return diff --git a/tornado/speedups.c b/tornado/speedups.c index b714268ab4..525d66034c 100644 --- a/tornado/speedups.c +++ b/tornado/speedups.c @@ -56,7 +56,6 @@ static PyMethodDef methods[] = { {NULL, NULL, 0, NULL} }; -#if PY_MAJOR_VERSION >= 3 static struct PyModuleDef speedupsmodule = { PyModuleDef_HEAD_INIT, "speedups", @@ -69,9 +68,3 @@ PyMODINIT_FUNC PyInit_speedups(void) { return PyModule_Create(&speedupsmodule); } -#else // Python 2.x -PyMODINIT_FUNC -initspeedups(void) { - Py_InitModule("tornado.speedups", methods); -} -#endif diff --git a/tornado/stack_context.py b/tornado/stack_context.py deleted file mode 100644 index 2f26f3845f..0000000000 --- a/tornado/stack_context.py +++ /dev/null @@ -1,389 +0,0 @@ -# -# Copyright 2010 Facebook -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -"""`StackContext` allows applications to maintain threadlocal-like state -that follows execution as it moves to other execution contexts. - -The motivating examples are to eliminate the need for explicit -``async_callback`` wrappers (as in `tornado.web.RequestHandler`), and to -allow some additional context to be kept for logging. - -This is slightly magic, but it's an extension of the idea that an -exception handler is a kind of stack-local state and when that stack -is suspended and resumed in a new context that state needs to be -preserved. `StackContext` shifts the burden of restoring that state -from each call site (e.g. wrapping each `.AsyncHTTPClient` callback -in ``async_callback``) to the mechanisms that transfer control from -one context to another (e.g. `.AsyncHTTPClient` itself, `.IOLoop`, -thread pools, etc). - -Example usage:: - - @contextlib.contextmanager - def die_on_error(): - try: - yield - except Exception: - logging.error("exception in asynchronous operation",exc_info=True) - sys.exit(1) - - with StackContext(die_on_error): - # Any exception thrown here *or in callback and its descendants* - # will cause the process to exit instead of spinning endlessly - # in the ioloop. - http_client.fetch(url, callback) - ioloop.start() - -Most applications shouldn't have to work with `StackContext` directly. -Here are a few rules of thumb for when it's necessary: - -* If you're writing an asynchronous library that doesn't rely on a - stack_context-aware library like `tornado.ioloop` or `tornado.iostream` - (for example, if you're writing a thread pool), use - `.stack_context.wrap()` before any asynchronous operations to capture the - stack context from where the operation was started. - -* If you're writing an asynchronous library that has some shared - resources (such as a connection pool), create those shared resources - within a ``with stack_context.NullContext():`` block. This will prevent - ``StackContexts`` from leaking from one request to another. - -* If you want to write something like an exception handler that will - persist across asynchronous calls, create a new `StackContext` (or - `ExceptionStackContext`), and make your asynchronous calls in a ``with`` - block that references your `StackContext`. -""" - -from __future__ import absolute_import, division, print_function - -import sys -import threading - -from tornado.util import raise_exc_info - - -class StackContextInconsistentError(Exception): - pass - - -class _State(threading.local): - def __init__(self): - self.contexts = (tuple(), None) - - -_state = _State() - - -class StackContext(object): - """Establishes the given context as a StackContext that will be transferred. - - Note that the parameter is a callable that returns a context - manager, not the context itself. That is, where for a - non-transferable context manager you would say:: - - with my_context(): - - StackContext takes the function itself rather than its result:: - - with StackContext(my_context): - - The result of ``with StackContext() as cb:`` is a deactivation - callback. Run this callback when the StackContext is no longer - needed to ensure that it is not propagated any further (note that - deactivating a context does not affect any instances of that - context that are currently pending). This is an advanced feature - and not necessary in most applications. - """ - def __init__(self, context_factory): - self.context_factory = context_factory - self.contexts = [] - self.active = True - - def _deactivate(self): - self.active = False - - # StackContext protocol - def enter(self): - context = self.context_factory() - self.contexts.append(context) - context.__enter__() - - def exit(self, type, value, traceback): - context = self.contexts.pop() - context.__exit__(type, value, traceback) - - # Note that some of this code is duplicated in ExceptionStackContext - # below. ExceptionStackContext is more common and doesn't need - # the full generality of this class. - def __enter__(self): - self.old_contexts = _state.contexts - self.new_contexts = (self.old_contexts[0] + (self,), self) - _state.contexts = self.new_contexts - - try: - self.enter() - except: - _state.contexts = self.old_contexts - raise - - return self._deactivate - - def __exit__(self, type, value, traceback): - try: - self.exit(type, value, traceback) - finally: - final_contexts = _state.contexts - _state.contexts = self.old_contexts - - # Generator coroutines and with-statements with non-local - # effects interact badly. Check here for signs of - # the stack getting out of sync. - # Note that this check comes after restoring _state.context - # so that if it fails things are left in a (relatively) - # consistent state. - if final_contexts is not self.new_contexts: - raise StackContextInconsistentError( - 'stack_context inconsistency (may be caused by yield ' - 'within a "with StackContext" block)') - - # Break up a reference to itself to allow for faster GC on CPython. - self.new_contexts = None - - -class ExceptionStackContext(object): - """Specialization of StackContext for exception handling. - - The supplied ``exception_handler`` function will be called in the - event of an uncaught exception in this context. The semantics are - similar to a try/finally clause, and intended use cases are to log - an error, close a socket, or similar cleanup actions. The - ``exc_info`` triple ``(type, value, traceback)`` will be passed to the - exception_handler function. - - If the exception handler returns true, the exception will be - consumed and will not be propagated to other exception handlers. - """ - def __init__(self, exception_handler): - self.exception_handler = exception_handler - self.active = True - - def _deactivate(self): - self.active = False - - def exit(self, type, value, traceback): - if type is not None: - return self.exception_handler(type, value, traceback) - - def __enter__(self): - self.old_contexts = _state.contexts - self.new_contexts = (self.old_contexts[0], self) - _state.contexts = self.new_contexts - - return self._deactivate - - def __exit__(self, type, value, traceback): - try: - if type is not None: - return self.exception_handler(type, value, traceback) - finally: - final_contexts = _state.contexts - _state.contexts = self.old_contexts - - if final_contexts is not self.new_contexts: - raise StackContextInconsistentError( - 'stack_context inconsistency (may be caused by yield ' - 'within a "with StackContext" block)') - - # Break up a reference to itself to allow for faster GC on CPython. - self.new_contexts = None - - -class NullContext(object): - """Resets the `StackContext`. - - Useful when creating a shared resource on demand (e.g. an - `.AsyncHTTPClient`) where the stack that caused the creating is - not relevant to future operations. - """ - def __enter__(self): - self.old_contexts = _state.contexts - _state.contexts = (tuple(), None) - - def __exit__(self, type, value, traceback): - _state.contexts = self.old_contexts - - -def _remove_deactivated(contexts): - """Remove deactivated handlers from the chain""" - # Clean ctx handlers - stack_contexts = tuple([h for h in contexts[0] if h.active]) - - # Find new head - head = contexts[1] - while head is not None and not head.active: - head = head.old_contexts[1] - - # Process chain - ctx = head - while ctx is not None: - parent = ctx.old_contexts[1] - - while parent is not None: - if parent.active: - break - ctx.old_contexts = parent.old_contexts - parent = parent.old_contexts[1] - - ctx = parent - - return (stack_contexts, head) - - -def wrap(fn): - """Returns a callable object that will restore the current `StackContext` - when executed. - - Use this whenever saving a callback to be executed later in a - different execution context (either in a different thread or - asynchronously in the same thread). - """ - # Check if function is already wrapped - if fn is None or hasattr(fn, '_wrapped'): - return fn - - # Capture current stack head - # TODO: Any other better way to store contexts and update them in wrapped function? - cap_contexts = [_state.contexts] - - if not cap_contexts[0][0] and not cap_contexts[0][1]: - # Fast path when there are no active contexts. - def null_wrapper(*args, **kwargs): - try: - current_state = _state.contexts - _state.contexts = cap_contexts[0] - return fn(*args, **kwargs) - finally: - _state.contexts = current_state - null_wrapper._wrapped = True - return null_wrapper - - def wrapped(*args, **kwargs): - ret = None - try: - # Capture old state - current_state = _state.contexts - - # Remove deactivated items - cap_contexts[0] = contexts = _remove_deactivated(cap_contexts[0]) - - # Force new state - _state.contexts = contexts - - # Current exception - exc = (None, None, None) - top = None - - # Apply stack contexts - last_ctx = 0 - stack = contexts[0] - - # Apply state - for n in stack: - try: - n.enter() - last_ctx += 1 - except: - # Exception happened. Record exception info and store top-most handler - exc = sys.exc_info() - top = n.old_contexts[1] - - # Execute callback if no exception happened while restoring state - if top is None: - try: - ret = fn(*args, **kwargs) - except: - exc = sys.exc_info() - top = contexts[1] - - # If there was exception, try to handle it by going through the exception chain - if top is not None: - exc = _handle_exception(top, exc) - else: - # Otherwise take shorter path and run stack contexts in reverse order - while last_ctx > 0: - last_ctx -= 1 - c = stack[last_ctx] - - try: - c.exit(*exc) - except: - exc = sys.exc_info() - top = c.old_contexts[1] - break - else: - top = None - - # If if exception happened while unrolling, take longer exception handler path - if top is not None: - exc = _handle_exception(top, exc) - - # If exception was not handled, raise it - if exc != (None, None, None): - raise_exc_info(exc) - finally: - _state.contexts = current_state - return ret - - wrapped._wrapped = True - return wrapped - - -def _handle_exception(tail, exc): - while tail is not None: - try: - if tail.exit(*exc): - exc = (None, None, None) - except: - exc = sys.exc_info() - - tail = tail.old_contexts[1] - - return exc - - -def run_with_stack_context(context, func): - """Run a coroutine ``func`` in the given `StackContext`. - - It is not safe to have a ``yield`` statement within a ``with StackContext`` - block, so it is difficult to use stack context with `.gen.coroutine`. - This helper function runs the function in the correct context while - keeping the ``yield`` and ``with`` statements syntactically separate. - - Example:: - - @gen.coroutine - def incorrect(): - with StackContext(ctx): - # ERROR: this will raise StackContextInconsistentError - yield other_coroutine() - - @gen.coroutine - def correct(): - yield run_with_stack_context(StackContext(ctx), other_coroutine) - - .. versionadded:: 3.1 - """ - with context: - return func() diff --git a/tornado/tcpclient.py b/tornado/tcpclient.py index 3a1b58ca86..e2d682ea64 100644 --- a/tornado/tcpclient.py +++ b/tornado/tcpclient.py @@ -15,21 +15,21 @@ """A non-blocking TCP connection factory. """ -from __future__ import absolute_import, division, print_function import functools import socket import numbers import datetime +import ssl from tornado.concurrent import Future, future_add_done_callback from tornado.ioloop import IOLoop from tornado.iostream import IOStream from tornado import gen from tornado.netutil import Resolver -from tornado.platform.auto import set_close_exec from tornado.gen import TimeoutError -from tornado.util import timedelta_to_seconds + +from typing import Any, Union, Dict, Tuple, List, Callable, Iterator, Optional, Set _INITIAL_CONNECT_TIMEOUT = 0.3 @@ -51,20 +51,34 @@ class _Connector(object): http://tools.ietf.org/html/rfc6555 """ - def __init__(self, addrinfo, connect): + + def __init__( + self, + addrinfo: List[Tuple], + connect: Callable[ + [socket.AddressFamily, Tuple], Tuple[IOStream, "Future[IOStream]"] + ], + ) -> None: self.io_loop = IOLoop.current() self.connect = connect - self.future = Future() - self.timeout = None - self.connect_timeout = None - self.last_error = None + self.future = ( + Future() + ) # type: Future[Tuple[socket.AddressFamily, Any, IOStream]] + self.timeout = None # type: Optional[object] + self.connect_timeout = None # type: Optional[object] + self.last_error = None # type: Optional[Exception] self.remaining = len(addrinfo) self.primary_addrs, self.secondary_addrs = self.split(addrinfo) - self.streams = set() + self.streams = set() # type: Set[IOStream] @staticmethod - def split(addrinfo): + def split( + addrinfo: List[Tuple], + ) -> Tuple[ + List[Tuple[socket.AddressFamily, Tuple]], + List[Tuple[socket.AddressFamily, Tuple]], + ]: """Partition the ``addrinfo`` list by address family. Returns two lists. The first list contains the first entry from @@ -83,14 +97,18 @@ def split(addrinfo): secondary.append((af, addr)) return primary, secondary - def start(self, timeout=_INITIAL_CONNECT_TIMEOUT, connect_timeout=None): + def start( + self, + timeout: float = _INITIAL_CONNECT_TIMEOUT, + connect_timeout: Optional[Union[float, datetime.timedelta]] = None, + ) -> "Future[Tuple[socket.AddressFamily, Any, IOStream]]": self.try_connect(iter(self.primary_addrs)) self.set_timeout(timeout) if connect_timeout is not None: self.set_connect_timeout(connect_timeout) return self.future - def try_connect(self, addrs): + def try_connect(self, addrs: Iterator[Tuple[socket.AddressFamily, Tuple]]) -> None: try: af, addr = next(addrs) except StopIteration: @@ -98,15 +116,23 @@ def try_connect(self, addrs): # might still be working. Send a final error on the future # only when both queues are finished. if self.remaining == 0 and not self.future.done(): - self.future.set_exception(self.last_error or - IOError("connection failed")) + self.future.set_exception( + self.last_error or IOError("connection failed") + ) return stream, future = self.connect(af, addr) self.streams.add(stream) future_add_done_callback( - future, functools.partial(self.on_connect_done, addrs, af, addr)) + future, functools.partial(self.on_connect_done, addrs, af, addr) + ) - def on_connect_done(self, addrs, af, addr, future): + def on_connect_done( + self, + addrs: Iterator[Tuple[socket.AddressFamily, Tuple]], + af: socket.AddressFamily, + addr: Tuple, + future: "Future[IOStream]", + ) -> None: self.remaining -= 1 try: stream = future.result() @@ -132,35 +158,39 @@ def on_connect_done(self, addrs, af, addr, future): self.future.set_result((af, addr, stream)) self.close_streams() - def set_timeout(self, timeout): - self.timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout, - self.on_timeout) + def set_timeout(self, timeout: float) -> None: + self.timeout = self.io_loop.add_timeout( + self.io_loop.time() + timeout, self.on_timeout + ) - def on_timeout(self): + def on_timeout(self) -> None: self.timeout = None if not self.future.done(): self.try_connect(iter(self.secondary_addrs)) - def clear_timeout(self): + def clear_timeout(self) -> None: if self.timeout is not None: self.io_loop.remove_timeout(self.timeout) - def set_connect_timeout(self, connect_timeout): + def set_connect_timeout( + self, connect_timeout: Union[float, datetime.timedelta] + ) -> None: self.connect_timeout = self.io_loop.add_timeout( - connect_timeout, self.on_connect_timeout) + connect_timeout, self.on_connect_timeout + ) - def on_connect_timeout(self): + def on_connect_timeout(self) -> None: if not self.future.done(): self.future.set_exception(TimeoutError()) self.close_streams() - def clear_timeouts(self): + def clear_timeouts(self) -> None: if self.timeout is not None: self.io_loop.remove_timeout(self.timeout) if self.connect_timeout is not None: self.io_loop.remove_timeout(self.connect_timeout) - def close_streams(self): + def close_streams(self) -> None: for stream in self.streams: stream.close() @@ -171,7 +201,8 @@ class TCPClient(object): .. versionchanged:: 5.0 The ``io_loop`` argument (deprecated since version 4.1) has been removed. """ - def __init__(self, resolver=None): + + def __init__(self, resolver: Optional[Resolver] = None) -> None: if resolver is not None: self.resolver = resolver self._own_resolver = False @@ -179,14 +210,21 @@ def __init__(self, resolver=None): self.resolver = Resolver() self._own_resolver = True - def close(self): + def close(self) -> None: if self._own_resolver: self.resolver.close() - @gen.coroutine - def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None, - max_buffer_size=None, source_ip=None, source_port=None, - timeout=None): + async def connect( + self, + host: str, + port: int, + af: socket.AddressFamily = socket.AF_UNSPEC, + ssl_options: Optional[Union[Dict[str, Any], ssl.SSLContext]] = None, + max_buffer_size: Optional[int] = None, + source_ip: Optional[str] = None, + source_port: Optional[int] = None, + timeout: Optional[Union[float, datetime.timedelta]] = None, + ) -> IOStream: """Connect to the given host and port. Asynchronously returns an `.IOStream` (or `.SSLIOStream` if @@ -216,34 +254,50 @@ def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None, if isinstance(timeout, numbers.Real): timeout = IOLoop.current().time() + timeout elif isinstance(timeout, datetime.timedelta): - timeout = IOLoop.current().time() + timedelta_to_seconds(timeout) + timeout = IOLoop.current().time() + timeout.total_seconds() else: raise TypeError("Unsupported timeout %r" % timeout) if timeout is not None: - addrinfo = yield gen.with_timeout( - timeout, self.resolver.resolve(host, port, af)) + addrinfo = await gen.with_timeout( + timeout, self.resolver.resolve(host, port, af) + ) else: - addrinfo = yield self.resolver.resolve(host, port, af) + addrinfo = await self.resolver.resolve(host, port, af) connector = _Connector( addrinfo, - functools.partial(self._create_stream, max_buffer_size, - source_ip=source_ip, source_port=source_port) + functools.partial( + self._create_stream, + max_buffer_size, + source_ip=source_ip, + source_port=source_port, + ), ) - af, addr, stream = yield connector.start(connect_timeout=timeout) + af, addr, stream = await connector.start(connect_timeout=timeout) # TODO: For better performance we could cache the (af, addr) # information here and re-use it on subsequent connections to # the same host. (http://tools.ietf.org/html/rfc6555#section-4.2) if ssl_options is not None: if timeout is not None: - stream = yield gen.with_timeout(timeout, stream.start_tls( - False, ssl_options=ssl_options, server_hostname=host)) + stream = await gen.with_timeout( + timeout, + stream.start_tls( + False, ssl_options=ssl_options, server_hostname=host + ), + ) else: - stream = yield stream.start_tls(False, ssl_options=ssl_options, - server_hostname=host) - raise gen.Return(stream) - - def _create_stream(self, max_buffer_size, af, addr, source_ip=None, - source_port=None): + stream = await stream.start_tls( + False, ssl_options=ssl_options, server_hostname=host + ) + return stream + + def _create_stream( + self, + max_buffer_size: int, + af: socket.AddressFamily, + addr: Tuple, + source_ip: Optional[str] = None, + source_port: Optional[int] = None, + ) -> Tuple[IOStream, "Future[IOStream]"]: # Always connect in plaintext; we'll convert to ssl if necessary # after one connection has completed. source_port_bind = source_port if isinstance(source_port, int) else 0 @@ -251,12 +305,11 @@ def _create_stream(self, max_buffer_size, af, addr, source_ip=None, if source_port_bind and not source_ip: # User required a specific port, but did not specify # a certain source IP, will bind to the default loopback. - source_ip_bind = '::1' if af == socket.AF_INET6 else '127.0.0.1' + source_ip_bind = "::1" if af == socket.AF_INET6 else "127.0.0.1" # Trying to use the same address family as the requested af socket: # - 127.0.0.1 for IPv4 # - ::1 for IPv6 socket_obj = socket.socket(af) - set_close_exec(socket_obj.fileno()) if source_port_bind or source_ip_bind: # If the user requires binding also to a specific IP/port. try: @@ -266,11 +319,10 @@ def _create_stream(self, max_buffer_size, af, addr, source_ip=None, # Fail loudly if unable to use the IP/port. raise try: - stream = IOStream(socket_obj, - max_buffer_size=max_buffer_size) + stream = IOStream(socket_obj, max_buffer_size=max_buffer_size) except socket.error as e: - fu = Future() + fu = Future() # type: Future[IOStream] fu.set_exception(e) - return fu + return stream, fu else: return stream, stream.connect(addr) diff --git a/tornado/tcpserver.py b/tornado/tcpserver.py index 1adc489560..476ffc936f 100644 --- a/tornado/tcpserver.py +++ b/tornado/tcpserver.py @@ -14,11 +14,11 @@ # under the License. """A non-blocking, single-threaded TCP server.""" -from __future__ import absolute_import, division, print_function import errno import os import socket +import ssl from tornado import gen from tornado.log import app_log @@ -28,11 +28,11 @@ from tornado import process from tornado.util import errno_from_exception -try: - import ssl -except ImportError: - # ssl is not available on Google App Engine. - ssl = None +import typing +from typing import Union, Dict, Any, Iterable, Optional, Awaitable + +if typing.TYPE_CHECKING: + from typing import Callable, List # noqa: F401 class TCPServer(object): @@ -46,12 +46,11 @@ class TCPServer(object): from tornado import gen class EchoServer(TCPServer): - @gen.coroutine - def handle_stream(self, stream, address): + async def handle_stream(self, stream, address): while True: try: - data = yield stream.read_until(b"\n") - yield stream.write(data) + data = await stream.read_until(b"\n") + await stream.write(data) except StreamClosedError: break @@ -105,12 +104,17 @@ def handle_stream(self, stream, address): .. versionchanged:: 5.0 The ``io_loop`` argument has been removed. """ - def __init__(self, ssl_options=None, max_buffer_size=None, - read_chunk_size=None): + + def __init__( + self, + ssl_options: Optional[Union[Dict[str, Any], ssl.SSLContext]] = None, + max_buffer_size: Optional[int] = None, + read_chunk_size: Optional[int] = None, + ) -> None: self.ssl_options = ssl_options - self._sockets = {} # fd -> socket object - self._handlers = {} # fd -> remove_handler callable - self._pending_sockets = [] + self._sockets = {} # type: Dict[int, socket.socket] + self._handlers = {} # type: Dict[int, Callable[[], None]] + self._pending_sockets = [] # type: List[socket.socket] self._started = False self._stopped = False self.max_buffer_size = max_buffer_size @@ -122,18 +126,21 @@ def __init__(self, ssl_options=None, max_buffer_size=None, # which seems like too much work if self.ssl_options is not None and isinstance(self.ssl_options, dict): # Only certfile is required: it can contain both keys - if 'certfile' not in self.ssl_options: + if "certfile" not in self.ssl_options: raise KeyError('missing key "certfile" in ssl_options') - if not os.path.exists(self.ssl_options['certfile']): - raise ValueError('certfile "%s" does not exist' % - self.ssl_options['certfile']) - if ('keyfile' in self.ssl_options and - not os.path.exists(self.ssl_options['keyfile'])): - raise ValueError('keyfile "%s" does not exist' % - self.ssl_options['keyfile']) - - def listen(self, port, address=""): + if not os.path.exists(self.ssl_options["certfile"]): + raise ValueError( + 'certfile "%s" does not exist' % self.ssl_options["certfile"] + ) + if "keyfile" in self.ssl_options and not os.path.exists( + self.ssl_options["keyfile"] + ): + raise ValueError( + 'keyfile "%s" does not exist' % self.ssl_options["keyfile"] + ) + + def listen(self, port: int, address: str = "") -> None: """Starts accepting connections on the given port. This method may be called more than once to listen on multiple ports. @@ -144,7 +151,7 @@ def listen(self, port, address=""): sockets = bind_sockets(port, address=address) self.add_sockets(sockets) - def add_sockets(self, sockets): + def add_sockets(self, sockets: Iterable[socket.socket]) -> None: """Makes this server start accepting connections on the given sockets. The ``sockets`` parameter is a list of socket objects such as @@ -156,14 +163,21 @@ def add_sockets(self, sockets): for sock in sockets: self._sockets[sock.fileno()] = sock self._handlers[sock.fileno()] = add_accept_handler( - sock, self._handle_connection) + sock, self._handle_connection + ) - def add_socket(self, socket): + def add_socket(self, socket: socket.socket) -> None: """Singular version of `add_sockets`. Takes a single socket object.""" self.add_sockets([socket]) - def bind(self, port, address=None, family=socket.AF_UNSPEC, backlog=128, - reuse_port=False): + def bind( + self, + port: int, + address: Optional[str] = None, + family: socket.AddressFamily = socket.AF_UNSPEC, + backlog: int = 128, + reuse_port: bool = False, + ) -> None: """Binds this server to the given port on the given address. To start the server, call `start`. If you want to run this server @@ -187,14 +201,17 @@ def bind(self, port, address=None, family=socket.AF_UNSPEC, backlog=128, .. versionchanged:: 4.4 Added the ``reuse_port`` argument. """ - sockets = bind_sockets(port, address=address, family=family, - backlog=backlog, reuse_port=reuse_port) + sockets = bind_sockets( + port, address=address, family=family, backlog=backlog, reuse_port=reuse_port + ) if self._started: self.add_sockets(sockets) else: self._pending_sockets.extend(sockets) - def start(self, num_processes=1): + def start( + self, num_processes: Optional[int] = 1, max_restarts: Optional[int] = None + ) -> None: """Starts this server in the `.IOLoop`. By default, we run the server in this process and do not fork any @@ -213,16 +230,24 @@ def start(self, num_processes=1): which defaults to True when ``debug=True``). When using multiple processes, no IOLoops can be created or referenced until after the call to ``TCPServer.start(n)``. + + Values of ``num_processes`` other than 1 are not supported on Windows. + + The ``max_restarts`` argument is passed to `.fork_processes`. + + .. versionchanged:: 6.0 + + Added ``max_restarts`` argument. """ assert not self._started self._started = True if num_processes != 1: - process.fork_processes(num_processes) + process.fork_processes(num_processes, max_restarts) sockets = self._pending_sockets self._pending_sockets = [] self.add_sockets(sockets) - def stop(self): + def stop(self) -> None: """Stops listening for new connections. Requests currently in progress may still continue after the @@ -237,7 +262,9 @@ def stop(self): self._handlers.pop(fd)() sock.close() - def handle_stream(self, stream, address): + def handle_stream( + self, stream: IOStream, address: tuple + ) -> Optional[Awaitable[None]]: """Override to handle a new `.IOStream` from an incoming connection. This method may be a coroutine; if so any exceptions it raises @@ -254,14 +281,16 @@ def handle_stream(self, stream, address): """ raise NotImplementedError() - def _handle_connection(self, connection, address): + def _handle_connection(self, connection: socket.socket, address: Any) -> None: if self.ssl_options is not None: assert ssl, "Python 2.6+ and OpenSSL required for SSL" try: - connection = ssl_wrap_socket(connection, - self.ssl_options, - server_side=True, - do_handshake_on_connect=False) + connection = ssl_wrap_socket( + connection, + self.ssl_options, + server_side=True, + do_handshake_on_connect=False, + ) except ssl.SSLError as err: if err.args[0] == ssl.SSL_ERROR_EOF: return connection.close() @@ -284,17 +313,22 @@ def _handle_connection(self, connection, address): raise try: if self.ssl_options is not None: - stream = SSLIOStream(connection, - max_buffer_size=self.max_buffer_size, - read_chunk_size=self.read_chunk_size) + stream = SSLIOStream( + connection, + max_buffer_size=self.max_buffer_size, + read_chunk_size=self.read_chunk_size, + ) # type: IOStream else: - stream = IOStream(connection, - max_buffer_size=self.max_buffer_size, - read_chunk_size=self.read_chunk_size) + stream = IOStream( + connection, + max_buffer_size=self.max_buffer_size, + read_chunk_size=self.read_chunk_size, + ) future = self.handle_stream(stream, address) if future is not None: - IOLoop.current().add_future(gen.convert_yielded(future), - lambda f: f.result()) + IOLoop.current().add_future( + gen.convert_yielded(future), lambda f: f.result() + ) except Exception: app_log.error("Error in connection callback", exc_info=True) diff --git a/tornado/template.py b/tornado/template.py index 61b987462c..d53e977c5e 100644 --- a/tornado/template.py +++ b/tornado/template.py @@ -98,8 +98,9 @@ def add(x, y): To comment out a section so that it is omitted from the output, surround it with ``{# ... #}``. -These tags may be escaped as ``{{!``, ``{%!``, and ``{#!`` -if you need to include a literal ``{{``, ``{%``, or ``{#`` in the output. + +To include a literal ``{{``, ``{%``, or ``{#`` in the output, escape them as +``{{!``, ``{%!``, and ``{#!``, respectively. ``{% apply *function* %}...{% end %}`` @@ -195,9 +196,8 @@ class (and specifically its ``render`` method) and will not work `filter_whitespace` for available options. New in Tornado 4.3. """ -from __future__ import absolute_import, division, print_function - import datetime +from io import StringIO import linecache import os.path import posixpath @@ -206,18 +206,25 @@ class (and specifically its ``render`` method) and will not work from tornado import escape from tornado.log import app_log -from tornado.util import ObjectDict, exec_in, unicode_type, PY3 +from tornado.util import ObjectDict, exec_in, unicode_type -if PY3: - from io import StringIO -else: - from cStringIO import StringIO +from typing import Any, Union, Callable, List, Dict, Iterable, Optional, TextIO +import typing + +if typing.TYPE_CHECKING: + from typing import Tuple, ContextManager # noqa: F401 _DEFAULT_AUTOESCAPE = "xhtml_escape" -_UNSET = object() -def filter_whitespace(mode, text): +class _UnsetMarker: + pass + + +_UNSET = _UnsetMarker() + + +def filter_whitespace(mode: str, text: str) -> str: """Transform whitespace in ``text`` according to ``mode``. Available modes are: @@ -230,13 +237,13 @@ def filter_whitespace(mode, text): .. versionadded:: 4.3 """ - if mode == 'all': + if mode == "all": return text - elif mode == 'single': + elif mode == "single": text = re.sub(r"([\t ]+)", " ", text) text = re.sub(r"(\s*\n\s*)", "\n", text) return text - elif mode == 'oneline': + elif mode == "oneline": return re.sub(r"(\s+)", " ", text) else: raise Exception("invalid whitespace mode %s" % mode) @@ -248,12 +255,19 @@ class Template(object): We compile into Python from the given template_string. You can generate the template from variables with generate(). """ + # note that the constructor's signature is not extracted with # autodoc because _UNSET looks like garbage. When changing # this signature update website/sphinx/template.rst too. - def __init__(self, template_string, name="", loader=None, - compress_whitespace=_UNSET, autoescape=_UNSET, - whitespace=None): + def __init__( + self, + template_string: Union[str, bytes], + name: str = "", + loader: Optional["BaseLoader"] = None, + compress_whitespace: Union[bool, _UnsetMarker] = _UNSET, + autoescape: Optional[Union[str, _UnsetMarker]] = _UNSET, + whitespace: Optional[str] = None, + ) -> None: """Construct a Template. :arg str template_string: the contents of the template file. @@ -289,18 +303,18 @@ def __init__(self, template_string, name="", loader=None, else: whitespace = "all" # Validate the whitespace setting. - filter_whitespace(whitespace, '') + assert whitespace is not None + filter_whitespace(whitespace, "") - if autoescape is not _UNSET: - self.autoescape = autoescape + if not isinstance(autoescape, _UnsetMarker): + self.autoescape = autoescape # type: Optional[str] elif loader: self.autoescape = loader.autoescape else: self.autoescape = _DEFAULT_AUTOESCAPE self.namespace = loader.namespace if loader else {} - reader = _TemplateReader(name, escape.native_str(template_string), - whitespace) + reader = _TemplateReader(name, escape.native_str(template_string), whitespace) self.file = _File(self, _parse(reader, self)) self.code = self._generate_python(loader) self.loader = loader @@ -311,14 +325,16 @@ def __init__(self, template_string, name="", loader=None, # from being applied to the generated code. self.compiled = compile( escape.to_unicode(self.code), - "%s.generated.py" % self.name.replace('.', '_'), - "exec", dont_inherit=True) + "%s.generated.py" % self.name.replace(".", "_"), + "exec", + dont_inherit=True, + ) except Exception: formatted_code = _format_code(self.code).rstrip() app_log.error("%s code:\n%s", self.name, formatted_code) raise - def generate(self, **kwargs): + def generate(self, **kwargs: Any) -> bytes: """Generate this template with the given arguments.""" namespace = { "escape": escape.xhtml_escape, @@ -332,42 +348,42 @@ def generate(self, **kwargs): "_tt_string_types": (unicode_type, bytes), # __name__ and __loader__ allow the traceback mechanism to find # the generated source code. - "__name__": self.name.replace('.', '_'), + "__name__": self.name.replace(".", "_"), "__loader__": ObjectDict(get_source=lambda name: self.code), } namespace.update(self.namespace) namespace.update(kwargs) exec_in(self.compiled, namespace) - execute = namespace["_tt_execute"] + execute = typing.cast(Callable[[], bytes], namespace["_tt_execute"]) # Clear the traceback module's cache of source data now that # we've generated a new template (mainly for this module's # unittests, where different tests reuse the same name). linecache.clearcache() return execute() - def _generate_python(self, loader): + def _generate_python(self, loader: Optional["BaseLoader"]) -> str: buffer = StringIO() try: # named_blocks maps from names to _NamedBlock objects - named_blocks = {} + named_blocks = {} # type: Dict[str, _NamedBlock] ancestors = self._get_ancestors(loader) ancestors.reverse() for ancestor in ancestors: ancestor.find_named_blocks(loader, named_blocks) - writer = _CodeWriter(buffer, named_blocks, loader, - ancestors[0].template) + writer = _CodeWriter(buffer, named_blocks, loader, ancestors[0].template) ancestors[0].generate(writer) return buffer.getvalue() finally: buffer.close() - def _get_ancestors(self, loader): + def _get_ancestors(self, loader: Optional["BaseLoader"]) -> List["_File"]: ancestors = [self.file] for chunk in self.file.body.chunks: if isinstance(chunk, _ExtendsBlock): if not loader: - raise ParseError("{% extends %} block found, but no " - "template loader") + raise ParseError( + "{% extends %} block found, but no " "template loader" + ) template = loader.load(chunk.name, self.name) ancestors.extend(template._get_ancestors(loader)) return ancestors @@ -380,8 +396,13 @@ class BaseLoader(object): ``{% extends %}`` and ``{% include %}``. The loader caches all templates after they are loaded the first time. """ - def __init__(self, autoescape=_DEFAULT_AUTOESCAPE, namespace=None, - whitespace=None): + + def __init__( + self, + autoescape: str = _DEFAULT_AUTOESCAPE, + namespace: Optional[Dict[str, Any]] = None, + whitespace: Optional[str] = None, + ) -> None: """Construct a template loader. :arg str autoescape: The name of a function in the template @@ -400,7 +421,7 @@ def __init__(self, autoescape=_DEFAULT_AUTOESCAPE, namespace=None, self.autoescape = autoescape self.namespace = namespace or {} self.whitespace = whitespace - self.templates = {} + self.templates = {} # type: Dict[str, Template] # self.lock protects self.templates. It's a reentrant lock # because templates may load other templates via `include` or # `extends`. Note that thanks to the GIL this code would be safe @@ -408,16 +429,16 @@ def __init__(self, autoescape=_DEFAULT_AUTOESCAPE, namespace=None, # threads tried to compile the same template simultaneously. self.lock = threading.RLock() - def reset(self): + def reset(self) -> None: """Resets the cache of compiled templates.""" with self.lock: self.templates = {} - def resolve_path(self, name, parent_path=None): + def resolve_path(self, name: str, parent_path: Optional[str] = None) -> str: """Converts a possibly-relative path to absolute (used internally).""" raise NotImplementedError() - def load(self, name, parent_path=None): + def load(self, name: str, parent_path: Optional[str] = None) -> Template: """Loads a template.""" name = self.resolve_path(name, parent_path=parent_path) with self.lock: @@ -425,29 +446,32 @@ def load(self, name, parent_path=None): self.templates[name] = self._create_template(name) return self.templates[name] - def _create_template(self, name): + def _create_template(self, name: str) -> Template: raise NotImplementedError() class Loader(BaseLoader): - """A template loader that loads from a single root directory. - """ - def __init__(self, root_directory, **kwargs): - super(Loader, self).__init__(**kwargs) + """A template loader that loads from a single root directory.""" + + def __init__(self, root_directory: str, **kwargs: Any) -> None: + super().__init__(**kwargs) self.root = os.path.abspath(root_directory) - def resolve_path(self, name, parent_path=None): - if parent_path and not parent_path.startswith("<") and \ - not parent_path.startswith("/") and \ - not name.startswith("/"): + def resolve_path(self, name: str, parent_path: Optional[str] = None) -> str: + if ( + parent_path + and not parent_path.startswith("<") + and not parent_path.startswith("/") + and not name.startswith("/") + ): current_path = os.path.join(self.root, parent_path) file_dir = os.path.dirname(os.path.abspath(current_path)) relative_path = os.path.abspath(os.path.join(file_dir, name)) if relative_path.startswith(self.root): - name = relative_path[len(self.root) + 1:] + name = relative_path[len(self.root) + 1 :] return name - def _create_template(self, name): + def _create_template(self, name: str) -> Template: path = os.path.join(self.root, name) with open(path, "rb") as f: template = Template(f.read(), name=name, loader=self) @@ -456,41 +480,47 @@ def _create_template(self, name): class DictLoader(BaseLoader): """A template loader that loads from a dictionary.""" - def __init__(self, dict, **kwargs): - super(DictLoader, self).__init__(**kwargs) + + def __init__(self, dict: Dict[str, str], **kwargs: Any) -> None: + super().__init__(**kwargs) self.dict = dict - def resolve_path(self, name, parent_path=None): - if parent_path and not parent_path.startswith("<") and \ - not parent_path.startswith("/") and \ - not name.startswith("/"): + def resolve_path(self, name: str, parent_path: Optional[str] = None) -> str: + if ( + parent_path + and not parent_path.startswith("<") + and not parent_path.startswith("/") + and not name.startswith("/") + ): file_dir = posixpath.dirname(parent_path) name = posixpath.normpath(posixpath.join(file_dir, name)) return name - def _create_template(self, name): + def _create_template(self, name: str) -> Template: return Template(self.dict[name], name=name, loader=self) class _Node(object): - def each_child(self): + def each_child(self) -> Iterable["_Node"]: return () - def generate(self, writer): + def generate(self, writer: "_CodeWriter") -> None: raise NotImplementedError() - def find_named_blocks(self, loader, named_blocks): + def find_named_blocks( + self, loader: Optional[BaseLoader], named_blocks: Dict[str, "_NamedBlock"] + ) -> None: for child in self.each_child(): child.find_named_blocks(loader, named_blocks) class _File(_Node): - def __init__(self, template, body): + def __init__(self, template: Template, body: "_ChunkList") -> None: self.template = template self.body = body self.line = 0 - def generate(self, writer): + def generate(self, writer: "_CodeWriter") -> None: writer.write_line("def _tt_execute():", self.line) with writer.indent(): writer.write_line("_tt_buffer = []", self.line) @@ -498,73 +528,79 @@ def generate(self, writer): self.body.generate(writer) writer.write_line("return _tt_utf8('').join(_tt_buffer)", self.line) - def each_child(self): + def each_child(self) -> Iterable["_Node"]: return (self.body,) class _ChunkList(_Node): - def __init__(self, chunks): + def __init__(self, chunks: List[_Node]) -> None: self.chunks = chunks - def generate(self, writer): + def generate(self, writer: "_CodeWriter") -> None: for chunk in self.chunks: chunk.generate(writer) - def each_child(self): + def each_child(self) -> Iterable["_Node"]: return self.chunks class _NamedBlock(_Node): - def __init__(self, name, body, template, line): + def __init__(self, name: str, body: _Node, template: Template, line: int) -> None: self.name = name self.body = body self.template = template self.line = line - def each_child(self): + def each_child(self) -> Iterable["_Node"]: return (self.body,) - def generate(self, writer): + def generate(self, writer: "_CodeWriter") -> None: block = writer.named_blocks[self.name] with writer.include(block.template, self.line): block.body.generate(writer) - def find_named_blocks(self, loader, named_blocks): + def find_named_blocks( + self, loader: Optional[BaseLoader], named_blocks: Dict[str, "_NamedBlock"] + ) -> None: named_blocks[self.name] = self _Node.find_named_blocks(self, loader, named_blocks) class _ExtendsBlock(_Node): - def __init__(self, name): + def __init__(self, name: str) -> None: self.name = name class _IncludeBlock(_Node): - def __init__(self, name, reader, line): + def __init__(self, name: str, reader: "_TemplateReader", line: int) -> None: self.name = name self.template_name = reader.name self.line = line - def find_named_blocks(self, loader, named_blocks): + def find_named_blocks( + self, loader: Optional[BaseLoader], named_blocks: Dict[str, _NamedBlock] + ) -> None: + assert loader is not None included = loader.load(self.name, self.template_name) included.file.find_named_blocks(loader, named_blocks) - def generate(self, writer): + def generate(self, writer: "_CodeWriter") -> None: + assert writer.loader is not None included = writer.loader.load(self.name, self.template_name) with writer.include(included, self.line): included.file.body.generate(writer) class _ApplyBlock(_Node): - def __init__(self, method, line, body=None): + def __init__(self, method: str, line: int, body: _Node) -> None: self.method = method self.line = line self.body = body - def each_child(self): + def each_child(self) -> Iterable["_Node"]: return (self.body,) - def generate(self, writer): + def generate(self, writer: "_CodeWriter") -> None: method_name = "_tt_apply%d" % writer.apply_counter writer.apply_counter += 1 writer.write_line("def %s():" % method_name, self.line) @@ -573,20 +609,21 @@ def generate(self, writer): writer.write_line("_tt_append = _tt_buffer.append", self.line) self.body.generate(writer) writer.write_line("return _tt_utf8('').join(_tt_buffer)", self.line) - writer.write_line("_tt_append(_tt_utf8(%s(%s())))" % ( - self.method, method_name), self.line) + writer.write_line( + "_tt_append(_tt_utf8(%s(%s())))" % (self.method, method_name), self.line + ) class _ControlBlock(_Node): - def __init__(self, statement, line, body=None): + def __init__(self, statement: str, line: int, body: _Node) -> None: self.statement = statement self.line = line self.body = body - def each_child(self): + def each_child(self) -> Iterable[_Node]: return (self.body,) - def generate(self, writer): + def generate(self, writer: "_CodeWriter") -> None: writer.write_line("%s:" % self.statement, self.line) with writer.indent(): self.body.generate(writer) @@ -595,57 +632,60 @@ def generate(self, writer): class _IntermediateControlBlock(_Node): - def __init__(self, statement, line): + def __init__(self, statement: str, line: int) -> None: self.statement = statement self.line = line - def generate(self, writer): + def generate(self, writer: "_CodeWriter") -> None: # In case the previous block was empty writer.write_line("pass", self.line) writer.write_line("%s:" % self.statement, self.line, writer.indent_size() - 1) class _Statement(_Node): - def __init__(self, statement, line): + def __init__(self, statement: str, line: int) -> None: self.statement = statement self.line = line - def generate(self, writer): + def generate(self, writer: "_CodeWriter") -> None: writer.write_line(self.statement, self.line) class _Expression(_Node): - def __init__(self, expression, line, raw=False): + def __init__(self, expression: str, line: int, raw: bool = False) -> None: self.expression = expression self.line = line self.raw = raw - def generate(self, writer): + def generate(self, writer: "_CodeWriter") -> None: writer.write_line("_tt_tmp = %s" % self.expression, self.line) - writer.write_line("if isinstance(_tt_tmp, _tt_string_types):" - " _tt_tmp = _tt_utf8(_tt_tmp)", self.line) + writer.write_line( + "if isinstance(_tt_tmp, _tt_string_types):" " _tt_tmp = _tt_utf8(_tt_tmp)", + self.line, + ) writer.write_line("else: _tt_tmp = _tt_utf8(str(_tt_tmp))", self.line) if not self.raw and writer.current_template.autoescape is not None: # In python3 functions like xhtml_escape return unicode, # so we have to convert to utf8 again. - writer.write_line("_tt_tmp = _tt_utf8(%s(_tt_tmp))" % - writer.current_template.autoescape, self.line) + writer.write_line( + "_tt_tmp = _tt_utf8(%s(_tt_tmp))" % writer.current_template.autoescape, + self.line, + ) writer.write_line("_tt_append(_tt_tmp)", self.line) class _Module(_Expression): - def __init__(self, expression, line): - super(_Module, self).__init__("_tt_modules." + expression, line, - raw=True) + def __init__(self, expression: str, line: int) -> None: + super().__init__("_tt_modules." + expression, line, raw=True) class _Text(_Node): - def __init__(self, value, line, whitespace): + def __init__(self, value: str, line: int, whitespace: str) -> None: self.value = value self.line = line self.whitespace = whitespace - def generate(self, writer): + def generate(self, writer: "_CodeWriter") -> None: value = self.value # Compress whitespace if requested, with a crude heuristic to avoid @@ -654,7 +694,7 @@ def generate(self, writer): value = filter_whitespace(self.whitespace, value) if value: - writer.write_line('_tt_append(%r)' % escape.utf8(value), self.line) + writer.write_line("_tt_append(%r)" % escape.utf8(value), self.line) class ParseError(Exception): @@ -666,75 +706,87 @@ class ParseError(Exception): .. versionchanged:: 4.3 Added ``filename`` and ``lineno`` attributes. """ - def __init__(self, message, filename=None, lineno=0): + + def __init__( + self, message: str, filename: Optional[str] = None, lineno: int = 0 + ) -> None: self.message = message # The names "filename" and "lineno" are chosen for consistency # with python SyntaxError. self.filename = filename self.lineno = lineno - def __str__(self): - return '%s at %s:%d' % (self.message, self.filename, self.lineno) + def __str__(self) -> str: + return "%s at %s:%d" % (self.message, self.filename, self.lineno) class _CodeWriter(object): - def __init__(self, file, named_blocks, loader, current_template): + def __init__( + self, + file: TextIO, + named_blocks: Dict[str, _NamedBlock], + loader: Optional[BaseLoader], + current_template: Template, + ) -> None: self.file = file self.named_blocks = named_blocks self.loader = loader self.current_template = current_template self.apply_counter = 0 - self.include_stack = [] + self.include_stack = [] # type: List[Tuple[Template, int]] self._indent = 0 - def indent_size(self): + def indent_size(self) -> int: return self._indent - def indent(self): + def indent(self) -> "ContextManager": class Indenter(object): - def __enter__(_): + def __enter__(_) -> "_CodeWriter": self._indent += 1 return self - def __exit__(_, *args): + def __exit__(_, *args: Any) -> None: assert self._indent > 0 self._indent -= 1 return Indenter() - def include(self, template, line): + def include(self, template: Template, line: int) -> "ContextManager": self.include_stack.append((self.current_template, line)) self.current_template = template class IncludeTemplate(object): - def __enter__(_): + def __enter__(_) -> "_CodeWriter": return self - def __exit__(_, *args): + def __exit__(_, *args: Any) -> None: self.current_template = self.include_stack.pop()[0] return IncludeTemplate() - def write_line(self, line, line_number, indent=None): + def write_line( + self, line: str, line_number: int, indent: Optional[int] = None + ) -> None: if indent is None: indent = self._indent - line_comment = ' # %s:%d' % (self.current_template.name, line_number) + line_comment = " # %s:%d" % (self.current_template.name, line_number) if self.include_stack: - ancestors = ["%s:%d" % (tmpl.name, lineno) - for (tmpl, lineno) in self.include_stack] - line_comment += ' (via %s)' % ', '.join(reversed(ancestors)) + ancestors = [ + "%s:%d" % (tmpl.name, lineno) for (tmpl, lineno) in self.include_stack + ] + line_comment += " (via %s)" % ", ".join(reversed(ancestors)) print(" " * indent + line + line_comment, file=self.file) class _TemplateReader(object): - def __init__(self, name, text, whitespace): + def __init__(self, name: str, text: str, whitespace: str) -> None: self.name = name self.text = text self.whitespace = whitespace self.line = 1 self.pos = 0 - def find(self, needle, start=0, end=None): + def find(self, needle: str, start: int = 0, end: Optional[int] = None) -> int: assert start >= 0, start pos = self.pos start += pos @@ -748,23 +800,23 @@ def find(self, needle, start=0, end=None): index -= pos return index - def consume(self, count=None): + def consume(self, count: Optional[int] = None) -> str: if count is None: count = len(self.text) - self.pos newpos = self.pos + count self.line += self.text.count("\n", self.pos, newpos) - s = self.text[self.pos:newpos] + s = self.text[self.pos : newpos] self.pos = newpos return s - def remaining(self): + def remaining(self) -> int: return len(self.text) - self.pos - def __len__(self): + def __len__(self) -> int: return self.remaining() - def __getitem__(self, key): - if type(key) is slice: + def __getitem__(self, key: Union[int, slice]) -> str: + if isinstance(key, slice): size = len(self) start, stop, step = key.indices(size) if start is None: @@ -779,20 +831,25 @@ def __getitem__(self, key): else: return self.text[self.pos + key] - def __str__(self): - return self.text[self.pos:] + def __str__(self) -> str: + return self.text[self.pos :] - def raise_parse_error(self, msg): + def raise_parse_error(self, msg: str) -> None: raise ParseError(msg, self.name, self.line) -def _format_code(code): +def _format_code(code: str) -> str: lines = code.splitlines() format = "%%%dd %%s\n" % len(repr(len(lines) + 1)) return "".join([format % (i + 1, line) for (i, line) in enumerate(lines)]) -def _parse(reader, template, in_block=None, in_loop=None): +def _parse( + reader: _TemplateReader, + template: Template, + in_block: Optional[str] = None, + in_loop: Optional[str] = None, +) -> _ChunkList: body = _ChunkList([]) while True: # Find next template directive @@ -803,9 +860,11 @@ def _parse(reader, template, in_block=None, in_loop=None): # EOF if in_block: reader.raise_parse_error( - "Missing {%% end %%} block for %s" % in_block) - body.chunks.append(_Text(reader.consume(), reader.line, - reader.whitespace)) + "Missing {%% end %%} block for %s" % in_block + ) + body.chunks.append( + _Text(reader.consume(), reader.line, reader.whitespace) + ) return body # If the first curly brace is not the start of a special token, # start searching from the character after it @@ -815,8 +874,11 @@ def _parse(reader, template, in_block=None, in_loop=None): # When there are more than 2 curlies in a row, use the # innermost ones. This is useful when generating languages # like latex where curlies are also meaningful - if (curly + 2 < reader.remaining() and - reader[curly + 1] == '{' and reader[curly + 2] == '{'): + if ( + curly + 2 < reader.remaining() + and reader[curly + 1] == "{" + and reader[curly + 2] == "{" + ): curly += 1 continue break @@ -824,8 +886,7 @@ def _parse(reader, template, in_block=None, in_loop=None): # Append any text before the special token if curly > 0: cons = reader.consume(curly) - body.chunks.append(_Text(cons, reader.line, - reader.whitespace)) + body.chunks.append(_Text(cons, reader.line, reader.whitespace)) start_brace = reader.consume(2) line = reader.line @@ -836,8 +897,7 @@ def _parse(reader, template, in_block=None, in_loop=None): # which also use double braces. if reader.remaining() and reader[0] == "!": reader.consume(1) - body.chunks.append(_Text(start_brace, line, - reader.whitespace)) + body.chunks.append(_Text(start_brace, line, reader.whitespace)) continue # Comment @@ -884,12 +944,13 @@ def _parse(reader, template, in_block=None, in_loop=None): allowed_parents = intermediate_blocks.get(operator) if allowed_parents is not None: if not in_block: - reader.raise_parse_error("%s outside %s block" % - (operator, allowed_parents)) + reader.raise_parse_error( + "%s outside %s block" % (operator, allowed_parents) + ) if in_block not in allowed_parents: reader.raise_parse_error( - "%s block cannot be attached to %s block" % - (operator, in_block)) + "%s block cannot be attached to %s block" % (operator, in_block) + ) body.chunks.append(_IntermediateControlBlock(contents, line)) continue @@ -899,16 +960,25 @@ def _parse(reader, template, in_block=None, in_loop=None): reader.raise_parse_error("Extra {% end %} block") return body - elif operator in ("extends", "include", "set", "import", "from", - "comment", "autoescape", "whitespace", "raw", - "module"): + elif operator in ( + "extends", + "include", + "set", + "import", + "from", + "comment", + "autoescape", + "whitespace", + "raw", + "module", + ): if operator == "comment": continue if operator == "extends": suffix = suffix.strip('"').strip("'") if not suffix: reader.raise_parse_error("extends missing file path") - block = _ExtendsBlock(suffix) + block = _ExtendsBlock(suffix) # type: _Node elif operator in ("import", "from"): if not suffix: reader.raise_parse_error("import missing statement") @@ -923,7 +993,7 @@ def _parse(reader, template, in_block=None, in_loop=None): reader.raise_parse_error("set missing statement") block = _Statement(suffix, line) elif operator == "autoescape": - fn = suffix.strip() + fn = suffix.strip() # type: Optional[str] if fn == "None": fn = None template.autoescape = fn @@ -931,7 +1001,7 @@ def _parse(reader, template, in_block=None, in_loop=None): elif operator == "whitespace": mode = suffix.strip() # Validate the selected mode - filter_whitespace(mode, '') + filter_whitespace(mode, "") reader.whitespace = mode continue elif operator == "raw": @@ -967,8 +1037,9 @@ def _parse(reader, template, in_block=None, in_loop=None): elif operator in ("break", "continue"): if not in_loop: - reader.raise_parse_error("%s outside %s block" % - (operator, set(["for", "while"]))) + reader.raise_parse_error( + "%s outside %s block" % (operator, set(["for", "while"])) + ) body.chunks.append(_Statement(contents, line)) continue diff --git a/tornado/test/__main__.py b/tornado/test/__main__.py index c78478cbd3..430c895fa2 100644 --- a/tornado/test/__main__.py +++ b/tornado/test/__main__.py @@ -2,8 +2,6 @@ This only works in python 2.7+. """ -from __future__ import absolute_import, division, print_function - from tornado.test.runtests import all, main # tornado.testing.main autodiscovery relies on 'all' being present in diff --git a/tornado/test/asyncio_test.py b/tornado/test/asyncio_test.py index a7c7564964..3f9f3389a2 100644 --- a/tornado/test/asyncio_test.py +++ b/tornado/test/asyncio_test.py @@ -10,25 +10,20 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function +import asyncio +import unittest from concurrent.futures import ThreadPoolExecutor from tornado import gen from tornado.ioloop import IOLoop +from tornado.platform.asyncio import ( + AsyncIOLoop, + to_asyncio_future, + AnyThreadEventLoopPolicy, +) from tornado.testing import AsyncTestCase, gen_test -from tornado.test.util import unittest, skipBefore33, skipBefore35, exec_test -try: - from tornado.platform.asyncio import asyncio -except ImportError: - asyncio = None -else: - from tornado.platform.asyncio import AsyncIOLoop, to_asyncio_future, AnyThreadEventLoopPolicy - # This is used in dynamically-evaluated code, so silence pyflakes. - to_asyncio_future - -@unittest.skipIf(asyncio is None, "asyncio module not present") class AsyncIOLoopTest(AsyncTestCase): def get_new_ioloop(self): io_loop = AsyncIOLoop() @@ -44,32 +39,28 @@ def test_asyncio_future(self): # Test that we can yield an asyncio future from a tornado coroutine. # Without 'yield from', we must wrap coroutines in ensure_future, # which was introduced during Python 3.4, deprecating the prior "async". - if hasattr(asyncio, 'ensure_future'): + if hasattr(asyncio, "ensure_future"): ensure_future = asyncio.ensure_future else: # async is a reserved word in Python 3.7 - ensure_future = getattr(asyncio, 'async') + ensure_future = getattr(asyncio, "async") x = yield ensure_future( - asyncio.get_event_loop().run_in_executor(None, lambda: 42)) + asyncio.get_event_loop().run_in_executor(None, lambda: 42) + ) self.assertEqual(x, 42) - @skipBefore33 @gen_test def test_asyncio_yield_from(self): - # Test that we can use asyncio coroutines with 'yield from' - # instead of asyncio.async(). This requires python 3.3 syntax. - namespace = exec_test(globals(), locals(), """ @gen.coroutine def f(): event_loop = asyncio.get_event_loop() x = yield from event_loop.run_in_executor(None, lambda: 42) return x - """) - result = yield namespace['f']() + + result = yield f() self.assertEqual(result, 42) - @skipBefore35 def test_asyncio_adapter(self): # This test demonstrates that when using the asyncio coroutine # runner (i.e. run_until_complete), the to_asyncio_future @@ -79,50 +70,44 @@ def test_asyncio_adapter(self): def tornado_coroutine(): yield gen.moment raise gen.Return(42) - native_coroutine_without_adapter = exec_test(globals(), locals(), """ + async def native_coroutine_without_adapter(): return await tornado_coroutine() - """)["native_coroutine_without_adapter"] - native_coroutine_with_adapter = exec_test(globals(), locals(), """ async def native_coroutine_with_adapter(): return await to_asyncio_future(tornado_coroutine()) - """)["native_coroutine_with_adapter"] # Use the adapter, but two degrees from the tornado coroutine. - native_coroutine_with_adapter2 = exec_test(globals(), locals(), """ async def native_coroutine_with_adapter2(): return await to_asyncio_future(native_coroutine_without_adapter()) - """)["native_coroutine_with_adapter2"] # Tornado supports native coroutines both with and without adapters - self.assertEqual( - self.io_loop.run_sync(native_coroutine_without_adapter), - 42) - self.assertEqual( - self.io_loop.run_sync(native_coroutine_with_adapter), - 42) - self.assertEqual( - self.io_loop.run_sync(native_coroutine_with_adapter2), - 42) + self.assertEqual(self.io_loop.run_sync(native_coroutine_without_adapter), 42) + self.assertEqual(self.io_loop.run_sync(native_coroutine_with_adapter), 42) + self.assertEqual(self.io_loop.run_sync(native_coroutine_with_adapter2), 42) # Asyncio only supports coroutines that yield asyncio-compatible # Futures (which our Future is since 5.0). self.assertEqual( asyncio.get_event_loop().run_until_complete( - native_coroutine_without_adapter()), - 42) + native_coroutine_without_adapter() + ), + 42, + ) self.assertEqual( asyncio.get_event_loop().run_until_complete( - native_coroutine_with_adapter()), - 42) + native_coroutine_with_adapter() + ), + 42, + ) self.assertEqual( asyncio.get_event_loop().run_until_complete( - native_coroutine_with_adapter2()), - 42) + native_coroutine_with_adapter2() + ), + 42, + ) -@unittest.skipIf(asyncio is None, "asyncio module not present") class LeakTest(unittest.TestCase): def setUp(self): # Trigger a cleanup of the mapping so we start with a clean slate. @@ -160,7 +145,6 @@ def test_asyncio_close_leak(self): self.assertEqual(new_count, 1) -@unittest.skipIf(asyncio is None, "asyncio module not present") class AnyThreadEventLoopPolicyTest(unittest.TestCase): def setUp(self): self.orig_policy = asyncio.get_event_loop_policy() @@ -182,22 +166,22 @@ def get_and_close_event_loop(): loop = asyncio.get_event_loop() loop.close() return loop + future = self.executor.submit(get_and_close_event_loop) return future.result() def run_policy_test(self, accessor, expected_type): # With the default policy, non-main threads don't get an event # loop. - self.assertRaises((RuntimeError, AssertionError), - self.executor.submit(accessor).result) + self.assertRaises( + (RuntimeError, AssertionError), self.executor.submit(accessor).result + ) # Set the policy and we can get a loop. asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) - self.assertIsInstance( - self.executor.submit(accessor).result(), - expected_type) + self.assertIsInstance(self.executor.submit(accessor).result(), expected_type) # Clean up to silence leak warnings. Always use asyncio since # IOLoop doesn't (currently) close the underlying loop. - self.executor.submit(lambda: asyncio.get_event_loop().close()).result() + self.executor.submit(lambda: asyncio.get_event_loop().close()).result() # type: ignore def test_asyncio_accessor(self): self.run_policy_test(asyncio.get_event_loop, asyncio.AbstractEventLoop) diff --git a/tornado/test/auth_test.py b/tornado/test/auth_test.py index 0d26a35afc..8de863eb21 100644 --- a/tornado/test/auth_test.py +++ b/tornado/test/auth_test.py @@ -3,139 +3,92 @@ # and ensure that it doesn't blow up (e.g. with unicode/bytes issues in # python 3) - -from __future__ import absolute_import, division, print_function - -import warnings +import unittest from tornado.auth import ( - AuthError, OpenIdMixin, OAuthMixin, OAuth2Mixin, - GoogleOAuth2Mixin, FacebookGraphMixin, TwitterMixin, + OpenIdMixin, + OAuthMixin, + OAuth2Mixin, + GoogleOAuth2Mixin, + FacebookGraphMixin, + TwitterMixin, ) -from tornado.concurrent import Future from tornado.escape import json_decode from tornado import gen +from tornado.httpclient import HTTPClientError from tornado.httputil import url_concat -from tornado.log import gen_log +from tornado.log import app_log from tornado.testing import AsyncHTTPTestCase, ExpectLog -from tornado.test.util import ignore_deprecation -from tornado.web import RequestHandler, Application, asynchronous, HTTPError - - -class OpenIdClientLoginHandlerLegacy(RequestHandler, OpenIdMixin): - def initialize(self, test): - self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate') +from tornado.web import RequestHandler, Application, HTTPError - @asynchronous - def get(self): - if self.get_argument('openid.mode', None): - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - self.get_authenticated_user( - self.on_user, http_client=self.settings['http_client']) - return - res = self.authenticate_redirect() - assert isinstance(res, Future) - assert res.done() - - def on_user(self, user): - if user is None: - raise Exception("user is None") - self.finish(user) +try: + from unittest import mock +except ImportError: + mock = None # type: ignore class OpenIdClientLoginHandler(RequestHandler, OpenIdMixin): def initialize(self, test): - self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate') + self._OPENID_ENDPOINT = test.get_url("/openid/server/authenticate") @gen.coroutine def get(self): - if self.get_argument('openid.mode', None): - user = yield self.get_authenticated_user(http_client=self.settings['http_client']) + if self.get_argument("openid.mode", None): + user = yield self.get_authenticated_user( + http_client=self.settings["http_client"] + ) if user is None: raise Exception("user is None") self.finish(user) return - res = self.authenticate_redirect() - assert isinstance(res, Future) - assert res.done() + res = self.authenticate_redirect() # type: ignore + assert res is None class OpenIdServerAuthenticateHandler(RequestHandler): def post(self): - if self.get_argument('openid.mode') != 'check_authentication': + if self.get_argument("openid.mode") != "check_authentication": raise Exception("incorrect openid.mode %r") - self.write('is_valid:true') - - -class OAuth1ClientLoginHandlerLegacy(RequestHandler, OAuthMixin): - def initialize(self, test, version): - self._OAUTH_VERSION = version - self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token') - self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize') - self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/oauth1/server/access_token') - - def _oauth_consumer_token(self): - return dict(key='asdf', secret='qwer') - - @asynchronous - def get(self): - if self.get_argument('oauth_token', None): - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - self.get_authenticated_user( - self.on_user, http_client=self.settings['http_client']) - return - res = self.authorize_redirect(http_client=self.settings['http_client']) - assert isinstance(res, Future) - - def on_user(self, user): - if user is None: - raise Exception("user is None") - self.finish(user) - - def _oauth_get_user(self, access_token, callback): - if self.get_argument('fail_in_get_user', None): - raise Exception("failing in get_user") - if access_token != dict(key='uiop', secret='5678'): - raise Exception("incorrect access token %r" % access_token) - callback(dict(email='foo@example.com')) + self.write("is_valid:true") class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin): def initialize(self, test, version): self._OAUTH_VERSION = version - self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token') - self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize') - self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/oauth1/server/access_token') + self._OAUTH_REQUEST_TOKEN_URL = test.get_url("/oauth1/server/request_token") + self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth1/server/authorize") + self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/oauth1/server/access_token") def _oauth_consumer_token(self): - return dict(key='asdf', secret='qwer') + return dict(key="asdf", secret="qwer") @gen.coroutine def get(self): - if self.get_argument('oauth_token', None): - user = yield self.get_authenticated_user(http_client=self.settings['http_client']) + if self.get_argument("oauth_token", None): + user = yield self.get_authenticated_user( + http_client=self.settings["http_client"] + ) if user is None: raise Exception("user is None") self.finish(user) return - yield self.authorize_redirect(http_client=self.settings['http_client']) + yield self.authorize_redirect(http_client=self.settings["http_client"]) @gen.coroutine def _oauth_get_user_future(self, access_token): - if self.get_argument('fail_in_get_user', None): + if self.get_argument("fail_in_get_user", None): raise Exception("failing in get_user") - if access_token != dict(key='uiop', secret='5678'): + if access_token != dict(key="uiop", secret="5678"): raise Exception("incorrect access token %r" % access_token) - return dict(email='foo@example.com') + return dict(email="foo@example.com") class OAuth1ClientLoginCoroutineHandler(OAuth1ClientLoginHandler): """Replaces OAuth1ClientLoginCoroutineHandler's get() with a coroutine.""" + @gen.coroutine def get(self): - if self.get_argument('oauth_token', None): + if self.get_argument("oauth_token", None): # Ensure that any exceptions are set on the returned Future, # not simply thrown into the surrounding StackContext. try: @@ -152,41 +105,41 @@ def initialize(self, version): self._OAUTH_VERSION = version def _oauth_consumer_token(self): - return dict(key='asdf', secret='qwer') + return dict(key="asdf", secret="qwer") def get(self): params = self._oauth_request_parameters( - 'http://www.example.com/api/asdf', - dict(key='uiop', secret='5678'), - parameters=dict(foo='bar')) + "http://www.example.com/api/asdf", + dict(key="uiop", secret="5678"), + parameters=dict(foo="bar"), + ) self.write(params) class OAuth1ServerRequestTokenHandler(RequestHandler): def get(self): - self.write('oauth_token=zxcv&oauth_token_secret=1234') + self.write("oauth_token=zxcv&oauth_token_secret=1234") class OAuth1ServerAccessTokenHandler(RequestHandler): def get(self): - self.write('oauth_token=uiop&oauth_token_secret=5678') + self.write("oauth_token=uiop&oauth_token_secret=5678") class OAuth2ClientLoginHandler(RequestHandler, OAuth2Mixin): def initialize(self, test): - self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth2/server/authorize') + self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth2/server/authorize") def get(self): - res = self.authorize_redirect() - assert isinstance(res, Future) - assert res.done() + res = self.authorize_redirect() # type: ignore + assert res is None class FacebookClientLoginHandler(RequestHandler, FacebookGraphMixin): def initialize(self, test): - self._OAUTH_AUTHORIZE_URL = test.get_url('/facebook/server/authorize') - self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/facebook/server/access_token') - self._FACEBOOK_BASE_URL = test.get_url('/facebook/server') + self._OAUTH_AUTHORIZE_URL = test.get_url("/facebook/server/authorize") + self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/facebook/server/access_token") + self._FACEBOOK_BASE_URL = test.get_url("/facebook/server") @gen.coroutine def get(self): @@ -195,13 +148,15 @@ def get(self): redirect_uri=self.request.full_url(), client_id=self.settings["facebook_api_key"], client_secret=self.settings["facebook_secret"], - code=self.get_argument("code")) + code=self.get_argument("code"), + ) self.write(user) else: - yield self.authorize_redirect( + self.authorize_redirect( redirect_uri=self.request.full_url(), client_id=self.settings["facebook_api_key"], - extra_params={"scope": "read_stream,offline_access"}) + extra_params={"scope": "read_stream,offline_access"}, + ) class FacebookServerAccessTokenHandler(RequestHandler): @@ -211,35 +166,36 @@ def get(self): class FacebookServerMeHandler(RequestHandler): def get(self): - self.write('{}') + self.write("{}") class TwitterClientHandler(RequestHandler, TwitterMixin): def initialize(self, test): - self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token') - self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/twitter/server/access_token') - self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize') - self._TWITTER_BASE_URL = test.get_url('/twitter/api') + self._OAUTH_REQUEST_TOKEN_URL = test.get_url("/oauth1/server/request_token") + self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/twitter/server/access_token") + self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth1/server/authorize") + self._OAUTH_AUTHENTICATE_URL = test.get_url("/twitter/server/authenticate") + self._TWITTER_BASE_URL = test.get_url("/twitter/api") def get_auth_http_client(self): - return self.settings['http_client'] + return self.settings["http_client"] -class TwitterClientLoginHandlerLegacy(TwitterClientHandler): - @asynchronous +class TwitterClientLoginHandler(TwitterClientHandler): + @gen.coroutine def get(self): if self.get_argument("oauth_token", None): - self.get_authenticated_user(self.on_user) + user = yield self.get_authenticated_user() + if user is None: + raise Exception("user is None") + self.finish(user) return - self.authorize_redirect() - - def on_user(self, user): - if user is None: - raise Exception("user is None") - self.finish(user) + yield self.authorize_redirect() -class TwitterClientLoginHandler(TwitterClientHandler): +class TwitterClientAuthenticateHandler(TwitterClientHandler): + # Like TwitterClientLoginHandler, but uses authenticate_redirect + # instead of authorize_redirect. @gen.coroutine def get(self): if self.get_argument("oauth_token", None): @@ -248,21 +204,7 @@ def get(self): raise Exception("user is None") self.finish(user) return - yield self.authorize_redirect() - - -class TwitterClientLoginGenEngineHandler(TwitterClientHandler): - with ignore_deprecation(): - @asynchronous - @gen.engine - def get(self): - if self.get_argument("oauth_token", None): - user = yield self.get_authenticated_user() - self.finish(user) - else: - # Old style: with @gen.engine we can ignore the Future from - # authorize_redirect. - self.authorize_redirect() + yield self.authenticate_redirect() class TwitterClientLoginGenCoroutineHandler(TwitterClientHandler): @@ -277,25 +219,6 @@ def get(self): yield self.authorize_redirect() -class TwitterClientShowUserHandlerLegacy(TwitterClientHandler): - with ignore_deprecation(): - @asynchronous - @gen.engine - def get(self): - # TODO: would be nice to go through the login flow instead of - # cheating with a hard-coded access token. - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - response = yield gen.Task(self.twitter_request, - '/users/show/%s' % self.get_argument('name'), - access_token=dict(key='hjkl', secret='vbnm')) - if response is None: - self.set_status(500) - self.finish('error from twitter request') - else: - self.finish(response) - - class TwitterClientShowUserHandler(TwitterClientHandler): @gen.coroutine def get(self): @@ -303,44 +226,47 @@ def get(self): # cheating with a hard-coded access token. try: response = yield self.twitter_request( - '/users/show/%s' % self.get_argument('name'), - access_token=dict(key='hjkl', secret='vbnm')) - except AuthError: + "/users/show/%s" % self.get_argument("name"), + access_token=dict(key="hjkl", secret="vbnm"), + ) + except HTTPClientError: + # TODO(bdarnell): Should we catch HTTP errors and + # transform some of them (like 403s) into AuthError? self.set_status(500) - self.finish('error from twitter request') + self.finish("error from twitter request") else: self.finish(response) class TwitterServerAccessTokenHandler(RequestHandler): def get(self): - self.write('oauth_token=hjkl&oauth_token_secret=vbnm&screen_name=foo') + self.write("oauth_token=hjkl&oauth_token_secret=vbnm&screen_name=foo") class TwitterServerShowUserHandler(RequestHandler): def get(self, screen_name): - if screen_name == 'error': + if screen_name == "error": raise HTTPError(500) - assert 'oauth_nonce' in self.request.arguments - assert 'oauth_timestamp' in self.request.arguments - assert 'oauth_signature' in self.request.arguments - assert self.get_argument('oauth_consumer_key') == 'test_twitter_consumer_key' - assert self.get_argument('oauth_signature_method') == 'HMAC-SHA1' - assert self.get_argument('oauth_version') == '1.0' - assert self.get_argument('oauth_token') == 'hjkl' + assert "oauth_nonce" in self.request.arguments + assert "oauth_timestamp" in self.request.arguments + assert "oauth_signature" in self.request.arguments + assert self.get_argument("oauth_consumer_key") == "test_twitter_consumer_key" + assert self.get_argument("oauth_signature_method") == "HMAC-SHA1" + assert self.get_argument("oauth_version") == "1.0" + assert self.get_argument("oauth_token") == "hjkl" self.write(dict(screen_name=screen_name, name=screen_name.capitalize())) class TwitterServerVerifyCredentialsHandler(RequestHandler): def get(self): - assert 'oauth_nonce' in self.request.arguments - assert 'oauth_timestamp' in self.request.arguments - assert 'oauth_signature' in self.request.arguments - assert self.get_argument('oauth_consumer_key') == 'test_twitter_consumer_key' - assert self.get_argument('oauth_signature_method') == 'HMAC-SHA1' - assert self.get_argument('oauth_version') == '1.0' - assert self.get_argument('oauth_token') == 'hjkl' - self.write(dict(screen_name='foo', name='Foo')) + assert "oauth_nonce" in self.request.arguments + assert "oauth_timestamp" in self.request.arguments + assert "oauth_signature" in self.request.arguments + assert self.get_argument("oauth_consumer_key") == "test_twitter_consumer_key" + assert self.get_argument("oauth_signature_method") == "HMAC-SHA1" + assert self.get_argument("oauth_version") == "1.0" + assert self.get_argument("oauth_token") == "hjkl" + self.write(dict(screen_name="foo", name="Foo")) class AuthTest(AsyncHTTPTestCase): @@ -348,322 +274,310 @@ def get_app(self): return Application( [ # test endpoints - ('/legacy/openid/client/login', OpenIdClientLoginHandlerLegacy, dict(test=self)), - ('/openid/client/login', OpenIdClientLoginHandler, dict(test=self)), - ('/legacy/oauth10/client/login', OAuth1ClientLoginHandlerLegacy, - dict(test=self, version='1.0')), - ('/oauth10/client/login', OAuth1ClientLoginHandler, - dict(test=self, version='1.0')), - ('/oauth10/client/request_params', - OAuth1ClientRequestParametersHandler, - dict(version='1.0')), - ('/legacy/oauth10a/client/login', OAuth1ClientLoginHandlerLegacy, - dict(test=self, version='1.0a')), - ('/oauth10a/client/login', OAuth1ClientLoginHandler, - dict(test=self, version='1.0a')), - ('/oauth10a/client/login_coroutine', - OAuth1ClientLoginCoroutineHandler, - dict(test=self, version='1.0a')), - ('/oauth10a/client/request_params', - OAuth1ClientRequestParametersHandler, - dict(version='1.0a')), - ('/oauth2/client/login', OAuth2ClientLoginHandler, dict(test=self)), - - ('/facebook/client/login', FacebookClientLoginHandler, dict(test=self)), - - ('/legacy/twitter/client/login', TwitterClientLoginHandlerLegacy, dict(test=self)), - ('/twitter/client/login', TwitterClientLoginHandler, dict(test=self)), - ('/twitter/client/login_gen_engine', - TwitterClientLoginGenEngineHandler, dict(test=self)), - ('/twitter/client/login_gen_coroutine', - TwitterClientLoginGenCoroutineHandler, dict(test=self)), - ('/legacy/twitter/client/show_user', - TwitterClientShowUserHandlerLegacy, dict(test=self)), - ('/twitter/client/show_user', - TwitterClientShowUserHandler, dict(test=self)), - + ("/openid/client/login", OpenIdClientLoginHandler, dict(test=self)), + ( + "/oauth10/client/login", + OAuth1ClientLoginHandler, + dict(test=self, version="1.0"), + ), + ( + "/oauth10/client/request_params", + OAuth1ClientRequestParametersHandler, + dict(version="1.0"), + ), + ( + "/oauth10a/client/login", + OAuth1ClientLoginHandler, + dict(test=self, version="1.0a"), + ), + ( + "/oauth10a/client/login_coroutine", + OAuth1ClientLoginCoroutineHandler, + dict(test=self, version="1.0a"), + ), + ( + "/oauth10a/client/request_params", + OAuth1ClientRequestParametersHandler, + dict(version="1.0a"), + ), + ("/oauth2/client/login", OAuth2ClientLoginHandler, dict(test=self)), + ("/facebook/client/login", FacebookClientLoginHandler, dict(test=self)), + ("/twitter/client/login", TwitterClientLoginHandler, dict(test=self)), + ( + "/twitter/client/authenticate", + TwitterClientAuthenticateHandler, + dict(test=self), + ), + ( + "/twitter/client/login_gen_coroutine", + TwitterClientLoginGenCoroutineHandler, + dict(test=self), + ), + ( + "/twitter/client/show_user", + TwitterClientShowUserHandler, + dict(test=self), + ), # simulated servers - ('/openid/server/authenticate', OpenIdServerAuthenticateHandler), - ('/oauth1/server/request_token', OAuth1ServerRequestTokenHandler), - ('/oauth1/server/access_token', OAuth1ServerAccessTokenHandler), - - ('/facebook/server/access_token', FacebookServerAccessTokenHandler), - ('/facebook/server/me', FacebookServerMeHandler), - ('/twitter/server/access_token', TwitterServerAccessTokenHandler), - (r'/twitter/api/users/show/(.*)\.json', TwitterServerShowUserHandler), - (r'/twitter/api/account/verify_credentials\.json', - TwitterServerVerifyCredentialsHandler), + ("/openid/server/authenticate", OpenIdServerAuthenticateHandler), + ("/oauth1/server/request_token", OAuth1ServerRequestTokenHandler), + ("/oauth1/server/access_token", OAuth1ServerAccessTokenHandler), + ("/facebook/server/access_token", FacebookServerAccessTokenHandler), + ("/facebook/server/me", FacebookServerMeHandler), + ("/twitter/server/access_token", TwitterServerAccessTokenHandler), + (r"/twitter/api/users/show/(.*)\.json", TwitterServerShowUserHandler), + ( + r"/twitter/api/account/verify_credentials\.json", + TwitterServerVerifyCredentialsHandler, + ), ], http_client=self.http_client, - twitter_consumer_key='test_twitter_consumer_key', - twitter_consumer_secret='test_twitter_consumer_secret', - facebook_api_key='test_facebook_api_key', - facebook_secret='test_facebook_secret') - - def test_openid_redirect_legacy(self): - response = self.fetch('/legacy/openid/client/login', follow_redirects=False) - self.assertEqual(response.code, 302) - self.assertTrue( - '/openid/server/authenticate?' in response.headers['Location']) - - def test_openid_get_user_legacy(self): - response = self.fetch('/legacy/openid/client/login?openid.mode=blah' - '&openid.ns.ax=http://openid.net/srv/ax/1.0' - '&openid.ax.type.email=http://axschema.org/contact/email' - '&openid.ax.value.email=foo@example.com') - response.rethrow() - parsed = json_decode(response.body) - self.assertEqual(parsed["email"], "foo@example.com") + twitter_consumer_key="test_twitter_consumer_key", + twitter_consumer_secret="test_twitter_consumer_secret", + facebook_api_key="test_facebook_api_key", + facebook_secret="test_facebook_secret", + ) def test_openid_redirect(self): - response = self.fetch('/openid/client/login', follow_redirects=False) + response = self.fetch("/openid/client/login", follow_redirects=False) self.assertEqual(response.code, 302) - self.assertTrue( - '/openid/server/authenticate?' in response.headers['Location']) + self.assertTrue("/openid/server/authenticate?" in response.headers["Location"]) def test_openid_get_user(self): - response = self.fetch('/openid/client/login?openid.mode=blah' - '&openid.ns.ax=http://openid.net/srv/ax/1.0' - '&openid.ax.type.email=http://axschema.org/contact/email' - '&openid.ax.value.email=foo@example.com') + response = self.fetch( + "/openid/client/login?openid.mode=blah" + "&openid.ns.ax=http://openid.net/srv/ax/1.0" + "&openid.ax.type.email=http://axschema.org/contact/email" + "&openid.ax.value.email=foo@example.com" + ) response.rethrow() parsed = json_decode(response.body) self.assertEqual(parsed["email"], "foo@example.com") - def test_oauth10_redirect_legacy(self): - response = self.fetch('/legacy/oauth10/client/login', follow_redirects=False) - self.assertEqual(response.code, 302) - self.assertTrue(response.headers['Location'].endswith( - '/oauth1/server/authorize?oauth_token=zxcv')) - # the cookie is base64('zxcv')|base64('1234') - self.assertTrue( - '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'], - response.headers['Set-Cookie']) - def test_oauth10_redirect(self): - response = self.fetch('/oauth10/client/login', follow_redirects=False) + response = self.fetch("/oauth10/client/login", follow_redirects=False) self.assertEqual(response.code, 302) - self.assertTrue(response.headers['Location'].endswith( - '/oauth1/server/authorize?oauth_token=zxcv')) + self.assertTrue( + response.headers["Location"].endswith( + "/oauth1/server/authorize?oauth_token=zxcv" + ) + ) # the cookie is base64('zxcv')|base64('1234') self.assertTrue( - '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'], - response.headers['Set-Cookie']) - - def test_oauth10_get_user_legacy(self): - with ignore_deprecation(): - response = self.fetch( - '/legacy/oauth10/client/login?oauth_token=zxcv', - headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='}) - response.rethrow() - parsed = json_decode(response.body) - self.assertEqual(parsed['email'], 'foo@example.com') - self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678')) + '_oauth_request_token="enhjdg==|MTIzNA=="' + in response.headers["Set-Cookie"], + response.headers["Set-Cookie"], + ) def test_oauth10_get_user(self): response = self.fetch( - '/oauth10/client/login?oauth_token=zxcv', - headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='}) + "/oauth10/client/login?oauth_token=zxcv", + headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="}, + ) response.rethrow() parsed = json_decode(response.body) - self.assertEqual(parsed['email'], 'foo@example.com') - self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678')) + self.assertEqual(parsed["email"], "foo@example.com") + self.assertEqual(parsed["access_token"], dict(key="uiop", secret="5678")) def test_oauth10_request_parameters(self): - response = self.fetch('/oauth10/client/request_params') - response.rethrow() - parsed = json_decode(response.body) - self.assertEqual(parsed['oauth_consumer_key'], 'asdf') - self.assertEqual(parsed['oauth_token'], 'uiop') - self.assertTrue('oauth_nonce' in parsed) - self.assertTrue('oauth_signature' in parsed) - - def test_oauth10a_redirect_legacy(self): - response = self.fetch('/legacy/oauth10a/client/login', follow_redirects=False) - self.assertEqual(response.code, 302) - self.assertTrue(response.headers['Location'].endswith( - '/oauth1/server/authorize?oauth_token=zxcv')) - # the cookie is base64('zxcv')|base64('1234') - self.assertTrue( - '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'], - response.headers['Set-Cookie']) - - def test_oauth10a_get_user_legacy(self): - with ignore_deprecation(): - response = self.fetch( - '/legacy/oauth10a/client/login?oauth_token=zxcv', - headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='}) + response = self.fetch("/oauth10/client/request_params") response.rethrow() parsed = json_decode(response.body) - self.assertEqual(parsed['email'], 'foo@example.com') - self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678')) + self.assertEqual(parsed["oauth_consumer_key"], "asdf") + self.assertEqual(parsed["oauth_token"], "uiop") + self.assertTrue("oauth_nonce" in parsed) + self.assertTrue("oauth_signature" in parsed) def test_oauth10a_redirect(self): - response = self.fetch('/oauth10a/client/login', follow_redirects=False) + response = self.fetch("/oauth10a/client/login", follow_redirects=False) self.assertEqual(response.code, 302) - self.assertTrue(response.headers['Location'].endswith( - '/oauth1/server/authorize?oauth_token=zxcv')) + self.assertTrue( + response.headers["Location"].endswith( + "/oauth1/server/authorize?oauth_token=zxcv" + ) + ) # the cookie is base64('zxcv')|base64('1234') self.assertTrue( - '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'], - response.headers['Set-Cookie']) + '_oauth_request_token="enhjdg==|MTIzNA=="' + in response.headers["Set-Cookie"], + response.headers["Set-Cookie"], + ) + + @unittest.skipIf(mock is None, "mock package not present") + def test_oauth10a_redirect_error(self): + with mock.patch.object(OAuth1ServerRequestTokenHandler, "get") as get: + get.side_effect = Exception("boom") + with ExpectLog(app_log, "Uncaught exception"): + response = self.fetch("/oauth10a/client/login", follow_redirects=False) + self.assertEqual(response.code, 500) def test_oauth10a_get_user(self): response = self.fetch( - '/oauth10a/client/login?oauth_token=zxcv', - headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='}) + "/oauth10a/client/login?oauth_token=zxcv", + headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="}, + ) response.rethrow() parsed = json_decode(response.body) - self.assertEqual(parsed['email'], 'foo@example.com') - self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678')) + self.assertEqual(parsed["email"], "foo@example.com") + self.assertEqual(parsed["access_token"], dict(key="uiop", secret="5678")) def test_oauth10a_request_parameters(self): - response = self.fetch('/oauth10a/client/request_params') + response = self.fetch("/oauth10a/client/request_params") response.rethrow() parsed = json_decode(response.body) - self.assertEqual(parsed['oauth_consumer_key'], 'asdf') - self.assertEqual(parsed['oauth_token'], 'uiop') - self.assertTrue('oauth_nonce' in parsed) - self.assertTrue('oauth_signature' in parsed) + self.assertEqual(parsed["oauth_consumer_key"], "asdf") + self.assertEqual(parsed["oauth_token"], "uiop") + self.assertTrue("oauth_nonce" in parsed) + self.assertTrue("oauth_signature" in parsed) def test_oauth10a_get_user_coroutine_exception(self): response = self.fetch( - '/oauth10a/client/login_coroutine?oauth_token=zxcv&fail_in_get_user=true', - headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='}) + "/oauth10a/client/login_coroutine?oauth_token=zxcv&fail_in_get_user=true", + headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="}, + ) self.assertEqual(response.code, 503) def test_oauth2_redirect(self): - response = self.fetch('/oauth2/client/login', follow_redirects=False) + response = self.fetch("/oauth2/client/login", follow_redirects=False) self.assertEqual(response.code, 302) - self.assertTrue('/oauth2/server/authorize?' in response.headers['Location']) + self.assertTrue("/oauth2/server/authorize?" in response.headers["Location"]) def test_facebook_login(self): - response = self.fetch('/facebook/client/login', follow_redirects=False) + response = self.fetch("/facebook/client/login", follow_redirects=False) self.assertEqual(response.code, 302) - self.assertTrue('/facebook/server/authorize?' in response.headers['Location']) - response = self.fetch('/facebook/client/login?code=1234', follow_redirects=False) + self.assertTrue("/facebook/server/authorize?" in response.headers["Location"]) + response = self.fetch( + "/facebook/client/login?code=1234", follow_redirects=False + ) self.assertEqual(response.code, 200) user = json_decode(response.body) - self.assertEqual(user['access_token'], 'asdf') - self.assertEqual(user['session_expires'], '3600') + self.assertEqual(user["access_token"], "asdf") + self.assertEqual(user["session_expires"], "3600") def base_twitter_redirect(self, url): # Same as test_oauth10a_redirect response = self.fetch(url, follow_redirects=False) self.assertEqual(response.code, 302) - self.assertTrue(response.headers['Location'].endswith( - '/oauth1/server/authorize?oauth_token=zxcv')) + self.assertTrue( + response.headers["Location"].endswith( + "/oauth1/server/authorize?oauth_token=zxcv" + ) + ) # the cookie is base64('zxcv')|base64('1234') self.assertTrue( - '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'], - response.headers['Set-Cookie']) - - def test_twitter_redirect_legacy(self): - self.base_twitter_redirect('/legacy/twitter/client/login') + '_oauth_request_token="enhjdg==|MTIzNA=="' + in response.headers["Set-Cookie"], + response.headers["Set-Cookie"], + ) def test_twitter_redirect(self): - self.base_twitter_redirect('/twitter/client/login') - - def test_twitter_redirect_gen_engine(self): - self.base_twitter_redirect('/twitter/client/login_gen_engine') + self.base_twitter_redirect("/twitter/client/login") def test_twitter_redirect_gen_coroutine(self): - self.base_twitter_redirect('/twitter/client/login_gen_coroutine') + self.base_twitter_redirect("/twitter/client/login_gen_coroutine") + + def test_twitter_authenticate_redirect(self): + response = self.fetch("/twitter/client/authenticate", follow_redirects=False) + self.assertEqual(response.code, 302) + self.assertTrue( + response.headers["Location"].endswith( + "/twitter/server/authenticate?oauth_token=zxcv" + ), + response.headers["Location"], + ) + # the cookie is base64('zxcv')|base64('1234') + self.assertTrue( + '_oauth_request_token="enhjdg==|MTIzNA=="' + in response.headers["Set-Cookie"], + response.headers["Set-Cookie"], + ) def test_twitter_get_user(self): response = self.fetch( - '/twitter/client/login?oauth_token=zxcv', - headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='}) + "/twitter/client/login?oauth_token=zxcv", + headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="}, + ) response.rethrow() parsed = json_decode(response.body) - self.assertEqual(parsed, - {u'access_token': {u'key': u'hjkl', - u'screen_name': u'foo', - u'secret': u'vbnm'}, - u'name': u'Foo', - u'screen_name': u'foo', - u'username': u'foo'}) - - def test_twitter_show_user_legacy(self): - response = self.fetch('/legacy/twitter/client/show_user?name=somebody') - response.rethrow() - self.assertEqual(json_decode(response.body), - {'name': 'Somebody', 'screen_name': 'somebody'}) - - def test_twitter_show_user_error_legacy(self): - with ExpectLog(gen_log, 'Error response HTTP 500'): - response = self.fetch('/legacy/twitter/client/show_user?name=error') - self.assertEqual(response.code, 500) - self.assertEqual(response.body, b'error from twitter request') + self.assertEqual( + parsed, + { + u"access_token": { + u"key": u"hjkl", + u"screen_name": u"foo", + u"secret": u"vbnm", + }, + u"name": u"Foo", + u"screen_name": u"foo", + u"username": u"foo", + }, + ) def test_twitter_show_user(self): - response = self.fetch('/twitter/client/show_user?name=somebody') + response = self.fetch("/twitter/client/show_user?name=somebody") response.rethrow() - self.assertEqual(json_decode(response.body), - {'name': 'Somebody', 'screen_name': 'somebody'}) + self.assertEqual( + json_decode(response.body), {"name": "Somebody", "screen_name": "somebody"} + ) def test_twitter_show_user_error(self): - response = self.fetch('/twitter/client/show_user?name=error') + response = self.fetch("/twitter/client/show_user?name=error") self.assertEqual(response.code, 500) - self.assertEqual(response.body, b'error from twitter request') + self.assertEqual(response.body, b"error from twitter request") class GoogleLoginHandler(RequestHandler, GoogleOAuth2Mixin): def initialize(self, test): self.test = test - self._OAUTH_REDIRECT_URI = test.get_url('/client/login') - self._OAUTH_AUTHORIZE_URL = test.get_url('/google/oauth2/authorize') - self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/google/oauth2/token') + self._OAUTH_REDIRECT_URI = test.get_url("/client/login") + self._OAUTH_AUTHORIZE_URL = test.get_url("/google/oauth2/authorize") + self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/google/oauth2/token") @gen.coroutine def get(self): - code = self.get_argument('code', None) + code = self.get_argument("code", None) if code is not None: # retrieve authenticate google user - access = yield self.get_authenticated_user(self._OAUTH_REDIRECT_URI, - code) + access = yield self.get_authenticated_user(self._OAUTH_REDIRECT_URI, code) user = yield self.oauth2_request( self.test.get_url("/google/oauth2/userinfo"), - access_token=access["access_token"]) + access_token=access["access_token"], + ) # return the user and access token as json user["access_token"] = access["access_token"] self.write(user) else: - yield self.authorize_redirect( + self.authorize_redirect( redirect_uri=self._OAUTH_REDIRECT_URI, - client_id=self.settings['google_oauth']['key'], - client_secret=self.settings['google_oauth']['secret'], - scope=['profile', 'email'], - response_type='code', - extra_params={'prompt': 'select_account'}) + client_id=self.settings["google_oauth"]["key"], + client_secret=self.settings["google_oauth"]["secret"], + scope=["profile", "email"], + response_type="code", + extra_params={"prompt": "select_account"}, + ) class GoogleOAuth2AuthorizeHandler(RequestHandler): def get(self): # issue a fake auth code and redirect to redirect_uri - code = 'fake-authorization-code' - self.redirect(url_concat(self.get_argument('redirect_uri'), - dict(code=code))) + code = "fake-authorization-code" + self.redirect(url_concat(self.get_argument("redirect_uri"), dict(code=code))) class GoogleOAuth2TokenHandler(RequestHandler): def post(self): - assert self.get_argument('code') == 'fake-authorization-code' + assert self.get_argument("code") == "fake-authorization-code" # issue a fake token - self.finish({ - 'access_token': 'fake-access-token', - 'expires_in': 'never-expires' - }) + self.finish( + {"access_token": "fake-access-token", "expires_in": "never-expires"} + ) class GoogleOAuth2UserinfoHandler(RequestHandler): def get(self): - assert self.get_argument('access_token') == 'fake-access-token' + assert self.get_argument("access_token") == "fake-access-token" # return a fake user - self.finish({ - 'name': 'Foo', - 'email': 'foo@example.com' - }) + self.finish({"name": "Foo", "email": "foo@example.com"}) class GoogleOAuth2Test(AsyncHTTPTestCase): @@ -671,22 +585,25 @@ def get_app(self): return Application( [ # test endpoints - ('/client/login', GoogleLoginHandler, dict(test=self)), - + ("/client/login", GoogleLoginHandler, dict(test=self)), # simulated google authorization server endpoints - ('/google/oauth2/authorize', GoogleOAuth2AuthorizeHandler), - ('/google/oauth2/token', GoogleOAuth2TokenHandler), - ('/google/oauth2/userinfo', GoogleOAuth2UserinfoHandler), + ("/google/oauth2/authorize", GoogleOAuth2AuthorizeHandler), + ("/google/oauth2/token", GoogleOAuth2TokenHandler), + ("/google/oauth2/userinfo", GoogleOAuth2UserinfoHandler), ], google_oauth={ - "key": 'fake_google_client_id', - "secret": 'fake_google_client_secret' - }) + "key": "fake_google_client_id", + "secret": "fake_google_client_secret", + }, + ) def test_google_login(self): - response = self.fetch('/client/login') - self.assertDictEqual({ - u'name': u'Foo', - u'email': u'foo@example.com', - u'access_token': u'fake-access-token', - }, json_decode(response.body)) + response = self.fetch("/client/login") + self.assertDictEqual( + { + u"name": u"Foo", + u"email": u"foo@example.com", + u"access_token": u"fake-access-token", + }, + json_decode(response.body), + ) diff --git a/tornado/test/autoreload_test.py b/tornado/test/autoreload_test.py index 6a9729dbbe..be481e106f 100644 --- a/tornado/test/autoreload_test.py +++ b/tornado/test/autoreload_test.py @@ -1,14 +1,31 @@ -from __future__ import absolute_import, division, print_function import os +import shutil import subprocess from subprocess import Popen import sys from tempfile import mkdtemp +import time +import unittest -from tornado.test.util import unittest +class AutoreloadTest(unittest.TestCase): + def setUp(self): + self.path = mkdtemp() + + def tearDown(self): + try: + shutil.rmtree(self.path) + except OSError: + # Windows disallows deleting files that are in use by + # another process, and even though we've waited for our + # child process below, it appears that its lock on these + # files is not guaranteed to be released by this point. + # Sleep and try again (once). + time.sleep(1) + shutil.rmtree(self.path) -MAIN = """\ + def test_reload_module(self): + main = """\ import os import sys @@ -24,25 +41,87 @@ autoreload._reload() """ - -class AutoreloadTest(unittest.TestCase): - def test_reload_module(self): # Create temporary test application - path = mkdtemp() - os.mkdir(os.path.join(path, 'testapp')) - open(os.path.join(path, 'testapp/__init__.py'), 'w').close() - with open(os.path.join(path, 'testapp/__main__.py'), 'w') as f: - f.write(MAIN) + os.mkdir(os.path.join(self.path, "testapp")) + open(os.path.join(self.path, "testapp/__init__.py"), "w").close() + with open(os.path.join(self.path, "testapp/__main__.py"), "w") as f: + f.write(main) # Make sure the tornado module under test is available to the test # application pythonpath = os.getcwd() - if 'PYTHONPATH' in os.environ: - pythonpath += os.pathsep + os.environ['PYTHONPATH'] + if "PYTHONPATH" in os.environ: + pythonpath += os.pathsep + os.environ["PYTHONPATH"] p = Popen( - [sys.executable, '-m', 'testapp'], stdout=subprocess.PIPE, - cwd=path, env=dict(os.environ, PYTHONPATH=pythonpath), - universal_newlines=True) + [sys.executable, "-m", "testapp"], + stdout=subprocess.PIPE, + cwd=self.path, + env=dict(os.environ, PYTHONPATH=pythonpath), + universal_newlines=True, + ) out = p.communicate()[0] - self.assertEqual(out, 'Starting\nStarting\n') + self.assertEqual(out, "Starting\nStarting\n") + + def test_reload_wrapper_preservation(self): + # This test verifies that when `python -m tornado.autoreload` + # is used on an application that also has an internal + # autoreload, the reload wrapper is preserved on restart. + main = """\ +import os +import sys + +# This import will fail if path is not set up correctly +import testapp + +if 'tornado.autoreload' not in sys.modules: + raise Exception('started without autoreload wrapper') + +import tornado.autoreload + +print('Starting') +sys.stdout.flush() +if 'TESTAPP_STARTED' not in os.environ: + os.environ['TESTAPP_STARTED'] = '1' + # Simulate an internal autoreload (one not caused + # by the wrapper). + tornado.autoreload._reload() +else: + # Exit directly so autoreload doesn't catch it. + os._exit(0) +""" + + # Create temporary test application + os.mkdir(os.path.join(self.path, "testapp")) + init_file = os.path.join(self.path, "testapp", "__init__.py") + open(init_file, "w").close() + main_file = os.path.join(self.path, "testapp", "__main__.py") + with open(main_file, "w") as f: + f.write(main) + + # Make sure the tornado module under test is available to the test + # application + pythonpath = os.getcwd() + if "PYTHONPATH" in os.environ: + pythonpath += os.pathsep + os.environ["PYTHONPATH"] + + autoreload_proc = Popen( + [sys.executable, "-m", "tornado.autoreload", "-m", "testapp"], + stdout=subprocess.PIPE, + cwd=self.path, + env=dict(os.environ, PYTHONPATH=pythonpath), + universal_newlines=True, + ) + + # This timeout needs to be fairly generous for pypy due to jit + # warmup costs. + for i in range(40): + if autoreload_proc.poll() is not None: + break + time.sleep(0.1) + else: + autoreload_proc.kill() + raise Exception("subprocess failed to terminate") + + out = autoreload_proc.communicate()[0] + self.assertEqual(out, "Starting\n" * 2) diff --git a/tornado/test/concurrent_test.py b/tornado/test/concurrent_test.py index 1df0532cc5..b121c6971a 100644 --- a/tornado/test/concurrent_test.py +++ b/tornado/test/concurrent_test.py @@ -12,39 +12,28 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function - -import gc +from concurrent import futures import logging import re import socket -import sys -import traceback -import warnings - -from tornado.concurrent import (Future, return_future, ReturnValueIgnoredError, - run_on_executor, future_set_result_unless_cancelled) +import typing +import unittest + +from tornado.concurrent import ( + Future, + run_on_executor, + future_set_result_unless_cancelled, +) from tornado.escape import utf8, to_unicode from tornado import gen -from tornado.ioloop import IOLoop from tornado.iostream import IOStream -from tornado.log import app_log -from tornado import stack_context from tornado.tcpserver import TCPServer -from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test -from tornado.test.util import unittest, skipBefore35, exec_test, ignore_deprecation - - -try: - from concurrent import futures -except ImportError: - futures = None +from tornado.testing import AsyncTestCase, bind_unused_port, gen_test class MiscFutureTest(AsyncTestCase): - def test_future_set_result_unless_cancelled(self): - fut = Future() + fut = Future() # type: Future[int] future_set_result_unless_cancelled(fut, 42) self.assertEqual(fut.result(), 42) self.assertFalse(fut.cancelled()) @@ -58,193 +47,6 @@ def test_future_set_result_unless_cancelled(self): self.assertEqual(fut.result(), 42) -class ReturnFutureTest(AsyncTestCase): - @return_future - def sync_future(self, callback): - callback(42) - - @return_future - def async_future(self, callback): - self.io_loop.add_callback(callback, 42) - - @return_future - def immediate_failure(self, callback): - 1 / 0 - - @return_future - def delayed_failure(self, callback): - self.io_loop.add_callback(lambda: 1 / 0) - - @return_future - def return_value(self, callback): - # Note that the result of both running the callback and returning - # a value (or raising an exception) is unspecified; with current - # implementations the last event prior to callback resolution wins. - return 42 - - @return_future - def no_result_future(self, callback): - callback() - - def test_immediate_failure(self): - with self.assertRaises(ZeroDivisionError): - # The caller sees the error just like a normal function. - self.immediate_failure(callback=self.stop) - # The callback is not run because the function failed synchronously. - self.io_loop.add_timeout(self.io_loop.time() + 0.05, self.stop) - result = self.wait() - self.assertIs(result, None) - - def test_return_value(self): - with self.assertRaises(ReturnValueIgnoredError): - self.return_value(callback=self.stop) - - def test_callback_kw(self): - with ignore_deprecation(): - future = self.sync_future(callback=self.stop) - result = self.wait() - self.assertEqual(result, 42) - self.assertEqual(future.result(), 42) - - def test_callback_positional(self): - # When the callback is passed in positionally, future_wrap shouldn't - # add another callback in the kwargs. - with ignore_deprecation(): - future = self.sync_future(self.stop) - result = self.wait() - self.assertEqual(result, 42) - self.assertEqual(future.result(), 42) - - def test_no_callback(self): - future = self.sync_future() - self.assertEqual(future.result(), 42) - - def test_none_callback_kw(self): - # explicitly pass None as callback - future = self.sync_future(callback=None) - self.assertEqual(future.result(), 42) - - def test_none_callback_pos(self): - future = self.sync_future(None) - self.assertEqual(future.result(), 42) - - def test_async_future(self): - future = self.async_future() - self.assertFalse(future.done()) - self.io_loop.add_future(future, self.stop) - future2 = self.wait() - self.assertIs(future, future2) - self.assertEqual(future.result(), 42) - - @gen_test - def test_async_future_gen(self): - result = yield self.async_future() - self.assertEqual(result, 42) - - def test_delayed_failure(self): - future = self.delayed_failure() - self.io_loop.add_future(future, self.stop) - future2 = self.wait() - self.assertIs(future, future2) - with self.assertRaises(ZeroDivisionError): - future.result() - - def test_kw_only_callback(self): - @return_future - def f(**kwargs): - kwargs['callback'](42) - future = f() - self.assertEqual(future.result(), 42) - - def test_error_in_callback(self): - with ignore_deprecation(): - self.sync_future(callback=lambda future: 1 / 0) - # The exception gets caught by our StackContext and will be re-raised - # when we wait. - self.assertRaises(ZeroDivisionError, self.wait) - - def test_no_result_future(self): - with ignore_deprecation(): - future = self.no_result_future(self.stop) - result = self.wait() - self.assertIs(result, None) - # result of this future is undefined, but not an error - future.result() - - def test_no_result_future_callback(self): - with ignore_deprecation(): - future = self.no_result_future(callback=lambda: self.stop()) - result = self.wait() - self.assertIs(result, None) - future.result() - - @gen_test - def test_future_traceback_legacy(self): - with ignore_deprecation(): - @return_future - @gen.engine - def f(callback): - yield gen.Task(self.io_loop.add_callback) - try: - 1 / 0 - except ZeroDivisionError: - self.expected_frame = traceback.extract_tb( - sys.exc_info()[2], limit=1)[0] - raise - try: - yield f() - self.fail("didn't get expected exception") - except ZeroDivisionError: - tb = traceback.extract_tb(sys.exc_info()[2]) - self.assertIn(self.expected_frame, tb) - - @gen_test - def test_future_traceback(self): - @gen.coroutine - def f(): - yield gen.moment - try: - 1 / 0 - except ZeroDivisionError: - self.expected_frame = traceback.extract_tb( - sys.exc_info()[2], limit=1)[0] - raise - try: - yield f() - self.fail("didn't get expected exception") - except ZeroDivisionError: - tb = traceback.extract_tb(sys.exc_info()[2]) - self.assertIn(self.expected_frame, tb) - - @gen_test - def test_uncaught_exception_log(self): - if IOLoop.configured_class().__name__.endswith('AsyncIOLoop'): - # Install an exception handler that mirrors our - # non-asyncio logging behavior. - def exc_handler(loop, context): - app_log.error('%s: %s', context['message'], - type(context.get('exception'))) - self.io_loop.asyncio_loop.set_exception_handler(exc_handler) - - @gen.coroutine - def f(): - yield gen.moment - 1 / 0 - - g = f() - - with ExpectLog(app_log, - "(?s)Future.* exception was never retrieved:" - ".*ZeroDivisionError"): - yield gen.moment - yield gen.moment - # For some reason, TwistedIOLoop and pypy3 need a third iteration - # in order to drain references to the future - yield gen.moment - del g - gc.collect() # for PyPy - - # The following series of classes demonstrate and test various styles # of use, with and without generators and futures. @@ -271,79 +73,36 @@ def __init__(self, port): self.port = port def process_response(self, data): - status, message = re.match('(.*)\t(.*)\n', to_unicode(data)).groups() - if status == 'ok': + m = re.match("(.*)\t(.*)\n", to_unicode(data)) + if m is None: + raise Exception("did not match") + status, message = m.groups() + if status == "ok": return message else: raise CapError(message) -class ManualCapClient(BaseCapClient): - def capitalize(self, request_data, callback=None): - logging.debug("capitalize") - self.request_data = request_data - self.stream = IOStream(socket.socket()) - self.stream.connect(('127.0.0.1', self.port), - callback=self.handle_connect) - self.future = Future() - if callback is not None: - self.future.add_done_callback( - stack_context.wrap(lambda future: callback(future.result()))) - return self.future - - def handle_connect(self): - logging.debug("handle_connect") - self.stream.write(utf8(self.request_data + "\n")) - self.stream.read_until(b'\n', callback=self.handle_read) - - def handle_read(self, data): - logging.debug("handle_read") - self.stream.close() - try: - self.future.set_result(self.process_response(data)) - except CapError as e: - self.future.set_exception(e) - - -class DecoratorCapClient(BaseCapClient): - @return_future - def capitalize(self, request_data, callback): - logging.debug("capitalize") - self.request_data = request_data - self.stream = IOStream(socket.socket()) - self.stream.connect(('127.0.0.1', self.port), - callback=self.handle_connect) - self.callback = callback - - def handle_connect(self): - logging.debug("handle_connect") - self.stream.write(utf8(self.request_data + "\n")) - self.stream.read_until(b'\n', callback=self.handle_read) - - def handle_read(self, data): - logging.debug("handle_read") - self.stream.close() - self.callback(self.process_response(data)) - - class GeneratorCapClient(BaseCapClient): @gen.coroutine def capitalize(self, request_data): - logging.debug('capitalize') + logging.debug("capitalize") stream = IOStream(socket.socket()) - logging.debug('connecting') - yield stream.connect(('127.0.0.1', self.port)) - stream.write(utf8(request_data + '\n')) - logging.debug('reading') - data = yield stream.read_until(b'\n') - logging.debug('returning') + logging.debug("connecting") + yield stream.connect(("127.0.0.1", self.port)) + stream.write(utf8(request_data + "\n")) + logging.debug("reading") + data = yield stream.read_until(b"\n") + logging.debug("returning") stream.close() raise gen.Return(self.process_response(data)) class ClientTestMixin(object): + client_class = None # type: typing.Callable + def setUp(self): - super(ClientTestMixin, self).setUp() # type: ignore + super().setUp() # type: ignore self.server = CapServer() sock, port = bind_unused_port() self.server.add_sockets([sock]) @@ -351,79 +110,41 @@ def setUp(self): def tearDown(self): self.server.stop() - super(ClientTestMixin, self).tearDown() # type: ignore + super().tearDown() # type: ignore - def test_callback(self): - with ignore_deprecation(): - self.client.capitalize("hello", callback=self.stop) - result = self.wait() - self.assertEqual(result, "HELLO") - - def test_callback_error(self): - with ignore_deprecation(): - self.client.capitalize("HELLO", callback=self.stop) - self.assertRaisesRegexp(CapError, "already capitalized", self.wait) - - def test_future(self): + def test_future(self: typing.Any): future = self.client.capitalize("hello") self.io_loop.add_future(future, self.stop) self.wait() self.assertEqual(future.result(), "HELLO") - def test_future_error(self): + def test_future_error(self: typing.Any): future = self.client.capitalize("HELLO") self.io_loop.add_future(future, self.stop) self.wait() - self.assertRaisesRegexp(CapError, "already capitalized", future.result) + self.assertRaisesRegexp(CapError, "already capitalized", future.result) # type: ignore - def test_generator(self): + def test_generator(self: typing.Any): @gen.coroutine def f(): result = yield self.client.capitalize("hello") self.assertEqual(result, "HELLO") + self.io_loop.run_sync(f) - def test_generator_error(self): + def test_generator_error(self: typing.Any): @gen.coroutine def f(): with self.assertRaisesRegexp(CapError, "already capitalized"): yield self.client.capitalize("HELLO") - self.io_loop.run_sync(f) - - -class ManualClientTest(ClientTestMixin, AsyncTestCase): - client_class = ManualCapClient - - def setUp(self): - self.warning_catcher = warnings.catch_warnings() - self.warning_catcher.__enter__() - warnings.simplefilter('ignore', DeprecationWarning) - super(ManualClientTest, self).setUp() - - def tearDown(self): - super(ManualClientTest, self).tearDown() - self.warning_catcher.__exit__(None, None, None) - -class DecoratorClientTest(ClientTestMixin, AsyncTestCase): - client_class = DecoratorCapClient - - def setUp(self): - self.warning_catcher = warnings.catch_warnings() - self.warning_catcher.__enter__() - warnings.simplefilter('ignore', DeprecationWarning) - super(DecoratorClientTest, self).setUp() - - def tearDown(self): - super(DecoratorClientTest, self).tearDown() - self.warning_catcher.__exit__(None, None, None) + self.io_loop.run_sync(f) class GeneratorClientTest(ClientTestMixin, AsyncTestCase): client_class = GeneratorCapClient -@unittest.skipIf(futures is None, "concurrent.futures module not present") class RunOnExecutorTest(AsyncTestCase): @gen_test def test_no_calling(self): @@ -459,7 +180,7 @@ class Object(object): def __init__(self): self.__executor = futures.thread.ThreadPoolExecutor(1) - @run_on_executor(executor='_Object__executor') + @run_on_executor(executor="_Object__executor") def f(self): return 42 @@ -467,7 +188,6 @@ def f(self): answer = yield o.f() self.assertEqual(answer, 42) - @skipBefore35 @gen_test def test_async_await(self): class Object(object): @@ -479,14 +199,14 @@ def f(self): return 42 o = Object() - namespace = exec_test(globals(), locals(), """ + async def f(): answer = await o.f() return answer - """) - result = yield namespace['f']() + + result = yield f() self.assertEqual(result, 42) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tornado/test/curl_httpclient_test.py b/tornado/test/curl_httpclient_test.py index d0cfa979b4..99af293367 100644 --- a/tornado/test/curl_httpclient_test.py +++ b/tornado/test/curl_httpclient_test.py @@ -1,22 +1,16 @@ -# coding: utf-8 -from __future__ import absolute_import, division, print_function - from hashlib import md5 +import unittest from tornado.escape import utf8 -from tornado.httpclient import HTTPRequest, HTTPClientError -from tornado.locks import Event -from tornado.stack_context import ExceptionStackContext -from tornado.testing import AsyncHTTPTestCase, gen_test +from tornado.testing import AsyncHTTPTestCase from tornado.test import httpclient_test -from tornado.test.util import unittest, ignore_deprecation from tornado.web import Application, RequestHandler try: - import pycurl # type: ignore + import pycurl except ImportError: - pycurl = None + pycurl = None # type: ignore if pycurl is not None: from tornado.curl_httpclient import CurlAsyncHTTPClient @@ -32,42 +26,48 @@ def get_http_client(self): class DigestAuthHandler(RequestHandler): + def initialize(self, username, password): + self.username = username + self.password = password + def get(self): - realm = 'test' - opaque = 'asdf' + realm = "test" + opaque = "asdf" # Real implementations would use a random nonce. nonce = "1234" - username = 'foo' - password = 'bar' - auth_header = self.request.headers.get('Authorization', None) + auth_header = self.request.headers.get("Authorization", None) if auth_header is not None: - auth_mode, params = auth_header.split(' ', 1) - assert auth_mode == 'Digest' + auth_mode, params = auth_header.split(" ", 1) + assert auth_mode == "Digest" param_dict = {} - for pair in params.split(','): - k, v = pair.strip().split('=', 1) + for pair in params.split(","): + k, v = pair.strip().split("=", 1) if v[0] == '"' and v[-1] == '"': v = v[1:-1] param_dict[k] = v - assert param_dict['realm'] == realm - assert param_dict['opaque'] == opaque - assert param_dict['nonce'] == nonce - assert param_dict['username'] == username - assert param_dict['uri'] == self.request.path - h1 = md5(utf8('%s:%s:%s' % (username, realm, password))).hexdigest() - h2 = md5(utf8('%s:%s' % (self.request.method, - self.request.path))).hexdigest() - digest = md5(utf8('%s:%s:%s' % (h1, nonce, h2))).hexdigest() - if digest == param_dict['response']: - self.write('ok') + assert param_dict["realm"] == realm + assert param_dict["opaque"] == opaque + assert param_dict["nonce"] == nonce + assert param_dict["username"] == self.username + assert param_dict["uri"] == self.request.path + h1 = md5( + utf8("%s:%s:%s" % (self.username, realm, self.password)) + ).hexdigest() + h2 = md5( + utf8("%s:%s" % (self.request.method, self.request.path)) + ).hexdigest() + digest = md5(utf8("%s:%s:%s" % (h1, nonce, h2))).hexdigest() + if digest == param_dict["response"]: + self.write("ok") else: - self.write('fail') + self.write("fail") else: self.set_status(401) - self.set_header('WWW-Authenticate', - 'Digest realm="%s", nonce="%s", opaque="%s"' % - (realm, nonce, opaque)) + self.set_header( + "WWW-Authenticate", + 'Digest realm="%s", nonce="%s", opaque="%s"' % (realm, nonce, opaque), + ) class CustomReasonHandler(RequestHandler): @@ -83,62 +83,47 @@ def get(self): @unittest.skipIf(pycurl is None, "pycurl module not present") class CurlHTTPClientTestCase(AsyncHTTPTestCase): def setUp(self): - super(CurlHTTPClientTestCase, self).setUp() + super().setUp() self.http_client = self.create_client() def get_app(self): - return Application([ - ('/digest', DigestAuthHandler), - ('/custom_reason', CustomReasonHandler), - ('/custom_fail_reason', CustomFailReasonHandler), - ]) + return Application( + [ + ("/digest", DigestAuthHandler, {"username": "foo", "password": "bar"}), + ( + "/digest_non_ascii", + DigestAuthHandler, + {"username": "foo", "password": "barユ£"}, + ), + ("/custom_reason", CustomReasonHandler), + ("/custom_fail_reason", CustomFailReasonHandler), + ] + ) def create_client(self, **kwargs): - return CurlAsyncHTTPClient(force_instance=True, - defaults=dict(allow_ipv6=False), - **kwargs) - - @gen_test - def test_prepare_curl_callback_stack_context(self): - exc_info = [] - error_event = Event() - - def error_handler(typ, value, tb): - exc_info.append((typ, value, tb)) - error_event.set() - return True - - with ExceptionStackContext(error_handler): - request = HTTPRequest(self.get_url('/custom_reason'), - prepare_curl_callback=lambda curl: 1 / 0) - yield [error_event.wait(), self.http_client.fetch(request)] - self.assertEqual(1, len(exc_info)) - self.assertIs(exc_info[0][0], ZeroDivisionError) + return CurlAsyncHTTPClient( + force_instance=True, defaults=dict(allow_ipv6=False), **kwargs + ) def test_digest_auth(self): - response = self.fetch('/digest', auth_mode='digest', - auth_username='foo', auth_password='bar') - self.assertEqual(response.body, b'ok') + response = self.fetch( + "/digest", auth_mode="digest", auth_username="foo", auth_password="bar" + ) + self.assertEqual(response.body, b"ok") def test_custom_reason(self): - response = self.fetch('/custom_reason') + response = self.fetch("/custom_reason") self.assertEqual(response.reason, "Custom reason") def test_fail_custom_reason(self): - response = self.fetch('/custom_fail_reason') + response = self.fetch("/custom_fail_reason") self.assertEqual(str(response.error), "HTTP 400: Custom reason") - def test_failed_setup(self): - self.http_client = self.create_client(max_clients=1) - for i in range(5): - with ignore_deprecation(): - response = self.fetch(u'/ユニコード') - self.assertIsNot(response.error, None) - - with self.assertRaises((UnicodeEncodeError, HTTPClientError)): - # This raises UnicodeDecodeError on py3 and - # HTTPClientError(404) on py2. The main motivation of - # this test is to ensure that the UnicodeEncodeError - # during the setup phase doesn't lead the request to - # be dropped on the floor. - response = self.fetch(u'/ユニコード', raise_error=True) + def test_digest_auth_non_ascii(self): + response = self.fetch( + "/digest_non_ascii", + auth_mode="digest", + auth_username="foo", + auth_password="barユ£", + ) + self.assertEqual(response.body, b"ok") diff --git a/tornado/test/escape_test.py b/tornado/test/escape_test.py index f2f2902a0a..d8f95e426e 100644 --- a/tornado/test/escape_test.py +++ b/tornado/test/escape_test.py @@ -1,139 +1,212 @@ -from __future__ import absolute_import, division, print_function +import unittest import tornado.escape from tornado.escape import ( - utf8, xhtml_escape, xhtml_unescape, url_escape, url_unescape, - to_unicode, json_decode, json_encode, squeeze, recursive_unicode, + utf8, + xhtml_escape, + xhtml_unescape, + url_escape, + url_unescape, + to_unicode, + json_decode, + json_encode, + squeeze, + recursive_unicode, ) from tornado.util import unicode_type -from tornado.test.util import unittest + +from typing import List, Tuple, Union, Dict, Any # noqa: F401 linkify_tests = [ # (input, linkify_kwargs, expected_output) - - ("hello http://world.com/!", {}, - u'hello http://world.com/!'), - - ("hello http://world.com/with?param=true&stuff=yes", {}, - u'hello http://world.com/with?param=true&stuff=yes'), # noqa: E501 - + ( + "hello http://world.com/!", + {}, + u'hello http://world.com/!', + ), + ( + "hello http://world.com/with?param=true&stuff=yes", + {}, + u'hello http://world.com/with?param=true&stuff=yes', # noqa: E501 + ), # an opened paren followed by many chars killed Gruber's regex - ("http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", {}, - u'http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'), # noqa: E501 - + ( + "http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + {}, + u'http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', # noqa: E501 + ), # as did too many dots at the end - ("http://url.com/withmany.......................................", {}, - u'http://url.com/withmany.......................................'), # noqa: E501 - - ("http://url.com/withmany((((((((((((((((((((((((((((((((((a)", {}, - u'http://url.com/withmany((((((((((((((((((((((((((((((((((a)'), # noqa: E501 - + ( + "http://url.com/withmany.......................................", + {}, + u'http://url.com/withmany.......................................', # noqa: E501 + ), + ( + "http://url.com/withmany((((((((((((((((((((((((((((((((((a)", + {}, + u'http://url.com/withmany((((((((((((((((((((((((((((((((((a)', # noqa: E501 + ), # some examples from http://daringfireball.net/2009/11/liberal_regex_for_matching_urls # plus a fex extras (such as multiple parentheses). - ("http://foo.com/blah_blah", {}, - u'http://foo.com/blah_blah'), - - ("http://foo.com/blah_blah/", {}, - u'http://foo.com/blah_blah/'), - - ("(Something like http://foo.com/blah_blah)", {}, - u'(Something like http://foo.com/blah_blah)'), - - ("http://foo.com/blah_blah_(wikipedia)", {}, - u'http://foo.com/blah_blah_(wikipedia)'), - - ("http://foo.com/blah_(blah)_(wikipedia)_blah", {}, - u'http://foo.com/blah_(blah)_(wikipedia)_blah'), # noqa: E501 - - ("(Something like http://foo.com/blah_blah_(wikipedia))", {}, - u'(Something like http://foo.com/blah_blah_(wikipedia))'), # noqa: E501 - - ("http://foo.com/blah_blah.", {}, - u'http://foo.com/blah_blah.'), - - ("http://foo.com/blah_blah/.", {}, - u'http://foo.com/blah_blah/.'), - - ("", {}, - u'<http://foo.com/blah_blah>'), - - ("", {}, - u'<http://foo.com/blah_blah/>'), - - ("http://foo.com/blah_blah,", {}, - u'http://foo.com/blah_blah,'), - - ("http://www.example.com/wpstyle/?p=364.", {}, - u'http://www.example.com/wpstyle/?p=364.'), - - ("rdar://1234", - {"permitted_protocols": ["http", "rdar"]}, - u'rdar://1234'), - - ("rdar:/1234", - {"permitted_protocols": ["rdar"]}, - u'rdar:/1234'), - - ("http://userid:password@example.com:8080", {}, - u'http://userid:password@example.com:8080'), # noqa: E501 - - ("http://userid@example.com", {}, - u'http://userid@example.com'), - - ("http://userid@example.com:8080", {}, - u'http://userid@example.com:8080'), - - ("http://userid:password@example.com", {}, - u'http://userid:password@example.com'), - - ("message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e", - {"permitted_protocols": ["http", "message"]}, - u'' - u'message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e'), - - (u"http://\u27a1.ws/\u4a39", {}, - u'http://\u27a1.ws/\u4a39'), - - ("http://example.com", {}, - u'<tag>http://example.com</tag>'), - - ("Just a www.example.com link.", {}, - u'Just a www.example.com link.'), - - ("Just a www.example.com link.", - {"require_protocol": True}, - u'Just a www.example.com link.'), - - ("A http://reallylong.com/link/that/exceedsthelenglimit.html", - {"require_protocol": True, "shorten": True}, - u'A http://reallylong.com/link...'), # noqa: E501 - - ("A http://reallylongdomainnamethatwillbetoolong.com/hi!", - {"shorten": True}, - u'A http://reallylongdomainnametha...!'), # noqa: E501 - - ("A file:///passwords.txt and http://web.com link", {}, - u'A file:///passwords.txt and http://web.com link'), - - ("A file:///passwords.txt and http://web.com link", - {"permitted_protocols": ["file"]}, - u'A file:///passwords.txt and http://web.com link'), - - ("www.external-link.com", - {"extra_params": 'rel="nofollow" class="external"'}, - u'www.external-link.com'), # noqa: E501 - - ("www.external-link.com and www.internal-link.com/blogs extra", - {"extra_params": lambda href: 'class="internal"' if href.startswith("http://www.internal-link.com") else 'rel="nofollow" class="external"'}, # noqa: E501 - u'www.external-link.com' # noqa: E501 - u' and www.internal-link.com/blogs extra'), # noqa: E501 - - ("www.external-link.com", - {"extra_params": lambda href: ' rel="nofollow" class="external" '}, - u'www.external-link.com'), # noqa: E501 -] + ( + "http://foo.com/blah_blah", + {}, + u'http://foo.com/blah_blah', + ), + ( + "http://foo.com/blah_blah/", + {}, + u'http://foo.com/blah_blah/', + ), + ( + "(Something like http://foo.com/blah_blah)", + {}, + u'(Something like http://foo.com/blah_blah)', + ), + ( + "http://foo.com/blah_blah_(wikipedia)", + {}, + u'http://foo.com/blah_blah_(wikipedia)', + ), + ( + "http://foo.com/blah_(blah)_(wikipedia)_blah", + {}, + u'http://foo.com/blah_(blah)_(wikipedia)_blah', # noqa: E501 + ), + ( + "(Something like http://foo.com/blah_blah_(wikipedia))", + {}, + u'(Something like http://foo.com/blah_blah_(wikipedia))', # noqa: E501 + ), + ( + "http://foo.com/blah_blah.", + {}, + u'http://foo.com/blah_blah.', + ), + ( + "http://foo.com/blah_blah/.", + {}, + u'http://foo.com/blah_blah/.', + ), + ( + "", + {}, + u'<http://foo.com/blah_blah>', + ), + ( + "", + {}, + u'<http://foo.com/blah_blah/>', + ), + ( + "http://foo.com/blah_blah,", + {}, + u'http://foo.com/blah_blah,', + ), + ( + "http://www.example.com/wpstyle/?p=364.", + {}, + u'http://www.example.com/wpstyle/?p=364.', # noqa: E501 + ), + ( + "rdar://1234", + {"permitted_protocols": ["http", "rdar"]}, + u'rdar://1234', + ), + ( + "rdar:/1234", + {"permitted_protocols": ["rdar"]}, + u'rdar:/1234', + ), + ( + "http://userid:password@example.com:8080", + {}, + u'http://userid:password@example.com:8080', # noqa: E501 + ), + ( + "http://userid@example.com", + {}, + u'http://userid@example.com', + ), + ( + "http://userid@example.com:8080", + {}, + u'http://userid@example.com:8080', + ), + ( + "http://userid:password@example.com", + {}, + u'http://userid:password@example.com', + ), + ( + "message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e", + {"permitted_protocols": ["http", "message"]}, + u'' + u"message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e", + ), + ( + u"http://\u27a1.ws/\u4a39", + {}, + u'http://\u27a1.ws/\u4a39', + ), + ( + "http://example.com", + {}, + u'<tag>http://example.com</tag>', + ), + ( + "Just a www.example.com link.", + {}, + u'Just a www.example.com link.', + ), + ( + "Just a www.example.com link.", + {"require_protocol": True}, + u"Just a www.example.com link.", + ), + ( + "A http://reallylong.com/link/that/exceedsthelenglimit.html", + {"require_protocol": True, "shorten": True}, + u'A http://reallylong.com/link...', # noqa: E501 + ), + ( + "A http://reallylongdomainnamethatwillbetoolong.com/hi!", + {"shorten": True}, + u'A http://reallylongdomainnametha...!', # noqa: E501 + ), + ( + "A file:///passwords.txt and http://web.com link", + {}, + u'A file:///passwords.txt and http://web.com link', + ), + ( + "A file:///passwords.txt and http://web.com link", + {"permitted_protocols": ["file"]}, + u'A file:///passwords.txt and http://web.com link', + ), + ( + "www.external-link.com", + {"extra_params": 'rel="nofollow" class="external"'}, + u'www.external-link.com', # noqa: E501 + ), + ( + "www.external-link.com and www.internal-link.com/blogs extra", + { + "extra_params": lambda href: 'class="internal"' + if href.startswith("http://www.internal-link.com") + else 'rel="nofollow" class="external"' + }, + u'www.external-link.com' # noqa: E501 + u' and www.internal-link.com/blogs extra', # noqa: E501 + ), + ( + "www.external-link.com", + {"extra_params": lambda href: ' rel="nofollow" class="external" '}, + u'www.external-link.com', # noqa: E501 + ), +] # type: List[Tuple[Union[str, bytes], Dict[str, Any], str]] class EscapeTestCase(unittest.TestCase): @@ -147,26 +220,24 @@ def test_xhtml_escape(self): ("", "<foo>"), (u"", u"<foo>"), (b"", b"<foo>"), - ("<>&\"'", "<>&"'"), ("&", "&amp;"), - (u"<\u00e9>", u"<\u00e9>"), (b"<\xc3\xa9>", b"<\xc3\xa9>"), - ] + ] # type: List[Tuple[Union[str, bytes], Union[str, bytes]]] for unescaped, escaped in tests: self.assertEqual(utf8(xhtml_escape(unescaped)), utf8(escaped)) self.assertEqual(utf8(unescaped), utf8(xhtml_unescape(escaped))) def test_xhtml_unescape_numeric(self): tests = [ - ('foo bar', 'foo bar'), - ('foo bar', 'foo bar'), - ('foo bar', 'foo bar'), - ('foo઼bar', u'foo\u0abcbar'), - ('foo&#xyz;bar', 'foo&#xyz;bar'), # invalid encoding - ('foo&#;bar', 'foo&#;bar'), # invalid encoding - ('foo&#x;bar', 'foo&#x;bar'), # invalid encoding + ("foo bar", "foo bar"), + ("foo bar", "foo bar"), + ("foo bar", "foo bar"), + ("foo઼bar", u"foo\u0abcbar"), + ("foo&#xyz;bar", "foo&#xyz;bar"), # invalid encoding + ("foo&#;bar", "foo&#;bar"), # invalid encoding + ("foo&#x;bar", "foo&#x;bar"), # invalid encoding ] for escaped, unescaped in tests: self.assertEqual(unescaped, xhtml_unescape(escaped)) @@ -174,20 +245,19 @@ def test_xhtml_unescape_numeric(self): def test_url_escape_unicode(self): tests = [ # byte strings are passed through as-is - (u'\u00e9'.encode('utf8'), '%C3%A9'), - (u'\u00e9'.encode('latin1'), '%E9'), - + (u"\u00e9".encode("utf8"), "%C3%A9"), + (u"\u00e9".encode("latin1"), "%E9"), # unicode strings become utf8 - (u'\u00e9', '%C3%A9'), - ] + (u"\u00e9", "%C3%A9"), + ] # type: List[Tuple[Union[str, bytes], str]] for unescaped, escaped in tests: self.assertEqual(url_escape(unescaped), escaped) def test_url_unescape_unicode(self): tests = [ - ('%C3%A9', u'\u00e9', 'utf8'), - ('%C3%A9', u'\u00c3\u00a9', 'latin1'), - ('%C3%A9', utf8(u'\u00e9'), None), + ("%C3%A9", u"\u00e9", "utf8"), + ("%C3%A9", u"\u00c3\u00a9", "latin1"), + ("%C3%A9", utf8(u"\u00e9"), None), ] for escaped, unescaped, encoding in tests: # input strings to url_unescape should only contain ascii @@ -197,17 +267,17 @@ def test_url_unescape_unicode(self): self.assertEqual(url_unescape(utf8(escaped), encoding), unescaped) def test_url_escape_quote_plus(self): - unescaped = '+ #%' - plus_escaped = '%2B+%23%25' - escaped = '%2B%20%23%25' + unescaped = "+ #%" + plus_escaped = "%2B+%23%25" + escaped = "%2B%20%23%25" self.assertEqual(url_escape(unescaped), plus_escaped) self.assertEqual(url_escape(unescaped, plus=False), escaped) self.assertEqual(url_unescape(plus_escaped), unescaped) self.assertEqual(url_unescape(escaped, plus=False), unescaped) - self.assertEqual(url_unescape(plus_escaped, encoding=None), - utf8(unescaped)) - self.assertEqual(url_unescape(escaped, encoding=None, plus=False), - utf8(unescaped)) + self.assertEqual(url_unescape(plus_escaped, encoding=None), utf8(unescaped)) + self.assertEqual( + url_unescape(escaped, encoding=None, plus=False), utf8(unescaped) + ) def test_escape_return_types(self): # On python2 the escape methods should generally return the same @@ -234,17 +304,19 @@ def test_json_encode(self): self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9") def test_squeeze(self): - self.assertEqual(squeeze(u'sequences of whitespace chars'), - u'sequences of whitespace chars') + self.assertEqual( + squeeze(u"sequences of whitespace chars"), + u"sequences of whitespace chars", + ) def test_recursive_unicode(self): tests = { - 'dict': {b"foo": b"bar"}, - 'list': [b"foo", b"bar"], - 'tuple': (b"foo", b"bar"), - 'bytes': b"foo" + "dict": {b"foo": b"bar"}, + "list": [b"foo", b"bar"], + "tuple": (b"foo", b"bar"), + "bytes": b"foo", } - self.assertEqual(recursive_unicode(tests['dict']), {u"foo": u"bar"}) - self.assertEqual(recursive_unicode(tests['list']), [u"foo", u"bar"]) - self.assertEqual(recursive_unicode(tests['tuple']), (u"foo", u"bar")) - self.assertEqual(recursive_unicode(tests['bytes']), u"foo") + self.assertEqual(recursive_unicode(tests["dict"]), {u"foo": u"bar"}) + self.assertEqual(recursive_unicode(tests["list"]), [u"foo", u"bar"]) + self.assertEqual(recursive_unicode(tests["tuple"]), (u"foo", u"bar")) + self.assertEqual(recursive_unicode(tests["bytes"]), u"foo") diff --git a/tornado/test/gen_test.py b/tornado/test/gen_test.py index 8b2f62b972..73c3387803 100644 --- a/tornado/test/gen_test.py +++ b/tornado/test/gen_test.py @@ -1,675 +1,32 @@ -from __future__ import absolute_import, division, print_function - +import asyncio +from concurrent import futures import gc -import contextlib import datetime -import functools import platform import sys -import textwrap import time import weakref -import warnings +import unittest -from tornado.concurrent import return_future, Future -from tornado.escape import url_escape -from tornado.httpclient import AsyncHTTPClient -from tornado.ioloop import IOLoop +from tornado.concurrent import Future from tornado.log import app_log -from tornado import stack_context from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test -from tornado.test.util import unittest, skipOnTravis, skipBefore33, skipBefore35, skipNotCPython, exec_test, ignore_deprecation # noqa: E501 -from tornado.web import Application, RequestHandler, asynchronous, HTTPError +from tornado.test.util import skipOnTravis, skipNotCPython +from tornado.web import Application, RequestHandler, HTTPError from tornado import gen try: - from concurrent import futures -except ImportError: - futures = None - -try: - import asyncio + import contextvars except ImportError: - asyncio = None - - -class GenEngineTest(AsyncTestCase): - def setUp(self): - self.warning_catcher = warnings.catch_warnings() - self.warning_catcher.__enter__() - warnings.simplefilter('ignore', DeprecationWarning) - super(GenEngineTest, self).setUp() - self.named_contexts = [] - - def tearDown(self): - super(GenEngineTest, self).tearDown() - self.warning_catcher.__exit__(None, None, None) - - def named_context(self, name): - @contextlib.contextmanager - def context(): - self.named_contexts.append(name) - try: - yield - finally: - self.assertEqual(self.named_contexts.pop(), name) - return context - - def run_gen(self, f): - f() - return self.wait() - - def delay_callback(self, iterations, callback, arg): - """Runs callback(arg) after a number of IOLoop iterations.""" - if iterations == 0: - callback(arg) - else: - self.io_loop.add_callback(functools.partial( - self.delay_callback, iterations - 1, callback, arg)) - - @return_future - def async_future(self, result, callback): - self.io_loop.add_callback(callback, result) - - @gen.coroutine - def async_exception(self, e): - yield gen.moment - raise e - - def test_no_yield(self): - @gen.engine - def f(): - self.stop() - self.run_gen(f) - - def test_inline_cb(self): - @gen.engine - def f(): - (yield gen.Callback("k1"))() - res = yield gen.Wait("k1") - self.assertTrue(res is None) - self.stop() - self.run_gen(f) - - def test_ioloop_cb(self): - @gen.engine - def f(): - self.io_loop.add_callback((yield gen.Callback("k1"))) - yield gen.Wait("k1") - self.stop() - self.run_gen(f) - - def test_exception_phase1(self): - @gen.engine - def f(): - 1 / 0 - self.assertRaises(ZeroDivisionError, self.run_gen, f) - - def test_exception_phase2(self): - @gen.engine - def f(): - self.io_loop.add_callback((yield gen.Callback("k1"))) - yield gen.Wait("k1") - 1 / 0 - self.assertRaises(ZeroDivisionError, self.run_gen, f) - - def test_exception_in_task_phase1(self): - def fail_task(callback): - 1 / 0 - - @gen.engine - def f(): - try: - yield gen.Task(fail_task) - raise Exception("did not get expected exception") - except ZeroDivisionError: - self.stop() - self.run_gen(f) - - def test_exception_in_task_phase2(self): - # This is the case that requires the use of stack_context in gen.engine - def fail_task(callback): - self.io_loop.add_callback(lambda: 1 / 0) - - @gen.engine - def f(): - try: - yield gen.Task(fail_task) - raise Exception("did not get expected exception") - except ZeroDivisionError: - self.stop() - self.run_gen(f) - - def test_with_arg(self): - @gen.engine - def f(): - (yield gen.Callback("k1"))(42) - res = yield gen.Wait("k1") - self.assertEqual(42, res) - self.stop() - self.run_gen(f) - - def test_with_arg_tuple(self): - @gen.engine - def f(): - (yield gen.Callback((1, 2)))((3, 4)) - res = yield gen.Wait((1, 2)) - self.assertEqual((3, 4), res) - self.stop() - self.run_gen(f) - - def test_key_reuse(self): - @gen.engine - def f(): - yield gen.Callback("k1") - yield gen.Callback("k1") - self.stop() - self.assertRaises(gen.KeyReuseError, self.run_gen, f) - - def test_key_reuse_tuple(self): - @gen.engine - def f(): - yield gen.Callback((1, 2)) - yield gen.Callback((1, 2)) - self.stop() - self.assertRaises(gen.KeyReuseError, self.run_gen, f) - - def test_key_mismatch(self): - @gen.engine - def f(): - yield gen.Callback("k1") - yield gen.Wait("k2") - self.stop() - self.assertRaises(gen.UnknownKeyError, self.run_gen, f) - - def test_key_mismatch_tuple(self): - @gen.engine - def f(): - yield gen.Callback((1, 2)) - yield gen.Wait((2, 3)) - self.stop() - self.assertRaises(gen.UnknownKeyError, self.run_gen, f) - - def test_leaked_callback(self): - @gen.engine - def f(): - yield gen.Callback("k1") - self.stop() - self.assertRaises(gen.LeakedCallbackError, self.run_gen, f) - - def test_leaked_callback_tuple(self): - @gen.engine - def f(): - yield gen.Callback((1, 2)) - self.stop() - self.assertRaises(gen.LeakedCallbackError, self.run_gen, f) - - def test_parallel_callback(self): - @gen.engine - def f(): - for k in range(3): - self.io_loop.add_callback((yield gen.Callback(k))) - yield gen.Wait(1) - self.io_loop.add_callback((yield gen.Callback(3))) - yield gen.Wait(0) - yield gen.Wait(3) - yield gen.Wait(2) - self.stop() - self.run_gen(f) - - def test_bogus_yield(self): - @gen.engine - def f(): - yield 42 - self.assertRaises(gen.BadYieldError, self.run_gen, f) - - def test_bogus_yield_tuple(self): - @gen.engine - def f(): - yield (1, 2) - self.assertRaises(gen.BadYieldError, self.run_gen, f) - - def test_reuse(self): - @gen.engine - def f(): - self.io_loop.add_callback((yield gen.Callback(0))) - yield gen.Wait(0) - self.stop() - self.run_gen(f) - self.run_gen(f) - - def test_task(self): - @gen.engine - def f(): - yield gen.Task(self.io_loop.add_callback) - self.stop() - self.run_gen(f) - - def test_wait_all(self): - @gen.engine - def f(): - (yield gen.Callback("k1"))("v1") - (yield gen.Callback("k2"))("v2") - results = yield gen.WaitAll(["k1", "k2"]) - self.assertEqual(results, ["v1", "v2"]) - self.stop() - self.run_gen(f) - - def test_exception_in_yield(self): - @gen.engine - def f(): - try: - yield gen.Wait("k1") - raise Exception("did not get expected exception") - except gen.UnknownKeyError: - pass - self.stop() - self.run_gen(f) - - def test_resume_after_exception_in_yield(self): - @gen.engine - def f(): - try: - yield gen.Wait("k1") - raise Exception("did not get expected exception") - except gen.UnknownKeyError: - pass - (yield gen.Callback("k2"))("v2") - self.assertEqual((yield gen.Wait("k2")), "v2") - self.stop() - self.run_gen(f) - - def test_orphaned_callback(self): - @gen.engine - def f(): - self.orphaned_callback = yield gen.Callback(1) - try: - self.run_gen(f) - raise Exception("did not get expected exception") - except gen.LeakedCallbackError: - pass - self.orphaned_callback() - - def test_none(self): - @gen.engine - def f(): - yield None - self.stop() - self.run_gen(f) - - def test_multi(self): - @gen.engine - def f(): - (yield gen.Callback("k1"))("v1") - (yield gen.Callback("k2"))("v2") - results = yield [gen.Wait("k1"), gen.Wait("k2")] - self.assertEqual(results, ["v1", "v2"]) - self.stop() - self.run_gen(f) - - def test_multi_dict(self): - @gen.engine - def f(): - (yield gen.Callback("k1"))("v1") - (yield gen.Callback("k2"))("v2") - results = yield dict(foo=gen.Wait("k1"), bar=gen.Wait("k2")) - self.assertEqual(results, dict(foo="v1", bar="v2")) - self.stop() - self.run_gen(f) - - # The following tests explicitly run with both gen.Multi - # and gen.multi_future (Task returns a Future, so it can be used - # with either). - def test_multi_yieldpoint_delayed(self): - @gen.engine - def f(): - # callbacks run at different times - responses = yield gen.Multi([ - gen.Task(self.delay_callback, 3, arg="v1"), - gen.Task(self.delay_callback, 1, arg="v2"), - ]) - self.assertEqual(responses, ["v1", "v2"]) - self.stop() - self.run_gen(f) - - def test_multi_yieldpoint_dict_delayed(self): - @gen.engine - def f(): - # callbacks run at different times - responses = yield gen.Multi(dict( - foo=gen.Task(self.delay_callback, 3, arg="v1"), - bar=gen.Task(self.delay_callback, 1, arg="v2"), - )) - self.assertEqual(responses, dict(foo="v1", bar="v2")) - self.stop() - self.run_gen(f) - - def test_multi_future_delayed(self): - @gen.engine - def f(): - # callbacks run at different times - responses = yield gen.multi_future([ - gen.Task(self.delay_callback, 3, arg="v1"), - gen.Task(self.delay_callback, 1, arg="v2"), - ]) - self.assertEqual(responses, ["v1", "v2"]) - self.stop() - self.run_gen(f) - - def test_multi_future_dict_delayed(self): - @gen.engine - def f(): - # callbacks run at different times - responses = yield gen.multi_future(dict( - foo=gen.Task(self.delay_callback, 3, arg="v1"), - bar=gen.Task(self.delay_callback, 1, arg="v2"), - )) - self.assertEqual(responses, dict(foo="v1", bar="v2")) - self.stop() - self.run_gen(f) - - @skipOnTravis - @gen_test - def test_multi_performance(self): - # Yielding a list used to have quadratic performance; make - # sure a large list stays reasonable. On my laptop a list of - # 2000 used to take 1.8s, now it takes 0.12. - start = time.time() - yield [gen.Task(self.io_loop.add_callback) for i in range(2000)] - end = time.time() - self.assertLess(end - start, 1.0) - - @gen_test - def test_multi_empty(self): - # Empty lists or dicts should return the same type. - x = yield [] - self.assertTrue(isinstance(x, list)) - y = yield {} - self.assertTrue(isinstance(y, dict)) - - @gen_test - def test_multi_mixed_types(self): - # A YieldPoint (Wait) and Future (Task) can be combined - # (and use the YieldPoint codepath) - (yield gen.Callback("k1"))("v1") - responses = yield [gen.Wait("k1"), - gen.Task(self.delay_callback, 3, arg="v2")] - self.assertEqual(responses, ["v1", "v2"]) - - @gen_test - def test_future(self): - result = yield self.async_future(1) - self.assertEqual(result, 1) - - @gen_test - def test_multi_future(self): - results = yield [self.async_future(1), self.async_future(2)] - self.assertEqual(results, [1, 2]) - - @gen_test - def test_multi_future_duplicate(self): - f = self.async_future(2) - results = yield [self.async_future(1), f, self.async_future(3), f] - self.assertEqual(results, [1, 2, 3, 2]) - - @gen_test - def test_multi_dict_future(self): - results = yield dict(foo=self.async_future(1), bar=self.async_future(2)) - self.assertEqual(results, dict(foo=1, bar=2)) - - @gen_test - def test_multi_exceptions(self): - with ExpectLog(app_log, "Multiple exceptions in yield list"): - with self.assertRaises(RuntimeError) as cm: - yield gen.Multi([self.async_exception(RuntimeError("error 1")), - self.async_exception(RuntimeError("error 2"))]) - self.assertEqual(str(cm.exception), "error 1") - - # With only one exception, no error is logged. - with self.assertRaises(RuntimeError): - yield gen.Multi([self.async_exception(RuntimeError("error 1")), - self.async_future(2)]) - - # Exception logging may be explicitly quieted. - with self.assertRaises(RuntimeError): - yield gen.Multi([self.async_exception(RuntimeError("error 1")), - self.async_exception(RuntimeError("error 2"))], - quiet_exceptions=RuntimeError) - - @gen_test - def test_multi_future_exceptions(self): - with ExpectLog(app_log, "Multiple exceptions in yield list"): - with self.assertRaises(RuntimeError) as cm: - yield [self.async_exception(RuntimeError("error 1")), - self.async_exception(RuntimeError("error 2"))] - self.assertEqual(str(cm.exception), "error 1") - - # With only one exception, no error is logged. - with self.assertRaises(RuntimeError): - yield [self.async_exception(RuntimeError("error 1")), - self.async_future(2)] - - # Exception logging may be explicitly quieted. - with self.assertRaises(RuntimeError): - yield gen.multi_future( - [self.async_exception(RuntimeError("error 1")), - self.async_exception(RuntimeError("error 2"))], - quiet_exceptions=RuntimeError) - - def test_arguments(self): - @gen.engine - def f(): - (yield gen.Callback("noargs"))() - self.assertEqual((yield gen.Wait("noargs")), None) - (yield gen.Callback("1arg"))(42) - self.assertEqual((yield gen.Wait("1arg")), 42) - - (yield gen.Callback("kwargs"))(value=42) - result = yield gen.Wait("kwargs") - self.assertTrue(isinstance(result, gen.Arguments)) - self.assertEqual(((), dict(value=42)), result) - self.assertEqual(dict(value=42), result.kwargs) - - (yield gen.Callback("2args"))(42, 43) - result = yield gen.Wait("2args") - self.assertTrue(isinstance(result, gen.Arguments)) - self.assertEqual(((42, 43), {}), result) - self.assertEqual((42, 43), result.args) - - def task_func(callback): - callback(None, error="foo") - result = yield gen.Task(task_func) - self.assertTrue(isinstance(result, gen.Arguments)) - self.assertEqual(((None,), dict(error="foo")), result) - - self.stop() - self.run_gen(f) - - def test_stack_context_leak(self): - # regression test: repeated invocations of a gen-based - # function should not result in accumulated stack_contexts - def _stack_depth(): - head = stack_context._state.contexts[1] - length = 0 - - while head is not None: - length += 1 - head = head.old_contexts[1] - - return length - - @gen.engine - def inner(callback): - yield gen.Task(self.io_loop.add_callback) - callback() - - @gen.engine - def outer(): - for i in range(10): - yield gen.Task(inner) - - stack_increase = _stack_depth() - initial_stack_depth - self.assertTrue(stack_increase <= 2) - self.stop() - initial_stack_depth = _stack_depth() - self.run_gen(outer) - - def test_stack_context_leak_exception(self): - # same as previous, but with a function that exits with an exception - @gen.engine - def inner(callback): - yield gen.Task(self.io_loop.add_callback) - 1 / 0 - - @gen.engine - def outer(): - for i in range(10): - try: - yield gen.Task(inner) - except ZeroDivisionError: - pass - stack_increase = len(stack_context._state.contexts) - initial_stack_depth - self.assertTrue(stack_increase <= 2) - self.stop() - initial_stack_depth = len(stack_context._state.contexts) - self.run_gen(outer) - - def function_with_stack_context(self, callback): - # Technically this function should stack_context.wrap its callback - # upon entry. However, it is very common for this step to be - # omitted. - def step2(): - self.assertEqual(self.named_contexts, ['a']) - self.io_loop.add_callback(callback) - - with stack_context.StackContext(self.named_context('a')): - self.io_loop.add_callback(step2) - - @gen_test - def test_wait_transfer_stack_context(self): - # Wait should not pick up contexts from where callback was invoked, - # even if that function improperly fails to wrap its callback. - cb = yield gen.Callback('k1') - self.function_with_stack_context(cb) - self.assertEqual(self.named_contexts, []) - yield gen.Wait('k1') - self.assertEqual(self.named_contexts, []) - - @gen_test - def test_task_transfer_stack_context(self): - yield gen.Task(self.function_with_stack_context) - self.assertEqual(self.named_contexts, []) - - def test_raise_after_stop(self): - # This pattern will be used in the following tests so make sure - # the exception propagates as expected. - @gen.engine - def f(): - self.stop() - 1 / 0 - - with self.assertRaises(ZeroDivisionError): - self.run_gen(f) - - def test_sync_raise_return(self): - # gen.Return is allowed in @gen.engine, but it may not be used - # to return a value. - @gen.engine - def f(): - self.stop(42) - raise gen.Return() - - result = self.run_gen(f) - self.assertEqual(result, 42) - - def test_async_raise_return(self): - @gen.engine - def f(): - yield gen.Task(self.io_loop.add_callback) - self.stop(42) - raise gen.Return() - - result = self.run_gen(f) - self.assertEqual(result, 42) - - def test_sync_raise_return_value(self): - @gen.engine - def f(): - raise gen.Return(42) - - with self.assertRaises(gen.ReturnValueIgnoredError): - self.run_gen(f) - - def test_sync_raise_return_value_tuple(self): - @gen.engine - def f(): - raise gen.Return((1, 2)) - - with self.assertRaises(gen.ReturnValueIgnoredError): - self.run_gen(f) - - def test_async_raise_return_value(self): - @gen.engine - def f(): - yield gen.Task(self.io_loop.add_callback) - raise gen.Return(42) - - with self.assertRaises(gen.ReturnValueIgnoredError): - self.run_gen(f) - - def test_async_raise_return_value_tuple(self): - @gen.engine - def f(): - yield gen.Task(self.io_loop.add_callback) - raise gen.Return((1, 2)) - - with self.assertRaises(gen.ReturnValueIgnoredError): - self.run_gen(f) - - def test_return_value(self): - # It is an error to apply @gen.engine to a function that returns - # a value. - @gen.engine - def f(): - return 42 - - with self.assertRaises(gen.ReturnValueIgnoredError): - self.run_gen(f) - - def test_return_value_tuple(self): - # It is an error to apply @gen.engine to a function that returns - # a value. - @gen.engine - def f(): - return (1, 2) - - with self.assertRaises(gen.ReturnValueIgnoredError): - self.run_gen(f) + contextvars = None # type: ignore - @skipNotCPython - def test_task_refcounting(self): - # On CPython, tasks and their arguments should be released immediately - # without waiting for garbage collection. - @gen.engine - def f(): - class Foo(object): - pass - arg = Foo() - self.arg_ref = weakref.ref(arg) - task = gen.Task(self.io_loop.add_callback, arg=arg) - self.task_ref = weakref.ref(task) - yield task - self.stop() +import typing - self.run_gen(f) - self.assertIs(self.arg_ref(), None) - self.assertIs(self.task_ref(), None) +if typing.TYPE_CHECKING: + from typing import List, Optional # noqa: F401 -# GenBasicTest duplicates the non-deprecated portions of GenEngineTest -# with gen.coroutine to ensure we don't lose coverage when gen.engine -# goes away. class GenBasicTest(AsyncTestCase): @gen.coroutine def delay(self, iterations, arg): @@ -678,9 +35,10 @@ def delay(self, iterations, arg): yield gen.moment raise gen.Return(arg) - @return_future - def async_future(self, result, callback): - self.io_loop.add_callback(callback, result) + @gen.coroutine + def async_future(self, result): + yield gen.moment + return result @gen.coroutine def async_exception(self, e): @@ -696,12 +54,14 @@ def test_no_yield(self): @gen.coroutine def f(): pass + self.io_loop.run_sync(f) def test_exception_phase1(self): @gen.coroutine def f(): 1 / 0 + self.assertRaises(ZeroDivisionError, self.io_loop.run_sync, f) def test_exception_phase2(self): @@ -709,24 +69,28 @@ def test_exception_phase2(self): def f(): yield gen.moment 1 / 0 + self.assertRaises(ZeroDivisionError, self.io_loop.run_sync, f) def test_bogus_yield(self): @gen.coroutine def f(): yield 42 + self.assertRaises(gen.BadYieldError, self.io_loop.run_sync, f) def test_bogus_yield_tuple(self): @gen.coroutine def f(): yield (1, 2) + self.assertRaises(gen.BadYieldError, self.io_loop.run_sync, f) def test_reuse(self): @gen.coroutine def f(): yield gen.moment + self.io_loop.run_sync(f) self.io_loop.run_sync(f) @@ -734,6 +98,7 @@ def test_none(self): @gen.coroutine def f(): yield None + self.io_loop.run_sync(f) def test_multi(self): @@ -741,6 +106,7 @@ def test_multi(self): def f(): results = yield [self.add_one_async(1), self.add_one_async(2)] self.assertEqual(results, [2, 3]) + self.io_loop.run_sync(f) def test_multi_dict(self): @@ -748,28 +114,29 @@ def test_multi_dict(self): def f(): results = yield dict(foo=self.add_one_async(1), bar=self.add_one_async(2)) self.assertEqual(results, dict(foo=2, bar=3)) + self.io_loop.run_sync(f) def test_multi_delayed(self): @gen.coroutine def f(): # callbacks run at different times - responses = yield gen.multi_future([ - self.delay(3, "v1"), - self.delay(1, "v2"), - ]) + responses = yield gen.multi_future( + [self.delay(3, "v1"), self.delay(1, "v2")] + ) self.assertEqual(responses, ["v1", "v2"]) + self.io_loop.run_sync(f) def test_multi_dict_delayed(self): @gen.coroutine def f(): # callbacks run at different times - responses = yield gen.multi_future(dict( - foo=self.delay(3, "v1"), - bar=self.delay(1, "v2"), - )) + responses = yield gen.multi_future( + dict(foo=self.delay(3, "v1"), bar=self.delay(1, "v2")) + ) self.assertEqual(responses, dict(foo="v1", bar="v2")) + self.io_loop.run_sync(f) @skipOnTravis @@ -803,6 +170,8 @@ def test_multi_future(self): @gen_test def test_multi_future_duplicate(self): + # Note that this doesn't work with native corotines, only with + # decorated coroutines. f = self.async_future(2) results = yield [self.async_future(1), f, self.async_future(3), f] self.assertEqual(results, [1, 2, 3, 2]) @@ -816,40 +185,53 @@ def test_multi_dict_future(self): def test_multi_exceptions(self): with ExpectLog(app_log, "Multiple exceptions in yield list"): with self.assertRaises(RuntimeError) as cm: - yield gen.Multi([self.async_exception(RuntimeError("error 1")), - self.async_exception(RuntimeError("error 2"))]) + yield gen.Multi( + [ + self.async_exception(RuntimeError("error 1")), + self.async_exception(RuntimeError("error 2")), + ] + ) self.assertEqual(str(cm.exception), "error 1") # With only one exception, no error is logged. with self.assertRaises(RuntimeError): - yield gen.Multi([self.async_exception(RuntimeError("error 1")), - self.async_future(2)]) + yield gen.Multi( + [self.async_exception(RuntimeError("error 1")), self.async_future(2)] + ) # Exception logging may be explicitly quieted. with self.assertRaises(RuntimeError): - yield gen.Multi([self.async_exception(RuntimeError("error 1")), - self.async_exception(RuntimeError("error 2"))], - quiet_exceptions=RuntimeError) + yield gen.Multi( + [ + self.async_exception(RuntimeError("error 1")), + self.async_exception(RuntimeError("error 2")), + ], + quiet_exceptions=RuntimeError, + ) @gen_test def test_multi_future_exceptions(self): with ExpectLog(app_log, "Multiple exceptions in yield list"): with self.assertRaises(RuntimeError) as cm: - yield [self.async_exception(RuntimeError("error 1")), - self.async_exception(RuntimeError("error 2"))] + yield [ + self.async_exception(RuntimeError("error 1")), + self.async_exception(RuntimeError("error 2")), + ] self.assertEqual(str(cm.exception), "error 1") # With only one exception, no error is logged. with self.assertRaises(RuntimeError): - yield [self.async_exception(RuntimeError("error 1")), - self.async_future(2)] + yield [self.async_exception(RuntimeError("error 1")), self.async_future(2)] # Exception logging may be explicitly quieted. with self.assertRaises(RuntimeError): yield gen.multi_future( - [self.async_exception(RuntimeError("error 1")), - self.async_exception(RuntimeError("error 2"))], - quiet_exceptions=RuntimeError) + [ + self.async_exception(RuntimeError("error 1")), + self.async_exception(RuntimeError("error 2")), + ], + quiet_exceptions=RuntimeError, + ) def test_sync_raise_return(self): @gen.coroutine @@ -903,10 +285,10 @@ def setUp(self): # so we need explicit checks here to make sure the tests run all # the way through. self.finished = False - super(GenCoroutineTest, self).setUp() + super().setUp() def tearDown(self): - super(GenCoroutineTest, self).tearDown() + super().tearDown() assert self.finished def test_attributes(self): @@ -918,7 +300,7 @@ def f(): coro = gen.coroutine(f) self.assertEqual(coro.__name__, f.__name__) self.assertEqual(coro.__module__, f.__module__) - self.assertIs(coro.__wrapped__, f) + self.assertIs(coro.__wrapped__, f) # type: ignore def test_is_coroutine_function(self): self.finished = True @@ -936,6 +318,7 @@ def test_sync_gen_return(self): @gen.coroutine def f(): raise gen.Return(42) + result = yield f() self.assertEqual(result, 42) self.finished = True @@ -946,6 +329,7 @@ def test_async_gen_return(self): def f(): yield gen.moment raise gen.Return(42) + result = yield f() self.assertEqual(result, 42) self.finished = True @@ -955,41 +339,37 @@ def test_sync_return(self): @gen.coroutine def f(): return 42 + result = yield f() self.assertEqual(result, 42) self.finished = True - @skipBefore33 @gen_test def test_async_return(self): - namespace = exec_test(globals(), locals(), """ @gen.coroutine def f(): yield gen.moment return 42 - """) - result = yield namespace['f']() + + result = yield f() self.assertEqual(result, 42) self.finished = True - @skipBefore33 @gen_test def test_async_early_return(self): # A yield statement exists but is not executed, which means # this function "returns" via an exception. This exception # doesn't happen before the exception handling is set up. - namespace = exec_test(globals(), locals(), """ @gen.coroutine def f(): if True: return 42 yield gen.Task(self.io_loop.add_callback) - """) - result = yield namespace['f']() + + result = yield f() self.assertEqual(result, 42) self.finished = True - @skipBefore35 @gen_test def test_async_await(self): @gen.coroutine @@ -1000,82 +380,53 @@ def f1(): # This test verifies that an async function can await a # yield-based gen.coroutine, and that a gen.coroutine # (the test method itself) can yield an async function. - namespace = exec_test(globals(), locals(), """ async def f2(): result = await f1() return result - """) - result = yield namespace['f2']() + + result = yield f2() self.assertEqual(result, 42) self.finished = True - @skipBefore35 @gen_test def test_asyncio_sleep_zero(self): # asyncio.sleep(0) turns into a special case (equivalent to # `yield None`) - namespace = exec_test(globals(), locals(), """ async def f(): import asyncio + await asyncio.sleep(0) return 42 - """) - result = yield namespace['f']() + + result = yield f() self.assertEqual(result, 42) self.finished = True - @skipBefore35 @gen_test def test_async_await_mixed_multi_native_future(self): @gen.coroutine def f1(): yield gen.moment - namespace = exec_test(globals(), locals(), """ async def f2(): await f1() return 42 - """) @gen.coroutine def f3(): yield gen.moment raise gen.Return(43) - results = yield [namespace['f2'](), f3()] - self.assertEqual(results, [42, 43]) - self.finished = True - - @skipBefore35 - @gen_test - def test_async_await_mixed_multi_native_yieldpoint(self): - namespace = exec_test(globals(), locals(), """ - async def f1(): - await gen.Task(self.io_loop.add_callback) - return 42 - """) - - @gen.coroutine - def f2(): - yield gen.Task(self.io_loop.add_callback) - raise gen.Return(43) - - with ignore_deprecation(): - f2(callback=(yield gen.Callback('cb'))) - results = yield [namespace['f1'](), gen.Wait('cb')] + results = yield [f2(), f3()] self.assertEqual(results, [42, 43]) self.finished = True - @skipBefore35 @gen_test def test_async_with_timeout(self): - namespace = exec_test(globals(), locals(), """ async def f1(): return 42 - """) - result = yield gen.with_timeout(datetime.timedelta(hours=1), - namespace['f1']()) + result = yield gen.with_timeout(datetime.timedelta(hours=1), f1()) self.assertEqual(result, 42) self.finished = True @@ -1084,6 +435,7 @@ def test_sync_return_no_value(self): @gen.coroutine def f(): return + result = yield f() self.assertEqual(result, None) self.finished = True @@ -1095,6 +447,7 @@ def test_async_return_no_value(self): def f(): yield gen.moment return + result = yield f() self.assertEqual(result, None) self.finished = True @@ -1104,6 +457,7 @@ def test_sync_raise(self): @gen.coroutine def f(): 1 / 0 + # The exception is raised when the future is yielded # (or equivalently when its result method is called), # not when the function itself is called). @@ -1118,21 +472,12 @@ def test_async_raise(self): def f(): yield gen.moment 1 / 0 + future = f() with self.assertRaises(ZeroDivisionError): yield future self.finished = True - @gen_test - def test_pass_callback(self): - with ignore_deprecation(): - @gen.coroutine - def f(): - raise gen.Return(42) - result = yield gen.Task(f) - self.assertEqual(result, 42) - self.finished = True - @gen_test def test_replace_yieldpoint_exception(self): # Test exception handling: a coroutine can catch one exception @@ -1172,51 +517,6 @@ def f2(): self.assertEqual(result, 42) self.finished = True - @gen_test - def test_replace_context_exception(self): - with ignore_deprecation(): - # Test exception handling: exceptions thrown into the stack context - # can be caught and replaced. - # Note that this test and the following are for behavior that is - # not really supported any more: coroutines no longer create a - # stack context automatically; but one is created after the first - # YieldPoint (i.e. not a Future). - @gen.coroutine - def f2(): - (yield gen.Callback(1))() - yield gen.Wait(1) - self.io_loop.add_callback(lambda: 1 / 0) - try: - yield gen.Task(self.io_loop.add_timeout, - self.io_loop.time() + 10) - except ZeroDivisionError: - raise KeyError() - - future = f2() - with self.assertRaises(KeyError): - yield future - self.finished = True - - @gen_test - def test_swallow_context_exception(self): - with ignore_deprecation(): - # Test exception handling: exceptions thrown into the stack context - # can be caught and ignored. - @gen.coroutine - def f2(): - (yield gen.Callback(1))() - yield gen.Wait(1) - self.io_loop.add_callback(lambda: 1 / 0) - try: - yield gen.Task(self.io_loop.add_timeout, - self.io_loop.time() + 10) - except ZeroDivisionError: - raise gen.Return(42) - - result = yield f2() - self.assertEqual(result, 42) - self.finished = True - @gen_test def test_moment(self): calls = [] @@ -1226,29 +526,29 @@ def f(name, yieldable): for i in range(5): calls.append(name) yield yieldable + # First, confirm the behavior without moment: each coroutine # monopolizes the event loop until it finishes. - immediate = Future() + immediate = Future() # type: Future[None] immediate.set_result(None) - yield [f('a', immediate), f('b', immediate)] - self.assertEqual(''.join(calls), 'aaaaabbbbb') + yield [f("a", immediate), f("b", immediate)] + self.assertEqual("".join(calls), "aaaaabbbbb") # With moment, they take turns. calls = [] - yield [f('a', gen.moment), f('b', gen.moment)] - self.assertEqual(''.join(calls), 'ababababab') + yield [f("a", gen.moment), f("b", gen.moment)] + self.assertEqual("".join(calls), "ababababab") self.finished = True calls = [] - yield [f('a', gen.moment), f('b', immediate)] - self.assertEqual(''.join(calls), 'abbbbbaaaa') + yield [f("a", gen.moment), f("b", immediate)] + self.assertEqual("".join(calls), "abbbbbaaaa") @gen_test def test_sleep(self): yield gen.sleep(0.01) self.finished = True - @skipBefore33 @gen_test def test_py3_leak_exception_context(self): class LeakedException(Exception): @@ -1273,8 +573,9 @@ def inner(iteration): self.finished = True @skipNotCPython - @unittest.skipIf((3,) < sys.version_info < (3, 6), - "asyncio.Future has reference cycles") + @unittest.skipIf( + (3,) < sys.version_info < (3, 6), "asyncio.Future has reference cycles" + ) def test_coroutine_refcounting(self): # On CPython, tasks and their arguments should be released immediately # without waiting for garbage collection. @@ -1282,10 +583,15 @@ def test_coroutine_refcounting(self): def inner(): class Foo(object): pass + local_var = Foo() self.local_ref = weakref.ref(local_var) - yield gen.coroutine(lambda: None)() - raise ValueError('Some error') + + def dummy(): + pass + + yield gen.coroutine(dummy)() + raise ValueError("Some error") @gen.coroutine def inner2(): @@ -1299,8 +605,6 @@ def inner2(): self.assertIs(self.local_ref(), None) self.finished = True - @unittest.skipIf(sys.version_info < (3,), - "test only relevant with asyncio Futures") def test_asyncio_future_debug_info(self): self.finished = True # Enable debug mode @@ -1315,12 +619,10 @@ def f(): self.assertIsInstance(coro, asyncio.Future) # We expect the coroutine repr() to show the place where # it was instantiated - expected = ("created at %s:%d" - % (__file__, f.__code__.co_firstlineno + 3)) + expected = "created at %s:%d" % (__file__, f.__code__.co_firstlineno + 3) actual = repr(coro) self.assertIn(expected, actual) - @unittest.skipIf(asyncio is None, "asyncio module not present") @gen_test def test_asyncio_gather(self): # This demonstrates that tornado coroutines can be understood @@ -1335,27 +637,6 @@ def f(): self.finished = True -class GenSequenceHandler(RequestHandler): - with ignore_deprecation(): - @asynchronous - @gen.engine - def get(self): - # The outer ignore_deprecation applies at definition time. - # We need another for serving time. - with ignore_deprecation(): - self.io_loop = self.request.connection.stream.io_loop - self.io_loop.add_callback((yield gen.Callback("k1"))) - yield gen.Wait("k1") - self.write("1") - self.io_loop.add_callback((yield gen.Callback("k2"))) - yield gen.Wait("k2") - self.write("2") - # reuse an old key - self.io_loop.add_callback((yield gen.Callback("k1"))) - yield gen.Wait("k1") - self.finish("3") - - class GenCoroutineSequenceHandler(RequestHandler): @gen.coroutine def get(self): @@ -1368,7 +649,6 @@ def get(self): class GenCoroutineUnfinishedSequenceHandler(RequestHandler): - @asynchronous @gen.coroutine def get(self): yield gen.moment @@ -1380,66 +660,21 @@ def get(self): self.write("3") -class GenTaskHandler(RequestHandler): - @gen.coroutine - def get(self): - client = AsyncHTTPClient() - with ignore_deprecation(): - response = yield gen.Task(client.fetch, self.get_argument('url')) - response.rethrow() - self.finish(b"got response: " + response.body) - - -class GenExceptionHandler(RequestHandler): - with ignore_deprecation(): - @asynchronous - @gen.engine - def get(self): - # This test depends on the order of the two decorators. - io_loop = self.request.connection.stream.io_loop - yield gen.Task(io_loop.add_callback) - raise Exception("oops") - - -class GenCoroutineExceptionHandler(RequestHandler): - @gen.coroutine - def get(self): - # This test depends on the order of the two decorators. - io_loop = self.request.connection.stream.io_loop - yield gen.Task(io_loop.add_callback) - raise Exception("oops") - - -class GenYieldExceptionHandler(RequestHandler): - @gen.coroutine - def get(self): - io_loop = self.request.connection.stream.io_loop - # Test the interaction of the two stack_contexts. - with ignore_deprecation(): - def fail_task(callback): - io_loop.add_callback(lambda: 1 / 0) - try: - yield gen.Task(fail_task) - raise Exception("did not get expected exception") - except ZeroDivisionError: - self.finish('ok') - - # "Undecorated" here refers to the absence of @asynchronous. class UndecoratedCoroutinesHandler(RequestHandler): @gen.coroutine def prepare(self): - self.chunks = [] + self.chunks = [] # type: List[str] yield gen.moment - self.chunks.append('1') + self.chunks.append("1") @gen.coroutine def get(self): - self.chunks.append('2') + self.chunks.append("2") yield gen.moment - self.chunks.append('3') + self.chunks.append("3") yield gen.moment - self.write(''.join(self.chunks)) + self.write("".join(self.chunks)) class AsyncPrepareErrorHandler(RequestHandler): @@ -1449,161 +684,133 @@ def prepare(self): raise HTTPError(403) def get(self): - self.finish('ok') + self.finish("ok") class NativeCoroutineHandler(RequestHandler): - if sys.version_info > (3, 5): - exec(textwrap.dedent(""" - async def get(self): - import asyncio - await asyncio.sleep(0) - self.write("ok") - """)) + async def get(self): + await asyncio.sleep(0) + self.write("ok") class GenWebTest(AsyncHTTPTestCase): def get_app(self): - return Application([ - ('/sequence', GenSequenceHandler), - ('/coroutine_sequence', GenCoroutineSequenceHandler), - ('/coroutine_unfinished_sequence', - GenCoroutineUnfinishedSequenceHandler), - ('/task', GenTaskHandler), - ('/exception', GenExceptionHandler), - ('/coroutine_exception', GenCoroutineExceptionHandler), - ('/yield_exception', GenYieldExceptionHandler), - ('/undecorated_coroutine', UndecoratedCoroutinesHandler), - ('/async_prepare_error', AsyncPrepareErrorHandler), - ('/native_coroutine', NativeCoroutineHandler), - ]) - - def test_sequence_handler(self): - response = self.fetch('/sequence') - self.assertEqual(response.body, b"123") + return Application( + [ + ("/coroutine_sequence", GenCoroutineSequenceHandler), + ( + "/coroutine_unfinished_sequence", + GenCoroutineUnfinishedSequenceHandler, + ), + ("/undecorated_coroutine", UndecoratedCoroutinesHandler), + ("/async_prepare_error", AsyncPrepareErrorHandler), + ("/native_coroutine", NativeCoroutineHandler), + ] + ) def test_coroutine_sequence_handler(self): - response = self.fetch('/coroutine_sequence') + response = self.fetch("/coroutine_sequence") self.assertEqual(response.body, b"123") def test_coroutine_unfinished_sequence_handler(self): - response = self.fetch('/coroutine_unfinished_sequence') + response = self.fetch("/coroutine_unfinished_sequence") self.assertEqual(response.body, b"123") - def test_task_handler(self): - response = self.fetch('/task?url=%s' % url_escape(self.get_url('/sequence'))) - self.assertEqual(response.body, b"got response: 123") - - def test_exception_handler(self): - # Make sure we get an error and not a timeout - with ExpectLog(app_log, "Uncaught exception GET /exception"): - response = self.fetch('/exception') - self.assertEqual(500, response.code) - - def test_coroutine_exception_handler(self): - # Make sure we get an error and not a timeout - with ExpectLog(app_log, "Uncaught exception GET /coroutine_exception"): - response = self.fetch('/coroutine_exception') - self.assertEqual(500, response.code) - - def test_yield_exception_handler(self): - response = self.fetch('/yield_exception') - self.assertEqual(response.body, b'ok') - def test_undecorated_coroutines(self): - response = self.fetch('/undecorated_coroutine') - self.assertEqual(response.body, b'123') + response = self.fetch("/undecorated_coroutine") + self.assertEqual(response.body, b"123") def test_async_prepare_error_handler(self): - response = self.fetch('/async_prepare_error') + response = self.fetch("/async_prepare_error") self.assertEqual(response.code, 403) - @skipBefore35 def test_native_coroutine_handler(self): - response = self.fetch('/native_coroutine') + response = self.fetch("/native_coroutine") self.assertEqual(response.code, 200) - self.assertEqual(response.body, b'ok') + self.assertEqual(response.body, b"ok") class WithTimeoutTest(AsyncTestCase): @gen_test def test_timeout(self): with self.assertRaises(gen.TimeoutError): - yield gen.with_timeout(datetime.timedelta(seconds=0.1), - Future()) + yield gen.with_timeout(datetime.timedelta(seconds=0.1), Future()) @gen_test def test_completes_before_timeout(self): - future = Future() - self.io_loop.add_timeout(datetime.timedelta(seconds=0.1), - lambda: future.set_result('asdf')) - result = yield gen.with_timeout(datetime.timedelta(seconds=3600), - future) - self.assertEqual(result, 'asdf') + future = Future() # type: Future[str] + self.io_loop.add_timeout( + datetime.timedelta(seconds=0.1), lambda: future.set_result("asdf") + ) + result = yield gen.with_timeout(datetime.timedelta(seconds=3600), future) + self.assertEqual(result, "asdf") @gen_test def test_fails_before_timeout(self): - future = Future() + future = Future() # type: Future[str] self.io_loop.add_timeout( datetime.timedelta(seconds=0.1), - lambda: future.set_exception(ZeroDivisionError())) + lambda: future.set_exception(ZeroDivisionError()), + ) with self.assertRaises(ZeroDivisionError): - yield gen.with_timeout(datetime.timedelta(seconds=3600), - future) + yield gen.with_timeout(datetime.timedelta(seconds=3600), future) @gen_test def test_already_resolved(self): - future = Future() - future.set_result('asdf') - result = yield gen.with_timeout(datetime.timedelta(seconds=3600), - future) - self.assertEqual(result, 'asdf') + future = Future() # type: Future[str] + future.set_result("asdf") + result = yield gen.with_timeout(datetime.timedelta(seconds=3600), future) + self.assertEqual(result, "asdf") - @unittest.skipIf(futures is None, 'futures module not present') @gen_test def test_timeout_concurrent_future(self): # A concurrent future that does not resolve before the timeout. with futures.ThreadPoolExecutor(1) as executor: with self.assertRaises(gen.TimeoutError): - yield gen.with_timeout(self.io_loop.time(), - executor.submit(time.sleep, 0.1)) + yield gen.with_timeout( + self.io_loop.time(), executor.submit(time.sleep, 0.1) + ) - @unittest.skipIf(futures is None, 'futures module not present') @gen_test def test_completed_concurrent_future(self): # A concurrent future that is resolved before we even submit it # to with_timeout. with futures.ThreadPoolExecutor(1) as executor: - f = executor.submit(lambda: None) + + def dummy(): + pass + + f = executor.submit(dummy) f.result() # wait for completion yield gen.with_timeout(datetime.timedelta(seconds=3600), f) - @unittest.skipIf(futures is None, 'futures module not present') @gen_test def test_normal_concurrent_future(self): # A conccurrent future that resolves while waiting for the timeout. with futures.ThreadPoolExecutor(1) as executor: - yield gen.with_timeout(datetime.timedelta(seconds=3600), - executor.submit(lambda: time.sleep(0.01))) + yield gen.with_timeout( + datetime.timedelta(seconds=3600), + executor.submit(lambda: time.sleep(0.01)), + ) class WaitIteratorTest(AsyncTestCase): @gen_test def test_empty_iterator(self): g = gen.WaitIterator() - self.assertTrue(g.done(), 'empty generator iterated') + self.assertTrue(g.done(), "empty generator iterated") with self.assertRaises(ValueError): - g = gen.WaitIterator(False, bar=False) + g = gen.WaitIterator(Future(), bar=Future()) self.assertEqual(g.current_index, None, "bad nil current index") self.assertEqual(g.current_future, None, "bad nil current future") @gen_test def test_already_done(self): - f1 = Future() - f2 = Future() - f3 = Future() + f1 = Future() # type: Future[int] + f2 = Future() # type: Future[int] + f3 = Future() # type: Future[int] f1.set_result(24) f2.set_result(42) f3.set_result(84) @@ -1636,14 +843,17 @@ def test_already_done(self): while not dg.done(): dr = yield dg.next() if dg.current_index == "f1": - self.assertTrue(dg.current_future == f1 and dr == 24, - "WaitIterator dict status incorrect") + self.assertTrue( + dg.current_future == f1 and dr == 24, + "WaitIterator dict status incorrect", + ) elif dg.current_index == "f2": - self.assertTrue(dg.current_future == f2 and dr == 42, - "WaitIterator dict status incorrect") + self.assertTrue( + dg.current_future == f2 and dr == 42, + "WaitIterator dict status incorrect", + ) else: - self.fail("got bad WaitIterator index {}".format( - dg.current_index)) + self.fail("got bad WaitIterator index {}".format(dg.current_index)) i += 1 @@ -1664,7 +874,7 @@ def finish_coroutines(self, iteration, futures): @gen_test def test_iterator(self): - futures = [Future(), Future(), Future(), Future()] + futures = [Future(), Future(), Future(), Future()] # type: List[Future[int]] self.finish_coroutines(0, futures) @@ -1675,39 +885,36 @@ def test_iterator(self): try: r = yield g.next() except ZeroDivisionError: - self.assertIs(g.current_future, futures[0], - 'exception future invalid') + self.assertIs(g.current_future, futures[0], "exception future invalid") else: if i == 0: - self.assertEqual(r, 24, 'iterator value incorrect') - self.assertEqual(g.current_index, 2, 'wrong index') + self.assertEqual(r, 24, "iterator value incorrect") + self.assertEqual(g.current_index, 2, "wrong index") elif i == 2: - self.assertEqual(r, 42, 'iterator value incorrect') - self.assertEqual(g.current_index, 1, 'wrong index') + self.assertEqual(r, 42, "iterator value incorrect") + self.assertEqual(g.current_index, 1, "wrong index") elif i == 3: - self.assertEqual(r, 84, 'iterator value incorrect') - self.assertEqual(g.current_index, 3, 'wrong index') + self.assertEqual(r, 84, "iterator value incorrect") + self.assertEqual(g.current_index, 3, "wrong index") i += 1 - @skipBefore35 @gen_test def test_iterator_async_await(self): # Recreate the previous test with py35 syntax. It's a little clunky # because of the way the previous test handles an exception on # a single iteration. - futures = [Future(), Future(), Future(), Future()] + futures = [Future(), Future(), Future(), Future()] # type: List[Future[int]] self.finish_coroutines(0, futures) self.finished = False - namespace = exec_test(globals(), locals(), """ async def f(): i = 0 g = gen.WaitIterator(*futures) try: async for r in g: if i == 0: - self.assertEqual(r, 24, 'iterator value incorrect') - self.assertEqual(g.current_index, 2, 'wrong index') + self.assertEqual(r, 24, "iterator value incorrect") + self.assertEqual(g.current_index, 2, "wrong index") else: raise Exception("expected exception on iteration 1") i += 1 @@ -1715,17 +922,17 @@ async def f(): i += 1 async for r in g: if i == 2: - self.assertEqual(r, 42, 'iterator value incorrect') - self.assertEqual(g.current_index, 1, 'wrong index') + self.assertEqual(r, 42, "iterator value incorrect") + self.assertEqual(g.current_index, 1, "wrong index") elif i == 3: - self.assertEqual(r, 84, 'iterator value incorrect') - self.assertEqual(g.current_index, 3, 'wrong index') + self.assertEqual(r, 84, "iterator value incorrect") + self.assertEqual(g.current_index, 3, "wrong index") else: raise Exception("didn't expect iteration %d" % i) i += 1 self.finished = True - """) - yield namespace['f']() + + yield f() self.assertTrue(self.finished) @gen_test @@ -1734,46 +941,40 @@ def test_no_ref(self): # WaitIterator itself, only the Future it returns. Since # WaitIterator uses weak references internally to improve GC # performance, this used to cause problems. - yield gen.with_timeout(datetime.timedelta(seconds=0.1), - gen.WaitIterator(gen.sleep(0)).next()) + yield gen.with_timeout( + datetime.timedelta(seconds=0.1), gen.WaitIterator(gen.sleep(0)).next() + ) class RunnerGCTest(AsyncTestCase): def is_pypy3(self): - return (platform.python_implementation() == 'PyPy' and - sys.version_info > (3,)) + return platform.python_implementation() == "PyPy" and sys.version_info > (3,) @gen_test def test_gc(self): - # Github issue 1769: Runner objects can get GCed unexpectedly + # GitHub issue 1769: Runner objects can get GCed unexpectedly # while their future is alive. - weakref_scope = [None] + weakref_scope = [None] # type: List[Optional[weakref.ReferenceType]] def callback(): gc.collect(2) - weakref_scope[0]().set_result(123) + weakref_scope[0]().set_result(123) # type: ignore @gen.coroutine def tester(): - fut = Future() + fut = Future() # type: Future[int] weakref_scope[0] = weakref.ref(fut) self.io_loop.add_callback(callback) yield fut - yield gen.with_timeout( - datetime.timedelta(seconds=0.2), - tester() - ) + yield gen.with_timeout(datetime.timedelta(seconds=0.2), tester()) def test_gc_infinite_coro(self): - # Github issue 2229: suspended coroutines should be GCed when + # GitHub issue 2229: suspended coroutines should be GCed when # their loop is closed, even if they're involved in a reference # cycle. - if IOLoop.configured_class().__name__.endswith('TwistedIOLoop'): - raise unittest.SkipTest("Test may fail on TwistedIOLoop") - loop = self.get_new_ioloop() - result = [] + result = [] # type: List[Optional[bool]] wfut = [] @gen.coroutine @@ -1789,7 +990,7 @@ def infinite_coro(): @gen.coroutine def do_something(): fut = infinite_coro() - fut._refcycle = fut + fut._refcycle = fut # type: ignore wfut.append(weakref.ref(fut)) yield gen.sleep(0.2) @@ -1804,12 +1005,10 @@ def do_something(): # coroutine finalizer was called (not on PyPy3 apparently) self.assertIs(result[-1], None) - @skipBefore35 def test_gc_infinite_async_await(self): # Same as test_gc_infinite_coro, but with a `async def` function import asyncio - namespace = exec_test(globals(), locals(), """ async def infinite_coro(result): try: while True: @@ -1818,22 +1017,20 @@ async def infinite_coro(result): finally: # coroutine finalizer result.append(None) - """) - infinite_coro = namespace['infinite_coro'] loop = self.get_new_ioloop() - result = [] + result = [] # type: List[Optional[bool]] wfut = [] @gen.coroutine def do_something(): fut = asyncio.get_event_loop().create_task(infinite_coro(result)) - fut._refcycle = fut + fut._refcycle = fut # type: ignore wfut.append(weakref.ref(fut)) yield gen.sleep(0.2) loop.run_sync(do_something) - with ExpectLog('asyncio', "Task was destroyed but it is pending"): + with ExpectLog("asyncio", "Task was destroyed but it is pending"): loop.close() gc.collect() # Future was collected @@ -1857,5 +1054,66 @@ def wait_a_moment(): self.assertEqual(result, [None, None]) -if __name__ == '__main__': +if contextvars is not None: + ctx_var = contextvars.ContextVar("ctx_var") # type: contextvars.ContextVar[int] + + +@unittest.skipIf(contextvars is None, "contextvars module not present") +class ContextVarsTest(AsyncTestCase): + async def native_root(self, x): + ctx_var.set(x) + await self.inner(x) + + @gen.coroutine + def gen_root(self, x): + ctx_var.set(x) + yield + yield self.inner(x) + + async def inner(self, x): + self.assertEqual(ctx_var.get(), x) + await self.gen_inner(x) + self.assertEqual(ctx_var.get(), x) + + # IOLoop.run_in_executor doesn't automatically copy context + ctx = contextvars.copy_context() + await self.io_loop.run_in_executor(None, lambda: ctx.run(self.thread_inner, x)) + self.assertEqual(ctx_var.get(), x) + + # Neither does asyncio's run_in_executor. + await asyncio.get_event_loop().run_in_executor( + None, lambda: ctx.run(self.thread_inner, x) + ) + self.assertEqual(ctx_var.get(), x) + + @gen.coroutine + def gen_inner(self, x): + self.assertEqual(ctx_var.get(), x) + yield + self.assertEqual(ctx_var.get(), x) + + def thread_inner(self, x): + self.assertEqual(ctx_var.get(), x) + + @gen_test + def test_propagate(self): + # Verify that context vars get propagated across various + # combinations of native and decorated coroutines. + yield [ + self.native_root(1), + self.native_root(2), + self.gen_root(3), + self.gen_root(4), + ] + + @gen_test + def test_reset(self): + token = ctx_var.set(1) + yield + # reset asserts that we are still at the same level of the context tree, + # so we must make sure that we maintain that property across yield. + ctx_var.reset(token) + + +if __name__ == "__main__": unittest.main() diff --git a/tornado/test/gettext_translations/extract_me.py b/tornado/test/gettext_translations/extract_me.py index 283c13f413..08b29bc53c 100644 --- a/tornado/test/gettext_translations/extract_me.py +++ b/tornado/test/gettext_translations/extract_me.py @@ -8,7 +8,6 @@ # 3) msgfmt tornado_test.po -o tornado_test.mo # 4) Put the file in the proper location: $LANG/LC_MESSAGES -from __future__ import absolute_import, division, print_function _("school") pgettext("law", "right") pgettext("good", "right") diff --git a/tornado/test/http1connection_test.py b/tornado/test/http1connection_test.py index 8aaaaf35b7..d21d506228 100644 --- a/tornado/test/http1connection_test.py +++ b/tornado/test/http1connection_test.py @@ -1,6 +1,5 @@ -from __future__ import absolute_import, division, print_function - import socket +import typing from tornado.http1connection import HTTP1Connection from tornado.httputil import HTTPMessageDelegate @@ -11,8 +10,10 @@ class HTTP1ConnectionTest(AsyncTestCase): + code = None # type: typing.Optional[int] + def setUp(self): - super(HTTP1ConnectionTest, self).setUp() + super().setUp() self.asyncSetUp() @gen_test @@ -28,8 +29,7 @@ def accept_callback(conn, addr): add_accept_handler(listener, accept_callback) self.client_stream = IOStream(socket.socket()) self.addCleanup(self.client_stream.close) - yield [self.client_stream.connect(('127.0.0.1', port)), - event.wait()] + yield [self.client_stream.connect(("127.0.0.1", port)), event.wait()] self.io_loop.remove_handler(listener) listener.close() @@ -58,4 +58,4 @@ def finish(self): yield conn.read_response(Delegate()) yield event.wait() self.assertEqual(self.code, 200) - self.assertEqual(b''.join(body), b'hello') + self.assertEqual(b"".join(body), b"hello") diff --git a/tornado/test/httpclient_test.py b/tornado/test/httpclient_test.py index 851f126885..fd9a978640 100644 --- a/tornado/test/httpclient_test.py +++ b/tornado/test/httpclient_test.py @@ -1,25 +1,34 @@ -from __future__ import absolute_import, division, print_function - import base64 import binascii from contextlib import closing import copy -import sys +import gzip import threading import datetime from io import BytesIO +import subprocess +import sys +import time +import typing # noqa: F401 +import unicodedata +import unittest -from tornado.escape import utf8, native_str +from tornado.escape import utf8, native_str, to_unicode from tornado import gen -from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient +from tornado.httpclient import ( + HTTPRequest, + HTTPResponse, + _RequestProxy, + HTTPError, + HTTPClient, +) from tornado.httpserver import HTTPServer from tornado.ioloop import IOLoop from tornado.iostream import IOStream -from tornado.log import gen_log +from tornado.log import gen_log, app_log from tornado import netutil -from tornado.stack_context import ExceptionStackContext, NullContext from tornado.testing import AsyncHTTPTestCase, bind_unused_port, gen_test, ExpectLog -from tornado.test.util import unittest, skipOnTravis, ignore_deprecation +from tornado.test.util import skipOnTravis from tornado.web import Application, RequestHandler, url from tornado.httputil import format_timestamp, HTTPHeaders @@ -33,8 +42,10 @@ def get(self): class PostHandler(RequestHandler): def post(self): - self.finish("Post arg1: %s, arg2: %s" % ( - self.get_argument("arg1"), self.get_argument("arg2"))) + self.finish( + "Post arg1: %s, arg2: %s" + % (self.get_argument("arg1"), self.get_argument("arg2")) + ) class PutHandler(RequestHandler): @@ -45,9 +56,17 @@ def put(self): class RedirectHandler(RequestHandler): def prepare(self): - self.write('redirects can have bodies too') - self.redirect(self.get_argument("url"), - status=int(self.get_argument("status", "302"))) + self.write("redirects can have bodies too") + self.redirect( + self.get_argument("url"), status=int(self.get_argument("status", "302")) + ) + + +class RedirectWithoutLocationHandler(RequestHandler): + def prepare(self): + # For testing error handling of a redirect with no location header. + self.set_status(301) + self.finish() class ChunkHandler(RequestHandler): @@ -81,44 +100,57 @@ def post(self): class UserAgentHandler(RequestHandler): def get(self): - self.write(self.request.headers.get('User-Agent', 'User agent not set')) + self.write(self.request.headers.get("User-Agent", "User agent not set")) class ContentLength304Handler(RequestHandler): def get(self): self.set_status(304) - self.set_header('Content-Length', 42) + self.set_header("Content-Length", 42) - def _clear_headers_for_304(self): + def _clear_representation_headers(self): # Tornado strips content-length from 304 responses, but here we # want to simulate servers that include the headers anyway. pass class PatchHandler(RequestHandler): - def patch(self): "Return the request payload - so we can check it is being kept" self.write(self.request.body) class AllMethodsHandler(RequestHandler): - SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ('OTHER',) + SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ("OTHER",) # type: ignore def method(self): + assert self.request.method is not None self.write(self.request.method) - get = post = put = delete = options = patch = other = method + get = head = post = put = delete = options = patch = other = method # type: ignore class SetHeaderHandler(RequestHandler): def get(self): # Use get_arguments for keys to get strings, but # request.arguments for values to get bytes. - for k, v in zip(self.get_arguments('k'), - self.request.arguments['v']): + for k, v in zip(self.get_arguments("k"), self.request.arguments["v"]): self.set_header(k, v) + +class InvalidGzipHandler(RequestHandler): + def get(self): + # set Content-Encoding manually to avoid automatic gzip encoding + self.set_header("Content-Type", "text/plain") + self.set_header("Content-Encoding", "gzip") + # Triggering the potential bug seems to depend on input length. + # This length is taken from the bad-response example reported in + # https://github.com/tornadoweb/tornado/pull/2875 (uncompressed). + body = "".join("Hello World {}\n".format(i) for i in range(9000))[:149051] + body = gzip.compress(body.encode(), compresslevel=6) + b"\00" + self.write(body) + + # These tests end up getting run redundantly: once here with the default # HTTPClient implementation, and then again in each implementation's own # test suite. @@ -126,25 +158,30 @@ def get(self): class HTTPClientCommonTestCase(AsyncHTTPTestCase): def get_app(self): - return Application([ - url("/hello", HelloWorldHandler), - url("/post", PostHandler), - url("/put", PutHandler), - url("/redirect", RedirectHandler), - url("/chunk", ChunkHandler), - url("/auth", AuthHandler), - url("/countdown/([0-9]+)", CountdownHandler, name="countdown"), - url("/echopost", EchoPostHandler), - url("/user_agent", UserAgentHandler), - url("/304_with_content_length", ContentLength304Handler), - url("/all_methods", AllMethodsHandler), - url('/patch', PatchHandler), - url('/set_header', SetHeaderHandler), - ], gzip=True) + return Application( + [ + url("/hello", HelloWorldHandler), + url("/post", PostHandler), + url("/put", PutHandler), + url("/redirect", RedirectHandler), + url("/redirect_without_location", RedirectWithoutLocationHandler), + url("/chunk", ChunkHandler), + url("/auth", AuthHandler), + url("/countdown/([0-9]+)", CountdownHandler, name="countdown"), + url("/echopost", EchoPostHandler), + url("/user_agent", UserAgentHandler), + url("/304_with_content_length", ContentLength304Handler), + url("/all_methods", AllMethodsHandler), + url("/patch", PatchHandler), + url("/set_header", SetHeaderHandler), + url("/invalid_gzip", InvalidGzipHandler), + ], + gzip=True, + ) def test_patch_receives_payload(self): body = b"some patch data" - response = self.fetch("/patch", method='PATCH', body=body) + response = self.fetch("/patch", method="PATCH", body=body) self.assertEqual(response.code, 200) self.assertEqual(response.body, body) @@ -154,6 +191,7 @@ def test_hello_world(self): self.assertEqual(response.code, 200) self.assertEqual(response.headers["Content-Type"], "text/plain") self.assertEqual(response.body, b"Hello world!") + assert response.request_time is not None self.assertEqual(int(response.request_time), 0) response = self.fetch("/hello?name=Ben") @@ -161,16 +199,14 @@ def test_hello_world(self): def test_streaming_callback(self): # streaming_callback is also tested in test_chunked - chunks = [] - response = self.fetch("/hello", - streaming_callback=chunks.append) + chunks = [] # type: typing.List[bytes] + response = self.fetch("/hello", streaming_callback=chunks.append) # with streaming_callback, data goes to the callback and not response.body self.assertEqual(chunks, [b"Hello world!"]) self.assertFalse(response.body) def test_post(self): - response = self.fetch("/post", method="POST", - body="arg1=foo&arg2=bar") + response = self.fetch("/post", method="POST", body="arg1=foo&arg2=bar") self.assertEqual(response.code, 200) self.assertEqual(response.body, b"Post arg1: foo, arg2: bar") @@ -178,9 +214,8 @@ def test_chunked(self): response = self.fetch("/chunk") self.assertEqual(response.body, b"asdfqwer") - chunks = [] - response = self.fetch("/chunk", - streaming_callback=chunks.append) + chunks = [] # type: typing.List[bytes] + response = self.fetch("/chunk", streaming_callback=chunks.append) self.assertEqual(chunks, [b"asdf", b"qwer"]) self.assertFalse(response.body) @@ -189,6 +224,7 @@ def test_chunked_close(self): # over several ioloop iterations, but the connection is already closed. sock, port = bind_unused_port() with closing(sock): + @gen.coroutine def accept_callback(conn, address): # fake an HTTP server using chunked encoding where the final chunks @@ -197,7 +233,8 @@ def accept_callback(conn, address): request_data = yield stream.read_until(b"\r\n\r\n") if b"HTTP/1." not in request_data: self.skipTest("requires HTTP/1.x") - yield stream.write(b"""\ + yield stream.write( + b"""\ HTTP/1.1 200 OK Transfer-Encoding: chunked @@ -207,55 +244,66 @@ def accept_callback(conn, address): 2 0 -""".replace(b"\n", b"\r\n")) +""".replace( + b"\n", b"\r\n" + ) + ) stream.close() - netutil.add_accept_handler(sock, accept_callback) + + netutil.add_accept_handler(sock, accept_callback) # type: ignore resp = self.fetch("http://127.0.0.1:%d/" % port) resp.rethrow() self.assertEqual(resp.body, b"12") self.io_loop.remove_handler(sock.fileno()) - def test_streaming_stack_context(self): - chunks = [] - exc_info = [] - - def error_handler(typ, value, tb): - exc_info.append((typ, value, tb)) - return True - - def streaming_cb(chunk): - chunks.append(chunk) - if chunk == b'qwer': - 1 / 0 - - with ExceptionStackContext(error_handler): - self.fetch('/chunk', streaming_callback=streaming_cb) - - self.assertEqual(chunks, [b'asdf', b'qwer']) - self.assertEqual(1, len(exc_info)) - self.assertIs(exc_info[0][0], ZeroDivisionError) - def test_basic_auth(self): - self.assertEqual(self.fetch("/auth", auth_username="Aladdin", - auth_password="open sesame").body, - b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==") + # This test data appears in section 2 of RFC 7617. + self.assertEqual( + self.fetch( + "/auth", auth_username="Aladdin", auth_password="open sesame" + ).body, + b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", + ) def test_basic_auth_explicit_mode(self): - self.assertEqual(self.fetch("/auth", auth_username="Aladdin", - auth_password="open sesame", - auth_mode="basic").body, - b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==") + self.assertEqual( + self.fetch( + "/auth", + auth_username="Aladdin", + auth_password="open sesame", + auth_mode="basic", + ).body, + b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", + ) + + def test_basic_auth_unicode(self): + # This test data appears in section 2.1 of RFC 7617. + self.assertEqual( + self.fetch("/auth", auth_username="test", auth_password="123£").body, + b"Basic dGVzdDoxMjPCow==", + ) + + # The standard mandates NFC. Give it a decomposed username + # and ensure it is normalized to composed form. + username = unicodedata.normalize("NFD", u"josé") + self.assertEqual( + self.fetch("/auth", auth_username=username, auth_password="səcrət").body, + b"Basic am9zw6k6c8mZY3LJmXQ=", + ) def test_unsupported_auth_mode(self): # curl and simple clients handle errors a bit differently; the # important thing is that they don't fall back to basic auth # on an unknown mode. with ExpectLog(gen_log, "uncaught exception", required=False): - with self.assertRaises((ValueError, HTTPError)): - self.fetch("/auth", auth_username="Aladdin", - auth_password="open sesame", - auth_mode="asdf", - raise_error=True) + with self.assertRaises((ValueError, HTTPError)): # type: ignore + self.fetch( + "/auth", + auth_username="Aladdin", + auth_password="open sesame", + auth_mode="asdf", + raise_error=True, + ) def test_follow_redirect(self): response = self.fetch("/countdown/2", follow_redirects=False) @@ -267,34 +315,97 @@ def test_follow_redirect(self): self.assertTrue(response.effective_url.endswith("/countdown/0")) self.assertEqual(b"Zero", response.body) + def test_redirect_without_location(self): + response = self.fetch("/redirect_without_location", follow_redirects=True) + # If there is no location header, the redirect response should + # just be returned as-is. (This should arguably raise an + # error, but libcurl doesn't treat this as an error, so we + # don't either). + self.assertEqual(301, response.code) + + def test_redirect_put_with_body(self): + response = self.fetch( + "/redirect?url=/put&status=307", method="PUT", body="hello" + ) + self.assertEqual(response.body, b"Put body: hello") + + def test_redirect_put_without_body(self): + # This "without body" edge case is similar to what happens with body_producer. + response = self.fetch( + "/redirect?url=/put&status=307", + method="PUT", + allow_nonstandard_methods=True, + ) + self.assertEqual(response.body, b"Put body: ") + + def test_method_after_redirect(self): + # Legacy redirect codes (301, 302) convert POST requests to GET. + for status in [301, 302, 303]: + url = "/redirect?url=/all_methods&status=%d" % status + resp = self.fetch(url, method="POST", body=b"") + self.assertEqual(b"GET", resp.body) + + # Other methods are left alone, except for 303 redirect, depending on client + for method in ["GET", "OPTIONS", "PUT", "DELETE"]: + resp = self.fetch(url, method=method, allow_nonstandard_methods=True) + if status in [301, 302]: + self.assertEqual(utf8(method), resp.body) + else: + self.assertIn(resp.body, [utf8(method), b"GET"]) + + # HEAD is different so check it separately. + resp = self.fetch(url, method="HEAD") + self.assertEqual(200, resp.code) + self.assertEqual(b"", resp.body) + + # Newer redirects always preserve the original method. + for status in [307, 308]: + url = "/redirect?url=/all_methods&status=307" + for method in ["GET", "OPTIONS", "POST", "PUT", "DELETE"]: + resp = self.fetch(url, method=method, allow_nonstandard_methods=True) + self.assertEqual(method, to_unicode(resp.body)) + resp = self.fetch(url, method="HEAD") + self.assertEqual(200, resp.code) + self.assertEqual(b"", resp.body) + def test_credentials_in_url(self): url = self.get_url("/auth").replace("http://", "http://me:secret@") response = self.fetch(url) - self.assertEqual(b"Basic " + base64.b64encode(b"me:secret"), - response.body) + self.assertEqual(b"Basic " + base64.b64encode(b"me:secret"), response.body) def test_body_encoding(self): unicode_body = u"\xe9" byte_body = binascii.a2b_hex(b"e9") # unicode string in body gets converted to utf8 - response = self.fetch("/echopost", method="POST", body=unicode_body, - headers={"Content-Type": "application/blah"}) + response = self.fetch( + "/echopost", + method="POST", + body=unicode_body, + headers={"Content-Type": "application/blah"}, + ) self.assertEqual(response.headers["Content-Length"], "2") self.assertEqual(response.body, utf8(unicode_body)) # byte strings pass through directly - response = self.fetch("/echopost", method="POST", - body=byte_body, - headers={"Content-Type": "application/blah"}) + response = self.fetch( + "/echopost", + method="POST", + body=byte_body, + headers={"Content-Type": "application/blah"}, + ) self.assertEqual(response.headers["Content-Length"], "1") self.assertEqual(response.body, byte_body) # Mixing unicode in headers and byte string bodies shouldn't # break anything - response = self.fetch("/echopost", method="POST", body=byte_body, - headers={"Content-Type": "application/blah"}, - user_agent=u"foo") + response = self.fetch( + "/echopost", + method="POST", + body=byte_body, + headers={"Content-Type": "application/blah"}, + user_agent=u"foo", + ) self.assertEqual(response.headers["Content-Length"], "1") self.assertEqual(response.body, byte_body) @@ -305,59 +416,75 @@ def test_types(self): self.assertEqual(type(response.code), int) self.assertEqual(type(response.effective_url), str) + def test_gzip(self): + # All the tests in this file should be using gzip, but this test + # ensures that it is in fact getting compressed, and also tests + # the httpclient's decompress=False option. + # Setting Accept-Encoding manually bypasses the client's + # decompression so we can see the raw data. + response = self.fetch( + "/chunk", decompress_response=False, headers={"Accept-Encoding": "gzip"} + ) + self.assertEqual(response.headers["Content-Encoding"], "gzip") + self.assertNotEqual(response.body, b"asdfqwer") + # Our test data gets bigger when gzipped. Oops. :) + # Chunked encoding bypasses the MIN_LENGTH check. + self.assertEqual(len(response.body), 34) + f = gzip.GzipFile(mode="r", fileobj=response.buffer) + self.assertEqual(f.read(), b"asdfqwer") + + def test_invalid_gzip(self): + # test if client hangs on tricky invalid gzip + # curl/simple httpclient have different behavior (exception, logging) + with ExpectLog( + app_log, "(Uncaught exception|Exception in callback)", required=False + ): + try: + response = self.fetch("/invalid_gzip") + self.assertEqual(response.code, 200) + self.assertEqual(response.body[:14], b"Hello World 0\n") + except HTTPError: + pass # acceptable + def test_header_callback(self): first_line = [] headers = {} chunks = [] def header_callback(header_line): - if header_line.startswith('HTTP/1.1 101'): + if header_line.startswith("HTTP/1.1 101"): # Upgrading to HTTP/2 pass - elif header_line.startswith('HTTP/'): + elif header_line.startswith("HTTP/"): first_line.append(header_line) - elif header_line != '\r\n': - k, v = header_line.split(':', 1) + elif header_line != "\r\n": + k, v = header_line.split(":", 1) headers[k.lower()] = v.strip() def streaming_callback(chunk): # All header callbacks are run before any streaming callbacks, # so the header data is available to process the data as it # comes in. - self.assertEqual(headers['content-type'], 'text/html; charset=UTF-8') + self.assertEqual(headers["content-type"], "text/html; charset=UTF-8") chunks.append(chunk) - self.fetch('/chunk', header_callback=header_callback, - streaming_callback=streaming_callback) + self.fetch( + "/chunk", + header_callback=header_callback, + streaming_callback=streaming_callback, + ) self.assertEqual(len(first_line), 1, first_line) - self.assertRegexpMatches(first_line[0], 'HTTP/[0-9]\\.[0-9] 200.*\r\n') - self.assertEqual(chunks, [b'asdf', b'qwer']) - - def test_header_callback_stack_context(self): - exc_info = [] - - def error_handler(typ, value, tb): - exc_info.append((typ, value, tb)) - return True - - def header_callback(header_line): - if header_line.lower().startswith('content-type:'): - 1 / 0 - - with ExceptionStackContext(error_handler): - self.fetch('/chunk', header_callback=header_callback) - self.assertEqual(len(exc_info), 1) - self.assertIs(exc_info[0][0], ZeroDivisionError) + self.assertRegexpMatches(first_line[0], "HTTP/[0-9]\\.[0-9] 200.*\r\n") + self.assertEqual(chunks, [b"asdf", b"qwer"]) @gen_test def test_configure_defaults(self): - defaults = dict(user_agent='TestDefaultUserAgent', allow_ipv6=False) + defaults = dict(user_agent="TestDefaultUserAgent", allow_ipv6=False) # Construct a new instance of the configured client class - client = self.http_client.__class__(force_instance=True, - defaults=defaults) + client = self.http_client.__class__(force_instance=True, defaults=defaults) try: - response = yield client.fetch(self.get_url('/user_agent')) - self.assertEqual(response.body, b'TestDefaultUserAgent') + response = yield client.fetch(self.get_url("/user_agent")) + self.assertEqual(response.body, b"TestDefaultUserAgent") finally: client.close() @@ -369,84 +496,75 @@ def test_header_types(self): for value in [u"MyUserAgent", b"MyUserAgent"]: for container in [dict, HTTPHeaders]: headers = container() - headers['User-Agent'] = value - resp = self.fetch('/user_agent', headers=headers) + headers["User-Agent"] = value + resp = self.fetch("/user_agent", headers=headers) self.assertEqual( - resp.body, b"MyUserAgent", - "response=%r, value=%r, container=%r" % - (resp.body, value, container)) + resp.body, + b"MyUserAgent", + "response=%r, value=%r, container=%r" + % (resp.body, value, container), + ) def test_multi_line_headers(self): # Multi-line http headers are rare but rfc-allowed # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 sock, port = bind_unused_port() with closing(sock): + @gen.coroutine def accept_callback(conn, address): stream = IOStream(conn) request_data = yield stream.read_until(b"\r\n\r\n") if b"HTTP/1." not in request_data: self.skipTest("requires HTTP/1.x") - yield stream.write(b"""\ + yield stream.write( + b"""\ HTTP/1.1 200 OK X-XSS-Protection: 1; \tmode=block -""".replace(b"\n", b"\r\n")) +""".replace( + b"\n", b"\r\n" + ) + ) stream.close() - netutil.add_accept_handler(sock, accept_callback) - resp = self.fetch("http://127.0.0.1:%d/" % port) - resp.rethrow() - self.assertEqual(resp.headers['X-XSS-Protection'], "1; mode=block") - self.io_loop.remove_handler(sock.fileno()) + netutil.add_accept_handler(sock, accept_callback) # type: ignore + try: + resp = self.fetch("http://127.0.0.1:%d/" % port) + resp.rethrow() + self.assertEqual(resp.headers["X-XSS-Protection"], "1; mode=block") + finally: + self.io_loop.remove_handler(sock.fileno()) def test_304_with_content_length(self): # According to the spec 304 responses SHOULD NOT include # Content-Length or other entity headers, but some servers do it # anyway. # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.5 - response = self.fetch('/304_with_content_length') + response = self.fetch("/304_with_content_length") self.assertEqual(response.code, 304) - self.assertEqual(response.headers['Content-Length'], '42') - - def test_final_callback_stack_context(self): - # The final callback should be run outside of the httpclient's - # stack_context. We want to ensure that there is not stack_context - # between the user's callback and the IOLoop, so monkey-patch - # IOLoop.handle_callback_exception and disable the test harness's - # context with a NullContext. - # Note that this does not apply to secondary callbacks (header - # and streaming_callback), as errors there must be seen as errors - # by the http client so it can clean up the connection. - exc_info = [] - - def handle_callback_exception(callback): - exc_info.append(sys.exc_info()) - self.stop() - self.io_loop.handle_callback_exception = handle_callback_exception - with NullContext(): - with ignore_deprecation(): - self.http_client.fetch(self.get_url('/hello'), - lambda response: 1 / 0) - self.wait() - self.assertEqual(exc_info[0][0], ZeroDivisionError) + self.assertEqual(response.headers["Content-Length"], "42") @gen_test def test_future_interface(self): - response = yield self.http_client.fetch(self.get_url('/hello')) - self.assertEqual(response.body, b'Hello world!') + response = yield self.http_client.fetch(self.get_url("/hello")) + self.assertEqual(response.body, b"Hello world!") @gen_test def test_future_http_error(self): with self.assertRaises(HTTPError) as context: - yield self.http_client.fetch(self.get_url('/notfound')) + yield self.http_client.fetch(self.get_url("/notfound")) + assert context.exception is not None + assert context.exception.response is not None self.assertEqual(context.exception.code, 404) self.assertEqual(context.exception.response.code, 404) @gen_test def test_future_http_error_no_raise(self): - response = yield self.http_client.fetch(self.get_url('/notfound'), raise_error=False) + response = yield self.http_client.fetch( + self.get_url("/notfound"), raise_error=False + ) self.assertEqual(response.code, 404) @gen_test @@ -455,48 +573,69 @@ def test_reuse_request_from_response(self): # a _RequestProxy. # This test uses self.http_client.fetch because self.fetch calls # self.get_url on the input unconditionally. - url = self.get_url('/hello') + url = self.get_url("/hello") response = yield self.http_client.fetch(url) self.assertEqual(response.request.url, url) self.assertTrue(isinstance(response.request, HTTPRequest)) response2 = yield self.http_client.fetch(response.request) - self.assertEqual(response2.body, b'Hello world!') + self.assertEqual(response2.body, b"Hello world!") + + @gen_test + def test_bind_source_ip(self): + url = self.get_url("/hello") + request = HTTPRequest(url, network_interface="127.0.0.1") + response = yield self.http_client.fetch(request) + self.assertEqual(response.code, 200) + + with self.assertRaises((ValueError, HTTPError)) as context: # type: ignore + request = HTTPRequest(url, network_interface="not-interface-or-ip") + yield self.http_client.fetch(request) + self.assertIn("not-interface-or-ip", str(context.exception)) def test_all_methods(self): - for method in ['GET', 'DELETE', 'OPTIONS']: - response = self.fetch('/all_methods', method=method) + for method in ["GET", "DELETE", "OPTIONS"]: + response = self.fetch("/all_methods", method=method) self.assertEqual(response.body, utf8(method)) - for method in ['POST', 'PUT', 'PATCH']: - response = self.fetch('/all_methods', method=method, body=b'') + for method in ["POST", "PUT", "PATCH"]: + response = self.fetch("/all_methods", method=method, body=b"") self.assertEqual(response.body, utf8(method)) - response = self.fetch('/all_methods', method='HEAD') - self.assertEqual(response.body, b'') - response = self.fetch('/all_methods', method='OTHER', - allow_nonstandard_methods=True) - self.assertEqual(response.body, b'OTHER') + response = self.fetch("/all_methods", method="HEAD") + self.assertEqual(response.body, b"") + response = self.fetch( + "/all_methods", method="OTHER", allow_nonstandard_methods=True + ) + self.assertEqual(response.body, b"OTHER") def test_body_sanity_checks(self): # These methods require a body. - for method in ('POST', 'PUT', 'PATCH'): + for method in ("POST", "PUT", "PATCH"): with self.assertRaises(ValueError) as context: - self.fetch('/all_methods', method=method, raise_error=True) - self.assertIn('must not be None', str(context.exception)) + self.fetch("/all_methods", method=method, raise_error=True) + self.assertIn("must not be None", str(context.exception)) - resp = self.fetch('/all_methods', method=method, - allow_nonstandard_methods=True) + resp = self.fetch( + "/all_methods", method=method, allow_nonstandard_methods=True + ) self.assertEqual(resp.code, 200) # These methods don't allow a body. - for method in ('GET', 'DELETE', 'OPTIONS'): + for method in ("GET", "DELETE", "OPTIONS"): with self.assertRaises(ValueError) as context: - self.fetch('/all_methods', method=method, body=b'asdf', raise_error=True) - self.assertIn('must be None', str(context.exception)) + self.fetch( + "/all_methods", method=method, body=b"asdf", raise_error=True + ) + self.assertIn("must be None", str(context.exception)) # In most cases this can be overridden, but curl_httpclient # does not allow body with a GET at all. - if method != 'GET': - self.fetch('/all_methods', method=method, body=b'asdf', - allow_nonstandard_methods=True, raise_error=True) + if method != "GET": + self.fetch( + "/all_methods", + method=method, + body=b"asdf", + allow_nonstandard_methods=True, + raise_error=True, + ) self.assertEqual(resp.code, 200) # This test causes odd failures with the combination of @@ -516,8 +655,9 @@ def test_body_sanity_checks(self): # self.assertEqual(response.body, b"Post arg1: foo, arg2: bar") def test_put_307(self): - response = self.fetch("/redirect?status=307&url=/put", - method="PUT", body=b"hello") + response = self.fetch( + "/redirect?status=307&url=/put", method="PUT", body=b"hello" + ) response.rethrow() self.assertEqual(response.body, b"Put body: hello") @@ -527,69 +667,111 @@ def test_non_ascii_header(self): response.rethrow() self.assertEqual(response.headers["Foo"], native_str(u"\u00e9")) + def test_response_times(self): + # A few simple sanity checks of the response time fields to + # make sure they're using the right basis (between the + # wall-time and monotonic clocks). + start_time = time.time() + response = self.fetch("/hello") + response.rethrow() + self.assertGreaterEqual(response.request_time, 0) + self.assertLess(response.request_time, 1.0) + # A very crude check to make sure that start_time is based on + # wall time and not the monotonic clock. + assert response.start_time is not None + self.assertLess(abs(response.start_time - start_time), 1.0) + + for k, v in response.time_info.items(): + self.assertTrue(0 <= v < 1.0, "time_info[%s] out of bounds: %s" % (k, v)) + + def test_zero_timeout(self): + response = self.fetch("/hello", connect_timeout=0) + self.assertEqual(response.code, 200) + + response = self.fetch("/hello", request_timeout=0) + self.assertEqual(response.code, 200) + + response = self.fetch("/hello", connect_timeout=0, request_timeout=0) + self.assertEqual(response.code, 200) + + @gen_test + def test_error_after_cancel(self): + fut = self.http_client.fetch(self.get_url("/404")) + self.assertTrue(fut.cancel()) + with ExpectLog(app_log, "Exception after Future was cancelled") as el: + # We can't wait on the cancelled Future any more, so just + # let the IOLoop run until the exception gets logged (or + # not, in which case we exit the loop and ExpectLog will + # raise). + for i in range(100): + yield gen.sleep(0.01) + if el.logged_stack: + break + class RequestProxyTest(unittest.TestCase): def test_request_set(self): - proxy = _RequestProxy(HTTPRequest('http://example.com/', - user_agent='foo'), - dict()) - self.assertEqual(proxy.user_agent, 'foo') + proxy = _RequestProxy( + HTTPRequest("http://example.com/", user_agent="foo"), dict() + ) + self.assertEqual(proxy.user_agent, "foo") def test_default_set(self): - proxy = _RequestProxy(HTTPRequest('http://example.com/'), - dict(network_interface='foo')) - self.assertEqual(proxy.network_interface, 'foo') + proxy = _RequestProxy( + HTTPRequest("http://example.com/"), dict(network_interface="foo") + ) + self.assertEqual(proxy.network_interface, "foo") def test_both_set(self): - proxy = _RequestProxy(HTTPRequest('http://example.com/', - proxy_host='foo'), - dict(proxy_host='bar')) - self.assertEqual(proxy.proxy_host, 'foo') + proxy = _RequestProxy( + HTTPRequest("http://example.com/", proxy_host="foo"), dict(proxy_host="bar") + ) + self.assertEqual(proxy.proxy_host, "foo") def test_neither_set(self): - proxy = _RequestProxy(HTTPRequest('http://example.com/'), - dict()) + proxy = _RequestProxy(HTTPRequest("http://example.com/"), dict()) self.assertIs(proxy.auth_username, None) def test_bad_attribute(self): - proxy = _RequestProxy(HTTPRequest('http://example.com/'), - dict()) + proxy = _RequestProxy(HTTPRequest("http://example.com/"), dict()) with self.assertRaises(AttributeError): proxy.foo def test_defaults_none(self): - proxy = _RequestProxy(HTTPRequest('http://example.com/'), None) + proxy = _RequestProxy(HTTPRequest("http://example.com/"), None) self.assertIs(proxy.auth_username, None) class HTTPResponseTestCase(unittest.TestCase): def test_str(self): - response = HTTPResponse(HTTPRequest('http://example.com'), - 200, headers={}, buffer=BytesIO()) + response = HTTPResponse( # type: ignore + HTTPRequest("http://example.com"), 200, buffer=BytesIO() + ) s = str(response) - self.assertTrue(s.startswith('HTTPResponse(')) - self.assertIn('code=200', s) + self.assertTrue(s.startswith("HTTPResponse(")) + self.assertIn("code=200", s) class SyncHTTPClientTest(unittest.TestCase): def setUp(self): - if IOLoop.configured_class().__name__ == 'TwistedIOLoop': - # TwistedIOLoop only supports the global reactor, so we can't have - # separate IOLoops for client and server threads. - raise unittest.SkipTest( - 'Sync HTTPClient not compatible with TwistedIOLoop') self.server_ioloop = IOLoop() + event = threading.Event() @gen.coroutine def init_server(): sock, self.port = bind_unused_port() - app = Application([('/', HelloWorldHandler)]) + app = Application([("/", HelloWorldHandler)]) self.server = HTTPServer(app) self.server.add_socket(sock) - self.server_ioloop.run_sync(init_server) + event.set() - self.server_thread = threading.Thread(target=self.server_ioloop.start) + def start(): + self.server_ioloop.run_sync(init_server) + self.server_ioloop.start() + + self.server_thread = threading.Thread(target=start) self.server_thread.start() + event.wait() self.http_client = HTTPClient() @@ -604,61 +786,95 @@ def stop_server(): @gen.coroutine def slow_stop(): + yield self.server.close_all_connections() # The number of iterations is difficult to predict. Typically, # one is sufficient, although sometimes it needs more. for i in range(5): yield self.server_ioloop.stop() + self.server_ioloop.add_callback(slow_stop) + self.server_ioloop.add_callback(stop_server) self.server_thread.join() self.http_client.close() self.server_ioloop.close(all_fds=True) def get_url(self, path): - return 'http://127.0.0.1:%d%s' % (self.port, path) + return "http://127.0.0.1:%d%s" % (self.port, path) def test_sync_client(self): - response = self.http_client.fetch(self.get_url('/')) - self.assertEqual(b'Hello world!', response.body) + response = self.http_client.fetch(self.get_url("/")) + self.assertEqual(b"Hello world!", response.body) def test_sync_client_error(self): # Synchronous HTTPClient raises errors directly; no need for # response.rethrow() with self.assertRaises(HTTPError) as assertion: - self.http_client.fetch(self.get_url('/notfound')) + self.http_client.fetch(self.get_url("/notfound")) self.assertEqual(assertion.exception.code, 404) +class SyncHTTPClientSubprocessTest(unittest.TestCase): + def test_destructor_log(self): + # Regression test for + # https://github.com/tornadoweb/tornado/issues/2539 + # + # In the past, the following program would log an + # "inconsistent AsyncHTTPClient cache" error from a destructor + # when the process is shutting down. The shutdown process is + # subtle and I don't fully understand it; the failure does not + # manifest if that lambda isn't there or is a simpler object + # like an int (nor does it manifest in the tornado test suite + # as a whole, which is why we use this subprocess). + proc = subprocess.run( + [ + sys.executable, + "-c", + "from tornado.httpclient import HTTPClient; f = lambda: None; c = HTTPClient()", + ], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + check=True, + timeout=5, + ) + if proc.stdout: + print("STDOUT:") + print(to_unicode(proc.stdout)) + if proc.stdout: + self.fail("subprocess produced unexpected output") + + class HTTPRequestTestCase(unittest.TestCase): def test_headers(self): - request = HTTPRequest('http://example.com', headers={'foo': 'bar'}) - self.assertEqual(request.headers, {'foo': 'bar'}) + request = HTTPRequest("http://example.com", headers={"foo": "bar"}) + self.assertEqual(request.headers, {"foo": "bar"}) def test_headers_setter(self): - request = HTTPRequest('http://example.com') - request.headers = {'bar': 'baz'} - self.assertEqual(request.headers, {'bar': 'baz'}) + request = HTTPRequest("http://example.com") + request.headers = {"bar": "baz"} # type: ignore + self.assertEqual(request.headers, {"bar": "baz"}) def test_null_headers_setter(self): - request = HTTPRequest('http://example.com') - request.headers = None + request = HTTPRequest("http://example.com") + request.headers = None # type: ignore self.assertEqual(request.headers, {}) def test_body(self): - request = HTTPRequest('http://example.com', body='foo') - self.assertEqual(request.body, utf8('foo')) + request = HTTPRequest("http://example.com", body="foo") + self.assertEqual(request.body, utf8("foo")) def test_body_setter(self): - request = HTTPRequest('http://example.com') - request.body = 'foo' - self.assertEqual(request.body, utf8('foo')) + request = HTTPRequest("http://example.com") + request.body = "foo" # type: ignore + self.assertEqual(request.body, utf8("foo")) def test_if_modified_since(self): http_date = datetime.datetime.utcnow() - request = HTTPRequest('http://example.com', if_modified_since=http_date) - self.assertEqual(request.headers, - {'If-Modified-Since': format_timestamp(http_date)}) + request = HTTPRequest("http://example.com", if_modified_since=http_date) + self.assertEqual( + request.headers, {"If-Modified-Since": format_timestamp(http_date)} + ) class HTTPErrorTestCase(unittest.TestCase): @@ -674,7 +890,7 @@ def test_plain_error(self): self.assertEqual(repr(e), "HTTP 403: Forbidden") def test_error_with_response(self): - resp = HTTPResponse(HTTPRequest('http://example.com/'), 403) + resp = HTTPResponse(HTTPRequest("http://example.com/"), 403) with self.assertRaises(HTTPError) as cm: resp.rethrow() e = cm.exception diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py index 4bca757a66..614dec7b8f 100644 --- a/tornado/test/httpserver_test.py +++ b/tornado/test/httpserver_test.py @@ -1,35 +1,58 @@ -from __future__ import absolute_import, division, print_function - from tornado import gen, netutil -from tornado.concurrent import Future -from tornado.escape import json_decode, json_encode, utf8, _unicode, recursive_unicode, native_str +from tornado.escape import ( + json_decode, + json_encode, + utf8, + _unicode, + recursive_unicode, + native_str, +) from tornado.http1connection import HTTP1Connection from tornado.httpclient import HTTPError from tornado.httpserver import HTTPServer -from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine # noqa: E501 +from tornado.httputil import ( + HTTPHeaders, + HTTPMessageDelegate, + HTTPServerConnectionDelegate, + ResponseStartLine, +) from tornado.iostream import IOStream from tornado.locks import Event from tornado.log import gen_log from tornado.netutil import ssl_options_to_context from tornado.simple_httpclient import SimpleAsyncHTTPClient -from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test # noqa: E501 -from tornado.test.util import unittest, skipOnTravis, ignore_deprecation -from tornado.web import Application, RequestHandler, asynchronous, stream_request_body +from tornado.testing import ( + AsyncHTTPTestCase, + AsyncHTTPSTestCase, + AsyncTestCase, + ExpectLog, + gen_test, +) +from tornado.test.util import skipOnTravis +from tornado.web import Application, RequestHandler, stream_request_body from contextlib import closing import datetime import gzip +import logging import os import shutil import socket import ssl import sys import tempfile +import unittest +import urllib.parse from io import BytesIO +import typing + +if typing.TYPE_CHECKING: + from typing import Dict, List # noqa: F401 + -def read_stream_body(stream, callback): - """Reads an HTTP response from `stream` and runs callback with its +async def read_stream_body(stream): + """Reads an HTTP response from `stream` and returns a tuple of its start_line, headers and body.""" chunks = [] @@ -42,15 +65,19 @@ def data_received(self, chunk): chunks.append(chunk) def finish(self): - conn.detach() - callback((self.start_line, self.headers, b''.join(chunks))) + conn.detach() # type: ignore + conn = HTTP1Connection(stream, True) - conn.read_response(Delegate()) + delegate = Delegate() + await conn.read_response(delegate) + return delegate.start_line, delegate.headers, b"".join(chunks) class HandlerBaseTestCase(AsyncHTTPTestCase): + Handler = None + def get_app(self): - return Application([('/', self.__class__.Handler)]) + return Application([("/", self.__class__.Handler)]) def fetch_json(self, *args, **kwargs): response = self.fetch(*args, **kwargs) @@ -77,55 +104,58 @@ def post(self): # introduced in python3.2, it was present but undocumented in # python 2.7 skipIfOldSSL = unittest.skipIf( - getattr(ssl, 'OPENSSL_VERSION_INFO', (0, 0)) < (1, 0), - "old version of ssl module and/or openssl") + getattr(ssl, "OPENSSL_VERSION_INFO", (0, 0)) < (1, 0), + "old version of ssl module and/or openssl", +) class BaseSSLTest(AsyncHTTPSTestCase): def get_app(self): - return Application([('/', HelloWorldRequestHandler, - dict(protocol="https"))]) + return Application([("/", HelloWorldRequestHandler, dict(protocol="https"))]) class SSLTestMixin(object): def get_ssl_options(self): - return dict(ssl_version=self.get_ssl_version(), # type: ignore - **AsyncHTTPSTestCase.get_ssl_options()) + return dict( + ssl_version=self.get_ssl_version(), + **AsyncHTTPSTestCase.default_ssl_options() + ) def get_ssl_version(self): raise NotImplementedError() - def test_ssl(self): - response = self.fetch('/') + def test_ssl(self: typing.Any): + response = self.fetch("/") self.assertEqual(response.body, b"Hello world") - def test_large_post(self): - response = self.fetch('/', - method='POST', - body='A' * 5000) + def test_large_post(self: typing.Any): + response = self.fetch("/", method="POST", body="A" * 5000) self.assertEqual(response.body, b"Got 5000 bytes in POST") - def test_non_ssl_request(self): + def test_non_ssl_request(self: typing.Any): # Make sure the server closes the connection when it gets a non-ssl # connection, rather than waiting for a timeout or otherwise # misbehaving. - with ExpectLog(gen_log, '(SSL Error|uncaught exception)'): - with ExpectLog(gen_log, 'Uncaught exception', required=False): - with self.assertRaises((IOError, HTTPError)): + with ExpectLog(gen_log, "(SSL Error|uncaught exception)"): + with ExpectLog(gen_log, "Uncaught exception", required=False): + with self.assertRaises((IOError, HTTPError)): # type: ignore self.fetch( - self.get_url("/").replace('https:', 'http:'), + self.get_url("/").replace("https:", "http:"), request_timeout=3600, connect_timeout=3600, - raise_error=True) + raise_error=True, + ) - def test_error_logging(self): + def test_error_logging(self: typing.Any): # No stack traces are logged for SSL errors. - with ExpectLog(gen_log, 'SSL Error') as expect_log: - with self.assertRaises((IOError, HTTPError)): - self.fetch(self.get_url("/").replace("https:", "http:"), - raise_error=True) + with ExpectLog(gen_log, "SSL Error") as expect_log: + with self.assertRaises((IOError, HTTPError)): # type: ignore + self.fetch( + self.get_url("/").replace("https:", "http:"), raise_error=True + ) self.assertFalse(expect_log.logged_stack) + # Python's SSL implementation differs significantly between versions. # For example, SSLv3 and TLSv1 throw an exception if you try to read # from the socket before the handshake is complete, but the default @@ -151,8 +181,7 @@ def get_ssl_version(self): class SSLContextTest(BaseSSLTest, SSLTestMixin): def get_ssl_options(self): - context = ssl_options_to_context( - AsyncHTTPSTestCase.get_ssl_options(self)) + context = ssl_options_to_context(AsyncHTTPSTestCase.get_ssl_options(self)) assert isinstance(context, ssl.SSLContext) return context @@ -160,85 +189,108 @@ def get_ssl_options(self): class BadSSLOptionsTest(unittest.TestCase): def test_missing_arguments(self): application = Application() - self.assertRaises(KeyError, HTTPServer, application, ssl_options={ - "keyfile": "/__missing__.crt", - }) + self.assertRaises( + KeyError, + HTTPServer, + application, + ssl_options={"keyfile": "/__missing__.crt"}, + ) def test_missing_key(self): """A missing SSL key should cause an immediate exception.""" application = Application() module_dir = os.path.dirname(__file__) - existing_certificate = os.path.join(module_dir, 'test.crt') - existing_key = os.path.join(module_dir, 'test.key') - - self.assertRaises((ValueError, IOError), - HTTPServer, application, ssl_options={ - "certfile": "/__mising__.crt", - }) - self.assertRaises((ValueError, IOError), - HTTPServer, application, ssl_options={ - "certfile": existing_certificate, - "keyfile": "/__missing__.key" - }) + existing_certificate = os.path.join(module_dir, "test.crt") + existing_key = os.path.join(module_dir, "test.key") + + self.assertRaises( + (ValueError, IOError), + HTTPServer, + application, + ssl_options={"certfile": "/__mising__.crt"}, + ) + self.assertRaises( + (ValueError, IOError), + HTTPServer, + application, + ssl_options={ + "certfile": existing_certificate, + "keyfile": "/__missing__.key", + }, + ) # This actually works because both files exist - HTTPServer(application, ssl_options={ - "certfile": existing_certificate, - "keyfile": existing_key, - }) + HTTPServer( + application, + ssl_options={"certfile": existing_certificate, "keyfile": existing_key}, + ) class MultipartTestHandler(RequestHandler): def post(self): - self.finish({"header": self.request.headers["X-Header-Encoding-Test"], - "argument": self.get_argument("argument"), - "filename": self.request.files["files"][0].filename, - "filebody": _unicode(self.request.files["files"][0]["body"]), - }) + self.finish( + { + "header": self.request.headers["X-Header-Encoding-Test"], + "argument": self.get_argument("argument"), + "filename": self.request.files["files"][0].filename, + "filebody": _unicode(self.request.files["files"][0]["body"]), + } + ) # This test is also called from wsgi_test class HTTPConnectionTest(AsyncHTTPTestCase): def get_handlers(self): - return [("/multipart", MultipartTestHandler), - ("/hello", HelloWorldRequestHandler)] + return [ + ("/multipart", MultipartTestHandler), + ("/hello", HelloWorldRequestHandler), + ] def get_app(self): return Application(self.get_handlers()) def raw_fetch(self, headers, body, newline=b"\r\n"): with closing(IOStream(socket.socket())) as stream: - with ignore_deprecation(): - stream.connect(('127.0.0.1', self.get_http_port()), self.stop) - self.wait() + self.io_loop.run_sync( + lambda: stream.connect(("127.0.0.1", self.get_http_port())) + ) stream.write( - newline.join(headers + - [utf8("Content-Length: %d" % len(body))]) + - newline + newline + body) - read_stream_body(stream, self.stop) - start_line, headers, body = self.wait() + newline.join(headers + [utf8("Content-Length: %d" % len(body))]) + + newline + + newline + + body + ) + start_line, headers, body = self.io_loop.run_sync( + lambda: read_stream_body(stream) + ) return body def test_multipart_form(self): # Encodings here are tricky: Headers are latin1, bodies can be # anything (we use utf8 by default). - response = self.raw_fetch([ - b"POST /multipart HTTP/1.0", - b"Content-Type: multipart/form-data; boundary=1234567890", - b"X-Header-encoding-test: \xe9", - ], - b"\r\n".join([ - b"Content-Disposition: form-data; name=argument", - b"", - u"\u00e1".encode("utf-8"), - b"--1234567890", - u'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode("utf8"), - b"", - u"\u00fa".encode("utf-8"), - b"--1234567890--", - b"", - ])) + response = self.raw_fetch( + [ + b"POST /multipart HTTP/1.0", + b"Content-Type: multipart/form-data; boundary=1234567890", + b"X-Header-encoding-test: \xe9", + ], + b"\r\n".join( + [ + b"Content-Disposition: form-data; name=argument", + b"", + u"\u00e1".encode("utf-8"), + b"--1234567890", + u'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode( + "utf8" + ), + b"", + u"\u00fa".encode("utf-8"), + b"--1234567890--", + b"", + ] + ), + ) data = json_decode(response) self.assertEqual(u"\u00e9", data["header"]) self.assertEqual(u"\u00e1", data["argument"]) @@ -248,9 +300,8 @@ def test_multipart_form(self): def test_newlines(self): # We support both CRLF and bare LF as line separators. for newline in (b"\r\n", b"\n"): - response = self.raw_fetch([b"GET /hello HTTP/1.0"], b"", - newline=newline) - self.assertEqual(response, b'Hello world') + response = self.raw_fetch([b"GET /hello HTTP/1.0"], b"", newline=newline) + self.assertEqual(response, b"Hello world") @gen_test def test_100_continue(self): @@ -259,19 +310,24 @@ def test_100_continue(self): # headers, and then the real response after the body. stream = IOStream(socket.socket()) yield stream.connect(("127.0.0.1", self.get_http_port())) - yield stream.write(b"\r\n".join([ - b"POST /hello HTTP/1.1", - b"Content-Length: 1024", - b"Expect: 100-continue", - b"Connection: close", - b"\r\n"])) + yield stream.write( + b"\r\n".join( + [ + b"POST /hello HTTP/1.1", + b"Content-Length: 1024", + b"Expect: 100-continue", + b"Connection: close", + b"\r\n", + ] + ) + ) data = yield stream.read_until(b"\r\n\r\n") self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data) stream.write(b"a" * 1024) first_line = yield stream.read_until(b"\r\n") self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line) header_data = yield stream.read_until(b"\r\n\r\n") - headers = HTTPHeaders.parse(native_str(header_data.decode('latin1'))) + headers = HTTPHeaders.parse(native_str(header_data.decode("latin1"))) body = yield stream.read_bytes(int(headers["Content-Length"])) self.assertEqual(body, b"Got 1024 bytes in POST") stream.close() @@ -287,32 +343,34 @@ def post(self): class TypeCheckHandler(RequestHandler): def prepare(self): - self.errors = {} + self.errors = {} # type: Dict[str, str] fields = [ - ('method', str), - ('uri', str), - ('version', str), - ('remote_ip', str), - ('protocol', str), - ('host', str), - ('path', str), - ('query', str), + ("method", str), + ("uri", str), + ("version", str), + ("remote_ip", str), + ("protocol", str), + ("host", str), + ("path", str), + ("query", str), ] for field, expected_type in fields: self.check_type(field, getattr(self.request, field), expected_type) - self.check_type('header_key', list(self.request.headers.keys())[0], str) - self.check_type('header_value', list(self.request.headers.values())[0], str) + self.check_type("header_key", list(self.request.headers.keys())[0], str) + self.check_type("header_value", list(self.request.headers.values())[0], str) - self.check_type('cookie_key', list(self.request.cookies.keys())[0], str) - self.check_type('cookie_value', list(self.request.cookies.values())[0].value, str) + self.check_type("cookie_key", list(self.request.cookies.keys())[0], str) + self.check_type( + "cookie_value", list(self.request.cookies.values())[0].value, str + ) # secure cookies - self.check_type('arg_key', list(self.request.arguments.keys())[0], str) - self.check_type('arg_value', list(self.request.arguments.values())[0][0], bytes) + self.check_type("arg_key", list(self.request.arguments.keys())[0], str) + self.check_type("arg_value", list(self.request.arguments.values())[0][0], bytes) def post(self): - self.check_type('body', self.request.body, bytes) + self.check_type("body", self.request.body, bytes) self.write(self.errors) def get(self): @@ -321,16 +379,33 @@ def get(self): def check_type(self, name, obj, expected_type): actual_type = type(obj) if expected_type != actual_type: - self.errors[name] = "expected %s, got %s" % (expected_type, - actual_type) + self.errors[name] = "expected %s, got %s" % (expected_type, actual_type) + + +class PostEchoHandler(RequestHandler): + def post(self, *path_args): + self.write(dict(echo=self.get_argument("data"))) + + +class PostEchoGBKHandler(PostEchoHandler): + def decode_argument(self, value, name=None): + try: + return value.decode("gbk") + except Exception: + raise HTTPError(400, "invalid gbk bytes: %r" % value) class HTTPServerTest(AsyncHTTPTestCase): def get_app(self): - return Application([("/echo", EchoHandler), - ("/typecheck", TypeCheckHandler), - ("//doubleslash", EchoHandler), - ]) + return Application( + [ + ("/echo", EchoHandler), + ("/typecheck", TypeCheckHandler), + ("//doubleslash", EchoHandler), + ("/post_utf8", PostEchoHandler), + ("/post_gbk", PostEchoGBKHandler), + ] + ) def test_query_string_encoding(self): response = self.fetch("/echo?foo=%C3%A9") @@ -353,7 +428,9 @@ def test_types(self): data = json_decode(response.body) self.assertEqual(data, {}) - response = self.fetch("/typecheck", method="POST", body="foo=bar", headers=headers) + response = self.fetch( + "/typecheck", method="POST", body="foo=bar", headers=headers + ) data = json_decode(response.body) self.assertEqual(data, {}) @@ -365,36 +442,38 @@ def test_double_slash(self): self.assertEqual(200, response.code) self.assertEqual(json_decode(response.body), {}) - def test_malformed_body(self): - # parse_qs is pretty forgiving, but it will fail on python 3 - # if the data is not utf8. On python 2 parse_qs will work, - # but then the recursive_unicode call in EchoHandler will - # fail. - if str is bytes: - return - with ExpectLog(gen_log, 'Invalid x-www-form-urlencoded body'): - response = self.fetch( - '/echo', method="POST", - headers={'Content-Type': 'application/x-www-form-urlencoded'}, - body=b'\xe9') - self.assertEqual(200, response.code) - self.assertEqual(b'{}', response.body) + def test_post_encodings(self): + headers = {"Content-Type": "application/x-www-form-urlencoded"} + uni_text = "chinese: \u5f20\u4e09" + for enc in ("utf8", "gbk"): + for quote in (True, False): + with self.subTest(enc=enc, quote=quote): + bin_text = uni_text.encode(enc) + if quote: + bin_text = urllib.parse.quote(bin_text).encode("ascii") + response = self.fetch( + "/post_" + enc, + method="POST", + headers=headers, + body=(b"data=" + bin_text), + ) + self.assertEqual(json_decode(response.body), {"echo": uni_text}) class HTTPServerRawTest(AsyncHTTPTestCase): def get_app(self): - return Application([ - ('/echo', EchoHandler), - ]) + return Application([("/echo", EchoHandler)]) def setUp(self): - super(HTTPServerRawTest, self).setUp() + super().setUp() self.stream = IOStream(socket.socket()) - self.io_loop.run_sync(lambda: self.stream.connect(('127.0.0.1', self.get_http_port()))) + self.io_loop.run_sync( + lambda: self.stream.connect(("127.0.0.1", self.get_http_port())) + ) def tearDown(self): self.stream.close() - super(HTTPServerRawTest, self).tearDown() + super().tearDown() def test_empty_request(self): self.stream.close() @@ -402,34 +481,38 @@ def test_empty_request(self): self.wait() def test_malformed_first_line_response(self): - with ExpectLog(gen_log, '.*Malformed HTTP request line'): - self.stream.write(b'asdf\r\n\r\n') - read_stream_body(self.stream, self.stop) - start_line, headers, response = self.wait() - self.assertEqual('HTTP/1.1', start_line.version) + with ExpectLog(gen_log, ".*Malformed HTTP request line", level=logging.INFO): + self.stream.write(b"asdf\r\n\r\n") + start_line, headers, response = self.io_loop.run_sync( + lambda: read_stream_body(self.stream) + ) + self.assertEqual("HTTP/1.1", start_line.version) self.assertEqual(400, start_line.code) - self.assertEqual('Bad Request', start_line.reason) + self.assertEqual("Bad Request", start_line.reason) def test_malformed_first_line_log(self): - with ExpectLog(gen_log, '.*Malformed HTTP request line'): - self.stream.write(b'asdf\r\n\r\n') + with ExpectLog(gen_log, ".*Malformed HTTP request line", level=logging.INFO): + self.stream.write(b"asdf\r\n\r\n") # TODO: need an async version of ExpectLog so we don't need # hard-coded timeouts here. - self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), - self.stop) + self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), self.stop) self.wait() def test_malformed_headers(self): - with ExpectLog(gen_log, '.*Malformed HTTP message.*no colon in header line'): - self.stream.write(b'GET / HTTP/1.0\r\nasdf\r\n\r\n') - self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), - self.stop) + with ExpectLog( + gen_log, + ".*Malformed HTTP message.*no colon in header line", + level=logging.INFO, + ): + self.stream.write(b"GET / HTTP/1.0\r\nasdf\r\n\r\n") + self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), self.stop) self.wait() def test_chunked_request_body(self): # Chunked requests are not widely supported and we don't have a way # to generate them in AsyncHTTPClient, but HTTPServer will read them. - self.stream.write(b"""\ + self.stream.write( + b"""\ POST /echo HTTP/1.1 Transfer-Encoding: chunked Content-Type: application/x-www-form-urlencoded @@ -440,15 +523,20 @@ def test_chunked_request_body(self): bar 0 -""".replace(b"\n", b"\r\n")) - read_stream_body(self.stream, self.stop) - start_line, headers, response = self.wait() - self.assertEqual(json_decode(response), {u'foo': [u'bar']}) +""".replace( + b"\n", b"\r\n" + ) + ) + start_line, headers, response = self.io_loop.run_sync( + lambda: read_stream_body(self.stream) + ) + self.assertEqual(json_decode(response), {u"foo": [u"bar"]}) def test_chunked_request_uppercase(self): # As per RFC 2616 section 3.6, "Transfer-Encoding" header's value is # case-insensitive. - self.stream.write(b"""\ + self.stream.write( + b"""\ POST /echo HTTP/1.1 Transfer-Encoding: Chunked Content-Type: application/x-www-form-urlencoded @@ -459,118 +547,136 @@ def test_chunked_request_uppercase(self): bar 0 -""".replace(b"\n", b"\r\n")) - read_stream_body(self.stream, self.stop) - start_line, headers, response = self.wait() - self.assertEqual(json_decode(response), {u'foo': [u'bar']}) +""".replace( + b"\n", b"\r\n" + ) + ) + start_line, headers, response = self.io_loop.run_sync( + lambda: read_stream_body(self.stream) + ) + self.assertEqual(json_decode(response), {u"foo": [u"bar"]}) @gen_test def test_invalid_content_length(self): - with ExpectLog(gen_log, '.*Only integer Content-Length is allowed'): - self.stream.write(b"""\ + with ExpectLog( + gen_log, ".*Only integer Content-Length is allowed", level=logging.INFO + ): + self.stream.write( + b"""\ POST /echo HTTP/1.1 Content-Length: foo bar -""".replace(b"\n", b"\r\n")) +""".replace( + b"\n", b"\r\n" + ) + ) yield self.stream.read_until_close() class XHeaderTest(HandlerBaseTestCase): class Handler(RequestHandler): def get(self): - self.set_header('request-version', self.request.version) - self.write(dict(remote_ip=self.request.remote_ip, - remote_protocol=self.request.protocol)) + self.set_header("request-version", self.request.version) + self.write( + dict( + remote_ip=self.request.remote_ip, + remote_protocol=self.request.protocol, + ) + ) def get_httpserver_options(self): - return dict(xheaders=True, trusted_downstream=['5.5.5.5']) + return dict(xheaders=True, trusted_downstream=["5.5.5.5"]) def test_ip_headers(self): self.assertEqual(self.fetch_json("/")["remote_ip"], "127.0.0.1") valid_ipv4 = {"X-Real-IP": "4.4.4.4"} self.assertEqual( - self.fetch_json("/", headers=valid_ipv4)["remote_ip"], - "4.4.4.4") + self.fetch_json("/", headers=valid_ipv4)["remote_ip"], "4.4.4.4" + ) valid_ipv4_list = {"X-Forwarded-For": "127.0.0.1, 4.4.4.4"} self.assertEqual( - self.fetch_json("/", headers=valid_ipv4_list)["remote_ip"], - "4.4.4.4") + self.fetch_json("/", headers=valid_ipv4_list)["remote_ip"], "4.4.4.4" + ) valid_ipv6 = {"X-Real-IP": "2620:0:1cfe:face:b00c::3"} self.assertEqual( self.fetch_json("/", headers=valid_ipv6)["remote_ip"], - "2620:0:1cfe:face:b00c::3") + "2620:0:1cfe:face:b00c::3", + ) valid_ipv6_list = {"X-Forwarded-For": "::1, 2620:0:1cfe:face:b00c::3"} self.assertEqual( self.fetch_json("/", headers=valid_ipv6_list)["remote_ip"], - "2620:0:1cfe:face:b00c::3") + "2620:0:1cfe:face:b00c::3", + ) invalid_chars = {"X-Real-IP": "4.4.4.4 ' - for p in paths) + return "".join( + '' + for p in paths + ) - def render_embed_js(self, js_embed): + def render_embed_js(self, js_embed: Iterable[bytes]) -> bytes: """Default method used to render the final embedded js for the rendered webpage. Override this method in a sub-classed controller to change the output. """ - return b'' + return ( + b'" + ) - def render_linked_css(self, css_files): + def render_linked_css(self, css_files: Iterable[str]) -> str: """Default method used to render the final css links for the rendered webpage. Override this method in a sub-classed controller to change the output. """ paths = [] - unique_paths = set() + unique_paths = set() # type: Set[str] for path in css_files: if not is_absolute(path): @@ -852,20 +967,21 @@ def render_linked_css(self, css_files): paths.append(path) unique_paths.add(path) - return ''.join('' - for p in paths) + return "".join( + '' + for p in paths + ) - def render_embed_css(self, css_embed): + def render_embed_css(self, css_embed: Iterable[bytes]) -> bytes: """Default method used to render the final embedded css for the rendered webpage. Override this method in a sub-classed controller to change the output. """ - return b'' + return b'" - def render_string(self, template_name, **kwargs): + def render_string(self, template_name: str, **kwargs: Any) -> bytes: """Generate the given template with the given arguments. We return the generated byte string (in utf8). To generate and @@ -876,8 +992,9 @@ def render_string(self, template_name, **kwargs): if not template_path: frame = sys._getframe(0) web_file = frame.f_code.co_filename - while frame.f_code.co_filename == web_file: + while frame.f_code.co_filename == web_file and frame.f_back is not None: frame = frame.f_back + assert frame.f_code.co_filename is not None template_path = os.path.dirname(frame.f_code.co_filename) with RequestHandler._template_loader_lock: if template_path not in RequestHandler._template_loaders: @@ -890,7 +1007,7 @@ def render_string(self, template_name, **kwargs): namespace.update(kwargs) return t.generate(**namespace) - def get_template_namespace(self): + def get_template_namespace(self) -> Dict[str, Any]: """Returns a dictionary to be used as the default template namespace. May be overridden by subclasses to add or modify values. @@ -908,12 +1025,12 @@ def get_template_namespace(self): pgettext=self.locale.pgettext, static_url=self.static_url, xsrf_form_html=self.xsrf_form_html, - reverse_url=self.reverse_url + reverse_url=self.reverse_url, ) namespace.update(self.ui) return namespace - def create_template_loader(self, template_path): + def create_template_loader(self, template_path: str) -> template.BaseLoader: """Returns a new template loader for the given path. May be overridden by subclasses. By default returns a @@ -934,30 +1051,33 @@ def create_template_loader(self, template_path): kwargs["whitespace"] = settings["template_whitespace"] return template.Loader(template_path, **kwargs) - def flush(self, include_footers=False, callback=None): + def flush(self, include_footers: bool = False) -> "Future[None]": """Flushes the current output buffer to the network. - The ``callback`` argument, if given, can be used for flow control: - it will be run when all flushed data has been written to the socket. - Note that only one flush callback can be outstanding at a time; - if another flush occurs before the previous flush's callback - has been run, the previous callback will be discarded. - .. versionchanged:: 4.0 Now returns a `.Future` if no callback is given. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. """ + assert self.request.connection is not None chunk = b"".join(self._write_buffer) self._write_buffer = [] if not self._headers_written: self._headers_written = True for transform in self._transforms: - self._status_code, self._headers, chunk = \ - transform.transform_first_chunk( - self._status_code, self._headers, - chunk, include_footers) + assert chunk is not None + ( + self._status_code, + self._headers, + chunk, + ) = transform.transform_first_chunk( + self._status_code, self._headers, chunk, include_footers + ) # Ignore the chunk and only write the headers for HEAD requests if self.request.method == "HEAD": - chunk = None + chunk = b"" # Finalize the cookie headers (which have been stored in a side # object so an outgoing cookie could be overwritten before it @@ -966,24 +1086,36 @@ def flush(self, include_footers=False, callback=None): for cookie in self._new_cookie.values(): self.add_header("Set-Cookie", cookie.OutputString(None)) - start_line = httputil.ResponseStartLine('', - self._status_code, - self._reason) + start_line = httputil.ResponseStartLine("", self._status_code, self._reason) return self.request.connection.write_headers( - start_line, self._headers, chunk, callback=callback) + start_line, self._headers, chunk + ) else: for transform in self._transforms: chunk = transform.transform_chunk(chunk, include_footers) # Ignore the chunk and only write the headers for HEAD requests if self.request.method != "HEAD": - return self.request.connection.write(chunk, callback=callback) + return self.request.connection.write(chunk) else: - future = Future() + future = Future() # type: Future[None] future.set_result(None) return future - def finish(self, chunk=None): - """Finishes this response, ending the HTTP request.""" + def finish(self, chunk: Optional[Union[str, bytes, dict]] = None) -> "Future[None]": + """Finishes this response, ending the HTTP request. + + Passing a ``chunk`` to ``finish()`` is equivalent to passing that + chunk to ``write()`` and then calling ``finish()`` with no arguments. + + Returns a `.Future` which may optionally be awaited to track the sending + of the response to the client. This `.Future` resolves when all the response + data has been sent, and raises an error if the connection is closed before all + data can be sent. + + .. versionchanged:: 5.1 + + Now returns a `.Future` instead of ``None``. + """ if self._finished: raise RuntimeError("finish() called twice") @@ -993,41 +1125,60 @@ def finish(self, chunk=None): # Automatically support ETags and add the Content-Length header if # we have not flushed any content yet. if not self._headers_written: - if (self._status_code == 200 and - self.request.method in ("GET", "HEAD") and - "Etag" not in self._headers): + if ( + self._status_code == 200 + and self.request.method in ("GET", "HEAD") + and "Etag" not in self._headers + ): self.set_etag_header() if self.check_etag_header(): self._write_buffer = [] self.set_status(304) - if (self._status_code in (204, 304) or - (self._status_code >= 100 and self._status_code < 200)): - assert not self._write_buffer, "Cannot send body with %s" % self._status_code - self._clear_headers_for_304() + if self._status_code in (204, 304) or (100 <= self._status_code < 200): + assert not self._write_buffer, ( + "Cannot send body with %s" % self._status_code + ) + self._clear_representation_headers() elif "Content-Length" not in self._headers: content_length = sum(len(part) for part in self._write_buffer) self.set_header("Content-Length", content_length) - if hasattr(self.request, "connection"): - # Now that the request is finished, clear the callback we - # set on the HTTPConnection (which would otherwise prevent the - # garbage collection of the RequestHandler when there - # are keepalive connections) - self.request.connection.set_close_callback(None) + assert self.request.connection is not None + # Now that the request is finished, clear the callback we + # set on the HTTPConnection (which would otherwise prevent the + # garbage collection of the RequestHandler when there + # are keepalive connections) + self.request.connection.set_close_callback(None) # type: ignore - self.flush(include_footers=True) + future = self.flush(include_footers=True) self.request.connection.finish() self._log() self._finished = True self.on_finish() self._break_cycles() + return future - def _break_cycles(self): + def detach(self) -> iostream.IOStream: + """Take control of the underlying stream. + + Returns the underlying `.IOStream` object and stops all + further HTTP processing. Intended for implementing protocols + like websockets that tunnel over an HTTP handshake. + + This method is only supported when HTTP/1.1 is used. + + .. versionadded:: 5.1 + """ + self._finished = True + # TODO: add detach to HTTPConnection? + return self.request.connection.detach() # type: ignore + + def _break_cycles(self) -> None: # Break up a reference cycle between this handler and the # _ui_module closures to allow for faster GC on CPython. - self.ui = None + self.ui = None # type: ignore - def send_error(self, status_code=500, **kwargs): + def send_error(self, status_code: int = 500, **kwargs: Any) -> None: """Sends the given HTTP error code to the browser. If `flush()` has already been called, it is not possible to send @@ -1048,14 +1199,13 @@ def send_error(self, status_code=500, **kwargs): try: self.finish() except Exception: - gen_log.error("Failed to flush partial response", - exc_info=True) + gen_log.error("Failed to flush partial response", exc_info=True) return self.clear() - reason = kwargs.get('reason') - if 'exc_info' in kwargs: - exception = kwargs['exc_info'][1] + reason = kwargs.get("reason") + if "exc_info" in kwargs: + exception = kwargs["exc_info"][1] if isinstance(exception, HTTPError) and exception.reason: reason = exception.reason self.set_status(status_code, reason=reason) @@ -1066,7 +1216,7 @@ def send_error(self, status_code=500, **kwargs): if not self._finished: self.finish() - def write_error(self, status_code, **kwargs): + def write_error(self, status_code: int, **kwargs: Any) -> None: """Override to implement custom error pages. ``write_error`` may call `write`, `render`, `set_header`, etc @@ -1080,19 +1230,19 @@ def write_error(self, status_code, **kwargs): """ if self.settings.get("serve_traceback") and "exc_info" in kwargs: # in debug mode, try to send a traceback - self.set_header('Content-Type', 'text/plain') + self.set_header("Content-Type", "text/plain") for line in traceback.format_exception(*kwargs["exc_info"]): self.write(line) self.finish() else: - self.finish("%(code)d: %(message)s" - "%(code)d: %(message)s" % { - "code": status_code, - "message": self._reason, - }) + self.finish( + "%(code)d: %(message)s" + "%(code)d: %(message)s" + % {"code": status_code, "message": self._reason} + ) @property - def locale(self): + def locale(self) -> tornado.locale.Locale: """The locale for the current session. Determined by either `get_user_locale`, which you can override to @@ -1104,17 +1254,19 @@ def locale(self): Added a property setter. """ if not hasattr(self, "_locale"): - self._locale = self.get_user_locale() - if not self._locale: + loc = self.get_user_locale() + if loc is not None: + self._locale = loc + else: self._locale = self.get_browser_locale() assert self._locale return self._locale @locale.setter - def locale(self, value): + def locale(self, value: tornado.locale.Locale) -> None: self._locale = value - def get_user_locale(self): + def get_user_locale(self) -> Optional[tornado.locale.Locale]: """Override to determine the locale from the authenticated user. If None is returned, we fall back to `get_browser_locale()`. @@ -1124,7 +1276,7 @@ def get_user_locale(self): """ return None - def get_browser_locale(self, default="en_US"): + def get_browser_locale(self, default: str = "en_US") -> tornado.locale.Locale: """Determines the user's locale from ``Accept-Language`` header. See http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.4 @@ -1144,12 +1296,12 @@ def get_browser_locale(self, default="en_US"): locales.append((parts[0], score)) if locales: locales.sort(key=lambda pair: pair[1], reverse=True) - codes = [l[0] for l in locales] + codes = [loc[0] for loc in locales] return locale.get(*codes) return locale.get(default) @property - def current_user(self): + def current_user(self) -> Any: """The authenticated user for this request. This is set in one of two ways: @@ -1185,17 +1337,17 @@ def prepare(self): return self._current_user @current_user.setter - def current_user(self, value): + def current_user(self, value: Any) -> None: self._current_user = value - def get_current_user(self): + def get_current_user(self) -> Any: """Override to determine the current user from, e.g., a cookie. This method may not be a coroutine. """ return None - def get_login_url(self): + def get_login_url(self) -> str: """Override to customize the login URL based on the request. By default, we use the ``login_url`` application setting. @@ -1203,7 +1355,7 @@ def get_login_url(self): self.require_setting("login_url", "@tornado.web.authenticated") return self.application.settings["login_url"] - def get_template_path(self): + def get_template_path(self) -> Optional[str]: """Override to customize template path for each handler. By default, we use the ``template_path`` application setting. @@ -1212,7 +1364,7 @@ def get_template_path(self): return self.application.settings.get("template_path") @property - def xsrf_token(self): + def xsrf_token(self) -> bytes: """The XSRF-prevention token for the current user/session. To prevent cross-site request forgery, we set an '_xsrf' cookie @@ -1252,22 +1404,23 @@ def xsrf_token(self): self._xsrf_token = binascii.b2a_hex(token) elif output_version == 2: mask = os.urandom(4) - self._xsrf_token = b"|".join([ - b"2", - binascii.b2a_hex(mask), - binascii.b2a_hex(_websocket_mask(mask, token)), - utf8(str(int(timestamp)))]) + self._xsrf_token = b"|".join( + [ + b"2", + binascii.b2a_hex(mask), + binascii.b2a_hex(_websocket_mask(mask, token)), + utf8(str(int(timestamp))), + ] + ) else: - raise ValueError("unknown xsrf cookie version %d", - output_version) + raise ValueError("unknown xsrf cookie version %d", output_version) if version is None: - expires_days = 30 if self.current_user else None - self.set_cookie("_xsrf", self._xsrf_token, - expires_days=expires_days, - **cookie_kwargs) + if self.current_user and "expires_days" not in cookie_kwargs: + cookie_kwargs["expires_days"] = 30 + self.set_cookie("_xsrf", self._xsrf_token, **cookie_kwargs) return self._xsrf_token - def _get_raw_xsrf_token(self): + def _get_raw_xsrf_token(self) -> Tuple[Optional[int], bytes, float]: """Read or generate the xsrf token in its raw form. The raw_xsrf_token is a tuple containing: @@ -1278,7 +1431,7 @@ def _get_raw_xsrf_token(self): * timestamp: the time this token was generated (will not be accurate for version 1 cookies) """ - if not hasattr(self, '_raw_xsrf_token'): + if not hasattr(self, "_raw_xsrf_token"): cookie = self.get_cookie("_xsrf") if cookie: version, token, timestamp = self._decode_xsrf_token(cookie) @@ -1288,10 +1441,14 @@ def _get_raw_xsrf_token(self): version = None token = os.urandom(16) timestamp = time.time() + assert token is not None + assert timestamp is not None self._raw_xsrf_token = (version, token, timestamp) return self._raw_xsrf_token - def _decode_xsrf_token(self, cookie): + def _decode_xsrf_token( + self, cookie: str + ) -> Tuple[Optional[int], Optional[bytes], Optional[float]]: """Convert a cookie string into a the tuple form returned by _get_raw_xsrf_token. """ @@ -1302,12 +1459,11 @@ def _decode_xsrf_token(self, cookie): if m: version = int(m.group(1)) if version == 2: - _, mask, masked_token, timestamp = cookie.split("|") + _, mask_str, masked_token, timestamp_str = cookie.split("|") - mask = binascii.a2b_hex(utf8(mask)) - token = _websocket_mask( - mask, binascii.a2b_hex(utf8(masked_token))) - timestamp = int(timestamp) + mask = binascii.a2b_hex(utf8(mask_str)) + token = _websocket_mask(mask, binascii.a2b_hex(utf8(masked_token))) + timestamp = int(timestamp_str) return version, token, timestamp else: # Treat unknown versions as not present instead of failing. @@ -1323,11 +1479,10 @@ def _decode_xsrf_token(self, cookie): return (version, token, timestamp) except Exception: # Catch exceptions and return nothing instead of failing. - gen_log.debug("Uncaught exception in _decode_xsrf_token", - exc_info=True) + gen_log.debug("Uncaught exception in _decode_xsrf_token", exc_info=True) return None, None, None - def check_xsrf_cookie(self): + def check_xsrf_cookie(self) -> None: """Verifies that the ``_xsrf`` cookie matches the ``_xsrf`` argument. To prevent cross-site request forgery, we set an ``_xsrf`` @@ -1341,30 +1496,31 @@ def check_xsrf_cookie(self): See http://en.wikipedia.org/wiki/Cross-site_request_forgery - Prior to release 1.1.1, this check was ignored if the HTTP header - ``X-Requested-With: XMLHTTPRequest`` was present. This exception - has been shown to be insecure and has been removed. For more - information please see - http://www.djangoproject.com/weblog/2011/feb/08/security/ - http://weblog.rubyonrails.org/2011/2/8/csrf-protection-bypass-in-ruby-on-rails - .. versionchanged:: 3.2.2 Added support for cookie version 2. Both versions 1 and 2 are supported. """ - token = (self.get_argument("_xsrf", None) or - self.request.headers.get("X-Xsrftoken") or - self.request.headers.get("X-Csrftoken")) + # Prior to release 1.1.1, this check was ignored if the HTTP header + # ``X-Requested-With: XMLHTTPRequest`` was present. This exception + # has been shown to be insecure and has been removed. For more + # information please see + # http://www.djangoproject.com/weblog/2011/feb/08/security/ + # http://weblog.rubyonrails.org/2011/2/8/csrf-protection-bypass-in-ruby-on-rails + token = ( + self.get_argument("_xsrf", None) + or self.request.headers.get("X-Xsrftoken") + or self.request.headers.get("X-Csrftoken") + ) if not token: raise HTTPError(403, "'_xsrf' argument missing from POST") _, token, _ = self._decode_xsrf_token(token) _, expected_token, _ = self._get_raw_xsrf_token() if not token: raise HTTPError(403, "'_xsrf' argument has invalid format") - if not _time_independent_equals(utf8(token), utf8(expected_token)): + if not hmac.compare_digest(utf8(token), utf8(expected_token)): raise HTTPError(403, "XSRF cookie does not match POST argument") - def xsrf_form_html(self): + def xsrf_form_html(self) -> str: """An HTML ```` element to be included with all POST forms. It defines the ``_xsrf`` input value, which we check on all POST @@ -1377,10 +1533,15 @@ def xsrf_form_html(self): See `check_xsrf_cookie()` above for more information. """ - return '' + return ( + '' + ) - def static_url(self, path, include_host=None, **kwargs): + def static_url( + self, path: str, include_host: Optional[bool] = None, **kwargs: Any + ) -> str: """Returns a static URL for the given relative static file path. This method requires you set the ``static_path`` setting in your @@ -1402,8 +1563,9 @@ def static_url(self, path, include_host=None, **kwargs): """ self.require_setting("static_path", "static_url") - get_url = self.settings.get("static_handler_class", - StaticFileHandler).make_static_url + get_url = self.settings.get( + "static_handler_class", StaticFileHandler + ).make_static_url if include_host is None: include_host = getattr(self, "include_host", False) @@ -1415,17 +1577,19 @@ def static_url(self, path, include_host=None, **kwargs): return base + get_url(self.settings, path, **kwargs) - def require_setting(self, name, feature="this feature"): + def require_setting(self, name: str, feature: str = "this feature") -> None: """Raises an exception if the given app setting is not defined.""" if not self.application.settings.get(name): - raise Exception("You must define the '%s' setting in your " - "application to use %s" % (name, feature)) + raise Exception( + "You must define the '%s' setting in your " + "application to use %s" % (name, feature) + ) - def reverse_url(self, name, *args): + def reverse_url(self, name: str, *args: Any) -> str: """Alias for `Application.reverse_url`.""" return self.application.reverse_url(name, *args) - def compute_etag(self): + def compute_etag(self) -> Optional[str]: """Computes the etag header to be used for this request. By default uses a hash of the content written so far. @@ -1438,7 +1602,7 @@ def compute_etag(self): hasher.update(part) return '"%s"' % hasher.hexdigest() - def set_etag_header(self): + def set_etag_header(self) -> None: """Sets the response's Etag header using ``self.compute_etag()``. Note: no header will be set if ``compute_etag()`` returns ``None``. @@ -1449,7 +1613,7 @@ def set_etag_header(self): if etag is not None: self.set_header("Etag", etag) - def check_etag_header(self): + def check_etag_header(self) -> bool: """Checks the ``Etag`` header against requests's ``If-None-Match``. Returns ``True`` if the request's Etag matches and a 304 should be @@ -1470,19 +1634,18 @@ def check_etag_header(self): # Find all weak and strong etag values from If-None-Match header # because RFC 7232 allows multiple etag values in a single header. etags = re.findall( - br'\*|(?:W/)?"[^"]*"', - utf8(self.request.headers.get("If-None-Match", "")) + br'\*|(?:W/)?"[^"]*"', utf8(self.request.headers.get("If-None-Match", "")) ) if not computed_etag or not etags: return False match = False - if etags[0] == b'*': + if etags[0] == b"*": match = True else: # Use a weak comparison when comparing entity-tags. - def val(x): - return x[2:] if x.startswith(b'W/') else x + def val(x: bytes) -> bytes: + return x[2:] if x.startswith(b"W/") else x for etag in etags: if val(etag) == val(computed_etag): @@ -1490,36 +1653,34 @@ def val(x): break return match - def _stack_context_handle_exception(self, type, value, traceback): - try: - # For historical reasons _handle_request_exception only takes - # the exception value instead of the full triple, - # so re-raise the exception to ensure that it's in - # sys.exc_info() - raise_exc_info((type, value, traceback)) - except Exception: - self._handle_request_exception(value) - return True - - @gen.coroutine - def _execute(self, transforms, *args, **kwargs): + async def _execute( + self, transforms: List["OutputTransform"], *args: bytes, **kwargs: bytes + ) -> None: """Executes this request with the given output transforms.""" self._transforms = transforms try: if self.request.method not in self.SUPPORTED_METHODS: raise HTTPError(405) self.path_args = [self.decode_argument(arg) for arg in args] - self.path_kwargs = dict((k, self.decode_argument(v, name=k)) - for (k, v) in kwargs.items()) + self.path_kwargs = dict( + (k, self.decode_argument(v, name=k)) for (k, v) in kwargs.items() + ) # If XSRF cookies are turned on, reject form submissions without # the proper cookie - if self.request.method not in ("GET", "HEAD", "OPTIONS") and \ - self.application.settings.get("xsrf_cookies"): + if ( + self.request.method + not in ( + "GET", + "HEAD", + "OPTIONS", + ) + and self.application.settings.get("xsrf_cookies") + ): self.check_xsrf_cookie() result = self.prepare() if result is not None: - result = yield result + result = await result if self._prepared_future is not None: # Tell the Application we've finished with prepare() # and are ready for the body to arrive. @@ -1533,14 +1694,14 @@ def _execute(self, transforms, *args, **kwargs): # result; the data has been passed to self.data_received # instead. try: - yield self.request.body + await self.request._body_future except iostream.StreamClosedError: return method = getattr(self, self.request.method.lower()) result = method(*self.path_args, **self.path_kwargs) if result is not None: - result = yield result + result = await result if self._auto_finish and not self._finished: self.finish() except Exception as e: @@ -1551,21 +1712,22 @@ def _execute(self, transforms, *args, **kwargs): finally: # Unset result to avoid circular references result = None - if (self._prepared_future is not None and - not self._prepared_future.done()): + if self._prepared_future is not None and not self._prepared_future.done(): # In case we failed before setting _prepared_future, do it # now (to unblock the HTTP server). Note that this is not # in a finally block to avoid GC issues prior to Python 3.4. self._prepared_future.set_result(None) - def data_received(self, chunk): + def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: """Implement this method to handle streamed request data. Requires the `.stream_request_body` decorator. + + May be a coroutine for flow control. """ raise NotImplementedError() - def _log(self): + def _log(self) -> None: """Logs the current request. Sort of deprecated since this functionality was moved to the @@ -1574,11 +1736,14 @@ def _log(self): """ self.application.log_request(self) - def _request_summary(self): - return "%s %s (%s)" % (self.request.method, self.request.uri, - self.request.remote_ip) + def _request_summary(self) -> str: + return "%s %s (%s)" % ( + self.request.method, + self.request.uri, + self.request.remote_ip, + ) - def _handle_request_exception(self, e): + def _handle_request_exception(self, e: BaseException) -> None: if isinstance(e, Finish): # Not an error; just finish the request without logging. if not self._finished: @@ -1600,7 +1765,12 @@ def _handle_request_exception(self, e): else: self.send_error(500, exc_info=sys.exc_info()) - def log_exception(self, typ, value, tb): + def log_exception( + self, + typ: "Optional[Type[BaseException]]", + value: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: """Override to customize logging of uncaught exceptions. By default logs instances of `HTTPError` as warnings without @@ -1613,116 +1783,42 @@ def log_exception(self, typ, value, tb): if isinstance(value, HTTPError): if value.log_message: format = "%d %s: " + value.log_message - args = ([value.status_code, self._request_summary()] + - list(value.args)) + args = [value.status_code, self._request_summary()] + list(value.args) gen_log.warning(format, *args) else: - app_log.error("Uncaught exception %s\n%r", self._request_summary(), - self.request, exc_info=(typ, value, tb)) - - def _ui_module(self, name, module): - def render(*args, **kwargs): + app_log.error( + "Uncaught exception %s\n%r", + self._request_summary(), + self.request, + exc_info=(typ, value, tb), # type: ignore + ) + + def _ui_module(self, name: str, module: Type["UIModule"]) -> Callable[..., str]: + def render(*args, **kwargs) -> str: # type: ignore if not hasattr(self, "_active_modules"): - self._active_modules = {} + self._active_modules = {} # type: Dict[str, UIModule] if name not in self._active_modules: self._active_modules[name] = module(self) rendered = self._active_modules[name].render(*args, **kwargs) return rendered + return render - def _ui_method(self, method): + def _ui_method(self, method: Callable[..., str]) -> Callable[..., str]: return lambda *args, **kwargs: method(self, *args, **kwargs) - def _clear_headers_for_304(self): - # 304 responses should not contain entity headers (defined in - # http://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.1) + def _clear_representation_headers(self) -> None: + # 304 responses should not contain representation metadata + # headers (defined in + # https://tools.ietf.org/html/rfc7231#section-3.1) # not explicitly allowed by - # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.5 - headers = ["Allow", "Content-Encoding", "Content-Language", - "Content-Length", "Content-MD5", "Content-Range", - "Content-Type", "Last-Modified"] + # https://tools.ietf.org/html/rfc7232#section-4.1 + headers = ["Content-Encoding", "Content-Language", "Content-Type"] for h in headers: self.clear_header(h) -def asynchronous(method): - """Wrap request handler methods with this if they are asynchronous. - - This decorator is for callback-style asynchronous methods; for - coroutines, use the ``@gen.coroutine`` decorator without - ``@asynchronous``. (It is legal for legacy reasons to use the two - decorators together provided ``@asynchronous`` is first, but - ``@asynchronous`` will be ignored in this case) - - This decorator should only be applied to the :ref:`HTTP verb - methods `; its behavior is undefined for any other method. - This decorator does not *make* a method asynchronous; it tells - the framework that the method *is* asynchronous. For this decorator - to be useful the method must (at least sometimes) do something - asynchronous. - - If this decorator is given, the response is not finished when the - method returns. It is up to the request handler to call - `self.finish() ` to finish the HTTP - request. Without this decorator, the request is automatically - finished when the ``get()`` or ``post()`` method returns. Example: - - .. testcode:: - - class MyRequestHandler(RequestHandler): - @asynchronous - def get(self): - http = httpclient.AsyncHTTPClient() - http.fetch("http://friendfeed.com/", self._on_download) - - def _on_download(self, response): - self.write("Downloaded!") - self.finish() - - .. testoutput:: - :hide: - - .. versionchanged:: 3.1 - The ability to use ``@gen.coroutine`` without ``@asynchronous``. - - .. versionchanged:: 4.3 Returning anything but ``None`` or a - yieldable object from a method decorated with ``@asynchronous`` - is an error. Such return values were previously ignored silently. - """ - # Delay the IOLoop import because it's not available on app engine. - from tornado.ioloop import IOLoop - - @functools.wraps(method) - def wrapper(self, *args, **kwargs): - self._auto_finish = False - with stack_context.ExceptionStackContext( - self._stack_context_handle_exception): - result = method(self, *args, **kwargs) - if result is not None: - result = gen.convert_yielded(result) - - # If @asynchronous is used with @gen.coroutine, (but - # not @gen.engine), we can automatically finish the - # request when the future resolves. Additionally, - # the Future will swallow any exceptions so we need - # to throw them back out to the stack context to finish - # the request. - def future_complete(f): - f.result() - if not self._finished: - self.finish() - IOLoop.current().add_future(result, future_complete) - # Once we have done this, hide the Future from our - # caller (i.e. RequestHandler._when_complete), which - # would otherwise set up its own callback and - # exception handler (resulting in exceptions being - # logged twice). - return None - return result - return wrapper - - -def stream_request_body(cls): +def stream_request_body(cls: Type[RequestHandler]) -> Type[RequestHandler]: """Apply to `RequestHandler` subclasses to enable streaming body support. This decorator implies the following changes: @@ -1749,21 +1845,26 @@ def stream_request_body(cls): return cls -def _has_stream_request_body(cls): +def _has_stream_request_body(cls: Type[RequestHandler]) -> bool: if not issubclass(cls, RequestHandler): raise TypeError("expected subclass of RequestHandler, got %r", cls) - return getattr(cls, '_stream_request_body', False) + return cls._stream_request_body -def removeslash(method): +def removeslash( + method: Callable[..., Optional[Awaitable[None]]] +) -> Callable[..., Optional[Awaitable[None]]]: """Use this decorator to remove trailing slashes from the request path. For example, a request to ``/foo/`` would redirect to ``/foo`` with this decorator. Your request handler mapping should use a regular expression like ``r'/foo/*'`` in conjunction with using the decorator. """ + @functools.wraps(method) - def wrapper(self, *args, **kwargs): + def wrapper( # type: ignore + self: RequestHandler, *args, **kwargs + ) -> Optional[Awaitable[None]]: if self.request.path.endswith("/"): if self.request.method in ("GET", "HEAD"): uri = self.request.path.rstrip("/") @@ -1771,31 +1872,38 @@ def wrapper(self, *args, **kwargs): if self.request.query: uri += "?" + self.request.query self.redirect(uri, permanent=True) - return + return None else: raise HTTPError(404) return method(self, *args, **kwargs) + return wrapper -def addslash(method): +def addslash( + method: Callable[..., Optional[Awaitable[None]]] +) -> Callable[..., Optional[Awaitable[None]]]: """Use this decorator to add a missing trailing slash to the request path. For example, a request to ``/foo`` would redirect to ``/foo/`` with this decorator. Your request handler mapping should use a regular expression like ``r'/foo/?'`` in conjunction with using the decorator. """ + @functools.wraps(method) - def wrapper(self, *args, **kwargs): + def wrapper( # type: ignore + self: RequestHandler, *args, **kwargs + ) -> Optional[Awaitable[None]]: if not self.request.path.endswith("/"): if self.request.method in ("GET", "HEAD"): uri = self.request.path + "/" if self.request.query: uri += "?" + self.request.query self.redirect(uri, permanent=True) - return + return None raise HTTPError(404) return method(self, *args, **kwargs) + return wrapper @@ -1810,28 +1918,36 @@ class _ApplicationRouter(ReversibleRuleRouter): `_ApplicationRouter` instance. """ - def __init__(self, application, rules=None): + def __init__( + self, application: "Application", rules: Optional[_RuleList] = None + ) -> None: assert isinstance(application, Application) self.application = application - super(_ApplicationRouter, self).__init__(rules) + super().__init__(rules) - def process_rule(self, rule): - rule = super(_ApplicationRouter, self).process_rule(rule) + def process_rule(self, rule: Rule) -> Rule: + rule = super().process_rule(rule) if isinstance(rule.target, (list, tuple)): - rule.target = _ApplicationRouter(self.application, rule.target) + rule.target = _ApplicationRouter( + self.application, rule.target # type: ignore + ) return rule - def get_target_delegate(self, target, request, **target_params): + def get_target_delegate( + self, target: Any, request: httputil.HTTPServerRequest, **target_params: Any + ) -> Optional[httputil.HTTPMessageDelegate]: if isclass(target) and issubclass(target, RequestHandler): - return self.application.get_handler_delegate(request, target, **target_params) + return self.application.get_handler_delegate( + request, target, **target_params + ) - return super(_ApplicationRouter, self).get_target_delegate(target, request, **target_params) + return super().get_target_delegate(target, request, **target_params) class Application(ReversibleRouter): - """A collection of request handlers that make up a web application. + r"""A collection of request handlers that make up a web application. Instances of this class are callable and can be passed directly to HTTPServer to serve the application:: @@ -1895,7 +2011,7 @@ class Application(ReversibleRouter): Applications that do not use TLS may be vulnerable to :ref:`DNS rebinding ` attacks. This attack is especially - relevant to applications that only listen on ``127.0.0.1` or + relevant to applications that only listen on ``127.0.0.1`` or other private networks. Appropriate host patterns must be used (instead of the default of ``r'.*'``) to prevent this risk. The ``default_host`` argument must not be used in applications that @@ -1913,54 +2029,64 @@ class Application(ReversibleRouter): Integration with the new `tornado.routing` module. """ - def __init__(self, handlers=None, default_host=None, transforms=None, - **settings): + + def __init__( + self, + handlers: Optional[_RuleList] = None, + default_host: Optional[str] = None, + transforms: Optional[List[Type["OutputTransform"]]] = None, + **settings: Any + ) -> None: if transforms is None: - self.transforms = [] + self.transforms = [] # type: List[Type[OutputTransform]] if settings.get("compress_response") or settings.get("gzip"): self.transforms.append(GZipContentEncoding) else: self.transforms = transforms self.default_host = default_host self.settings = settings - self.ui_modules = {'linkify': _linkify, - 'xsrf_form_html': _xsrf_form_html, - 'Template': TemplateModule, - } - self.ui_methods = {} + self.ui_modules = { + "linkify": _linkify, + "xsrf_form_html": _xsrf_form_html, + "Template": TemplateModule, + } + self.ui_methods = {} # type: Dict[str, Callable[..., str]] self._load_ui_modules(settings.get("ui_modules", {})) self._load_ui_methods(settings.get("ui_methods", {})) if self.settings.get("static_path"): path = self.settings["static_path"] handlers = list(handlers or []) - static_url_prefix = settings.get("static_url_prefix", - "/static/") - static_handler_class = settings.get("static_handler_class", - StaticFileHandler) + static_url_prefix = settings.get("static_url_prefix", "/static/") + static_handler_class = settings.get( + "static_handler_class", StaticFileHandler + ) static_handler_args = settings.get("static_handler_args", {}) - static_handler_args['path'] = path - for pattern in [re.escape(static_url_prefix) + r"(.*)", - r"/(favicon\.ico)", r"/(robots\.txt)"]: - handlers.insert(0, (pattern, static_handler_class, - static_handler_args)) - - if self.settings.get('debug'): - self.settings.setdefault('autoreload', True) - self.settings.setdefault('compiled_template_cache', False) - self.settings.setdefault('static_hash_cache', False) - self.settings.setdefault('serve_traceback', True) + static_handler_args["path"] = path + for pattern in [ + re.escape(static_url_prefix) + r"(.*)", + r"/(favicon\.ico)", + r"/(robots\.txt)", + ]: + handlers.insert(0, (pattern, static_handler_class, static_handler_args)) + + if self.settings.get("debug"): + self.settings.setdefault("autoreload", True) + self.settings.setdefault("compiled_template_cache", False) + self.settings.setdefault("static_hash_cache", False) + self.settings.setdefault("serve_traceback", True) self.wildcard_router = _ApplicationRouter(self, handlers) - self.default_router = _ApplicationRouter(self, [ - Rule(AnyMatches(), self.wildcard_router) - ]) + self.default_router = _ApplicationRouter( + self, [Rule(AnyMatches(), self.wildcard_router)] + ) # Automatically reload modified modules - if self.settings.get('autoreload'): + if self.settings.get("autoreload"): from tornado import autoreload + autoreload.start() - def listen(self, port, address="", **kwargs): + def listen(self, port: int, address: str = "", **kwargs: Any) -> HTTPServer: """Starts an HTTP server for this application on the given port. This is a convenience alias for creating an `.HTTPServer` @@ -1979,14 +2105,11 @@ def listen(self, port, address="", **kwargs): .. versionchanged:: 4.3 Now returns the `.HTTPServer` object. """ - # import is here rather than top level because HTTPServer - # is not importable on appengine - from tornado.httpserver import HTTPServer server = HTTPServer(self, **kwargs) server.listen(port, address) return server - def add_handlers(self, host_pattern, host_handlers): + def add_handlers(self, host_pattern: str, host_handlers: _RuleList) -> None: """Appends the given handlers to our handler list. Host patterns are processed sequentially in the order they were @@ -1998,31 +2121,31 @@ def add_handlers(self, host_pattern, host_handlers): self.default_router.rules.insert(-1, rule) if self.default_host is not None: - self.wildcard_router.add_rules([( - DefaultHostMatches(self, host_matcher.host_pattern), - host_handlers - )]) + self.wildcard_router.add_rules( + [(DefaultHostMatches(self, host_matcher.host_pattern), host_handlers)] + ) - def add_transform(self, transform_class): + def add_transform(self, transform_class: Type["OutputTransform"]) -> None: self.transforms.append(transform_class) - def _load_ui_methods(self, methods): + def _load_ui_methods(self, methods: Any) -> None: if isinstance(methods, types.ModuleType): - self._load_ui_methods(dict((n, getattr(methods, n)) - for n in dir(methods))) + self._load_ui_methods(dict((n, getattr(methods, n)) for n in dir(methods))) elif isinstance(methods, list): for m in methods: self._load_ui_methods(m) else: for name, fn in methods.items(): - if not name.startswith("_") and hasattr(fn, "__call__") \ - and name[0].lower() == name[0]: + if ( + not name.startswith("_") + and hasattr(fn, "__call__") + and name[0].lower() == name[0] + ): self.ui_methods[name] = fn - def _load_ui_modules(self, modules): + def _load_ui_modules(self, modules: Any) -> None: if isinstance(modules, types.ModuleType): - self._load_ui_modules(dict((n, getattr(modules, n)) - for n in dir(modules))) + self._load_ui_modules(dict((n, getattr(modules, n)) for n in dir(modules))) elif isinstance(modules, list): for m in modules: self._load_ui_modules(m) @@ -2035,27 +2158,37 @@ def _load_ui_modules(self, modules): except TypeError: pass - def __call__(self, request): + def __call__( + self, request: httputil.HTTPServerRequest + ) -> Optional[Awaitable[None]]: # Legacy HTTPServer interface dispatcher = self.find_handler(request) return dispatcher.execute() - def find_handler(self, request, **kwargs): + def find_handler( + self, request: httputil.HTTPServerRequest, **kwargs: Any + ) -> "_HandlerDelegate": route = self.default_router.find_handler(request) if route is not None: - return route + return cast("_HandlerDelegate", route) - if self.settings.get('default_handler_class'): + if self.settings.get("default_handler_class"): return self.get_handler_delegate( request, - self.settings['default_handler_class'], - self.settings.get('default_handler_args', {})) - - return self.get_handler_delegate( - request, ErrorHandler, {'status_code': 404}) - - def get_handler_delegate(self, request, target_class, target_kwargs=None, - path_args=None, path_kwargs=None): + self.settings["default_handler_class"], + self.settings.get("default_handler_args", {}), + ) + + return self.get_handler_delegate(request, ErrorHandler, {"status_code": 404}) + + def get_handler_delegate( + self, + request: httputil.HTTPServerRequest, + target_class: Type[RequestHandler], + target_kwargs: Optional[Dict[str, Any]] = None, + path_args: Optional[List[bytes]] = None, + path_kwargs: Optional[Dict[str, bytes]] = None, + ) -> "_HandlerDelegate": """Returns `~.httputil.HTTPMessageDelegate` that can serve a request for application and `RequestHandler` subclass. @@ -2067,9 +2200,10 @@ def get_handler_delegate(self, request, target_class, target_kwargs=None, :arg dict path_kwargs: keyword arguments for ``target_class`` HTTP method. """ return _HandlerDelegate( - self, request, target_class, target_kwargs, path_args, path_kwargs) + self, request, target_class, target_kwargs, path_args, path_kwargs + ) - def reverse_url(self, name, *args): + def reverse_url(self, name: str, *args: Any) -> str: """Returns a URL path for handler named ``name`` The handler must be added to the application as a named `URLSpec`. @@ -2084,7 +2218,7 @@ def reverse_url(self, name, *args): raise KeyError("%s not found in named urls" % name) - def log_request(self, handler): + def log_request(self, handler: RequestHandler) -> None: """Writes a completed HTTP request to the logs. By default writes to the python root logger. To change @@ -2102,13 +2236,24 @@ def log_request(self, handler): else: log_method = access_log.error request_time = 1000.0 * handler.request.request_time() - log_method("%d %s %.2fms", handler.get_status(), - handler._request_summary(), request_time) + log_method( + "%d %s %.2fms", + handler.get_status(), + handler._request_summary(), + request_time, + ) class _HandlerDelegate(httputil.HTTPMessageDelegate): - def __init__(self, application, request, handler_class, handler_kwargs, - path_args, path_kwargs): + def __init__( + self, + application: Application, + request: httputil.HTTPServerRequest, + handler_class: Type[RequestHandler], + handler_kwargs: Optional[Dict[str, Any]], + path_args: Optional[List[bytes]], + path_kwargs: Optional[Dict[str, bytes]], + ) -> None: self.application = application self.connection = request.connection self.request = request @@ -2116,35 +2261,41 @@ def __init__(self, application, request, handler_class, handler_kwargs, self.handler_kwargs = handler_kwargs or {} self.path_args = path_args or [] self.path_kwargs = path_kwargs or {} - self.chunks = [] + self.chunks = [] # type: List[bytes] self.stream_request_body = _has_stream_request_body(self.handler_class) - def headers_received(self, start_line, headers): + def headers_received( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + ) -> Optional[Awaitable[None]]: if self.stream_request_body: - self.request.body = Future() + self.request._body_future = Future() return self.execute() + return None - def data_received(self, data): + def data_received(self, data: bytes) -> Optional[Awaitable[None]]: if self.stream_request_body: return self.handler.data_received(data) else: self.chunks.append(data) + return None - def finish(self): + def finish(self) -> None: if self.stream_request_body: - future_set_result_unless_cancelled(self.request.body, None) + future_set_result_unless_cancelled(self.request._body_future, None) else: - self.request.body = b''.join(self.chunks) + self.request.body = b"".join(self.chunks) self.request._parse_body() self.execute() - def on_connection_close(self): + def on_connection_close(self) -> None: if self.stream_request_body: self.handler.on_connection_close() else: - self.chunks = None + self.chunks = None # type: ignore - def execute(self): + def execute(self) -> Optional[Awaitable[None]]: # If template cache is disabled (usually in the debug mode), # re-compile templates and reload static files on every # request so you don't need to restart to see changes @@ -2152,11 +2303,12 @@ def execute(self): with RequestHandler._template_loader_lock: for loader in RequestHandler._template_loaders.values(): loader.reset() - if not self.application.settings.get('static_hash_cache', True): + if not self.application.settings.get("static_hash_cache", True): StaticFileHandler.reset() - self.handler = self.handler_class(self.application, self.request, - **self.handler_kwargs) + self.handler = self.handler_class( + self.application, self.request, **self.handler_kwargs + ) transforms = [t(self.request) for t in self.application.transforms] if self.stream_request_body: @@ -2168,8 +2320,10 @@ def execute(self): # except handler, and we cannot easily access the IOLoop here to # call add_future (because of the requirement to remain compatible # with WSGI) - self.handler._execute(transforms, *self.path_args, - **self.path_kwargs) + fut = gen.convert_yielded( + self.handler._execute(transforms, *self.path_args, **self.path_kwargs) + ) + fut.add_done_callback(lambda f: f.result()) # If we are streaming the request body, then execute() is finished # when the handler has prepared to receive the body. If not, # it doesn't matter when execute() finishes (so we return None) @@ -2198,18 +2352,26 @@ class HTTPError(Exception): determined automatically from ``status_code``, but can be used to use a non-standard numeric code. """ - def __init__(self, status_code=500, log_message=None, *args, **kwargs): + + def __init__( + self, + status_code: int = 500, + log_message: Optional[str] = None, + *args: Any, + **kwargs: Any + ) -> None: self.status_code = status_code self.log_message = log_message self.args = args - self.reason = kwargs.get('reason', None) + self.reason = kwargs.get("reason", None) if log_message and not args: - self.log_message = log_message.replace('%', '%%') + self.log_message = log_message.replace("%", "%%") - def __str__(self): + def __str__(self) -> str: message = "HTTP %d: %s" % ( self.status_code, - self.reason or httputil.responses.get(self.status_code, 'Unknown')) + self.reason or httputil.responses.get(self.status_code, "Unknown"), + ) if self.log_message: return message + " (" + (self.log_message % self.args) + ")" else: @@ -2240,6 +2402,7 @@ class Finish(Exception): Arguments passed to ``Finish()`` will be passed on to `RequestHandler.finish`. """ + pass @@ -2251,21 +2414,22 @@ class MissingArgumentError(HTTPError): .. versionadded:: 3.1 """ - def __init__(self, arg_name): - super(MissingArgumentError, self).__init__( - 400, 'Missing argument %s' % arg_name) + + def __init__(self, arg_name: str) -> None: + super().__init__(400, "Missing argument %s" % arg_name) self.arg_name = arg_name class ErrorHandler(RequestHandler): """Generates an error response with ``status_code`` for all requests.""" - def initialize(self, status_code): + + def initialize(self, status_code: int) -> None: self.set_status(status_code) - def prepare(self): + def prepare(self) -> None: raise HTTPError(self._status_code) - def check_xsrf_cookie(self): + def check_xsrf_cookie(self) -> None: # POSTs to an ErrorHandler don't actually have side effects, # so we don't need to check the xsrf token. This allows POSTs # to the wrong url to return a 404 instead of 403. @@ -2304,15 +2468,19 @@ class RedirectHandler(RequestHandler): If any query arguments are present, they will be copied to the destination URL. """ - def initialize(self, url, permanent=True): + + def initialize(self, url: str, permanent: bool = True) -> None: self._url = url self._permanent = permanent - def get(self, *args): - to_url = self._url.format(*args) + def get(self, *args: Any, **kwargs: Any) -> None: + to_url = self._url.format(*args, **kwargs) if self.request.query_arguments: + # TODO: figure out typing for the next line. to_url = httputil.url_concat( - to_url, list(httputil.qs_to_qsl(self.request.query_arguments))) + to_url, + list(httputil.qs_to_qsl(self.request.query_arguments)), # type: ignore + ) self.redirect(to_url, permanent=self._permanent) @@ -2382,31 +2550,30 @@ class method. Instance methods may use the attributes ``self.path`` .. versionchanged:: 3.1 Many of the methods for subclasses were added in Tornado 3.1. """ + CACHE_MAX_AGE = 86400 * 365 * 10 # 10 years - _static_hashes = {} # type: typing.Dict + _static_hashes = {} # type: Dict[str, Optional[str]] _lock = threading.Lock() # protects _static_hashes - def initialize(self, path, default_filename=None): + def initialize(self, path: str, default_filename: Optional[str] = None) -> None: self.root = path self.default_filename = default_filename @classmethod - def reset(cls): + def reset(cls) -> None: with cls._lock: cls._static_hashes = {} - def head(self, path): + def head(self, path: str) -> Awaitable[None]: return self.get(path, include_body=False) - @gen.coroutine - def get(self, path, include_body=True): + async def get(self, path: str, include_body: bool = True) -> None: # Set up our path instance variables. self.path = self.parse_url_path(path) del path # make sure we don't refer to path instead of self.path again absolute_path = self.get_absolute_path(self.root, self.path) - self.absolute_path = self.validate_absolute_path( - self.root, absolute_path) + self.absolute_path = self.validate_absolute_path(self.root, absolute_path) if self.absolute_path is None: return @@ -2427,16 +2594,24 @@ def get(self, path, include_body=True): size = self.get_content_size() if request_range: start, end = request_range - if (start is not None and start >= size) or end == 0: + if start is not None and start < 0: + start += size + if start < 0: + start = 0 + if ( + start is not None + and (start >= size or (end is not None and start >= end)) + ) or end == 0: # As per RFC 2616 14.35.1, a range is not satisfiable only: if # the first requested byte is equal to or greater than the - # content, or when a suffix with length 0 is specified + # content, or when a suffix with length 0 is specified. + # https://tools.ietf.org/html/rfc7233#section-2.1 + # A byte-range-spec is invalid if the last-byte-pos value is present + # and less than the first-byte-pos. self.set_status(416) # Range Not Satisfiable self.set_header("Content-Type", "text/plain") - self.set_header("Content-Range", "bytes */%s" % (size, )) + self.set_header("Content-Range", "bytes */%s" % (size,)) return - if start is not None and start < 0: - start += size if end is not None and end > size: # Clients sometimes blindly use a large range to limit their # download size; cap the endpoint at the actual file size. @@ -2447,8 +2622,9 @@ def get(self, path, include_body=True): # ``Range: bytes=0-``. if size != (end or size) - (start or 0): self.set_status(206) # Partial Content - self.set_header("Content-Range", - httputil._get_content_range(start, end, size)) + self.set_header( + "Content-Range", httputil._get_content_range(start, end, size) + ) else: start = end = None @@ -2469,13 +2645,13 @@ def get(self, path, include_body=True): for chunk in content: try: self.write(chunk) - yield self.flush() + await self.flush() except iostream.StreamClosedError: return else: assert self.request.method == "HEAD" - def compute_etag(self): + def compute_etag(self) -> Optional[str]: """Sets the ``Etag`` header based on static url version. This allows efficient ``If-None-Match`` checks against cached @@ -2484,12 +2660,13 @@ def compute_etag(self): .. versionadded:: 3.1 """ + assert self.absolute_path is not None version_hash = self._get_cached_version(self.absolute_path) if not version_hash: return None - return '"%s"' % (version_hash, ) + return '"%s"' % (version_hash,) - def set_headers(self): + def set_headers(self) -> None: """Sets the content and caching headers on the response. .. versionadded:: 3.1 @@ -2504,22 +2681,23 @@ def set_headers(self): if content_type: self.set_header("Content-Type", content_type) - cache_time = self.get_cache_time(self.path, self.modified, - content_type) + cache_time = self.get_cache_time(self.path, self.modified, content_type) if cache_time > 0: - self.set_header("Expires", datetime.datetime.utcnow() + - datetime.timedelta(seconds=cache_time)) + self.set_header( + "Expires", + datetime.datetime.utcnow() + datetime.timedelta(seconds=cache_time), + ) self.set_header("Cache-Control", "max-age=" + str(cache_time)) self.set_extra_headers(self.path) - def should_return_304(self): + def should_return_304(self) -> bool: """Returns True if the headers indicate that we should return 304. .. versionadded:: 3.1 """ # If client sent If-None-Match, use it, ignore If-Modified-Since - if self.request.headers.get('If-None-Match'): + if self.request.headers.get("If-None-Match"): return self.check_etag_header() # Check the If-Modified-Since, and don't send the result if the @@ -2529,13 +2707,14 @@ def should_return_304(self): date_tuple = email.utils.parsedate(ims_value) if date_tuple is not None: if_since = datetime.datetime(*date_tuple[:6]) + assert self.modified is not None if if_since >= self.modified: return True return False @classmethod - def get_absolute_path(cls, root, path): + def get_absolute_path(cls, root: str, path: str) -> str: """Returns the absolute location of ``path`` relative to ``root``. ``root`` is the path configured for this `StaticFileHandler` @@ -2551,7 +2730,7 @@ def get_absolute_path(cls, root, path): abspath = os.path.abspath(os.path.join(root, path)) return abspath - def validate_absolute_path(self, root, absolute_path): + def validate_absolute_path(self, root: str, absolute_path: str) -> Optional[str]: """Validate and return the absolute path. ``root`` is the configured path for the `StaticFileHandler`, @@ -2586,16 +2765,14 @@ def validate_absolute_path(self, root, absolute_path): # The trailing slash also needs to be temporarily added back # the requested path so a request to root/ will match. if not (absolute_path + os.path.sep).startswith(root): - raise HTTPError(403, "%s is not in root static directory", - self.path) - if (os.path.isdir(absolute_path) and - self.default_filename is not None): + raise HTTPError(403, "%s is not in root static directory", self.path) + if os.path.isdir(absolute_path) and self.default_filename is not None: # need to look at the request.path here for when path is empty # but there is some prefix to the path that was already # trimmed by the routing if not self.request.path.endswith("/"): self.redirect(self.request.path + "/", permanent=True) - return + return None absolute_path = os.path.join(absolute_path, self.default_filename) if not os.path.exists(absolute_path): raise HTTPError(404) @@ -2604,7 +2781,9 @@ def validate_absolute_path(self, root, absolute_path): return absolute_path @classmethod - def get_content(cls, abspath, start=None, end=None): + def get_content( + cls, abspath: str, start: Optional[int] = None, end: Optional[int] = None + ) -> Generator[bytes, None, None]: """Retrieve the content of the requested resource which is located at the given absolute path. @@ -2623,7 +2802,7 @@ def get_content(cls, abspath, start=None, end=None): if start is not None: file.seek(start) if end is not None: - remaining = end - (start or 0) + remaining = end - (start or 0) # type: Optional[int] else: remaining = None while True: @@ -2641,16 +2820,16 @@ def get_content(cls, abspath, start=None, end=None): return @classmethod - def get_content_version(cls, abspath): + def get_content_version(cls, abspath: str) -> str: """Returns a version string for the resource at the given path. This class method may be overridden by subclasses. The - default implementation is a hash of the file's contents. + default implementation is a SHA-512 hash of the file's contents. .. versionadded:: 3.1 """ data = cls.get_content(abspath) - hasher = hashlib.md5() + hasher = hashlib.sha512() if isinstance(data, bytes): hasher.update(data) else: @@ -2658,12 +2837,13 @@ def get_content_version(cls, abspath): hasher.update(chunk) return hasher.hexdigest() - def _stat(self): - if not hasattr(self, '_stat_result'): + def _stat(self) -> os.stat_result: + assert self.absolute_path is not None + if not hasattr(self, "_stat_result"): self._stat_result = os.stat(self.absolute_path) return self._stat_result - def get_content_size(self): + def get_content_size(self) -> int: """Retrieve the total size of the resource at the given path. This method may be overridden by subclasses. @@ -2675,9 +2855,9 @@ def get_content_size(self): partial results are requested. """ stat_result = self._stat() - return stat_result[stat.ST_SIZE] + return stat_result.st_size - def get_modified_time(self): + def get_modified_time(self) -> Optional[datetime.datetime]: """Returns the time that ``self.absolute_path`` was last modified. May be overridden in subclasses. Should return a `~datetime.datetime` @@ -2686,15 +2866,23 @@ def get_modified_time(self): .. versionadded:: 3.1 """ stat_result = self._stat() - modified = datetime.datetime.utcfromtimestamp( - stat_result[stat.ST_MTIME]) + # NOTE: Historically, this used stat_result[stat.ST_MTIME], + # which truncates the fractional portion of the timestamp. It + # was changed from that form to stat_result.st_mtime to + # satisfy mypy (which disallows the bracket operator), but the + # latter form returns a float instead of an int. For + # consistency with the past (and because we have a unit test + # that relies on this), we truncate the float here, although + # I'm not sure that's the right thing to do. + modified = datetime.datetime.utcfromtimestamp(int(stat_result.st_mtime)) return modified - def get_content_type(self): + def get_content_type(self) -> str: """Returns the ``Content-Type`` header to be used for this request. .. versionadded:: 3.1 """ + assert self.absolute_path is not None mime_type, encoding = mimetypes.guess_type(self.absolute_path) # per RFC 6713, use the appropriate type for a gzip compressed file if encoding == "gzip": @@ -2710,11 +2898,13 @@ def get_content_type(self): else: return "application/octet-stream" - def set_extra_headers(self, path): + def set_extra_headers(self, path: str) -> None: """For subclass to add extra headers to the response""" pass - def get_cache_time(self, path, modified, mime_type): + def get_cache_time( + self, path: str, modified: Optional[datetime.datetime], mime_type: str + ) -> int: """Override to customize cache control behavior. Return a positive number of seconds to make the result @@ -2728,7 +2918,9 @@ def get_cache_time(self, path, modified, mime_type): return self.CACHE_MAX_AGE if "v" in self.request.arguments else 0 @classmethod - def make_static_url(cls, settings, path, include_version=True): + def make_static_url( + cls, settings: Dict[str, Any], path: str, include_version: bool = True + ) -> str: """Constructs a versioned url for the given path. This method may be overridden in subclasses (but note that it @@ -2747,7 +2939,7 @@ def make_static_url(cls, settings, path, include_version=True): file corresponding to the given ``path``. """ - url = settings.get('static_url_prefix', '/static/') + path + url = settings.get("static_url_prefix", "/static/") + path if not include_version: return url @@ -2755,9 +2947,9 @@ def make_static_url(cls, settings, path, include_version=True): if not version_hash: return url - return '%s?v=%s' % (url, version_hash) + return "%s?v=%s" % (url, version_hash) - def parse_url_path(self, url_path): + def parse_url_path(self, url_path: str) -> str: """Converts a static URL path into a filesystem path. ``url_path`` is the path component of the URL with @@ -2771,7 +2963,7 @@ def parse_url_path(self, url_path): return url_path @classmethod - def get_version(cls, settings, path): + def get_version(cls, settings: Dict[str, Any], path: str) -> Optional[str]: """Generate the version string to be used in static URLs. ``settings`` is the `Application.settings` dictionary and ``path`` @@ -2784,11 +2976,11 @@ def get_version(cls, settings, path): `get_content_version` is now preferred as it allows the base class to handle caching of the result. """ - abs_path = cls.get_absolute_path(settings['static_path'], path) + abs_path = cls.get_absolute_path(settings["static_path"], path) return cls._get_cached_version(abs_path) @classmethod - def _get_cached_version(cls, abs_path): + def _get_cached_version(cls, abs_path: str) -> Optional[str]: with cls._lock: hashes = cls._static_hashes if abs_path not in hashes: @@ -2819,10 +3011,13 @@ class FallbackHandler(RequestHandler): (r".*", FallbackHandler, dict(fallback=wsgi_app), ]) """ - def initialize(self, fallback): + + def initialize( + self, fallback: Callable[[httputil.HTTPServerRequest], None] + ) -> None: self.fallback = fallback - def prepare(self): + def prepare(self) -> None: self.fallback(self.request) self._finished = True self.on_finish() @@ -2835,14 +3030,20 @@ class OutputTransform(object): or interact with them directly; the framework chooses which transforms (if any) to apply. """ - def __init__(self, request): + + def __init__(self, request: httputil.HTTPServerRequest) -> None: pass - def transform_first_chunk(self, status_code, headers, chunk, finishing): - # type: (int, httputil.HTTPHeaders, bytes, bool) -> typing.Tuple[int, httputil.HTTPHeaders, bytes] # noqa: E501 + def transform_first_chunk( + self, + status_code: int, + headers: httputil.HTTPHeaders, + chunk: bytes, + finishing: bool, + ) -> Tuple[int, httputil.HTTPHeaders, bytes]: return status_code, headers, chunk - def transform_chunk(self, chunk, finishing): + def transform_chunk(self, chunk: bytes, finishing: bool) -> bytes: return chunk @@ -2856,12 +3057,20 @@ class GZipContentEncoding(OutputTransform): of just a whitelist. (the whitelist is still used for certain non-text mime types). """ + # Whitelist of compressible mime types (in addition to any types # beginning with "text/"). - CONTENT_TYPES = set(["application/javascript", "application/x-javascript", - "application/xml", "application/atom+xml", - "application/json", "application/xhtml+xml", - "image/svg+xml"]) + CONTENT_TYPES = set( + [ + "application/javascript", + "application/x-javascript", + "application/xml", + "application/atom+xml", + "application/json", + "application/xhtml+xml", + "image/svg+xml", + ] + ) # Python's GzipFile defaults to level 9, while most other gzip # tools (including gzip itself) default to 6, which is probably a # better CPU/size tradeoff. @@ -2873,29 +3082,37 @@ class GZipContentEncoding(OutputTransform): # regardless of size. MIN_LENGTH = 1024 - def __init__(self, request): + def __init__(self, request: httputil.HTTPServerRequest) -> None: self._gzipping = "gzip" in request.headers.get("Accept-Encoding", "") - def _compressible_type(self, ctype): - return ctype.startswith('text/') or ctype in self.CONTENT_TYPES + def _compressible_type(self, ctype: str) -> bool: + return ctype.startswith("text/") or ctype in self.CONTENT_TYPES - def transform_first_chunk(self, status_code, headers, chunk, finishing): - # type: (int, httputil.HTTPHeaders, bytes, bool) -> typing.Tuple[int, httputil.HTTPHeaders, bytes] # noqa: E501 + def transform_first_chunk( + self, + status_code: int, + headers: httputil.HTTPHeaders, + chunk: bytes, + finishing: bool, + ) -> Tuple[int, httputil.HTTPHeaders, bytes]: # TODO: can/should this type be inherited from the superclass? - if 'Vary' in headers: - headers['Vary'] += ', Accept-Encoding' + if "Vary" in headers: + headers["Vary"] += ", Accept-Encoding" else: - headers['Vary'] = 'Accept-Encoding' + headers["Vary"] = "Accept-Encoding" if self._gzipping: ctype = _unicode(headers.get("Content-Type", "")).split(";")[0] - self._gzipping = self._compressible_type(ctype) and \ - (not finishing or len(chunk) >= self.MIN_LENGTH) and \ - ("Content-Encoding" not in headers) + self._gzipping = ( + self._compressible_type(ctype) + and (not finishing or len(chunk) >= self.MIN_LENGTH) + and ("Content-Encoding" not in headers) + ) if self._gzipping: headers["Content-Encoding"] = "gzip" self._gzip_value = BytesIO() - self._gzip_file = gzip.GzipFile(mode="w", fileobj=self._gzip_value, - compresslevel=self.GZIP_LEVEL) + self._gzip_file = gzip.GzipFile( + mode="w", fileobj=self._gzip_value, compresslevel=self.GZIP_LEVEL + ) chunk = self.transform_chunk(chunk, finishing) if "Content-Length" in headers: # The original content length is no longer correct. @@ -2908,7 +3125,7 @@ def transform_first_chunk(self, status_code, headers, chunk, finishing): del headers["Content-Length"] return status_code, headers, chunk - def transform_chunk(self, chunk, finishing): + def transform_chunk(self, chunk: bytes, finishing: bool) -> bytes: if self._gzipping: self._gzip_file.write(chunk) if finishing: @@ -2921,7 +3138,9 @@ def transform_chunk(self, chunk, finishing): return chunk -def authenticated(method): +def authenticated( + method: Callable[..., Optional[Awaitable[None]]] +) -> Callable[..., Optional[Awaitable[None]]]: """Decorate methods with this to require that the user be logged in. If the user is not logged in, they will be redirected to the configured @@ -2932,22 +3151,27 @@ def authenticated(method): will add a `next` parameter so the login page knows where to send you once you're logged in. """ + @functools.wraps(method) - def wrapper(self, *args, **kwargs): + def wrapper( # type: ignore + self: RequestHandler, *args, **kwargs + ) -> Optional[Awaitable[None]]: if not self.current_user: if self.request.method in ("GET", "HEAD"): url = self.get_login_url() if "?" not in url: - if urlparse.urlsplit(url).scheme: + if urllib.parse.urlsplit(url).scheme: # if login url is absolute, make next absolute too next_url = self.request.full_url() else: + assert self.request.uri is not None next_url = self.request.uri url += "?" + urlencode(dict(next=next_url)) self.redirect(url) - return + return None raise HTTPError(403) return method(self, *args, **kwargs) + return wrapper @@ -2960,26 +3184,27 @@ class UIModule(object): Subclasses of UIModule must override the `render` method. """ - def __init__(self, handler): + + def __init__(self, handler: RequestHandler) -> None: self.handler = handler self.request = handler.request self.ui = handler.ui self.locale = handler.locale @property - def current_user(self): + def current_user(self) -> Any: return self.handler.current_user - def render(self, *args, **kwargs): + def render(self, *args: Any, **kwargs: Any) -> str: """Override in subclasses to return this module's output.""" raise NotImplementedError() - def embedded_javascript(self): + def embedded_javascript(self) -> Optional[str]: """Override to return a JavaScript string to be embedded in the page.""" return None - def javascript_files(self): + def javascript_files(self) -> Optional[Iterable[str]]: """Override to return a list of JavaScript files needed by this module. If the return values are relative paths, they will be passed to @@ -2987,12 +3212,12 @@ def javascript_files(self): """ return None - def embedded_css(self): + def embedded_css(self) -> Optional[str]: """Override to return a CSS string that will be embedded in the page.""" return None - def css_files(self): + def css_files(self) -> Optional[Iterable[str]]: """Override to returns a list of CSS files required by this module. If the return values are relative paths, they will be passed to @@ -3000,30 +3225,30 @@ def css_files(self): """ return None - def html_head(self): + def html_head(self) -> Optional[str]: """Override to return an HTML string that will be put in the element. """ return None - def html_body(self): + def html_body(self) -> Optional[str]: """Override to return an HTML string that will be put at the end of the element. """ return None - def render_string(self, path, **kwargs): + def render_string(self, path: str, **kwargs: Any) -> bytes: """Renders a template and returns it as a string.""" return self.handler.render_string(path, **kwargs) class _linkify(UIModule): - def render(self, text, **kwargs): + def render(self, text: str, **kwargs: Any) -> str: # type: ignore return escape.linkify(text, **kwargs) class _xsrf_form_html(UIModule): - def render(self): + def render(self) -> str: # type: ignore return self.handler.xsrf_form_html() @@ -3035,39 +3260,42 @@ class TemplateModule(UIModule): Template()) instead of inheriting the outer template's namespace. Templates rendered through this module also get access to UIModule's - automatic javascript/css features. Simply call set_resources + automatic JavaScript/CSS features. Simply call set_resources inside the template and give it keyword arguments corresponding to the methods on UIModule: {{ set_resources(js_files=static_url("my.js")) }} Note that these resources are output once per template file, not once per instantiation of the template, so they must not depend on any arguments to the template. """ - def __init__(self, handler): - super(TemplateModule, self).__init__(handler) + + def __init__(self, handler: RequestHandler) -> None: + super().__init__(handler) # keep resources in both a list and a dict to preserve order - self._resource_list = [] - self._resource_dict = {} + self._resource_list = [] # type: List[Dict[str, Any]] + self._resource_dict = {} # type: Dict[str, Dict[str, Any]] - def render(self, path, **kwargs): - def set_resources(**kwargs): + def render(self, path: str, **kwargs: Any) -> bytes: # type: ignore + def set_resources(**kwargs) -> str: # type: ignore if path not in self._resource_dict: self._resource_list.append(kwargs) self._resource_dict[path] = kwargs else: if self._resource_dict[path] != kwargs: - raise ValueError("set_resources called with different " - "resources for the same template") + raise ValueError( + "set_resources called with different " + "resources for the same template" + ) return "" - return self.render_string(path, set_resources=set_resources, - **kwargs) - def _get_resources(self, key): + return self.render_string(path, set_resources=set_resources, **kwargs) + + def _get_resources(self, key: str) -> Iterable[str]: return (r[key] for r in self._resource_list if key in r) - def embedded_javascript(self): + def embedded_javascript(self) -> str: return "\n".join(self._get_resources("embedded_javascript")) - def javascript_files(self): + def javascript_files(self) -> Iterable[str]: result = [] for f in self._get_resources("javascript_files"): if isinstance(f, (unicode_type, bytes)): @@ -3076,10 +3304,10 @@ def javascript_files(self): result.extend(f) return result - def embedded_css(self): + def embedded_css(self) -> str: return "\n".join(self._get_resources("embedded_css")) - def css_files(self): + def css_files(self) -> Iterable[str]: result = [] for f in self._get_resources("css_files"): if isinstance(f, (unicode_type, bytes)): @@ -3088,47 +3316,40 @@ def css_files(self): result.extend(f) return result - def html_head(self): + def html_head(self) -> str: return "".join(self._get_resources("html_head")) - def html_body(self): + def html_body(self) -> str: return "".join(self._get_resources("html_body")) class _UIModuleNamespace(object): """Lazy namespace which creates UIModule proxies bound to a handler.""" - def __init__(self, handler, ui_modules): + + def __init__( + self, handler: RequestHandler, ui_modules: Dict[str, Type[UIModule]] + ) -> None: self.handler = handler self.ui_modules = ui_modules - def __getitem__(self, key): + def __getitem__(self, key: str) -> Callable[..., str]: return self.handler._ui_module(key, self.ui_modules[key]) - def __getattr__(self, key): + def __getattr__(self, key: str) -> Callable[..., str]: try: return self[key] except KeyError as e: raise AttributeError(str(e)) -if hasattr(hmac, 'compare_digest'): # python 3.3 - _time_independent_equals = hmac.compare_digest -else: - def _time_independent_equals(a, b): - if len(a) != len(b): - return False - result = 0 - if isinstance(a[0], int): # python3 byte strings - for x, y in zip(a, b): - result |= x ^ y - else: # python2 - for x, y in zip(a, b): - result |= ord(x) ^ ord(y) - return result == 0 - - -def create_signed_value(secret, name, value, version=None, clock=None, - key_version=None): +def create_signed_value( + secret: _CookieSecretTypes, + name: str, + value: Union[str, bytes], + version: Optional[int] = None, + clock: Optional[Callable[[], float]] = None, + key_version: Optional[int] = None, +) -> bytes: if version is None: version = DEFAULT_SIGNED_VALUE_VERSION if clock is None: @@ -3137,6 +3358,7 @@ def create_signed_value(secret, name, value, version=None, clock=None, timestamp = utf8(str(int(clock()))) value = base64.b64encode(utf8(value)) if version == 1: + assert not isinstance(secret, dict) signature = _create_signature_v1(secret, name, value, timestamp) value = b"|".join([value, timestamp, signature]) return value @@ -3155,19 +3377,25 @@ def create_signed_value(secret, name, value, version=None, clock=None, # - name (not encoded; assumed to be ~alphanumeric) # - value (base64-encoded) # - signature (hex-encoded; no length prefix) - def format_field(s): + def format_field(s: Union[str, bytes]) -> bytes: return utf8("%d:" % len(s)) + utf8(s) - to_sign = b"|".join([ - b"2", - format_field(str(key_version or 0)), - format_field(timestamp), - format_field(name), - format_field(value), - b'']) + + to_sign = b"|".join( + [ + b"2", + format_field(str(key_version or 0)), + format_field(timestamp), + format_field(name), + format_field(value), + b"", + ] + ) if isinstance(secret, dict): - assert key_version is not None, 'Key version must be set when sign key dict is used' - assert version >= 2, 'Version must be at least 2 for key version support' + assert ( + key_version is not None + ), "Key version must be set when sign key dict is used" + assert version >= 2, "Version must be at least 2 for key version support" secret = secret[key_version] signature = _create_signature_v2(secret, to_sign) @@ -3181,7 +3409,7 @@ def format_field(s): _signed_value_version_re = re.compile(br"^([1-9][0-9]*)\|(.*)$") -def _get_version(value): +def _get_version(value: bytes) -> int: # Figures out what version value is. Version 1 did not include an # explicit version field and started with arbitrary base64 data, # which makes this tricky. @@ -3204,8 +3432,14 @@ def _get_version(value): return version -def decode_signed_value(secret, name, value, max_age_days=31, - clock=None, min_version=None): +def decode_signed_value( + secret: _CookieSecretTypes, + name: str, + value: Union[None, str, bytes], + max_age_days: float = 31, + clock: Optional[Callable[[], float]] = None, + min_version: Optional[int] = None, +) -> Optional[bytes]: if clock is None: clock = time.time if min_version is None: @@ -3221,21 +3455,26 @@ def decode_signed_value(secret, name, value, max_age_days=31, if version < min_version: return None if version == 1: - return _decode_signed_value_v1(secret, name, value, - max_age_days, clock) + assert not isinstance(secret, dict) + return _decode_signed_value_v1(secret, name, value, max_age_days, clock) elif version == 2: - return _decode_signed_value_v2(secret, name, value, - max_age_days, clock) + return _decode_signed_value_v2(secret, name, value, max_age_days, clock) else: return None -def _decode_signed_value_v1(secret, name, value, max_age_days, clock): +def _decode_signed_value_v1( + secret: Union[str, bytes], + name: str, + value: bytes, + max_age_days: float, + clock: Callable[[], float], +) -> Optional[bytes]: parts = utf8(value).split(b"|") if len(parts) != 3: return None signature = _create_signature_v1(secret, name, parts[0], parts[1]) - if not _time_independent_equals(parts[2], signature): + if not hmac.compare_digest(parts[2], signature): gen_log.warning("Invalid cookie signature %r", value) return None timestamp = int(parts[1]) @@ -3248,8 +3487,7 @@ def _decode_signed_value_v1(secret, name, value, max_age_days, clock): # digits from the payload to the timestamp without altering the # signature. For backwards compatibility, sanity-check timestamp # here instead of modifying _cookie_signature. - gen_log.warning("Cookie timestamp in future; possible tampering %r", - value) + gen_log.warning("Cookie timestamp in future; possible tampering %r", value) return None if parts[1].startswith(b"0"): gen_log.warning("Tampered cookie %r", value) @@ -3260,16 +3498,16 @@ def _decode_signed_value_v1(secret, name, value, max_age_days, clock): return None -def _decode_fields_v2(value): - def _consume_field(s): - length, _, rest = s.partition(b':') +def _decode_fields_v2(value: bytes) -> Tuple[int, bytes, bytes, bytes, bytes]: + def _consume_field(s: bytes) -> Tuple[bytes, bytes]: + length, _, rest = s.partition(b":") n = int(length) field_value = rest[:n] # In python 3, indexing bytes returns small integers; we must # use a slice to get a byte string as in python 2. - if rest[n:n + 1] != b'|': + if rest[n : n + 1] != b"|": raise ValueError("malformed v2 signed value field") - rest = rest[n + 1:] + rest = rest[n + 1 :] return field_value, rest rest = value[2:] # remove version number @@ -3280,12 +3518,24 @@ def _consume_field(s): return int(key_version), timestamp, name_field, value_field, passed_sig -def _decode_signed_value_v2(secret, name, value, max_age_days, clock): +def _decode_signed_value_v2( + secret: _CookieSecretTypes, + name: str, + value: bytes, + max_age_days: float, + clock: Callable[[], float], +) -> Optional[bytes]: try: - key_version, timestamp, name_field, value_field, passed_sig = _decode_fields_v2(value) + ( + key_version, + timestamp_bytes, + name_field, + value_field, + passed_sig, + ) = _decode_fields_v2(value) except ValueError: return None - signed_string = value[:-len(passed_sig)] + signed_string = value[: -len(passed_sig)] if isinstance(secret, dict): try: @@ -3294,11 +3544,11 @@ def _decode_signed_value_v2(secret, name, value, max_age_days, clock): return None expected_sig = _create_signature_v2(secret, signed_string) - if not _time_independent_equals(passed_sig, expected_sig): + if not hmac.compare_digest(passed_sig, expected_sig): return None if name_field != utf8(name): return None - timestamp = int(timestamp) + timestamp = int(timestamp_bytes) if timestamp < clock() - max_age_days * 86400: # The signature has expired. return None @@ -3308,7 +3558,7 @@ def _decode_signed_value_v2(secret, name, value, max_age_days, clock): return None -def get_signature_key_version(value): +def get_signature_key_version(value: Union[str, bytes]) -> Optional[int]: value = utf8(value) version = _get_version(value) if version < 2: @@ -3321,18 +3571,18 @@ def get_signature_key_version(value): return key_version -def _create_signature_v1(secret, *parts): +def _create_signature_v1(secret: Union[str, bytes], *parts: Union[str, bytes]) -> bytes: hash = hmac.new(utf8(secret), digestmod=hashlib.sha1) for part in parts: hash.update(utf8(part)) return utf8(hash.hexdigest()) -def _create_signature_v2(secret, s): +def _create_signature_v2(secret: Union[str, bytes], s: bytes) -> bytes: hash = hmac.new(utf8(secret), digestmod=hashlib.sha256) hash.update(utf8(s)) return utf8(hash.hexdigest()) -def is_absolute(path): +def is_absolute(path: str) -> bool: return any(path.startswith(x) for x in ["/", "http:", "https:"]) diff --git a/tornado/websocket.py b/tornado/websocket.py index 0507a92c67..ec4e1fb85b 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -16,32 +16,90 @@ Removed support for the draft 76 protocol version. """ -from __future__ import absolute_import, division, print_function - +import abc +import asyncio import base64 import hashlib import os +import sys import struct import tornado.escape import tornado.web +from urllib.parse import urlparse import zlib from tornado.concurrent import Future, future_set_result_unless_cancelled from tornado.escape import utf8, native_str, to_unicode from tornado import gen, httpclient, httputil from tornado.ioloop import IOLoop, PeriodicCallback -from tornado.iostream import StreamClosedError +from tornado.iostream import StreamClosedError, IOStream from tornado.log import gen_log, app_log from tornado import simple_httpclient from tornado.queues import Queue from tornado.tcpclient import TCPClient -from tornado.util import _websocket_mask, PY3 - -if PY3: - from urllib.parse import urlparse # py2 - xrange = range -else: - from urlparse import urlparse # py3 +from tornado.util import _websocket_mask + +from typing import ( + TYPE_CHECKING, + cast, + Any, + Optional, + Dict, + Union, + List, + Awaitable, + Callable, + Tuple, + Type, +) +from types import TracebackType + +if TYPE_CHECKING: + from typing_extensions import Protocol + + # The zlib compressor types aren't actually exposed anywhere + # publicly, so declare protocols for the portions we use. + class _Compressor(Protocol): + def compress(self, data: bytes) -> bytes: + pass + + def flush(self, mode: int) -> bytes: + pass + + class _Decompressor(Protocol): + unconsumed_tail = b"" # type: bytes + + def decompress(self, data: bytes, max_length: int) -> bytes: + pass + + class _WebSocketDelegate(Protocol): + # The common base interface implemented by WebSocketHandler on + # the server side and WebSocketClientConnection on the client + # side. + def on_ws_connection_close( + self, close_code: Optional[int] = None, close_reason: Optional[str] = None + ) -> None: + pass + + def on_message(self, message: Union[str, bytes]) -> Optional["Awaitable[None]"]: + pass + + def on_ping(self, data: bytes) -> None: + pass + + def on_pong(self, data: bytes) -> None: + pass + + def log_exception( + self, + typ: Optional[Type[BaseException]], + value: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + pass + + +_default_max_message_size = 10 * 1024 * 1024 class WebSocketError(Exception): @@ -53,9 +111,28 @@ class WebSocketClosedError(WebSocketError): .. versionadded:: 3.2 """ + pass +class _DecompressTooLargeError(Exception): + pass + + +class _WebSocketParams(object): + def __init__( + self, + ping_interval: Optional[float] = None, + ping_timeout: Optional[float] = None, + max_message_size: int = _default_max_message_size, + compression_options: Optional[Dict[str, Any]] = None, + ) -> None: + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.max_message_size = max_message_size + self.compression_options = compression_options + + class WebSocketHandler(tornado.web.RequestHandler): """Subclass this class to create a basic WebSocket handler. @@ -113,7 +190,7 @@ def on_close(self): Web browsers allow any site to open a websocket connection to any other, instead of using the same-origin policy that governs other network - access from javascript. This can be surprising and is a potential + access from JavaScript. This can be surprising and is a potential security hole, so since Tornado 4.0 `WebSocketHandler` requires applications that wish to receive cross-origin websockets to opt in by overriding the `~WebSocketHandler.check_origin` method (see that @@ -137,23 +214,27 @@ def on_close(self): Added ``websocket_ping_interval``, ``websocket_ping_timeout``, and ``websocket_max_message_size``. """ - def __init__(self, application, request, **kwargs): - super(WebSocketHandler, self).__init__(application, request, **kwargs) - self.ws_connection = None - self.close_code = None - self.close_reason = None - self.stream = None + + def __init__( + self, + application: tornado.web.Application, + request: httputil.HTTPServerRequest, + **kwargs: Any + ) -> None: + super().__init__(application, request, **kwargs) + self.ws_connection = None # type: Optional[WebSocketProtocol] + self.close_code = None # type: Optional[int] + self.close_reason = None # type: Optional[str] self._on_close_called = False - @tornado.web.asynchronous - def get(self, *args, **kwargs): + async def get(self, *args: Any, **kwargs: Any) -> None: self.open_args = args self.open_kwargs = kwargs # Upgrade header should be present and should be equal to WebSocket - if self.request.headers.get("Upgrade", "").lower() != 'websocket': + if self.request.headers.get("Upgrade", "").lower() != "websocket": self.set_status(400) - log_msg = "Can \"Upgrade\" only to \"WebSocket\"." + log_msg = 'Can "Upgrade" only to "WebSocket".' self.finish(log_msg) gen_log.debug(log_msg) return @@ -162,11 +243,12 @@ def get(self, *args, **kwargs): # Some proxy servers/load balancers # might mess with it. headers = self.request.headers - connection = map(lambda s: s.strip().lower(), - headers.get("Connection", "").split(",")) - if 'upgrade' not in connection: + connection = map( + lambda s: s.strip().lower(), headers.get("Connection", "").split(",") + ) + if "upgrade" not in connection: self.set_status(400) - log_msg = "\"Connection\" must be \"Upgrade\"." + log_msg = '"Connection" must be "Upgrade".' self.finish(log_msg) gen_log.debug(log_msg) return @@ -192,32 +274,29 @@ def get(self, *args, **kwargs): self.ws_connection = self.get_websocket_protocol() if self.ws_connection: - self.ws_connection.accept_connection() + await self.ws_connection.accept_connection(self) else: self.set_status(426, "Upgrade Required") self.set_header("Sec-WebSocket-Version", "7, 8, 13") - self.finish() - - stream = None @property - def ping_interval(self): + def ping_interval(self) -> Optional[float]: """The interval for websocket keep-alive pings. Set websocket_ping_interval = 0 to disable pings. """ - return self.settings.get('websocket_ping_interval', None) + return self.settings.get("websocket_ping_interval", None) @property - def ping_timeout(self): + def ping_timeout(self) -> Optional[float]: """If no ping is received in this many seconds, close the websocket connection (VPNs, etc. can fail to cleanly close ws connections). Default is max of 3 pings or 30 seconds. """ - return self.settings.get('websocket_ping_timeout', None) + return self.settings.get("websocket_ping_timeout", None) @property - def max_message_size(self): + def max_message_size(self) -> int: """Maximum allowed message size. If the remote peer sends a message larger than this, the connection @@ -225,9 +304,13 @@ def max_message_size(self): Default is 10MiB. """ - return self.settings.get('websocket_max_message_size', None) + return self.settings.get( + "websocket_max_message_size", _default_max_message_size + ) - def write_message(self, message, binary=False): + def write_message( + self, message: Union[bytes, str, Dict[str, Any]], binary: bool = False + ) -> "Future[None]": """Sends the given message to the client of this Web Socket. The message may be either a string or a dict (which will be @@ -249,26 +332,47 @@ def write_message(self, message, binary=False): Consistently raises `WebSocketClosedError`. Previously could sometimes raise `.StreamClosedError`. """ - if self.ws_connection is None: + if self.ws_connection is None or self.ws_connection.is_closing(): raise WebSocketClosedError() if isinstance(message, dict): message = tornado.escape.json_encode(message) return self.ws_connection.write_message(message, binary=binary) - def select_subprotocol(self, subprotocols): - """Invoked when a new WebSocket requests specific subprotocols. + def select_subprotocol(self, subprotocols: List[str]) -> Optional[str]: + """Override to implement subprotocol negotiation. ``subprotocols`` is a list of strings identifying the subprotocols proposed by the client. This method may be overridden to return one of those strings to select it, or - ``None`` to not select a subprotocol. Failure to select a - subprotocol does not automatically abort the connection, - although clients may close the connection if none of their - proposed subprotocols was selected. + ``None`` to not select a subprotocol. + + Failure to select a subprotocol does not automatically abort + the connection, although clients may close the connection if + none of their proposed subprotocols was selected. + + The list may be empty, in which case this method must return + None. This method is always called exactly once even if no + subprotocols were proposed so that the handler can be advised + of this fact. + + .. versionchanged:: 5.1 + + Previously, this method was called with a list containing + an empty string instead of an empty list if no subprotocols + were proposed by the client. """ return None - def get_compression_options(self): + @property + def selected_subprotocol(self) -> Optional[str]: + """The subprotocol returned by `select_subprotocol`. + + .. versionadded:: 5.1 + """ + assert self.ws_connection is not None + return self.ws_connection.selected_subprotocol + + def get_compression_options(self) -> Optional[Dict[str, Any]]: """Override to return compression options for the connection. If this method returns None (the default), compression will @@ -292,16 +396,23 @@ def get_compression_options(self): # TODO: Add wbits option. return None - def open(self, *args, **kwargs): + def open(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]: """Invoked when a new WebSocket is opened. The arguments to `open` are extracted from the `tornado.web.URLSpec` regular expression, just like the arguments to `tornado.web.RequestHandler.get`. + + `open` may be a coroutine. `on_message` will not be called until + `open` has returned. + + .. versionchanged:: 5.1 + + ``open`` may be a coroutine. """ pass - def on_message(self, message): + def on_message(self, message: Union[str, bytes]) -> Optional[Awaitable[None]]: """Handle incoming messages on the WebSocket This method must be overridden. @@ -312,7 +423,7 @@ def on_message(self, message): """ raise NotImplementedError - def ping(self, data=b''): + def ping(self, data: Union[str, bytes] = b"") -> None: """Send ping frame to the remote end. The data argument allows a small amount of data (up to 125 @@ -329,19 +440,19 @@ def ping(self, data=b''): """ data = utf8(data) - if self.ws_connection is None: + if self.ws_connection is None or self.ws_connection.is_closing(): raise WebSocketClosedError() self.ws_connection.write_ping(data) - def on_pong(self, data): + def on_pong(self, data: bytes) -> None: """Invoked when the response to a ping frame is received.""" pass - def on_ping(self, data): + def on_ping(self, data: bytes) -> None: """Invoked when the a ping frame is received.""" pass - def on_close(self): + def on_close(self) -> None: """Invoked when the WebSocket is closed. If the connection was closed cleanly and a status code or reason @@ -354,7 +465,7 @@ def on_close(self): """ pass - def close(self, code=None, reason=None): + def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None: """Closes this Web Socket. Once the close handshake is successful the socket will be closed. @@ -374,7 +485,7 @@ def close(self, code=None, reason=None): self.ws_connection.close(code, reason) self.ws_connection = None - def check_origin(self, origin): + def check_origin(self, origin: str) -> bool: """Override to enable support for allowing alternate origins. The ``origin`` argument is the value of the ``Origin`` HTTP @@ -384,9 +495,9 @@ def check_origin(self, origin): implement WebSockets support this header, and non-browser clients do not have the same cross-site security concerns). - Should return True to accept the request or False to reject it. - By default, rejects all requests with an origin on a host other - than this one. + Should return ``True`` to accept the request or ``False`` to + reject it. By default, rejects all requests with an origin on + a host other than this one. This is a security protection against cross site scripting attacks on browsers, since WebSockets are allowed to bypass the usual same-origin @@ -406,7 +517,7 @@ def check_origin(self, origin): for more. To accept all cross-origin traffic (which was the default prior to - Tornado 4.0), simply override this method to always return true:: + Tornado 4.0), simply override this method to always return ``True``:: def check_origin(self, origin): return True @@ -430,7 +541,7 @@ def check_origin(self, origin): # Check to see that origin matches host directly, including ports return origin == host - def set_nodelay(self, value): + def set_nodelay(self, value: bool) -> None: """Set the no-delay flag for this stream. By default, small messages may be delayed and/or combined to minimize @@ -444,9 +555,10 @@ def set_nodelay(self, value): .. versionadded:: 3.1 """ - self.stream.set_nodelay(value) + assert self.ws_connection is not None + self.ws_connection.set_nodelay(value) - def on_connection_close(self): + def on_connection_close(self) -> None: if self.ws_connection: self.ws_connection.on_connection_close() self.ws_connection = None @@ -455,55 +567,65 @@ def on_connection_close(self): self.on_close() self._break_cycles() - def _break_cycles(self): + def on_ws_connection_close( + self, close_code: Optional[int] = None, close_reason: Optional[str] = None + ) -> None: + self.close_code = close_code + self.close_reason = close_reason + self.on_connection_close() + + def _break_cycles(self) -> None: # WebSocketHandlers call finish() early, but we don't want to # break up reference cycles (which makes it impossible to call # self.render_string) until after we've really closed the # connection (if it was established in the first place, # indicated by status code 101). if self.get_status() != 101 or self._on_close_called: - super(WebSocketHandler, self)._break_cycles() + super()._break_cycles() - def send_error(self, *args, **kwargs): - if self.stream is None: - super(WebSocketHandler, self).send_error(*args, **kwargs) - else: - # If we get an uncaught exception during the handshake, - # we have no choice but to abruptly close the connection. - # TODO: for uncaught exceptions after the handshake, - # we can close the connection more gracefully. - self.stream.close() - - def get_websocket_protocol(self): + def get_websocket_protocol(self) -> Optional["WebSocketProtocol"]: websocket_version = self.request.headers.get("Sec-WebSocket-Version") if websocket_version in ("7", "8", "13"): - return WebSocketProtocol13( - self, compression_options=self.get_compression_options()) + params = _WebSocketParams( + ping_interval=self.ping_interval, + ping_timeout=self.ping_timeout, + max_message_size=self.max_message_size, + compression_options=self.get_compression_options(), + ) + return WebSocketProtocol13(self, False, params) + return None - def _attach_stream(self): - self.stream = self.request.connection.detach() - self.stream.set_close_callback(self.on_connection_close) + def _detach_stream(self) -> IOStream: # disable non-WS methods - for method in ["write", "redirect", "set_header", "set_cookie", - "set_status", "flush", "finish"]: + for method in [ + "write", + "redirect", + "set_header", + "set_cookie", + "set_status", + "flush", + "finish", + ]: setattr(self, method, _raise_not_supported_for_websockets) + return self.detach() -def _raise_not_supported_for_websockets(*args, **kwargs): +def _raise_not_supported_for_websockets(*args: Any, **kwargs: Any) -> None: raise RuntimeError("Method not supported for Web Sockets") -class WebSocketProtocol(object): - """Base class for WebSocket protocol versions. - """ - def __init__(self, handler): +class WebSocketProtocol(abc.ABC): + """Base class for WebSocket protocol versions.""" + + def __init__(self, handler: "_WebSocketDelegate") -> None: self.handler = handler - self.request = handler.request - self.stream = handler.stream + self.stream = None # type: Optional[IOStream] self.client_terminated = False self.server_terminated = False - def _run_callback(self, callback, *args, **kwargs): + def _run_callback( + self, callback: Callable, *args: Any, **kwargs: Any + ) -> "Optional[Future[Any]]": """Runs the given callback with exception handling. If the callback is a coroutine, returns its Future. On error, aborts the @@ -512,82 +634,161 @@ def _run_callback(self, callback, *args, **kwargs): try: result = callback(*args, **kwargs) except Exception: - app_log.error("Uncaught exception in %s", - getattr(self.request, 'path', None), exc_info=True) + self.handler.log_exception(*sys.exc_info()) self._abort() + return None else: if result is not None: result = gen.convert_yielded(result) + assert self.stream is not None self.stream.io_loop.add_future(result, lambda f: f.result()) return result - def on_connection_close(self): + def on_connection_close(self) -> None: self._abort() - def _abort(self): + def _abort(self) -> None: """Instantly aborts the WebSocket connection by closing the socket""" self.client_terminated = True self.server_terminated = True - self.stream.close() # forcibly tear down the connection + if self.stream is not None: + self.stream.close() # forcibly tear down the connection self.close() # let the subclass cleanup + @abc.abstractmethod + def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def is_closing(self) -> bool: + raise NotImplementedError() + + @abc.abstractmethod + async def accept_connection(self, handler: WebSocketHandler) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def write_message( + self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False + ) -> "Future[None]": + raise NotImplementedError() + + @property + @abc.abstractmethod + def selected_subprotocol(self) -> Optional[str]: + raise NotImplementedError() + + @abc.abstractmethod + def write_ping(self, data: bytes) -> None: + raise NotImplementedError() + + # The entry points below are used by WebSocketClientConnection, + # which was introduced after we only supported a single version of + # WebSocketProtocol. The WebSocketProtocol/WebSocketProtocol13 + # boundary is currently pretty ad-hoc. + @abc.abstractmethod + def _process_server_headers( + self, key: Union[str, bytes], headers: httputil.HTTPHeaders + ) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def start_pinging(self) -> None: + raise NotImplementedError() + + @abc.abstractmethod + async def _receive_frame_loop(self) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def set_nodelay(self, x: bool) -> None: + raise NotImplementedError() + class _PerMessageDeflateCompressor(object): - def __init__(self, persistent, max_wbits, compression_options=None): + def __init__( + self, + persistent: bool, + max_wbits: Optional[int], + compression_options: Optional[Dict[str, Any]] = None, + ) -> None: if max_wbits is None: max_wbits = zlib.MAX_WBITS # There is no symbolic constant for the minimum wbits value. if not (8 <= max_wbits <= zlib.MAX_WBITS): - raise ValueError("Invalid max_wbits value %r; allowed range 8-%d", - max_wbits, zlib.MAX_WBITS) + raise ValueError( + "Invalid max_wbits value %r; allowed range 8-%d", + max_wbits, + zlib.MAX_WBITS, + ) self._max_wbits = max_wbits - if compression_options is None or 'compression_level' not in compression_options: + if ( + compression_options is None + or "compression_level" not in compression_options + ): self._compression_level = tornado.web.GZipContentEncoding.GZIP_LEVEL else: - self._compression_level = compression_options['compression_level'] + self._compression_level = compression_options["compression_level"] - if compression_options is None or 'mem_level' not in compression_options: + if compression_options is None or "mem_level" not in compression_options: self._mem_level = 8 else: - self._mem_level = compression_options['mem_level'] + self._mem_level = compression_options["mem_level"] if persistent: - self._compressor = self._create_compressor() + self._compressor = self._create_compressor() # type: Optional[_Compressor] else: self._compressor = None - def _create_compressor(self): - return zlib.compressobj(self._compression_level, - zlib.DEFLATED, -self._max_wbits, self._mem_level) + def _create_compressor(self) -> "_Compressor": + return zlib.compressobj( + self._compression_level, zlib.DEFLATED, -self._max_wbits, self._mem_level + ) - def compress(self, data): + def compress(self, data: bytes) -> bytes: compressor = self._compressor or self._create_compressor() - data = (compressor.compress(data) + - compressor.flush(zlib.Z_SYNC_FLUSH)) - assert data.endswith(b'\x00\x00\xff\xff') + data = compressor.compress(data) + compressor.flush(zlib.Z_SYNC_FLUSH) + assert data.endswith(b"\x00\x00\xff\xff") return data[:-4] class _PerMessageDeflateDecompressor(object): - def __init__(self, persistent, max_wbits, compression_options=None): + def __init__( + self, + persistent: bool, + max_wbits: Optional[int], + max_message_size: int, + compression_options: Optional[Dict[str, Any]] = None, + ) -> None: + self._max_message_size = max_message_size if max_wbits is None: max_wbits = zlib.MAX_WBITS if not (8 <= max_wbits <= zlib.MAX_WBITS): - raise ValueError("Invalid max_wbits value %r; allowed range 8-%d", - max_wbits, zlib.MAX_WBITS) + raise ValueError( + "Invalid max_wbits value %r; allowed range 8-%d", + max_wbits, + zlib.MAX_WBITS, + ) self._max_wbits = max_wbits if persistent: - self._decompressor = self._create_decompressor() + self._decompressor = ( + self._create_decompressor() + ) # type: Optional[_Decompressor] else: self._decompressor = None - def _create_decompressor(self): + def _create_decompressor(self) -> "_Decompressor": return zlib.decompressobj(-self._max_wbits) - def decompress(self, data): + def decompress(self, data: bytes) -> bytes: decompressor = self._decompressor or self._create_decompressor() - return decompressor.decompress(data + b'\x00\x00\xff\xff') + result = decompressor.decompress( + data + b"\x00\x00\xff\xff", self._max_message_size + ) + if decompressor.unconsumed_tail: + raise _DecompressTooLargeError() + return result class WebSocketProtocol13(WebSocketProtocol): @@ -596,30 +797,38 @@ class WebSocketProtocol13(WebSocketProtocol): This class supports versions 7 and 8 of the protocol in addition to the final version 13. """ + # Bit masks for the first byte of a frame. FIN = 0x80 RSV1 = 0x40 RSV2 = 0x20 RSV3 = 0x10 RSV_MASK = RSV1 | RSV2 | RSV3 - OPCODE_MASK = 0x0f + OPCODE_MASK = 0x0F + + stream = None # type: IOStream - def __init__(self, handler, mask_outgoing=False, - compression_options=None): + def __init__( + self, + handler: "_WebSocketDelegate", + mask_outgoing: bool, + params: _WebSocketParams, + ) -> None: WebSocketProtocol.__init__(self, handler) self.mask_outgoing = mask_outgoing + self.params = params self._final_frame = False self._frame_opcode = None self._masked_frame = None - self._frame_mask = None + self._frame_mask = None # type: Optional[bytes] self._frame_length = None - self._fragmented_message_buffer = None + self._fragmented_message_buffer = None # type: Optional[bytes] self._fragmented_message_opcode = None - self._waiting = None - self._compression_options = compression_options - self._decompressor = None - self._compressor = None - self._frame_compressed = None + self._waiting = None # type: object + self._compression_options = params.compression_options + self._decompressor = None # type: Optional[_PerMessageDeflateDecompressor] + self._compressor = None # type: Optional[_PerMessageDeflateCompressor] + self._frame_compressed = None # type: Optional[bool] # The total uncompressed size of all messages received or sent. # Unicode messages are encoded to utf8. # Only for testing; subject to change. @@ -629,40 +838,53 @@ def __init__(self, handler, mask_outgoing=False, # the effect of compression, frame overhead, and control frames. self._wire_bytes_in = 0 self._wire_bytes_out = 0 - self.ping_callback = None - self.last_ping = 0 - self.last_pong = 0 + self.ping_callback = None # type: Optional[PeriodicCallback] + self.last_ping = 0.0 + self.last_pong = 0.0 + self.close_code = None # type: Optional[int] + self.close_reason = None # type: Optional[str] + + # Use a property for this to satisfy the abc. + @property + def selected_subprotocol(self) -> Optional[str]: + return self._selected_subprotocol + + @selected_subprotocol.setter + def selected_subprotocol(self, value: Optional[str]) -> None: + self._selected_subprotocol = value - def accept_connection(self): + async def accept_connection(self, handler: WebSocketHandler) -> None: try: - self._handle_websocket_headers() + self._handle_websocket_headers(handler) except ValueError: - self.handler.set_status(400) + handler.set_status(400) log_msg = "Missing/Invalid WebSocket headers" - self.handler.finish(log_msg) + handler.finish(log_msg) gen_log.debug(log_msg) return try: - self._accept_connection() + await self._accept_connection(handler) + except asyncio.CancelledError: + self._abort() + return except ValueError: - gen_log.debug("Malformed WebSocket request received", - exc_info=True) + gen_log.debug("Malformed WebSocket request received", exc_info=True) self._abort() return - def _handle_websocket_headers(self): + def _handle_websocket_headers(self, handler: WebSocketHandler) -> None: """Verifies all invariant- and required headers If a header is missing or have an incorrect value ValueError will be raised """ fields = ("Host", "Sec-Websocket-Key", "Sec-Websocket-Version") - if not all(map(lambda f: self.request.headers.get(f), fields)): + if not all(map(lambda f: handler.request.headers.get(f), fields)): raise ValueError("Missing/Invalid WebSocket headers") @staticmethod - def compute_accept_value(key): + def compute_accept_value(key: Union[str, bytes]) -> str: """Computes the value for the Sec-WebSocket-Accept header, given the value for Sec-WebSocket-Key. """ @@ -671,105 +893,143 @@ def compute_accept_value(key): sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11") # Magic value return native_str(base64.b64encode(sha1.digest())) - def _challenge_response(self): + def _challenge_response(self, handler: WebSocketHandler) -> str: return WebSocketProtocol13.compute_accept_value( - self.request.headers.get("Sec-Websocket-Key")) + cast(str, handler.request.headers.get("Sec-Websocket-Key")) + ) - def _accept_connection(self): - subprotocols = [s.strip() for s in self.request.headers.get_list("Sec-WebSocket-Protocol")] - if subprotocols: - selected = self.handler.select_subprotocol(subprotocols) - if selected: - assert selected in subprotocols - self.handler.set_header("Sec-WebSocket-Protocol", selected) + async def _accept_connection(self, handler: WebSocketHandler) -> None: + subprotocol_header = handler.request.headers.get("Sec-WebSocket-Protocol") + if subprotocol_header: + subprotocols = [s.strip() for s in subprotocol_header.split(",")] + else: + subprotocols = [] + self.selected_subprotocol = handler.select_subprotocol(subprotocols) + if self.selected_subprotocol: + assert self.selected_subprotocol in subprotocols + handler.set_header("Sec-WebSocket-Protocol", self.selected_subprotocol) - extensions = self._parse_extensions_header(self.request.headers) + extensions = self._parse_extensions_header(handler.request.headers) for ext in extensions: - if (ext[0] == 'permessage-deflate' and - self._compression_options is not None): + if ext[0] == "permessage-deflate" and self._compression_options is not None: # TODO: negotiate parameters if compression_options # specifies limits. - self._create_compressors('server', ext[1], self._compression_options) - if ('client_max_window_bits' in ext[1] and - ext[1]['client_max_window_bits'] is None): + self._create_compressors("server", ext[1], self._compression_options) + if ( + "client_max_window_bits" in ext[1] + and ext[1]["client_max_window_bits"] is None + ): # Don't echo an offered client_max_window_bits # parameter with no value. - del ext[1]['client_max_window_bits'] - self.handler.set_header("Sec-WebSocket-Extensions", - httputil._encode_header( - 'permessage-deflate', ext[1])) + del ext[1]["client_max_window_bits"] + handler.set_header( + "Sec-WebSocket-Extensions", + httputil._encode_header("permessage-deflate", ext[1]), + ) break - self.handler.clear_header("Content-Type") - self.handler.set_status(101) - self.handler.set_header("Upgrade", "websocket") - self.handler.set_header("Connection", "Upgrade") - self.handler.set_header("Sec-WebSocket-Accept", self._challenge_response()) - self.handler.finish() + handler.clear_header("Content-Type") + handler.set_status(101) + handler.set_header("Upgrade", "websocket") + handler.set_header("Connection", "Upgrade") + handler.set_header("Sec-WebSocket-Accept", self._challenge_response(handler)) + handler.finish() - self.handler._attach_stream() - self.stream = self.handler.stream + self.stream = handler._detach_stream() self.start_pinging() - self._run_callback(self.handler.open, *self.handler.open_args, - **self.handler.open_kwargs) - IOLoop.current().add_callback(self._receive_frame_loop) + try: + open_result = handler.open(*handler.open_args, **handler.open_kwargs) + if open_result is not None: + await open_result + except Exception: + handler.log_exception(*sys.exc_info()) + self._abort() + return - def _parse_extensions_header(self, headers): - extensions = headers.get("Sec-WebSocket-Extensions", '') + await self._receive_frame_loop() + + def _parse_extensions_header( + self, headers: httputil.HTTPHeaders + ) -> List[Tuple[str, Dict[str, str]]]: + extensions = headers.get("Sec-WebSocket-Extensions", "") if extensions: - return [httputil._parse_header(e.strip()) - for e in extensions.split(',')] + return [httputil._parse_header(e.strip()) for e in extensions.split(",")] return [] - def _process_server_headers(self, key, headers): + def _process_server_headers( + self, key: Union[str, bytes], headers: httputil.HTTPHeaders + ) -> None: """Process the headers sent by the server to this client connection. 'key' is the websocket handshake challenge/response key. """ - assert headers['Upgrade'].lower() == 'websocket' - assert headers['Connection'].lower() == 'upgrade' + assert headers["Upgrade"].lower() == "websocket" + assert headers["Connection"].lower() == "upgrade" accept = self.compute_accept_value(key) - assert headers['Sec-Websocket-Accept'] == accept + assert headers["Sec-Websocket-Accept"] == accept extensions = self._parse_extensions_header(headers) for ext in extensions: - if (ext[0] == 'permessage-deflate' and - self._compression_options is not None): - self._create_compressors('client', ext[1]) + if ext[0] == "permessage-deflate" and self._compression_options is not None: + self._create_compressors("client", ext[1]) else: raise ValueError("unsupported extension %r", ext) - def _get_compressor_options(self, side, agreed_parameters, compression_options=None): + self.selected_subprotocol = headers.get("Sec-WebSocket-Protocol", None) + + def _get_compressor_options( + self, + side: str, + agreed_parameters: Dict[str, Any], + compression_options: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: """Converts a websocket agreed_parameters set to keyword arguments for our compressor objects. """ options = dict( - persistent=(side + '_no_context_takeover') not in agreed_parameters) - wbits_header = agreed_parameters.get(side + '_max_window_bits', None) + persistent=(side + "_no_context_takeover") not in agreed_parameters + ) # type: Dict[str, Any] + wbits_header = agreed_parameters.get(side + "_max_window_bits", None) if wbits_header is None: - options['max_wbits'] = zlib.MAX_WBITS + options["max_wbits"] = zlib.MAX_WBITS else: - options['max_wbits'] = int(wbits_header) - options['compression_options'] = compression_options + options["max_wbits"] = int(wbits_header) + options["compression_options"] = compression_options return options - def _create_compressors(self, side, agreed_parameters, compression_options=None): + def _create_compressors( + self, + side: str, + agreed_parameters: Dict[str, Any], + compression_options: Optional[Dict[str, Any]] = None, + ) -> None: # TODO: handle invalid parameters gracefully - allowed_keys = set(['server_no_context_takeover', - 'client_no_context_takeover', - 'server_max_window_bits', - 'client_max_window_bits']) + allowed_keys = set( + [ + "server_no_context_takeover", + "client_no_context_takeover", + "server_max_window_bits", + "client_max_window_bits", + ] + ) for key in agreed_parameters: if key not in allowed_keys: raise ValueError("unsupported compression parameter %r" % key) - other_side = 'client' if (side == 'server') else 'server' + other_side = "client" if (side == "server") else "server" self._compressor = _PerMessageDeflateCompressor( - **self._get_compressor_options(side, agreed_parameters, compression_options)) + **self._get_compressor_options(side, agreed_parameters, compression_options) + ) self._decompressor = _PerMessageDeflateDecompressor( - **self._get_compressor_options(other_side, agreed_parameters, compression_options)) - - def _write_frame(self, fin, opcode, data, flags=0): + max_message_size=self.params.max_message_size, + **self._get_compressor_options( + other_side, agreed_parameters, compression_options + ) + ) + + def _write_frame( + self, fin: bool, opcode: int, data: bytes, flags: int = 0 + ) -> "Future[None]": data_len = len(data) if opcode & 0x8: # All control frames MUST have a payload length of 125 @@ -800,12 +1060,16 @@ def _write_frame(self, fin, opcode, data, flags=0): self._wire_bytes_out += len(frame) return self.stream.write(frame) - def write_message(self, message, binary=False): + def write_message( + self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False + ) -> "Future[None]": """Sends the given message to the client of this Web Socket.""" if binary: opcode = 0x2 else: opcode = 0x1 + if isinstance(message, dict): + message = tornado.escape.json_encode(message) message = tornado.escape.utf8(message) assert isinstance(message, bytes) self._message_bytes_out += len(message) @@ -823,35 +1087,35 @@ def write_message(self, message, binary=False): except StreamClosedError: raise WebSocketClosedError() - @gen.coroutine - def wrapper(): + async def wrapper() -> None: try: - yield fut + await fut except StreamClosedError: raise WebSocketClosedError() - return wrapper() - def write_ping(self, data): + return asyncio.ensure_future(wrapper()) + + def write_ping(self, data: bytes) -> None: """Send ping frame.""" assert isinstance(data, bytes) self._write_frame(True, 0x9, data) - @gen.coroutine - def _receive_frame_loop(self): + async def _receive_frame_loop(self) -> None: try: while not self.client_terminated: - yield self._receive_frame() + await self._receive_frame() except StreamClosedError: self._abort() + self.handler.on_ws_connection_close(self.close_code, self.close_reason) - def _read_bytes(self, n): + async def _read_bytes(self, n: int) -> bytes: + data = await self.stream.read_bytes(n) self._wire_bytes_in += n - return self.stream.read_bytes(n) + return data - @gen.coroutine - def _receive_frame(self): + async def _receive_frame(self) -> None: # Read the frame header. - data = yield self._read_bytes(2) + data = await self._read_bytes(2) header, mask_payloadlen = struct.unpack("BB", data) is_final_frame = header & self.FIN reserved_bits = header & self.RSV_MASK @@ -868,7 +1132,7 @@ def _receive_frame(self): self._abort() return is_masked = bool(mask_payloadlen & 0x80) - payloadlen = mask_payloadlen & 0x7f + payloadlen = mask_payloadlen & 0x7F # Parse and validate the length. if opcode_is_control and payloadlen >= 126: @@ -878,24 +1142,25 @@ def _receive_frame(self): if payloadlen < 126: self._frame_length = payloadlen elif payloadlen == 126: - data = yield self._read_bytes(2) + data = await self._read_bytes(2) payloadlen = struct.unpack("!H", data)[0] elif payloadlen == 127: - data = yield self._read_bytes(8) + data = await self._read_bytes(8) payloadlen = struct.unpack("!Q", data)[0] new_len = payloadlen if self._fragmented_message_buffer is not None: new_len += len(self._fragmented_message_buffer) - if new_len > (self.handler.max_message_size or 10 * 1024 * 1024): + if new_len > self.params.max_message_size: self.close(1009, "message too big") self._abort() return # Read the payload, unmasking if necessary. if is_masked: - self._frame_mask = yield self._read_bytes(4) - data = yield self._read_bytes(payloadlen) + self._frame_mask = await self._read_bytes(4) + data = await self._read_bytes(payloadlen) if is_masked: + assert self._frame_mask is not None data = _websocket_mask(self._frame_mask, data) # Decide what to do with this frame. @@ -929,15 +1194,21 @@ def _receive_frame(self): if is_final_frame: handled_future = self._handle_message(opcode, data) if handled_future is not None: - yield handled_future + await handled_future - def _handle_message(self, opcode, data): + def _handle_message(self, opcode: int, data: bytes) -> "Optional[Future[None]]": """Execute on_message, returning its Future if it is a coroutine.""" if self.client_terminated: - return + return None if self._frame_compressed: - data = self._decompressor.decompress(data) + assert self._decompressor is not None + try: + data = self._decompressor.decompress(data) + except _DecompressTooLargeError: + self.close(1009, "message too big after decompression") + self._abort() + return None if opcode == 0x1: # UTF-8 data @@ -946,7 +1217,7 @@ def _handle_message(self, opcode, data): decoded = data.decode("utf-8") except UnicodeDecodeError: self._abort() - return + return None return self._run_callback(self.handler.on_message, decoded) elif opcode == 0x2: # Binary data @@ -956,11 +1227,11 @@ def _handle_message(self, opcode, data): # Close self.client_terminated = True if len(data) >= 2: - self.handler.close_code = struct.unpack('>H', data[:2])[0] + self.close_code = struct.unpack(">H", data[:2])[0] if len(data) > 2: - self.handler.close_reason = to_unicode(data[2:]) + self.close_reason = to_unicode(data[2:]) # Echo the received close code, if any (RFC 6455 section 5.5.1). - self.close(self.handler.close_code) + self.close(self.close_code) elif opcode == 0x9: # Ping try: @@ -974,17 +1245,18 @@ def _handle_message(self, opcode, data): return self._run_callback(self.handler.on_pong, data) else: self._abort() + return None - def close(self, code=None, reason=None): + def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None: """Closes the WebSocket connection.""" if not self.server_terminated: if not self.stream.closed(): if code is None and reason is not None: code = 1000 # "normal closure" status code if code is None: - close_data = b'' + close_data = b"" else: - close_data = struct.pack('>H', code) + close_data = struct.pack(">H", code) if reason is not None: close_data += utf8(reason) try: @@ -1001,36 +1273,52 @@ def close(self, code=None, reason=None): # Give the client a few seconds to complete a clean shutdown, # otherwise just close the connection. self._waiting = self.stream.io_loop.add_timeout( - self.stream.io_loop.time() + 5, self._abort) + self.stream.io_loop.time() + 5, self._abort + ) + if self.ping_callback: + self.ping_callback.stop() + self.ping_callback = None + + def is_closing(self) -> bool: + """Return ``True`` if this connection is closing. + + The connection is considered closing if either side has + initiated its closing handshake or if the stream has been + shut down uncleanly. + """ + return self.stream.closed() or self.client_terminated or self.server_terminated @property - def ping_interval(self): - interval = self.handler.ping_interval + def ping_interval(self) -> Optional[float]: + interval = self.params.ping_interval if interval is not None: return interval return 0 @property - def ping_timeout(self): - timeout = self.handler.ping_timeout + def ping_timeout(self) -> Optional[float]: + timeout = self.params.ping_timeout if timeout is not None: return timeout + assert self.ping_interval is not None return max(3 * self.ping_interval, 30) - def start_pinging(self): + def start_pinging(self) -> None: """Start sending periodic pings to keep the connection alive""" + assert self.ping_interval is not None if self.ping_interval > 0: self.last_ping = self.last_pong = IOLoop.current().time() self.ping_callback = PeriodicCallback( - self.periodic_ping, self.ping_interval * 1000) + self.periodic_ping, self.ping_interval * 1000 + ) self.ping_callback.start() - def periodic_ping(self): + def periodic_ping(self) -> None: """Send a ping to keep the websocket alive Called periodically if the websocket_ping_interval is set and non-zero. """ - if self.stream.closed() and self.ping_callback is not None: + if self.is_closing() and self.ping_callback is not None: self.ping_callback.stop() return @@ -1040,14 +1328,21 @@ def periodic_ping(self): now = IOLoop.current().time() since_last_pong = now - self.last_pong since_last_ping = now - self.last_ping - if (since_last_ping < 2 * self.ping_interval and - since_last_pong > self.ping_timeout): + assert self.ping_interval is not None + assert self.ping_timeout is not None + if ( + since_last_ping < 2 * self.ping_interval + and since_last_pong > self.ping_timeout + ): self.close() return - self.write_ping(b'') + self.write_ping(b"") self.last_ping = now + def set_nodelay(self, x: bool) -> None: + self.stream.set_nodelay(x) + class WebSocketClientConnection(simple_httpclient._HTTPConnection): """WebSocket client connection. @@ -1055,44 +1350,71 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): This class should not be instantiated directly; use the `websocket_connect` function instead. """ - def __init__(self, request, on_message_callback=None, - compression_options=None, ping_interval=None, ping_timeout=None, - max_message_size=None): - self.compression_options = compression_options - self.connect_future = Future() - self.protocol = None - self.read_queue = Queue(1) + + protocol = None # type: WebSocketProtocol + + def __init__( + self, + request: httpclient.HTTPRequest, + on_message_callback: Optional[Callable[[Union[None, str, bytes]], None]] = None, + compression_options: Optional[Dict[str, Any]] = None, + ping_interval: Optional[float] = None, + ping_timeout: Optional[float] = None, + max_message_size: int = _default_max_message_size, + subprotocols: Optional[List[str]] = [], + ) -> None: + self.connect_future = Future() # type: Future[WebSocketClientConnection] + self.read_queue = Queue(1) # type: Queue[Union[None, str, bytes]] self.key = base64.b64encode(os.urandom(16)) self._on_message_callback = on_message_callback - self.close_code = self.close_reason = None - self.ping_interval = ping_interval - self.ping_timeout = ping_timeout - self.max_message_size = max_message_size - - scheme, sep, rest = request.url.partition(':') - scheme = {'ws': 'http', 'wss': 'https'}[scheme] + self.close_code = None # type: Optional[int] + self.close_reason = None # type: Optional[str] + self.params = _WebSocketParams( + ping_interval=ping_interval, + ping_timeout=ping_timeout, + max_message_size=max_message_size, + compression_options=compression_options, + ) + + scheme, sep, rest = request.url.partition(":") + scheme = {"ws": "http", "wss": "https"}[scheme] request.url = scheme + sep + rest - request.headers.update({ - 'Upgrade': 'websocket', - 'Connection': 'Upgrade', - 'Sec-WebSocket-Key': self.key, - 'Sec-WebSocket-Version': '13', - }) - if self.compression_options is not None: + request.headers.update( + { + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Key": self.key, + "Sec-WebSocket-Version": "13", + } + ) + if subprotocols is not None: + request.headers["Sec-WebSocket-Protocol"] = ",".join(subprotocols) + if compression_options is not None: # Always offer to let the server set our max_wbits (and even though # we don't offer it, we will accept a client_no_context_takeover # from the server). # TODO: set server parameters for deflate extension # if requested in self.compression_options. - request.headers['Sec-WebSocket-Extensions'] = ( - 'permessage-deflate; client_max_window_bits') + request.headers[ + "Sec-WebSocket-Extensions" + ] = "permessage-deflate; client_max_window_bits" - self.tcp_client = TCPClient() - super(WebSocketClientConnection, self).__init__( - None, request, lambda: None, self._on_http_response, - 104857600, self.tcp_client, 65536, 104857600) + # Websocket connection is currently unable to follow redirects + request.follow_redirects = False - def close(self, code=None, reason=None): + self.tcp_client = TCPClient() + super().__init__( + None, + request, + lambda: None, + self._on_http_response, + 104857600, + self.tcp_client, + 65536, + 104857600, + ) + + def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None: """Closes the websocket connection. ``code`` and ``reason`` are documented under @@ -1106,49 +1428,64 @@ def close(self, code=None, reason=None): """ if self.protocol is not None: self.protocol.close(code, reason) - self.protocol = None + self.protocol = None # type: ignore - def on_connection_close(self): + def on_connection_close(self) -> None: if not self.connect_future.done(): self.connect_future.set_exception(StreamClosedError()) - self.on_message(None) + self._on_message(None) self.tcp_client.close() - super(WebSocketClientConnection, self).on_connection_close() + super().on_connection_close() + + def on_ws_connection_close( + self, close_code: Optional[int] = None, close_reason: Optional[str] = None + ) -> None: + self.close_code = close_code + self.close_reason = close_reason + self.on_connection_close() - def _on_http_response(self, response): + def _on_http_response(self, response: httpclient.HTTPResponse) -> None: if not self.connect_future.done(): if response.error: self.connect_future.set_exception(response.error) else: - self.connect_future.set_exception(WebSocketError( - "Non-websocket response")) - - def headers_received(self, start_line, headers): + self.connect_future.set_exception( + WebSocketError("Non-websocket response") + ) + + async def headers_received( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + ) -> None: + assert isinstance(start_line, httputil.ResponseStartLine) if start_line.code != 101: - return super(WebSocketClientConnection, self).headers_received( - start_line, headers) + await super().headers_received(start_line, headers) + return + + if self._timeout is not None: + self.io_loop.remove_timeout(self._timeout) + self._timeout = None self.headers = headers self.protocol = self.get_websocket_protocol() self.protocol._process_server_headers(self.key, self.headers) - self.protocol.start_pinging() - IOLoop.current().add_callback(self.protocol._receive_frame_loop) + self.protocol.stream = self.connection.detach() - if self._timeout is not None: - self.io_loop.remove_timeout(self._timeout) - self._timeout = None + IOLoop.current().add_callback(self.protocol._receive_frame_loop) + self.protocol.start_pinging() - self.stream = self.connection.detach() - self.stream.set_close_callback(self.on_connection_close) # Once we've taken over the connection, clear the final callback # we set on the http request. This deactivates the error handling # in simple_httpclient that would otherwise interfere with our # ability to see exceptions. - self.final_callback = None + self.final_callback = None # type: ignore future_set_result_unless_cancelled(self.connect_future, self) - def write_message(self, message, binary=False): + def write_message( + self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False + ) -> "Future[None]": """Sends a message to the WebSocket server. If the stream is closed, raises `WebSocketClosedError`. @@ -1160,7 +1497,10 @@ def write_message(self, message, binary=False): """ return self.protocol.write_message(message, binary=binary) - def read_message(self, callback=None): + def read_message( + self, + callback: Optional[Callable[["Future[Union[None, str, bytes]]"], None]] = None, + ) -> Awaitable[Union[None, str, bytes]]: """Reads a message from the WebSocket server. If on_message_callback was specified at WebSocket @@ -1172,18 +1512,24 @@ def read_message(self, callback=None): ready. """ - future = self.read_queue.get() + awaitable = self.read_queue.get() if callback is not None: - self.io_loop.add_future(future, callback) - return future + self.io_loop.add_future(asyncio.ensure_future(awaitable), callback) + return awaitable - def on_message(self, message): + def on_message(self, message: Union[str, bytes]) -> Optional[Awaitable[None]]: + return self._on_message(message) + + def _on_message( + self, message: Union[None, str, bytes] + ) -> Optional[Awaitable[None]]: if self._on_message_callback: self._on_message_callback(message) + return None else: return self.read_queue.put(message) - def ping(self, data=b''): + def ping(self, data: bytes = b"") -> None: """Send ping frame to the remote end. The data argument allows a small amount of data (up to 125 @@ -1202,21 +1548,45 @@ def ping(self, data=b''): raise WebSocketClosedError() self.protocol.write_ping(data) - def on_pong(self, data): + def on_pong(self, data: bytes) -> None: pass - def on_ping(self, data): + def on_ping(self, data: bytes) -> None: pass - def get_websocket_protocol(self): - return WebSocketProtocol13(self, mask_outgoing=True, - compression_options=self.compression_options) + def get_websocket_protocol(self) -> WebSocketProtocol: + return WebSocketProtocol13(self, mask_outgoing=True, params=self.params) + @property + def selected_subprotocol(self) -> Optional[str]: + """The subprotocol selected by the server. -def websocket_connect(url, callback=None, connect_timeout=None, - on_message_callback=None, compression_options=None, - ping_interval=None, ping_timeout=None, - max_message_size=None): + .. versionadded:: 5.1 + """ + return self.protocol.selected_subprotocol + + def log_exception( + self, + typ: "Optional[Type[BaseException]]", + value: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + assert typ is not None + assert value is not None + app_log.error("Uncaught exception %s", value, exc_info=(typ, value, tb)) + + +def websocket_connect( + url: Union[str, httpclient.HTTPRequest], + callback: Optional[Callable[["Future[WebSocketClientConnection]"], None]] = None, + connect_timeout: Optional[float] = None, + on_message_callback: Optional[Callable[[Union[None, str, bytes]], None]] = None, + compression_options: Optional[Dict[str, Any]] = None, + ping_interval: Optional[float] = None, + ping_timeout: Optional[float] = None, + max_message_size: int = _default_max_message_size, + subprotocols: Optional[List[str]] = None, +) -> "Awaitable[WebSocketClientConnection]": """Client-side websocket support. Takes a url and returns a Future whose result is a @@ -1239,6 +1609,11 @@ def websocket_connect(url, callback=None, connect_timeout=None, ``websocket_connect``. In both styles, a message of ``None`` indicates that the connection has been closed. + ``subprotocols`` may be a list of strings specifying proposed + subprotocols. The selected protocol may be found on the + ``selected_subprotocol`` attribute of the connection object + when the connection is complete. + .. versionchanged:: 3.2 Also accepts ``HTTPRequest`` objects in place of urls. @@ -1251,6 +1626,9 @@ def websocket_connect(url, callback=None, connect_timeout=None, .. versionchanged:: 5.0 The ``io_loop`` argument (deprecated since version 4.1) has been removed. + + .. versionchanged:: 5.1 + Added the ``subprotocols`` argument. """ if isinstance(url, httpclient.HTTPRequest): assert connect_timeout is None @@ -1260,14 +1638,19 @@ def websocket_connect(url, callback=None, connect_timeout=None, request.headers = httputil.HTTPHeaders(request.headers) else: request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout) - request = httpclient._RequestProxy( - request, httpclient.HTTPRequest._DEFAULTS) - conn = WebSocketClientConnection(request, - on_message_callback=on_message_callback, - compression_options=compression_options, - ping_interval=ping_interval, - ping_timeout=ping_timeout, - max_message_size=max_message_size) + request = cast( + httpclient.HTTPRequest, + httpclient._RequestProxy(request, httpclient.HTTPRequest._DEFAULTS), + ) + conn = WebSocketClientConnection( + request, + on_message_callback=on_message_callback, + compression_options=compression_options, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + max_message_size=max_message_size, + subprotocols=subprotocols, + ) if callback is not None: IOLoop.current().add_future(conn.connect_future, callback) return conn.connect_future diff --git a/tornado/wsgi.py b/tornado/wsgi.py index 22be7a8972..55ece9a22b 100644 --- a/tornado/wsgi.py +++ b/tornado/wsgi.py @@ -16,215 +16,41 @@ """WSGI support for the Tornado web framework. WSGI is the Python standard for web servers, and allows for interoperability -between Tornado and other Python web frameworks and servers. This module -provides WSGI support in two ways: +between Tornado and other Python web frameworks and servers. -* `WSGIAdapter` converts a `tornado.web.Application` to the WSGI application - interface. This is useful for running a Tornado app on another - HTTP server, such as Google App Engine. See the `WSGIAdapter` class - documentation for limitations that apply. -* `WSGIContainer` lets you run other WSGI applications and frameworks on the - Tornado HTTP server. For example, with this class you can mix Django - and Tornado handlers in a single server. -""" +This module provides WSGI support via the `WSGIContainer` class, which +makes it possible to run applications using other WSGI frameworks on +the Tornado HTTP server. The reverse is not supported; the Tornado +`.Application` and `.RequestHandler` classes are designed for use with +the Tornado `.HTTPServer` and cannot be used in a generic WSGI +container. -from __future__ import absolute_import, division, print_function +""" import sys from io import BytesIO import tornado -from tornado.concurrent import Future from tornado import escape from tornado import httputil from tornado.log import access_log -from tornado import web -from tornado.escape import native_str -from tornado.util import unicode_type, PY3 +from typing import List, Tuple, Optional, Callable, Any, Dict, Text +from types import TracebackType +import typing + +if typing.TYPE_CHECKING: + from typing import Type # noqa: F401 + from wsgiref.types import WSGIApplication as WSGIAppType # noqa: F401 -if PY3: - import urllib.parse as urllib_parse # py3 -else: - import urllib as urllib_parse # PEP 3333 specifies that WSGI on python 3 generally deals with byte strings # that are smuggled inside objects of type unicode (via the latin1 encoding). -# These functions are like those in the tornado.escape module, but defined -# here to minimize the temptation to use them in non-wsgi contexts. -if str is unicode_type: - def to_wsgi_str(s): - assert isinstance(s, bytes) - return s.decode('latin1') - - def from_wsgi_str(s): - assert isinstance(s, str) - return s.encode('latin1') -else: - def to_wsgi_str(s): - assert isinstance(s, bytes) - return s - - def from_wsgi_str(s): - assert isinstance(s, str) - return s - - -class WSGIApplication(web.Application): - """A WSGI equivalent of `tornado.web.Application`. - - .. deprecated:: 4.0 - - Use a regular `.Application` and wrap it in `WSGIAdapter` instead. - """ - def __call__(self, environ, start_response): - return WSGIAdapter(self)(environ, start_response) - - -# WSGI has no facilities for flow control, so just return an already-done -# Future when the interface requires it. -_dummy_future = Future() -_dummy_future.set_result(None) - - -class _WSGIConnection(httputil.HTTPConnection): - def __init__(self, method, start_response, context): - self.method = method - self.start_response = start_response - self.context = context - self._write_buffer = [] - self._finished = False - self._expected_content_remaining = None - self._error = None - - def set_close_callback(self, callback): - # WSGI has no facility for detecting a closed connection mid-request, - # so we can simply ignore the callback. - pass - - def write_headers(self, start_line, headers, chunk=None, callback=None): - if self.method == 'HEAD': - self._expected_content_remaining = 0 - elif 'Content-Length' in headers: - self._expected_content_remaining = int(headers['Content-Length']) - else: - self._expected_content_remaining = None - self.start_response( - '%s %s' % (start_line.code, start_line.reason), - [(native_str(k), native_str(v)) for (k, v) in headers.get_all()]) - if chunk is not None: - self.write(chunk, callback) - elif callback is not None: - callback() - return _dummy_future - - def write(self, chunk, callback=None): - if self._expected_content_remaining is not None: - self._expected_content_remaining -= len(chunk) - if self._expected_content_remaining < 0: - self._error = httputil.HTTPOutputError( - "Tried to write more data than Content-Length") - raise self._error - self._write_buffer.append(chunk) - if callback is not None: - callback() - return _dummy_future - - def finish(self): - if (self._expected_content_remaining is not None and - self._expected_content_remaining != 0): - self._error = httputil.HTTPOutputError( - "Tried to write %d bytes less than Content-Length" % - self._expected_content_remaining) - raise self._error - self._finished = True - - -class _WSGIRequestContext(object): - def __init__(self, remote_ip, protocol): - self.remote_ip = remote_ip - self.protocol = protocol - - def __str__(self): - return self.remote_ip - - -class WSGIAdapter(object): - """Converts a `tornado.web.Application` instance into a WSGI application. - - Example usage:: - - import tornado.web - import tornado.wsgi - import wsgiref.simple_server - - class MainHandler(tornado.web.RequestHandler): - def get(self): - self.write("Hello, world") - - if __name__ == "__main__": - application = tornado.web.Application([ - (r"/", MainHandler), - ]) - wsgi_app = tornado.wsgi.WSGIAdapter(application) - server = wsgiref.simple_server.make_server('', 8888, wsgi_app) - server.serve_forever() - - See the `appengine demo - `_ - for an example of using this module to run a Tornado app on Google - App Engine. - - In WSGI mode asynchronous methods are not supported. This means - that it is not possible to use `.AsyncHTTPClient`, or the - `tornado.auth` or `tornado.websocket` modules. - - .. versionadded:: 4.0 - """ - def __init__(self, application): - if isinstance(application, WSGIApplication): - self.application = lambda request: web.Application.__call__( - application, request) - else: - self.application = application - - def __call__(self, environ, start_response): - method = environ["REQUEST_METHOD"] - uri = urllib_parse.quote(from_wsgi_str(environ.get("SCRIPT_NAME", ""))) - uri += urllib_parse.quote(from_wsgi_str(environ.get("PATH_INFO", ""))) - if environ.get("QUERY_STRING"): - uri += "?" + environ["QUERY_STRING"] - headers = httputil.HTTPHeaders() - if environ.get("CONTENT_TYPE"): - headers["Content-Type"] = environ["CONTENT_TYPE"] - if environ.get("CONTENT_LENGTH"): - headers["Content-Length"] = environ["CONTENT_LENGTH"] - for key in environ: - if key.startswith("HTTP_"): - headers[key[5:].replace("_", "-")] = environ[key] - if headers.get("Content-Length"): - body = environ["wsgi.input"].read( - int(headers["Content-Length"])) - else: - body = b"" - protocol = environ["wsgi.url_scheme"] - remote_ip = environ.get("REMOTE_ADDR", "") - if environ.get("HTTP_HOST"): - host = environ["HTTP_HOST"] - else: - host = environ["SERVER_NAME"] - connection = _WSGIConnection(method, start_response, - _WSGIRequestContext(remote_ip, protocol)) - request = httputil.HTTPServerRequest( - method, uri, "HTTP/1.1", headers=headers, body=body, - host=host, connection=connection) - request._parse_body() - self.application(request) - if connection._error: - raise connection._error - if not connection._finished: - raise Exception("request did not finish synchronously") - return connection._write_buffer +# This function is like those in the tornado.escape module, but defined +# here to minimize the temptation to use it in non-wsgi contexts. +def to_wsgi_str(s: bytes) -> str: + assert isinstance(s, bytes) + return s.decode("latin1") class WSGIContainer(object): @@ -247,7 +73,7 @@ def simple_app(environ, start_response): status = "200 OK" response_headers = [("Content-type", "text/plain")] start_response(status, response_headers) - return ["Hello world!\n"] + return [b"Hello world!\n"] container = tornado.wsgi.WSGIContainer(simple_app) http_server = tornado.httpserver.HTTPServer(container) @@ -261,31 +87,44 @@ def simple_app(environ, start_response): Tornado and WSGI apps in the same server. See https://github.com/bdarnell/django-tornado-demo for a complete example. """ - def __init__(self, wsgi_application): - self.wsgi_application = wsgi_application - def __call__(self, request): - data = {} - response = [] + def __init__(self, wsgi_application: "WSGIAppType") -> None: + self.wsgi_application = wsgi_application - def start_response(status, response_headers, exc_info=None): + def __call__(self, request: httputil.HTTPServerRequest) -> None: + data = {} # type: Dict[str, Any] + response = [] # type: List[bytes] + + def start_response( + status: str, + headers: List[Tuple[str, str]], + exc_info: Optional[ + Tuple[ + "Optional[Type[BaseException]]", + Optional[BaseException], + Optional[TracebackType], + ] + ] = None, + ) -> Callable[[bytes], Any]: data["status"] = status - data["headers"] = response_headers + data["headers"] = headers return response.append + app_response = self.wsgi_application( - WSGIContainer.environ(request), start_response) + WSGIContainer.environ(request), start_response + ) try: response.extend(app_response) body = b"".join(response) finally: if hasattr(app_response, "close"): - app_response.close() + app_response.close() # type: ignore if not data: raise Exception("WSGI app did not call start_response") - status_code, reason = data["status"].split(' ', 1) - status_code = int(status_code) - headers = data["headers"] + status_code_str, reason = data["status"].split(" ", 1) + status_code = int(status_code_str) + headers = data["headers"] # type: List[Tuple[str, str]] header_set = set(k.lower() for (k, v) in headers) body = escape.utf8(body) if status_code != 304: @@ -300,14 +139,14 @@ def start_response(status, response_headers, exc_info=None): header_obj = httputil.HTTPHeaders() for key, value in headers: header_obj.add(key, value) + assert request.connection is not None request.connection.write_headers(start_line, header_obj, chunk=body) request.connection.finish() self._log(status_code, request) @staticmethod - def environ(request): - """Converts a `tornado.httputil.HTTPServerRequest` to a WSGI environment. - """ + def environ(request: httputil.HTTPServerRequest) -> Dict[Text, Any]: + """Converts a `tornado.httputil.HTTPServerRequest` to a WSGI environment.""" hostport = request.host.split(":") if len(hostport) == 2: host = hostport[0] @@ -318,8 +157,9 @@ def environ(request): environ = { "REQUEST_METHOD": request.method, "SCRIPT_NAME": "", - "PATH_INFO": to_wsgi_str(escape.url_unescape( - request.path, encoding=None, plus=False)), + "PATH_INFO": to_wsgi_str( + escape.url_unescape(request.path, encoding=None, plus=False) + ), "QUERY_STRING": request.query, "REMOTE_ADDR": request.remote_ip, "SERVER_NAME": host, @@ -341,7 +181,7 @@ def environ(request): environ["HTTP_" + key.replace("-", "_").upper()] = value return environ - def _log(self, status_code, request): + def _log(self, status_code: int, request: httputil.HTTPServerRequest) -> None: if status_code < 400: log_method = access_log.info elif status_code < 500: @@ -349,8 +189,9 @@ def _log(self, status_code, request): else: log_method = access_log.error request_time = 1000.0 * request.request_time() - summary = request.method + " " + request.uri + " (" + \ - request.remote_ip + ")" + assert request.method is not None + assert request.uri is not None + summary = request.method + " " + request.uri + " (" + request.remote_ip + ")" log_method("%d %s %.2fms", status_code, summary, request_time) diff --git a/tox.ini b/tox.ini index ddc6af28a5..6c267598d3 100644 --- a/tox.ini +++ b/tox.ini @@ -12,118 +12,85 @@ # libcurl. [tox] envlist = - # Basic configurations: Run the tests in both minimal installations - # and with all optional dependencies. - # (pypy3 doesn't have any optional deps yet) - {py27,pypy,py34,py35,py36,pypy3}, - {py27,pypy,py34,py35,py36}-full, + # Basic configurations: Run the tests for each python version. + py36-full,py37-full,py38-full,py39-full,pypy3-full - # Also run the tests with each possible replacement of a default - # component. Run each test on both python 2 and 3 where possible. - # (Only one 2.x and one 3.x unless there are known differences). - # py2 and py3 are aliases for py27-full and py35-full. + # Build and test the docs with sphinx. + docs - # Alternate HTTP clients. - {py2,py3}-curl, + # Run the linters. + lint - # Alternate IOLoops. - py2-select, - py2-full-twisted, - py2-twistedlayered, - - # Alternate Resolvers. - {py2,py3}-full-caresresolver, - - # Other configurations; see comments below. - py2-monotonic, - {py2,py3}-opt, - py3-{lang_c,lang_utf8}, - py2-locale, - {py27,py3}-unittest2, - - # Ensure the sphinx build has no errors or warnings - py3-sphinx-docs, - # Run the doctests via sphinx (which covers things not run - # in the regular test suite and vice versa) - {py2,py3}-sphinx-doctest, - - py3-lint +# Allow shell commands in tests +whitelist_externals = /bin/sh [testenv] -# Most of these are defaults, but if you specify any you can't fall back -# defaults for the others. basepython = - py27: python2.7 - py34: python3.4 - py35: python3.5 + py3: python3 py36: python3.6 - pypy: pypy + py37: python3.7 + py38: python3.8 + py39: python3.9 pypy3: pypy3 - py2: python2.7 - py3: python3.6 + # In theory, it doesn't matter which python version is used here. + # In practice, things like changes to the ast module can alter + # the outputs of the tools (especially where exactly the + # linter warning-suppression comments go), so we specify a + # python version for these builds. + docs: python3.8 + lint: python3.8 deps = - # unittest2 doesn't add anything we need on 2.7+, but we should ensure that - # its existence doesn't break anything due to conditional imports. - py27-unittest2: unittest2 - py3-unittest2: unittest2py3k - # cpython-only deps: pycurl installs but curl_httpclient doesn't work; - # twisted mostly works but is a bit flaky under pypy. - {py27,py34,py35,py36}-full: pycurl - {py2,py3}: pycurl>=7.19.3.1 - # twisted is cpython only. - {py27,py34,py35,py36}-full: twisted - {py2,py3}: twisted - {py2,py3,py27,py34,py35,py36}-full: pycares - # mock became standard in py33 - {py2,py27,pypy,pypy3}-full: mock - # singledispatch became standard in py34. - {py2,py27,pypy}-full: singledispatch - py2-monotonic: monotonic - sphinx: sphinx - sphinx: sphinx_rtd_theme - lint: flake8 + full: pycurl + full: twisted + full: pycares + docs: -r{toxinidir}/docs/requirements.txt + lint: -r{toxinidir}/maint/requirements.txt setenv = - # The extension is mandatory on cpython. - {py2,py27,py3,py34,py35,py36}: TORNADO_EXTENSION=1 - # In python 3, opening files in text mode uses a - # system-dependent encoding by default. Run the tests with "C" - # (ascii) and "utf-8" locales to ensure we don't have hidden - # dependencies on this setting. - lang_c: LANG=C - lang_utf8: LANG=en_US.utf-8 - # tox's parser chokes if all the setenv entries are conditional. - DUMMY=dummy - {py2,py27,py3,py34,py35,py36}-no-ext: TORNADO_EXTENSION=0 + # Treat the extension as mandatory in testing (but not on pypy) + {py3,py36,py37,py38,py39}: TORNADO_EXTENSION=1 + # CI workers are often overloaded and can cause our tests to exceed + # the default timeout of 5s. + ASYNC_TEST_TIMEOUT=25 + # Treat warnings as errors by default. We have a whitelist of + # allowed warnings in runtests.py, but we want to be strict + # about any import-time warnings before that setup code is + # reached. Note that syntax warnings are only reported in + # -opt builds because regular builds reuse pycs created + # during sdist installation (and it doesn't seem to be + # possible to set environment variables during that phase of + # tox). + {py3,py36,py37,py38,py39,pypy3}: PYTHONWARNINGS=error:::tornado + # All non-comment lines but the last must end in a backslash. # Tox filters line-by-line based on the environment name. commands = - python \ # py3*: -b turns on an extra warning when calling # str(bytes), and -bb makes it an error. - {py3,py34,py35,py36,pypy3}: -bb \ + python -bb -m tornado.test {posargs:} # Python's optimized mode disables the assert statement, so # run the tests in this mode to ensure we haven't fallen into # the trap of relying on an assertion's side effects or using # them for things that should be runtime errors. - opt: -O \ - -m tornado.test.runtests \ + full: python -O -m tornado.test + # In python 3, opening files in text mode uses a + # system-dependent encoding by default. Run the tests with "C" + # (ascii) and "utf-8" locales to ensure we don't have hidden + # dependencies on this setting. + full: sh -c 'LANG=C python -m tornado.test' + full: sh -c 'LANG=en_US.utf-8 python -m tornado.test' # Note that httpclient_test is always run with both client # implementations; this flag controls which client all the # other tests use. - curl: --httpclient=tornado.curl_httpclient.CurlAsyncHTTPClient \ - poll: --ioloop=tornado.ioloop.PollIOLoop \ - select: --ioloop=tornado.platform.select.SelectIOLoop \ - twisted: --ioloop=tornado.platform.twisted.TwistedIOLoop \ - twistedlayered: --ioloop=tornado.test.twisted_test.LayeredTwistedIOLoop --resolver=tornado.platform.twisted.TwistedResolver \ - caresresolver: --resolver=tornado.platform.caresresolver.CaresResolver \ - threadedresolver: --resolver=tornado.netutil.ThreadedResolver \ - monotonic: --ioloop=tornado.ioloop.PollIOLoop --ioloop_time_monotonic \ - # Test with a non-english locale to uncover str/bytes mixing issues. - locale: --locale=zh_TW \ - {posargs:} + full: python -m tornado.test --httpclient=tornado.curl_httpclient.CurlAsyncHTTPClient + full: python -m tornado.test --resolver=tornado.platform.caresresolver.CaresResolver + # Run the tests once from the source directory to detect issues + # involving relative __file__ paths; see + # https://github.com/tornadoweb/tornado/issues/1780 + full: sh -c '(cd {toxinidir} && unset TORNADO_EXTENSION && python -m tornado.test)' + # python will import relative to the current working directory by default, # so cd into the tox working directory to avoid picking up the working @@ -131,31 +98,28 @@ commands = changedir = {toxworkdir} # tox 1.6 passes --pre to pip by default, which currently has problems -# installing pycurl and monotime (https://github.com/pypa/pip/issues/1405). +# installing pycurl (https://github.com/pypa/pip/issues/1405). # Remove it (it's not a part of {opts}) to only install real releases. install_command = pip install {opts} {packages} -[testenv:py3-sphinx-docs] +[testenv:docs] changedir = docs # For some reason the extension fails to load in this configuration, # but it's not really needed for docs anyway. setenv = TORNADO_EXTENSION=0 commands = + # Build the docs sphinx-build -q -E -n -W -b html . {envtmpdir}/html + # Run the doctests. No -W for doctests because that disallows tests + # with empty output. + sphinx-build -q -E -n -b doctest . {envtmpdir}/doctest -[testenv:py2-sphinx-doctest] -changedir = docs -setenv = TORNADO_EXTENSION=0 -# No -W for doctests because that disallows tests with empty output. -commands = - sphinx-build -q -E -n -b doctest . {envtmpdir}/doctest - -[testenv:py3-sphinx-doctest] -changedir = docs -setenv = TORNADO_EXTENSION=0 +[testenv:lint] commands = - sphinx-build -q -E -n -b doctest . {envtmpdir}/doctest - -[testenv:py3-lint] -commands = flake8 {posargs:} + flake8 {posargs:} + black --check --diff {posargs:tornado demos} + # Many syscalls are defined differently on linux and windows, + # so we have to typecheck both. + mypy --platform linux {posargs:tornado} + mypy --platform windows {posargs:tornado} changedir = {toxinidir}