-
+
diff --git a/demos/webspider/webspider.py b/demos/webspider/webspider.py
index dd8e6b385b..16c3840fa7 100755
--- a/demos/webspider/webspider.py
+++ b/demos/webspider/webspider.py
@@ -1,42 +1,29 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
import time
from datetime import timedelta
-try:
- from HTMLParser import HTMLParser
- from urlparse import urljoin, urldefrag
-except ImportError:
- from html.parser import HTMLParser
- from urllib.parse import urljoin, urldefrag
+from html.parser import HTMLParser
+from urllib.parse import urljoin, urldefrag
-from tornado import httpclient, gen, ioloop, queues
+from tornado import gen, httpclient, ioloop, queues
-base_url = 'http://www.tornadoweb.org/en/stable/'
+base_url = "http://www.tornadoweb.org/en/stable/"
concurrency = 10
-@gen.coroutine
-def get_links_from_url(url):
+async def get_links_from_url(url):
"""Download the page at `url` and parse it for links.
Returned links have had the fragment after `#` removed, and have been made
absolute so, e.g. the URL 'gen.html#tornado.gen.coroutine' becomes
'http://www.tornadoweb.org/en/stable/gen.html'.
"""
- try:
- response = yield httpclient.AsyncHTTPClient().fetch(url)
- print('fetched %s' % url)
+ response = await httpclient.AsyncHTTPClient().fetch(url)
+ print("fetched %s" % url)
- html = response.body if isinstance(response.body, str) \
- else response.body.decode(errors='ignore')
- urls = [urljoin(url, remove_fragment(new_url))
- for new_url in get_links(html)]
- except Exception as e:
- print('Exception: %s %s' % (e, url))
- raise gen.Return([])
-
- raise gen.Return(urls)
+ html = response.body.decode(errors="ignore")
+ return [urljoin(url, remove_fragment(new_url)) for new_url in get_links(html)]
def remove_fragment(url):
@@ -51,8 +38,8 @@ def __init__(self):
self.urls = []
def handle_starttag(self, tag, attrs):
- href = dict(attrs).get('href')
- if href and tag == 'a':
+ href = dict(attrs).get("href")
+ if href and tag == "a":
self.urls.append(href)
url_seeker = URLSeeker()
@@ -60,48 +47,52 @@ def handle_starttag(self, tag, attrs):
return url_seeker.urls
-@gen.coroutine
-def main():
+async def main():
q = queues.Queue()
start = time.time()
- fetching, fetched = set(), set()
-
- @gen.coroutine
- def fetch_url():
- current_url = yield q.get()
- try:
- if current_url in fetching:
- return
+ fetching, fetched, dead = set(), set(), set()
- print('fetching %s' % current_url)
- fetching.add(current_url)
- urls = yield get_links_from_url(current_url)
- fetched.add(current_url)
+ async def fetch_url(current_url):
+ if current_url in fetching:
+ return
- for new_url in urls:
- # Only follow links beneath the base URL
- if new_url.startswith(base_url):
- yield q.put(new_url)
+ print("fetching %s" % current_url)
+ fetching.add(current_url)
+ urls = await get_links_from_url(current_url)
+ fetched.add(current_url)
- finally:
- q.task_done()
+ for new_url in urls:
+ # Only follow links beneath the base URL
+ if new_url.startswith(base_url):
+ await q.put(new_url)
- @gen.coroutine
- def worker():
- while True:
- yield fetch_url()
+ async def worker():
+ async for url in q:
+ if url is None:
+ return
+ try:
+ await fetch_url(url)
+ except Exception as e:
+ print("Exception: %s %s" % (e, url))
+ dead.add(url)
+ finally:
+ q.task_done()
- q.put(base_url)
+ await q.put(base_url)
# Start workers, then wait for the work queue to be empty.
+ workers = gen.multi([worker() for _ in range(concurrency)])
+ await q.join(timeout=timedelta(seconds=300))
+ assert fetching == (fetched | dead)
+ print("Done in %d seconds, fetched %s URLs." % (time.time() - start, len(fetched)))
+ print("Unable to fetch %s URLS." % len(dead))
+
+ # Signal all the workers to exit.
for _ in range(concurrency):
- worker()
- yield q.join(timeout=timedelta(seconds=300))
- assert fetching == fetched
- print('Done in %d seconds, fetched %s URLs.' % (
- time.time() - start, len(fetched)))
+ await q.put(None)
+ await workers
-if __name__ == '__main__':
+if __name__ == "__main__":
io_loop = ioloop.IOLoop.current()
io_loop.run_sync(main)
diff --git a/docs/caresresolver.rst b/docs/caresresolver.rst
index b5d6ddd101..4e0058eac0 100644
--- a/docs/caresresolver.rst
+++ b/docs/caresresolver.rst
@@ -18,3 +18,7 @@ wrapper ``pycares``).
so it is only recommended for use in ``AF_INET`` (i.e. IPv4). This is
the default for ``tornado.simple_httpclient``, but other libraries
may default to ``AF_UNSPEC``.
+
+ .. deprecated:: 6.2
+ This class is deprecated and will be removed in Tornado 7.0. Use the default
+ thread-based resolver instead.
diff --git a/docs/concurrent.rst b/docs/concurrent.rst
index 5904a60b89..f7a855a38f 100644
--- a/docs/concurrent.rst
+++ b/docs/concurrent.rst
@@ -11,9 +11,7 @@
.. class:: Future
- ``tornado.concurrent.Future`` is an alias for `asyncio.Future`
- on Python 3. On Python 2, it provides an equivalent
- implementation.
+ ``tornado.concurrent.Future`` is an alias for `asyncio.Future`.
In Tornado, the main way in which applications interact with
``Future`` objects is by ``awaiting`` or ``yielding`` them in
diff --git a/docs/conf.py b/docs/conf.py
index 39345d282f..efa1c01d03 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -1,14 +1,15 @@
-# Ensure we get the local copy of tornado instead of what's on the standard path
import os
+import sphinx.errors
import sys
-import time
+
+# Ensure we get the local copy of tornado instead of what's on the standard path
sys.path.insert(0, os.path.abspath(".."))
import tornado
master_doc = "index"
project = "Tornado"
-copyright = "2009-%s, The Tornado Authors" % time.strftime("%Y")
+copyright = "The Tornado Authors"
version = release = tornado.version
@@ -18,10 +19,11 @@
"sphinx.ext.doctest",
"sphinx.ext.intersphinx",
"sphinx.ext.viewcode",
+ "sphinxcontrib.asyncio",
]
-primary_domain = 'py'
-default_role = 'py:obj'
+primary_domain = "py"
+default_role = "py:obj"
autodoc_member_order = "bysource"
autoclass_content = "both"
@@ -37,22 +39,18 @@
"tornado.platform.asyncio",
"tornado.platform.caresresolver",
"tornado.platform.twisted",
+ "tornado.simple_httpclient",
]
# I wish this could go in a per-module file...
coverage_ignore_classes = [
# tornado.gen
"Runner",
-
- # tornado.ioloop
- "PollIOLoop",
-
# tornado.web
"ChunkedTransferEncoding",
"GZipContentEncoding",
"OutputTransform",
"TemplateModule",
"url",
-
# tornado.websocket
"WebSocketProtocol",
"WebSocketProtocol13",
@@ -63,32 +61,83 @@
# various modules
"doctests",
"main",
-
# tornado.escape
# parse_qs_bytes should probably be documented but it's complicated by
# having different implementations between py2 and py3.
"parse_qs_bytes",
-
# tornado.gen
"Multi",
]
-html_favicon = 'favicon.ico'
+html_favicon = "favicon.ico"
latex_documents = [
- ('index', 'tornado.tex', 'Tornado Documentation', 'The Tornado Authors', 'manual', False),
+ (
+ "index",
+ "tornado.tex",
+ "Tornado Documentation",
+ "The Tornado Authors",
+ "manual",
+ False,
+ )
]
-intersphinx_mapping = {
- 'python': ('https://docs.python.org/3.6/', None),
-}
+intersphinx_mapping = {"python": ("https://docs.python.org/3/", None)}
-on_rtd = os.environ.get('READTHEDOCS', None) == 'True'
+on_rtd = os.environ.get("READTHEDOCS", None) == "True"
# On RTD we can't import sphinx_rtd_theme, but it will be applied by
# default anyway. This block will use the same theme when building locally
# as on RTD.
if not on_rtd:
import sphinx_rtd_theme
- html_theme = 'sphinx_rtd_theme'
+
+ html_theme = "sphinx_rtd_theme"
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
+
+# Suppress warnings about "class reference target not found" for these types.
+# In most cases these types come from type annotations and are for mypy's use.
+missing_references = {
+ # Generic type variables; nothing to link to.
+ "_IOStreamType",
+ "_S",
+ "_T",
+ # Standard library types which are defined in one module and documented
+ # in another. We could probably remap them to their proper location if
+ # there's not an upstream fix in python and/or sphinx.
+ "_asyncio.Future",
+ "_io.BytesIO",
+ "asyncio.AbstractEventLoop.run_forever",
+ "asyncio.events.AbstractEventLoop",
+ "concurrent.futures._base.Executor",
+ "concurrent.futures._base.Future",
+ "futures.Future",
+ "socket.socket",
+ "TextIO",
+ # Other stuff. I'm not sure why some of these are showing up, but
+ # I'm just listing everything here to avoid blocking the upgrade of sphinx.
+ "Future",
+ "httputil.HTTPServerConnectionDelegate",
+ "httputil.HTTPServerRequest",
+ "OutputTransform",
+ "Pattern",
+ "RAISE",
+ "Rule",
+ "tornado.ioloop._Selectable",
+ "tornado.locks._ReleasingContextManager",
+ "tornado.options._Mockable",
+ "tornado.web._ArgDefaultMarker",
+ "tornado.web._HandlerDelegate",
+ "traceback",
+ "WSGIAppType",
+ "Yieldable",
+}
+
+
+def missing_reference_handler(app, env, node, contnode):
+ if node["reftarget"] in missing_references:
+ raise sphinx.errors.NoUri
+
+
+def setup(app):
+ app.connect("missing-reference", missing_reference_handler)
diff --git a/docs/escape.rst b/docs/escape.rst
index 54f1ca9d2d..2a03eddb38 100644
--- a/docs/escape.rst
+++ b/docs/escape.rst
@@ -17,19 +17,15 @@
Byte/unicode conversions
------------------------
- These functions are used extensively within Tornado itself,
- but should not be directly needed by most applications. Note that
- much of the complexity of these functions comes from the fact that
- Tornado supports both Python 2 and Python 3.
.. autofunction:: utf8
.. autofunction:: to_unicode
.. function:: native_str
+ .. function:: to_basestring
- Converts a byte or unicode string into type `str`. Equivalent to
- `utf8` on Python 2 and `to_unicode` on Python 3.
-
- .. autofunction:: to_basestring
+ Converts a byte or unicode string into type `str`. These functions
+ were used to help transition from Python 2 to Python 3 but are now
+ deprecated aliases for `to_unicode`.
.. autofunction:: recursive_unicode
diff --git a/docs/faq.rst b/docs/faq.rst
index d45173c86b..1628073a32 100644
--- a/docs/faq.rst
+++ b/docs/faq.rst
@@ -28,14 +28,13 @@ No matter what the real code is doing, to achieve concurrency blocking
code must be replaced with non-blocking equivalents. This means one of three things:
1. *Find a coroutine-friendly equivalent.* For `time.sleep`, use
- `tornado.gen.sleep` instead::
+ `tornado.gen.sleep` (or `asyncio.sleep`) instead::
class CoroutineSleepHandler(RequestHandler):
- @gen.coroutine
- def get(self):
+ async def get(self):
for i in range(5):
print(i)
- yield gen.sleep(1)
+ await gen.sleep(1)
When this option is available, it is usually the best approach.
See the `Tornado wiki `_
@@ -44,16 +43,16 @@ code must be replaced with non-blocking equivalents. This means one of three thi
2. *Find a callback-based equivalent.* Similar to the first option,
callback-based libraries are available for many tasks, although they
are slightly more complicated to use than a library designed for
- coroutines. These are typically used with `tornado.gen.Task` as an
- adapter::
+ coroutines. Adapt the callback-based function into a future::
class CoroutineTimeoutHandler(RequestHandler):
- @gen.coroutine
- def get(self):
+ async def get(self):
io_loop = IOLoop.current()
for i in range(5):
print(i)
- yield gen.Task(io_loop.add_timeout, io_loop.time() + 1)
+ f = tornado.concurrent.Future()
+ do_something_with_callback(f.set_result)
+ result = await f
Again, the
`Tornado wiki `_
@@ -65,21 +64,18 @@ code must be replaced with non-blocking equivalents. This means one of three thi
that can be used for any blocking function whether an asynchronous
counterpart exists or not::
- executor = concurrent.futures.ThreadPoolExecutor(8)
-
class ThreadPoolHandler(RequestHandler):
- @gen.coroutine
- def get(self):
+ async def get(self):
for i in range(5):
print(i)
- yield executor.submit(time.sleep, 1)
+ await IOLoop.current().run_in_executor(None, time.sleep, 1)
See the :doc:`Asynchronous I/O ` chapter of the Tornado
user's guide for more on blocking and asynchronous functions.
-My code is asynchronous, but it's not running in parallel in two browser tabs.
-------------------------------------------------------------------------------
+My code is asynchronous. Why is it not running in parallel in two browser tabs?
+-------------------------------------------------------------------------------
Even when a handler is asynchronous and non-blocking, it can be surprisingly
tricky to verify this. Browsers will recognize that you are trying to
diff --git a/docs/gen.rst b/docs/gen.rst
index 0f0bfdeefb..4cb5a4f434 100644
--- a/docs/gen.rst
+++ b/docs/gen.rst
@@ -13,26 +13,21 @@
.. autofunction:: coroutine
- .. autofunction:: engine
+ .. autoexception:: Return
Utility functions
-----------------
- .. autoexception:: Return
-
- .. autofunction:: with_timeout
+ .. autofunction:: with_timeout(timeout: Union[float, datetime.timedelta], future: Yieldable, quiet_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = ())
.. autofunction:: sleep
- .. autodata:: moment
- :annotation:
-
.. autoclass:: WaitIterator
:members:
- .. autofunction:: multi
+ .. autofunction:: multi(Union[List[Yieldable], Dict[Any, Yieldable]], quiet_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = ())
- .. autofunction:: multi_future
+ .. autofunction:: multi_future(Union[List[Yieldable], Dict[Any, Yieldable]], quiet_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = ())
.. autofunction:: convert_yielded
@@ -40,37 +35,5 @@
.. autofunction:: is_coroutine_function
- Legacy interface
- ----------------
-
- Before support for `Futures <.Future>` was introduced in Tornado 3.0,
- coroutines used subclasses of `YieldPoint` in their ``yield`` expressions.
- These classes are still supported but should generally not be used
- except for compatibility with older interfaces. None of these classes
- are compatible with native (``await``-based) coroutines.
-
- .. autoclass:: YieldPoint
- :members:
-
- .. autoclass:: Callback
-
- .. autoclass:: Wait
-
- .. autoclass:: WaitAll
-
- .. autoclass:: MultiYieldPoint
-
- .. autofunction:: Task
-
- .. class:: Arguments
-
- The result of a `Task` or `Wait` whose callback had more than one
- argument (or keyword arguments).
-
- The `Arguments` object is a `collections.namedtuple` and can be
- used either as a tuple ``(args, kwargs)`` or an object with attributes
- ``args`` and ``kwargs``.
-
- .. deprecated:: 5.1
-
- This class will be removed in 6.0.
+ .. autodata:: moment
+ :annotation:
diff --git a/docs/guide/async.rst b/docs/guide/async.rst
index 60f8a23b34..1b5526ffc5 100644
--- a/docs/guide/async.rst
+++ b/docs/guide/async.rst
@@ -53,6 +53,12 @@ transparent to its callers (systems like `gevent
comparable to asynchronous systems, but they do not actually make
things asynchronous).
+Asynchronous operations in Tornado generally return placeholder
+objects (``Futures``), with the exception of some low-level components
+like the `.IOLoop` that use callbacks. ``Futures`` are usually
+transformed into their result with the ``await`` or ``yield``
+keywords.
+
Examples
~~~~~~~~
@@ -70,65 +76,59 @@ Here is a sample synchronous function:
.. testoutput::
:hide:
-And here is the same function rewritten to be asynchronous with a
-callback argument:
+And here is the same function rewritten asynchronously as a native coroutine:
.. testcode::
- from tornado.httpclient import AsyncHTTPClient
+ from tornado.httpclient import AsyncHTTPClient
- def asynchronous_fetch(url, callback):
- http_client = AsyncHTTPClient()
- def handle_response(response):
- callback(response.body)
- http_client.fetch(url, callback=handle_response)
+ async def asynchronous_fetch(url):
+ http_client = AsyncHTTPClient()
+ response = await http_client.fetch(url)
+ return response.body
.. testoutput::
:hide:
-And again with a `.Future` instead of a callback:
+Or for compatibility with older versions of Python, using the `tornado.gen` module:
-.. testcode::
+.. testcode::
- from tornado.concurrent import Future
+ from tornado.httpclient import AsyncHTTPClient
+ from tornado import gen
- def async_fetch_future(url):
+ @gen.coroutine
+ def async_fetch_gen(url):
http_client = AsyncHTTPClient()
- my_future = Future()
- fetch_future = http_client.fetch(url)
- fetch_future.add_done_callback(
- lambda f: my_future.set_result(f.result()))
- return my_future
-
-.. testoutput::
- :hide:
+ response = yield http_client.fetch(url)
+ raise gen.Return(response.body)
-The raw `.Future` version is more complex, but ``Futures`` are
-nonetheless recommended practice in Tornado because they have two
-major advantages. Error handling is more consistent since the
-``Future.result`` method can simply raise an exception (as opposed to
-the ad-hoc error handling common in callback-oriented interfaces), and
-``Futures`` lend themselves well to use with coroutines. Coroutines
-will be discussed in depth in the next section of this guide. Here is
-the coroutine version of our sample function, which is very similar to
-the original synchronous version:
+Coroutines are a little magical, but what they do internally is something like this:
.. testcode::
- from tornado import gen
+ from tornado.concurrent import Future
- @gen.coroutine
- def fetch_coroutine(url):
+ def async_fetch_manual(url):
http_client = AsyncHTTPClient()
- response = yield http_client.fetch(url)
- raise gen.Return(response.body)
+ my_future = Future()
+ fetch_future = http_client.fetch(url)
+ def on_fetch(f):
+ my_future.set_result(f.result().body)
+ fetch_future.add_done_callback(on_fetch)
+ return my_future
.. testoutput::
:hide:
-The statement ``raise gen.Return(response.body)`` is an artifact of
-Python 2, in which generators aren't allowed to return
-values. To overcome this, Tornado coroutines raise a special kind of
-exception called a `.Return`. The coroutine catches this exception and
-treats it like a returned value. In Python 3.3 and later, a ``return
-response.body`` achieves the same result.
+Notice that the coroutine returns its `.Future` before the fetch is
+done. This is what makes coroutines *asynchronous*.
+
+Anything you can do with coroutines you can also do by passing
+callback objects around, but coroutines provide an important
+simplification by letting you organize your code in the same way you
+would if it were synchronous. This is especially important for error
+handling, since ``try``/``except`` blocks work as you would expect in
+coroutines while this is difficult to achieve with callbacks.
+Coroutines will be discussed in depth in the next section of this
+guide.
diff --git a/docs/guide/coroutines.rst b/docs/guide/coroutines.rst
index 8089c4f886..795631bf74 100644
--- a/docs/guide/coroutines.rst
+++ b/docs/guide/coroutines.rst
@@ -6,9 +6,9 @@ Coroutines
from tornado import gen
**Coroutines** are the recommended way to write asynchronous code in
-Tornado. Coroutines use the Python ``yield`` keyword to suspend and
-resume execution instead of a chain of callbacks (cooperative
-lightweight threads as seen in frameworks like `gevent
+Tornado. Coroutines use the Python ``await`` or ``yield`` keyword to
+suspend and resume execution instead of a chain of callbacks
+(cooperative lightweight threads as seen in frameworks like `gevent
`_ are sometimes called coroutines as well, but
in Tornado all coroutines use explicit context switches and are called
as asynchronous functions).
@@ -21,53 +21,80 @@ happen.
Example::
- from tornado import gen
-
- @gen.coroutine
- def fetch_coroutine(url):
+ async def fetch_coroutine(url):
http_client = AsyncHTTPClient()
- response = yield http_client.fetch(url)
- # In Python versions prior to 3.3, returning a value from
- # a generator is not allowed and you must use
- # raise gen.Return(response.body)
- # instead.
+ response = await http_client.fetch(url)
return response.body
.. _native_coroutines:
-Python 3.5: ``async`` and ``await``
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Native vs decorated coroutines
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-Python 3.5 introduces the ``async`` and ``await`` keywords (functions
-using these keywords are also called "native coroutines"). Starting in
-Tornado 4.3, you can use them in place of most ``yield``-based
-coroutines (see the following paragraphs for limitations). Simply use
-``async def foo()`` in place of a function definition with the
-``@gen.coroutine`` decorator, and ``await`` in place of yield. The
-rest of this document still uses the ``yield`` style for compatibility
-with older versions of Python, but ``async`` and ``await`` will run
-faster when they are available::
+Python 3.5 introduced the ``async`` and ``await`` keywords (functions
+using these keywords are also called "native coroutines"). For
+compatibility with older versions of Python, you can use "decorated"
+or "yield-based" coroutines using the `tornado.gen.coroutine`
+decorator.
- async def fetch_coroutine(url):
- http_client = AsyncHTTPClient()
- response = await http_client.fetch(url)
- return response.body
+Native coroutines are the recommended form whenever possible. Only use
+decorated coroutines when compatibility with older versions of Python
+is required. Examples in the Tornado documentation will generally use
+the native form.
-The ``await`` keyword is less versatile than the ``yield`` keyword.
-For example, in a ``yield``-based coroutine you can yield a list of
-``Futures``, while in a native coroutine you must wrap the list in
-`tornado.gen.multi`. This also eliminates the integration with
-`concurrent.futures`. You can use `tornado.gen.convert_yielded`
-to convert anything that would work with ``yield`` into a form that
-will work with ``await``::
+Translation between the two forms is generally straightforward::
- async def f():
- executor = concurrent.futures.ThreadPoolExecutor()
- await tornado.gen.convert_yielded(executor.submit(g))
+ # Decorated: # Native:
+
+ # Normal function declaration
+ # with decorator # "async def" keywords
+ @gen.coroutine
+ def a(): async def a():
+ # "yield" all async funcs # "await" all async funcs
+ b = yield c() b = await c()
+ # "return" and "yield"
+ # cannot be mixed in
+ # Python 2, so raise a
+ # special exception. # Return normally
+ raise gen.Return(b) return b
+
+Other differences between the two forms of coroutine are outlined below.
+
+- Native coroutines:
+
+ - are generally faster.
+ - can use ``async for`` and ``async with``
+ statements which make some patterns much simpler.
+ - do not run at all unless you ``await`` or
+ ``yield`` them. Decorated coroutines can start running "in the
+ background" as soon as they are called. Note that for both kinds of
+ coroutines it is important to use ``await`` or ``yield`` so that
+ any exceptions have somewhere to go.
+
+- Decorated coroutines:
+
+ - have additional integration with the
+ `concurrent.futures` package, allowing the result of
+ ``executor.submit`` to be yielded directly. For native coroutines,
+ use `.IOLoop.run_in_executor` instead.
+ - support some shorthand for waiting on multiple
+ objects by yielding a list or dict. Use `tornado.gen.multi` to do
+ this in native coroutines.
+ - can support integration with other packages
+ including Twisted via a registry of conversion functions.
+ To access this functionality in native coroutines, use
+ `tornado.gen.convert_yielded`.
+ - always return a `.Future` object. Native
+ coroutines return an *awaitable* object that is not a `.Future`. In
+ Tornado the two are mostly interchangeable.
How it works
~~~~~~~~~~~~
+This section explains the operation of decorated coroutines. Native
+coroutines are conceptually similar, but a little more complicated
+because of the extra integration with the Python runtime.
+
A function containing ``yield`` is a **generator**. All generators
are asynchronous; when called they return a generator object instead
of running to completion. The ``@gen.coroutine`` decorator
@@ -97,12 +124,11 @@ How to call a coroutine
~~~~~~~~~~~~~~~~~~~~~~~
Coroutines do not raise exceptions in the normal way: any exception
-they raise will be trapped in the `.Future` until it is yielded. This
-means it is important to call coroutines in the right way, or you may
-have errors that go unnoticed::
+they raise will be trapped in the awaitable object until it is
+yielded. This means it is important to call coroutines in the right
+way, or you may have errors that go unnoticed::
- @gen.coroutine
- def divide(x, y):
+ async def divide(x, y):
return x / y
def bad_call():
@@ -111,17 +137,16 @@ have errors that go unnoticed::
divide(1, 0)
In nearly all cases, any function that calls a coroutine must be a
-coroutine itself, and use the ``yield`` keyword in the call. When you
-are overriding a method defined in a superclass, consult the
-documentation to see if coroutines are allowed (the documentation
-should say that the method "may be a coroutine" or "may return a
-`.Future`")::
-
- @gen.coroutine
- def good_call():
- # yield will unwrap the Future returned by divide() and raise
+coroutine itself, and use the ``await`` or ``yield`` keyword in the
+call. When you are overriding a method defined in a superclass,
+consult the documentation to see if coroutines are allowed (the
+documentation should say that the method "may be a coroutine" or "may
+return a `.Future`")::
+
+ async def good_call():
+ # await will unwrap the object returned by divide() and raise
# the exception.
- yield divide(1, 0)
+ await divide(1, 0)
Sometimes you may want to "fire and forget" a coroutine without waiting
for its result. In this case it is recommended to use `.IOLoop.spawn_callback`,
@@ -156,49 +181,68 @@ The simplest way to call a blocking function from a coroutine is to
use `.IOLoop.run_in_executor`, which returns
``Futures`` that are compatible with coroutines::
- @gen.coroutine
- def call_blocking():
- yield IOLoop.current().run_in_executor(blocking_func, args)
+ async def call_blocking():
+ await IOLoop.current().run_in_executor(None, blocking_func, args)
Parallelism
^^^^^^^^^^^
-The `.coroutine` decorator recognizes lists and dicts whose values are
+The `.multi` function accepts lists and dicts whose values are
``Futures``, and waits for all of those ``Futures`` in parallel:
.. testcode::
- @gen.coroutine
- def parallel_fetch(url1, url2):
- resp1, resp2 = yield [http_client.fetch(url1),
- http_client.fetch(url2)]
+ from tornado.gen import multi
- @gen.coroutine
- def parallel_fetch_many(urls):
- responses = yield [http_client.fetch(url) for url in urls]
+ async def parallel_fetch(url1, url2):
+ resp1, resp2 = await multi([http_client.fetch(url1),
+ http_client.fetch(url2)])
+
+ async def parallel_fetch_many(urls):
+ responses = await multi ([http_client.fetch(url) for url in urls])
# responses is a list of HTTPResponses in the same order
- @gen.coroutine
- def parallel_fetch_dict(urls):
- responses = yield {url: http_client.fetch(url)
- for url in urls}
+ async def parallel_fetch_dict(urls):
+ responses = await multi({url: http_client.fetch(url)
+ for url in urls})
# responses is a dict {url: HTTPResponse}
.. testoutput::
:hide:
-Lists and dicts must be wrapped in `tornado.gen.multi` for use with
-``await``::
+In decorated coroutines, it is possible to ``yield`` the list or dict directly::
- async def parallel_fetch(url1, url2):
- resp1, resp2 = await gen.multi([http_client.fetch(url1),
- http_client.fetch(url2)])
+ @gen.coroutine
+ def parallel_fetch_decorated(url1, url2):
+ resp1, resp2 = yield [http_client.fetch(url1),
+ http_client.fetch(url2)]
Interleaving
^^^^^^^^^^^^
Sometimes it is useful to save a `.Future` instead of yielding it
-immediately, so you can start another operation before waiting:
+immediately, so you can start another operation before waiting.
+
+.. testcode::
+
+ from tornado.gen import convert_yielded
+
+ async def get(self):
+ # convert_yielded() starts the native coroutine in the background.
+ # This is equivalent to asyncio.ensure_future() (both work in Tornado).
+ fetch_future = convert_yielded(self.fetch_next_chunk())
+ while True:
+ chunk = await fetch_future
+ if chunk is None: break
+ self.write(chunk)
+ fetch_future = convert_yielded(self.fetch_next_chunk())
+ await self.flush()
+
+.. testoutput::
+ :hide:
+
+This is a little easier to do with decorated coroutines, because they
+start immediately when called:
.. testcode::
@@ -215,12 +259,6 @@ immediately, so you can start another operation before waiting:
.. testoutput::
:hide:
-This pattern is most usable with ``@gen.coroutine``. If
-``fetch_next_chunk()`` uses ``async def``, then it must be called as
-``fetch_future =
-tornado.gen.convert_yielded(self.fetch_next_chunk())`` to start the
-background processing.
-
Looping
^^^^^^^
@@ -247,11 +285,10 @@ Running in the background
coroutine can contain a ``while True:`` loop and use
`tornado.gen.sleep`::
- @gen.coroutine
- def minute_loop():
+ async def minute_loop():
while True:
- yield do_something()
- yield gen.sleep(60)
+ await do_something()
+ await gen.sleep(60)
# Coroutines that loop forever are generally started with
# spawn_callback().
@@ -262,9 +299,8 @@ previous loop runs every ``60+N`` seconds, where ``N`` is the running
time of ``do_something()``. To run exactly every 60 seconds, use the
interleaving pattern from above::
- @gen.coroutine
- def minute_loop2():
+ async def minute_loop2():
while True:
nxt = gen.sleep(60) # Start the clock.
- yield do_something() # Run while the clock is ticking.
- yield nxt # Wait for the timer to run out.
+ await do_something() # Run while the clock is ticking.
+ await nxt # Wait for the timer to run out.
diff --git a/docs/guide/intro.rst b/docs/guide/intro.rst
index d17e74420f..8d87ba62b2 100644
--- a/docs/guide/intro.rst
+++ b/docs/guide/intro.rst
@@ -20,12 +20,13 @@ Tornado can be roughly divided into four major components:
components and can also be used to implement other protocols.
* A coroutine library (`tornado.gen`) which allows asynchronous
code to be written in a more straightforward way than chaining
- callbacks.
+ callbacks. This is similar to the native coroutine feature introduced
+ in Python 3.5 (``async def``). Native coroutines are recommended
+ in place of the `tornado.gen` module when available.
The Tornado web framework and HTTP server together offer a full-stack
alternative to `WSGI `_.
-While it is possible to use the Tornado web framework in a WSGI
-container (`.WSGIAdapter`), or use the Tornado HTTP server as a
-container for other WSGI frameworks (`.WSGIContainer`), each of these
-combinations has limitations and to take full advantage of Tornado you
-will need to use the Tornado's web framework and HTTP server together.
+While it is possible to use the Tornado HTTP server as a container for
+other WSGI frameworks (`.WSGIContainer`), this combination has
+limitations and to take full advantage of Tornado you will need to use
+Tornado's web framework and HTTP server together.
diff --git a/docs/guide/running.rst b/docs/guide/running.rst
index e7c5b3494a..8cf34f0502 100644
--- a/docs/guide/running.rst
+++ b/docs/guide/running.rst
@@ -22,8 +22,9 @@ configuring a WSGI container to find your application, you write a
Configure your operating system or process manager to run this program to
start the server. Please note that it may be necessary to increase the number
of open files per process (to avoid "Too many open files"-Error).
-To raise this limit (setting it to 50000 for example) you can use the ulimit command,
-modify /etc/security/limits.conf or setting ``minfds`` in your supervisord config.
+To raise this limit (setting it to 50000 for example) you can use the
+``ulimit`` command, modify ``/etc/security/limits.conf`` or set
+``minfds`` in your `supervisord `_ config.
Processes and ports
~~~~~~~~~~~~~~~~~~~
@@ -33,8 +34,9 @@ multiple Python processes to take full advantage of multi-CPU machines.
Typically it is best to run one process per CPU.
Tornado includes a built-in multi-process mode to start several
-processes at once. This requires a slight alteration to the standard
-main function:
+processes at once (note that multi-process mode does not work on
+Windows). This requires a slight alteration to the standard main
+function:
.. testcode::
@@ -50,8 +52,8 @@ main function:
This is the easiest way to start multiple processes and have them all
share the same port, although it has some limitations. First, each
-child process will have its own IOLoop, so it is important that
-nothing touch the global IOLoop instance (even indirectly) before the
+child process will have its own ``IOLoop``, so it is important that
+nothing touches the global ``IOLoop`` instance (even indirectly) before the
fork. Second, it is difficult to do zero-downtime updates in this model.
Finally, since all the processes share the same port it is more difficult
to monitor them individually.
@@ -67,10 +69,10 @@ to present a single address to outside visitors.
Running behind a load balancer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-When running behind a load balancer like nginx, it is recommended to
-pass ``xheaders=True`` to the `.HTTPServer` constructor. This will tell
-Tornado to use headers like ``X-Real-IP`` to get the user's IP address
-instead of attributing all traffic to the balancer's IP address.
+When running behind a load balancer like `nginx `_,
+it is recommended to pass ``xheaders=True`` to the `.HTTPServer` constructor.
+This will tell Tornado to use headers like ``X-Real-IP`` to get the user's
+IP address instead of attributing all traffic to the balancer's IP address.
This is a barebones nginx config file that is structurally similar to
the one we use at FriendFeed. It assumes nginx and the Tornado servers
@@ -169,7 +171,7 @@ You can serve static files from Tornado by specifying the
], **settings)
This setting will automatically make all requests that start with
-``/static/`` serve from that static directory, e.g.,
+``/static/`` serve from that static directory, e.g.
``http://localhost:8888/static/foo.png`` will serve the file
``foo.png`` from the specified static directory. We also automatically
serve ``/robots.txt`` and ``/favicon.ico`` from the static directory
@@ -248,7 +250,7 @@ individual flag takes precedence):
server down in a way that debug mode cannot currently recover from.
* ``compiled_template_cache=False``: Templates will not be cached.
* ``static_hash_cache=False``: Static file hashes (used by the
- ``static_url`` function) will not be cached
+ ``static_url`` function) will not be cached.
* ``serve_traceback=True``: When an exception in a `.RequestHandler`
is not caught, an error page including a stack trace will be
generated.
@@ -273,41 +275,3 @@ On some platforms (including Windows and Mac OSX prior to 10.6), the
process cannot be updated "in-place", so when a code change is
detected the old server exits and a new one starts. This has been
known to confuse some IDEs.
-
-
-WSGI and Google App Engine
-~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-Tornado is normally intended to be run on its own, without a WSGI
-container. However, in some environments (such as Google App Engine),
-only WSGI is allowed and applications cannot run their own servers.
-In this case Tornado supports a limited mode of operation that does
-not support asynchronous operation but allows a subset of Tornado's
-functionality in a WSGI-only environment. The features that are
-not allowed in WSGI mode include coroutines, the ``@asynchronous``
-decorator, `.AsyncHTTPClient`, the ``auth`` module, and WebSockets.
-
-You can convert a Tornado `.Application` to a WSGI application
-with `tornado.wsgi.WSGIAdapter`. In this example, configure
-your WSGI container to find the ``application`` object:
-
-.. testcode::
-
- import tornado.web
- import tornado.wsgi
-
- class MainHandler(tornado.web.RequestHandler):
- def get(self):
- self.write("Hello, world")
-
- tornado_app = tornado.web.Application([
- (r"/", MainHandler),
- ])
- application = tornado.wsgi.WSGIAdapter(tornado_app)
-
-.. testoutput::
- :hide:
-
-See the `appengine example application
-`_ for a
-full-featured AppEngine app built on Tornado.
diff --git a/docs/guide/security.rst b/docs/guide/security.rst
index f71f323aed..b65cd3f370 100644
--- a/docs/guide/security.rst
+++ b/docs/guide/security.rst
@@ -169,7 +169,7 @@ not be appropriate for non-browser-based login schemes.
Check out the `Tornado Blog example application
`_ for a
complete example that uses authentication (and stores user data in a
-MySQL database).
+PostgreSQL database).
Third party authentication
~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -189,15 +189,14 @@ the Google credentials in a cookie for later access:
class GoogleOAuth2LoginHandler(tornado.web.RequestHandler,
tornado.auth.GoogleOAuth2Mixin):
- @tornado.gen.coroutine
- def get(self):
+ async def get(self):
if self.get_argument('code', False):
- user = yield self.get_authenticated_user(
+ user = await self.get_authenticated_user(
redirect_uri='http://your.site.com/auth/google',
code=self.get_argument('code'))
# Save the user with e.g. set_secure_cookie
else:
- yield self.authorize_redirect(
+ await self.authorize_redirect(
redirect_uri='http://your.site.com/auth/google',
client_id=self.settings['google_oauth']['key'],
scope=['profile', 'email'],
@@ -280,7 +279,7 @@ all requests::
For ``PUT`` and ``DELETE`` requests (as well as ``POST`` requests that
do not use form-encoded arguments), the XSRF token may also be passed
via an HTTP header named ``X-XSRFToken``. The XSRF cookie is normally
-set when ``xsrf_form_html`` is used, but in a pure-Javascript application
+set when ``xsrf_form_html`` is used, but in a pure-JavaScript application
that does not use any regular forms you may need to access
``self.xsrf_token`` manually (just reading the property is enough to
set the cookie as a side effect).
diff --git a/docs/guide/structure.rst b/docs/guide/structure.rst
index 0f20d33864..407edf414b 100644
--- a/docs/guide/structure.rst
+++ b/docs/guide/structure.rst
@@ -178,7 +178,7 @@ In addition to ``get()``/``post()``/etc, certain other methods in
necessary. On every request, the following sequence of calls takes
place:
-1. A new `.RequestHandler` object is created on each request
+1. A new `.RequestHandler` object is created on each request.
2. `~.RequestHandler.initialize()` is called with the initialization
arguments from the `.Application` configuration. ``initialize``
should typically just save the arguments passed into member
@@ -193,9 +193,8 @@ place:
etc. If the URL regular expression contains capturing groups, they
are passed as arguments to this method.
5. When the request is finished, `~.RequestHandler.on_finish()` is
- called. For synchronous handlers this is immediately after
- ``get()`` (etc) return; for asynchronous handlers it is after the
- call to `~.RequestHandler.finish()`.
+ called. This is generally after ``get()`` or another HTTP method
+ returns.
All methods designed to be overridden are noted as such in the
`.RequestHandler` documentation. Some of the most commonly
@@ -207,12 +206,12 @@ overridden methods include:
disconnects; applications may choose to detect this case and halt
further processing. Note that there is no guarantee that a closed
connection can be detected promptly.
-- `~.RequestHandler.get_current_user` - see :ref:`user-authentication`
+- `~.RequestHandler.get_current_user` - see :ref:`user-authentication`.
- `~.RequestHandler.get_user_locale` - returns `.Locale` object to use
- for the current user
+ for the current user.
- `~.RequestHandler.set_default_headers` - may be used to set
additional headers on the response (such as a custom ``Server``
- header)
+ header).
Error Handling
~~~~~~~~~~~~~~
@@ -261,7 +260,7 @@ redirect users elsewhere. There is also an optional parameter
considered permanent. The default value of ``permanent`` is
``False``, which generates a ``302 Found`` HTTP response code and is
appropriate for things like redirecting users after successful
-``POST`` requests. If ``permanent`` is true, the ``301 Moved
+``POST`` requests. If ``permanent`` is ``True``, the ``301 Moved
Permanently`` HTTP response code is used, which is useful for
e.g. redirecting to a canonical URL for a page in an SEO-friendly
manner.
@@ -295,65 +294,18 @@ To send a temporary redirect with a `.RedirectHandler`, add
Asynchronous handlers
~~~~~~~~~~~~~~~~~~~~~
-Tornado handlers are synchronous by default: when the
-``get()``/``post()`` method returns, the request is considered
-finished and the response is sent. Since all other requests are
-blocked while one handler is running, any long-running handler should
-be made asynchronous so it can call its slow operations in a
-non-blocking way. This topic is covered in more detail in
-:doc:`async`; this section is about the particulars of
-asynchronous techniques in `.RequestHandler` subclasses.
-
-The simplest way to make a handler asynchronous is to use the
-`.coroutine` decorator or ``async def``. This allows you to perform
-non-blocking I/O with the ``yield`` or ``await`` keywords, and no
-response will be sent until the coroutine has returned. See
-:doc:`coroutines` for more details.
-
-In some cases, coroutines may be less convenient than a
-callback-oriented style, in which case the `.tornado.web.asynchronous`
-decorator can be used instead. When this decorator is used the response
-is not automatically sent; instead the request will be kept open until
-some callback calls `.RequestHandler.finish`. It is up to the application
-to ensure that this method is called, or else the user's browser will
-simply hang.
-
-Here is an example that makes a call to the FriendFeed API using
-Tornado's built-in `.AsyncHTTPClient`:
+Certain handler methods (including ``prepare()`` and the HTTP verb
+methods ``get()``/``post()``/etc) may be overridden as coroutines to
+make the handler asynchronous.
-.. testcode::
-
- class MainHandler(tornado.web.RequestHandler):
- @tornado.web.asynchronous
- def get(self):
- http = tornado.httpclient.AsyncHTTPClient()
- http.fetch("http://friendfeed-api.com/v2/feed/bret",
- callback=self.on_response)
-
- def on_response(self, response):
- if response.error: raise tornado.web.HTTPError(500)
- json = tornado.escape.json_decode(response.body)
- self.write("Fetched " + str(len(json["entries"])) + " entries "
- "from the FriendFeed API")
- self.finish()
-
-.. testoutput::
- :hide:
-
-When ``get()`` returns, the request has not finished. When the HTTP
-client eventually calls ``on_response()``, the request is still open,
-and the response is finally flushed to the client with the call to
-``self.finish()``.
-
-For comparison, here is the same example using a coroutine:
+For example, here is a simple handler using a coroutine:
.. testcode::
class MainHandler(tornado.web.RequestHandler):
- @tornado.gen.coroutine
- def get(self):
+ async def get(self):
http = tornado.httpclient.AsyncHTTPClient()
- response = yield http.fetch("http://friendfeed-api.com/v2/feed/bret")
+ response = await http.fetch("http://friendfeed-api.com/v2/feed/bret")
json = tornado.escape.json_decode(response.body)
self.write("Fetched " + str(len(json["entries"])) + " entries "
"from the FriendFeed API")
diff --git a/docs/guide/templates.rst b/docs/guide/templates.rst
index 8755d25e13..61ce753e6a 100644
--- a/docs/guide/templates.rst
+++ b/docs/guide/templates.rst
@@ -66,9 +66,9 @@ directory as your Python file, you could render this template with:
:hide:
Tornado templates support *control statements* and *expressions*.
-Control statements are surrounded by ``{%`` and ``%}``, e.g.,
+Control statements are surrounded by ``{%`` and ``%}``, e.g.
``{% if len(items) > 2 %}``. Expressions are surrounded by ``{{`` and
-``}}``, e.g., ``{{ items[0] }}``.
+``}}``, e.g. ``{{ items[0] }}``.
Control statements more or less map exactly to Python statements. We
support ``if``, ``for``, ``while``, and ``try``, all of which are
@@ -78,7 +78,7 @@ detail in the documentation for the `tornado.template`.
Expressions can be any Python expression, including function calls.
Template code is executed in a namespace that includes the following
-objects and functions (Note that this list applies to templates
+objects and functions. (Note that this list applies to templates
rendered using `.RequestHandler.render` and
`~.RequestHandler.render_string`. If you're using the
`tornado.template` module directly outside of a `.RequestHandler` many
@@ -132,11 +132,12 @@ instead of ``None``.
Note that while Tornado's automatic escaping is helpful in avoiding
XSS vulnerabilities, it is not sufficient in all cases. Expressions
-that appear in certain locations, such as in Javascript or CSS, may need
+that appear in certain locations, such as in JavaScript or CSS, may need
additional escaping. Additionally, either care must be taken to always
use double quotes and `.xhtml_escape` in HTML attributes that may contain
untrusted content, or a separate escaping function must be used for
-attributes (see e.g. http://wonko.com/post/html-escaping)
+attributes (see e.g.
+`this blog post `_).
Internationalization
~~~~~~~~~~~~~~~~~~~~
@@ -212,7 +213,7 @@ formats: the ``.mo`` format used by `gettext` and related tools, and a
simple ``.csv`` format. An application will generally call either
`tornado.locale.load_translations` or
`tornado.locale.load_gettext_translations` once at startup; see those
-methods for more details on the supported formats..
+methods for more details on the supported formats.
You can get the list of supported locales in your application with
`tornado.locale.get_supported_locales()`. The user's locale is chosen
@@ -234,7 +235,7 @@ packaged with their own CSS and JavaScript.
For example, if you are implementing a blog, and you want to have blog
entries appear on both the blog home page and on each blog entry page,
you can make an ``Entry`` module to render them on both pages. First,
-create a Python module for your UI modules, e.g., ``uimodules.py``::
+create a Python module for your UI modules, e.g. ``uimodules.py``::
class Entry(tornado.web.UIModule):
def render(self, entry, show_comments=False):
diff --git a/docs/httpclient.rst b/docs/httpclient.rst
index cf3bc8e715..178dc1480a 100644
--- a/docs/httpclient.rst
+++ b/docs/httpclient.rst
@@ -47,7 +47,9 @@ Implementations
~~~~~~~~~~~~~~~
.. automodule:: tornado.simple_httpclient
- :members:
+
+ .. autoclass:: SimpleAsyncHTTPClient
+ :members:
.. module:: tornado.curl_httpclient
diff --git a/docs/httpserver.rst b/docs/httpserver.rst
index 88c74376bd..74d411ddd0 100644
--- a/docs/httpserver.rst
+++ b/docs/httpserver.rst
@@ -5,5 +5,8 @@
HTTP Server
-----------
- .. autoclass:: HTTPServer
+ .. autoclass:: HTTPServer(request_callback: Union[httputil.HTTPServerConnectionDelegate, Callable[[httputil.HTTPServerRequest], None]], no_keep_alive: bool = False, xheaders: bool = False, ssl_options: Union[Dict[str, Any], ssl.SSLContext] = None, protocol: Optional[str] = None, decompress_request: bool = False, chunk_size: Optional[int] = None, max_header_size: Optional[int] = None, idle_connection_timeout: Optional[float] = None, body_timeout: Optional[float] = None, max_body_size: Optional[int] = None, max_buffer_size: Optional[int] = None, trusted_downstream: Optional[List[str]] = None)
:members:
+
+ The public interface of this class is mostly inherited from
+ `.TCPServer` and is documented under that class.
diff --git a/docs/index.rst b/docs/index.rst
index 0892e92c22..6f59bd7085 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -9,21 +9,21 @@
.. |Tornado Web Server| image:: tornado.png
:alt: Tornado Web Server
-`Tornado `_ is a Python web framework and
+`Tornado `_ is a Python web framework and
asynchronous networking library, originally developed at `FriendFeed
-`_. By using non-blocking network I/O, Tornado
+`_. By using non-blocking network I/O, Tornado
can scale to tens of thousands of open connections, making it ideal for
-`long polling `_,
-`WebSockets `_, and other
+`long polling `_,
+`WebSockets `_, and other
applications that require a long-lived connection to each user.
Quick links
-----------
* Current version: |version| (`download from PyPI `_, :doc:`release notes `)
-* `Source (github) `_
-* Mailing lists: `discussion `_ and `announcements `_
-* `Stack Overflow `_
+* `Source (GitHub) `_
+* Mailing lists: `discussion `_ and `announcements `_
+* `Stack Overflow `_
* `Wiki `_
Hello, world
@@ -73,6 +73,14 @@ that the function passed to ``run_in_executor`` should avoid
referencing any Tornado objects. ``run_in_executor`` is the
recommended way to interact with blocking code.
+``asyncio`` Integration
+-----------------------
+
+Tornado is integrated with the standard library `asyncio` module and
+shares the same event loop (by default since Tornado 5.0). In general,
+libraries designed for use with `asyncio` can be mixed freely with
+Tornado.
+
Installation
------------
@@ -81,42 +89,36 @@ Installation
pip install tornado
-Tornado is listed in `PyPI `_ and
+Tornado is listed in `PyPI `_ and
can be installed with ``pip``. Note that the source distribution
includes demo applications that are not present when Tornado is
installed in this way, so you may wish to download a copy of the
source tarball or clone the `git repository
`_ as well.
-**Prerequisites**: Tornado runs on Python 2.7, and 3.4+.
-The updates to the `ssl` module in Python 2.7.9 are required
-(in some distributions, these updates may be available in
-older python versions). In addition to the requirements
-which will be installed automatically by ``pip`` or ``setup.py install``,
-the following optional packages may be useful:
+**Prerequisites**: Tornado 6.0 requires Python 3.6 or newer (See
+`Tornado 5.1 `_ if
+compatibility with Python 2.7 is required). The following optional
+packages may be useful:
-* `pycurl `_ is used by the optional
+* `pycurl `_ is used by the optional
``tornado.curl_httpclient``. Libcurl version 7.22 or higher is required.
-* `Twisted `_ may be used with the classes in
+* `Twisted `_ may be used with the classes in
`tornado.platform.twisted`.
-* `pycares `_ is an alternative
+* `pycares `_ is an alternative
non-blocking DNS resolver that can be used when threads are not
appropriate.
-* `monotonic `_ or `Monotime
- `_ add support for a
- monotonic clock, which improves reliability in environments where
- clock adjustments are frequent. No longer needed in Python 3.
-
-**Platforms**: Tornado should run on any Unix-like platform, although
-for the best performance and scalability only Linux (with ``epoll``)
-and BSD (with ``kqueue``) are recommended for production deployment
-(even though Mac OS X is derived from BSD and supports kqueue, its
-networking performance is generally poor so it is recommended only for
-development use). Tornado will also run on Windows, although this
-configuration is not officially supported and is recommended only for
-development use. Without reworking Tornado IOLoop interface, it's not
-possible to add a native Tornado Windows IOLoop implementation or
-leverage Windows' IOCP support from frameworks like AsyncIO or Twisted.
+
+**Platforms**: Tornado is designed for Unix-like platforms, with best
+performance and scalability on systems supporting ``epoll`` (Linux),
+``kqueue`` (BSD/macOS), or ``/dev/poll`` (Solaris).
+
+Tornado will also run on Windows, although this configuration is not
+officially supported or recommended for production use. Some features
+are missing on Windows (including multi-process mode) and scalability
+is limited (Even though Tornado is built on ``asyncio``, which
+supports Windows, Tornado does not use the APIs that are necessary for
+scalable networking on Windows).
Documentation
-------------
@@ -145,17 +147,17 @@ Discussion and support
----------------------
You can discuss Tornado on `the Tornado developer mailing list
-`_, and report bugs on
+`_, and report bugs on
the `GitHub issue tracker
`_. Links to additional
resources can be found on the `Tornado wiki
`_. New releases are
announced on the `announcements mailing list
-`_.
+`_.
Tornado is available under
the `Apache License, Version 2.0
`_.
This web site and all documentation is licensed under `Creative
-Commons 3.0 `_.
+Commons 3.0 `_.
diff --git a/docs/ioloop.rst b/docs/ioloop.rst
index 2c5fdc2cd2..5b748d3695 100644
--- a/docs/ioloop.rst
+++ b/docs/ioloop.rst
@@ -45,18 +45,3 @@
.. automethod:: IOLoop.time
.. autoclass:: PeriodicCallback
:members:
-
- Debugging and error handling
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
- .. automethod:: IOLoop.handle_callback_exception
- .. automethod:: IOLoop.set_blocking_signal_threshold
- .. automethod:: IOLoop.set_blocking_log_threshold
- .. automethod:: IOLoop.log_stack
-
- Methods for subclasses
- ^^^^^^^^^^^^^^^^^^^^^^
-
- .. automethod:: IOLoop.initialize
- .. automethod:: IOLoop.close_fd
- .. automethod:: IOLoop.split_fd
diff --git a/docs/locks.rst b/docs/locks.rst
index 9f991880c3..df30351cea 100644
--- a/docs/locks.rst
+++ b/docs/locks.rst
@@ -10,9 +10,10 @@ similar to those provided in the standard library's `asyncio package
.. warning::
- Note that these primitives are not actually thread-safe and cannot be used in
- place of those from the standard library--they are meant to coordinate Tornado
- coroutines in a single-threaded app, not to protect shared objects in a
+ Note that these primitives are not actually thread-safe and cannot
+ be used in place of those from the standard library's `threading`
+ module--they are meant to coordinate Tornado coroutines in a
+ single-threaded app, not to protect shared objects in a
multithreaded app.
.. automodule:: tornado.locks
diff --git a/docs/releases.rst b/docs/releases.rst
index 6f87edc30b..a478821d78 100644
--- a/docs/releases.rst
+++ b/docs/releases.rst
@@ -4,6 +4,14 @@ Release notes
.. toctree::
:maxdepth: 2
+ releases/v6.1.0
+ releases/v6.0.4
+ releases/v6.0.3
+ releases/v6.0.2
+ releases/v6.0.1
+ releases/v6.0.0
+ releases/v5.1.1
+ releases/v5.1.0
releases/v5.0.2
releases/v5.0.1
releases/v5.0.0
diff --git a/docs/releases/v2.2.0.rst b/docs/releases/v2.2.0.rst
index a3298c557c..922c6d2ff5 100644
--- a/docs/releases/v2.2.0.rst
+++ b/docs/releases/v2.2.0.rst
@@ -56,7 +56,7 @@ Backwards-incompatible changes
* ``IOStream.write`` now works correctly when given an empty string.
* ``IOStream.read_until`` (and ``read_until_regex``) now perform better
- when there is a lot of buffered data, which improves peformance of
+ when there is a lot of buffered data, which improves performance of
``SimpleAsyncHTTPClient`` when downloading files with lots of
chunks.
* `.SSLIOStream` now works correctly when ``ssl_version`` is set to
diff --git a/docs/releases/v2.3.0.rst b/docs/releases/v2.3.0.rst
index d24f46c547..5be231d032 100644
--- a/docs/releases/v2.3.0.rst
+++ b/docs/releases/v2.3.0.rst
@@ -102,9 +102,9 @@ Other modules
function is called repeatedly.
* `tornado.locale.get_supported_locales` no longer takes a meaningless
``cls`` argument.
-* `.StackContext` instances now have a deactivation callback that can be
+* ``StackContext`` instances now have a deactivation callback that can be
used to prevent further propagation.
* `tornado.testing.AsyncTestCase.wait` now resets its timeout on each call.
-* `tornado.wsgi.WSGIApplication` now parses arguments correctly on Python 3.
+* ``tornado.wsgi.WSGIApplication`` now parses arguments correctly on Python 3.
* Exception handling on Python 3 has been improved; previously some exceptions
such as `UnicodeDecodeError` would generate ``TypeErrors``
diff --git a/docs/releases/v2.4.0.rst b/docs/releases/v2.4.0.rst
index bbf07bf82e..5bbff30bcb 100644
--- a/docs/releases/v2.4.0.rst
+++ b/docs/releases/v2.4.0.rst
@@ -57,7 +57,7 @@ HTTP clients
* New method `.RequestHandler.get_template_namespace` can be overridden to
add additional variables without modifying keyword arguments to
``render_string``.
-* `.RequestHandler.add_header` now works with `.WSGIApplication`.
+* `.RequestHandler.add_header` now works with ``WSGIApplication``.
* `.RequestHandler.get_secure_cookie` now handles a potential error case.
* ``RequestHandler.__init__`` now calls ``super().__init__`` to ensure that
all constructors are called when multiple inheritance is used.
diff --git a/docs/releases/v2.4.1.rst b/docs/releases/v2.4.1.rst
index 22b09ca428..82eabc94e0 100644
--- a/docs/releases/v2.4.1.rst
+++ b/docs/releases/v2.4.1.rst
@@ -7,7 +7,7 @@ Nov 24, 2012
Bug fixes
~~~~~~~~~
-* Fixed a memory leak in `tornado.stack_context` that was especially likely
+* Fixed a memory leak in ``tornado.stack_context`` that was especially likely
with long-running ``@gen.engine`` functions.
* `tornado.auth.TwitterMixin` now works on Python 3.
* Fixed a bug in which ``IOStream.read_until_close`` with a streaming callback
diff --git a/docs/releases/v3.0.0.rst b/docs/releases/v3.0.0.rst
index a1d34e12f2..53c9771f30 100644
--- a/docs/releases/v3.0.0.rst
+++ b/docs/releases/v3.0.0.rst
@@ -10,7 +10,7 @@ Highlights
* The ``callback`` argument to many asynchronous methods is now
optional, and these methods return a `.Future`. The `tornado.gen`
module now understands ``Futures``, and these methods can be used
- directly without a `.gen.Task` wrapper.
+ directly without a ``gen.Task`` wrapper.
* New function `.IOLoop.current` returns the `.IOLoop` that is running
on the current thread (as opposed to `.IOLoop.instance`, which
returns a specific thread's (usually the main thread's) IOLoop.
@@ -136,7 +136,7 @@ Multiple modules
calling a callback you return a value with ``raise
gen.Return(value)`` (or simply ``return value`` in Python 3.3).
* Generators may now yield `.Future` objects.
-* Callbacks produced by `.gen.Callback` and `.gen.Task` are now automatically
+* Callbacks produced by ``gen.Callback`` and ``gen.Task`` are now automatically
stack-context-wrapped, to minimize the risk of context leaks when used
with asynchronous functions that don't do their own wrapping.
* Fixed a memory leak involving generators, `.RequestHandler.flush`,
@@ -167,7 +167,7 @@ Multiple modules
when instantiating an implementation subclass directly.
* Secondary `.AsyncHTTPClient` callbacks (``streaming_callback``,
``header_callback``, and ``prepare_curl_callback``) now respect
- `.StackContext`.
+ ``StackContext``.
`tornado.httpserver`
~~~~~~~~~~~~~~~~~~~~
@@ -311,9 +311,9 @@ Multiple modules
`tornado.platform.twisted`
~~~~~~~~~~~~~~~~~~~~~~~~~~
-* New class `tornado.platform.twisted.TwistedIOLoop` allows Tornado
+* New class ``tornado.platform.twisted.TwistedIOLoop`` allows Tornado
code to be run on the Twisted reactor (as opposed to the existing
- `.TornadoReactor`, which bridges the gap in the other direction).
+ ``TornadoReactor``, which bridges the gap in the other direction).
* New class `tornado.platform.twisted.TwistedResolver` is an asynchronous
implementation of the `.Resolver` interface.
@@ -343,10 +343,10 @@ Multiple modules
* Fixed a bug in which ``SimpleAsyncHTTPClient`` callbacks were being run in the
client's ``stack_context``.
-`tornado.stack_context`
-~~~~~~~~~~~~~~~~~~~~~~~
+``tornado.stack_context``
+~~~~~~~~~~~~~~~~~~~~~~~~~
-* `.stack_context.wrap` now runs the wrapped callback in a more consistent
+* ``stack_context.wrap`` now runs the wrapped callback in a more consistent
environment by recreating contexts even if they already exist on the
stack.
* Fixed a bug in which stack contexts could leak from one callback
diff --git a/docs/releases/v3.0.2.rst b/docs/releases/v3.0.2.rst
index 7eac09dba0..70e7d52b3e 100644
--- a/docs/releases/v3.0.2.rst
+++ b/docs/releases/v3.0.2.rst
@@ -9,4 +9,4 @@ Jun 2, 2013
June 11 `_. It also now uses HTTPS
when talking to Twitter.
* Fixed a potential memory leak with a long chain of `.gen.coroutine`
- or `.gen.engine` functions.
+ or ``gen.engine`` functions.
diff --git a/docs/releases/v3.1.0.rst b/docs/releases/v3.1.0.rst
index edbf39dba6..b4ae0e12ee 100644
--- a/docs/releases/v3.1.0.rst
+++ b/docs/releases/v3.1.0.rst
@@ -28,7 +28,7 @@ Multiple modules
are asynchronous in `.OAuthMixin` and derived classes, although they
do not take a callback. The `.Future` these methods return must be
yielded if they are called from a function decorated with `.gen.coroutine`
- (but not `.gen.engine`).
+ (but not ``gen.engine``).
* `.TwitterMixin` now uses ``/account/verify_credentials`` to get information
about the logged-in user, which is more robust against changing screen
names.
@@ -147,11 +147,11 @@ Multiple modules
* `.Subprocess.set_exit_callback` now works for subprocesses created
without an explicit ``io_loop`` parameter.
-`tornado.stack_context`
-~~~~~~~~~~~~~~~~~~~~~~~
+``tornado.stack_context``
+~~~~~~~~~~~~~~~~~~~~~~~~~
-* `tornado.stack_context` has been rewritten and is now much faster.
-* New function `.run_with_stack_context` facilitates the use of stack
+* ``tornado.stack_context`` has been rewritten and is now much faster.
+* New function ``run_with_stack_context`` facilitates the use of stack
contexts with coroutines.
`tornado.tcpserver`
@@ -170,7 +170,7 @@ Multiple modules
~~~~~~~~~~~~~~~~~
* `tornado.testing.AsyncTestCase.wait` now raises the correct exception
- when it has been modified by `tornado.stack_context`.
+ when it has been modified by ``tornado.stack_context``.
* `tornado.testing.gen_test` can now be called as ``@gen_test(timeout=60)``
to give some tests a longer timeout than others.
* The environment variable ``ASYNC_TEST_TIMEOUT`` can now be set to
@@ -210,11 +210,11 @@ Multiple modules
instead of being turned into spaces.
* `.RequestHandler.send_error` will now only be called once per request,
even if multiple exceptions are caught by the stack context.
-* The `tornado.web.asynchronous` decorator is no longer necessary for
+* The ``tornado.web.asynchronous`` decorator is no longer necessary for
methods that return a `.Future` (i.e. those that use the `.gen.coroutine`
- or `.return_future` decorators)
+ or ``return_future`` decorators)
* `.RequestHandler.prepare` may now be asynchronous if it returns a
- `.Future`. The `~tornado.web.asynchronous` decorator is not used with
+ `.Future`. The ``tornado.web.asynchronous`` decorator is not used with
``prepare``; one of the `.Future`-related decorators should be used instead.
* ``RequestHandler.current_user`` may now be assigned to normally.
* `.RequestHandler.redirect` no longer silently strips control characters
diff --git a/docs/releases/v3.2.0.rst b/docs/releases/v3.2.0.rst
index 09057030ac..f0be99961a 100644
--- a/docs/releases/v3.2.0.rst
+++ b/docs/releases/v3.2.0.rst
@@ -79,7 +79,7 @@ New modules
`tornado.ioloop`
~~~~~~~~~~~~~~~~
-* `.IOLoop` now uses `~.IOLoop.handle_callback_exception` consistently for
+* `.IOLoop` now uses ``IOLoop.handle_callback_exception`` consistently for
error logging.
* `.IOLoop` now frees callback objects earlier, reducing memory usage
while idle.
diff --git a/docs/releases/v4.0.0.rst b/docs/releases/v4.0.0.rst
index dd60ea8c3a..6b470d3bc1 100644
--- a/docs/releases/v4.0.0.rst
+++ b/docs/releases/v4.0.0.rst
@@ -96,14 +96,14 @@ Other notes
will be created on demand when needed.
* The internals of the `tornado.gen` module have been rewritten to
improve performance when using ``Futures``, at the expense of some
- performance degradation for the older `.YieldPoint` interfaces.
+ performance degradation for the older ``YieldPoint`` interfaces.
* New function `.with_timeout` wraps a `.Future` and raises an exception
if it doesn't complete in a given amount of time.
* New object `.moment` can be yielded to allow the IOLoop to run for
one iteration before resuming.
-* `.Task` is now a function returning a `.Future` instead of a `.YieldPoint`
+* ``Task`` is now a function returning a `.Future` instead of a ``YieldPoint``
subclass. This change should be transparent to application code, but
- allows `.Task` to take advantage of the newly-optimized `.Future`
+ allows ``Task`` to take advantage of the newly-optimized `.Future`
handling.
`tornado.http1connection`
@@ -135,7 +135,7 @@ Other notes
* The ``connection`` attribute of `.HTTPServerRequest` is now documented
for public use; applications are expected to write their responses
via the `.HTTPConnection` interface.
-* The `.HTTPServerRequest.write` and `.HTTPServerRequest.finish` methods
+* The ``HTTPServerRequest.write`` and ``HTTPServerRequest.finish`` methods
are now deprecated. (`.RequestHandler.write` and `.RequestHandler.finish`
are *not* deprecated; this only applies to the methods on
`.HTTPServerRequest`)
@@ -233,7 +233,7 @@ Other notes
`tornado.platform.twisted`
~~~~~~~~~~~~~~~~~~~~~~~~~~
-* `.TwistedIOLoop` now works on Python 3.3+ (with Twisted 14.0.0+).
+* ``TwistedIOLoop`` now works on Python 3.3+ (with Twisted 14.0.0+).
``tornado.simple_httpclient``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -251,8 +251,8 @@ Other notes
* ``simple_httpclient`` now raises the original exception (e.g. an `IOError`)
in more cases, instead of converting everything to ``HTTPError``.
-`tornado.stack_context`
-~~~~~~~~~~~~~~~~~~~~~~~
+``tornado.stack_context``
+~~~~~~~~~~~~~~~~~~~~~~~~~
* The stack context system now has less performance overhead when no
stack contexts are active.
@@ -324,8 +324,8 @@ Other notes
`tornado.wsgi`
~~~~~~~~~~~~~~
-* New class `.WSGIAdapter` supports running a Tornado `.Application` on
+* New class ``WSGIAdapter`` supports running a Tornado `.Application` on
a WSGI server in a way that is more compatible with Tornado's non-WSGI
- `.HTTPServer`. `.WSGIApplication` is deprecated in favor of using
- `.WSGIAdapter` with a regular `.Application`.
-* `.WSGIAdapter` now supports gzipped output.
+ `.HTTPServer`. ``WSGIApplication`` is deprecated in favor of using
+ ``WSGIAdapter`` with a regular `.Application`.
+* ``WSGIAdapter`` now supports gzipped output.
diff --git a/docs/releases/v4.1.0.rst b/docs/releases/v4.1.0.rst
index 29ad19146a..74cd30a49f 100644
--- a/docs/releases/v4.1.0.rst
+++ b/docs/releases/v4.1.0.rst
@@ -60,7 +60,7 @@ Backwards-compatibility notes
yieldable in coroutines.
* New function `tornado.gen.sleep` is a coroutine-friendly
analogue to `time.sleep`.
-* `.gen.engine` now correctly captures the stack context for its callbacks.
+* ``gen.engine`` now correctly captures the stack context for its callbacks.
`tornado.httpclient`
~~~~~~~~~~~~~~~~~~~~
@@ -167,7 +167,7 @@ Backwards-compatibility notes
`tornado.web`
~~~~~~~~~~~~~
-* The `.asynchronous` decorator now understands `concurrent.futures.Future`
+* The ``asynchronous`` decorator now understands `concurrent.futures.Future`
in addition to `tornado.concurrent.Future`.
* `.StaticFileHandler` no longer logs a stack trace if the connection is
closed while sending the file.
diff --git a/docs/releases/v4.2.0.rst b/docs/releases/v4.2.0.rst
index 93493ee113..bacfb13a05 100644
--- a/docs/releases/v4.2.0.rst
+++ b/docs/releases/v4.2.0.rst
@@ -162,7 +162,7 @@ Then the Tornado equivalent is::
* The `.IOLoop` constructor now has a ``make_current`` keyword argument
to control whether the new `.IOLoop` becomes `.IOLoop.current()`.
* Third-party implementations of `.IOLoop` should accept ``**kwargs``
- in their `~.IOLoop.initialize` methods and pass them to the superclass
+ in their ``IOLoop.initialize`` methods and pass them to the superclass
implementation.
* `.PeriodicCallback` is now more efficient when the clock jumps forward
by a large amount.
diff --git a/docs/releases/v4.3.0.rst b/docs/releases/v4.3.0.rst
index e6ea1589c7..b19b297c1b 100644
--- a/docs/releases/v4.3.0.rst
+++ b/docs/releases/v4.3.0.rst
@@ -12,7 +12,7 @@ Highlights
Inside a function defined with ``async def``, use ``await`` instead of
``yield`` to wait on an asynchronous operation. Coroutines defined with
async/await will be faster than those defined with ``@gen.coroutine`` and
- ``yield``, but do not support some features including `.Callback`/`.Wait` or
+ ``yield``, but do not support some features including ``Callback``/``Wait`` or
the ability to yield a Twisted ``Deferred``. See :ref:`the users'
guide ` for more.
* The async/await keywords are also available when compiling with Cython in
diff --git a/docs/releases/v4.4.0.rst b/docs/releases/v4.4.0.rst
index fa7c81702a..5ac3018b9f 100644
--- a/docs/releases/v4.4.0.rst
+++ b/docs/releases/v4.4.0.rst
@@ -26,7 +26,7 @@ General
~~~~~~~~~~~~~
* `.with_timeout` now accepts any yieldable object (except
- `.YieldPoint`), not just `tornado.concurrent.Future`.
+ ``YieldPoint``), not just `tornado.concurrent.Future`.
`tornado.httpclient`
~~~~~~~~~~~~~~~~~~~~
diff --git a/docs/releases/v4.5.0.rst b/docs/releases/v4.5.0.rst
index 9cdf0ad5c2..5a4ce9e258 100644
--- a/docs/releases/v4.5.0.rst
+++ b/docs/releases/v4.5.0.rst
@@ -58,7 +58,7 @@ General changes
- Fixed an issue in which a generator object could be garbage
collected prematurely (most often when weak references are used.
- New function `.is_coroutine_function` identifies functions wrapped
- by `.coroutine` or `.engine`.
+ by `.coroutine` or ``engine``.
``tornado.http1connection``
~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/docs/releases/v5.0.0.rst b/docs/releases/v5.0.0.rst
index 7f1325b8de..dd0bd02439 100644
--- a/docs/releases/v5.0.0.rst
+++ b/docs/releases/v5.0.0.rst
@@ -193,8 +193,8 @@ Other notes
- The ``io_loop`` argument to `.PeriodicCallback` has been removed.
- It is now possible to create a `.PeriodicCallback` in one thread
and start it in another without passing an explicit event loop.
-- The `.IOLoop.set_blocking_signal_threshold` and
- `.IOLoop.set_blocking_log_threshold` methods are deprecated because
+- The ``IOLoop.set_blocking_signal_threshold`` and
+ ``IOLoop.set_blocking_log_threshold`` methods are deprecated because
they are not implemented for the `asyncio` event loop`. Use the
``PYTHONASYNCIODEBUG=1`` environment variable instead.
- `.IOLoop.clear_current` now works if it is called before any
@@ -257,8 +257,8 @@ Other notes
`tornado.platform.twisted`
~~~~~~~~~~~~~~~~~~~~~~~~~~
-- The ``io_loop`` arguments to `.TornadoReactor`, `.TwistedResolver`,
- and `tornado.platform.twisted.install` have been removed.
+- The ``io_loop`` arguments to ``TornadoReactor``, `.TwistedResolver`,
+ and ``tornado.platform.twisted.install`` have been removed.
`tornado.process`
~~~~~~~~~~~~~~~~~
diff --git a/docs/releases/v5.1.0.rst b/docs/releases/v5.1.0.rst
new file mode 100644
index 0000000000..00def8f38c
--- /dev/null
+++ b/docs/releases/v5.1.0.rst
@@ -0,0 +1,195 @@
+What's new in Tornado 5.1
+=========================
+
+July 12, 2018
+-------------
+
+Deprecation notice
+~~~~~~~~~~~~~~~~~~
+
+- Tornado 6.0 will drop support for Python 2.7 and 3.4. The minimum
+ supported Python version will be 3.5.2.
+- The ``tornado.stack_context`` module is deprecated and will be removed
+ in Tornado 6.0. The reason for this is that it is not feasible to
+ provide this module's semantics in the presence of ``async def``
+ native coroutines. ``ExceptionStackContext`` is mainly obsolete
+ thanks to coroutines. ``StackContext`` lacks a direct replacement
+ although the new ``contextvars`` package (in the Python standard
+ library beginning in Python 3.7) may be an alternative.
+- Callback-oriented code often relies on ``ExceptionStackContext`` to
+ handle errors and prevent leaked connections. In order to avoid the
+ risk of silently introducing subtle leaks (and to consolidate all of
+ Tornado's interfaces behind the coroutine pattern), ``callback``
+ arguments throughout the package are deprecated and will be removed
+ in version 6.0. All functions that had a ``callback`` argument
+ removed now return a `.Future` which should be used instead.
+- Where possible, deprecation warnings are emitted when any of these
+ deprecated interfaces is used. However, Python does not display
+ deprecation warnings by default. To prepare your application for
+ Tornado 6.0, run Python with the ``-Wd`` argument or set the
+ environment variable ``PYTHONWARNINGS`` to ``d``. If your
+ application runs on Python 3 without deprecation warnings, it should
+ be able to move to Tornado 6.0 without disruption.
+
+`tornado.auth`
+~~~~~~~~~~~~~~
+
+- `.OAuthMixin._oauth_get_user_future` may now be a native coroutine.
+- All ``callback`` arguments in this package are deprecated and will
+ be removed in 6.0. Use the coroutine interfaces instead.
+- The ``OAuthMixin._oauth_get_user`` method is deprecated and will be removed in
+ 6.0. Override `~.OAuthMixin._oauth_get_user_future` instead.
+
+`tornado.autoreload`
+~~~~~~~~~~~~~~~~~~~~
+
+- The command-line autoreload wrapper is now preserved if an internal
+ autoreload fires.
+- The command-line wrapper no longer starts duplicated processes on windows
+ when combined with internal autoreload.
+
+`tornado.concurrent`
+~~~~~~~~~~~~~~~~~~~~
+
+- `.run_on_executor` now returns `.Future` objects that are compatible
+ with ``await``.
+- The ``callback`` argument to `.run_on_executor` is deprecated and will
+ be removed in 6.0.
+- ``return_future`` is deprecated and will be removed in 6.0.
+
+`tornado.gen`
+~~~~~~~~~~~~~
+
+- Some older portions of this module are deprecated and will be removed
+ in 6.0. This includes ``engine``, ``YieldPoint``, ``Callback``,
+ ``Wait``, ``WaitAll``, ``MultiYieldPoint``, and ``Task``.
+- Functions decorated with ``@gen.coroutine`` will no longer accept
+ ``callback`` arguments in 6.0.
+
+`tornado.httpclient`
+~~~~~~~~~~~~~~~~~~~~
+
+- The behavior of ``raise_error=False`` is changing in 6.0. Currently
+ it suppresses all errors; in 6.0 it will only suppress the errors
+ raised due to completed responses with non-200 status codes.
+- The ``callback`` argument to `.AsyncHTTPClient.fetch` is deprecated
+ and will be removed in 6.0.
+- `tornado.httpclient.HTTPError` has been renamed to
+ `.HTTPClientError` to avoid ambiguity in code that also has to deal
+ with `tornado.web.HTTPError`. The old name remains as an alias.
+- ``tornado.curl_httpclient`` now supports non-ASCII characters in
+ username and password arguments.
+- ``.HTTPResponse.request_time`` now behaves consistently across
+ ``simple_httpclient`` and ``curl_httpclient``, excluding time spent
+ in the ``max_clients`` queue in both cases (previously this time was
+ included in ``simple_httpclient`` but excluded in
+ ``curl_httpclient``). In both cases the time is now computed using
+ a monotonic clock where available.
+- `.HTTPResponse` now has a ``start_time`` attribute recording a
+ wall-clock (`time.time`) timestamp at which the request started
+ (after leaving the ``max_clients`` queue if applicable).
+
+`tornado.httputil`
+~~~~~~~~~~~~~~~~~~
+
+- `.parse_multipart_form_data` now recognizes non-ASCII filenames in
+ RFC 2231/5987 (``filename*=``) format.
+- ``HTTPServerRequest.write`` is deprecated and will be removed in 6.0. Use
+ the methods of ``request.connection`` instead.
+- Malformed HTTP headers are now logged less noisily.
+
+`tornado.ioloop`
+~~~~~~~~~~~~~~~~
+
+- `.PeriodicCallback` now supports a ``jitter`` argument to randomly
+ vary the timeout.
+- ``IOLoop.set_blocking_signal_threshold``,
+ ``IOLoop.set_blocking_log_threshold``, ``IOLoop.log_stack``,
+ and ``IOLoop.handle_callback_exception`` are deprecated and will
+ be removed in 6.0.
+- Fixed a `KeyError` in `.IOLoop.close` when `.IOLoop` objects are
+ being opened and closed in multiple threads.
+
+`tornado.iostream`
+~~~~~~~~~~~~~~~~~~
+
+- All ``callback`` arguments in this module are deprecated except for
+ `.BaseIOStream.set_close_callback`. They will be removed in 6.0.
+- ``streaming_callback`` arguments to `.BaseIOStream.read_bytes` and
+ `.BaseIOStream.read_until_close` are deprecated and will be removed
+ in 6.0.
+
+`tornado.netutil`
+~~~~~~~~~~~~~~~~~
+
+- Improved compatibility with GNU Hurd.
+
+`tornado.options`
+~~~~~~~~~~~~~~~~~
+
+- `tornado.options.parse_config_file` now allows setting options to
+ strings (which will be parsed the same way as
+ `tornado.options.parse_command_line`) in addition to the specified
+ type for the option.
+
+`tornado.platform.twisted`
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+- ``TornadoReactor`` and ``TwistedIOLoop`` are deprecated and will be
+ removed in 6.0. Instead, Tornado will always use the asyncio event loop
+ and twisted can be configured to do so as well.
+
+``tornado.stack_context``
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+- The ``tornado.stack_context`` module is deprecated and will be removed
+ in 6.0.
+
+`tornado.testing`
+~~~~~~~~~~~~~~~~~
+
+- `.AsyncHTTPTestCase.fetch` now takes a ``raise_error`` argument.
+ This argument has the same semantics as `.AsyncHTTPClient.fetch`,
+ but defaults to false because tests often need to deal with non-200
+ responses (and for backwards-compatibility).
+- The `.AsyncTestCase.stop` and `.AsyncTestCase.wait` methods are
+ deprecated.
+
+`tornado.web`
+~~~~~~~~~~~~~
+
+- New method `.RequestHandler.detach` can be used from methods
+ that are not decorated with ``@asynchronous`` (the decorator
+ was required to use ``self.request.connection.detach()``.
+- `.RequestHandler.finish` and `.RequestHandler.render` now return
+ ``Futures`` that can be used to wait for the last part of the
+ response to be sent to the client.
+- `.FallbackHandler` now calls ``on_finish`` for the benefit of
+ subclasses that may have overridden it.
+- The ``asynchronous`` decorator is deprecated and will be removed in 6.0.
+- The ``callback`` argument to `.RequestHandler.flush` is deprecated
+ and will be removed in 6.0.
+
+
+`tornado.websocket`
+~~~~~~~~~~~~~~~~~~~
+
+- When compression is enabled, memory limits now apply to the
+ post-decompression size of the data, protecting against DoS attacks.
+- `.websocket_connect` now supports subprotocols.
+- `.WebSocketHandler` and `.WebSocketClientConnection` now have
+ ``selected_subprotocol`` attributes to see the subprotocol in use.
+- The `.WebSocketHandler.select_subprotocol` method is now called with
+ an empty list instead of a list containing an empty string if no
+ subprotocols were requested by the client.
+- `.WebSocketHandler.open` may now be a coroutine.
+- The ``data`` argument to `.WebSocketHandler.ping` is now optional.
+- Client-side websocket connections no longer buffer more than one
+ message in memory at a time.
+- Exception logging now uses `.RequestHandler.log_exception`.
+
+`tornado.wsgi`
+~~~~~~~~~~~~~~
+
+- ``WSGIApplication`` and ``WSGIAdapter`` are deprecated and will be removed
+ in Tornado 6.0.
diff --git a/docs/releases/v5.1.1.rst b/docs/releases/v5.1.1.rst
new file mode 100644
index 0000000000..7fc4fb881a
--- /dev/null
+++ b/docs/releases/v5.1.1.rst
@@ -0,0 +1,14 @@
+What's new in Tornado 5.1.1
+===========================
+
+Sep 16, 2018
+------------
+
+Bug fixes
+~~~~~~~~~
+
+- Fixed an case in which the `.Future` returned by
+ `.RequestHandler.finish` could fail to resolve.
+- The `.TwitterMixin.authenticate_redirect` method works again.
+- Improved error handling in the `tornado.auth` module, fixing hanging
+ requests when a network or other error occurs.
diff --git a/docs/releases/v6.0.0.rst b/docs/releases/v6.0.0.rst
new file mode 100644
index 0000000000..d3d2dfbc0b
--- /dev/null
+++ b/docs/releases/v6.0.0.rst
@@ -0,0 +1,162 @@
+What's new in Tornado 6.0
+=========================
+
+Mar 1, 2019
+-----------
+
+Backwards-incompatible changes
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+- Python 2.7 and 3.4 are no longer supported; the minimum supported
+ Python version is 3.5.2.
+- APIs deprecated in Tornado 5.1 have been removed. This includes the
+ ``tornado.stack_context`` module and most ``callback`` arguments
+ throughout the package. All removed APIs emitted
+ `DeprecationWarning` when used in Tornado 5.1, so running your
+ application with the ``-Wd`` Python command-line flag or the
+ environment variable ``PYTHONWARNINGS=d`` should tell you whether
+ your application is ready to move to Tornado 6.0.
+- ``.WebSocketHandler.get`` is now a coroutine and must be called
+ accordingly in any subclasses that override this method (but note
+ that overriding ``get`` is not recommended; either ``prepare`` or
+ ``open`` should be used instead).
+
+General changes
+~~~~~~~~~~~~~~~
+
+- Tornado now includes type annotations compatible with ``mypy``.
+ These annotations will be used when type-checking your application
+ with ``mypy``, and may be usable in editors and other tools.
+- Tornado now uses native coroutines internally, improving performance.
+
+`tornado.auth`
+~~~~~~~~~~~~~~
+
+- All ``callback`` arguments in this package have been removed. Use
+ the coroutine interfaces instead.
+- The ``OAuthMixin._oauth_get_user`` method has been removed.
+ Override `~.OAuthMixin._oauth_get_user_future` instead.
+
+`tornado.concurrent`
+~~~~~~~~~~~~~~~~~~~~
+
+- The ``callback`` argument to `.run_on_executor` has been removed.
+- ``return_future`` has been removed.
+
+`tornado.gen`
+~~~~~~~~~~~~~
+
+- Some older portions of this module have been removed. This includes
+ ``engine``, ``YieldPoint``, ``Callback``, ``Wait``, ``WaitAll``,
+ ``MultiYieldPoint``, and ``Task``.
+- Functions decorated with ``@gen.coroutine`` no longer accept
+ ``callback`` arguments.
+
+`tornado.httpclient`
+~~~~~~~~~~~~~~~~~~~~
+
+- The behavior of ``raise_error=False`` has changed. Now only
+ suppresses the errors raised due to completed responses with non-200
+ status codes (previously it suppressed all errors).
+- The ``callback`` argument to `.AsyncHTTPClient.fetch` has been removed.
+
+`tornado.httputil`
+~~~~~~~~~~~~~~~~~~
+
+- ``HTTPServerRequest.write`` has been removed. Use the methods of
+ ``request.connection`` instead.
+- Unrecognized ``Content-Encoding`` values now log warnings only for
+ content types that we would otherwise attempt to parse.
+
+`tornado.ioloop`
+~~~~~~~~~~~~~~~~
+
+- ``IOLoop.set_blocking_signal_threshold``,
+ ``IOLoop.set_blocking_log_threshold``, ``IOLoop.log_stack``,
+ and ``IOLoop.handle_callback_exception`` have been removed.
+- Improved performance of `.IOLoop.add_callback`.
+
+`tornado.iostream`
+~~~~~~~~~~~~~~~~~~
+
+- All ``callback`` arguments in this module have been removed except
+ for `.BaseIOStream.set_close_callback`.
+- ``streaming_callback`` arguments to `.BaseIOStream.read_bytes` and
+ `.BaseIOStream.read_until_close` have been removed.
+- Eliminated unnecessary logging of "Errno 0".
+
+`tornado.log`
+~~~~~~~~~~~~~
+
+- Log files opened by this module are now explicitly set to UTF-8 encoding.
+
+`tornado.netutil`
+~~~~~~~~~~~~~~~~~
+
+- The results of ``getaddrinfo`` are now sorted by address family to
+ avoid partial failures and deadlocks.
+
+`tornado.platform.twisted`
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+- ``TornadoReactor`` and ``TwistedIOLoop`` have been removed.
+
+``tornado.simple_httpclient``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+- The default HTTP client now supports the ``network_interface``
+ request argument to specify the source IP for the connection.
+- If a server returns a 3xx response code without a ``Location``
+ header, the response is raised or returned directly instead of
+ trying and failing to follow the redirect.
+- When following redirects, methods other than ``POST`` will no longer
+ be transformed into ``GET`` requests. 301 (permanent) redirects are
+ now treated the same way as 302 (temporary) and 303 (see other)
+ redirects in this respect.
+- Following redirects now works with ``body_producer``.
+
+``tornado.stack_context``
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+- The ``tornado.stack_context`` module has been removed.
+
+`tornado.tcpserver`
+~~~~~~~~~~~~~~~~~~~
+
+- `.TCPServer.start` now supports a ``max_restarts`` argument (same as
+ `.fork_processes`).
+
+`tornado.testing`
+~~~~~~~~~~~~~~~~~
+
+- `.AsyncHTTPTestCase` now drops all references to the `.Application`
+ during ``tearDown``, allowing its memory to be reclaimed sooner.
+- `.AsyncTestCase` now cancels all pending coroutines in ``tearDown``,
+ in an effort to reduce warnings from the python runtime about
+ coroutines that were not awaited. Note that this may cause
+ ``asyncio.CancelledError`` to be logged in other places. Coroutines
+ that expect to be running at test shutdown may need to catch this
+ exception.
+
+`tornado.web`
+~~~~~~~~~~~~~
+
+- The ``asynchronous`` decorator has been removed.
+- The ``callback`` argument to `.RequestHandler.flush` has been removed.
+- `.StaticFileHandler` now supports large negative values for the
+ ``Range`` header and returns an appropriate error for ``end >
+ start``.
+- It is now possible to set ``expires_days`` in ``xsrf_cookie_kwargs``.
+
+`tornado.websocket`
+~~~~~~~~~~~~~~~~~~~
+
+- Pings and other messages sent while the connection is closing are
+ now silently dropped instead of logging exceptions.
+- Errors raised by ``open()`` are now caught correctly when this method
+ is a coroutine.
+
+`tornado.wsgi`
+~~~~~~~~~~~~~~
+
+- ``WSGIApplication`` and ``WSGIAdapter`` have been removed.
diff --git a/docs/releases/v6.0.1.rst b/docs/releases/v6.0.1.rst
new file mode 100644
index 0000000000..c9da7507e6
--- /dev/null
+++ b/docs/releases/v6.0.1.rst
@@ -0,0 +1,11 @@
+What's new in Tornado 6.0.1
+===========================
+
+Mar 3, 2019
+-----------
+
+Bug fixes
+~~~~~~~~~
+
+- Fixed issues with type annotations that caused errors while
+ importing Tornado on Python 3.5.2.
diff --git a/docs/releases/v6.0.2.rst b/docs/releases/v6.0.2.rst
new file mode 100644
index 0000000000..3d394a3edc
--- /dev/null
+++ b/docs/releases/v6.0.2.rst
@@ -0,0 +1,13 @@
+What's new in Tornado 6.0.2
+===========================
+
+Mar 23, 2019
+------------
+
+Bug fixes
+~~~~~~~~~
+
+- `.WebSocketHandler.set_nodelay` works again.
+- Accessing ``HTTPResponse.body`` now returns an empty byte string
+ instead of raising ``ValueError`` for error responses that don't
+ have a body (it returned None in this case in Tornado 5).
diff --git a/docs/releases/v6.0.3.rst b/docs/releases/v6.0.3.rst
new file mode 100644
index 0000000000..c112a0286d
--- /dev/null
+++ b/docs/releases/v6.0.3.rst
@@ -0,0 +1,14 @@
+What's new in Tornado 6.0.3
+===========================
+
+Jun 22, 2019
+------------
+
+Bug fixes
+~~~~~~~~~
+
+- `.gen.with_timeout` always treats ``asyncio.CancelledError`` as a
+ ``quiet_exception`` (this improves compatibility with Python 3.8,
+ which changed ``CancelledError`` to a ``BaseException``).
+- ``IOStream`` now checks for closed streams earlier, avoiding
+ spurious logged errors in some situations (mainly with websockets).
diff --git a/docs/releases/v6.0.4.rst b/docs/releases/v6.0.4.rst
new file mode 100644
index 0000000000..f9864bff4a
--- /dev/null
+++ b/docs/releases/v6.0.4.rst
@@ -0,0 +1,21 @@
+What's new in Tornado 6.0.4
+===========================
+
+Mar 3, 2020
+-----------
+
+General changes
+~~~~~~~~~~~~~~~
+
+- Binary wheels are now available for Python 3.8 on Windows. Note that it is
+ still necessary to use
+ ``asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())`` for
+ this platform/version.
+
+Bug fixes
+~~~~~~~~~
+
+- Fixed an issue in `.IOStream` (introduced in 6.0.0) that resulted in
+ ``StreamClosedError`` being incorrectly raised if a stream is closed mid-read
+ but there is enough buffered data to satisfy the read.
+- `.AnyThreadEventLoopPolicy` now always uses the selector event loop on Windows.
\ No newline at end of file
diff --git a/docs/releases/v6.1.0.rst b/docs/releases/v6.1.0.rst
new file mode 100644
index 0000000000..7de6350ab5
--- /dev/null
+++ b/docs/releases/v6.1.0.rst
@@ -0,0 +1,106 @@
+What's new in Tornado 6.1.0
+===========================
+
+Oct 30, 2020
+------------
+
+Deprecation notice
+~~~~~~~~~~~~~~~~~~
+
+- This is the last release of Tornado to support Python 3.5. Future versions
+ will require Python 3.6 or newer.
+
+General changes
+~~~~~~~~~~~~~~~
+
+- Windows support has been improved. Tornado is now compatible with the proactor
+ event loop (which became the default in Python 3.8) by automatically falling
+ back to running a selector in a second thread. This means that it is no longer
+ necessary to explicitly configure a selector event loop, although doing so may
+ improve performance. This does not change the fact that Tornado is significantly
+ less scalable on Windows than on other platforms.
+- Binary wheels are now provided for Windows, MacOS, and Linux (amd64 and arm64).
+
+`tornado.gen`
+~~~~~~~~~~~~~
+
+- `.coroutine` now has better support for the Python 3.7+ ``contextvars`` module.
+ In particular, the ``ContextVar.reset`` method is now supported.
+
+`tornado.http1connection`
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+- ``HEAD`` requests to handlers that used chunked encoding no longer produce malformed output.
+- Certain kinds of malformed ``gzip`` data no longer cause an infinite loop.
+
+`tornado.httpclient`
+~~~~~~~~~~~~~~~~~~~~
+
+- Setting ``decompress_response=False`` now works correctly with
+ ``curl_httpclient``.
+- Mixing requests with and without proxies works correctly in ``curl_httpclient``
+ (assuming the version of pycurl is recent enough).
+- A default ``User-Agent`` of ``Tornado/$VERSION`` is now used if the
+ ``user_agent`` parameter is not specified.
+- After a 303 redirect, ``tornado.simple_httpclient`` always uses ``GET``.
+ Previously this would use ``GET`` if the original request was a ``POST`` and
+ would otherwise reuse the original request method. For ``curl_httpclient``, the
+ behavior depends on the version of ``libcurl`` (with the most recent versions
+ using ``GET`` after 303 regardless of the original method).
+- Setting ``request_timeout`` and/or ``connect_timeout`` to zero is now supported
+ to disable the timeout.
+
+`tornado.httputil`
+~~~~~~~~~~~~~~~~~~
+
+- Header parsing is now faster.
+- `.parse_body_arguments` now accepts incompletely-escaped non-ASCII inputs.
+
+`tornado.iostream`
+~~~~~~~~~~~~~~~~~~
+
+- `ssl.CertificateError` during the SSL handshake is now handled correctly.
+- Reads that are resolved while the stream is closing are now handled correctly.
+
+`tornado.log`
+~~~~~~~~~~~~~
+
+- When colored logging is enabled, ``logging.CRITICAL`` messages are now
+ recognized and colored magenta.
+
+`tornado.netutil`
+~~~~~~~~~~~~~~~~~
+
+- ``EADDRNOTAVAIL`` is now ignored when binding to ``localhost`` with IPv6. This
+ error is common in docker.
+
+`tornado.platform.asyncio`
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+- `.AnyThreadEventLoopPolicy` now also configures a selector event loop for
+ these threads (the proactor event loop only works on the main thread)
+
+``tornado.platform.auto``
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+- The ``set_close_exec`` function has been removed.
+
+`tornado.testing`
+~~~~~~~~~~~~~~~~~
+
+- `.ExpectLog` now has a ``level`` argument to ensure that the given log level
+ is enabled.
+
+`tornado.web`
+~~~~~~~~~~~~~
+
+- ``RedirectHandler.get`` now accepts keyword arguments.
+- When sending 304 responses, more headers (including ``Allow``) are now preserved.
+- ``reverse_url`` correctly handles escaped characters in the regex route.
+- Default ``Etag`` headers are now generated with SHA-512 instead of MD5.
+
+`tornado.websocket`
+~~~~~~~~~~~~~~~~~~~
+
+- The ``ping_interval`` timer is now stopped when the connection is closed.
+- `.websocket_connect` now raises an error when it encounters a redirect instead of hanging.
diff --git a/docs/requirements.in b/docs/requirements.in
new file mode 100644
index 0000000000..334534d739
--- /dev/null
+++ b/docs/requirements.in
@@ -0,0 +1,3 @@
+sphinx
+sphinxcontrib-asyncio
+sphinx_rtd_theme
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 082785c979..ab250c658e 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1 +1,26 @@
-Twisted
+alabaster==0.7.12
+babel==2.8.0
+certifi==2020.6.20
+chardet==3.0.4
+docutils==0.16
+idna==2.10
+imagesize==1.2.0
+jinja2==2.11.3
+markupsafe==1.1.1
+packaging==20.4
+pygments==2.7.4
+pyparsing==2.4.7
+pytz==2020.1
+requests==2.24.0
+six==1.15.0
+snowballstemmer==2.0.0
+sphinx-rtd-theme==0.5.0
+sphinx==3.2.1
+sphinxcontrib-applehelp==1.0.2
+sphinxcontrib-asyncio==0.3.0
+sphinxcontrib-devhelp==1.0.2
+sphinxcontrib-htmlhelp==1.0.3
+sphinxcontrib-jsmath==1.0.1
+sphinxcontrib-qthelp==1.0.3
+sphinxcontrib-serializinghtml==1.1.4
+urllib3==1.25.11
diff --git a/docs/stack_context.rst b/docs/stack_context.rst
deleted file mode 100644
index 489a37fdc5..0000000000
--- a/docs/stack_context.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-``tornado.stack_context`` --- Exception handling across asynchronous callbacks
-==============================================================================
-
-.. automodule:: tornado.stack_context
- :members:
diff --git a/docs/twisted.rst b/docs/twisted.rst
index a6e1946f11..5d8fe8fbc8 100644
--- a/docs/twisted.rst
+++ b/docs/twisted.rst
@@ -1,24 +1,60 @@
``tornado.platform.twisted`` --- Bridges between Twisted and Tornado
-========================================================================
+====================================================================
-.. automodule:: tornado.platform.twisted
+.. module:: tornado.platform.twisted
- Twisted on Tornado
- ------------------
+.. deprecated:: 6.0
- .. autoclass:: TornadoReactor
- :members:
+ This module is no longer recommended for new code. Instead of using
+ direct integration between Tornado and Twisted, new applications should
+ rely on the integration with ``asyncio`` provided by both packages.
- .. autofunction:: install
+Importing this module has the side effect of registering Twisted's ``Deferred``
+class with Tornado's ``@gen.coroutine`` so that ``Deferred`` objects can be
+used with ``yield`` in coroutines using this decorator (importing this module has
+no effect on native coroutines using ``async def``).
- Tornado on Twisted
- ------------------
+.. function:: install()
- .. autoclass:: TwistedIOLoop
- :members:
+ Install ``AsyncioSelectorReactor`` as the default Twisted reactor.
- Twisted DNS resolver
- --------------------
+ .. deprecated:: 5.1
+
+ This function is provided for backwards compatibility; code
+ that does not require compatibility with older versions of
+ Tornado should use
+ ``twisted.internet.asyncioreactor.install()`` directly.
+
+ .. versionchanged:: 6.0.3
+
+ In Tornado 5.x and before, this function installed a reactor
+ based on the Tornado ``IOLoop``. When that reactor
+ implementation was removed in Tornado 6.0.0, this function was
+ removed as well. It was restored in Tornado 6.0.3 using the
+ ``asyncio`` reactor instead.
+
+Twisted DNS resolver
+--------------------
+
+.. class:: TwistedResolver
+
+ Twisted-based asynchronous resolver.
+
+ This is a non-blocking and non-threaded resolver. It is
+ recommended only when threads cannot be used, since it has
+ limitations compared to the standard ``getaddrinfo``-based
+ `~tornado.netutil.Resolver` and
+ `~tornado.netutil.DefaultExecutorResolver`. Specifically, it returns at
+ most one result, and arguments other than ``host`` and ``family``
+ are ignored. It may fail to resolve when ``family`` is not
+ ``socket.AF_UNSPEC``.
+
+ Requires Twisted 12.1 or newer.
+
+ .. versionchanged:: 5.0
+ The ``io_loop`` argument (deprecated since version 4.1) has been removed.
+
+ .. deprecated:: 6.2
+ This class is deprecated and will be removed in Tornado 7.0. Use the default
+ thread-based resolver instead.
- .. autoclass:: TwistedResolver
- :members:
diff --git a/docs/utilities.rst b/docs/utilities.rst
index 55536626ed..4c6edf586a 100644
--- a/docs/utilities.rst
+++ b/docs/utilities.rst
@@ -7,6 +7,5 @@ Utilities
concurrent
log
options
- stack_context
testing
util
diff --git a/docs/web.rst b/docs/web.rst
index 77f2fff4c6..720d75678e 100644
--- a/docs/web.rst
+++ b/docs/web.rst
@@ -9,7 +9,7 @@
Request handlers
----------------
- .. autoclass:: RequestHandler
+ .. autoclass:: RequestHandler(...)
Entry points
^^^^^^^^^^^^
@@ -21,14 +21,14 @@
.. _verbs:
Implement any of the following methods (collectively known as the
- HTTP verb methods) to handle the corresponding HTTP method.
- These methods can be made asynchronous with one of the following
- decorators: `.gen.coroutine`, `.return_future`, or `asynchronous`.
+ HTTP verb methods) to handle the corresponding HTTP method. These
+ methods can be made asynchronous with the ``async def`` keyword or
+ `.gen.coroutine` decorator.
The arguments to these methods come from the `.URLSpec`: Any
capturing groups in the regular expression become arguments to the
HTTP verb methods (keyword arguments if the group is named,
- positional arguments if its unnamed).
+ positional arguments if it's unnamed).
To support a method not on this list, override the class variable
``SUPPORTED_METHODS``::
@@ -50,11 +50,24 @@
Input
^^^^^
- .. automethod:: RequestHandler.get_argument
+ The ``argument`` methods provide support for HTML form-style
+ arguments. These methods are available in both singular and plural
+ forms because HTML forms are ambiguous and do not distinguish
+ between a singular argument and a list containing one entry. If you
+ wish to use other formats for arguments (for example, JSON), parse
+ ``self.request.body`` yourself::
+
+ def prepare(self):
+ if self.request.headers['Content-Type'] == 'application/x-json':
+ self.args = json_decode(self.request.body)
+ # Access self.args directly instead of using self.get_argument.
+
+
+ .. automethod:: RequestHandler.get_argument(name: str, default: Union[None, str, RAISE] = RAISE, strip: bool = True) -> Optional[str]
.. automethod:: RequestHandler.get_arguments
- .. automethod:: RequestHandler.get_query_argument
+ .. automethod:: RequestHandler.get_query_argument(name: str, default: Union[None, str, RAISE] = RAISE, strip: bool = True) -> Optional[str]
.. automethod:: RequestHandler.get_query_arguments
- .. automethod:: RequestHandler.get_body_argument
+ .. automethod:: RequestHandler.get_body_argument(name: str, default: Union[None, str, RAISE] = RAISE, strip: bool = True) -> Optional[str]
.. automethod:: RequestHandler.get_body_arguments
.. automethod:: RequestHandler.decode_argument
.. attribute:: RequestHandler.request
@@ -125,6 +138,7 @@
.. automethod:: RequestHandler.compute_etag
.. automethod:: RequestHandler.create_template_loader
.. autoattribute:: RequestHandler.current_user
+ .. automethod:: RequestHandler.detach
.. automethod:: RequestHandler.get_browser_locale
.. automethod:: RequestHandler.get_current_user
.. automethod:: RequestHandler.get_login_url
@@ -145,9 +159,9 @@
Application configuration
- -----------------------------
- .. autoclass:: Application
- :members:
+ -------------------------
+
+ .. autoclass:: Application(handlers: Optional[List[Union[Rule, Tuple]]] = None, default_host: Optional[str] = None, transforms: Optional[List[Type[OutputTransform]]] = None, **settings)
.. attribute:: settings
@@ -184,7 +198,7 @@
`RequestHandler` object). The default implementation
writes to the `logging` module's root logger. May also be
customized by overriding `Application.log_request`.
- * ``serve_traceback``: If true, the default error page
+ * ``serve_traceback``: If ``True``, the default error page
will include the traceback of the error. This option is new in
Tornado 3.2; previously this functionality was controlled by
the ``debug`` setting.
@@ -211,7 +225,7 @@
* ``login_url``: The `authenticated` decorator will redirect
to this url if the user is not logged in. Can be further
customized by overriding `RequestHandler.get_login_url`
- * ``xsrf_cookies``: If true, :ref:`xsrf` will be enabled.
+ * ``xsrf_cookies``: If ``True``, :ref:`xsrf` will be enabled.
* ``xsrf_cookie_version``: Controls the version of new XSRF
cookies produced by this server. Should generally be left
at the default (which will always be the highest supported
@@ -265,13 +279,18 @@
should be a dictionary of keyword arguments to be passed to the
handler's ``initialize`` method.
+ .. automethod:: Application.listen
+ .. automethod:: Application.add_handlers(handlers: List[Union[Rule, Tuple]])
+ .. automethod:: Application.get_handler_delegate
+ .. automethod:: Application.reverse_url
+ .. automethod:: Application.log_request
+
.. autoclass:: URLSpec
The ``URLSpec`` class is also available under the name ``tornado.web.url``.
Decorators
----------
- .. autofunction:: asynchronous
.. autofunction:: authenticated
.. autofunction:: addslash
.. autofunction:: removeslash
diff --git a/docs/websocket.rst b/docs/websocket.rst
index 96255589be..76bc05227e 100644
--- a/docs/websocket.rst
+++ b/docs/websocket.rst
@@ -16,6 +16,7 @@
.. automethod:: WebSocketHandler.on_message
.. automethod:: WebSocketHandler.on_close
.. automethod:: WebSocketHandler.select_subprotocol
+ .. autoattribute:: WebSocketHandler.selected_subprotocol
.. automethod:: WebSocketHandler.on_ping
Output
diff --git a/docs/wsgi.rst b/docs/wsgi.rst
index a54b7aaae7..75d544aad0 100644
--- a/docs/wsgi.rst
+++ b/docs/wsgi.rst
@@ -3,17 +3,5 @@
.. automodule:: tornado.wsgi
- Running Tornado apps on WSGI servers
- ------------------------------------
-
- .. autoclass:: WSGIAdapter
- :members:
-
- .. autoclass:: WSGIApplication
- :members:
-
- Running WSGI apps on Tornado servers
- ------------------------------------
-
.. autoclass:: WSGIContainer
:members:
diff --git a/maint/README b/maint/README
index 9a9122b3b0..2ea722be39 100644
--- a/maint/README
+++ b/maint/README
@@ -1,3 +1,3 @@
This directory contains tools and scripts that are used in the development
-and maintainance of Tornado itself, but are probably not of interest to
+and maintenance of Tornado itself, but are probably not of interest to
Tornado users.
diff --git a/demos/benchmark/benchmark.py b/maint/benchmark/benchmark.py
similarity index 100%
rename from demos/benchmark/benchmark.py
rename to maint/benchmark/benchmark.py
diff --git a/demos/benchmark/chunk_benchmark.py b/maint/benchmark/chunk_benchmark.py
similarity index 100%
rename from demos/benchmark/chunk_benchmark.py
rename to maint/benchmark/chunk_benchmark.py
diff --git a/demos/benchmark/gen_benchmark.py b/maint/benchmark/gen_benchmark.py
similarity index 100%
rename from demos/benchmark/gen_benchmark.py
rename to maint/benchmark/gen_benchmark.py
diff --git a/maint/benchmark/parsing_benchmark.py b/maint/benchmark/parsing_benchmark.py
new file mode 100644
index 0000000000..d0bfcc8950
--- /dev/null
+++ b/maint/benchmark/parsing_benchmark.py
@@ -0,0 +1,112 @@
+#!/usr/bin/env python
+import re
+import timeit
+from enum import Enum
+from typing import Callable
+
+from tornado.httputil import HTTPHeaders
+from tornado.options import define, options, parse_command_line
+
+
+define("benchmark", type=str)
+define("num_runs", type=int, default=1)
+
+
+_CRLF_RE = re.compile(r"\r?\n")
+_TEST_HEADERS = (
+ "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,"
+ "image/apng,*/*;q=0.8,application/signed-exchange;v=b3\r\n"
+ "Accept-Encoding: gzip, deflate, br\r\n"
+ "Accept-Language: ru-RU,ru;q=0.9,en-US;q=0.8,en;q=0.7\r\n"
+ "Cache-Control: max-age=0\r\n"
+ "Connection: keep-alive\r\n"
+ "Host: example.com\r\n"
+ "Upgrade-Insecure-Requests: 1\r\n"
+ "User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
+ "(KHTML, like Gecko) Chrome/73.0.3683.103 Safari/537.36\r\n"
+)
+
+
+def headers_split_re(headers: str) -> None:
+ for line in _CRLF_RE.split(headers):
+ pass
+
+
+def headers_split_simple(headers: str) -> None:
+ for line in headers.split("\n"):
+ if line.endswith("\r"):
+ line = line[:-1]
+
+
+def headers_parse_re(headers: str) -> HTTPHeaders:
+ h = HTTPHeaders()
+ for line in _CRLF_RE.split(headers):
+ if line:
+ h.parse_line(line)
+ return h
+
+
+def headers_parse_simple(headers: str) -> HTTPHeaders:
+ h = HTTPHeaders()
+ for line in headers.split("\n"):
+ if line.endswith("\r"):
+ line = line[:-1]
+ if line:
+ h.parse_line(line)
+ return h
+
+
+def run_headers_split():
+ regex_time = timeit.timeit(lambda: headers_split_re(_TEST_HEADERS), number=100000)
+ print("regex", regex_time)
+
+ simple_time = timeit.timeit(
+ lambda: headers_split_simple(_TEST_HEADERS), number=100000
+ )
+ print("str.split", simple_time)
+
+ print("speedup", regex_time / simple_time)
+
+
+def run_headers_full():
+ regex_time = timeit.timeit(lambda: headers_parse_re(_TEST_HEADERS), number=10000)
+ print("regex", regex_time)
+
+ simple_time = timeit.timeit(
+ lambda: headers_parse_simple(_TEST_HEADERS), number=10000
+ )
+ print("str.split", simple_time)
+
+ print("speedup", regex_time / simple_time)
+
+
+class Benchmark(Enum):
+ def __new__(cls, arg_value: str, func: Callable[[], None]):
+ member = object.__new__(cls)
+ member._value_ = arg_value
+ member.func = func
+ return member
+
+ HEADERS_SPLIT = ("headers-split", run_headers_split)
+ HEADERS_FULL = ("headers-full", run_headers_full)
+
+
+def main():
+ parse_command_line()
+
+ try:
+ func = Benchmark(options.benchmark).func
+ except ValueError:
+ known_benchmarks = [benchmark.value for benchmark in Benchmark]
+ print(
+ "Unknown benchmark: '{}', supported values are: {}"
+ .format(options.benchmark, ", ".join(known_benchmarks))
+ )
+ return
+
+ for _ in range(options.num_runs):
+ func()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/demos/benchmark/template_benchmark.py b/maint/benchmark/template_benchmark.py
similarity index 100%
rename from demos/benchmark/template_benchmark.py
rename to maint/benchmark/template_benchmark.py
diff --git a/maint/circlerefs/circlerefs.py b/maint/circlerefs/circlerefs.py
index 5cc4e1f6de..bd8214aa82 100755
--- a/maint/circlerefs/circlerefs.py
+++ b/maint/circlerefs/circlerefs.py
@@ -7,7 +7,6 @@
increases memory footprint and CPU overhead, so we try to eliminate
circular references created by normal operation.
"""
-from __future__ import print_function
import gc
import traceback
diff --git a/maint/requirements.in b/maint/requirements.in
index eeb2f4d6a3..3f704d00cc 100644
--- a/maint/requirements.in
+++ b/maint/requirements.in
@@ -1,27 +1,13 @@
# Requirements for tools used in the development of tornado.
-# This list is for python 3.5; for 2.7 add:
-# - futures
-# - mock
#
-# Use virtualenv instead of venv; tox seems to get confused otherwise.
+# This mainly contains tools that should be installed for editor integration.
+# Other tools we use are installed only via tox or CI scripts.
#
-# maint/requirements.txt contains the pinned versions of all direct and
-# indirect dependencies; this file only contains direct dependencies
-# and is useful for upgrading.
+# This is a manual recreation of the lockfile pattern: maint/requirements.txt
+# is the lockfile, and maint/requirements.in is the input file containing only
+# direct dependencies.
-# Tornado's optional dependencies
-Twisted
-pycares
-pycurl
-
-# Other useful tools
-Sphinx
-autopep8
-coverage
+black
flake8
-pep8
-pyflakes
-sphinx-rtd-theme
+mypy
tox
-twine
-virtualenv
diff --git a/maint/requirements.txt b/maint/requirements.txt
index eaa02ea57b..0ef29bc3bc 100644
--- a/maint/requirements.txt
+++ b/maint/requirements.txt
@@ -1,44 +1,34 @@
-alabaster==0.7.10
-attrs==17.4.0
-Automat==0.6.0
-autopep8==1.3.4
-Babel==2.5.3
-certifi==2018.1.18
-chardet==3.0.4
-constantly==15.1.0
-coverage==4.5.1
-docutils==0.14
-flake8==3.5.0
-hyperlink==18.0.0
-idna==2.6
-imagesize==1.0.0
-incremental==17.5.0
-Jinja2==2.10
-MarkupSafe==1.0
+# Requirements for tools used in the development of tornado.
+#
+# This mainly contains tools that should be installed for editor integration.
+# Other tools we use are installed only via tox or CI scripts.
+# This is a manual recreation of the lockfile pattern: maint/requirements.txt
+# is the lockfile, and maint/requirements.in is the input file containing only
+# direct dependencies.
+
+black==20.8b1
+flake8==3.8.4
+mypy==0.790
+tox==3.20.1
+## The following requirements were added by pip freeze:
+appdirs==1.4.4
+attrs==20.2.0
+click==7.1.2
+colorama==0.4.4
+distlib==0.3.1
+filelock==3.0.12
mccabe==0.6.1
-packaging==17.1
-pep8==1.7.1
-pkginfo==1.4.2
-pluggy==0.6.0
-py==1.5.2
-pycares==2.3.0
-pycodestyle==2.3.1
-pycurl==7.43.0.1
-pyflakes==1.6.0
-Pygments==2.2.0
-pyparsing==2.2.0
-pytz==2018.3
-requests==2.18.4
-requests-toolbelt==0.8.0
-six==1.11.0
-snowballstemmer==1.2.1
-Sphinx==1.7.1
-sphinx-rtd-theme==0.2.4
-sphinxcontrib-websupport==1.0.1
-tox==2.9.1
-tqdm==4.19.8
-twine==1.10.0
-Twisted==17.9.0
-urllib3==1.22
-virtualenv==15.1.0
-zope.interface==4.4.3
+mypy-extensions==0.4.3
+packaging==20.4
+pathspec==0.8.0
+pluggy==0.13.1
+py==1.9.0
+pycodestyle==2.6.0
+pyflakes==2.2.0
+pyparsing==2.4.7
+regex==2020.10.28
+six==1.15.0
+toml==0.10.1
+typed-ast==1.4.1
+typing-extensions==3.7.4.3
+virtualenv==20.1.0
diff --git a/maint/scripts/test_resolvers.py b/maint/scripts/test_resolvers.py
index 2a466c1ac9..82dec30e66 100755
--- a/maint/scripts/test_resolvers.py
+++ b/maint/scripts/test_resolvers.py
@@ -1,6 +1,4 @@
#!/usr/bin/env python
-from __future__ import print_function
-
import pprint
import socket
diff --git a/maint/test/appengine/README b/maint/test/appengine/README
deleted file mode 100644
index 8d534f28d1..0000000000
--- a/maint/test/appengine/README
+++ /dev/null
@@ -1,8 +0,0 @@
-Unit test support for app engine. Currently very limited as most of
-our tests depend on direct network access, but these tests ensure that the
-modules that are supposed to work on app engine don't depend on any
-forbidden modules.
-
-The code lives in maint/appengine/common, but should be run from the py25
-or py27 subdirectories (which contain an app.yaml and a bunch of symlinks).
-runtests.py is the entry point; cgi_runtests.py is used internally.
diff --git a/maint/test/appengine/common/cgi_runtests.py b/maint/test/appengine/common/cgi_runtests.py
deleted file mode 100755
index f28aa6ca01..0000000000
--- a/maint/test/appengine/common/cgi_runtests.py
+++ /dev/null
@@ -1,59 +0,0 @@
-#!/usr/bin/env python
-from __future__ import absolute_import, division, print_function
-
-import sys
-import unittest
-
-# Most of our tests depend on IOLoop, which is not usable on app engine.
-# Run the tests that work, and check that everything else is at least
-# importable (via tornado.test.import_test)
-TEST_MODULES = [
- 'tornado.httputil.doctests',
- 'tornado.iostream.doctests',
- 'tornado.util.doctests',
- #'tornado.test.auth_test',
- #'tornado.test.concurrent_test',
- #'tornado.test.curl_httpclient_test',
- 'tornado.test.escape_test',
- #'tornado.test.gen_test',
- #'tornado.test.httpclient_test',
- #'tornado.test.httpserver_test',
- 'tornado.test.httputil_test',
- 'tornado.test.import_test',
- #'tornado.test.ioloop_test',
- #'tornado.test.iostream_test',
- 'tornado.test.locale_test',
- #'tornado.test.netutil_test',
- #'tornado.test.log_test',
- 'tornado.test.options_test',
- #'tornado.test.process_test',
- #'tornado.test.simple_httpclient_test',
- #'tornado.test.stack_context_test',
- 'tornado.test.template_test',
- #'tornado.test.testing_test',
- #'tornado.test.twisted_test',
- 'tornado.test.util_test',
- #'tornado.test.web_test',
- #'tornado.test.websocket_test',
- #'tornado.test.wsgi_test',
-]
-
-
-def all():
- return unittest.defaultTestLoader.loadTestsFromNames(TEST_MODULES)
-
-
-def main():
- print("Content-Type: text/plain\r\n\r\n", end="")
-
- try:
- unittest.main(defaultTest='all', argv=sys.argv[:1])
- except SystemExit as e:
- if e.code == 0:
- print("PASS")
- else:
- raise
-
-
-if __name__ == '__main__':
- main()
diff --git a/maint/test/appengine/common/runtests.py b/maint/test/appengine/common/runtests.py
deleted file mode 100755
index ca7abe119f..0000000000
--- a/maint/test/appengine/common/runtests.py
+++ /dev/null
@@ -1,58 +0,0 @@
-#!/usr/bin/env python
-from __future__ import absolute_import, division, print_function
-
-import contextlib
-import errno
-import os
-import random
-import signal
-import socket
-import subprocess
-import sys
-import time
-import urllib2
-
-try:
- xrange
-except NameError:
- xrange = range
-
-if __name__ == "__main__":
- tornado_root = os.path.abspath(os.path.join(os.path.dirname(__file__),
- '../../..'))
- # dev_appserver doesn't seem to set SO_REUSEADDR
- port = random.randrange(10000, 11000)
- # does dev_appserver.py ever live anywhere but /usr/local/bin?
- proc = subprocess.Popen([sys.executable,
- "/usr/local/bin/dev_appserver.py",
- os.path.dirname(os.path.abspath(__file__)),
- "--port=%d" % port,
- "--skip_sdk_update_check",
- ],
- cwd=tornado_root)
-
- try:
- for i in xrange(50):
- with contextlib.closing(socket.socket()) as sock:
- err = sock.connect_ex(('localhost', port))
- if err == 0:
- break
- elif err != errno.ECONNREFUSED:
- raise Exception("Got unexpected socket error %d" % err)
- time.sleep(0.1)
- else:
- raise Exception("Server didn't start listening")
-
- resp = urllib2.urlopen("http://localhost:%d/" % port)
- print(resp.read())
- finally:
- # dev_appserver sometimes ignores SIGTERM (especially on 2.5),
- # so try a few times to kill it.
- for sig in [signal.SIGTERM, signal.SIGTERM, signal.SIGKILL]:
- os.kill(proc.pid, sig)
- res = os.waitpid(proc.pid, os.WNOHANG)
- if res != (0, 0):
- break
- time.sleep(0.1)
- else:
- os.waitpid(proc.pid, 0)
diff --git a/maint/test/appengine/py27/app.yaml b/maint/test/appengine/py27/app.yaml
deleted file mode 100644
index e5dea072da..0000000000
--- a/maint/test/appengine/py27/app.yaml
+++ /dev/null
@@ -1,9 +0,0 @@
-application: tornado-tests-appengine27
-version: 1
-runtime: python27
-threadsafe: false
-api_version: 1
-
-handlers:
-- url: /
- script: cgi_runtests.py
\ No newline at end of file
diff --git a/maint/test/appengine/py27/cgi_runtests.py b/maint/test/appengine/py27/cgi_runtests.py
deleted file mode 120000
index a9fc90e99c..0000000000
--- a/maint/test/appengine/py27/cgi_runtests.py
+++ /dev/null
@@ -1 +0,0 @@
-../common/cgi_runtests.py
\ No newline at end of file
diff --git a/maint/test/appengine/py27/runtests.py b/maint/test/appengine/py27/runtests.py
deleted file mode 120000
index 2cce26b0fb..0000000000
--- a/maint/test/appengine/py27/runtests.py
+++ /dev/null
@@ -1 +0,0 @@
-../common/runtests.py
\ No newline at end of file
diff --git a/maint/test/appengine/py27/tornado b/maint/test/appengine/py27/tornado
deleted file mode 120000
index d4f6cc317d..0000000000
--- a/maint/test/appengine/py27/tornado
+++ /dev/null
@@ -1 +0,0 @@
-../../../../tornado
\ No newline at end of file
diff --git a/maint/test/appengine/setup.py b/maint/test/appengine/setup.py
deleted file mode 100644
index 5d2d3141d2..0000000000
--- a/maint/test/appengine/setup.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# Dummy setup file to make tox happy. In the appengine world things aren't
-# installed through setup.py
-import distutils.core
-distutils.core.setup()
diff --git a/maint/test/appengine/tox.ini b/maint/test/appengine/tox.ini
deleted file mode 100644
index ca7a861aee..0000000000
--- a/maint/test/appengine/tox.ini
+++ /dev/null
@@ -1,15 +0,0 @@
-# App Engine tests require the SDK to be installed separately.
-# Version 1.6.1 or newer is required (older versions don't work when
-# python is run from a virtualenv)
-#
-# These are currently excluded from the main tox.ini because their
-# logs are spammy and they're a little flaky.
-[tox]
-envlist = py27-appengine
-
-[testenv]
-changedir = {toxworkdir}
-
-[testenv:py27-appengine]
-basepython = python2.7
-commands = python {toxinidir}/py27/runtests.py {posargs:}
diff --git a/maint/test/cython/tox.ini b/maint/test/cython/tox.ini
index bbf8f15748..c79ab7db5e 100644
--- a/maint/test/cython/tox.ini
+++ b/maint/test/cython/tox.ini
@@ -1,6 +1,6 @@
[tox]
# This currently segfaults on pypy.
-envlist = py27,py35,py36
+envlist = py27,py36
[testenv]
deps =
@@ -13,5 +13,4 @@ commands = python -m unittest cythonapp_test
# defaults for the others.
basepython =
py27: python2.7
- py35: python3.5
py36: python3.6
diff --git a/maint/test/mypy/.gitignore b/maint/test/mypy/.gitignore
new file mode 100644
index 0000000000..dc3112749e
--- /dev/null
+++ b/maint/test/mypy/.gitignore
@@ -0,0 +1 @@
+UNKNOWN.egg-info
diff --git a/maint/test/mypy/bad.py b/maint/test/mypy/bad.py
new file mode 100644
index 0000000000..3e6b6342e2
--- /dev/null
+++ b/maint/test/mypy/bad.py
@@ -0,0 +1,6 @@
+from tornado.web import RequestHandler
+
+
+class MyHandler(RequestHandler):
+ def get(self) -> str: # Deliberate type error
+ return "foo"
diff --git a/maint/test/mypy/good.py b/maint/test/mypy/good.py
new file mode 100644
index 0000000000..5ee2d3ddcb
--- /dev/null
+++ b/maint/test/mypy/good.py
@@ -0,0 +1,11 @@
+from tornado import gen
+from tornado.web import RequestHandler
+
+
+class MyHandler(RequestHandler):
+ def get(self) -> None:
+ self.write("foo")
+
+ async def post(self) -> None:
+ await gen.sleep(1)
+ self.write("foo")
diff --git a/maint/test/mypy/setup.py b/maint/test/mypy/setup.py
new file mode 100644
index 0000000000..606849326a
--- /dev/null
+++ b/maint/test/mypy/setup.py
@@ -0,0 +1,3 @@
+from setuptools import setup
+
+setup()
diff --git a/maint/test/mypy/tox.ini b/maint/test/mypy/tox.ini
new file mode 100644
index 0000000000..42235252d6
--- /dev/null
+++ b/maint/test/mypy/tox.ini
@@ -0,0 +1,14 @@
+# Test that the py.typed marker file is respected and client
+# application code can be typechecked using tornado's published
+# annotations.
+[tox]
+envlist = py37
+
+[testenv]
+deps =
+ ../../..
+ mypy
+whitelist_externals = /bin/sh
+commands =
+ mypy good.py
+ /bin/sh -c '! mypy bad.py'
diff --git a/maint/test/pyuv/tox.ini b/maint/test/pyuv/tox.ini
deleted file mode 100644
index 8b6d569a7a..0000000000
--- a/maint/test/pyuv/tox.ini
+++ /dev/null
@@ -1,13 +0,0 @@
-[tox]
-envlist = py27
-setupdir = ../../..
-
-[testenv]
-commands =
- python -m tornado.test.runtests --ioloop=tornaduv.UVLoop {posargs:}
-# twisted tests don't work on pyuv IOLoop currently.
-deps =
- pyuv
- tornaduv
- futures
- mock
diff --git a/maint/test/redbot/red_test.py b/maint/test/redbot/red_test.py
index 055b479565..ac4b5ad25b 100755
--- a/maint/test/redbot/red_test.py
+++ b/maint/test/redbot/red_test.py
@@ -248,7 +248,7 @@ def get_app(self):
return Application(self.get_handlers(), gzip=True, **self.get_app_kwargs())
def get_allowed_errors(self):
- return super(GzipHTTPTest, self).get_allowed_errors() + [
+ return super().get_allowed_errors() + [
# TODO: The Etag is supposed to change when Content-Encoding is
# used. This should be fixed, but it's difficult to do with the
# way GZipContentEncoding fits into the pipeline, and in practice
diff --git a/maint/test/websocket/fuzzingclient.json b/maint/test/websocket/fuzzingclient.json
index 7b4cb318cf..2ac091f37a 100644
--- a/maint/test/websocket/fuzzingclient.json
+++ b/maint/test/websocket/fuzzingclient.json
@@ -1,17 +1,42 @@
{
- "options": {"failByDrop": false},
- "outdir": "./reports/servers",
-
- "servers": [
- {"agent": "Tornado/py27", "url": "ws://localhost:9001",
- "options": {"version": 18}},
- {"agent": "Tornado/py35", "url": "ws://localhost:9002",
- "options": {"version": 18}},
- {"agent": "Tornado/pypy", "url": "ws://localhost:9003",
- "options": {"version": 18}}
- ],
-
- "cases": ["*"],
- "exclude-cases": ["9.*", "12.*.1","12.2.*", "12.3.*", "12.4.*", "12.5.*", "13.*.1"],
- "exclude-agent-cases": {}
-}
+ "options": {
+ "failByDrop": false
+ },
+ "outdir": "./reports/servers",
+ "servers": [
+ {
+ "agent": "Tornado/py27",
+ "url": "ws://localhost:9001",
+ "options": {
+ "version": 18
+ }
+ },
+ {
+ "agent": "Tornado/py39",
+ "url": "ws://localhost:9002",
+ "options": {
+ "version": 18
+ }
+ },
+ {
+ "agent": "Tornado/pypy",
+ "url": "ws://localhost:9003",
+ "options": {
+ "version": 18
+ }
+ }
+ ],
+ "cases": [
+ "*"
+ ],
+ "exclude-cases": [
+ "9.*",
+ "12.*.1",
+ "12.2.*",
+ "12.3.*",
+ "12.4.*",
+ "12.5.*",
+ "13.*.1"
+ ],
+ "exclude-agent-cases": {}
+}
\ No newline at end of file
diff --git a/maint/test/websocket/run-client.sh b/maint/test/websocket/run-client.sh
index bd35f4dca8..f32e72aff9 100755
--- a/maint/test/websocket/run-client.sh
+++ b/maint/test/websocket/run-client.sh
@@ -10,7 +10,7 @@ FUZZING_SERVER_PID=$!
sleep 1
.tox/py27/bin/python client.py --name='Tornado/py27'
-.tox/py35/bin/python client.py --name='Tornado/py35'
+.tox/py39/bin/python client.py --name='Tornado/py39'
.tox/pypy/bin/python client.py --name='Tornado/pypy'
kill $FUZZING_SERVER_PID
diff --git a/maint/test/websocket/run-server.sh b/maint/test/websocket/run-server.sh
index 2a83871366..401795a005 100755
--- a/maint/test/websocket/run-server.sh
+++ b/maint/test/websocket/run-server.sh
@@ -15,8 +15,8 @@ tox
.tox/py27/bin/python server.py --port=9001 &
PY27_SERVER_PID=$!
-.tox/py35/bin/python server.py --port=9002 &
-PY35_SERVER_PID=$!
+.tox/py39/bin/python server.py --port=9002 &
+PY39_SERVER_PID=$!
.tox/pypy/bin/python server.py --port=9003 &
PYPY_SERVER_PID=$!
@@ -26,7 +26,7 @@ sleep 1
.tox/py27/bin/wstest -m fuzzingclient
kill $PY27_SERVER_PID
-kill $PY35_SERVER_PID
+kill $PY39_SERVER_PID
kill $PYPY_SERVER_PID
wait
diff --git a/maint/test/websocket/tox.ini b/maint/test/websocket/tox.ini
index 289d127b10..7c4b72ebc6 100644
--- a/maint/test/websocket/tox.ini
+++ b/maint/test/websocket/tox.ini
@@ -2,7 +2,7 @@
# to install autobahn and build the speedups module.
# See run.sh for the real test runner.
[tox]
-envlist = py27, py35, pypy
+envlist = py27, py39, pypy
setupdir=../../..
[testenv]
diff --git a/maint/vm/windows/bootstrap.py b/maint/vm/windows/bootstrap.py
index 281bf573cf..9bfb5c7230 100755
--- a/maint/vm/windows/bootstrap.py
+++ b/maint/vm/windows/bootstrap.py
@@ -21,7 +21,6 @@
To run under cygwin (which must be installed separately), run
cd /cygdrive/e; python -m tornado.test.runtests
"""
-from __future__ import absolute_import, division, print_function
import os
import subprocess
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000000..a365d2c33c
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,15 @@
+[metadata]
+license_file = LICENSE
+
+[mypy]
+python_version = 3.6
+no_implicit_optional = True
+
+[mypy-tornado.*,tornado.platform.*]
+disallow_untyped_defs = True
+
+# It's generally too tedious to require type annotations in tests, but
+# we do want to type check them as much as type inference allows.
+[mypy-tornado.test.*]
+disallow_untyped_defs = False
+check_untyped_defs = True
diff --git a/setup.py b/setup.py
index be3e3d1b08..7c9b35c5af 100644
--- a/setup.py
+++ b/setup.py
@@ -13,9 +13,10 @@
# License for the specific language governing permissions and limitations
# under the License.
+# type: ignore
+
import os
import platform
-import ssl
import sys
import warnings
@@ -80,11 +81,16 @@ def run(self):
build_ext.run(self)
except Exception:
e = sys.exc_info()[1]
- sys.stdout.write('%s\n' % str(e))
- warnings.warn(self.warning_message % ("Extension modules",
- "There was an issue with "
- "your platform configuration"
- " - see above."))
+ sys.stdout.write("%s\n" % str(e))
+ warnings.warn(
+ self.warning_message
+ % (
+ "Extension modules",
+ "There was an issue with "
+ "your platform configuration"
+ " - see above.",
+ )
+ )
def build_extension(self, ext):
name = ext.name
@@ -92,60 +98,48 @@ def build_extension(self, ext):
build_ext.build_extension(self, ext)
except Exception:
e = sys.exc_info()[1]
- sys.stdout.write('%s\n' % str(e))
- warnings.warn(self.warning_message % ("The %s extension "
- "module" % (name,),
- "The output above "
- "this warning shows how "
- "the compilation "
- "failed."))
+ sys.stdout.write("%s\n" % str(e))
+ warnings.warn(
+ self.warning_message
+ % (
+ "The %s extension " "module" % (name,),
+ "The output above "
+ "this warning shows how "
+ "the compilation "
+ "failed.",
+ )
+ )
kwargs = {}
-version = "5.1.dev1"
+with open("tornado/__init__.py") as f:
+ ns = {}
+ exec(f.read(), ns)
+ version = ns["version"]
-with open('README.rst') as f:
- kwargs['long_description'] = f.read()
+with open("README.rst") as f:
+ kwargs["long_description"] = f.read()
-if (platform.python_implementation() == 'CPython' and
- os.environ.get('TORNADO_EXTENSION') != '0'):
+if (
+ platform.python_implementation() == "CPython"
+ and os.environ.get("TORNADO_EXTENSION") != "0"
+):
# This extension builds and works on pypy as well, although pypy's jit
# produces equivalent performance.
- kwargs['ext_modules'] = [
- Extension('tornado.speedups',
- sources=['tornado/speedups.c']),
+ kwargs["ext_modules"] = [
+ Extension("tornado.speedups", sources=["tornado/speedups.c"])
]
- if os.environ.get('TORNADO_EXTENSION') != '1':
+ if os.environ.get("TORNADO_EXTENSION") != "1":
# Unless the user has specified that the extension is mandatory,
# fall back to the pure-python implementation on any build failure.
- kwargs['cmdclass'] = {'build_ext': custom_build_ext}
+ kwargs["cmdclass"] = {"build_ext": custom_build_ext}
if setuptools is not None:
- # If setuptools is not available, you're on your own for dependencies.
- install_requires = []
- if sys.version_info < (3, 2):
- install_requires.append('futures')
- if sys.version_info < (3, 4):
- install_requires.append('singledispatch')
- if sys.version_info < (3, 5):
- install_requires.append('backports_abc>=0.4')
- kwargs['install_requires'] = install_requires
-
- python_requires = '>= 2.7, !=3.0.*, !=3.1.*, !=3.2.*, != 3.3.*'
- kwargs['python_requires'] = python_requires
-
-# Verify that the SSL module has all the modern upgrades. Check for several
-# names individually since they were introduced at different versions,
-# although they should all be present by Python 3.4 or 2.7.9.
-if (not hasattr(ssl, 'SSLContext') or
- not hasattr(ssl, 'create_default_context') or
- not hasattr(ssl, 'match_hostname')):
- raise ImportError("Tornado requires an up-to-date SSL module. This means "
- "Python 2.7.9+ or 3.4+ (although some distributions have "
- "backported the necessary changes to older versions).")
+ python_requires = ">= 3.6"
+ kwargs["python_requires"] = python_requires
setup(
name="tornado",
@@ -155,12 +149,15 @@ def build_extension(self, ext):
# data files need to be listed both here (which determines what gets
# installed) and in MANIFEST.in (which determines what gets included
# in the sdist tarball)
+ "tornado": ["py.typed"],
"tornado.test": [
"README",
"csv_translations/fr_FR.csv",
"gettext_translations/fr_FR/LC_MESSAGES/tornado_test.mo",
"gettext_translations/fr_FR/LC_MESSAGES/tornado_test.po",
"options_test.cfg",
+ "options_test_types.cfg",
+ "options_test_types_str.cfg",
"static/robots.txt",
"static/sample.xml",
"static/sample.xml.gz",
@@ -176,18 +173,19 @@ def build_extension(self, ext):
author_email="python-tornado@googlegroups.com",
url="http://www.tornadoweb.org/",
license="http://www.apache.org/licenses/LICENSE-2.0",
- description=("Tornado is a Python web framework and asynchronous networking library,"
- " originally developed at FriendFeed."),
+ description=(
+ "Tornado is a Python web framework and asynchronous networking library,"
+ " originally developed at FriendFeed."
+ ),
classifiers=[
- 'License :: OSI Approved :: Apache Software License',
- 'Programming Language :: Python :: 2',
- 'Programming Language :: Python :: 2.7',
- 'Programming Language :: Python :: 3',
- 'Programming Language :: Python :: 3.4',
- 'Programming Language :: Python :: 3.5',
- 'Programming Language :: Python :: 3.6',
- 'Programming Language :: Python :: Implementation :: CPython',
- 'Programming Language :: Python :: Implementation :: PyPy',
+ "License :: OSI Approved :: Apache Software License",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.6",
+ "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: Implementation :: CPython",
+ "Programming Language :: Python :: Implementation :: PyPy",
],
**kwargs
)
diff --git a/tornado/__init__.py b/tornado/__init__.py
index 9a9d2202ef..7c889e2d5a 100644
--- a/tornado/__init__.py
+++ b/tornado/__init__.py
@@ -15,8 +15,6 @@
"""The Tornado web server and tools."""
-from __future__ import absolute_import, division, print_function
-
# version is a human-readable version number.
# version_info is a four-tuple for programmatic comparison. The first
@@ -24,5 +22,5 @@
# is zero for an official release, positive for a development branch,
# or negative for a release candidate or beta (after the base version
# number has been incremented)
-version = "5.1.dev1"
-version_info = (5, 1, 0, -100)
+version = "6.2.dev1"
+version_info = (6, 2, 0, -100)
diff --git a/tornado/_locale_data.py b/tornado/_locale_data.py
index a2c503907d..c706230ee5 100644
--- a/tornado/_locale_data.py
+++ b/tornado/_locale_data.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
@@ -16,8 +14,6 @@
"""Data used by the tornado.locale module."""
-from __future__ import absolute_import, division, print_function
-
LOCALE_NAMES = {
"af_ZA": {"name_en": u"Afrikaans", "name": u"Afrikaans"},
"am_ET": {"name_en": u"Amharic", "name": u"አማርኛ"},
diff --git a/tornado/auth.py b/tornado/auth.py
index 0069efcb48..d1cf29b39d 100644
--- a/tornado/auth.py
+++ b/tornado/auth.py
@@ -37,15 +37,14 @@
class GoogleOAuth2LoginHandler(tornado.web.RequestHandler,
tornado.auth.GoogleOAuth2Mixin):
- @tornado.gen.coroutine
- def get(self):
+ async def get(self):
if self.get_argument('code', False):
- user = yield self.get_authenticated_user(
+ user = await self.get_authenticated_user(
redirect_uri='http://your.site.com/auth/google',
code=self.get_argument('code'))
# Save the user with e.g. set_secure_cookie
else:
- yield self.authorize_redirect(
+ self.authorize_redirect(
redirect_uri='http://your.site.com/auth/google',
client_id=self.settings['google_oauth']['key'],
scope=['profile', 'email'],
@@ -55,93 +54,29 @@ def get(self):
.. testoutput::
:hide:
-
-.. versionchanged:: 4.0
- All of the callback interfaces in this module are now guaranteed
- to run their callback with an argument of ``None`` on error.
- Previously some functions would do this while others would simply
- terminate the request on their own. This change also ensures that
- errors are more consistently reported through the ``Future`` interfaces.
"""
-from __future__ import absolute_import, division, print_function
-
import base64
import binascii
-import functools
import hashlib
import hmac
import time
+import urllib.parse
import uuid
-import warnings
-from tornado.concurrent import (Future, return_future, chain_future,
- future_set_exc_info,
- future_set_result_unless_cancelled)
-from tornado import gen
from tornado import httpclient
from tornado import escape
from tornado.httputil import url_concat
-from tornado.log import gen_log
-from tornado.stack_context import ExceptionStackContext
-from tornado.util import unicode_type, ArgReplacer, PY3
+from tornado.util import unicode_type
+from tornado.web import RequestHandler
-if PY3:
- import urllib.parse as urlparse
- import urllib.parse as urllib_parse
- long = int
-else:
- import urlparse
- import urllib as urllib_parse
+from typing import List, Any, Dict, cast, Iterable, Union, Optional
class AuthError(Exception):
pass
-def _auth_future_to_callback(callback, future):
- try:
- result = future.result()
- except AuthError as e:
- gen_log.warning(str(e))
- result = None
- callback(result)
-
-
-def _auth_return_future(f):
- """Similar to tornado.concurrent.return_future, but uses the auth
- module's legacy callback interface.
-
- Note that when using this decorator the ``callback`` parameter
- inside the function will actually be a future.
-
- .. deprecated:: 5.1
- Will be removed in 6.0.
- """
- replacer = ArgReplacer(f, 'callback')
-
- @functools.wraps(f)
- def wrapper(*args, **kwargs):
- future = Future()
- callback, args, kwargs = replacer.replace(future, args, kwargs)
- if callback is not None:
- warnings.warn("callback arguments are deprecated, use the returned Future instead",
- DeprecationWarning)
- future.add_done_callback(
- functools.partial(_auth_future_to_callback, callback))
-
- def handle_exception(typ, value, tb):
- if future.done():
- return False
- else:
- future_set_exc_info(future, (typ, value, tb))
- return True
- with ExceptionStackContext(handle_exception):
- f(*args, **kwargs)
- return future
- return wrapper
-
-
class OpenIdMixin(object):
"""Abstract implementation of OpenID and Attribute Exchange.
@@ -149,10 +84,12 @@ class OpenIdMixin(object):
* ``_OPENID_ENDPOINT``: the identity provider's URI.
"""
- @return_future
- def authenticate_redirect(self, callback_uri=None,
- ax_attrs=["name", "email", "language", "username"],
- callback=None):
+
+ def authenticate_redirect(
+ self,
+ callback_uri: Optional[str] = None,
+ ax_attrs: List[str] = ["name", "email", "language", "username"],
+ ) -> None:
"""Redirects to the authentication URL for this service.
After authentication, the service will redirect back to the given
@@ -163,24 +100,22 @@ def authenticate_redirect(self, callback_uri=None,
all those attributes for your app, you can request fewer with
the ax_attrs keyword argument.
- .. versionchanged:: 3.1
- Returns a `.Future` and takes an optional callback. These are
- not strictly necessary as this method is synchronous,
- but they are supplied for consistency with
- `OAuthMixin.authorize_redirect`.
-
- .. deprecated:: 5.1
+ .. versionchanged:: 6.0
- The ``callback`` argument and returned awaitable will be removed
- in Tornado 6.0; this will be an ordinary synchronous function.
+ The ``callback`` argument was removed and this method no
+ longer returns an awaitable object. It is now an ordinary
+ synchronous function.
"""
- callback_uri = callback_uri or self.request.uri
+ handler = cast(RequestHandler, self)
+ callback_uri = callback_uri or handler.request.uri
+ assert callback_uri is not None
args = self._openid_args(callback_uri, ax_attrs=ax_attrs)
- self.redirect(self._OPENID_ENDPOINT + "?" + urllib_parse.urlencode(args))
- callback()
+ endpoint = self._OPENID_ENDPOINT # type: ignore
+ handler.redirect(endpoint + "?" + urllib.parse.urlencode(args))
- @_auth_return_future
- def get_authenticated_user(self, callback, http_client=None):
+ async def get_authenticated_user(
+ self, http_client: Optional[httpclient.AsyncHTTPClient] = None
+ ) -> Dict[str, Any]:
"""Fetches the authenticated user data upon redirect.
This method should be called by the handler that receives the
@@ -191,51 +126,60 @@ def get_authenticated_user(self, callback, http_client=None):
The result of this method will generally be used to set a cookie.
- .. deprecated:: 5.1
+ .. versionchanged:: 6.0
- The ``callback`` argument is deprecated and will be removed in 6.0.
- Use the returned awaitable object instead.
+ The ``callback`` argument was removed. Use the returned
+ awaitable object instead.
"""
+ handler = cast(RequestHandler, self)
# Verify the OpenID response via direct request to the OP
- args = dict((k, v[-1]) for k, v in self.request.arguments.items())
+ args = dict(
+ (k, v[-1]) for k, v in handler.request.arguments.items()
+ ) # type: Dict[str, Union[str, bytes]]
args["openid.mode"] = u"check_authentication"
- url = self._OPENID_ENDPOINT
+ url = self._OPENID_ENDPOINT # type: ignore
if http_client is None:
http_client = self.get_auth_http_client()
- fut = http_client.fetch(url, method="POST", body=urllib_parse.urlencode(args))
- fut.add_done_callback(functools.partial(
- self._on_authentication_verified, callback))
-
- def _openid_args(self, callback_uri, ax_attrs=[], oauth_scope=None):
- url = urlparse.urljoin(self.request.full_url(), callback_uri)
+ resp = await http_client.fetch(
+ url, method="POST", body=urllib.parse.urlencode(args)
+ )
+ return self._on_authentication_verified(resp)
+
+ def _openid_args(
+ self,
+ callback_uri: str,
+ ax_attrs: Iterable[str] = [],
+ oauth_scope: Optional[str] = None,
+ ) -> Dict[str, str]:
+ handler = cast(RequestHandler, self)
+ url = urllib.parse.urljoin(handler.request.full_url(), callback_uri)
args = {
"openid.ns": "http://specs.openid.net/auth/2.0",
- "openid.claimed_id":
- "http://specs.openid.net/auth/2.0/identifier_select",
- "openid.identity":
- "http://specs.openid.net/auth/2.0/identifier_select",
+ "openid.claimed_id": "http://specs.openid.net/auth/2.0/identifier_select",
+ "openid.identity": "http://specs.openid.net/auth/2.0/identifier_select",
"openid.return_to": url,
- "openid.realm": urlparse.urljoin(url, '/'),
+ "openid.realm": urllib.parse.urljoin(url, "/"),
"openid.mode": "checkid_setup",
}
if ax_attrs:
- args.update({
- "openid.ns.ax": "http://openid.net/srv/ax/1.0",
- "openid.ax.mode": "fetch_request",
- })
+ args.update(
+ {
+ "openid.ns.ax": "http://openid.net/srv/ax/1.0",
+ "openid.ax.mode": "fetch_request",
+ }
+ )
ax_attrs = set(ax_attrs)
- required = []
+ required = [] # type: List[str]
if "name" in ax_attrs:
ax_attrs -= set(["name", "firstname", "fullname", "lastname"])
required += ["firstname", "fullname", "lastname"]
- args.update({
- "openid.ax.type.firstname":
- "http://axschema.org/namePerson/first",
- "openid.ax.type.fullname":
- "http://axschema.org/namePerson",
- "openid.ax.type.lastname":
- "http://axschema.org/namePerson/last",
- })
+ args.update(
+ {
+ "openid.ax.type.firstname": "http://axschema.org/namePerson/first",
+ "openid.ax.type.fullname": "http://axschema.org/namePerson",
+ "openid.ax.type.lastname": "http://axschema.org/namePerson/last",
+ }
+ )
known_attrs = {
"email": "http://axschema.org/contact/email",
"language": "http://axschema.org/pref/language",
@@ -246,47 +190,45 @@ def _openid_args(self, callback_uri, ax_attrs=[], oauth_scope=None):
required.append(name)
args["openid.ax.required"] = ",".join(required)
if oauth_scope:
- args.update({
- "openid.ns.oauth":
- "http://specs.openid.net/extensions/oauth/1.0",
- "openid.oauth.consumer": self.request.host.split(":")[0],
- "openid.oauth.scope": oauth_scope,
- })
+ args.update(
+ {
+ "openid.ns.oauth": "http://specs.openid.net/extensions/oauth/1.0",
+ "openid.oauth.consumer": handler.request.host.split(":")[0],
+ "openid.oauth.scope": oauth_scope,
+ }
+ )
return args
- def _on_authentication_verified(self, future, response_fut):
- try:
- response = response_fut.result()
- except Exception as e:
- future.set_exception(AuthError(
- "Error response %s" % e))
- return
+ def _on_authentication_verified(
+ self, response: httpclient.HTTPResponse
+ ) -> Dict[str, Any]:
+ handler = cast(RequestHandler, self)
if b"is_valid:true" not in response.body:
- future.set_exception(AuthError(
- "Invalid OpenID response: %s" % response.body))
- return
+ raise AuthError("Invalid OpenID response: %r" % response.body)
# Make sure we got back at least an email from attribute exchange
ax_ns = None
- for name in self.request.arguments:
- if name.startswith("openid.ns.") and \
- self.get_argument(name) == u"http://openid.net/srv/ax/1.0":
- ax_ns = name[10:]
+ for key in handler.request.arguments:
+ if (
+ key.startswith("openid.ns.")
+ and handler.get_argument(key) == u"http://openid.net/srv/ax/1.0"
+ ):
+ ax_ns = key[10:]
break
- def get_ax_arg(uri):
+ def get_ax_arg(uri: str) -> str:
if not ax_ns:
return u""
prefix = "openid." + ax_ns + ".type."
ax_name = None
- for name in self.request.arguments.keys():
- if self.get_argument(name) == uri and name.startswith(prefix):
- part = name[len(prefix):]
+ for name in handler.request.arguments.keys():
+ if handler.get_argument(name) == uri and name.startswith(prefix):
+ part = name[len(prefix) :]
ax_name = "openid." + ax_ns + ".value." + part
break
if not ax_name:
return u""
- return self.get_argument(ax_name, u"")
+ return handler.get_argument(ax_name, u"")
email = get_ax_arg("http://axschema.org/contact/email")
name = get_ax_arg("http://axschema.org/namePerson")
@@ -314,12 +256,12 @@ def get_ax_arg(uri):
user["locale"] = locale
if username:
user["username"] = username
- claimed_id = self.get_argument("openid.claimed_id", None)
+ claimed_id = handler.get_argument("openid.claimed_id", None)
if claimed_id:
user["claimed_id"] = claimed_id
- future_set_result_unless_cancelled(future, user)
+ return user
- def get_auth_http_client(self):
+ def get_auth_http_client(self) -> httpclient.AsyncHTTPClient:
"""Returns the `.AsyncHTTPClient` instance to be used for auth requests.
May be overridden by subclasses to use an HTTP client other than
@@ -344,9 +286,13 @@ class OAuthMixin(object):
Subclasses must also override the `_oauth_get_user_future` and
`_oauth_consumer_token` methods.
"""
- @return_future
- def authorize_redirect(self, callback_uri=None, extra_params=None,
- http_client=None, callback=None):
+
+ async def authorize_redirect(
+ self,
+ callback_uri: Optional[str] = None,
+ extra_params: Optional[Dict[str, Any]] = None,
+ http_client: Optional[httpclient.AsyncHTTPClient] = None,
+ ) -> None:
"""Redirects the user to obtain OAuth authorization for this service.
The ``callback_uri`` may be omitted if you have previously
@@ -368,35 +314,31 @@ def authorize_redirect(self, callback_uri=None, extra_params=None,
Now returns a `.Future` and takes an optional callback, for
compatibility with `.gen.coroutine`.
- .. deprecated:: 5.1
+ .. versionchanged:: 6.0
- The ``callback`` argument is deprecated and will be removed in 6.0.
- Use the returned awaitable object instead.
+ The ``callback`` argument was removed. Use the returned
+ awaitable object instead.
"""
if callback_uri and getattr(self, "_OAUTH_NO_CALLBACKS", False):
raise Exception("This service does not support oauth_callback")
if http_client is None:
http_client = self.get_auth_http_client()
+ assert http_client is not None
if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a":
- fut = http_client.fetch(
- self._oauth_request_token_url(callback_uri=callback_uri,
- extra_params=extra_params))
- fut.add_done_callback(functools.partial(
- self._on_request_token,
- self._OAUTH_AUTHORIZE_URL,
- callback_uri,
- callback))
+ response = await http_client.fetch(
+ self._oauth_request_token_url(
+ callback_uri=callback_uri, extra_params=extra_params
+ )
+ )
else:
- fut = http_client.fetch(self._oauth_request_token_url())
- fut.add_done_callback(
- functools.partial(
- self._on_request_token, self._OAUTH_AUTHORIZE_URL,
- callback_uri,
- callback))
-
- @_auth_return_future
- def get_authenticated_user(self, callback, http_client=None):
+ response = await http_client.fetch(self._oauth_request_token_url())
+ url = self._OAUTH_AUTHORIZE_URL # type: ignore
+ self._on_request_token(url, callback_uri, response)
+
+ async def get_authenticated_user(
+ self, http_client: Optional[httpclient.AsyncHTTPClient] = None
+ ) -> Dict[str, Any]:
"""Gets the OAuth authorized user and access token.
This method should be called from the handler for your
@@ -407,37 +349,47 @@ def get_authenticated_user(self, callback, http_client=None):
also contain other fields such as ``name``, depending on the service
used.
- .. deprecated:: 5.1
+ .. versionchanged:: 6.0
- The ``callback`` argument is deprecated and will be removed in 6.0.
- Use the returned awaitable object instead.
+ The ``callback`` argument was removed. Use the returned
+ awaitable object instead.
"""
- future = callback
- request_key = escape.utf8(self.get_argument("oauth_token"))
- oauth_verifier = self.get_argument("oauth_verifier", None)
- request_cookie = self.get_cookie("_oauth_request_token")
+ handler = cast(RequestHandler, self)
+ request_key = escape.utf8(handler.get_argument("oauth_token"))
+ oauth_verifier = handler.get_argument("oauth_verifier", None)
+ request_cookie = handler.get_cookie("_oauth_request_token")
if not request_cookie:
- future.set_exception(AuthError(
- "Missing OAuth request token cookie"))
- return
- self.clear_cookie("_oauth_request_token")
+ raise AuthError("Missing OAuth request token cookie")
+ handler.clear_cookie("_oauth_request_token")
cookie_key, cookie_secret = [
- base64.b64decode(escape.utf8(i)) for i in request_cookie.split("|")]
+ base64.b64decode(escape.utf8(i)) for i in request_cookie.split("|")
+ ]
if cookie_key != request_key:
- future.set_exception(AuthError(
- "Request token does not match cookie"))
- return
- token = dict(key=cookie_key, secret=cookie_secret)
+ raise AuthError("Request token does not match cookie")
+ token = dict(
+ key=cookie_key, secret=cookie_secret
+ ) # type: Dict[str, Union[str, bytes]]
if oauth_verifier:
token["verifier"] = oauth_verifier
if http_client is None:
http_client = self.get_auth_http_client()
- fut = http_client.fetch(self._oauth_access_token_url(token))
- fut.add_done_callback(functools.partial(self._on_access_token, callback))
-
- def _oauth_request_token_url(self, callback_uri=None, extra_params=None):
+ assert http_client is not None
+ response = await http_client.fetch(self._oauth_access_token_url(token))
+ access_token = _oauth_parse_response(response.body)
+ user = await self._oauth_get_user_future(access_token)
+ if not user:
+ raise AuthError("Error getting user")
+ user["access_token"] = access_token
+ return user
+
+ def _oauth_request_token_url(
+ self,
+ callback_uri: Optional[str] = None,
+ extra_params: Optional[Dict[str, Any]] = None,
+ ) -> str:
+ handler = cast(RequestHandler, self)
consumer_token = self._oauth_consumer_token()
- url = self._OAUTH_REQUEST_TOKEN_URL
+ url = self._OAUTH_REQUEST_TOKEN_URL # type: ignore
args = dict(
oauth_consumer_key=escape.to_basestring(consumer_token["key"]),
oauth_signature_method="HMAC-SHA1",
@@ -449,8 +401,9 @@ def _oauth_request_token_url(self, callback_uri=None, extra_params=None):
if callback_uri == "oob":
args["oauth_callback"] = "oob"
elif callback_uri:
- args["oauth_callback"] = urlparse.urljoin(
- self.request.full_url(), callback_uri)
+ args["oauth_callback"] = urllib.parse.urljoin(
+ handler.request.full_url(), callback_uri
+ )
if extra_params:
args.update(extra_params)
signature = _oauth10a_signature(consumer_token, "GET", url, args)
@@ -458,32 +411,35 @@ def _oauth_request_token_url(self, callback_uri=None, extra_params=None):
signature = _oauth_signature(consumer_token, "GET", url, args)
args["oauth_signature"] = signature
- return url + "?" + urllib_parse.urlencode(args)
-
- def _on_request_token(self, authorize_url, callback_uri, callback,
- response_fut):
- try:
- response = response_fut.result()
- except Exception as e:
- raise Exception("Could not get request token: %s" % e)
+ return url + "?" + urllib.parse.urlencode(args)
+
+ def _on_request_token(
+ self,
+ authorize_url: str,
+ callback_uri: Optional[str],
+ response: httpclient.HTTPResponse,
+ ) -> None:
+ handler = cast(RequestHandler, self)
request_token = _oauth_parse_response(response.body)
- data = (base64.b64encode(escape.utf8(request_token["key"])) + b"|" +
- base64.b64encode(escape.utf8(request_token["secret"])))
- self.set_cookie("_oauth_request_token", data)
+ data = (
+ base64.b64encode(escape.utf8(request_token["key"]))
+ + b"|"
+ + base64.b64encode(escape.utf8(request_token["secret"]))
+ )
+ handler.set_cookie("_oauth_request_token", data)
args = dict(oauth_token=request_token["key"])
if callback_uri == "oob":
- self.finish(authorize_url + "?" + urllib_parse.urlencode(args))
- callback()
+ handler.finish(authorize_url + "?" + urllib.parse.urlencode(args))
return
elif callback_uri:
- args["oauth_callback"] = urlparse.urljoin(
- self.request.full_url(), callback_uri)
- self.redirect(authorize_url + "?" + urllib_parse.urlencode(args))
- callback()
+ args["oauth_callback"] = urllib.parse.urljoin(
+ handler.request.full_url(), callback_uri
+ )
+ handler.redirect(authorize_url + "?" + urllib.parse.urlencode(args))
- def _oauth_access_token_url(self, request_token):
+ def _oauth_access_token_url(self, request_token: Dict[str, Any]) -> str:
consumer_token = self._oauth_consumer_token()
- url = self._OAUTH_ACCESS_TOKEN_URL
+ url = self._OAUTH_ACCESS_TOKEN_URL # type: ignore
args = dict(
oauth_consumer_key=escape.to_basestring(consumer_token["key"]),
oauth_token=escape.to_basestring(request_token["key"]),
@@ -496,41 +452,31 @@ def _oauth_access_token_url(self, request_token):
args["oauth_verifier"] = request_token["verifier"]
if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a":
- signature = _oauth10a_signature(consumer_token, "GET", url, args,
- request_token)
+ signature = _oauth10a_signature(
+ consumer_token, "GET", url, args, request_token
+ )
else:
- signature = _oauth_signature(consumer_token, "GET", url, args,
- request_token)
+ signature = _oauth_signature(
+ consumer_token, "GET", url, args, request_token
+ )
args["oauth_signature"] = signature
- return url + "?" + urllib_parse.urlencode(args)
-
- def _on_access_token(self, future, response_fut):
- try:
- response = response_fut.result()
- except Exception:
- future.set_exception(AuthError("Could not fetch access token"))
- return
+ return url + "?" + urllib.parse.urlencode(args)
- access_token = _oauth_parse_response(response.body)
- fut = self._oauth_get_user_future(access_token)
- fut = gen.convert_yielded(fut)
- fut.add_done_callback(
- functools.partial(self._on_oauth_get_user, access_token, future))
-
- def _oauth_consumer_token(self):
+ def _oauth_consumer_token(self) -> Dict[str, Any]:
"""Subclasses must override this to return their OAuth consumer keys.
The return value should be a `dict` with keys ``key`` and ``secret``.
"""
raise NotImplementedError()
- @return_future
- def _oauth_get_user_future(self, access_token, callback):
+ async def _oauth_get_user_future(
+ self, access_token: Dict[str, Any]
+ ) -> Dict[str, Any]:
"""Subclasses must override this to get basic information about the
user.
- Should return a `.Future` whose result is a dictionary
+ Should be a coroutine whose result is a dictionary
containing information about the user, which may have been
retrieved by using ``access_token`` to make a request to the
service.
@@ -538,40 +484,23 @@ def _oauth_get_user_future(self, access_token, callback):
The access token will be added to the returned dictionary to make
the result of `get_authenticated_user`.
- For backwards compatibility, the callback-based ``_oauth_get_user``
- method is also supported.
-
.. versionchanged:: 5.1
Subclasses may also define this method with ``async def``.
- .. deprecated:: 5.1
+ .. versionchanged:: 6.0
- The ``_oauth_get_user`` fallback is deprecated and support for it
- will be removed in 6.0.
+ A synchronous fallback to ``_oauth_get_user`` was removed.
"""
- warnings.warn("_oauth_get_user is deprecated, override _oauth_get_user_future instead",
- DeprecationWarning)
- # By default, call the old-style _oauth_get_user, but new code
- # should override this method instead.
- self._oauth_get_user(access_token, callback)
-
- def _oauth_get_user(self, access_token, callback):
raise NotImplementedError()
- def _on_oauth_get_user(self, access_token, future, user_future):
- if user_future.exception() is not None:
- future.set_exception(user_future.exception())
- return
- user = user_future.result()
- if not user:
- future.set_exception(AuthError("Error getting user"))
- return
- user["access_token"] = access_token
- future_set_result_unless_cancelled(future, user)
-
- def _oauth_request_parameters(self, url, access_token, parameters={},
- method="GET"):
+ def _oauth_request_parameters(
+ self,
+ url: str,
+ access_token: Dict[str, Any],
+ parameters: Dict[str, Any] = {},
+ method: str = "GET",
+ ) -> Dict[str, Any]:
"""Returns the OAuth parameters as a dict for the given request.
parameters should include all POST arguments and query string arguments
@@ -590,15 +519,17 @@ def _oauth_request_parameters(self, url, access_token, parameters={},
args.update(base_args)
args.update(parameters)
if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a":
- signature = _oauth10a_signature(consumer_token, method, url, args,
- access_token)
+ signature = _oauth10a_signature(
+ consumer_token, method, url, args, access_token
+ )
else:
- signature = _oauth_signature(consumer_token, method, url, args,
- access_token)
+ signature = _oauth_signature(
+ consumer_token, method, url, args, access_token
+ )
base_args["oauth_signature"] = escape.to_basestring(signature)
return base_args
- def get_auth_http_client(self):
+ def get_auth_http_client(self) -> httpclient.AsyncHTTPClient:
"""Returns the `.AsyncHTTPClient` instance to be used for auth requests.
May be overridden by subclasses to use an HTTP client other than
@@ -618,10 +549,16 @@ class OAuth2Mixin(object):
* ``_OAUTH_AUTHORIZE_URL``: The service's authorization url.
* ``_OAUTH_ACCESS_TOKEN_URL``: The service's access token url.
"""
- @return_future
- def authorize_redirect(self, redirect_uri=None, client_id=None,
- client_secret=None, extra_params=None,
- callback=None, scope=None, response_type="code"):
+
+ def authorize_redirect(
+ self,
+ redirect_uri: Optional[str] = None,
+ client_id: Optional[str] = None,
+ client_secret: Optional[str] = None,
+ extra_params: Optional[Dict[str, Any]] = None,
+ scope: Optional[List[str]] = None,
+ response_type: str = "code",
+ ) -> None:
"""Redirects the user to obtain OAuth authorization for this service.
Some providers require that you register a redirect URL with
@@ -630,47 +567,53 @@ def authorize_redirect(self, redirect_uri=None, client_id=None,
``get_authenticated_user`` in the handler for your
redirect URL to complete the authorization process.
- .. versionchanged:: 3.1
- Returns a `.Future` and takes an optional callback. These are
- not strictly necessary as this method is synchronous,
- but they are supplied for consistency with
- `OAuthMixin.authorize_redirect`.
-
- .. deprecated:: 5.1
+ .. versionchanged:: 6.0
- The ``callback`` argument and returned awaitable will be removed
- in Tornado 6.0; this will be an ordinary synchronous function.
+ The ``callback`` argument and returned awaitable were removed;
+ this is now an ordinary synchronous function.
"""
- args = {
- "redirect_uri": redirect_uri,
- "client_id": client_id,
- "response_type": response_type
- }
+ handler = cast(RequestHandler, self)
+ args = {"response_type": response_type}
+ if redirect_uri is not None:
+ args["redirect_uri"] = redirect_uri
+ if client_id is not None:
+ args["client_id"] = client_id
if extra_params:
args.update(extra_params)
if scope:
- args['scope'] = ' '.join(scope)
- self.redirect(
- url_concat(self._OAUTH_AUTHORIZE_URL, args))
- callback()
-
- def _oauth_request_token_url(self, redirect_uri=None, client_id=None,
- client_secret=None, code=None,
- extra_params=None):
- url = self._OAUTH_ACCESS_TOKEN_URL
- args = dict(
- redirect_uri=redirect_uri,
- code=code,
- client_id=client_id,
- client_secret=client_secret,
- )
+ args["scope"] = " ".join(scope)
+ url = self._OAUTH_AUTHORIZE_URL # type: ignore
+ handler.redirect(url_concat(url, args))
+
+ def _oauth_request_token_url(
+ self,
+ redirect_uri: Optional[str] = None,
+ client_id: Optional[str] = None,
+ client_secret: Optional[str] = None,
+ code: Optional[str] = None,
+ extra_params: Optional[Dict[str, Any]] = None,
+ ) -> str:
+ url = self._OAUTH_ACCESS_TOKEN_URL # type: ignore
+ args = {} # type: Dict[str, str]
+ if redirect_uri is not None:
+ args["redirect_uri"] = redirect_uri
+ if code is not None:
+ args["code"] = code
+ if client_id is not None:
+ args["client_id"] = client_id
+ if client_secret is not None:
+ args["client_secret"] = client_secret
if extra_params:
args.update(extra_params)
return url_concat(url, args)
- @_auth_return_future
- def oauth2_request(self, url, callback, access_token=None,
- post_args=None, **args):
+ async def oauth2_request(
+ self,
+ url: str,
+ access_token: Optional[str] = None,
+ post_args: Optional[Dict[str, Any]] = None,
+ **args: Any
+ ) -> Any:
"""Fetches the given URL auth an OAuth2 access token.
If the request is a POST, ``post_args`` should be provided. Query
@@ -683,16 +626,15 @@ def oauth2_request(self, url, callback, access_token=None,
class MainHandler(tornado.web.RequestHandler,
tornado.auth.FacebookGraphMixin):
@tornado.web.authenticated
- @tornado.gen.coroutine
- def get(self):
- new_entry = yield self.oauth2_request(
+ async def get(self):
+ new_entry = await self.oauth2_request(
"https://graph.facebook.com/me/feed",
post_args={"message": "I am posting from my Tornado application!"},
access_token=self.current_user["access_token"])
if not new_entry:
# Call failed; perhaps missing permission?
- yield self.authorize_redirect()
+ self.authorize_redirect()
return
self.finish("Posted a message!")
@@ -701,10 +643,9 @@ def get(self):
.. versionadded:: 4.3
- .. deprecated:: 5.1
+ .. versionchanged::: 6.0
- The ``callback`` argument is deprecated and will be removed in 6.0.
- Use the returned awaitable object instead.
+ The ``callback`` argument was removed. Use the returned awaitable object instead.
"""
all_args = {}
if access_token:
@@ -712,25 +653,17 @@ def get(self):
all_args.update(args)
if all_args:
- url += "?" + urllib_parse.urlencode(all_args)
- callback = functools.partial(self._on_oauth2_request, callback)
+ url += "?" + urllib.parse.urlencode(all_args)
http = self.get_auth_http_client()
if post_args is not None:
- fut = http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args))
+ response = await http.fetch(
+ url, method="POST", body=urllib.parse.urlencode(post_args)
+ )
else:
- fut = http.fetch(url)
- fut.add_done_callback(callback)
-
- def _on_oauth2_request(self, future, response_fut):
- try:
- response = response_fut.result()
- except Exception as e:
- future.set_exception(AuthError("Error response %s" % e))
- return
+ response = await http.fetch(url)
+ return escape.json_decode(response.body)
- future_set_result_unless_cancelled(future, escape.json_decode(response.body))
-
- def get_auth_http_client(self):
+ def get_auth_http_client(self) -> httpclient.AsyncHTTPClient:
"""Returns the `.AsyncHTTPClient` instance to be used for auth requests.
May be overridden by subclasses to use an HTTP client other than
@@ -758,13 +691,12 @@ class TwitterMixin(OAuthMixin):
class TwitterLoginHandler(tornado.web.RequestHandler,
tornado.auth.TwitterMixin):
- @tornado.gen.coroutine
- def get(self):
+ async def get(self):
if self.get_argument("oauth_token", None):
- user = yield self.get_authenticated_user()
+ user = await self.get_authenticated_user()
# Save the user using e.g. set_secure_cookie()
else:
- yield self.authorize_redirect()
+ await self.authorize_redirect()
.. testoutput::
:hide:
@@ -774,6 +706,7 @@ def get(self):
and all of the custom Twitter user attributes described at
https://dev.twitter.com/docs/api/1.1/get/users/show
"""
+
_OAUTH_REQUEST_TOKEN_URL = "https://api.twitter.com/oauth/request_token"
_OAUTH_ACCESS_TOKEN_URL = "https://api.twitter.com/oauth/access_token"
_OAUTH_AUTHORIZE_URL = "https://api.twitter.com/oauth/authorize"
@@ -781,8 +714,7 @@ def get(self):
_OAUTH_NO_CALLBACKS = False
_TWITTER_BASE_URL = "https://api.twitter.com/1.1"
- @return_future
- def authenticate_redirect(self, callback_uri=None, callback=None):
+ async def authenticate_redirect(self, callback_uri: Optional[str] = None) -> None:
"""Just like `~OAuthMixin.authorize_redirect`, but
auto-redirects if authorized.
@@ -793,20 +725,24 @@ def authenticate_redirect(self, callback_uri=None, callback=None):
Now returns a `.Future` and takes an optional callback, for
compatibility with `.gen.coroutine`.
- .. deprecated:: 5.1
+ .. versionchanged:: 6.0
- The ``callback`` argument is deprecated and will be removed in 6.0.
- Use the returned awaitable object instead.
+ The ``callback`` argument was removed. Use the returned
+ awaitable object instead.
"""
http = self.get_auth_http_client()
- http.fetch(self._oauth_request_token_url(callback_uri=callback_uri),
- functools.partial(
- self._on_request_token, self._OAUTH_AUTHENTICATE_URL,
- None, callback))
-
- @_auth_return_future
- def twitter_request(self, path, callback=None, access_token=None,
- post_args=None, **args):
+ response = await http.fetch(
+ self._oauth_request_token_url(callback_uri=callback_uri)
+ )
+ self._on_request_token(self._OAUTH_AUTHENTICATE_URL, None, response)
+
+ async def twitter_request(
+ self,
+ path: str,
+ access_token: Dict[str, Any],
+ post_args: Optional[Dict[str, Any]] = None,
+ **args: Any
+ ) -> Any:
"""Fetches the given API path, e.g., ``statuses/user_timeline/btaylor``
The path should not include the format or API version number.
@@ -829,27 +765,26 @@ def twitter_request(self, path, callback=None, access_token=None,
class MainHandler(tornado.web.RequestHandler,
tornado.auth.TwitterMixin):
@tornado.web.authenticated
- @tornado.gen.coroutine
- def get(self):
- new_entry = yield self.twitter_request(
+ async def get(self):
+ new_entry = await self.twitter_request(
"/statuses/update",
post_args={"status": "Testing Tornado Web Server"},
access_token=self.current_user["access_token"])
if not new_entry:
# Call failed; perhaps missing permission?
- yield self.authorize_redirect()
+ await self.authorize_redirect()
return
self.finish("Posted a message!")
.. testoutput::
:hide:
- .. deprecated:: 5.1
+ .. versionchanged:: 6.0
- The ``callback`` argument is deprecated and will be removed in 6.0.
- Use the returned awaitable object instead.
+ The ``callback`` argument was removed. Use the returned
+ awaitable object instead.
"""
- if path.startswith('http:') or path.startswith('https:'):
+ if path.startswith("http:") or path.startswith("https:"):
# Raw urls are useful for e.g. search which doesn't follow the
# usual pattern: http://search.twitter.com/search.json
url = path
@@ -862,42 +797,38 @@ def get(self):
all_args.update(post_args or {})
method = "POST" if post_args is not None else "GET"
oauth = self._oauth_request_parameters(
- url, access_token, all_args, method=method)
+ url, access_token, all_args, method=method
+ )
args.update(oauth)
if args:
- url += "?" + urllib_parse.urlencode(args)
+ url += "?" + urllib.parse.urlencode(args)
http = self.get_auth_http_client()
- http_callback = functools.partial(self._on_twitter_request, callback, url)
if post_args is not None:
- fut = http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args))
+ response = await http.fetch(
+ url, method="POST", body=urllib.parse.urlencode(post_args)
+ )
else:
- fut = http.fetch(url)
- fut.add_done_callback(http_callback)
-
- def _on_twitter_request(self, future, url, response_fut):
- try:
- response = response_fut.result()
- except Exception as e:
- future.set_exception(AuthError(
- "Error response %s fetching %s" % (e, url)))
- return
- future_set_result_unless_cancelled(future, escape.json_decode(response.body))
+ response = await http.fetch(url)
+ return escape.json_decode(response.body)
- def _oauth_consumer_token(self):
- self.require_setting("twitter_consumer_key", "Twitter OAuth")
- self.require_setting("twitter_consumer_secret", "Twitter OAuth")
+ def _oauth_consumer_token(self) -> Dict[str, Any]:
+ handler = cast(RequestHandler, self)
+ handler.require_setting("twitter_consumer_key", "Twitter OAuth")
+ handler.require_setting("twitter_consumer_secret", "Twitter OAuth")
return dict(
- key=self.settings["twitter_consumer_key"],
- secret=self.settings["twitter_consumer_secret"])
-
- @gen.coroutine
- def _oauth_get_user_future(self, access_token):
- user = yield self.twitter_request(
- "/account/verify_credentials",
- access_token=access_token)
+ key=handler.settings["twitter_consumer_key"],
+ secret=handler.settings["twitter_consumer_secret"],
+ )
+
+ async def _oauth_get_user_future(
+ self, access_token: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ user = await self.twitter_request(
+ "/account/verify_credentials", access_token=access_token
+ )
if user:
user["username"] = user["screen_name"]
- raise gen.Return(user)
+ return user
class GoogleOAuth2Mixin(OAuth2Mixin):
@@ -908,24 +839,25 @@ class GoogleOAuth2Mixin(OAuth2Mixin):
* Go to the Google Dev Console at http://console.developers.google.com
* Select a project, or create a new one.
- * In the sidebar on the left, select APIs & Auth.
- * In the list of APIs, find the Google+ API service and set it to ON.
* In the sidebar on the left, select Credentials.
- * In the OAuth section of the page, select Create New Client ID.
- * Set the Redirect URI to point to your auth handler
+ * Click CREATE CREDENTIALS and click OAuth client ID.
+ * Under Application type, select Web application.
+ * Name OAuth 2.0 client and click Create.
* Copy the "Client secret" and "Client ID" to the application settings as
- {"google_oauth": {"key": CLIENT_ID, "secret": CLIENT_SECRET}}
+ ``{"google_oauth": {"key": CLIENT_ID, "secret": CLIENT_SECRET}}``
.. versionadded:: 3.2
"""
+
_OAUTH_AUTHORIZE_URL = "https://accounts.google.com/o/oauth2/v2/auth"
_OAUTH_ACCESS_TOKEN_URL = "https://www.googleapis.com/oauth2/v4/token"
_OAUTH_USERINFO_URL = "https://www.googleapis.com/oauth2/v1/userinfo"
_OAUTH_NO_CALLBACKS = False
- _OAUTH_SETTINGS_KEY = 'google_oauth'
+ _OAUTH_SETTINGS_KEY = "google_oauth"
- @_auth_return_future
- def get_authenticated_user(self, redirect_uri, code, callback):
+ async def get_authenticated_user(
+ self, redirect_uri: str, code: str
+ ) -> Dict[str, Any]:
"""Handles the login for the Google user, returning an access token.
The result is a dictionary containing an ``access_token`` field
@@ -942,19 +874,18 @@ def get_authenticated_user(self, redirect_uri, code, callback):
class GoogleOAuth2LoginHandler(tornado.web.RequestHandler,
tornado.auth.GoogleOAuth2Mixin):
- @tornado.gen.coroutine
- def get(self):
+ async def get(self):
if self.get_argument('code', False):
- access = yield self.get_authenticated_user(
+ access = await self.get_authenticated_user(
redirect_uri='http://your.site.com/auth/google',
code=self.get_argument('code'))
- user = yield self.oauth2_request(
+ user = await self.oauth2_request(
"https://www.googleapis.com/oauth2/v1/userinfo",
access_token=access["access_token"])
# Save the user and access token with
# e.g. set_secure_cookie.
else:
- yield self.authorize_redirect(
+ self.authorize_redirect(
redirect_uri='http://your.site.com/auth/google',
client_id=self.settings['google_oauth']['key'],
scope=['profile', 'email'],
@@ -964,48 +895,47 @@ def get(self):
.. testoutput::
:hide:
- .. deprecated:: 5.1
+ .. versionchanged:: 6.0
- The ``callback`` argument is deprecated and will be removed in 6.0.
- Use the returned awaitable object instead.
+ The ``callback`` argument was removed. Use the returned awaitable object instead.
""" # noqa: E501
+ handler = cast(RequestHandler, self)
http = self.get_auth_http_client()
- body = urllib_parse.urlencode({
- "redirect_uri": redirect_uri,
- "code": code,
- "client_id": self.settings[self._OAUTH_SETTINGS_KEY]['key'],
- "client_secret": self.settings[self._OAUTH_SETTINGS_KEY]['secret'],
- "grant_type": "authorization_code",
- })
-
- fut = http.fetch(self._OAUTH_ACCESS_TOKEN_URL,
- method="POST",
- headers={'Content-Type': 'application/x-www-form-urlencoded'},
- body=body)
- fut.add_done_callback(functools.partial(self._on_access_token, callback))
-
- def _on_access_token(self, future, response_fut):
- """Callback function for the exchange to the access token."""
- try:
- response = response_fut.result()
- except Exception as e:
- future.set_exception(AuthError('Google auth error: %s' % str(e)))
- return
+ body = urllib.parse.urlencode(
+ {
+ "redirect_uri": redirect_uri,
+ "code": code,
+ "client_id": handler.settings[self._OAUTH_SETTINGS_KEY]["key"],
+ "client_secret": handler.settings[self._OAUTH_SETTINGS_KEY]["secret"],
+ "grant_type": "authorization_code",
+ }
+ )
- args = escape.json_decode(response.body)
- future_set_result_unless_cancelled(future, args)
+ response = await http.fetch(
+ self._OAUTH_ACCESS_TOKEN_URL,
+ method="POST",
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
+ body=body,
+ )
+ return escape.json_decode(response.body)
class FacebookGraphMixin(OAuth2Mixin):
"""Facebook authentication using the new Graph API and OAuth2."""
+
_OAUTH_ACCESS_TOKEN_URL = "https://graph.facebook.com/oauth/access_token?"
_OAUTH_AUTHORIZE_URL = "https://www.facebook.com/dialog/oauth?"
_OAUTH_NO_CALLBACKS = False
_FACEBOOK_BASE_URL = "https://graph.facebook.com"
- @_auth_return_future
- def get_authenticated_user(self, redirect_uri, client_id, client_secret,
- code, callback, extra_fields=None):
+ async def get_authenticated_user(
+ self,
+ redirect_uri: str,
+ client_id: str,
+ client_secret: str,
+ code: str,
+ extra_fields: Optional[Dict[str, Any]] = None,
+ ) -> Optional[Dict[str, Any]]:
"""Handles the login for the Facebook user, returning a user object.
Example usage:
@@ -1014,17 +944,16 @@ def get_authenticated_user(self, redirect_uri, client_id, client_secret,
class FacebookGraphLoginHandler(tornado.web.RequestHandler,
tornado.auth.FacebookGraphMixin):
- @tornado.gen.coroutine
- def get(self):
+ async def get(self):
if self.get_argument("code", False):
- user = yield self.get_authenticated_user(
+ user = await self.get_authenticated_user(
redirect_uri='/auth/facebookgraph/',
client_id=self.settings["facebook_api_key"],
client_secret=self.settings["facebook_secret"],
code=self.get_argument("code"))
# Save the user with e.g. set_secure_cookie
else:
- yield self.authorize_redirect(
+ self.authorize_redirect(
redirect_uri='/auth/facebookgraph/',
client_id=self.settings["facebook_api_key"],
extra_params={"scope": "read_stream,offline_access"})
@@ -1048,10 +977,9 @@ def get(self):
The ``session_expires`` field was updated to support changes made to the
Facebook API in March 2017.
- .. deprecated:: 5.1
+ .. versionchanged:: 6.0
- The ``callback`` argument is deprecated and will be removed in 6.0.
- Use the returned awaitable object instead.
+ The ``callback`` argument was removed. Use the returned awaitable object instead.
"""
http = self.get_auth_http_client()
args = {
@@ -1061,42 +989,35 @@ def get(self):
"client_secret": client_secret,
}
- fields = set(['id', 'name', 'first_name', 'last_name',
- 'locale', 'picture', 'link'])
+ fields = set(
+ ["id", "name", "first_name", "last_name", "locale", "picture", "link"]
+ )
if extra_fields:
fields.update(extra_fields)
- fut = http.fetch(self._oauth_request_token_url(**args))
- fut.add_done_callback(functools.partial(self._on_access_token, redirect_uri, client_id,
- client_secret, callback, fields))
-
- @gen.coroutine
- def _on_access_token(self, redirect_uri, client_id, client_secret,
- future, fields, response_fut):
- try:
- response = response_fut.result()
- except Exception as e:
- future.set_exception(AuthError('Facebook auth error: %s' % str(e)))
- return
-
+ response = await http.fetch(
+ self._oauth_request_token_url(**args) # type: ignore
+ )
args = escape.json_decode(response.body)
session = {
"access_token": args.get("access_token"),
- "expires_in": args.get("expires_in")
+ "expires_in": args.get("expires_in"),
}
+ assert session["access_token"] is not None
- user = yield self.facebook_request(
+ user = await self.facebook_request(
path="/me",
access_token=session["access_token"],
- appsecret_proof=hmac.new(key=client_secret.encode('utf8'),
- msg=session["access_token"].encode('utf8'),
- digestmod=hashlib.sha256).hexdigest(),
- fields=",".join(fields)
+ appsecret_proof=hmac.new(
+ key=client_secret.encode("utf8"),
+ msg=session["access_token"].encode("utf8"),
+ digestmod=hashlib.sha256,
+ ).hexdigest(),
+ fields=",".join(fields),
)
if user is None:
- future_set_result_unless_cancelled(future, None)
- return
+ return None
fieldmap = {}
for field in fields:
@@ -1106,13 +1027,21 @@ def _on_access_token(self, redirect_uri, client_id, client_secret,
# older versions in which the server used url-encoding and
# this code simply returned the string verbatim.
# This should change in Tornado 5.0.
- fieldmap.update({"access_token": session["access_token"],
- "session_expires": str(session.get("expires_in"))})
- future_set_result_unless_cancelled(future, fieldmap)
-
- @_auth_return_future
- def facebook_request(self, path, callback, access_token=None,
- post_args=None, **args):
+ fieldmap.update(
+ {
+ "access_token": session["access_token"],
+ "session_expires": str(session.get("expires_in")),
+ }
+ )
+ return fieldmap
+
+ async def facebook_request(
+ self,
+ path: str,
+ access_token: Optional[str] = None,
+ post_args: Optional[Dict[str, Any]] = None,
+ **args: Any
+ ) -> Any:
"""Fetches the given relative API path, e.g., "/btaylor/picture"
If the request is a POST, ``post_args`` should be provided. Query
@@ -1134,16 +1063,15 @@ def facebook_request(self, path, callback, access_token=None,
class MainHandler(tornado.web.RequestHandler,
tornado.auth.FacebookGraphMixin):
@tornado.web.authenticated
- @tornado.gen.coroutine
- def get(self):
- new_entry = yield self.facebook_request(
+ async def get(self):
+ new_entry = await self.facebook_request(
"/me/feed",
post_args={"message": "I am posting from my Tornado application!"},
access_token=self.current_user["access_token"])
if not new_entry:
# Call failed; perhaps missing permission?
- yield self.authorize_redirect()
+ self.authorize_redirect()
return
self.finish("Posted a message!")
@@ -1160,35 +1088,39 @@ def get(self):
.. versionchanged:: 3.1
Added the ability to override ``self._FACEBOOK_BASE_URL``.
- .. deprecated:: 5.1
+ .. versionchanged:: 6.0
- The ``callback`` argument is deprecated and will be removed in 6.0.
- Use the returned awaitable object instead.
+ The ``callback`` argument was removed. Use the returned awaitable object instead.
"""
url = self._FACEBOOK_BASE_URL + path
- # Thanks to the _auth_return_future decorator, our "callback"
- # argument is a Future, which we cannot pass as a callback to
- # oauth2_request. Instead, have oauth2_request return a
- # future and chain them together.
- oauth_future = self.oauth2_request(url, access_token=access_token,
- post_args=post_args, **args)
- chain_future(oauth_future, callback)
+ return await self.oauth2_request(
+ url, access_token=access_token, post_args=post_args, **args
+ )
-def _oauth_signature(consumer_token, method, url, parameters={}, token=None):
+def _oauth_signature(
+ consumer_token: Dict[str, Any],
+ method: str,
+ url: str,
+ parameters: Dict[str, Any] = {},
+ token: Optional[Dict[str, Any]] = None,
+) -> bytes:
"""Calculates the HMAC-SHA1 OAuth signature for the given request.
See http://oauth.net/core/1.0/#signing_process
"""
- parts = urlparse.urlparse(url)
+ parts = urllib.parse.urlparse(url)
scheme, netloc, path = parts[:3]
normalized_url = scheme.lower() + "://" + netloc.lower() + path
base_elems = []
base_elems.append(method.upper())
base_elems.append(normalized_url)
- base_elems.append("&".join("%s=%s" % (k, _oauth_escape(str(v)))
- for k, v in sorted(parameters.items())))
+ base_elems.append(
+ "&".join(
+ "%s=%s" % (k, _oauth_escape(str(v))) for k, v in sorted(parameters.items())
+ )
+ )
base_string = "&".join(_oauth_escape(e) for e in base_elems)
key_elems = [escape.utf8(consumer_token["secret"])]
@@ -1199,42 +1131,53 @@ def _oauth_signature(consumer_token, method, url, parameters={}, token=None):
return binascii.b2a_base64(hash.digest())[:-1]
-def _oauth10a_signature(consumer_token, method, url, parameters={}, token=None):
+def _oauth10a_signature(
+ consumer_token: Dict[str, Any],
+ method: str,
+ url: str,
+ parameters: Dict[str, Any] = {},
+ token: Optional[Dict[str, Any]] = None,
+) -> bytes:
"""Calculates the HMAC-SHA1 OAuth 1.0a signature for the given request.
See http://oauth.net/core/1.0a/#signing_process
"""
- parts = urlparse.urlparse(url)
+ parts = urllib.parse.urlparse(url)
scheme, netloc, path = parts[:3]
normalized_url = scheme.lower() + "://" + netloc.lower() + path
base_elems = []
base_elems.append(method.upper())
base_elems.append(normalized_url)
- base_elems.append("&".join("%s=%s" % (k, _oauth_escape(str(v)))
- for k, v in sorted(parameters.items())))
+ base_elems.append(
+ "&".join(
+ "%s=%s" % (k, _oauth_escape(str(v))) for k, v in sorted(parameters.items())
+ )
+ )
base_string = "&".join(_oauth_escape(e) for e in base_elems)
- key_elems = [escape.utf8(urllib_parse.quote(consumer_token["secret"], safe='~'))]
- key_elems.append(escape.utf8(urllib_parse.quote(token["secret"], safe='~') if token else ""))
+ key_elems = [escape.utf8(urllib.parse.quote(consumer_token["secret"], safe="~"))]
+ key_elems.append(
+ escape.utf8(urllib.parse.quote(token["secret"], safe="~") if token else "")
+ )
key = b"&".join(key_elems)
hash = hmac.new(key, escape.utf8(base_string), hashlib.sha1)
return binascii.b2a_base64(hash.digest())[:-1]
-def _oauth_escape(val):
+def _oauth_escape(val: Union[str, bytes]) -> str:
if isinstance(val, unicode_type):
val = val.encode("utf-8")
- return urllib_parse.quote(val, safe="~")
+ return urllib.parse.quote(val, safe="~")
-def _oauth_parse_response(body):
+def _oauth_parse_response(body: bytes) -> Dict[str, Any]:
# I can't find an officially-defined encoding for oauth responses and
# have never seen anyone use non-ascii. Leave the response in a byte
# string for python 2, and use utf8 on python 3.
- body = escape.native_str(body)
- p = urlparse.parse_qs(body, keep_blank_values=False)
+ body_str = escape.native_str(body)
+ p = urllib.parse.parse_qs(body_str, keep_blank_values=False)
token = dict(key=p["oauth_token"][0], secret=p["oauth_token_secret"][0])
# Add the extra parameters the Provider included to the token
diff --git a/tornado/autoreload.py b/tornado/autoreload.py
index 2f91127031..db47262be2 100644
--- a/tornado/autoreload.py
+++ b/tornado/autoreload.py
@@ -33,9 +33,8 @@
other import-time failures, while debug mode catches changes once
the server has started.
-This module depends on `.IOLoop`, so it will not work in WSGI applications
-and Google App Engine. It also will not work correctly when `.HTTPServer`'s
-multi-process mode is used.
+This module will not work correctly when `.HTTPServer`'s multi-process
+mode is used.
Reloading loses any Python interpreter command-line arguments (e.g. ``-u``)
because it re-executes Python using ``sys.executable`` and ``sys.argv``.
@@ -44,8 +43,6 @@
"""
-from __future__ import absolute_import, division, print_function
-
import os
import sys
@@ -96,20 +93,29 @@
try:
import signal
except ImportError:
- signal = None
+ signal = None # type: ignore
+
+import typing
+from typing import Callable, Dict
+
+if typing.TYPE_CHECKING:
+ from typing import List, Optional, Union # noqa: F401
# os.execv is broken on Windows and can't properly parse command line
# arguments and executable name if they contain whitespaces. subprocess
# fixes that behavior.
-_has_execv = sys.platform != 'win32'
+_has_execv = sys.platform != "win32"
_watched_files = set()
_reload_hooks = []
_reload_attempted = False
_io_loops = weakref.WeakKeyDictionary() # type: ignore
+_autoreload_is_main = False
+_original_argv = None # type: Optional[List[str]]
+_original_spec = None
-def start(check_time=500):
+def start(check_time: int = 500) -> None:
"""Begins watching source files for changes.
.. versionchanged:: 5.0
@@ -121,13 +127,13 @@ def start(check_time=500):
_io_loops[io_loop] = True
if len(_io_loops) > 1:
gen_log.warning("tornado.autoreload started more than once in the same process")
- modify_times = {}
+ modify_times = {} # type: Dict[str, float]
callback = functools.partial(_reload_on_update, modify_times)
scheduler = ioloop.PeriodicCallback(callback, check_time)
scheduler.start()
-def wait():
+def wait() -> None:
"""Wait for a watched file to change, then restart the process.
Intended to be used at the end of scripts like unit test runners,
@@ -139,7 +145,7 @@ def wait():
io_loop.start()
-def watch(filename):
+def watch(filename: str) -> None:
"""Add a file to the watch list.
All imported modules are watched by default.
@@ -147,18 +153,17 @@ def watch(filename):
_watched_files.add(filename)
-def add_reload_hook(fn):
+def add_reload_hook(fn: Callable[[], None]) -> None:
"""Add a function to be called before reloading the process.
Note that for open file and socket handles it is generally
preferable to set the ``FD_CLOEXEC`` flag (using `fcntl` or
- ``tornado.platform.auto.set_close_exec``) instead
- of using a reload hook to close them.
+ `os.set_inheritable`) instead of using a reload hook to close them.
"""
_reload_hooks.append(fn)
-def _reload_on_update(modify_times):
+def _reload_on_update(modify_times: Dict[str, float]) -> None:
if _reload_attempted:
# We already tried to reload and it didn't work, so don't try again.
return
@@ -184,7 +189,7 @@ def _reload_on_update(modify_times):
_check_file(modify_times, path)
-def _check_file(modify_times, path):
+def _check_file(modify_times: Dict[str, float], path: str) -> None:
try:
modified = os.stat(path).st_mtime
except Exception:
@@ -197,12 +202,12 @@ def _check_file(modify_times, path):
_reload()
-def _reload():
+def _reload() -> None:
global _reload_attempted
_reload_attempted = True
for fn in _reload_hooks:
fn()
- if hasattr(signal, "setitimer"):
+ if sys.platform != "win32":
# Clear the alarm signal set by
# ioloop.set_blocking_log_threshold so it doesn't fire
# after the exec.
@@ -214,19 +219,24 @@ def _reload():
# __spec__ is not available (Python < 3.4), check instead if
# sys.path[0] is an empty string and add the current directory to
# $PYTHONPATH.
- spec = getattr(sys.modules['__main__'], '__spec__', None)
- if spec:
- argv = ['-m', spec.name] + sys.argv[1:]
+ if _autoreload_is_main:
+ assert _original_argv is not None
+ spec = _original_spec
+ argv = _original_argv
else:
+ spec = getattr(sys.modules["__main__"], "__spec__", None)
argv = sys.argv
- path_prefix = '.' + os.pathsep
- if (sys.path[0] == '' and
- not os.environ.get("PYTHONPATH", "").startswith(path_prefix)):
- os.environ["PYTHONPATH"] = (path_prefix +
- os.environ.get("PYTHONPATH", ""))
+ if spec:
+ argv = ["-m", spec.name] + argv[1:]
+ else:
+ path_prefix = "." + os.pathsep
+ if sys.path[0] == "" and not os.environ.get("PYTHONPATH", "").startswith(
+ path_prefix
+ ):
+ os.environ["PYTHONPATH"] = path_prefix + os.environ.get("PYTHONPATH", "")
if not _has_execv:
subprocess.Popen([sys.executable] + argv)
- sys.exit(0)
+ os._exit(0)
else:
try:
os.execv(sys.executable, [sys.executable] + argv)
@@ -242,7 +252,9 @@ def _reload():
# Unfortunately the errno returned in this case does not
# appear to be consistent, so we can't easily check for
# this error specifically.
- os.spawnv(os.P_NOWAIT, sys.executable, [sys.executable] + argv)
+ os.spawnv(
+ os.P_NOWAIT, sys.executable, [sys.executable] + argv # type: ignore
+ )
# At this point the IOLoop has been closed and finally
# blocks will experience errors if we allow the stack to
# unwind, so just exit uncleanly.
@@ -256,7 +268,7 @@ def _reload():
"""
-def main():
+def main() -> None:
"""Command-line wrapper to re-run a script whenever its source changes.
Scripts may be specified by filename or module name::
@@ -269,7 +281,18 @@ def main():
can catch import-time problems like syntax errors that would otherwise
prevent the script from reaching its call to `wait`.
"""
+ # Remember that we were launched with autoreload as main.
+ # The main module can be tricky; set the variables both in our globals
+ # (which may be __main__) and the real importable version.
+ import tornado.autoreload
+
+ global _autoreload_is_main
+ global _original_argv, _original_spec
+ tornado.autoreload._autoreload_is_main = _autoreload_is_main = True
original_argv = sys.argv
+ tornado.autoreload._original_argv = _original_argv = original_argv
+ original_spec = getattr(sys.modules["__main__"], "__spec__", None)
+ tornado.autoreload._original_spec = _original_spec = original_spec
sys.argv = sys.argv[:]
if len(sys.argv) >= 3 and sys.argv[1] == "-m":
mode = "module"
@@ -286,6 +309,7 @@ def main():
try:
if mode == "module":
import runpy
+
runpy.run_module(module, run_name="__main__", alter_sys=True)
elif mode == "script":
with open(script) as f:
@@ -316,19 +340,20 @@ def main():
# SyntaxErrors are special: their innermost stack frame is fake
# so extract_tb won't see it and we have to get the filename
# from the exception object.
- watch(e.filename)
+ if e.filename is not None:
+ watch(e.filename)
else:
logging.basicConfig()
gen_log.info("Script exited normally")
# restore sys.argv so subsequent executions will include autoreload
sys.argv = original_argv
- if mode == 'module':
+ if mode == "module":
# runpy did a fake import of the module as __main__, but now it's
# no longer in sys.modules. Figure out where it is and watch it.
loader = pkgutil.get_loader(module)
if loader is not None:
- watch(loader.get_filename())
+ watch(loader.get_filename()) # type: ignore
wait()
diff --git a/tornado/concurrent.py b/tornado/concurrent.py
index 850766818d..6e05346b56 100644
--- a/tornado/concurrent.py
+++ b/tornado/concurrent.py
@@ -14,393 +14,67 @@
# under the License.
"""Utilities for working with ``Future`` objects.
-``Futures`` are a pattern for concurrent programming introduced in
-Python 3.2 in the `concurrent.futures` package, and also adopted (in a
-slightly different form) in Python 3.4's `asyncio` package. This
-package defines a ``Future`` class that is an alias for `asyncio.Future`
-when available, and a compatible implementation for older versions of
-Python. It also includes some utility functions for interacting with
-``Future`` objects.
-
-While this package is an important part of Tornado's internal
+Tornado previously provided its own ``Future`` class, but now uses
+`asyncio.Future`. This module contains utility functions for working
+with `asyncio.Future` in a way that is backwards-compatible with
+Tornado's old ``Future`` implementation.
+
+While this module is an important part of Tornado's internal
implementation, applications rarely need to interact with it
directly.
+
"""
-from __future__ import absolute_import, division, print_function
+import asyncio
+from concurrent import futures
import functools
-import platform
-import textwrap
-import traceback
import sys
-import warnings
+import types
from tornado.log import app_log
-from tornado.stack_context import ExceptionStackContext, wrap
-from tornado.util import raise_exc_info, ArgReplacer, is_finalizing
-
-try:
- from concurrent import futures
-except ImportError:
- futures = None
-
-try:
- import asyncio
-except ImportError:
- asyncio = None
-
-try:
- import typing
-except ImportError:
- typing = None
+import typing
+from typing import Any, Callable, Optional, Tuple, Union
-# Can the garbage collector handle cycles that include __del__ methods?
-# This is true in cpython beginning with version 3.4 (PEP 442).
-_GC_CYCLE_FINALIZERS = (platform.python_implementation() == 'CPython' and
- sys.version_info >= (3, 4))
+_T = typing.TypeVar("_T")
class ReturnValueIgnoredError(Exception):
+ # No longer used; was previously used by @return_future
pass
-# This class and associated code in the future object is derived
-# from the Trollius project, a backport of asyncio to Python 2.x - 3.x
-
-
-class _TracebackLogger(object):
- """Helper to log a traceback upon destruction if not cleared.
-
- This solves a nasty problem with Futures and Tasks that have an
- exception set: if nobody asks for the exception, the exception is
- never logged. This violates the Zen of Python: 'Errors should
- never pass silently. Unless explicitly silenced.'
-
- However, we don't want to log the exception as soon as
- set_exception() is called: if the calling code is written
- properly, it will get the exception and handle it properly. But
- we *do* want to log it if result() or exception() was never called
- -- otherwise developers waste a lot of time wondering why their
- buggy code fails silently.
-
- An earlier attempt added a __del__() method to the Future class
- itself, but this backfired because the presence of __del__()
- prevents garbage collection from breaking cycles. A way out of
- this catch-22 is to avoid having a __del__() method on the Future
- class itself, but instead to have a reference to a helper object
- with a __del__() method that logs the traceback, where we ensure
- that the helper object doesn't participate in cycles, and only the
- Future has a reference to it.
-
- The helper object is added when set_exception() is called. When
- the Future is collected, and the helper is present, the helper
- object is also collected, and its __del__() method will log the
- traceback. When the Future's result() or exception() method is
- called (and a helper object is present), it removes the the helper
- object, after calling its clear() method to prevent it from
- logging.
-
- One downside is that we do a fair amount of work to extract the
- traceback from the exception, even when it is never logged. It
- would seem cheaper to just store the exception object, but that
- references the traceback, which references stack frames, which may
- reference the Future, which references the _TracebackLogger, and
- then the _TracebackLogger would be included in a cycle, which is
- what we're trying to avoid! As an optimization, we don't
- immediately format the exception; we only do the work when
- activate() is called, which call is delayed until after all the
- Future's callbacks have run. Since usually a Future has at least
- one callback (typically set by 'yield From') and usually that
- callback extracts the callback, thereby removing the need to
- format the exception.
-
- PS. I don't claim credit for this solution. I first heard of it
- in a discussion about closing files when they are collected.
- """
-
- __slots__ = ('exc_info', 'formatted_tb')
-
- def __init__(self, exc_info):
- self.exc_info = exc_info
- self.formatted_tb = None
-
- def activate(self):
- exc_info = self.exc_info
- if exc_info is not None:
- self.exc_info = None
- self.formatted_tb = traceback.format_exception(*exc_info)
-
- def clear(self):
- self.exc_info = None
- self.formatted_tb = None
-
- def __del__(self, is_finalizing=is_finalizing):
- if not is_finalizing() and self.formatted_tb:
- app_log.error('Future exception was never retrieved: %s',
- ''.join(self.formatted_tb).rstrip())
-
-
-class Future(object):
- """Placeholder for an asynchronous result.
-
- A ``Future`` encapsulates the result of an asynchronous
- operation. In synchronous applications ``Futures`` are used
- to wait for the result from a thread or process pool; in
- Tornado they are normally used with `.IOLoop.add_future` or by
- yielding them in a `.gen.coroutine`.
-
- `tornado.concurrent.Future` is an alias for `asyncio.Future` when
- that package is available (Python 3.4+). Unlike
- `concurrent.futures.Future`, the ``Futures`` used by Tornado and
- `asyncio` are not thread-safe (and therefore faster for use with
- single-threaded event loops).
-
- In addition to ``exception`` and ``set_exception``, Tornado's
- ``Future`` implementation supports storing an ``exc_info`` triple
- to support better tracebacks on Python 2. To set an ``exc_info``
- triple, use `future_set_exc_info`, and to retrieve one, call
- `result()` (which will raise it).
-
- .. versionchanged:: 4.0
- `tornado.concurrent.Future` is always a thread-unsafe ``Future``
- with support for the ``exc_info`` methods. Previously it would
- be an alias for the thread-safe `concurrent.futures.Future`
- if that package was available and fall back to the thread-unsafe
- implementation if it was not.
-
- .. versionchanged:: 4.1
- If a `.Future` contains an error but that error is never observed
- (by calling ``result()``, ``exception()``, or ``exc_info()``),
- a stack trace will be logged when the `.Future` is garbage collected.
- This normally indicates an error in the application, but in cases
- where it results in undesired logging it may be necessary to
- suppress the logging by ensuring that the exception is observed:
- ``f.add_done_callback(lambda f: f.exception())``.
-
- .. versionchanged:: 5.0
-
- This class was previoiusly available under the name
- ``TracebackFuture``. This name, which was deprecated since
- version 4.0, has been removed. When `asyncio` is available
- ``tornado.concurrent.Future`` is now an alias for
- `asyncio.Future`. Like `asyncio.Future`, callbacks are now
- always scheduled on the `.IOLoop` and are never run
- synchronously.
-
- """
- def __init__(self):
- self._done = False
- self._result = None
- self._exc_info = None
-
- self._log_traceback = False # Used for Python >= 3.4
- self._tb_logger = None # Used for Python <= 3.3
-
- self._callbacks = []
-
- # Implement the Python 3.5 Awaitable protocol if possible
- # (we can't use return and yield together until py33).
- if sys.version_info >= (3, 3):
- exec(textwrap.dedent("""
- def __await__(self):
- return (yield self)
- """))
- else:
- # Py2-compatible version for use with cython.
- def __await__(self):
- result = yield self
- # StopIteration doesn't take args before py33,
- # but Cython recognizes the args tuple.
- e = StopIteration()
- e.args = (result,)
- raise e
-
- def cancel(self):
- """Cancel the operation, if possible.
-
- Tornado ``Futures`` do not support cancellation, so this method always
- returns False.
- """
- return False
-
- def cancelled(self):
- """Returns True if the operation has been cancelled.
-
- Tornado ``Futures`` do not support cancellation, so this method
- always returns False.
- """
- return False
-
- def running(self):
- """Returns True if this operation is currently running."""
- return not self._done
-
- def done(self):
- """Returns True if the future has finished running."""
- return self._done
-
- def _clear_tb_log(self):
- self._log_traceback = False
- if self._tb_logger is not None:
- self._tb_logger.clear()
- self._tb_logger = None
-
- def result(self, timeout=None):
- """If the operation succeeded, return its result. If it failed,
- re-raise its exception.
-
- This method takes a ``timeout`` argument for compatibility with
- `concurrent.futures.Future` but it is an error to call it
- before the `Future` is done, so the ``timeout`` is never used.
- """
- self._clear_tb_log()
- if self._result is not None:
- return self._result
- if self._exc_info is not None:
- try:
- raise_exc_info(self._exc_info)
- finally:
- self = None
- self._check_done()
- return self._result
-
- def exception(self, timeout=None):
- """If the operation raised an exception, return the `Exception`
- object. Otherwise returns None.
-
- This method takes a ``timeout`` argument for compatibility with
- `concurrent.futures.Future` but it is an error to call it
- before the `Future` is done, so the ``timeout`` is never used.
- """
- self._clear_tb_log()
- if self._exc_info is not None:
- return self._exc_info[1]
- else:
- self._check_done()
- return None
-
- def add_done_callback(self, fn):
- """Attaches the given callback to the `Future`.
-
- It will be invoked with the `Future` as its argument when the Future
- has finished running and its result is available. In Tornado
- consider using `.IOLoop.add_future` instead of calling
- `add_done_callback` directly.
- """
- if self._done:
- from tornado.ioloop import IOLoop
- IOLoop.current().add_callback(fn, self)
- else:
- self._callbacks.append(fn)
-
- def set_result(self, result):
- """Sets the result of a ``Future``.
-
- It is undefined to call any of the ``set`` methods more than once
- on the same object.
- """
- self._result = result
- self._set_done()
- def set_exception(self, exception):
- """Sets the exception of a ``Future.``"""
- self.set_exc_info(
- (exception.__class__,
- exception,
- getattr(exception, '__traceback__', None)))
+Future = asyncio.Future
- def exc_info(self):
- """Returns a tuple in the same format as `sys.exc_info` or None.
+FUTURES = (futures.Future, Future)
- .. versionadded:: 4.0
- """
- self._clear_tb_log()
- return self._exc_info
- def set_exc_info(self, exc_info):
- """Sets the exception information of a ``Future.``
-
- Preserves tracebacks on Python 2.
-
- .. versionadded:: 4.0
- """
- self._exc_info = exc_info
- self._log_traceback = True
- if not _GC_CYCLE_FINALIZERS:
- self._tb_logger = _TracebackLogger(exc_info)
-
- try:
- self._set_done()
- finally:
- # Activate the logger after all callbacks have had a
- # chance to call result() or exception().
- if self._log_traceback and self._tb_logger is not None:
- self._tb_logger.activate()
- self._exc_info = exc_info
-
- def _check_done(self):
- if not self._done:
- raise Exception("DummyFuture does not support blocking for results")
-
- def _set_done(self):
- self._done = True
- if self._callbacks:
- from tornado.ioloop import IOLoop
- loop = IOLoop.current()
- for cb in self._callbacks:
- loop.add_callback(cb, self)
- self._callbacks = None
-
- # On Python 3.3 or older, objects with a destructor part of a reference
- # cycle are never destroyed. It's no longer the case on Python 3.4 thanks to
- # the PEP 442.
- if _GC_CYCLE_FINALIZERS:
- def __del__(self, is_finalizing=is_finalizing):
- if is_finalizing() or not self._log_traceback:
- # set_exception() was not called, or result() or exception()
- # has consumed the exception
- return
-
- tb = traceback.format_exception(*self._exc_info)
-
- app_log.error('Future %r exception was never retrieved: %s',
- self, ''.join(tb).rstrip())
-
-
-if asyncio is not None:
- Future = asyncio.Future # noqa
-
-if futures is None:
- FUTURES = Future # type: typing.Union[type, typing.Tuple[type, ...]]
-else:
- FUTURES = (futures.Future, Future)
-
-
-def is_future(x):
+def is_future(x: Any) -> bool:
return isinstance(x, FUTURES)
-class DummyExecutor(object):
- def submit(self, fn, *args, **kwargs):
- future = Future()
+class DummyExecutor(futures.Executor):
+ def submit(
+ self, fn: Callable[..., _T], *args: Any, **kwargs: Any
+ ) -> "futures.Future[_T]":
+ future = futures.Future() # type: futures.Future[_T]
try:
future_set_result_unless_cancelled(future, fn(*args, **kwargs))
except Exception:
future_set_exc_info(future, sys.exc_info())
return future
- def shutdown(self, wait=True):
+ def shutdown(self, wait: bool = True) -> None:
pass
dummy_executor = DummyExecutor()
-def run_on_executor(*args, **kwargs):
+def run_on_executor(*args: Any, **kwargs: Any) -> Callable:
"""Decorator to run a synchronous method asynchronously on an executor.
- The decorated method may be called with a ``callback`` keyword
- argument and returns a future.
+ Returns a future.
The executor to be used is determined by the ``executor``
attributes of ``self``. To use a different attribute name, pass a
@@ -432,24 +106,25 @@ def foo(self):
The ``callback`` argument is deprecated and will be removed in
6.0. The decorator itself is discouraged in new code but will
not be removed in 6.0.
+
+ .. versionchanged:: 6.0
+
+ The ``callback`` argument was removed.
"""
- def run_on_executor_decorator(fn):
+ # Fully type-checking decorators is tricky, and this one is
+ # discouraged anyway so it doesn't have all the generic magic.
+ def run_on_executor_decorator(fn: Callable) -> Callable[..., Future]:
executor = kwargs.get("executor", "executor")
@functools.wraps(fn)
- def wrapper(self, *args, **kwargs):
- callback = kwargs.pop("callback", None)
- async_future = Future()
+ def wrapper(self: Any, *args: Any, **kwargs: Any) -> Future:
+ async_future = Future() # type: Future
conc_future = getattr(self, executor).submit(fn, self, *args, **kwargs)
chain_future(conc_future, async_future)
- if callback:
- warnings.warn("callback arguments are deprecated, use the returned Future instead",
- DeprecationWarning)
- from tornado.ioloop import IOLoop
- IOLoop.current().add_future(
- async_future, lambda future: callback(future.result()))
return async_future
+
return wrapper
+
if args and kwargs:
raise ValueError("cannot combine positional and keyword args")
if len(args) == 1:
@@ -462,110 +137,7 @@ def wrapper(self, *args, **kwargs):
_NO_RESULT = object()
-def return_future(f):
- """Decorator to make a function that returns via callback return a
- `Future`.
-
- This decorator was provided to ease the transition from
- callback-oriented code to coroutines. It is not recommended for
- new code.
-
- The wrapped function should take a ``callback`` keyword argument
- and invoke it with one argument when it has finished. To signal failure,
- the function can simply raise an exception (which will be
- captured by the `.StackContext` and passed along to the ``Future``).
-
- From the caller's perspective, the callback argument is optional.
- If one is given, it will be invoked when the function is complete
- with ``Future.result()`` as an argument. If the function fails, the
- callback will not be run and an exception will be raised into the
- surrounding `.StackContext`.
-
- If no callback is given, the caller should use the ``Future`` to
- wait for the function to complete (perhaps by yielding it in a
- `.gen.engine` function, or passing it to `.IOLoop.add_future`).
-
- Usage:
-
- .. testcode::
-
- @return_future
- def future_func(arg1, arg2, callback):
- # Do stuff (possibly asynchronous)
- callback(result)
-
- @gen.engine
- def caller(callback):
- yield future_func(arg1, arg2)
- callback()
-
- ..
-
- Note that ``@return_future`` and ``@gen.engine`` can be applied to the
- same function, provided ``@return_future`` appears first. However,
- consider using ``@gen.coroutine`` instead of this combination.
-
- .. versionchanged:: 5.1
-
- Now raises a `.DeprecationWarning` if a callback argument is passed to
- the decorated function and deprecation warnings are enabled.
-
- .. deprecated:: 5.1
-
- New code should use coroutines directly instead of wrapping
- callback-based code with this decorator.
- """
- replacer = ArgReplacer(f, 'callback')
-
- @functools.wraps(f)
- def wrapper(*args, **kwargs):
- future = Future()
- callback, args, kwargs = replacer.replace(
- lambda value=_NO_RESULT: future_set_result_unless_cancelled(future, value),
- args, kwargs)
-
- def handle_error(typ, value, tb):
- future_set_exc_info(future, (typ, value, tb))
- return True
- exc_info = None
- with ExceptionStackContext(handle_error):
- try:
- result = f(*args, **kwargs)
- if result is not None:
- raise ReturnValueIgnoredError(
- "@return_future should not be used with functions "
- "that return values")
- except:
- exc_info = sys.exc_info()
- raise
- if exc_info is not None:
- # If the initial synchronous part of f() raised an exception,
- # go ahead and raise it to the caller directly without waiting
- # for them to inspect the Future.
- future.result()
-
- # If the caller passed in a callback, schedule it to be called
- # when the future resolves. It is important that this happens
- # just before we return the future, or else we risk confusing
- # stack contexts with multiple exceptions (one here with the
- # immediate exception, and again when the future resolves and
- # the callback triggers its exception by calling future.result()).
- if callback is not None:
- warnings.warn("callback arguments are deprecated, use the returned Future instead",
- DeprecationWarning)
-
- def run_callback(future):
- result = future.result()
- if result is _NO_RESULT:
- callback()
- else:
- callback(future.result())
- future_add_done_callback(future, wrap(run_callback))
- return future
- return wrapper
-
-
-def chain_future(a, b):
+def chain_future(a: "Future[_T]", b: "Future[_T]") -> None:
"""Chain two futures together so that when one completes, so does the other.
The result (success or failure) of ``a`` will be copied to ``b``, unless
@@ -577,29 +149,35 @@ def chain_future(a, b):
`concurrent.futures.Future`.
"""
- def copy(future):
+
+ def copy(future: "Future[_T]") -> None:
assert future is a
if b.done():
return
- if (hasattr(a, 'exc_info') and
- a.exc_info() is not None):
- future_set_exc_info(b, a.exc_info())
- elif a.exception() is not None:
- b.set_exception(a.exception())
+ if hasattr(a, "exc_info") and a.exc_info() is not None: # type: ignore
+ future_set_exc_info(b, a.exc_info()) # type: ignore
else:
- b.set_result(a.result())
+ a_exc = a.exception()
+ if a_exc is not None:
+ b.set_exception(a_exc)
+ else:
+ b.set_result(a.result())
+
if isinstance(a, Future):
future_add_done_callback(a, copy)
else:
# concurrent.futures.Future
from tornado.ioloop import IOLoop
+
IOLoop.current().add_future(a, copy)
-def future_set_result_unless_cancelled(future, value):
+def future_set_result_unless_cancelled(
+ future: "Union[futures.Future[_T], Future[_T]]", value: _T
+) -> None:
"""Set the given ``value`` as the `Future`'s result, if not cancelled.
- Avoids asyncio.InvalidStateError when calling set_result() on
+ Avoids ``asyncio.InvalidStateError`` when calling ``set_result()`` on
a cancelled `asyncio.Future`.
.. versionadded:: 5.0
@@ -608,23 +186,69 @@ def future_set_result_unless_cancelled(future, value):
future.set_result(value)
-def future_set_exc_info(future, exc_info):
+def future_set_exception_unless_cancelled(
+ future: "Union[futures.Future[_T], Future[_T]]", exc: BaseException
+) -> None:
+ """Set the given ``exc`` as the `Future`'s exception.
+
+ If the Future is already canceled, logs the exception instead. If
+ this logging is not desired, the caller should explicitly check
+ the state of the Future and call ``Future.set_exception`` instead of
+ this wrapper.
+
+ Avoids ``asyncio.InvalidStateError`` when calling ``set_exception()`` on
+ a cancelled `asyncio.Future`.
+
+ .. versionadded:: 6.0
+
+ """
+ if not future.cancelled():
+ future.set_exception(exc)
+ else:
+ app_log.error("Exception after Future was cancelled", exc_info=exc)
+
+
+def future_set_exc_info(
+ future: "Union[futures.Future[_T], Future[_T]]",
+ exc_info: Tuple[
+ Optional[type], Optional[BaseException], Optional[types.TracebackType]
+ ],
+) -> None:
"""Set the given ``exc_info`` as the `Future`'s exception.
- Understands both `asyncio.Future` and Tornado's extensions to
- enable better tracebacks on Python 2.
+ Understands both `asyncio.Future` and the extensions in older
+ versions of Tornado to enable better tracebacks on Python 2.
.. versionadded:: 5.0
+
+ .. versionchanged:: 6.0
+
+ If the future is already cancelled, this function is a no-op.
+ (previously ``asyncio.InvalidStateError`` would be raised)
+
"""
- if hasattr(future, 'set_exc_info'):
- # Tornado's Future
- future.set_exc_info(exc_info)
- else:
- # asyncio.Future
- future.set_exception(exc_info[1])
+ if exc_info[1] is None:
+ raise Exception("future_set_exc_info called with no exception")
+ future_set_exception_unless_cancelled(future, exc_info[1])
+
+
+@typing.overload
+def future_add_done_callback(
+ future: "futures.Future[_T]", callback: Callable[["futures.Future[_T]"], None]
+) -> None:
+ pass
+
+
+@typing.overload # noqa: F811
+def future_add_done_callback(
+ future: "Future[_T]", callback: Callable[["Future[_T]"], None]
+) -> None:
+ pass
-def future_add_done_callback(future, callback):
+def future_add_done_callback( # noqa: F811
+ future: "Union[futures.Future[_T], Future[_T]]", callback: Callable[..., None]
+) -> None:
"""Arrange to call ``callback`` when ``future`` is complete.
``callback`` is invoked with one argument, the ``future``.
diff --git a/tornado/curl_httpclient.py b/tornado/curl_httpclient.py
index 54fc5b36da..3860944956 100644
--- a/tornado/curl_httpclient.py
+++ b/tornado/curl_httpclient.py
@@ -15,44 +15,60 @@
"""Non-blocking HTTP client implementation using pycurl."""
-from __future__ import absolute_import, division, print_function
-
import collections
import functools
import logging
-import pycurl # type: ignore
+import pycurl
import threading
import time
from io import BytesIO
from tornado import httputil
from tornado import ioloop
-from tornado import stack_context
from tornado.escape import utf8, native_str
-from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main
+from tornado.httpclient import (
+ HTTPRequest,
+ HTTPResponse,
+ HTTPError,
+ AsyncHTTPClient,
+ main,
+)
+from tornado.log import app_log
+
+from typing import Dict, Any, Callable, Union, Tuple, Optional
+import typing
+
+if typing.TYPE_CHECKING:
+ from typing import Deque # noqa: F401
-curl_log = logging.getLogger('tornado.curl_httpclient')
+curl_log = logging.getLogger("tornado.curl_httpclient")
class CurlAsyncHTTPClient(AsyncHTTPClient):
- def initialize(self, max_clients=10, defaults=None):
- super(CurlAsyncHTTPClient, self).initialize(defaults=defaults)
- self._multi = pycurl.CurlMulti()
+ def initialize( # type: ignore
+ self, max_clients: int = 10, defaults: Optional[Dict[str, Any]] = None
+ ) -> None:
+ super().initialize(defaults=defaults)
+ # Typeshed is incomplete for CurlMulti, so just use Any for now.
+ self._multi = pycurl.CurlMulti() # type: Any
self._multi.setopt(pycurl.M_TIMERFUNCTION, self._set_timeout)
self._multi.setopt(pycurl.M_SOCKETFUNCTION, self._handle_socket)
self._curls = [self._curl_create() for i in range(max_clients)]
self._free_list = self._curls[:]
- self._requests = collections.deque()
- self._fds = {}
- self._timeout = None
+ self._requests = (
+ collections.deque()
+ ) # type: Deque[Tuple[HTTPRequest, Callable[[HTTPResponse], None], float]]
+ self._fds = {} # type: Dict[int, int]
+ self._timeout = None # type: Optional[object]
# libcurl has bugs that sometimes cause it to not report all
# relevant file descriptors and timeouts to TIMERFUNCTION/
# SOCKETFUNCTION. Mitigate the effects of such bugs by
# forcing a periodic scan of all active requests.
self._force_timeout_callback = ioloop.PeriodicCallback(
- self._handle_force_timeout, 1000)
+ self._handle_force_timeout, 1000
+ )
self._force_timeout_callback.start()
# Work around a bug in libcurl 7.29.0: Some fields in the curl
@@ -64,27 +80,29 @@ def initialize(self, max_clients=10, defaults=None):
self._multi.add_handle(dummy_curl_handle)
self._multi.remove_handle(dummy_curl_handle)
- def close(self):
+ def close(self) -> None:
self._force_timeout_callback.stop()
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
for curl in self._curls:
curl.close()
self._multi.close()
- super(CurlAsyncHTTPClient, self).close()
+ super().close()
# Set below properties to None to reduce the reference count of current
# instance, because those properties hold some methods of current
# instance that will case circular reference.
- self._force_timeout_callback = None
+ self._force_timeout_callback = None # type: ignore
self._multi = None
- def fetch_impl(self, request, callback):
- self._requests.append((request, callback))
+ def fetch_impl(
+ self, request: HTTPRequest, callback: Callable[[HTTPResponse], None]
+ ) -> None:
+ self._requests.append((request, callback, self.io_loop.time()))
self._process_queue()
self._set_timeout(0)
- def _handle_socket(self, event, fd, multi, data):
+ def _handle_socket(self, event: int, fd: int, multi: Any, data: bytes) -> None:
"""Called by libcurl when it wants to change the file descriptors
it cares about.
"""
@@ -92,7 +110,7 @@ def _handle_socket(self, event, fd, multi, data):
pycurl.POLL_NONE: ioloop.IOLoop.NONE,
pycurl.POLL_IN: ioloop.IOLoop.READ,
pycurl.POLL_OUT: ioloop.IOLoop.WRITE,
- pycurl.POLL_INOUT: ioloop.IOLoop.READ | ioloop.IOLoop.WRITE
+ pycurl.POLL_INOUT: ioloop.IOLoop.READ | ioloop.IOLoop.WRITE,
}
if event == pycurl.POLL_REMOVE:
if fd in self._fds:
@@ -110,18 +128,18 @@ def _handle_socket(self, event, fd, multi, data):
# instead of update.
if fd in self._fds:
self.io_loop.remove_handler(fd)
- self.io_loop.add_handler(fd, self._handle_events,
- ioloop_event)
+ self.io_loop.add_handler(fd, self._handle_events, ioloop_event)
self._fds[fd] = ioloop_event
- def _set_timeout(self, msecs):
+ def _set_timeout(self, msecs: int) -> None:
"""Called by libcurl to schedule a timeout."""
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = self.io_loop.add_timeout(
- self.io_loop.time() + msecs / 1000.0, self._handle_timeout)
+ self.io_loop.time() + msecs / 1000.0, self._handle_timeout
+ )
- def _handle_events(self, fd, events):
+ def _handle_events(self, fd: int, events: int) -> None:
"""Called by IOLoop when there is activity on one of our
file descriptors.
"""
@@ -139,19 +157,17 @@ def _handle_events(self, fd, events):
break
self._finish_pending_requests()
- def _handle_timeout(self):
+ def _handle_timeout(self) -> None:
"""Called by IOLoop when the requested timeout has passed."""
- with stack_context.NullContext():
- self._timeout = None
- while True:
- try:
- ret, num_handles = self._multi.socket_action(
- pycurl.SOCKET_TIMEOUT, 0)
- except pycurl.error as e:
- ret = e.args[0]
- if ret != pycurl.E_CALL_MULTI_PERFORM:
- break
- self._finish_pending_requests()
+ self._timeout = None
+ while True:
+ try:
+ ret, num_handles = self._multi.socket_action(pycurl.SOCKET_TIMEOUT, 0)
+ except pycurl.error as e:
+ ret = e.args[0]
+ if ret != pycurl.E_CALL_MULTI_PERFORM:
+ break
+ self._finish_pending_requests()
# In theory, we shouldn't have to do this because curl will
# call _set_timeout whenever the timeout changes. However,
@@ -170,21 +186,20 @@ def _handle_timeout(self):
if new_timeout >= 0:
self._set_timeout(new_timeout)
- def _handle_force_timeout(self):
+ def _handle_force_timeout(self) -> None:
"""Called by IOLoop periodically to ask libcurl to process any
events it may have forgotten about.
"""
- with stack_context.NullContext():
- while True:
- try:
- ret, num_handles = self._multi.socket_all()
- except pycurl.error as e:
- ret = e.args[0]
- if ret != pycurl.E_CALL_MULTI_PERFORM:
- break
- self._finish_pending_requests()
-
- def _finish_pending_requests(self):
+ while True:
+ try:
+ ret, num_handles = self._multi.socket_all()
+ except pycurl.error as e:
+ ret = e.args[0]
+ if ret != pycurl.E_CALL_MULTI_PERFORM:
+ break
+ self._finish_pending_requests()
+
+ def _finish_pending_requests(self) -> None:
"""Process any requests that were completed by the last
call to multi.socket_action.
"""
@@ -198,53 +213,62 @@ def _finish_pending_requests(self):
break
self._process_queue()
- def _process_queue(self):
- with stack_context.NullContext():
- while True:
- started = 0
- while self._free_list and self._requests:
- started += 1
- curl = self._free_list.pop()
- (request, callback) = self._requests.popleft()
- curl.info = {
- "headers": httputil.HTTPHeaders(),
- "buffer": BytesIO(),
- "request": request,
- "callback": callback,
- "curl_start_time": time.time(),
- }
- try:
- self._curl_setup_request(
- curl, request, curl.info["buffer"],
- curl.info["headers"])
- except Exception as e:
- # If there was an error in setup, pass it on
- # to the callback. Note that allowing the
- # error to escape here will appear to work
- # most of the time since we are still in the
- # caller's original stack frame, but when
- # _process_queue() is called from
- # _finish_pending_requests the exceptions have
- # nowhere to go.
- self._free_list.append(curl)
- callback(HTTPResponse(
- request=request,
- code=599,
- error=e))
- else:
- self._multi.add_handle(curl)
-
- if not started:
- break
-
- def _finish(self, curl, curl_error=None, curl_message=None):
- info = curl.info
- curl.info = None
+ def _process_queue(self) -> None:
+ while True:
+ started = 0
+ while self._free_list and self._requests:
+ started += 1
+ curl = self._free_list.pop()
+ (request, callback, queue_start_time) = self._requests.popleft()
+ # TODO: Don't smuggle extra data on an attribute of the Curl object.
+ curl.info = { # type: ignore
+ "headers": httputil.HTTPHeaders(),
+ "buffer": BytesIO(),
+ "request": request,
+ "callback": callback,
+ "queue_start_time": queue_start_time,
+ "curl_start_time": time.time(),
+ "curl_start_ioloop_time": self.io_loop.current().time(), # type: ignore
+ }
+ try:
+ self._curl_setup_request(
+ curl,
+ request,
+ curl.info["buffer"], # type: ignore
+ curl.info["headers"], # type: ignore
+ )
+ except Exception as e:
+ # If there was an error in setup, pass it on
+ # to the callback. Note that allowing the
+ # error to escape here will appear to work
+ # most of the time since we are still in the
+ # caller's original stack frame, but when
+ # _process_queue() is called from
+ # _finish_pending_requests the exceptions have
+ # nowhere to go.
+ self._free_list.append(curl)
+ callback(HTTPResponse(request=request, code=599, error=e))
+ else:
+ self._multi.add_handle(curl)
+
+ if not started:
+ break
+
+ def _finish(
+ self,
+ curl: pycurl.Curl,
+ curl_error: Optional[int] = None,
+ curl_message: Optional[str] = None,
+ ) -> None:
+ info = curl.info # type: ignore
+ curl.info = None # type: ignore
self._multi.remove_handle(curl)
self._free_list.append(curl)
buffer = info["buffer"]
if curl_error:
- error = CurlError(curl_error, curl_message)
+ assert curl_message is not None
+ error = CurlError(curl_error, curl_message) # type: Optional[CurlError]
+ assert error is not None
code = error.code
effective_url = None
buffer.close()
@@ -257,7 +281,7 @@ def _finish(self, curl, curl_error=None, curl_message=None):
# the various curl timings are documented at
# http://curl.haxx.se/libcurl/c/curl_easy_getinfo.html
time_info = dict(
- queue=info["curl_start_time"] - info["request"].start_time,
+ queue=info["curl_start_ioloop_time"] - info["queue_start_time"],
namelookup=curl.getinfo(pycurl.NAMELOOKUP_TIME),
connect=curl.getinfo(pycurl.CONNECT_TIME),
appconnect=curl.getinfo(pycurl.APPCONNECT_TIME),
@@ -267,29 +291,45 @@ def _finish(self, curl, curl_error=None, curl_message=None):
redirect=curl.getinfo(pycurl.REDIRECT_TIME),
)
try:
- info["callback"](HTTPResponse(
- request=info["request"], code=code, headers=info["headers"],
- buffer=buffer, effective_url=effective_url, error=error,
- reason=info['headers'].get("X-Http-Reason", None),
- request_time=time.time() - info["curl_start_time"],
- time_info=time_info))
+ info["callback"](
+ HTTPResponse(
+ request=info["request"],
+ code=code,
+ headers=info["headers"],
+ buffer=buffer,
+ effective_url=effective_url,
+ error=error,
+ reason=info["headers"].get("X-Http-Reason", None),
+ request_time=self.io_loop.time() - info["curl_start_ioloop_time"],
+ start_time=info["curl_start_time"],
+ time_info=time_info,
+ )
+ )
except Exception:
self.handle_callback_exception(info["callback"])
- def handle_callback_exception(self, callback):
- self.io_loop.handle_callback_exception(callback)
+ def handle_callback_exception(self, callback: Any) -> None:
+ app_log.error("Exception in callback %r", callback, exc_info=True)
- def _curl_create(self):
+ def _curl_create(self) -> pycurl.Curl:
curl = pycurl.Curl()
if curl_log.isEnabledFor(logging.DEBUG):
curl.setopt(pycurl.VERBOSE, 1)
curl.setopt(pycurl.DEBUGFUNCTION, self._curl_debug)
- if hasattr(pycurl, 'PROTOCOLS'): # PROTOCOLS first appeared in pycurl 7.19.5 (2014-07-12)
+ if hasattr(
+ pycurl, "PROTOCOLS"
+ ): # PROTOCOLS first appeared in pycurl 7.19.5 (2014-07-12)
curl.setopt(pycurl.PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS)
curl.setopt(pycurl.REDIR_PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS)
return curl
- def _curl_setup_request(self, curl, request, buffer, headers):
+ def _curl_setup_request(
+ self,
+ curl: pycurl.Curl,
+ request: HTTPRequest,
+ buffer: BytesIO,
+ headers: httputil.HTTPHeaders,
+ ) -> None:
curl.setopt(pycurl.URL, native_str(request.url))
# libcurl's magic "Expect: 100-continue" behavior causes delays
@@ -307,32 +347,35 @@ def _curl_setup_request(self, curl, request, buffer, headers):
if "Pragma" not in request.headers:
request.headers["Pragma"] = ""
- curl.setopt(pycurl.HTTPHEADER,
- ["%s: %s" % (native_str(k), native_str(v))
- for k, v in request.headers.get_all()])
+ curl.setopt(
+ pycurl.HTTPHEADER,
+ [
+ "%s: %s" % (native_str(k), native_str(v))
+ for k, v in request.headers.get_all()
+ ],
+ )
- curl.setopt(pycurl.HEADERFUNCTION,
- functools.partial(self._curl_header_callback,
- headers, request.header_callback))
+ curl.setopt(
+ pycurl.HEADERFUNCTION,
+ functools.partial(
+ self._curl_header_callback, headers, request.header_callback
+ ),
+ )
if request.streaming_callback:
- def write_function(chunk):
- self.io_loop.add_callback(request.streaming_callback, chunk)
+
+ def write_function(b: Union[bytes, bytearray]) -> int:
+ assert request.streaming_callback is not None
+ self.io_loop.add_callback(request.streaming_callback, b)
+ return len(b)
+
else:
- write_function = buffer.write
- if bytes is str: # py2
- curl.setopt(pycurl.WRITEFUNCTION, write_function)
- else: # py3
- # Upstream pycurl doesn't support py3, but ubuntu 12.10 includes
- # a fork/port. That version has a bug in which it passes unicode
- # strings instead of bytes to the WRITEFUNCTION. This means that
- # if you use a WRITEFUNCTION (which tornado always does), you cannot
- # download arbitrary binary data. This needs to be fixed in the
- # ported pycurl package, but in the meantime this lambda will
- # make it work for downloading (utf8) text.
- curl.setopt(pycurl.WRITEFUNCTION, lambda s: write_function(utf8(s)))
+ write_function = buffer.write # type: ignore
+ curl.setopt(pycurl.WRITEFUNCTION, write_function)
curl.setopt(pycurl.FOLLOWLOCATION, request.follow_redirects)
curl.setopt(pycurl.MAXREDIRS, request.max_redirects)
+ assert request.connect_timeout is not None
curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(1000 * request.connect_timeout))
+ assert request.request_timeout is not None
curl.setopt(pycurl.TIMEOUT_MS, int(1000 * request.request_timeout))
if request.user_agent:
curl.setopt(pycurl.USERAGENT, native_str(request.user_agent))
@@ -343,25 +386,30 @@ def write_function(chunk):
if request.decompress_response:
curl.setopt(pycurl.ENCODING, "gzip,deflate")
else:
- curl.setopt(pycurl.ENCODING, "none")
+ curl.setopt(pycurl.ENCODING, None)
if request.proxy_host and request.proxy_port:
curl.setopt(pycurl.PROXY, request.proxy_host)
curl.setopt(pycurl.PROXYPORT, request.proxy_port)
if request.proxy_username:
- credentials = '%s:%s' % (request.proxy_username,
- request.proxy_password)
+ assert request.proxy_password is not None
+ credentials = httputil.encode_username_password(
+ request.proxy_username, request.proxy_password
+ )
curl.setopt(pycurl.PROXYUSERPWD, credentials)
- if (request.proxy_auth_mode is None or
- request.proxy_auth_mode == "basic"):
+ if request.proxy_auth_mode is None or request.proxy_auth_mode == "basic":
curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_BASIC)
elif request.proxy_auth_mode == "digest":
curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_DIGEST)
else:
raise ValueError(
- "Unsupported proxy_auth_mode %s" % request.proxy_auth_mode)
+ "Unsupported proxy_auth_mode %s" % request.proxy_auth_mode
+ )
else:
- curl.setopt(pycurl.PROXY, '')
+ try:
+ curl.unsetopt(pycurl.PROXY)
+ except TypeError: # not supported, disable proxy
+ curl.setopt(pycurl.PROXY, "")
curl.unsetopt(pycurl.PROXYUSERPWD)
if request.validate_cert:
curl.setopt(pycurl.SSL_VERIFYPEER, 1)
@@ -404,7 +452,7 @@ def write_function(chunk):
elif request.allow_nonstandard_methods or request.method in custom_methods:
curl.setopt(pycurl.CUSTOMREQUEST, request.method)
else:
- raise KeyError('unknown method ' + request.method)
+ raise KeyError("unknown method " + request.method)
body_expected = request.method in ("POST", "PATCH", "PUT")
body_present = request.body is not None
@@ -412,12 +460,14 @@ def write_function(chunk):
# Some HTTP methods nearly always have bodies while others
# almost never do. Fail in this case unless the user has
# opted out of sanity checks with allow_nonstandard_methods.
- if ((body_expected and not body_present) or
- (body_present and not body_expected)):
+ if (body_expected and not body_present) or (
+ body_present and not body_expected
+ ):
raise ValueError(
- 'Body must %sbe None for method %s (unless '
- 'allow_nonstandard_methods is true)' %
- ('not ' if body_expected else '', request.method))
+ "Body must %sbe None for method %s (unless "
+ "allow_nonstandard_methods is true)"
+ % ("not " if body_expected else "", request.method)
+ )
if body_expected or body_present:
if request.method == "GET":
@@ -426,23 +476,23 @@ def write_function(chunk):
# unless we use CUSTOMREQUEST). While the spec doesn't
# forbid clients from sending a body, it arguably
# disallows the server from doing anything with them.
- raise ValueError('Body must be None for GET request')
- request_buffer = BytesIO(utf8(request.body or ''))
+ raise ValueError("Body must be None for GET request")
+ request_buffer = BytesIO(utf8(request.body or ""))
- def ioctl(cmd):
- if cmd == curl.IOCMD_RESTARTREAD:
+ def ioctl(cmd: int) -> None:
+ if cmd == curl.IOCMD_RESTARTREAD: # type: ignore
request_buffer.seek(0)
+
curl.setopt(pycurl.READFUNCTION, request_buffer.read)
curl.setopt(pycurl.IOCTLFUNCTION, ioctl)
if request.method == "POST":
- curl.setopt(pycurl.POSTFIELDSIZE, len(request.body or ''))
+ curl.setopt(pycurl.POSTFIELDSIZE, len(request.body or ""))
else:
curl.setopt(pycurl.UPLOAD, True)
- curl.setopt(pycurl.INFILESIZE, len(request.body or ''))
+ curl.setopt(pycurl.INFILESIZE, len(request.body or ""))
if request.auth_username is not None:
- userpwd = "%s:%s" % (request.auth_username, request.auth_password or '')
-
+ assert request.auth_password is not None
if request.auth_mode is None or request.auth_mode == "basic":
curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_BASIC)
elif request.auth_mode == "digest":
@@ -450,9 +500,16 @@ def ioctl(cmd):
else:
raise ValueError("Unsupported auth_mode %s" % request.auth_mode)
- curl.setopt(pycurl.USERPWD, native_str(userpwd))
- curl_log.debug("%s %s (username: %r)", request.method, request.url,
- request.auth_username)
+ userpwd = httputil.encode_username_password(
+ request.auth_username, request.auth_password
+ )
+ curl.setopt(pycurl.USERPWD, userpwd)
+ curl_log.debug(
+ "%s %s (username: %r)",
+ request.method,
+ request.url,
+ request.auth_username,
+ )
else:
curl.unsetopt(pycurl.USERPWD)
curl_log.debug("%s %s", request.method, request.url)
@@ -466,7 +523,7 @@ def ioctl(cmd):
if request.ssl_options is not None:
raise ValueError("ssl_options not supported in curl_httpclient")
- if threading.activeCount() > 1:
+ if threading.active_count() > 1:
# libcurl/pycurl is not thread-safe by default. When multiple threads
# are used, signals should be disabled. This has the side effect
# of disabling DNS timeouts in some environments (when libcurl is
@@ -479,8 +536,13 @@ def ioctl(cmd):
if request.prepare_curl_callback is not None:
request.prepare_curl_callback(curl)
- def _curl_header_callback(self, headers, header_callback, header_line):
- header_line = native_str(header_line.decode('latin1'))
+ def _curl_header_callback(
+ self,
+ headers: httputil.HTTPHeaders,
+ header_callback: Callable[[str], None],
+ header_line_bytes: bytes,
+ ) -> None:
+ header_line = native_str(header_line_bytes.decode("latin1"))
if header_callback is not None:
self.io_loop.add_callback(header_callback, header_line)
# header_line as returned by curl includes the end-of-line characters.
@@ -497,21 +559,21 @@ def _curl_header_callback(self, headers, header_callback, header_line):
return
headers.parse_line(header_line)
- def _curl_debug(self, debug_type, debug_msg):
- debug_types = ('I', '<', '>', '<', '>')
+ def _curl_debug(self, debug_type: int, debug_msg: str) -> None:
+ debug_types = ("I", "<", ">", "<", ">")
if debug_type == 0:
debug_msg = native_str(debug_msg)
- curl_log.debug('%s', debug_msg.strip())
+ curl_log.debug("%s", debug_msg.strip())
elif debug_type in (1, 2):
debug_msg = native_str(debug_msg)
for line in debug_msg.splitlines():
- curl_log.debug('%s %s', debug_types[debug_type], line)
+ curl_log.debug("%s %s", debug_types[debug_type], line)
elif debug_type == 4:
- curl_log.debug('%s %r', debug_types[debug_type], debug_msg)
+ curl_log.debug("%s %r", debug_types[debug_type], debug_msg)
class CurlError(HTTPError):
- def __init__(self, errno, message):
+ def __init__(self, errno: int, message: str) -> None:
HTTPError.__init__(self, 599, message)
self.errno = errno
diff --git a/tornado/escape.py b/tornado/escape.py
index a79ece66ce..3cf7ff2e4a 100644
--- a/tornado/escape.py
+++ b/tornado/escape.py
@@ -19,35 +19,28 @@
have crept in over time.
"""
-from __future__ import absolute_import, division, print_function
-
+import html.entities
import json
import re
+import urllib.parse
-from tornado.util import PY3, unicode_type, basestring_type
-
-if PY3:
- from urllib.parse import parse_qs as _parse_qs
- import html.entities as htmlentitydefs
- import urllib.parse as urllib_parse
- unichr = chr
-else:
- from urlparse import parse_qs as _parse_qs
- import htmlentitydefs
- import urllib as urllib_parse
-
-try:
- import typing # noqa
-except ImportError:
- pass
+from tornado.util import unicode_type
+import typing
+from typing import Union, Any, Optional, Dict, List, Callable
-_XHTML_ESCAPE_RE = re.compile('[&<>"\']')
-_XHTML_ESCAPE_DICT = {'&': '&', '<': '<', '>': '>', '"': '"',
- '\'': '''}
+_XHTML_ESCAPE_RE = re.compile("[&<>\"']")
+_XHTML_ESCAPE_DICT = {
+ "&": "&",
+ "<": "<",
+ ">": ">",
+ '"': """,
+ "'": "'",
+}
-def xhtml_escape(value):
+
+def xhtml_escape(value: Union[str, bytes]) -> str:
"""Escapes a string so it is valid within HTML or XML.
Escapes the characters ``<``, ``>``, ``"``, ``'``, and ``&``.
@@ -58,11 +51,12 @@ def xhtml_escape(value):
Added the single quote to the list of escaped characters.
"""
- return _XHTML_ESCAPE_RE.sub(lambda match: _XHTML_ESCAPE_DICT[match.group(0)],
- to_basestring(value))
+ return _XHTML_ESCAPE_RE.sub(
+ lambda match: _XHTML_ESCAPE_DICT[match.group(0)], to_basestring(value)
+ )
-def xhtml_unescape(value):
+def xhtml_unescape(value: Union[str, bytes]) -> str:
"""Un-escapes an XML-escaped string."""
return re.sub(r"&(#?)(\w+?);", _convert_entity, _unicode(value))
@@ -70,28 +64,31 @@ def xhtml_unescape(value):
# The fact that json_encode wraps json.dumps is an implementation detail.
# Please see https://github.com/tornadoweb/tornado/pull/706
# before sending a pull request that adds **kwargs to this function.
-def json_encode(value):
+def json_encode(value: Any) -> str:
"""JSON-encodes the given Python object."""
# JSON permits but does not require forward slashes to be escaped.
# This is useful when json data is emitted in a tags from prematurely terminating
- # the javascript. Some json libraries do this escaping by default,
+ # the JavaScript. Some json libraries do this escaping by default,
# although python's standard library does not, so we do it here.
# http://stackoverflow.com/questions/1580647/json-why-are-forward-slashes-escaped
return json.dumps(value).replace("", "<\\/")
-def json_decode(value):
- """Returns Python objects for the given JSON string."""
+def json_decode(value: Union[str, bytes]) -> Any:
+ """Returns Python objects for the given JSON string.
+
+ Supports both `str` and `bytes` inputs.
+ """
return json.loads(to_basestring(value))
-def squeeze(value):
+def squeeze(value: str) -> str:
"""Replace all sequences of whitespace chars with a single space."""
return re.sub(r"[\x00-\x20]+", " ", value).strip()
-def url_escape(value, plus=True):
+def url_escape(value: Union[str, bytes], plus: bool = True) -> str:
"""Returns a URL-encoded version of the given value.
If ``plus`` is true (the default), spaces will be represented
@@ -102,89 +99,93 @@ def url_escape(value, plus=True):
.. versionadded:: 3.1
The ``plus`` argument
"""
- quote = urllib_parse.quote_plus if plus else urllib_parse.quote
+ quote = urllib.parse.quote_plus if plus else urllib.parse.quote
return quote(utf8(value))
-# python 3 changed things around enough that we need two separate
-# implementations of url_unescape. We also need our own implementation
-# of parse_qs since python 3's version insists on decoding everything.
-if not PY3:
- def url_unescape(value, encoding='utf-8', plus=True):
- """Decodes the given value from a URL.
+@typing.overload
+def url_unescape(value: Union[str, bytes], encoding: None, plus: bool = True) -> bytes:
+ pass
- The argument may be either a byte or unicode string.
- If encoding is None, the result will be a byte string. Otherwise,
- the result is a unicode string in the specified encoding.
+@typing.overload # noqa: F811
+def url_unescape(
+ value: Union[str, bytes], encoding: str = "utf-8", plus: bool = True
+) -> str:
+ pass
- If ``plus`` is true (the default), plus signs will be interpreted
- as spaces (literal plus signs must be represented as "%2B"). This
- is appropriate for query strings and form-encoded values but not
- for the path component of a URL. Note that this default is the
- reverse of Python's urllib module.
- .. versionadded:: 3.1
- The ``plus`` argument
- """
- unquote = (urllib_parse.unquote_plus if plus else urllib_parse.unquote)
- if encoding is None:
- return unquote(utf8(value))
- else:
- return unicode_type(unquote(utf8(value)), encoding)
-
- parse_qs_bytes = _parse_qs
-else:
- def url_unescape(value, encoding='utf-8', plus=True):
- """Decodes the given value from a URL.
-
- The argument may be either a byte or unicode string.
-
- If encoding is None, the result will be a byte string. Otherwise,
- the result is a unicode string in the specified encoding.
-
- If ``plus`` is true (the default), plus signs will be interpreted
- as spaces (literal plus signs must be represented as "%2B"). This
- is appropriate for query strings and form-encoded values but not
- for the path component of a URL. Note that this default is the
- reverse of Python's urllib module.
-
- .. versionadded:: 3.1
- The ``plus`` argument
- """
- if encoding is None:
- if plus:
- # unquote_to_bytes doesn't have a _plus variant
- value = to_basestring(value).replace('+', ' ')
- return urllib_parse.unquote_to_bytes(value)
- else:
- unquote = (urllib_parse.unquote_plus if plus
- else urllib_parse.unquote)
- return unquote(to_basestring(value), encoding=encoding)
-
- def parse_qs_bytes(qs, keep_blank_values=False, strict_parsing=False):
- """Parses a query string like urlparse.parse_qs, but returns the
- values as byte strings.
-
- Keys still become type str (interpreted as latin1 in python3!)
- because it's too painful to keep them as byte strings in
- python3 and in practice they're nearly always ascii anyway.
- """
- # This is gross, but python3 doesn't give us another way.
- # Latin1 is the universal donor of character encodings.
- result = _parse_qs(qs, keep_blank_values, strict_parsing,
- encoding='latin1', errors='strict')
- encoded = {}
- for k, v in result.items():
- encoded[k] = [i.encode('latin1') for i in v]
- return encoded
+def url_unescape( # noqa: F811
+ value: Union[str, bytes], encoding: Optional[str] = "utf-8", plus: bool = True
+) -> Union[str, bytes]:
+ """Decodes the given value from a URL.
+
+ The argument may be either a byte or unicode string.
+
+ If encoding is None, the result will be a byte string. Otherwise,
+ the result is a unicode string in the specified encoding.
+
+ If ``plus`` is true (the default), plus signs will be interpreted
+ as spaces (literal plus signs must be represented as "%2B"). This
+ is appropriate for query strings and form-encoded values but not
+ for the path component of a URL. Note that this default is the
+ reverse of Python's urllib module.
+
+ .. versionadded:: 3.1
+ The ``plus`` argument
+ """
+ if encoding is None:
+ if plus:
+ # unquote_to_bytes doesn't have a _plus variant
+ value = to_basestring(value).replace("+", " ")
+ return urllib.parse.unquote_to_bytes(value)
+ else:
+ unquote = urllib.parse.unquote_plus if plus else urllib.parse.unquote
+ return unquote(to_basestring(value), encoding=encoding)
+
+
+def parse_qs_bytes(
+ qs: Union[str, bytes], keep_blank_values: bool = False, strict_parsing: bool = False
+) -> Dict[str, List[bytes]]:
+ """Parses a query string like urlparse.parse_qs,
+ but takes bytes and returns the values as byte strings.
+
+ Keys still become type str (interpreted as latin1 in python3!)
+ because it's too painful to keep them as byte strings in
+ python3 and in practice they're nearly always ascii anyway.
+ """
+ # This is gross, but python3 doesn't give us another way.
+ # Latin1 is the universal donor of character encodings.
+ if isinstance(qs, bytes):
+ qs = qs.decode("latin1")
+ result = urllib.parse.parse_qs(
+ qs, keep_blank_values, strict_parsing, encoding="latin1", errors="strict"
+ )
+ encoded = {}
+ for k, v in result.items():
+ encoded[k] = [i.encode("latin1") for i in v]
+ return encoded
_UTF8_TYPES = (bytes, type(None))
-def utf8(value):
- # type: (typing.Union[bytes,unicode_type,None])->typing.Union[bytes,None]
+@typing.overload
+def utf8(value: bytes) -> bytes:
+ pass
+
+
+@typing.overload # noqa: F811
+def utf8(value: str) -> bytes:
+ pass
+
+
+@typing.overload # noqa: F811
+def utf8(value: None) -> None:
+ pass
+
+
+def utf8(value: Union[None, str, bytes]) -> Optional[bytes]: # noqa: F811
"""Converts a string argument to a byte string.
If the argument is already a byte string or None, it is returned unchanged.
@@ -193,16 +194,29 @@ def utf8(value):
if isinstance(value, _UTF8_TYPES):
return value
if not isinstance(value, unicode_type):
- raise TypeError(
- "Expected bytes, unicode, or None; got %r" % type(value)
- )
+ raise TypeError("Expected bytes, unicode, or None; got %r" % type(value))
return value.encode("utf-8")
_TO_UNICODE_TYPES = (unicode_type, type(None))
-def to_unicode(value):
+@typing.overload
+def to_unicode(value: str) -> str:
+ pass
+
+
+@typing.overload # noqa: F811
+def to_unicode(value: bytes) -> str:
+ pass
+
+
+@typing.overload # noqa: F811
+def to_unicode(value: None) -> None:
+ pass
+
+
+def to_unicode(value: Union[None, str, bytes]) -> Optional[str]: # noqa: F811
"""Converts a string argument to a unicode string.
If the argument is already a unicode string or None, it is returned
@@ -211,9 +225,7 @@ def to_unicode(value):
if isinstance(value, _TO_UNICODE_TYPES):
return value
if not isinstance(value, bytes):
- raise TypeError(
- "Expected bytes, unicode, or None; got %r" % type(value)
- )
+ raise TypeError("Expected bytes, unicode, or None; got %r" % type(value))
return value.decode("utf-8")
@@ -223,39 +235,19 @@ def to_unicode(value):
# When dealing with the standard library across python 2 and 3 it is
# sometimes useful to have a direct conversion to the native string type
-if str is unicode_type:
- native_str = to_unicode
-else:
- native_str = utf8
-
-_BASESTRING_TYPES = (basestring_type, type(None))
+native_str = to_unicode
+to_basestring = to_unicode
-def to_basestring(value):
- """Converts a string argument to a subclass of basestring.
-
- In python2, byte and unicode strings are mostly interchangeable,
- so functions that deal with a user-supplied argument in combination
- with ascii string constants can use either and should return the type
- the user supplied. In python3, the two types are not interchangeable,
- so this method is needed to convert byte strings to unicode.
- """
- if isinstance(value, _BASESTRING_TYPES):
- return value
- if not isinstance(value, bytes):
- raise TypeError(
- "Expected bytes, unicode, or None; got %r" % type(value)
- )
- return value.decode("utf-8")
-
-
-def recursive_unicode(obj):
+def recursive_unicode(obj: Any) -> Any:
"""Walks a simple data structure, converting byte strings to unicode.
Supports lists, tuples, and dictionaries.
"""
if isinstance(obj, dict):
- return dict((recursive_unicode(k), recursive_unicode(v)) for (k, v) in obj.items())
+ return dict(
+ (recursive_unicode(k), recursive_unicode(v)) for (k, v) in obj.items()
+ )
elif isinstance(obj, list):
return list(recursive_unicode(i) for i in obj)
elif isinstance(obj, tuple):
@@ -273,13 +265,20 @@ def recursive_unicode(obj):
# This regex should avoid those problems.
# Use to_unicode instead of tornado.util.u - we don't want backslashes getting
# processed as escapes.
-_URL_RE = re.compile(to_unicode(
- r"""\b((?:([\w-]+):(/{1,3})|www[.])(?:(?:(?:[^\s&()]|&|")*(?:[^!"#$%&'()*+,.:;<=>?@\[\]^`{|}~\s]))|(?:\((?:[^\s&()]|&|")*\)))+)""" # noqa: E501
-))
-
-
-def linkify(text, shorten=False, extra_params="",
- require_protocol=False, permitted_protocols=["http", "https"]):
+_URL_RE = re.compile(
+ to_unicode(
+ r"""\b((?:([\w-]+):(/{1,3})|www[.])(?:(?:(?:[^\s&()]|&|")*(?:[^!"#$%&'()*+,.:;<=>?@\[\]^`{|}~\s]))|(?:\((?:[^\s&()]|&|")*\)))+)""" # noqa: E501
+ )
+)
+
+
+def linkify(
+ text: Union[str, bytes],
+ shorten: bool = False,
+ extra_params: Union[str, Callable[[str], str]] = "",
+ require_protocol: bool = False,
+ permitted_protocols: List[str] = ["http", "https"],
+) -> str:
"""Converts plain text into HTML with links.
For example: ``linkify("Hello http://tornadoweb.org!")`` would return
@@ -312,7 +311,7 @@ def extra_params_cb(url):
if extra_params and not callable(extra_params):
extra_params = " " + extra_params.strip()
- def make_link(m):
+ def make_link(m: typing.Match) -> str:
url = m.group(1)
proto = m.group(2)
if require_protocol and not proto:
@@ -323,7 +322,7 @@ def make_link(m):
href = m.group(1)
if not proto:
- href = "http://" + href # no proto specified, use http
+ href = "http://" + href # no proto specified, use http
if callable(extra_params):
params = " " + extra_params(href).strip()
@@ -345,14 +344,18 @@ def make_link(m):
# The path is usually not that interesting once shortened
# (no more slug, etc), so it really just provides a little
# extra indication of shortening.
- url = url[:proto_len] + parts[0] + "/" + \
- parts[1][:8].split('?')[0].split('.')[0]
+ url = (
+ url[:proto_len]
+ + parts[0]
+ + "/"
+ + parts[1][:8].split("?")[0].split(".")[0]
+ )
if len(url) > max_len * 1.5: # still too long
url = url[:max_len]
if url != before_clip:
- amp = url.rfind('&')
+ amp = url.rfind("&")
# avoid splitting html char entities
if amp > max_len - 5:
url = url[:amp]
@@ -374,13 +377,13 @@ def make_link(m):
return _URL_RE.sub(make_link, text)
-def _convert_entity(m):
+def _convert_entity(m: typing.Match) -> str:
if m.group(1) == "#":
try:
- if m.group(2)[:1].lower() == 'x':
- return unichr(int(m.group(2)[1:], 16))
+ if m.group(2)[:1].lower() == "x":
+ return chr(int(m.group(2)[1:], 16))
else:
- return unichr(int(m.group(2)))
+ return chr(int(m.group(2)))
except ValueError:
return "%s;" % m.group(2)
try:
@@ -389,10 +392,10 @@ def _convert_entity(m):
return "&%s;" % m.group(2)
-def _build_unicode_map():
+def _build_unicode_map() -> Dict[str, str]:
unicode_map = {}
- for name, value in htmlentitydefs.name2codepoint.items():
- unicode_map[name] = unichr(value)
+ for name, value in html.entities.name2codepoint.items():
+ unicode_map[name] = chr(value)
return unicode_map
diff --git a/tornado/gen.py b/tornado/gen.py
index cc59402e5e..a6370259b3 100644
--- a/tornado/gen.py
+++ b/tornado/gen.py
@@ -17,25 +17,7 @@
technically asynchronous, but it is written as a single generator
instead of a collection of separate functions.
-For example, the following asynchronous handler:
-
-.. testcode::
-
- class AsyncHandler(RequestHandler):
- @asynchronous
- def get(self):
- http_client = AsyncHTTPClient()
- http_client.fetch("http://example.com",
- callback=self.on_fetch)
-
- def on_fetch(self, response):
- do_something_with_response(response)
- self.render("template.html")
-
-.. testoutput::
- :hide:
-
-could be written with ``gen`` as:
+For example, here's a coroutine-based handler:
.. testcode::
@@ -50,12 +32,12 @@ def get(self):
.. testoutput::
:hide:
-Most asynchronous functions in Tornado return a `.Future`;
-yielding this object returns its ``Future.result``.
+Asynchronous functions in Tornado return an ``Awaitable`` or `.Future`;
+yielding this object returns its result.
-You can also yield a list or dict of ``Futures``, which will be
-started at the same time and run in parallel; a list or dict of results will
-be returned when they are all finished:
+You can also yield a list or dict of other yieldable objects, which
+will be started at the same time and run in parallel; a list or dict
+of results will be returned when they are all finished:
.. testcode::
@@ -72,13 +54,9 @@ def get(self):
.. testoutput::
:hide:
-If the `~functools.singledispatch` library is available (standard in
-Python 3.4, available via the `singledispatch
-`_ package on older
-versions), additional types of objects may be yielded. Tornado includes
-support for ``asyncio.Future`` and Twisted's ``Deferred`` class when
-``tornado.platform.asyncio`` and ``tornado.platform.twisted`` are imported.
-See the `convert_yielded` function to extend this mechanism.
+If ``tornado.platform.twisted`` is imported, it is also possible to
+yield Twisted's ``Deferred`` objects. See the `convert_yielded`
+function to extend this mechanism.
.. versionchanged:: 3.2
Dict support added.
@@ -88,64 +66,46 @@ def get(self):
via ``singledispatch``.
"""
-from __future__ import absolute_import, division, print_function
-
+import asyncio
+import builtins
import collections
+from collections.abc import Generator
+import concurrent.futures
+import datetime
import functools
-import itertools
-import os
+from functools import singledispatch
+from inspect import isawaitable
import sys
import types
-import warnings
-from tornado.concurrent import (Future, is_future, chain_future, future_set_exc_info,
- future_add_done_callback, future_set_result_unless_cancelled)
+from tornado.concurrent import (
+ Future,
+ is_future,
+ chain_future,
+ future_set_exc_info,
+ future_add_done_callback,
+ future_set_result_unless_cancelled,
+)
from tornado.ioloop import IOLoop
from tornado.log import app_log
-from tornado import stack_context
-from tornado.util import PY3, raise_exc_info, TimeoutError
+from tornado.util import TimeoutError
try:
- try:
- # py34+
- from functools import singledispatch # type: ignore
- except ImportError:
- from singledispatch import singledispatch # backport
+ import contextvars
except ImportError:
- # In most cases, singledispatch is required (to avoid
- # difficult-to-diagnose problems in which the functionality
- # available differs depending on which invisble packages are
- # installed). However, in Google App Engine third-party
- # dependencies are more trouble so we allow this module to be
- # imported without it.
- if 'APPENGINE_RUNTIME' not in os.environ:
- raise
- singledispatch = None
+ contextvars = None # type: ignore
-try:
- try:
- # py35+
- from collections.abc import Generator as GeneratorType # type: ignore
- except ImportError:
- from backports_abc import Generator as GeneratorType # type: ignore
+import typing
+from typing import Union, Any, Callable, List, Type, Tuple, Awaitable, Dict, overload
- try:
- # py35+
- from inspect import isawaitable # type: ignore
- except ImportError:
- from backports_abc import isawaitable
-except ImportError:
- if 'APPENGINE_RUNTIME' not in os.environ:
- raise
- from types import GeneratorType
+if typing.TYPE_CHECKING:
+ from typing import Sequence, Deque, Optional, Set, Iterable # noqa: F401
- def isawaitable(x): # type: ignore
- return False
+_T = typing.TypeVar("_T")
-if PY3:
- import builtins
-else:
- import __builtin__ as builtins
+_Yieldable = Union[
+ None, Awaitable, List[Awaitable], Dict[Any, Awaitable], concurrent.futures.Future
+]
class KeyReuseError(Exception):
@@ -168,7 +128,7 @@ class ReturnValueIgnoredError(Exception):
pass
-def _value_from_stopiteration(e):
+def _value_from_stopiteration(e: Union[StopIteration, "Return"]) -> Any:
try:
# StopIteration has a value attribute beginning in py33.
# So does our Return class.
@@ -183,8 +143,8 @@ def _value_from_stopiteration(e):
return None
-def _create_future():
- future = Future()
+def _create_future() -> Future:
+ future = Future() # type: Future
# Fixup asyncio debug info by removing extraneous stack entries
source_traceback = getattr(future, "_source_traceback", ())
while source_traceback:
@@ -198,68 +158,32 @@ def _create_future():
return future
-def engine(func):
- """Callback-oriented decorator for asynchronous generators.
-
- This is an older interface; for new code that does not need to be
- compatible with versions of Tornado older than 3.0 the
- `coroutine` decorator is recommended instead.
+def _fake_ctx_run(f: Callable[..., _T], *args: Any, **kw: Any) -> _T:
+ return f(*args, **kw)
- This decorator is similar to `coroutine`, except it does not
- return a `.Future` and the ``callback`` argument is not treated
- specially.
- In most cases, functions decorated with `engine` should take
- a ``callback`` argument and invoke it with their result when
- they are finished. One notable exception is the
- `~tornado.web.RequestHandler` :ref:`HTTP verb methods `,
- which use ``self.finish()`` in place of a callback argument.
+@overload
+def coroutine(
+ func: Callable[..., "Generator[Any, Any, _T]"]
+) -> Callable[..., "Future[_T]"]:
+ ...
- .. deprecated:: 5.1
- This decorator will be removed in 6.0. Use `coroutine` or
- ``async def`` instead.
- """
- warnings.warn("gen.engine is deprecated, use gen.coroutine or async def instead",
- DeprecationWarning)
- func = _make_coroutine_wrapper(func, replace_callback=False)
-
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- future = func(*args, **kwargs)
-
- def final_callback(future):
- if future.result() is not None:
- raise ReturnValueIgnoredError(
- "@gen.engine functions cannot return values: %r" %
- (future.result(),))
- # The engine interface doesn't give us any way to return
- # errors but to raise them into the stack context.
- # Save the stack context here to use when the Future has resolved.
- future_add_done_callback(future, stack_context.wrap(final_callback))
- return wrapper
+@overload
+def coroutine(func: Callable[..., _T]) -> Callable[..., "Future[_T]"]:
+ ...
-def coroutine(func):
+def coroutine(
+ func: Union[Callable[..., "Generator[Any, Any, _T]"], Callable[..., _T]]
+) -> Callable[..., "Future[_T]"]:
"""Decorator for asynchronous generators.
- Any generator that yields objects from this module must be wrapped
- in either this decorator or `engine`.
-
- Coroutines may "return" by raising the special exception
- `Return(value) `. In Python 3.3+, it is also possible for
- the function to simply use the ``return value`` statement (prior to
- Python 3.3 generators were not allowed to also return values).
- In all versions of Python a coroutine that simply wishes to exit
- early may use the ``return`` statement without a value.
+ For compatibility with older versions of Python, coroutines may
+ also "return" by raising the special exception `Return(value)
+ `.
- Functions with this decorator return a `.Future`. Additionally,
- they may be called with a ``callback`` keyword argument, which
- will be invoked with the future's result when it resolves. If the
- coroutine fails, the callback will not be run and an exception
- will be raised into the surrounding `.StackContext`. The
- ``callback`` argument is not visible inside the decorated
- function; it is handled by the decorator itself.
+ Functions with this decorator return a `.Future`.
.. warning::
@@ -271,40 +195,25 @@ def coroutine(func):
`.IOLoop.run_sync` for top-level calls, or passing the `.Future`
to `.IOLoop.add_future`.
- .. deprecated:: 5.1
-
- The ``callback`` argument is deprecated and will be removed in 6.0.
- Use the returned awaitable object instead.
- """
- return _make_coroutine_wrapper(func, replace_callback=True)
-
+ .. versionchanged:: 6.0
-def _make_coroutine_wrapper(func, replace_callback):
- """The inner workings of ``@gen.coroutine`` and ``@gen.engine``.
+ The ``callback`` argument was removed. Use the returned
+ awaitable object instead.
- The two decorators differ in their treatment of the ``callback``
- argument, so we cannot simply implement ``@engine`` in terms of
- ``@coroutine``.
"""
- # On Python 3.5, set the coroutine flag on our generator, to allow it
- # to be used with 'await'.
- wrapped = func
- if hasattr(types, 'coroutine'):
- func = types.coroutine(func)
- @functools.wraps(wrapped)
+ @functools.wraps(func)
def wrapper(*args, **kwargs):
+ # type: (*Any, **Any) -> Future[_T]
+ # This function is type-annotated with a comment to work around
+ # https://bitbucket.org/pypy/pypy/issues/2868/segfault-with-args-type-annotation-in
future = _create_future()
-
- if replace_callback and 'callback' in kwargs:
- warnings.warn("callback arguments are deprecated, use the returned Future instead",
- DeprecationWarning, stacklevel=2)
- callback = kwargs.pop('callback')
- IOLoop.current().add_future(
- future, lambda future: callback(future.result()))
-
+ if contextvars is not None:
+ ctx_run = contextvars.copy_context().run # type: Callable
+ else:
+ ctx_run = _fake_ctx_run
try:
- result = func(*args, **kwargs)
+ result = ctx_run(func, *args, **kwargs)
except (Return, StopIteration) as e:
result = _value_from_stopiteration(e)
except Exception:
@@ -313,25 +222,20 @@ def wrapper(*args, **kwargs):
return future
finally:
# Avoid circular references
- future = None
+ future = None # type: ignore
else:
- if isinstance(result, GeneratorType):
+ if isinstance(result, Generator):
# Inline the first iteration of Runner.run. This lets us
# avoid the cost of creating a Runner when the coroutine
# never actually yields, which in turn allows us to
# use "optional" coroutines in critical path code without
# performance penalty for the synchronous case.
try:
- orig_stack_contexts = stack_context._state.contexts
- yielded = next(result)
- if stack_context._state.contexts is not orig_stack_contexts:
- yielded = _create_future()
- yielded.set_exception(
- stack_context.StackContextInconsistentError(
- 'stack_context inconsistency (probably caused '
- 'by yield within a "with StackContext" block)'))
+ yielded = ctx_run(next, result)
except (StopIteration, Return) as e:
- future_set_result_unless_cancelled(future, _value_from_stopiteration(e))
+ future_set_result_unless_cancelled(
+ future, _value_from_stopiteration(e)
+ )
except Exception:
future_set_exc_info(future, sys.exc_info())
else:
@@ -342,8 +246,8 @@ def wrapper(*args, **kwargs):
# We do this by exploiting the public API
# add_done_callback() instead of putting a private
# attribute on the Future.
- # (Github issues #1769, #2229).
- runner = Runner(result, future, yielded)
+ # (GitHub issues #1769, #2229).
+ runner = Runner(ctx_run, result, future, yielded)
future.add_done_callback(lambda _: runner)
yielded = None
try:
@@ -357,22 +261,22 @@ def wrapper(*args, **kwargs):
# benchmarks (relative to the refcount-based scheme
# used in the absence of cycles). We can avoid the
# cycle by clearing the local variable after we return it.
- future = None
+ future = None # type: ignore
future_set_result_unless_cancelled(future, result)
return future
- wrapper.__wrapped__ = wrapped
- wrapper.__tornado_coroutine__ = True
+ wrapper.__wrapped__ = func # type: ignore
+ wrapper.__tornado_coroutine__ = True # type: ignore
return wrapper
-def is_coroutine_function(func):
+def is_coroutine_function(func: Any) -> bool:
"""Return whether *func* is a coroutine function, i.e. a function
wrapped with `~.gen.coroutine`.
.. versionadded:: 4.5
"""
- return getattr(func, '__tornado_coroutine__', False)
+ return getattr(func, "__tornado_coroutine__", False)
class Return(Exception):
@@ -395,30 +299,32 @@ def fetch_json(url):
but it is never necessary to ``raise gen.Return()``. The ``return``
statement can be used with no arguments instead.
"""
- def __init__(self, value=None):
- super(Return, self).__init__()
+
+ def __init__(self, value: Any = None) -> None:
+ super().__init__()
self.value = value
# Cython recognizes subclasses of StopIteration with a .args tuple.
self.args = (value,)
class WaitIterator(object):
- """Provides an iterator to yield the results of futures as they finish.
+ """Provides an iterator to yield the results of awaitables as they finish.
- Yielding a set of futures like this:
+ Yielding a set of awaitables like this:
- ``results = yield [future1, future2]``
+ ``results = yield [awaitable1, awaitable2]``
- pauses the coroutine until both ``future1`` and ``future2``
+ pauses the coroutine until both ``awaitable1`` and ``awaitable2``
return, and then restarts the coroutine with the results of both
- futures. If either future is an exception, the expression will
- raise that exception and all the results will be lost.
+ awaitables. If either awaitable raises an exception, the
+ expression will raise that exception and all the results will be
+ lost.
- If you need to get the result of each future as soon as possible,
- or if you need the result of some futures even if others produce
+ If you need to get the result of each awaitable as soon as possible,
+ or if you need the result of some awaitables even if others produce
errors, you can use ``WaitIterator``::
- wait_iterator = gen.WaitIterator(future1, future2)
+ wait_iterator = gen.WaitIterator(awaitable1, awaitable2)
while not wait_iterator.done():
try:
result = yield wait_iterator.next()
@@ -434,7 +340,7 @@ class WaitIterator(object):
input arguments*. If you need to know which future produced the
current result, you can use the attributes
``WaitIterator.current_future``, or ``WaitIterator.current_index``
- to get the index of the future from the input list. (if keyword
+ to get the index of the awaitable from the input list. (if keyword
arguments were used in the construction of the `WaitIterator`,
``current_index`` will use the corresponding keyword).
@@ -455,26 +361,29 @@ class WaitIterator(object):
Added ``async for`` support in Python 3.5.
"""
- def __init__(self, *args, **kwargs):
+
+ _unfinished = {} # type: Dict[Future, Union[int, str]]
+
+ def __init__(self, *args: Future, **kwargs: Future) -> None:
if args and kwargs:
- raise ValueError(
- "You must provide args or kwargs, not both")
+ raise ValueError("You must provide args or kwargs, not both")
if kwargs:
self._unfinished = dict((f, k) for (k, f) in kwargs.items())
- futures = list(kwargs.values())
+ futures = list(kwargs.values()) # type: Sequence[Future]
else:
self._unfinished = dict((f, i) for (i, f) in enumerate(args))
futures = args
- self._finished = collections.deque()
- self.current_index = self.current_future = None
- self._running_future = None
+ self._finished = collections.deque() # type: Deque[Future]
+ self.current_index = None # type: Optional[Union[str, int]]
+ self.current_future = None # type: Optional[Future]
+ self._running_future = None # type: Optional[Future]
for future in futures:
future_add_done_callback(future, self._done_callback)
- def done(self):
+ def done(self) -> bool:
"""Returns True if this iterator has no more results."""
if self._finished or self._unfinished:
return False
@@ -482,7 +391,7 @@ def done(self):
self.current_index = self.current_future = None
return True
- def next(self):
+ def next(self) -> Future:
"""Returns a `.Future` that will yield the next available result.
Note that this `.Future` will not be the same object as any of
@@ -491,233 +400,45 @@ def next(self):
self._running_future = Future()
if self._finished:
- self._return_result(self._finished.popleft())
+ return self._return_result(self._finished.popleft())
return self._running_future
- def _done_callback(self, done):
+ def _done_callback(self, done: Future) -> None:
if self._running_future and not self._running_future.done():
self._return_result(done)
else:
self._finished.append(done)
- def _return_result(self, done):
+ def _return_result(self, done: Future) -> Future:
"""Called set the returned future's state that of the future
we yielded, and set the current future for the iterator.
"""
+ if self._running_future is None:
+ raise Exception("no future is running")
chain_future(done, self._running_future)
+ res = self._running_future
+ self._running_future = None
self.current_future = done
self.current_index = self._unfinished.pop(done)
- def __aiter__(self):
+ return res
+
+ def __aiter__(self) -> typing.AsyncIterator:
return self
- def __anext__(self):
+ def __anext__(self) -> Future:
if self.done():
# Lookup by name to silence pyflakes on older versions.
- raise getattr(builtins, 'StopAsyncIteration')()
+ raise getattr(builtins, "StopAsyncIteration")()
return self.next()
-class YieldPoint(object):
- """Base class for objects that may be yielded from the generator.
-
- .. deprecated:: 4.0
- Use `Futures <.Future>` instead. This class and all its subclasses
- will be removed in 6.0
- """
- def __init__(self):
- warnings.warn("YieldPoint is deprecated, use Futures instead",
- DeprecationWarning)
-
- def start(self, runner):
- """Called by the runner after the generator has yielded.
-
- No other methods will be called on this object before ``start``.
- """
- raise NotImplementedError()
-
- def is_ready(self):
- """Called by the runner to determine whether to resume the generator.
-
- Returns a boolean; may be called more than once.
- """
- raise NotImplementedError()
-
- def get_result(self):
- """Returns the value to use as the result of the yield expression.
-
- This method will only be called once, and only after `is_ready`
- has returned true.
- """
- raise NotImplementedError()
-
-
-class Callback(YieldPoint):
- """Returns a callable object that will allow a matching `Wait` to proceed.
-
- The key may be any value suitable for use as a dictionary key, and is
- used to match ``Callbacks`` to their corresponding ``Waits``. The key
- must be unique among outstanding callbacks within a single run of the
- generator function, but may be reused across different runs of the same
- function (so constants generally work fine).
-
- The callback may be called with zero or one arguments; if an argument
- is given it will be returned by `Wait`.
-
- .. deprecated:: 4.0
- Use `Futures <.Future>` instead. This class will be removed in 6.0.
- """
- def __init__(self, key):
- warnings.warn("gen.Callback is deprecated, use Futures instead",
- DeprecationWarning)
- self.key = key
-
- def start(self, runner):
- self.runner = runner
- runner.register_callback(self.key)
-
- def is_ready(self):
- return True
-
- def get_result(self):
- return self.runner.result_callback(self.key)
-
-
-class Wait(YieldPoint):
- """Returns the argument passed to the result of a previous `Callback`.
-
- .. deprecated:: 4.0
- Use `Futures <.Future>` instead. This class will be removed in 6.0.
- """
- def __init__(self, key):
- warnings.warn("gen.Wait is deprecated, use Futures instead",
- DeprecationWarning)
- self.key = key
-
- def start(self, runner):
- self.runner = runner
-
- def is_ready(self):
- return self.runner.is_ready(self.key)
-
- def get_result(self):
- return self.runner.pop_result(self.key)
-
-
-class WaitAll(YieldPoint):
- """Returns the results of multiple previous `Callbacks `.
-
- The argument is a sequence of `Callback` keys, and the result is
- a list of results in the same order.
-
- `WaitAll` is equivalent to yielding a list of `Wait` objects.
-
- .. deprecated:: 4.0
- Use `Futures <.Future>` instead. This class will be removed in 6.0.
- """
- def __init__(self, keys):
- warnings.warn("gen.WaitAll is deprecated, use gen.multi instead",
- DeprecationWarning)
- self.keys = keys
-
- def start(self, runner):
- self.runner = runner
-
- def is_ready(self):
- return all(self.runner.is_ready(key) for key in self.keys)
-
- def get_result(self):
- return [self.runner.pop_result(key) for key in self.keys]
-
-
-def Task(func, *args, **kwargs):
- """Adapts a callback-based asynchronous function for use in coroutines.
-
- Takes a function (and optional additional arguments) and runs it with
- those arguments plus a ``callback`` keyword argument. The argument passed
- to the callback is returned as the result of the yield expression.
-
- .. versionchanged:: 4.0
- ``gen.Task`` is now a function that returns a `.Future`, instead of
- a subclass of `YieldPoint`. It still behaves the same way when
- yielded.
-
- .. deprecated:: 5.1
- This function is deprecated and will be removed in 6.0.
- """
- warnings.warn("gen.Task is deprecated, use Futures instead",
- DeprecationWarning)
- future = _create_future()
-
- def handle_exception(typ, value, tb):
- if future.done():
- return False
- future_set_exc_info(future, (typ, value, tb))
- return True
-
- def set_result(result):
- if future.done():
- return
- future_set_result_unless_cancelled(future, result)
- with stack_context.ExceptionStackContext(handle_exception):
- func(*args, callback=_argument_adapter(set_result), **kwargs)
- return future
-
-
-class YieldFuture(YieldPoint):
- def __init__(self, future):
- """Adapts a `.Future` to the `YieldPoint` interface.
-
- .. versionchanged:: 5.0
- The ``io_loop`` argument (deprecated since version 4.1) has been removed.
-
- .. deprecated:: 5.1
- This class will be removed in 6.0.
- """
- warnings.warn("YieldFuture is deprecated, use Futures instead",
- DeprecationWarning)
- self.future = future
- self.io_loop = IOLoop.current()
-
- def start(self, runner):
- if not self.future.done():
- self.runner = runner
- self.key = object()
- runner.register_callback(self.key)
- self.io_loop.add_future(self.future, runner.result_callback(self.key))
- else:
- self.runner = None
- self.result_fn = self.future.result
-
- def is_ready(self):
- if self.runner is not None:
- return self.runner.is_ready(self.key)
- else:
- return True
-
- def get_result(self):
- if self.runner is not None:
- return self.runner.pop_result(self.key).result()
- else:
- return self.result_fn()
-
-
-def _contains_yieldpoint(children):
- """Returns True if ``children`` contains any YieldPoints.
-
- ``children`` may be a dict or a list, as used by `MultiYieldPoint`
- and `multi_future`.
- """
- if isinstance(children, dict):
- return any(isinstance(i, YieldPoint) for i in children.values())
- if isinstance(children, list):
- return any(isinstance(i, YieldPoint) for i in children)
- return False
-
-
-def multi(children, quiet_exceptions=()):
+def multi(
+ children: Union[List[_Yieldable], Dict[Any, _Yieldable]],
+ quiet_exceptions: "Union[Type[Exception], Tuple[Type[Exception], ...]]" = (),
+) -> "Union[Future[List], Future[Dict]]":
"""Runs multiple asynchronous operations in parallel.
``children`` may either be a list or a dict whose values are
@@ -738,11 +459,6 @@ def multi(children, quiet_exceptions=()):
one. All others will be logged, unless they are of types
contained in the ``quiet_exceptions`` argument.
- If any of the inputs are `YieldPoints `, the returned
- yieldable object is a `YieldPoint`. Otherwise, returns a `.Future`.
- This means that the result of `multi` can be used in a native
- coroutine if and only if all of its children can be.
-
In a ``yield``-based coroutine, it is not normally necessary to
call this function directly, since the coroutine runner will
do it automatically when a list or dict is yielded. However,
@@ -764,91 +480,22 @@ def multi(children, quiet_exceptions=()):
.. versionchanged:: 4.3
Replaced the class ``Multi`` and the function ``multi_future``
with a unified function ``multi``. Added support for yieldables
- other than `YieldPoint` and `.Future`.
+ other than ``YieldPoint`` and `.Future`.
"""
- if _contains_yieldpoint(children):
- return MultiYieldPoint(children, quiet_exceptions=quiet_exceptions)
- else:
- return multi_future(children, quiet_exceptions=quiet_exceptions)
+ return multi_future(children, quiet_exceptions=quiet_exceptions)
Multi = multi
-class MultiYieldPoint(YieldPoint):
- """Runs multiple asynchronous operations in parallel.
-
- This class is similar to `multi`, but it always creates a stack
- context even when no children require it. It is not compatible with
- native coroutines.
-
- .. versionchanged:: 4.2
- If multiple ``YieldPoints`` fail, any exceptions after the first
- (which is raised) will be logged. Added the ``quiet_exceptions``
- argument to suppress this logging for selected exception types.
-
- .. versionchanged:: 4.3
- Renamed from ``Multi`` to ``MultiYieldPoint``. The name ``Multi``
- remains as an alias for the equivalent `multi` function.
-
- .. deprecated:: 4.3
- Use `multi` instead. This class will be removed in 6.0.
- """
- def __init__(self, children, quiet_exceptions=()):
- warnings.warn("MultiYieldPoint is deprecated, use Futures instead",
- DeprecationWarning)
- self.keys = None
- if isinstance(children, dict):
- self.keys = list(children.keys())
- children = children.values()
- self.children = []
- for i in children:
- if not isinstance(i, YieldPoint):
- i = convert_yielded(i)
- if is_future(i):
- i = YieldFuture(i)
- self.children.append(i)
- assert all(isinstance(i, YieldPoint) for i in self.children)
- self.unfinished_children = set(self.children)
- self.quiet_exceptions = quiet_exceptions
-
- def start(self, runner):
- for i in self.children:
- i.start(runner)
-
- def is_ready(self):
- finished = list(itertools.takewhile(
- lambda i: i.is_ready(), self.unfinished_children))
- self.unfinished_children.difference_update(finished)
- return not self.unfinished_children
-
- def get_result(self):
- result_list = []
- exc_info = None
- for f in self.children:
- try:
- result_list.append(f.get_result())
- except Exception as e:
- if exc_info is None:
- exc_info = sys.exc_info()
- else:
- if not isinstance(e, self.quiet_exceptions):
- app_log.error("Multiple exceptions in yield list",
- exc_info=True)
- if exc_info is not None:
- raise_exc_info(exc_info)
- if self.keys is not None:
- return dict(zip(self.keys, result_list))
- else:
- return list(result_list)
-
-
-def multi_future(children, quiet_exceptions=()):
+def multi_future(
+ children: Union[List[_Yieldable], Dict[Any, _Yieldable]],
+ quiet_exceptions: "Union[Type[Exception], Tuple[Type[Exception], ...]]" = (),
+) -> "Union[Future[List], Future[Dict]]":
"""Wait for multiple asynchronous futures in parallel.
- This function is similar to `multi`, but does not support
- `YieldPoints `.
+ Since Tornado 6.0, this function is exactly the same as `multi`.
.. versionadded:: 4.0
@@ -861,49 +508,51 @@ def multi_future(children, quiet_exceptions=()):
Use `multi` instead.
"""
if isinstance(children, dict):
- keys = list(children.keys())
- children = children.values()
+ keys = list(children.keys()) # type: Optional[List]
+ children_seq = children.values() # type: Iterable
else:
keys = None
- children = list(map(convert_yielded, children))
- assert all(is_future(i) or isinstance(i, _NullFuture) for i in children)
- unfinished_children = set(children)
+ children_seq = children
+ children_futs = list(map(convert_yielded, children_seq))
+ assert all(is_future(i) or isinstance(i, _NullFuture) for i in children_futs)
+ unfinished_children = set(children_futs)
future = _create_future()
- if not children:
- future_set_result_unless_cancelled(future,
- {} if keys is not None else [])
+ if not children_futs:
+ future_set_result_unless_cancelled(future, {} if keys is not None else [])
- def callback(f):
- unfinished_children.remove(f)
+ def callback(fut: Future) -> None:
+ unfinished_children.remove(fut)
if not unfinished_children:
result_list = []
- for f in children:
+ for f in children_futs:
try:
result_list.append(f.result())
except Exception as e:
if future.done():
if not isinstance(e, quiet_exceptions):
- app_log.error("Multiple exceptions in yield list",
- exc_info=True)
+ app_log.error(
+ "Multiple exceptions in yield list", exc_info=True
+ )
else:
future_set_exc_info(future, sys.exc_info())
if not future.done():
if keys is not None:
- future_set_result_unless_cancelled(future,
- dict(zip(keys, result_list)))
+ future_set_result_unless_cancelled(
+ future, dict(zip(keys, result_list))
+ )
else:
future_set_result_unless_cancelled(future, result_list)
- listening = set()
- for f in children:
+ listening = set() # type: Set[Future]
+ for f in children_futs:
if f not in listening:
listening.add(f)
future_add_done_callback(f, callback)
return future
-def maybe_future(x):
+def maybe_future(x: Any) -> Future:
"""Converts ``x`` into a `.Future`.
If ``x`` is already a `.Future`, it is simply returned; otherwise
@@ -924,7 +573,11 @@ def maybe_future(x):
return fut
-def with_timeout(timeout, future, quiet_exceptions=()):
+def with_timeout(
+ timeout: Union[float, datetime.timedelta],
+ future: _Yieldable,
+ quiet_exceptions: "Union[Type[Exception], Tuple[Type[Exception], ...]]" = (),
+) -> Future:
"""Wraps a `.Future` (or other yieldable object) in a timeout.
Raises `tornado.util.TimeoutError` if the input future does not
@@ -933,10 +586,9 @@ def with_timeout(timeout, future, quiet_exceptions=()):
an absolute time relative to `.IOLoop.time`)
If the wrapped `.Future` fails after it has timed out, the exception
- will be logged unless it is of a type contained in ``quiet_exceptions``
- (which may be an exception type or a sequence of types).
-
- Does not support `YieldPoint` subclasses.
+ will be logged unless it is either of a type contained in
+ ``quiet_exceptions`` (which may be an exception type or a sequence of
+ types), or an ``asyncio.CancelledError``.
The wrapped `.Future` is not canceled when the timeout expires,
permitting it to be reused. `asyncio.wait_for` is similar to this
@@ -951,50 +603,55 @@ def with_timeout(timeout, future, quiet_exceptions=()):
.. versionchanged:: 4.4
Added support for yieldable objects other than `.Future`.
+ .. versionchanged:: 6.0.3
+ ``asyncio.CancelledError`` is now always considered "quiet".
+
"""
- # TODO: allow YieldPoints in addition to other yieldables?
- # Tricky to do with stack_context semantics.
- #
# It's tempting to optimize this by cancelling the input future on timeout
# instead of creating a new one, but A) we can't know if we are the only
# one waiting on the input future, so cancelling it might disrupt other
# callers and B) concurrent futures can only be cancelled while they are
# in the queue, so cancellation cannot reliably bound our waiting time.
- future = convert_yielded(future)
+ future_converted = convert_yielded(future)
result = _create_future()
- chain_future(future, result)
+ chain_future(future_converted, result)
io_loop = IOLoop.current()
- def error_callback(future):
+ def error_callback(future: Future) -> None:
try:
future.result()
+ except asyncio.CancelledError:
+ pass
except Exception as e:
if not isinstance(e, quiet_exceptions):
- app_log.error("Exception in Future %r after timeout",
- future, exc_info=True)
+ app_log.error(
+ "Exception in Future %r after timeout", future, exc_info=True
+ )
- def timeout_callback():
+ def timeout_callback() -> None:
if not result.done():
result.set_exception(TimeoutError("Timeout"))
# In case the wrapped future goes on to fail, log it.
- future_add_done_callback(future, error_callback)
- timeout_handle = io_loop.add_timeout(
- timeout, timeout_callback)
- if isinstance(future, Future):
+ future_add_done_callback(future_converted, error_callback)
+
+ timeout_handle = io_loop.add_timeout(timeout, timeout_callback)
+ if isinstance(future_converted, Future):
# We know this future will resolve on the IOLoop, so we don't
# need the extra thread-safety of IOLoop.add_future (and we also
# don't care about StackContext here.
future_add_done_callback(
- future, lambda future: io_loop.remove_timeout(timeout_handle))
+ future_converted, lambda future: io_loop.remove_timeout(timeout_handle)
+ )
else:
# concurrent.futures.Futures may resolve on any thread, so we
# need to route them back to the IOLoop.
io_loop.add_future(
- future, lambda future: io_loop.remove_timeout(timeout_handle))
+ future_converted, lambda future: io_loop.remove_timeout(timeout_handle)
+ )
return result
-def sleep(duration):
+def sleep(duration: float) -> "Future[None]":
"""Return a `.Future` that resolves after the given number of seconds.
When used with ``yield`` in a coroutine, this is a non-blocking
@@ -1009,8 +666,9 @@ def sleep(duration):
.. versionadded:: 4.1
"""
f = _create_future()
- IOLoop.current().call_later(duration,
- lambda: future_set_result_unless_cancelled(f, None))
+ IOLoop.current().call_later(
+ duration, lambda: future_set_result_unless_cancelled(f, None)
+ )
return f
@@ -1019,22 +677,28 @@ class _NullFuture(object):
It's not actually a `Future` to avoid depending on a particular event loop.
Handled as a special case in the coroutine runner.
+
+ We lie and tell the type checker that a _NullFuture is a Future so
+ we don't have to leak _NullFuture into lots of public APIs. But
+ this means that the type checker can't warn us when we're passing
+ a _NullFuture into a code path that doesn't understand what to do
+ with it.
"""
- def result(self):
+
+ def result(self) -> None:
return None
- def done(self):
+ def done(self) -> bool:
return True
# _null_future is used as a dummy value in the coroutine runner. It differs
# from moment in that moment always adds a delay of one IOLoop iteration
# while _null_future is processed as soon as possible.
-_null_future = _NullFuture()
+_null_future = typing.cast(Future, _NullFuture())
-moment = _NullFuture()
-moment.__doc__ = \
- """A special object which may be yielded to allow the IOLoop to run for
+moment = typing.cast(Future, _NullFuture())
+moment.__doc__ = """A special object which may be yielded to allow the IOLoop to run for
one iteration.
This is not needed in normal use but it can be helpful in long-running
@@ -1042,6 +706,9 @@ def done(self):
Usage: ``yield gen.moment``
+In native coroutines, the equivalent of ``yield gen.moment`` is
+``await asyncio.sleep(0)``.
+
.. versionadded:: 4.0
.. deprecated:: 4.5
@@ -1051,68 +718,33 @@ def done(self):
class Runner(object):
- """Internal implementation of `tornado.gen.engine`.
+ """Internal implementation of `tornado.gen.coroutine`.
Maintains information about pending callbacks and their results.
The results of the generator are stored in ``result_future`` (a
`.Future`)
"""
- def __init__(self, gen, result_future, first_yielded):
+
+ def __init__(
+ self,
+ ctx_run: Callable,
+ gen: "Generator[_Yieldable, Any, _T]",
+ result_future: "Future[_T]",
+ first_yielded: _Yieldable,
+ ) -> None:
+ self.ctx_run = ctx_run
self.gen = gen
self.result_future = result_future
- self.future = _null_future
- self.yield_point = None
- self.pending_callbacks = None
- self.results = None
+ self.future = _null_future # type: Union[None, Future]
self.running = False
self.finished = False
- self.had_exception = False
self.io_loop = IOLoop.current()
- # For efficiency, we do not create a stack context until we
- # reach a YieldPoint (stack contexts are required for the historical
- # semantics of YieldPoints, but not for Futures). When we have
- # done so, this field will be set and must be called at the end
- # of the coroutine.
- self.stack_context_deactivate = None
if self.handle_yield(first_yielded):
- gen = result_future = first_yielded = None
- self.run()
-
- def register_callback(self, key):
- """Adds ``key`` to the list of callbacks."""
- if self.pending_callbacks is None:
- # Lazily initialize the old-style YieldPoint data structures.
- self.pending_callbacks = set()
- self.results = {}
- if key in self.pending_callbacks:
- raise KeyReuseError("key %r is already pending" % (key,))
- self.pending_callbacks.add(key)
-
- def is_ready(self, key):
- """Returns true if a result is available for ``key``."""
- if self.pending_callbacks is None or key not in self.pending_callbacks:
- raise UnknownKeyError("key %r is not pending" % (key,))
- return key in self.results
-
- def set_result(self, key, result):
- """Sets the result for ``key`` and attempts to resume the generator."""
- self.results[key] = result
- if self.yield_point is not None and self.yield_point.is_ready():
- try:
- future_set_result_unless_cancelled(self.future,
- self.yield_point.get_result())
- except:
- future_set_exc_info(self.future, sys.exc_info())
- self.yield_point = None
- self.run()
-
- def pop_result(self, key):
- """Returns the result for ``key`` and unregisters it."""
- self.pending_callbacks.remove(key)
- return self.results.pop(key)
-
- def run(self):
+ gen = result_future = first_yielded = None # type: ignore
+ self.ctx_run(self.run)
+
+ def run(self) -> None:
"""Starts or resumes the generator, running until it reaches a
yield point that is not ready.
"""
@@ -1122,23 +754,23 @@ def run(self):
self.running = True
while True:
future = self.future
+ if future is None:
+ raise Exception("No pending future")
if not future.done():
return
self.future = None
try:
- orig_stack_contexts = stack_context._state.contexts
exc_info = None
try:
value = future.result()
except Exception:
- self.had_exception = True
exc_info = sys.exc_info()
future = None
if exc_info is not None:
try:
- yielded = self.gen.throw(*exc_info)
+ yielded = self.gen.throw(*exc_info) # type: ignore
finally:
# Break up a reference to itself
# for faster GC on CPython.
@@ -1146,33 +778,19 @@ def run(self):
else:
yielded = self.gen.send(value)
- if stack_context._state.contexts is not orig_stack_contexts:
- self.gen.throw(
- stack_context.StackContextInconsistentError(
- 'stack_context inconsistency (probably caused '
- 'by yield within a "with StackContext" block)'))
except (StopIteration, Return) as e:
self.finished = True
self.future = _null_future
- if self.pending_callbacks and not self.had_exception:
- # If we ran cleanly without waiting on all callbacks
- # raise an error (really more of a warning). If we
- # had an exception then some callbacks may have been
- # orphaned, so skip the check in that case.
- raise LeakedCallbackError(
- "finished without waiting for callbacks %r" %
- self.pending_callbacks)
- future_set_result_unless_cancelled(self.result_future,
- _value_from_stopiteration(e))
- self.result_future = None
- self._deactivate_stack_context()
+ future_set_result_unless_cancelled(
+ self.result_future, _value_from_stopiteration(e)
+ )
+ self.result_future = None # type: ignore
return
except Exception:
self.finished = True
self.future = _null_future
future_set_exc_info(self.result_future, sys.exc_info())
- self.result_future = None
- self._deactivate_stack_context()
+ self.result_future = None # type: ignore
return
if not self.handle_yield(yielded):
return
@@ -1180,164 +798,56 @@ def run(self):
finally:
self.running = False
- def handle_yield(self, yielded):
- # Lists containing YieldPoints require stack contexts;
- # other lists are handled in convert_yielded.
- if _contains_yieldpoint(yielded):
- yielded = multi(yielded)
-
- if isinstance(yielded, YieldPoint):
- # YieldPoints are too closely coupled to the Runner to go
- # through the generic convert_yielded mechanism.
+ def handle_yield(self, yielded: _Yieldable) -> bool:
+ try:
+ self.future = convert_yielded(yielded)
+ except BadYieldError:
self.future = Future()
-
- def start_yield_point():
- try:
- yielded.start(self)
- if yielded.is_ready():
- future_set_result_unless_cancelled(self.future, yielded.get_result())
- else:
- self.yield_point = yielded
- except Exception:
- self.future = Future()
- future_set_exc_info(self.future, sys.exc_info())
-
- if self.stack_context_deactivate is None:
- # Start a stack context if this is the first
- # YieldPoint we've seen.
- with stack_context.ExceptionStackContext(
- self.handle_exception) as deactivate:
- self.stack_context_deactivate = deactivate
-
- def cb():
- start_yield_point()
- self.run()
- self.io_loop.add_callback(cb)
- return False
- else:
- start_yield_point()
- else:
- try:
- self.future = convert_yielded(yielded)
- except BadYieldError:
- self.future = Future()
- future_set_exc_info(self.future, sys.exc_info())
+ future_set_exc_info(self.future, sys.exc_info())
if self.future is moment:
- self.io_loop.add_callback(self.run)
+ self.io_loop.add_callback(self.ctx_run, self.run)
return False
+ elif self.future is None:
+ raise Exception("no pending future")
elif not self.future.done():
- def inner(f):
+
+ def inner(f: Any) -> None:
# Break a reference cycle to speed GC.
- f = None # noqa
- self.run()
- self.io_loop.add_future(
- self.future, inner)
+ f = None # noqa: F841
+ self.ctx_run(self.run)
+
+ self.io_loop.add_future(self.future, inner)
return False
return True
- def result_callback(self, key):
- return stack_context.wrap(_argument_adapter(
- functools.partial(self.set_result, key)))
-
- def handle_exception(self, typ, value, tb):
+ def handle_exception(
+ self, typ: Type[Exception], value: Exception, tb: types.TracebackType
+ ) -> bool:
if not self.running and not self.finished:
self.future = Future()
future_set_exc_info(self.future, (typ, value, tb))
- self.run()
+ self.ctx_run(self.run)
return True
else:
return False
- def _deactivate_stack_context(self):
- if self.stack_context_deactivate is not None:
- self.stack_context_deactivate()
- self.stack_context_deactivate = None
-
-
-Arguments = collections.namedtuple('Arguments', ['args', 'kwargs'])
-
-
-def _argument_adapter(callback):
- """Returns a function that when invoked runs ``callback`` with one arg.
-
- If the function returned by this function is called with exactly
- one argument, that argument is passed to ``callback``. Otherwise
- the args tuple and kwargs dict are wrapped in an `Arguments` object.
- """
- def wrapper(*args, **kwargs):
- if kwargs or len(args) > 1:
- callback(Arguments(args, kwargs))
- elif args:
- callback(args[0])
- else:
- callback(None)
- return wrapper
-
# Convert Awaitables into Futures.
try:
- import asyncio
-except ImportError:
- # Py2-compatible version for use with Cython.
- # Copied from PEP 380.
- @coroutine
- def _wrap_awaitable(x):
- if hasattr(x, '__await__'):
- _i = x.__await__()
- else:
- _i = iter(x)
- try:
- _y = next(_i)
- except StopIteration as _e:
- _r = _value_from_stopiteration(_e)
- else:
- while 1:
- try:
- _s = yield _y
- except GeneratorExit as _e:
- try:
- _m = _i.close
- except AttributeError:
- pass
- else:
- _m()
- raise _e
- except BaseException as _e:
- _x = sys.exc_info()
- try:
- _m = _i.throw
- except AttributeError:
- raise _e
- else:
- try:
- _y = _m(*_x)
- except StopIteration as _e:
- _r = _value_from_stopiteration(_e)
- break
- else:
- try:
- if _s is None:
- _y = next(_i)
- else:
- _y = _i.send(_s)
- except StopIteration as _e:
- _r = _value_from_stopiteration(_e)
- break
- raise Return(_r)
-else:
- try:
- _wrap_awaitable = asyncio.ensure_future
- except AttributeError:
- # asyncio.ensure_future was introduced in Python 3.4.4, but
- # Debian jessie still ships with 3.4.2 so try the old name.
- _wrap_awaitable = getattr(asyncio, 'async')
+ _wrap_awaitable = asyncio.ensure_future
+except AttributeError:
+ # asyncio.ensure_future was introduced in Python 3.4.4, but
+ # Debian jessie still ships with 3.4.2 so try the old name.
+ _wrap_awaitable = getattr(asyncio, "async")
-def convert_yielded(yielded):
+def convert_yielded(yielded: _Yieldable) -> Future:
"""Convert a yielded object into a `.Future`.
- The default implementation accepts lists, dictionaries, and Futures.
+ The default implementation accepts lists, dictionaries, and
+ Futures. This has the side effect of starting any coroutines that
+ did not start themselves, similar to `asyncio.ensure_future`.
If the `~functools.singledispatch` library is available, this function
may be extended to support additional types. For example::
@@ -1347,21 +857,20 @@ def _(asyncio_future):
return tornado.platform.asyncio.to_tornado_future(asyncio_future)
.. versionadded:: 4.1
+
"""
- # Lists and dicts containing YieldPoints were handled earlier.
if yielded is None or yielded is moment:
return moment
elif yielded is _null_future:
return _null_future
elif isinstance(yielded, (list, dict)):
- return multi(yielded)
+ return multi(yielded) # type: ignore
elif is_future(yielded):
- return yielded
+ return typing.cast(Future, yielded)
elif isawaitable(yielded):
- return _wrap_awaitable(yielded)
+ return _wrap_awaitable(yielded) # type: ignore
else:
raise BadYieldError("yielded unknown object %r" % (yielded,))
-if singledispatch is not None:
- convert_yielded = singledispatch(convert_yielded)
+convert_yielded = singledispatch(convert_yielded)
diff --git a/tornado/http1connection.py b/tornado/http1connection.py
index 1c5eadf84c..72088d6023 100644
--- a/tornado/http1connection.py
+++ b/tornado/http1connection.py
@@ -18,23 +18,29 @@
.. versionadded:: 4.0
"""
-from __future__ import absolute_import, division, print_function
-
+import asyncio
+import logging
import re
+import types
-from tornado.concurrent import (Future, future_add_done_callback,
- future_set_result_unless_cancelled)
+from tornado.concurrent import (
+ Future,
+ future_add_done_callback,
+ future_set_result_unless_cancelled,
+)
from tornado.escape import native_str, utf8
from tornado import gen
from tornado import httputil
from tornado import iostream
from tornado.log import gen_log, app_log
-from tornado import stack_context
-from tornado.util import GzipDecompressor, PY3
+from tornado.util import GzipDecompressor
+
+
+from typing import cast, Optional, Type, Awaitable, Callable, Union, Tuple
class _QuietException(Exception):
- def __init__(self):
+ def __init__(self) -> None:
pass
@@ -43,24 +49,38 @@ class _ExceptionLoggingContext(object):
log any exceptions with the given logger. Any exceptions caught are
converted to _QuietException
"""
- def __init__(self, logger):
+
+ def __init__(self, logger: logging.Logger) -> None:
self.logger = logger
- def __enter__(self):
+ def __enter__(self) -> None:
pass
- def __exit__(self, typ, value, tb):
+ def __exit__(
+ self,
+ typ: "Optional[Type[BaseException]]",
+ value: Optional[BaseException],
+ tb: types.TracebackType,
+ ) -> None:
if value is not None:
+ assert typ is not None
self.logger.error("Uncaught exception", exc_info=(typ, value, tb))
raise _QuietException
class HTTP1ConnectionParameters(object):
- """Parameters for `.HTTP1Connection` and `.HTTP1ServerConnection`.
- """
- def __init__(self, no_keep_alive=False, chunk_size=None,
- max_header_size=None, header_timeout=None, max_body_size=None,
- body_timeout=None, decompress=False):
+ """Parameters for `.HTTP1Connection` and `.HTTP1ServerConnection`."""
+
+ def __init__(
+ self,
+ no_keep_alive: bool = False,
+ chunk_size: Optional[int] = None,
+ max_header_size: Optional[int] = None,
+ header_timeout: Optional[float] = None,
+ max_body_size: Optional[int] = None,
+ body_timeout: Optional[float] = None,
+ decompress: bool = False,
+ ) -> None:
"""
:arg bool no_keep_alive: If true, always close the connection after
one request.
@@ -87,7 +107,14 @@ class HTTP1Connection(httputil.HTTPConnection):
This class can be on its own for clients, or via `HTTP1ServerConnection`
for servers.
"""
- def __init__(self, stream, is_client, params=None, context=None):
+
+ def __init__(
+ self,
+ stream: iostream.IOStream,
+ is_client: bool,
+ params: Optional[HTTP1ConnectionParameters] = None,
+ context: Optional[object] = None,
+ ) -> None:
"""
:arg stream: an `.IOStream`
:arg bool is_client: client or server
@@ -104,8 +131,11 @@ def __init__(self, stream, is_client, params=None, context=None):
self.no_keep_alive = params.no_keep_alive
# The body limits can be altered by the delegate, so save them
# here instead of just referencing self.params later.
- self._max_body_size = (self.params.max_body_size or
- self.stream.max_buffer_size)
+ self._max_body_size = (
+ self.params.max_body_size
+ if self.params.max_body_size is not None
+ else self.stream.max_buffer_size
+ )
self._body_timeout = self.params.body_timeout
# _write_finished is set to True when finish() has been called,
# i.e. there will be no more data sent. Data may still be in the
@@ -115,7 +145,7 @@ def __init__(self, stream, is_client, params=None, context=None):
self._read_finished = False
# _finish_future resolves when all data has been written and flushed
# to the IOStream.
- self._finish_future = Future()
+ self._finish_future = Future() # type: Future[None]
# If true, the connection should be closed after this request
# (after the response has been written in the server side,
# and after it has been read in the client)
@@ -124,18 +154,18 @@ def __init__(self, stream, is_client, params=None, context=None):
# Save the start lines after we read or write them; they
# affect later processing (e.g. 304 responses and HEAD methods
# have content-length but no bodies)
- self._request_start_line = None
- self._response_start_line = None
- self._request_headers = None
+ self._request_start_line = None # type: Optional[httputil.RequestStartLine]
+ self._response_start_line = None # type: Optional[httputil.ResponseStartLine]
+ self._request_headers = None # type: Optional[httputil.HTTPHeaders]
# True if we are writing output with chunked encoding.
- self._chunking_output = None
+ self._chunking_output = False
# While reading a body with a content-length, this is the
# amount left to read.
- self._expected_content_remaining = None
+ self._expected_content_remaining = None # type: Optional[int]
# A Future for our outgoing writes, returned by IOStream.write.
- self._pending_write = None
+ self._pending_write = None # type: Optional[Future[None]]
- def read_response(self, delegate):
+ def read_response(self, delegate: httputil.HTTPMessageDelegate) -> Awaitable[bool]:
"""Read a single HTTP response.
Typical client-mode usage is to write a request using `write_headers`,
@@ -143,55 +173,64 @@ def read_response(self, delegate):
:arg delegate: a `.HTTPMessageDelegate`
- Returns a `.Future` that resolves to None after the full response has
- been read.
+ Returns a `.Future` that resolves to a bool after the full response has
+ been read. The result is true if the stream is still open.
"""
if self.params.decompress:
delegate = _GzipMessageDelegate(delegate, self.params.chunk_size)
return self._read_message(delegate)
- @gen.coroutine
- def _read_message(self, delegate):
+ async def _read_message(self, delegate: httputil.HTTPMessageDelegate) -> bool:
need_delegate_close = False
try:
header_future = self.stream.read_until_regex(
- b"\r?\n\r?\n",
- max_bytes=self.params.max_header_size)
+ b"\r?\n\r?\n", max_bytes=self.params.max_header_size
+ )
if self.params.header_timeout is None:
- header_data = yield header_future
+ header_data = await header_future
else:
try:
- header_data = yield gen.with_timeout(
+ header_data = await gen.with_timeout(
self.stream.io_loop.time() + self.params.header_timeout,
header_future,
- quiet_exceptions=iostream.StreamClosedError)
+ quiet_exceptions=iostream.StreamClosedError,
+ )
except gen.TimeoutError:
self.close()
- raise gen.Return(False)
- start_line, headers = self._parse_headers(header_data)
+ return False
+ start_line_str, headers = self._parse_headers(header_data)
if self.is_client:
- start_line = httputil.parse_response_start_line(start_line)
- self._response_start_line = start_line
+ resp_start_line = httputil.parse_response_start_line(start_line_str)
+ self._response_start_line = resp_start_line
+ start_line = (
+ resp_start_line
+ ) # type: Union[httputil.RequestStartLine, httputil.ResponseStartLine]
+ # TODO: this will need to change to support client-side keepalive
+ self._disconnect_on_finish = False
else:
- start_line = httputil.parse_request_start_line(start_line)
- self._request_start_line = start_line
+ req_start_line = httputil.parse_request_start_line(start_line_str)
+ self._request_start_line = req_start_line
self._request_headers = headers
-
- self._disconnect_on_finish = not self._can_keep_alive(
- start_line, headers)
+ start_line = req_start_line
+ self._disconnect_on_finish = not self._can_keep_alive(
+ req_start_line, headers
+ )
need_delegate_close = True
with _ExceptionLoggingContext(app_log):
- header_future = delegate.headers_received(start_line, headers)
- if header_future is not None:
- yield header_future
+ header_recv_future = delegate.headers_received(start_line, headers)
+ if header_recv_future is not None:
+ await header_recv_future
if self.stream is None:
# We've been detached.
need_delegate_close = False
- raise gen.Return(False)
+ return False
skip_body = False
if self.is_client:
- if (self._request_start_line is not None and
- self._request_start_line.method == 'HEAD'):
+ assert isinstance(start_line, httputil.ResponseStartLine)
+ if (
+ self._request_start_line is not None
+ and self._request_start_line.method == "HEAD"
+ ):
skip_body = True
code = start_line.code
if code == 304:
@@ -199,37 +238,37 @@ def _read_message(self, delegate):
# but do not actually have a body.
# http://tools.ietf.org/html/rfc7230#section-3.3
skip_body = True
- if code >= 100 and code < 200:
+ if 100 <= code < 200:
# 1xx responses should never indicate the presence of
# a body.
- if ('Content-Length' in headers or
- 'Transfer-Encoding' in headers):
+ if "Content-Length" in headers or "Transfer-Encoding" in headers:
raise httputil.HTTPInputError(
- "Response code %d cannot have body" % code)
+ "Response code %d cannot have body" % code
+ )
# TODO: client delegates will get headers_received twice
# in the case of a 100-continue. Document or change?
- yield self._read_message(delegate)
+ await self._read_message(delegate)
else:
- if (headers.get("Expect") == "100-continue" and
- not self._write_finished):
+ if headers.get("Expect") == "100-continue" and not self._write_finished:
self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
if not skip_body:
body_future = self._read_body(
- start_line.code if self.is_client else 0, headers, delegate)
+ resp_start_line.code if self.is_client else 0, headers, delegate
+ )
if body_future is not None:
if self._body_timeout is None:
- yield body_future
+ await body_future
else:
try:
- yield gen.with_timeout(
+ await gen.with_timeout(
self.stream.io_loop.time() + self._body_timeout,
body_future,
- quiet_exceptions=iostream.StreamClosedError)
+ quiet_exceptions=iostream.StreamClosedError,
+ )
except gen.TimeoutError:
- gen_log.info("Timeout reading body from %s",
- self.context)
+ gen_log.info("Timeout reading body from %s", self.context)
self.stream.close()
- raise gen.Return(False)
+ return False
self._read_finished = True
if not self._write_finished or self.is_client:
need_delegate_close = False
@@ -238,57 +277,58 @@ def _read_message(self, delegate):
# If we're waiting for the application to produce an asynchronous
# response, and we're not detached, register a close callback
# on the stream (we didn't need one while we were reading)
- if (not self._finish_future.done() and
- self.stream is not None and
- not self.stream.closed()):
+ if (
+ not self._finish_future.done()
+ and self.stream is not None
+ and not self.stream.closed()
+ ):
self.stream.set_close_callback(self._on_connection_close)
- yield self._finish_future
+ await self._finish_future
if self.is_client and self._disconnect_on_finish:
self.close()
if self.stream is None:
- raise gen.Return(False)
+ return False
except httputil.HTTPInputError as e:
- gen_log.info("Malformed HTTP message from %s: %s",
- self.context, e)
+ gen_log.info("Malformed HTTP message from %s: %s", self.context, e)
if not self.is_client:
- yield self.stream.write(b'HTTP/1.1 400 Bad Request\r\n\r\n')
+ await self.stream.write(b"HTTP/1.1 400 Bad Request\r\n\r\n")
self.close()
- raise gen.Return(False)
+ return False
finally:
if need_delegate_close:
with _ExceptionLoggingContext(app_log):
delegate.on_connection_close()
- header_future = None
+ header_future = None # type: ignore
self._clear_callbacks()
- raise gen.Return(True)
+ return True
- def _clear_callbacks(self):
+ def _clear_callbacks(self) -> None:
"""Clears the callback attributes.
This allows the request handler to be garbage collected more
quickly in CPython by breaking up reference cycles.
"""
self._write_callback = None
- self._write_future = None
- self._close_callback = None
+ self._write_future = None # type: Optional[Future[None]]
+ self._close_callback = None # type: Optional[Callable[[], None]]
if self.stream is not None:
self.stream.set_close_callback(None)
- def set_close_callback(self, callback):
+ def set_close_callback(self, callback: Optional[Callable[[], None]]) -> None:
"""Sets a callback that will be run when the connection is closed.
Note that this callback is slightly different from
`.HTTPMessageDelegate.on_connection_close`: The
`.HTTPMessageDelegate` method is called when the connection is
- closed while recieving a message. This callback is used when
+ closed while receiving a message. This callback is used when
there is not an active delegate (for example, on the server
side this callback is used if the client closes the connection
after sending its request but before receiving all the
response.
"""
- self._close_callback = stack_context.wrap(callback)
+ self._close_callback = callback
- def _on_connection_close(self):
+ def _on_connection_close(self) -> None:
# Note that this callback is only registered on the IOStream
# when we have finished reading the request and are waiting for
# the application to produce its response.
@@ -300,14 +340,14 @@ def _on_connection_close(self):
future_set_result_unless_cancelled(self._finish_future, None)
self._clear_callbacks()
- def close(self):
+ def close(self) -> None:
if self.stream is not None:
self.stream.close()
self._clear_callbacks()
if not self._finish_future.done():
future_set_result_unless_cancelled(self._finish_future, None)
- def detach(self):
+ def detach(self) -> iostream.IOStream:
"""Take control of the underlying stream.
Returns the underlying `.IOStream` object and stops all further
@@ -317,108 +357,127 @@ def detach(self):
"""
self._clear_callbacks()
stream = self.stream
- self.stream = None
+ self.stream = None # type: ignore
if not self._finish_future.done():
future_set_result_unless_cancelled(self._finish_future, None)
return stream
- def set_body_timeout(self, timeout):
+ def set_body_timeout(self, timeout: float) -> None:
"""Sets the body timeout for a single request.
Overrides the value from `.HTTP1ConnectionParameters`.
"""
self._body_timeout = timeout
- def set_max_body_size(self, max_body_size):
+ def set_max_body_size(self, max_body_size: int) -> None:
"""Sets the body size limit for a single request.
Overrides the value from `.HTTP1ConnectionParameters`.
"""
self._max_body_size = max_body_size
- def write_headers(self, start_line, headers, chunk=None, callback=None):
+ def write_headers(
+ self,
+ start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
+ headers: httputil.HTTPHeaders,
+ chunk: Optional[bytes] = None,
+ ) -> "Future[None]":
"""Implements `.HTTPConnection.write_headers`."""
lines = []
if self.is_client:
+ assert isinstance(start_line, httputil.RequestStartLine)
self._request_start_line = start_line
- lines.append(utf8('%s %s HTTP/1.1' % (start_line[0], start_line[1])))
+ lines.append(utf8("%s %s HTTP/1.1" % (start_line[0], start_line[1])))
# Client requests with a non-empty body must have either a
# Content-Length or a Transfer-Encoding.
self._chunking_output = (
- start_line.method in ('POST', 'PUT', 'PATCH') and
- 'Content-Length' not in headers and
- 'Transfer-Encoding' not in headers)
+ start_line.method in ("POST", "PUT", "PATCH")
+ and "Content-Length" not in headers
+ and (
+ "Transfer-Encoding" not in headers
+ or headers["Transfer-Encoding"] == "chunked"
+ )
+ )
else:
+ assert isinstance(start_line, httputil.ResponseStartLine)
+ assert self._request_start_line is not None
+ assert self._request_headers is not None
self._response_start_line = start_line
- lines.append(utf8('HTTP/1.1 %d %s' % (start_line[1], start_line[2])))
+ lines.append(utf8("HTTP/1.1 %d %s" % (start_line[1], start_line[2])))
self._chunking_output = (
# TODO: should this use
# self._request_start_line.version or
# start_line.version?
- self._request_start_line.version == 'HTTP/1.1' and
+ self._request_start_line.version == "HTTP/1.1"
+ # Omit payload header field for HEAD request.
+ and self._request_start_line.method != "HEAD"
# 1xx, 204 and 304 responses have no body (not even a zero-length
# body), and so should not have either Content-Length or
# Transfer-Encoding headers.
- start_line.code not in (204, 304) and
- (start_line.code < 100 or start_line.code >= 200) and
+ and start_line.code not in (204, 304)
+ and (start_line.code < 100 or start_line.code >= 200)
# No need to chunk the output if a Content-Length is specified.
- 'Content-Length' not in headers and
+ and "Content-Length" not in headers
# Applications are discouraged from touching Transfer-Encoding,
# but if they do, leave it alone.
- 'Transfer-Encoding' not in headers)
+ and "Transfer-Encoding" not in headers
+ )
# If connection to a 1.1 client will be closed, inform client
- if (self._request_start_line.version == 'HTTP/1.1' and self._disconnect_on_finish):
- headers['Connection'] = 'close'
+ if (
+ self._request_start_line.version == "HTTP/1.1"
+ and self._disconnect_on_finish
+ ):
+ headers["Connection"] = "close"
# If a 1.0 client asked for keep-alive, add the header.
- if (self._request_start_line.version == 'HTTP/1.0' and
- self._request_headers.get('Connection', '').lower() == 'keep-alive'):
- headers['Connection'] = 'Keep-Alive'
+ if (
+ self._request_start_line.version == "HTTP/1.0"
+ and self._request_headers.get("Connection", "").lower() == "keep-alive"
+ ):
+ headers["Connection"] = "Keep-Alive"
if self._chunking_output:
- headers['Transfer-Encoding'] = 'chunked'
- if (not self.is_client and
- (self._request_start_line.method == 'HEAD' or
- start_line.code == 304)):
+ headers["Transfer-Encoding"] = "chunked"
+ if not self.is_client and (
+ self._request_start_line.method == "HEAD"
+ or cast(httputil.ResponseStartLine, start_line).code == 304
+ ):
self._expected_content_remaining = 0
- elif 'Content-Length' in headers:
- self._expected_content_remaining = int(headers['Content-Length'])
+ elif "Content-Length" in headers:
+ self._expected_content_remaining = int(headers["Content-Length"])
else:
self._expected_content_remaining = None
# TODO: headers are supposed to be of type str, but we still have some
# cases that let bytes slip through. Remove these native_str calls when those
# are fixed.
- header_lines = (native_str(n) + ": " + native_str(v) for n, v in headers.get_all())
- if PY3:
- lines.extend(l.encode('latin1') for l in header_lines)
- else:
- lines.extend(header_lines)
+ header_lines = (
+ native_str(n) + ": " + native_str(v) for n, v in headers.get_all()
+ )
+ lines.extend(line.encode("latin1") for line in header_lines)
for line in lines:
- if b'\n' in line:
- raise ValueError('Newline in header: ' + repr(line))
+ if b"\n" in line:
+ raise ValueError("Newline in header: " + repr(line))
future = None
if self.stream.closed():
future = self._write_future = Future()
future.set_exception(iostream.StreamClosedError())
future.exception()
else:
- if callback is not None:
- self._write_callback = stack_context.wrap(callback)
- else:
- future = self._write_future = Future()
+ future = self._write_future = Future()
data = b"\r\n".join(lines) + b"\r\n\r\n"
if chunk:
data += self._format_chunk(chunk)
self._pending_write = self.stream.write(data)
- self._pending_write.add_done_callback(self._on_write_complete)
+ future_add_done_callback(self._pending_write, self._on_write_complete)
return future
- def _format_chunk(self, chunk):
+ def _format_chunk(self, chunk: bytes) -> bytes:
if self._expected_content_remaining is not None:
self._expected_content_remaining -= len(chunk)
if self._expected_content_remaining < 0:
# Close the stream now to stop further framing errors.
self.stream.close()
raise httputil.HTTPOutputError(
- "Tried to write more data than Content-Length")
+ "Tried to write more data than Content-Length"
+ )
if self._chunking_output and chunk:
# Don't write out empty chunks because that means END-OF-STREAM
# with chunked encoding
@@ -426,7 +485,7 @@ def _format_chunk(self, chunk):
else:
return chunk
- def write(self, chunk, callback=None):
+ def write(self, chunk: bytes) -> "Future[None]":
"""Implements `.HTTPConnection.write`.
For backwards compatibility it is allowed but deprecated to
@@ -439,23 +498,23 @@ def write(self, chunk, callback=None):
self._write_future.set_exception(iostream.StreamClosedError())
self._write_future.exception()
else:
- if callback is not None:
- self._write_callback = stack_context.wrap(callback)
- else:
- future = self._write_future = Future()
+ future = self._write_future = Future()
self._pending_write = self.stream.write(self._format_chunk(chunk))
- self._pending_write.add_done_callback(self._on_write_complete)
+ future_add_done_callback(self._pending_write, self._on_write_complete)
return future
- def finish(self):
+ def finish(self) -> None:
"""Implements `.HTTPConnection.finish`."""
- if (self._expected_content_remaining is not None and
- self._expected_content_remaining != 0 and
- not self.stream.closed()):
+ if (
+ self._expected_content_remaining is not None
+ and self._expected_content_remaining != 0
+ and not self.stream.closed()
+ ):
self.stream.close()
raise httputil.HTTPOutputError(
- "Tried to write %d bytes less than Content-Length" %
- self._expected_content_remaining)
+ "Tried to write %d bytes less than Content-Length"
+ % self._expected_content_remaining
+ )
if self._chunking_output:
if not self.stream.closed():
self._pending_write = self.stream.write(b"0\r\n\r\n")
@@ -476,7 +535,7 @@ def finish(self):
else:
future_add_done_callback(self._pending_write, self._finish_request)
- def _on_write_complete(self, future):
+ def _on_write_complete(self, future: "Future[None]") -> None:
exc = future.exception()
if exc is not None and not isinstance(exc, iostream.StreamClosedError):
future.result()
@@ -489,7 +548,9 @@ def _on_write_complete(self, future):
self._write_future = None
future_set_result_unless_cancelled(future, None)
- def _can_keep_alive(self, start_line, headers):
+ def _can_keep_alive(
+ self, start_line: httputil.RequestStartLine, headers: httputil.HTTPHeaders
+ ) -> bool:
if self.params.no_keep_alive:
return False
connection_header = headers.get("Connection")
@@ -497,15 +558,17 @@ def _can_keep_alive(self, start_line, headers):
connection_header = connection_header.lower()
if start_line.version == "HTTP/1.1":
return connection_header != "close"
- elif ("Content-Length" in headers or
- headers.get("Transfer-Encoding", "").lower() == "chunked" or
- getattr(start_line, 'method', None) in ("HEAD", "GET")):
+ elif (
+ "Content-Length" in headers
+ or headers.get("Transfer-Encoding", "").lower() == "chunked"
+ or getattr(start_line, "method", None) in ("HEAD", "GET")
+ ):
# start_line may be a request or response start line; only
# the former has a method attribute.
return connection_header == "keep-alive"
return False
- def _finish_request(self, future):
+ def _finish_request(self, future: "Optional[Future[None]]") -> None:
self._clear_callbacks()
if not self.is_client and self._disconnect_on_finish:
self.close()
@@ -516,45 +579,54 @@ def _finish_request(self, future):
if not self._finish_future.done():
future_set_result_unless_cancelled(self._finish_future, None)
- def _parse_headers(self, data):
+ def _parse_headers(self, data: bytes) -> Tuple[str, httputil.HTTPHeaders]:
# The lstrip removes newlines that some implementations sometimes
# insert between messages of a reused connection. Per RFC 7230,
# we SHOULD ignore at least one empty line before the request.
# http://tools.ietf.org/html/rfc7230#section-3.5
- data = native_str(data.decode('latin1')).lstrip("\r\n")
+ data_str = native_str(data.decode("latin1")).lstrip("\r\n")
# RFC 7230 section allows for both CRLF and bare LF.
- eol = data.find("\n")
- start_line = data[:eol].rstrip("\r")
- headers = httputil.HTTPHeaders.parse(data[eol:])
+ eol = data_str.find("\n")
+ start_line = data_str[:eol].rstrip("\r")
+ headers = httputil.HTTPHeaders.parse(data_str[eol:])
return start_line, headers
- def _read_body(self, code, headers, delegate):
+ def _read_body(
+ self,
+ code: int,
+ headers: httputil.HTTPHeaders,
+ delegate: httputil.HTTPMessageDelegate,
+ ) -> Optional[Awaitable[None]]:
if "Content-Length" in headers:
if "Transfer-Encoding" in headers:
# Response cannot contain both Content-Length and
# Transfer-Encoding headers.
# http://tools.ietf.org/html/rfc7230#section-3.3.3
raise httputil.HTTPInputError(
- "Response with both Transfer-Encoding and Content-Length")
+ "Response with both Transfer-Encoding and Content-Length"
+ )
if "," in headers["Content-Length"]:
# Proxies sometimes cause Content-Length headers to get
# duplicated. If all the values are identical then we can
# use them but if they differ it's an error.
- pieces = re.split(r',\s*', headers["Content-Length"])
+ pieces = re.split(r",\s*", headers["Content-Length"])
if any(i != pieces[0] for i in pieces):
raise httputil.HTTPInputError(
- "Multiple unequal Content-Lengths: %r" %
- headers["Content-Length"])
+ "Multiple unequal Content-Lengths: %r"
+ % headers["Content-Length"]
+ )
headers["Content-Length"] = pieces[0]
try:
- content_length = int(headers["Content-Length"])
+ content_length = int(headers["Content-Length"]) # type: Optional[int]
except ValueError:
# Handles non-integer Content-Length value.
raise httputil.HTTPInputError(
- "Only integer Content-Length is allowed: %s" % headers["Content-Length"])
+ "Only integer Content-Length is allowed: %s"
+ % headers["Content-Length"]
+ )
- if content_length > self._max_body_size:
+ if cast(int, content_length) > self._max_body_size:
raise httputil.HTTPInputError("Content-Length too long")
else:
content_length = None
@@ -563,10 +635,10 @@ def _read_body(self, code, headers, delegate):
# This response code is not allowed to have a non-empty body,
# and has an implicit length of zero instead of read-until-close.
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
- if ("Transfer-Encoding" in headers or
- content_length not in (None, 0)):
+ if "Transfer-Encoding" in headers or content_length not in (None, 0):
raise httputil.HTTPInputError(
- "Response with code %d should not have body" % code)
+ "Response with code %d should not have body" % code
+ )
content_length = 0
if content_length is not None:
@@ -577,110 +649,133 @@ def _read_body(self, code, headers, delegate):
return self._read_body_until_close(delegate)
return None
- @gen.coroutine
- def _read_fixed_body(self, content_length, delegate):
+ async def _read_fixed_body(
+ self, content_length: int, delegate: httputil.HTTPMessageDelegate
+ ) -> None:
while content_length > 0:
- body = yield self.stream.read_bytes(
- min(self.params.chunk_size, content_length), partial=True)
+ body = await self.stream.read_bytes(
+ min(self.params.chunk_size, content_length), partial=True
+ )
content_length -= len(body)
if not self._write_finished or self.is_client:
with _ExceptionLoggingContext(app_log):
ret = delegate.data_received(body)
if ret is not None:
- yield ret
+ await ret
- @gen.coroutine
- def _read_chunked_body(self, delegate):
+ async def _read_chunked_body(self, delegate: httputil.HTTPMessageDelegate) -> None:
# TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
total_size = 0
while True:
- chunk_len = yield self.stream.read_until(b"\r\n", max_bytes=64)
- chunk_len = int(chunk_len.strip(), 16)
+ chunk_len_str = await self.stream.read_until(b"\r\n", max_bytes=64)
+ chunk_len = int(chunk_len_str.strip(), 16)
if chunk_len == 0:
- crlf = yield self.stream.read_bytes(2)
- if crlf != b'\r\n':
- raise httputil.HTTPInputError("improperly terminated chunked request")
+ crlf = await self.stream.read_bytes(2)
+ if crlf != b"\r\n":
+ raise httputil.HTTPInputError(
+ "improperly terminated chunked request"
+ )
return
total_size += chunk_len
if total_size > self._max_body_size:
raise httputil.HTTPInputError("chunked body too large")
bytes_to_read = chunk_len
while bytes_to_read:
- chunk = yield self.stream.read_bytes(
- min(bytes_to_read, self.params.chunk_size), partial=True)
+ chunk = await self.stream.read_bytes(
+ min(bytes_to_read, self.params.chunk_size), partial=True
+ )
bytes_to_read -= len(chunk)
if not self._write_finished or self.is_client:
with _ExceptionLoggingContext(app_log):
ret = delegate.data_received(chunk)
if ret is not None:
- yield ret
+ await ret
# chunk ends with \r\n
- crlf = yield self.stream.read_bytes(2)
+ crlf = await self.stream.read_bytes(2)
assert crlf == b"\r\n"
- @gen.coroutine
- def _read_body_until_close(self, delegate):
- body = yield self.stream.read_until_close()
+ async def _read_body_until_close(
+ self, delegate: httputil.HTTPMessageDelegate
+ ) -> None:
+ body = await self.stream.read_until_close()
if not self._write_finished or self.is_client:
with _ExceptionLoggingContext(app_log):
- delegate.data_received(body)
+ ret = delegate.data_received(body)
+ if ret is not None:
+ await ret
class _GzipMessageDelegate(httputil.HTTPMessageDelegate):
- """Wraps an `HTTPMessageDelegate` to decode ``Content-Encoding: gzip``.
- """
- def __init__(self, delegate, chunk_size):
+ """Wraps an `HTTPMessageDelegate` to decode ``Content-Encoding: gzip``."""
+
+ def __init__(self, delegate: httputil.HTTPMessageDelegate, chunk_size: int) -> None:
self._delegate = delegate
self._chunk_size = chunk_size
- self._decompressor = None
+ self._decompressor = None # type: Optional[GzipDecompressor]
- def headers_received(self, start_line, headers):
+ def headers_received(
+ self,
+ start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
+ headers: httputil.HTTPHeaders,
+ ) -> Optional[Awaitable[None]]:
if headers.get("Content-Encoding") == "gzip":
self._decompressor = GzipDecompressor()
# Downstream delegates will only see uncompressed data,
# so rename the content-encoding header.
# (but note that curl_httpclient doesn't do this).
- headers.add("X-Consumed-Content-Encoding",
- headers["Content-Encoding"])
+ headers.add("X-Consumed-Content-Encoding", headers["Content-Encoding"])
del headers["Content-Encoding"]
return self._delegate.headers_received(start_line, headers)
- @gen.coroutine
- def data_received(self, chunk):
+ async def data_received(self, chunk: bytes) -> None:
if self._decompressor:
compressed_data = chunk
while compressed_data:
decompressed = self._decompressor.decompress(
- compressed_data, self._chunk_size)
+ compressed_data, self._chunk_size
+ )
if decompressed:
ret = self._delegate.data_received(decompressed)
if ret is not None:
- yield ret
+ await ret
compressed_data = self._decompressor.unconsumed_tail
+ if compressed_data and not decompressed:
+ raise httputil.HTTPInputError(
+ "encountered unconsumed gzip data without making progress"
+ )
else:
ret = self._delegate.data_received(chunk)
if ret is not None:
- yield ret
+ await ret
- def finish(self):
+ def finish(self) -> None:
if self._decompressor is not None:
tail = self._decompressor.flush()
if tail:
- # I believe the tail will always be empty (i.e.
- # decompress will return all it can). The purpose
- # of the flush call is to detect errors such
- # as truncated input. But in case it ever returns
- # anything, treat it as an extra chunk
- self._delegate.data_received(tail)
+ # The tail should always be empty: decompress returned
+ # all that it can in data_received and the only
+ # purpose of the flush call is to detect errors such
+ # as truncated input. If we did legitimately get a new
+ # chunk at this point we'd need to change the
+ # interface to make finish() a coroutine.
+ raise ValueError(
+ "decompressor.flush returned data; possible truncated input"
+ )
return self._delegate.finish()
- def on_connection_close(self):
+ def on_connection_close(self) -> None:
return self._delegate.on_connection_close()
class HTTP1ServerConnection(object):
"""An HTTP/1.x server."""
- def __init__(self, stream, params=None, context=None):
+
+ def __init__(
+ self,
+ stream: iostream.IOStream,
+ params: Optional[HTTP1ConnectionParameters] = None,
+ context: Optional[object] = None,
+ ) -> None:
"""
:arg stream: an `.IOStream`
:arg params: a `.HTTP1ConnectionParameters` or None
@@ -692,10 +787,9 @@ def __init__(self, stream, params=None, context=None):
params = HTTP1ConnectionParameters()
self.params = params
self.context = context
- self._serving_future = None
+ self._serving_future = None # type: Optional[Future[None]]
- @gen.coroutine
- def close(self):
+ async def close(self) -> None:
"""Closes the connection.
Returns a `.Future` that resolves after the serving loop has exited.
@@ -703,33 +797,37 @@ def close(self):
self.stream.close()
# Block until the serving loop is done, but ignore any exceptions
# (start_serving is already responsible for logging them).
+ assert self._serving_future is not None
try:
- yield self._serving_future
+ await self._serving_future
except Exception:
pass
- def start_serving(self, delegate):
+ def start_serving(self, delegate: httputil.HTTPServerConnectionDelegate) -> None:
"""Starts serving requests on this connection.
:arg delegate: a `.HTTPServerConnectionDelegate`
"""
assert isinstance(delegate, httputil.HTTPServerConnectionDelegate)
- self._serving_future = self._server_request_loop(delegate)
+ fut = gen.convert_yielded(self._server_request_loop(delegate))
+ self._serving_future = fut
# Register the future on the IOLoop so its errors get logged.
- self.stream.io_loop.add_future(self._serving_future,
- lambda f: f.result())
+ self.stream.io_loop.add_future(fut, lambda f: f.result())
- @gen.coroutine
- def _server_request_loop(self, delegate):
+ async def _server_request_loop(
+ self, delegate: httputil.HTTPServerConnectionDelegate
+ ) -> None:
try:
while True:
- conn = HTTP1Connection(self.stream, False,
- self.params, self.context)
+ conn = HTTP1Connection(self.stream, False, self.params, self.context)
request_delegate = delegate.start_request(self, conn)
try:
- ret = yield conn.read_response(request_delegate)
- except (iostream.StreamClosedError,
- iostream.UnsatisfiableReadError):
+ ret = await conn.read_response(request_delegate)
+ except (
+ iostream.StreamClosedError,
+ iostream.UnsatisfiableReadError,
+ asyncio.CancelledError,
+ ):
return
except _QuietException:
# This exception was already logged.
@@ -741,6 +839,6 @@ def _server_request_loop(self, delegate):
return
if not ret:
return
- yield gen.moment
+ await asyncio.sleep(0)
finally:
delegate.on_close(self)
diff --git a/tornado/httpclient.py b/tornado/httpclient.py
index f0a2df8871..3011c371b8 100644
--- a/tornado/httpclient.py
+++ b/tornado/httpclient.py
@@ -20,8 +20,6 @@
* ``curl_httpclient`` is faster.
-* ``curl_httpclient`` was the default prior to Tornado 2.0.
-
Note that if you are using ``curl_httpclient``, it is highly
recommended that you use a recent version of ``libcurl`` and
``pycurl``. Currently the minimum supported version of libcurl is
@@ -38,19 +36,25 @@
AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
"""
-from __future__ import absolute_import, division, print_function
-
+import datetime
import functools
+from io import BytesIO
+import ssl
import time
-import warnings
import weakref
-from tornado.concurrent import Future, future_set_result_unless_cancelled
+from tornado.concurrent import (
+ Future,
+ future_set_result_unless_cancelled,
+ future_set_exception_unless_cancelled,
+)
from tornado.escape import utf8, native_str
-from tornado import gen, httputil, stack_context
+from tornado import gen, httputil
from tornado.ioloop import IOLoop
from tornado.util import Configurable
+from typing import Type, Any, Union, Dict, Callable, Optional, cast
+
class HTTPClient(object):
"""A blocking HTTP client.
@@ -81,7 +85,12 @@ class HTTPClient(object):
Use `AsyncHTTPClient` instead.
"""
- def __init__(self, async_client_class=None, **kwargs):
+
+ def __init__(
+ self,
+ async_client_class: "Optional[Type[AsyncHTTPClient]]" = None,
+ **kwargs: Any
+ ) -> None:
# Initialize self._closed at the beginning of the constructor
# so that an exception raised here doesn't lead to confusing
# failures in __del__.
@@ -89,23 +98,30 @@ def __init__(self, async_client_class=None, **kwargs):
self._io_loop = IOLoop(make_current=False)
if async_client_class is None:
async_client_class = AsyncHTTPClient
+
# Create the client while our IOLoop is "current", without
# clobbering the thread's real current IOLoop (if any).
- self._async_client = self._io_loop.run_sync(
- gen.coroutine(lambda: async_client_class(**kwargs)))
+ async def make_client() -> "AsyncHTTPClient":
+ await gen.sleep(0)
+ assert async_client_class is not None
+ return async_client_class(**kwargs)
+
+ self._async_client = self._io_loop.run_sync(make_client)
self._closed = False
- def __del__(self):
+ def __del__(self) -> None:
self.close()
- def close(self):
+ def close(self) -> None:
"""Closes the HTTPClient, freeing any resources used."""
if not self._closed:
self._async_client.close()
self._io_loop.close()
self._closed = True
- def fetch(self, request, **kwargs):
+ def fetch(
+ self, request: Union["HTTPRequest", str], **kwargs: Any
+ ) -> "HTTPResponse":
"""Executes a request, returning an `HTTPResponse`.
The request may be either a string URL or an `HTTPRequest` object.
@@ -115,8 +131,9 @@ def fetch(self, request, **kwargs):
If an error occurs during the fetch, we raise an `HTTPError` unless
the ``raise_error`` keyword argument is set to False.
"""
- response = self._io_loop.run_sync(functools.partial(
- self._async_client.fetch, request, **kwargs))
+ response = self._io_loop.run_sync(
+ functools.partial(self._async_client.fetch, request, **kwargs)
+ )
return response
@@ -125,15 +142,15 @@ class AsyncHTTPClient(Configurable):
Example usage::
- def handle_response(response):
- if response.error:
- print("Error: %s" % response.error)
+ async def f():
+ http_client = AsyncHTTPClient()
+ try:
+ response = await http_client.fetch("http://www.google.com")
+ except Exception as e:
+ print("Error: %s" % e)
else:
print(response.body)
- http_client = AsyncHTTPClient()
- http_client.fetch("http://www.google.com/", handle_response)
-
The constructor for this class is magic in several respects: It
actually creates an instance of an implementation-specific
subclass, and instances are reused as a kind of pseudo-singleton
@@ -158,23 +175,27 @@ def handle_response(response):
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
"""
+
+ _instance_cache = None # type: Dict[IOLoop, AsyncHTTPClient]
+
@classmethod
- def configurable_base(cls):
+ def configurable_base(cls) -> Type[Configurable]:
return AsyncHTTPClient
@classmethod
- def configurable_default(cls):
+ def configurable_default(cls) -> Type[Configurable]:
from tornado.simple_httpclient import SimpleAsyncHTTPClient
+
return SimpleAsyncHTTPClient
@classmethod
- def _async_clients(cls):
- attr_name = '_async_client_dict_' + cls.__name__
+ def _async_clients(cls) -> Dict[IOLoop, "AsyncHTTPClient"]:
+ attr_name = "_async_client_dict_" + cls.__name__
if not hasattr(cls, attr_name):
setattr(cls, attr_name, weakref.WeakKeyDictionary())
return getattr(cls, attr_name)
- def __new__(cls, force_instance=False, **kwargs):
+ def __new__(cls, force_instance: bool = False, **kwargs: Any) -> "AsyncHTTPClient":
io_loop = IOLoop.current()
if force_instance:
instance_cache = None
@@ -182,7 +203,7 @@ def __new__(cls, force_instance=False, **kwargs):
instance_cache = cls._async_clients()
if instance_cache is not None and io_loop in instance_cache:
return instance_cache[io_loop]
- instance = super(AsyncHTTPClient, cls).__new__(cls, **kwargs)
+ instance = super(AsyncHTTPClient, cls).__new__(cls, **kwargs) # type: ignore
# Make sure the instance knows which cache to remove itself from.
# It can't simply call _async_clients() because we may be in
# __new__(AsyncHTTPClient) but instance.__class__ may be
@@ -192,14 +213,14 @@ def __new__(cls, force_instance=False, **kwargs):
instance_cache[instance.io_loop] = instance
return instance
- def initialize(self, defaults=None):
+ def initialize(self, defaults: Optional[Dict[str, Any]] = None) -> None:
self.io_loop = IOLoop.current()
self.defaults = dict(HTTPRequest._DEFAULTS)
if defaults is not None:
self.defaults.update(defaults)
self._closed = False
- def close(self):
+ def close(self) -> None:
"""Destroys this HTTP client, freeing any file descriptors used.
This method is **not needed in normal use** due to the way
@@ -216,11 +237,21 @@ def close(self):
return
self._closed = True
if self._instance_cache is not None:
- if self._instance_cache.get(self.io_loop) is not self:
+ cached_val = self._instance_cache.pop(self.io_loop, None)
+ # If there's an object other than self in the instance
+ # cache for our IOLoop, something has gotten mixed up. A
+ # value of None appears to be possible when this is called
+ # from a destructor (HTTPClient.__del__) as the weakref
+ # gets cleared before the destructor runs.
+ if cached_val is not None and cached_val is not self:
raise RuntimeError("inconsistent AsyncHTTPClient cache")
- del self._instance_cache[self.io_loop]
- def fetch(self, request, callback=None, raise_error=True, **kwargs):
+ def fetch(
+ self,
+ request: Union[str, "HTTPRequest"],
+ raise_error: bool = True,
+ **kwargs: Any
+ ) -> "Future[HTTPResponse]":
"""Executes a request, asynchronously returning an `HTTPResponse`.
The request may be either a string URL or an `HTTPRequest` object.
@@ -240,17 +271,14 @@ def fetch(self, request, callback=None, raise_error=True, **kwargs):
Instead, you must check the response's ``error`` attribute or
call its `~HTTPResponse.rethrow` method.
- .. deprecated:: 5.1
-
- The ``callback`` argument is deprecated and will be removed
- in 6.0. Use the returned `.Future` instead.
+ .. versionchanged:: 6.0
- The ``raise_error=False`` argument currently suppresses
- *all* errors, encapsulating them in `HTTPResponse` objects
- with a 599 response code. This will change in Tornado 6.0:
- ``raise_error=False`` will only affect the `HTTPError`
- raised when a non-200 response code is used.
+ The ``callback`` argument was removed. Use the returned
+ `.Future` instead.
+ The ``raise_error=False`` argument only affects the
+ `HTTPError` raised when a non-200 response code is used,
+ instead of suppressing all errors.
"""
if self._closed:
raise RuntimeError("fetch() called on closed AsyncHTTPClient")
@@ -258,49 +286,35 @@ def fetch(self, request, callback=None, raise_error=True, **kwargs):
request = HTTPRequest(url=request, **kwargs)
else:
if kwargs:
- raise ValueError("kwargs can't be used if request is an HTTPRequest object")
+ raise ValueError(
+ "kwargs can't be used if request is an HTTPRequest object"
+ )
# We may modify this (to add Host, Accept-Encoding, etc),
# so make sure we don't modify the caller's object. This is also
# where normal dicts get converted to HTTPHeaders objects.
request.headers = httputil.HTTPHeaders(request.headers)
- request = _RequestProxy(request, self.defaults)
- future = Future()
- if callback is not None:
- warnings.warn("callback arguments are deprecated, use the returned Future instead",
- DeprecationWarning)
- callback = stack_context.wrap(callback)
-
- def handle_future(future):
- exc = future.exception()
- if isinstance(exc, HTTPError) and exc.response is not None:
- response = exc.response
- elif exc is not None:
- response = HTTPResponse(
- request, 599, error=exc,
- request_time=time.time() - request.start_time)
- else:
- response = future.result()
- self.io_loop.add_callback(callback, response)
- future.add_done_callback(handle_future)
-
- def handle_response(response):
- if raise_error and response.error:
- if isinstance(response.error, HTTPError):
- response.error.response = response
- future.set_exception(response.error)
- else:
- if response.error and not response._error_is_response_code:
- warnings.warn("raise_error=False will allow '%s' to be raised in the future" %
- response.error, DeprecationWarning)
- future_set_result_unless_cancelled(future, response)
- self.fetch_impl(request, handle_response)
+ request_proxy = _RequestProxy(request, self.defaults)
+ future = Future() # type: Future[HTTPResponse]
+
+ def handle_response(response: "HTTPResponse") -> None:
+ if response.error:
+ if raise_error or not response._error_is_response_code:
+ future_set_exception_unless_cancelled(future, response.error)
+ return
+ future_set_result_unless_cancelled(future, response)
+
+ self.fetch_impl(cast(HTTPRequest, request_proxy), handle_response)
return future
- def fetch_impl(self, request, callback):
+ def fetch_impl(
+ self, request: "HTTPRequest", callback: Callable[["HTTPResponse"], None]
+ ) -> None:
raise NotImplementedError()
@classmethod
- def configure(cls, impl, **kwargs):
+ def configure(
+ cls, impl: "Union[None, str, Type[Configurable]]", **kwargs: Any
+ ) -> None:
"""Configures the `AsyncHTTPClient` subclass to use.
``AsyncHTTPClient()`` actually creates an instance of a subclass.
@@ -325,6 +339,8 @@ def configure(cls, impl, **kwargs):
class HTTPRequest(object):
"""HTTP client request object."""
+ _headers = None # type: Union[Dict[str, str], httputil.HTTPHeaders]
+
# Default values for HTTPRequest parameters.
# Merged with the values on the request object by AsyncHTTPClient
# implementations.
@@ -334,24 +350,49 @@ class HTTPRequest(object):
follow_redirects=True,
max_redirects=5,
decompress_response=True,
- proxy_password='',
+ proxy_password="",
allow_nonstandard_methods=False,
- validate_cert=True)
-
- def __init__(self, url, method="GET", headers=None, body=None,
- auth_username=None, auth_password=None, auth_mode=None,
- connect_timeout=None, request_timeout=None,
- if_modified_since=None, follow_redirects=None,
- max_redirects=None, user_agent=None, use_gzip=None,
- network_interface=None, streaming_callback=None,
- header_callback=None, prepare_curl_callback=None,
- proxy_host=None, proxy_port=None, proxy_username=None,
- proxy_password=None, proxy_auth_mode=None,
- allow_nonstandard_methods=None, validate_cert=None,
- ca_certs=None, allow_ipv6=None, client_key=None,
- client_cert=None, body_producer=None,
- expect_100_continue=False, decompress_response=None,
- ssl_options=None):
+ validate_cert=True,
+ )
+
+ def __init__(
+ self,
+ url: str,
+ method: str = "GET",
+ headers: Optional[Union[Dict[str, str], httputil.HTTPHeaders]] = None,
+ body: Optional[Union[bytes, str]] = None,
+ auth_username: Optional[str] = None,
+ auth_password: Optional[str] = None,
+ auth_mode: Optional[str] = None,
+ connect_timeout: Optional[float] = None,
+ request_timeout: Optional[float] = None,
+ if_modified_since: Optional[Union[float, datetime.datetime]] = None,
+ follow_redirects: Optional[bool] = None,
+ max_redirects: Optional[int] = None,
+ user_agent: Optional[str] = None,
+ use_gzip: Optional[bool] = None,
+ network_interface: Optional[str] = None,
+ streaming_callback: Optional[Callable[[bytes], None]] = None,
+ header_callback: Optional[Callable[[str], None]] = None,
+ prepare_curl_callback: Optional[Callable[[Any], None]] = None,
+ proxy_host: Optional[str] = None,
+ proxy_port: Optional[int] = None,
+ proxy_username: Optional[str] = None,
+ proxy_password: Optional[str] = None,
+ proxy_auth_mode: Optional[str] = None,
+ allow_nonstandard_methods: Optional[bool] = None,
+ validate_cert: Optional[bool] = None,
+ ca_certs: Optional[str] = None,
+ allow_ipv6: Optional[bool] = None,
+ client_key: Optional[str] = None,
+ client_cert: Optional[str] = None,
+ body_producer: Optional[
+ Callable[[Callable[[bytes], None]], "Future[None]"]
+ ] = None,
+ expect_100_continue: bool = False,
+ decompress_response: Optional[bool] = None,
+ ssl_options: Optional[Union[Dict[str, Any], ssl.SSLContext]] = None,
+ ) -> None:
r"""All parameters except ``url`` are optional.
:arg str url: URL to fetch
@@ -360,7 +401,9 @@ def __init__(self, url, method="GET", headers=None, body=None,
:type headers: `~tornado.httputil.HTTPHeaders` or `dict`
:arg body: HTTP request body as a string (byte or unicode; if unicode
the utf-8 encoding will be used)
- :arg body_producer: Callable used for lazy/asynchronous request bodies.
+ :type body: `str` or `bytes`
+ :arg collections.abc.Callable body_producer: Callable used for
+ lazy/asynchronous request bodies.
It is called with one argument, a ``write`` function, and should
return a `.Future`. It should call the write function with new
data as it becomes available. The write function returns a
@@ -378,9 +421,9 @@ def __init__(self, url, method="GET", headers=None, body=None,
supports "basic" and "digest"; ``simple_httpclient`` only supports
"basic"
:arg float connect_timeout: Timeout for initial connection in seconds,
- default 20 seconds
+ default 20 seconds (0 means no timeout)
:arg float request_timeout: Timeout for entire request in seconds,
- default 20 seconds
+ default 20 seconds (0 means no timeout)
:arg if_modified_since: Timestamp for ``If-Modified-Since`` header
:type if_modified_since: `datetime` or `float`
:arg bool follow_redirects: Should redirects be followed automatically
@@ -392,8 +435,8 @@ def __init__(self, url, method="GET", headers=None, body=None,
New in Tornado 4.0.
:arg bool use_gzip: Deprecated alias for ``decompress_response``
since Tornado 4.0.
- :arg str network_interface: Network interface to use for request.
- ``curl_httpclient`` only; see note below.
+ :arg str network_interface: Network interface or source IP to use for request.
+ See ``curl_httpclient`` note below.
:arg collections.abc.Callable streaming_callback: If set, ``streaming_callback`` will
be run with each chunk of data as it is received, and
``HTTPResponse.body`` and ``HTTPResponse.buffer`` will be empty in
@@ -433,11 +476,11 @@ def __init__(self, url, method="GET", headers=None, body=None,
``simple_httpclient`` (unsupported by ``curl_httpclient``).
Overrides ``validate_cert``, ``ca_certs``, ``client_key``,
and ``client_cert``.
- :arg bool allow_ipv6: Use IPv6 when available? Default is true.
+ :arg bool allow_ipv6: Use IPv6 when available? Default is True.
:arg bool expect_100_continue: If true, send the
``Expect: 100-continue`` header and wait for a continue response
before sending the request body. Only supported with
- simple_httpclient.
+ ``simple_httpclient``.
.. note::
@@ -465,10 +508,11 @@ def __init__(self, url, method="GET", headers=None, body=None,
"""
# Note that some of these attributes go through property setters
# defined below.
- self.headers = headers
+ self.headers = headers # type: ignore
if if_modified_since:
self.headers["If-Modified-Since"] = httputil.format_timestamp(
- if_modified_since)
+ if_modified_since
+ )
self.proxy_host = proxy_host
self.proxy_port = proxy_port
self.proxy_username = proxy_username
@@ -476,7 +520,7 @@ def __init__(self, url, method="GET", headers=None, body=None,
self.proxy_auth_mode = proxy_auth_mode
self.url = url
self.method = method
- self.body = body
+ self.body = body # type: ignore
self.body_producer = body_producer
self.auth_username = auth_username
self.auth_password = auth_password
@@ -487,7 +531,7 @@ def __init__(self, url, method="GET", headers=None, body=None,
self.max_redirects = max_redirects
self.user_agent = user_agent
if decompress_response is not None:
- self.decompress_response = decompress_response
+ self.decompress_response = decompress_response # type: Optional[bool]
else:
self.decompress_response = use_gzip
self.network_interface = network_interface
@@ -505,90 +549,96 @@ def __init__(self, url, method="GET", headers=None, body=None,
self.start_time = time.time()
@property
- def headers(self):
- return self._headers
+ def headers(self) -> httputil.HTTPHeaders:
+ # TODO: headers may actually be a plain dict until fairly late in
+ # the process (AsyncHTTPClient.fetch), but practically speaking,
+ # whenever the property is used they're already HTTPHeaders.
+ return self._headers # type: ignore
@headers.setter
- def headers(self, value):
+ def headers(self, value: Union[Dict[str, str], httputil.HTTPHeaders]) -> None:
if value is None:
self._headers = httputil.HTTPHeaders()
else:
- self._headers = value
+ self._headers = value # type: ignore
@property
- def body(self):
+ def body(self) -> bytes:
return self._body
@body.setter
- def body(self, value):
+ def body(self, value: Union[bytes, str]) -> None:
self._body = utf8(value)
- @property
- def body_producer(self):
- return self._body_producer
-
- @body_producer.setter
- def body_producer(self, value):
- self._body_producer = stack_context.wrap(value)
-
- @property
- def streaming_callback(self):
- return self._streaming_callback
-
- @streaming_callback.setter
- def streaming_callback(self, value):
- self._streaming_callback = stack_context.wrap(value)
-
- @property
- def header_callback(self):
- return self._header_callback
-
- @header_callback.setter
- def header_callback(self, value):
- self._header_callback = stack_context.wrap(value)
-
- @property
- def prepare_curl_callback(self):
- return self._prepare_curl_callback
-
- @prepare_curl_callback.setter
- def prepare_curl_callback(self, value):
- self._prepare_curl_callback = stack_context.wrap(value)
-
class HTTPResponse(object):
"""HTTP Response object.
Attributes:
- * request: HTTPRequest object
+ * ``request``: HTTPRequest object
- * code: numeric HTTP status code, e.g. 200 or 404
+ * ``code``: numeric HTTP status code, e.g. 200 or 404
- * reason: human-readable reason phrase describing the status code
+ * ``reason``: human-readable reason phrase describing the status code
- * headers: `tornado.httputil.HTTPHeaders` object
+ * ``headers``: `tornado.httputil.HTTPHeaders` object
- * effective_url: final location of the resource after following any
+ * ``effective_url``: final location of the resource after following any
redirects
- * buffer: ``cStringIO`` object for response body
+ * ``buffer``: ``cStringIO`` object for response body
+
+ * ``body``: response body as bytes (created on demand from ``self.buffer``)
- * body: response body as bytes (created on demand from ``self.buffer``)
+ * ``error``: Exception object, if any
- * error: Exception object, if any
+ * ``request_time``: seconds from request start to finish. Includes all
+ network operations from DNS resolution to receiving the last byte of
+ data. Does not include time spent in the queue (due to the
+ ``max_clients`` option). If redirects were followed, only includes
+ the final request.
- * request_time: seconds from request start to finish
+ * ``start_time``: Time at which the HTTP operation started, based on
+ `time.time` (not the monotonic clock used by `.IOLoop.time`). May
+ be ``None`` if the request timed out while in the queue.
- * time_info: dictionary of diagnostic timing information from the request.
- Available data are subject to change, but currently uses timings
+ * ``time_info``: dictionary of diagnostic timing information from the
+ request. Available data are subject to change, but currently uses timings
available from http://curl.haxx.se/libcurl/c/curl_easy_getinfo.html,
plus ``queue``, which is the delay (if any) introduced by waiting for
a slot under `AsyncHTTPClient`'s ``max_clients`` setting.
+
+ .. versionadded:: 5.1
+
+ Added the ``start_time`` attribute.
+
+ .. versionchanged:: 5.1
+
+ The ``request_time`` attribute previously included time spent in the queue
+ for ``simple_httpclient``, but not in ``curl_httpclient``. Now queueing time
+ is excluded in both implementations. ``request_time`` is now more accurate for
+ ``curl_httpclient`` because it uses a monotonic clock when available.
"""
- def __init__(self, request, code, headers=None, buffer=None,
- effective_url=None, error=None, request_time=None,
- time_info=None, reason=None):
+
+ # I'm not sure why these don't get type-inferred from the references in __init__.
+ error = None # type: Optional[BaseException]
+ _error_is_response_code = False
+ request = None # type: HTTPRequest
+
+ def __init__(
+ self,
+ request: HTTPRequest,
+ code: int,
+ headers: Optional[httputil.HTTPHeaders] = None,
+ buffer: Optional[BytesIO] = None,
+ effective_url: Optional[str] = None,
+ error: Optional[BaseException] = None,
+ request_time: Optional[float] = None,
+ time_info: Optional[Dict[str, float]] = None,
+ reason: Optional[str] = None,
+ start_time: Optional[float] = None,
+ ) -> None:
if isinstance(request, _RequestProxy):
self.request = request.request
else:
@@ -600,7 +650,7 @@ def __init__(self, request, code, headers=None, buffer=None,
else:
self.headers = httputil.HTTPHeaders()
self.buffer = buffer
- self._body = None
+ self._body = None # type: Optional[bytes]
if effective_url is None:
self.effective_url = request.url
else:
@@ -609,30 +659,30 @@ def __init__(self, request, code, headers=None, buffer=None,
if error is None:
if self.code < 200 or self.code >= 300:
self._error_is_response_code = True
- self.error = HTTPError(self.code, message=self.reason,
- response=self)
+ self.error = HTTPError(self.code, message=self.reason, response=self)
else:
self.error = None
else:
self.error = error
+ self.start_time = start_time
self.request_time = request_time
self.time_info = time_info or {}
@property
- def body(self):
+ def body(self) -> bytes:
if self.buffer is None:
- return None
+ return b""
elif self._body is None:
self._body = self.buffer.getvalue()
return self._body
- def rethrow(self):
+ def rethrow(self) -> None:
"""If there was an error on the request, raise an `HTTPError`."""
if self.error:
raise self.error
- def __repr__(self):
+ def __repr__(self) -> str:
args = ",".join("%s=%r" % i for i in sorted(self.__dict__.items()))
return "%s(%s)" % (self.__class__.__name__, args)
@@ -657,13 +707,19 @@ class HTTPClientError(Exception):
`tornado.web.HTTPError`. The name ``tornado.httpclient.HTTPError`` remains
as an alias.
"""
- def __init__(self, code, message=None, response=None):
+
+ def __init__(
+ self,
+ code: int,
+ message: Optional[str] = None,
+ response: Optional[HTTPResponse] = None,
+ ) -> None:
self.code = code
self.message = message or httputil.responses.get(code, "Unknown")
self.response = response
- super(HTTPClientError, self).__init__(code, message, response)
+ super().__init__(code, message, response)
- def __str__(self):
+ def __str__(self) -> str:
return "HTTP %d: %s" % (self.code, self.message)
# There is a cyclic reference between self and self.response,
@@ -681,11 +737,14 @@ class _RequestProxy(object):
Used internally by AsyncHTTPClient implementations.
"""
- def __init__(self, request, defaults):
+
+ def __init__(
+ self, request: HTTPRequest, defaults: Optional[Dict[str, Any]]
+ ) -> None:
self.request = request
self.defaults = defaults
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> Any:
request_attr = getattr(self.request, name)
if request_attr is not None:
return request_attr
@@ -695,8 +754,9 @@ def __getattr__(self, name):
return None
-def main():
+def main() -> None:
from tornado.options import define, options, parse_command_line
+
define("print_headers", type=bool, default=False)
define("print_body", type=bool, default=True)
define("follow_redirects", type=bool, default=True)
@@ -707,12 +767,13 @@ def main():
client = HTTPClient()
for arg in args:
try:
- response = client.fetch(arg,
- follow_redirects=options.follow_redirects,
- validate_cert=options.validate_cert,
- proxy_host=options.proxy_host,
- proxy_port=options.proxy_port,
- )
+ response = client.fetch(
+ arg,
+ follow_redirects=options.follow_redirects,
+ validate_cert=options.validate_cert,
+ proxy_host=options.proxy_host,
+ proxy_port=options.proxy_port,
+ )
except HTTPError as e:
if e.response is not None:
response = e.response
diff --git a/tornado/httpserver.py b/tornado/httpserver.py
index 3498d71fb6..cd4a468120 100644
--- a/tornado/httpserver.py
+++ b/tornado/httpserver.py
@@ -25,22 +25,25 @@ class except to start a server at the beginning of the process
to `tornado.httputil.HTTPServerRequest`. The old name remains as an alias.
"""
-from __future__ import absolute_import, division, print_function
-
import socket
+import ssl
from tornado.escape import native_str
from tornado.http1connection import HTTP1ServerConnection, HTTP1ConnectionParameters
-from tornado import gen
from tornado import httputil
from tornado import iostream
from tornado import netutil
from tornado.tcpserver import TCPServer
from tornado.util import Configurable
+import typing
+from typing import Union, Any, Dict, Callable, List, Type, Tuple, Optional, Awaitable
+
+if typing.TYPE_CHECKING:
+ from typing import Set # noqa: F401
+
-class HTTPServer(TCPServer, Configurable,
- httputil.HTTPServerConnectionDelegate):
+class HTTPServer(TCPServer, Configurable, httputil.HTTPServerConnectionDelegate):
r"""A non-blocking, single-threaded HTTP server.
A server is defined by a subclass of `.HTTPServerConnectionDelegate`,
@@ -137,7 +140,8 @@ class HTTPServer(TCPServer, Configurable,
.. versionchanged:: 5.0
The ``io_loop`` argument has been removed.
"""
- def __init__(self, *args, **kwargs):
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
# Ignore args to __init__; real initialization belongs in
# initialize since we're Configurable. (there's something
# weird in initialization order between this class,
@@ -145,13 +149,29 @@ def __init__(self, *args, **kwargs):
# completely)
pass
- def initialize(self, request_callback, no_keep_alive=False,
- xheaders=False, ssl_options=None, protocol=None,
- decompress_request=False,
- chunk_size=None, max_header_size=None,
- idle_connection_timeout=None, body_timeout=None,
- max_body_size=None, max_buffer_size=None,
- trusted_downstream=None):
+ def initialize(
+ self,
+ request_callback: Union[
+ httputil.HTTPServerConnectionDelegate,
+ Callable[[httputil.HTTPServerRequest], None],
+ ],
+ no_keep_alive: bool = False,
+ xheaders: bool = False,
+ ssl_options: Optional[Union[Dict[str, Any], ssl.SSLContext]] = None,
+ protocol: Optional[str] = None,
+ decompress_request: bool = False,
+ chunk_size: Optional[int] = None,
+ max_header_size: Optional[int] = None,
+ idle_connection_timeout: Optional[float] = None,
+ body_timeout: Optional[float] = None,
+ max_body_size: Optional[int] = None,
+ max_buffer_size: Optional[int] = None,
+ trusted_downstream: Optional[List[str]] = None,
+ ) -> None:
+ # This method's signature is not extracted with autodoc
+ # because we want its arguments to appear on the class
+ # constructor. When changing this signature, also update the
+ # copy in httpserver.rst.
self.request_callback = request_callback
self.xheaders = xheaders
self.protocol = protocol
@@ -162,38 +182,55 @@ def initialize(self, request_callback, no_keep_alive=False,
header_timeout=idle_connection_timeout or 3600,
max_body_size=max_body_size,
body_timeout=body_timeout,
- no_keep_alive=no_keep_alive)
- TCPServer.__init__(self, ssl_options=ssl_options,
- max_buffer_size=max_buffer_size,
- read_chunk_size=chunk_size)
- self._connections = set()
+ no_keep_alive=no_keep_alive,
+ )
+ TCPServer.__init__(
+ self,
+ ssl_options=ssl_options,
+ max_buffer_size=max_buffer_size,
+ read_chunk_size=chunk_size,
+ )
+ self._connections = set() # type: Set[HTTP1ServerConnection]
self.trusted_downstream = trusted_downstream
@classmethod
- def configurable_base(cls):
+ def configurable_base(cls) -> Type[Configurable]:
return HTTPServer
@classmethod
- def configurable_default(cls):
+ def configurable_default(cls) -> Type[Configurable]:
return HTTPServer
- @gen.coroutine
- def close_all_connections(self):
+ async def close_all_connections(self) -> None:
+ """Close all open connections and asynchronously wait for them to finish.
+
+ This method is used in combination with `~.TCPServer.stop` to
+ support clean shutdowns (especially for unittests). Typical
+ usage would call ``stop()`` first to stop accepting new
+ connections, then ``await close_all_connections()`` to wait for
+ existing connections to finish.
+
+ This method does not currently close open websocket connections.
+
+ Note that this method is a coroutine and must be called with ``await``.
+
+ """
while self._connections:
# Peek at an arbitrary element of the set
conn = next(iter(self._connections))
- yield conn.close()
-
- def handle_stream(self, stream, address):
- context = _HTTPRequestContext(stream, address,
- self.protocol,
- self.trusted_downstream)
- conn = HTTP1ServerConnection(
- stream, self.conn_params, context)
+ await conn.close()
+
+ def handle_stream(self, stream: iostream.IOStream, address: Tuple) -> None:
+ context = _HTTPRequestContext(
+ stream, address, self.protocol, self.trusted_downstream
+ )
+ conn = HTTP1ServerConnection(stream, self.conn_params, context)
self._connections.add(conn)
conn.start_serving(self)
- def start_request(self, server_conn, request_conn):
+ def start_request(
+ self, server_conn: object, request_conn: httputil.HTTPConnection
+ ) -> httputil.HTTPMessageDelegate:
if isinstance(self.request_callback, httputil.HTTPServerConnectionDelegate):
delegate = self.request_callback.start_request(server_conn, request_conn)
else:
@@ -204,37 +241,56 @@ def start_request(self, server_conn, request_conn):
return delegate
- def on_close(self, server_conn):
- self._connections.remove(server_conn)
+ def on_close(self, server_conn: object) -> None:
+ self._connections.remove(typing.cast(HTTP1ServerConnection, server_conn))
class _CallableAdapter(httputil.HTTPMessageDelegate):
- def __init__(self, request_callback, request_conn):
+ def __init__(
+ self,
+ request_callback: Callable[[httputil.HTTPServerRequest], None],
+ request_conn: httputil.HTTPConnection,
+ ) -> None:
self.connection = request_conn
self.request_callback = request_callback
- self.request = None
+ self.request = None # type: Optional[httputil.HTTPServerRequest]
self.delegate = None
- self._chunks = []
+ self._chunks = [] # type: List[bytes]
- def headers_received(self, start_line, headers):
+ def headers_received(
+ self,
+ start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
+ headers: httputil.HTTPHeaders,
+ ) -> Optional[Awaitable[None]]:
self.request = httputil.HTTPServerRequest(
- connection=self.connection, start_line=start_line,
- headers=headers)
+ connection=self.connection,
+ start_line=typing.cast(httputil.RequestStartLine, start_line),
+ headers=headers,
+ )
+ return None
- def data_received(self, chunk):
+ def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]:
self._chunks.append(chunk)
+ return None
- def finish(self):
- self.request.body = b''.join(self._chunks)
+ def finish(self) -> None:
+ assert self.request is not None
+ self.request.body = b"".join(self._chunks)
self.request._parse_body()
self.request_callback(self.request)
- def on_connection_close(self):
- self._chunks = None
+ def on_connection_close(self) -> None:
+ del self._chunks
class _HTTPRequestContext(object):
- def __init__(self, stream, address, protocol, trusted_downstream=None):
+ def __init__(
+ self,
+ stream: iostream.IOStream,
+ address: Tuple,
+ protocol: Optional[str],
+ trusted_downstream: Optional[List[str]] = None,
+ ) -> None:
self.address = address
# Save the socket's address family now so we know how to
# interpret self.address even after the stream is closed
@@ -244,12 +300,14 @@ def __init__(self, stream, address, protocol, trusted_downstream=None):
else:
self.address_family = None
# In HTTPServerRequest we want an IP, not a full socket address.
- if (self.address_family in (socket.AF_INET, socket.AF_INET6) and
- address is not None):
+ if (
+ self.address_family in (socket.AF_INET, socket.AF_INET6)
+ and address is not None
+ ):
self.remote_ip = address[0]
else:
# Unix (or other) socket; fake the remote address.
- self.remote_ip = '0.0.0.0'
+ self.remote_ip = "0.0.0.0"
if protocol:
self.protocol = protocol
elif isinstance(stream, iostream.SSLIOStream):
@@ -260,7 +318,7 @@ def __init__(self, stream, address, protocol, trusted_downstream=None):
self._orig_protocol = self.protocol
self.trusted_downstream = set(trusted_downstream or [])
- def __str__(self):
+ def __str__(self) -> str:
if self.address_family in (socket.AF_INET, socket.AF_INET6):
return self.remote_ip
elif isinstance(self.address, bytes):
@@ -271,12 +329,12 @@ def __str__(self):
else:
return str(self.address)
- def _apply_xheaders(self, headers):
+ def _apply_xheaders(self, headers: httputil.HTTPHeaders) -> None:
"""Rewrite the ``remote_ip`` and ``protocol`` fields."""
# Squid uses X-Forwarded-For, others use X-Real-Ip
ip = headers.get("X-Forwarded-For", self.remote_ip)
# Skip trusted downstream hosts in X-Forwarded-For list
- for ip in (cand.strip() for cand in reversed(ip.split(','))):
+ for ip in (cand.strip() for cand in reversed(ip.split(","))):
if ip not in self.trusted_downstream:
break
ip = headers.get("X-Real-Ip", ip)
@@ -284,16 +342,16 @@ def _apply_xheaders(self, headers):
self.remote_ip = ip
# AWS uses X-Forwarded-Proto
proto_header = headers.get(
- "X-Scheme", headers.get("X-Forwarded-Proto",
- self.protocol))
+ "X-Scheme", headers.get("X-Forwarded-Proto", self.protocol)
+ )
if proto_header:
# use only the last proto entry if there is more than one
- # TODO: support trusting mutiple layers of proxied protocol
- proto_header = proto_header.split(',')[-1].strip()
+ # TODO: support trusting multiple layers of proxied protocol
+ proto_header = proto_header.split(",")[-1].strip()
if proto_header in ("http", "https"):
self.protocol = proto_header
- def _unapply_xheaders(self):
+ def _unapply_xheaders(self) -> None:
"""Undo changes from `_apply_xheaders`.
Xheaders are per-request so they should not leak to the next
@@ -304,27 +362,37 @@ def _unapply_xheaders(self):
class _ProxyAdapter(httputil.HTTPMessageDelegate):
- def __init__(self, delegate, request_conn):
+ def __init__(
+ self,
+ delegate: httputil.HTTPMessageDelegate,
+ request_conn: httputil.HTTPConnection,
+ ) -> None:
self.connection = request_conn
self.delegate = delegate
- def headers_received(self, start_line, headers):
- self.connection.context._apply_xheaders(headers)
+ def headers_received(
+ self,
+ start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
+ headers: httputil.HTTPHeaders,
+ ) -> Optional[Awaitable[None]]:
+ # TODO: either make context an official part of the
+ # HTTPConnection interface or figure out some other way to do this.
+ self.connection.context._apply_xheaders(headers) # type: ignore
return self.delegate.headers_received(start_line, headers)
- def data_received(self, chunk):
+ def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]:
return self.delegate.data_received(chunk)
- def finish(self):
+ def finish(self) -> None:
self.delegate.finish()
self._cleanup()
- def on_connection_close(self):
+ def on_connection_close(self) -> None:
self.delegate.on_connection_close()
self._cleanup()
- def _cleanup(self):
- self.connection.context._unapply_xheaders()
+ def _cleanup(self) -> None:
+ self.connection.context._unapply_xheaders() # type: ignore
HTTPRequest = httputil.HTTPServerRequest
diff --git a/tornado/httputil.py b/tornado/httputil.py
index 9c607b8c85..c0c57e6e95 100644
--- a/tornado/httputil.py
+++ b/tornado/httputil.py
@@ -19,91 +19,61 @@
via `tornado.web.RequestHandler.request`.
"""
-from __future__ import absolute_import, division, print_function
-
import calendar
-import collections
+import collections.abc
import copy
import datetime
import email.utils
-import numbers
+from functools import lru_cache
+from http.client import responses
+import http.cookies
import re
+from ssl import SSLError
import time
-import warnings
+import unicodedata
+from urllib.parse import urlencode, urlparse, urlunparse, parse_qsl
from tornado.escape import native_str, parse_qs_bytes, utf8
from tornado.log import gen_log
-from tornado.util import ObjectDict, PY3
-
-if PY3:
- import http.cookies as Cookie
- from http.client import responses
- from urllib.parse import urlencode, urlparse, urlunparse, parse_qsl
-else:
- import Cookie
- from httplib import responses
- from urllib import urlencode
- from urlparse import urlparse, urlunparse, parse_qsl
+from tornado.util import ObjectDict, unicode_type
# responses is unused in this file, but we re-export it to other files.
# Reference it so pyflakes doesn't complain.
responses
-try:
- from ssl import SSLError
-except ImportError:
- # ssl is unavailable on app engine.
- class _SSLError(Exception):
- pass
- # Hack around a mypy limitation. We can't simply put "type: ignore"
- # on the class definition itself; must go through an assignment.
- SSLError = _SSLError # type: ignore
-
-try:
- import typing # noqa: F401
-except ImportError:
- pass
-
-
-# RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line
-# terminator and ignore any preceding CR.
-_CRLF_RE = re.compile(r'\r?\n')
-
-
-class _NormalizedHeaderCache(dict):
- """Dynamic cached mapping of header names to Http-Header-Case.
-
- Implemented as a dict subclass so that cache hits are as fast as a
- normal dict lookup, without the overhead of a python function
- call.
-
- >>> normalized_headers = _NormalizedHeaderCache(10)
- >>> normalized_headers["coNtent-TYPE"]
+import typing
+from typing import (
+ Tuple,
+ Iterable,
+ List,
+ Mapping,
+ Iterator,
+ Dict,
+ Union,
+ Optional,
+ Awaitable,
+ Generator,
+ AnyStr,
+)
+
+if typing.TYPE_CHECKING:
+ from typing import Deque # noqa: F401
+ from asyncio import Future # noqa: F401
+ import unittest # noqa: F401
+
+
+@lru_cache(1000)
+def _normalize_header(name: str) -> str:
+ """Map a header name to Http-Header-Case.
+
+ >>> _normalize_header("coNtent-TYPE")
'Content-Type'
"""
- def __init__(self, size):
- super(_NormalizedHeaderCache, self).__init__()
- self.size = size
- self.queue = collections.deque()
-
- def __missing__(self, key):
- normalized = "-".join([w.capitalize() for w in key.split("-")])
- self[key] = normalized
- self.queue.append(key)
- if len(self.queue) > self.size:
- # Limit the size of the cache. LRU would be better, but this
- # simpler approach should be fine. In Python 2.7+ we could
- # use OrderedDict (or in 3.2+, @functools.lru_cache).
- old_key = self.queue.popleft()
- del self[old_key]
- return normalized
+ return "-".join([w.capitalize() for w in name.split("-")])
-_normalized_headers = _NormalizedHeaderCache(1000)
-
-
-class HTTPHeaders(collections.MutableMapping):
+class HTTPHeaders(collections.abc.MutableMapping):
"""A dictionary that maintains ``Http-Header-Case`` for all keys.
Supports multiple values per key via a pair of new methods,
@@ -131,12 +101,28 @@ class HTTPHeaders(collections.MutableMapping):
Set-Cookie: A=B
Set-Cookie: C=D
"""
- def __init__(self, *args, **kwargs):
+
+ @typing.overload
+ def __init__(self, __arg: Mapping[str, List[str]]) -> None:
+ pass
+
+ @typing.overload # noqa: F811
+ def __init__(self, __arg: Mapping[str, str]) -> None:
+ pass
+
+ @typing.overload # noqa: F811
+ def __init__(self, *args: Tuple[str, str]) -> None:
+ pass
+
+ @typing.overload # noqa: F811
+ def __init__(self, **kwargs: str) -> None:
+ pass
+
+ def __init__(self, *args: typing.Any, **kwargs: str) -> None: # noqa: F811
self._dict = {} # type: typing.Dict[str, str]
self._as_list = {} # type: typing.Dict[str, typing.List[str]]
- self._last_key = None
- if (len(args) == 1 and len(kwargs) == 0 and
- isinstance(args[0], HTTPHeaders)):
+ self._last_key = None # type: Optional[str]
+ if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], HTTPHeaders):
# Copy constructor
for k, v in args[0].get_all():
self.add(k, v)
@@ -146,25 +132,24 @@ def __init__(self, *args, **kwargs):
# new public methods
- def add(self, name, value):
- # type: (str, str) -> None
+ def add(self, name: str, value: str) -> None:
"""Adds a new value for the given key."""
- norm_name = _normalized_headers[name]
+ norm_name = _normalize_header(name)
self._last_key = norm_name
if norm_name in self:
- self._dict[norm_name] = (native_str(self[norm_name]) + ',' +
- native_str(value))
+ self._dict[norm_name] = (
+ native_str(self[norm_name]) + "," + native_str(value)
+ )
self._as_list[norm_name].append(value)
else:
self[norm_name] = value
- def get_list(self, name):
+ def get_list(self, name: str) -> List[str]:
"""Returns all values for the given header as a list."""
- norm_name = _normalized_headers[name]
+ norm_name = _normalize_header(name)
return self._as_list.get(norm_name, [])
- def get_all(self):
- # type: () -> typing.Iterable[typing.Tuple[str, str]]
+ def get_all(self) -> Iterable[Tuple[str, str]]:
"""Returns an iterable of all (name, value) pairs.
If a header has multiple values, multiple pairs will be
@@ -174,7 +159,7 @@ def get_all(self):
for value in values:
yield (name, value)
- def parse_line(self, line):
+ def parse_line(self, line: str) -> None:
"""Updates the dictionary with a single header line.
>>> h = HTTPHeaders()
@@ -186,7 +171,7 @@ def parse_line(self, line):
# continuation of a multi-line header
if self._last_key is None:
raise HTTPInputError("first header line cannot start with whitespace")
- new_part = ' ' + line.lstrip()
+ new_part = " " + line.lstrip()
self._as_list[self._last_key][-1] += new_part
self._dict[self._last_key] += new_part
else:
@@ -197,7 +182,7 @@ def parse_line(self, line):
self.add(name, value.strip())
@classmethod
- def parse(cls, headers):
+ def parse(cls, headers: str) -> "HTTPHeaders":
"""Returns a dictionary from HTTP header text.
>>> h = HTTPHeaders.parse("Content-Type: text/html\\r\\nContent-Length: 42\\r\\n")
@@ -211,34 +196,37 @@ def parse(cls, headers):
"""
h = cls()
- for line in _CRLF_RE.split(headers):
+ # RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line
+ # terminator and ignore any preceding CR.
+ for line in headers.split("\n"):
+ if line.endswith("\r"):
+ line = line[:-1]
if line:
h.parse_line(line)
return h
# MutableMapping abstract method implementations.
- def __setitem__(self, name, value):
- norm_name = _normalized_headers[name]
+ def __setitem__(self, name: str, value: str) -> None:
+ norm_name = _normalize_header(name)
self._dict[norm_name] = value
self._as_list[norm_name] = [value]
- def __getitem__(self, name):
- # type: (str) -> str
- return self._dict[_normalized_headers[name]]
+ def __getitem__(self, name: str) -> str:
+ return self._dict[_normalize_header(name)]
- def __delitem__(self, name):
- norm_name = _normalized_headers[name]
+ def __delitem__(self, name: str) -> None:
+ norm_name = _normalize_header(name)
del self._dict[norm_name]
del self._as_list[norm_name]
- def __len__(self):
+ def __len__(self) -> int:
return len(self._dict)
- def __iter__(self):
+ def __iter__(self) -> Iterator[typing.Any]:
return iter(self._dict)
- def copy(self):
+ def copy(self) -> "HTTPHeaders":
# defined in dict but not in MutableMapping.
return HTTPHeaders(self)
@@ -247,7 +235,7 @@ def copy(self):
# the appearance that HTTPHeaders is a single container.
__copy__ = copy
- def __str__(self):
+ def __str__(self) -> str:
lines = []
for name, value in self.get_all():
lines.append("%s: %s\n" % (name, value))
@@ -348,9 +336,26 @@ class HTTPServerRequest(object):
.. versionchanged:: 4.0
Moved from ``tornado.httpserver.HTTPRequest``.
"""
- def __init__(self, method=None, uri=None, version="HTTP/1.0", headers=None,
- body=None, host=None, files=None, connection=None,
- start_line=None, server_connection=None):
+
+ path = None # type: str
+ query = None # type: str
+
+ # HACK: Used for stream_request_body
+ _body_future = None # type: Future[None]
+
+ def __init__(
+ self,
+ method: Optional[str] = None,
+ uri: Optional[str] = None,
+ version: str = "HTTP/1.0",
+ headers: Optional[HTTPHeaders] = None,
+ body: Optional[bytes] = None,
+ host: Optional[str] = None,
+ files: Optional[Dict[str, List["HTTPFile"]]] = None,
+ connection: Optional["HTTPConnection"] = None,
+ start_line: Optional["RequestStartLine"] = None,
+ server_connection: Optional[object] = None,
+ ) -> None:
if start_line is not None:
method, uri, version = start_line
self.method = method
@@ -360,9 +365,9 @@ def __init__(self, method=None, uri=None, version="HTTP/1.0", headers=None,
self.body = body or b""
# set remote IP and protocol
- context = getattr(connection, 'context', None)
- self.remote_ip = getattr(context, 'remote_ip', None)
- self.protocol = getattr(context, 'protocol', "http")
+ context = getattr(connection, "context", None)
+ self.remote_ip = getattr(context, "remote_ip", None)
+ self.protocol = getattr(context, "protocol", "http")
self.host = host or self.headers.get("Host") or "127.0.0.1"
self.host_name = split_host_and_port(self.host.lower())[0]
@@ -372,31 +377,19 @@ def __init__(self, method=None, uri=None, version="HTTP/1.0", headers=None,
self._start_time = time.time()
self._finish_time = None
- self.path, sep, self.query = uri.partition('?')
+ if uri is not None:
+ self.path, sep, self.query = uri.partition("?")
self.arguments = parse_qs_bytes(self.query, keep_blank_values=True)
self.query_arguments = copy.deepcopy(self.arguments)
- self.body_arguments = {}
-
- def supports_http_1_1(self):
- """Returns True if this request supports HTTP/1.1 semantics.
-
- .. deprecated:: 4.0
-
- Applications are less likely to need this information with
- the introduction of `.HTTPConnection`. If you still need
- it, access the ``version`` attribute directly. This method
- will be removed in Tornado 6.0.
-
- """
- warnings.warn("supports_http_1_1() is deprecated, use request.version instead",
- DeprecationWarning)
- return self.version == "HTTP/1.1"
+ self.body_arguments = {} # type: Dict[str, List[bytes]]
@property
- def cookies(self):
- """A dictionary of Cookie.Morsel objects."""
+ def cookies(self) -> Dict[str, http.cookies.Morsel]:
+ """A dictionary of ``http.cookies.Morsel`` objects."""
if not hasattr(self, "_cookies"):
- self._cookies = Cookie.SimpleCookie()
+ self._cookies = (
+ http.cookies.SimpleCookie()
+ ) # type: http.cookies.SimpleCookie
if "Cookie" in self.headers:
try:
parsed = parse_cookie(self.headers["Cookie"])
@@ -413,44 +406,20 @@ def cookies(self):
pass
return self._cookies
- def write(self, chunk, callback=None):
- """Writes the given chunk to the response stream.
-
- .. deprecated:: 4.0
- Use ``request.connection`` and the `.HTTPConnection` methods
- to write the response. This method will be removed in Tornado 6.0.
- """
- warnings.warn("req.write deprecated, use req.connection.write and write_headers instead",
- DeprecationWarning)
- assert isinstance(chunk, bytes)
- assert self.version.startswith("HTTP/1."), \
- "deprecated interface only supported in HTTP/1.x"
- self.connection.write(chunk, callback=callback)
-
- def finish(self):
- """Finishes this HTTP request on the open connection.
-
- .. deprecated:: 4.0
- Use ``request.connection`` and the `.HTTPConnection` methods
- to write the response. This method will be removed in Tornado 6.0.
- """
- warnings.warn("req.finish deprecated, use req.connection.finish instead",
- DeprecationWarning)
- self.connection.finish()
- self._finish_time = time.time()
-
- def full_url(self):
+ def full_url(self) -> str:
"""Reconstructs the full URL for this request."""
return self.protocol + "://" + self.host + self.uri
- def request_time(self):
+ def request_time(self) -> float:
"""Returns the amount of time it took for this request to execute."""
if self._finish_time is None:
return time.time() - self._start_time
else:
return self._finish_time - self._start_time
- def get_ssl_certificate(self, binary_form=False):
+ def get_ssl_certificate(
+ self, binary_form: bool = False
+ ) -> Union[None, Dict, bytes]:
"""Returns the client's SSL certificate, if any.
To use client certificates, the HTTPServer's
@@ -470,21 +439,28 @@ def get_ssl_certificate(self, binary_form=False):
http://docs.python.org/library/ssl.html#sslsocket-objects
"""
try:
- return self.connection.stream.socket.getpeercert(
- binary_form=binary_form)
+ if self.connection is None:
+ return None
+ # TODO: add a method to HTTPConnection for this so it can work with HTTP/2
+ return self.connection.stream.socket.getpeercert( # type: ignore
+ binary_form=binary_form
+ )
except SSLError:
return None
- def _parse_body(self):
+ def _parse_body(self) -> None:
parse_body_arguments(
- self.headers.get("Content-Type", ""), self.body,
- self.body_arguments, self.files,
- self.headers)
+ self.headers.get("Content-Type", ""),
+ self.body,
+ self.body_arguments,
+ self.files,
+ self.headers,
+ )
for k, v in self.body_arguments.items():
self.arguments.setdefault(k, []).extend(v)
- def __repr__(self):
+ def __repr__(self) -> str:
attrs = ("protocol", "host", "method", "uri", "version", "remote_ip")
args = ", ".join(["%s=%r" % (n, getattr(self, n)) for n in attrs])
return "%s(%s)" % (self.__class__.__name__, args)
@@ -496,6 +472,7 @@ class HTTPInputError(Exception):
.. versionadded:: 4.0
"""
+
pass
@@ -504,6 +481,7 @@ class HTTPOutputError(Exception):
.. versionadded:: 4.0
"""
+
pass
@@ -512,7 +490,10 @@ class HTTPServerConnectionDelegate(object):
.. versionadded:: 4.0
"""
- def start_request(self, server_conn, request_conn):
+
+ def start_request(
+ self, server_conn: object, request_conn: "HTTPConnection"
+ ) -> "HTTPMessageDelegate":
"""This method is called by the server when a new request has started.
:arg server_conn: is an opaque object representing the long-lived
@@ -524,7 +505,7 @@ def start_request(self, server_conn, request_conn):
"""
raise NotImplementedError()
- def on_close(self, server_conn):
+ def on_close(self, server_conn: object) -> None:
"""This method is called when a connection has been closed.
:arg server_conn: is a server connection that has previously been
@@ -538,7 +519,13 @@ class HTTPMessageDelegate(object):
.. versionadded:: 4.0
"""
- def headers_received(self, start_line, headers):
+
+ # TODO: genericize this class to avoid exposing the Union.
+ def headers_received(
+ self,
+ start_line: Union["RequestStartLine", "ResponseStartLine"],
+ headers: HTTPHeaders,
+ ) -> Optional[Awaitable[None]]:
"""Called when the HTTP headers have been received and parsed.
:arg start_line: a `.RequestStartLine` or `.ResponseStartLine`
@@ -553,18 +540,18 @@ def headers_received(self, start_line, headers):
"""
pass
- def data_received(self, chunk):
+ def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]:
"""Called when a chunk of data has been received.
May return a `.Future` for flow control.
"""
pass
- def finish(self):
+ def finish(self) -> None:
"""Called after the last chunk of data has been received."""
pass
- def on_connection_close(self):
+ def on_connection_close(self) -> None:
"""Called if the connection is closed without finishing the request.
If ``headers_received`` is called, either ``finish`` or
@@ -578,7 +565,13 @@ class HTTPConnection(object):
.. versionadded:: 4.0
"""
- def write_headers(self, start_line, headers, chunk=None, callback=None):
+
+ def write_headers(
+ self,
+ start_line: Union["RequestStartLine", "ResponseStartLine"],
+ headers: HTTPHeaders,
+ chunk: Optional[bytes] = None,
+ ) -> "Future[None]":
"""Write an HTTP header block.
:arg start_line: a `.RequestStartLine` or `.ResponseStartLine`.
@@ -586,29 +579,39 @@ def write_headers(self, start_line, headers, chunk=None, callback=None):
:arg chunk: the first (optional) chunk of data. This is an optimization
so that small responses can be written in the same call as their
headers.
- :arg callback: a callback to be run when the write is complete.
The ``version`` field of ``start_line`` is ignored.
- Returns a `.Future` if no callback is given.
+ Returns a future for flow control.
+
+ .. versionchanged:: 6.0
+
+ The ``callback`` argument was removed.
"""
raise NotImplementedError()
- def write(self, chunk, callback=None):
+ def write(self, chunk: bytes) -> "Future[None]":
"""Writes a chunk of body data.
- The callback will be run when the write is complete. If no callback
- is given, returns a Future.
+ Returns a future for flow control.
+
+ .. versionchanged:: 6.0
+
+ The ``callback`` argument was removed.
"""
raise NotImplementedError()
- def finish(self):
- """Indicates that the last body data has been written.
- """
+ def finish(self) -> None:
+ """Indicates that the last body data has been written."""
raise NotImplementedError()
-def url_concat(url, args):
+def url_concat(
+ url: str,
+ args: Union[
+ None, Dict[str, str], List[Tuple[str, str]], Tuple[Tuple[str, str], ...]
+ ],
+) -> str:
"""Concatenate url and arguments regardless of whether
url has existing query parameters.
@@ -633,16 +636,20 @@ def url_concat(url, args):
parsed_query.extend(args)
else:
err = "'args' parameter should be dict, list or tuple. Not {0}".format(
- type(args))
+ type(args)
+ )
raise TypeError(err)
final_query = urlencode(parsed_query)
- url = urlunparse((
- parsed_url[0],
- parsed_url[1],
- parsed_url[2],
- parsed_url[3],
- final_query,
- parsed_url[5]))
+ url = urlunparse(
+ (
+ parsed_url[0],
+ parsed_url[1],
+ parsed_url[2],
+ parsed_url[3],
+ final_query,
+ parsed_url[5],
+ )
+ )
return url
@@ -656,10 +663,13 @@ class HTTPFile(ObjectDict):
* ``body``
* ``content_type``
"""
+
pass
-def _parse_request_range(range_header):
+def _parse_request_range(
+ range_header: str,
+) -> Optional[Tuple[Optional[int], Optional[int]]]:
"""Parses a Range header.
Returns either ``None`` or tuple ``(start, end)``.
@@ -708,7 +718,7 @@ def _parse_request_range(range_header):
return (start, end)
-def _get_content_range(start, end, total):
+def _get_content_range(start: Optional[int], end: Optional[int], total: int) -> str:
"""Returns a suitable Content-Range header:
>>> print(_get_content_range(None, 1, 4))
@@ -723,14 +733,20 @@ def _get_content_range(start, end, total):
return "bytes %s-%s/%s" % (start, end, total)
-def _int_or_none(val):
+def _int_or_none(val: str) -> Optional[int]:
val = val.strip()
if val == "":
return None
return int(val)
-def parse_body_arguments(content_type, body, arguments, files, headers=None):
+def parse_body_arguments(
+ content_type: str,
+ body: bytes,
+ arguments: Dict[str, List[bytes]],
+ files: Dict[str, List[HTTPFile]],
+ headers: Optional[HTTPHeaders] = None,
+) -> None:
"""Parses a form request body.
Supports ``application/x-www-form-urlencoded`` and
@@ -739,20 +755,27 @@ def parse_body_arguments(content_type, body, arguments, files, headers=None):
and ``files`` parameters are dictionaries that will be updated
with the parsed contents.
"""
- if headers and 'Content-Encoding' in headers:
- gen_log.warning("Unsupported Content-Encoding: %s",
- headers['Content-Encoding'])
- return
if content_type.startswith("application/x-www-form-urlencoded"):
+ if headers and "Content-Encoding" in headers:
+ gen_log.warning(
+ "Unsupported Content-Encoding: %s", headers["Content-Encoding"]
+ )
+ return
try:
- uri_arguments = parse_qs_bytes(native_str(body), keep_blank_values=True)
+ # real charset decoding will happen in RequestHandler.decode_argument()
+ uri_arguments = parse_qs_bytes(body, keep_blank_values=True)
except Exception as e:
- gen_log.warning('Invalid x-www-form-urlencoded body: %s', e)
+ gen_log.warning("Invalid x-www-form-urlencoded body: %s", e)
uri_arguments = {}
for name, values in uri_arguments.items():
if values:
arguments.setdefault(name, []).extend(values)
elif content_type.startswith("multipart/form-data"):
+ if headers and "Content-Encoding" in headers:
+ gen_log.warning(
+ "Unsupported Content-Encoding: %s", headers["Content-Encoding"]
+ )
+ return
try:
fields = content_type.split(";")
for field in fields:
@@ -766,12 +789,22 @@ def parse_body_arguments(content_type, body, arguments, files, headers=None):
gen_log.warning("Invalid multipart/form-data: %s", e)
-def parse_multipart_form_data(boundary, data, arguments, files):
+def parse_multipart_form_data(
+ boundary: bytes,
+ data: bytes,
+ arguments: Dict[str, List[bytes]],
+ files: Dict[str, List[HTTPFile]],
+) -> None:
"""Parses a ``multipart/form-data`` body.
The ``boundary`` and ``data`` parameters are both byte strings.
The dictionaries given in the arguments and files parameters
will be updated with the contents of the body.
+
+ .. versionchanged:: 5.1
+
+ Now recognizes non-ASCII filenames in RFC 2231/5987
+ (``filename*=``) format.
"""
# The standard allows for the boundary to be quoted in the header,
# although it's rare (it happens at least for google app engine
@@ -798,21 +831,25 @@ def parse_multipart_form_data(boundary, data, arguments, files):
if disposition != "form-data" or not part.endswith(b"\r\n"):
gen_log.warning("Invalid multipart/form-data")
continue
- value = part[eoh + 4:-2]
+ value = part[eoh + 4 : -2]
if not disp_params.get("name"):
gen_log.warning("multipart/form-data value missing name")
continue
name = disp_params["name"]
if disp_params.get("filename"):
ctype = headers.get("Content-Type", "application/unknown")
- files.setdefault(name, []).append(HTTPFile( # type: ignore
- filename=disp_params["filename"], body=value,
- content_type=ctype))
+ files.setdefault(name, []).append(
+ HTTPFile(
+ filename=disp_params["filename"], body=value, content_type=ctype
+ )
+ )
else:
arguments.setdefault(name, []).append(value)
-def format_timestamp(ts):
+def format_timestamp(
+ ts: Union[int, float, tuple, time.struct_time, datetime.datetime]
+) -> str:
"""Formats a timestamp in the format used by HTTP.
The argument may be a numeric timestamp as returned by `time.time`,
@@ -822,22 +859,26 @@ def format_timestamp(ts):
>>> format_timestamp(1359312200)
'Sun, 27 Jan 2013 18:43:20 GMT'
"""
- if isinstance(ts, numbers.Real):
- pass
+ if isinstance(ts, (int, float)):
+ time_num = ts
elif isinstance(ts, (tuple, time.struct_time)):
- ts = calendar.timegm(ts)
+ time_num = calendar.timegm(ts)
elif isinstance(ts, datetime.datetime):
- ts = calendar.timegm(ts.utctimetuple())
+ time_num = calendar.timegm(ts.utctimetuple())
else:
raise TypeError("unknown timestamp type: %r" % ts)
- return email.utils.formatdate(ts, usegmt=True)
+ return email.utils.formatdate(time_num, usegmt=True)
RequestStartLine = collections.namedtuple(
- 'RequestStartLine', ['method', 'path', 'version'])
+ "RequestStartLine", ["method", "path", "version"]
+)
-def parse_request_start_line(line):
+_http_version_re = re.compile(r"^HTTP/1\.[0-9]$")
+
+
+def parse_request_start_line(line: str) -> RequestStartLine:
"""Returns a (method, path, version) tuple for an HTTP 1.x request line.
The response is a `collections.namedtuple`.
@@ -851,17 +892,22 @@ def parse_request_start_line(line):
# https://tools.ietf.org/html/rfc7230#section-3.1.1
# invalid request-line SHOULD respond with a 400 (Bad Request)
raise HTTPInputError("Malformed HTTP request line")
- if not re.match(r"^HTTP/1\.[0-9]$", version):
+ if not _http_version_re.match(version):
raise HTTPInputError(
- "Malformed HTTP version in HTTP Request-Line: %r" % version)
+ "Malformed HTTP version in HTTP Request-Line: %r" % version
+ )
return RequestStartLine(method, path, version)
ResponseStartLine = collections.namedtuple(
- 'ResponseStartLine', ['version', 'code', 'reason'])
+ "ResponseStartLine", ["version", "code", "reason"]
+)
+
+
+_http_response_line_re = re.compile(r"(HTTP/1.[0-9]) ([0-9]+) ([^\r]*)")
-def parse_response_start_line(line):
+def parse_response_start_line(line: str) -> ResponseStartLine:
"""Returns a (version, code, reason) tuple for an HTTP 1.x response line.
The response is a `collections.namedtuple`.
@@ -870,25 +916,26 @@ def parse_response_start_line(line):
ResponseStartLine(version='HTTP/1.1', code=200, reason='OK')
"""
line = native_str(line)
- match = re.match("(HTTP/1.[0-9]) ([0-9]+) ([^\r]*)", line)
+ match = _http_response_line_re.match(line)
if not match:
raise HTTPInputError("Error parsing response start line")
- return ResponseStartLine(match.group(1), int(match.group(2)),
- match.group(3))
+ return ResponseStartLine(match.group(1), int(match.group(2)), match.group(3))
+
# _parseparam and _parse_header are copied and modified from python2.7's cgi.py
# The original 2.7 version of this code did not correctly support some
# combinations of semicolons and double quotes.
# It has also been modified to support valueless parameters as seen in
-# websocket extension negotiations.
+# websocket extension negotiations, and to support non-ascii values in
+# RFC 2231/5987 format.
-def _parseparam(s):
- while s[:1] == ';':
+def _parseparam(s: str) -> Generator[str, None, None]:
+ while s[:1] == ";":
s = s[1:]
- end = s.find(';')
+ end = s.find(";")
while end > 0 and (s.count('"', 0, end) - s.count('\\"', 0, end)) % 2:
- end = s.find(';', end + 1)
+ end = s.find(";", end + 1)
if end < 0:
end = len(s)
f = s[:end]
@@ -896,30 +943,42 @@ def _parseparam(s):
s = s[end:]
-def _parse_header(line):
- """Parse a Content-type like header.
+def _parse_header(line: str) -> Tuple[str, Dict[str, str]]:
+ r"""Parse a Content-type like header.
Return the main content-type and a dictionary of options.
+ >>> d = "form-data; foo=\"b\\\\a\\\"r\"; file*=utf-8''T%C3%A4st"
+ >>> ct, d = _parse_header(d)
+ >>> ct
+ 'form-data'
+ >>> d['file'] == r'T\u00e4st'.encode('ascii').decode('unicode_escape')
+ True
+ >>> d['foo']
+ 'b\\a"r'
"""
- parts = _parseparam(';' + line)
+ parts = _parseparam(";" + line)
key = next(parts)
- pdict = {}
+ # decode_params treats first argument special, but we already stripped key
+ params = [("Dummy", "value")]
for p in parts:
- i = p.find('=')
+ i = p.find("=")
if i >= 0:
name = p[:i].strip().lower()
- value = p[i + 1:].strip()
- if len(value) >= 2 and value[0] == value[-1] == '"':
- value = value[1:-1]
- value = value.replace('\\\\', '\\').replace('\\"', '"')
- pdict[name] = value
- else:
- pdict[p] = None
+ value = p[i + 1 :].strip()
+ params.append((name, native_str(value)))
+ decoded_params = email.utils.decode_params(params)
+ decoded_params.pop(0) # get rid of the dummy again
+ pdict = {}
+ for name, decoded_value in decoded_params:
+ value = email.utils.collapse_rfc2231_value(decoded_value)
+ if len(value) >= 2 and value[0] == '"' and value[-1] == '"':
+ value = value[1:-1]
+ pdict[name] = value
return key, pdict
-def _encode_header(key, pdict):
+def _encode_header(key: str, pdict: Dict[str, str]) -> str:
"""Inverse of _parse_header.
>>> _encode_header('permessage-deflate',
@@ -935,33 +994,54 @@ def _encode_header(key, pdict):
out.append(k)
else:
# TODO: quote if necessary.
- out.append('%s=%s' % (k, v))
- return '; '.join(out)
+ out.append("%s=%s" % (k, v))
+ return "; ".join(out)
+
+
+def encode_username_password(
+ username: Union[str, bytes], password: Union[str, bytes]
+) -> bytes:
+ """Encodes a username/password pair in the format used by HTTP auth.
+
+ The return value is a byte string in the form ``username:password``.
+
+ .. versionadded:: 5.1
+ """
+ if isinstance(username, unicode_type):
+ username = unicodedata.normalize("NFC", username)
+ if isinstance(password, unicode_type):
+ password = unicodedata.normalize("NFC", password)
+ return utf8(username) + b":" + utf8(password)
def doctests():
+ # type: () -> unittest.TestSuite
import doctest
+
return doctest.DocTestSuite()
-def split_host_and_port(netloc):
+_netloc_re = re.compile(r"^(.+):(\d+)$")
+
+
+def split_host_and_port(netloc: str) -> Tuple[str, Optional[int]]:
"""Returns ``(host, port)`` tuple from ``netloc``.
Returned ``port`` will be ``None`` if not present.
.. versionadded:: 4.1
"""
- match = re.match(r'^(.+):(\d+)$', netloc)
+ match = _netloc_re.match(netloc)
if match:
host = match.group(1)
- port = int(match.group(2))
+ port = int(match.group(2)) # type: Optional[int]
else:
host = netloc
port = None
return (host, port)
-def qs_to_qsl(qs):
+def qs_to_qsl(qs: Dict[str, List[AnyStr]]) -> Iterable[Tuple[str, AnyStr]]:
"""Generator converting a result of ``parse_qs`` back to name-value pairs.
.. versionadded:: 5.0
@@ -973,10 +1053,10 @@ def qs_to_qsl(qs):
_OctalPatt = re.compile(r"\\[0-3][0-7][0-7]")
_QuotePatt = re.compile(r"[\\].")
-_nulljoin = ''.join
+_nulljoin = "".join
-def _unquote_cookie(str):
+def _unquote_cookie(s: str) -> str:
"""Handle double quotes and escaping in cookie values.
This method is copied verbatim from the Python 3.5 standard
@@ -985,29 +1065,29 @@ def _unquote_cookie(str):
"""
# If there aren't any doublequotes,
# then there can't be any special characters. See RFC 2109.
- if str is None or len(str) < 2:
- return str
- if str[0] != '"' or str[-1] != '"':
- return str
+ if s is None or len(s) < 2:
+ return s
+ if s[0] != '"' or s[-1] != '"':
+ return s
# We have to assume that we must decode this string.
# Down to work.
# Remove the "s
- str = str[1:-1]
+ s = s[1:-1]
# Check for special sequences. Examples:
# \012 --> \n
# \" --> "
#
i = 0
- n = len(str)
+ n = len(s)
res = []
while 0 <= i < n:
- o_match = _OctalPatt.search(str, i)
- q_match = _QuotePatt.search(str, i)
- if not o_match and not q_match: # Neither matched
- res.append(str[i:])
+ o_match = _OctalPatt.search(s, i)
+ q_match = _QuotePatt.search(s, i)
+ if not o_match and not q_match: # Neither matched
+ res.append(s[i:])
break
# else:
j = k = -1
@@ -1015,18 +1095,18 @@ def _unquote_cookie(str):
j = o_match.start(0)
if q_match:
k = q_match.start(0)
- if q_match and (not o_match or k < j): # QuotePatt matched
- res.append(str[i:k])
- res.append(str[k + 1])
+ if q_match and (not o_match or k < j): # QuotePatt matched
+ res.append(s[i:k])
+ res.append(s[k + 1])
i = k + 2
- else: # OctalPatt matched
- res.append(str[i:j])
- res.append(chr(int(str[j + 1:j + 4], 8)))
+ else: # OctalPatt matched
+ res.append(s[i:j])
+ res.append(chr(int(s[j + 1 : j + 4], 8)))
i = j + 4
return _nulljoin(res)
-def parse_cookie(cookie):
+def parse_cookie(cookie: str) -> Dict[str, str]:
"""Parse a ``Cookie`` HTTP header into a dict of name/value pairs.
This function attempts to mimic browser cookie parsing behavior;
@@ -1038,13 +1118,13 @@ def parse_cookie(cookie):
.. versionadded:: 4.4.2
"""
cookiedict = {}
- for chunk in cookie.split(str(';')):
- if str('=') in chunk:
- key, val = chunk.split(str('='), 1)
+ for chunk in cookie.split(str(";")):
+ if str("=") in chunk:
+ key, val = chunk.split(str("="), 1)
else:
# Assume an empty name per
# https://bugzilla.mozilla.org/show_bug.cgi?id=169091
- key, val = str(''), chunk
+ key, val = str(""), chunk
key, val = key.strip(), val.strip()
if key or val:
# unquote using Python's algorithm.
diff --git a/tornado/ioloop.py b/tornado/ioloop.py
index 55623f34e4..2cf884450c 100644
--- a/tornado/ioloop.py
+++ b/tornado/ioloop.py
@@ -15,7 +15,11 @@
"""An I/O event loop for non-blocking sockets.
-On Python 3, `.IOLoop` is a wrapper around the `asyncio` event loop.
+In Tornado 6.0, `.IOLoop` is a wrapper around the `asyncio` event
+loop, with a slightly different interface for historical reasons.
+Applications can use either the `.IOLoop` interface or the underlying
+`asyncio` event loop directly (unless compatibility with older
+versions of Tornado is desired, in which case `.IOLoop` must be used).
Typical applications will use a single `IOLoop` object, accessed via
`IOLoop.current` class method. The `IOLoop.start` method (or
@@ -24,73 +28,58 @@
may use more than one `IOLoop`, such as one `IOLoop` per thread, or
per `unittest` case.
-In addition to I/O events, the `IOLoop` can also schedule time-based
-events. `IOLoop.add_timeout` is a non-blocking alternative to
-`time.sleep`.
-
"""
-from __future__ import absolute_import, division, print_function
-
-import collections
+import asyncio
+import concurrent.futures
import datetime
-import errno
import functools
-import heapq
-import itertools
import logging
import numbers
import os
-import select
import sys
-import threading
import time
-import traceback
import math
import random
-from tornado.concurrent import Future, is_future, chain_future, future_set_exc_info, future_add_done_callback # noqa: E501
-from tornado.log import app_log, gen_log
-from tornado.platform.auto import set_close_exec, Waker
-from tornado import stack_context
-from tornado.util import (
- PY3, Configurable, errno_from_exception, timedelta_to_seconds,
- TimeoutError, unicode_type, import_object,
+from tornado.concurrent import (
+ Future,
+ is_future,
+ chain_future,
+ future_set_exc_info,
+ future_add_done_callback,
)
+from tornado.log import app_log
+from tornado.util import Configurable, TimeoutError, import_object
-try:
- import signal
-except ImportError:
- signal = None
+import typing
+from typing import Union, Any, Type, Optional, Callable, TypeVar, Tuple, Awaitable
-try:
- from concurrent.futures import ThreadPoolExecutor
-except ImportError:
- ThreadPoolExecutor = None
+if typing.TYPE_CHECKING:
+ from typing import Dict, List # noqa: F401
-if PY3:
- import _thread as thread
+ from typing_extensions import Protocol
else:
- import thread
+ Protocol = object
-try:
- import asyncio
-except ImportError:
- asyncio = None
+class _Selectable(Protocol):
+ def fileno(self) -> int:
+ pass
+
+ def close(self) -> None:
+ pass
-_POLL_TIMEOUT = 3600.0
+
+_T = TypeVar("_T")
+_S = TypeVar("_S", bound=_Selectable)
class IOLoop(Configurable):
- """A level-triggered I/O loop.
+ """An I/O event loop.
- On Python 3, `IOLoop` is a wrapper around the `asyncio` event
- loop. On Python 2, it uses ``epoll`` (Linux) or ``kqueue`` (BSD
- and Mac OS X) if they are available, or else we fall back on
- select(). If you are implementing a system that needs to handle
- thousands of simultaneous connections, you should use a system
- that supports either ``epoll`` or ``kqueue``.
+ As of Tornado 6.0, `IOLoop` is a wrapper around the `asyncio` event
+ loop.
Example usage for a simple TCP server:
@@ -101,25 +90,22 @@ class IOLoop(Configurable):
import socket
import tornado.ioloop
- from tornado import gen
from tornado.iostream import IOStream
- @gen.coroutine
- def handle_connection(connection, address):
+ async def handle_connection(connection, address):
stream = IOStream(connection)
- message = yield stream.read_until_close()
+ message = await stream.read_until_close()
print("message from client:", message.decode().strip())
def connection_ready(sock, fd, events):
while True:
try:
connection, address = sock.accept()
- except socket.error as e:
- if e.args[0] not in (errno.EWOULDBLOCK, errno.EAGAIN):
- raise
+ except BlockingIOError:
return
connection.setblocking(0)
- handle_connection(connection, address)
+ io_loop = tornado.ioloop.IOLoop.current()
+ io_loop.spawn_callback(handle_connection, connection, address)
if __name__ == '__main__':
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
@@ -165,42 +151,33 @@ def connection_ready(sock, fd, events):
to redundantly specify the `asyncio` event loop.
"""
- # Constants from the epoll module
- _EPOLLIN = 0x001
- _EPOLLPRI = 0x002
- _EPOLLOUT = 0x004
- _EPOLLERR = 0x008
- _EPOLLHUP = 0x010
- _EPOLLRDHUP = 0x2000
- _EPOLLONESHOT = (1 << 30)
- _EPOLLET = (1 << 31)
-
- # Our events map exactly to the epoll events
- NONE = 0
- READ = _EPOLLIN
- WRITE = _EPOLLOUT
- ERROR = _EPOLLERR | _EPOLLHUP
- # In Python 2, _current.instance points to the current IOLoop.
- _current = threading.local()
+ # These constants were originally based on constants from the epoll module.
+ NONE = 0
+ READ = 0x001
+ WRITE = 0x004
+ ERROR = 0x018
# In Python 3, _ioloop_for_asyncio maps from asyncio loops to IOLoops.
- _ioloop_for_asyncio = dict()
+ _ioloop_for_asyncio = dict() # type: Dict[asyncio.AbstractEventLoop, IOLoop]
@classmethod
- def configure(cls, impl, **kwargs):
+ def configure(
+ cls, impl: "Union[None, str, Type[Configurable]]", **kwargs: Any
+ ) -> None:
if asyncio is not None:
from tornado.platform.asyncio import BaseAsyncIOLoop
- if isinstance(impl, (str, unicode_type)):
+ if isinstance(impl, str):
impl = import_object(impl)
- if not issubclass(impl, BaseAsyncIOLoop):
+ if isinstance(impl, type) and not issubclass(impl, BaseAsyncIOLoop):
raise RuntimeError(
- "only AsyncIOLoop is allowed when asyncio is available")
+ "only AsyncIOLoop is allowed when asyncio is available"
+ )
super(IOLoop, cls).configure(impl, **kwargs)
@staticmethod
- def instance():
+ def instance() -> "IOLoop":
"""Deprecated alias for `IOLoop.current()`.
.. versionchanged:: 5.0
@@ -221,7 +198,7 @@ def instance():
"""
return IOLoop.current()
- def install(self):
+ def install(self) -> None:
"""Deprecated alias for `make_current()`.
.. versionchanged:: 5.0
@@ -236,7 +213,7 @@ def install(self):
self.make_current()
@staticmethod
- def clear_instance():
+ def clear_instance() -> None:
"""Deprecated alias for `clear_current()`.
.. versionchanged:: 5.0
@@ -251,8 +228,18 @@ def clear_instance():
"""
IOLoop.clear_current()
+ @typing.overload
@staticmethod
- def current(instance=True):
+ def current() -> "IOLoop":
+ pass
+
+ @typing.overload
+ @staticmethod
+ def current(instance: bool = True) -> Optional["IOLoop"]: # noqa: F811
+ pass
+
+ @staticmethod
+ def current(instance: bool = True) -> Optional["IOLoop"]: # noqa: F811
"""Returns the current thread's `IOLoop`.
If an `IOLoop` is currently running or has been marked as
@@ -272,30 +259,24 @@ def current(instance=True):
since even if we do not create an `IOLoop`, this method
may initialize the asyncio loop.
"""
- if asyncio is None:
- current = getattr(IOLoop._current, "instance", None)
- if current is None and instance:
- current = IOLoop()
- if IOLoop._current.instance is not current:
- raise RuntimeError("new IOLoop did not become current")
- else:
- try:
- loop = asyncio.get_event_loop()
- except (RuntimeError, AssertionError):
- if not instance:
- return None
- raise
- try:
- return IOLoop._ioloop_for_asyncio[loop]
- except KeyError:
- if instance:
- from tornado.platform.asyncio import AsyncIOMainLoop
- current = AsyncIOMainLoop(make_current=True)
- else:
- current = None
+ try:
+ loop = asyncio.get_event_loop()
+ except (RuntimeError, AssertionError):
+ if not instance:
+ return None
+ raise
+ try:
+ return IOLoop._ioloop_for_asyncio[loop]
+ except KeyError:
+ if instance:
+ from tornado.platform.asyncio import AsyncIOMainLoop
+
+ current = AsyncIOMainLoop(make_current=True) # type: Optional[IOLoop]
+ else:
+ current = None
return current
- def make_current(self):
+ def make_current(self) -> None:
"""Makes this the `IOLoop` for the current thread.
An `IOLoop` automatically becomes current for its thread
@@ -312,14 +293,10 @@ def make_current(self):
This method also sets the current `asyncio` event loop.
"""
# The asyncio event loops override this method.
- assert asyncio is None
- old = getattr(IOLoop._current, "instance", None)
- if old is not None:
- old.clear_current()
- IOLoop._current.instance = self
+ raise NotImplementedError()
@staticmethod
- def clear_current():
+ def clear_current() -> None:
"""Clears the `IOLoop` for the current thread.
Intended primarily for use by test frameworks in between tests.
@@ -333,7 +310,7 @@ def clear_current():
if asyncio is None:
IOLoop._current.instance = None
- def _clear_current_hook(self):
+ def _clear_current_hook(self) -> None:
"""Instance method called when an IOLoop ceases to be current.
May be overridden by subclasses as a counterpart to make_current.
@@ -341,17 +318,16 @@ def _clear_current_hook(self):
pass
@classmethod
- def configurable_base(cls):
+ def configurable_base(cls) -> Type[Configurable]:
return IOLoop
@classmethod
- def configurable_default(cls):
- if asyncio is not None:
- from tornado.platform.asyncio import AsyncIOLoop
- return AsyncIOLoop
- return PollIOLoop
+ def configurable_default(cls) -> Type[Configurable]:
+ from tornado.platform.asyncio import AsyncIOLoop
+
+ return AsyncIOLoop
- def initialize(self, make_current=None):
+ def initialize(self, make_current: Optional[bool] = None) -> None:
if make_current is None:
if IOLoop.current(instance=False) is None:
self.make_current()
@@ -362,7 +338,7 @@ def initialize(self, make_current=None):
raise RuntimeError("current IOLoop already exists")
self.make_current()
- def close(self, all_fds=False):
+ def close(self, all_fds: bool = False) -> None:
"""Closes the `IOLoop`, freeing any resources used.
If ``all_fds`` is true, all file descriptors registered on the
@@ -389,13 +365,25 @@ def close(self, all_fds=False):
"""
raise NotImplementedError()
- def add_handler(self, fd, handler, events):
+ @typing.overload
+ def add_handler(
+ self, fd: int, handler: Callable[[int, int], None], events: int
+ ) -> None:
+ pass
+
+ @typing.overload # noqa: F811
+ def add_handler(
+ self, fd: _S, handler: Callable[[_S, int], None], events: int
+ ) -> None:
+ pass
+
+ def add_handler( # noqa: F811
+ self, fd: Union[int, _Selectable], handler: Callable[..., None], events: int
+ ) -> None:
"""Registers the given handler to receive the given events for ``fd``.
The ``fd`` argument may either be an integer file descriptor or
- a file-like object with a ``fileno()`` method (and optionally a
- ``close()`` method, which may be called when the `IOLoop` is shut
- down).
+ a file-like object with a ``fileno()`` and ``close()`` method.
The ``events`` argument is a bitwise or of the constants
``IOLoop.READ``, ``IOLoop.WRITE``, and ``IOLoop.ERROR``.
@@ -408,7 +396,7 @@ def add_handler(self, fd, handler, events):
"""
raise NotImplementedError()
- def update_handler(self, fd, events):
+ def update_handler(self, fd: Union[int, _Selectable], events: int) -> None:
"""Changes the events we listen for ``fd``.
.. versionchanged:: 4.0
@@ -417,7 +405,7 @@ def update_handler(self, fd, events):
"""
raise NotImplementedError()
- def remove_handler(self, fd):
+ def remove_handler(self, fd: Union[int, _Selectable]) -> None:
"""Stop listening for events on ``fd``.
.. versionchanged:: 4.0
@@ -426,55 +414,7 @@ def remove_handler(self, fd):
"""
raise NotImplementedError()
- def set_blocking_signal_threshold(self, seconds, action):
- """Sends a signal if the `IOLoop` is blocked for more than
- ``s`` seconds.
-
- Pass ``seconds=None`` to disable. Requires Python 2.6 on a unixy
- platform.
-
- The action parameter is a Python signal handler. Read the
- documentation for the `signal` module for more information.
- If ``action`` is None, the process will be killed if it is
- blocked for too long.
-
- .. deprecated:: 5.0
-
- Not implemented on the `asyncio` event loop. Use the environment
- variable ``PYTHONASYNCIODEBUG=1`` instead. This method will be
- removed in Tornado 6.0.
- """
- raise NotImplementedError()
-
- def set_blocking_log_threshold(self, seconds):
- """Logs a stack trace if the `IOLoop` is blocked for more than
- ``s`` seconds.
-
- Equivalent to ``set_blocking_signal_threshold(seconds,
- self.log_stack)``
-
- .. deprecated:: 5.0
-
- Not implemented on the `asyncio` event loop. Use the environment
- variable ``PYTHONASYNCIODEBUG=1`` instead. This method will be
- removed in Tornado 6.0.
- """
- self.set_blocking_signal_threshold(seconds, self.log_stack)
-
- def log_stack(self, signal, frame):
- """Signal handler to log the stack trace of the current thread.
-
- For use with `set_blocking_signal_threshold`.
-
- .. deprecated:: 5.1
-
- This method will be removed in Tornado 6.0.
- """
- gen_log.warning('IOLoop blocked for %f seconds in\n%s',
- self._blocking_signal_threshold,
- ''.join(traceback.format_stack(frame)))
-
- def start(self):
+ def start(self) -> None:
"""Starts the I/O loop.
The loop will run until one of the callbacks calls `stop()`, which
@@ -482,7 +422,7 @@ def start(self):
"""
raise NotImplementedError()
- def _setup_logging(self):
+ def _setup_logging(self) -> None:
"""The IOLoop catches and logs exceptions, so it's
important that log output be visible. However, python's
default behavior for non-root loggers (prior to python
@@ -493,28 +433,21 @@ def _setup_logging(self):
This method should be called from start() in subclasses.
"""
- if not any([logging.getLogger().handlers,
- logging.getLogger('tornado').handlers,
- logging.getLogger('tornado.application').handlers]):
+ if not any(
+ [
+ logging.getLogger().handlers,
+ logging.getLogger("tornado").handlers,
+ logging.getLogger("tornado.application").handlers,
+ ]
+ ):
logging.basicConfig()
- def stop(self):
+ def stop(self) -> None:
"""Stop the I/O loop.
If the event loop is not currently running, the next call to `start()`
will return immediately.
- To use asynchronous methods from otherwise-synchronous code (such as
- unit tests), you can start and stop the event loop like this::
-
- ioloop = IOLoop()
- async_method(ioloop=ioloop, callback=ioloop.stop)
- ioloop.start()
-
- ``ioloop.start()`` will return after ``async_method`` has run
- its callback, whether that callback was invoked before or
- after ``ioloop.start``.
-
Note that even after `stop` has been called, the `IOLoop` is not
completely stopped until `IOLoop.start` has also returned.
Some work that was scheduled before the call to `stop` may still
@@ -522,13 +455,13 @@ def stop(self):
"""
raise NotImplementedError()
- def run_sync(self, func, timeout=None):
+ def run_sync(self, func: Callable, timeout: Optional[float] = None) -> Any:
"""Starts the `IOLoop`, runs the given function, and stops the loop.
- The function must return either a yieldable object or
- ``None``. If the function returns a yieldable object, the
- `IOLoop` will run until the yieldable is resolved (and
- `run_sync()` will return the yieldable's result). If it raises
+ The function must return either an awaitable object or
+ ``None``. If the function returns an awaitable object, the
+ `IOLoop` will run until the awaitable is resolved (and
+ `run_sync()` will return the awaitable's result). If it raises
an exception, the `IOLoop` will stop and the exception will be
re-raised to the caller.
@@ -536,73 +469,87 @@ def run_sync(self, func, timeout=None):
a maximum duration for the function. If the timeout expires,
a `tornado.util.TimeoutError` is raised.
- This method is useful in conjunction with `tornado.gen.coroutine`
- to allow asynchronous calls in a ``main()`` function::
+ This method is useful to allow asynchronous calls in a
+ ``main()`` function::
- @gen.coroutine
- def main():
+ async def main():
# do stuff...
if __name__ == '__main__':
IOLoop.current().run_sync(main)
.. versionchanged:: 4.3
- Returning a non-``None``, non-yieldable value is now an error.
+ Returning a non-``None``, non-awaitable value is now an error.
.. versionchanged:: 5.0
If a timeout occurs, the ``func`` coroutine will be cancelled.
+
"""
- future_cell = [None]
+ future_cell = [None] # type: List[Optional[Future]]
- def run():
+ def run() -> None:
try:
result = func()
if result is not None:
from tornado.gen import convert_yielded
+
result = convert_yielded(result)
except Exception:
- future_cell[0] = Future()
- future_set_exc_info(future_cell[0], sys.exc_info())
+ fut = Future() # type: Future[Any]
+ future_cell[0] = fut
+ future_set_exc_info(fut, sys.exc_info())
else:
if is_future(result):
future_cell[0] = result
else:
- future_cell[0] = Future()
- future_cell[0].set_result(result)
+ fut = Future()
+ future_cell[0] = fut
+ fut.set_result(result)
+ assert future_cell[0] is not None
self.add_future(future_cell[0], lambda future: self.stop())
+
self.add_callback(run)
if timeout is not None:
- def timeout_callback():
+
+ def timeout_callback() -> None:
# If we can cancel the future, do so and wait on it. If not,
# Just stop the loop and return with the task still pending.
# (If we neither cancel nor wait for the task, a warning
# will be logged).
+ assert future_cell[0] is not None
if not future_cell[0].cancel():
self.stop()
+
timeout_handle = self.add_timeout(self.time() + timeout, timeout_callback)
self.start()
if timeout is not None:
self.remove_timeout(timeout_handle)
+ assert future_cell[0] is not None
if future_cell[0].cancelled() or not future_cell[0].done():
- raise TimeoutError('Operation timed out after %s seconds' % timeout)
+ raise TimeoutError("Operation timed out after %s seconds" % timeout)
return future_cell[0].result()
- def time(self):
+ def time(self) -> float:
"""Returns the current time according to the `IOLoop`'s clock.
The return value is a floating-point number relative to an
unspecified time in the past.
- By default, the `IOLoop`'s time function is `time.time`. However,
- it may be configured to use e.g. `time.monotonic` instead.
- Calls to `add_timeout` that pass a number instead of a
- `datetime.timedelta` should use this function to compute the
- appropriate time, so they can work no matter what time function
- is chosen.
+ Historically, the IOLoop could be customized to use e.g.
+ `time.monotonic` instead of `time.time`, but this is not
+ currently supported and so this method is equivalent to
+ `time.time`.
+
"""
return time.time()
- def add_timeout(self, deadline, callback, *args, **kwargs):
+ def add_timeout(
+ self,
+ deadline: Union[float, datetime.timedelta],
+ callback: Callable[..., None],
+ *args: Any,
+ **kwargs: Any
+ ) -> object:
"""Runs the ``callback`` at the time ``deadline`` from the I/O loop.
Returns an opaque handle that may be passed to
@@ -631,12 +578,15 @@ def add_timeout(self, deadline, callback, *args, **kwargs):
if isinstance(deadline, numbers.Real):
return self.call_at(deadline, callback, *args, **kwargs)
elif isinstance(deadline, datetime.timedelta):
- return self.call_at(self.time() + timedelta_to_seconds(deadline),
- callback, *args, **kwargs)
+ return self.call_at(
+ self.time() + deadline.total_seconds(), callback, *args, **kwargs
+ )
else:
raise TypeError("Unsupported deadline %r" % deadline)
- def call_later(self, delay, callback, *args, **kwargs):
+ def call_later(
+ self, delay: float, callback: Callable[..., None], *args: Any, **kwargs: Any
+ ) -> object:
"""Runs the ``callback`` after ``delay`` seconds have passed.
Returns an opaque handle that may be passed to `remove_timeout`
@@ -649,7 +599,9 @@ def call_later(self, delay, callback, *args, **kwargs):
"""
return self.call_at(self.time() + delay, callback, *args, **kwargs)
- def call_at(self, when, callback, *args, **kwargs):
+ def call_at(
+ self, when: float, callback: Callable[..., None], *args: Any, **kwargs: Any
+ ) -> object:
"""Runs the ``callback`` at the absolute time designated by ``when``.
``when`` must be a number using the same reference point as
@@ -665,7 +617,7 @@ def call_at(self, when, callback, *args, **kwargs):
"""
return self.add_timeout(when, callback, *args, **kwargs)
- def remove_timeout(self, timeout):
+ def remove_timeout(self, timeout: object) -> None:
"""Cancels a pending timeout.
The argument is a handle as returned by `add_timeout`. It is
@@ -674,7 +626,7 @@ def remove_timeout(self, timeout):
"""
raise NotImplementedError()
- def add_callback(self, callback, *args, **kwargs):
+ def add_callback(self, callback: Callable, *args: Any, **kwargs: Any) -> None:
"""Calls the given callback on the next I/O loop iteration.
It is safe to call this method from any thread at any time,
@@ -689,44 +641,66 @@ def add_callback(self, callback, *args, **kwargs):
"""
raise NotImplementedError()
- def add_callback_from_signal(self, callback, *args, **kwargs):
+ def add_callback_from_signal(
+ self, callback: Callable, *args: Any, **kwargs: Any
+ ) -> None:
"""Calls the given callback on the next I/O loop iteration.
Safe for use from a Python signal handler; should not be used
otherwise.
-
- Callbacks added with this method will be run without any
- `.stack_context`, to avoid picking up the context of the function
- that was interrupted by the signal.
"""
raise NotImplementedError()
- def spawn_callback(self, callback, *args, **kwargs):
+ def spawn_callback(self, callback: Callable, *args: Any, **kwargs: Any) -> None:
"""Calls the given callback on the next IOLoop iteration.
- Unlike all other callback-related methods on IOLoop,
- ``spawn_callback`` does not associate the callback with its caller's
- ``stack_context``, so it is suitable for fire-and-forget callbacks
- that should not interfere with the caller.
+ As of Tornado 6.0, this method is equivalent to `add_callback`.
.. versionadded:: 4.0
"""
- with stack_context.NullContext():
- self.add_callback(callback, *args, **kwargs)
+ self.add_callback(callback, *args, **kwargs)
- def add_future(self, future, callback):
+ def add_future(
+ self,
+ future: "Union[Future[_T], concurrent.futures.Future[_T]]",
+ callback: Callable[["Future[_T]"], None],
+ ) -> None:
"""Schedules a callback on the ``IOLoop`` when the given
`.Future` is finished.
The callback is invoked with one argument, the
`.Future`.
- """
- assert is_future(future)
- callback = stack_context.wrap(callback)
- future_add_done_callback(
- future, lambda future: self.add_callback(callback, future))
- def run_in_executor(self, executor, func, *args):
+ This method only accepts `.Future` objects and not other
+ awaitables (unlike most of Tornado where the two are
+ interchangeable).
+ """
+ if isinstance(future, Future):
+ # Note that we specifically do not want the inline behavior of
+ # tornado.concurrent.future_add_done_callback. We always want
+ # this callback scheduled on the next IOLoop iteration (which
+ # asyncio.Future always does).
+ #
+ # Wrap the callback in self._run_callback so we control
+ # the error logging (i.e. it goes to tornado.log.app_log
+ # instead of asyncio's log).
+ future.add_done_callback(
+ lambda f: self._run_callback(functools.partial(callback, future))
+ )
+ else:
+ assert is_future(future)
+ # For concurrent futures, we use self.add_callback, so
+ # it's fine if future_add_done_callback inlines that call.
+ future_add_done_callback(
+ future, lambda f: self.add_callback(callback, future)
+ )
+
+ def run_in_executor(
+ self,
+ executor: Optional[concurrent.futures.Executor],
+ func: Callable[..., _T],
+ *args: Any
+ ) -> Awaitable[_T]:
"""Runs a function in a ``concurrent.futures.Executor``. If
``executor`` is ``None``, the IO loop's default executor will be used.
@@ -734,38 +708,40 @@ def run_in_executor(self, executor, func, *args):
.. versionadded:: 5.0
"""
- if ThreadPoolExecutor is None:
- raise RuntimeError(
- "concurrent.futures is required to use IOLoop.run_in_executor")
-
if executor is None:
- if not hasattr(self, '_executor'):
+ if not hasattr(self, "_executor"):
from tornado.process import cpu_count
- self._executor = ThreadPoolExecutor(max_workers=(cpu_count() * 5))
+
+ self._executor = concurrent.futures.ThreadPoolExecutor(
+ max_workers=(cpu_count() * 5)
+ ) # type: concurrent.futures.Executor
executor = self._executor
c_future = executor.submit(func, *args)
# Concurrent Futures are not usable with await. Wrap this in a
# Tornado Future instead, using self.add_future for thread-safety.
- t_future = Future()
+ t_future = Future() # type: Future[_T]
self.add_future(c_future, lambda f: chain_future(f, t_future))
return t_future
- def set_default_executor(self, executor):
+ def set_default_executor(self, executor: concurrent.futures.Executor) -> None:
"""Sets the default executor to use with :meth:`run_in_executor`.
.. versionadded:: 5.0
"""
self._executor = executor
- def _run_callback(self, callback):
+ def _run_callback(self, callback: Callable[[], Any]) -> None:
"""Runs a callback with error handling.
- For use in subclasses.
+ .. versionchanged:: 6.0
+
+ CancelledErrors are no longer logged.
"""
try:
ret = callback()
if ret is not None:
from tornado import gen
+
# Functions that return Futures typically swallow all
# exceptions and store them in the Future. If a Future
# makes it out to the IOLoop, ensure its exception (if any)
@@ -779,395 +755,84 @@ def _run_callback(self, callback):
pass
else:
self.add_future(ret, self._discard_future_result)
+ except asyncio.CancelledError:
+ pass
except Exception:
- self.handle_callback_exception(callback)
+ app_log.error("Exception in callback %r", callback, exc_info=True)
- def _discard_future_result(self, future):
+ def _discard_future_result(self, future: Future) -> None:
"""Avoid unhandled-exception warnings from spawned coroutines."""
future.result()
- def handle_callback_exception(self, callback):
- """This method is called whenever a callback run by the `IOLoop`
- throws an exception.
-
- By default simply logs the exception as an error. Subclasses
- may override this method to customize reporting of exceptions.
-
- The exception itself is not passed explicitly, but is available
- in `sys.exc_info`.
-
- .. versionchanged:: 5.0
-
- When the `asyncio` event loop is used (which is now the
- default on Python 3), some callback errors will be handled by
- `asyncio` instead of this method.
-
- .. deprecated: 5.1
-
- Support for this method will be removed in Tornado 6.0.
- """
- app_log.error("Exception in callback %r", callback, exc_info=True)
-
- def split_fd(self, fd):
- """Returns an (fd, obj) pair from an ``fd`` parameter.
-
- We accept both raw file descriptors and file-like objects as
- input to `add_handler` and related methods. When a file-like
- object is passed, we must retain the object itself so we can
- close it correctly when the `IOLoop` shuts down, but the
- poller interfaces favor file descriptors (they will accept
- file-like objects and call ``fileno()`` for you, but they
- always return the descriptor itself).
-
- This method is provided for use by `IOLoop` subclasses and should
- not generally be used by application code.
-
- .. versionadded:: 4.0
- """
- try:
- return fd.fileno(), fd
- except AttributeError:
+ def split_fd(
+ self, fd: Union[int, _Selectable]
+ ) -> Tuple[int, Union[int, _Selectable]]:
+ # """Returns an (fd, obj) pair from an ``fd`` parameter.
+
+ # We accept both raw file descriptors and file-like objects as
+ # input to `add_handler` and related methods. When a file-like
+ # object is passed, we must retain the object itself so we can
+ # close it correctly when the `IOLoop` shuts down, but the
+ # poller interfaces favor file descriptors (they will accept
+ # file-like objects and call ``fileno()`` for you, but they
+ # always return the descriptor itself).
+
+ # This method is provided for use by `IOLoop` subclasses and should
+ # not generally be used by application code.
+
+ # .. versionadded:: 4.0
+ # """
+ if isinstance(fd, int):
return fd, fd
+ return fd.fileno(), fd
- def close_fd(self, fd):
- """Utility method to close an ``fd``.
+ def close_fd(self, fd: Union[int, _Selectable]) -> None:
+ # """Utility method to close an ``fd``.
- If ``fd`` is a file-like object, we close it directly; otherwise
- we use `os.close`.
+ # If ``fd`` is a file-like object, we close it directly; otherwise
+ # we use `os.close`.
- This method is provided for use by `IOLoop` subclasses (in
- implementations of ``IOLoop.close(all_fds=True)`` and should
- not generally be used by application code.
+ # This method is provided for use by `IOLoop` subclasses (in
+ # implementations of ``IOLoop.close(all_fds=True)`` and should
+ # not generally be used by application code.
- .. versionadded:: 4.0
- """
+ # .. versionadded:: 4.0
+ # """
try:
- try:
- fd.close()
- except AttributeError:
+ if isinstance(fd, int):
os.close(fd)
+ else:
+ fd.close()
except OSError:
pass
-class PollIOLoop(IOLoop):
- """Base class for IOLoops built around a select-like function.
-
- For concrete implementations, see `tornado.platform.epoll.EPollIOLoop`
- (Linux), `tornado.platform.kqueue.KQueueIOLoop` (BSD and Mac), or
- `tornado.platform.select.SelectIOLoop` (all platforms).
- """
- def initialize(self, impl, time_func=None, **kwargs):
- super(PollIOLoop, self).initialize(**kwargs)
- self._impl = impl
- if hasattr(self._impl, 'fileno'):
- set_close_exec(self._impl.fileno())
- self.time_func = time_func or time.time
- self._handlers = {}
- self._events = {}
- self._callbacks = collections.deque()
- self._timeouts = []
- self._cancellations = 0
- self._running = False
- self._stopped = False
- self._closing = False
- self._thread_ident = None
- self._pid = os.getpid()
- self._blocking_signal_threshold = None
- self._timeout_counter = itertools.count()
-
- # Create a pipe that we send bogus data to when we want to wake
- # the I/O loop when it is idle
- self._waker = Waker()
- self.add_handler(self._waker.fileno(),
- lambda fd, events: self._waker.consume(),
- self.READ)
-
- @classmethod
- def configurable_base(cls):
- return PollIOLoop
-
- @classmethod
- def configurable_default(cls):
- if hasattr(select, "epoll"):
- from tornado.platform.epoll import EPollIOLoop
- return EPollIOLoop
- if hasattr(select, "kqueue"):
- # Python 2.6+ on BSD or Mac
- from tornado.platform.kqueue import KQueueIOLoop
- return KQueueIOLoop
- from tornado.platform.select import SelectIOLoop
- return SelectIOLoop
-
- def close(self, all_fds=False):
- self._closing = True
- self.remove_handler(self._waker.fileno())
- if all_fds:
- for fd, handler in list(self._handlers.values()):
- self.close_fd(fd)
- self._waker.close()
- self._impl.close()
- self._callbacks = None
- self._timeouts = None
- if hasattr(self, '_executor'):
- self._executor.shutdown()
-
- def add_handler(self, fd, handler, events):
- fd, obj = self.split_fd(fd)
- self._handlers[fd] = (obj, stack_context.wrap(handler))
- self._impl.register(fd, events | self.ERROR)
-
- def update_handler(self, fd, events):
- fd, obj = self.split_fd(fd)
- self._impl.modify(fd, events | self.ERROR)
-
- def remove_handler(self, fd):
- fd, obj = self.split_fd(fd)
- self._handlers.pop(fd, None)
- self._events.pop(fd, None)
- try:
- self._impl.unregister(fd)
- except Exception:
- gen_log.debug("Error deleting fd from IOLoop", exc_info=True)
-
- def set_blocking_signal_threshold(self, seconds, action):
- if not hasattr(signal, "setitimer"):
- gen_log.error("set_blocking_signal_threshold requires a signal module "
- "with the setitimer method")
- return
- self._blocking_signal_threshold = seconds
- if seconds is not None:
- signal.signal(signal.SIGALRM,
- action if action is not None else signal.SIG_DFL)
-
- def start(self):
- if self._running:
- raise RuntimeError("IOLoop is already running")
- if os.getpid() != self._pid:
- raise RuntimeError("Cannot share PollIOLoops across processes")
- self._setup_logging()
- if self._stopped:
- self._stopped = False
- return
- old_current = IOLoop.current(instance=False)
- if old_current is not self:
- self.make_current()
- self._thread_ident = thread.get_ident()
- self._running = True
-
- # signal.set_wakeup_fd closes a race condition in event loops:
- # a signal may arrive at the beginning of select/poll/etc
- # before it goes into its interruptible sleep, so the signal
- # will be consumed without waking the select. The solution is
- # for the (C, synchronous) signal handler to write to a pipe,
- # which will then be seen by select.
- #
- # In python's signal handling semantics, this only matters on the
- # main thread (fortunately, set_wakeup_fd only works on the main
- # thread and will raise a ValueError otherwise).
- #
- # If someone has already set a wakeup fd, we don't want to
- # disturb it. This is an issue for twisted, which does its
- # SIGCHLD processing in response to its own wakeup fd being
- # written to. As long as the wakeup fd is registered on the IOLoop,
- # the loop will still wake up and everything should work.
- old_wakeup_fd = None
- if hasattr(signal, 'set_wakeup_fd') and os.name == 'posix':
- # requires python 2.6+, unix. set_wakeup_fd exists but crashes
- # the python process on windows.
- try:
- old_wakeup_fd = signal.set_wakeup_fd(self._waker.write_fileno())
- if old_wakeup_fd != -1:
- # Already set, restore previous value. This is a little racy,
- # but there's no clean get_wakeup_fd and in real use the
- # IOLoop is just started once at the beginning.
- signal.set_wakeup_fd(old_wakeup_fd)
- old_wakeup_fd = None
- except ValueError:
- # Non-main thread, or the previous value of wakeup_fd
- # is no longer valid.
- old_wakeup_fd = None
-
- try:
- while True:
- # Prevent IO event starvation by delaying new callbacks
- # to the next iteration of the event loop.
- ncallbacks = len(self._callbacks)
-
- # Add any timeouts that have come due to the callback list.
- # Do not run anything until we have determined which ones
- # are ready, so timeouts that call add_timeout cannot
- # schedule anything in this iteration.
- due_timeouts = []
- if self._timeouts:
- now = self.time()
- while self._timeouts:
- if self._timeouts[0].callback is None:
- # The timeout was cancelled. Note that the
- # cancellation check is repeated below for timeouts
- # that are cancelled by another timeout or callback.
- heapq.heappop(self._timeouts)
- self._cancellations -= 1
- elif self._timeouts[0].deadline <= now:
- due_timeouts.append(heapq.heappop(self._timeouts))
- else:
- break
- if (self._cancellations > 512 and
- self._cancellations > (len(self._timeouts) >> 1)):
- # Clean up the timeout queue when it gets large and it's
- # more than half cancellations.
- self._cancellations = 0
- self._timeouts = [x for x in self._timeouts
- if x.callback is not None]
- heapq.heapify(self._timeouts)
-
- for i in range(ncallbacks):
- self._run_callback(self._callbacks.popleft())
- for timeout in due_timeouts:
- if timeout.callback is not None:
- self._run_callback(timeout.callback)
- # Closures may be holding on to a lot of memory, so allow
- # them to be freed before we go into our poll wait.
- due_timeouts = timeout = None
-
- if self._callbacks:
- # If any callbacks or timeouts called add_callback,
- # we don't want to wait in poll() before we run them.
- poll_timeout = 0.0
- elif self._timeouts:
- # If there are any timeouts, schedule the first one.
- # Use self.time() instead of 'now' to account for time
- # spent running callbacks.
- poll_timeout = self._timeouts[0].deadline - self.time()
- poll_timeout = max(0, min(poll_timeout, _POLL_TIMEOUT))
- else:
- # No timeouts and no callbacks, so use the default.
- poll_timeout = _POLL_TIMEOUT
-
- if not self._running:
- break
-
- if self._blocking_signal_threshold is not None:
- # clear alarm so it doesn't fire while poll is waiting for
- # events.
- signal.setitimer(signal.ITIMER_REAL, 0, 0)
-
- try:
- event_pairs = self._impl.poll(poll_timeout)
- except Exception as e:
- # Depending on python version and IOLoop implementation,
- # different exception types may be thrown and there are
- # two ways EINTR might be signaled:
- # * e.errno == errno.EINTR
- # * e.args is like (errno.EINTR, 'Interrupted system call')
- if errno_from_exception(e) == errno.EINTR:
- continue
- else:
- raise
-
- if self._blocking_signal_threshold is not None:
- signal.setitimer(signal.ITIMER_REAL,
- self._blocking_signal_threshold, 0)
-
- # Pop one fd at a time from the set of pending fds and run
- # its handler. Since that handler may perform actions on
- # other file descriptors, there may be reentrant calls to
- # this IOLoop that modify self._events
- self._events.update(event_pairs)
- while self._events:
- fd, events = self._events.popitem()
- try:
- fd_obj, handler_func = self._handlers[fd]
- handler_func(fd_obj, events)
- except (OSError, IOError) as e:
- if errno_from_exception(e) == errno.EPIPE:
- # Happens when the client closes the connection
- pass
- else:
- self.handle_callback_exception(self._handlers.get(fd))
- except Exception:
- self.handle_callback_exception(self._handlers.get(fd))
- fd_obj = handler_func = None
-
- finally:
- # reset the stopped flag so another start/stop pair can be issued
- self._stopped = False
- if self._blocking_signal_threshold is not None:
- signal.setitimer(signal.ITIMER_REAL, 0, 0)
- if old_current is None:
- IOLoop.clear_current()
- elif old_current is not self:
- old_current.make_current()
- if old_wakeup_fd is not None:
- signal.set_wakeup_fd(old_wakeup_fd)
-
- def stop(self):
- self._running = False
- self._stopped = True
- self._waker.wake()
-
- def time(self):
- return self.time_func()
-
- def call_at(self, deadline, callback, *args, **kwargs):
- timeout = _Timeout(
- deadline,
- functools.partial(stack_context.wrap(callback), *args, **kwargs),
- self)
- heapq.heappush(self._timeouts, timeout)
- return timeout
-
- def remove_timeout(self, timeout):
- # Removing from a heap is complicated, so just leave the defunct
- # timeout object in the queue (see discussion in
- # http://docs.python.org/library/heapq.html).
- # If this turns out to be a problem, we could add a garbage
- # collection pass whenever there are too many dead timeouts.
- timeout.callback = None
- self._cancellations += 1
-
- def add_callback(self, callback, *args, **kwargs):
- if self._closing:
- return
- # Blindly insert into self._callbacks. This is safe even
- # from signal handlers because deque.append is atomic.
- self._callbacks.append(functools.partial(
- stack_context.wrap(callback), *args, **kwargs))
- if thread.get_ident() != self._thread_ident:
- # This will write one byte but Waker.consume() reads many
- # at once, so it's ok to write even when not strictly
- # necessary.
- self._waker.wake()
- else:
- # If we're on the IOLoop's thread, we don't need to wake anyone.
- pass
-
- def add_callback_from_signal(self, callback, *args, **kwargs):
- with stack_context.NullContext():
- self.add_callback(callback, *args, **kwargs)
-
-
class _Timeout(object):
"""An IOLoop timeout, a UNIX timestamp and a callback"""
# Reduce memory overhead when there are lots of pending callbacks
- __slots__ = ['deadline', 'callback', 'tdeadline']
+ __slots__ = ["deadline", "callback", "tdeadline"]
- def __init__(self, deadline, callback, io_loop):
+ def __init__(
+ self, deadline: float, callback: Callable[[], None], io_loop: IOLoop
+ ) -> None:
if not isinstance(deadline, numbers.Real):
raise TypeError("Unsupported deadline %r" % deadline)
self.deadline = deadline
self.callback = callback
- self.tdeadline = (deadline, next(io_loop._timeout_counter))
+ self.tdeadline = (
+ deadline,
+ next(io_loop._timeout_counter),
+ ) # type: Tuple[float, int]
# Comparison methods to sort by deadline, with object id as a tiebreaker
# to guarantee a consistent ordering. The heapq module uses __le__
# in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons
# use __lt__).
- def __lt__(self, other):
+ def __lt__(self, other: "_Timeout") -> bool:
return self.tdeadline < other.tdeadline
- def __le__(self, other):
+ def __le__(self, other: "_Timeout") -> bool:
return self.tdeadline <= other.tdeadline
@@ -1197,16 +862,19 @@ class PeriodicCallback(object):
.. versionchanged:: 5.1
The ``jitter`` argument is added.
"""
- def __init__(self, callback, callback_time, jitter=0):
+
+ def __init__(
+ self, callback: Callable[[], None], callback_time: float, jitter: float = 0
+ ) -> None:
self.callback = callback
if callback_time <= 0:
raise ValueError("Periodic callback must have a positive callback_time")
self.callback_time = callback_time
self.jitter = jitter
self._running = False
- self._timeout = None
+ self._timeout = None # type: object
- def start(self):
+ def start(self) -> None:
"""Starts the timer."""
# Looking up the IOLoop here allows to first instantiate the
# PeriodicCallback in another thread, then start it using
@@ -1216,36 +884,36 @@ def start(self):
self._next_timeout = self.io_loop.time()
self._schedule_next()
- def stop(self):
+ def stop(self) -> None:
"""Stops the timer."""
self._running = False
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
- def is_running(self):
- """Return True if this `.PeriodicCallback` has been started.
+ def is_running(self) -> bool:
+ """Returns ``True`` if this `.PeriodicCallback` has been started.
.. versionadded:: 4.1
"""
return self._running
- def _run(self):
+ def _run(self) -> None:
if not self._running:
return
try:
return self.callback()
except Exception:
- self.io_loop.handle_callback_exception(self.callback)
+ app_log.error("Exception in callback %r", self.callback, exc_info=True)
finally:
self._schedule_next()
- def _schedule_next(self):
+ def _schedule_next(self) -> None:
if self._running:
self._update_next(self.io_loop.time())
self._timeout = self.io_loop.add_timeout(self._next_timeout, self._run)
- def _update_next(self, current_time):
+ def _update_next(self, current_time: float) -> None:
callback_time_sec = self.callback_time / 1000.0
if self.jitter:
# apply jitter fraction
@@ -1255,8 +923,9 @@ def _update_next(self, current_time):
# to the start of the next. If one call takes too long,
# skip cycles to get back to a multiple of the original
# schedule.
- self._next_timeout += (math.floor((current_time - self._next_timeout) /
- callback_time_sec) + 1) * callback_time_sec
+ self._next_timeout += (
+ math.floor((current_time - self._next_timeout) / callback_time_sec) + 1
+ ) * callback_time_sec
else:
# If the clock moved backwards, ensure we advance the next
# timeout instead of recomputing the same value again.
diff --git a/tornado/iostream.py b/tornado/iostream.py
index 20f481d2ac..86235f4dc3 100644
--- a/tornado/iostream.py
+++ b/tornado/iostream.py
@@ -23,53 +23,54 @@
* `PipeIOStream`: Pipe-based IOStream implementation.
"""
-from __future__ import absolute_import, division, print_function
-
+import asyncio
import collections
import errno
import io
import numbers
import os
import socket
+import ssl
import sys
import re
-import warnings
-from tornado.concurrent import Future
+from tornado.concurrent import Future, future_set_result_unless_cancelled
from tornado import ioloop
-from tornado.log import gen_log, app_log
+from tornado.log import gen_log
from tornado.netutil import ssl_wrap_socket, _client_ssl_defaults, _server_ssl_defaults
-from tornado import stack_context
from tornado.util import errno_from_exception
-try:
- from tornado.platform.posix import _set_nonblocking
-except ImportError:
- _set_nonblocking = None
-
-try:
- import ssl
-except ImportError:
- # ssl is not available on Google App Engine
- ssl = None
-
-# These errnos indicate that a non-blocking operation must be retried
-# at a later time. On most platforms they're the same value, but on
-# some they differ.
-_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
-
-if hasattr(errno, "WSAEWOULDBLOCK"):
- _ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,) # type: ignore
+import typing
+from typing import (
+ Union,
+ Optional,
+ Awaitable,
+ Callable,
+ Pattern,
+ Any,
+ Dict,
+ TypeVar,
+ Tuple,
+)
+from types import TracebackType
+
+if typing.TYPE_CHECKING:
+ from typing import Deque, List, Type # noqa: F401
+
+_IOStreamType = TypeVar("_IOStreamType", bound="IOStream")
# These errnos indicate that a connection has been abruptly terminated.
# They should be caught and handled less noisily than other errors.
-_ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE,
- errno.ETIMEDOUT)
+_ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE, errno.ETIMEDOUT)
if hasattr(errno, "WSAECONNRESET"):
- _ERRNO_CONNRESET += (errno.WSAECONNRESET, errno.WSAECONNABORTED, errno.WSAETIMEDOUT) # type: ignore # noqa: E501
+ _ERRNO_CONNRESET += ( # type: ignore
+ errno.WSAECONNRESET, # type: ignore
+ errno.WSAECONNABORTED, # type: ignore
+ errno.WSAETIMEDOUT, # type: ignore
+ )
-if sys.platform == 'darwin':
+if sys.platform == "darwin":
# OSX appears to have a race condition that causes send(2) to return
# EPROTOTYPE if called while a socket is being torn down:
# http://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/
@@ -77,13 +78,7 @@
# instead of an unexpected error.
_ERRNO_CONNRESET += (errno.EPROTOTYPE,) # type: ignore
-# More non-portable errnos:
-_ERRNO_INPROGRESS = (errno.EINPROGRESS,)
-
-if hasattr(errno, "WSAEINPROGRESS"):
- _ERRNO_INPROGRESS += (errno.WSAEINPROGRESS,) # type: ignore
-
-_WINDOWS = sys.platform.startswith('win')
+_WINDOWS = sys.platform.startswith("win")
class StreamClosedError(IOError):
@@ -99,8 +94,9 @@ class StreamClosedError(IOError):
.. versionchanged:: 4.3
Added the ``real_error`` attribute.
"""
- def __init__(self, real_error=None):
- super(StreamClosedError, self).__init__('Stream is closed')
+
+ def __init__(self, real_error: Optional[BaseException] = None) -> None:
+ super().__init__("Stream is closed")
self.real_error = real_error
@@ -110,12 +106,12 @@ class UnsatisfiableReadError(Exception):
Raised by ``read_until`` and ``read_until_regex`` with a ``max_bytes``
argument.
"""
+
pass
class StreamBufferFullError(Exception):
- """Exception raised by `IOStream` methods when the buffer is full.
- """
+ """Exception raised by `IOStream` methods when the buffer is full."""
class _StreamBuffer(object):
@@ -124,21 +120,23 @@ class _StreamBuffer(object):
of data are encountered.
"""
- def __init__(self):
+ def __init__(self) -> None:
# A sequence of (False, bytearray) and (True, memoryview) objects
- self._buffers = collections.deque()
+ self._buffers = (
+ collections.deque()
+ ) # type: Deque[Tuple[bool, Union[bytearray, memoryview]]]
# Position in the first buffer
self._first_pos = 0
self._size = 0
- def __len__(self):
+ def __len__(self) -> int:
return self._size
# Data above this size will be appended separately instead
# of extending an existing bytearray
_large_buf_threshold = 2048
- def append(self, data):
+ def append(self, data: Union[bytes, bytearray, memoryview]) -> None:
"""
Append the given piece of data (should be a buffer-compatible object).
"""
@@ -156,11 +154,11 @@ def append(self, data):
if new_buf:
self._buffers.append((False, bytearray(data)))
else:
- b += data
+ b += data # type: ignore
self._size += size
- def peek(self, size):
+ def peek(self, size: int) -> memoryview:
"""
Get a view over at most ``size`` bytes (possibly fewer) at the
current buffer position.
@@ -169,15 +167,15 @@ def peek(self, size):
try:
is_memview, b = self._buffers[0]
except IndexError:
- return memoryview(b'')
+ return memoryview(b"")
pos = self._first_pos
if is_memview:
- return b[pos:pos + size]
+ return typing.cast(memoryview, b[pos : pos + size])
else:
- return memoryview(b)[pos:pos + size]
+ return memoryview(b)[pos : pos + size]
- def advance(self, size):
+ def advance(self, size: int) -> None:
"""
Advance the current buffer position by ``size`` bytes.
"""
@@ -200,7 +198,7 @@ def advance(self, size):
# Amortized O(1) shrink for Python 2
pos += size
if len(b) <= 2 * pos:
- del b[:pos]
+ del typing.cast(bytearray, b)[:pos]
pos = 0
size = 0
@@ -211,23 +209,27 @@ def advance(self, size):
class BaseIOStream(object):
"""A utility class to write to and read from a non-blocking file or socket.
- We support a non-blocking ``write()`` and a family of ``read_*()`` methods.
- All of the methods take an optional ``callback`` argument and return a
- `.Future` only if no callback is given. When the operation completes,
- the callback will be run or the `.Future` will resolve with the data
- read (or ``None`` for ``write()``). All outstanding ``Futures`` will
- resolve with a `StreamClosedError` when the stream is closed; users
- of the callback interface will be notified via
- `.BaseIOStream.set_close_callback` instead.
+ We support a non-blocking ``write()`` and a family of ``read_*()``
+ methods. When the operation completes, the ``Awaitable`` will resolve
+ with the data read (or ``None`` for ``write()``). All outstanding
+ ``Awaitables`` will resolve with a `StreamClosedError` when the
+ stream is closed; `.BaseIOStream.set_close_callback` can also be used
+ to be notified of a closed stream.
When a stream is closed due to an error, the IOStream's ``error``
attribute contains the exception object.
Subclasses must implement `fileno`, `close_fd`, `write_to_fd`,
`read_from_fd`, and optionally `get_fd_error`.
+
"""
- def __init__(self, max_buffer_size=None,
- read_chunk_size=None, max_write_buffer_size=None):
+
+ def __init__(
+ self,
+ max_buffer_size: Optional[int] = None,
+ read_chunk_size: Optional[int] = None,
+ max_write_buffer_size: Optional[int] = None,
+ ) -> None:
"""`BaseIOStream` constructor.
:arg max_buffer_size: Maximum amount of incoming data to buffer;
@@ -248,47 +250,43 @@ def __init__(self, max_buffer_size=None,
self.max_buffer_size = max_buffer_size or 104857600
# A chunk size that is too close to max_buffer_size can cause
# spurious failures.
- self.read_chunk_size = min(read_chunk_size or 65536,
- self.max_buffer_size // 2)
+ self.read_chunk_size = min(read_chunk_size or 65536, self.max_buffer_size // 2)
self.max_write_buffer_size = max_write_buffer_size
- self.error = None
+ self.error = None # type: Optional[BaseException]
self._read_buffer = bytearray()
self._read_buffer_pos = 0
self._read_buffer_size = 0
self._user_read_buffer = False
- self._after_user_read_buffer = None
+ self._after_user_read_buffer = None # type: Optional[bytearray]
self._write_buffer = _StreamBuffer()
self._total_write_index = 0
self._total_write_done_index = 0
- self._read_delimiter = None
- self._read_regex = None
- self._read_max_bytes = None
- self._read_bytes = None
+ self._read_delimiter = None # type: Optional[bytes]
+ self._read_regex = None # type: Optional[Pattern]
+ self._read_max_bytes = None # type: Optional[int]
+ self._read_bytes = None # type: Optional[int]
self._read_partial = False
self._read_until_close = False
- self._read_callback = None
- self._read_future = None
- self._streaming_callback = None
- self._write_callback = None
- self._write_futures = collections.deque()
- self._close_callback = None
- self._connect_callback = None
- self._connect_future = None
+ self._read_future = None # type: Optional[Future]
+ self._write_futures = (
+ collections.deque()
+ ) # type: Deque[Tuple[int, Future[None]]]
+ self._close_callback = None # type: Optional[Callable[[], None]]
+ self._connect_future = None # type: Optional[Future[IOStream]]
# _ssl_connect_future should be defined in SSLIOStream
- # but it's here so we can clean it up in maybe_run_close_callback.
+ # but it's here so we can clean it up in _signal_closed
# TODO: refactor that so subclasses can add additional futures
# to be cancelled.
- self._ssl_connect_future = None
+ self._ssl_connect_future = None # type: Optional[Future[SSLIOStream]]
self._connecting = False
- self._state = None
- self._pending_callbacks = 0
+ self._state = None # type: Optional[int]
self._closed = False
- def fileno(self):
+ def fileno(self) -> Union[int, ioloop._Selectable]:
"""Returns the file descriptor for this stream."""
raise NotImplementedError()
- def close_fd(self):
+ def close_fd(self) -> None:
"""Closes the file underlying this stream.
``close_fd`` is called by `BaseIOStream` and should not be called
@@ -296,14 +294,14 @@ def close_fd(self):
"""
raise NotImplementedError()
- def write_to_fd(self, data):
+ def write_to_fd(self, data: memoryview) -> int:
"""Attempts to write ``data`` to the underlying file.
Returns the number of bytes written.
"""
raise NotImplementedError()
- def read_from_fd(self, buf):
+ def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]:
"""Attempts to read from the underlying file.
Reads up to ``len(buf)`` bytes, storing them in the buffer.
@@ -318,7 +316,7 @@ def read_from_fd(self, buf):
"""
raise NotImplementedError()
- def get_fd_error(self):
+ def get_fd_error(self) -> Optional[Exception]:
"""Returns information about any error on the underlying file.
This method is called after the `.IOLoop` has signaled an error on the
@@ -328,13 +326,13 @@ def get_fd_error(self):
"""
return None
- def read_until_regex(self, regex, callback=None, max_bytes=None):
+ def read_until_regex(
+ self, regex: bytes, max_bytes: Optional[int] = None
+ ) -> Awaitable[bytes]:
"""Asynchronously read until we have matched the given regex.
The result includes the data that matches the regex and anything
- that came before it. If a callback is given, it will be run
- with the data as an argument; if not, this method returns a
- `.Future`.
+ that came before it.
If ``max_bytes`` is not None, the connection will be closed
if more than ``max_bytes`` bytes have been read and the regex is
@@ -344,13 +342,13 @@ def read_until_regex(self, regex, callback=None, max_bytes=None):
Added the ``max_bytes`` argument. The ``callback`` argument is
now optional and a `.Future` will be returned if it is omitted.
- .. deprecated:: 5.1
+ .. versionchanged:: 6.0
- The ``callback`` argument is deprecated and will be removed
- in Tornado 6.0. Use the returned `.Future` instead.
+ The ``callback`` argument was removed. Use the returned
+ `.Future` instead.
"""
- future = self._set_read_callback(callback)
+ future = self._start_read()
self._read_regex = re.compile(regex)
self._read_max_bytes = max_bytes
try:
@@ -361,19 +359,18 @@ def read_until_regex(self, regex, callback=None, max_bytes=None):
self.close(exc_info=e)
return future
except:
- if future is not None:
- # Ensure that the future doesn't log an error because its
- # failure was never examined.
- future.add_done_callback(lambda f: f.exception())
+ # Ensure that the future doesn't log an error because its
+ # failure was never examined.
+ future.add_done_callback(lambda f: f.exception())
raise
return future
- def read_until(self, delimiter, callback=None, max_bytes=None):
+ def read_until(
+ self, delimiter: bytes, max_bytes: Optional[int] = None
+ ) -> Awaitable[bytes]:
"""Asynchronously read until we have found the given delimiter.
The result includes all the data read including the delimiter.
- If a callback is given, it will be run with the data as an argument;
- if not, this method returns a `.Future`.
If ``max_bytes`` is not None, the connection will be closed
if more than ``max_bytes`` bytes have been read and the delimiter
@@ -383,12 +380,12 @@ def read_until(self, delimiter, callback=None, max_bytes=None):
Added the ``max_bytes`` argument. The ``callback`` argument is
now optional and a `.Future` will be returned if it is omitted.
- .. deprecated:: 5.1
+ .. versionchanged:: 6.0
- The ``callback`` argument is deprecated and will be removed
- in Tornado 6.0. Use the returned `.Future` instead.
+ The ``callback`` argument was removed. Use the returned
+ `.Future` instead.
"""
- future = self._set_read_callback(callback)
+ future = self._start_read()
self._read_delimiter = delimiter
self._read_max_bytes = max_bytes
try:
@@ -399,58 +396,42 @@ def read_until(self, delimiter, callback=None, max_bytes=None):
self.close(exc_info=e)
return future
except:
- if future is not None:
- future.add_done_callback(lambda f: f.exception())
+ future.add_done_callback(lambda f: f.exception())
raise
return future
- def read_bytes(self, num_bytes, callback=None, streaming_callback=None,
- partial=False):
+ def read_bytes(self, num_bytes: int, partial: bool = False) -> Awaitable[bytes]:
"""Asynchronously read a number of bytes.
- If a ``streaming_callback`` is given, it will be called with chunks
- of data as they become available, and the final result will be empty.
- Otherwise, the result is all the data that was read.
- If a callback is given, it will be run with the data as an argument;
- if not, this method returns a `.Future`.
-
- If ``partial`` is true, the callback is run as soon as we have
+ If ``partial`` is true, data is returned as soon as we have
any bytes to return (but never more than ``num_bytes``)
.. versionchanged:: 4.0
Added the ``partial`` argument. The callback argument is now
optional and a `.Future` will be returned if it is omitted.
- .. deprecated:: 5.1
+ .. versionchanged:: 6.0
- The ``callback`` and ``streaming_callback`` arguments are
- deprecated and will be removed in Tornado 6.0. Use the
- returned `.Future` (and ``partial=True`` for
- ``streaming_callback``) instead.
+ The ``callback`` and ``streaming_callback`` arguments have
+ been removed. Use the returned `.Future` (and
+ ``partial=True`` for ``streaming_callback``) instead.
"""
- future = self._set_read_callback(callback)
+ future = self._start_read()
assert isinstance(num_bytes, numbers.Integral)
self._read_bytes = num_bytes
self._read_partial = partial
- if streaming_callback is not None:
- warnings.warn("streaming_callback is deprecated, use partial instead",
- DeprecationWarning)
- self._streaming_callback = stack_context.wrap(streaming_callback)
try:
self._try_inline_read()
except:
- if future is not None:
- future.add_done_callback(lambda f: f.exception())
+ future.add_done_callback(lambda f: f.exception())
raise
return future
- def read_into(self, buf, callback=None, partial=False):
+ def read_into(self, buf: bytearray, partial: bool = False) -> Awaitable[int]:
"""Asynchronously read a number of bytes.
``buf`` must be a writable buffer into which data will be read.
- If a callback is given, it will be run with the number of read
- bytes as an argument; if not, this method returns a `.Future`.
If ``partial`` is true, the callback is run as soon as any bytes
have been read. Otherwise, it is run when the ``buf`` has been
@@ -458,24 +439,26 @@ def read_into(self, buf, callback=None, partial=False):
.. versionadded:: 5.0
- .. deprecated:: 5.1
+ .. versionchanged:: 6.0
- The ``callback`` argument is deprecated and will be removed
- in Tornado 6.0. Use the returned `.Future` instead.
+ The ``callback`` argument was removed. Use the returned
+ `.Future` instead.
"""
- future = self._set_read_callback(callback)
+ future = self._start_read()
# First copy data already in read buffer
available_bytes = self._read_buffer_size
n = len(buf)
if available_bytes >= n:
end = self._read_buffer_pos + n
- buf[:] = memoryview(self._read_buffer)[self._read_buffer_pos:end]
+ buf[:] = memoryview(self._read_buffer)[self._read_buffer_pos : end]
del self._read_buffer[:end]
self._after_user_read_buffer = self._read_buffer
elif available_bytes > 0:
- buf[:available_bytes] = memoryview(self._read_buffer)[self._read_buffer_pos:]
+ buf[:available_bytes] = memoryview(self._read_buffer)[
+ self._read_buffer_pos :
+ ]
# Set up the supplied buffer as our temporary read buffer.
# The original (if it had any data remaining) has been
@@ -490,68 +473,45 @@ def read_into(self, buf, callback=None, partial=False):
try:
self._try_inline_read()
except:
- if future is not None:
- future.add_done_callback(lambda f: f.exception())
+ future.add_done_callback(lambda f: f.exception())
raise
return future
- def read_until_close(self, callback=None, streaming_callback=None):
+ def read_until_close(self) -> Awaitable[bytes]:
"""Asynchronously reads all data from the socket until it is closed.
- If a ``streaming_callback`` is given, it will be called with chunks
- of data as they become available, and the final result will be empty.
- Otherwise, the result is all the data that was read.
- If a callback is given, it will be run with the data as an argument;
- if not, this method returns a `.Future`.
-
- Note that if a ``streaming_callback`` is used, data will be
- read from the socket as quickly as it becomes available; there
- is no way to apply backpressure or cancel the reads. If flow
- control or cancellation are desired, use a loop with
- `read_bytes(partial=True) <.read_bytes>` instead.
+ This will buffer all available data until ``max_buffer_size``
+ is reached. If flow control or cancellation are desired, use a
+ loop with `read_bytes(partial=True) <.read_bytes>` instead.
.. versionchanged:: 4.0
The callback argument is now optional and a `.Future` will
be returned if it is omitted.
- .. deprecated:: 5.1
+ .. versionchanged:: 6.0
- The ``callback`` and ``streaming_callback`` arguments are
- deprecated and will be removed in Tornado 6.0. Use the
- returned `.Future` (and `read_bytes` with ``partial=True``
- for ``streaming_callback``) instead.
+ The ``callback`` and ``streaming_callback`` arguments have
+ been removed. Use the returned `.Future` (and `read_bytes`
+ with ``partial=True`` for ``streaming_callback``) instead.
"""
- future = self._set_read_callback(callback)
- if streaming_callback is not None:
- warnings.warn("streaming_callback is deprecated, use read_bytes(partial=True) instead",
- DeprecationWarning)
- self._streaming_callback = stack_context.wrap(streaming_callback)
+ future = self._start_read()
if self.closed():
- if self._streaming_callback is not None:
- self._run_read_callback(self._read_buffer_size, True)
- self._run_read_callback(self._read_buffer_size, False)
+ self._finish_read(self._read_buffer_size, False)
return future
self._read_until_close = True
try:
self._try_inline_read()
except:
- if future is not None:
- future.add_done_callback(lambda f: f.exception())
+ future.add_done_callback(lambda f: f.exception())
raise
return future
- def write(self, data, callback=None):
+ def write(self, data: Union[bytes, memoryview]) -> "Future[None]":
"""Asynchronously write the given data to this stream.
- If ``callback`` is given, we call it when all of the buffered write
- data has been successfully written to the stream. If there was
- previously buffered write data and an old write callback, that
- callback is simply overwritten with this new callback.
-
- If no ``callback`` is given, this method returns a `.Future` that
- resolves (with a result of ``None``) when the write has been
- completed.
+ This method returns a `.Future` that resolves (with a result
+ of ``None``) when the write has been completed.
The ``data`` argument may be of type `bytes` or `memoryview`.
@@ -561,28 +521,27 @@ def write(self, data, callback=None):
.. versionchanged:: 4.5
Added support for `memoryview` arguments.
- .. deprecated:: 5.1
+ .. versionchanged:: 6.0
- The ``callback`` argument is deprecated and will be removed
- in Tornado 6.0. Use the returned `.Future` instead.
+ The ``callback`` argument was removed. Use the returned
+ `.Future` instead.
"""
self._check_closed()
if data:
- if (self.max_write_buffer_size is not None and
- len(self._write_buffer) + len(data) > self.max_write_buffer_size):
+ if isinstance(data, memoryview):
+ # Make sure that ``len(data) == data.nbytes``
+ data = memoryview(data).cast("B")
+ if (
+ self.max_write_buffer_size is not None
+ and len(self._write_buffer) + len(data) > self.max_write_buffer_size
+ ):
raise StreamBufferFullError("Reached maximum write buffer size")
self._write_buffer.append(data)
self._total_write_index += len(data)
- if callback is not None:
- warnings.warn("callback argument is deprecated, use returned Future instead",
- DeprecationWarning)
- self._write_callback = stack_context.wrap(callback)
- future = None
- else:
- future = Future()
- future.add_done_callback(lambda f: f.exception())
- self._write_futures.append((self._total_write_index, future))
+ future = Future() # type: Future[None]
+ future.add_done_callback(lambda f: f.exception())
+ self._write_futures.append((self._total_write_index, future))
if not self._connecting:
self._handle_write()
if self._write_buffer:
@@ -590,7 +549,7 @@ def write(self, data, callback=None):
self._maybe_add_error_listener()
return future
- def set_close_callback(self, callback):
+ def set_close_callback(self, callback: Optional[Callable[[], None]]) -> None:
"""Call the given callback when the stream is closed.
This mostly is not necessary for applications that use the
@@ -600,12 +559,24 @@ def set_close_callback(self, callback):
closed while no other read or write is in progress.
Unlike other callback-based interfaces, ``set_close_callback``
- will not be removed in Tornado 6.0.
+ was not removed in Tornado 6.0.
"""
- self._close_callback = stack_context.wrap(callback)
+ self._close_callback = callback
self._maybe_add_error_listener()
- def close(self, exc_info=False):
+ def close(
+ self,
+ exc_info: Union[
+ None,
+ bool,
+ BaseException,
+ Tuple[
+ "Optional[Type[BaseException]]",
+ Optional[BaseException],
+ Optional[TracebackType],
+ ],
+ ] = False,
+ ) -> None:
"""Close this stream.
If ``exc_info`` is true, set the ``error`` attribute to the current
@@ -623,61 +594,76 @@ def close(self, exc_info=False):
if any(exc_info):
self.error = exc_info[1]
if self._read_until_close:
- if (self._streaming_callback is not None and
- self._read_buffer_size):
- self._run_read_callback(self._read_buffer_size, True)
self._read_until_close = False
- self._run_read_callback(self._read_buffer_size, False)
+ self._finish_read(self._read_buffer_size, False)
+ elif self._read_future is not None:
+ # resolve reads that are pending and ready to complete
+ try:
+ pos = self._find_read_pos()
+ except UnsatisfiableReadError:
+ pass
+ else:
+ if pos is not None:
+ self._read_from_buffer(pos)
if self._state is not None:
self.io_loop.remove_handler(self.fileno())
self._state = None
self.close_fd()
self._closed = True
- self._maybe_run_close_callback()
-
- def _maybe_run_close_callback(self):
- # If there are pending callbacks, don't run the close callback
- # until they're done (see _maybe_add_error_handler)
- if self.closed() and self._pending_callbacks == 0:
- futures = []
- if self._read_future is not None:
- futures.append(self._read_future)
- self._read_future = None
- futures += [future for _, future in self._write_futures]
- self._write_futures.clear()
- if self._connect_future is not None:
- futures.append(self._connect_future)
- self._connect_future = None
- if self._ssl_connect_future is not None:
- futures.append(self._ssl_connect_future)
- self._ssl_connect_future = None
- for future in futures:
+ self._signal_closed()
+
+ def _signal_closed(self) -> None:
+ futures = [] # type: List[Future]
+ if self._read_future is not None:
+ futures.append(self._read_future)
+ self._read_future = None
+ futures += [future for _, future in self._write_futures]
+ self._write_futures.clear()
+ if self._connect_future is not None:
+ futures.append(self._connect_future)
+ self._connect_future = None
+ for future in futures:
+ if not future.done():
future.set_exception(StreamClosedError(real_error=self.error))
+ # Reference the exception to silence warnings. Annoyingly,
+ # this raises if the future was cancelled, but just
+ # returns any other error.
+ try:
future.exception()
- if self._close_callback is not None:
- cb = self._close_callback
- self._close_callback = None
- self._run_callback(cb)
- # Delete any unfinished callbacks to break up reference cycles.
- self._read_callback = self._write_callback = None
- # Clear the buffers so they can be cleared immediately even
- # if the IOStream object is kept alive by a reference cycle.
- # TODO: Clear the read buffer too; it currently breaks some tests.
- self._write_buffer = None
-
- def reading(self):
- """Returns true if we are currently reading from the stream."""
- return self._read_callback is not None or self._read_future is not None
-
- def writing(self):
- """Returns true if we are currently writing to the stream."""
+ except asyncio.CancelledError:
+ pass
+ if self._ssl_connect_future is not None:
+ # _ssl_connect_future expects to see the real exception (typically
+ # an ssl.SSLError), not just StreamClosedError.
+ if not self._ssl_connect_future.done():
+ if self.error is not None:
+ self._ssl_connect_future.set_exception(self.error)
+ else:
+ self._ssl_connect_future.set_exception(StreamClosedError())
+ self._ssl_connect_future.exception()
+ self._ssl_connect_future = None
+ if self._close_callback is not None:
+ cb = self._close_callback
+ self._close_callback = None
+ self.io_loop.add_callback(cb)
+ # Clear the buffers so they can be cleared immediately even
+ # if the IOStream object is kept alive by a reference cycle.
+ # TODO: Clear the read buffer too; it currently breaks some tests.
+ self._write_buffer = None # type: ignore
+
+ def reading(self) -> bool:
+ """Returns ``True`` if we are currently reading from the stream."""
+ return self._read_future is not None
+
+ def writing(self) -> bool:
+ """Returns ``True`` if we are currently writing to the stream."""
return bool(self._write_buffer)
- def closed(self):
- """Returns true if the stream has been closed."""
+ def closed(self) -> bool:
+ """Returns ``True`` if the stream has been closed."""
return self._closed
- def set_nodelay(self, value):
+ def set_nodelay(self, value: bool) -> None:
"""Sets the no-delay flag for this stream.
By default, data written to TCP streams may be held for a time
@@ -692,7 +678,10 @@ def set_nodelay(self, value):
"""
pass
- def _handle_events(self, fd, events):
+ def _handle_connect(self) -> None:
+ raise NotImplementedError()
+
+ def _handle_events(self, fd: Union[int, ioloop._Selectable], events: int) -> None:
if self.closed():
gen_log.warning("Got events for closed stream %s", fd)
return
@@ -732,170 +721,114 @@ def _handle_events(self, fd, events):
# yet anyway, so we don't need to listen in this case.
state |= self.io_loop.READ
if state != self._state:
- assert self._state is not None, \
- "shouldn't happen: _handle_events without self._state"
+ assert (
+ self._state is not None
+ ), "shouldn't happen: _handle_events without self._state"
self._state = state
self.io_loop.update_handler(self.fileno(), self._state)
except UnsatisfiableReadError as e:
gen_log.info("Unsatisfiable read, closing connection: %s" % e)
self.close(exc_info=e)
except Exception as e:
- gen_log.error("Uncaught exception, closing connection.",
- exc_info=True)
+ gen_log.error("Uncaught exception, closing connection.", exc_info=True)
self.close(exc_info=e)
raise
- def _run_callback(self, callback, *args):
- def wrapper():
- self._pending_callbacks -= 1
- try:
- return callback(*args)
- except Exception as e:
- app_log.error("Uncaught exception, closing connection.",
- exc_info=True)
- # Close the socket on an uncaught exception from a user callback
- # (It would eventually get closed when the socket object is
- # gc'd, but we don't want to rely on gc happening before we
- # run out of file descriptors)
- self.close(exc_info=e)
- # Re-raise the exception so that IOLoop.handle_callback_exception
- # can see it and log the error
- raise
- finally:
- self._maybe_add_error_listener()
- # We schedule callbacks to be run on the next IOLoop iteration
- # rather than running them directly for several reasons:
- # * Prevents unbounded stack growth when a callback calls an
- # IOLoop operation that immediately runs another callback
- # * Provides a predictable execution context for e.g.
- # non-reentrant mutexes
- # * Ensures that the try/except in wrapper() is run outside
- # of the application's StackContexts
- with stack_context.NullContext():
- # stack_context was already captured in callback, we don't need to
- # capture it again for IOStream's wrapper. This is especially
- # important if the callback was pre-wrapped before entry to
- # IOStream (as in HTTPConnection._header_callback), as we could
- # capture and leak the wrong context here.
- self._pending_callbacks += 1
- self.io_loop.add_callback(wrapper)
-
- def _read_to_buffer_loop(self):
+ def _read_to_buffer_loop(self) -> Optional[int]:
# This method is called from _handle_read and _try_inline_read.
- try:
- if self._read_bytes is not None:
- target_bytes = self._read_bytes
- elif self._read_max_bytes is not None:
- target_bytes = self._read_max_bytes
- elif self.reading():
- # For read_until without max_bytes, or
- # read_until_close, read as much as we can before
- # scanning for the delimiter.
- target_bytes = None
- else:
- target_bytes = 0
- next_find_pos = 0
- # Pretend to have a pending callback so that an EOF in
- # _read_to_buffer doesn't trigger an immediate close
- # callback. At the end of this method we'll either
- # establish a real pending callback via
- # _read_from_buffer or run the close callback.
- #
- # We need two try statements here so that
- # pending_callbacks is decremented before the `except`
- # clause below (which calls `close` and does need to
- # trigger the callback)
- self._pending_callbacks += 1
- while not self.closed():
- # Read from the socket until we get EWOULDBLOCK or equivalent.
- # SSL sockets do some internal buffering, and if the data is
- # sitting in the SSL object's buffer select() and friends
- # can't see it; the only way to find out if it's there is to
- # try to read it.
- if self._read_to_buffer() == 0:
- break
-
- self._run_streaming_callback()
+ if self._read_bytes is not None:
+ target_bytes = self._read_bytes # type: Optional[int]
+ elif self._read_max_bytes is not None:
+ target_bytes = self._read_max_bytes
+ elif self.reading():
+ # For read_until without max_bytes, or
+ # read_until_close, read as much as we can before
+ # scanning for the delimiter.
+ target_bytes = None
+ else:
+ target_bytes = 0
+ next_find_pos = 0
+ while not self.closed():
+ # Read from the socket until we get EWOULDBLOCK or equivalent.
+ # SSL sockets do some internal buffering, and if the data is
+ # sitting in the SSL object's buffer select() and friends
+ # can't see it; the only way to find out if it's there is to
+ # try to read it.
+ if self._read_to_buffer() == 0:
+ break
- # If we've read all the bytes we can use, break out of
- # this loop. We can't just call read_from_buffer here
- # because of subtle interactions with the
- # pending_callback and error_listener mechanisms.
- #
- # If we've reached target_bytes, we know we're done.
- if (target_bytes is not None and
- self._read_buffer_size >= target_bytes):
- break
+ # If we've read all the bytes we can use, break out of
+ # this loop.
- # Otherwise, we need to call the more expensive find_read_pos.
- # It's inefficient to do this on every read, so instead
- # do it on the first read and whenever the read buffer
- # size has doubled.
- if self._read_buffer_size >= next_find_pos:
- pos = self._find_read_pos()
- if pos is not None:
- return pos
- next_find_pos = self._read_buffer_size * 2
- return self._find_read_pos()
- finally:
- self._pending_callbacks -= 1
+ # If we've reached target_bytes, we know we're done.
+ if target_bytes is not None and self._read_buffer_size >= target_bytes:
+ break
- def _handle_read(self):
+ # Otherwise, we need to call the more expensive find_read_pos.
+ # It's inefficient to do this on every read, so instead
+ # do it on the first read and whenever the read buffer
+ # size has doubled.
+ if self._read_buffer_size >= next_find_pos:
+ pos = self._find_read_pos()
+ if pos is not None:
+ return pos
+ next_find_pos = self._read_buffer_size * 2
+ return self._find_read_pos()
+
+ def _handle_read(self) -> None:
try:
pos = self._read_to_buffer_loop()
except UnsatisfiableReadError:
raise
+ except asyncio.CancelledError:
+ raise
except Exception as e:
gen_log.warning("error on read: %s" % e)
self.close(exc_info=e)
return
if pos is not None:
self._read_from_buffer(pos)
- return
- else:
- self._maybe_run_close_callback()
-
- def _set_read_callback(self, callback):
- assert self._read_callback is None, "Already reading"
- assert self._read_future is None, "Already reading"
- if callback is not None:
- warnings.warn("callbacks are deprecated, use returned Future instead",
- DeprecationWarning)
- self._read_callback = stack_context.wrap(callback)
- else:
- self._read_future = Future()
+
+ def _start_read(self) -> Future:
+ if self._read_future is not None:
+ # It is an error to start a read while a prior read is unresolved.
+ # However, if the prior read is unresolved because the stream was
+ # closed without satisfying it, it's better to raise
+ # StreamClosedError instead of AssertionError. In particular, this
+ # situation occurs in harmless situations in http1connection.py and
+ # an AssertionError would be logged noisily.
+ #
+ # On the other hand, it is legal to start a new read while the
+ # stream is closed, in case the read can be satisfied from the
+ # read buffer. So we only want to check the closed status of the
+ # stream if we need to decide what kind of error to raise for
+ # "already reading".
+ #
+ # These conditions have proven difficult to test; we have no
+ # unittests that reliably verify this behavior so be careful
+ # when making changes here. See #2651 and #2719.
+ self._check_closed()
+ assert self._read_future is None, "Already reading"
+ self._read_future = Future()
return self._read_future
- def _run_read_callback(self, size, streaming):
+ def _finish_read(self, size: int, streaming: bool) -> None:
if self._user_read_buffer:
self._read_buffer = self._after_user_read_buffer or bytearray()
self._after_user_read_buffer = None
self._read_buffer_pos = 0
self._read_buffer_size = len(self._read_buffer)
self._user_read_buffer = False
- result = size
+ result = size # type: Union[int, bytes]
else:
result = self._consume(size)
- if streaming:
- callback = self._streaming_callback
- else:
- callback = self._read_callback
- self._read_callback = self._streaming_callback = None
- if self._read_future is not None:
- assert callback is None
- future = self._read_future
- self._read_future = None
-
- future.set_result(result)
- if callback is not None:
- assert (self._read_future is None) or streaming
- self._run_callback(callback, result)
- else:
- # If we scheduled a callback, we will add the error listener
- # afterwards. If we didn't, we have to do it now.
- self._maybe_add_error_listener()
+ if self._read_future is not None:
+ future = self._read_future
+ self._read_future = None
+ future_set_result_unless_cancelled(future, result)
+ self._maybe_add_error_listener()
- def _try_inline_read(self):
+ def _try_inline_read(self) -> None:
"""Attempt to complete the current read operation from buffered data.
If the read can be completed without blocking, schedules the
@@ -903,32 +836,21 @@ def _try_inline_read(self):
listening for reads on the socket.
"""
# See if we've already got the data from a previous read
- self._run_streaming_callback()
pos = self._find_read_pos()
if pos is not None:
self._read_from_buffer(pos)
return
self._check_closed()
- try:
- pos = self._read_to_buffer_loop()
- except Exception:
- # If there was an in _read_to_buffer, we called close() already,
- # but couldn't run the close callback because of _pending_callbacks.
- # Before we escape from this function, run the close callback if
- # applicable.
- self._maybe_run_close_callback()
- raise
+ pos = self._read_to_buffer_loop()
if pos is not None:
self._read_from_buffer(pos)
return
- # We couldn't satisfy the read inline, so either close the stream
- # or listen for new data.
- if self.closed():
- self._maybe_run_close_callback()
- else:
+ # We couldn't satisfy the read inline, so make sure we're
+ # listening for new data unless the stream is closed.
+ if not self.closed():
self._add_io_state(ioloop.IOLoop.READ)
- def _read_to_buffer(self):
+ def _read_to_buffer(self) -> Optional[int]:
"""Reads from the socket and appends the result to the read buffer.
Returns the number of bytes read. Returns 0 if there is nothing
@@ -939,20 +861,20 @@ def _read_to_buffer(self):
while True:
try:
if self._user_read_buffer:
- buf = memoryview(self._read_buffer)[self._read_buffer_size:]
+ buf = memoryview(self._read_buffer)[
+ self._read_buffer_size :
+ ] # type: Union[memoryview, bytearray]
else:
buf = bytearray(self.read_chunk_size)
bytes_read = self.read_from_fd(buf)
except (socket.error, IOError, OSError) as e:
- if errno_from_exception(e) == errno.EINTR:
- continue
# ssl.SSLError is a subclass of socket.error
if self._is_connreset(e):
# Treat ECONNRESET as a connection close rather than
# an error to minimize log spam (the exception will
# be available on self.error for apps that care).
self.close(exc_info=e)
- return
+ return None
self.close(exc_info=e)
raise
break
@@ -967,22 +889,14 @@ def _read_to_buffer(self):
finally:
# Break the reference to buf so we don't waste a chunk's worth of
# memory in case an exception hangs on to our stack frame.
- buf = None
+ del buf
if self._read_buffer_size > self.max_buffer_size:
gen_log.error("Reached maximum read buffer size")
self.close()
raise StreamBufferFullError("Reached maximum read buffer size")
return bytes_read
- def _run_streaming_callback(self):
- if self._streaming_callback is not None and self._read_buffer_size:
- bytes_to_consume = self._read_buffer_size
- if self._read_bytes is not None:
- bytes_to_consume = min(self._read_bytes, bytes_to_consume)
- self._read_bytes -= bytes_to_consume
- self._run_read_callback(bytes_to_consume, True)
-
- def _read_from_buffer(self, pos):
+ def _read_from_buffer(self, pos: int) -> None:
"""Attempts to complete the currently-pending read from the buffer.
The argument is either a position in the read buffer or None,
@@ -990,18 +904,19 @@ def _read_from_buffer(self, pos):
"""
self._read_bytes = self._read_delimiter = self._read_regex = None
self._read_partial = False
- self._run_read_callback(pos, False)
+ self._finish_read(pos, False)
- def _find_read_pos(self):
+ def _find_read_pos(self) -> Optional[int]:
"""Attempts to find a position in the read buffer that satisfies
the currently-pending read.
Returns a position in the buffer if the current read can be satisfied,
or None if it cannot.
"""
- if (self._read_bytes is not None and
- (self._read_buffer_size >= self._read_bytes or
- (self._read_partial and self._read_buffer_size > 0))):
+ if self._read_bytes is not None and (
+ self._read_buffer_size >= self._read_bytes
+ or (self._read_partial and self._read_buffer_size > 0)
+ ):
num_bytes = min(self._read_bytes, self._read_buffer_size)
return num_bytes
elif self._read_delimiter is not None:
@@ -1014,20 +929,18 @@ def _find_read_pos(self):
# since large merges are relatively expensive and get undone in
# _consume().
if self._read_buffer:
- loc = self._read_buffer.find(self._read_delimiter,
- self._read_buffer_pos)
+ loc = self._read_buffer.find(
+ self._read_delimiter, self._read_buffer_pos
+ )
if loc != -1:
loc -= self._read_buffer_pos
delimiter_len = len(self._read_delimiter)
- self._check_max_bytes(self._read_delimiter,
- loc + delimiter_len)
+ self._check_max_bytes(self._read_delimiter, loc + delimiter_len)
return loc + delimiter_len
- self._check_max_bytes(self._read_delimiter,
- self._read_buffer_size)
+ self._check_max_bytes(self._read_delimiter, self._read_buffer_size)
elif self._read_regex is not None:
if self._read_buffer:
- m = self._read_regex.search(self._read_buffer,
- self._read_buffer_pos)
+ m = self._read_regex.search(self._read_buffer, self._read_buffer_pos)
if m is not None:
loc = m.end() - self._read_buffer_pos
self._check_max_bytes(self._read_regex, loc)
@@ -1035,14 +948,14 @@ def _find_read_pos(self):
self._check_max_bytes(self._read_regex, self._read_buffer_size)
return None
- def _check_max_bytes(self, delimiter, size):
- if (self._read_max_bytes is not None and
- size > self._read_max_bytes):
+ def _check_max_bytes(self, delimiter: Union[bytes, Pattern], size: int) -> None:
+ if self._read_max_bytes is not None and size > self._read_max_bytes:
raise UnsatisfiableReadError(
- "delimiter %r not found within %d bytes" % (
- delimiter, self._read_max_bytes))
+ "delimiter %r not found within %d bytes"
+ % (delimiter, self._read_max_bytes)
+ )
- def _handle_write(self):
+ def _handle_write(self) -> None:
while True:
size = len(self._write_buffer)
if not size:
@@ -1062,111 +975,102 @@ def _handle_write(self):
break
self._write_buffer.advance(num_bytes)
self._total_write_done_index += num_bytes
+ except BlockingIOError:
+ break
except (socket.error, IOError, OSError) as e:
- if e.args[0] in _ERRNO_WOULDBLOCK:
- break
- else:
- if not self._is_connreset(e):
- # Broken pipe errors are usually caused by connection
- # reset, and its better to not log EPIPE errors to
- # minimize log spam
- gen_log.warning("Write error on %s: %s",
- self.fileno(), e)
- self.close(exc_info=e)
- return
+ if not self._is_connreset(e):
+ # Broken pipe errors are usually caused by connection
+ # reset, and its better to not log EPIPE errors to
+ # minimize log spam
+ gen_log.warning("Write error on %s: %s", self.fileno(), e)
+ self.close(exc_info=e)
+ return
while self._write_futures:
index, future = self._write_futures[0]
if index > self._total_write_done_index:
break
self._write_futures.popleft()
- future.set_result(None)
-
- if not len(self._write_buffer):
- if self._write_callback:
- callback = self._write_callback
- self._write_callback = None
- self._run_callback(callback)
+ future_set_result_unless_cancelled(future, None)
- def _consume(self, loc):
+ def _consume(self, loc: int) -> bytes:
# Consume loc bytes from the read buffer and return them
if loc == 0:
return b""
assert loc <= self._read_buffer_size
# Slice the bytearray buffer into bytes, without intermediate copying
- b = (memoryview(self._read_buffer)
- [self._read_buffer_pos:self._read_buffer_pos + loc]
- ).tobytes()
+ b = (
+ memoryview(self._read_buffer)[
+ self._read_buffer_pos : self._read_buffer_pos + loc
+ ]
+ ).tobytes()
self._read_buffer_pos += loc
self._read_buffer_size -= loc
# Amortized O(1) shrink
# (this heuristic is implemented natively in Python 3.4+
# but is replicated here for Python 2)
if self._read_buffer_pos > self._read_buffer_size:
- del self._read_buffer[:self._read_buffer_pos]
+ del self._read_buffer[: self._read_buffer_pos]
self._read_buffer_pos = 0
return b
- def _check_closed(self):
+ def _check_closed(self) -> None:
if self.closed():
raise StreamClosedError(real_error=self.error)
- def _maybe_add_error_listener(self):
+ def _maybe_add_error_listener(self) -> None:
# This method is part of an optimization: to detect a connection that
# is closed when we're not actively reading or writing, we must listen
# for read events. However, it is inefficient to do this when the
# connection is first established because we are going to read or write
# immediately anyway. Instead, we insert checks at various times to
# see if the connection is idle and add the read listener then.
- if self._pending_callbacks != 0:
- return
if self._state is None or self._state == ioloop.IOLoop.ERROR:
- if self.closed():
- self._maybe_run_close_callback()
- elif (self._read_buffer_size == 0 and
- self._close_callback is not None):
+ if (
+ not self.closed()
+ and self._read_buffer_size == 0
+ and self._close_callback is not None
+ ):
self._add_io_state(ioloop.IOLoop.READ)
- def _add_io_state(self, state):
+ def _add_io_state(self, state: int) -> None:
"""Adds `state` (IOLoop.{READ,WRITE} flags) to our event handler.
Implementation notes: Reads and writes have a fast path and a
slow path. The fast path reads synchronously from socket
buffers, while the slow path uses `_add_io_state` to schedule
- an IOLoop callback. Note that in both cases, the callback is
- run asynchronously with `_run_callback`.
+ an IOLoop callback.
To detect closed connections, we must have called
`_add_io_state` at some point, but we want to delay this as
much as possible so we don't have to set an `IOLoop.ERROR`
listener that will be overwritten by the next slow-path
- operation. As long as there are callbacks scheduled for
- fast-path ops, those callbacks may do more reads.
- If a sequence of fast-path ops do not end in a slow-path op,
- (e.g. for an @asynchronous long-poll request), we must add
- the error handler. This is done in `_run_callback` and `write`
- (since the write callback is optional so we can have a
- fast-path write with no `_run_callback`)
+ operation. If a sequence of fast-path ops do not end in a
+ slow-path op, (e.g. for an @asynchronous long-poll request),
+ we must add the error handler.
+
+ TODO: reevaluate this now that callbacks are gone.
+
"""
if self.closed():
# connection has been closed, so there can be no future events
return
if self._state is None:
self._state = ioloop.IOLoop.ERROR | state
- with stack_context.NullContext():
- self.io_loop.add_handler(
- self.fileno(), self._handle_events, self._state)
+ self.io_loop.add_handler(self.fileno(), self._handle_events, self._state)
elif not self._state & state:
self._state = self._state | state
self.io_loop.update_handler(self.fileno(), self._state)
- def _is_connreset(self, exc):
- """Return true if exc is ECONNRESET or equivalent.
+ def _is_connreset(self, exc: BaseException) -> bool:
+ """Return ``True`` if exc is ECONNRESET or equivalent.
May be overridden in subclasses.
"""
- return (isinstance(exc, (socket.error, IOError)) and
- errno_from_exception(exc) in _ERRNO_CONNRESET)
+ return (
+ isinstance(exc, (socket.error, IOError))
+ and errno_from_exception(exc) in _ERRNO_CONNRESET
+ )
class IOStream(BaseIOStream):
@@ -1190,24 +1094,23 @@ class IOStream(BaseIOStream):
import tornado.iostream
import socket
- def send_request():
- stream.write(b"GET / HTTP/1.0\r\nHost: friendfeed.com\r\n\r\n")
- stream.read_until(b"\r\n\r\n", on_headers)
-
- def on_headers(data):
+ async def main():
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+ stream = tornado.iostream.IOStream(s)
+ await stream.connect(("friendfeed.com", 80))
+ await stream.write(b"GET / HTTP/1.0\r\nHost: friendfeed.com\r\n\r\n")
+ header_data = await stream.read_until(b"\r\n\r\n")
headers = {}
- for line in data.split(b"\r\n"):
- parts = line.split(b":")
- if len(parts) == 2:
- headers[parts[0].strip()] = parts[1].strip()
- stream.read_bytes(int(headers[b"Content-Length"]), on_body)
-
- def on_body(data):
- print(data)
+ for line in header_data.split(b"\r\n"):
+ parts = line.split(b":")
+ if len(parts) == 2:
+ headers[parts[0].strip()] = parts[1].strip()
+ body_data = await stream.read_bytes(int(headers[b"Content-Length"]))
+ print(body_data)
stream.close()
- tornado.ioloop.IOLoop.current().stop()
if __name__ == '__main__':
+ tornado.ioloop.IOLoop.current().run_sync(main)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
stream = tornado.iostream.IOStream(s)
stream.connect(("friendfeed.com", 80), send_request)
@@ -1217,43 +1120,42 @@ def on_body(data):
:hide:
"""
- def __init__(self, socket, *args, **kwargs):
+
+ def __init__(self, socket: socket.socket, *args: Any, **kwargs: Any) -> None:
self.socket = socket
self.socket.setblocking(False)
- super(IOStream, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
- def fileno(self):
+ def fileno(self) -> Union[int, ioloop._Selectable]:
return self.socket
- def close_fd(self):
+ def close_fd(self) -> None:
self.socket.close()
- self.socket = None
+ self.socket = None # type: ignore
- def get_fd_error(self):
- errno = self.socket.getsockopt(socket.SOL_SOCKET,
- socket.SO_ERROR)
+ def get_fd_error(self) -> Optional[Exception]:
+ errno = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
return socket.error(errno, os.strerror(errno))
- def read_from_fd(self, buf):
+ def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]:
try:
- return self.socket.recv_into(buf)
- except socket.error as e:
- if e.args[0] in _ERRNO_WOULDBLOCK:
- return None
- else:
- raise
+ return self.socket.recv_into(buf, len(buf))
+ except BlockingIOError:
+ return None
finally:
- buf = None
+ del buf
- def write_to_fd(self, data):
+ def write_to_fd(self, data: memoryview) -> int:
try:
- return self.socket.send(data)
+ return self.socket.send(data) # type: ignore
finally:
# Avoid keeping to data, which can be a memoryview.
# See https://github.com/tornadoweb/tornado/pull/2008
del data
- def connect(self, address, callback=None, server_hostname=None):
+ def connect(
+ self: _IOStreamType, address: Any, server_hostname: Optional[str] = None
+ ) -> "Future[_IOStreamType]":
"""Connects the socket to a remote address without blocking.
May only be called if the socket passed to the constructor was
@@ -1292,41 +1194,39 @@ class is recommended instead of calling this method directly.
suitably-configured `ssl.SSLContext` to the
`SSLIOStream` constructor to disable.
- .. deprecated:: 5.1
+ .. versionchanged:: 6.0
- The ``callback`` argument is deprecated and will be removed
- in Tornado 6.0. Use the returned `.Future` instead.
+ The ``callback`` argument was removed. Use the returned
+ `.Future` instead.
"""
self._connecting = True
- if callback is not None:
- warnings.warn("callback argument is deprecated, use returned Future instead",
- DeprecationWarning)
- self._connect_callback = stack_context.wrap(callback)
- future = None
- else:
- future = self._connect_future = Future()
+ future = Future() # type: Future[_IOStreamType]
+ self._connect_future = typing.cast("Future[IOStream]", future)
try:
self.socket.connect(address)
- except socket.error as e:
+ except BlockingIOError:
# In non-blocking mode we expect connect() to raise an
# exception with EINPROGRESS or EWOULDBLOCK.
- #
+ pass
+ except socket.error as e:
# On freebsd, other errors such as ECONNREFUSED may be
# returned immediately when attempting to connect to
# localhost, so handle them the same way as an error
# reported later in _handle_connect.
- if (errno_from_exception(e) not in _ERRNO_INPROGRESS and
- errno_from_exception(e) not in _ERRNO_WOULDBLOCK):
- if future is None:
- gen_log.warning("Connect error on fd %s: %s",
- self.socket.fileno(), e)
- self.close(exc_info=e)
- return future
+ if future is None:
+ gen_log.warning("Connect error on fd %s: %s", self.socket.fileno(), e)
+ self.close(exc_info=e)
+ return future
self._add_io_state(self.io_loop.WRITE)
return future
- def start_tls(self, server_side, ssl_options=None, server_hostname=None):
+ def start_tls(
+ self,
+ server_side: bool,
+ ssl_options: Optional[Union[Dict[str, Any], ssl.SSLContext]] = None,
+ server_hostname: Optional[str] = None,
+ ) -> Awaitable["SSLIOStream"]:
"""Convert this `IOStream` to an `SSLIOStream`.
This enables protocols that begin in clear-text mode and
@@ -1361,11 +1261,14 @@ def start_tls(self, server_side, ssl_options=None, server_hostname=None):
``ssl_options=dict(cert_reqs=ssl.CERT_NONE)`` or a
suitably-configured `ssl.SSLContext` to disable.
"""
- if (self._read_callback or self._read_future or
- self._write_callback or self._write_futures or
- self._connect_callback or self._connect_future or
- self._pending_callbacks or self._closed or
- self._read_buffer or self._write_buffer):
+ if (
+ self._read_future
+ or self._write_futures
+ or self._connect_future
+ or self._closed
+ or self._read_buffer
+ or self._write_buffer
+ ):
raise ValueError("IOStream is not idle; cannot convert to SSL")
if ssl_options is None:
if server_side:
@@ -1375,43 +1278,33 @@ def start_tls(self, server_side, ssl_options=None, server_hostname=None):
socket = self.socket
self.io_loop.remove_handler(socket)
- self.socket = None
- socket = ssl_wrap_socket(socket, ssl_options,
- server_hostname=server_hostname,
- server_side=server_side,
- do_handshake_on_connect=False)
+ self.socket = None # type: ignore
+ socket = ssl_wrap_socket(
+ socket,
+ ssl_options,
+ server_hostname=server_hostname,
+ server_side=server_side,
+ do_handshake_on_connect=False,
+ )
orig_close_callback = self._close_callback
self._close_callback = None
- future = Future()
+ future = Future() # type: Future[SSLIOStream]
ssl_stream = SSLIOStream(socket, ssl_options=ssl_options)
- # Wrap the original close callback so we can fail our Future as well.
- # If we had an "unwrap" counterpart to this method we would need
- # to restore the original callback after our Future resolves
- # so that repeated wrap/unwrap calls don't build up layers.
-
- def close_callback():
- if not future.done():
- # Note that unlike most Futures returned by IOStream,
- # this one passes the underlying error through directly
- # instead of wrapping everything in a StreamClosedError
- # with a real_error attribute. This is because once the
- # connection is established it's more helpful to raise
- # the SSLError directly than to hide it behind a
- # StreamClosedError (and the client is expecting SSL
- # issues rather than network issues since this method is
- # named start_tls).
- future.set_exception(ssl_stream.error or StreamClosedError())
- if orig_close_callback is not None:
- orig_close_callback()
- ssl_stream.set_close_callback(close_callback)
- ssl_stream._ssl_connect_callback = lambda: future.set_result(ssl_stream)
+ ssl_stream.set_close_callback(orig_close_callback)
+ ssl_stream._ssl_connect_future = future
ssl_stream.max_buffer_size = self.max_buffer_size
ssl_stream.read_chunk_size = self.read_chunk_size
return future
- def _handle_connect(self):
- err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
+ def _handle_connect(self) -> None:
+ try:
+ err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
+ except socket.error as e:
+ # Hurd doesn't allow SO_ERROR for loopback sockets because all
+ # errors for such sockets are reported synchronously.
+ if errno_from_exception(e) == errno.ENOPROTOOPT:
+ err = 0
if err != 0:
self.error = socket.error(err, os.strerror(err))
# IOLoop implementations may vary: some of them return
@@ -1419,30 +1312,32 @@ def _handle_connect(self):
# in that case a connection failure would be handled by the
# error path in _handle_events instead of here.
if self._connect_future is None:
- gen_log.warning("Connect error on fd %s: %s",
- self.socket.fileno(), errno.errorcode[err])
+ gen_log.warning(
+ "Connect error on fd %s: %s",
+ self.socket.fileno(),
+ errno.errorcode[err],
+ )
self.close()
return
- if self._connect_callback is not None:
- callback = self._connect_callback
- self._connect_callback = None
- self._run_callback(callback)
if self._connect_future is not None:
future = self._connect_future
self._connect_future = None
- future.set_result(self)
+ future_set_result_unless_cancelled(future, self)
self._connecting = False
- def set_nodelay(self, value):
- if (self.socket is not None and
- self.socket.family in (socket.AF_INET, socket.AF_INET6)):
+ def set_nodelay(self, value: bool) -> None:
+ if self.socket is not None and self.socket.family in (
+ socket.AF_INET,
+ socket.AF_INET6,
+ ):
try:
- self.socket.setsockopt(socket.IPPROTO_TCP,
- socket.TCP_NODELAY, 1 if value else 0)
+ self.socket.setsockopt(
+ socket.IPPROTO_TCP, socket.TCP_NODELAY, 1 if value else 0
+ )
except socket.error as e:
# Sometimes setsockopt will fail if the socket is closed
# at the wrong time. This can happen with HTTPServer
- # resetting the value to false between requests.
+ # resetting the value to ``False`` between requests.
if e.errno != errno.EINVAL and not self._is_connreset(e):
raise
@@ -1458,18 +1353,20 @@ class SSLIOStream(IOStream):
before constructing the `SSLIOStream`. Unconnected sockets will be
wrapped when `IOStream.connect` is finished.
"""
- def __init__(self, *args, **kwargs):
+
+ socket = None # type: ssl.SSLSocket
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
"""The ``ssl_options`` keyword argument may either be an
`ssl.SSLContext` object or a dictionary of keywords arguments
for `ssl.wrap_socket`
"""
- self._ssl_options = kwargs.pop('ssl_options', _client_ssl_defaults)
- super(SSLIOStream, self).__init__(*args, **kwargs)
+ self._ssl_options = kwargs.pop("ssl_options", _client_ssl_defaults)
+ super().__init__(*args, **kwargs)
self._ssl_accepting = True
self._handshake_reading = False
self._handshake_writing = False
- self._ssl_connect_callback = None
- self._server_hostname = None
+ self._server_hostname = None # type: Optional[str]
# If the socket is already connected, attempt to start the handshake.
try:
@@ -1482,13 +1379,13 @@ def __init__(self, *args, **kwargs):
# _handle_events.
self._add_io_state(self.io_loop.WRITE)
- def reading(self):
- return self._handshake_reading or super(SSLIOStream, self).reading()
+ def reading(self) -> bool:
+ return self._handshake_reading or super().reading()
- def writing(self):
- return self._handshake_writing or super(SSLIOStream, self).writing()
+ def writing(self) -> bool:
+ return self._handshake_writing or super().writing()
- def _do_ssl_handshake(self):
+ def _do_ssl_handshake(self) -> None:
# Based on code from test_ssl.py in the python stdlib
try:
self._handshake_reading = False
@@ -1501,25 +1398,36 @@ def _do_ssl_handshake(self):
elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
self._handshake_writing = True
return
- elif err.args[0] in (ssl.SSL_ERROR_EOF,
- ssl.SSL_ERROR_ZERO_RETURN):
+ elif err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN):
return self.close(exc_info=err)
elif err.args[0] == ssl.SSL_ERROR_SSL:
try:
peer = self.socket.getpeername()
except Exception:
- peer = '(not connected)'
- gen_log.warning("SSL Error on %s %s: %s",
- self.socket.fileno(), peer, err)
+ peer = "(not connected)"
+ gen_log.warning(
+ "SSL Error on %s %s: %s", self.socket.fileno(), peer, err
+ )
return self.close(exc_info=err)
raise
+ except ssl.CertificateError as err:
+ # CertificateError can happen during handshake (hostname
+ # verification) and should be passed to user. Starting
+ # in Python 3.7, this error is a subclass of SSLError
+ # and will be handled by the previous block instead.
+ return self.close(exc_info=err)
except socket.error as err:
# Some port scans (e.g. nmap in -sT mode) have been known
# to cause do_handshake to raise EBADF and ENOTCONN, so make
# those errors quiet as well.
# https://groups.google.com/forum/?fromgroups#!topic/python-tornado/ApucKJat1_0
- if (self._is_connreset(err) or
- err.args[0] in (errno.EBADF, errno.ENOTCONN)):
+ # Errno 0 is also possible in some cases (nc -z).
+ # https://github.com/tornadoweb/tornado/issues/2504
+ if self._is_connreset(err) or err.args[0] in (
+ 0,
+ errno.EBADF,
+ errno.ENOTCONN,
+ ):
return self.close(exc_info=err)
raise
except AttributeError as err:
@@ -1532,20 +1440,16 @@ def _do_ssl_handshake(self):
if not self._verify_cert(self.socket.getpeercert()):
self.close()
return
- self._run_ssl_connect_callback()
+ self._finish_ssl_connect()
- def _run_ssl_connect_callback(self):
- if self._ssl_connect_callback is not None:
- callback = self._ssl_connect_callback
- self._ssl_connect_callback = None
- self._run_callback(callback)
+ def _finish_ssl_connect(self) -> None:
if self._ssl_connect_future is not None:
future = self._ssl_connect_future
self._ssl_connect_future = None
- future.set_result(self)
+ future_set_result_unless_cancelled(future, self)
- def _verify_cert(self, peercert):
- """Returns True if peercert is valid according to the configured
+ def _verify_cert(self, peercert: Any) -> bool:
+ """Returns ``True`` if peercert is valid according to the configured
validation mode and hostname.
The ssl handshake already tested the certificate for a valid
@@ -1553,7 +1457,7 @@ def _verify_cert(self, peercert):
the hostname.
"""
if isinstance(self._ssl_options, dict):
- verify_mode = self._ssl_options.get('cert_reqs', ssl.CERT_NONE)
+ verify_mode = self._ssl_options.get("cert_reqs", ssl.CERT_NONE)
elif isinstance(self._ssl_options, ssl.SSLContext):
verify_mode = self._ssl_options.verify_mode
assert verify_mode in (ssl.CERT_NONE, ssl.CERT_REQUIRED, ssl.CERT_OPTIONAL)
@@ -1571,32 +1475,40 @@ def _verify_cert(self, peercert):
else:
return True
- def _handle_read(self):
+ def _handle_read(self) -> None:
if self._ssl_accepting:
self._do_ssl_handshake()
return
- super(SSLIOStream, self)._handle_read()
+ super()._handle_read()
- def _handle_write(self):
+ def _handle_write(self) -> None:
if self._ssl_accepting:
self._do_ssl_handshake()
return
- super(SSLIOStream, self)._handle_write()
+ super()._handle_write()
- def connect(self, address, callback=None, server_hostname=None):
+ def connect(
+ self, address: Tuple, server_hostname: Optional[str] = None
+ ) -> "Future[SSLIOStream]":
self._server_hostname = server_hostname
# Ignore the result of connect(). If it fails,
# wait_for_handshake will raise an error too. This is
# necessary for the old semantics of the connect callback
# (which takes no arguments). In 6.0 this can be refactored to
# be a regular coroutine.
- fut = super(SSLIOStream, self).connect(address)
+ # TODO: This is trickier than it looks, since if write()
+ # is called with a connect() pending, we want the connect
+ # to resolve before the write. Or do we care about this?
+ # (There's a test for it, but I think in practice users
+ # either wait for the connect before performing a write or
+ # they don't care about the connect Future at all)
+ fut = super().connect(address)
fut.add_done_callback(lambda f: f.exception())
- return self.wait_for_handshake(callback)
+ return self.wait_for_handshake()
- def _handle_connect(self):
+ def _handle_connect(self) -> None:
# Call the superclass method to check for errors.
- super(SSLIOStream, self)._handle_connect()
+ super()._handle_connect()
if self.closed():
return
# When the connection is complete, wrap the socket for SSL
@@ -1611,13 +1523,17 @@ def _handle_connect(self):
# wrap_socket().
self.io_loop.remove_handler(self.socket)
old_state = self._state
+ assert old_state is not None
self._state = None
- self.socket = ssl_wrap_socket(self.socket, self._ssl_options,
- server_hostname=self._server_hostname,
- do_handshake_on_connect=False)
+ self.socket = ssl_wrap_socket(
+ self.socket,
+ self._ssl_options,
+ server_hostname=self._server_hostname,
+ do_handshake_on_connect=False,
+ )
self._add_io_state(old_state)
- def wait_for_handshake(self, callback=None):
+ def wait_for_handshake(self) -> "Future[SSLIOStream]":
"""Wait for the initial SSL handshake to complete.
If a ``callback`` is given, it will be called with no
@@ -1636,29 +1552,22 @@ def wait_for_handshake(self, callback=None):
.. versionadded:: 4.2
- .. deprecated:: 5.1
+ .. versionchanged:: 6.0
- The ``callback`` argument is deprecated and will be removed
- in Tornado 6.0. Use the returned `.Future` instead.
+ The ``callback`` argument was removed. Use the returned
+ `.Future` instead.
"""
- if (self._ssl_connect_callback is not None or
- self._ssl_connect_future is not None):
+ if self._ssl_connect_future is not None:
raise RuntimeError("Already waiting")
- if callback is not None:
- warnings.warn("callback argument is deprecated, use returned Future instead",
- DeprecationWarning)
- self._ssl_connect_callback = stack_context.wrap(callback)
- future = None
- else:
- future = self._ssl_connect_future = Future()
+ future = self._ssl_connect_future = Future()
if not self._ssl_accepting:
- self._run_ssl_connect_callback()
+ self._finish_ssl_connect()
return future
- def write_to_fd(self, data):
+ def write_to_fd(self, data: memoryview) -> int:
try:
- return self.socket.send(data)
+ return self.socket.send(data) # type: ignore
except ssl.SSLError as e:
if e.args[0] == ssl.SSL_ERROR_WANT_WRITE:
# In Python 3.5+, SSLSocket.send raises a WANT_WRITE error if
@@ -1674,15 +1583,20 @@ def write_to_fd(self, data):
# See https://github.com/tornadoweb/tornado/pull/2008
del data
- def read_from_fd(self, buf):
+ def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]:
try:
if self._ssl_accepting:
# If the handshake hasn't finished yet, there can't be anything
# to read (attempting to read may or may not raise an exception
# depending on the SSL version)
return None
+ # clip buffer size at 1GB since SSL sockets only support upto 2GB
+ # this change in behaviour is transparent, since the function is
+ # already expected to (possibly) read less than the provided buffer
+ if len(buf) >> 30:
+ buf = memoryview(buf)[: 1 << 30]
try:
- return self.socket.recv_into(buf)
+ return self.socket.recv_into(buf, len(buf))
except ssl.SSLError as e:
# SSLError is a subclass of socket.error, so this except
# block must come first.
@@ -1690,18 +1604,15 @@ def read_from_fd(self, buf):
return None
else:
raise
- except socket.error as e:
- if e.args[0] in _ERRNO_WOULDBLOCK:
- return None
- else:
- raise
+ except BlockingIOError:
+ return None
finally:
- buf = None
+ del buf
- def _is_connreset(self, e):
+ def _is_connreset(self, e: BaseException) -> bool:
if isinstance(e, ssl.SSLError) and e.args[0] == ssl.SSL_ERROR_EOF:
return True
- return super(SSLIOStream, self)._is_connreset(e)
+ return super()._is_connreset(e)
class PipeIOStream(BaseIOStream):
@@ -1711,30 +1622,40 @@ class PipeIOStream(BaseIOStream):
by `os.pipe`) rather than an open file object. Pipes are generally
one-way, so a `PipeIOStream` can be used for reading or writing but not
both.
+
+ ``PipeIOStream`` is only available on Unix-based platforms.
"""
- def __init__(self, fd, *args, **kwargs):
+
+ def __init__(self, fd: int, *args: Any, **kwargs: Any) -> None:
self.fd = fd
self._fio = io.FileIO(self.fd, "r+")
- _set_nonblocking(fd)
- super(PipeIOStream, self).__init__(*args, **kwargs)
-
- def fileno(self):
+ if sys.platform == "win32":
+ # The form and placement of this assertion is important to mypy.
+ # A plain assert statement isn't recognized here. If the assertion
+ # were earlier it would worry that the attributes of self aren't
+ # set on windows. If it were missing it would complain about
+ # the absence of the set_blocking function.
+ raise AssertionError("PipeIOStream is not supported on Windows")
+ os.set_blocking(fd, False)
+ super().__init__(*args, **kwargs)
+
+ def fileno(self) -> int:
return self.fd
- def close_fd(self):
+ def close_fd(self) -> None:
self._fio.close()
- def write_to_fd(self, data):
+ def write_to_fd(self, data: memoryview) -> int:
try:
- return os.write(self.fd, data)
+ return os.write(self.fd, data) # type: ignore
finally:
# Avoid keeping to data, which can be a memoryview.
# See https://github.com/tornadoweb/tornado/pull/2008
del data
- def read_from_fd(self, buf):
+ def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]:
try:
- return self._fio.readinto(buf)
+ return self._fio.readinto(buf) # type: ignore
except (IOError, OSError) as e:
if errno_from_exception(e) == errno.EBADF:
# If the writing half of a pipe is closed, select will
@@ -1744,9 +1665,10 @@ def read_from_fd(self, buf):
else:
raise
finally:
- buf = None
+ del buf
-def doctests():
+def doctests() -> Any:
import doctest
+
return doctest.DocTestSuite()
diff --git a/tornado/locale.py b/tornado/locale.py
index d45172f3b8..533ce4d41c 100644
--- a/tornado/locale.py
+++ b/tornado/locale.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
@@ -39,30 +37,29 @@
the `Locale.translate` method will simply return the original string.
"""
-from __future__ import absolute_import, division, print_function
-
import codecs
import csv
import datetime
-from io import BytesIO
-import numbers
+import gettext
+import glob
import os
import re
from tornado import escape
from tornado.log import gen_log
-from tornado.util import PY3
from tornado._locale_data import LOCALE_NAMES
+from typing import Iterable, Any, Union, Dict, Optional
+
_default_locale = "en_US"
-_translations = {} # type: dict
+_translations = {} # type: Dict[str, Any]
_supported_locales = frozenset([_default_locale])
_use_gettext = False
CONTEXT_SEPARATOR = "\x04"
-def get(*locale_codes):
+def get(*locale_codes: str) -> "Locale":
"""Returns the closest match for the given locale codes.
We iterate over all given locale codes in order. If we have a tight
@@ -76,7 +73,7 @@ def get(*locale_codes):
return Locale.get_closest(*locale_codes)
-def set_default_locale(code):
+def set_default_locale(code: str) -> None:
"""Sets the default locale.
The default locale is assumed to be the language used for all strings
@@ -90,7 +87,7 @@ def set_default_locale(code):
_supported_locales = frozenset(list(_translations.keys()) + [_default_locale])
-def load_translations(directory, encoding=None):
+def load_translations(directory: str, encoding: Optional[str] = None) -> None:
"""Loads translations from CSV files in a directory.
Translations are strings with optional Python-style named placeholders
@@ -133,54 +130,51 @@ def load_translations(directory, encoding=None):
continue
locale, extension = path.split(".")
if not re.match("[a-z]+(_[A-Z]+)?$", locale):
- gen_log.error("Unrecognized locale %r (path: %s)", locale,
- os.path.join(directory, path))
+ gen_log.error(
+ "Unrecognized locale %r (path: %s)",
+ locale,
+ os.path.join(directory, path),
+ )
continue
full_path = os.path.join(directory, path)
if encoding is None:
# Try to autodetect encoding based on the BOM.
- with open(full_path, 'rb') as f:
- data = f.read(len(codecs.BOM_UTF16_LE))
+ with open(full_path, "rb") as bf:
+ data = bf.read(len(codecs.BOM_UTF16_LE))
if data in (codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE):
- encoding = 'utf-16'
+ encoding = "utf-16"
else:
# utf-8-sig is "utf-8 with optional BOM". It's discouraged
# in most cases but is common with CSV files because Excel
# cannot read utf-8 files without a BOM.
- encoding = 'utf-8-sig'
- if PY3:
- # python 3: csv.reader requires a file open in text mode.
- # Force utf8 to avoid dependence on $LANG environment variable.
- f = open(full_path, "r", encoding=encoding)
- else:
- # python 2: csv can only handle byte strings (in ascii-compatible
- # encodings), which we decode below. Transcode everything into
- # utf8 before passing it to csv.reader.
- f = BytesIO()
- with codecs.open(full_path, "r", encoding=encoding) as infile:
- f.write(escape.utf8(infile.read()))
- f.seek(0)
- _translations[locale] = {}
- for i, row in enumerate(csv.reader(f)):
- if not row or len(row) < 2:
- continue
- row = [escape.to_unicode(c).strip() for c in row]
- english, translation = row[:2]
- if len(row) > 2:
- plural = row[2] or "unknown"
- else:
- plural = "unknown"
- if plural not in ("plural", "singular", "unknown"):
- gen_log.error("Unrecognized plural indicator %r in %s line %d",
- plural, path, i + 1)
- continue
- _translations[locale].setdefault(plural, {})[english] = translation
- f.close()
+ encoding = "utf-8-sig"
+ # python 3: csv.reader requires a file open in text mode.
+ # Specify an encoding to avoid dependence on $LANG environment variable.
+ with open(full_path, encoding=encoding) as f:
+ _translations[locale] = {}
+ for i, row in enumerate(csv.reader(f)):
+ if not row or len(row) < 2:
+ continue
+ row = [escape.to_unicode(c).strip() for c in row]
+ english, translation = row[:2]
+ if len(row) > 2:
+ plural = row[2] or "unknown"
+ else:
+ plural = "unknown"
+ if plural not in ("plural", "singular", "unknown"):
+ gen_log.error(
+ "Unrecognized plural indicator %r in %s line %d",
+ plural,
+ path,
+ i + 1,
+ )
+ continue
+ _translations[locale].setdefault(plural, {})[english] = translation
_supported_locales = frozenset(list(_translations.keys()) + [_default_locale])
gen_log.debug("Supported locales: %s", sorted(_supported_locales))
-def load_gettext_translations(directory, domain):
+def load_gettext_translations(directory: str, domain: str) -> None:
"""Loads translations from `gettext`'s locale tree
Locale tree is similar to system's ``/usr/share/locale``, like::
@@ -201,20 +195,19 @@ def load_gettext_translations(directory, domain):
msgfmt mydomain.po -o {directory}/pt_BR/LC_MESSAGES/mydomain.mo
"""
- import gettext
global _translations
global _supported_locales
global _use_gettext
_translations = {}
- for lang in os.listdir(directory):
- if lang.startswith('.'):
- continue # skip .svn, etc
- if os.path.isfile(os.path.join(directory, lang)):
- continue
+
+ for filename in glob.glob(
+ os.path.join(directory, "*", "LC_MESSAGES", domain + ".mo")
+ ):
+ lang = os.path.basename(os.path.dirname(os.path.dirname(filename)))
try:
- os.stat(os.path.join(directory, lang, "LC_MESSAGES", domain + ".mo"))
- _translations[lang] = gettext.translation(domain, directory,
- languages=[lang])
+ _translations[lang] = gettext.translation(
+ domain, directory, languages=[lang]
+ )
except Exception as e:
gen_log.error("Cannot load translation for '%s': %s", lang, str(e))
continue
@@ -223,7 +216,7 @@ def load_gettext_translations(directory, domain):
gen_log.debug("Supported locales: %s", sorted(_supported_locales))
-def get_supported_locales():
+def get_supported_locales() -> Iterable[str]:
"""Returns a list of all the supported locale codes."""
return _supported_locales
@@ -234,8 +227,11 @@ class Locale(object):
After calling one of `load_translations` or `load_gettext_translations`,
call `get` or `get_closest` to get a Locale object.
"""
+
+ _cache = {} # type: Dict[str, Locale]
+
@classmethod
- def get_closest(cls, *locale_codes):
+ def get_closest(cls, *locale_codes: str) -> "Locale":
"""Returns the closest match for the given locale code."""
for code in locale_codes:
if not code:
@@ -253,18 +249,16 @@ def get_closest(cls, *locale_codes):
return cls.get(_default_locale)
@classmethod
- def get(cls, code):
+ def get(cls, code: str) -> "Locale":
"""Returns the Locale for the given locale code.
If it is not supported, we raise an exception.
"""
- if not hasattr(cls, "_cache"):
- cls._cache = {}
if code not in cls._cache:
assert code in _supported_locales
translations = _translations.get(code, None)
if translations is None:
- locale = CSVLocale(code, {})
+ locale = CSVLocale(code, {}) # type: Locale
elif _use_gettext:
locale = GettextLocale(code, translations)
else:
@@ -272,7 +266,7 @@ def get(cls, code):
cls._cache[code] = locale
return cls._cache[code]
- def __init__(self, code, translations):
+ def __init__(self, code: str) -> None:
self.code = code
self.name = LOCALE_NAMES.get(code, {}).get("name", u"Unknown")
self.rtl = False
@@ -280,19 +274,39 @@ def __init__(self, code, translations):
if self.code.startswith(prefix):
self.rtl = True
break
- self.translations = translations
# Initialize strings for date formatting
_ = self.translate
self._months = [
- _("January"), _("February"), _("March"), _("April"),
- _("May"), _("June"), _("July"), _("August"),
- _("September"), _("October"), _("November"), _("December")]
+ _("January"),
+ _("February"),
+ _("March"),
+ _("April"),
+ _("May"),
+ _("June"),
+ _("July"),
+ _("August"),
+ _("September"),
+ _("October"),
+ _("November"),
+ _("December"),
+ ]
self._weekdays = [
- _("Monday"), _("Tuesday"), _("Wednesday"), _("Thursday"),
- _("Friday"), _("Saturday"), _("Sunday")]
-
- def translate(self, message, plural_message=None, count=None):
+ _("Monday"),
+ _("Tuesday"),
+ _("Wednesday"),
+ _("Thursday"),
+ _("Friday"),
+ _("Saturday"),
+ _("Sunday"),
+ ]
+
+ def translate(
+ self,
+ message: str,
+ plural_message: Optional[str] = None,
+ count: Optional[int] = None,
+ ) -> str:
"""Returns the translation for the given message for this locale.
If ``plural_message`` is given, you must also provide
@@ -302,11 +316,23 @@ def translate(self, message, plural_message=None, count=None):
"""
raise NotImplementedError()
- def pgettext(self, context, message, plural_message=None, count=None):
+ def pgettext(
+ self,
+ context: str,
+ message: str,
+ plural_message: Optional[str] = None,
+ count: Optional[int] = None,
+ ) -> str:
raise NotImplementedError()
- def format_date(self, date, gmt_offset=0, relative=True, shorter=False,
- full_format=False):
+ def format_date(
+ self,
+ date: Union[int, float, datetime.datetime],
+ gmt_offset: int = 0,
+ relative: bool = True,
+ shorter: bool = False,
+ full_format: bool = False,
+ ) -> str:
"""Formats the given date (which should be GMT).
By default, we return a relative time (e.g., "2 minutes ago"). You
@@ -318,7 +344,7 @@ def format_date(self, date, gmt_offset=0, relative=True, shorter=False,
This method is primarily intended for dates in the past.
For dates in the future, we fall back to full format.
"""
- if isinstance(date, numbers.Real):
+ if isinstance(date, (int, float)):
date = datetime.datetime.utcfromtimestamp(date)
now = datetime.datetime.utcnow()
if date > now:
@@ -342,56 +368,66 @@ def format_date(self, date, gmt_offset=0, relative=True, shorter=False,
if not full_format:
if relative and days == 0:
if seconds < 50:
- return _("1 second ago", "%(seconds)d seconds ago",
- seconds) % {"seconds": seconds}
+ return _("1 second ago", "%(seconds)d seconds ago", seconds) % {
+ "seconds": seconds
+ }
if seconds < 50 * 60:
minutes = round(seconds / 60.0)
- return _("1 minute ago", "%(minutes)d minutes ago",
- minutes) % {"minutes": minutes}
+ return _("1 minute ago", "%(minutes)d minutes ago", minutes) % {
+ "minutes": minutes
+ }
hours = round(seconds / (60.0 * 60))
- return _("1 hour ago", "%(hours)d hours ago",
- hours) % {"hours": hours}
+ return _("1 hour ago", "%(hours)d hours ago", hours) % {"hours": hours}
if days == 0:
format = _("%(time)s")
- elif days == 1 and local_date.day == local_yesterday.day and \
- relative:
- format = _("yesterday") if shorter else \
- _("yesterday at %(time)s")
+ elif days == 1 and local_date.day == local_yesterday.day and relative:
+ format = _("yesterday") if shorter else _("yesterday at %(time)s")
elif days < 5:
- format = _("%(weekday)s") if shorter else \
- _("%(weekday)s at %(time)s")
+ format = _("%(weekday)s") if shorter else _("%(weekday)s at %(time)s")
elif days < 334: # 11mo, since confusing for same month last year
- format = _("%(month_name)s %(day)s") if shorter else \
- _("%(month_name)s %(day)s at %(time)s")
+ format = (
+ _("%(month_name)s %(day)s")
+ if shorter
+ else _("%(month_name)s %(day)s at %(time)s")
+ )
if format is None:
- format = _("%(month_name)s %(day)s, %(year)s") if shorter else \
- _("%(month_name)s %(day)s, %(year)s at %(time)s")
+ format = (
+ _("%(month_name)s %(day)s, %(year)s")
+ if shorter
+ else _("%(month_name)s %(day)s, %(year)s at %(time)s")
+ )
tfhour_clock = self.code not in ("en", "en_US", "zh_CN")
if tfhour_clock:
str_time = "%d:%02d" % (local_date.hour, local_date.minute)
elif self.code == "zh_CN":
str_time = "%s%d:%02d" % (
- (u'\u4e0a\u5348', u'\u4e0b\u5348')[local_date.hour >= 12],
- local_date.hour % 12 or 12, local_date.minute)
+ (u"\u4e0a\u5348", u"\u4e0b\u5348")[local_date.hour >= 12],
+ local_date.hour % 12 or 12,
+ local_date.minute,
+ )
else:
str_time = "%d:%02d %s" % (
- local_date.hour % 12 or 12, local_date.minute,
- ("am", "pm")[local_date.hour >= 12])
+ local_date.hour % 12 or 12,
+ local_date.minute,
+ ("am", "pm")[local_date.hour >= 12],
+ )
return format % {
"month_name": self._months[local_date.month - 1],
"weekday": self._weekdays[local_date.weekday()],
"day": str(local_date.day),
"year": str(local_date.year),
- "time": str_time
+ "time": str_time,
}
- def format_day(self, date, gmt_offset=0, dow=True):
+ def format_day(
+ self, date: datetime.datetime, gmt_offset: int = 0, dow: bool = True
+ ) -> bool:
"""Formats the given date as a day of week.
Example: "Monday, January 22". You can remove the day of week with
@@ -411,7 +447,7 @@ def format_day(self, date, gmt_offset=0, dow=True):
"day": str(local_date.day),
}
- def list(self, parts):
+ def list(self, parts: Any) -> str:
"""Returns a comma-separated list for the given list of parts.
The format is, e.g., "A, B and C", "A and B" or just "A" for lists
@@ -422,27 +458,37 @@ def list(self, parts):
return ""
if len(parts) == 1:
return parts[0]
- comma = u' \u0648 ' if self.code.startswith("fa") else u", "
+ comma = u" \u0648 " if self.code.startswith("fa") else u", "
return _("%(commas)s and %(last)s") % {
"commas": comma.join(parts[:-1]),
"last": parts[len(parts) - 1],
}
- def friendly_number(self, value):
+ def friendly_number(self, value: int) -> str:
"""Returns a comma-separated number for the given integer."""
if self.code not in ("en", "en_US"):
return str(value)
- value = str(value)
+ s = str(value)
parts = []
- while value:
- parts.append(value[-3:])
- value = value[:-3]
+ while s:
+ parts.append(s[-3:])
+ s = s[:-3]
return ",".join(reversed(parts))
class CSVLocale(Locale):
"""Locale implementation using tornado's CSV translation format."""
- def translate(self, message, plural_message=None, count=None):
+
+ def __init__(self, code: str, translations: Dict[str, Dict[str, str]]) -> None:
+ self.translations = translations
+ super().__init__(code)
+
+ def translate(
+ self,
+ message: str,
+ plural_message: Optional[str] = None,
+ count: Optional[int] = None,
+ ) -> str:
if plural_message is not None:
assert count is not None
if count != 1:
@@ -454,35 +500,47 @@ def translate(self, message, plural_message=None, count=None):
message_dict = self.translations.get("unknown", {})
return message_dict.get(message, message)
- def pgettext(self, context, message, plural_message=None, count=None):
+ def pgettext(
+ self,
+ context: str,
+ message: str,
+ plural_message: Optional[str] = None,
+ count: Optional[int] = None,
+ ) -> str:
if self.translations:
- gen_log.warning('pgettext is not supported by CSVLocale')
+ gen_log.warning("pgettext is not supported by CSVLocale")
return self.translate(message, plural_message, count)
class GettextLocale(Locale):
"""Locale implementation using the `gettext` module."""
- def __init__(self, code, translations):
- try:
- # python 2
- self.ngettext = translations.ungettext
- self.gettext = translations.ugettext
- except AttributeError:
- # python 3
- self.ngettext = translations.ngettext
- self.gettext = translations.gettext
+
+ def __init__(self, code: str, translations: gettext.NullTranslations) -> None:
+ self.ngettext = translations.ngettext
+ self.gettext = translations.gettext
# self.gettext must exist before __init__ is called, since it
# calls into self.translate
- super(GettextLocale, self).__init__(code, translations)
-
- def translate(self, message, plural_message=None, count=None):
+ super().__init__(code)
+
+ def translate(
+ self,
+ message: str,
+ plural_message: Optional[str] = None,
+ count: Optional[int] = None,
+ ) -> str:
if plural_message is not None:
assert count is not None
return self.ngettext(message, plural_message, count)
else:
return self.gettext(message)
- def pgettext(self, context, message, plural_message=None, count=None):
+ def pgettext(
+ self,
+ context: str,
+ message: str,
+ plural_message: Optional[str] = None,
+ count: Optional[int] = None,
+ ) -> str:
"""Allows to set context for translation, accepts plural forms.
Usage example::
@@ -504,9 +562,11 @@ def pgettext(self, context, message, plural_message=None, count=None):
"""
if plural_message is not None:
assert count is not None
- msgs_with_ctxt = ("%s%s%s" % (context, CONTEXT_SEPARATOR, message),
- "%s%s%s" % (context, CONTEXT_SEPARATOR, plural_message),
- count)
+ msgs_with_ctxt = (
+ "%s%s%s" % (context, CONTEXT_SEPARATOR, message),
+ "%s%s%s" % (context, CONTEXT_SEPARATOR, plural_message),
+ count,
+ )
result = self.ngettext(*msgs_with_ctxt)
if CONTEXT_SEPARATOR in result:
# Translation not found
diff --git a/tornado/locks.py b/tornado/locks.py
index 94adb322e1..29b6b41299 100644
--- a/tornado/locks.py
+++ b/tornado/locks.py
@@ -12,15 +12,20 @@
# License for the specific language governing permissions and limitations
# under the License.
-from __future__ import absolute_import, division, print_function
-
import collections
-from concurrent.futures import CancelledError
+import datetime
+import types
from tornado import gen, ioloop
from tornado.concurrent import Future, future_set_result_unless_cancelled
-__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock']
+from typing import Union, Optional, Type, Any, Awaitable
+import typing
+
+if typing.TYPE_CHECKING:
+ from typing import Deque, Set # noqa: F401
+
+__all__ = ["Condition", "Event", "Semaphore", "BoundedSemaphore", "Lock"]
class _TimeoutGarbageCollector(object):
@@ -32,17 +37,17 @@ class _TimeoutGarbageCollector(object):
yield condition.wait(short_timeout)
print('looping....')
"""
- def __init__(self):
- self._waiters = collections.deque() # Futures.
+
+ def __init__(self) -> None:
+ self._waiters = collections.deque() # type: Deque[Future]
self._timeouts = 0
- def _garbage_collect(self):
+ def _garbage_collect(self) -> None:
# Occasionally clear timed-out waiters.
self._timeouts += 1
if self._timeouts > 100:
self._timeouts = 0
- self._waiters = collections.deque(
- w for w in self._waiters if not w.done())
+ self._waiters = collections.deque(w for w in self._waiters if not w.done())
class Condition(_TimeoutGarbageCollector):
@@ -61,22 +66,19 @@ class Condition(_TimeoutGarbageCollector):
condition = Condition()
- @gen.coroutine
- def waiter():
+ async def waiter():
print("I'll wait right here")
- yield condition.wait() # Yield a Future.
+ await condition.wait()
print("I'm done waiting")
- @gen.coroutine
- def notifier():
+ async def notifier():
print("About to notify")
condition.notify()
print("Done notifying")
- @gen.coroutine
- def runner():
- # Yield two Futures; wait for waiter() and notifier() to finish.
- yield [waiter(), notifier()]
+ async def runner():
+ # Wait for waiter() and notifier() in parallel
+ await gen.multi([waiter(), notifier()])
IOLoop.current().run_sync(runner)
@@ -93,12 +95,12 @@ def runner():
io_loop = IOLoop.current()
# Wait up to 1 second for a notification.
- yield condition.wait(timeout=io_loop.time() + 1)
+ await condition.wait(timeout=io_loop.time() + 1)
...or a `datetime.timedelta` for a timeout relative to the current time::
# Wait up to 1 second.
- yield condition.wait(timeout=datetime.timedelta(seconds=1))
+ await condition.wait(timeout=datetime.timedelta(seconds=1))
The method returns False if there's no notification before the deadline.
@@ -108,36 +110,39 @@ def runner():
next iteration of the `.IOLoop`.
"""
- def __init__(self):
- super(Condition, self).__init__()
+ def __init__(self) -> None:
+ super().__init__()
self.io_loop = ioloop.IOLoop.current()
- def __repr__(self):
- result = '<%s' % (self.__class__.__name__, )
+ def __repr__(self) -> str:
+ result = "<%s" % (self.__class__.__name__,)
if self._waiters:
- result += ' waiters[%s]' % len(self._waiters)
- return result + '>'
+ result += " waiters[%s]" % len(self._waiters)
+ return result + ">"
- def wait(self, timeout=None):
+ def wait(
+ self, timeout: Optional[Union[float, datetime.timedelta]] = None
+ ) -> Awaitable[bool]:
"""Wait for `.notify`.
Returns a `.Future` that resolves ``True`` if the condition is notified,
or ``False`` after a timeout.
"""
- waiter = Future()
+ waiter = Future() # type: Future[bool]
self._waiters.append(waiter)
if timeout:
- def on_timeout():
+
+ def on_timeout() -> None:
if not waiter.done():
future_set_result_unless_cancelled(waiter, False)
self._garbage_collect()
+
io_loop = ioloop.IOLoop.current()
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
- waiter.add_done_callback(
- lambda _: io_loop.remove_timeout(timeout_handle))
+ waiter.add_done_callback(lambda _: io_loop.remove_timeout(timeout_handle))
return waiter
- def notify(self, n=1):
+ def notify(self, n: int = 1) -> None:
"""Wake ``n`` waiters."""
waiters = [] # Waiters we plan to run right now.
while n and self._waiters:
@@ -149,7 +154,7 @@ def notify(self, n=1):
for waiter in waiters:
future_set_result_unless_cancelled(waiter, True)
- def notify_all(self):
+ def notify_all(self) -> None:
"""Wake all waiters."""
self.notify(len(self._waiters))
@@ -170,22 +175,19 @@ class Event(object):
event = Event()
- @gen.coroutine
- def waiter():
+ async def waiter():
print("Waiting for event")
- yield event.wait()
+ await event.wait()
print("Not waiting this time")
- yield event.wait()
+ await event.wait()
print("Done")
- @gen.coroutine
- def setter():
+ async def setter():
print("About to set the event")
event.set()
- @gen.coroutine
- def runner():
- yield [waiter(), setter()]
+ async def runner():
+ await gen.multi([waiter(), setter()])
IOLoop.current().run_sync(runner)
@@ -196,19 +198,22 @@ def runner():
Not waiting this time
Done
"""
- def __init__(self):
+
+ def __init__(self) -> None:
self._value = False
- self._waiters = set()
+ self._waiters = set() # type: Set[Future[None]]
- def __repr__(self):
- return '<%s %s>' % (
- self.__class__.__name__, 'set' if self.is_set() else 'clear')
+ def __repr__(self) -> str:
+ return "<%s %s>" % (
+ self.__class__.__name__,
+ "set" if self.is_set() else "clear",
+ )
- def is_set(self):
+ def is_set(self) -> bool:
"""Return ``True`` if the internal flag is true."""
return self._value
- def set(self):
+ def set(self) -> None:
"""Set the internal flag to ``True``. All waiters are awakened.
Calling `.wait` once the flag is set will not block.
@@ -220,20 +225,22 @@ def set(self):
if not fut.done():
fut.set_result(None)
- def clear(self):
+ def clear(self) -> None:
"""Reset the internal flag to ``False``.
Calls to `.wait` will block until `.set` is called.
"""
self._value = False
- def wait(self, timeout=None):
+ def wait(
+ self, timeout: Optional[Union[float, datetime.timedelta]] = None
+ ) -> Awaitable[None]:
"""Block until the internal flag is true.
- Returns a Future, which raises `tornado.util.TimeoutError` after a
+ Returns an awaitable, which raises `tornado.util.TimeoutError` after a
timeout.
"""
- fut = Future()
+ fut = Future() # type: Future[None]
if self._value:
fut.set_result(None)
return fut
@@ -242,29 +249,37 @@ def wait(self, timeout=None):
if timeout is None:
return fut
else:
- timeout_fut = gen.with_timeout(timeout, fut, quiet_exceptions=(CancelledError,))
+ timeout_fut = gen.with_timeout(timeout, fut)
# This is a slightly clumsy workaround for the fact that
# gen.with_timeout doesn't cancel its futures. Cancelling
# fut will remove it from the waiters list.
- timeout_fut.add_done_callback(lambda tf: fut.cancel() if not fut.done() else None)
+ timeout_fut.add_done_callback(
+ lambda tf: fut.cancel() if not fut.done() else None
+ )
return timeout_fut
class _ReleasingContextManager(object):
"""Releases a Lock or Semaphore at the end of a "with" statement.
- with (yield semaphore.acquire()):
- pass
+ with (yield semaphore.acquire()):
+ pass
- # Now semaphore.release() has been called.
+ # Now semaphore.release() has been called.
"""
- def __init__(self, obj):
+
+ def __init__(self, obj: Any) -> None:
self._obj = obj
- def __enter__(self):
+ def __enter__(self) -> None:
pass
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(
+ self,
+ exc_type: "Optional[Type[BaseException]]",
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[types.TracebackType],
+ ) -> None:
self._obj.release()
@@ -290,12 +305,11 @@ class Semaphore(_TimeoutGarbageCollector):
# Ensure reliable doctest output: resolve Futures one at a time.
futures_q = deque([Future() for _ in range(3)])
- @gen.coroutine
- def simulator(futures):
+ async def simulator(futures):
for f in futures:
# simulate the asynchronous passage of time
- yield gen.moment
- yield gen.moment
+ await gen.sleep(0)
+ await gen.sleep(0)
f.set_result(None)
IOLoop.current().add_callback(simulator, list(futures_q))
@@ -311,20 +325,18 @@ def use_some_resource():
sem = Semaphore(2)
- @gen.coroutine
- def worker(worker_id):
- yield sem.acquire()
+ async def worker(worker_id):
+ await sem.acquire()
try:
print("Worker %d is working" % worker_id)
- yield use_some_resource()
+ await use_some_resource()
finally:
print("Worker %d is done" % worker_id)
sem.release()
- @gen.coroutine
- def runner():
+ async def runner():
# Join all workers.
- yield [worker(i) for i in range(3)]
+ await gen.multi([worker(i) for i in range(3)])
IOLoop.current().run_sync(runner)
@@ -340,47 +352,50 @@ def runner():
Workers 0 and 1 are allowed to run concurrently, but worker 2 waits until
the semaphore has been released once, by worker 0.
- `.acquire` is a context manager, so ``worker`` could be written as::
+ The semaphore can be used as an async context manager::
- @gen.coroutine
- def worker(worker_id):
- with (yield sem.acquire()):
+ async def worker(worker_id):
+ async with sem:
print("Worker %d is working" % worker_id)
- yield use_some_resource()
+ await use_some_resource()
# Now the semaphore has been released.
print("Worker %d is done" % worker_id)
- In Python 3.5, the semaphore itself can be used as an async context
- manager::
+ For compatibility with older versions of Python, `.acquire` is a
+ context manager, so ``worker`` could also be written as::
- async def worker(worker_id):
- async with sem:
+ @gen.coroutine
+ def worker(worker_id):
+ with (yield sem.acquire()):
print("Worker %d is working" % worker_id)
- await use_some_resource()
+ yield use_some_resource()
# Now the semaphore has been released.
print("Worker %d is done" % worker_id)
.. versionchanged:: 4.3
Added ``async with`` support in Python 3.5.
+
"""
- def __init__(self, value=1):
- super(Semaphore, self).__init__()
+
+ def __init__(self, value: int = 1) -> None:
+ super().__init__()
if value < 0:
- raise ValueError('semaphore initial value must be >= 0')
+ raise ValueError("semaphore initial value must be >= 0")
self._value = value
- def __repr__(self):
- res = super(Semaphore, self).__repr__()
- extra = 'locked' if self._value == 0 else 'unlocked,value:{0}'.format(
- self._value)
+ def __repr__(self) -> str:
+ res = super().__repr__()
+ extra = (
+ "locked" if self._value == 0 else "unlocked,value:{0}".format(self._value)
+ )
if self._waiters:
- extra = '{0},waiters:{1}'.format(extra, len(self._waiters))
- return '<{0} [{1}]>'.format(res[1:-1], extra)
+ extra = "{0},waiters:{1}".format(extra, len(self._waiters))
+ return "<{0} [{1}]>".format(res[1:-1], extra)
- def release(self):
+ def release(self) -> None:
"""Increment the counter and wake one waiter."""
self._value += 1
while self._waiters:
@@ -397,42 +412,54 @@ def release(self):
waiter.set_result(_ReleasingContextManager(self))
break
- def acquire(self, timeout=None):
- """Decrement the counter. Returns a Future.
+ def acquire(
+ self, timeout: Optional[Union[float, datetime.timedelta]] = None
+ ) -> Awaitable[_ReleasingContextManager]:
+ """Decrement the counter. Returns an awaitable.
- Block if the counter is zero and wait for a `.release`. The Future
+ Block if the counter is zero and wait for a `.release`. The awaitable
raises `.TimeoutError` after the deadline.
"""
- waiter = Future()
+ waiter = Future() # type: Future[_ReleasingContextManager]
if self._value > 0:
self._value -= 1
waiter.set_result(_ReleasingContextManager(self))
else:
self._waiters.append(waiter)
if timeout:
- def on_timeout():
+
+ def on_timeout() -> None:
if not waiter.done():
waiter.set_exception(gen.TimeoutError())
self._garbage_collect()
+
io_loop = ioloop.IOLoop.current()
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
waiter.add_done_callback(
- lambda _: io_loop.remove_timeout(timeout_handle))
+ lambda _: io_loop.remove_timeout(timeout_handle)
+ )
return waiter
- def __enter__(self):
- raise RuntimeError(
- "Use Semaphore like 'with (yield semaphore.acquire())', not like"
- " 'with semaphore'")
-
- __exit__ = __enter__
-
- @gen.coroutine
- def __aenter__(self):
- yield self.acquire()
-
- @gen.coroutine
- def __aexit__(self, typ, value, tb):
+ def __enter__(self) -> None:
+ raise RuntimeError("Use 'async with' instead of 'with' for Semaphore")
+
+ def __exit__(
+ self,
+ typ: "Optional[Type[BaseException]]",
+ value: Optional[BaseException],
+ traceback: Optional[types.TracebackType],
+ ) -> None:
+ self.__enter__()
+
+ async def __aenter__(self) -> None:
+ await self.acquire()
+
+ async def __aexit__(
+ self,
+ typ: "Optional[Type[BaseException]]",
+ value: Optional[BaseException],
+ tb: Optional[types.TracebackType],
+ ) -> None:
self.release()
@@ -444,15 +471,16 @@ class BoundedSemaphore(Semaphore):
resources with limited capacity, so a semaphore released too many times
is a sign of a bug.
"""
- def __init__(self, value=1):
- super(BoundedSemaphore, self).__init__(value=value)
+
+ def __init__(self, value: int = 1) -> None:
+ super().__init__(value=value)
self._initial_value = value
- def release(self):
+ def release(self) -> None:
"""Increment the counter and wake one waiter."""
if self._value >= self._initial_value:
raise ValueError("Semaphore released too many times")
- super(BoundedSemaphore, self).release()
+ super().release()
class Lock(object):
@@ -464,26 +492,24 @@ class Lock(object):
Releasing an unlocked lock raises `RuntimeError`.
- `acquire` supports the context manager protocol in all Python versions:
+ A Lock can be used as an async context manager with the ``async
+ with`` statement:
- >>> from tornado import gen, locks
+ >>> from tornado import locks
>>> lock = locks.Lock()
>>>
- >>> @gen.coroutine
- ... def f():
- ... with (yield lock.acquire()):
+ >>> async def f():
+ ... async with lock:
... # Do something holding the lock.
... pass
...
... # Now the lock is released.
- In Python 3.5, `Lock` also supports the async context manager
- protocol. Note that in this case there is no `acquire`, because
- ``async with`` includes both the ``yield`` and the ``acquire``
- (just as it does with `threading.Lock`):
+ For compatibility with older versions of Python, the `.acquire`
+ method asynchronously returns a regular context manager:
- >>> async def f2(): # doctest: +SKIP
- ... async with lock:
+ >>> async def f2():
+ ... with (yield lock.acquire()):
... # Do something holding the lock.
... pass
...
@@ -493,23 +519,24 @@ class Lock(object):
Added ``async with`` support in Python 3.5.
"""
- def __init__(self):
+
+ def __init__(self) -> None:
self._block = BoundedSemaphore(value=1)
- def __repr__(self):
- return "<%s _block=%s>" % (
- self.__class__.__name__,
- self._block)
+ def __repr__(self) -> str:
+ return "<%s _block=%s>" % (self.__class__.__name__, self._block)
- def acquire(self, timeout=None):
- """Attempt to lock. Returns a Future.
+ def acquire(
+ self, timeout: Optional[Union[float, datetime.timedelta]] = None
+ ) -> Awaitable[_ReleasingContextManager]:
+ """Attempt to lock. Returns an awaitable.
- Returns a Future, which raises `tornado.util.TimeoutError` after a
+ Returns an awaitable, which raises `tornado.util.TimeoutError` after a
timeout.
"""
return self._block.acquire(timeout)
- def release(self):
+ def release(self) -> None:
"""Unlock.
The first coroutine in line waiting for `acquire` gets the lock.
@@ -519,18 +546,26 @@ def release(self):
try:
self._block.release()
except ValueError:
- raise RuntimeError('release unlocked lock')
-
- def __enter__(self):
- raise RuntimeError(
- "Use Lock like 'with (yield lock)', not like 'with lock'")
-
- __exit__ = __enter__
-
- @gen.coroutine
- def __aenter__(self):
- yield self.acquire()
-
- @gen.coroutine
- def __aexit__(self, typ, value, tb):
+ raise RuntimeError("release unlocked lock")
+
+ def __enter__(self) -> None:
+ raise RuntimeError("Use `async with` instead of `with` for Lock")
+
+ def __exit__(
+ self,
+ typ: "Optional[Type[BaseException]]",
+ value: Optional[BaseException],
+ tb: Optional[types.TracebackType],
+ ) -> None:
+ self.__enter__()
+
+ async def __aenter__(self) -> None:
+ await self.acquire()
+
+ async def __aexit__(
+ self,
+ typ: "Optional[Type[BaseException]]",
+ value: Optional[BaseException],
+ tb: Optional[types.TracebackType],
+ ) -> None:
self.release()
diff --git a/tornado/log.py b/tornado/log.py
index cda905c9ba..810a0373b5 100644
--- a/tornado/log.py
+++ b/tornado/log.py
@@ -27,8 +27,6 @@
`logging` module. For example, you may wish to send ``tornado.access`` logs
to a separate file for analysis.
"""
-from __future__ import absolute_import, division, print_function
-
import logging
import logging.handlers
import sys
@@ -37,14 +35,16 @@
from tornado.util import unicode_type, basestring_type
try:
- import colorama
+ import colorama # type: ignore
except ImportError:
colorama = None
try:
- import curses # type: ignore
+ import curses
except ImportError:
- curses = None
+ curses = None # type: ignore
+
+from typing import Dict, Any, cast, Optional
# Logger objects for internal tornado use
access_log = logging.getLogger("tornado.access")
@@ -52,16 +52,17 @@
gen_log = logging.getLogger("tornado.general")
-def _stderr_supports_color():
+def _stderr_supports_color() -> bool:
try:
- if hasattr(sys.stderr, 'isatty') and sys.stderr.isatty():
+ if hasattr(sys.stderr, "isatty") and sys.stderr.isatty():
if curses:
curses.setupterm()
if curses.tigetnum("colors") > 0:
return True
elif colorama:
- if sys.stderr is getattr(colorama.initialise, 'wrapped_stderr',
- object()):
+ if sys.stderr is getattr(
+ colorama.initialise, "wrapped_stderr", object()
+ ):
return True
except Exception:
# Very broad exception handling because it's always better to
@@ -70,7 +71,7 @@ def _stderr_supports_color():
return False
-def _safe_unicode(s):
+def _safe_unicode(s: Any) -> str:
try:
return _unicode(s)
except UnicodeDecodeError:
@@ -101,18 +102,25 @@ class LogFormatter(logging.Formatter):
Added support for ``colorama``. Changed the constructor
signature to be compatible with `logging.config.dictConfig`.
"""
- DEFAULT_FORMAT = \
- '%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s'
- DEFAULT_DATE_FORMAT = '%y%m%d %H:%M:%S'
+
+ DEFAULT_FORMAT = "%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s" # noqa: E501
+ DEFAULT_DATE_FORMAT = "%y%m%d %H:%M:%S"
DEFAULT_COLORS = {
logging.DEBUG: 4, # Blue
logging.INFO: 2, # Green
logging.WARNING: 3, # Yellow
logging.ERROR: 1, # Red
+ logging.CRITICAL: 5, # Magenta
}
- def __init__(self, fmt=DEFAULT_FORMAT, datefmt=DEFAULT_DATE_FORMAT,
- style='%', color=True, colors=DEFAULT_COLORS):
+ def __init__(
+ self,
+ fmt: str = DEFAULT_FORMAT,
+ datefmt: str = DEFAULT_DATE_FORMAT,
+ style: str = "%",
+ color: bool = True,
+ colors: Dict[int, int] = DEFAULT_COLORS,
+ ) -> None:
r"""
:arg bool color: Enables color support.
:arg str fmt: Log message format.
@@ -131,34 +139,29 @@ def __init__(self, fmt=DEFAULT_FORMAT, datefmt=DEFAULT_DATE_FORMAT,
logging.Formatter.__init__(self, datefmt=datefmt)
self._fmt = fmt
- self._colors = {}
+ self._colors = {} # type: Dict[int, str]
if color and _stderr_supports_color():
if curses is not None:
- # The curses module has some str/bytes confusion in
- # python3. Until version 3.2.3, most methods return
- # bytes, but only accept strings. In addition, we want to
- # output these strings with the logging module, which
- # works with unicode strings. The explicit calls to
- # unicode() below are harmless in python2 but will do the
- # right conversion in python 3.
- fg_color = (curses.tigetstr("setaf") or
- curses.tigetstr("setf") or "")
- if (3, 0) < sys.version_info < (3, 2, 3):
- fg_color = unicode_type(fg_color, "ascii")
+ fg_color = curses.tigetstr("setaf") or curses.tigetstr("setf") or b""
for levelno, code in colors.items():
- self._colors[levelno] = unicode_type(curses.tparm(fg_color, code), "ascii")
+ # Convert the terminal control characters from
+ # bytes to unicode strings for easier use with the
+ # logging module.
+ self._colors[levelno] = unicode_type(
+ curses.tparm(fg_color, code), "ascii"
+ )
self._normal = unicode_type(curses.tigetstr("sgr0"), "ascii")
else:
# If curses is not present (currently we'll only get here for
# colorama on windows), assume hard-coded ANSI color codes.
for levelno, code in colors.items():
- self._colors[levelno] = '\033[2;3%dm' % code
- self._normal = '\033[0m'
+ self._colors[levelno] = "\033[2;3%dm" % code
+ self._normal = "\033[0m"
else:
- self._normal = ''
+ self._normal = ""
- def format(self, record):
+ def format(self, record: Any) -> str:
try:
message = record.getMessage()
assert isinstance(message, basestring_type) # guaranteed by logging
@@ -182,13 +185,13 @@ def format(self, record):
except Exception as e:
record.message = "Bad message (%r): %r" % (e, record.__dict__)
- record.asctime = self.formatTime(record, self.datefmt)
+ record.asctime = self.formatTime(record, cast(str, self.datefmt))
if record.levelno in self._colors:
record.color = self._colors[record.levelno]
record.end_color = self._normal
else:
- record.color = record.end_color = ''
+ record.color = record.end_color = ""
formatted = self._fmt % record.__dict__
@@ -200,12 +203,14 @@ def format(self, record):
# each line separately so that non-utf8 bytes don't cause
# all the newlines to turn into '\n'.
lines = [formatted.rstrip()]
- lines.extend(_safe_unicode(ln) for ln in record.exc_text.split('\n'))
- formatted = '\n'.join(lines)
+ lines.extend(_safe_unicode(ln) for ln in record.exc_text.split("\n"))
+ formatted = "\n".join(lines)
return formatted.replace("\n", "\n ")
-def enable_pretty_logging(options=None, logger=None):
+def enable_pretty_logging(
+ options: Any = None, logger: Optional[logging.Logger] = None
+) -> None:
"""Turns on formatted logging output as configured.
This is called automatically by `tornado.options.parse_command_line`
@@ -213,41 +218,47 @@ def enable_pretty_logging(options=None, logger=None):
"""
if options is None:
import tornado.options
+
options = tornado.options.options
- if options.logging is None or options.logging.lower() == 'none':
+ if options.logging is None or options.logging.lower() == "none":
return
if logger is None:
logger = logging.getLogger()
logger.setLevel(getattr(logging, options.logging.upper()))
if options.log_file_prefix:
rotate_mode = options.log_rotate_mode
- if rotate_mode == 'size':
+ if rotate_mode == "size":
channel = logging.handlers.RotatingFileHandler(
filename=options.log_file_prefix,
maxBytes=options.log_file_max_size,
- backupCount=options.log_file_num_backups)
- elif rotate_mode == 'time':
+ backupCount=options.log_file_num_backups,
+ encoding="utf-8",
+ ) # type: logging.Handler
+ elif rotate_mode == "time":
channel = logging.handlers.TimedRotatingFileHandler(
filename=options.log_file_prefix,
when=options.log_rotate_when,
interval=options.log_rotate_interval,
- backupCount=options.log_file_num_backups)
+ backupCount=options.log_file_num_backups,
+ encoding="utf-8",
+ )
else:
- error_message = 'The value of log_rotate_mode option should be ' +\
- '"size" or "time", not "%s".' % rotate_mode
+ error_message = (
+ "The value of log_rotate_mode option should be "
+ + '"size" or "time", not "%s".' % rotate_mode
+ )
raise ValueError(error_message)
channel.setFormatter(LogFormatter(color=False))
logger.addHandler(channel)
- if (options.log_to_stderr or
- (options.log_to_stderr is None and not logger.handlers)):
+ if options.log_to_stderr or (options.log_to_stderr is None and not logger.handlers):
# Set up color if we are in a tty and curses is installed
channel = logging.StreamHandler()
channel.setFormatter(LogFormatter())
logger.addHandler(channel)
-def define_logging_options(options=None):
+def define_logging_options(options: Any = None) -> None:
"""Add logging-related flags to ``options``.
These options are present automatically on the default options instance;
@@ -259,32 +270,70 @@ def define_logging_options(options=None):
if options is None:
# late import to prevent cycle
import tornado.options
+
options = tornado.options.options
- options.define("logging", default="info",
- help=("Set the Python log level. If 'none', tornado won't touch the "
- "logging configuration."),
- metavar="debug|info|warning|error|none")
- options.define("log_to_stderr", type=bool, default=None,
- help=("Send log output to stderr (colorized if possible). "
- "By default use stderr if --log_file_prefix is not set and "
- "no other logging is configured."))
- options.define("log_file_prefix", type=str, default=None, metavar="PATH",
- help=("Path prefix for log files. "
- "Note that if you are running multiple tornado processes, "
- "log_file_prefix must be different for each of them (e.g. "
- "include the port number)"))
- options.define("log_file_max_size", type=int, default=100 * 1000 * 1000,
- help="max size of log files before rollover")
- options.define("log_file_num_backups", type=int, default=10,
- help="number of log files to keep")
-
- options.define("log_rotate_when", type=str, default='midnight',
- help=("specify the type of TimedRotatingFileHandler interval "
- "other options:('S', 'M', 'H', 'D', 'W0'-'W6')"))
- options.define("log_rotate_interval", type=int, default=1,
- help="The interval value of timed rotating")
-
- options.define("log_rotate_mode", type=str, default='size',
- help="The mode of rotating files(time or size)")
+ options.define(
+ "logging",
+ default="info",
+ help=(
+ "Set the Python log level. If 'none', tornado won't touch the "
+ "logging configuration."
+ ),
+ metavar="debug|info|warning|error|none",
+ )
+ options.define(
+ "log_to_stderr",
+ type=bool,
+ default=None,
+ help=(
+ "Send log output to stderr (colorized if possible). "
+ "By default use stderr if --log_file_prefix is not set and "
+ "no other logging is configured."
+ ),
+ )
+ options.define(
+ "log_file_prefix",
+ type=str,
+ default=None,
+ metavar="PATH",
+ help=(
+ "Path prefix for log files. "
+ "Note that if you are running multiple tornado processes, "
+ "log_file_prefix must be different for each of them (e.g. "
+ "include the port number)"
+ ),
+ )
+ options.define(
+ "log_file_max_size",
+ type=int,
+ default=100 * 1000 * 1000,
+ help="max size of log files before rollover",
+ )
+ options.define(
+ "log_file_num_backups", type=int, default=10, help="number of log files to keep"
+ )
+
+ options.define(
+ "log_rotate_when",
+ type=str,
+ default="midnight",
+ help=(
+ "specify the type of TimedRotatingFileHandler interval "
+ "other options:('S', 'M', 'H', 'D', 'W0'-'W6')"
+ ),
+ )
+ options.define(
+ "log_rotate_interval",
+ type=int,
+ default=1,
+ help="The interval value of timed rotating",
+ )
+
+ options.define(
+ "log_rotate_mode",
+ type=str,
+ default="size",
+ help="The mode of rotating files(time or size)",
+ )
options.add_parse_callback(lambda: enable_pretty_logging(options))
diff --git a/tornado/netutil.py b/tornado/netutil.py
index 08c9d88627..f8a3038051 100644
--- a/tornado/netutil.py
+++ b/tornado/netutil.py
@@ -15,70 +15,51 @@
"""Miscellaneous network utility code."""
-from __future__ import absolute_import, division, print_function
-
+import concurrent.futures
import errno
import os
import sys
import socket
+import ssl
import stat
from tornado.concurrent import dummy_executor, run_on_executor
-from tornado import gen
from tornado.ioloop import IOLoop
-from tornado.platform.auto import set_close_exec
-from tornado.util import PY3, Configurable, errno_from_exception
-
-try:
- import ssl
-except ImportError:
- # ssl is not available on Google App Engine
- ssl = None
-
-if PY3:
- xrange = range
-
-if ssl is not None:
- # Note that the naming of ssl.Purpose is confusing; the purpose
- # of a context is to authentiate the opposite side of the connection.
- _client_ssl_defaults = ssl.create_default_context(
- ssl.Purpose.SERVER_AUTH)
- _server_ssl_defaults = ssl.create_default_context(
- ssl.Purpose.CLIENT_AUTH)
- if hasattr(ssl, 'OP_NO_COMPRESSION'):
- # See netutil.ssl_options_to_context
- _client_ssl_defaults.options |= ssl.OP_NO_COMPRESSION
- _server_ssl_defaults.options |= ssl.OP_NO_COMPRESSION
-else:
- # Google App Engine
- _client_ssl_defaults = dict(cert_reqs=None,
- ca_certs=None)
- _server_ssl_defaults = {}
+from tornado.util import Configurable, errno_from_exception
+
+from typing import List, Callable, Any, Type, Dict, Union, Tuple, Awaitable, Optional
+
+# Note that the naming of ssl.Purpose is confusing; the purpose
+# of a context is to authenticate the opposite side of the connection.
+_client_ssl_defaults = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
+_server_ssl_defaults = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
+if hasattr(ssl, "OP_NO_COMPRESSION"):
+ # See netutil.ssl_options_to_context
+ _client_ssl_defaults.options |= ssl.OP_NO_COMPRESSION
+ _server_ssl_defaults.options |= ssl.OP_NO_COMPRESSION
# ThreadedResolver runs getaddrinfo on a thread. If the hostname is unicode,
# getaddrinfo attempts to import encodings.idna. If this is done at
# module-import time, the import lock is already held by the main thread,
# leading to deadlock. Avoid it by caching the idna encoder on the main
# thread now.
-u'foo'.encode('idna')
+u"foo".encode("idna")
# For undiagnosed reasons, 'latin1' codec may also need to be preloaded.
-u'foo'.encode('latin1')
-
-# These errnos indicate that a non-blocking operation must be retried
-# at a later time. On most platforms they're the same value, but on
-# some they differ.
-_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
-
-if hasattr(errno, "WSAEWOULDBLOCK"):
- _ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,) # type: ignore
+u"foo".encode("latin1")
# Default backlog used when calling sock.listen()
_DEFAULT_BACKLOG = 128
-def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
- backlog=_DEFAULT_BACKLOG, flags=None, reuse_port=False):
+def bind_sockets(
+ port: int,
+ address: Optional[str] = None,
+ family: socket.AddressFamily = socket.AF_UNSPEC,
+ backlog: int = _DEFAULT_BACKLOG,
+ flags: Optional[int] = None,
+ reuse_port: bool = False,
+) -> List[socket.socket]:
"""Creates listening sockets bound to the given port and address.
Returns a list of socket objects (multiple sockets are returned if
@@ -118,11 +99,23 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
if flags is None:
flags = socket.AI_PASSIVE
bound_port = None
- for res in set(socket.getaddrinfo(address, port, family, socket.SOCK_STREAM,
- 0, flags)):
+ unique_addresses = set() # type: set
+ for res in sorted(
+ socket.getaddrinfo(address, port, family, socket.SOCK_STREAM, 0, flags),
+ key=lambda x: x[0],
+ ):
+ if res in unique_addresses:
+ continue
+
+ unique_addresses.add(res)
+
af, socktype, proto, canonname, sockaddr = res
- if (sys.platform == 'darwin' and address == 'localhost' and
- af == socket.AF_INET6 and sockaddr[3] != 0):
+ if (
+ sys.platform == "darwin"
+ and address == "localhost"
+ and af == socket.AF_INET6
+ and sockaddr[3] != 0
+ ):
# Mac OS X includes a link-local address fe80::1%lo0 in the
# getaddrinfo results for 'localhost'. However, the firewall
# doesn't understand that this is a local address and will
@@ -136,9 +129,13 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
if errno_from_exception(e) == errno.EAFNOSUPPORT:
continue
raise
- set_close_exec(sock.fileno())
- if os.name != 'nt':
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ if os.name != "nt":
+ try:
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ except socket.error as e:
+ if errno_from_exception(e) != errno.ENOPROTOOPT:
+ # Hurd doesn't support SO_REUSEADDR.
+ raise
if reuse_port:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
if af == socket.AF_INET6:
@@ -159,16 +156,41 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
if requested_port == 0 and bound_port is not None:
sockaddr = tuple([host, bound_port] + list(sockaddr[2:]))
- sock.setblocking(0)
- sock.bind(sockaddr)
+ sock.setblocking(False)
+ try:
+ sock.bind(sockaddr)
+ except OSError as e:
+ if (
+ errno_from_exception(e) == errno.EADDRNOTAVAIL
+ and address == "localhost"
+ and sockaddr[0] == "::1"
+ ):
+ # On some systems (most notably docker with default
+ # configurations), ipv6 is partially disabled:
+ # socket.has_ipv6 is true, we can create AF_INET6
+ # sockets, and getaddrinfo("localhost", ...,
+ # AF_PASSIVE) resolves to ::1, but we get an error
+ # when binding.
+ #
+ # Swallow the error, but only for this specific case.
+ # If EADDRNOTAVAIL occurs in other situations, it
+ # might be a real problem like a typo in a
+ # configuration.
+ sock.close()
+ continue
+ else:
+ raise
bound_port = sock.getsockname()[1]
sock.listen(backlog)
sockets.append(sock)
return sockets
-if hasattr(socket, 'AF_UNIX'):
- def bind_unix_socket(file, mode=0o600, backlog=_DEFAULT_BACKLOG):
+if hasattr(socket, "AF_UNIX"):
+
+ def bind_unix_socket(
+ file: str, mode: int = 0o600, backlog: int = _DEFAULT_BACKLOG
+ ) -> socket.socket:
"""Creates a listening unix socket.
If a socket with the given name already exists, it will be deleted.
@@ -179,14 +201,17 @@ def bind_unix_socket(file, mode=0o600, backlog=_DEFAULT_BACKLOG):
`bind_sockets`)
"""
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
- set_close_exec(sock.fileno())
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- sock.setblocking(0)
try:
- st = os.stat(file)
- except OSError as err:
- if errno_from_exception(err) != errno.ENOENT:
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ except socket.error as e:
+ if errno_from_exception(e) != errno.ENOPROTOOPT:
+ # Hurd doesn't support SO_REUSEADDR
raise
+ sock.setblocking(False)
+ try:
+ st = os.stat(file)
+ except FileNotFoundError:
+ pass
else:
if stat.S_ISSOCK(st.st_mode):
os.remove(file)
@@ -198,7 +223,9 @@ def bind_unix_socket(file, mode=0o600, backlog=_DEFAULT_BACKLOG):
return sock
-def add_accept_handler(sock, callback):
+def add_accept_handler(
+ sock: socket.socket, callback: Callable[[socket.socket, Any], None]
+) -> Callable[[], None]:
"""Adds an `.IOLoop` event handler to accept new connections on ``sock``.
When a connection is accepted, ``callback(connection, address)`` will
@@ -219,7 +246,7 @@ def add_accept_handler(sock, callback):
io_loop = IOLoop.current()
removed = [False]
- def accept_handler(fd, events):
+ def accept_handler(fd: socket.socket, events: int) -> None:
# More connections may come in while we're handling callbacks;
# to prevent starvation of other tasks we must limit the number
# of connections we accept at a time. Ideally we would accept
@@ -231,27 +258,24 @@ def accept_handler(fd, events):
# Instead, we use the (default) listen backlog as a rough
# heuristic for the number of connections we can reasonably
# accept at once.
- for i in xrange(_DEFAULT_BACKLOG):
+ for i in range(_DEFAULT_BACKLOG):
if removed[0]:
# The socket was probably closed
return
try:
connection, address = sock.accept()
- except socket.error as e:
- # _ERRNO_WOULDBLOCK indicate we have accepted every
+ except BlockingIOError:
+ # EWOULDBLOCK indicates we have accepted every
# connection that is available.
- if errno_from_exception(e) in _ERRNO_WOULDBLOCK:
- return
+ return
+ except ConnectionAbortedError:
# ECONNABORTED indicates that there was a connection
# but it was closed while still in the accept queue.
# (observed on FreeBSD).
- if errno_from_exception(e) == errno.ECONNABORTED:
- continue
- raise
- set_close_exec(connection.fileno())
+ continue
callback(connection, address)
- def remove_handler():
+ def remove_handler() -> None:
io_loop.remove_handler(sock)
removed[0] = True
@@ -259,19 +283,19 @@ def remove_handler():
return remove_handler
-def is_valid_ip(ip):
- """Returns true if the given string is a well-formed IP address.
+def is_valid_ip(ip: str) -> bool:
+ """Returns ``True`` if the given string is a well-formed IP address.
Supports IPv4 and IPv6.
"""
- if not ip or '\x00' in ip:
+ if not ip or "\x00" in ip:
# getaddrinfo resolves empty strings to localhost, and truncates
# on zero bytes.
return False
try:
- res = socket.getaddrinfo(ip, 0, socket.AF_UNSPEC,
- socket.SOCK_STREAM,
- 0, socket.AI_NUMERICHOST)
+ res = socket.getaddrinfo(
+ ip, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_NUMERICHOST
+ )
return bool(res)
except socket.gaierror as e:
if e.args[0] == socket.EAI_NONAME:
@@ -303,15 +327,18 @@ class method::
The default implementation has changed from `BlockingResolver` to
`DefaultExecutorResolver`.
"""
+
@classmethod
- def configurable_base(cls):
+ def configurable_base(cls) -> Type["Resolver"]:
return Resolver
@classmethod
- def configurable_default(cls):
+ def configurable_default(cls) -> Type["Resolver"]:
return DefaultExecutorResolver
- def resolve(self, host, port, family=socket.AF_UNSPEC, callback=None):
+ def resolve(
+ self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC
+ ) -> Awaitable[List[Tuple[int, Any]]]:
"""Resolves an address.
The ``host`` argument is a string which may be a hostname or a
@@ -329,13 +356,13 @@ def resolve(self, host, port, family=socket.AF_UNSPEC, callback=None):
.. versionchanged:: 4.4
Standardized all implementations to raise `IOError`.
- .. deprecated:: 5.1
- The ``callback`` argument is deprecated and will be removed in 6.0.
+ .. versionchanged:: 6.0 The ``callback`` argument was removed.
Use the returned awaitable object instead.
+
"""
raise NotImplementedError()
- def close(self):
+ def close(self) -> None:
"""Closes the `Resolver`, freeing any resources used.
.. versionadded:: 3.1
@@ -344,7 +371,9 @@ def close(self):
pass
-def _resolve_addr(host, port, family=socket.AF_UNSPEC):
+def _resolve_addr(
+ host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC
+) -> List[Tuple[int, Any]]:
# On Solaris, getaddrinfo fails if the given port is not found
# in /etc/services and no socket type is given, so we must pass
# one here. The socket type used here doesn't seem to actually
@@ -352,9 +381,9 @@ def _resolve_addr(host, port, family=socket.AF_UNSPEC):
# so the addresses we return should still be usable with SOCK_DGRAM.
addrinfo = socket.getaddrinfo(host, port, family, socket.SOCK_STREAM)
results = []
- for family, socktype, proto, canonname, address in addrinfo:
- results.append((family, address))
- return results
+ for fam, socktype, proto, canonname, address in addrinfo:
+ results.append((fam, address))
+ return results # type: ignore
class DefaultExecutorResolver(Resolver):
@@ -362,11 +391,14 @@ class DefaultExecutorResolver(Resolver):
.. versionadded:: 5.0
"""
- @gen.coroutine
- def resolve(self, host, port, family=socket.AF_UNSPEC):
- result = yield IOLoop.current().run_in_executor(
- None, _resolve_addr, host, port, family)
- raise gen.Return(result)
+
+ async def resolve(
+ self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC
+ ) -> List[Tuple[int, Any]]:
+ result = await IOLoop.current().run_in_executor(
+ None, _resolve_addr, host, port, family
+ )
+ return result
class ExecutorResolver(Resolver):
@@ -386,7 +418,12 @@ class ExecutorResolver(Resolver):
The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead
of this class.
"""
- def initialize(self, executor=None, close_executor=True):
+
+ def initialize(
+ self,
+ executor: Optional[concurrent.futures.Executor] = None,
+ close_executor: bool = True,
+ ) -> None:
self.io_loop = IOLoop.current()
if executor is not None:
self.executor = executor
@@ -395,13 +432,15 @@ def initialize(self, executor=None, close_executor=True):
self.executor = dummy_executor
self.close_executor = False
- def close(self):
+ def close(self) -> None:
if self.close_executor:
self.executor.shutdown()
- self.executor = None
+ self.executor = None # type: ignore
@run_on_executor
- def resolve(self, host, port, family=socket.AF_UNSPEC):
+ def resolve(
+ self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC
+ ) -> List[Tuple[int, Any]]:
return _resolve_addr(host, port, family)
@@ -415,8 +454,9 @@ class BlockingResolver(ExecutorResolver):
The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead
of this class.
"""
- def initialize(self):
- super(BlockingResolver, self).initialize()
+
+ def initialize(self) -> None: # type: ignore
+ super().initialize()
class ThreadedResolver(ExecutorResolver):
@@ -439,24 +479,25 @@ class ThreadedResolver(ExecutorResolver):
The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead
of this class.
"""
+
_threadpool = None # type: ignore
_threadpool_pid = None # type: int
- def initialize(self, num_threads=10):
+ def initialize(self, num_threads: int = 10) -> None: # type: ignore
threadpool = ThreadedResolver._create_threadpool(num_threads)
- super(ThreadedResolver, self).initialize(
- executor=threadpool, close_executor=False)
+ super().initialize(executor=threadpool, close_executor=False)
@classmethod
- def _create_threadpool(cls, num_threads):
+ def _create_threadpool(
+ cls, num_threads: int
+ ) -> concurrent.futures.ThreadPoolExecutor:
pid = os.getpid()
if cls._threadpool_pid != pid:
# Threads cannot survive after a fork, so if our pid isn't what it
# was when we created the pool then delete it.
cls._threadpool = None
if cls._threadpool is None:
- from concurrent.futures import ThreadPoolExecutor
- cls._threadpool = ThreadPoolExecutor(num_threads)
+ cls._threadpool = concurrent.futures.ThreadPoolExecutor(num_threads)
cls._threadpool_pid = pid
return cls._threadpool
@@ -483,31 +524,37 @@ class OverrideResolver(Resolver):
.. versionchanged:: 5.0
Added support for host-port-family triplets.
"""
- def initialize(self, resolver, mapping):
+
+ def initialize(self, resolver: Resolver, mapping: dict) -> None:
self.resolver = resolver
self.mapping = mapping
- def close(self):
+ def close(self) -> None:
self.resolver.close()
- def resolve(self, host, port, family=socket.AF_UNSPEC, *args, **kwargs):
+ def resolve(
+ self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC
+ ) -> Awaitable[List[Tuple[int, Any]]]:
if (host, port, family) in self.mapping:
host, port = self.mapping[(host, port, family)]
elif (host, port) in self.mapping:
host, port = self.mapping[(host, port)]
elif host in self.mapping:
host = self.mapping[host]
- return self.resolver.resolve(host, port, family, *args, **kwargs)
+ return self.resolver.resolve(host, port, family)
# These are the keyword arguments to ssl.wrap_socket that must be translated
# to their SSLContext equivalents (the other arguments are still passed
# to SSLContext.wrap_socket).
-_SSL_CONTEXT_KEYWORDS = frozenset(['ssl_version', 'certfile', 'keyfile',
- 'cert_reqs', 'ca_certs', 'ciphers'])
+_SSL_CONTEXT_KEYWORDS = frozenset(
+ ["ssl_version", "certfile", "keyfile", "cert_reqs", "ca_certs", "ciphers"]
+)
-def ssl_options_to_context(ssl_options):
+def ssl_options_to_context(
+ ssl_options: Union[Dict[str, Any], ssl.SSLContext]
+) -> ssl.SSLContext:
"""Try to convert an ``ssl_options`` dictionary to an
`~ssl.SSLContext` object.
@@ -524,17 +571,18 @@ def ssl_options_to_context(ssl_options):
assert all(k in _SSL_CONTEXT_KEYWORDS for k in ssl_options), ssl_options
# Can't use create_default_context since this interface doesn't
# tell us client vs server.
- context = ssl.SSLContext(
- ssl_options.get('ssl_version', ssl.PROTOCOL_SSLv23))
- if 'certfile' in ssl_options:
- context.load_cert_chain(ssl_options['certfile'], ssl_options.get('keyfile', None))
- if 'cert_reqs' in ssl_options:
- context.verify_mode = ssl_options['cert_reqs']
- if 'ca_certs' in ssl_options:
- context.load_verify_locations(ssl_options['ca_certs'])
- if 'ciphers' in ssl_options:
- context.set_ciphers(ssl_options['ciphers'])
- if hasattr(ssl, 'OP_NO_COMPRESSION'):
+ context = ssl.SSLContext(ssl_options.get("ssl_version", ssl.PROTOCOL_SSLv23))
+ if "certfile" in ssl_options:
+ context.load_cert_chain(
+ ssl_options["certfile"], ssl_options.get("keyfile", None)
+ )
+ if "cert_reqs" in ssl_options:
+ context.verify_mode = ssl_options["cert_reqs"]
+ if "ca_certs" in ssl_options:
+ context.load_verify_locations(ssl_options["ca_certs"])
+ if "ciphers" in ssl_options:
+ context.set_ciphers(ssl_options["ciphers"])
+ if hasattr(ssl, "OP_NO_COMPRESSION"):
# Disable TLS compression to avoid CRIME and related attacks.
# This constant depends on openssl version 1.0.
# TODO: Do we need to do this ourselves or can we trust
@@ -543,7 +591,12 @@ def ssl_options_to_context(ssl_options):
return context
-def ssl_wrap_socket(socket, ssl_options, server_hostname=None, **kwargs):
+def ssl_wrap_socket(
+ socket: socket.socket,
+ ssl_options: Union[Dict[str, Any], ssl.SSLContext],
+ server_hostname: Optional[str] = None,
+ **kwargs: Any
+) -> ssl.SSLSocket:
"""Returns an ``ssl.SSLSocket`` wrapping the given socket.
``ssl_options`` may be either an `ssl.SSLContext` object or a
@@ -559,7 +612,6 @@ def ssl_wrap_socket(socket, ssl_options, server_hostname=None, **kwargs):
# TODO: add a unittest (python added server-side SNI support in 3.4)
# In the meantime it can be manually tested with
# python3 -m tornado.httpclient https://sni.velox.ch
- return context.wrap_socket(socket, server_hostname=server_hostname,
- **kwargs)
+ return context.wrap_socket(socket, server_hostname=server_hostname, **kwargs)
else:
return context.wrap_socket(socket, **kwargs)
diff --git a/tornado/options.py b/tornado/options.py
index a6f77029f7..058f88d16d 100644
--- a/tornado/options.py
+++ b/tornado/options.py
@@ -91,11 +91,8 @@ def start_server():
options can be defined, set, and read with any mix of the two.
Dashes are typical for command-line usage while config files require
underscores.
-
"""
-from __future__ import absolute_import, division, print_function
-
import datetime
import numbers
import re
@@ -105,12 +102,25 @@ def start_server():
from tornado.escape import _unicode, native_str
from tornado.log import define_logging_options
-from tornado import stack_context
from tornado.util import basestring_type, exec_in
+from typing import (
+ Any,
+ Iterator,
+ Iterable,
+ Tuple,
+ Set,
+ Dict,
+ Callable,
+ List,
+ TextIO,
+ Optional,
+)
+
class Error(Exception):
"""Exception raised by errors in the options module."""
+
pass
@@ -120,56 +130,61 @@ class OptionParser(object):
Normally accessed via static functions in the `tornado.options` module,
which reference a global instance.
"""
- def __init__(self):
- # we have to use self.__dict__ because we override setattr.
- self.__dict__['_options'] = {}
- self.__dict__['_parse_callbacks'] = []
- self.define("help", type=bool, help="show this help information",
- callback=self._help_callback)
- def _normalize_name(self, name):
- return name.replace('_', '-')
-
- def __getattr__(self, name):
+ def __init__(self) -> None:
+ # we have to use self.__dict__ because we override setattr.
+ self.__dict__["_options"] = {}
+ self.__dict__["_parse_callbacks"] = []
+ self.define(
+ "help",
+ type=bool,
+ help="show this help information",
+ callback=self._help_callback,
+ )
+
+ def _normalize_name(self, name: str) -> str:
+ return name.replace("_", "-")
+
+ def __getattr__(self, name: str) -> Any:
name = self._normalize_name(name)
if isinstance(self._options.get(name), _Option):
return self._options[name].value()
raise AttributeError("Unrecognized option %r" % name)
- def __setattr__(self, name, value):
+ def __setattr__(self, name: str, value: Any) -> None:
name = self._normalize_name(name)
if isinstance(self._options.get(name), _Option):
return self._options[name].set(value)
raise AttributeError("Unrecognized option %r" % name)
- def __iter__(self):
+ def __iter__(self) -> Iterator:
return (opt.name for opt in self._options.values())
- def __contains__(self, name):
+ def __contains__(self, name: str) -> bool:
name = self._normalize_name(name)
return name in self._options
- def __getitem__(self, name):
+ def __getitem__(self, name: str) -> Any:
return self.__getattr__(name)
- def __setitem__(self, name, value):
+ def __setitem__(self, name: str, value: Any) -> None:
return self.__setattr__(name, value)
- def items(self):
- """A sequence of (name, value) pairs.
+ def items(self) -> Iterable[Tuple[str, Any]]:
+ """An iterable of (name, value) pairs.
.. versionadded:: 3.1
"""
return [(opt.name, opt.value()) for name, opt in self._options.items()]
- def groups(self):
+ def groups(self) -> Set[str]:
"""The set of option-groups created by ``define``.
.. versionadded:: 3.1
"""
return set(opt.group_name for opt in self._options.values())
- def group_dict(self, group):
+ def group_dict(self, group: str) -> Dict[str, Any]:
"""The names and values of options in a group.
Useful for copying options into Application settings::
@@ -187,19 +202,29 @@ def group_dict(self, group):
.. versionadded:: 3.1
"""
return dict(
- (opt.name, opt.value()) for name, opt in self._options.items()
- if not group or group == opt.group_name)
+ (opt.name, opt.value())
+ for name, opt in self._options.items()
+ if not group or group == opt.group_name
+ )
- def as_dict(self):
+ def as_dict(self) -> Dict[str, Any]:
"""The names and values of all options.
.. versionadded:: 3.1
"""
- return dict(
- (opt.name, opt.value()) for name, opt in self._options.items())
-
- def define(self, name, default=None, type=None, help=None, metavar=None,
- multiple=False, group=None, callback=None):
+ return dict((opt.name, opt.value()) for name, opt in self._options.items())
+
+ def define(
+ self,
+ name: str,
+ default: Any = None,
+ type: Optional[type] = None,
+ help: Optional[str] = None,
+ metavar: Optional[str] = None,
+ multiple: bool = False,
+ group: Optional[str] = None,
+ callback: Optional[Callable[[Any], None]] = None,
+ ) -> None:
"""Defines a new command line option.
``type`` can be any of `str`, `int`, `float`, `bool`,
@@ -236,18 +261,27 @@ def define(self, name, default=None, type=None, help=None, metavar=None,
"""
normalized = self._normalize_name(name)
if normalized in self._options:
- raise Error("Option %r already defined in %s" %
- (normalized, self._options[normalized].file_name))
+ raise Error(
+ "Option %r already defined in %s"
+ % (normalized, self._options[normalized].file_name)
+ )
frame = sys._getframe(0)
- options_file = frame.f_code.co_filename
-
- # Can be called directly, or through top level define() fn, in which
- # case, step up above that frame to look for real caller.
- if (frame.f_back.f_code.co_filename == options_file and
- frame.f_back.f_code.co_name == 'define'):
- frame = frame.f_back
-
- file_name = frame.f_back.f_code.co_filename
+ if frame is not None:
+ options_file = frame.f_code.co_filename
+
+ # Can be called directly, or through top level define() fn, in which
+ # case, step up above that frame to look for real caller.
+ if (
+ frame.f_back is not None
+ and frame.f_back.f_code.co_filename == options_file
+ and frame.f_back.f_code.co_name == "define"
+ ):
+ frame = frame.f_back
+
+ assert frame.f_back is not None
+ file_name = frame.f_back.f_code.co_filename
+ else:
+ file_name = ""
if file_name == options_file:
file_name = ""
if type is None:
@@ -256,17 +290,25 @@ def define(self, name, default=None, type=None, help=None, metavar=None,
else:
type = str
if group:
- group_name = group
+ group_name = group # type: Optional[str]
else:
group_name = file_name
- option = _Option(name, file_name=file_name,
- default=default, type=type, help=help,
- metavar=metavar, multiple=multiple,
- group_name=group_name,
- callback=callback)
+ option = _Option(
+ name,
+ file_name=file_name,
+ default=default,
+ type=type,
+ help=help,
+ metavar=metavar,
+ multiple=multiple,
+ group_name=group_name,
+ callback=callback,
+ )
self._options[normalized] = option
- def parse_command_line(self, args=None, final=True):
+ def parse_command_line(
+ self, args: Optional[List[str]] = None, final: bool = True
+ ) -> List[str]:
"""Parses all options given on the command line (defaults to
`sys.argv`).
@@ -290,27 +332,27 @@ def parse_command_line(self, args=None, final=True):
"""
if args is None:
args = sys.argv
- remaining = []
+ remaining = [] # type: List[str]
for i in range(1, len(args)):
# All things after the last option are command line arguments
if not args[i].startswith("-"):
remaining = args[i:]
break
if args[i] == "--":
- remaining = args[i + 1:]
+ remaining = args[i + 1 :]
break
arg = args[i].lstrip("-")
name, equals, value = arg.partition("=")
name = self._normalize_name(name)
if name not in self._options:
self.print_help()
- raise Error('Unrecognized command line option: %r' % name)
+ raise Error("Unrecognized command line option: %r" % name)
option = self._options[name]
if not equals:
if option.type == bool:
value = "true"
else:
- raise Error('Option %r requires a value' % name)
+ raise Error("Option %r requires a value" % name)
option.parse(value)
if final:
@@ -318,7 +360,7 @@ def parse_command_line(self, args=None, final=True):
return remaining
- def parse_config_file(self, path, final=True):
+ def parse_config_file(self, path: str, final: bool = True) -> None:
"""Parses and loads the config file at the given path.
The config file contains Python code that will be executed (so
@@ -326,18 +368,20 @@ def parse_config_file(self, path, final=True):
the global namespace that matches a defined option will be
used to set that option's value.
- Options are not parsed from strings as they would be on the
- command line; they should be set to the correct type (this
- means if you have ``datetime`` or ``timedelta`` options you
- will need to import those modules in the config file.
+ Options may either be the specified type for the option or
+ strings (in which case they will be parsed the same way as in
+ `.parse_command_line`)
Example (using the options defined in the top-level docs of
this module)::
port = 80
mysql_host = 'mydb.example.com:3306'
+ # Both lists and comma-separated strings are allowed for
+ # multiple=True.
memcache_hosts = ['cache1.example.com:11011',
'cache2.example.com:11011']
+ memcache_hosts = 'cache1.example.com:11011,cache2.example.com:11011'
If ``final`` is ``False``, parse callbacks will not be run.
This is useful for applications that wish to combine configurations
@@ -358,25 +402,40 @@ def parse_config_file(self, path, final=True):
The special variable ``__file__`` is available inside config
files, specifying the absolute path to the config file itself.
+ .. versionchanged:: 5.1
+ Added the ability to set options via strings in config files.
+
"""
- config = {'__file__': os.path.abspath(path)}
- with open(path, 'rb') as f:
+ config = {"__file__": os.path.abspath(path)}
+ with open(path, "rb") as f:
exec_in(native_str(f.read()), config, config)
for name in config:
normalized = self._normalize_name(name)
if normalized in self._options:
- self._options[normalized].set(config[name])
+ option = self._options[normalized]
+ if option.multiple:
+ if not isinstance(config[name], (list, str)):
+ raise Error(
+ "Option %r is required to be a list of %s "
+ "or a comma-separated string"
+ % (option.name, option.type.__name__)
+ )
+
+ if type(config[name]) == str and option.type != str:
+ option.parse(config[name])
+ else:
+ option.set(config[name])
if final:
self.run_parse_callbacks()
- def print_help(self, file=None):
+ def print_help(self, file: Optional[TextIO] = None) -> None:
"""Prints all the command line options to stderr (or another file)."""
if file is None:
file = sys.stderr
print("Usage: %s [OPTIONS]" % sys.argv[0], file=file)
print("\nOptions:\n", file=file)
- by_group = {}
+ by_group = {} # type: Dict[str, List[_Option]]
for option in self._options.values():
by_group.setdefault(option.group_name, []).append(option)
@@ -390,30 +449,30 @@ def print_help(self, file=None):
if option.metavar:
prefix += "=" + option.metavar
description = option.help or ""
- if option.default is not None and option.default != '':
+ if option.default is not None and option.default != "":
description += " (default %s)" % option.default
lines = textwrap.wrap(description, 79 - 35)
if len(prefix) > 30 or len(lines) == 0:
- lines.insert(0, '')
+ lines.insert(0, "")
print(" --%-30s %s" % (prefix, lines[0]), file=file)
for line in lines[1:]:
- print("%-34s %s" % (' ', line), file=file)
+ print("%-34s %s" % (" ", line), file=file)
print(file=file)
- def _help_callback(self, value):
+ def _help_callback(self, value: bool) -> None:
if value:
self.print_help()
sys.exit(0)
- def add_parse_callback(self, callback):
+ def add_parse_callback(self, callback: Callable[[], None]) -> None:
"""Adds a parse callback, to be invoked when option parsing is done."""
- self._parse_callbacks.append(stack_context.wrap(callback))
+ self._parse_callbacks.append(callback)
- def run_parse_callbacks(self):
+ def run_parse_callbacks(self) -> None:
for callback in self._parse_callbacks:
callback()
- def mockable(self):
+ def mockable(self) -> "_Mockable":
"""Returns a wrapper around self that is compatible with
`mock.patch `.
@@ -437,38 +496,53 @@ class _Mockable(object):
As of ``mock`` version 1.0.1, when an object uses ``__getattr__``
hooks instead of ``__dict__``, ``patch.__exit__`` tries to delete
the attribute it set instead of setting a new one (assuming that
- the object does not catpure ``__setattr__``, so the patch
+ the object does not capture ``__setattr__``, so the patch
created a new attribute in ``__dict__``).
_Mockable's getattr and setattr pass through to the underlying
OptionParser, and delattr undoes the effect of a previous setattr.
"""
- def __init__(self, options):
+
+ def __init__(self, options: OptionParser) -> None:
# Modify __dict__ directly to bypass __setattr__
- self.__dict__['_options'] = options
- self.__dict__['_originals'] = {}
+ self.__dict__["_options"] = options
+ self.__dict__["_originals"] = {}
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> Any:
return getattr(self._options, name)
- def __setattr__(self, name, value):
+ def __setattr__(self, name: str, value: Any) -> None:
assert name not in self._originals, "don't reuse mockable objects"
self._originals[name] = getattr(self._options, name)
setattr(self._options, name, value)
- def __delattr__(self, name):
+ def __delattr__(self, name: str) -> None:
setattr(self._options, name, self._originals.pop(name))
class _Option(object):
+ # This class could almost be made generic, but the way the types
+ # interact with the multiple argument makes this tricky. (default
+ # and the callback use List[T], but type is still Type[T]).
UNSET = object()
- def __init__(self, name, default=None, type=basestring_type, help=None,
- metavar=None, multiple=False, file_name=None, group_name=None,
- callback=None):
+ def __init__(
+ self,
+ name: str,
+ default: Any = None,
+ type: Optional[type] = None,
+ help: Optional[str] = None,
+ metavar: Optional[str] = None,
+ multiple: bool = False,
+ file_name: Optional[str] = None,
+ group_name: Optional[str] = None,
+ callback: Optional[Callable[[Any], None]] = None,
+ ) -> None:
if default is None and multiple:
default = []
self.name = name
+ if type is None:
+ raise ValueError("type must not be None")
self.type = type
self.help = help
self.metavar = metavar
@@ -477,26 +551,28 @@ def __init__(self, name, default=None, type=basestring_type, help=None,
self.group_name = group_name
self.callback = callback
self.default = default
- self._value = _Option.UNSET
+ self._value = _Option.UNSET # type: Any
- def value(self):
+ def value(self) -> Any:
return self.default if self._value is _Option.UNSET else self._value
- def parse(self, value):
+ def parse(self, value: str) -> Any:
_parse = {
datetime.datetime: self._parse_datetime,
datetime.timedelta: self._parse_timedelta,
bool: self._parse_bool,
basestring_type: self._parse_string,
- }.get(self.type, self.type)
+ }.get(
+ self.type, self.type
+ ) # type: Callable[[str], Any]
if self.multiple:
self._value = []
for part in value.split(","):
if issubclass(self.type, numbers.Integral):
# allow ranges of the form X:Y (inclusive at both ends)
- lo, _, hi = part.partition(":")
- lo = _parse(lo)
- hi = _parse(hi) if hi else lo
+ lo_str, _, hi_str = part.partition(":")
+ lo = _parse(lo_str)
+ hi = _parse(hi_str) if hi_str else lo
self._value.extend(range(lo, hi + 1))
else:
self._value.append(_parse(part))
@@ -506,19 +582,25 @@ def parse(self, value):
self.callback(self._value)
return self.value()
- def set(self, value):
+ def set(self, value: Any) -> None:
if self.multiple:
if not isinstance(value, list):
- raise Error("Option %r is required to be a list of %s" %
- (self.name, self.type.__name__))
+ raise Error(
+ "Option %r is required to be a list of %s"
+ % (self.name, self.type.__name__)
+ )
for item in value:
if item is not None and not isinstance(item, self.type):
- raise Error("Option %r is required to be a list of %s" %
- (self.name, self.type.__name__))
+ raise Error(
+ "Option %r is required to be a list of %s"
+ % (self.name, self.type.__name__)
+ )
else:
if value is not None and not isinstance(value, self.type):
- raise Error("Option %r is required to be a %s (%s given)" %
- (self.name, self.type.__name__, type(value)))
+ raise Error(
+ "Option %r is required to be a %s (%s given)"
+ % (self.name, self.type.__name__, type(value))
+ )
self._value = value
if self.callback is not None:
self.callback(self._value)
@@ -537,32 +619,33 @@ def set(self, value):
"%H:%M",
]
- def _parse_datetime(self, value):
+ def _parse_datetime(self, value: str) -> datetime.datetime:
for format in self._DATETIME_FORMATS:
try:
return datetime.datetime.strptime(value, format)
except ValueError:
pass
- raise Error('Unrecognized date/time format: %r' % value)
+ raise Error("Unrecognized date/time format: %r" % value)
_TIMEDELTA_ABBREV_DICT = {
- 'h': 'hours',
- 'm': 'minutes',
- 'min': 'minutes',
- 's': 'seconds',
- 'sec': 'seconds',
- 'ms': 'milliseconds',
- 'us': 'microseconds',
- 'd': 'days',
- 'w': 'weeks',
+ "h": "hours",
+ "m": "minutes",
+ "min": "minutes",
+ "s": "seconds",
+ "sec": "seconds",
+ "ms": "milliseconds",
+ "us": "microseconds",
+ "d": "days",
+ "w": "weeks",
}
- _FLOAT_PATTERN = r'[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?'
+ _FLOAT_PATTERN = r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?"
_TIMEDELTA_PATTERN = re.compile(
- r'\s*(%s)\s*(\w*)\s*' % _FLOAT_PATTERN, re.IGNORECASE)
+ r"\s*(%s)\s*(\w*)\s*" % _FLOAT_PATTERN, re.IGNORECASE
+ )
- def _parse_timedelta(self, value):
+ def _parse_timedelta(self, value: str) -> datetime.timedelta:
try:
sum = datetime.timedelta()
start = 0
@@ -571,18 +654,20 @@ def _parse_timedelta(self, value):
if not m:
raise Exception()
num = float(m.group(1))
- units = m.group(2) or 'seconds'
+ units = m.group(2) or "seconds"
units = self._TIMEDELTA_ABBREV_DICT.get(units, units)
- sum += datetime.timedelta(**{units: num})
+ # This line confuses mypy when setup.py sets python_version=3.6
+ # https://github.com/python/mypy/issues/9676
+ sum += datetime.timedelta(**{units: num}) # type: ignore
start = m.end()
return sum
except Exception:
raise
- def _parse_bool(self, value):
+ def _parse_bool(self, value: str) -> bool:
return value.lower() not in ("false", "0", "f")
- def _parse_string(self, value):
+ def _parse_string(self, value: str) -> str:
return _unicode(value)
@@ -593,18 +678,35 @@ def _parse_string(self, value):
"""
-def define(name, default=None, type=None, help=None, metavar=None,
- multiple=False, group=None, callback=None):
+def define(
+ name: str,
+ default: Any = None,
+ type: Optional[type] = None,
+ help: Optional[str] = None,
+ metavar: Optional[str] = None,
+ multiple: bool = False,
+ group: Optional[str] = None,
+ callback: Optional[Callable[[Any], None]] = None,
+) -> None:
"""Defines an option in the global namespace.
See `OptionParser.define`.
"""
- return options.define(name, default=default, type=type, help=help,
- metavar=metavar, multiple=multiple, group=group,
- callback=callback)
-
-
-def parse_command_line(args=None, final=True):
+ return options.define(
+ name,
+ default=default,
+ type=type,
+ help=help,
+ metavar=metavar,
+ multiple=multiple,
+ group=group,
+ callback=callback,
+ )
+
+
+def parse_command_line(
+ args: Optional[List[str]] = None, final: bool = True
+) -> List[str]:
"""Parses global options from the command line.
See `OptionParser.parse_command_line`.
@@ -612,7 +714,7 @@ def parse_command_line(args=None, final=True):
return options.parse_command_line(args, final=final)
-def parse_config_file(path, final=True):
+def parse_config_file(path: str, final: bool = True) -> None:
"""Parses global options from a config file.
See `OptionParser.parse_config_file`.
@@ -620,7 +722,7 @@ def parse_config_file(path, final=True):
return options.parse_config_file(path, final=final)
-def print_help(file=None):
+def print_help(file: Optional[TextIO] = None) -> None:
"""Prints all the command line options to stderr (or another file).
See `OptionParser.print_help`.
@@ -628,7 +730,7 @@ def print_help(file=None):
return options.print_help(file)
-def add_parse_callback(callback):
+def add_parse_callback(callback: Callable[[], None]) -> None:
"""Adds a parse callback, to be invoked when option parsing is done.
See `OptionParser.add_parse_callback`
diff --git a/tornado/platform/asyncio.py b/tornado/platform/asyncio.py
index b6a490afac..292d9b66a4 100644
--- a/tornado/platform/asyncio.py
+++ b/tornado/platform/asyncio.py
@@ -14,29 +14,88 @@
.. note::
- Tornado requires the `~asyncio.AbstractEventLoop.add_reader` family of
- methods, so it is not compatible with the `~asyncio.ProactorEventLoop` on
- Windows. Use the `~asyncio.SelectorEventLoop` instead.
+ Tornado is designed to use a selector-based event loop. On Windows,
+ where a proactor-based event loop has been the default since Python 3.8,
+ a selector event loop is emulated by running ``select`` on a separate thread.
+ Configuring ``asyncio`` to use a selector event loop may improve performance
+ of Tornado (but may reduce performance of other ``asyncio``-based libraries
+ in the same process).
"""
-from __future__ import absolute_import, division, print_function
+import asyncio
+import atexit
+import concurrent.futures
+import errno
import functools
-
+import select
+import socket
+import sys
+import threading
+import typing
from tornado.gen import convert_yielded
-from tornado.ioloop import IOLoop
-from tornado import stack_context
+from tornado.ioloop import IOLoop, _Selectable
+
+from typing import Any, TypeVar, Awaitable, Callable, Union, Optional, List, Tuple, Dict
+
+if typing.TYPE_CHECKING:
+ from typing import Set # noqa: F401
+ from typing_extensions import Protocol
+
+ class _HasFileno(Protocol):
+ def fileno(self) -> int:
+ pass
+
+ _FileDescriptorLike = Union[int, _HasFileno]
+
+_T = TypeVar("_T")
+
+
+# Collection of selector thread event loops to shut down on exit.
+_selector_loops = set() # type: Set[AddThreadSelectorEventLoop]
+
+
+def _atexit_callback() -> None:
+ for loop in _selector_loops:
+ with loop._select_cond:
+ loop._closing_selector = True
+ loop._select_cond.notify()
+ try:
+ loop._waker_w.send(b"a")
+ except BlockingIOError:
+ pass
+ # If we don't join our (daemon) thread here, we may get a deadlock
+ # during interpreter shutdown. I don't really understand why. This
+ # deadlock happens every time in CI (both travis and appveyor) but
+ # I've never been able to reproduce locally.
+ loop._thread.join()
+ _selector_loops.clear()
-import asyncio
+
+atexit.register(_atexit_callback)
class BaseAsyncIOLoop(IOLoop):
- def initialize(self, asyncio_loop, **kwargs):
+ def initialize( # type: ignore
+ self, asyncio_loop: asyncio.AbstractEventLoop, **kwargs: Any
+ ) -> None:
+ # asyncio_loop is always the real underlying IOLoop. This is used in
+ # ioloop.py to maintain the asyncio-to-ioloop mappings.
self.asyncio_loop = asyncio_loop
+ # selector_loop is an event loop that implements the add_reader family of
+ # methods. Usually the same as asyncio_loop but differs on platforms such
+ # as windows where the default event loop does not implement these methods.
+ self.selector_loop = asyncio_loop
+ if hasattr(asyncio, "ProactorEventLoop") and isinstance(
+ asyncio_loop, asyncio.ProactorEventLoop # type: ignore
+ ):
+ # Ignore this line for mypy because the abstract method checker
+ # doesn't understand dynamic proxies.
+ self.selector_loop = AddThreadSelectorEventLoop(asyncio_loop) # type: ignore
# Maps fd to (fileobj, handler function) pair (as in IOLoop.add_handler)
- self.handlers = {}
+ self.handlers = {} # type: Dict[int, Tuple[Union[int, _Selectable], Callable]]
# Set of fds listening for reads/writes
- self.readers = set()
- self.writers = set()
+ self.readers = set() # type: Set[int]
+ self.writers = set() # type: Set[int]
self.closing = False
# If an asyncio loop was closed through an asyncio interface
# instead of IOLoop.close(), we'd never hear about it and may
@@ -53,74 +112,87 @@ def initialize(self, asyncio_loop, **kwargs):
if loop.is_closed():
del IOLoop._ioloop_for_asyncio[loop]
IOLoop._ioloop_for_asyncio[asyncio_loop] = self
- super(BaseAsyncIOLoop, self).initialize(**kwargs)
- def close(self, all_fds=False):
+ self._thread_identity = 0
+
+ super().initialize(**kwargs)
+
+ def assign_thread_identity() -> None:
+ self._thread_identity = threading.get_ident()
+
+ self.add_callback(assign_thread_identity)
+
+ def close(self, all_fds: bool = False) -> None:
self.closing = True
for fd in list(self.handlers):
fileobj, handler_func = self.handlers[fd]
self.remove_handler(fd)
if all_fds:
self.close_fd(fileobj)
- self.asyncio_loop.close()
+ # Remove the mapping before closing the asyncio loop. If this
+ # happened in the other order, we could race against another
+ # initialize() call which would see the closed asyncio loop,
+ # assume it was closed from the asyncio side, and do this
+ # cleanup for us, leading to a KeyError.
del IOLoop._ioloop_for_asyncio[self.asyncio_loop]
+ if self.selector_loop is not self.asyncio_loop:
+ self.selector_loop.close()
+ self.asyncio_loop.close()
- def add_handler(self, fd, handler, events):
+ def add_handler(
+ self, fd: Union[int, _Selectable], handler: Callable[..., None], events: int
+ ) -> None:
fd, fileobj = self.split_fd(fd)
if fd in self.handlers:
raise ValueError("fd %s added twice" % fd)
- self.handlers[fd] = (fileobj, stack_context.wrap(handler))
+ self.handlers[fd] = (fileobj, handler)
if events & IOLoop.READ:
- self.asyncio_loop.add_reader(
- fd, self._handle_events, fd, IOLoop.READ)
+ self.selector_loop.add_reader(fd, self._handle_events, fd, IOLoop.READ)
self.readers.add(fd)
if events & IOLoop.WRITE:
- self.asyncio_loop.add_writer(
- fd, self._handle_events, fd, IOLoop.WRITE)
+ self.selector_loop.add_writer(fd, self._handle_events, fd, IOLoop.WRITE)
self.writers.add(fd)
- def update_handler(self, fd, events):
+ def update_handler(self, fd: Union[int, _Selectable], events: int) -> None:
fd, fileobj = self.split_fd(fd)
if events & IOLoop.READ:
if fd not in self.readers:
- self.asyncio_loop.add_reader(
- fd, self._handle_events, fd, IOLoop.READ)
+ self.selector_loop.add_reader(fd, self._handle_events, fd, IOLoop.READ)
self.readers.add(fd)
else:
if fd in self.readers:
- self.asyncio_loop.remove_reader(fd)
+ self.selector_loop.remove_reader(fd)
self.readers.remove(fd)
if events & IOLoop.WRITE:
if fd not in self.writers:
- self.asyncio_loop.add_writer(
- fd, self._handle_events, fd, IOLoop.WRITE)
+ self.selector_loop.add_writer(fd, self._handle_events, fd, IOLoop.WRITE)
self.writers.add(fd)
else:
if fd in self.writers:
- self.asyncio_loop.remove_writer(fd)
+ self.selector_loop.remove_writer(fd)
self.writers.remove(fd)
- def remove_handler(self, fd):
+ def remove_handler(self, fd: Union[int, _Selectable]) -> None:
fd, fileobj = self.split_fd(fd)
if fd not in self.handlers:
return
if fd in self.readers:
- self.asyncio_loop.remove_reader(fd)
+ self.selector_loop.remove_reader(fd)
self.readers.remove(fd)
if fd in self.writers:
- self.asyncio_loop.remove_writer(fd)
+ self.selector_loop.remove_writer(fd)
self.writers.remove(fd)
del self.handlers[fd]
- def _handle_events(self, fd, events):
+ def _handle_events(self, fd: int, events: int) -> None:
fileobj, handler_func = self.handlers[fd]
handler_func(fileobj, events)
- def start(self):
+ def start(self) -> None:
try:
old_loop = asyncio.get_event_loop()
except (RuntimeError, AssertionError):
- old_loop = None
+ old_loop = None # type: ignore
try:
self._setup_logging()
asyncio.set_event_loop(self.asyncio_loop)
@@ -128,25 +200,31 @@ def start(self):
finally:
asyncio.set_event_loop(old_loop)
- def stop(self):
+ def stop(self) -> None:
self.asyncio_loop.stop()
- def call_at(self, when, callback, *args, **kwargs):
+ def call_at(
+ self, when: float, callback: Callable[..., None], *args: Any, **kwargs: Any
+ ) -> object:
# asyncio.call_at supports *args but not **kwargs, so bind them here.
# We do not synchronize self.time and asyncio_loop.time, so
# convert from absolute to relative.
return self.asyncio_loop.call_later(
- max(0, when - self.time()), self._run_callback,
- functools.partial(stack_context.wrap(callback), *args, **kwargs))
+ max(0, when - self.time()),
+ self._run_callback,
+ functools.partial(callback, *args, **kwargs),
+ )
- def remove_timeout(self, timeout):
- timeout.cancel()
+ def remove_timeout(self, timeout: object) -> None:
+ timeout.cancel() # type: ignore
- def add_callback(self, callback, *args, **kwargs):
+ def add_callback(self, callback: Callable, *args: Any, **kwargs: Any) -> None:
+ if threading.get_ident() == self._thread_identity:
+ call_soon = self.asyncio_loop.call_soon
+ else:
+ call_soon = self.asyncio_loop.call_soon_threadsafe
try:
- self.asyncio_loop.call_soon_threadsafe(
- self._run_callback,
- functools.partial(stack_context.wrap(callback), *args, **kwargs))
+ call_soon(self._run_callback, functools.partial(callback, *args, **kwargs))
except RuntimeError:
# "Event loop is closed". Swallow the exception for
# consistency with PollIOLoop (and logical consistency
@@ -154,13 +232,31 @@ def add_callback(self, callback, *args, **kwargs):
# add_callback that completes without error will
# eventually execute).
pass
+ except AttributeError:
+ # ProactorEventLoop may raise this instead of RuntimeError
+ # if call_soon_threadsafe races with a call to close().
+ # Swallow it too for consistency.
+ pass
- add_callback_from_signal = add_callback
+ def add_callback_from_signal(
+ self, callback: Callable, *args: Any, **kwargs: Any
+ ) -> None:
+ try:
+ self.asyncio_loop.call_soon_threadsafe(
+ self._run_callback, functools.partial(callback, *args, **kwargs)
+ )
+ except RuntimeError:
+ pass
- def run_in_executor(self, executor, func, *args):
+ def run_in_executor(
+ self,
+ executor: Optional[concurrent.futures.Executor],
+ func: Callable[..., _T],
+ *args: Any
+ ) -> Awaitable[_T]:
return self.asyncio_loop.run_in_executor(executor, func, *args)
- def set_default_executor(self, executor):
+ def set_default_executor(self, executor: concurrent.futures.Executor) -> None:
return self.asyncio_loop.set_default_executor(executor)
@@ -178,10 +274,11 @@ class AsyncIOMainLoop(BaseAsyncIOLoop):
Closing an `AsyncIOMainLoop` now closes the underlying asyncio loop.
"""
- def initialize(self, **kwargs):
- super(AsyncIOMainLoop, self).initialize(asyncio.get_event_loop(), **kwargs)
- def make_current(self):
+ def initialize(self, **kwargs: Any) -> None: # type: ignore
+ super().initialize(asyncio.get_event_loop(), **kwargs)
+
+ def make_current(self) -> None:
# AsyncIOMainLoop already refers to the current asyncio loop so
# nothing to do here.
pass
@@ -206,38 +303,39 @@ class AsyncIOLoop(BaseAsyncIOLoop):
Now used automatically when appropriate; it is no longer necessary
to refer to this class directly.
"""
- def initialize(self, **kwargs):
+
+ def initialize(self, **kwargs: Any) -> None: # type: ignore
self.is_current = False
loop = asyncio.new_event_loop()
try:
- super(AsyncIOLoop, self).initialize(loop, **kwargs)
+ super().initialize(loop, **kwargs)
except Exception:
# If initialize() does not succeed (taking ownership of the loop),
# we have to close it.
loop.close()
raise
- def close(self, all_fds=False):
+ def close(self, all_fds: bool = False) -> None:
if self.is_current:
self.clear_current()
- super(AsyncIOLoop, self).close(all_fds=all_fds)
+ super().close(all_fds=all_fds)
- def make_current(self):
+ def make_current(self) -> None:
if not self.is_current:
try:
self.old_asyncio = asyncio.get_event_loop()
except (RuntimeError, AssertionError):
- self.old_asyncio = None
+ self.old_asyncio = None # type: ignore
self.is_current = True
asyncio.set_event_loop(self.asyncio_loop)
- def _clear_current_hook(self):
+ def _clear_current_hook(self) -> None:
if self.is_current:
asyncio.set_event_loop(self.old_asyncio)
self.is_current = False
-def to_tornado_future(asyncio_future):
+def to_tornado_future(asyncio_future: asyncio.Future) -> asyncio.Future:
"""Convert an `asyncio.Future` to a `tornado.concurrent.Future`.
.. versionadded:: 4.1
@@ -249,7 +347,7 @@ def to_tornado_future(asyncio_future):
return asyncio_future
-def to_asyncio_future(tornado_future):
+def to_asyncio_future(tornado_future: asyncio.Future) -> asyncio.Future:
"""Convert a Tornado yieldable object to an `asyncio.Future`.
.. versionadded:: 4.1
@@ -265,7 +363,15 @@ def to_asyncio_future(tornado_future):
return convert_yielded(tornado_future)
-class AnyThreadEventLoopPolicy(asyncio.DefaultEventLoopPolicy):
+if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
+ # "Any thread" and "selector" should be orthogonal, but there's not a clean
+ # interface for composing policies so pick the right base.
+ _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
+else:
+ _BasePolicy = asyncio.DefaultEventLoopPolicy
+
+
+class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
"""Event loop policy that allows loop creation on any thread.
The default `asyncio` event loop policy only automatically creates
@@ -282,13 +388,228 @@ class AnyThreadEventLoopPolicy(asyncio.DefaultEventLoopPolicy):
.. versionadded:: 5.0
"""
- def get_event_loop(self):
+
+ def get_event_loop(self) -> asyncio.AbstractEventLoop:
try:
return super().get_event_loop()
except (RuntimeError, AssertionError):
- # This was an AssertionError in python 3.4.2 (which ships with debian jessie)
+ # This was an AssertionError in Python 3.4.2 (which ships with Debian Jessie)
# and changed to a RuntimeError in 3.4.3.
# "There is no current event loop in thread %r"
loop = self.new_event_loop()
self.set_event_loop(loop)
return loop
+
+
+class AddThreadSelectorEventLoop(asyncio.AbstractEventLoop):
+ """Wrap an event loop to add implementations of the ``add_reader`` method family.
+
+ Instances of this class start a second thread to run a selector.
+ This thread is completely hidden from the user; all callbacks are
+ run on the wrapped event loop's thread.
+
+ This class is used automatically by Tornado; applications should not need
+ to refer to it directly.
+
+ It is safe to wrap any event loop with this class, although it only makes sense
+ for event loops that do not implement the ``add_reader`` family of methods
+ themselves (i.e. ``WindowsProactorEventLoop``)
+
+ Closing the ``AddThreadSelectorEventLoop`` also closes the wrapped event loop.
+
+ """
+
+ # This class is a __getattribute__-based proxy. All attributes other than those
+ # in this set are proxied through to the underlying loop.
+ MY_ATTRIBUTES = {
+ "_consume_waker",
+ "_select_cond",
+ "_select_args",
+ "_closing_selector",
+ "_thread",
+ "_handle_event",
+ "_readers",
+ "_real_loop",
+ "_start_select",
+ "_run_select",
+ "_handle_select",
+ "_wake_selector",
+ "_waker_r",
+ "_waker_w",
+ "_writers",
+ "add_reader",
+ "add_writer",
+ "close",
+ "remove_reader",
+ "remove_writer",
+ }
+
+ def __getattribute__(self, name: str) -> Any:
+ if name in AddThreadSelectorEventLoop.MY_ATTRIBUTES:
+ return super().__getattribute__(name)
+ return getattr(self._real_loop, name)
+
+ def __init__(self, real_loop: asyncio.AbstractEventLoop) -> None:
+ self._real_loop = real_loop
+
+ # Create a thread to run the select system call. We manage this thread
+ # manually so we can trigger a clean shutdown from an atexit hook. Note
+ # that due to the order of operations at shutdown, only daemon threads
+ # can be shut down in this way (non-daemon threads would require the
+ # introduction of a new hook: https://bugs.python.org/issue41962)
+ self._select_cond = threading.Condition()
+ self._select_args = (
+ None
+ ) # type: Optional[Tuple[List[_FileDescriptorLike], List[_FileDescriptorLike]]]
+ self._closing_selector = False
+ self._thread = threading.Thread(
+ name="Tornado selector",
+ daemon=True,
+ target=self._run_select,
+ )
+ self._thread.start()
+ # Start the select loop once the loop is started.
+ self._real_loop.call_soon(self._start_select)
+
+ self._readers = {} # type: Dict[_FileDescriptorLike, Callable]
+ self._writers = {} # type: Dict[_FileDescriptorLike, Callable]
+
+ # Writing to _waker_w will wake up the selector thread, which
+ # watches for _waker_r to be readable.
+ self._waker_r, self._waker_w = socket.socketpair()
+ self._waker_r.setblocking(False)
+ self._waker_w.setblocking(False)
+ _selector_loops.add(self)
+ self.add_reader(self._waker_r, self._consume_waker)
+
+ def __del__(self) -> None:
+ # If the top-level application code uses asyncio interfaces to
+ # start and stop the event loop, no objects created in Tornado
+ # can get a clean shutdown notification. If we're just left to
+ # be GC'd, we must explicitly close our sockets to avoid
+ # logging warnings.
+ _selector_loops.discard(self)
+ self._waker_r.close()
+ self._waker_w.close()
+
+ def close(self) -> None:
+ with self._select_cond:
+ self._closing_selector = True
+ self._select_cond.notify()
+ self._wake_selector()
+ self._thread.join()
+ _selector_loops.discard(self)
+ self._waker_r.close()
+ self._waker_w.close()
+ self._real_loop.close()
+
+ def _wake_selector(self) -> None:
+ try:
+ self._waker_w.send(b"a")
+ except BlockingIOError:
+ pass
+
+ def _consume_waker(self) -> None:
+ try:
+ self._waker_r.recv(1024)
+ except BlockingIOError:
+ pass
+
+ def _start_select(self) -> None:
+ # Capture reader and writer sets here in the event loop
+ # thread to avoid any problems with concurrent
+ # modification while the select loop uses them.
+ with self._select_cond:
+ assert self._select_args is None
+ self._select_args = (list(self._readers.keys()), list(self._writers.keys()))
+ self._select_cond.notify()
+
+ def _run_select(self) -> None:
+ while True:
+ with self._select_cond:
+ while self._select_args is None and not self._closing_selector:
+ self._select_cond.wait()
+ if self._closing_selector:
+ return
+ assert self._select_args is not None
+ to_read, to_write = self._select_args
+ self._select_args = None
+
+ # We use the simpler interface of the select module instead of
+ # the more stateful interface in the selectors module because
+ # this class is only intended for use on windows, where
+ # select.select is the only option. The selector interface
+ # does not have well-documented thread-safety semantics that
+ # we can rely on so ensuring proper synchronization would be
+ # tricky.
+ try:
+ # On windows, selecting on a socket for write will not
+ # return the socket when there is an error (but selecting
+ # for reads works). Also select for errors when selecting
+ # for writes, and merge the results.
+ #
+ # This pattern is also used in
+ # https://github.com/python/cpython/blob/v3.8.0/Lib/selectors.py#L312-L317
+ rs, ws, xs = select.select(to_read, to_write, to_write)
+ ws = ws + xs
+ except OSError as e:
+ # After remove_reader or remove_writer is called, the file
+ # descriptor may subsequently be closed on the event loop
+ # thread. It's possible that this select thread hasn't
+ # gotten into the select system call by the time that
+ # happens in which case (at least on macOS), select may
+ # raise a "bad file descriptor" error. If we get that
+ # error, check and see if we're also being woken up by
+ # polling the waker alone. If we are, just return to the
+ # event loop and we'll get the updated set of file
+ # descriptors on the next iteration. Otherwise, raise the
+ # original error.
+ if e.errno == getattr(errno, "WSAENOTSOCK", errno.EBADF):
+ rs, _, _ = select.select([self._waker_r.fileno()], [], [], 0)
+ if rs:
+ ws = []
+ else:
+ raise
+ else:
+ raise
+ self._real_loop.call_soon_threadsafe(self._handle_select, rs, ws)
+
+ def _handle_select(
+ self, rs: List["_FileDescriptorLike"], ws: List["_FileDescriptorLike"]
+ ) -> None:
+ for r in rs:
+ self._handle_event(r, self._readers)
+ for w in ws:
+ self._handle_event(w, self._writers)
+ self._start_select()
+
+ def _handle_event(
+ self,
+ fd: "_FileDescriptorLike",
+ cb_map: Dict["_FileDescriptorLike", Callable],
+ ) -> None:
+ try:
+ callback = cb_map[fd]
+ except KeyError:
+ return
+ callback()
+
+ def add_reader(
+ self, fd: "_FileDescriptorLike", callback: Callable[..., None], *args: Any
+ ) -> None:
+ self._readers[fd] = functools.partial(callback, *args)
+ self._wake_selector()
+
+ def add_writer(
+ self, fd: "_FileDescriptorLike", callback: Callable[..., None], *args: Any
+ ) -> None:
+ self._writers[fd] = functools.partial(callback, *args)
+ self._wake_selector()
+
+ def remove_reader(self, fd: "_FileDescriptorLike") -> None:
+ del self._readers[fd]
+ self._wake_selector()
+
+ def remove_writer(self, fd: "_FileDescriptorLike") -> None:
+ del self._writers[fd]
+ self._wake_selector()
diff --git a/tornado/platform/auto.py b/tornado/platform/auto.py
deleted file mode 100644
index 1a9133faf3..0000000000
--- a/tornado/platform/auto.py
+++ /dev/null
@@ -1,58 +0,0 @@
-#
-# Copyright 2011 Facebook
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-"""Implementation of platform-specific functionality.
-
-For each function or class described in `tornado.platform.interface`,
-the appropriate platform-specific implementation exists in this module.
-Most code that needs access to this functionality should do e.g.::
-
- from tornado.platform.auto import set_close_exec
-"""
-
-from __future__ import absolute_import, division, print_function
-
-import os
-
-if 'APPENGINE_RUNTIME' in os.environ:
- from tornado.platform.common import Waker
-
- def set_close_exec(fd):
- pass
-elif os.name == 'nt':
- from tornado.platform.common import Waker
- from tornado.platform.windows import set_close_exec
-else:
- from tornado.platform.posix import set_close_exec, Waker
-
-try:
- # monotime monkey-patches the time module to have a monotonic function
- # in versions of python before 3.3.
- import monotime
- # Silence pyflakes warning about this unused import
- monotime
-except ImportError:
- pass
-try:
- # monotonic can provide a monotonic function in versions of python before
- # 3.3, too.
- from monotonic import monotonic as monotonic_time
-except ImportError:
- try:
- from time import monotonic as monotonic_time
- except ImportError:
- monotonic_time = None
-
-__all__ = ['Waker', 'set_close_exec', 'monotonic_time']
diff --git a/tornado/platform/auto.pyi b/tornado/platform/auto.pyi
deleted file mode 100644
index a1c97228a4..0000000000
--- a/tornado/platform/auto.pyi
+++ /dev/null
@@ -1,4 +0,0 @@
-# auto.py is full of patterns mypy doesn't like, so for type checking
-# purposes we replace it with interface.py.
-
-from .interface import *
diff --git a/tornado/platform/caresresolver.py b/tornado/platform/caresresolver.py
index 768cb62499..962f84f48f 100644
--- a/tornado/platform/caresresolver.py
+++ b/tornado/platform/caresresolver.py
@@ -1,4 +1,3 @@
-from __future__ import absolute_import, division, print_function
import pycares # type: ignore
import socket
@@ -7,6 +6,11 @@
from tornado.ioloop import IOLoop
from tornado.netutil import Resolver, is_valid_ip
+import typing
+
+if typing.TYPE_CHECKING:
+ from typing import Generator, Any, List, Tuple, Dict # noqa: F401
+
class CaresResolver(Resolver):
"""Name resolver based on the c-ares library.
@@ -22,15 +26,19 @@ class CaresResolver(Resolver):
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
+
+ .. deprecated:: 6.2
+ This class is deprecated and will be removed in Tornado 7.0. Use the default
+ thread-based resolver instead.
"""
- def initialize(self):
+
+ def initialize(self) -> None:
self.io_loop = IOLoop.current()
self.channel = pycares.Channel(sock_state_cb=self._sock_state_cb)
- self.fds = {}
+ self.fds = {} # type: Dict[int, int]
- def _sock_state_cb(self, fd, readable, writable):
- state = ((IOLoop.READ if readable else 0) |
- (IOLoop.WRITE if writable else 0))
+ def _sock_state_cb(self, fd: int, readable: bool, writable: bool) -> None:
+ state = (IOLoop.READ if readable else 0) | (IOLoop.WRITE if writable else 0)
if not state:
self.io_loop.remove_handler(fd)
del self.fds[fd]
@@ -41,7 +49,7 @@ def _sock_state_cb(self, fd, readable, writable):
self.io_loop.add_handler(fd, self._handle_events, state)
self.fds[fd] = state
- def _handle_events(self, fd, events):
+ def _handle_events(self, fd: int, events: int) -> None:
read_fd = pycares.ARES_SOCKET_BAD
write_fd = pycares.ARES_SOCKET_BAD
if events & IOLoop.READ:
@@ -51,29 +59,35 @@ def _handle_events(self, fd, events):
self.channel.process_fd(read_fd, write_fd)
@gen.coroutine
- def resolve(self, host, port, family=0):
+ def resolve(
+ self, host: str, port: int, family: int = 0
+ ) -> "Generator[Any, Any, List[Tuple[int, Any]]]":
if is_valid_ip(host):
addresses = [host]
else:
# gethostbyname doesn't take callback as a kwarg
- fut = Future()
- self.channel.gethostbyname(host, family,
- lambda result, error: fut.set_result((result, error)))
+ fut = Future() # type: Future[Tuple[Any, Any]]
+ self.channel.gethostbyname(
+ host, family, lambda result, error: fut.set_result((result, error))
+ )
result, error = yield fut
if error:
- raise IOError('C-Ares returned error %s: %s while resolving %s' %
- (error, pycares.errno.strerror(error), host))
+ raise IOError(
+ "C-Ares returned error %s: %s while resolving %s"
+ % (error, pycares.errno.strerror(error), host)
+ )
addresses = result.addresses
addrinfo = []
for address in addresses:
- if '.' in address:
+ if "." in address:
address_family = socket.AF_INET
- elif ':' in address:
+ elif ":" in address:
address_family = socket.AF_INET6
else:
address_family = socket.AF_UNSPEC
if family != socket.AF_UNSPEC and family != address_family:
- raise IOError('Requested socket family %d but got %d' %
- (family, address_family))
- addrinfo.append((address_family, (address, port)))
- raise gen.Return(addrinfo)
+ raise IOError(
+ "Requested socket family %d but got %d" % (family, address_family)
+ )
+ addrinfo.append((typing.cast(int, address_family), (address, port)))
+ return addrinfo
diff --git a/tornado/platform/common.py b/tornado/platform/common.py
deleted file mode 100644
index b597748d1f..0000000000
--- a/tornado/platform/common.py
+++ /dev/null
@@ -1,113 +0,0 @@
-"""Lowest-common-denominator implementations of platform functionality."""
-from __future__ import absolute_import, division, print_function
-
-import errno
-import socket
-import time
-
-from tornado.platform import interface
-from tornado.util import errno_from_exception
-
-
-def try_close(f):
- # Avoid issue #875 (race condition when using the file in another
- # thread).
- for i in range(10):
- try:
- f.close()
- except IOError:
- # Yield to another thread
- time.sleep(1e-3)
- else:
- break
- # Try a last time and let raise
- f.close()
-
-
-class Waker(interface.Waker):
- """Create an OS independent asynchronous pipe.
-
- For use on platforms that don't have os.pipe() (or where pipes cannot
- be passed to select()), but do have sockets. This includes Windows
- and Jython.
- """
- def __init__(self):
- from .auto import set_close_exec
- # Based on Zope select_trigger.py:
- # https://github.com/zopefoundation/Zope/blob/master/src/ZServer/medusa/thread/select_trigger.py
-
- self.writer = socket.socket()
- set_close_exec(self.writer.fileno())
- # Disable buffering -- pulling the trigger sends 1 byte,
- # and we want that sent immediately, to wake up ASAP.
- self.writer.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
-
- count = 0
- while 1:
- count += 1
- # Bind to a local port; for efficiency, let the OS pick
- # a free port for us.
- # Unfortunately, stress tests showed that we may not
- # be able to connect to that port ("Address already in
- # use") despite that the OS picked it. This appears
- # to be a race bug in the Windows socket implementation.
- # So we loop until a connect() succeeds (almost always
- # on the first try). See the long thread at
- # http://mail.zope.org/pipermail/zope/2005-July/160433.html
- # for hideous details.
- a = socket.socket()
- set_close_exec(a.fileno())
- a.bind(("127.0.0.1", 0))
- a.listen(1)
- connect_address = a.getsockname() # assigned (host, port) pair
- try:
- self.writer.connect(connect_address)
- break # success
- except socket.error as detail:
- if (not hasattr(errno, 'WSAEADDRINUSE') or
- errno_from_exception(detail) != errno.WSAEADDRINUSE):
- # "Address already in use" is the only error
- # I've seen on two WinXP Pro SP2 boxes, under
- # Pythons 2.3.5 and 2.4.1.
- raise
- # (10048, 'Address already in use')
- # assert count <= 2 # never triggered in Tim's tests
- if count >= 10: # I've never seen it go above 2
- a.close()
- self.writer.close()
- raise socket.error("Cannot bind trigger!")
- # Close `a` and try again. Note: I originally put a short
- # sleep() here, but it didn't appear to help or hurt.
- a.close()
-
- self.reader, addr = a.accept()
- set_close_exec(self.reader.fileno())
- self.reader.setblocking(0)
- self.writer.setblocking(0)
- a.close()
- self.reader_fd = self.reader.fileno()
-
- def fileno(self):
- return self.reader.fileno()
-
- def write_fileno(self):
- return self.writer.fileno()
-
- def wake(self):
- try:
- self.writer.send(b"x")
- except (IOError, socket.error, ValueError):
- pass
-
- def consume(self):
- try:
- while True:
- result = self.reader.recv(1024)
- if not result:
- break
- except (IOError, socket.error):
- pass
-
- def close(self):
- self.reader.close()
- try_close(self.writer)
diff --git a/tornado/platform/epoll.py b/tornado/platform/epoll.py
deleted file mode 100644
index 4e34617406..0000000000
--- a/tornado/platform/epoll.py
+++ /dev/null
@@ -1,25 +0,0 @@
-#
-# Copyright 2012 Facebook
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-"""EPoll-based IOLoop implementation for Linux systems."""
-from __future__ import absolute_import, division, print_function
-
-import select
-
-from tornado.ioloop import PollIOLoop
-
-
-class EPollIOLoop(PollIOLoop):
- def initialize(self, **kwargs):
- super(EPollIOLoop, self).initialize(impl=select.epoll(), **kwargs)
diff --git a/tornado/platform/interface.py b/tornado/platform/interface.py
deleted file mode 100644
index cac5326465..0000000000
--- a/tornado/platform/interface.py
+++ /dev/null
@@ -1,66 +0,0 @@
-#
-# Copyright 2011 Facebook
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-"""Interfaces for platform-specific functionality.
-
-This module exists primarily for documentation purposes and as base classes
-for other tornado.platform modules. Most code should import the appropriate
-implementation from `tornado.platform.auto`.
-"""
-
-from __future__ import absolute_import, division, print_function
-
-
-def set_close_exec(fd):
- """Sets the close-on-exec bit (``FD_CLOEXEC``)for a file descriptor."""
- raise NotImplementedError()
-
-
-class Waker(object):
- """A socket-like object that can wake another thread from ``select()``.
-
- The `~tornado.ioloop.IOLoop` will add the Waker's `fileno()` to
- its ``select`` (or ``epoll`` or ``kqueue``) calls. When another
- thread wants to wake up the loop, it calls `wake`. Once it has woken
- up, it will call `consume` to do any necessary per-wake cleanup. When
- the ``IOLoop`` is closed, it closes its waker too.
- """
- def fileno(self):
- """Returns the read file descriptor for this waker.
-
- Must be suitable for use with ``select()`` or equivalent on the
- local platform.
- """
- raise NotImplementedError()
-
- def write_fileno(self):
- """Returns the write file descriptor for this waker."""
- raise NotImplementedError()
-
- def wake(self):
- """Triggers activity on the waker's file descriptor."""
- raise NotImplementedError()
-
- def consume(self):
- """Called after the listen has woken up to do any necessary cleanup."""
- raise NotImplementedError()
-
- def close(self):
- """Closes the waker's file descriptor(s)."""
- raise NotImplementedError()
-
-
-def monotonic_time():
- raise NotImplementedError()
diff --git a/tornado/platform/kqueue.py b/tornado/platform/kqueue.py
deleted file mode 100644
index 4e0aee02ee..0000000000
--- a/tornado/platform/kqueue.py
+++ /dev/null
@@ -1,90 +0,0 @@
-#
-# Copyright 2012 Facebook
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-"""KQueue-based IOLoop implementation for BSD/Mac systems."""
-from __future__ import absolute_import, division, print_function
-
-import select
-
-from tornado.ioloop import IOLoop, PollIOLoop
-
-assert hasattr(select, 'kqueue'), 'kqueue not supported'
-
-
-class _KQueue(object):
- """A kqueue-based event loop for BSD/Mac systems."""
- def __init__(self):
- self._kqueue = select.kqueue()
- self._active = {}
-
- def fileno(self):
- return self._kqueue.fileno()
-
- def close(self):
- self._kqueue.close()
-
- def register(self, fd, events):
- if fd in self._active:
- raise IOError("fd %s already registered" % fd)
- self._control(fd, events, select.KQ_EV_ADD)
- self._active[fd] = events
-
- def modify(self, fd, events):
- self.unregister(fd)
- self.register(fd, events)
-
- def unregister(self, fd):
- events = self._active.pop(fd)
- self._control(fd, events, select.KQ_EV_DELETE)
-
- def _control(self, fd, events, flags):
- kevents = []
- if events & IOLoop.WRITE:
- kevents.append(select.kevent(
- fd, filter=select.KQ_FILTER_WRITE, flags=flags))
- if events & IOLoop.READ:
- kevents.append(select.kevent(
- fd, filter=select.KQ_FILTER_READ, flags=flags))
- # Even though control() takes a list, it seems to return EINVAL
- # on Mac OS X (10.6) when there is more than one event in the list.
- for kevent in kevents:
- self._kqueue.control([kevent], 0)
-
- def poll(self, timeout):
- kevents = self._kqueue.control(None, 1000, timeout)
- events = {}
- for kevent in kevents:
- fd = kevent.ident
- if kevent.filter == select.KQ_FILTER_READ:
- events[fd] = events.get(fd, 0) | IOLoop.READ
- if kevent.filter == select.KQ_FILTER_WRITE:
- if kevent.flags & select.KQ_EV_EOF:
- # If an asynchronous connection is refused, kqueue
- # returns a write event with the EOF flag set.
- # Turn this into an error for consistency with the
- # other IOLoop implementations.
- # Note that for read events, EOF may be returned before
- # all data has been consumed from the socket buffer,
- # so we only check for EOF on write events.
- events[fd] = IOLoop.ERROR
- else:
- events[fd] = events.get(fd, 0) | IOLoop.WRITE
- if kevent.flags & select.KQ_EV_ERROR:
- events[fd] = events.get(fd, 0) | IOLoop.ERROR
- return events.items()
-
-
-class KQueueIOLoop(PollIOLoop):
- def initialize(self, **kwargs):
- super(KQueueIOLoop, self).initialize(impl=_KQueue(), **kwargs)
diff --git a/tornado/platform/posix.py b/tornado/platform/posix.py
deleted file mode 100644
index 6fe1fa8372..0000000000
--- a/tornado/platform/posix.py
+++ /dev/null
@@ -1,69 +0,0 @@
-#
-# Copyright 2011 Facebook
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-"""Posix implementations of platform-specific functionality."""
-
-from __future__ import absolute_import, division, print_function
-
-import fcntl
-import os
-
-from tornado.platform import common, interface
-
-
-def set_close_exec(fd):
- flags = fcntl.fcntl(fd, fcntl.F_GETFD)
- fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC)
-
-
-def _set_nonblocking(fd):
- flags = fcntl.fcntl(fd, fcntl.F_GETFL)
- fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
-
-
-class Waker(interface.Waker):
- def __init__(self):
- r, w = os.pipe()
- _set_nonblocking(r)
- _set_nonblocking(w)
- set_close_exec(r)
- set_close_exec(w)
- self.reader = os.fdopen(r, "rb", 0)
- self.writer = os.fdopen(w, "wb", 0)
-
- def fileno(self):
- return self.reader.fileno()
-
- def write_fileno(self):
- return self.writer.fileno()
-
- def wake(self):
- try:
- self.writer.write(b"x")
- except (IOError, ValueError):
- pass
-
- def consume(self):
- try:
- while True:
- result = self.reader.read()
- if not result:
- break
- except IOError:
- pass
-
- def close(self):
- self.reader.close()
- common.try_close(self.writer)
diff --git a/tornado/platform/select.py b/tornado/platform/select.py
deleted file mode 100644
index 14e8a4745c..0000000000
--- a/tornado/platform/select.py
+++ /dev/null
@@ -1,75 +0,0 @@
-#
-# Copyright 2012 Facebook
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-"""Select-based IOLoop implementation.
-
-Used as a fallback for systems that don't support epoll or kqueue.
-"""
-from __future__ import absolute_import, division, print_function
-
-import select
-
-from tornado.ioloop import IOLoop, PollIOLoop
-
-
-class _Select(object):
- """A simple, select()-based IOLoop implementation for non-Linux systems"""
- def __init__(self):
- self.read_fds = set()
- self.write_fds = set()
- self.error_fds = set()
- self.fd_sets = (self.read_fds, self.write_fds, self.error_fds)
-
- def close(self):
- pass
-
- def register(self, fd, events):
- if fd in self.read_fds or fd in self.write_fds or fd in self.error_fds:
- raise IOError("fd %s already registered" % fd)
- if events & IOLoop.READ:
- self.read_fds.add(fd)
- if events & IOLoop.WRITE:
- self.write_fds.add(fd)
- if events & IOLoop.ERROR:
- self.error_fds.add(fd)
- # Closed connections are reported as errors by epoll and kqueue,
- # but as zero-byte reads by select, so when errors are requested
- # we need to listen for both read and error.
- # self.read_fds.add(fd)
-
- def modify(self, fd, events):
- self.unregister(fd)
- self.register(fd, events)
-
- def unregister(self, fd):
- self.read_fds.discard(fd)
- self.write_fds.discard(fd)
- self.error_fds.discard(fd)
-
- def poll(self, timeout):
- readable, writeable, errors = select.select(
- self.read_fds, self.write_fds, self.error_fds, timeout)
- events = {}
- for fd in readable:
- events[fd] = events.get(fd, 0) | IOLoop.READ
- for fd in writeable:
- events[fd] = events.get(fd, 0) | IOLoop.WRITE
- for fd in errors:
- events[fd] = events.get(fd, 0) | IOLoop.ERROR
- return events.items()
-
-
-class SelectIOLoop(PollIOLoop):
- def initialize(self, **kwargs):
- super(SelectIOLoop, self).initialize(impl=_Select(), **kwargs)
diff --git a/tornado/platform/twisted.py b/tornado/platform/twisted.py
index 4ae98be976..153fe436eb 100644
--- a/tornado/platform/twisted.py
+++ b/tornado/platform/twisted.py
@@ -1,6 +1,3 @@
-# Author: Ovidiu Predescu
-# Date: July 2011
-#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
@@ -12,505 +9,31 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
-"""Bridges between the Twisted reactor and Tornado IOLoop.
-
-This module lets you run applications and libraries written for
-Twisted in a Tornado application. It can be used in two modes,
-depending on which library's underlying event loop you want to use.
-
-This module has been tested with Twisted versions 11.0.0 and newer.
+"""Bridges between the Twisted package and Tornado.
"""
-from __future__ import absolute_import, division, print_function
-
-import datetime
-import functools
-import numbers
import socket
import sys
import twisted.internet.abstract # type: ignore
+import twisted.internet.asyncioreactor # type: ignore
from twisted.internet.defer import Deferred # type: ignore
-from twisted.internet.posixbase import PosixReactorBase # type: ignore
-from twisted.internet.interfaces import IReactorFDSet, IDelayedCall, IReactorTime, IReadDescriptor, IWriteDescriptor # type: ignore # noqa: E501
-from twisted.python import failure, log # type: ignore
-from twisted.internet import error # type: ignore
+from twisted.python import failure # type: ignore
import twisted.names.cache # type: ignore
import twisted.names.client # type: ignore
import twisted.names.hosts # type: ignore
import twisted.names.resolve # type: ignore
-from zope.interface import implementer # type: ignore
from tornado.concurrent import Future, future_set_exc_info
from tornado.escape import utf8
from tornado import gen
-import tornado.ioloop
-from tornado.log import app_log
from tornado.netutil import Resolver
-from tornado.stack_context import NullContext, wrap
-from tornado.ioloop import IOLoop
-from tornado.util import timedelta_to_seconds
-
-
-@implementer(IDelayedCall)
-class TornadoDelayedCall(object):
- """DelayedCall object for Tornado."""
- def __init__(self, reactor, seconds, f, *args, **kw):
- self._reactor = reactor
- self._func = functools.partial(f, *args, **kw)
- self._time = self._reactor.seconds() + seconds
- self._timeout = self._reactor._io_loop.add_timeout(self._time,
- self._called)
- self._active = True
-
- def _called(self):
- self._active = False
- self._reactor._removeDelayedCall(self)
- try:
- self._func()
- except:
- app_log.error("_called caught exception", exc_info=True)
-
- def getTime(self):
- return self._time
-
- def cancel(self):
- self._active = False
- self._reactor._io_loop.remove_timeout(self._timeout)
- self._reactor._removeDelayedCall(self)
-
- def delay(self, seconds):
- self._reactor._io_loop.remove_timeout(self._timeout)
- self._time += seconds
- self._timeout = self._reactor._io_loop.add_timeout(self._time,
- self._called)
-
- def reset(self, seconds):
- self._reactor._io_loop.remove_timeout(self._timeout)
- self._time = self._reactor.seconds() + seconds
- self._timeout = self._reactor._io_loop.add_timeout(self._time,
- self._called)
-
- def active(self):
- return self._active
-
-
-@implementer(IReactorTime, IReactorFDSet)
-class TornadoReactor(PosixReactorBase):
- """Twisted reactor built on the Tornado IOLoop.
-
- `TornadoReactor` implements the Twisted reactor interface on top of
- the Tornado IOLoop. To use it, simply call `install` at the beginning
- of the application::
-
- import tornado.platform.twisted
- tornado.platform.twisted.install()
- from twisted.internet import reactor
-
- When the app is ready to start, call ``IOLoop.current().start()``
- instead of ``reactor.run()``.
-
- It is also possible to create a non-global reactor by calling
- ``tornado.platform.twisted.TornadoReactor()``. However, if
- the `.IOLoop` and reactor are to be short-lived (such as those used in
- unit tests), additional cleanup may be required. Specifically, it is
- recommended to call::
-
- reactor.fireSystemEvent('shutdown')
- reactor.disconnectAll()
-
- before closing the `.IOLoop`.
-
- .. versionchanged:: 5.0
- The ``io_loop`` argument (deprecated since version 4.1) has been removed.
- """
- def __init__(self):
- self._io_loop = tornado.ioloop.IOLoop.current()
- self._readers = {} # map of reader objects to fd
- self._writers = {} # map of writer objects to fd
- self._fds = {} # a map of fd to a (reader, writer) tuple
- self._delayedCalls = {}
- PosixReactorBase.__init__(self)
- self.addSystemEventTrigger('during', 'shutdown', self.crash)
-
- # IOLoop.start() bypasses some of the reactor initialization.
- # Fire off the necessary events if they weren't already triggered
- # by reactor.run().
- def start_if_necessary():
- if not self._started:
- self.fireSystemEvent('startup')
- self._io_loop.add_callback(start_if_necessary)
-
- # IReactorTime
- def seconds(self):
- return self._io_loop.time()
-
- def callLater(self, seconds, f, *args, **kw):
- dc = TornadoDelayedCall(self, seconds, f, *args, **kw)
- self._delayedCalls[dc] = True
- return dc
-
- def getDelayedCalls(self):
- return [x for x in self._delayedCalls if x._active]
-
- def _removeDelayedCall(self, dc):
- if dc in self._delayedCalls:
- del self._delayedCalls[dc]
-
- # IReactorThreads
- def callFromThread(self, f, *args, **kw):
- assert callable(f), "%s is not callable" % f
- with NullContext():
- # This NullContext is mainly for an edge case when running
- # TwistedIOLoop on top of a TornadoReactor.
- # TwistedIOLoop.add_callback uses reactor.callFromThread and
- # should not pick up additional StackContexts along the way.
- self._io_loop.add_callback(f, *args, **kw)
-
- # We don't need the waker code from the super class, Tornado uses
- # its own waker.
- def installWaker(self):
- pass
-
- def wakeUp(self):
- pass
-
- # IReactorFDSet
- def _invoke_callback(self, fd, events):
- if fd not in self._fds:
- return
- (reader, writer) = self._fds[fd]
- if reader:
- err = None
- if reader.fileno() == -1:
- err = error.ConnectionLost()
- elif events & IOLoop.READ:
- err = log.callWithLogger(reader, reader.doRead)
- if err is None and events & IOLoop.ERROR:
- err = error.ConnectionLost()
- if err is not None:
- self.removeReader(reader)
- reader.readConnectionLost(failure.Failure(err))
- if writer:
- err = None
- if writer.fileno() == -1:
- err = error.ConnectionLost()
- elif events & IOLoop.WRITE:
- err = log.callWithLogger(writer, writer.doWrite)
- if err is None and events & IOLoop.ERROR:
- err = error.ConnectionLost()
- if err is not None:
- self.removeWriter(writer)
- writer.writeConnectionLost(failure.Failure(err))
-
- def addReader(self, reader):
- if reader in self._readers:
- # Don't add the reader if it's already there
- return
- fd = reader.fileno()
- self._readers[reader] = fd
- if fd in self._fds:
- (_, writer) = self._fds[fd]
- self._fds[fd] = (reader, writer)
- if writer:
- # We already registered this fd for write events,
- # update it for read events as well.
- self._io_loop.update_handler(fd, IOLoop.READ | IOLoop.WRITE)
- else:
- with NullContext():
- self._fds[fd] = (reader, None)
- self._io_loop.add_handler(fd, self._invoke_callback,
- IOLoop.READ)
-
- def addWriter(self, writer):
- if writer in self._writers:
- return
- fd = writer.fileno()
- self._writers[writer] = fd
- if fd in self._fds:
- (reader, _) = self._fds[fd]
- self._fds[fd] = (reader, writer)
- if reader:
- # We already registered this fd for read events,
- # update it for write events as well.
- self._io_loop.update_handler(fd, IOLoop.READ | IOLoop.WRITE)
- else:
- with NullContext():
- self._fds[fd] = (None, writer)
- self._io_loop.add_handler(fd, self._invoke_callback,
- IOLoop.WRITE)
-
- def removeReader(self, reader):
- if reader in self._readers:
- fd = self._readers.pop(reader)
- (_, writer) = self._fds[fd]
- if writer:
- # We have a writer so we need to update the IOLoop for
- # write events only.
- self._fds[fd] = (None, writer)
- self._io_loop.update_handler(fd, IOLoop.WRITE)
- else:
- # Since we have no writer registered, we remove the
- # entry from _fds and unregister the handler from the
- # IOLoop
- del self._fds[fd]
- self._io_loop.remove_handler(fd)
-
- def removeWriter(self, writer):
- if writer in self._writers:
- fd = self._writers.pop(writer)
- (reader, _) = self._fds[fd]
- if reader:
- # We have a reader so we need to update the IOLoop for
- # read events only.
- self._fds[fd] = (reader, None)
- self._io_loop.update_handler(fd, IOLoop.READ)
- else:
- # Since we have no reader registered, we remove the
- # entry from the _fds and unregister the handler from
- # the IOLoop.
- del self._fds[fd]
- self._io_loop.remove_handler(fd)
-
- def removeAll(self):
- return self._removeAll(self._readers, self._writers)
-
- def getReaders(self):
- return self._readers.keys()
-
- def getWriters(self):
- return self._writers.keys()
-
- # The following functions are mainly used in twisted-style test cases;
- # it is expected that most users of the TornadoReactor will call
- # IOLoop.start() instead of Reactor.run().
- def stop(self):
- PosixReactorBase.stop(self)
- fire_shutdown = functools.partial(self.fireSystemEvent, "shutdown")
- self._io_loop.add_callback(fire_shutdown)
-
- def crash(self):
- PosixReactorBase.crash(self)
- self._io_loop.stop()
-
- def doIteration(self, delay):
- raise NotImplementedError("doIteration")
-
- def mainLoop(self):
- # Since this class is intended to be used in applications
- # where the top-level event loop is ``io_loop.start()`` rather
- # than ``reactor.run()``, it is implemented a little
- # differently than other Twisted reactors. We override
- # ``mainLoop`` instead of ``doIteration`` and must implement
- # timed call functionality on top of `.IOLoop.add_timeout`
- # rather than using the implementation in
- # ``PosixReactorBase``.
- self._io_loop.start()
-
-
-class _TestReactor(TornadoReactor):
- """Subclass of TornadoReactor for use in unittests.
-
- This can't go in the test.py file because of import-order dependencies
- with the Twisted reactor test builder.
- """
- def __init__(self):
- # always use a new ioloop
- IOLoop.clear_current()
- IOLoop(make_current=True)
- super(_TestReactor, self).__init__()
- IOLoop.clear_current()
- def listenTCP(self, port, factory, backlog=50, interface=''):
- # default to localhost to avoid firewall prompts on the mac
- if not interface:
- interface = '127.0.0.1'
- return super(_TestReactor, self).listenTCP(
- port, factory, backlog=backlog, interface=interface)
+import typing
- def listenUDP(self, port, protocol, interface='', maxPacketSize=8192):
- if not interface:
- interface = '127.0.0.1'
- return super(_TestReactor, self).listenUDP(
- port, protocol, interface=interface, maxPacketSize=maxPacketSize)
-
-
-def install():
- """Install this package as the default Twisted reactor.
-
- ``install()`` must be called very early in the startup process,
- before most other twisted-related imports. Conversely, because it
- initializes the `.IOLoop`, it cannot be called before
- `.fork_processes` or multi-process `~.TCPServer.start`. These
- conflicting requirements make it difficult to use `.TornadoReactor`
- in multi-process mode, and an external process manager such as
- ``supervisord`` is recommended instead.
-
- .. versionchanged:: 5.0
- The ``io_loop`` argument (deprecated since version 4.1) has been removed.
-
- """
- reactor = TornadoReactor()
- from twisted.internet.main import installReactor # type: ignore
- installReactor(reactor)
- return reactor
-
-
-@implementer(IReadDescriptor, IWriteDescriptor)
-class _FD(object):
- def __init__(self, fd, fileobj, handler):
- self.fd = fd
- self.fileobj = fileobj
- self.handler = handler
- self.reading = False
- self.writing = False
- self.lost = False
-
- def fileno(self):
- return self.fd
-
- def doRead(self):
- if not self.lost:
- self.handler(self.fileobj, tornado.ioloop.IOLoop.READ)
-
- def doWrite(self):
- if not self.lost:
- self.handler(self.fileobj, tornado.ioloop.IOLoop.WRITE)
-
- def connectionLost(self, reason):
- if not self.lost:
- self.handler(self.fileobj, tornado.ioloop.IOLoop.ERROR)
- self.lost = True
-
- writeConnectionLost = readConnectionLost = connectionLost
-
- def logPrefix(self):
- return ''
-
-
-class TwistedIOLoop(tornado.ioloop.IOLoop):
- """IOLoop implementation that runs on Twisted.
-
- `TwistedIOLoop` implements the Tornado IOLoop interface on top of
- the Twisted reactor. Recommended usage::
-
- from tornado.platform.twisted import TwistedIOLoop
- from twisted.internet import reactor
- TwistedIOLoop().install()
- # Set up your tornado application as usual using `IOLoop.instance`
- reactor.run()
-
- Uses the global Twisted reactor by default. To create multiple
- ``TwistedIOLoops`` in the same process, you must pass a unique reactor
- when constructing each one.
-
- Not compatible with `tornado.process.Subprocess.set_exit_callback`
- because the ``SIGCHLD`` handlers used by Tornado and Twisted conflict
- with each other.
-
- See also :meth:`tornado.ioloop.IOLoop.install` for general notes on
- installing alternative IOLoops.
- """
- def initialize(self, reactor=None, **kwargs):
- super(TwistedIOLoop, self).initialize(**kwargs)
- if reactor is None:
- import twisted.internet.reactor # type: ignore
- reactor = twisted.internet.reactor
- self.reactor = reactor
- self.fds = {}
-
- def close(self, all_fds=False):
- fds = self.fds
- self.reactor.removeAll()
- for c in self.reactor.getDelayedCalls():
- c.cancel()
- if all_fds:
- for fd in fds.values():
- self.close_fd(fd.fileobj)
-
- def add_handler(self, fd, handler, events):
- if fd in self.fds:
- raise ValueError('fd %s added twice' % fd)
- fd, fileobj = self.split_fd(fd)
- self.fds[fd] = _FD(fd, fileobj, wrap(handler))
- if events & tornado.ioloop.IOLoop.READ:
- self.fds[fd].reading = True
- self.reactor.addReader(self.fds[fd])
- if events & tornado.ioloop.IOLoop.WRITE:
- self.fds[fd].writing = True
- self.reactor.addWriter(self.fds[fd])
-
- def update_handler(self, fd, events):
- fd, fileobj = self.split_fd(fd)
- if events & tornado.ioloop.IOLoop.READ:
- if not self.fds[fd].reading:
- self.fds[fd].reading = True
- self.reactor.addReader(self.fds[fd])
- else:
- if self.fds[fd].reading:
- self.fds[fd].reading = False
- self.reactor.removeReader(self.fds[fd])
- if events & tornado.ioloop.IOLoop.WRITE:
- if not self.fds[fd].writing:
- self.fds[fd].writing = True
- self.reactor.addWriter(self.fds[fd])
- else:
- if self.fds[fd].writing:
- self.fds[fd].writing = False
- self.reactor.removeWriter(self.fds[fd])
-
- def remove_handler(self, fd):
- fd, fileobj = self.split_fd(fd)
- if fd not in self.fds:
- return
- self.fds[fd].lost = True
- if self.fds[fd].reading:
- self.reactor.removeReader(self.fds[fd])
- if self.fds[fd].writing:
- self.reactor.removeWriter(self.fds[fd])
- del self.fds[fd]
-
- def start(self):
- old_current = IOLoop.current(instance=False)
- try:
- self._setup_logging()
- self.make_current()
- self.reactor.run()
- finally:
- if old_current is None:
- IOLoop.clear_current()
- else:
- old_current.make_current()
-
- def stop(self):
- self.reactor.crash()
-
- def add_timeout(self, deadline, callback, *args, **kwargs):
- # This method could be simplified (since tornado 4.0) by
- # overriding call_at instead of add_timeout, but we leave it
- # for now as a test of backwards-compatibility.
- if isinstance(deadline, numbers.Real):
- delay = max(deadline - self.time(), 0)
- elif isinstance(deadline, datetime.timedelta):
- delay = timedelta_to_seconds(deadline)
- else:
- raise TypeError("Unsupported deadline %r")
- return self.reactor.callLater(
- delay, self._run_callback,
- functools.partial(wrap(callback), *args, **kwargs))
-
- def remove_timeout(self, timeout):
- if timeout.active():
- timeout.cancel()
-
- def add_callback(self, callback, *args, **kwargs):
- self.reactor.callFromThread(
- self._run_callback,
- functools.partial(wrap(callback), *args, **kwargs))
-
- def add_callback_from_signal(self, callback, *args, **kwargs):
- self.add_callback(callback, *args, **kwargs)
+if typing.TYPE_CHECKING:
+ from typing import Generator, Any, List, Tuple # noqa: F401
class TwistedResolver(Resolver):
@@ -529,21 +52,30 @@ class TwistedResolver(Resolver):
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
+
+ .. deprecated:: 6.2
+ This class is deprecated and will be removed in Tornado 7.0. Use the default
+ thread-based resolver instead.
"""
- def initialize(self):
+
+ def initialize(self) -> None:
# partial copy of twisted.names.client.createResolver, which doesn't
# allow for a reactor to be passed in.
- self.reactor = tornado.platform.twisted.TornadoReactor()
+ self.reactor = twisted.internet.asyncioreactor.AsyncioSelectorReactor()
- host_resolver = twisted.names.hosts.Resolver('/etc/hosts')
+ host_resolver = twisted.names.hosts.Resolver("/etc/hosts")
cache_resolver = twisted.names.cache.CacheResolver(reactor=self.reactor)
- real_resolver = twisted.names.client.Resolver('/etc/resolv.conf',
- reactor=self.reactor)
+ real_resolver = twisted.names.client.Resolver(
+ "/etc/resolv.conf", reactor=self.reactor
+ )
self.resolver = twisted.names.resolve.ResolverChain(
- [host_resolver, cache_resolver, real_resolver])
+ [host_resolver, cache_resolver, real_resolver]
+ )
@gen.coroutine
- def resolve(self, host, port, family=0):
+ def resolve(
+ self, host: str, port: int, family: int = 0
+ ) -> "Generator[Any, Any, List[Tuple[int, Any]]]":
# getHostByName doesn't accept IP addresses, so if the input
# looks like an IP address just return it immediately.
if twisted.internet.abstract.isIPAddress(host):
@@ -554,7 +86,7 @@ def resolve(self, host, port, family=0):
resolved_family = socket.AF_INET6
else:
deferred = self.resolver.getHostByName(utf8(host))
- fut = Future()
+ fut = Future() # type: Future[Any]
deferred.addBoth(fut.set_result)
resolved = yield fut
if isinstance(resolved, failure.Failure):
@@ -569,25 +101,50 @@ def resolve(self, host, port, family=0):
else:
resolved_family = socket.AF_UNSPEC
if family != socket.AF_UNSPEC and family != resolved_family:
- raise Exception('Requested socket family %d but got %d' %
- (family, resolved_family))
- result = [
- (resolved_family, (resolved, port)),
- ]
- raise gen.Return(result)
+ raise Exception(
+ "Requested socket family %d but got %d" % (family, resolved_family)
+ )
+ result = [(typing.cast(int, resolved_family), (resolved, port))]
+ return result
+
+
+def install() -> None:
+ """Install ``AsyncioSelectorReactor`` as the default Twisted reactor.
+
+ .. deprecated:: 5.1
+
+ This function is provided for backwards compatibility; code
+ that does not require compatibility with older versions of
+ Tornado should use
+ ``twisted.internet.asyncioreactor.install()`` directly.
+
+ .. versionchanged:: 6.0.3
+ In Tornado 5.x and before, this function installed a reactor
+ based on the Tornado ``IOLoop``. When that reactor
+ implementation was removed in Tornado 6.0.0, this function was
+ removed as well. It was restored in Tornado 6.0.3 using the
+ ``asyncio`` reactor instead.
+
+ """
+ from twisted.internet.asyncioreactor import install
+
+ install()
+
+
+if hasattr(gen.convert_yielded, "register"):
-if hasattr(gen.convert_yielded, 'register'):
@gen.convert_yielded.register(Deferred) # type: ignore
- def _(d):
- f = Future()
+ def _(d: Deferred) -> Future:
+ f = Future() # type: Future[Any]
- def errback(failure):
+ def errback(failure: failure.Failure) -> None:
try:
failure.raiseException()
# Should never happen, but just in case
raise Exception("errback called without error")
except:
future_set_exc_info(f, sys.exc_info())
+
d.addCallbacks(f.set_result, errback)
return f
diff --git a/tornado/platform/windows.py b/tornado/platform/windows.py
deleted file mode 100644
index 4127700659..0000000000
--- a/tornado/platform/windows.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# NOTE: win32 support is currently experimental, and not recommended
-# for production use.
-
-
-from __future__ import absolute_import, division, print_function
-import ctypes # type: ignore
-import ctypes.wintypes # type: ignore
-
-# See: http://msdn.microsoft.com/en-us/library/ms724935(VS.85).aspx
-SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation
-SetHandleInformation.argtypes = (ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD, ctypes.wintypes.DWORD) # noqa: E501
-SetHandleInformation.restype = ctypes.wintypes.BOOL
-
-HANDLE_FLAG_INHERIT = 0x00000001
-
-
-def set_close_exec(fd):
- success = SetHandleInformation(fd, HANDLE_FLAG_INHERIT, 0)
- if not success:
- raise ctypes.WinError()
diff --git a/tornado/process.py b/tornado/process.py
index 122fd7e14b..26428feb77 100644
--- a/tornado/process.py
+++ b/tornado/process.py
@@ -17,10 +17,8 @@
the server into multiple processes and managing subprocesses.
"""
-from __future__ import absolute_import, division, print_function
-
-import errno
import os
+import multiprocessing
import signal
import subprocess
import sys
@@ -28,35 +26,26 @@
from binascii import hexlify
-from tornado.concurrent import Future, future_set_result_unless_cancelled
+from tornado.concurrent import (
+ Future,
+ future_set_result_unless_cancelled,
+ future_set_exception_unless_cancelled,
+)
from tornado import ioloop
from tornado.iostream import PipeIOStream
from tornado.log import gen_log
-from tornado.platform.auto import set_close_exec
-from tornado import stack_context
-from tornado.util import errno_from_exception, PY3
-try:
- import multiprocessing
-except ImportError:
- # Multiprocessing is not available on Google App Engine.
- multiprocessing = None
+import typing
+from typing import Optional, Any, Callable
-if PY3:
- long = int
+if typing.TYPE_CHECKING:
+ from typing import List # noqa: F401
# Re-export this exception for convenience.
-try:
- CalledProcessError = subprocess.CalledProcessError
-except AttributeError:
- # The subprocess module exists in Google App Engine, but is empty.
- # This module isn't very useful in that case, but it should
- # at least be importable.
- if 'APPENGINE_RUNTIME' not in os.environ:
- raise
+CalledProcessError = subprocess.CalledProcessError
-def cpu_count():
+def cpu_count() -> int:
"""Returns the number of processors on this machine."""
if multiprocessing is None:
return 1
@@ -65,38 +54,34 @@ def cpu_count():
except NotImplementedError:
pass
try:
- return os.sysconf("SC_NPROCESSORS_CONF")
+ return os.sysconf("SC_NPROCESSORS_CONF") # type: ignore
except (AttributeError, ValueError):
pass
gen_log.error("Could not detect number of processors; assuming 1")
return 1
-def _reseed_random():
- if 'random' not in sys.modules:
+def _reseed_random() -> None:
+ if "random" not in sys.modules:
return
import random
+
# If os.urandom is available, this method does the same thing as
# random.seed (at least as of python 2.6). If os.urandom is not
# available, we mix in the pid in addition to a timestamp.
try:
- seed = long(hexlify(os.urandom(16)), 16)
+ seed = int(hexlify(os.urandom(16)), 16)
except NotImplementedError:
seed = int(time.time() * 1000) ^ os.getpid()
random.seed(seed)
-def _pipe_cloexec():
- r, w = os.pipe()
- set_close_exec(r)
- set_close_exec(w)
- return r, w
-
-
_task_id = None
-def fork_processes(num_processes, max_restarts=100):
+def fork_processes(
+ num_processes: Optional[int], max_restarts: Optional[int] = None
+) -> int:
"""Starts multiple worker processes.
If ``num_processes`` is None or <= 0, we detect the number of cores
@@ -117,10 +102,20 @@ def fork_processes(num_processes, max_restarts=100):
number between 0 and ``num_processes``. Processes that exit
abnormally (due to a signal or non-zero exit status) are restarted
with the same id (up to ``max_restarts`` times). In the parent
- process, ``fork_processes`` returns None if all child processes
- have exited normally, but will otherwise only exit by throwing an
- exception.
+ process, ``fork_processes`` calls ``sys.exit(0)`` after all child
+ processes have exited normally.
+
+ max_restarts defaults to 100.
+
+ Availability: Unix
"""
+ if sys.platform == "win32":
+ # The exact form of this condition matters to mypy; it understands
+ # if but not assert in this context.
+ raise Exception("fork not available on windows")
+ if max_restarts is None:
+ max_restarts = 100
+
global _task_id
assert _task_id is None
if num_processes is None or num_processes <= 0:
@@ -128,7 +123,7 @@ def fork_processes(num_processes, max_restarts=100):
gen_log.info("Starting %d processes", num_processes)
children = {}
- def start_child(i):
+ def start_child(i: int) -> Optional[int]:
pid = os.fork()
if pid == 0:
# child process
@@ -146,21 +141,24 @@ def start_child(i):
return id
num_restarts = 0
while children:
- try:
- pid, status = os.wait()
- except OSError as e:
- if errno_from_exception(e) == errno.EINTR:
- continue
- raise
+ pid, status = os.wait()
if pid not in children:
continue
id = children.pop(pid)
if os.WIFSIGNALED(status):
- gen_log.warning("child %d (pid %d) killed by signal %d, restarting",
- id, pid, os.WTERMSIG(status))
+ gen_log.warning(
+ "child %d (pid %d) killed by signal %d, restarting",
+ id,
+ pid,
+ os.WTERMSIG(status),
+ )
elif os.WEXITSTATUS(status) != 0:
- gen_log.warning("child %d (pid %d) exited with status %d, restarting",
- id, pid, os.WEXITSTATUS(status))
+ gen_log.warning(
+ "child %d (pid %d) exited with status %d, restarting",
+ id,
+ pid,
+ os.WEXITSTATUS(status),
+ )
else:
gen_log.info("child %d (pid %d) exited normally", id, pid)
continue
@@ -177,7 +175,7 @@ def start_child(i):
sys.exit(0)
-def task_id():
+def task_id() -> Optional[int]:
"""Returns the current task id, if any.
Returns None if this process was not created by `fork_processes`.
@@ -207,32 +205,34 @@ class Subprocess(object):
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
"""
+
STREAM = object()
_initialized = False
_waiting = {} # type: ignore
+ _old_sigchld = None
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
self.io_loop = ioloop.IOLoop.current()
# All FDs we create should be closed on error; those in to_close
# should be closed in the parent process on success.
- pipe_fds = []
- to_close = []
- if kwargs.get('stdin') is Subprocess.STREAM:
- in_r, in_w = _pipe_cloexec()
- kwargs['stdin'] = in_r
+ pipe_fds = [] # type: List[int]
+ to_close = [] # type: List[int]
+ if kwargs.get("stdin") is Subprocess.STREAM:
+ in_r, in_w = os.pipe()
+ kwargs["stdin"] = in_r
pipe_fds.extend((in_r, in_w))
to_close.append(in_r)
self.stdin = PipeIOStream(in_w)
- if kwargs.get('stdout') is Subprocess.STREAM:
- out_r, out_w = _pipe_cloexec()
- kwargs['stdout'] = out_w
+ if kwargs.get("stdout") is Subprocess.STREAM:
+ out_r, out_w = os.pipe()
+ kwargs["stdout"] = out_w
pipe_fds.extend((out_r, out_w))
to_close.append(out_w)
self.stdout = PipeIOStream(out_r)
- if kwargs.get('stderr') is Subprocess.STREAM:
- err_r, err_w = _pipe_cloexec()
- kwargs['stderr'] = err_w
+ if kwargs.get("stderr") is Subprocess.STREAM:
+ err_r, err_w = os.pipe()
+ kwargs["stderr"] = err_w
pipe_fds.extend((err_r, err_w))
to_close.append(err_w)
self.stderr = PipeIOStream(err_r)
@@ -244,13 +244,14 @@ def __init__(self, *args, **kwargs):
raise
for fd in to_close:
os.close(fd)
- for attr in ['stdin', 'stdout', 'stderr', 'pid']:
+ self.pid = self.proc.pid
+ for attr in ["stdin", "stdout", "stderr"]:
if not hasattr(self, attr): # don't clobber streams set above
setattr(self, attr, getattr(self.proc, attr))
- self._exit_callback = None
- self.returncode = None
+ self._exit_callback = None # type: Optional[Callable[[int], None]]
+ self.returncode = None # type: Optional[int]
- def set_exit_callback(self, callback):
+ def set_exit_callback(self, callback: Callable[[int], None]) -> None:
"""Runs ``callback`` when this process exits.
The callback takes one argument, the return code of the process.
@@ -264,13 +265,15 @@ def set_exit_callback(self, callback):
In many cases a close callback on the stdout or stderr streams
can be used as an alternative to an exit callback if the
signal handler is causing a problem.
+
+ Availability: Unix
"""
- self._exit_callback = stack_context.wrap(callback)
+ self._exit_callback = callback
Subprocess.initialize()
Subprocess._waiting[self.pid] = self
Subprocess._try_cleanup_process(self.pid)
- def wait_for_exit(self, raise_error=True):
+ def wait_for_exit(self, raise_error: bool = True) -> "Future[int]":
"""Returns a `.Future` which resolves when the process exits.
Usage::
@@ -285,20 +288,25 @@ def wait_for_exit(self, raise_error=True):
to suppress this behavior and return the exit status without raising.
.. versionadded:: 4.2
+
+ Availability: Unix
"""
- future = Future()
+ future = Future() # type: Future[int]
- def callback(ret):
+ def callback(ret: int) -> None:
if ret != 0 and raise_error:
# Unfortunately we don't have the original args any more.
- future.set_exception(CalledProcessError(ret, None))
+ future_set_exception_unless_cancelled(
+ future, CalledProcessError(ret, "unknown")
+ )
else:
future_set_result_unless_cancelled(future, ret)
+
self.set_exit_callback(callback)
return future
@classmethod
- def initialize(cls):
+ def initialize(cls) -> None:
"""Initializes the ``SIGCHLD`` handler.
The signal handler is run on an `.IOLoop` to avoid locking issues.
@@ -309,17 +317,20 @@ def initialize(cls):
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been
removed.
+
+ Availability: Unix
"""
if cls._initialized:
return
io_loop = ioloop.IOLoop.current()
cls._old_sigchld = signal.signal(
signal.SIGCHLD,
- lambda sig, frame: io_loop.add_callback_from_signal(cls._cleanup))
+ lambda sig, frame: io_loop.add_callback_from_signal(cls._cleanup),
+ )
cls._initialized = True
@classmethod
- def uninitialize(cls):
+ def uninitialize(cls) -> None:
"""Removes the ``SIGCHLD`` handler."""
if not cls._initialized:
return
@@ -327,30 +338,31 @@ def uninitialize(cls):
cls._initialized = False
@classmethod
- def _cleanup(cls):
+ def _cleanup(cls) -> None:
for pid in list(cls._waiting.keys()): # make a copy
cls._try_cleanup_process(pid)
@classmethod
- def _try_cleanup_process(cls, pid):
+ def _try_cleanup_process(cls, pid: int) -> None:
try:
- ret_pid, status = os.waitpid(pid, os.WNOHANG)
- except OSError as e:
- if errno_from_exception(e) == errno.ECHILD:
- return
+ ret_pid, status = os.waitpid(pid, os.WNOHANG) # type: ignore
+ except ChildProcessError:
+ return
if ret_pid == 0:
return
assert ret_pid == pid
subproc = cls._waiting.pop(pid)
- subproc.io_loop.add_callback_from_signal(
- subproc._set_returncode, status)
+ subproc.io_loop.add_callback_from_signal(subproc._set_returncode, status)
- def _set_returncode(self, status):
- if os.WIFSIGNALED(status):
- self.returncode = -os.WTERMSIG(status)
+ def _set_returncode(self, status: int) -> None:
+ if sys.platform == "win32":
+ self.returncode = -1
else:
- assert os.WIFEXITED(status)
- self.returncode = os.WEXITSTATUS(status)
+ if os.WIFSIGNALED(status):
+ self.returncode = -os.WTERMSIG(status)
+ else:
+ assert os.WIFEXITED(status)
+ self.returncode = os.WEXITSTATUS(status)
# We've taken over wait() duty from the subprocess.Popen
# object. If we don't inform it of the process's return code,
# it will log a warning at destruction in python 3.6+.
diff --git a/tornado/test/__init__.py b/tornado/py.typed
similarity index 100%
rename from tornado/test/__init__.py
rename to tornado/py.typed
diff --git a/tornado/queues.py b/tornado/queues.py
index 23b8bb9caa..1e87f62e09 100644
--- a/tornado/queues.py
+++ b/tornado/queues.py
@@ -25,48 +25,60 @@
"""
-from __future__ import absolute_import, division, print_function
-
import collections
+import datetime
import heapq
from tornado import gen, ioloop
from tornado.concurrent import Future, future_set_result_unless_cancelled
from tornado.locks import Event
-__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty']
+from typing import Union, TypeVar, Generic, Awaitable, Optional
+import typing
+
+if typing.TYPE_CHECKING:
+ from typing import Deque, Tuple, Any # noqa: F401
+
+_T = TypeVar("_T")
+
+__all__ = ["Queue", "PriorityQueue", "LifoQueue", "QueueFull", "QueueEmpty"]
class QueueEmpty(Exception):
"""Raised by `.Queue.get_nowait` when the queue has no items."""
+
pass
class QueueFull(Exception):
"""Raised by `.Queue.put_nowait` when a queue is at its maximum size."""
+
pass
-def _set_timeout(future, timeout):
+def _set_timeout(
+ future: Future, timeout: Union[None, float, datetime.timedelta]
+) -> None:
if timeout:
- def on_timeout():
+
+ def on_timeout() -> None:
if not future.done():
future.set_exception(gen.TimeoutError())
+
io_loop = ioloop.IOLoop.current()
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
- future.add_done_callback(
- lambda _: io_loop.remove_timeout(timeout_handle))
+ future.add_done_callback(lambda _: io_loop.remove_timeout(timeout_handle))
-class _QueueIterator(object):
- def __init__(self, q):
+class _QueueIterator(Generic[_T]):
+ def __init__(self, q: "Queue[_T]") -> None:
self.q = q
- def __anext__(self):
+ def __anext__(self) -> Awaitable[_T]:
return self.q.get()
-class Queue(object):
+class Queue(Generic[_T]):
"""Coordinate producer and consumer coroutines.
If maxsize is 0 (the default) the queue size is unbounded.
@@ -79,28 +91,24 @@ class Queue(object):
q = Queue(maxsize=2)
- @gen.coroutine
- def consumer():
- while True:
- item = yield q.get()
+ async def consumer():
+ async for item in q:
try:
print('Doing work on %s' % item)
- yield gen.sleep(0.01)
+ await gen.sleep(0.01)
finally:
q.task_done()
- @gen.coroutine
- def producer():
+ async def producer():
for item in range(5):
- yield q.put(item)
+ await q.put(item)
print('Put %s' % item)
- @gen.coroutine
- def main():
+ async def main():
# Start consumer without waiting (since it never finishes).
IOLoop.current().spawn_callback(consumer)
- yield producer() # Wait for producer to put all tasks.
- yield q.join() # Wait for consumer to finish all tasks.
+ await producer() # Wait for producer to put all tasks.
+ await q.join() # Wait for consumer to finish all tasks.
print('Done')
IOLoop.current().run_sync(main)
@@ -119,11 +127,14 @@ def main():
Doing work on 4
Done
- In Python 3.5, `Queue` implements the async iterator protocol, so
- ``consumer()`` could be rewritten as::
- async def consumer():
- async for item in q:
+ In versions of Python without native coroutines (before 3.5),
+ ``consumer()`` could be written as::
+
+ @gen.coroutine
+ def consumer():
+ while True:
+ item = yield q.get()
try:
print('Doing work on %s' % item)
yield gen.sleep(0.01)
@@ -134,7 +145,12 @@ async def consumer():
Added ``async for`` support in Python 3.5.
"""
- def __init__(self, maxsize=0):
+
+ # Exact type depends on subclass. Could be another generic
+ # parameter and use protocols to be more precise here.
+ _queue = None # type: Any
+
+ def __init__(self, maxsize: int = 0) -> None:
if maxsize is None:
raise TypeError("maxsize can't be None")
@@ -143,31 +159,33 @@ def __init__(self, maxsize=0):
self._maxsize = maxsize
self._init()
- self._getters = collections.deque([]) # Futures.
- self._putters = collections.deque([]) # Pairs of (item, Future).
+ self._getters = collections.deque([]) # type: Deque[Future[_T]]
+ self._putters = collections.deque([]) # type: Deque[Tuple[_T, Future[None]]]
self._unfinished_tasks = 0
self._finished = Event()
self._finished.set()
@property
- def maxsize(self):
+ def maxsize(self) -> int:
"""Number of items allowed in the queue."""
return self._maxsize
- def qsize(self):
+ def qsize(self) -> int:
"""Number of items in the queue."""
return len(self._queue)
- def empty(self):
+ def empty(self) -> bool:
return not self._queue
- def full(self):
+ def full(self) -> bool:
if self.maxsize == 0:
return False
else:
return self.qsize() >= self.maxsize
- def put(self, item, timeout=None):
+ def put(
+ self, item: _T, timeout: Optional[Union[float, datetime.timedelta]] = None
+ ) -> "Future[None]":
"""Put an item into the queue, perhaps waiting until there is room.
Returns a Future, which raises `tornado.util.TimeoutError` after a
@@ -178,7 +196,7 @@ def put(self, item, timeout=None):
`datetime.timedelta` object for a deadline relative to the
current time.
"""
- future = Future()
+ future = Future() # type: Future[None]
try:
self.put_nowait(item)
except QueueFull:
@@ -188,7 +206,7 @@ def put(self, item, timeout=None):
future.set_result(None)
return future
- def put_nowait(self, item):
+ def put_nowait(self, item: _T) -> None:
"""Put an item into the queue without blocking.
If no free slot is immediately available, raise `QueueFull`.
@@ -204,18 +222,30 @@ def put_nowait(self, item):
else:
self.__put_internal(item)
- def get(self, timeout=None):
+ def get(
+ self, timeout: Optional[Union[float, datetime.timedelta]] = None
+ ) -> Awaitable[_T]:
"""Remove and return an item from the queue.
- Returns a Future which resolves once an item is available, or raises
+ Returns an awaitable which resolves once an item is available, or raises
`tornado.util.TimeoutError` after a timeout.
``timeout`` may be a number denoting a time (on the same
scale as `tornado.ioloop.IOLoop.time`, normally `time.time`), or a
`datetime.timedelta` object for a deadline relative to the
current time.
+
+ .. note::
+
+ The ``timeout`` argument of this method differs from that
+ of the standard library's `queue.Queue.get`. That method
+ interprets numeric values as relative timeouts; this one
+ interprets them as absolute deadlines and requires
+ ``timedelta`` objects for relative timeouts (consistent
+ with other timeouts in Tornado).
+
"""
- future = Future()
+ future = Future() # type: Future[_T]
try:
future.set_result(self.get_nowait())
except QueueEmpty:
@@ -223,7 +253,7 @@ def get(self, timeout=None):
_set_timeout(future, timeout)
return future
- def get_nowait(self):
+ def get_nowait(self) -> _T:
"""Remove and return an item from the queue without blocking.
Return an item if one is immediately available, else raise
@@ -241,7 +271,7 @@ def get_nowait(self):
else:
raise QueueEmpty
- def task_done(self):
+ def task_done(self) -> None:
"""Indicate that a formerly enqueued task is complete.
Used by queue consumers. For each `.get` used to fetch a task, a
@@ -254,39 +284,42 @@ def task_done(self):
Raises `ValueError` if called more times than `.put`.
"""
if self._unfinished_tasks <= 0:
- raise ValueError('task_done() called too many times')
+ raise ValueError("task_done() called too many times")
self._unfinished_tasks -= 1
if self._unfinished_tasks == 0:
self._finished.set()
- def join(self, timeout=None):
+ def join(
+ self, timeout: Optional[Union[float, datetime.timedelta]] = None
+ ) -> Awaitable[None]:
"""Block until all items in the queue are processed.
- Returns a Future, which raises `tornado.util.TimeoutError` after a
+ Returns an awaitable, which raises `tornado.util.TimeoutError` after a
timeout.
"""
return self._finished.wait(timeout)
- def __aiter__(self):
+ def __aiter__(self) -> _QueueIterator[_T]:
return _QueueIterator(self)
# These three are overridable in subclasses.
- def _init(self):
+ def _init(self) -> None:
self._queue = collections.deque()
- def _get(self):
+ def _get(self) -> _T:
return self._queue.popleft()
- def _put(self, item):
+ def _put(self, item: _T) -> None:
self._queue.append(item)
+
# End of the overridable methods.
- def __put_internal(self, item):
+ def __put_internal(self, item: _T) -> None:
self._unfinished_tasks += 1
self._finished.clear()
self._put(item)
- def _consume_expired(self):
+ def _consume_expired(self) -> None:
# Remove timed-out waiters.
while self._putters and self._putters[0][1].done():
self._putters.popleft()
@@ -294,23 +327,22 @@ def _consume_expired(self):
while self._getters and self._getters[0].done():
self._getters.popleft()
- def __repr__(self):
- return '<%s at %s %s>' % (
- type(self).__name__, hex(id(self)), self._format())
+ def __repr__(self) -> str:
+ return "<%s at %s %s>" % (type(self).__name__, hex(id(self)), self._format())
- def __str__(self):
- return '<%s %s>' % (type(self).__name__, self._format())
+ def __str__(self) -> str:
+ return "<%s %s>" % (type(self).__name__, self._format())
- def _format(self):
- result = 'maxsize=%r' % (self.maxsize, )
- if getattr(self, '_queue', None):
- result += ' queue=%r' % self._queue
+ def _format(self) -> str:
+ result = "maxsize=%r" % (self.maxsize,)
+ if getattr(self, "_queue", None):
+ result += " queue=%r" % self._queue
if self._getters:
- result += ' getters[%s]' % len(self._getters)
+ result += " getters[%s]" % len(self._getters)
if self._putters:
- result += ' putters[%s]' % len(self._putters)
+ result += " putters[%s]" % len(self._putters)
if self._unfinished_tasks:
- result += ' tasks=%s' % self._unfinished_tasks
+ result += " tasks=%s" % self._unfinished_tasks
return result
@@ -338,13 +370,14 @@ class PriorityQueue(Queue):
(1, 'medium-priority item')
(10, 'low-priority item')
"""
- def _init(self):
+
+ def _init(self) -> None:
self._queue = []
- def _put(self, item):
+ def _put(self, item: _T) -> None:
heapq.heappush(self._queue, item)
- def _get(self):
+ def _get(self) -> _T:
return heapq.heappop(self._queue)
@@ -370,11 +403,12 @@ class LifoQueue(Queue):
2
3
"""
- def _init(self):
+
+ def _init(self) -> None:
self._queue = []
- def _put(self, item):
+ def _put(self, item: _T) -> None:
self._queue.append(item)
- def _get(self):
+ def _get(self) -> _T:
return self._queue.pop()
diff --git a/tornado/routing.py b/tornado/routing.py
index e56d1a75f9..a145d71916 100644
--- a/tornado/routing.py
+++ b/tornado/routing.py
@@ -142,7 +142,7 @@ def request_callable(request):
router = RuleRouter([
Rule(HostMatches("example.com"), RuleRouter([
- Rule(PathMatches("/app1/.*"), Application([(r"/app1/handler", Handler)]))),
+ Rule(PathMatches("/app1/.*"), Application([(r"/app1/handler", Handler)])),
]))
])
@@ -175,8 +175,6 @@ def request_callable(request):
"""
-from __future__ import absolute_import, division, print_function
-
import re
from functools import partial
@@ -186,17 +184,15 @@ def request_callable(request):
from tornado.log import app_log
from tornado.util import basestring_type, import_object, re_unescape, unicode_type
-try:
- import typing # noqa
-except ImportError:
- pass
+from typing import Any, Union, Optional, Awaitable, List, Dict, Pattern, Tuple, overload
class Router(httputil.HTTPServerConnectionDelegate):
"""Abstract router interface."""
- def find_handler(self, request, **kwargs):
- # type: (httputil.HTTPServerRequest, typing.Any)->httputil.HTTPMessageDelegate
+ def find_handler(
+ self, request: httputil.HTTPServerRequest, **kwargs: Any
+ ) -> Optional[httputil.HTTPMessageDelegate]:
"""Must be implemented to return an appropriate instance of `~.httputil.HTTPMessageDelegate`
that can serve the request.
Routing implementations may pass additional kwargs to extend the routing logic.
@@ -208,7 +204,9 @@ def find_handler(self, request, **kwargs):
"""
raise NotImplementedError()
- def start_request(self, server_conn, request_conn):
+ def start_request(
+ self, server_conn: object, request_conn: httputil.HTTPConnection
+ ) -> httputil.HTTPMessageDelegate:
return _RoutingDelegate(self, server_conn, request_conn)
@@ -217,7 +215,7 @@ class ReversibleRouter(Router):
and support reversing them to original urls.
"""
- def reverse_url(self, name, *args):
+ def reverse_url(self, name: str, *args: Any) -> Optional[str]:
"""Returns url string for a given route name and arguments
or ``None`` if no match is found.
@@ -229,50 +227,80 @@ def reverse_url(self, name, *args):
class _RoutingDelegate(httputil.HTTPMessageDelegate):
- def __init__(self, router, server_conn, request_conn):
+ def __init__(
+ self, router: Router, server_conn: object, request_conn: httputil.HTTPConnection
+ ) -> None:
self.server_conn = server_conn
self.request_conn = request_conn
- self.delegate = None
+ self.delegate = None # type: Optional[httputil.HTTPMessageDelegate]
self.router = router # type: Router
- def headers_received(self, start_line, headers):
+ def headers_received(
+ self,
+ start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
+ headers: httputil.HTTPHeaders,
+ ) -> Optional[Awaitable[None]]:
+ assert isinstance(start_line, httputil.RequestStartLine)
request = httputil.HTTPServerRequest(
connection=self.request_conn,
server_connection=self.server_conn,
- start_line=start_line, headers=headers)
+ start_line=start_line,
+ headers=headers,
+ )
self.delegate = self.router.find_handler(request)
if self.delegate is None:
- app_log.debug("Delegate for %s %s request not found",
- start_line.method, start_line.path)
+ app_log.debug(
+ "Delegate for %s %s request not found",
+ start_line.method,
+ start_line.path,
+ )
self.delegate = _DefaultMessageDelegate(self.request_conn)
return self.delegate.headers_received(start_line, headers)
- def data_received(self, chunk):
+ def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]:
+ assert self.delegate is not None
return self.delegate.data_received(chunk)
- def finish(self):
+ def finish(self) -> None:
+ assert self.delegate is not None
self.delegate.finish()
- def on_connection_close(self):
+ def on_connection_close(self) -> None:
+ assert self.delegate is not None
self.delegate.on_connection_close()
class _DefaultMessageDelegate(httputil.HTTPMessageDelegate):
- def __init__(self, connection):
+ def __init__(self, connection: httputil.HTTPConnection) -> None:
self.connection = connection
- def finish(self):
+ def finish(self) -> None:
self.connection.write_headers(
- httputil.ResponseStartLine("HTTP/1.1", 404, "Not Found"), httputil.HTTPHeaders())
+ httputil.ResponseStartLine("HTTP/1.1", 404, "Not Found"),
+ httputil.HTTPHeaders(),
+ )
self.connection.finish()
+# _RuleList can either contain pre-constructed Rules or a sequence of
+# arguments to be passed to the Rule constructor.
+_RuleList = List[
+ Union[
+ "Rule",
+ List[Any], # Can't do detailed typechecking of lists.
+ Tuple[Union[str, "Matcher"], Any],
+ Tuple[Union[str, "Matcher"], Any, Dict[str, Any]],
+ Tuple[Union[str, "Matcher"], Any, Dict[str, Any], str],
+ ]
+]
+
+
class RuleRouter(Router):
"""Rule-based router implementation."""
- def __init__(self, rules=None):
+ def __init__(self, rules: Optional[_RuleList] = None) -> None:
"""Constructs a router from an ordered list of rules::
RuleRouter([
@@ -299,11 +327,11 @@ def __init__(self, rules=None):
:arg rules: a list of `Rule` instances or tuples of `Rule`
constructor arguments.
"""
- self.rules = [] # type: typing.List[Rule]
+ self.rules = [] # type: List[Rule]
if rules:
self.add_rules(rules)
- def add_rules(self, rules):
+ def add_rules(self, rules: _RuleList) -> None:
"""Appends new rules to the router.
:arg rules: a list of Rule instances (or tuples of arguments, which are
@@ -319,7 +347,7 @@ def add_rules(self, rules):
self.rules.append(self.process_rule(rule))
- def process_rule(self, rule):
+ def process_rule(self, rule: "Rule") -> "Rule":
"""Override this method for additional preprocessing of each rule.
:arg Rule rule: a rule to be processed.
@@ -327,22 +355,27 @@ def process_rule(self, rule):
"""
return rule
- def find_handler(self, request, **kwargs):
+ def find_handler(
+ self, request: httputil.HTTPServerRequest, **kwargs: Any
+ ) -> Optional[httputil.HTTPMessageDelegate]:
for rule in self.rules:
target_params = rule.matcher.match(request)
if target_params is not None:
if rule.target_kwargs:
- target_params['target_kwargs'] = rule.target_kwargs
+ target_params["target_kwargs"] = rule.target_kwargs
delegate = self.get_target_delegate(
- rule.target, request, **target_params)
+ rule.target, request, **target_params
+ )
if delegate is not None:
return delegate
return None
- def get_target_delegate(self, target, request, **target_params):
+ def get_target_delegate(
+ self, target: Any, request: httputil.HTTPServerRequest, **target_params: Any
+ ) -> Optional[httputil.HTTPMessageDelegate]:
"""Returns an instance of `~.httputil.HTTPMessageDelegate` for a
Rule's target. This method is called by `~.find_handler` and can be
extended to provide additional target types.
@@ -356,9 +389,11 @@ def get_target_delegate(self, target, request, **target_params):
return target.find_handler(request, **target_params)
elif isinstance(target, httputil.HTTPServerConnectionDelegate):
+ assert request.connection is not None
return target.start_request(request.server_connection, request.connection)
elif callable(target):
+ assert request.connection is not None
return _CallableAdapter(
partial(target, **target_params), request.connection
)
@@ -374,23 +409,23 @@ class ReversibleRuleRouter(ReversibleRouter, RuleRouter):
in a rule's matcher (see `Matcher.reverse`).
"""
- def __init__(self, rules=None):
- self.named_rules = {} # type: typing.Dict[str]
- super(ReversibleRuleRouter, self).__init__(rules)
+ def __init__(self, rules: Optional[_RuleList] = None) -> None:
+ self.named_rules = {} # type: Dict[str, Any]
+ super().__init__(rules)
- def process_rule(self, rule):
- rule = super(ReversibleRuleRouter, self).process_rule(rule)
+ def process_rule(self, rule: "Rule") -> "Rule":
+ rule = super().process_rule(rule)
if rule.name:
if rule.name in self.named_rules:
app_log.warning(
- "Multiple handlers named %s; replacing previous value",
- rule.name)
+ "Multiple handlers named %s; replacing previous value", rule.name
+ )
self.named_rules[rule.name] = rule
return rule
- def reverse_url(self, name, *args):
+ def reverse_url(self, name: str, *args: Any) -> Optional[str]:
if name in self.named_rules:
return self.named_rules[name].matcher.reverse(*args)
@@ -406,7 +441,13 @@ def reverse_url(self, name, *args):
class Rule(object):
"""A routing rule."""
- def __init__(self, matcher, target, target_kwargs=None, name=None):
+ def __init__(
+ self,
+ matcher: "Matcher",
+ target: Any,
+ target_kwargs: Optional[Dict[str, Any]] = None,
+ name: Optional[str] = None,
+ ) -> None:
"""Constructs a Rule instance.
:arg Matcher matcher: a `Matcher` instance used for determining
@@ -433,19 +474,23 @@ def __init__(self, matcher, target, target_kwargs=None, name=None):
self.target_kwargs = target_kwargs if target_kwargs else {}
self.name = name
- def reverse(self, *args):
+ def reverse(self, *args: Any) -> Optional[str]:
return self.matcher.reverse(*args)
- def __repr__(self):
- return '%s(%r, %s, kwargs=%r, name=%r)' % \
- (self.__class__.__name__, self.matcher,
- self.target, self.target_kwargs, self.name)
+ def __repr__(self) -> str:
+ return "%s(%r, %s, kwargs=%r, name=%r)" % (
+ self.__class__.__name__,
+ self.matcher,
+ self.target,
+ self.target_kwargs,
+ self.name,
+ )
class Matcher(object):
"""Represents a matcher for request features."""
- def match(self, request):
+ def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]:
"""Matches current instance against the request.
:arg httputil.HTTPServerRequest request: current HTTP request
@@ -457,7 +502,7 @@ def match(self, request):
``None`` must be returned to indicate that there is no match."""
raise NotImplementedError()
- def reverse(self, *args):
+ def reverse(self, *args: Any) -> Optional[str]:
"""Reconstructs full url from matcher instance and additional arguments."""
return None
@@ -465,14 +510,14 @@ def reverse(self, *args):
class AnyMatches(Matcher):
"""Matches any request."""
- def match(self, request):
+ def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]:
return {}
class HostMatches(Matcher):
"""Matches requests from hosts specified by ``host_pattern`` regex."""
- def __init__(self, host_pattern):
+ def __init__(self, host_pattern: Union[str, Pattern]) -> None:
if isinstance(host_pattern, basestring_type):
if not host_pattern.endswith("$"):
host_pattern += "$"
@@ -480,7 +525,7 @@ def __init__(self, host_pattern):
else:
self.host_pattern = host_pattern
- def match(self, request):
+ def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]:
if self.host_pattern.match(request.host_name):
return {}
@@ -492,11 +537,11 @@ class DefaultHostMatches(Matcher):
Always returns no match if ``X-Real-Ip`` header is present.
"""
- def __init__(self, application, host_pattern):
+ def __init__(self, application: Any, host_pattern: Pattern) -> None:
self.application = application
self.host_pattern = host_pattern
- def match(self, request):
+ def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]:
# Look for default host if not behind load balancer (for debugging)
if "X-Real-Ip" not in request.headers:
if self.host_pattern.match(self.application.default_host):
@@ -507,28 +552,30 @@ def match(self, request):
class PathMatches(Matcher):
"""Matches requests with paths specified by ``path_pattern`` regex."""
- def __init__(self, path_pattern):
+ def __init__(self, path_pattern: Union[str, Pattern]) -> None:
if isinstance(path_pattern, basestring_type):
- if not path_pattern.endswith('$'):
- path_pattern += '$'
+ if not path_pattern.endswith("$"):
+ path_pattern += "$"
self.regex = re.compile(path_pattern)
else:
self.regex = path_pattern
- assert len(self.regex.groupindex) in (0, self.regex.groups), \
- ("groups in url regexes must either be all named or all "
- "positional: %r" % self.regex.pattern)
+ assert len(self.regex.groupindex) in (0, self.regex.groups), (
+ "groups in url regexes must either be all named or all "
+ "positional: %r" % self.regex.pattern
+ )
self._path, self._group_count = self._find_groups()
- def match(self, request):
+ def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]:
match = self.regex.match(request.path)
if match is None:
return None
if not self.regex.groups:
return {}
- path_args, path_kwargs = [], {}
+ path_args = [] # type: List[bytes]
+ path_kwargs = {} # type: Dict[str, bytes]
# Pass matched groups to the handler. Since
# match.groups() includes both named and
@@ -536,18 +583,19 @@ def match(self, request):
# or groupdict but not both.
if self.regex.groupindex:
path_kwargs = dict(
- (str(k), _unquote_or_none(v))
- for (k, v) in match.groupdict().items())
+ (str(k), _unquote_or_none(v)) for (k, v) in match.groupdict().items()
+ )
else:
path_args = [_unquote_or_none(s) for s in match.groups()]
return dict(path_args=path_args, path_kwargs=path_kwargs)
- def reverse(self, *args):
+ def reverse(self, *args: Any) -> Optional[str]:
if self._path is None:
raise ValueError("Cannot reverse url regex " + self.regex.pattern)
- assert len(args) == self._group_count, "required number of arguments " \
- "not found"
+ assert len(args) == self._group_count, (
+ "required number of arguments " "not found"
+ )
if not len(args):
return self._path
converted_args = []
@@ -557,29 +605,35 @@ def reverse(self, *args):
converted_args.append(url_escape(utf8(a), plus=False))
return self._path % tuple(converted_args)
- def _find_groups(self):
+ def _find_groups(self) -> Tuple[Optional[str], Optional[int]]:
"""Returns a tuple (reverse string, group count) for a url.
For example: Given the url pattern /([0-9]{4})/([a-z-]+)/, this method
would return ('/%s/%s/', 2).
"""
pattern = self.regex.pattern
- if pattern.startswith('^'):
+ if pattern.startswith("^"):
pattern = pattern[1:]
- if pattern.endswith('$'):
+ if pattern.endswith("$"):
pattern = pattern[:-1]
- if self.regex.groups != pattern.count('('):
+ if self.regex.groups != pattern.count("("):
# The pattern is too complicated for our simplistic matching,
# so we can't support reversing it.
return None, None
pieces = []
- for fragment in pattern.split('('):
- if ')' in fragment:
- paren_loc = fragment.index(')')
+ for fragment in pattern.split("("):
+ if ")" in fragment:
+ paren_loc = fragment.index(")")
if paren_loc >= 0:
- pieces.append('%s' + fragment[paren_loc + 1:])
+ try:
+ unescaped_fragment = re_unescape(fragment[paren_loc + 1 :])
+ except ValueError:
+ # If we can't unescape part of it, we can't
+ # reverse this url.
+ return (None, None)
+ pieces.append("%s" + unescaped_fragment)
else:
try:
unescaped_fragment = re_unescape(fragment)
@@ -589,7 +643,7 @@ def _find_groups(self):
return (None, None)
pieces.append(unescaped_fragment)
- return ''.join(pieces), self.regex.groups
+ return "".join(pieces), self.regex.groups
class URLSpec(Rule):
@@ -599,7 +653,14 @@ class URLSpec(Rule):
`URLSpec` is now a subclass of a `Rule` with `PathMatches` matcher and is preserved for
backwards compatibility.
"""
- def __init__(self, pattern, handler, kwargs=None, name=None):
+
+ def __init__(
+ self,
+ pattern: Union[str, Pattern],
+ handler: Any,
+ kwargs: Optional[Dict[str, Any]] = None,
+ name: Optional[str] = None,
+ ) -> None:
"""Parameters:
* ``pattern``: Regular expression to be matched. Any capturing
@@ -617,19 +678,34 @@ def __init__(self, pattern, handler, kwargs=None, name=None):
`~.web.Application.reverse_url`.
"""
- super(URLSpec, self).__init__(PathMatches(pattern), handler, kwargs, name)
+ matcher = PathMatches(pattern)
+ super().__init__(matcher, handler, kwargs, name)
- self.regex = self.matcher.regex
+ self.regex = matcher.regex
self.handler_class = self.target
self.kwargs = kwargs
- def __repr__(self):
- return '%s(%r, %s, kwargs=%r, name=%r)' % \
- (self.__class__.__name__, self.regex.pattern,
- self.handler_class, self.kwargs, self.name)
+ def __repr__(self) -> str:
+ return "%s(%r, %s, kwargs=%r, name=%r)" % (
+ self.__class__.__name__,
+ self.regex.pattern,
+ self.handler_class,
+ self.kwargs,
+ self.name,
+ )
+
+
+@overload
+def _unquote_or_none(s: str) -> bytes:
+ pass
+
+
+@overload # noqa: F811
+def _unquote_or_none(s: None) -> None:
+ pass
-def _unquote_or_none(s):
+def _unquote_or_none(s: Optional[str]) -> Optional[bytes]: # noqa: F811
"""None-safe wrapper around url_unescape to handle unmatched optional
groups correctly.
diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py
index 7696dd1849..1d9d14ba17 100644
--- a/tornado/simple_httpclient.py
+++ b/tornado/simple_httpclient.py
@@ -1,17 +1,25 @@
-from __future__ import absolute_import, division, print_function
-
-from tornado.escape import utf8, _unicode
-from tornado import gen
-from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _RequestProxy
+from tornado.escape import _unicode
+from tornado import gen, version
+from tornado.httpclient import (
+ HTTPResponse,
+ HTTPError,
+ AsyncHTTPClient,
+ main,
+ _RequestProxy,
+ HTTPRequest,
+)
from tornado import httputil
from tornado.http1connection import HTTP1Connection, HTTP1ConnectionParameters
from tornado.ioloop import IOLoop
-from tornado.iostream import StreamClosedError
-from tornado.netutil import Resolver, OverrideResolver, _client_ssl_defaults
+from tornado.iostream import StreamClosedError, IOStream
+from tornado.netutil import (
+ Resolver,
+ OverrideResolver,
+ _client_ssl_defaults,
+ is_valid_ip,
+)
from tornado.log import gen_log
-from tornado import stack_context
from tornado.tcpclient import TCPClient
-from tornado.util import PY3
import base64
import collections
@@ -19,20 +27,18 @@
import functools
import re
import socket
+import ssl
import sys
+import time
from io import BytesIO
+import urllib.parse
+from typing import Dict, Any, Callable, Optional, Type, Union
+from types import TracebackType
+import typing
-if PY3:
- import urllib.parse as urlparse
-else:
- import urlparse
-
-try:
- import ssl
-except ImportError:
- # ssl is not available on Google App Engine.
- ssl = None
+if typing.TYPE_CHECKING:
+ from typing import Deque, Tuple, List # noqa: F401
class HTTPTimeoutError(HTTPError):
@@ -43,11 +49,12 @@ class HTTPTimeoutError(HTTPError):
.. versionadded:: 5.1
"""
- def __init__(self, message):
- super(HTTPTimeoutError, self).__init__(599, message=message)
- def __str__(self):
- return self.message
+ def __init__(self, message: str) -> None:
+ super().__init__(599, message=message)
+
+ def __str__(self) -> str:
+ return self.message or "Timeout"
class HTTPStreamClosedError(HTTPError):
@@ -61,11 +68,12 @@ class HTTPStreamClosedError(HTTPError):
.. versionadded:: 5.1
"""
- def __init__(self, message):
- super(HTTPStreamClosedError, self).__init__(599, message=message)
- def __str__(self):
- return self.message
+ def __init__(self, message: str) -> None:
+ super().__init__(599, message=message)
+
+ def __str__(self) -> str:
+ return self.message or "Stream closed"
class SimpleAsyncHTTPClient(AsyncHTTPClient):
@@ -77,10 +85,17 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
are not reused, and callers cannot select the network interface to be
used.
"""
- def initialize(self, max_clients=10,
- hostname_mapping=None, max_buffer_size=104857600,
- resolver=None, defaults=None, max_header_size=None,
- max_body_size=None):
+
+ def initialize( # type: ignore
+ self,
+ max_clients: int = 10,
+ hostname_mapping: Optional[Dict[str, str]] = None,
+ max_buffer_size: int = 104857600,
+ resolver: Optional[Resolver] = None,
+ defaults: Optional[Dict[str, Any]] = None,
+ max_header_size: Optional[int] = None,
+ max_body_size: Optional[int] = None,
+ ) -> None:
"""Creates a AsyncHTTPClient.
Only a single AsyncHTTPClient instance exists per IOLoop
@@ -113,11 +128,17 @@ def initialize(self, max_clients=10,
.. versionchanged:: 4.2
Added the ``max_body_size`` argument.
"""
- super(SimpleAsyncHTTPClient, self).initialize(defaults=defaults)
+ super().initialize(defaults=defaults)
self.max_clients = max_clients
- self.queue = collections.deque()
- self.active = {}
- self.waiting = {}
+ self.queue = (
+ collections.deque()
+ ) # type: Deque[Tuple[object, HTTPRequest, Callable[[HTTPResponse], None]]]
+ self.active = (
+ {}
+ ) # type: Dict[object, Tuple[HTTPRequest, Callable[[HTTPResponse], None]]]
+ self.waiting = (
+ {}
+ ) # type: Dict[object, Tuple[HTTPRequest, Callable[[HTTPResponse], None], object]]
self.max_buffer_size = max_buffer_size
self.max_header_size = max_header_size
self.max_body_size = max_body_size
@@ -130,65 +151,86 @@ def initialize(self, max_clients=10,
self.resolver = Resolver()
self.own_resolver = True
if hostname_mapping is not None:
- self.resolver = OverrideResolver(resolver=self.resolver,
- mapping=hostname_mapping)
+ self.resolver = OverrideResolver(
+ resolver=self.resolver, mapping=hostname_mapping
+ )
self.tcp_client = TCPClient(resolver=self.resolver)
- def close(self):
- super(SimpleAsyncHTTPClient, self).close()
+ def close(self) -> None:
+ super().close()
if self.own_resolver:
self.resolver.close()
self.tcp_client.close()
- def fetch_impl(self, request, callback):
+ def fetch_impl(
+ self, request: HTTPRequest, callback: Callable[[HTTPResponse], None]
+ ) -> None:
key = object()
self.queue.append((key, request, callback))
- if not len(self.active) < self.max_clients:
- timeout_handle = self.io_loop.add_timeout(
- self.io_loop.time() + min(request.connect_timeout,
- request.request_timeout),
- functools.partial(self._on_timeout, key, "in request queue"))
- else:
- timeout_handle = None
+ assert request.connect_timeout is not None
+ assert request.request_timeout is not None
+ timeout_handle = None
+ if len(self.active) >= self.max_clients:
+ timeout = (
+ min(request.connect_timeout, request.request_timeout)
+ or request.connect_timeout
+ or request.request_timeout
+ ) # min but skip zero
+ if timeout:
+ timeout_handle = self.io_loop.add_timeout(
+ self.io_loop.time() + timeout,
+ functools.partial(self._on_timeout, key, "in request queue"),
+ )
self.waiting[key] = (request, callback, timeout_handle)
self._process_queue()
if self.queue:
- gen_log.debug("max_clients limit reached, request queued. "
- "%d active, %d queued requests." % (
- len(self.active), len(self.queue)))
-
- def _process_queue(self):
- with stack_context.NullContext():
- while self.queue and len(self.active) < self.max_clients:
- key, request, callback = self.queue.popleft()
- if key not in self.waiting:
- continue
- self._remove_timeout(key)
- self.active[key] = (request, callback)
- release_callback = functools.partial(self._release_fetch, key)
- self._handle_request(request, release_callback, callback)
-
- def _connection_class(self):
+ gen_log.debug(
+ "max_clients limit reached, request queued. "
+ "%d active, %d queued requests." % (len(self.active), len(self.queue))
+ )
+
+ def _process_queue(self) -> None:
+ while self.queue and len(self.active) < self.max_clients:
+ key, request, callback = self.queue.popleft()
+ if key not in self.waiting:
+ continue
+ self._remove_timeout(key)
+ self.active[key] = (request, callback)
+ release_callback = functools.partial(self._release_fetch, key)
+ self._handle_request(request, release_callback, callback)
+
+ def _connection_class(self) -> type:
return _HTTPConnection
- def _handle_request(self, request, release_callback, final_callback):
+ def _handle_request(
+ self,
+ request: HTTPRequest,
+ release_callback: Callable[[], None],
+ final_callback: Callable[[HTTPResponse], None],
+ ) -> None:
self._connection_class()(
- self, request, release_callback,
- final_callback, self.max_buffer_size, self.tcp_client,
- self.max_header_size, self.max_body_size)
-
- def _release_fetch(self, key):
+ self,
+ request,
+ release_callback,
+ final_callback,
+ self.max_buffer_size,
+ self.tcp_client,
+ self.max_header_size,
+ self.max_body_size,
+ )
+
+ def _release_fetch(self, key: object) -> None:
del self.active[key]
self._process_queue()
- def _remove_timeout(self, key):
+ def _remove_timeout(self, key: object) -> None:
if key in self.waiting:
request, callback, timeout_handle = self.waiting[key]
if timeout_handle is not None:
self.io_loop.remove_timeout(timeout_handle)
del self.waiting[key]
- def _on_timeout(self, key, info=None):
+ def _on_timeout(self, key: object, info: Optional[str] = None) -> None:
"""Timeout callback of request.
Construct a timeout HTTPResponse when a timeout occurs.
@@ -201,20 +243,34 @@ def _on_timeout(self, key, info=None):
error_message = "Timeout {0}".format(info) if info else "Timeout"
timeout_response = HTTPResponse(
- request, 599, error=HTTPTimeoutError(error_message),
- request_time=self.io_loop.time() - request.start_time)
+ request,
+ 599,
+ error=HTTPTimeoutError(error_message),
+ request_time=self.io_loop.time() - request.start_time,
+ )
self.io_loop.add_callback(callback, timeout_response)
del self.waiting[key]
class _HTTPConnection(httputil.HTTPMessageDelegate):
- _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
-
- def __init__(self, client, request, release_callback,
- final_callback, max_buffer_size, tcp_client,
- max_header_size, max_body_size):
+ _SUPPORTED_METHODS = set(
+ ["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]
+ )
+
+ def __init__(
+ self,
+ client: Optional[SimpleAsyncHTTPClient],
+ request: HTTPRequest,
+ release_callback: Callable[[], None],
+ final_callback: Callable[[HTTPResponse], None],
+ max_buffer_size: int,
+ tcp_client: TCPClient,
+ max_header_size: int,
+ max_body_size: int,
+ ) -> None:
self.io_loop = IOLoop.current()
self.start_time = self.io_loop.time()
+ self.start_wall_time = time.time()
self.client = client
self.request = request
self.release_callback = release_callback
@@ -223,18 +279,22 @@ def __init__(self, client, request, release_callback,
self.tcp_client = tcp_client
self.max_header_size = max_header_size
self.max_body_size = max_body_size
- self.code = None
- self.headers = None
- self.chunks = []
+ self.code = None # type: Optional[int]
+ self.headers = None # type: Optional[httputil.HTTPHeaders]
+ self.chunks = [] # type: List[bytes]
self._decompressor = None
# Timeout handle returned by IOLoop.add_timeout
- self._timeout = None
+ self._timeout = None # type: object
self._sockaddr = None
- with stack_context.ExceptionStackContext(self._handle_exception):
- self.parsed = urlparse.urlsplit(_unicode(self.request.url))
+ IOLoop.current().add_future(
+ gen.convert_yielded(self.run()), lambda f: f.result()
+ )
+
+ async def run(self) -> None:
+ try:
+ self.parsed = urllib.parse.urlsplit(_unicode(self.request.url))
if self.parsed.scheme not in ("http", "https"):
- raise ValueError("Unsupported url scheme: %s" %
- self.request.url)
+ raise ValueError("Unsupported url scheme: %s" % self.request.url)
# urlsplit results have hostname and port results, but they
# didn't support ipv6 literals until python 2.7.
netloc = self.parsed.netloc
@@ -243,55 +303,186 @@ def __init__(self, client, request, release_callback,
host, port = httputil.split_host_and_port(netloc)
if port is None:
port = 443 if self.parsed.scheme == "https" else 80
- if re.match(r'^\[.*\]$', host):
+ if re.match(r"^\[.*\]$", host):
# raw ipv6 addresses in urls are enclosed in brackets
host = host[1:-1]
self.parsed_hostname = host # save final host for _on_connect
- if request.allow_ipv6 is False:
+ if self.request.allow_ipv6 is False:
af = socket.AF_INET
else:
af = socket.AF_UNSPEC
ssl_options = self._get_ssl_options(self.parsed.scheme)
- timeout = min(self.request.connect_timeout, self.request.request_timeout)
+ source_ip = None
+ if self.request.network_interface:
+ if is_valid_ip(self.request.network_interface):
+ source_ip = self.request.network_interface
+ else:
+ raise ValueError(
+ "Unrecognized IPv4 or IPv6 address for network_interface, got %r"
+ % (self.request.network_interface,)
+ )
+
+ if self.request.connect_timeout and self.request.request_timeout:
+ timeout = min(
+ self.request.connect_timeout, self.request.request_timeout
+ )
+ elif self.request.connect_timeout:
+ timeout = self.request.connect_timeout
+ elif self.request.request_timeout:
+ timeout = self.request.request_timeout
+ else:
+ timeout = 0
if timeout:
self._timeout = self.io_loop.add_timeout(
self.start_time + timeout,
- stack_context.wrap(functools.partial(self._on_timeout, "while connecting")))
- fut = self.tcp_client.connect(host, port, af=af,
- ssl_options=ssl_options,
- max_buffer_size=self.max_buffer_size)
- fut.add_done_callback(stack_context.wrap(self._on_connect))
-
- def _get_ssl_options(self, scheme):
+ functools.partial(self._on_timeout, "while connecting"),
+ )
+ stream = await self.tcp_client.connect(
+ host,
+ port,
+ af=af,
+ ssl_options=ssl_options,
+ max_buffer_size=self.max_buffer_size,
+ source_ip=source_ip,
+ )
+
+ if self.final_callback is None:
+ # final_callback is cleared if we've hit our timeout.
+ stream.close()
+ return
+ self.stream = stream
+ self.stream.set_close_callback(self.on_connection_close)
+ self._remove_timeout()
+ if self.final_callback is None:
+ return
+ if self.request.request_timeout:
+ self._timeout = self.io_loop.add_timeout(
+ self.start_time + self.request.request_timeout,
+ functools.partial(self._on_timeout, "during request"),
+ )
+ if (
+ self.request.method not in self._SUPPORTED_METHODS
+ and not self.request.allow_nonstandard_methods
+ ):
+ raise KeyError("unknown method %s" % self.request.method)
+ for key in (
+ "proxy_host",
+ "proxy_port",
+ "proxy_username",
+ "proxy_password",
+ "proxy_auth_mode",
+ ):
+ if getattr(self.request, key, None):
+ raise NotImplementedError("%s not supported" % key)
+ if "Connection" not in self.request.headers:
+ self.request.headers["Connection"] = "close"
+ if "Host" not in self.request.headers:
+ if "@" in self.parsed.netloc:
+ self.request.headers["Host"] = self.parsed.netloc.rpartition("@")[
+ -1
+ ]
+ else:
+ self.request.headers["Host"] = self.parsed.netloc
+ username, password = None, None
+ if self.parsed.username is not None:
+ username, password = self.parsed.username, self.parsed.password
+ elif self.request.auth_username is not None:
+ username = self.request.auth_username
+ password = self.request.auth_password or ""
+ if username is not None:
+ assert password is not None
+ if self.request.auth_mode not in (None, "basic"):
+ raise ValueError("unsupported auth_mode %s", self.request.auth_mode)
+ self.request.headers["Authorization"] = "Basic " + _unicode(
+ base64.b64encode(
+ httputil.encode_username_password(username, password)
+ )
+ )
+ if self.request.user_agent:
+ self.request.headers["User-Agent"] = self.request.user_agent
+ elif self.request.headers.get("User-Agent") is None:
+ self.request.headers["User-Agent"] = "Tornado/{}".format(version)
+ if not self.request.allow_nonstandard_methods:
+ # Some HTTP methods nearly always have bodies while others
+ # almost never do. Fail in this case unless the user has
+ # opted out of sanity checks with allow_nonstandard_methods.
+ body_expected = self.request.method in ("POST", "PATCH", "PUT")
+ body_present = (
+ self.request.body is not None
+ or self.request.body_producer is not None
+ )
+ if (body_expected and not body_present) or (
+ body_present and not body_expected
+ ):
+ raise ValueError(
+ "Body must %sbe None for method %s (unless "
+ "allow_nonstandard_methods is true)"
+ % ("not " if body_expected else "", self.request.method)
+ )
+ if self.request.expect_100_continue:
+ self.request.headers["Expect"] = "100-continue"
+ if self.request.body is not None:
+ # When body_producer is used the caller is responsible for
+ # setting Content-Length (or else chunked encoding will be used).
+ self.request.headers["Content-Length"] = str(len(self.request.body))
+ if (
+ self.request.method == "POST"
+ and "Content-Type" not in self.request.headers
+ ):
+ self.request.headers[
+ "Content-Type"
+ ] = "application/x-www-form-urlencoded"
+ if self.request.decompress_response:
+ self.request.headers["Accept-Encoding"] = "gzip"
+ req_path = (self.parsed.path or "/") + (
+ ("?" + self.parsed.query) if self.parsed.query else ""
+ )
+ self.connection = self._create_connection(stream)
+ start_line = httputil.RequestStartLine(self.request.method, req_path, "")
+ self.connection.write_headers(start_line, self.request.headers)
+ if self.request.expect_100_continue:
+ await self.connection.read_response(self)
+ else:
+ await self._write_body(True)
+ except Exception:
+ if not self._handle_exception(*sys.exc_info()):
+ raise
+
+ def _get_ssl_options(
+ self, scheme: str
+ ) -> Union[None, Dict[str, Any], ssl.SSLContext]:
if scheme == "https":
if self.request.ssl_options is not None:
return self.request.ssl_options
# If we are using the defaults, don't construct a
# new SSLContext.
- if (self.request.validate_cert and
- self.request.ca_certs is None and
- self.request.client_cert is None and
- self.request.client_key is None):
+ if (
+ self.request.validate_cert
+ and self.request.ca_certs is None
+ and self.request.client_cert is None
+ and self.request.client_key is None
+ ):
return _client_ssl_defaults
ssl_ctx = ssl.create_default_context(
- ssl.Purpose.SERVER_AUTH,
- cafile=self.request.ca_certs)
+ ssl.Purpose.SERVER_AUTH, cafile=self.request.ca_certs
+ )
if not self.request.validate_cert:
ssl_ctx.check_hostname = False
ssl_ctx.verify_mode = ssl.CERT_NONE
if self.request.client_cert is not None:
- ssl_ctx.load_cert_chain(self.request.client_cert,
- self.request.client_key)
- if hasattr(ssl, 'OP_NO_COMPRESSION'):
+ ssl_ctx.load_cert_chain(
+ self.request.client_cert, self.request.client_key
+ )
+ if hasattr(ssl, "OP_NO_COMPRESSION"):
# See netutil.ssl_options_to_context
ssl_ctx.options |= ssl.OP_NO_COMPRESSION
return ssl_ctx
return None
- def _on_timeout(self, info=None):
+ def _on_timeout(self, info: Optional[str] = None) -> None:
"""Timeout callback of _HTTPConnection instance.
Raise a `HTTPTimeoutError` when a timeout occurs.
@@ -301,147 +492,64 @@ def _on_timeout(self, info=None):
self._timeout = None
error_message = "Timeout {0}".format(info) if info else "Timeout"
if self.final_callback is not None:
- raise HTTPTimeoutError(error_message)
+ self._handle_exception(
+ HTTPTimeoutError, HTTPTimeoutError(error_message), None
+ )
- def _remove_timeout(self):
+ def _remove_timeout(self) -> None:
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
- def _on_connect(self, stream_fut):
- stream = stream_fut.result()
- if self.final_callback is None:
- # final_callback is cleared if we've hit our timeout.
- stream.close()
- return
- self.stream = stream
- self.stream.set_close_callback(self.on_connection_close)
- self._remove_timeout()
- if self.final_callback is None:
- return
- if self.request.request_timeout:
- self._timeout = self.io_loop.add_timeout(
- self.start_time + self.request.request_timeout,
- stack_context.wrap(functools.partial(self._on_timeout, "during request")))
- if (self.request.method not in self._SUPPORTED_METHODS and
- not self.request.allow_nonstandard_methods):
- raise KeyError("unknown method %s" % self.request.method)
- for key in ('network_interface',
- 'proxy_host', 'proxy_port',
- 'proxy_username', 'proxy_password',
- 'proxy_auth_mode'):
- if getattr(self.request, key, None):
- raise NotImplementedError('%s not supported' % key)
- if "Connection" not in self.request.headers:
- self.request.headers["Connection"] = "close"
- if "Host" not in self.request.headers:
- if '@' in self.parsed.netloc:
- self.request.headers["Host"] = self.parsed.netloc.rpartition('@')[-1]
- else:
- self.request.headers["Host"] = self.parsed.netloc
- username, password = None, None
- if self.parsed.username is not None:
- username, password = self.parsed.username, self.parsed.password
- elif self.request.auth_username is not None:
- username = self.request.auth_username
- password = self.request.auth_password or ''
- if username is not None:
- if self.request.auth_mode not in (None, "basic"):
- raise ValueError("unsupported auth_mode %s",
- self.request.auth_mode)
- auth = utf8(username) + b":" + utf8(password)
- self.request.headers["Authorization"] = (b"Basic " +
- base64.b64encode(auth))
- if self.request.user_agent:
- self.request.headers["User-Agent"] = self.request.user_agent
- if not self.request.allow_nonstandard_methods:
- # Some HTTP methods nearly always have bodies while others
- # almost never do. Fail in this case unless the user has
- # opted out of sanity checks with allow_nonstandard_methods.
- body_expected = self.request.method in ("POST", "PATCH", "PUT")
- body_present = (self.request.body is not None or
- self.request.body_producer is not None)
- if ((body_expected and not body_present) or
- (body_present and not body_expected)):
- raise ValueError(
- 'Body must %sbe None for method %s (unless '
- 'allow_nonstandard_methods is true)' %
- ('not ' if body_expected else '', self.request.method))
- if self.request.expect_100_continue:
- self.request.headers["Expect"] = "100-continue"
- if self.request.body is not None:
- # When body_producer is used the caller is responsible for
- # setting Content-Length (or else chunked encoding will be used).
- self.request.headers["Content-Length"] = str(len(
- self.request.body))
- if (self.request.method == "POST" and
- "Content-Type" not in self.request.headers):
- self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
- if self.request.decompress_response:
- self.request.headers["Accept-Encoding"] = "gzip"
- req_path = ((self.parsed.path or '/') +
- (('?' + self.parsed.query) if self.parsed.query else ''))
- self.connection = self._create_connection(stream)
- start_line = httputil.RequestStartLine(self.request.method,
- req_path, '')
- self.connection.write_headers(start_line, self.request.headers)
- if self.request.expect_100_continue:
- self._read_response()
- else:
- self._write_body(True)
-
- def _create_connection(self, stream):
+ def _create_connection(self, stream: IOStream) -> HTTP1Connection:
stream.set_nodelay(True)
connection = HTTP1Connection(
- stream, True,
+ stream,
+ True,
HTTP1ConnectionParameters(
no_keep_alive=True,
max_header_size=self.max_header_size,
max_body_size=self.max_body_size,
- decompress=self.request.decompress_response),
- self._sockaddr)
+ decompress=bool(self.request.decompress_response),
+ ),
+ self._sockaddr,
+ )
return connection
- def _write_body(self, start_read):
+ async def _write_body(self, start_read: bool) -> None:
if self.request.body is not None:
self.connection.write(self.request.body)
elif self.request.body_producer is not None:
fut = self.request.body_producer(self.connection.write)
if fut is not None:
- fut = gen.convert_yielded(fut)
-
- def on_body_written(fut):
- fut.result()
- self.connection.finish()
- if start_read:
- self._read_response()
- self.io_loop.add_future(fut, on_body_written)
- return
+ await fut
self.connection.finish()
if start_read:
- self._read_response()
-
- def _read_response(self):
- # Ensure that any exception raised in read_response ends up in our
- # stack context.
- self.io_loop.add_future(
- self.connection.read_response(self),
- lambda f: f.result())
+ try:
+ await self.connection.read_response(self)
+ except StreamClosedError:
+ if not self._handle_exception(*sys.exc_info()):
+ raise
- def _release(self):
+ def _release(self) -> None:
if self.release_callback is not None:
release_callback = self.release_callback
- self.release_callback = None
+ self.release_callback = None # type: ignore
release_callback()
- def _run_callback(self, response):
+ def _run_callback(self, response: HTTPResponse) -> None:
self._release()
if self.final_callback is not None:
final_callback = self.final_callback
- self.final_callback = None
+ self.final_callback = None # type: ignore
self.io_loop.add_callback(final_callback, response)
- def _handle_exception(self, typ, value, tb):
+ def _handle_exception(
+ self,
+ typ: "Optional[Type[BaseException]]",
+ value: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> bool:
if self.final_callback:
self._remove_timeout()
if isinstance(value, StreamClosedError):
@@ -449,9 +557,15 @@ def _handle_exception(self, typ, value, tb):
value = HTTPStreamClosedError("Stream closed")
else:
value = value.real_error
- self._run_callback(HTTPResponse(self.request, 599, error=value,
- request_time=self.io_loop.time() - self.start_time,
- ))
+ self._run_callback(
+ HTTPResponse(
+ self.request,
+ 599,
+ error=value,
+ request_time=self.io_loop.time() - self.start_time,
+ start_time=self.start_wall_time,
+ )
+ )
if hasattr(self, "stream"):
# TODO: this may cause a StreamClosedError to be raised
@@ -466,7 +580,7 @@ def _handle_exception(self, typ, value, tb):
# pass it along, unless it's just the stream being closed.
return isinstance(value, StreamClosedError)
- def on_connection_close(self):
+ def on_connection_close(self) -> None:
if self.final_callback is not None:
message = "Connection closed"
if self.stream.error:
@@ -476,9 +590,14 @@ def on_connection_close(self):
except HTTPStreamClosedError:
self._handle_exception(*sys.exc_info())
- def headers_received(self, first_line, headers):
+ async def headers_received(
+ self,
+ first_line: Union[httputil.ResponseStartLine, httputil.RequestStartLine],
+ headers: httputil.HTTPHeaders,
+ ) -> None:
+ assert isinstance(first_line, httputil.ResponseStartLine)
if self.request.expect_100_continue and first_line.code == 100:
- self._write_body(False)
+ await self._write_body(False)
return
self.code = first_line.code
self.reason = first_line.reason
@@ -489,48 +608,66 @@ def headers_received(self, first_line, headers):
if self.request.header_callback is not None:
# Reassemble the start line.
- self.request.header_callback('%s %s %s\r\n' % first_line)
+ self.request.header_callback("%s %s %s\r\n" % first_line)
for k, v in self.headers.get_all():
self.request.header_callback("%s: %s\r\n" % (k, v))
- self.request.header_callback('\r\n')
-
- def _should_follow_redirect(self):
- return (self.request.follow_redirects and
- self.request.max_redirects > 0 and
- self.code in (301, 302, 303, 307, 308))
-
- def finish(self):
- data = b''.join(self.chunks)
+ self.request.header_callback("\r\n")
+
+ def _should_follow_redirect(self) -> bool:
+ if self.request.follow_redirects:
+ assert self.request.max_redirects is not None
+ return (
+ self.code in (301, 302, 303, 307, 308)
+ and self.request.max_redirects > 0
+ and self.headers is not None
+ and self.headers.get("Location") is not None
+ )
+ return False
+
+ def finish(self) -> None:
+ assert self.code is not None
+ data = b"".join(self.chunks)
self._remove_timeout()
- original_request = getattr(self.request, "original_request",
- self.request)
+ original_request = getattr(self.request, "original_request", self.request)
if self._should_follow_redirect():
assert isinstance(self.request, _RequestProxy)
+ assert self.headers is not None
new_request = copy.copy(self.request.request)
- new_request.url = urlparse.urljoin(self.request.url,
- self.headers["Location"])
+ new_request.url = urllib.parse.urljoin(
+ self.request.url, self.headers["Location"]
+ )
+ assert self.request.max_redirects is not None
new_request.max_redirects = self.request.max_redirects - 1
del new_request.headers["Host"]
- # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4
- # Client SHOULD make a GET request after a 303.
- # According to the spec, 302 should be followed by the same
- # method as the original request, but in practice browsers
- # treat 302 the same as 303, and many servers use 302 for
- # compatibility with pre-HTTP/1.1 user agents which don't
- # understand the 303 status.
- if self.code in (302, 303):
+ # https://tools.ietf.org/html/rfc7231#section-6.4
+ #
+ # The original HTTP spec said that after a 301 or 302
+ # redirect, the request method should be preserved.
+ # However, browsers implemented this by changing the
+ # method to GET, and the behavior stuck. 303 redirects
+ # always specified this POST-to-GET behavior, arguably
+ # for *all* methods, but libcurl < 7.70 only does this
+ # for POST, while libcurl >= 7.70 does it for other methods.
+ if (self.code == 303 and self.request.method != "HEAD") or (
+ self.code in (301, 302) and self.request.method == "POST"
+ ):
new_request.method = "GET"
- new_request.body = None
- for h in ["Content-Length", "Content-Type",
- "Content-Encoding", "Transfer-Encoding"]:
+ new_request.body = None # type: ignore
+ for h in [
+ "Content-Length",
+ "Content-Type",
+ "Content-Encoding",
+ "Transfer-Encoding",
+ ]:
try:
del self.request.headers[h]
except KeyError:
pass
- new_request.original_request = original_request
+ new_request.original_request = original_request # type: ignore
final_callback = self.final_callback
- self.final_callback = None
+ self.final_callback = None # type: ignore
self._release()
+ assert self.client is not None
fut = self.client.fetch(new_request, raise_error=False)
fut.add_done_callback(lambda f: final_callback(f.result()))
self._on_end_request()
@@ -539,19 +676,23 @@ def finish(self):
buffer = BytesIO()
else:
buffer = BytesIO(data) # TODO: don't require one big string?
- response = HTTPResponse(original_request,
- self.code, reason=getattr(self, 'reason', None),
- headers=self.headers,
- request_time=self.io_loop.time() - self.start_time,
- buffer=buffer,
- effective_url=self.request.url)
+ response = HTTPResponse(
+ original_request,
+ self.code,
+ reason=getattr(self, "reason", None),
+ headers=self.headers,
+ request_time=self.io_loop.time() - self.start_time,
+ start_time=self.start_wall_time,
+ buffer=buffer,
+ effective_url=self.request.url,
+ )
self._run_callback(response)
self._on_end_request()
- def _on_end_request(self):
+ def _on_end_request(self) -> None:
self.stream.close()
- def data_received(self, chunk):
+ def data_received(self, chunk: bytes) -> None:
if self._should_follow_redirect():
# We're going to follow a redirect so just discard the body.
return
diff --git a/tornado/speedups.c b/tornado/speedups.c
index b714268ab4..525d66034c 100644
--- a/tornado/speedups.c
+++ b/tornado/speedups.c
@@ -56,7 +56,6 @@ static PyMethodDef methods[] = {
{NULL, NULL, 0, NULL}
};
-#if PY_MAJOR_VERSION >= 3
static struct PyModuleDef speedupsmodule = {
PyModuleDef_HEAD_INIT,
"speedups",
@@ -69,9 +68,3 @@ PyMODINIT_FUNC
PyInit_speedups(void) {
return PyModule_Create(&speedupsmodule);
}
-#else // Python 2.x
-PyMODINIT_FUNC
-initspeedups(void) {
- Py_InitModule("tornado.speedups", methods);
-}
-#endif
diff --git a/tornado/stack_context.py b/tornado/stack_context.py
deleted file mode 100644
index 2f26f3845f..0000000000
--- a/tornado/stack_context.py
+++ /dev/null
@@ -1,389 +0,0 @@
-#
-# Copyright 2010 Facebook
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-"""`StackContext` allows applications to maintain threadlocal-like state
-that follows execution as it moves to other execution contexts.
-
-The motivating examples are to eliminate the need for explicit
-``async_callback`` wrappers (as in `tornado.web.RequestHandler`), and to
-allow some additional context to be kept for logging.
-
-This is slightly magic, but it's an extension of the idea that an
-exception handler is a kind of stack-local state and when that stack
-is suspended and resumed in a new context that state needs to be
-preserved. `StackContext` shifts the burden of restoring that state
-from each call site (e.g. wrapping each `.AsyncHTTPClient` callback
-in ``async_callback``) to the mechanisms that transfer control from
-one context to another (e.g. `.AsyncHTTPClient` itself, `.IOLoop`,
-thread pools, etc).
-
-Example usage::
-
- @contextlib.contextmanager
- def die_on_error():
- try:
- yield
- except Exception:
- logging.error("exception in asynchronous operation",exc_info=True)
- sys.exit(1)
-
- with StackContext(die_on_error):
- # Any exception thrown here *or in callback and its descendants*
- # will cause the process to exit instead of spinning endlessly
- # in the ioloop.
- http_client.fetch(url, callback)
- ioloop.start()
-
-Most applications shouldn't have to work with `StackContext` directly.
-Here are a few rules of thumb for when it's necessary:
-
-* If you're writing an asynchronous library that doesn't rely on a
- stack_context-aware library like `tornado.ioloop` or `tornado.iostream`
- (for example, if you're writing a thread pool), use
- `.stack_context.wrap()` before any asynchronous operations to capture the
- stack context from where the operation was started.
-
-* If you're writing an asynchronous library that has some shared
- resources (such as a connection pool), create those shared resources
- within a ``with stack_context.NullContext():`` block. This will prevent
- ``StackContexts`` from leaking from one request to another.
-
-* If you want to write something like an exception handler that will
- persist across asynchronous calls, create a new `StackContext` (or
- `ExceptionStackContext`), and make your asynchronous calls in a ``with``
- block that references your `StackContext`.
-"""
-
-from __future__ import absolute_import, division, print_function
-
-import sys
-import threading
-
-from tornado.util import raise_exc_info
-
-
-class StackContextInconsistentError(Exception):
- pass
-
-
-class _State(threading.local):
- def __init__(self):
- self.contexts = (tuple(), None)
-
-
-_state = _State()
-
-
-class StackContext(object):
- """Establishes the given context as a StackContext that will be transferred.
-
- Note that the parameter is a callable that returns a context
- manager, not the context itself. That is, where for a
- non-transferable context manager you would say::
-
- with my_context():
-
- StackContext takes the function itself rather than its result::
-
- with StackContext(my_context):
-
- The result of ``with StackContext() as cb:`` is a deactivation
- callback. Run this callback when the StackContext is no longer
- needed to ensure that it is not propagated any further (note that
- deactivating a context does not affect any instances of that
- context that are currently pending). This is an advanced feature
- and not necessary in most applications.
- """
- def __init__(self, context_factory):
- self.context_factory = context_factory
- self.contexts = []
- self.active = True
-
- def _deactivate(self):
- self.active = False
-
- # StackContext protocol
- def enter(self):
- context = self.context_factory()
- self.contexts.append(context)
- context.__enter__()
-
- def exit(self, type, value, traceback):
- context = self.contexts.pop()
- context.__exit__(type, value, traceback)
-
- # Note that some of this code is duplicated in ExceptionStackContext
- # below. ExceptionStackContext is more common and doesn't need
- # the full generality of this class.
- def __enter__(self):
- self.old_contexts = _state.contexts
- self.new_contexts = (self.old_contexts[0] + (self,), self)
- _state.contexts = self.new_contexts
-
- try:
- self.enter()
- except:
- _state.contexts = self.old_contexts
- raise
-
- return self._deactivate
-
- def __exit__(self, type, value, traceback):
- try:
- self.exit(type, value, traceback)
- finally:
- final_contexts = _state.contexts
- _state.contexts = self.old_contexts
-
- # Generator coroutines and with-statements with non-local
- # effects interact badly. Check here for signs of
- # the stack getting out of sync.
- # Note that this check comes after restoring _state.context
- # so that if it fails things are left in a (relatively)
- # consistent state.
- if final_contexts is not self.new_contexts:
- raise StackContextInconsistentError(
- 'stack_context inconsistency (may be caused by yield '
- 'within a "with StackContext" block)')
-
- # Break up a reference to itself to allow for faster GC on CPython.
- self.new_contexts = None
-
-
-class ExceptionStackContext(object):
- """Specialization of StackContext for exception handling.
-
- The supplied ``exception_handler`` function will be called in the
- event of an uncaught exception in this context. The semantics are
- similar to a try/finally clause, and intended use cases are to log
- an error, close a socket, or similar cleanup actions. The
- ``exc_info`` triple ``(type, value, traceback)`` will be passed to the
- exception_handler function.
-
- If the exception handler returns true, the exception will be
- consumed and will not be propagated to other exception handlers.
- """
- def __init__(self, exception_handler):
- self.exception_handler = exception_handler
- self.active = True
-
- def _deactivate(self):
- self.active = False
-
- def exit(self, type, value, traceback):
- if type is not None:
- return self.exception_handler(type, value, traceback)
-
- def __enter__(self):
- self.old_contexts = _state.contexts
- self.new_contexts = (self.old_contexts[0], self)
- _state.contexts = self.new_contexts
-
- return self._deactivate
-
- def __exit__(self, type, value, traceback):
- try:
- if type is not None:
- return self.exception_handler(type, value, traceback)
- finally:
- final_contexts = _state.contexts
- _state.contexts = self.old_contexts
-
- if final_contexts is not self.new_contexts:
- raise StackContextInconsistentError(
- 'stack_context inconsistency (may be caused by yield '
- 'within a "with StackContext" block)')
-
- # Break up a reference to itself to allow for faster GC on CPython.
- self.new_contexts = None
-
-
-class NullContext(object):
- """Resets the `StackContext`.
-
- Useful when creating a shared resource on demand (e.g. an
- `.AsyncHTTPClient`) where the stack that caused the creating is
- not relevant to future operations.
- """
- def __enter__(self):
- self.old_contexts = _state.contexts
- _state.contexts = (tuple(), None)
-
- def __exit__(self, type, value, traceback):
- _state.contexts = self.old_contexts
-
-
-def _remove_deactivated(contexts):
- """Remove deactivated handlers from the chain"""
- # Clean ctx handlers
- stack_contexts = tuple([h for h in contexts[0] if h.active])
-
- # Find new head
- head = contexts[1]
- while head is not None and not head.active:
- head = head.old_contexts[1]
-
- # Process chain
- ctx = head
- while ctx is not None:
- parent = ctx.old_contexts[1]
-
- while parent is not None:
- if parent.active:
- break
- ctx.old_contexts = parent.old_contexts
- parent = parent.old_contexts[1]
-
- ctx = parent
-
- return (stack_contexts, head)
-
-
-def wrap(fn):
- """Returns a callable object that will restore the current `StackContext`
- when executed.
-
- Use this whenever saving a callback to be executed later in a
- different execution context (either in a different thread or
- asynchronously in the same thread).
- """
- # Check if function is already wrapped
- if fn is None or hasattr(fn, '_wrapped'):
- return fn
-
- # Capture current stack head
- # TODO: Any other better way to store contexts and update them in wrapped function?
- cap_contexts = [_state.contexts]
-
- if not cap_contexts[0][0] and not cap_contexts[0][1]:
- # Fast path when there are no active contexts.
- def null_wrapper(*args, **kwargs):
- try:
- current_state = _state.contexts
- _state.contexts = cap_contexts[0]
- return fn(*args, **kwargs)
- finally:
- _state.contexts = current_state
- null_wrapper._wrapped = True
- return null_wrapper
-
- def wrapped(*args, **kwargs):
- ret = None
- try:
- # Capture old state
- current_state = _state.contexts
-
- # Remove deactivated items
- cap_contexts[0] = contexts = _remove_deactivated(cap_contexts[0])
-
- # Force new state
- _state.contexts = contexts
-
- # Current exception
- exc = (None, None, None)
- top = None
-
- # Apply stack contexts
- last_ctx = 0
- stack = contexts[0]
-
- # Apply state
- for n in stack:
- try:
- n.enter()
- last_ctx += 1
- except:
- # Exception happened. Record exception info and store top-most handler
- exc = sys.exc_info()
- top = n.old_contexts[1]
-
- # Execute callback if no exception happened while restoring state
- if top is None:
- try:
- ret = fn(*args, **kwargs)
- except:
- exc = sys.exc_info()
- top = contexts[1]
-
- # If there was exception, try to handle it by going through the exception chain
- if top is not None:
- exc = _handle_exception(top, exc)
- else:
- # Otherwise take shorter path and run stack contexts in reverse order
- while last_ctx > 0:
- last_ctx -= 1
- c = stack[last_ctx]
-
- try:
- c.exit(*exc)
- except:
- exc = sys.exc_info()
- top = c.old_contexts[1]
- break
- else:
- top = None
-
- # If if exception happened while unrolling, take longer exception handler path
- if top is not None:
- exc = _handle_exception(top, exc)
-
- # If exception was not handled, raise it
- if exc != (None, None, None):
- raise_exc_info(exc)
- finally:
- _state.contexts = current_state
- return ret
-
- wrapped._wrapped = True
- return wrapped
-
-
-def _handle_exception(tail, exc):
- while tail is not None:
- try:
- if tail.exit(*exc):
- exc = (None, None, None)
- except:
- exc = sys.exc_info()
-
- tail = tail.old_contexts[1]
-
- return exc
-
-
-def run_with_stack_context(context, func):
- """Run a coroutine ``func`` in the given `StackContext`.
-
- It is not safe to have a ``yield`` statement within a ``with StackContext``
- block, so it is difficult to use stack context with `.gen.coroutine`.
- This helper function runs the function in the correct context while
- keeping the ``yield`` and ``with`` statements syntactically separate.
-
- Example::
-
- @gen.coroutine
- def incorrect():
- with StackContext(ctx):
- # ERROR: this will raise StackContextInconsistentError
- yield other_coroutine()
-
- @gen.coroutine
- def correct():
- yield run_with_stack_context(StackContext(ctx), other_coroutine)
-
- .. versionadded:: 3.1
- """
- with context:
- return func()
diff --git a/tornado/tcpclient.py b/tornado/tcpclient.py
index 3a1b58ca86..e2d682ea64 100644
--- a/tornado/tcpclient.py
+++ b/tornado/tcpclient.py
@@ -15,21 +15,21 @@
"""A non-blocking TCP connection factory.
"""
-from __future__ import absolute_import, division, print_function
import functools
import socket
import numbers
import datetime
+import ssl
from tornado.concurrent import Future, future_add_done_callback
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream
from tornado import gen
from tornado.netutil import Resolver
-from tornado.platform.auto import set_close_exec
from tornado.gen import TimeoutError
-from tornado.util import timedelta_to_seconds
+
+from typing import Any, Union, Dict, Tuple, List, Callable, Iterator, Optional, Set
_INITIAL_CONNECT_TIMEOUT = 0.3
@@ -51,20 +51,34 @@ class _Connector(object):
http://tools.ietf.org/html/rfc6555
"""
- def __init__(self, addrinfo, connect):
+
+ def __init__(
+ self,
+ addrinfo: List[Tuple],
+ connect: Callable[
+ [socket.AddressFamily, Tuple], Tuple[IOStream, "Future[IOStream]"]
+ ],
+ ) -> None:
self.io_loop = IOLoop.current()
self.connect = connect
- self.future = Future()
- self.timeout = None
- self.connect_timeout = None
- self.last_error = None
+ self.future = (
+ Future()
+ ) # type: Future[Tuple[socket.AddressFamily, Any, IOStream]]
+ self.timeout = None # type: Optional[object]
+ self.connect_timeout = None # type: Optional[object]
+ self.last_error = None # type: Optional[Exception]
self.remaining = len(addrinfo)
self.primary_addrs, self.secondary_addrs = self.split(addrinfo)
- self.streams = set()
+ self.streams = set() # type: Set[IOStream]
@staticmethod
- def split(addrinfo):
+ def split(
+ addrinfo: List[Tuple],
+ ) -> Tuple[
+ List[Tuple[socket.AddressFamily, Tuple]],
+ List[Tuple[socket.AddressFamily, Tuple]],
+ ]:
"""Partition the ``addrinfo`` list by address family.
Returns two lists. The first list contains the first entry from
@@ -83,14 +97,18 @@ def split(addrinfo):
secondary.append((af, addr))
return primary, secondary
- def start(self, timeout=_INITIAL_CONNECT_TIMEOUT, connect_timeout=None):
+ def start(
+ self,
+ timeout: float = _INITIAL_CONNECT_TIMEOUT,
+ connect_timeout: Optional[Union[float, datetime.timedelta]] = None,
+ ) -> "Future[Tuple[socket.AddressFamily, Any, IOStream]]":
self.try_connect(iter(self.primary_addrs))
self.set_timeout(timeout)
if connect_timeout is not None:
self.set_connect_timeout(connect_timeout)
return self.future
- def try_connect(self, addrs):
+ def try_connect(self, addrs: Iterator[Tuple[socket.AddressFamily, Tuple]]) -> None:
try:
af, addr = next(addrs)
except StopIteration:
@@ -98,15 +116,23 @@ def try_connect(self, addrs):
# might still be working. Send a final error on the future
# only when both queues are finished.
if self.remaining == 0 and not self.future.done():
- self.future.set_exception(self.last_error or
- IOError("connection failed"))
+ self.future.set_exception(
+ self.last_error or IOError("connection failed")
+ )
return
stream, future = self.connect(af, addr)
self.streams.add(stream)
future_add_done_callback(
- future, functools.partial(self.on_connect_done, addrs, af, addr))
+ future, functools.partial(self.on_connect_done, addrs, af, addr)
+ )
- def on_connect_done(self, addrs, af, addr, future):
+ def on_connect_done(
+ self,
+ addrs: Iterator[Tuple[socket.AddressFamily, Tuple]],
+ af: socket.AddressFamily,
+ addr: Tuple,
+ future: "Future[IOStream]",
+ ) -> None:
self.remaining -= 1
try:
stream = future.result()
@@ -132,35 +158,39 @@ def on_connect_done(self, addrs, af, addr, future):
self.future.set_result((af, addr, stream))
self.close_streams()
- def set_timeout(self, timeout):
- self.timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout,
- self.on_timeout)
+ def set_timeout(self, timeout: float) -> None:
+ self.timeout = self.io_loop.add_timeout(
+ self.io_loop.time() + timeout, self.on_timeout
+ )
- def on_timeout(self):
+ def on_timeout(self) -> None:
self.timeout = None
if not self.future.done():
self.try_connect(iter(self.secondary_addrs))
- def clear_timeout(self):
+ def clear_timeout(self) -> None:
if self.timeout is not None:
self.io_loop.remove_timeout(self.timeout)
- def set_connect_timeout(self, connect_timeout):
+ def set_connect_timeout(
+ self, connect_timeout: Union[float, datetime.timedelta]
+ ) -> None:
self.connect_timeout = self.io_loop.add_timeout(
- connect_timeout, self.on_connect_timeout)
+ connect_timeout, self.on_connect_timeout
+ )
- def on_connect_timeout(self):
+ def on_connect_timeout(self) -> None:
if not self.future.done():
self.future.set_exception(TimeoutError())
self.close_streams()
- def clear_timeouts(self):
+ def clear_timeouts(self) -> None:
if self.timeout is not None:
self.io_loop.remove_timeout(self.timeout)
if self.connect_timeout is not None:
self.io_loop.remove_timeout(self.connect_timeout)
- def close_streams(self):
+ def close_streams(self) -> None:
for stream in self.streams:
stream.close()
@@ -171,7 +201,8 @@ class TCPClient(object):
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
"""
- def __init__(self, resolver=None):
+
+ def __init__(self, resolver: Optional[Resolver] = None) -> None:
if resolver is not None:
self.resolver = resolver
self._own_resolver = False
@@ -179,14 +210,21 @@ def __init__(self, resolver=None):
self.resolver = Resolver()
self._own_resolver = True
- def close(self):
+ def close(self) -> None:
if self._own_resolver:
self.resolver.close()
- @gen.coroutine
- def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None,
- max_buffer_size=None, source_ip=None, source_port=None,
- timeout=None):
+ async def connect(
+ self,
+ host: str,
+ port: int,
+ af: socket.AddressFamily = socket.AF_UNSPEC,
+ ssl_options: Optional[Union[Dict[str, Any], ssl.SSLContext]] = None,
+ max_buffer_size: Optional[int] = None,
+ source_ip: Optional[str] = None,
+ source_port: Optional[int] = None,
+ timeout: Optional[Union[float, datetime.timedelta]] = None,
+ ) -> IOStream:
"""Connect to the given host and port.
Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
@@ -216,34 +254,50 @@ def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None,
if isinstance(timeout, numbers.Real):
timeout = IOLoop.current().time() + timeout
elif isinstance(timeout, datetime.timedelta):
- timeout = IOLoop.current().time() + timedelta_to_seconds(timeout)
+ timeout = IOLoop.current().time() + timeout.total_seconds()
else:
raise TypeError("Unsupported timeout %r" % timeout)
if timeout is not None:
- addrinfo = yield gen.with_timeout(
- timeout, self.resolver.resolve(host, port, af))
+ addrinfo = await gen.with_timeout(
+ timeout, self.resolver.resolve(host, port, af)
+ )
else:
- addrinfo = yield self.resolver.resolve(host, port, af)
+ addrinfo = await self.resolver.resolve(host, port, af)
connector = _Connector(
addrinfo,
- functools.partial(self._create_stream, max_buffer_size,
- source_ip=source_ip, source_port=source_port)
+ functools.partial(
+ self._create_stream,
+ max_buffer_size,
+ source_ip=source_ip,
+ source_port=source_port,
+ ),
)
- af, addr, stream = yield connector.start(connect_timeout=timeout)
+ af, addr, stream = await connector.start(connect_timeout=timeout)
# TODO: For better performance we could cache the (af, addr)
# information here and re-use it on subsequent connections to
# the same host. (http://tools.ietf.org/html/rfc6555#section-4.2)
if ssl_options is not None:
if timeout is not None:
- stream = yield gen.with_timeout(timeout, stream.start_tls(
- False, ssl_options=ssl_options, server_hostname=host))
+ stream = await gen.with_timeout(
+ timeout,
+ stream.start_tls(
+ False, ssl_options=ssl_options, server_hostname=host
+ ),
+ )
else:
- stream = yield stream.start_tls(False, ssl_options=ssl_options,
- server_hostname=host)
- raise gen.Return(stream)
-
- def _create_stream(self, max_buffer_size, af, addr, source_ip=None,
- source_port=None):
+ stream = await stream.start_tls(
+ False, ssl_options=ssl_options, server_hostname=host
+ )
+ return stream
+
+ def _create_stream(
+ self,
+ max_buffer_size: int,
+ af: socket.AddressFamily,
+ addr: Tuple,
+ source_ip: Optional[str] = None,
+ source_port: Optional[int] = None,
+ ) -> Tuple[IOStream, "Future[IOStream]"]:
# Always connect in plaintext; we'll convert to ssl if necessary
# after one connection has completed.
source_port_bind = source_port if isinstance(source_port, int) else 0
@@ -251,12 +305,11 @@ def _create_stream(self, max_buffer_size, af, addr, source_ip=None,
if source_port_bind and not source_ip:
# User required a specific port, but did not specify
# a certain source IP, will bind to the default loopback.
- source_ip_bind = '::1' if af == socket.AF_INET6 else '127.0.0.1'
+ source_ip_bind = "::1" if af == socket.AF_INET6 else "127.0.0.1"
# Trying to use the same address family as the requested af socket:
# - 127.0.0.1 for IPv4
# - ::1 for IPv6
socket_obj = socket.socket(af)
- set_close_exec(socket_obj.fileno())
if source_port_bind or source_ip_bind:
# If the user requires binding also to a specific IP/port.
try:
@@ -266,11 +319,10 @@ def _create_stream(self, max_buffer_size, af, addr, source_ip=None,
# Fail loudly if unable to use the IP/port.
raise
try:
- stream = IOStream(socket_obj,
- max_buffer_size=max_buffer_size)
+ stream = IOStream(socket_obj, max_buffer_size=max_buffer_size)
except socket.error as e:
- fu = Future()
+ fu = Future() # type: Future[IOStream]
fu.set_exception(e)
- return fu
+ return stream, fu
else:
return stream, stream.connect(addr)
diff --git a/tornado/tcpserver.py b/tornado/tcpserver.py
index 1adc489560..476ffc936f 100644
--- a/tornado/tcpserver.py
+++ b/tornado/tcpserver.py
@@ -14,11 +14,11 @@
# under the License.
"""A non-blocking, single-threaded TCP server."""
-from __future__ import absolute_import, division, print_function
import errno
import os
import socket
+import ssl
from tornado import gen
from tornado.log import app_log
@@ -28,11 +28,11 @@
from tornado import process
from tornado.util import errno_from_exception
-try:
- import ssl
-except ImportError:
- # ssl is not available on Google App Engine.
- ssl = None
+import typing
+from typing import Union, Dict, Any, Iterable, Optional, Awaitable
+
+if typing.TYPE_CHECKING:
+ from typing import Callable, List # noqa: F401
class TCPServer(object):
@@ -46,12 +46,11 @@ class TCPServer(object):
from tornado import gen
class EchoServer(TCPServer):
- @gen.coroutine
- def handle_stream(self, stream, address):
+ async def handle_stream(self, stream, address):
while True:
try:
- data = yield stream.read_until(b"\n")
- yield stream.write(data)
+ data = await stream.read_until(b"\n")
+ await stream.write(data)
except StreamClosedError:
break
@@ -105,12 +104,17 @@ def handle_stream(self, stream, address):
.. versionchanged:: 5.0
The ``io_loop`` argument has been removed.
"""
- def __init__(self, ssl_options=None, max_buffer_size=None,
- read_chunk_size=None):
+
+ def __init__(
+ self,
+ ssl_options: Optional[Union[Dict[str, Any], ssl.SSLContext]] = None,
+ max_buffer_size: Optional[int] = None,
+ read_chunk_size: Optional[int] = None,
+ ) -> None:
self.ssl_options = ssl_options
- self._sockets = {} # fd -> socket object
- self._handlers = {} # fd -> remove_handler callable
- self._pending_sockets = []
+ self._sockets = {} # type: Dict[int, socket.socket]
+ self._handlers = {} # type: Dict[int, Callable[[], None]]
+ self._pending_sockets = [] # type: List[socket.socket]
self._started = False
self._stopped = False
self.max_buffer_size = max_buffer_size
@@ -122,18 +126,21 @@ def __init__(self, ssl_options=None, max_buffer_size=None,
# which seems like too much work
if self.ssl_options is not None and isinstance(self.ssl_options, dict):
# Only certfile is required: it can contain both keys
- if 'certfile' not in self.ssl_options:
+ if "certfile" not in self.ssl_options:
raise KeyError('missing key "certfile" in ssl_options')
- if not os.path.exists(self.ssl_options['certfile']):
- raise ValueError('certfile "%s" does not exist' %
- self.ssl_options['certfile'])
- if ('keyfile' in self.ssl_options and
- not os.path.exists(self.ssl_options['keyfile'])):
- raise ValueError('keyfile "%s" does not exist' %
- self.ssl_options['keyfile'])
-
- def listen(self, port, address=""):
+ if not os.path.exists(self.ssl_options["certfile"]):
+ raise ValueError(
+ 'certfile "%s" does not exist' % self.ssl_options["certfile"]
+ )
+ if "keyfile" in self.ssl_options and not os.path.exists(
+ self.ssl_options["keyfile"]
+ ):
+ raise ValueError(
+ 'keyfile "%s" does not exist' % self.ssl_options["keyfile"]
+ )
+
+ def listen(self, port: int, address: str = "") -> None:
"""Starts accepting connections on the given port.
This method may be called more than once to listen on multiple ports.
@@ -144,7 +151,7 @@ def listen(self, port, address=""):
sockets = bind_sockets(port, address=address)
self.add_sockets(sockets)
- def add_sockets(self, sockets):
+ def add_sockets(self, sockets: Iterable[socket.socket]) -> None:
"""Makes this server start accepting connections on the given sockets.
The ``sockets`` parameter is a list of socket objects such as
@@ -156,14 +163,21 @@ def add_sockets(self, sockets):
for sock in sockets:
self._sockets[sock.fileno()] = sock
self._handlers[sock.fileno()] = add_accept_handler(
- sock, self._handle_connection)
+ sock, self._handle_connection
+ )
- def add_socket(self, socket):
+ def add_socket(self, socket: socket.socket) -> None:
"""Singular version of `add_sockets`. Takes a single socket object."""
self.add_sockets([socket])
- def bind(self, port, address=None, family=socket.AF_UNSPEC, backlog=128,
- reuse_port=False):
+ def bind(
+ self,
+ port: int,
+ address: Optional[str] = None,
+ family: socket.AddressFamily = socket.AF_UNSPEC,
+ backlog: int = 128,
+ reuse_port: bool = False,
+ ) -> None:
"""Binds this server to the given port on the given address.
To start the server, call `start`. If you want to run this server
@@ -187,14 +201,17 @@ def bind(self, port, address=None, family=socket.AF_UNSPEC, backlog=128,
.. versionchanged:: 4.4
Added the ``reuse_port`` argument.
"""
- sockets = bind_sockets(port, address=address, family=family,
- backlog=backlog, reuse_port=reuse_port)
+ sockets = bind_sockets(
+ port, address=address, family=family, backlog=backlog, reuse_port=reuse_port
+ )
if self._started:
self.add_sockets(sockets)
else:
self._pending_sockets.extend(sockets)
- def start(self, num_processes=1):
+ def start(
+ self, num_processes: Optional[int] = 1, max_restarts: Optional[int] = None
+ ) -> None:
"""Starts this server in the `.IOLoop`.
By default, we run the server in this process and do not fork any
@@ -213,16 +230,24 @@ def start(self, num_processes=1):
which defaults to True when ``debug=True``).
When using multiple processes, no IOLoops can be created or
referenced until after the call to ``TCPServer.start(n)``.
+
+ Values of ``num_processes`` other than 1 are not supported on Windows.
+
+ The ``max_restarts`` argument is passed to `.fork_processes`.
+
+ .. versionchanged:: 6.0
+
+ Added ``max_restarts`` argument.
"""
assert not self._started
self._started = True
if num_processes != 1:
- process.fork_processes(num_processes)
+ process.fork_processes(num_processes, max_restarts)
sockets = self._pending_sockets
self._pending_sockets = []
self.add_sockets(sockets)
- def stop(self):
+ def stop(self) -> None:
"""Stops listening for new connections.
Requests currently in progress may still continue after the
@@ -237,7 +262,9 @@ def stop(self):
self._handlers.pop(fd)()
sock.close()
- def handle_stream(self, stream, address):
+ def handle_stream(
+ self, stream: IOStream, address: tuple
+ ) -> Optional[Awaitable[None]]:
"""Override to handle a new `.IOStream` from an incoming connection.
This method may be a coroutine; if so any exceptions it raises
@@ -254,14 +281,16 @@ def handle_stream(self, stream, address):
"""
raise NotImplementedError()
- def _handle_connection(self, connection, address):
+ def _handle_connection(self, connection: socket.socket, address: Any) -> None:
if self.ssl_options is not None:
assert ssl, "Python 2.6+ and OpenSSL required for SSL"
try:
- connection = ssl_wrap_socket(connection,
- self.ssl_options,
- server_side=True,
- do_handshake_on_connect=False)
+ connection = ssl_wrap_socket(
+ connection,
+ self.ssl_options,
+ server_side=True,
+ do_handshake_on_connect=False,
+ )
except ssl.SSLError as err:
if err.args[0] == ssl.SSL_ERROR_EOF:
return connection.close()
@@ -284,17 +313,22 @@ def _handle_connection(self, connection, address):
raise
try:
if self.ssl_options is not None:
- stream = SSLIOStream(connection,
- max_buffer_size=self.max_buffer_size,
- read_chunk_size=self.read_chunk_size)
+ stream = SSLIOStream(
+ connection,
+ max_buffer_size=self.max_buffer_size,
+ read_chunk_size=self.read_chunk_size,
+ ) # type: IOStream
else:
- stream = IOStream(connection,
- max_buffer_size=self.max_buffer_size,
- read_chunk_size=self.read_chunk_size)
+ stream = IOStream(
+ connection,
+ max_buffer_size=self.max_buffer_size,
+ read_chunk_size=self.read_chunk_size,
+ )
future = self.handle_stream(stream, address)
if future is not None:
- IOLoop.current().add_future(gen.convert_yielded(future),
- lambda f: f.result())
+ IOLoop.current().add_future(
+ gen.convert_yielded(future), lambda f: f.result()
+ )
except Exception:
app_log.error("Error in connection callback", exc_info=True)
diff --git a/tornado/template.py b/tornado/template.py
index 61b987462c..d53e977c5e 100644
--- a/tornado/template.py
+++ b/tornado/template.py
@@ -98,8 +98,9 @@ def add(x, y):
To comment out a section so that it is omitted from the output, surround it
with ``{# ... #}``.
-These tags may be escaped as ``{{!``, ``{%!``, and ``{#!``
-if you need to include a literal ``{{``, ``{%``, or ``{#`` in the output.
+
+To include a literal ``{{``, ``{%``, or ``{#`` in the output, escape them as
+``{{!``, ``{%!``, and ``{#!``, respectively.
``{% apply *function* %}...{% end %}``
@@ -195,9 +196,8 @@ class (and specifically its ``render`` method) and will not work
`filter_whitespace` for available options. New in Tornado 4.3.
"""
-from __future__ import absolute_import, division, print_function
-
import datetime
+from io import StringIO
import linecache
import os.path
import posixpath
@@ -206,18 +206,25 @@ class (and specifically its ``render`` method) and will not work
from tornado import escape
from tornado.log import app_log
-from tornado.util import ObjectDict, exec_in, unicode_type, PY3
+from tornado.util import ObjectDict, exec_in, unicode_type
-if PY3:
- from io import StringIO
-else:
- from cStringIO import StringIO
+from typing import Any, Union, Callable, List, Dict, Iterable, Optional, TextIO
+import typing
+
+if typing.TYPE_CHECKING:
+ from typing import Tuple, ContextManager # noqa: F401
_DEFAULT_AUTOESCAPE = "xhtml_escape"
-_UNSET = object()
-def filter_whitespace(mode, text):
+class _UnsetMarker:
+ pass
+
+
+_UNSET = _UnsetMarker()
+
+
+def filter_whitespace(mode: str, text: str) -> str:
"""Transform whitespace in ``text`` according to ``mode``.
Available modes are:
@@ -230,13 +237,13 @@ def filter_whitespace(mode, text):
.. versionadded:: 4.3
"""
- if mode == 'all':
+ if mode == "all":
return text
- elif mode == 'single':
+ elif mode == "single":
text = re.sub(r"([\t ]+)", " ", text)
text = re.sub(r"(\s*\n\s*)", "\n", text)
return text
- elif mode == 'oneline':
+ elif mode == "oneline":
return re.sub(r"(\s+)", " ", text)
else:
raise Exception("invalid whitespace mode %s" % mode)
@@ -248,12 +255,19 @@ class Template(object):
We compile into Python from the given template_string. You can generate
the template from variables with generate().
"""
+
# note that the constructor's signature is not extracted with
# autodoc because _UNSET looks like garbage. When changing
# this signature update website/sphinx/template.rst too.
- def __init__(self, template_string, name="", loader=None,
- compress_whitespace=_UNSET, autoescape=_UNSET,
- whitespace=None):
+ def __init__(
+ self,
+ template_string: Union[str, bytes],
+ name: str = "",
+ loader: Optional["BaseLoader"] = None,
+ compress_whitespace: Union[bool, _UnsetMarker] = _UNSET,
+ autoescape: Optional[Union[str, _UnsetMarker]] = _UNSET,
+ whitespace: Optional[str] = None,
+ ) -> None:
"""Construct a Template.
:arg str template_string: the contents of the template file.
@@ -289,18 +303,18 @@ def __init__(self, template_string, name="", loader=None,
else:
whitespace = "all"
# Validate the whitespace setting.
- filter_whitespace(whitespace, '')
+ assert whitespace is not None
+ filter_whitespace(whitespace, "")
- if autoescape is not _UNSET:
- self.autoescape = autoescape
+ if not isinstance(autoescape, _UnsetMarker):
+ self.autoescape = autoescape # type: Optional[str]
elif loader:
self.autoescape = loader.autoescape
else:
self.autoescape = _DEFAULT_AUTOESCAPE
self.namespace = loader.namespace if loader else {}
- reader = _TemplateReader(name, escape.native_str(template_string),
- whitespace)
+ reader = _TemplateReader(name, escape.native_str(template_string), whitespace)
self.file = _File(self, _parse(reader, self))
self.code = self._generate_python(loader)
self.loader = loader
@@ -311,14 +325,16 @@ def __init__(self, template_string, name="", loader=None,
# from being applied to the generated code.
self.compiled = compile(
escape.to_unicode(self.code),
- "%s.generated.py" % self.name.replace('.', '_'),
- "exec", dont_inherit=True)
+ "%s.generated.py" % self.name.replace(".", "_"),
+ "exec",
+ dont_inherit=True,
+ )
except Exception:
formatted_code = _format_code(self.code).rstrip()
app_log.error("%s code:\n%s", self.name, formatted_code)
raise
- def generate(self, **kwargs):
+ def generate(self, **kwargs: Any) -> bytes:
"""Generate this template with the given arguments."""
namespace = {
"escape": escape.xhtml_escape,
@@ -332,42 +348,42 @@ def generate(self, **kwargs):
"_tt_string_types": (unicode_type, bytes),
# __name__ and __loader__ allow the traceback mechanism to find
# the generated source code.
- "__name__": self.name.replace('.', '_'),
+ "__name__": self.name.replace(".", "_"),
"__loader__": ObjectDict(get_source=lambda name: self.code),
}
namespace.update(self.namespace)
namespace.update(kwargs)
exec_in(self.compiled, namespace)
- execute = namespace["_tt_execute"]
+ execute = typing.cast(Callable[[], bytes], namespace["_tt_execute"])
# Clear the traceback module's cache of source data now that
# we've generated a new template (mainly for this module's
# unittests, where different tests reuse the same name).
linecache.clearcache()
return execute()
- def _generate_python(self, loader):
+ def _generate_python(self, loader: Optional["BaseLoader"]) -> str:
buffer = StringIO()
try:
# named_blocks maps from names to _NamedBlock objects
- named_blocks = {}
+ named_blocks = {} # type: Dict[str, _NamedBlock]
ancestors = self._get_ancestors(loader)
ancestors.reverse()
for ancestor in ancestors:
ancestor.find_named_blocks(loader, named_blocks)
- writer = _CodeWriter(buffer, named_blocks, loader,
- ancestors[0].template)
+ writer = _CodeWriter(buffer, named_blocks, loader, ancestors[0].template)
ancestors[0].generate(writer)
return buffer.getvalue()
finally:
buffer.close()
- def _get_ancestors(self, loader):
+ def _get_ancestors(self, loader: Optional["BaseLoader"]) -> List["_File"]:
ancestors = [self.file]
for chunk in self.file.body.chunks:
if isinstance(chunk, _ExtendsBlock):
if not loader:
- raise ParseError("{% extends %} block found, but no "
- "template loader")
+ raise ParseError(
+ "{% extends %} block found, but no " "template loader"
+ )
template = loader.load(chunk.name, self.name)
ancestors.extend(template._get_ancestors(loader))
return ancestors
@@ -380,8 +396,13 @@ class BaseLoader(object):
``{% extends %}`` and ``{% include %}``. The loader caches all
templates after they are loaded the first time.
"""
- def __init__(self, autoescape=_DEFAULT_AUTOESCAPE, namespace=None,
- whitespace=None):
+
+ def __init__(
+ self,
+ autoescape: str = _DEFAULT_AUTOESCAPE,
+ namespace: Optional[Dict[str, Any]] = None,
+ whitespace: Optional[str] = None,
+ ) -> None:
"""Construct a template loader.
:arg str autoescape: The name of a function in the template
@@ -400,7 +421,7 @@ def __init__(self, autoescape=_DEFAULT_AUTOESCAPE, namespace=None,
self.autoescape = autoescape
self.namespace = namespace or {}
self.whitespace = whitespace
- self.templates = {}
+ self.templates = {} # type: Dict[str, Template]
# self.lock protects self.templates. It's a reentrant lock
# because templates may load other templates via `include` or
# `extends`. Note that thanks to the GIL this code would be safe
@@ -408,16 +429,16 @@ def __init__(self, autoescape=_DEFAULT_AUTOESCAPE, namespace=None,
# threads tried to compile the same template simultaneously.
self.lock = threading.RLock()
- def reset(self):
+ def reset(self) -> None:
"""Resets the cache of compiled templates."""
with self.lock:
self.templates = {}
- def resolve_path(self, name, parent_path=None):
+ def resolve_path(self, name: str, parent_path: Optional[str] = None) -> str:
"""Converts a possibly-relative path to absolute (used internally)."""
raise NotImplementedError()
- def load(self, name, parent_path=None):
+ def load(self, name: str, parent_path: Optional[str] = None) -> Template:
"""Loads a template."""
name = self.resolve_path(name, parent_path=parent_path)
with self.lock:
@@ -425,29 +446,32 @@ def load(self, name, parent_path=None):
self.templates[name] = self._create_template(name)
return self.templates[name]
- def _create_template(self, name):
+ def _create_template(self, name: str) -> Template:
raise NotImplementedError()
class Loader(BaseLoader):
- """A template loader that loads from a single root directory.
- """
- def __init__(self, root_directory, **kwargs):
- super(Loader, self).__init__(**kwargs)
+ """A template loader that loads from a single root directory."""
+
+ def __init__(self, root_directory: str, **kwargs: Any) -> None:
+ super().__init__(**kwargs)
self.root = os.path.abspath(root_directory)
- def resolve_path(self, name, parent_path=None):
- if parent_path and not parent_path.startswith("<") and \
- not parent_path.startswith("/") and \
- not name.startswith("/"):
+ def resolve_path(self, name: str, parent_path: Optional[str] = None) -> str:
+ if (
+ parent_path
+ and not parent_path.startswith("<")
+ and not parent_path.startswith("/")
+ and not name.startswith("/")
+ ):
current_path = os.path.join(self.root, parent_path)
file_dir = os.path.dirname(os.path.abspath(current_path))
relative_path = os.path.abspath(os.path.join(file_dir, name))
if relative_path.startswith(self.root):
- name = relative_path[len(self.root) + 1:]
+ name = relative_path[len(self.root) + 1 :]
return name
- def _create_template(self, name):
+ def _create_template(self, name: str) -> Template:
path = os.path.join(self.root, name)
with open(path, "rb") as f:
template = Template(f.read(), name=name, loader=self)
@@ -456,41 +480,47 @@ def _create_template(self, name):
class DictLoader(BaseLoader):
"""A template loader that loads from a dictionary."""
- def __init__(self, dict, **kwargs):
- super(DictLoader, self).__init__(**kwargs)
+
+ def __init__(self, dict: Dict[str, str], **kwargs: Any) -> None:
+ super().__init__(**kwargs)
self.dict = dict
- def resolve_path(self, name, parent_path=None):
- if parent_path and not parent_path.startswith("<") and \
- not parent_path.startswith("/") and \
- not name.startswith("/"):
+ def resolve_path(self, name: str, parent_path: Optional[str] = None) -> str:
+ if (
+ parent_path
+ and not parent_path.startswith("<")
+ and not parent_path.startswith("/")
+ and not name.startswith("/")
+ ):
file_dir = posixpath.dirname(parent_path)
name = posixpath.normpath(posixpath.join(file_dir, name))
return name
- def _create_template(self, name):
+ def _create_template(self, name: str) -> Template:
return Template(self.dict[name], name=name, loader=self)
class _Node(object):
- def each_child(self):
+ def each_child(self) -> Iterable["_Node"]:
return ()
- def generate(self, writer):
+ def generate(self, writer: "_CodeWriter") -> None:
raise NotImplementedError()
- def find_named_blocks(self, loader, named_blocks):
+ def find_named_blocks(
+ self, loader: Optional[BaseLoader], named_blocks: Dict[str, "_NamedBlock"]
+ ) -> None:
for child in self.each_child():
child.find_named_blocks(loader, named_blocks)
class _File(_Node):
- def __init__(self, template, body):
+ def __init__(self, template: Template, body: "_ChunkList") -> None:
self.template = template
self.body = body
self.line = 0
- def generate(self, writer):
+ def generate(self, writer: "_CodeWriter") -> None:
writer.write_line("def _tt_execute():", self.line)
with writer.indent():
writer.write_line("_tt_buffer = []", self.line)
@@ -498,73 +528,79 @@ def generate(self, writer):
self.body.generate(writer)
writer.write_line("return _tt_utf8('').join(_tt_buffer)", self.line)
- def each_child(self):
+ def each_child(self) -> Iterable["_Node"]:
return (self.body,)
class _ChunkList(_Node):
- def __init__(self, chunks):
+ def __init__(self, chunks: List[_Node]) -> None:
self.chunks = chunks
- def generate(self, writer):
+ def generate(self, writer: "_CodeWriter") -> None:
for chunk in self.chunks:
chunk.generate(writer)
- def each_child(self):
+ def each_child(self) -> Iterable["_Node"]:
return self.chunks
class _NamedBlock(_Node):
- def __init__(self, name, body, template, line):
+ def __init__(self, name: str, body: _Node, template: Template, line: int) -> None:
self.name = name
self.body = body
self.template = template
self.line = line
- def each_child(self):
+ def each_child(self) -> Iterable["_Node"]:
return (self.body,)
- def generate(self, writer):
+ def generate(self, writer: "_CodeWriter") -> None:
block = writer.named_blocks[self.name]
with writer.include(block.template, self.line):
block.body.generate(writer)
- def find_named_blocks(self, loader, named_blocks):
+ def find_named_blocks(
+ self, loader: Optional[BaseLoader], named_blocks: Dict[str, "_NamedBlock"]
+ ) -> None:
named_blocks[self.name] = self
_Node.find_named_blocks(self, loader, named_blocks)
class _ExtendsBlock(_Node):
- def __init__(self, name):
+ def __init__(self, name: str) -> None:
self.name = name
class _IncludeBlock(_Node):
- def __init__(self, name, reader, line):
+ def __init__(self, name: str, reader: "_TemplateReader", line: int) -> None:
self.name = name
self.template_name = reader.name
self.line = line
- def find_named_blocks(self, loader, named_blocks):
+ def find_named_blocks(
+ self, loader: Optional[BaseLoader], named_blocks: Dict[str, _NamedBlock]
+ ) -> None:
+ assert loader is not None
included = loader.load(self.name, self.template_name)
included.file.find_named_blocks(loader, named_blocks)
- def generate(self, writer):
+ def generate(self, writer: "_CodeWriter") -> None:
+ assert writer.loader is not None
included = writer.loader.load(self.name, self.template_name)
with writer.include(included, self.line):
included.file.body.generate(writer)
class _ApplyBlock(_Node):
- def __init__(self, method, line, body=None):
+ def __init__(self, method: str, line: int, body: _Node) -> None:
self.method = method
self.line = line
self.body = body
- def each_child(self):
+ def each_child(self) -> Iterable["_Node"]:
return (self.body,)
- def generate(self, writer):
+ def generate(self, writer: "_CodeWriter") -> None:
method_name = "_tt_apply%d" % writer.apply_counter
writer.apply_counter += 1
writer.write_line("def %s():" % method_name, self.line)
@@ -573,20 +609,21 @@ def generate(self, writer):
writer.write_line("_tt_append = _tt_buffer.append", self.line)
self.body.generate(writer)
writer.write_line("return _tt_utf8('').join(_tt_buffer)", self.line)
- writer.write_line("_tt_append(_tt_utf8(%s(%s())))" % (
- self.method, method_name), self.line)
+ writer.write_line(
+ "_tt_append(_tt_utf8(%s(%s())))" % (self.method, method_name), self.line
+ )
class _ControlBlock(_Node):
- def __init__(self, statement, line, body=None):
+ def __init__(self, statement: str, line: int, body: _Node) -> None:
self.statement = statement
self.line = line
self.body = body
- def each_child(self):
+ def each_child(self) -> Iterable[_Node]:
return (self.body,)
- def generate(self, writer):
+ def generate(self, writer: "_CodeWriter") -> None:
writer.write_line("%s:" % self.statement, self.line)
with writer.indent():
self.body.generate(writer)
@@ -595,57 +632,60 @@ def generate(self, writer):
class _IntermediateControlBlock(_Node):
- def __init__(self, statement, line):
+ def __init__(self, statement: str, line: int) -> None:
self.statement = statement
self.line = line
- def generate(self, writer):
+ def generate(self, writer: "_CodeWriter") -> None:
# In case the previous block was empty
writer.write_line("pass", self.line)
writer.write_line("%s:" % self.statement, self.line, writer.indent_size() - 1)
class _Statement(_Node):
- def __init__(self, statement, line):
+ def __init__(self, statement: str, line: int) -> None:
self.statement = statement
self.line = line
- def generate(self, writer):
+ def generate(self, writer: "_CodeWriter") -> None:
writer.write_line(self.statement, self.line)
class _Expression(_Node):
- def __init__(self, expression, line, raw=False):
+ def __init__(self, expression: str, line: int, raw: bool = False) -> None:
self.expression = expression
self.line = line
self.raw = raw
- def generate(self, writer):
+ def generate(self, writer: "_CodeWriter") -> None:
writer.write_line("_tt_tmp = %s" % self.expression, self.line)
- writer.write_line("if isinstance(_tt_tmp, _tt_string_types):"
- " _tt_tmp = _tt_utf8(_tt_tmp)", self.line)
+ writer.write_line(
+ "if isinstance(_tt_tmp, _tt_string_types):" " _tt_tmp = _tt_utf8(_tt_tmp)",
+ self.line,
+ )
writer.write_line("else: _tt_tmp = _tt_utf8(str(_tt_tmp))", self.line)
if not self.raw and writer.current_template.autoescape is not None:
# In python3 functions like xhtml_escape return unicode,
# so we have to convert to utf8 again.
- writer.write_line("_tt_tmp = _tt_utf8(%s(_tt_tmp))" %
- writer.current_template.autoescape, self.line)
+ writer.write_line(
+ "_tt_tmp = _tt_utf8(%s(_tt_tmp))" % writer.current_template.autoescape,
+ self.line,
+ )
writer.write_line("_tt_append(_tt_tmp)", self.line)
class _Module(_Expression):
- def __init__(self, expression, line):
- super(_Module, self).__init__("_tt_modules." + expression, line,
- raw=True)
+ def __init__(self, expression: str, line: int) -> None:
+ super().__init__("_tt_modules." + expression, line, raw=True)
class _Text(_Node):
- def __init__(self, value, line, whitespace):
+ def __init__(self, value: str, line: int, whitespace: str) -> None:
self.value = value
self.line = line
self.whitespace = whitespace
- def generate(self, writer):
+ def generate(self, writer: "_CodeWriter") -> None:
value = self.value
# Compress whitespace if requested, with a crude heuristic to avoid
@@ -654,7 +694,7 @@ def generate(self, writer):
value = filter_whitespace(self.whitespace, value)
if value:
- writer.write_line('_tt_append(%r)' % escape.utf8(value), self.line)
+ writer.write_line("_tt_append(%r)" % escape.utf8(value), self.line)
class ParseError(Exception):
@@ -666,75 +706,87 @@ class ParseError(Exception):
.. versionchanged:: 4.3
Added ``filename`` and ``lineno`` attributes.
"""
- def __init__(self, message, filename=None, lineno=0):
+
+ def __init__(
+ self, message: str, filename: Optional[str] = None, lineno: int = 0
+ ) -> None:
self.message = message
# The names "filename" and "lineno" are chosen for consistency
# with python SyntaxError.
self.filename = filename
self.lineno = lineno
- def __str__(self):
- return '%s at %s:%d' % (self.message, self.filename, self.lineno)
+ def __str__(self) -> str:
+ return "%s at %s:%d" % (self.message, self.filename, self.lineno)
class _CodeWriter(object):
- def __init__(self, file, named_blocks, loader, current_template):
+ def __init__(
+ self,
+ file: TextIO,
+ named_blocks: Dict[str, _NamedBlock],
+ loader: Optional[BaseLoader],
+ current_template: Template,
+ ) -> None:
self.file = file
self.named_blocks = named_blocks
self.loader = loader
self.current_template = current_template
self.apply_counter = 0
- self.include_stack = []
+ self.include_stack = [] # type: List[Tuple[Template, int]]
self._indent = 0
- def indent_size(self):
+ def indent_size(self) -> int:
return self._indent
- def indent(self):
+ def indent(self) -> "ContextManager":
class Indenter(object):
- def __enter__(_):
+ def __enter__(_) -> "_CodeWriter":
self._indent += 1
return self
- def __exit__(_, *args):
+ def __exit__(_, *args: Any) -> None:
assert self._indent > 0
self._indent -= 1
return Indenter()
- def include(self, template, line):
+ def include(self, template: Template, line: int) -> "ContextManager":
self.include_stack.append((self.current_template, line))
self.current_template = template
class IncludeTemplate(object):
- def __enter__(_):
+ def __enter__(_) -> "_CodeWriter":
return self
- def __exit__(_, *args):
+ def __exit__(_, *args: Any) -> None:
self.current_template = self.include_stack.pop()[0]
return IncludeTemplate()
- def write_line(self, line, line_number, indent=None):
+ def write_line(
+ self, line: str, line_number: int, indent: Optional[int] = None
+ ) -> None:
if indent is None:
indent = self._indent
- line_comment = ' # %s:%d' % (self.current_template.name, line_number)
+ line_comment = " # %s:%d" % (self.current_template.name, line_number)
if self.include_stack:
- ancestors = ["%s:%d" % (tmpl.name, lineno)
- for (tmpl, lineno) in self.include_stack]
- line_comment += ' (via %s)' % ', '.join(reversed(ancestors))
+ ancestors = [
+ "%s:%d" % (tmpl.name, lineno) for (tmpl, lineno) in self.include_stack
+ ]
+ line_comment += " (via %s)" % ", ".join(reversed(ancestors))
print(" " * indent + line + line_comment, file=self.file)
class _TemplateReader(object):
- def __init__(self, name, text, whitespace):
+ def __init__(self, name: str, text: str, whitespace: str) -> None:
self.name = name
self.text = text
self.whitespace = whitespace
self.line = 1
self.pos = 0
- def find(self, needle, start=0, end=None):
+ def find(self, needle: str, start: int = 0, end: Optional[int] = None) -> int:
assert start >= 0, start
pos = self.pos
start += pos
@@ -748,23 +800,23 @@ def find(self, needle, start=0, end=None):
index -= pos
return index
- def consume(self, count=None):
+ def consume(self, count: Optional[int] = None) -> str:
if count is None:
count = len(self.text) - self.pos
newpos = self.pos + count
self.line += self.text.count("\n", self.pos, newpos)
- s = self.text[self.pos:newpos]
+ s = self.text[self.pos : newpos]
self.pos = newpos
return s
- def remaining(self):
+ def remaining(self) -> int:
return len(self.text) - self.pos
- def __len__(self):
+ def __len__(self) -> int:
return self.remaining()
- def __getitem__(self, key):
- if type(key) is slice:
+ def __getitem__(self, key: Union[int, slice]) -> str:
+ if isinstance(key, slice):
size = len(self)
start, stop, step = key.indices(size)
if start is None:
@@ -779,20 +831,25 @@ def __getitem__(self, key):
else:
return self.text[self.pos + key]
- def __str__(self):
- return self.text[self.pos:]
+ def __str__(self) -> str:
+ return self.text[self.pos :]
- def raise_parse_error(self, msg):
+ def raise_parse_error(self, msg: str) -> None:
raise ParseError(msg, self.name, self.line)
-def _format_code(code):
+def _format_code(code: str) -> str:
lines = code.splitlines()
format = "%%%dd %%s\n" % len(repr(len(lines) + 1))
return "".join([format % (i + 1, line) for (i, line) in enumerate(lines)])
-def _parse(reader, template, in_block=None, in_loop=None):
+def _parse(
+ reader: _TemplateReader,
+ template: Template,
+ in_block: Optional[str] = None,
+ in_loop: Optional[str] = None,
+) -> _ChunkList:
body = _ChunkList([])
while True:
# Find next template directive
@@ -803,9 +860,11 @@ def _parse(reader, template, in_block=None, in_loop=None):
# EOF
if in_block:
reader.raise_parse_error(
- "Missing {%% end %%} block for %s" % in_block)
- body.chunks.append(_Text(reader.consume(), reader.line,
- reader.whitespace))
+ "Missing {%% end %%} block for %s" % in_block
+ )
+ body.chunks.append(
+ _Text(reader.consume(), reader.line, reader.whitespace)
+ )
return body
# If the first curly brace is not the start of a special token,
# start searching from the character after it
@@ -815,8 +874,11 @@ def _parse(reader, template, in_block=None, in_loop=None):
# When there are more than 2 curlies in a row, use the
# innermost ones. This is useful when generating languages
# like latex where curlies are also meaningful
- if (curly + 2 < reader.remaining() and
- reader[curly + 1] == '{' and reader[curly + 2] == '{'):
+ if (
+ curly + 2 < reader.remaining()
+ and reader[curly + 1] == "{"
+ and reader[curly + 2] == "{"
+ ):
curly += 1
continue
break
@@ -824,8 +886,7 @@ def _parse(reader, template, in_block=None, in_loop=None):
# Append any text before the special token
if curly > 0:
cons = reader.consume(curly)
- body.chunks.append(_Text(cons, reader.line,
- reader.whitespace))
+ body.chunks.append(_Text(cons, reader.line, reader.whitespace))
start_brace = reader.consume(2)
line = reader.line
@@ -836,8 +897,7 @@ def _parse(reader, template, in_block=None, in_loop=None):
# which also use double braces.
if reader.remaining() and reader[0] == "!":
reader.consume(1)
- body.chunks.append(_Text(start_brace, line,
- reader.whitespace))
+ body.chunks.append(_Text(start_brace, line, reader.whitespace))
continue
# Comment
@@ -884,12 +944,13 @@ def _parse(reader, template, in_block=None, in_loop=None):
allowed_parents = intermediate_blocks.get(operator)
if allowed_parents is not None:
if not in_block:
- reader.raise_parse_error("%s outside %s block" %
- (operator, allowed_parents))
+ reader.raise_parse_error(
+ "%s outside %s block" % (operator, allowed_parents)
+ )
if in_block not in allowed_parents:
reader.raise_parse_error(
- "%s block cannot be attached to %s block" %
- (operator, in_block))
+ "%s block cannot be attached to %s block" % (operator, in_block)
+ )
body.chunks.append(_IntermediateControlBlock(contents, line))
continue
@@ -899,16 +960,25 @@ def _parse(reader, template, in_block=None, in_loop=None):
reader.raise_parse_error("Extra {% end %} block")
return body
- elif operator in ("extends", "include", "set", "import", "from",
- "comment", "autoescape", "whitespace", "raw",
- "module"):
+ elif operator in (
+ "extends",
+ "include",
+ "set",
+ "import",
+ "from",
+ "comment",
+ "autoescape",
+ "whitespace",
+ "raw",
+ "module",
+ ):
if operator == "comment":
continue
if operator == "extends":
suffix = suffix.strip('"').strip("'")
if not suffix:
reader.raise_parse_error("extends missing file path")
- block = _ExtendsBlock(suffix)
+ block = _ExtendsBlock(suffix) # type: _Node
elif operator in ("import", "from"):
if not suffix:
reader.raise_parse_error("import missing statement")
@@ -923,7 +993,7 @@ def _parse(reader, template, in_block=None, in_loop=None):
reader.raise_parse_error("set missing statement")
block = _Statement(suffix, line)
elif operator == "autoescape":
- fn = suffix.strip()
+ fn = suffix.strip() # type: Optional[str]
if fn == "None":
fn = None
template.autoescape = fn
@@ -931,7 +1001,7 @@ def _parse(reader, template, in_block=None, in_loop=None):
elif operator == "whitespace":
mode = suffix.strip()
# Validate the selected mode
- filter_whitespace(mode, '')
+ filter_whitespace(mode, "")
reader.whitespace = mode
continue
elif operator == "raw":
@@ -967,8 +1037,9 @@ def _parse(reader, template, in_block=None, in_loop=None):
elif operator in ("break", "continue"):
if not in_loop:
- reader.raise_parse_error("%s outside %s block" %
- (operator, set(["for", "while"])))
+ reader.raise_parse_error(
+ "%s outside %s block" % (operator, set(["for", "while"]))
+ )
body.chunks.append(_Statement(contents, line))
continue
diff --git a/tornado/test/__main__.py b/tornado/test/__main__.py
index c78478cbd3..430c895fa2 100644
--- a/tornado/test/__main__.py
+++ b/tornado/test/__main__.py
@@ -2,8 +2,6 @@
This only works in python 2.7+.
"""
-from __future__ import absolute_import, division, print_function
-
from tornado.test.runtests import all, main
# tornado.testing.main autodiscovery relies on 'all' being present in
diff --git a/tornado/test/asyncio_test.py b/tornado/test/asyncio_test.py
index a7c7564964..3f9f3389a2 100644
--- a/tornado/test/asyncio_test.py
+++ b/tornado/test/asyncio_test.py
@@ -10,25 +10,20 @@
# License for the specific language governing permissions and limitations
# under the License.
-from __future__ import absolute_import, division, print_function
+import asyncio
+import unittest
from concurrent.futures import ThreadPoolExecutor
from tornado import gen
from tornado.ioloop import IOLoop
+from tornado.platform.asyncio import (
+ AsyncIOLoop,
+ to_asyncio_future,
+ AnyThreadEventLoopPolicy,
+)
from tornado.testing import AsyncTestCase, gen_test
-from tornado.test.util import unittest, skipBefore33, skipBefore35, exec_test
-try:
- from tornado.platform.asyncio import asyncio
-except ImportError:
- asyncio = None
-else:
- from tornado.platform.asyncio import AsyncIOLoop, to_asyncio_future, AnyThreadEventLoopPolicy
- # This is used in dynamically-evaluated code, so silence pyflakes.
- to_asyncio_future
-
-@unittest.skipIf(asyncio is None, "asyncio module not present")
class AsyncIOLoopTest(AsyncTestCase):
def get_new_ioloop(self):
io_loop = AsyncIOLoop()
@@ -44,32 +39,28 @@ def test_asyncio_future(self):
# Test that we can yield an asyncio future from a tornado coroutine.
# Without 'yield from', we must wrap coroutines in ensure_future,
# which was introduced during Python 3.4, deprecating the prior "async".
- if hasattr(asyncio, 'ensure_future'):
+ if hasattr(asyncio, "ensure_future"):
ensure_future = asyncio.ensure_future
else:
# async is a reserved word in Python 3.7
- ensure_future = getattr(asyncio, 'async')
+ ensure_future = getattr(asyncio, "async")
x = yield ensure_future(
- asyncio.get_event_loop().run_in_executor(None, lambda: 42))
+ asyncio.get_event_loop().run_in_executor(None, lambda: 42)
+ )
self.assertEqual(x, 42)
- @skipBefore33
@gen_test
def test_asyncio_yield_from(self):
- # Test that we can use asyncio coroutines with 'yield from'
- # instead of asyncio.async(). This requires python 3.3 syntax.
- namespace = exec_test(globals(), locals(), """
@gen.coroutine
def f():
event_loop = asyncio.get_event_loop()
x = yield from event_loop.run_in_executor(None, lambda: 42)
return x
- """)
- result = yield namespace['f']()
+
+ result = yield f()
self.assertEqual(result, 42)
- @skipBefore35
def test_asyncio_adapter(self):
# This test demonstrates that when using the asyncio coroutine
# runner (i.e. run_until_complete), the to_asyncio_future
@@ -79,50 +70,44 @@ def test_asyncio_adapter(self):
def tornado_coroutine():
yield gen.moment
raise gen.Return(42)
- native_coroutine_without_adapter = exec_test(globals(), locals(), """
+
async def native_coroutine_without_adapter():
return await tornado_coroutine()
- """)["native_coroutine_without_adapter"]
- native_coroutine_with_adapter = exec_test(globals(), locals(), """
async def native_coroutine_with_adapter():
return await to_asyncio_future(tornado_coroutine())
- """)["native_coroutine_with_adapter"]
# Use the adapter, but two degrees from the tornado coroutine.
- native_coroutine_with_adapter2 = exec_test(globals(), locals(), """
async def native_coroutine_with_adapter2():
return await to_asyncio_future(native_coroutine_without_adapter())
- """)["native_coroutine_with_adapter2"]
# Tornado supports native coroutines both with and without adapters
- self.assertEqual(
- self.io_loop.run_sync(native_coroutine_without_adapter),
- 42)
- self.assertEqual(
- self.io_loop.run_sync(native_coroutine_with_adapter),
- 42)
- self.assertEqual(
- self.io_loop.run_sync(native_coroutine_with_adapter2),
- 42)
+ self.assertEqual(self.io_loop.run_sync(native_coroutine_without_adapter), 42)
+ self.assertEqual(self.io_loop.run_sync(native_coroutine_with_adapter), 42)
+ self.assertEqual(self.io_loop.run_sync(native_coroutine_with_adapter2), 42)
# Asyncio only supports coroutines that yield asyncio-compatible
# Futures (which our Future is since 5.0).
self.assertEqual(
asyncio.get_event_loop().run_until_complete(
- native_coroutine_without_adapter()),
- 42)
+ native_coroutine_without_adapter()
+ ),
+ 42,
+ )
self.assertEqual(
asyncio.get_event_loop().run_until_complete(
- native_coroutine_with_adapter()),
- 42)
+ native_coroutine_with_adapter()
+ ),
+ 42,
+ )
self.assertEqual(
asyncio.get_event_loop().run_until_complete(
- native_coroutine_with_adapter2()),
- 42)
+ native_coroutine_with_adapter2()
+ ),
+ 42,
+ )
-@unittest.skipIf(asyncio is None, "asyncio module not present")
class LeakTest(unittest.TestCase):
def setUp(self):
# Trigger a cleanup of the mapping so we start with a clean slate.
@@ -160,7 +145,6 @@ def test_asyncio_close_leak(self):
self.assertEqual(new_count, 1)
-@unittest.skipIf(asyncio is None, "asyncio module not present")
class AnyThreadEventLoopPolicyTest(unittest.TestCase):
def setUp(self):
self.orig_policy = asyncio.get_event_loop_policy()
@@ -182,22 +166,22 @@ def get_and_close_event_loop():
loop = asyncio.get_event_loop()
loop.close()
return loop
+
future = self.executor.submit(get_and_close_event_loop)
return future.result()
def run_policy_test(self, accessor, expected_type):
# With the default policy, non-main threads don't get an event
# loop.
- self.assertRaises((RuntimeError, AssertionError),
- self.executor.submit(accessor).result)
+ self.assertRaises(
+ (RuntimeError, AssertionError), self.executor.submit(accessor).result
+ )
# Set the policy and we can get a loop.
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
- self.assertIsInstance(
- self.executor.submit(accessor).result(),
- expected_type)
+ self.assertIsInstance(self.executor.submit(accessor).result(), expected_type)
# Clean up to silence leak warnings. Always use asyncio since
# IOLoop doesn't (currently) close the underlying loop.
- self.executor.submit(lambda: asyncio.get_event_loop().close()).result()
+ self.executor.submit(lambda: asyncio.get_event_loop().close()).result() # type: ignore
def test_asyncio_accessor(self):
self.run_policy_test(asyncio.get_event_loop, asyncio.AbstractEventLoop)
diff --git a/tornado/test/auth_test.py b/tornado/test/auth_test.py
index 0d26a35afc..8de863eb21 100644
--- a/tornado/test/auth_test.py
+++ b/tornado/test/auth_test.py
@@ -3,139 +3,92 @@
# and ensure that it doesn't blow up (e.g. with unicode/bytes issues in
# python 3)
-
-from __future__ import absolute_import, division, print_function
-
-import warnings
+import unittest
from tornado.auth import (
- AuthError, OpenIdMixin, OAuthMixin, OAuth2Mixin,
- GoogleOAuth2Mixin, FacebookGraphMixin, TwitterMixin,
+ OpenIdMixin,
+ OAuthMixin,
+ OAuth2Mixin,
+ GoogleOAuth2Mixin,
+ FacebookGraphMixin,
+ TwitterMixin,
)
-from tornado.concurrent import Future
from tornado.escape import json_decode
from tornado import gen
+from tornado.httpclient import HTTPClientError
from tornado.httputil import url_concat
-from tornado.log import gen_log
+from tornado.log import app_log
from tornado.testing import AsyncHTTPTestCase, ExpectLog
-from tornado.test.util import ignore_deprecation
-from tornado.web import RequestHandler, Application, asynchronous, HTTPError
-
-
-class OpenIdClientLoginHandlerLegacy(RequestHandler, OpenIdMixin):
- def initialize(self, test):
- self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
+from tornado.web import RequestHandler, Application, HTTPError
- @asynchronous
- def get(self):
- if self.get_argument('openid.mode', None):
- with warnings.catch_warnings():
- warnings.simplefilter('ignore', DeprecationWarning)
- self.get_authenticated_user(
- self.on_user, http_client=self.settings['http_client'])
- return
- res = self.authenticate_redirect()
- assert isinstance(res, Future)
- assert res.done()
-
- def on_user(self, user):
- if user is None:
- raise Exception("user is None")
- self.finish(user)
+try:
+ from unittest import mock
+except ImportError:
+ mock = None # type: ignore
class OpenIdClientLoginHandler(RequestHandler, OpenIdMixin):
def initialize(self, test):
- self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
+ self._OPENID_ENDPOINT = test.get_url("/openid/server/authenticate")
@gen.coroutine
def get(self):
- if self.get_argument('openid.mode', None):
- user = yield self.get_authenticated_user(http_client=self.settings['http_client'])
+ if self.get_argument("openid.mode", None):
+ user = yield self.get_authenticated_user(
+ http_client=self.settings["http_client"]
+ )
if user is None:
raise Exception("user is None")
self.finish(user)
return
- res = self.authenticate_redirect()
- assert isinstance(res, Future)
- assert res.done()
+ res = self.authenticate_redirect() # type: ignore
+ assert res is None
class OpenIdServerAuthenticateHandler(RequestHandler):
def post(self):
- if self.get_argument('openid.mode') != 'check_authentication':
+ if self.get_argument("openid.mode") != "check_authentication":
raise Exception("incorrect openid.mode %r")
- self.write('is_valid:true')
-
-
-class OAuth1ClientLoginHandlerLegacy(RequestHandler, OAuthMixin):
- def initialize(self, test, version):
- self._OAUTH_VERSION = version
- self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
- self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
- self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/oauth1/server/access_token')
-
- def _oauth_consumer_token(self):
- return dict(key='asdf', secret='qwer')
-
- @asynchronous
- def get(self):
- if self.get_argument('oauth_token', None):
- with warnings.catch_warnings():
- warnings.simplefilter('ignore', DeprecationWarning)
- self.get_authenticated_user(
- self.on_user, http_client=self.settings['http_client'])
- return
- res = self.authorize_redirect(http_client=self.settings['http_client'])
- assert isinstance(res, Future)
-
- def on_user(self, user):
- if user is None:
- raise Exception("user is None")
- self.finish(user)
-
- def _oauth_get_user(self, access_token, callback):
- if self.get_argument('fail_in_get_user', None):
- raise Exception("failing in get_user")
- if access_token != dict(key='uiop', secret='5678'):
- raise Exception("incorrect access token %r" % access_token)
- callback(dict(email='foo@example.com'))
+ self.write("is_valid:true")
class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
def initialize(self, test, version):
self._OAUTH_VERSION = version
- self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
- self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
- self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/oauth1/server/access_token')
+ self._OAUTH_REQUEST_TOKEN_URL = test.get_url("/oauth1/server/request_token")
+ self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth1/server/authorize")
+ self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/oauth1/server/access_token")
def _oauth_consumer_token(self):
- return dict(key='asdf', secret='qwer')
+ return dict(key="asdf", secret="qwer")
@gen.coroutine
def get(self):
- if self.get_argument('oauth_token', None):
- user = yield self.get_authenticated_user(http_client=self.settings['http_client'])
+ if self.get_argument("oauth_token", None):
+ user = yield self.get_authenticated_user(
+ http_client=self.settings["http_client"]
+ )
if user is None:
raise Exception("user is None")
self.finish(user)
return
- yield self.authorize_redirect(http_client=self.settings['http_client'])
+ yield self.authorize_redirect(http_client=self.settings["http_client"])
@gen.coroutine
def _oauth_get_user_future(self, access_token):
- if self.get_argument('fail_in_get_user', None):
+ if self.get_argument("fail_in_get_user", None):
raise Exception("failing in get_user")
- if access_token != dict(key='uiop', secret='5678'):
+ if access_token != dict(key="uiop", secret="5678"):
raise Exception("incorrect access token %r" % access_token)
- return dict(email='foo@example.com')
+ return dict(email="foo@example.com")
class OAuth1ClientLoginCoroutineHandler(OAuth1ClientLoginHandler):
"""Replaces OAuth1ClientLoginCoroutineHandler's get() with a coroutine."""
+
@gen.coroutine
def get(self):
- if self.get_argument('oauth_token', None):
+ if self.get_argument("oauth_token", None):
# Ensure that any exceptions are set on the returned Future,
# not simply thrown into the surrounding StackContext.
try:
@@ -152,41 +105,41 @@ def initialize(self, version):
self._OAUTH_VERSION = version
def _oauth_consumer_token(self):
- return dict(key='asdf', secret='qwer')
+ return dict(key="asdf", secret="qwer")
def get(self):
params = self._oauth_request_parameters(
- 'http://www.example.com/api/asdf',
- dict(key='uiop', secret='5678'),
- parameters=dict(foo='bar'))
+ "http://www.example.com/api/asdf",
+ dict(key="uiop", secret="5678"),
+ parameters=dict(foo="bar"),
+ )
self.write(params)
class OAuth1ServerRequestTokenHandler(RequestHandler):
def get(self):
- self.write('oauth_token=zxcv&oauth_token_secret=1234')
+ self.write("oauth_token=zxcv&oauth_token_secret=1234")
class OAuth1ServerAccessTokenHandler(RequestHandler):
def get(self):
- self.write('oauth_token=uiop&oauth_token_secret=5678')
+ self.write("oauth_token=uiop&oauth_token_secret=5678")
class OAuth2ClientLoginHandler(RequestHandler, OAuth2Mixin):
def initialize(self, test):
- self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth2/server/authorize')
+ self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth2/server/authorize")
def get(self):
- res = self.authorize_redirect()
- assert isinstance(res, Future)
- assert res.done()
+ res = self.authorize_redirect() # type: ignore
+ assert res is None
class FacebookClientLoginHandler(RequestHandler, FacebookGraphMixin):
def initialize(self, test):
- self._OAUTH_AUTHORIZE_URL = test.get_url('/facebook/server/authorize')
- self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/facebook/server/access_token')
- self._FACEBOOK_BASE_URL = test.get_url('/facebook/server')
+ self._OAUTH_AUTHORIZE_URL = test.get_url("/facebook/server/authorize")
+ self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/facebook/server/access_token")
+ self._FACEBOOK_BASE_URL = test.get_url("/facebook/server")
@gen.coroutine
def get(self):
@@ -195,13 +148,15 @@ def get(self):
redirect_uri=self.request.full_url(),
client_id=self.settings["facebook_api_key"],
client_secret=self.settings["facebook_secret"],
- code=self.get_argument("code"))
+ code=self.get_argument("code"),
+ )
self.write(user)
else:
- yield self.authorize_redirect(
+ self.authorize_redirect(
redirect_uri=self.request.full_url(),
client_id=self.settings["facebook_api_key"],
- extra_params={"scope": "read_stream,offline_access"})
+ extra_params={"scope": "read_stream,offline_access"},
+ )
class FacebookServerAccessTokenHandler(RequestHandler):
@@ -211,35 +166,36 @@ def get(self):
class FacebookServerMeHandler(RequestHandler):
def get(self):
- self.write('{}')
+ self.write("{}")
class TwitterClientHandler(RequestHandler, TwitterMixin):
def initialize(self, test):
- self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
- self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/twitter/server/access_token')
- self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
- self._TWITTER_BASE_URL = test.get_url('/twitter/api')
+ self._OAUTH_REQUEST_TOKEN_URL = test.get_url("/oauth1/server/request_token")
+ self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/twitter/server/access_token")
+ self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth1/server/authorize")
+ self._OAUTH_AUTHENTICATE_URL = test.get_url("/twitter/server/authenticate")
+ self._TWITTER_BASE_URL = test.get_url("/twitter/api")
def get_auth_http_client(self):
- return self.settings['http_client']
+ return self.settings["http_client"]
-class TwitterClientLoginHandlerLegacy(TwitterClientHandler):
- @asynchronous
+class TwitterClientLoginHandler(TwitterClientHandler):
+ @gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
- self.get_authenticated_user(self.on_user)
+ user = yield self.get_authenticated_user()
+ if user is None:
+ raise Exception("user is None")
+ self.finish(user)
return
- self.authorize_redirect()
-
- def on_user(self, user):
- if user is None:
- raise Exception("user is None")
- self.finish(user)
+ yield self.authorize_redirect()
-class TwitterClientLoginHandler(TwitterClientHandler):
+class TwitterClientAuthenticateHandler(TwitterClientHandler):
+ # Like TwitterClientLoginHandler, but uses authenticate_redirect
+ # instead of authorize_redirect.
@gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
@@ -248,21 +204,7 @@ def get(self):
raise Exception("user is None")
self.finish(user)
return
- yield self.authorize_redirect()
-
-
-class TwitterClientLoginGenEngineHandler(TwitterClientHandler):
- with ignore_deprecation():
- @asynchronous
- @gen.engine
- def get(self):
- if self.get_argument("oauth_token", None):
- user = yield self.get_authenticated_user()
- self.finish(user)
- else:
- # Old style: with @gen.engine we can ignore the Future from
- # authorize_redirect.
- self.authorize_redirect()
+ yield self.authenticate_redirect()
class TwitterClientLoginGenCoroutineHandler(TwitterClientHandler):
@@ -277,25 +219,6 @@ def get(self):
yield self.authorize_redirect()
-class TwitterClientShowUserHandlerLegacy(TwitterClientHandler):
- with ignore_deprecation():
- @asynchronous
- @gen.engine
- def get(self):
- # TODO: would be nice to go through the login flow instead of
- # cheating with a hard-coded access token.
- with warnings.catch_warnings():
- warnings.simplefilter('ignore', DeprecationWarning)
- response = yield gen.Task(self.twitter_request,
- '/users/show/%s' % self.get_argument('name'),
- access_token=dict(key='hjkl', secret='vbnm'))
- if response is None:
- self.set_status(500)
- self.finish('error from twitter request')
- else:
- self.finish(response)
-
-
class TwitterClientShowUserHandler(TwitterClientHandler):
@gen.coroutine
def get(self):
@@ -303,44 +226,47 @@ def get(self):
# cheating with a hard-coded access token.
try:
response = yield self.twitter_request(
- '/users/show/%s' % self.get_argument('name'),
- access_token=dict(key='hjkl', secret='vbnm'))
- except AuthError:
+ "/users/show/%s" % self.get_argument("name"),
+ access_token=dict(key="hjkl", secret="vbnm"),
+ )
+ except HTTPClientError:
+ # TODO(bdarnell): Should we catch HTTP errors and
+ # transform some of them (like 403s) into AuthError?
self.set_status(500)
- self.finish('error from twitter request')
+ self.finish("error from twitter request")
else:
self.finish(response)
class TwitterServerAccessTokenHandler(RequestHandler):
def get(self):
- self.write('oauth_token=hjkl&oauth_token_secret=vbnm&screen_name=foo')
+ self.write("oauth_token=hjkl&oauth_token_secret=vbnm&screen_name=foo")
class TwitterServerShowUserHandler(RequestHandler):
def get(self, screen_name):
- if screen_name == 'error':
+ if screen_name == "error":
raise HTTPError(500)
- assert 'oauth_nonce' in self.request.arguments
- assert 'oauth_timestamp' in self.request.arguments
- assert 'oauth_signature' in self.request.arguments
- assert self.get_argument('oauth_consumer_key') == 'test_twitter_consumer_key'
- assert self.get_argument('oauth_signature_method') == 'HMAC-SHA1'
- assert self.get_argument('oauth_version') == '1.0'
- assert self.get_argument('oauth_token') == 'hjkl'
+ assert "oauth_nonce" in self.request.arguments
+ assert "oauth_timestamp" in self.request.arguments
+ assert "oauth_signature" in self.request.arguments
+ assert self.get_argument("oauth_consumer_key") == "test_twitter_consumer_key"
+ assert self.get_argument("oauth_signature_method") == "HMAC-SHA1"
+ assert self.get_argument("oauth_version") == "1.0"
+ assert self.get_argument("oauth_token") == "hjkl"
self.write(dict(screen_name=screen_name, name=screen_name.capitalize()))
class TwitterServerVerifyCredentialsHandler(RequestHandler):
def get(self):
- assert 'oauth_nonce' in self.request.arguments
- assert 'oauth_timestamp' in self.request.arguments
- assert 'oauth_signature' in self.request.arguments
- assert self.get_argument('oauth_consumer_key') == 'test_twitter_consumer_key'
- assert self.get_argument('oauth_signature_method') == 'HMAC-SHA1'
- assert self.get_argument('oauth_version') == '1.0'
- assert self.get_argument('oauth_token') == 'hjkl'
- self.write(dict(screen_name='foo', name='Foo'))
+ assert "oauth_nonce" in self.request.arguments
+ assert "oauth_timestamp" in self.request.arguments
+ assert "oauth_signature" in self.request.arguments
+ assert self.get_argument("oauth_consumer_key") == "test_twitter_consumer_key"
+ assert self.get_argument("oauth_signature_method") == "HMAC-SHA1"
+ assert self.get_argument("oauth_version") == "1.0"
+ assert self.get_argument("oauth_token") == "hjkl"
+ self.write(dict(screen_name="foo", name="Foo"))
class AuthTest(AsyncHTTPTestCase):
@@ -348,322 +274,310 @@ def get_app(self):
return Application(
[
# test endpoints
- ('/legacy/openid/client/login', OpenIdClientLoginHandlerLegacy, dict(test=self)),
- ('/openid/client/login', OpenIdClientLoginHandler, dict(test=self)),
- ('/legacy/oauth10/client/login', OAuth1ClientLoginHandlerLegacy,
- dict(test=self, version='1.0')),
- ('/oauth10/client/login', OAuth1ClientLoginHandler,
- dict(test=self, version='1.0')),
- ('/oauth10/client/request_params',
- OAuth1ClientRequestParametersHandler,
- dict(version='1.0')),
- ('/legacy/oauth10a/client/login', OAuth1ClientLoginHandlerLegacy,
- dict(test=self, version='1.0a')),
- ('/oauth10a/client/login', OAuth1ClientLoginHandler,
- dict(test=self, version='1.0a')),
- ('/oauth10a/client/login_coroutine',
- OAuth1ClientLoginCoroutineHandler,
- dict(test=self, version='1.0a')),
- ('/oauth10a/client/request_params',
- OAuth1ClientRequestParametersHandler,
- dict(version='1.0a')),
- ('/oauth2/client/login', OAuth2ClientLoginHandler, dict(test=self)),
-
- ('/facebook/client/login', FacebookClientLoginHandler, dict(test=self)),
-
- ('/legacy/twitter/client/login', TwitterClientLoginHandlerLegacy, dict(test=self)),
- ('/twitter/client/login', TwitterClientLoginHandler, dict(test=self)),
- ('/twitter/client/login_gen_engine',
- TwitterClientLoginGenEngineHandler, dict(test=self)),
- ('/twitter/client/login_gen_coroutine',
- TwitterClientLoginGenCoroutineHandler, dict(test=self)),
- ('/legacy/twitter/client/show_user',
- TwitterClientShowUserHandlerLegacy, dict(test=self)),
- ('/twitter/client/show_user',
- TwitterClientShowUserHandler, dict(test=self)),
-
+ ("/openid/client/login", OpenIdClientLoginHandler, dict(test=self)),
+ (
+ "/oauth10/client/login",
+ OAuth1ClientLoginHandler,
+ dict(test=self, version="1.0"),
+ ),
+ (
+ "/oauth10/client/request_params",
+ OAuth1ClientRequestParametersHandler,
+ dict(version="1.0"),
+ ),
+ (
+ "/oauth10a/client/login",
+ OAuth1ClientLoginHandler,
+ dict(test=self, version="1.0a"),
+ ),
+ (
+ "/oauth10a/client/login_coroutine",
+ OAuth1ClientLoginCoroutineHandler,
+ dict(test=self, version="1.0a"),
+ ),
+ (
+ "/oauth10a/client/request_params",
+ OAuth1ClientRequestParametersHandler,
+ dict(version="1.0a"),
+ ),
+ ("/oauth2/client/login", OAuth2ClientLoginHandler, dict(test=self)),
+ ("/facebook/client/login", FacebookClientLoginHandler, dict(test=self)),
+ ("/twitter/client/login", TwitterClientLoginHandler, dict(test=self)),
+ (
+ "/twitter/client/authenticate",
+ TwitterClientAuthenticateHandler,
+ dict(test=self),
+ ),
+ (
+ "/twitter/client/login_gen_coroutine",
+ TwitterClientLoginGenCoroutineHandler,
+ dict(test=self),
+ ),
+ (
+ "/twitter/client/show_user",
+ TwitterClientShowUserHandler,
+ dict(test=self),
+ ),
# simulated servers
- ('/openid/server/authenticate', OpenIdServerAuthenticateHandler),
- ('/oauth1/server/request_token', OAuth1ServerRequestTokenHandler),
- ('/oauth1/server/access_token', OAuth1ServerAccessTokenHandler),
-
- ('/facebook/server/access_token', FacebookServerAccessTokenHandler),
- ('/facebook/server/me', FacebookServerMeHandler),
- ('/twitter/server/access_token', TwitterServerAccessTokenHandler),
- (r'/twitter/api/users/show/(.*)\.json', TwitterServerShowUserHandler),
- (r'/twitter/api/account/verify_credentials\.json',
- TwitterServerVerifyCredentialsHandler),
+ ("/openid/server/authenticate", OpenIdServerAuthenticateHandler),
+ ("/oauth1/server/request_token", OAuth1ServerRequestTokenHandler),
+ ("/oauth1/server/access_token", OAuth1ServerAccessTokenHandler),
+ ("/facebook/server/access_token", FacebookServerAccessTokenHandler),
+ ("/facebook/server/me", FacebookServerMeHandler),
+ ("/twitter/server/access_token", TwitterServerAccessTokenHandler),
+ (r"/twitter/api/users/show/(.*)\.json", TwitterServerShowUserHandler),
+ (
+ r"/twitter/api/account/verify_credentials\.json",
+ TwitterServerVerifyCredentialsHandler,
+ ),
],
http_client=self.http_client,
- twitter_consumer_key='test_twitter_consumer_key',
- twitter_consumer_secret='test_twitter_consumer_secret',
- facebook_api_key='test_facebook_api_key',
- facebook_secret='test_facebook_secret')
-
- def test_openid_redirect_legacy(self):
- response = self.fetch('/legacy/openid/client/login', follow_redirects=False)
- self.assertEqual(response.code, 302)
- self.assertTrue(
- '/openid/server/authenticate?' in response.headers['Location'])
-
- def test_openid_get_user_legacy(self):
- response = self.fetch('/legacy/openid/client/login?openid.mode=blah'
- '&openid.ns.ax=http://openid.net/srv/ax/1.0'
- '&openid.ax.type.email=http://axschema.org/contact/email'
- '&openid.ax.value.email=foo@example.com')
- response.rethrow()
- parsed = json_decode(response.body)
- self.assertEqual(parsed["email"], "foo@example.com")
+ twitter_consumer_key="test_twitter_consumer_key",
+ twitter_consumer_secret="test_twitter_consumer_secret",
+ facebook_api_key="test_facebook_api_key",
+ facebook_secret="test_facebook_secret",
+ )
def test_openid_redirect(self):
- response = self.fetch('/openid/client/login', follow_redirects=False)
+ response = self.fetch("/openid/client/login", follow_redirects=False)
self.assertEqual(response.code, 302)
- self.assertTrue(
- '/openid/server/authenticate?' in response.headers['Location'])
+ self.assertTrue("/openid/server/authenticate?" in response.headers["Location"])
def test_openid_get_user(self):
- response = self.fetch('/openid/client/login?openid.mode=blah'
- '&openid.ns.ax=http://openid.net/srv/ax/1.0'
- '&openid.ax.type.email=http://axschema.org/contact/email'
- '&openid.ax.value.email=foo@example.com')
+ response = self.fetch(
+ "/openid/client/login?openid.mode=blah"
+ "&openid.ns.ax=http://openid.net/srv/ax/1.0"
+ "&openid.ax.type.email=http://axschema.org/contact/email"
+ "&openid.ax.value.email=foo@example.com"
+ )
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["email"], "foo@example.com")
- def test_oauth10_redirect_legacy(self):
- response = self.fetch('/legacy/oauth10/client/login', follow_redirects=False)
- self.assertEqual(response.code, 302)
- self.assertTrue(response.headers['Location'].endswith(
- '/oauth1/server/authorize?oauth_token=zxcv'))
- # the cookie is base64('zxcv')|base64('1234')
- self.assertTrue(
- '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
- response.headers['Set-Cookie'])
-
def test_oauth10_redirect(self):
- response = self.fetch('/oauth10/client/login', follow_redirects=False)
+ response = self.fetch("/oauth10/client/login", follow_redirects=False)
self.assertEqual(response.code, 302)
- self.assertTrue(response.headers['Location'].endswith(
- '/oauth1/server/authorize?oauth_token=zxcv'))
+ self.assertTrue(
+ response.headers["Location"].endswith(
+ "/oauth1/server/authorize?oauth_token=zxcv"
+ )
+ )
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
- '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
- response.headers['Set-Cookie'])
-
- def test_oauth10_get_user_legacy(self):
- with ignore_deprecation():
- response = self.fetch(
- '/legacy/oauth10/client/login?oauth_token=zxcv',
- headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
- response.rethrow()
- parsed = json_decode(response.body)
- self.assertEqual(parsed['email'], 'foo@example.com')
- self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
+ '_oauth_request_token="enhjdg==|MTIzNA=="'
+ in response.headers["Set-Cookie"],
+ response.headers["Set-Cookie"],
+ )
def test_oauth10_get_user(self):
response = self.fetch(
- '/oauth10/client/login?oauth_token=zxcv',
- headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
+ "/oauth10/client/login?oauth_token=zxcv",
+ headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="},
+ )
response.rethrow()
parsed = json_decode(response.body)
- self.assertEqual(parsed['email'], 'foo@example.com')
- self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
+ self.assertEqual(parsed["email"], "foo@example.com")
+ self.assertEqual(parsed["access_token"], dict(key="uiop", secret="5678"))
def test_oauth10_request_parameters(self):
- response = self.fetch('/oauth10/client/request_params')
- response.rethrow()
- parsed = json_decode(response.body)
- self.assertEqual(parsed['oauth_consumer_key'], 'asdf')
- self.assertEqual(parsed['oauth_token'], 'uiop')
- self.assertTrue('oauth_nonce' in parsed)
- self.assertTrue('oauth_signature' in parsed)
-
- def test_oauth10a_redirect_legacy(self):
- response = self.fetch('/legacy/oauth10a/client/login', follow_redirects=False)
- self.assertEqual(response.code, 302)
- self.assertTrue(response.headers['Location'].endswith(
- '/oauth1/server/authorize?oauth_token=zxcv'))
- # the cookie is base64('zxcv')|base64('1234')
- self.assertTrue(
- '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
- response.headers['Set-Cookie'])
-
- def test_oauth10a_get_user_legacy(self):
- with ignore_deprecation():
- response = self.fetch(
- '/legacy/oauth10a/client/login?oauth_token=zxcv',
- headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
+ response = self.fetch("/oauth10/client/request_params")
response.rethrow()
parsed = json_decode(response.body)
- self.assertEqual(parsed['email'], 'foo@example.com')
- self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
+ self.assertEqual(parsed["oauth_consumer_key"], "asdf")
+ self.assertEqual(parsed["oauth_token"], "uiop")
+ self.assertTrue("oauth_nonce" in parsed)
+ self.assertTrue("oauth_signature" in parsed)
def test_oauth10a_redirect(self):
- response = self.fetch('/oauth10a/client/login', follow_redirects=False)
+ response = self.fetch("/oauth10a/client/login", follow_redirects=False)
self.assertEqual(response.code, 302)
- self.assertTrue(response.headers['Location'].endswith(
- '/oauth1/server/authorize?oauth_token=zxcv'))
+ self.assertTrue(
+ response.headers["Location"].endswith(
+ "/oauth1/server/authorize?oauth_token=zxcv"
+ )
+ )
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
- '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
- response.headers['Set-Cookie'])
+ '_oauth_request_token="enhjdg==|MTIzNA=="'
+ in response.headers["Set-Cookie"],
+ response.headers["Set-Cookie"],
+ )
+
+ @unittest.skipIf(mock is None, "mock package not present")
+ def test_oauth10a_redirect_error(self):
+ with mock.patch.object(OAuth1ServerRequestTokenHandler, "get") as get:
+ get.side_effect = Exception("boom")
+ with ExpectLog(app_log, "Uncaught exception"):
+ response = self.fetch("/oauth10a/client/login", follow_redirects=False)
+ self.assertEqual(response.code, 500)
def test_oauth10a_get_user(self):
response = self.fetch(
- '/oauth10a/client/login?oauth_token=zxcv',
- headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
+ "/oauth10a/client/login?oauth_token=zxcv",
+ headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="},
+ )
response.rethrow()
parsed = json_decode(response.body)
- self.assertEqual(parsed['email'], 'foo@example.com')
- self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
+ self.assertEqual(parsed["email"], "foo@example.com")
+ self.assertEqual(parsed["access_token"], dict(key="uiop", secret="5678"))
def test_oauth10a_request_parameters(self):
- response = self.fetch('/oauth10a/client/request_params')
+ response = self.fetch("/oauth10a/client/request_params")
response.rethrow()
parsed = json_decode(response.body)
- self.assertEqual(parsed['oauth_consumer_key'], 'asdf')
- self.assertEqual(parsed['oauth_token'], 'uiop')
- self.assertTrue('oauth_nonce' in parsed)
- self.assertTrue('oauth_signature' in parsed)
+ self.assertEqual(parsed["oauth_consumer_key"], "asdf")
+ self.assertEqual(parsed["oauth_token"], "uiop")
+ self.assertTrue("oauth_nonce" in parsed)
+ self.assertTrue("oauth_signature" in parsed)
def test_oauth10a_get_user_coroutine_exception(self):
response = self.fetch(
- '/oauth10a/client/login_coroutine?oauth_token=zxcv&fail_in_get_user=true',
- headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
+ "/oauth10a/client/login_coroutine?oauth_token=zxcv&fail_in_get_user=true",
+ headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="},
+ )
self.assertEqual(response.code, 503)
def test_oauth2_redirect(self):
- response = self.fetch('/oauth2/client/login', follow_redirects=False)
+ response = self.fetch("/oauth2/client/login", follow_redirects=False)
self.assertEqual(response.code, 302)
- self.assertTrue('/oauth2/server/authorize?' in response.headers['Location'])
+ self.assertTrue("/oauth2/server/authorize?" in response.headers["Location"])
def test_facebook_login(self):
- response = self.fetch('/facebook/client/login', follow_redirects=False)
+ response = self.fetch("/facebook/client/login", follow_redirects=False)
self.assertEqual(response.code, 302)
- self.assertTrue('/facebook/server/authorize?' in response.headers['Location'])
- response = self.fetch('/facebook/client/login?code=1234', follow_redirects=False)
+ self.assertTrue("/facebook/server/authorize?" in response.headers["Location"])
+ response = self.fetch(
+ "/facebook/client/login?code=1234", follow_redirects=False
+ )
self.assertEqual(response.code, 200)
user = json_decode(response.body)
- self.assertEqual(user['access_token'], 'asdf')
- self.assertEqual(user['session_expires'], '3600')
+ self.assertEqual(user["access_token"], "asdf")
+ self.assertEqual(user["session_expires"], "3600")
def base_twitter_redirect(self, url):
# Same as test_oauth10a_redirect
response = self.fetch(url, follow_redirects=False)
self.assertEqual(response.code, 302)
- self.assertTrue(response.headers['Location'].endswith(
- '/oauth1/server/authorize?oauth_token=zxcv'))
+ self.assertTrue(
+ response.headers["Location"].endswith(
+ "/oauth1/server/authorize?oauth_token=zxcv"
+ )
+ )
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
- '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
- response.headers['Set-Cookie'])
-
- def test_twitter_redirect_legacy(self):
- self.base_twitter_redirect('/legacy/twitter/client/login')
+ '_oauth_request_token="enhjdg==|MTIzNA=="'
+ in response.headers["Set-Cookie"],
+ response.headers["Set-Cookie"],
+ )
def test_twitter_redirect(self):
- self.base_twitter_redirect('/twitter/client/login')
-
- def test_twitter_redirect_gen_engine(self):
- self.base_twitter_redirect('/twitter/client/login_gen_engine')
+ self.base_twitter_redirect("/twitter/client/login")
def test_twitter_redirect_gen_coroutine(self):
- self.base_twitter_redirect('/twitter/client/login_gen_coroutine')
+ self.base_twitter_redirect("/twitter/client/login_gen_coroutine")
+
+ def test_twitter_authenticate_redirect(self):
+ response = self.fetch("/twitter/client/authenticate", follow_redirects=False)
+ self.assertEqual(response.code, 302)
+ self.assertTrue(
+ response.headers["Location"].endswith(
+ "/twitter/server/authenticate?oauth_token=zxcv"
+ ),
+ response.headers["Location"],
+ )
+ # the cookie is base64('zxcv')|base64('1234')
+ self.assertTrue(
+ '_oauth_request_token="enhjdg==|MTIzNA=="'
+ in response.headers["Set-Cookie"],
+ response.headers["Set-Cookie"],
+ )
def test_twitter_get_user(self):
response = self.fetch(
- '/twitter/client/login?oauth_token=zxcv',
- headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
+ "/twitter/client/login?oauth_token=zxcv",
+ headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="},
+ )
response.rethrow()
parsed = json_decode(response.body)
- self.assertEqual(parsed,
- {u'access_token': {u'key': u'hjkl',
- u'screen_name': u'foo',
- u'secret': u'vbnm'},
- u'name': u'Foo',
- u'screen_name': u'foo',
- u'username': u'foo'})
-
- def test_twitter_show_user_legacy(self):
- response = self.fetch('/legacy/twitter/client/show_user?name=somebody')
- response.rethrow()
- self.assertEqual(json_decode(response.body),
- {'name': 'Somebody', 'screen_name': 'somebody'})
-
- def test_twitter_show_user_error_legacy(self):
- with ExpectLog(gen_log, 'Error response HTTP 500'):
- response = self.fetch('/legacy/twitter/client/show_user?name=error')
- self.assertEqual(response.code, 500)
- self.assertEqual(response.body, b'error from twitter request')
+ self.assertEqual(
+ parsed,
+ {
+ u"access_token": {
+ u"key": u"hjkl",
+ u"screen_name": u"foo",
+ u"secret": u"vbnm",
+ },
+ u"name": u"Foo",
+ u"screen_name": u"foo",
+ u"username": u"foo",
+ },
+ )
def test_twitter_show_user(self):
- response = self.fetch('/twitter/client/show_user?name=somebody')
+ response = self.fetch("/twitter/client/show_user?name=somebody")
response.rethrow()
- self.assertEqual(json_decode(response.body),
- {'name': 'Somebody', 'screen_name': 'somebody'})
+ self.assertEqual(
+ json_decode(response.body), {"name": "Somebody", "screen_name": "somebody"}
+ )
def test_twitter_show_user_error(self):
- response = self.fetch('/twitter/client/show_user?name=error')
+ response = self.fetch("/twitter/client/show_user?name=error")
self.assertEqual(response.code, 500)
- self.assertEqual(response.body, b'error from twitter request')
+ self.assertEqual(response.body, b"error from twitter request")
class GoogleLoginHandler(RequestHandler, GoogleOAuth2Mixin):
def initialize(self, test):
self.test = test
- self._OAUTH_REDIRECT_URI = test.get_url('/client/login')
- self._OAUTH_AUTHORIZE_URL = test.get_url('/google/oauth2/authorize')
- self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/google/oauth2/token')
+ self._OAUTH_REDIRECT_URI = test.get_url("/client/login")
+ self._OAUTH_AUTHORIZE_URL = test.get_url("/google/oauth2/authorize")
+ self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/google/oauth2/token")
@gen.coroutine
def get(self):
- code = self.get_argument('code', None)
+ code = self.get_argument("code", None)
if code is not None:
# retrieve authenticate google user
- access = yield self.get_authenticated_user(self._OAUTH_REDIRECT_URI,
- code)
+ access = yield self.get_authenticated_user(self._OAUTH_REDIRECT_URI, code)
user = yield self.oauth2_request(
self.test.get_url("/google/oauth2/userinfo"),
- access_token=access["access_token"])
+ access_token=access["access_token"],
+ )
# return the user and access token as json
user["access_token"] = access["access_token"]
self.write(user)
else:
- yield self.authorize_redirect(
+ self.authorize_redirect(
redirect_uri=self._OAUTH_REDIRECT_URI,
- client_id=self.settings['google_oauth']['key'],
- client_secret=self.settings['google_oauth']['secret'],
- scope=['profile', 'email'],
- response_type='code',
- extra_params={'prompt': 'select_account'})
+ client_id=self.settings["google_oauth"]["key"],
+ client_secret=self.settings["google_oauth"]["secret"],
+ scope=["profile", "email"],
+ response_type="code",
+ extra_params={"prompt": "select_account"},
+ )
class GoogleOAuth2AuthorizeHandler(RequestHandler):
def get(self):
# issue a fake auth code and redirect to redirect_uri
- code = 'fake-authorization-code'
- self.redirect(url_concat(self.get_argument('redirect_uri'),
- dict(code=code)))
+ code = "fake-authorization-code"
+ self.redirect(url_concat(self.get_argument("redirect_uri"), dict(code=code)))
class GoogleOAuth2TokenHandler(RequestHandler):
def post(self):
- assert self.get_argument('code') == 'fake-authorization-code'
+ assert self.get_argument("code") == "fake-authorization-code"
# issue a fake token
- self.finish({
- 'access_token': 'fake-access-token',
- 'expires_in': 'never-expires'
- })
+ self.finish(
+ {"access_token": "fake-access-token", "expires_in": "never-expires"}
+ )
class GoogleOAuth2UserinfoHandler(RequestHandler):
def get(self):
- assert self.get_argument('access_token') == 'fake-access-token'
+ assert self.get_argument("access_token") == "fake-access-token"
# return a fake user
- self.finish({
- 'name': 'Foo',
- 'email': 'foo@example.com'
- })
+ self.finish({"name": "Foo", "email": "foo@example.com"})
class GoogleOAuth2Test(AsyncHTTPTestCase):
@@ -671,22 +585,25 @@ def get_app(self):
return Application(
[
# test endpoints
- ('/client/login', GoogleLoginHandler, dict(test=self)),
-
+ ("/client/login", GoogleLoginHandler, dict(test=self)),
# simulated google authorization server endpoints
- ('/google/oauth2/authorize', GoogleOAuth2AuthorizeHandler),
- ('/google/oauth2/token', GoogleOAuth2TokenHandler),
- ('/google/oauth2/userinfo', GoogleOAuth2UserinfoHandler),
+ ("/google/oauth2/authorize", GoogleOAuth2AuthorizeHandler),
+ ("/google/oauth2/token", GoogleOAuth2TokenHandler),
+ ("/google/oauth2/userinfo", GoogleOAuth2UserinfoHandler),
],
google_oauth={
- "key": 'fake_google_client_id',
- "secret": 'fake_google_client_secret'
- })
+ "key": "fake_google_client_id",
+ "secret": "fake_google_client_secret",
+ },
+ )
def test_google_login(self):
- response = self.fetch('/client/login')
- self.assertDictEqual({
- u'name': u'Foo',
- u'email': u'foo@example.com',
- u'access_token': u'fake-access-token',
- }, json_decode(response.body))
+ response = self.fetch("/client/login")
+ self.assertDictEqual(
+ {
+ u"name": u"Foo",
+ u"email": u"foo@example.com",
+ u"access_token": u"fake-access-token",
+ },
+ json_decode(response.body),
+ )
diff --git a/tornado/test/autoreload_test.py b/tornado/test/autoreload_test.py
index 6a9729dbbe..be481e106f 100644
--- a/tornado/test/autoreload_test.py
+++ b/tornado/test/autoreload_test.py
@@ -1,14 +1,31 @@
-from __future__ import absolute_import, division, print_function
import os
+import shutil
import subprocess
from subprocess import Popen
import sys
from tempfile import mkdtemp
+import time
+import unittest
-from tornado.test.util import unittest
+class AutoreloadTest(unittest.TestCase):
+ def setUp(self):
+ self.path = mkdtemp()
+
+ def tearDown(self):
+ try:
+ shutil.rmtree(self.path)
+ except OSError:
+ # Windows disallows deleting files that are in use by
+ # another process, and even though we've waited for our
+ # child process below, it appears that its lock on these
+ # files is not guaranteed to be released by this point.
+ # Sleep and try again (once).
+ time.sleep(1)
+ shutil.rmtree(self.path)
-MAIN = """\
+ def test_reload_module(self):
+ main = """\
import os
import sys
@@ -24,25 +41,87 @@
autoreload._reload()
"""
-
-class AutoreloadTest(unittest.TestCase):
- def test_reload_module(self):
# Create temporary test application
- path = mkdtemp()
- os.mkdir(os.path.join(path, 'testapp'))
- open(os.path.join(path, 'testapp/__init__.py'), 'w').close()
- with open(os.path.join(path, 'testapp/__main__.py'), 'w') as f:
- f.write(MAIN)
+ os.mkdir(os.path.join(self.path, "testapp"))
+ open(os.path.join(self.path, "testapp/__init__.py"), "w").close()
+ with open(os.path.join(self.path, "testapp/__main__.py"), "w") as f:
+ f.write(main)
# Make sure the tornado module under test is available to the test
# application
pythonpath = os.getcwd()
- if 'PYTHONPATH' in os.environ:
- pythonpath += os.pathsep + os.environ['PYTHONPATH']
+ if "PYTHONPATH" in os.environ:
+ pythonpath += os.pathsep + os.environ["PYTHONPATH"]
p = Popen(
- [sys.executable, '-m', 'testapp'], stdout=subprocess.PIPE,
- cwd=path, env=dict(os.environ, PYTHONPATH=pythonpath),
- universal_newlines=True)
+ [sys.executable, "-m", "testapp"],
+ stdout=subprocess.PIPE,
+ cwd=self.path,
+ env=dict(os.environ, PYTHONPATH=pythonpath),
+ universal_newlines=True,
+ )
out = p.communicate()[0]
- self.assertEqual(out, 'Starting\nStarting\n')
+ self.assertEqual(out, "Starting\nStarting\n")
+
+ def test_reload_wrapper_preservation(self):
+ # This test verifies that when `python -m tornado.autoreload`
+ # is used on an application that also has an internal
+ # autoreload, the reload wrapper is preserved on restart.
+ main = """\
+import os
+import sys
+
+# This import will fail if path is not set up correctly
+import testapp
+
+if 'tornado.autoreload' not in sys.modules:
+ raise Exception('started without autoreload wrapper')
+
+import tornado.autoreload
+
+print('Starting')
+sys.stdout.flush()
+if 'TESTAPP_STARTED' not in os.environ:
+ os.environ['TESTAPP_STARTED'] = '1'
+ # Simulate an internal autoreload (one not caused
+ # by the wrapper).
+ tornado.autoreload._reload()
+else:
+ # Exit directly so autoreload doesn't catch it.
+ os._exit(0)
+"""
+
+ # Create temporary test application
+ os.mkdir(os.path.join(self.path, "testapp"))
+ init_file = os.path.join(self.path, "testapp", "__init__.py")
+ open(init_file, "w").close()
+ main_file = os.path.join(self.path, "testapp", "__main__.py")
+ with open(main_file, "w") as f:
+ f.write(main)
+
+ # Make sure the tornado module under test is available to the test
+ # application
+ pythonpath = os.getcwd()
+ if "PYTHONPATH" in os.environ:
+ pythonpath += os.pathsep + os.environ["PYTHONPATH"]
+
+ autoreload_proc = Popen(
+ [sys.executable, "-m", "tornado.autoreload", "-m", "testapp"],
+ stdout=subprocess.PIPE,
+ cwd=self.path,
+ env=dict(os.environ, PYTHONPATH=pythonpath),
+ universal_newlines=True,
+ )
+
+ # This timeout needs to be fairly generous for pypy due to jit
+ # warmup costs.
+ for i in range(40):
+ if autoreload_proc.poll() is not None:
+ break
+ time.sleep(0.1)
+ else:
+ autoreload_proc.kill()
+ raise Exception("subprocess failed to terminate")
+
+ out = autoreload_proc.communicate()[0]
+ self.assertEqual(out, "Starting\n" * 2)
diff --git a/tornado/test/concurrent_test.py b/tornado/test/concurrent_test.py
index 1df0532cc5..b121c6971a 100644
--- a/tornado/test/concurrent_test.py
+++ b/tornado/test/concurrent_test.py
@@ -12,39 +12,28 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
-from __future__ import absolute_import, division, print_function
-
-import gc
+from concurrent import futures
import logging
import re
import socket
-import sys
-import traceback
-import warnings
-
-from tornado.concurrent import (Future, return_future, ReturnValueIgnoredError,
- run_on_executor, future_set_result_unless_cancelled)
+import typing
+import unittest
+
+from tornado.concurrent import (
+ Future,
+ run_on_executor,
+ future_set_result_unless_cancelled,
+)
from tornado.escape import utf8, to_unicode
from tornado import gen
-from tornado.ioloop import IOLoop
from tornado.iostream import IOStream
-from tornado.log import app_log
-from tornado import stack_context
from tornado.tcpserver import TCPServer
-from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test
-from tornado.test.util import unittest, skipBefore35, exec_test, ignore_deprecation
-
-
-try:
- from concurrent import futures
-except ImportError:
- futures = None
+from tornado.testing import AsyncTestCase, bind_unused_port, gen_test
class MiscFutureTest(AsyncTestCase):
-
def test_future_set_result_unless_cancelled(self):
- fut = Future()
+ fut = Future() # type: Future[int]
future_set_result_unless_cancelled(fut, 42)
self.assertEqual(fut.result(), 42)
self.assertFalse(fut.cancelled())
@@ -58,193 +47,6 @@ def test_future_set_result_unless_cancelled(self):
self.assertEqual(fut.result(), 42)
-class ReturnFutureTest(AsyncTestCase):
- @return_future
- def sync_future(self, callback):
- callback(42)
-
- @return_future
- def async_future(self, callback):
- self.io_loop.add_callback(callback, 42)
-
- @return_future
- def immediate_failure(self, callback):
- 1 / 0
-
- @return_future
- def delayed_failure(self, callback):
- self.io_loop.add_callback(lambda: 1 / 0)
-
- @return_future
- def return_value(self, callback):
- # Note that the result of both running the callback and returning
- # a value (or raising an exception) is unspecified; with current
- # implementations the last event prior to callback resolution wins.
- return 42
-
- @return_future
- def no_result_future(self, callback):
- callback()
-
- def test_immediate_failure(self):
- with self.assertRaises(ZeroDivisionError):
- # The caller sees the error just like a normal function.
- self.immediate_failure(callback=self.stop)
- # The callback is not run because the function failed synchronously.
- self.io_loop.add_timeout(self.io_loop.time() + 0.05, self.stop)
- result = self.wait()
- self.assertIs(result, None)
-
- def test_return_value(self):
- with self.assertRaises(ReturnValueIgnoredError):
- self.return_value(callback=self.stop)
-
- def test_callback_kw(self):
- with ignore_deprecation():
- future = self.sync_future(callback=self.stop)
- result = self.wait()
- self.assertEqual(result, 42)
- self.assertEqual(future.result(), 42)
-
- def test_callback_positional(self):
- # When the callback is passed in positionally, future_wrap shouldn't
- # add another callback in the kwargs.
- with ignore_deprecation():
- future = self.sync_future(self.stop)
- result = self.wait()
- self.assertEqual(result, 42)
- self.assertEqual(future.result(), 42)
-
- def test_no_callback(self):
- future = self.sync_future()
- self.assertEqual(future.result(), 42)
-
- def test_none_callback_kw(self):
- # explicitly pass None as callback
- future = self.sync_future(callback=None)
- self.assertEqual(future.result(), 42)
-
- def test_none_callback_pos(self):
- future = self.sync_future(None)
- self.assertEqual(future.result(), 42)
-
- def test_async_future(self):
- future = self.async_future()
- self.assertFalse(future.done())
- self.io_loop.add_future(future, self.stop)
- future2 = self.wait()
- self.assertIs(future, future2)
- self.assertEqual(future.result(), 42)
-
- @gen_test
- def test_async_future_gen(self):
- result = yield self.async_future()
- self.assertEqual(result, 42)
-
- def test_delayed_failure(self):
- future = self.delayed_failure()
- self.io_loop.add_future(future, self.stop)
- future2 = self.wait()
- self.assertIs(future, future2)
- with self.assertRaises(ZeroDivisionError):
- future.result()
-
- def test_kw_only_callback(self):
- @return_future
- def f(**kwargs):
- kwargs['callback'](42)
- future = f()
- self.assertEqual(future.result(), 42)
-
- def test_error_in_callback(self):
- with ignore_deprecation():
- self.sync_future(callback=lambda future: 1 / 0)
- # The exception gets caught by our StackContext and will be re-raised
- # when we wait.
- self.assertRaises(ZeroDivisionError, self.wait)
-
- def test_no_result_future(self):
- with ignore_deprecation():
- future = self.no_result_future(self.stop)
- result = self.wait()
- self.assertIs(result, None)
- # result of this future is undefined, but not an error
- future.result()
-
- def test_no_result_future_callback(self):
- with ignore_deprecation():
- future = self.no_result_future(callback=lambda: self.stop())
- result = self.wait()
- self.assertIs(result, None)
- future.result()
-
- @gen_test
- def test_future_traceback_legacy(self):
- with ignore_deprecation():
- @return_future
- @gen.engine
- def f(callback):
- yield gen.Task(self.io_loop.add_callback)
- try:
- 1 / 0
- except ZeroDivisionError:
- self.expected_frame = traceback.extract_tb(
- sys.exc_info()[2], limit=1)[0]
- raise
- try:
- yield f()
- self.fail("didn't get expected exception")
- except ZeroDivisionError:
- tb = traceback.extract_tb(sys.exc_info()[2])
- self.assertIn(self.expected_frame, tb)
-
- @gen_test
- def test_future_traceback(self):
- @gen.coroutine
- def f():
- yield gen.moment
- try:
- 1 / 0
- except ZeroDivisionError:
- self.expected_frame = traceback.extract_tb(
- sys.exc_info()[2], limit=1)[0]
- raise
- try:
- yield f()
- self.fail("didn't get expected exception")
- except ZeroDivisionError:
- tb = traceback.extract_tb(sys.exc_info()[2])
- self.assertIn(self.expected_frame, tb)
-
- @gen_test
- def test_uncaught_exception_log(self):
- if IOLoop.configured_class().__name__.endswith('AsyncIOLoop'):
- # Install an exception handler that mirrors our
- # non-asyncio logging behavior.
- def exc_handler(loop, context):
- app_log.error('%s: %s', context['message'],
- type(context.get('exception')))
- self.io_loop.asyncio_loop.set_exception_handler(exc_handler)
-
- @gen.coroutine
- def f():
- yield gen.moment
- 1 / 0
-
- g = f()
-
- with ExpectLog(app_log,
- "(?s)Future.* exception was never retrieved:"
- ".*ZeroDivisionError"):
- yield gen.moment
- yield gen.moment
- # For some reason, TwistedIOLoop and pypy3 need a third iteration
- # in order to drain references to the future
- yield gen.moment
- del g
- gc.collect() # for PyPy
-
-
# The following series of classes demonstrate and test various styles
# of use, with and without generators and futures.
@@ -271,79 +73,36 @@ def __init__(self, port):
self.port = port
def process_response(self, data):
- status, message = re.match('(.*)\t(.*)\n', to_unicode(data)).groups()
- if status == 'ok':
+ m = re.match("(.*)\t(.*)\n", to_unicode(data))
+ if m is None:
+ raise Exception("did not match")
+ status, message = m.groups()
+ if status == "ok":
return message
else:
raise CapError(message)
-class ManualCapClient(BaseCapClient):
- def capitalize(self, request_data, callback=None):
- logging.debug("capitalize")
- self.request_data = request_data
- self.stream = IOStream(socket.socket())
- self.stream.connect(('127.0.0.1', self.port),
- callback=self.handle_connect)
- self.future = Future()
- if callback is not None:
- self.future.add_done_callback(
- stack_context.wrap(lambda future: callback(future.result())))
- return self.future
-
- def handle_connect(self):
- logging.debug("handle_connect")
- self.stream.write(utf8(self.request_data + "\n"))
- self.stream.read_until(b'\n', callback=self.handle_read)
-
- def handle_read(self, data):
- logging.debug("handle_read")
- self.stream.close()
- try:
- self.future.set_result(self.process_response(data))
- except CapError as e:
- self.future.set_exception(e)
-
-
-class DecoratorCapClient(BaseCapClient):
- @return_future
- def capitalize(self, request_data, callback):
- logging.debug("capitalize")
- self.request_data = request_data
- self.stream = IOStream(socket.socket())
- self.stream.connect(('127.0.0.1', self.port),
- callback=self.handle_connect)
- self.callback = callback
-
- def handle_connect(self):
- logging.debug("handle_connect")
- self.stream.write(utf8(self.request_data + "\n"))
- self.stream.read_until(b'\n', callback=self.handle_read)
-
- def handle_read(self, data):
- logging.debug("handle_read")
- self.stream.close()
- self.callback(self.process_response(data))
-
-
class GeneratorCapClient(BaseCapClient):
@gen.coroutine
def capitalize(self, request_data):
- logging.debug('capitalize')
+ logging.debug("capitalize")
stream = IOStream(socket.socket())
- logging.debug('connecting')
- yield stream.connect(('127.0.0.1', self.port))
- stream.write(utf8(request_data + '\n'))
- logging.debug('reading')
- data = yield stream.read_until(b'\n')
- logging.debug('returning')
+ logging.debug("connecting")
+ yield stream.connect(("127.0.0.1", self.port))
+ stream.write(utf8(request_data + "\n"))
+ logging.debug("reading")
+ data = yield stream.read_until(b"\n")
+ logging.debug("returning")
stream.close()
raise gen.Return(self.process_response(data))
class ClientTestMixin(object):
+ client_class = None # type: typing.Callable
+
def setUp(self):
- super(ClientTestMixin, self).setUp() # type: ignore
+ super().setUp() # type: ignore
self.server = CapServer()
sock, port = bind_unused_port()
self.server.add_sockets([sock])
@@ -351,79 +110,41 @@ def setUp(self):
def tearDown(self):
self.server.stop()
- super(ClientTestMixin, self).tearDown() # type: ignore
+ super().tearDown() # type: ignore
- def test_callback(self):
- with ignore_deprecation():
- self.client.capitalize("hello", callback=self.stop)
- result = self.wait()
- self.assertEqual(result, "HELLO")
-
- def test_callback_error(self):
- with ignore_deprecation():
- self.client.capitalize("HELLO", callback=self.stop)
- self.assertRaisesRegexp(CapError, "already capitalized", self.wait)
-
- def test_future(self):
+ def test_future(self: typing.Any):
future = self.client.capitalize("hello")
self.io_loop.add_future(future, self.stop)
self.wait()
self.assertEqual(future.result(), "HELLO")
- def test_future_error(self):
+ def test_future_error(self: typing.Any):
future = self.client.capitalize("HELLO")
self.io_loop.add_future(future, self.stop)
self.wait()
- self.assertRaisesRegexp(CapError, "already capitalized", future.result)
+ self.assertRaisesRegexp(CapError, "already capitalized", future.result) # type: ignore
- def test_generator(self):
+ def test_generator(self: typing.Any):
@gen.coroutine
def f():
result = yield self.client.capitalize("hello")
self.assertEqual(result, "HELLO")
+
self.io_loop.run_sync(f)
- def test_generator_error(self):
+ def test_generator_error(self: typing.Any):
@gen.coroutine
def f():
with self.assertRaisesRegexp(CapError, "already capitalized"):
yield self.client.capitalize("HELLO")
- self.io_loop.run_sync(f)
-
-
-class ManualClientTest(ClientTestMixin, AsyncTestCase):
- client_class = ManualCapClient
-
- def setUp(self):
- self.warning_catcher = warnings.catch_warnings()
- self.warning_catcher.__enter__()
- warnings.simplefilter('ignore', DeprecationWarning)
- super(ManualClientTest, self).setUp()
-
- def tearDown(self):
- super(ManualClientTest, self).tearDown()
- self.warning_catcher.__exit__(None, None, None)
-
-class DecoratorClientTest(ClientTestMixin, AsyncTestCase):
- client_class = DecoratorCapClient
-
- def setUp(self):
- self.warning_catcher = warnings.catch_warnings()
- self.warning_catcher.__enter__()
- warnings.simplefilter('ignore', DeprecationWarning)
- super(DecoratorClientTest, self).setUp()
-
- def tearDown(self):
- super(DecoratorClientTest, self).tearDown()
- self.warning_catcher.__exit__(None, None, None)
+ self.io_loop.run_sync(f)
class GeneratorClientTest(ClientTestMixin, AsyncTestCase):
client_class = GeneratorCapClient
-@unittest.skipIf(futures is None, "concurrent.futures module not present")
class RunOnExecutorTest(AsyncTestCase):
@gen_test
def test_no_calling(self):
@@ -459,7 +180,7 @@ class Object(object):
def __init__(self):
self.__executor = futures.thread.ThreadPoolExecutor(1)
- @run_on_executor(executor='_Object__executor')
+ @run_on_executor(executor="_Object__executor")
def f(self):
return 42
@@ -467,7 +188,6 @@ def f(self):
answer = yield o.f()
self.assertEqual(answer, 42)
- @skipBefore35
@gen_test
def test_async_await(self):
class Object(object):
@@ -479,14 +199,14 @@ def f(self):
return 42
o = Object()
- namespace = exec_test(globals(), locals(), """
+
async def f():
answer = await o.f()
return answer
- """)
- result = yield namespace['f']()
+
+ result = yield f()
self.assertEqual(result, 42)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tornado/test/curl_httpclient_test.py b/tornado/test/curl_httpclient_test.py
index d0cfa979b4..99af293367 100644
--- a/tornado/test/curl_httpclient_test.py
+++ b/tornado/test/curl_httpclient_test.py
@@ -1,22 +1,16 @@
-# coding: utf-8
-from __future__ import absolute_import, division, print_function
-
from hashlib import md5
+import unittest
from tornado.escape import utf8
-from tornado.httpclient import HTTPRequest, HTTPClientError
-from tornado.locks import Event
-from tornado.stack_context import ExceptionStackContext
-from tornado.testing import AsyncHTTPTestCase, gen_test
+from tornado.testing import AsyncHTTPTestCase
from tornado.test import httpclient_test
-from tornado.test.util import unittest, ignore_deprecation
from tornado.web import Application, RequestHandler
try:
- import pycurl # type: ignore
+ import pycurl
except ImportError:
- pycurl = None
+ pycurl = None # type: ignore
if pycurl is not None:
from tornado.curl_httpclient import CurlAsyncHTTPClient
@@ -32,42 +26,48 @@ def get_http_client(self):
class DigestAuthHandler(RequestHandler):
+ def initialize(self, username, password):
+ self.username = username
+ self.password = password
+
def get(self):
- realm = 'test'
- opaque = 'asdf'
+ realm = "test"
+ opaque = "asdf"
# Real implementations would use a random nonce.
nonce = "1234"
- username = 'foo'
- password = 'bar'
- auth_header = self.request.headers.get('Authorization', None)
+ auth_header = self.request.headers.get("Authorization", None)
if auth_header is not None:
- auth_mode, params = auth_header.split(' ', 1)
- assert auth_mode == 'Digest'
+ auth_mode, params = auth_header.split(" ", 1)
+ assert auth_mode == "Digest"
param_dict = {}
- for pair in params.split(','):
- k, v = pair.strip().split('=', 1)
+ for pair in params.split(","):
+ k, v = pair.strip().split("=", 1)
if v[0] == '"' and v[-1] == '"':
v = v[1:-1]
param_dict[k] = v
- assert param_dict['realm'] == realm
- assert param_dict['opaque'] == opaque
- assert param_dict['nonce'] == nonce
- assert param_dict['username'] == username
- assert param_dict['uri'] == self.request.path
- h1 = md5(utf8('%s:%s:%s' % (username, realm, password))).hexdigest()
- h2 = md5(utf8('%s:%s' % (self.request.method,
- self.request.path))).hexdigest()
- digest = md5(utf8('%s:%s:%s' % (h1, nonce, h2))).hexdigest()
- if digest == param_dict['response']:
- self.write('ok')
+ assert param_dict["realm"] == realm
+ assert param_dict["opaque"] == opaque
+ assert param_dict["nonce"] == nonce
+ assert param_dict["username"] == self.username
+ assert param_dict["uri"] == self.request.path
+ h1 = md5(
+ utf8("%s:%s:%s" % (self.username, realm, self.password))
+ ).hexdigest()
+ h2 = md5(
+ utf8("%s:%s" % (self.request.method, self.request.path))
+ ).hexdigest()
+ digest = md5(utf8("%s:%s:%s" % (h1, nonce, h2))).hexdigest()
+ if digest == param_dict["response"]:
+ self.write("ok")
else:
- self.write('fail')
+ self.write("fail")
else:
self.set_status(401)
- self.set_header('WWW-Authenticate',
- 'Digest realm="%s", nonce="%s", opaque="%s"' %
- (realm, nonce, opaque))
+ self.set_header(
+ "WWW-Authenticate",
+ 'Digest realm="%s", nonce="%s", opaque="%s"' % (realm, nonce, opaque),
+ )
class CustomReasonHandler(RequestHandler):
@@ -83,62 +83,47 @@ def get(self):
@unittest.skipIf(pycurl is None, "pycurl module not present")
class CurlHTTPClientTestCase(AsyncHTTPTestCase):
def setUp(self):
- super(CurlHTTPClientTestCase, self).setUp()
+ super().setUp()
self.http_client = self.create_client()
def get_app(self):
- return Application([
- ('/digest', DigestAuthHandler),
- ('/custom_reason', CustomReasonHandler),
- ('/custom_fail_reason', CustomFailReasonHandler),
- ])
+ return Application(
+ [
+ ("/digest", DigestAuthHandler, {"username": "foo", "password": "bar"}),
+ (
+ "/digest_non_ascii",
+ DigestAuthHandler,
+ {"username": "foo", "password": "barユ£"},
+ ),
+ ("/custom_reason", CustomReasonHandler),
+ ("/custom_fail_reason", CustomFailReasonHandler),
+ ]
+ )
def create_client(self, **kwargs):
- return CurlAsyncHTTPClient(force_instance=True,
- defaults=dict(allow_ipv6=False),
- **kwargs)
-
- @gen_test
- def test_prepare_curl_callback_stack_context(self):
- exc_info = []
- error_event = Event()
-
- def error_handler(typ, value, tb):
- exc_info.append((typ, value, tb))
- error_event.set()
- return True
-
- with ExceptionStackContext(error_handler):
- request = HTTPRequest(self.get_url('/custom_reason'),
- prepare_curl_callback=lambda curl: 1 / 0)
- yield [error_event.wait(), self.http_client.fetch(request)]
- self.assertEqual(1, len(exc_info))
- self.assertIs(exc_info[0][0], ZeroDivisionError)
+ return CurlAsyncHTTPClient(
+ force_instance=True, defaults=dict(allow_ipv6=False), **kwargs
+ )
def test_digest_auth(self):
- response = self.fetch('/digest', auth_mode='digest',
- auth_username='foo', auth_password='bar')
- self.assertEqual(response.body, b'ok')
+ response = self.fetch(
+ "/digest", auth_mode="digest", auth_username="foo", auth_password="bar"
+ )
+ self.assertEqual(response.body, b"ok")
def test_custom_reason(self):
- response = self.fetch('/custom_reason')
+ response = self.fetch("/custom_reason")
self.assertEqual(response.reason, "Custom reason")
def test_fail_custom_reason(self):
- response = self.fetch('/custom_fail_reason')
+ response = self.fetch("/custom_fail_reason")
self.assertEqual(str(response.error), "HTTP 400: Custom reason")
- def test_failed_setup(self):
- self.http_client = self.create_client(max_clients=1)
- for i in range(5):
- with ignore_deprecation():
- response = self.fetch(u'/ユニコード')
- self.assertIsNot(response.error, None)
-
- with self.assertRaises((UnicodeEncodeError, HTTPClientError)):
- # This raises UnicodeDecodeError on py3 and
- # HTTPClientError(404) on py2. The main motivation of
- # this test is to ensure that the UnicodeEncodeError
- # during the setup phase doesn't lead the request to
- # be dropped on the floor.
- response = self.fetch(u'/ユニコード', raise_error=True)
+ def test_digest_auth_non_ascii(self):
+ response = self.fetch(
+ "/digest_non_ascii",
+ auth_mode="digest",
+ auth_username="foo",
+ auth_password="barユ£",
+ )
+ self.assertEqual(response.body, b"ok")
diff --git a/tornado/test/escape_test.py b/tornado/test/escape_test.py
index f2f2902a0a..d8f95e426e 100644
--- a/tornado/test/escape_test.py
+++ b/tornado/test/escape_test.py
@@ -1,139 +1,212 @@
-from __future__ import absolute_import, division, print_function
+import unittest
import tornado.escape
from tornado.escape import (
- utf8, xhtml_escape, xhtml_unescape, url_escape, url_unescape,
- to_unicode, json_decode, json_encode, squeeze, recursive_unicode,
+ utf8,
+ xhtml_escape,
+ xhtml_unescape,
+ url_escape,
+ url_unescape,
+ to_unicode,
+ json_decode,
+ json_encode,
+ squeeze,
+ recursive_unicode,
)
from tornado.util import unicode_type
-from tornado.test.util import unittest
+
+from typing import List, Tuple, Union, Dict, Any # noqa: F401
linkify_tests = [
# (input, linkify_kwargs, expected_output)
-
- ("hello http://world.com/!", {},
- u'hello http://world.com/ !'),
-
- ("hello http://world.com/with?param=true&stuff=yes", {},
- u'hello http://world.com/with?param=true&stuff=yes '), # noqa: E501
-
+ (
+ "hello http://world.com/!",
+ {},
+ u'hello http://world.com/ !',
+ ),
+ (
+ "hello http://world.com/with?param=true&stuff=yes",
+ {},
+ u'hello http://world.com/with?param=true&stuff=yes ', # noqa: E501
+ ),
# an opened paren followed by many chars killed Gruber's regex
- ("http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", {},
- u'http://url.com/w (aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'), # noqa: E501
-
+ (
+ "http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
+ {},
+ u'http://url.com/w (aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', # noqa: E501
+ ),
# as did too many dots at the end
- ("http://url.com/withmany.......................................", {},
- u'http://url.com/withmany .......................................'), # noqa: E501
-
- ("http://url.com/withmany((((((((((((((((((((((((((((((((((a)", {},
- u'http://url.com/withmany ((((((((((((((((((((((((((((((((((a)'), # noqa: E501
-
+ (
+ "http://url.com/withmany.......................................",
+ {},
+ u'http://url.com/withmany .......................................', # noqa: E501
+ ),
+ (
+ "http://url.com/withmany((((((((((((((((((((((((((((((((((a)",
+ {},
+ u'http://url.com/withmany ((((((((((((((((((((((((((((((((((a)', # noqa: E501
+ ),
# some examples from http://daringfireball.net/2009/11/liberal_regex_for_matching_urls
# plus a fex extras (such as multiple parentheses).
- ("http://foo.com/blah_blah", {},
- u'http://foo.com/blah_blah '),
-
- ("http://foo.com/blah_blah/", {},
- u'http://foo.com/blah_blah/ '),
-
- ("(Something like http://foo.com/blah_blah)", {},
- u'(Something like http://foo.com/blah_blah )'),
-
- ("http://foo.com/blah_blah_(wikipedia)", {},
- u'http://foo.com/blah_blah_(wikipedia) '),
-
- ("http://foo.com/blah_(blah)_(wikipedia)_blah", {},
- u'http://foo.com/blah_(blah)_(wikipedia)_blah '), # noqa: E501
-
- ("(Something like http://foo.com/blah_blah_(wikipedia))", {},
- u'(Something like http://foo.com/blah_blah_(wikipedia) )'), # noqa: E501
-
- ("http://foo.com/blah_blah.", {},
- u'http://foo.com/blah_blah .'),
-
- ("http://foo.com/blah_blah/.", {},
- u'http://foo.com/blah_blah/ .'),
-
- ("", {},
- u'<http://foo.com/blah_blah >'),
-
- (" ", {},
- u'<http://foo.com/blah_blah/ >'),
-
- ("http://foo.com/blah_blah,", {},
- u'http://foo.com/blah_blah ,'),
-
- ("http://www.example.com/wpstyle/?p=364.", {},
- u'http://www.example.com/wpstyle/?p=364 .'),
-
- ("rdar://1234",
- {"permitted_protocols": ["http", "rdar"]},
- u'rdar://1234 '),
-
- ("rdar:/1234",
- {"permitted_protocols": ["rdar"]},
- u'rdar:/1234 '),
-
- ("http://userid:password@example.com:8080", {},
- u'http://userid:password@example.com:8080 '), # noqa: E501
-
- ("http://userid@example.com", {},
- u'http://userid@example.com '),
-
- ("http://userid@example.com:8080", {},
- u'http://userid@example.com:8080 '),
-
- ("http://userid:password@example.com", {},
- u'http://userid:password@example.com '),
-
- ("message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e",
- {"permitted_protocols": ["http", "message"]},
- u''
- u'message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e '),
-
- (u"http://\u27a1.ws/\u4a39", {},
- u'http://\u27a1.ws/\u4a39 '),
-
- ("http://example.com ", {},
- u'<tag>http://example.com </tag>'),
-
- ("Just a www.example.com link.", {},
- u'Just a www.example.com link.'),
-
- ("Just a www.example.com link.",
- {"require_protocol": True},
- u'Just a www.example.com link.'),
-
- ("A http://reallylong.com/link/that/exceedsthelenglimit.html",
- {"require_protocol": True, "shorten": True},
- u'A http://reallylong.com/link... '), # noqa: E501
-
- ("A http://reallylongdomainnamethatwillbetoolong.com/hi!",
- {"shorten": True},
- u'A http://reallylongdomainnametha... !'), # noqa: E501
-
- ("A file:///passwords.txt and http://web.com link", {},
- u'A file:///passwords.txt and http://web.com link'),
-
- ("A file:///passwords.txt and http://web.com link",
- {"permitted_protocols": ["file"]},
- u'A file:///passwords.txt and http://web.com link'),
-
- ("www.external-link.com",
- {"extra_params": 'rel="nofollow" class="external"'},
- u'www.external-link.com '), # noqa: E501
-
- ("www.external-link.com and www.internal-link.com/blogs extra",
- {"extra_params": lambda href: 'class="internal"' if href.startswith("http://www.internal-link.com") else 'rel="nofollow" class="external"'}, # noqa: E501
- u'www.external-link.com ' # noqa: E501
- u' and www.internal-link.com/blogs extra'), # noqa: E501
-
- ("www.external-link.com",
- {"extra_params": lambda href: ' rel="nofollow" class="external" '},
- u'www.external-link.com '), # noqa: E501
-]
+ (
+ "http://foo.com/blah_blah",
+ {},
+ u'http://foo.com/blah_blah ',
+ ),
+ (
+ "http://foo.com/blah_blah/",
+ {},
+ u'http://foo.com/blah_blah/ ',
+ ),
+ (
+ "(Something like http://foo.com/blah_blah)",
+ {},
+ u'(Something like http://foo.com/blah_blah )',
+ ),
+ (
+ "http://foo.com/blah_blah_(wikipedia)",
+ {},
+ u'http://foo.com/blah_blah_(wikipedia) ',
+ ),
+ (
+ "http://foo.com/blah_(blah)_(wikipedia)_blah",
+ {},
+ u'http://foo.com/blah_(blah)_(wikipedia)_blah ', # noqa: E501
+ ),
+ (
+ "(Something like http://foo.com/blah_blah_(wikipedia))",
+ {},
+ u'(Something like http://foo.com/blah_blah_(wikipedia) )', # noqa: E501
+ ),
+ (
+ "http://foo.com/blah_blah.",
+ {},
+ u'http://foo.com/blah_blah .',
+ ),
+ (
+ "http://foo.com/blah_blah/.",
+ {},
+ u'http://foo.com/blah_blah/ .',
+ ),
+ (
+ "",
+ {},
+ u'<http://foo.com/blah_blah >',
+ ),
+ (
+ " ",
+ {},
+ u'<http://foo.com/blah_blah/ >',
+ ),
+ (
+ "http://foo.com/blah_blah,",
+ {},
+ u'http://foo.com/blah_blah ,',
+ ),
+ (
+ "http://www.example.com/wpstyle/?p=364.",
+ {},
+ u'http://www.example.com/wpstyle/?p=364 .', # noqa: E501
+ ),
+ (
+ "rdar://1234",
+ {"permitted_protocols": ["http", "rdar"]},
+ u'rdar://1234 ',
+ ),
+ (
+ "rdar:/1234",
+ {"permitted_protocols": ["rdar"]},
+ u'rdar:/1234 ',
+ ),
+ (
+ "http://userid:password@example.com:8080",
+ {},
+ u'http://userid:password@example.com:8080 ', # noqa: E501
+ ),
+ (
+ "http://userid@example.com",
+ {},
+ u'http://userid@example.com ',
+ ),
+ (
+ "http://userid@example.com:8080",
+ {},
+ u'http://userid@example.com:8080 ',
+ ),
+ (
+ "http://userid:password@example.com",
+ {},
+ u'http://userid:password@example.com ',
+ ),
+ (
+ "message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e",
+ {"permitted_protocols": ["http", "message"]},
+ u''
+ u"message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e ",
+ ),
+ (
+ u"http://\u27a1.ws/\u4a39",
+ {},
+ u'http://\u27a1.ws/\u4a39 ',
+ ),
+ (
+ "http://example.com ",
+ {},
+ u'<tag>http://example.com </tag>',
+ ),
+ (
+ "Just a www.example.com link.",
+ {},
+ u'Just a www.example.com link.',
+ ),
+ (
+ "Just a www.example.com link.",
+ {"require_protocol": True},
+ u"Just a www.example.com link.",
+ ),
+ (
+ "A http://reallylong.com/link/that/exceedsthelenglimit.html",
+ {"require_protocol": True, "shorten": True},
+ u'A http://reallylong.com/link... ', # noqa: E501
+ ),
+ (
+ "A http://reallylongdomainnamethatwillbetoolong.com/hi!",
+ {"shorten": True},
+ u'A http://reallylongdomainnametha... !', # noqa: E501
+ ),
+ (
+ "A file:///passwords.txt and http://web.com link",
+ {},
+ u'A file:///passwords.txt and http://web.com link',
+ ),
+ (
+ "A file:///passwords.txt and http://web.com link",
+ {"permitted_protocols": ["file"]},
+ u'A file:///passwords.txt and http://web.com link',
+ ),
+ (
+ "www.external-link.com",
+ {"extra_params": 'rel="nofollow" class="external"'},
+ u'www.external-link.com ', # noqa: E501
+ ),
+ (
+ "www.external-link.com and www.internal-link.com/blogs extra",
+ {
+ "extra_params": lambda href: 'class="internal"'
+ if href.startswith("http://www.internal-link.com")
+ else 'rel="nofollow" class="external"'
+ },
+ u'www.external-link.com ' # noqa: E501
+ u' and www.internal-link.com/blogs extra', # noqa: E501
+ ),
+ (
+ "www.external-link.com",
+ {"extra_params": lambda href: ' rel="nofollow" class="external" '},
+ u'www.external-link.com ', # noqa: E501
+ ),
+] # type: List[Tuple[Union[str, bytes], Dict[str, Any], str]]
class EscapeTestCase(unittest.TestCase):
@@ -147,26 +220,24 @@ def test_xhtml_escape(self):
("", "<foo>"),
(u"", u"<foo>"),
(b"", b"<foo>"),
-
("<>&\"'", "<>&"'"),
("&", "&"),
-
(u"<\u00e9>", u"<\u00e9>"),
(b"<\xc3\xa9>", b"<\xc3\xa9>"),
- ]
+ ] # type: List[Tuple[Union[str, bytes], Union[str, bytes]]]
for unescaped, escaped in tests:
self.assertEqual(utf8(xhtml_escape(unescaped)), utf8(escaped))
self.assertEqual(utf8(unescaped), utf8(xhtml_unescape(escaped)))
def test_xhtml_unescape_numeric(self):
tests = [
- ('foo bar', 'foo bar'),
- ('foo bar', 'foo bar'),
- ('foo bar', 'foo bar'),
- ('foo઼bar', u'foo\u0abcbar'),
- ('fooyz;bar', 'fooyz;bar'), # invalid encoding
- ('foobar', 'foobar'), # invalid encoding
- ('foobar', 'foobar'), # invalid encoding
+ ("foo bar", "foo bar"),
+ ("foo bar", "foo bar"),
+ ("foo bar", "foo bar"),
+ ("foo઼bar", u"foo\u0abcbar"),
+ ("fooyz;bar", "fooyz;bar"), # invalid encoding
+ ("foobar", "foobar"), # invalid encoding
+ ("foobar", "foobar"), # invalid encoding
]
for escaped, unescaped in tests:
self.assertEqual(unescaped, xhtml_unescape(escaped))
@@ -174,20 +245,19 @@ def test_xhtml_unescape_numeric(self):
def test_url_escape_unicode(self):
tests = [
# byte strings are passed through as-is
- (u'\u00e9'.encode('utf8'), '%C3%A9'),
- (u'\u00e9'.encode('latin1'), '%E9'),
-
+ (u"\u00e9".encode("utf8"), "%C3%A9"),
+ (u"\u00e9".encode("latin1"), "%E9"),
# unicode strings become utf8
- (u'\u00e9', '%C3%A9'),
- ]
+ (u"\u00e9", "%C3%A9"),
+ ] # type: List[Tuple[Union[str, bytes], str]]
for unescaped, escaped in tests:
self.assertEqual(url_escape(unescaped), escaped)
def test_url_unescape_unicode(self):
tests = [
- ('%C3%A9', u'\u00e9', 'utf8'),
- ('%C3%A9', u'\u00c3\u00a9', 'latin1'),
- ('%C3%A9', utf8(u'\u00e9'), None),
+ ("%C3%A9", u"\u00e9", "utf8"),
+ ("%C3%A9", u"\u00c3\u00a9", "latin1"),
+ ("%C3%A9", utf8(u"\u00e9"), None),
]
for escaped, unescaped, encoding in tests:
# input strings to url_unescape should only contain ascii
@@ -197,17 +267,17 @@ def test_url_unescape_unicode(self):
self.assertEqual(url_unescape(utf8(escaped), encoding), unescaped)
def test_url_escape_quote_plus(self):
- unescaped = '+ #%'
- plus_escaped = '%2B+%23%25'
- escaped = '%2B%20%23%25'
+ unescaped = "+ #%"
+ plus_escaped = "%2B+%23%25"
+ escaped = "%2B%20%23%25"
self.assertEqual(url_escape(unescaped), plus_escaped)
self.assertEqual(url_escape(unescaped, plus=False), escaped)
self.assertEqual(url_unescape(plus_escaped), unescaped)
self.assertEqual(url_unescape(escaped, plus=False), unescaped)
- self.assertEqual(url_unescape(plus_escaped, encoding=None),
- utf8(unescaped))
- self.assertEqual(url_unescape(escaped, encoding=None, plus=False),
- utf8(unescaped))
+ self.assertEqual(url_unescape(plus_escaped, encoding=None), utf8(unescaped))
+ self.assertEqual(
+ url_unescape(escaped, encoding=None, plus=False), utf8(unescaped)
+ )
def test_escape_return_types(self):
# On python2 the escape methods should generally return the same
@@ -234,17 +304,19 @@ def test_json_encode(self):
self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9")
def test_squeeze(self):
- self.assertEqual(squeeze(u'sequences of whitespace chars'),
- u'sequences of whitespace chars')
+ self.assertEqual(
+ squeeze(u"sequences of whitespace chars"),
+ u"sequences of whitespace chars",
+ )
def test_recursive_unicode(self):
tests = {
- 'dict': {b"foo": b"bar"},
- 'list': [b"foo", b"bar"],
- 'tuple': (b"foo", b"bar"),
- 'bytes': b"foo"
+ "dict": {b"foo": b"bar"},
+ "list": [b"foo", b"bar"],
+ "tuple": (b"foo", b"bar"),
+ "bytes": b"foo",
}
- self.assertEqual(recursive_unicode(tests['dict']), {u"foo": u"bar"})
- self.assertEqual(recursive_unicode(tests['list']), [u"foo", u"bar"])
- self.assertEqual(recursive_unicode(tests['tuple']), (u"foo", u"bar"))
- self.assertEqual(recursive_unicode(tests['bytes']), u"foo")
+ self.assertEqual(recursive_unicode(tests["dict"]), {u"foo": u"bar"})
+ self.assertEqual(recursive_unicode(tests["list"]), [u"foo", u"bar"])
+ self.assertEqual(recursive_unicode(tests["tuple"]), (u"foo", u"bar"))
+ self.assertEqual(recursive_unicode(tests["bytes"]), u"foo")
diff --git a/tornado/test/gen_test.py b/tornado/test/gen_test.py
index 8b2f62b972..73c3387803 100644
--- a/tornado/test/gen_test.py
+++ b/tornado/test/gen_test.py
@@ -1,675 +1,32 @@
-from __future__ import absolute_import, division, print_function
-
+import asyncio
+from concurrent import futures
import gc
-import contextlib
import datetime
-import functools
import platform
import sys
-import textwrap
import time
import weakref
-import warnings
+import unittest
-from tornado.concurrent import return_future, Future
-from tornado.escape import url_escape
-from tornado.httpclient import AsyncHTTPClient
-from tornado.ioloop import IOLoop
+from tornado.concurrent import Future
from tornado.log import app_log
-from tornado import stack_context
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
-from tornado.test.util import unittest, skipOnTravis, skipBefore33, skipBefore35, skipNotCPython, exec_test, ignore_deprecation # noqa: E501
-from tornado.web import Application, RequestHandler, asynchronous, HTTPError
+from tornado.test.util import skipOnTravis, skipNotCPython
+from tornado.web import Application, RequestHandler, HTTPError
from tornado import gen
try:
- from concurrent import futures
-except ImportError:
- futures = None
-
-try:
- import asyncio
+ import contextvars
except ImportError:
- asyncio = None
-
-
-class GenEngineTest(AsyncTestCase):
- def setUp(self):
- self.warning_catcher = warnings.catch_warnings()
- self.warning_catcher.__enter__()
- warnings.simplefilter('ignore', DeprecationWarning)
- super(GenEngineTest, self).setUp()
- self.named_contexts = []
-
- def tearDown(self):
- super(GenEngineTest, self).tearDown()
- self.warning_catcher.__exit__(None, None, None)
-
- def named_context(self, name):
- @contextlib.contextmanager
- def context():
- self.named_contexts.append(name)
- try:
- yield
- finally:
- self.assertEqual(self.named_contexts.pop(), name)
- return context
-
- def run_gen(self, f):
- f()
- return self.wait()
-
- def delay_callback(self, iterations, callback, arg):
- """Runs callback(arg) after a number of IOLoop iterations."""
- if iterations == 0:
- callback(arg)
- else:
- self.io_loop.add_callback(functools.partial(
- self.delay_callback, iterations - 1, callback, arg))
-
- @return_future
- def async_future(self, result, callback):
- self.io_loop.add_callback(callback, result)
-
- @gen.coroutine
- def async_exception(self, e):
- yield gen.moment
- raise e
-
- def test_no_yield(self):
- @gen.engine
- def f():
- self.stop()
- self.run_gen(f)
-
- def test_inline_cb(self):
- @gen.engine
- def f():
- (yield gen.Callback("k1"))()
- res = yield gen.Wait("k1")
- self.assertTrue(res is None)
- self.stop()
- self.run_gen(f)
-
- def test_ioloop_cb(self):
- @gen.engine
- def f():
- self.io_loop.add_callback((yield gen.Callback("k1")))
- yield gen.Wait("k1")
- self.stop()
- self.run_gen(f)
-
- def test_exception_phase1(self):
- @gen.engine
- def f():
- 1 / 0
- self.assertRaises(ZeroDivisionError, self.run_gen, f)
-
- def test_exception_phase2(self):
- @gen.engine
- def f():
- self.io_loop.add_callback((yield gen.Callback("k1")))
- yield gen.Wait("k1")
- 1 / 0
- self.assertRaises(ZeroDivisionError, self.run_gen, f)
-
- def test_exception_in_task_phase1(self):
- def fail_task(callback):
- 1 / 0
-
- @gen.engine
- def f():
- try:
- yield gen.Task(fail_task)
- raise Exception("did not get expected exception")
- except ZeroDivisionError:
- self.stop()
- self.run_gen(f)
-
- def test_exception_in_task_phase2(self):
- # This is the case that requires the use of stack_context in gen.engine
- def fail_task(callback):
- self.io_loop.add_callback(lambda: 1 / 0)
-
- @gen.engine
- def f():
- try:
- yield gen.Task(fail_task)
- raise Exception("did not get expected exception")
- except ZeroDivisionError:
- self.stop()
- self.run_gen(f)
-
- def test_with_arg(self):
- @gen.engine
- def f():
- (yield gen.Callback("k1"))(42)
- res = yield gen.Wait("k1")
- self.assertEqual(42, res)
- self.stop()
- self.run_gen(f)
-
- def test_with_arg_tuple(self):
- @gen.engine
- def f():
- (yield gen.Callback((1, 2)))((3, 4))
- res = yield gen.Wait((1, 2))
- self.assertEqual((3, 4), res)
- self.stop()
- self.run_gen(f)
-
- def test_key_reuse(self):
- @gen.engine
- def f():
- yield gen.Callback("k1")
- yield gen.Callback("k1")
- self.stop()
- self.assertRaises(gen.KeyReuseError, self.run_gen, f)
-
- def test_key_reuse_tuple(self):
- @gen.engine
- def f():
- yield gen.Callback((1, 2))
- yield gen.Callback((1, 2))
- self.stop()
- self.assertRaises(gen.KeyReuseError, self.run_gen, f)
-
- def test_key_mismatch(self):
- @gen.engine
- def f():
- yield gen.Callback("k1")
- yield gen.Wait("k2")
- self.stop()
- self.assertRaises(gen.UnknownKeyError, self.run_gen, f)
-
- def test_key_mismatch_tuple(self):
- @gen.engine
- def f():
- yield gen.Callback((1, 2))
- yield gen.Wait((2, 3))
- self.stop()
- self.assertRaises(gen.UnknownKeyError, self.run_gen, f)
-
- def test_leaked_callback(self):
- @gen.engine
- def f():
- yield gen.Callback("k1")
- self.stop()
- self.assertRaises(gen.LeakedCallbackError, self.run_gen, f)
-
- def test_leaked_callback_tuple(self):
- @gen.engine
- def f():
- yield gen.Callback((1, 2))
- self.stop()
- self.assertRaises(gen.LeakedCallbackError, self.run_gen, f)
-
- def test_parallel_callback(self):
- @gen.engine
- def f():
- for k in range(3):
- self.io_loop.add_callback((yield gen.Callback(k)))
- yield gen.Wait(1)
- self.io_loop.add_callback((yield gen.Callback(3)))
- yield gen.Wait(0)
- yield gen.Wait(3)
- yield gen.Wait(2)
- self.stop()
- self.run_gen(f)
-
- def test_bogus_yield(self):
- @gen.engine
- def f():
- yield 42
- self.assertRaises(gen.BadYieldError, self.run_gen, f)
-
- def test_bogus_yield_tuple(self):
- @gen.engine
- def f():
- yield (1, 2)
- self.assertRaises(gen.BadYieldError, self.run_gen, f)
-
- def test_reuse(self):
- @gen.engine
- def f():
- self.io_loop.add_callback((yield gen.Callback(0)))
- yield gen.Wait(0)
- self.stop()
- self.run_gen(f)
- self.run_gen(f)
-
- def test_task(self):
- @gen.engine
- def f():
- yield gen.Task(self.io_loop.add_callback)
- self.stop()
- self.run_gen(f)
-
- def test_wait_all(self):
- @gen.engine
- def f():
- (yield gen.Callback("k1"))("v1")
- (yield gen.Callback("k2"))("v2")
- results = yield gen.WaitAll(["k1", "k2"])
- self.assertEqual(results, ["v1", "v2"])
- self.stop()
- self.run_gen(f)
-
- def test_exception_in_yield(self):
- @gen.engine
- def f():
- try:
- yield gen.Wait("k1")
- raise Exception("did not get expected exception")
- except gen.UnknownKeyError:
- pass
- self.stop()
- self.run_gen(f)
-
- def test_resume_after_exception_in_yield(self):
- @gen.engine
- def f():
- try:
- yield gen.Wait("k1")
- raise Exception("did not get expected exception")
- except gen.UnknownKeyError:
- pass
- (yield gen.Callback("k2"))("v2")
- self.assertEqual((yield gen.Wait("k2")), "v2")
- self.stop()
- self.run_gen(f)
-
- def test_orphaned_callback(self):
- @gen.engine
- def f():
- self.orphaned_callback = yield gen.Callback(1)
- try:
- self.run_gen(f)
- raise Exception("did not get expected exception")
- except gen.LeakedCallbackError:
- pass
- self.orphaned_callback()
-
- def test_none(self):
- @gen.engine
- def f():
- yield None
- self.stop()
- self.run_gen(f)
-
- def test_multi(self):
- @gen.engine
- def f():
- (yield gen.Callback("k1"))("v1")
- (yield gen.Callback("k2"))("v2")
- results = yield [gen.Wait("k1"), gen.Wait("k2")]
- self.assertEqual(results, ["v1", "v2"])
- self.stop()
- self.run_gen(f)
-
- def test_multi_dict(self):
- @gen.engine
- def f():
- (yield gen.Callback("k1"))("v1")
- (yield gen.Callback("k2"))("v2")
- results = yield dict(foo=gen.Wait("k1"), bar=gen.Wait("k2"))
- self.assertEqual(results, dict(foo="v1", bar="v2"))
- self.stop()
- self.run_gen(f)
-
- # The following tests explicitly run with both gen.Multi
- # and gen.multi_future (Task returns a Future, so it can be used
- # with either).
- def test_multi_yieldpoint_delayed(self):
- @gen.engine
- def f():
- # callbacks run at different times
- responses = yield gen.Multi([
- gen.Task(self.delay_callback, 3, arg="v1"),
- gen.Task(self.delay_callback, 1, arg="v2"),
- ])
- self.assertEqual(responses, ["v1", "v2"])
- self.stop()
- self.run_gen(f)
-
- def test_multi_yieldpoint_dict_delayed(self):
- @gen.engine
- def f():
- # callbacks run at different times
- responses = yield gen.Multi(dict(
- foo=gen.Task(self.delay_callback, 3, arg="v1"),
- bar=gen.Task(self.delay_callback, 1, arg="v2"),
- ))
- self.assertEqual(responses, dict(foo="v1", bar="v2"))
- self.stop()
- self.run_gen(f)
-
- def test_multi_future_delayed(self):
- @gen.engine
- def f():
- # callbacks run at different times
- responses = yield gen.multi_future([
- gen.Task(self.delay_callback, 3, arg="v1"),
- gen.Task(self.delay_callback, 1, arg="v2"),
- ])
- self.assertEqual(responses, ["v1", "v2"])
- self.stop()
- self.run_gen(f)
-
- def test_multi_future_dict_delayed(self):
- @gen.engine
- def f():
- # callbacks run at different times
- responses = yield gen.multi_future(dict(
- foo=gen.Task(self.delay_callback, 3, arg="v1"),
- bar=gen.Task(self.delay_callback, 1, arg="v2"),
- ))
- self.assertEqual(responses, dict(foo="v1", bar="v2"))
- self.stop()
- self.run_gen(f)
-
- @skipOnTravis
- @gen_test
- def test_multi_performance(self):
- # Yielding a list used to have quadratic performance; make
- # sure a large list stays reasonable. On my laptop a list of
- # 2000 used to take 1.8s, now it takes 0.12.
- start = time.time()
- yield [gen.Task(self.io_loop.add_callback) for i in range(2000)]
- end = time.time()
- self.assertLess(end - start, 1.0)
-
- @gen_test
- def test_multi_empty(self):
- # Empty lists or dicts should return the same type.
- x = yield []
- self.assertTrue(isinstance(x, list))
- y = yield {}
- self.assertTrue(isinstance(y, dict))
-
- @gen_test
- def test_multi_mixed_types(self):
- # A YieldPoint (Wait) and Future (Task) can be combined
- # (and use the YieldPoint codepath)
- (yield gen.Callback("k1"))("v1")
- responses = yield [gen.Wait("k1"),
- gen.Task(self.delay_callback, 3, arg="v2")]
- self.assertEqual(responses, ["v1", "v2"])
-
- @gen_test
- def test_future(self):
- result = yield self.async_future(1)
- self.assertEqual(result, 1)
-
- @gen_test
- def test_multi_future(self):
- results = yield [self.async_future(1), self.async_future(2)]
- self.assertEqual(results, [1, 2])
-
- @gen_test
- def test_multi_future_duplicate(self):
- f = self.async_future(2)
- results = yield [self.async_future(1), f, self.async_future(3), f]
- self.assertEqual(results, [1, 2, 3, 2])
-
- @gen_test
- def test_multi_dict_future(self):
- results = yield dict(foo=self.async_future(1), bar=self.async_future(2))
- self.assertEqual(results, dict(foo=1, bar=2))
-
- @gen_test
- def test_multi_exceptions(self):
- with ExpectLog(app_log, "Multiple exceptions in yield list"):
- with self.assertRaises(RuntimeError) as cm:
- yield gen.Multi([self.async_exception(RuntimeError("error 1")),
- self.async_exception(RuntimeError("error 2"))])
- self.assertEqual(str(cm.exception), "error 1")
-
- # With only one exception, no error is logged.
- with self.assertRaises(RuntimeError):
- yield gen.Multi([self.async_exception(RuntimeError("error 1")),
- self.async_future(2)])
-
- # Exception logging may be explicitly quieted.
- with self.assertRaises(RuntimeError):
- yield gen.Multi([self.async_exception(RuntimeError("error 1")),
- self.async_exception(RuntimeError("error 2"))],
- quiet_exceptions=RuntimeError)
-
- @gen_test
- def test_multi_future_exceptions(self):
- with ExpectLog(app_log, "Multiple exceptions in yield list"):
- with self.assertRaises(RuntimeError) as cm:
- yield [self.async_exception(RuntimeError("error 1")),
- self.async_exception(RuntimeError("error 2"))]
- self.assertEqual(str(cm.exception), "error 1")
-
- # With only one exception, no error is logged.
- with self.assertRaises(RuntimeError):
- yield [self.async_exception(RuntimeError("error 1")),
- self.async_future(2)]
-
- # Exception logging may be explicitly quieted.
- with self.assertRaises(RuntimeError):
- yield gen.multi_future(
- [self.async_exception(RuntimeError("error 1")),
- self.async_exception(RuntimeError("error 2"))],
- quiet_exceptions=RuntimeError)
-
- def test_arguments(self):
- @gen.engine
- def f():
- (yield gen.Callback("noargs"))()
- self.assertEqual((yield gen.Wait("noargs")), None)
- (yield gen.Callback("1arg"))(42)
- self.assertEqual((yield gen.Wait("1arg")), 42)
-
- (yield gen.Callback("kwargs"))(value=42)
- result = yield gen.Wait("kwargs")
- self.assertTrue(isinstance(result, gen.Arguments))
- self.assertEqual(((), dict(value=42)), result)
- self.assertEqual(dict(value=42), result.kwargs)
-
- (yield gen.Callback("2args"))(42, 43)
- result = yield gen.Wait("2args")
- self.assertTrue(isinstance(result, gen.Arguments))
- self.assertEqual(((42, 43), {}), result)
- self.assertEqual((42, 43), result.args)
-
- def task_func(callback):
- callback(None, error="foo")
- result = yield gen.Task(task_func)
- self.assertTrue(isinstance(result, gen.Arguments))
- self.assertEqual(((None,), dict(error="foo")), result)
-
- self.stop()
- self.run_gen(f)
-
- def test_stack_context_leak(self):
- # regression test: repeated invocations of a gen-based
- # function should not result in accumulated stack_contexts
- def _stack_depth():
- head = stack_context._state.contexts[1]
- length = 0
-
- while head is not None:
- length += 1
- head = head.old_contexts[1]
-
- return length
-
- @gen.engine
- def inner(callback):
- yield gen.Task(self.io_loop.add_callback)
- callback()
-
- @gen.engine
- def outer():
- for i in range(10):
- yield gen.Task(inner)
-
- stack_increase = _stack_depth() - initial_stack_depth
- self.assertTrue(stack_increase <= 2)
- self.stop()
- initial_stack_depth = _stack_depth()
- self.run_gen(outer)
-
- def test_stack_context_leak_exception(self):
- # same as previous, but with a function that exits with an exception
- @gen.engine
- def inner(callback):
- yield gen.Task(self.io_loop.add_callback)
- 1 / 0
-
- @gen.engine
- def outer():
- for i in range(10):
- try:
- yield gen.Task(inner)
- except ZeroDivisionError:
- pass
- stack_increase = len(stack_context._state.contexts) - initial_stack_depth
- self.assertTrue(stack_increase <= 2)
- self.stop()
- initial_stack_depth = len(stack_context._state.contexts)
- self.run_gen(outer)
-
- def function_with_stack_context(self, callback):
- # Technically this function should stack_context.wrap its callback
- # upon entry. However, it is very common for this step to be
- # omitted.
- def step2():
- self.assertEqual(self.named_contexts, ['a'])
- self.io_loop.add_callback(callback)
-
- with stack_context.StackContext(self.named_context('a')):
- self.io_loop.add_callback(step2)
-
- @gen_test
- def test_wait_transfer_stack_context(self):
- # Wait should not pick up contexts from where callback was invoked,
- # even if that function improperly fails to wrap its callback.
- cb = yield gen.Callback('k1')
- self.function_with_stack_context(cb)
- self.assertEqual(self.named_contexts, [])
- yield gen.Wait('k1')
- self.assertEqual(self.named_contexts, [])
-
- @gen_test
- def test_task_transfer_stack_context(self):
- yield gen.Task(self.function_with_stack_context)
- self.assertEqual(self.named_contexts, [])
-
- def test_raise_after_stop(self):
- # This pattern will be used in the following tests so make sure
- # the exception propagates as expected.
- @gen.engine
- def f():
- self.stop()
- 1 / 0
-
- with self.assertRaises(ZeroDivisionError):
- self.run_gen(f)
-
- def test_sync_raise_return(self):
- # gen.Return is allowed in @gen.engine, but it may not be used
- # to return a value.
- @gen.engine
- def f():
- self.stop(42)
- raise gen.Return()
-
- result = self.run_gen(f)
- self.assertEqual(result, 42)
-
- def test_async_raise_return(self):
- @gen.engine
- def f():
- yield gen.Task(self.io_loop.add_callback)
- self.stop(42)
- raise gen.Return()
-
- result = self.run_gen(f)
- self.assertEqual(result, 42)
-
- def test_sync_raise_return_value(self):
- @gen.engine
- def f():
- raise gen.Return(42)
-
- with self.assertRaises(gen.ReturnValueIgnoredError):
- self.run_gen(f)
-
- def test_sync_raise_return_value_tuple(self):
- @gen.engine
- def f():
- raise gen.Return((1, 2))
-
- with self.assertRaises(gen.ReturnValueIgnoredError):
- self.run_gen(f)
-
- def test_async_raise_return_value(self):
- @gen.engine
- def f():
- yield gen.Task(self.io_loop.add_callback)
- raise gen.Return(42)
-
- with self.assertRaises(gen.ReturnValueIgnoredError):
- self.run_gen(f)
-
- def test_async_raise_return_value_tuple(self):
- @gen.engine
- def f():
- yield gen.Task(self.io_loop.add_callback)
- raise gen.Return((1, 2))
-
- with self.assertRaises(gen.ReturnValueIgnoredError):
- self.run_gen(f)
-
- def test_return_value(self):
- # It is an error to apply @gen.engine to a function that returns
- # a value.
- @gen.engine
- def f():
- return 42
-
- with self.assertRaises(gen.ReturnValueIgnoredError):
- self.run_gen(f)
-
- def test_return_value_tuple(self):
- # It is an error to apply @gen.engine to a function that returns
- # a value.
- @gen.engine
- def f():
- return (1, 2)
-
- with self.assertRaises(gen.ReturnValueIgnoredError):
- self.run_gen(f)
+ contextvars = None # type: ignore
- @skipNotCPython
- def test_task_refcounting(self):
- # On CPython, tasks and their arguments should be released immediately
- # without waiting for garbage collection.
- @gen.engine
- def f():
- class Foo(object):
- pass
- arg = Foo()
- self.arg_ref = weakref.ref(arg)
- task = gen.Task(self.io_loop.add_callback, arg=arg)
- self.task_ref = weakref.ref(task)
- yield task
- self.stop()
+import typing
- self.run_gen(f)
- self.assertIs(self.arg_ref(), None)
- self.assertIs(self.task_ref(), None)
+if typing.TYPE_CHECKING:
+ from typing import List, Optional # noqa: F401
-# GenBasicTest duplicates the non-deprecated portions of GenEngineTest
-# with gen.coroutine to ensure we don't lose coverage when gen.engine
-# goes away.
class GenBasicTest(AsyncTestCase):
@gen.coroutine
def delay(self, iterations, arg):
@@ -678,9 +35,10 @@ def delay(self, iterations, arg):
yield gen.moment
raise gen.Return(arg)
- @return_future
- def async_future(self, result, callback):
- self.io_loop.add_callback(callback, result)
+ @gen.coroutine
+ def async_future(self, result):
+ yield gen.moment
+ return result
@gen.coroutine
def async_exception(self, e):
@@ -696,12 +54,14 @@ def test_no_yield(self):
@gen.coroutine
def f():
pass
+
self.io_loop.run_sync(f)
def test_exception_phase1(self):
@gen.coroutine
def f():
1 / 0
+
self.assertRaises(ZeroDivisionError, self.io_loop.run_sync, f)
def test_exception_phase2(self):
@@ -709,24 +69,28 @@ def test_exception_phase2(self):
def f():
yield gen.moment
1 / 0
+
self.assertRaises(ZeroDivisionError, self.io_loop.run_sync, f)
def test_bogus_yield(self):
@gen.coroutine
def f():
yield 42
+
self.assertRaises(gen.BadYieldError, self.io_loop.run_sync, f)
def test_bogus_yield_tuple(self):
@gen.coroutine
def f():
yield (1, 2)
+
self.assertRaises(gen.BadYieldError, self.io_loop.run_sync, f)
def test_reuse(self):
@gen.coroutine
def f():
yield gen.moment
+
self.io_loop.run_sync(f)
self.io_loop.run_sync(f)
@@ -734,6 +98,7 @@ def test_none(self):
@gen.coroutine
def f():
yield None
+
self.io_loop.run_sync(f)
def test_multi(self):
@@ -741,6 +106,7 @@ def test_multi(self):
def f():
results = yield [self.add_one_async(1), self.add_one_async(2)]
self.assertEqual(results, [2, 3])
+
self.io_loop.run_sync(f)
def test_multi_dict(self):
@@ -748,28 +114,29 @@ def test_multi_dict(self):
def f():
results = yield dict(foo=self.add_one_async(1), bar=self.add_one_async(2))
self.assertEqual(results, dict(foo=2, bar=3))
+
self.io_loop.run_sync(f)
def test_multi_delayed(self):
@gen.coroutine
def f():
# callbacks run at different times
- responses = yield gen.multi_future([
- self.delay(3, "v1"),
- self.delay(1, "v2"),
- ])
+ responses = yield gen.multi_future(
+ [self.delay(3, "v1"), self.delay(1, "v2")]
+ )
self.assertEqual(responses, ["v1", "v2"])
+
self.io_loop.run_sync(f)
def test_multi_dict_delayed(self):
@gen.coroutine
def f():
# callbacks run at different times
- responses = yield gen.multi_future(dict(
- foo=self.delay(3, "v1"),
- bar=self.delay(1, "v2"),
- ))
+ responses = yield gen.multi_future(
+ dict(foo=self.delay(3, "v1"), bar=self.delay(1, "v2"))
+ )
self.assertEqual(responses, dict(foo="v1", bar="v2"))
+
self.io_loop.run_sync(f)
@skipOnTravis
@@ -803,6 +170,8 @@ def test_multi_future(self):
@gen_test
def test_multi_future_duplicate(self):
+ # Note that this doesn't work with native corotines, only with
+ # decorated coroutines.
f = self.async_future(2)
results = yield [self.async_future(1), f, self.async_future(3), f]
self.assertEqual(results, [1, 2, 3, 2])
@@ -816,40 +185,53 @@ def test_multi_dict_future(self):
def test_multi_exceptions(self):
with ExpectLog(app_log, "Multiple exceptions in yield list"):
with self.assertRaises(RuntimeError) as cm:
- yield gen.Multi([self.async_exception(RuntimeError("error 1")),
- self.async_exception(RuntimeError("error 2"))])
+ yield gen.Multi(
+ [
+ self.async_exception(RuntimeError("error 1")),
+ self.async_exception(RuntimeError("error 2")),
+ ]
+ )
self.assertEqual(str(cm.exception), "error 1")
# With only one exception, no error is logged.
with self.assertRaises(RuntimeError):
- yield gen.Multi([self.async_exception(RuntimeError("error 1")),
- self.async_future(2)])
+ yield gen.Multi(
+ [self.async_exception(RuntimeError("error 1")), self.async_future(2)]
+ )
# Exception logging may be explicitly quieted.
with self.assertRaises(RuntimeError):
- yield gen.Multi([self.async_exception(RuntimeError("error 1")),
- self.async_exception(RuntimeError("error 2"))],
- quiet_exceptions=RuntimeError)
+ yield gen.Multi(
+ [
+ self.async_exception(RuntimeError("error 1")),
+ self.async_exception(RuntimeError("error 2")),
+ ],
+ quiet_exceptions=RuntimeError,
+ )
@gen_test
def test_multi_future_exceptions(self):
with ExpectLog(app_log, "Multiple exceptions in yield list"):
with self.assertRaises(RuntimeError) as cm:
- yield [self.async_exception(RuntimeError("error 1")),
- self.async_exception(RuntimeError("error 2"))]
+ yield [
+ self.async_exception(RuntimeError("error 1")),
+ self.async_exception(RuntimeError("error 2")),
+ ]
self.assertEqual(str(cm.exception), "error 1")
# With only one exception, no error is logged.
with self.assertRaises(RuntimeError):
- yield [self.async_exception(RuntimeError("error 1")),
- self.async_future(2)]
+ yield [self.async_exception(RuntimeError("error 1")), self.async_future(2)]
# Exception logging may be explicitly quieted.
with self.assertRaises(RuntimeError):
yield gen.multi_future(
- [self.async_exception(RuntimeError("error 1")),
- self.async_exception(RuntimeError("error 2"))],
- quiet_exceptions=RuntimeError)
+ [
+ self.async_exception(RuntimeError("error 1")),
+ self.async_exception(RuntimeError("error 2")),
+ ],
+ quiet_exceptions=RuntimeError,
+ )
def test_sync_raise_return(self):
@gen.coroutine
@@ -903,10 +285,10 @@ def setUp(self):
# so we need explicit checks here to make sure the tests run all
# the way through.
self.finished = False
- super(GenCoroutineTest, self).setUp()
+ super().setUp()
def tearDown(self):
- super(GenCoroutineTest, self).tearDown()
+ super().tearDown()
assert self.finished
def test_attributes(self):
@@ -918,7 +300,7 @@ def f():
coro = gen.coroutine(f)
self.assertEqual(coro.__name__, f.__name__)
self.assertEqual(coro.__module__, f.__module__)
- self.assertIs(coro.__wrapped__, f)
+ self.assertIs(coro.__wrapped__, f) # type: ignore
def test_is_coroutine_function(self):
self.finished = True
@@ -936,6 +318,7 @@ def test_sync_gen_return(self):
@gen.coroutine
def f():
raise gen.Return(42)
+
result = yield f()
self.assertEqual(result, 42)
self.finished = True
@@ -946,6 +329,7 @@ def test_async_gen_return(self):
def f():
yield gen.moment
raise gen.Return(42)
+
result = yield f()
self.assertEqual(result, 42)
self.finished = True
@@ -955,41 +339,37 @@ def test_sync_return(self):
@gen.coroutine
def f():
return 42
+
result = yield f()
self.assertEqual(result, 42)
self.finished = True
- @skipBefore33
@gen_test
def test_async_return(self):
- namespace = exec_test(globals(), locals(), """
@gen.coroutine
def f():
yield gen.moment
return 42
- """)
- result = yield namespace['f']()
+
+ result = yield f()
self.assertEqual(result, 42)
self.finished = True
- @skipBefore33
@gen_test
def test_async_early_return(self):
# A yield statement exists but is not executed, which means
# this function "returns" via an exception. This exception
# doesn't happen before the exception handling is set up.
- namespace = exec_test(globals(), locals(), """
@gen.coroutine
def f():
if True:
return 42
yield gen.Task(self.io_loop.add_callback)
- """)
- result = yield namespace['f']()
+
+ result = yield f()
self.assertEqual(result, 42)
self.finished = True
- @skipBefore35
@gen_test
def test_async_await(self):
@gen.coroutine
@@ -1000,82 +380,53 @@ def f1():
# This test verifies that an async function can await a
# yield-based gen.coroutine, and that a gen.coroutine
# (the test method itself) can yield an async function.
- namespace = exec_test(globals(), locals(), """
async def f2():
result = await f1()
return result
- """)
- result = yield namespace['f2']()
+
+ result = yield f2()
self.assertEqual(result, 42)
self.finished = True
- @skipBefore35
@gen_test
def test_asyncio_sleep_zero(self):
# asyncio.sleep(0) turns into a special case (equivalent to
# `yield None`)
- namespace = exec_test(globals(), locals(), """
async def f():
import asyncio
+
await asyncio.sleep(0)
return 42
- """)
- result = yield namespace['f']()
+
+ result = yield f()
self.assertEqual(result, 42)
self.finished = True
- @skipBefore35
@gen_test
def test_async_await_mixed_multi_native_future(self):
@gen.coroutine
def f1():
yield gen.moment
- namespace = exec_test(globals(), locals(), """
async def f2():
await f1()
return 42
- """)
@gen.coroutine
def f3():
yield gen.moment
raise gen.Return(43)
- results = yield [namespace['f2'](), f3()]
- self.assertEqual(results, [42, 43])
- self.finished = True
-
- @skipBefore35
- @gen_test
- def test_async_await_mixed_multi_native_yieldpoint(self):
- namespace = exec_test(globals(), locals(), """
- async def f1():
- await gen.Task(self.io_loop.add_callback)
- return 42
- """)
-
- @gen.coroutine
- def f2():
- yield gen.Task(self.io_loop.add_callback)
- raise gen.Return(43)
-
- with ignore_deprecation():
- f2(callback=(yield gen.Callback('cb')))
- results = yield [namespace['f1'](), gen.Wait('cb')]
+ results = yield [f2(), f3()]
self.assertEqual(results, [42, 43])
self.finished = True
- @skipBefore35
@gen_test
def test_async_with_timeout(self):
- namespace = exec_test(globals(), locals(), """
async def f1():
return 42
- """)
- result = yield gen.with_timeout(datetime.timedelta(hours=1),
- namespace['f1']())
+ result = yield gen.with_timeout(datetime.timedelta(hours=1), f1())
self.assertEqual(result, 42)
self.finished = True
@@ -1084,6 +435,7 @@ def test_sync_return_no_value(self):
@gen.coroutine
def f():
return
+
result = yield f()
self.assertEqual(result, None)
self.finished = True
@@ -1095,6 +447,7 @@ def test_async_return_no_value(self):
def f():
yield gen.moment
return
+
result = yield f()
self.assertEqual(result, None)
self.finished = True
@@ -1104,6 +457,7 @@ def test_sync_raise(self):
@gen.coroutine
def f():
1 / 0
+
# The exception is raised when the future is yielded
# (or equivalently when its result method is called),
# not when the function itself is called).
@@ -1118,21 +472,12 @@ def test_async_raise(self):
def f():
yield gen.moment
1 / 0
+
future = f()
with self.assertRaises(ZeroDivisionError):
yield future
self.finished = True
- @gen_test
- def test_pass_callback(self):
- with ignore_deprecation():
- @gen.coroutine
- def f():
- raise gen.Return(42)
- result = yield gen.Task(f)
- self.assertEqual(result, 42)
- self.finished = True
-
@gen_test
def test_replace_yieldpoint_exception(self):
# Test exception handling: a coroutine can catch one exception
@@ -1172,51 +517,6 @@ def f2():
self.assertEqual(result, 42)
self.finished = True
- @gen_test
- def test_replace_context_exception(self):
- with ignore_deprecation():
- # Test exception handling: exceptions thrown into the stack context
- # can be caught and replaced.
- # Note that this test and the following are for behavior that is
- # not really supported any more: coroutines no longer create a
- # stack context automatically; but one is created after the first
- # YieldPoint (i.e. not a Future).
- @gen.coroutine
- def f2():
- (yield gen.Callback(1))()
- yield gen.Wait(1)
- self.io_loop.add_callback(lambda: 1 / 0)
- try:
- yield gen.Task(self.io_loop.add_timeout,
- self.io_loop.time() + 10)
- except ZeroDivisionError:
- raise KeyError()
-
- future = f2()
- with self.assertRaises(KeyError):
- yield future
- self.finished = True
-
- @gen_test
- def test_swallow_context_exception(self):
- with ignore_deprecation():
- # Test exception handling: exceptions thrown into the stack context
- # can be caught and ignored.
- @gen.coroutine
- def f2():
- (yield gen.Callback(1))()
- yield gen.Wait(1)
- self.io_loop.add_callback(lambda: 1 / 0)
- try:
- yield gen.Task(self.io_loop.add_timeout,
- self.io_loop.time() + 10)
- except ZeroDivisionError:
- raise gen.Return(42)
-
- result = yield f2()
- self.assertEqual(result, 42)
- self.finished = True
-
@gen_test
def test_moment(self):
calls = []
@@ -1226,29 +526,29 @@ def f(name, yieldable):
for i in range(5):
calls.append(name)
yield yieldable
+
# First, confirm the behavior without moment: each coroutine
# monopolizes the event loop until it finishes.
- immediate = Future()
+ immediate = Future() # type: Future[None]
immediate.set_result(None)
- yield [f('a', immediate), f('b', immediate)]
- self.assertEqual(''.join(calls), 'aaaaabbbbb')
+ yield [f("a", immediate), f("b", immediate)]
+ self.assertEqual("".join(calls), "aaaaabbbbb")
# With moment, they take turns.
calls = []
- yield [f('a', gen.moment), f('b', gen.moment)]
- self.assertEqual(''.join(calls), 'ababababab')
+ yield [f("a", gen.moment), f("b", gen.moment)]
+ self.assertEqual("".join(calls), "ababababab")
self.finished = True
calls = []
- yield [f('a', gen.moment), f('b', immediate)]
- self.assertEqual(''.join(calls), 'abbbbbaaaa')
+ yield [f("a", gen.moment), f("b", immediate)]
+ self.assertEqual("".join(calls), "abbbbbaaaa")
@gen_test
def test_sleep(self):
yield gen.sleep(0.01)
self.finished = True
- @skipBefore33
@gen_test
def test_py3_leak_exception_context(self):
class LeakedException(Exception):
@@ -1273,8 +573,9 @@ def inner(iteration):
self.finished = True
@skipNotCPython
- @unittest.skipIf((3,) < sys.version_info < (3, 6),
- "asyncio.Future has reference cycles")
+ @unittest.skipIf(
+ (3,) < sys.version_info < (3, 6), "asyncio.Future has reference cycles"
+ )
def test_coroutine_refcounting(self):
# On CPython, tasks and their arguments should be released immediately
# without waiting for garbage collection.
@@ -1282,10 +583,15 @@ def test_coroutine_refcounting(self):
def inner():
class Foo(object):
pass
+
local_var = Foo()
self.local_ref = weakref.ref(local_var)
- yield gen.coroutine(lambda: None)()
- raise ValueError('Some error')
+
+ def dummy():
+ pass
+
+ yield gen.coroutine(dummy)()
+ raise ValueError("Some error")
@gen.coroutine
def inner2():
@@ -1299,8 +605,6 @@ def inner2():
self.assertIs(self.local_ref(), None)
self.finished = True
- @unittest.skipIf(sys.version_info < (3,),
- "test only relevant with asyncio Futures")
def test_asyncio_future_debug_info(self):
self.finished = True
# Enable debug mode
@@ -1315,12 +619,10 @@ def f():
self.assertIsInstance(coro, asyncio.Future)
# We expect the coroutine repr() to show the place where
# it was instantiated
- expected = ("created at %s:%d"
- % (__file__, f.__code__.co_firstlineno + 3))
+ expected = "created at %s:%d" % (__file__, f.__code__.co_firstlineno + 3)
actual = repr(coro)
self.assertIn(expected, actual)
- @unittest.skipIf(asyncio is None, "asyncio module not present")
@gen_test
def test_asyncio_gather(self):
# This demonstrates that tornado coroutines can be understood
@@ -1335,27 +637,6 @@ def f():
self.finished = True
-class GenSequenceHandler(RequestHandler):
- with ignore_deprecation():
- @asynchronous
- @gen.engine
- def get(self):
- # The outer ignore_deprecation applies at definition time.
- # We need another for serving time.
- with ignore_deprecation():
- self.io_loop = self.request.connection.stream.io_loop
- self.io_loop.add_callback((yield gen.Callback("k1")))
- yield gen.Wait("k1")
- self.write("1")
- self.io_loop.add_callback((yield gen.Callback("k2")))
- yield gen.Wait("k2")
- self.write("2")
- # reuse an old key
- self.io_loop.add_callback((yield gen.Callback("k1")))
- yield gen.Wait("k1")
- self.finish("3")
-
-
class GenCoroutineSequenceHandler(RequestHandler):
@gen.coroutine
def get(self):
@@ -1368,7 +649,6 @@ def get(self):
class GenCoroutineUnfinishedSequenceHandler(RequestHandler):
- @asynchronous
@gen.coroutine
def get(self):
yield gen.moment
@@ -1380,66 +660,21 @@ def get(self):
self.write("3")
-class GenTaskHandler(RequestHandler):
- @gen.coroutine
- def get(self):
- client = AsyncHTTPClient()
- with ignore_deprecation():
- response = yield gen.Task(client.fetch, self.get_argument('url'))
- response.rethrow()
- self.finish(b"got response: " + response.body)
-
-
-class GenExceptionHandler(RequestHandler):
- with ignore_deprecation():
- @asynchronous
- @gen.engine
- def get(self):
- # This test depends on the order of the two decorators.
- io_loop = self.request.connection.stream.io_loop
- yield gen.Task(io_loop.add_callback)
- raise Exception("oops")
-
-
-class GenCoroutineExceptionHandler(RequestHandler):
- @gen.coroutine
- def get(self):
- # This test depends on the order of the two decorators.
- io_loop = self.request.connection.stream.io_loop
- yield gen.Task(io_loop.add_callback)
- raise Exception("oops")
-
-
-class GenYieldExceptionHandler(RequestHandler):
- @gen.coroutine
- def get(self):
- io_loop = self.request.connection.stream.io_loop
- # Test the interaction of the two stack_contexts.
- with ignore_deprecation():
- def fail_task(callback):
- io_loop.add_callback(lambda: 1 / 0)
- try:
- yield gen.Task(fail_task)
- raise Exception("did not get expected exception")
- except ZeroDivisionError:
- self.finish('ok')
-
-
# "Undecorated" here refers to the absence of @asynchronous.
class UndecoratedCoroutinesHandler(RequestHandler):
@gen.coroutine
def prepare(self):
- self.chunks = []
+ self.chunks = [] # type: List[str]
yield gen.moment
- self.chunks.append('1')
+ self.chunks.append("1")
@gen.coroutine
def get(self):
- self.chunks.append('2')
+ self.chunks.append("2")
yield gen.moment
- self.chunks.append('3')
+ self.chunks.append("3")
yield gen.moment
- self.write(''.join(self.chunks))
+ self.write("".join(self.chunks))
class AsyncPrepareErrorHandler(RequestHandler):
@@ -1449,161 +684,133 @@ def prepare(self):
raise HTTPError(403)
def get(self):
- self.finish('ok')
+ self.finish("ok")
class NativeCoroutineHandler(RequestHandler):
- if sys.version_info > (3, 5):
- exec(textwrap.dedent("""
- async def get(self):
- import asyncio
- await asyncio.sleep(0)
- self.write("ok")
- """))
+ async def get(self):
+ await asyncio.sleep(0)
+ self.write("ok")
class GenWebTest(AsyncHTTPTestCase):
def get_app(self):
- return Application([
- ('/sequence', GenSequenceHandler),
- ('/coroutine_sequence', GenCoroutineSequenceHandler),
- ('/coroutine_unfinished_sequence',
- GenCoroutineUnfinishedSequenceHandler),
- ('/task', GenTaskHandler),
- ('/exception', GenExceptionHandler),
- ('/coroutine_exception', GenCoroutineExceptionHandler),
- ('/yield_exception', GenYieldExceptionHandler),
- ('/undecorated_coroutine', UndecoratedCoroutinesHandler),
- ('/async_prepare_error', AsyncPrepareErrorHandler),
- ('/native_coroutine', NativeCoroutineHandler),
- ])
-
- def test_sequence_handler(self):
- response = self.fetch('/sequence')
- self.assertEqual(response.body, b"123")
+ return Application(
+ [
+ ("/coroutine_sequence", GenCoroutineSequenceHandler),
+ (
+ "/coroutine_unfinished_sequence",
+ GenCoroutineUnfinishedSequenceHandler,
+ ),
+ ("/undecorated_coroutine", UndecoratedCoroutinesHandler),
+ ("/async_prepare_error", AsyncPrepareErrorHandler),
+ ("/native_coroutine", NativeCoroutineHandler),
+ ]
+ )
def test_coroutine_sequence_handler(self):
- response = self.fetch('/coroutine_sequence')
+ response = self.fetch("/coroutine_sequence")
self.assertEqual(response.body, b"123")
def test_coroutine_unfinished_sequence_handler(self):
- response = self.fetch('/coroutine_unfinished_sequence')
+ response = self.fetch("/coroutine_unfinished_sequence")
self.assertEqual(response.body, b"123")
- def test_task_handler(self):
- response = self.fetch('/task?url=%s' % url_escape(self.get_url('/sequence')))
- self.assertEqual(response.body, b"got response: 123")
-
- def test_exception_handler(self):
- # Make sure we get an error and not a timeout
- with ExpectLog(app_log, "Uncaught exception GET /exception"):
- response = self.fetch('/exception')
- self.assertEqual(500, response.code)
-
- def test_coroutine_exception_handler(self):
- # Make sure we get an error and not a timeout
- with ExpectLog(app_log, "Uncaught exception GET /coroutine_exception"):
- response = self.fetch('/coroutine_exception')
- self.assertEqual(500, response.code)
-
- def test_yield_exception_handler(self):
- response = self.fetch('/yield_exception')
- self.assertEqual(response.body, b'ok')
-
def test_undecorated_coroutines(self):
- response = self.fetch('/undecorated_coroutine')
- self.assertEqual(response.body, b'123')
+ response = self.fetch("/undecorated_coroutine")
+ self.assertEqual(response.body, b"123")
def test_async_prepare_error_handler(self):
- response = self.fetch('/async_prepare_error')
+ response = self.fetch("/async_prepare_error")
self.assertEqual(response.code, 403)
- @skipBefore35
def test_native_coroutine_handler(self):
- response = self.fetch('/native_coroutine')
+ response = self.fetch("/native_coroutine")
self.assertEqual(response.code, 200)
- self.assertEqual(response.body, b'ok')
+ self.assertEqual(response.body, b"ok")
class WithTimeoutTest(AsyncTestCase):
@gen_test
def test_timeout(self):
with self.assertRaises(gen.TimeoutError):
- yield gen.with_timeout(datetime.timedelta(seconds=0.1),
- Future())
+ yield gen.with_timeout(datetime.timedelta(seconds=0.1), Future())
@gen_test
def test_completes_before_timeout(self):
- future = Future()
- self.io_loop.add_timeout(datetime.timedelta(seconds=0.1),
- lambda: future.set_result('asdf'))
- result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
- future)
- self.assertEqual(result, 'asdf')
+ future = Future() # type: Future[str]
+ self.io_loop.add_timeout(
+ datetime.timedelta(seconds=0.1), lambda: future.set_result("asdf")
+ )
+ result = yield gen.with_timeout(datetime.timedelta(seconds=3600), future)
+ self.assertEqual(result, "asdf")
@gen_test
def test_fails_before_timeout(self):
- future = Future()
+ future = Future() # type: Future[str]
self.io_loop.add_timeout(
datetime.timedelta(seconds=0.1),
- lambda: future.set_exception(ZeroDivisionError()))
+ lambda: future.set_exception(ZeroDivisionError()),
+ )
with self.assertRaises(ZeroDivisionError):
- yield gen.with_timeout(datetime.timedelta(seconds=3600),
- future)
+ yield gen.with_timeout(datetime.timedelta(seconds=3600), future)
@gen_test
def test_already_resolved(self):
- future = Future()
- future.set_result('asdf')
- result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
- future)
- self.assertEqual(result, 'asdf')
+ future = Future() # type: Future[str]
+ future.set_result("asdf")
+ result = yield gen.with_timeout(datetime.timedelta(seconds=3600), future)
+ self.assertEqual(result, "asdf")
- @unittest.skipIf(futures is None, 'futures module not present')
@gen_test
def test_timeout_concurrent_future(self):
# A concurrent future that does not resolve before the timeout.
with futures.ThreadPoolExecutor(1) as executor:
with self.assertRaises(gen.TimeoutError):
- yield gen.with_timeout(self.io_loop.time(),
- executor.submit(time.sleep, 0.1))
+ yield gen.with_timeout(
+ self.io_loop.time(), executor.submit(time.sleep, 0.1)
+ )
- @unittest.skipIf(futures is None, 'futures module not present')
@gen_test
def test_completed_concurrent_future(self):
# A concurrent future that is resolved before we even submit it
# to with_timeout.
with futures.ThreadPoolExecutor(1) as executor:
- f = executor.submit(lambda: None)
+
+ def dummy():
+ pass
+
+ f = executor.submit(dummy)
f.result() # wait for completion
yield gen.with_timeout(datetime.timedelta(seconds=3600), f)
- @unittest.skipIf(futures is None, 'futures module not present')
@gen_test
def test_normal_concurrent_future(self):
# A conccurrent future that resolves while waiting for the timeout.
with futures.ThreadPoolExecutor(1) as executor:
- yield gen.with_timeout(datetime.timedelta(seconds=3600),
- executor.submit(lambda: time.sleep(0.01)))
+ yield gen.with_timeout(
+ datetime.timedelta(seconds=3600),
+ executor.submit(lambda: time.sleep(0.01)),
+ )
class WaitIteratorTest(AsyncTestCase):
@gen_test
def test_empty_iterator(self):
g = gen.WaitIterator()
- self.assertTrue(g.done(), 'empty generator iterated')
+ self.assertTrue(g.done(), "empty generator iterated")
with self.assertRaises(ValueError):
- g = gen.WaitIterator(False, bar=False)
+ g = gen.WaitIterator(Future(), bar=Future())
self.assertEqual(g.current_index, None, "bad nil current index")
self.assertEqual(g.current_future, None, "bad nil current future")
@gen_test
def test_already_done(self):
- f1 = Future()
- f2 = Future()
- f3 = Future()
+ f1 = Future() # type: Future[int]
+ f2 = Future() # type: Future[int]
+ f3 = Future() # type: Future[int]
f1.set_result(24)
f2.set_result(42)
f3.set_result(84)
@@ -1636,14 +843,17 @@ def test_already_done(self):
while not dg.done():
dr = yield dg.next()
if dg.current_index == "f1":
- self.assertTrue(dg.current_future == f1 and dr == 24,
- "WaitIterator dict status incorrect")
+ self.assertTrue(
+ dg.current_future == f1 and dr == 24,
+ "WaitIterator dict status incorrect",
+ )
elif dg.current_index == "f2":
- self.assertTrue(dg.current_future == f2 and dr == 42,
- "WaitIterator dict status incorrect")
+ self.assertTrue(
+ dg.current_future == f2 and dr == 42,
+ "WaitIterator dict status incorrect",
+ )
else:
- self.fail("got bad WaitIterator index {}".format(
- dg.current_index))
+ self.fail("got bad WaitIterator index {}".format(dg.current_index))
i += 1
@@ -1664,7 +874,7 @@ def finish_coroutines(self, iteration, futures):
@gen_test
def test_iterator(self):
- futures = [Future(), Future(), Future(), Future()]
+ futures = [Future(), Future(), Future(), Future()] # type: List[Future[int]]
self.finish_coroutines(0, futures)
@@ -1675,39 +885,36 @@ def test_iterator(self):
try:
r = yield g.next()
except ZeroDivisionError:
- self.assertIs(g.current_future, futures[0],
- 'exception future invalid')
+ self.assertIs(g.current_future, futures[0], "exception future invalid")
else:
if i == 0:
- self.assertEqual(r, 24, 'iterator value incorrect')
- self.assertEqual(g.current_index, 2, 'wrong index')
+ self.assertEqual(r, 24, "iterator value incorrect")
+ self.assertEqual(g.current_index, 2, "wrong index")
elif i == 2:
- self.assertEqual(r, 42, 'iterator value incorrect')
- self.assertEqual(g.current_index, 1, 'wrong index')
+ self.assertEqual(r, 42, "iterator value incorrect")
+ self.assertEqual(g.current_index, 1, "wrong index")
elif i == 3:
- self.assertEqual(r, 84, 'iterator value incorrect')
- self.assertEqual(g.current_index, 3, 'wrong index')
+ self.assertEqual(r, 84, "iterator value incorrect")
+ self.assertEqual(g.current_index, 3, "wrong index")
i += 1
- @skipBefore35
@gen_test
def test_iterator_async_await(self):
# Recreate the previous test with py35 syntax. It's a little clunky
# because of the way the previous test handles an exception on
# a single iteration.
- futures = [Future(), Future(), Future(), Future()]
+ futures = [Future(), Future(), Future(), Future()] # type: List[Future[int]]
self.finish_coroutines(0, futures)
self.finished = False
- namespace = exec_test(globals(), locals(), """
async def f():
i = 0
g = gen.WaitIterator(*futures)
try:
async for r in g:
if i == 0:
- self.assertEqual(r, 24, 'iterator value incorrect')
- self.assertEqual(g.current_index, 2, 'wrong index')
+ self.assertEqual(r, 24, "iterator value incorrect")
+ self.assertEqual(g.current_index, 2, "wrong index")
else:
raise Exception("expected exception on iteration 1")
i += 1
@@ -1715,17 +922,17 @@ async def f():
i += 1
async for r in g:
if i == 2:
- self.assertEqual(r, 42, 'iterator value incorrect')
- self.assertEqual(g.current_index, 1, 'wrong index')
+ self.assertEqual(r, 42, "iterator value incorrect")
+ self.assertEqual(g.current_index, 1, "wrong index")
elif i == 3:
- self.assertEqual(r, 84, 'iterator value incorrect')
- self.assertEqual(g.current_index, 3, 'wrong index')
+ self.assertEqual(r, 84, "iterator value incorrect")
+ self.assertEqual(g.current_index, 3, "wrong index")
else:
raise Exception("didn't expect iteration %d" % i)
i += 1
self.finished = True
- """)
- yield namespace['f']()
+
+ yield f()
self.assertTrue(self.finished)
@gen_test
@@ -1734,46 +941,40 @@ def test_no_ref(self):
# WaitIterator itself, only the Future it returns. Since
# WaitIterator uses weak references internally to improve GC
# performance, this used to cause problems.
- yield gen.with_timeout(datetime.timedelta(seconds=0.1),
- gen.WaitIterator(gen.sleep(0)).next())
+ yield gen.with_timeout(
+ datetime.timedelta(seconds=0.1), gen.WaitIterator(gen.sleep(0)).next()
+ )
class RunnerGCTest(AsyncTestCase):
def is_pypy3(self):
- return (platform.python_implementation() == 'PyPy' and
- sys.version_info > (3,))
+ return platform.python_implementation() == "PyPy" and sys.version_info > (3,)
@gen_test
def test_gc(self):
- # Github issue 1769: Runner objects can get GCed unexpectedly
+ # GitHub issue 1769: Runner objects can get GCed unexpectedly
# while their future is alive.
- weakref_scope = [None]
+ weakref_scope = [None] # type: List[Optional[weakref.ReferenceType]]
def callback():
gc.collect(2)
- weakref_scope[0]().set_result(123)
+ weakref_scope[0]().set_result(123) # type: ignore
@gen.coroutine
def tester():
- fut = Future()
+ fut = Future() # type: Future[int]
weakref_scope[0] = weakref.ref(fut)
self.io_loop.add_callback(callback)
yield fut
- yield gen.with_timeout(
- datetime.timedelta(seconds=0.2),
- tester()
- )
+ yield gen.with_timeout(datetime.timedelta(seconds=0.2), tester())
def test_gc_infinite_coro(self):
- # Github issue 2229: suspended coroutines should be GCed when
+ # GitHub issue 2229: suspended coroutines should be GCed when
# their loop is closed, even if they're involved in a reference
# cycle.
- if IOLoop.configured_class().__name__.endswith('TwistedIOLoop'):
- raise unittest.SkipTest("Test may fail on TwistedIOLoop")
-
loop = self.get_new_ioloop()
- result = []
+ result = [] # type: List[Optional[bool]]
wfut = []
@gen.coroutine
@@ -1789,7 +990,7 @@ def infinite_coro():
@gen.coroutine
def do_something():
fut = infinite_coro()
- fut._refcycle = fut
+ fut._refcycle = fut # type: ignore
wfut.append(weakref.ref(fut))
yield gen.sleep(0.2)
@@ -1804,12 +1005,10 @@ def do_something():
# coroutine finalizer was called (not on PyPy3 apparently)
self.assertIs(result[-1], None)
- @skipBefore35
def test_gc_infinite_async_await(self):
# Same as test_gc_infinite_coro, but with a `async def` function
import asyncio
- namespace = exec_test(globals(), locals(), """
async def infinite_coro(result):
try:
while True:
@@ -1818,22 +1017,20 @@ async def infinite_coro(result):
finally:
# coroutine finalizer
result.append(None)
- """)
- infinite_coro = namespace['infinite_coro']
loop = self.get_new_ioloop()
- result = []
+ result = [] # type: List[Optional[bool]]
wfut = []
@gen.coroutine
def do_something():
fut = asyncio.get_event_loop().create_task(infinite_coro(result))
- fut._refcycle = fut
+ fut._refcycle = fut # type: ignore
wfut.append(weakref.ref(fut))
yield gen.sleep(0.2)
loop.run_sync(do_something)
- with ExpectLog('asyncio', "Task was destroyed but it is pending"):
+ with ExpectLog("asyncio", "Task was destroyed but it is pending"):
loop.close()
gc.collect()
# Future was collected
@@ -1857,5 +1054,66 @@ def wait_a_moment():
self.assertEqual(result, [None, None])
-if __name__ == '__main__':
+if contextvars is not None:
+ ctx_var = contextvars.ContextVar("ctx_var") # type: contextvars.ContextVar[int]
+
+
+@unittest.skipIf(contextvars is None, "contextvars module not present")
+class ContextVarsTest(AsyncTestCase):
+ async def native_root(self, x):
+ ctx_var.set(x)
+ await self.inner(x)
+
+ @gen.coroutine
+ def gen_root(self, x):
+ ctx_var.set(x)
+ yield
+ yield self.inner(x)
+
+ async def inner(self, x):
+ self.assertEqual(ctx_var.get(), x)
+ await self.gen_inner(x)
+ self.assertEqual(ctx_var.get(), x)
+
+ # IOLoop.run_in_executor doesn't automatically copy context
+ ctx = contextvars.copy_context()
+ await self.io_loop.run_in_executor(None, lambda: ctx.run(self.thread_inner, x))
+ self.assertEqual(ctx_var.get(), x)
+
+ # Neither does asyncio's run_in_executor.
+ await asyncio.get_event_loop().run_in_executor(
+ None, lambda: ctx.run(self.thread_inner, x)
+ )
+ self.assertEqual(ctx_var.get(), x)
+
+ @gen.coroutine
+ def gen_inner(self, x):
+ self.assertEqual(ctx_var.get(), x)
+ yield
+ self.assertEqual(ctx_var.get(), x)
+
+ def thread_inner(self, x):
+ self.assertEqual(ctx_var.get(), x)
+
+ @gen_test
+ def test_propagate(self):
+ # Verify that context vars get propagated across various
+ # combinations of native and decorated coroutines.
+ yield [
+ self.native_root(1),
+ self.native_root(2),
+ self.gen_root(3),
+ self.gen_root(4),
+ ]
+
+ @gen_test
+ def test_reset(self):
+ token = ctx_var.set(1)
+ yield
+ # reset asserts that we are still at the same level of the context tree,
+ # so we must make sure that we maintain that property across yield.
+ ctx_var.reset(token)
+
+
+if __name__ == "__main__":
unittest.main()
diff --git a/tornado/test/gettext_translations/extract_me.py b/tornado/test/gettext_translations/extract_me.py
index 283c13f413..08b29bc53c 100644
--- a/tornado/test/gettext_translations/extract_me.py
+++ b/tornado/test/gettext_translations/extract_me.py
@@ -8,7 +8,6 @@
# 3) msgfmt tornado_test.po -o tornado_test.mo
# 4) Put the file in the proper location: $LANG/LC_MESSAGES
-from __future__ import absolute_import, division, print_function
_("school")
pgettext("law", "right")
pgettext("good", "right")
diff --git a/tornado/test/http1connection_test.py b/tornado/test/http1connection_test.py
index 8aaaaf35b7..d21d506228 100644
--- a/tornado/test/http1connection_test.py
+++ b/tornado/test/http1connection_test.py
@@ -1,6 +1,5 @@
-from __future__ import absolute_import, division, print_function
-
import socket
+import typing
from tornado.http1connection import HTTP1Connection
from tornado.httputil import HTTPMessageDelegate
@@ -11,8 +10,10 @@
class HTTP1ConnectionTest(AsyncTestCase):
+ code = None # type: typing.Optional[int]
+
def setUp(self):
- super(HTTP1ConnectionTest, self).setUp()
+ super().setUp()
self.asyncSetUp()
@gen_test
@@ -28,8 +29,7 @@ def accept_callback(conn, addr):
add_accept_handler(listener, accept_callback)
self.client_stream = IOStream(socket.socket())
self.addCleanup(self.client_stream.close)
- yield [self.client_stream.connect(('127.0.0.1', port)),
- event.wait()]
+ yield [self.client_stream.connect(("127.0.0.1", port)), event.wait()]
self.io_loop.remove_handler(listener)
listener.close()
@@ -58,4 +58,4 @@ def finish(self):
yield conn.read_response(Delegate())
yield event.wait()
self.assertEqual(self.code, 200)
- self.assertEqual(b''.join(body), b'hello')
+ self.assertEqual(b"".join(body), b"hello")
diff --git a/tornado/test/httpclient_test.py b/tornado/test/httpclient_test.py
index 851f126885..fd9a978640 100644
--- a/tornado/test/httpclient_test.py
+++ b/tornado/test/httpclient_test.py
@@ -1,25 +1,34 @@
-from __future__ import absolute_import, division, print_function
-
import base64
import binascii
from contextlib import closing
import copy
-import sys
+import gzip
import threading
import datetime
from io import BytesIO
+import subprocess
+import sys
+import time
+import typing # noqa: F401
+import unicodedata
+import unittest
-from tornado.escape import utf8, native_str
+from tornado.escape import utf8, native_str, to_unicode
from tornado import gen
-from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient
+from tornado.httpclient import (
+ HTTPRequest,
+ HTTPResponse,
+ _RequestProxy,
+ HTTPError,
+ HTTPClient,
+)
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream
-from tornado.log import gen_log
+from tornado.log import gen_log, app_log
from tornado import netutil
-from tornado.stack_context import ExceptionStackContext, NullContext
from tornado.testing import AsyncHTTPTestCase, bind_unused_port, gen_test, ExpectLog
-from tornado.test.util import unittest, skipOnTravis, ignore_deprecation
+from tornado.test.util import skipOnTravis
from tornado.web import Application, RequestHandler, url
from tornado.httputil import format_timestamp, HTTPHeaders
@@ -33,8 +42,10 @@ def get(self):
class PostHandler(RequestHandler):
def post(self):
- self.finish("Post arg1: %s, arg2: %s" % (
- self.get_argument("arg1"), self.get_argument("arg2")))
+ self.finish(
+ "Post arg1: %s, arg2: %s"
+ % (self.get_argument("arg1"), self.get_argument("arg2"))
+ )
class PutHandler(RequestHandler):
@@ -45,9 +56,17 @@ def put(self):
class RedirectHandler(RequestHandler):
def prepare(self):
- self.write('redirects can have bodies too')
- self.redirect(self.get_argument("url"),
- status=int(self.get_argument("status", "302")))
+ self.write("redirects can have bodies too")
+ self.redirect(
+ self.get_argument("url"), status=int(self.get_argument("status", "302"))
+ )
+
+
+class RedirectWithoutLocationHandler(RequestHandler):
+ def prepare(self):
+ # For testing error handling of a redirect with no location header.
+ self.set_status(301)
+ self.finish()
class ChunkHandler(RequestHandler):
@@ -81,44 +100,57 @@ def post(self):
class UserAgentHandler(RequestHandler):
def get(self):
- self.write(self.request.headers.get('User-Agent', 'User agent not set'))
+ self.write(self.request.headers.get("User-Agent", "User agent not set"))
class ContentLength304Handler(RequestHandler):
def get(self):
self.set_status(304)
- self.set_header('Content-Length', 42)
+ self.set_header("Content-Length", 42)
- def _clear_headers_for_304(self):
+ def _clear_representation_headers(self):
# Tornado strips content-length from 304 responses, but here we
# want to simulate servers that include the headers anyway.
pass
class PatchHandler(RequestHandler):
-
def patch(self):
"Return the request payload - so we can check it is being kept"
self.write(self.request.body)
class AllMethodsHandler(RequestHandler):
- SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ('OTHER',)
+ SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ("OTHER",) # type: ignore
def method(self):
+ assert self.request.method is not None
self.write(self.request.method)
- get = post = put = delete = options = patch = other = method
+ get = head = post = put = delete = options = patch = other = method # type: ignore
class SetHeaderHandler(RequestHandler):
def get(self):
# Use get_arguments for keys to get strings, but
# request.arguments for values to get bytes.
- for k, v in zip(self.get_arguments('k'),
- self.request.arguments['v']):
+ for k, v in zip(self.get_arguments("k"), self.request.arguments["v"]):
self.set_header(k, v)
+
+class InvalidGzipHandler(RequestHandler):
+ def get(self):
+ # set Content-Encoding manually to avoid automatic gzip encoding
+ self.set_header("Content-Type", "text/plain")
+ self.set_header("Content-Encoding", "gzip")
+ # Triggering the potential bug seems to depend on input length.
+ # This length is taken from the bad-response example reported in
+ # https://github.com/tornadoweb/tornado/pull/2875 (uncompressed).
+ body = "".join("Hello World {}\n".format(i) for i in range(9000))[:149051]
+ body = gzip.compress(body.encode(), compresslevel=6) + b"\00"
+ self.write(body)
+
+
# These tests end up getting run redundantly: once here with the default
# HTTPClient implementation, and then again in each implementation's own
# test suite.
@@ -126,25 +158,30 @@ def get(self):
class HTTPClientCommonTestCase(AsyncHTTPTestCase):
def get_app(self):
- return Application([
- url("/hello", HelloWorldHandler),
- url("/post", PostHandler),
- url("/put", PutHandler),
- url("/redirect", RedirectHandler),
- url("/chunk", ChunkHandler),
- url("/auth", AuthHandler),
- url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
- url("/echopost", EchoPostHandler),
- url("/user_agent", UserAgentHandler),
- url("/304_with_content_length", ContentLength304Handler),
- url("/all_methods", AllMethodsHandler),
- url('/patch', PatchHandler),
- url('/set_header', SetHeaderHandler),
- ], gzip=True)
+ return Application(
+ [
+ url("/hello", HelloWorldHandler),
+ url("/post", PostHandler),
+ url("/put", PutHandler),
+ url("/redirect", RedirectHandler),
+ url("/redirect_without_location", RedirectWithoutLocationHandler),
+ url("/chunk", ChunkHandler),
+ url("/auth", AuthHandler),
+ url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
+ url("/echopost", EchoPostHandler),
+ url("/user_agent", UserAgentHandler),
+ url("/304_with_content_length", ContentLength304Handler),
+ url("/all_methods", AllMethodsHandler),
+ url("/patch", PatchHandler),
+ url("/set_header", SetHeaderHandler),
+ url("/invalid_gzip", InvalidGzipHandler),
+ ],
+ gzip=True,
+ )
def test_patch_receives_payload(self):
body = b"some patch data"
- response = self.fetch("/patch", method='PATCH', body=body)
+ response = self.fetch("/patch", method="PATCH", body=body)
self.assertEqual(response.code, 200)
self.assertEqual(response.body, body)
@@ -154,6 +191,7 @@ def test_hello_world(self):
self.assertEqual(response.code, 200)
self.assertEqual(response.headers["Content-Type"], "text/plain")
self.assertEqual(response.body, b"Hello world!")
+ assert response.request_time is not None
self.assertEqual(int(response.request_time), 0)
response = self.fetch("/hello?name=Ben")
@@ -161,16 +199,14 @@ def test_hello_world(self):
def test_streaming_callback(self):
# streaming_callback is also tested in test_chunked
- chunks = []
- response = self.fetch("/hello",
- streaming_callback=chunks.append)
+ chunks = [] # type: typing.List[bytes]
+ response = self.fetch("/hello", streaming_callback=chunks.append)
# with streaming_callback, data goes to the callback and not response.body
self.assertEqual(chunks, [b"Hello world!"])
self.assertFalse(response.body)
def test_post(self):
- response = self.fetch("/post", method="POST",
- body="arg1=foo&arg2=bar")
+ response = self.fetch("/post", method="POST", body="arg1=foo&arg2=bar")
self.assertEqual(response.code, 200)
self.assertEqual(response.body, b"Post arg1: foo, arg2: bar")
@@ -178,9 +214,8 @@ def test_chunked(self):
response = self.fetch("/chunk")
self.assertEqual(response.body, b"asdfqwer")
- chunks = []
- response = self.fetch("/chunk",
- streaming_callback=chunks.append)
+ chunks = [] # type: typing.List[bytes]
+ response = self.fetch("/chunk", streaming_callback=chunks.append)
self.assertEqual(chunks, [b"asdf", b"qwer"])
self.assertFalse(response.body)
@@ -189,6 +224,7 @@ def test_chunked_close(self):
# over several ioloop iterations, but the connection is already closed.
sock, port = bind_unused_port()
with closing(sock):
+
@gen.coroutine
def accept_callback(conn, address):
# fake an HTTP server using chunked encoding where the final chunks
@@ -197,7 +233,8 @@ def accept_callback(conn, address):
request_data = yield stream.read_until(b"\r\n\r\n")
if b"HTTP/1." not in request_data:
self.skipTest("requires HTTP/1.x")
- yield stream.write(b"""\
+ yield stream.write(
+ b"""\
HTTP/1.1 200 OK
Transfer-Encoding: chunked
@@ -207,55 +244,66 @@ def accept_callback(conn, address):
2
0
-""".replace(b"\n", b"\r\n"))
+""".replace(
+ b"\n", b"\r\n"
+ )
+ )
stream.close()
- netutil.add_accept_handler(sock, accept_callback)
+
+ netutil.add_accept_handler(sock, accept_callback) # type: ignore
resp = self.fetch("http://127.0.0.1:%d/" % port)
resp.rethrow()
self.assertEqual(resp.body, b"12")
self.io_loop.remove_handler(sock.fileno())
- def test_streaming_stack_context(self):
- chunks = []
- exc_info = []
-
- def error_handler(typ, value, tb):
- exc_info.append((typ, value, tb))
- return True
-
- def streaming_cb(chunk):
- chunks.append(chunk)
- if chunk == b'qwer':
- 1 / 0
-
- with ExceptionStackContext(error_handler):
- self.fetch('/chunk', streaming_callback=streaming_cb)
-
- self.assertEqual(chunks, [b'asdf', b'qwer'])
- self.assertEqual(1, len(exc_info))
- self.assertIs(exc_info[0][0], ZeroDivisionError)
-
def test_basic_auth(self):
- self.assertEqual(self.fetch("/auth", auth_username="Aladdin",
- auth_password="open sesame").body,
- b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
+ # This test data appears in section 2 of RFC 7617.
+ self.assertEqual(
+ self.fetch(
+ "/auth", auth_username="Aladdin", auth_password="open sesame"
+ ).body,
+ b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==",
+ )
def test_basic_auth_explicit_mode(self):
- self.assertEqual(self.fetch("/auth", auth_username="Aladdin",
- auth_password="open sesame",
- auth_mode="basic").body,
- b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
+ self.assertEqual(
+ self.fetch(
+ "/auth",
+ auth_username="Aladdin",
+ auth_password="open sesame",
+ auth_mode="basic",
+ ).body,
+ b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==",
+ )
+
+ def test_basic_auth_unicode(self):
+ # This test data appears in section 2.1 of RFC 7617.
+ self.assertEqual(
+ self.fetch("/auth", auth_username="test", auth_password="123£").body,
+ b"Basic dGVzdDoxMjPCow==",
+ )
+
+ # The standard mandates NFC. Give it a decomposed username
+ # and ensure it is normalized to composed form.
+ username = unicodedata.normalize("NFD", u"josé")
+ self.assertEqual(
+ self.fetch("/auth", auth_username=username, auth_password="səcrət").body,
+ b"Basic am9zw6k6c8mZY3LJmXQ=",
+ )
def test_unsupported_auth_mode(self):
# curl and simple clients handle errors a bit differently; the
# important thing is that they don't fall back to basic auth
# on an unknown mode.
with ExpectLog(gen_log, "uncaught exception", required=False):
- with self.assertRaises((ValueError, HTTPError)):
- self.fetch("/auth", auth_username="Aladdin",
- auth_password="open sesame",
- auth_mode="asdf",
- raise_error=True)
+ with self.assertRaises((ValueError, HTTPError)): # type: ignore
+ self.fetch(
+ "/auth",
+ auth_username="Aladdin",
+ auth_password="open sesame",
+ auth_mode="asdf",
+ raise_error=True,
+ )
def test_follow_redirect(self):
response = self.fetch("/countdown/2", follow_redirects=False)
@@ -267,34 +315,97 @@ def test_follow_redirect(self):
self.assertTrue(response.effective_url.endswith("/countdown/0"))
self.assertEqual(b"Zero", response.body)
+ def test_redirect_without_location(self):
+ response = self.fetch("/redirect_without_location", follow_redirects=True)
+ # If there is no location header, the redirect response should
+ # just be returned as-is. (This should arguably raise an
+ # error, but libcurl doesn't treat this as an error, so we
+ # don't either).
+ self.assertEqual(301, response.code)
+
+ def test_redirect_put_with_body(self):
+ response = self.fetch(
+ "/redirect?url=/put&status=307", method="PUT", body="hello"
+ )
+ self.assertEqual(response.body, b"Put body: hello")
+
+ def test_redirect_put_without_body(self):
+ # This "without body" edge case is similar to what happens with body_producer.
+ response = self.fetch(
+ "/redirect?url=/put&status=307",
+ method="PUT",
+ allow_nonstandard_methods=True,
+ )
+ self.assertEqual(response.body, b"Put body: ")
+
+ def test_method_after_redirect(self):
+ # Legacy redirect codes (301, 302) convert POST requests to GET.
+ for status in [301, 302, 303]:
+ url = "/redirect?url=/all_methods&status=%d" % status
+ resp = self.fetch(url, method="POST", body=b"")
+ self.assertEqual(b"GET", resp.body)
+
+ # Other methods are left alone, except for 303 redirect, depending on client
+ for method in ["GET", "OPTIONS", "PUT", "DELETE"]:
+ resp = self.fetch(url, method=method, allow_nonstandard_methods=True)
+ if status in [301, 302]:
+ self.assertEqual(utf8(method), resp.body)
+ else:
+ self.assertIn(resp.body, [utf8(method), b"GET"])
+
+ # HEAD is different so check it separately.
+ resp = self.fetch(url, method="HEAD")
+ self.assertEqual(200, resp.code)
+ self.assertEqual(b"", resp.body)
+
+ # Newer redirects always preserve the original method.
+ for status in [307, 308]:
+ url = "/redirect?url=/all_methods&status=307"
+ for method in ["GET", "OPTIONS", "POST", "PUT", "DELETE"]:
+ resp = self.fetch(url, method=method, allow_nonstandard_methods=True)
+ self.assertEqual(method, to_unicode(resp.body))
+ resp = self.fetch(url, method="HEAD")
+ self.assertEqual(200, resp.code)
+ self.assertEqual(b"", resp.body)
+
def test_credentials_in_url(self):
url = self.get_url("/auth").replace("http://", "http://me:secret@")
response = self.fetch(url)
- self.assertEqual(b"Basic " + base64.b64encode(b"me:secret"),
- response.body)
+ self.assertEqual(b"Basic " + base64.b64encode(b"me:secret"), response.body)
def test_body_encoding(self):
unicode_body = u"\xe9"
byte_body = binascii.a2b_hex(b"e9")
# unicode string in body gets converted to utf8
- response = self.fetch("/echopost", method="POST", body=unicode_body,
- headers={"Content-Type": "application/blah"})
+ response = self.fetch(
+ "/echopost",
+ method="POST",
+ body=unicode_body,
+ headers={"Content-Type": "application/blah"},
+ )
self.assertEqual(response.headers["Content-Length"], "2")
self.assertEqual(response.body, utf8(unicode_body))
# byte strings pass through directly
- response = self.fetch("/echopost", method="POST",
- body=byte_body,
- headers={"Content-Type": "application/blah"})
+ response = self.fetch(
+ "/echopost",
+ method="POST",
+ body=byte_body,
+ headers={"Content-Type": "application/blah"},
+ )
self.assertEqual(response.headers["Content-Length"], "1")
self.assertEqual(response.body, byte_body)
# Mixing unicode in headers and byte string bodies shouldn't
# break anything
- response = self.fetch("/echopost", method="POST", body=byte_body,
- headers={"Content-Type": "application/blah"},
- user_agent=u"foo")
+ response = self.fetch(
+ "/echopost",
+ method="POST",
+ body=byte_body,
+ headers={"Content-Type": "application/blah"},
+ user_agent=u"foo",
+ )
self.assertEqual(response.headers["Content-Length"], "1")
self.assertEqual(response.body, byte_body)
@@ -305,59 +416,75 @@ def test_types(self):
self.assertEqual(type(response.code), int)
self.assertEqual(type(response.effective_url), str)
+ def test_gzip(self):
+ # All the tests in this file should be using gzip, but this test
+ # ensures that it is in fact getting compressed, and also tests
+ # the httpclient's decompress=False option.
+ # Setting Accept-Encoding manually bypasses the client's
+ # decompression so we can see the raw data.
+ response = self.fetch(
+ "/chunk", decompress_response=False, headers={"Accept-Encoding": "gzip"}
+ )
+ self.assertEqual(response.headers["Content-Encoding"], "gzip")
+ self.assertNotEqual(response.body, b"asdfqwer")
+ # Our test data gets bigger when gzipped. Oops. :)
+ # Chunked encoding bypasses the MIN_LENGTH check.
+ self.assertEqual(len(response.body), 34)
+ f = gzip.GzipFile(mode="r", fileobj=response.buffer)
+ self.assertEqual(f.read(), b"asdfqwer")
+
+ def test_invalid_gzip(self):
+ # test if client hangs on tricky invalid gzip
+ # curl/simple httpclient have different behavior (exception, logging)
+ with ExpectLog(
+ app_log, "(Uncaught exception|Exception in callback)", required=False
+ ):
+ try:
+ response = self.fetch("/invalid_gzip")
+ self.assertEqual(response.code, 200)
+ self.assertEqual(response.body[:14], b"Hello World 0\n")
+ except HTTPError:
+ pass # acceptable
+
def test_header_callback(self):
first_line = []
headers = {}
chunks = []
def header_callback(header_line):
- if header_line.startswith('HTTP/1.1 101'):
+ if header_line.startswith("HTTP/1.1 101"):
# Upgrading to HTTP/2
pass
- elif header_line.startswith('HTTP/'):
+ elif header_line.startswith("HTTP/"):
first_line.append(header_line)
- elif header_line != '\r\n':
- k, v = header_line.split(':', 1)
+ elif header_line != "\r\n":
+ k, v = header_line.split(":", 1)
headers[k.lower()] = v.strip()
def streaming_callback(chunk):
# All header callbacks are run before any streaming callbacks,
# so the header data is available to process the data as it
# comes in.
- self.assertEqual(headers['content-type'], 'text/html; charset=UTF-8')
+ self.assertEqual(headers["content-type"], "text/html; charset=UTF-8")
chunks.append(chunk)
- self.fetch('/chunk', header_callback=header_callback,
- streaming_callback=streaming_callback)
+ self.fetch(
+ "/chunk",
+ header_callback=header_callback,
+ streaming_callback=streaming_callback,
+ )
self.assertEqual(len(first_line), 1, first_line)
- self.assertRegexpMatches(first_line[0], 'HTTP/[0-9]\\.[0-9] 200.*\r\n')
- self.assertEqual(chunks, [b'asdf', b'qwer'])
-
- def test_header_callback_stack_context(self):
- exc_info = []
-
- def error_handler(typ, value, tb):
- exc_info.append((typ, value, tb))
- return True
-
- def header_callback(header_line):
- if header_line.lower().startswith('content-type:'):
- 1 / 0
-
- with ExceptionStackContext(error_handler):
- self.fetch('/chunk', header_callback=header_callback)
- self.assertEqual(len(exc_info), 1)
- self.assertIs(exc_info[0][0], ZeroDivisionError)
+ self.assertRegexpMatches(first_line[0], "HTTP/[0-9]\\.[0-9] 200.*\r\n")
+ self.assertEqual(chunks, [b"asdf", b"qwer"])
@gen_test
def test_configure_defaults(self):
- defaults = dict(user_agent='TestDefaultUserAgent', allow_ipv6=False)
+ defaults = dict(user_agent="TestDefaultUserAgent", allow_ipv6=False)
# Construct a new instance of the configured client class
- client = self.http_client.__class__(force_instance=True,
- defaults=defaults)
+ client = self.http_client.__class__(force_instance=True, defaults=defaults)
try:
- response = yield client.fetch(self.get_url('/user_agent'))
- self.assertEqual(response.body, b'TestDefaultUserAgent')
+ response = yield client.fetch(self.get_url("/user_agent"))
+ self.assertEqual(response.body, b"TestDefaultUserAgent")
finally:
client.close()
@@ -369,84 +496,75 @@ def test_header_types(self):
for value in [u"MyUserAgent", b"MyUserAgent"]:
for container in [dict, HTTPHeaders]:
headers = container()
- headers['User-Agent'] = value
- resp = self.fetch('/user_agent', headers=headers)
+ headers["User-Agent"] = value
+ resp = self.fetch("/user_agent", headers=headers)
self.assertEqual(
- resp.body, b"MyUserAgent",
- "response=%r, value=%r, container=%r" %
- (resp.body, value, container))
+ resp.body,
+ b"MyUserAgent",
+ "response=%r, value=%r, container=%r"
+ % (resp.body, value, container),
+ )
def test_multi_line_headers(self):
# Multi-line http headers are rare but rfc-allowed
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
sock, port = bind_unused_port()
with closing(sock):
+
@gen.coroutine
def accept_callback(conn, address):
stream = IOStream(conn)
request_data = yield stream.read_until(b"\r\n\r\n")
if b"HTTP/1." not in request_data:
self.skipTest("requires HTTP/1.x")
- yield stream.write(b"""\
+ yield stream.write(
+ b"""\
HTTP/1.1 200 OK
X-XSS-Protection: 1;
\tmode=block
-""".replace(b"\n", b"\r\n"))
+""".replace(
+ b"\n", b"\r\n"
+ )
+ )
stream.close()
- netutil.add_accept_handler(sock, accept_callback)
- resp = self.fetch("http://127.0.0.1:%d/" % port)
- resp.rethrow()
- self.assertEqual(resp.headers['X-XSS-Protection'], "1; mode=block")
- self.io_loop.remove_handler(sock.fileno())
+ netutil.add_accept_handler(sock, accept_callback) # type: ignore
+ try:
+ resp = self.fetch("http://127.0.0.1:%d/" % port)
+ resp.rethrow()
+ self.assertEqual(resp.headers["X-XSS-Protection"], "1; mode=block")
+ finally:
+ self.io_loop.remove_handler(sock.fileno())
def test_304_with_content_length(self):
# According to the spec 304 responses SHOULD NOT include
# Content-Length or other entity headers, but some servers do it
# anyway.
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.5
- response = self.fetch('/304_with_content_length')
+ response = self.fetch("/304_with_content_length")
self.assertEqual(response.code, 304)
- self.assertEqual(response.headers['Content-Length'], '42')
-
- def test_final_callback_stack_context(self):
- # The final callback should be run outside of the httpclient's
- # stack_context. We want to ensure that there is not stack_context
- # between the user's callback and the IOLoop, so monkey-patch
- # IOLoop.handle_callback_exception and disable the test harness's
- # context with a NullContext.
- # Note that this does not apply to secondary callbacks (header
- # and streaming_callback), as errors there must be seen as errors
- # by the http client so it can clean up the connection.
- exc_info = []
-
- def handle_callback_exception(callback):
- exc_info.append(sys.exc_info())
- self.stop()
- self.io_loop.handle_callback_exception = handle_callback_exception
- with NullContext():
- with ignore_deprecation():
- self.http_client.fetch(self.get_url('/hello'),
- lambda response: 1 / 0)
- self.wait()
- self.assertEqual(exc_info[0][0], ZeroDivisionError)
+ self.assertEqual(response.headers["Content-Length"], "42")
@gen_test
def test_future_interface(self):
- response = yield self.http_client.fetch(self.get_url('/hello'))
- self.assertEqual(response.body, b'Hello world!')
+ response = yield self.http_client.fetch(self.get_url("/hello"))
+ self.assertEqual(response.body, b"Hello world!")
@gen_test
def test_future_http_error(self):
with self.assertRaises(HTTPError) as context:
- yield self.http_client.fetch(self.get_url('/notfound'))
+ yield self.http_client.fetch(self.get_url("/notfound"))
+ assert context.exception is not None
+ assert context.exception.response is not None
self.assertEqual(context.exception.code, 404)
self.assertEqual(context.exception.response.code, 404)
@gen_test
def test_future_http_error_no_raise(self):
- response = yield self.http_client.fetch(self.get_url('/notfound'), raise_error=False)
+ response = yield self.http_client.fetch(
+ self.get_url("/notfound"), raise_error=False
+ )
self.assertEqual(response.code, 404)
@gen_test
@@ -455,48 +573,69 @@ def test_reuse_request_from_response(self):
# a _RequestProxy.
# This test uses self.http_client.fetch because self.fetch calls
# self.get_url on the input unconditionally.
- url = self.get_url('/hello')
+ url = self.get_url("/hello")
response = yield self.http_client.fetch(url)
self.assertEqual(response.request.url, url)
self.assertTrue(isinstance(response.request, HTTPRequest))
response2 = yield self.http_client.fetch(response.request)
- self.assertEqual(response2.body, b'Hello world!')
+ self.assertEqual(response2.body, b"Hello world!")
+
+ @gen_test
+ def test_bind_source_ip(self):
+ url = self.get_url("/hello")
+ request = HTTPRequest(url, network_interface="127.0.0.1")
+ response = yield self.http_client.fetch(request)
+ self.assertEqual(response.code, 200)
+
+ with self.assertRaises((ValueError, HTTPError)) as context: # type: ignore
+ request = HTTPRequest(url, network_interface="not-interface-or-ip")
+ yield self.http_client.fetch(request)
+ self.assertIn("not-interface-or-ip", str(context.exception))
def test_all_methods(self):
- for method in ['GET', 'DELETE', 'OPTIONS']:
- response = self.fetch('/all_methods', method=method)
+ for method in ["GET", "DELETE", "OPTIONS"]:
+ response = self.fetch("/all_methods", method=method)
self.assertEqual(response.body, utf8(method))
- for method in ['POST', 'PUT', 'PATCH']:
- response = self.fetch('/all_methods', method=method, body=b'')
+ for method in ["POST", "PUT", "PATCH"]:
+ response = self.fetch("/all_methods", method=method, body=b"")
self.assertEqual(response.body, utf8(method))
- response = self.fetch('/all_methods', method='HEAD')
- self.assertEqual(response.body, b'')
- response = self.fetch('/all_methods', method='OTHER',
- allow_nonstandard_methods=True)
- self.assertEqual(response.body, b'OTHER')
+ response = self.fetch("/all_methods", method="HEAD")
+ self.assertEqual(response.body, b"")
+ response = self.fetch(
+ "/all_methods", method="OTHER", allow_nonstandard_methods=True
+ )
+ self.assertEqual(response.body, b"OTHER")
def test_body_sanity_checks(self):
# These methods require a body.
- for method in ('POST', 'PUT', 'PATCH'):
+ for method in ("POST", "PUT", "PATCH"):
with self.assertRaises(ValueError) as context:
- self.fetch('/all_methods', method=method, raise_error=True)
- self.assertIn('must not be None', str(context.exception))
+ self.fetch("/all_methods", method=method, raise_error=True)
+ self.assertIn("must not be None", str(context.exception))
- resp = self.fetch('/all_methods', method=method,
- allow_nonstandard_methods=True)
+ resp = self.fetch(
+ "/all_methods", method=method, allow_nonstandard_methods=True
+ )
self.assertEqual(resp.code, 200)
# These methods don't allow a body.
- for method in ('GET', 'DELETE', 'OPTIONS'):
+ for method in ("GET", "DELETE", "OPTIONS"):
with self.assertRaises(ValueError) as context:
- self.fetch('/all_methods', method=method, body=b'asdf', raise_error=True)
- self.assertIn('must be None', str(context.exception))
+ self.fetch(
+ "/all_methods", method=method, body=b"asdf", raise_error=True
+ )
+ self.assertIn("must be None", str(context.exception))
# In most cases this can be overridden, but curl_httpclient
# does not allow body with a GET at all.
- if method != 'GET':
- self.fetch('/all_methods', method=method, body=b'asdf',
- allow_nonstandard_methods=True, raise_error=True)
+ if method != "GET":
+ self.fetch(
+ "/all_methods",
+ method=method,
+ body=b"asdf",
+ allow_nonstandard_methods=True,
+ raise_error=True,
+ )
self.assertEqual(resp.code, 200)
# This test causes odd failures with the combination of
@@ -516,8 +655,9 @@ def test_body_sanity_checks(self):
# self.assertEqual(response.body, b"Post arg1: foo, arg2: bar")
def test_put_307(self):
- response = self.fetch("/redirect?status=307&url=/put",
- method="PUT", body=b"hello")
+ response = self.fetch(
+ "/redirect?status=307&url=/put", method="PUT", body=b"hello"
+ )
response.rethrow()
self.assertEqual(response.body, b"Put body: hello")
@@ -527,69 +667,111 @@ def test_non_ascii_header(self):
response.rethrow()
self.assertEqual(response.headers["Foo"], native_str(u"\u00e9"))
+ def test_response_times(self):
+ # A few simple sanity checks of the response time fields to
+ # make sure they're using the right basis (between the
+ # wall-time and monotonic clocks).
+ start_time = time.time()
+ response = self.fetch("/hello")
+ response.rethrow()
+ self.assertGreaterEqual(response.request_time, 0)
+ self.assertLess(response.request_time, 1.0)
+ # A very crude check to make sure that start_time is based on
+ # wall time and not the monotonic clock.
+ assert response.start_time is not None
+ self.assertLess(abs(response.start_time - start_time), 1.0)
+
+ for k, v in response.time_info.items():
+ self.assertTrue(0 <= v < 1.0, "time_info[%s] out of bounds: %s" % (k, v))
+
+ def test_zero_timeout(self):
+ response = self.fetch("/hello", connect_timeout=0)
+ self.assertEqual(response.code, 200)
+
+ response = self.fetch("/hello", request_timeout=0)
+ self.assertEqual(response.code, 200)
+
+ response = self.fetch("/hello", connect_timeout=0, request_timeout=0)
+ self.assertEqual(response.code, 200)
+
+ @gen_test
+ def test_error_after_cancel(self):
+ fut = self.http_client.fetch(self.get_url("/404"))
+ self.assertTrue(fut.cancel())
+ with ExpectLog(app_log, "Exception after Future was cancelled") as el:
+ # We can't wait on the cancelled Future any more, so just
+ # let the IOLoop run until the exception gets logged (or
+ # not, in which case we exit the loop and ExpectLog will
+ # raise).
+ for i in range(100):
+ yield gen.sleep(0.01)
+ if el.logged_stack:
+ break
+
class RequestProxyTest(unittest.TestCase):
def test_request_set(self):
- proxy = _RequestProxy(HTTPRequest('http://example.com/',
- user_agent='foo'),
- dict())
- self.assertEqual(proxy.user_agent, 'foo')
+ proxy = _RequestProxy(
+ HTTPRequest("http://example.com/", user_agent="foo"), dict()
+ )
+ self.assertEqual(proxy.user_agent, "foo")
def test_default_set(self):
- proxy = _RequestProxy(HTTPRequest('http://example.com/'),
- dict(network_interface='foo'))
- self.assertEqual(proxy.network_interface, 'foo')
+ proxy = _RequestProxy(
+ HTTPRequest("http://example.com/"), dict(network_interface="foo")
+ )
+ self.assertEqual(proxy.network_interface, "foo")
def test_both_set(self):
- proxy = _RequestProxy(HTTPRequest('http://example.com/',
- proxy_host='foo'),
- dict(proxy_host='bar'))
- self.assertEqual(proxy.proxy_host, 'foo')
+ proxy = _RequestProxy(
+ HTTPRequest("http://example.com/", proxy_host="foo"), dict(proxy_host="bar")
+ )
+ self.assertEqual(proxy.proxy_host, "foo")
def test_neither_set(self):
- proxy = _RequestProxy(HTTPRequest('http://example.com/'),
- dict())
+ proxy = _RequestProxy(HTTPRequest("http://example.com/"), dict())
self.assertIs(proxy.auth_username, None)
def test_bad_attribute(self):
- proxy = _RequestProxy(HTTPRequest('http://example.com/'),
- dict())
+ proxy = _RequestProxy(HTTPRequest("http://example.com/"), dict())
with self.assertRaises(AttributeError):
proxy.foo
def test_defaults_none(self):
- proxy = _RequestProxy(HTTPRequest('http://example.com/'), None)
+ proxy = _RequestProxy(HTTPRequest("http://example.com/"), None)
self.assertIs(proxy.auth_username, None)
class HTTPResponseTestCase(unittest.TestCase):
def test_str(self):
- response = HTTPResponse(HTTPRequest('http://example.com'),
- 200, headers={}, buffer=BytesIO())
+ response = HTTPResponse( # type: ignore
+ HTTPRequest("http://example.com"), 200, buffer=BytesIO()
+ )
s = str(response)
- self.assertTrue(s.startswith('HTTPResponse('))
- self.assertIn('code=200', s)
+ self.assertTrue(s.startswith("HTTPResponse("))
+ self.assertIn("code=200", s)
class SyncHTTPClientTest(unittest.TestCase):
def setUp(self):
- if IOLoop.configured_class().__name__ == 'TwistedIOLoop':
- # TwistedIOLoop only supports the global reactor, so we can't have
- # separate IOLoops for client and server threads.
- raise unittest.SkipTest(
- 'Sync HTTPClient not compatible with TwistedIOLoop')
self.server_ioloop = IOLoop()
+ event = threading.Event()
@gen.coroutine
def init_server():
sock, self.port = bind_unused_port()
- app = Application([('/', HelloWorldHandler)])
+ app = Application([("/", HelloWorldHandler)])
self.server = HTTPServer(app)
self.server.add_socket(sock)
- self.server_ioloop.run_sync(init_server)
+ event.set()
- self.server_thread = threading.Thread(target=self.server_ioloop.start)
+ def start():
+ self.server_ioloop.run_sync(init_server)
+ self.server_ioloop.start()
+
+ self.server_thread = threading.Thread(target=start)
self.server_thread.start()
+ event.wait()
self.http_client = HTTPClient()
@@ -604,61 +786,95 @@ def stop_server():
@gen.coroutine
def slow_stop():
+ yield self.server.close_all_connections()
# The number of iterations is difficult to predict. Typically,
# one is sufficient, although sometimes it needs more.
for i in range(5):
yield
self.server_ioloop.stop()
+
self.server_ioloop.add_callback(slow_stop)
+
self.server_ioloop.add_callback(stop_server)
self.server_thread.join()
self.http_client.close()
self.server_ioloop.close(all_fds=True)
def get_url(self, path):
- return 'http://127.0.0.1:%d%s' % (self.port, path)
+ return "http://127.0.0.1:%d%s" % (self.port, path)
def test_sync_client(self):
- response = self.http_client.fetch(self.get_url('/'))
- self.assertEqual(b'Hello world!', response.body)
+ response = self.http_client.fetch(self.get_url("/"))
+ self.assertEqual(b"Hello world!", response.body)
def test_sync_client_error(self):
# Synchronous HTTPClient raises errors directly; no need for
# response.rethrow()
with self.assertRaises(HTTPError) as assertion:
- self.http_client.fetch(self.get_url('/notfound'))
+ self.http_client.fetch(self.get_url("/notfound"))
self.assertEqual(assertion.exception.code, 404)
+class SyncHTTPClientSubprocessTest(unittest.TestCase):
+ def test_destructor_log(self):
+ # Regression test for
+ # https://github.com/tornadoweb/tornado/issues/2539
+ #
+ # In the past, the following program would log an
+ # "inconsistent AsyncHTTPClient cache" error from a destructor
+ # when the process is shutting down. The shutdown process is
+ # subtle and I don't fully understand it; the failure does not
+ # manifest if that lambda isn't there or is a simpler object
+ # like an int (nor does it manifest in the tornado test suite
+ # as a whole, which is why we use this subprocess).
+ proc = subprocess.run(
+ [
+ sys.executable,
+ "-c",
+ "from tornado.httpclient import HTTPClient; f = lambda: None; c = HTTPClient()",
+ ],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ check=True,
+ timeout=5,
+ )
+ if proc.stdout:
+ print("STDOUT:")
+ print(to_unicode(proc.stdout))
+ if proc.stdout:
+ self.fail("subprocess produced unexpected output")
+
+
class HTTPRequestTestCase(unittest.TestCase):
def test_headers(self):
- request = HTTPRequest('http://example.com', headers={'foo': 'bar'})
- self.assertEqual(request.headers, {'foo': 'bar'})
+ request = HTTPRequest("http://example.com", headers={"foo": "bar"})
+ self.assertEqual(request.headers, {"foo": "bar"})
def test_headers_setter(self):
- request = HTTPRequest('http://example.com')
- request.headers = {'bar': 'baz'}
- self.assertEqual(request.headers, {'bar': 'baz'})
+ request = HTTPRequest("http://example.com")
+ request.headers = {"bar": "baz"} # type: ignore
+ self.assertEqual(request.headers, {"bar": "baz"})
def test_null_headers_setter(self):
- request = HTTPRequest('http://example.com')
- request.headers = None
+ request = HTTPRequest("http://example.com")
+ request.headers = None # type: ignore
self.assertEqual(request.headers, {})
def test_body(self):
- request = HTTPRequest('http://example.com', body='foo')
- self.assertEqual(request.body, utf8('foo'))
+ request = HTTPRequest("http://example.com", body="foo")
+ self.assertEqual(request.body, utf8("foo"))
def test_body_setter(self):
- request = HTTPRequest('http://example.com')
- request.body = 'foo'
- self.assertEqual(request.body, utf8('foo'))
+ request = HTTPRequest("http://example.com")
+ request.body = "foo" # type: ignore
+ self.assertEqual(request.body, utf8("foo"))
def test_if_modified_since(self):
http_date = datetime.datetime.utcnow()
- request = HTTPRequest('http://example.com', if_modified_since=http_date)
- self.assertEqual(request.headers,
- {'If-Modified-Since': format_timestamp(http_date)})
+ request = HTTPRequest("http://example.com", if_modified_since=http_date)
+ self.assertEqual(
+ request.headers, {"If-Modified-Since": format_timestamp(http_date)}
+ )
class HTTPErrorTestCase(unittest.TestCase):
@@ -674,7 +890,7 @@ def test_plain_error(self):
self.assertEqual(repr(e), "HTTP 403: Forbidden")
def test_error_with_response(self):
- resp = HTTPResponse(HTTPRequest('http://example.com/'), 403)
+ resp = HTTPResponse(HTTPRequest("http://example.com/"), 403)
with self.assertRaises(HTTPError) as cm:
resp.rethrow()
e = cm.exception
diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py
index 4bca757a66..614dec7b8f 100644
--- a/tornado/test/httpserver_test.py
+++ b/tornado/test/httpserver_test.py
@@ -1,35 +1,58 @@
-from __future__ import absolute_import, division, print_function
-
from tornado import gen, netutil
-from tornado.concurrent import Future
-from tornado.escape import json_decode, json_encode, utf8, _unicode, recursive_unicode, native_str
+from tornado.escape import (
+ json_decode,
+ json_encode,
+ utf8,
+ _unicode,
+ recursive_unicode,
+ native_str,
+)
from tornado.http1connection import HTTP1Connection
from tornado.httpclient import HTTPError
from tornado.httpserver import HTTPServer
-from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine # noqa: E501
+from tornado.httputil import (
+ HTTPHeaders,
+ HTTPMessageDelegate,
+ HTTPServerConnectionDelegate,
+ ResponseStartLine,
+)
from tornado.iostream import IOStream
from tornado.locks import Event
from tornado.log import gen_log
from tornado.netutil import ssl_options_to_context
from tornado.simple_httpclient import SimpleAsyncHTTPClient
-from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test # noqa: E501
-from tornado.test.util import unittest, skipOnTravis, ignore_deprecation
-from tornado.web import Application, RequestHandler, asynchronous, stream_request_body
+from tornado.testing import (
+ AsyncHTTPTestCase,
+ AsyncHTTPSTestCase,
+ AsyncTestCase,
+ ExpectLog,
+ gen_test,
+)
+from tornado.test.util import skipOnTravis
+from tornado.web import Application, RequestHandler, stream_request_body
from contextlib import closing
import datetime
import gzip
+import logging
import os
import shutil
import socket
import ssl
import sys
import tempfile
+import unittest
+import urllib.parse
from io import BytesIO
+import typing
+
+if typing.TYPE_CHECKING:
+ from typing import Dict, List # noqa: F401
+
-def read_stream_body(stream, callback):
- """Reads an HTTP response from `stream` and runs callback with its
+async def read_stream_body(stream):
+ """Reads an HTTP response from `stream` and returns a tuple of its
start_line, headers and body."""
chunks = []
@@ -42,15 +65,19 @@ def data_received(self, chunk):
chunks.append(chunk)
def finish(self):
- conn.detach()
- callback((self.start_line, self.headers, b''.join(chunks)))
+ conn.detach() # type: ignore
+
conn = HTTP1Connection(stream, True)
- conn.read_response(Delegate())
+ delegate = Delegate()
+ await conn.read_response(delegate)
+ return delegate.start_line, delegate.headers, b"".join(chunks)
class HandlerBaseTestCase(AsyncHTTPTestCase):
+ Handler = None
+
def get_app(self):
- return Application([('/', self.__class__.Handler)])
+ return Application([("/", self.__class__.Handler)])
def fetch_json(self, *args, **kwargs):
response = self.fetch(*args, **kwargs)
@@ -77,55 +104,58 @@ def post(self):
# introduced in python3.2, it was present but undocumented in
# python 2.7
skipIfOldSSL = unittest.skipIf(
- getattr(ssl, 'OPENSSL_VERSION_INFO', (0, 0)) < (1, 0),
- "old version of ssl module and/or openssl")
+ getattr(ssl, "OPENSSL_VERSION_INFO", (0, 0)) < (1, 0),
+ "old version of ssl module and/or openssl",
+)
class BaseSSLTest(AsyncHTTPSTestCase):
def get_app(self):
- return Application([('/', HelloWorldRequestHandler,
- dict(protocol="https"))])
+ return Application([("/", HelloWorldRequestHandler, dict(protocol="https"))])
class SSLTestMixin(object):
def get_ssl_options(self):
- return dict(ssl_version=self.get_ssl_version(), # type: ignore
- **AsyncHTTPSTestCase.get_ssl_options())
+ return dict(
+ ssl_version=self.get_ssl_version(),
+ **AsyncHTTPSTestCase.default_ssl_options()
+ )
def get_ssl_version(self):
raise NotImplementedError()
- def test_ssl(self):
- response = self.fetch('/')
+ def test_ssl(self: typing.Any):
+ response = self.fetch("/")
self.assertEqual(response.body, b"Hello world")
- def test_large_post(self):
- response = self.fetch('/',
- method='POST',
- body='A' * 5000)
+ def test_large_post(self: typing.Any):
+ response = self.fetch("/", method="POST", body="A" * 5000)
self.assertEqual(response.body, b"Got 5000 bytes in POST")
- def test_non_ssl_request(self):
+ def test_non_ssl_request(self: typing.Any):
# Make sure the server closes the connection when it gets a non-ssl
# connection, rather than waiting for a timeout or otherwise
# misbehaving.
- with ExpectLog(gen_log, '(SSL Error|uncaught exception)'):
- with ExpectLog(gen_log, 'Uncaught exception', required=False):
- with self.assertRaises((IOError, HTTPError)):
+ with ExpectLog(gen_log, "(SSL Error|uncaught exception)"):
+ with ExpectLog(gen_log, "Uncaught exception", required=False):
+ with self.assertRaises((IOError, HTTPError)): # type: ignore
self.fetch(
- self.get_url("/").replace('https:', 'http:'),
+ self.get_url("/").replace("https:", "http:"),
request_timeout=3600,
connect_timeout=3600,
- raise_error=True)
+ raise_error=True,
+ )
- def test_error_logging(self):
+ def test_error_logging(self: typing.Any):
# No stack traces are logged for SSL errors.
- with ExpectLog(gen_log, 'SSL Error') as expect_log:
- with self.assertRaises((IOError, HTTPError)):
- self.fetch(self.get_url("/").replace("https:", "http:"),
- raise_error=True)
+ with ExpectLog(gen_log, "SSL Error") as expect_log:
+ with self.assertRaises((IOError, HTTPError)): # type: ignore
+ self.fetch(
+ self.get_url("/").replace("https:", "http:"), raise_error=True
+ )
self.assertFalse(expect_log.logged_stack)
+
# Python's SSL implementation differs significantly between versions.
# For example, SSLv3 and TLSv1 throw an exception if you try to read
# from the socket before the handshake is complete, but the default
@@ -151,8 +181,7 @@ def get_ssl_version(self):
class SSLContextTest(BaseSSLTest, SSLTestMixin):
def get_ssl_options(self):
- context = ssl_options_to_context(
- AsyncHTTPSTestCase.get_ssl_options(self))
+ context = ssl_options_to_context(AsyncHTTPSTestCase.get_ssl_options(self))
assert isinstance(context, ssl.SSLContext)
return context
@@ -160,85 +189,108 @@ def get_ssl_options(self):
class BadSSLOptionsTest(unittest.TestCase):
def test_missing_arguments(self):
application = Application()
- self.assertRaises(KeyError, HTTPServer, application, ssl_options={
- "keyfile": "/__missing__.crt",
- })
+ self.assertRaises(
+ KeyError,
+ HTTPServer,
+ application,
+ ssl_options={"keyfile": "/__missing__.crt"},
+ )
def test_missing_key(self):
"""A missing SSL key should cause an immediate exception."""
application = Application()
module_dir = os.path.dirname(__file__)
- existing_certificate = os.path.join(module_dir, 'test.crt')
- existing_key = os.path.join(module_dir, 'test.key')
-
- self.assertRaises((ValueError, IOError),
- HTTPServer, application, ssl_options={
- "certfile": "/__mising__.crt",
- })
- self.assertRaises((ValueError, IOError),
- HTTPServer, application, ssl_options={
- "certfile": existing_certificate,
- "keyfile": "/__missing__.key"
- })
+ existing_certificate = os.path.join(module_dir, "test.crt")
+ existing_key = os.path.join(module_dir, "test.key")
+
+ self.assertRaises(
+ (ValueError, IOError),
+ HTTPServer,
+ application,
+ ssl_options={"certfile": "/__mising__.crt"},
+ )
+ self.assertRaises(
+ (ValueError, IOError),
+ HTTPServer,
+ application,
+ ssl_options={
+ "certfile": existing_certificate,
+ "keyfile": "/__missing__.key",
+ },
+ )
# This actually works because both files exist
- HTTPServer(application, ssl_options={
- "certfile": existing_certificate,
- "keyfile": existing_key,
- })
+ HTTPServer(
+ application,
+ ssl_options={"certfile": existing_certificate, "keyfile": existing_key},
+ )
class MultipartTestHandler(RequestHandler):
def post(self):
- self.finish({"header": self.request.headers["X-Header-Encoding-Test"],
- "argument": self.get_argument("argument"),
- "filename": self.request.files["files"][0].filename,
- "filebody": _unicode(self.request.files["files"][0]["body"]),
- })
+ self.finish(
+ {
+ "header": self.request.headers["X-Header-Encoding-Test"],
+ "argument": self.get_argument("argument"),
+ "filename": self.request.files["files"][0].filename,
+ "filebody": _unicode(self.request.files["files"][0]["body"]),
+ }
+ )
# This test is also called from wsgi_test
class HTTPConnectionTest(AsyncHTTPTestCase):
def get_handlers(self):
- return [("/multipart", MultipartTestHandler),
- ("/hello", HelloWorldRequestHandler)]
+ return [
+ ("/multipart", MultipartTestHandler),
+ ("/hello", HelloWorldRequestHandler),
+ ]
def get_app(self):
return Application(self.get_handlers())
def raw_fetch(self, headers, body, newline=b"\r\n"):
with closing(IOStream(socket.socket())) as stream:
- with ignore_deprecation():
- stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
- self.wait()
+ self.io_loop.run_sync(
+ lambda: stream.connect(("127.0.0.1", self.get_http_port()))
+ )
stream.write(
- newline.join(headers +
- [utf8("Content-Length: %d" % len(body))]) +
- newline + newline + body)
- read_stream_body(stream, self.stop)
- start_line, headers, body = self.wait()
+ newline.join(headers + [utf8("Content-Length: %d" % len(body))])
+ + newline
+ + newline
+ + body
+ )
+ start_line, headers, body = self.io_loop.run_sync(
+ lambda: read_stream_body(stream)
+ )
return body
def test_multipart_form(self):
# Encodings here are tricky: Headers are latin1, bodies can be
# anything (we use utf8 by default).
- response = self.raw_fetch([
- b"POST /multipart HTTP/1.0",
- b"Content-Type: multipart/form-data; boundary=1234567890",
- b"X-Header-encoding-test: \xe9",
- ],
- b"\r\n".join([
- b"Content-Disposition: form-data; name=argument",
- b"",
- u"\u00e1".encode("utf-8"),
- b"--1234567890",
- u'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode("utf8"),
- b"",
- u"\u00fa".encode("utf-8"),
- b"--1234567890--",
- b"",
- ]))
+ response = self.raw_fetch(
+ [
+ b"POST /multipart HTTP/1.0",
+ b"Content-Type: multipart/form-data; boundary=1234567890",
+ b"X-Header-encoding-test: \xe9",
+ ],
+ b"\r\n".join(
+ [
+ b"Content-Disposition: form-data; name=argument",
+ b"",
+ u"\u00e1".encode("utf-8"),
+ b"--1234567890",
+ u'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode(
+ "utf8"
+ ),
+ b"",
+ u"\u00fa".encode("utf-8"),
+ b"--1234567890--",
+ b"",
+ ]
+ ),
+ )
data = json_decode(response)
self.assertEqual(u"\u00e9", data["header"])
self.assertEqual(u"\u00e1", data["argument"])
@@ -248,9 +300,8 @@ def test_multipart_form(self):
def test_newlines(self):
# We support both CRLF and bare LF as line separators.
for newline in (b"\r\n", b"\n"):
- response = self.raw_fetch([b"GET /hello HTTP/1.0"], b"",
- newline=newline)
- self.assertEqual(response, b'Hello world')
+ response = self.raw_fetch([b"GET /hello HTTP/1.0"], b"", newline=newline)
+ self.assertEqual(response, b"Hello world")
@gen_test
def test_100_continue(self):
@@ -259,19 +310,24 @@ def test_100_continue(self):
# headers, and then the real response after the body.
stream = IOStream(socket.socket())
yield stream.connect(("127.0.0.1", self.get_http_port()))
- yield stream.write(b"\r\n".join([
- b"POST /hello HTTP/1.1",
- b"Content-Length: 1024",
- b"Expect: 100-continue",
- b"Connection: close",
- b"\r\n"]))
+ yield stream.write(
+ b"\r\n".join(
+ [
+ b"POST /hello HTTP/1.1",
+ b"Content-Length: 1024",
+ b"Expect: 100-continue",
+ b"Connection: close",
+ b"\r\n",
+ ]
+ )
+ )
data = yield stream.read_until(b"\r\n\r\n")
self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data)
stream.write(b"a" * 1024)
first_line = yield stream.read_until(b"\r\n")
self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
header_data = yield stream.read_until(b"\r\n\r\n")
- headers = HTTPHeaders.parse(native_str(header_data.decode('latin1')))
+ headers = HTTPHeaders.parse(native_str(header_data.decode("latin1")))
body = yield stream.read_bytes(int(headers["Content-Length"]))
self.assertEqual(body, b"Got 1024 bytes in POST")
stream.close()
@@ -287,32 +343,34 @@ def post(self):
class TypeCheckHandler(RequestHandler):
def prepare(self):
- self.errors = {}
+ self.errors = {} # type: Dict[str, str]
fields = [
- ('method', str),
- ('uri', str),
- ('version', str),
- ('remote_ip', str),
- ('protocol', str),
- ('host', str),
- ('path', str),
- ('query', str),
+ ("method", str),
+ ("uri", str),
+ ("version", str),
+ ("remote_ip", str),
+ ("protocol", str),
+ ("host", str),
+ ("path", str),
+ ("query", str),
]
for field, expected_type in fields:
self.check_type(field, getattr(self.request, field), expected_type)
- self.check_type('header_key', list(self.request.headers.keys())[0], str)
- self.check_type('header_value', list(self.request.headers.values())[0], str)
+ self.check_type("header_key", list(self.request.headers.keys())[0], str)
+ self.check_type("header_value", list(self.request.headers.values())[0], str)
- self.check_type('cookie_key', list(self.request.cookies.keys())[0], str)
- self.check_type('cookie_value', list(self.request.cookies.values())[0].value, str)
+ self.check_type("cookie_key", list(self.request.cookies.keys())[0], str)
+ self.check_type(
+ "cookie_value", list(self.request.cookies.values())[0].value, str
+ )
# secure cookies
- self.check_type('arg_key', list(self.request.arguments.keys())[0], str)
- self.check_type('arg_value', list(self.request.arguments.values())[0][0], bytes)
+ self.check_type("arg_key", list(self.request.arguments.keys())[0], str)
+ self.check_type("arg_value", list(self.request.arguments.values())[0][0], bytes)
def post(self):
- self.check_type('body', self.request.body, bytes)
+ self.check_type("body", self.request.body, bytes)
self.write(self.errors)
def get(self):
@@ -321,16 +379,33 @@ def get(self):
def check_type(self, name, obj, expected_type):
actual_type = type(obj)
if expected_type != actual_type:
- self.errors[name] = "expected %s, got %s" % (expected_type,
- actual_type)
+ self.errors[name] = "expected %s, got %s" % (expected_type, actual_type)
+
+
+class PostEchoHandler(RequestHandler):
+ def post(self, *path_args):
+ self.write(dict(echo=self.get_argument("data")))
+
+
+class PostEchoGBKHandler(PostEchoHandler):
+ def decode_argument(self, value, name=None):
+ try:
+ return value.decode("gbk")
+ except Exception:
+ raise HTTPError(400, "invalid gbk bytes: %r" % value)
class HTTPServerTest(AsyncHTTPTestCase):
def get_app(self):
- return Application([("/echo", EchoHandler),
- ("/typecheck", TypeCheckHandler),
- ("//doubleslash", EchoHandler),
- ])
+ return Application(
+ [
+ ("/echo", EchoHandler),
+ ("/typecheck", TypeCheckHandler),
+ ("//doubleslash", EchoHandler),
+ ("/post_utf8", PostEchoHandler),
+ ("/post_gbk", PostEchoGBKHandler),
+ ]
+ )
def test_query_string_encoding(self):
response = self.fetch("/echo?foo=%C3%A9")
@@ -353,7 +428,9 @@ def test_types(self):
data = json_decode(response.body)
self.assertEqual(data, {})
- response = self.fetch("/typecheck", method="POST", body="foo=bar", headers=headers)
+ response = self.fetch(
+ "/typecheck", method="POST", body="foo=bar", headers=headers
+ )
data = json_decode(response.body)
self.assertEqual(data, {})
@@ -365,36 +442,38 @@ def test_double_slash(self):
self.assertEqual(200, response.code)
self.assertEqual(json_decode(response.body), {})
- def test_malformed_body(self):
- # parse_qs is pretty forgiving, but it will fail on python 3
- # if the data is not utf8. On python 2 parse_qs will work,
- # but then the recursive_unicode call in EchoHandler will
- # fail.
- if str is bytes:
- return
- with ExpectLog(gen_log, 'Invalid x-www-form-urlencoded body'):
- response = self.fetch(
- '/echo', method="POST",
- headers={'Content-Type': 'application/x-www-form-urlencoded'},
- body=b'\xe9')
- self.assertEqual(200, response.code)
- self.assertEqual(b'{}', response.body)
+ def test_post_encodings(self):
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
+ uni_text = "chinese: \u5f20\u4e09"
+ for enc in ("utf8", "gbk"):
+ for quote in (True, False):
+ with self.subTest(enc=enc, quote=quote):
+ bin_text = uni_text.encode(enc)
+ if quote:
+ bin_text = urllib.parse.quote(bin_text).encode("ascii")
+ response = self.fetch(
+ "/post_" + enc,
+ method="POST",
+ headers=headers,
+ body=(b"data=" + bin_text),
+ )
+ self.assertEqual(json_decode(response.body), {"echo": uni_text})
class HTTPServerRawTest(AsyncHTTPTestCase):
def get_app(self):
- return Application([
- ('/echo', EchoHandler),
- ])
+ return Application([("/echo", EchoHandler)])
def setUp(self):
- super(HTTPServerRawTest, self).setUp()
+ super().setUp()
self.stream = IOStream(socket.socket())
- self.io_loop.run_sync(lambda: self.stream.connect(('127.0.0.1', self.get_http_port())))
+ self.io_loop.run_sync(
+ lambda: self.stream.connect(("127.0.0.1", self.get_http_port()))
+ )
def tearDown(self):
self.stream.close()
- super(HTTPServerRawTest, self).tearDown()
+ super().tearDown()
def test_empty_request(self):
self.stream.close()
@@ -402,34 +481,38 @@ def test_empty_request(self):
self.wait()
def test_malformed_first_line_response(self):
- with ExpectLog(gen_log, '.*Malformed HTTP request line'):
- self.stream.write(b'asdf\r\n\r\n')
- read_stream_body(self.stream, self.stop)
- start_line, headers, response = self.wait()
- self.assertEqual('HTTP/1.1', start_line.version)
+ with ExpectLog(gen_log, ".*Malformed HTTP request line", level=logging.INFO):
+ self.stream.write(b"asdf\r\n\r\n")
+ start_line, headers, response = self.io_loop.run_sync(
+ lambda: read_stream_body(self.stream)
+ )
+ self.assertEqual("HTTP/1.1", start_line.version)
self.assertEqual(400, start_line.code)
- self.assertEqual('Bad Request', start_line.reason)
+ self.assertEqual("Bad Request", start_line.reason)
def test_malformed_first_line_log(self):
- with ExpectLog(gen_log, '.*Malformed HTTP request line'):
- self.stream.write(b'asdf\r\n\r\n')
+ with ExpectLog(gen_log, ".*Malformed HTTP request line", level=logging.INFO):
+ self.stream.write(b"asdf\r\n\r\n")
# TODO: need an async version of ExpectLog so we don't need
# hard-coded timeouts here.
- self.io_loop.add_timeout(datetime.timedelta(seconds=0.05),
- self.stop)
+ self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), self.stop)
self.wait()
def test_malformed_headers(self):
- with ExpectLog(gen_log, '.*Malformed HTTP message.*no colon in header line'):
- self.stream.write(b'GET / HTTP/1.0\r\nasdf\r\n\r\n')
- self.io_loop.add_timeout(datetime.timedelta(seconds=0.05),
- self.stop)
+ with ExpectLog(
+ gen_log,
+ ".*Malformed HTTP message.*no colon in header line",
+ level=logging.INFO,
+ ):
+ self.stream.write(b"GET / HTTP/1.0\r\nasdf\r\n\r\n")
+ self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), self.stop)
self.wait()
def test_chunked_request_body(self):
# Chunked requests are not widely supported and we don't have a way
# to generate them in AsyncHTTPClient, but HTTPServer will read them.
- self.stream.write(b"""\
+ self.stream.write(
+ b"""\
POST /echo HTTP/1.1
Transfer-Encoding: chunked
Content-Type: application/x-www-form-urlencoded
@@ -440,15 +523,20 @@ def test_chunked_request_body(self):
bar
0
-""".replace(b"\n", b"\r\n"))
- read_stream_body(self.stream, self.stop)
- start_line, headers, response = self.wait()
- self.assertEqual(json_decode(response), {u'foo': [u'bar']})
+""".replace(
+ b"\n", b"\r\n"
+ )
+ )
+ start_line, headers, response = self.io_loop.run_sync(
+ lambda: read_stream_body(self.stream)
+ )
+ self.assertEqual(json_decode(response), {u"foo": [u"bar"]})
def test_chunked_request_uppercase(self):
# As per RFC 2616 section 3.6, "Transfer-Encoding" header's value is
# case-insensitive.
- self.stream.write(b"""\
+ self.stream.write(
+ b"""\
POST /echo HTTP/1.1
Transfer-Encoding: Chunked
Content-Type: application/x-www-form-urlencoded
@@ -459,118 +547,136 @@ def test_chunked_request_uppercase(self):
bar
0
-""".replace(b"\n", b"\r\n"))
- read_stream_body(self.stream, self.stop)
- start_line, headers, response = self.wait()
- self.assertEqual(json_decode(response), {u'foo': [u'bar']})
+""".replace(
+ b"\n", b"\r\n"
+ )
+ )
+ start_line, headers, response = self.io_loop.run_sync(
+ lambda: read_stream_body(self.stream)
+ )
+ self.assertEqual(json_decode(response), {u"foo": [u"bar"]})
@gen_test
def test_invalid_content_length(self):
- with ExpectLog(gen_log, '.*Only integer Content-Length is allowed'):
- self.stream.write(b"""\
+ with ExpectLog(
+ gen_log, ".*Only integer Content-Length is allowed", level=logging.INFO
+ ):
+ self.stream.write(
+ b"""\
POST /echo HTTP/1.1
Content-Length: foo
bar
-""".replace(b"\n", b"\r\n"))
+""".replace(
+ b"\n", b"\r\n"
+ )
+ )
yield self.stream.read_until_close()
class XHeaderTest(HandlerBaseTestCase):
class Handler(RequestHandler):
def get(self):
- self.set_header('request-version', self.request.version)
- self.write(dict(remote_ip=self.request.remote_ip,
- remote_protocol=self.request.protocol))
+ self.set_header("request-version", self.request.version)
+ self.write(
+ dict(
+ remote_ip=self.request.remote_ip,
+ remote_protocol=self.request.protocol,
+ )
+ )
def get_httpserver_options(self):
- return dict(xheaders=True, trusted_downstream=['5.5.5.5'])
+ return dict(xheaders=True, trusted_downstream=["5.5.5.5"])
def test_ip_headers(self):
self.assertEqual(self.fetch_json("/")["remote_ip"], "127.0.0.1")
valid_ipv4 = {"X-Real-IP": "4.4.4.4"}
self.assertEqual(
- self.fetch_json("/", headers=valid_ipv4)["remote_ip"],
- "4.4.4.4")
+ self.fetch_json("/", headers=valid_ipv4)["remote_ip"], "4.4.4.4"
+ )
valid_ipv4_list = {"X-Forwarded-For": "127.0.0.1, 4.4.4.4"}
self.assertEqual(
- self.fetch_json("/", headers=valid_ipv4_list)["remote_ip"],
- "4.4.4.4")
+ self.fetch_json("/", headers=valid_ipv4_list)["remote_ip"], "4.4.4.4"
+ )
valid_ipv6 = {"X-Real-IP": "2620:0:1cfe:face:b00c::3"}
self.assertEqual(
self.fetch_json("/", headers=valid_ipv6)["remote_ip"],
- "2620:0:1cfe:face:b00c::3")
+ "2620:0:1cfe:face:b00c::3",
+ )
valid_ipv6_list = {"X-Forwarded-For": "::1, 2620:0:1cfe:face:b00c::3"}
self.assertEqual(
self.fetch_json("/", headers=valid_ipv6_list)["remote_ip"],
- "2620:0:1cfe:face:b00c::3")
+ "2620:0:1cfe:face:b00c::3",
+ )
invalid_chars = {"X-Real-IP": "4.4.4.4
-