diff --git a/README.md b/README.md
index 184eb480a..8eedea952 100644
--- a/README.md
+++ b/README.md
@@ -22,7 +22,7 @@
# Starlette
Starlette is a lightweight [ASGI](https://asgi.readthedocs.io/en/latest/) framework/toolkit,
-which is ideal for building high performance asyncio services.
+which is ideal for building high performance async services.
It is production-ready, and gives you the following:
@@ -36,7 +36,8 @@ It is production-ready, and gives you the following:
* Session and Cookie support.
* 100% test coverage.
* 100% type annotated codebase.
-* Zero hard dependencies.
+* Few hard dependencies.
+* Compatible with `asyncio` and `trio` backends.
## Requirements
@@ -84,10 +85,9 @@ For a more complete example, see [encode/starlette-example](https://github.com/e
## Dependencies
-Starlette does not have any hard dependencies, but the following are optional:
+Starlette only requires `anyio`, and the following are optional:
* [`requests`][requests] - Required if you want to use the `TestClient`.
-* [`aiofiles`][aiofiles] - Required if you want to use `FileResponse` or `StaticFiles`.
* [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`.
* [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`.
* [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support.
@@ -167,7 +167,6 @@ gunicorn -k uvicorn.workers.UvicornH11Worker ...
Starlette is BSD licensed code. Designed & built in Brighton, England.
[requests]: http://docs.python-requests.org/en/master/
-[aiofiles]: https://github.com/Tinche/aiofiles
[jinja2]: http://jinja.pocoo.org/
[python-multipart]: https://andrew-d.github.io/python-multipart/
[graphene]: https://graphene-python.org/
diff --git a/docs/index.md b/docs/index.md
index 4ae77f0e6..b9692a1fb 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -32,7 +32,7 @@ It is production-ready, and gives you the following:
* Session and Cookie support.
* 100% test coverage.
* 100% type annotated codebase.
-* Zero hard dependencies.
+* Few hard dependencies.
## Requirements
@@ -79,10 +79,9 @@ For a more complete example, [see here](https://github.com/encode/starlette-exam
## Dependencies
-Starlette does not have any hard dependencies, but the following are optional:
+Starlette only requires `anyio`, and the following dependencies are optional:
* [`requests`][requests] - Required if you want to use the `TestClient`.
-* [`aiofiles`][aiofiles] - Required if you want to use `FileResponse` or `StaticFiles`.
* [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`.
* [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`.
* [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support.
@@ -161,7 +160,6 @@ gunicorn -k uvicorn.workers.UvicornH11Worker ...
Starlette is BSD licensed code. Designed & built in Brighton, England.
[requests]: http://docs.python-requests.org/en/master/
-[aiofiles]: https://github.com/Tinche/aiofiles
[jinja2]: http://jinja.pocoo.org/
[python-multipart]: https://andrew-d.github.io/python-multipart/
[graphene]: https://graphene-python.org/
diff --git a/docs/testclient.md b/docs/testclient.md
index 61f7201c6..f37858401 100644
--- a/docs/testclient.md
+++ b/docs/testclient.md
@@ -31,6 +31,22 @@ application. Occasionally you might want to test the content of 500 error
responses, rather than allowing client to raise the server exception. In this
case you should use `client = TestClient(app, raise_server_exceptions=False)`.
+### Selecting the Async backend
+
+`TestClient.async_backend` is a dictionary which allows you to set the options
+for the backend used to run tests. These options are passed to
+`anyio.start_blocking_portal()`. See the [anyio documentation](https://anyio.readthedocs.io/en/stable/basics.html#backend-options)
+for more information about backend options. By default, `asyncio` is used.
+
+To run `Trio`, set `async_backend["backend"] = "trio"`, for example:
+
+```python
+def test_app()
+ client = TestClient(app)
+ client.async_backend["backend"] = "trio"
+ ...
+```
+
### Testing WebSocket sessions
You can also test websocket sessions with the test client.
@@ -72,6 +88,8 @@ always raised by the test client.
May raise `starlette.websockets.WebSocketDisconnect` if the application does not accept the websocket connection.
+`websocket_connect()` must be used as a context manager (in a `with` block).
+
#### Sending data
* `.send_text(data)` - Send the given text to the application.
diff --git a/requirements.txt b/requirements.txt
index 6ec5bf09e..ae3d91f26 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -18,9 +18,10 @@ types-requests
types-contextvars
types-aiofiles
types-PyYAML
+types-dataclasses
pytest
pytest-cov
-pytest-asyncio
+trio
# Documentation
mkdocs
diff --git a/setup.py b/setup.py
index c48356370..a687ad861 100644
--- a/setup.py
+++ b/setup.py
@@ -37,9 +37,9 @@ def get_long_description():
packages=find_packages(exclude=["tests*"]),
package_data={"starlette": ["py.typed"]},
include_package_data=True,
+ install_requires=["anyio>=3.0.0,<4"],
extras_require={
"full": [
- "aiofiles",
"graphene",
"itsdangerous",
"jinja2",
diff --git a/starlette/concurrency.py b/starlette/concurrency.py
index c8c5d57ac..e89d1e047 100644
--- a/starlette/concurrency.py
+++ b/starlette/concurrency.py
@@ -1,33 +1,32 @@
-import asyncio
import functools
-import sys
import typing
from typing import Any, AsyncGenerator, Iterator
+import anyio
+
try:
import contextvars # Python 3.7+ only or via contextvars backport.
except ImportError: # pragma: no cover
contextvars = None # type: ignore
-if sys.version_info >= (3, 7): # pragma: no cover
- from asyncio import create_task
-else: # pragma: no cover
- from asyncio import ensure_future as create_task
T = typing.TypeVar("T")
async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None:
- tasks = [create_task(handler(**kwargs)) for handler, kwargs in args]
- (done, pending) = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
- [task.cancel() for task in pending]
- [task.result() for task in done]
+ async with anyio.create_task_group() as task_group:
+
+ async def run(func: typing.Callable[[], typing.Coroutine]) -> None:
+ await func()
+ task_group.cancel_scope.cancel()
+
+ for func, kwargs in args:
+ task_group.start_soon(run, functools.partial(func, **kwargs))
async def run_in_threadpool(
func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any
) -> T:
- loop = asyncio.get_event_loop()
if contextvars is not None: # pragma: no cover
# Ensure we run in the same context
child = functools.partial(func, *args, **kwargs)
@@ -35,9 +34,9 @@ async def run_in_threadpool(
func = context.run
args = (child,)
elif kwargs: # pragma: no cover
- # loop.run_in_executor doesn't accept 'kwargs', so bind them in here
+ # run_sync doesn't accept 'kwargs', so bind them in here
func = functools.partial(func, **kwargs)
- return await loop.run_in_executor(None, func, *args)
+ return await anyio.to_thread.run_sync(func, *args)
class _StopIteration(Exception):
@@ -57,6 +56,6 @@ def _next(iterator: Iterator) -> Any:
async def iterate_in_threadpool(iterator: Iterator) -> AsyncGenerator:
while True:
try:
- yield await run_in_threadpool(_next, iterator)
+ yield await anyio.to_thread.run_sync(_next, iterator)
except _StopIteration:
break
diff --git a/starlette/graphql.py b/starlette/graphql.py
index ed2274f89..6e5d6ec6a 100644
--- a/starlette/graphql.py
+++ b/starlette/graphql.py
@@ -31,29 +31,18 @@ class GraphQLApp:
def __init__(
self,
schema: "graphene.Schema",
- executor: typing.Any = None,
executor_class: type = None,
graphiql: bool = True,
) -> None:
self.schema = schema
self.graphiql = graphiql
- if executor is None:
- # New style in 0.10.0. Use 'executor_class'.
- # See issue https://github.com/encode/starlette/issues/242
- self.executor = executor
- self.executor_class = executor_class
- self.is_async = executor_class is not None and issubclass(
- executor_class, AsyncioExecutor
- )
- else:
- # Old style. Use 'executor'.
- # We should remove this in the next median/major version bump.
- self.executor = executor
- self.executor_class = None
- self.is_async = isinstance(executor, AsyncioExecutor)
+ self.executor_class = executor_class
+ self.is_async = executor_class is not None and issubclass(
+ executor_class, AsyncioExecutor
+ )
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- if self.executor is None and self.executor_class is not None:
+ if self.executor_class is not None:
self.executor = self.executor_class()
request = Request(scope, receive=receive)
diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py
index b347a6a2d..77ba66925 100644
--- a/starlette/middleware/base.py
+++ b/starlette/middleware/base.py
@@ -1,9 +1,10 @@
-import asyncio
import typing
+import anyio
+
from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
-from starlette.types import ASGIApp, Message, Receive, Scope, Send
+from starlette.types import ASGIApp, Receive, Scope, Send
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
@@ -21,45 +22,39 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
return
- request = Request(scope, receive=receive)
- response = await self.dispatch_func(request, self.call_next)
- await response(scope, receive, send)
+ async def call_next(request: Request) -> Response:
+ send_stream, recv_stream = anyio.create_memory_object_stream()
- async def call_next(self, request: Request) -> Response:
- loop = asyncio.get_event_loop()
- queue: "asyncio.Queue[typing.Optional[Message]]" = asyncio.Queue()
+ async def coro() -> None:
+ async with send_stream:
+ await self.app(scope, request.receive, send_stream.send)
- scope = request.scope
- receive = request.receive
- send = queue.put
+ task_group.start_soon(coro)
- async def coro() -> None:
try:
- await self.app(scope, receive, send)
- finally:
- await queue.put(None)
-
- task = loop.create_task(coro())
- message = await queue.get()
- if message is None:
- task.result()
- raise RuntimeError("No response returned.")
- assert message["type"] == "http.response.start"
-
- async def body_stream() -> typing.AsyncGenerator[bytes, None]:
- while True:
- message = await queue.get()
- if message is None:
- break
- assert message["type"] == "http.response.body"
- yield message.get("body", b"")
- task.result()
-
- response = StreamingResponse(
- status_code=message["status"], content=body_stream()
- )
- response.raw_headers = message["headers"]
- return response
+ message = await recv_stream.receive()
+ except anyio.EndOfStream:
+ raise RuntimeError("No response returned.")
+
+ assert message["type"] == "http.response.start"
+
+ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
+ async with recv_stream:
+ async for message in recv_stream:
+ assert message["type"] == "http.response.body"
+ yield message.get("body", b"")
+
+ response = StreamingResponse(
+ status_code=message["status"], content=body_stream()
+ )
+ response.raw_headers = message["headers"]
+ return response
+
+ async with anyio.create_task_group() as task_group:
+ request = Request(scope, receive=receive)
+ response = await self.dispatch_func(request, call_next)
+ await response(scope, receive, send)
+ task_group.cancel_scope.cancel()
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py
index 515cf3e76..7e69e1a6b 100644
--- a/starlette/middleware/wsgi.py
+++ b/starlette/middleware/wsgi.py
@@ -1,10 +1,11 @@
-import asyncio
import io
+import math
import sys
import typing
-from starlette.concurrency import run_in_threadpool
-from starlette.types import Message, Receive, Scope, Send
+import anyio
+
+from starlette.types import Receive, Scope, Send
def build_environ(scope: Scope, body: bytes) -> dict:
@@ -69,9 +70,9 @@ def __init__(self, app: typing.Callable, scope: Scope) -> None:
self.scope = scope
self.status = None
self.response_headers = None
- self.send_event = asyncio.Event()
- self.send_queue: typing.List[typing.Optional[Message]] = []
- self.loop = asyncio.get_event_loop()
+ self.stream_send, self.stream_receive = anyio.create_memory_object_stream(
+ math.inf
+ )
self.response_started = False
self.exc_info: typing.Any = None
@@ -83,31 +84,18 @@ async def __call__(self, receive: Receive, send: Send) -> None:
body += message.get("body", b"")
more_body = message.get("more_body", False)
environ = build_environ(self.scope, body)
- sender = None
- try:
- sender = self.loop.create_task(self.sender(send))
- await run_in_threadpool(self.wsgi, environ, self.start_response)
- self.send_queue.append(None)
- self.send_event.set()
- await asyncio.wait_for(sender, None)
- if self.exc_info is not None:
- raise self.exc_info[0].with_traceback(
- self.exc_info[1], self.exc_info[2]
- )
- finally:
- if sender and not sender.done():
- sender.cancel() # pragma: no cover
+
+ async with anyio.create_task_group() as task_group:
+ task_group.start_soon(self.sender, send)
+ async with self.stream_send:
+ await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response)
+ if self.exc_info is not None:
+ raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2])
async def sender(self, send: Send) -> None:
- while True:
- if self.send_queue:
- message = self.send_queue.pop(0)
- if message is None:
- return
+ async with self.stream_receive:
+ async for message in self.stream_receive:
await send(message)
- else:
- await self.send_event.wait()
- self.send_event.clear()
def start_response(
self,
@@ -124,21 +112,22 @@ def start_response(
(name.strip().encode("ascii").lower(), value.strip().encode("ascii"))
for name, value in response_headers
]
- self.send_queue.append(
+ anyio.from_thread.run(
+ self.stream_send.send,
{
"type": "http.response.start",
"status": status_code,
"headers": headers,
- }
+ },
)
- self.loop.call_soon_threadsafe(self.send_event.set)
def wsgi(self, environ: dict, start_response: typing.Callable) -> None:
for chunk in self.app(environ, start_response):
- self.send_queue.append(
- {"type": "http.response.body", "body": chunk, "more_body": True}
+ anyio.from_thread.run(
+ self.stream_send.send,
+ {"type": "http.response.body", "body": chunk, "more_body": True},
)
- self.loop.call_soon_threadsafe(self.send_event.set)
- self.send_queue.append({"type": "http.response.body", "body": b""})
- self.loop.call_soon_threadsafe(self.send_event.set)
+ anyio.from_thread.run(
+ self.stream_send.send, {"type": "http.response.body", "body": b""}
+ )
diff --git a/starlette/requests.py b/starlette/requests.py
index ab6f51424..54ed8611e 100644
--- a/starlette/requests.py
+++ b/starlette/requests.py
@@ -1,9 +1,10 @@
-import asyncio
import json
import typing
from collections.abc import Mapping
from http import cookies as http_cookies
+import anyio
+
from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
from starlette.formparsers import FormParser, MultiPartParser
from starlette.types import Message, Receive, Scope, Send
@@ -251,10 +252,12 @@ async def close(self) -> None:
async def is_disconnected(self) -> bool:
if not self._is_disconnected:
- try:
- message = await asyncio.wait_for(self._receive(), timeout=0.0000001)
- except asyncio.TimeoutError:
- message = {}
+ message: Message = {}
+
+ # If message isn't immediately available, move on
+ with anyio.CancelScope() as cs:
+ cs.cancel()
+ message = await self._receive()
if message.get("type") == "http.disconnect":
self._is_disconnected = True
diff --git a/starlette/responses.py b/starlette/responses.py
index 00f6be4db..d03df2329 100644
--- a/starlette/responses.py
+++ b/starlette/responses.py
@@ -6,24 +6,20 @@
import sys
import typing
from email.utils import formatdate
+from functools import partial
from mimetypes import guess_type as mimetypes_guess_type
from urllib.parse import quote
+import anyio
+
from starlette.background import BackgroundTask
-from starlette.concurrency import iterate_in_threadpool, run_until_first_complete
+from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, MutableHeaders
from starlette.types import Receive, Scope, Send
# Workaround for adding samesite support to pre 3.8 python
http.cookies.Morsel._reserved["samesite"] = "SameSite" # type: ignore
-try:
- import aiofiles
- from aiofiles.os import stat as aio_stat
-except ImportError: # pragma: nocover
- aiofiles = None # type: ignore
- aio_stat = None # type: ignore
-
# Compatibility wrapper for `mimetypes.guess_type` to support `os.PathLike` on None:
await send({"type": "http.response.body", "body": b"", "more_body": False})
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- await run_until_first_complete(
- (self.stream_response, {"send": send}),
- (self.listen_for_disconnect, {"receive": receive}),
- )
+ async with anyio.create_task_group() as task_group:
+
+ async def wrap(func: typing.Callable[[], typing.Coroutine]) -> None:
+ await func()
+ task_group.cancel_scope.cancel()
+
+ task_group.start_soon(wrap, partial(self.stream_response, send))
+ await wrap(partial(self.listen_for_disconnect, receive))
if self.background is not None:
await self.background()
@@ -244,7 +244,6 @@ def __init__(
stat_result: os.stat_result = None,
method: str = None,
) -> None:
- assert aiofiles is not None, "'aiofiles' must be installed to use FileResponse"
self.path = path
self.status_code = status_code
self.filename = filename
@@ -280,7 +279,7 @@ def set_stat_headers(self, stat_result: os.stat_result) -> None:
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.stat_result is None:
try:
- stat_result = await aio_stat(self.path)
+ stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
self.set_stat_headers(stat_result)
except FileNotFoundError:
raise RuntimeError(f"File at path {self.path} does not exist.")
@@ -298,10 +297,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.send_header_only:
await send({"type": "http.response.body", "body": b"", "more_body": False})
else:
- # Tentatively ignoring type checking failure to work around the wrong type
- # definitions for aiofile that come with typeshed. See
- # https://github.com/python/typeshed/pull/4650
- async with aiofiles.open(self.path, mode="rb") as file: # type: ignore
+ async with await anyio.open_file(self.path, mode="rb") as file:
more_body = True
while more_body:
chunk = await file.read(self.chunk_size)
diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py
index 15a67fe35..33ea0b033 100644
--- a/starlette/staticfiles.py
+++ b/starlette/staticfiles.py
@@ -4,7 +4,7 @@
import typing
from email.utils import parsedate
-from aiofiles.os import stat as aio_stat
+import anyio
from starlette.datastructures import URL, Headers
from starlette.responses import (
@@ -154,7 +154,7 @@ async def lookup_path(
# directory.
continue
try:
- stat_result = await aio_stat(full_path)
+ stat_result = await anyio.to_thread.run_sync(os.stat, full_path)
return full_path, stat_result
except FileNotFoundError:
pass
@@ -187,7 +187,7 @@ async def check_config(self) -> None:
return
try:
- stat_result = await aio_stat(self.directory)
+ stat_result = await anyio.to_thread.run_sync(os.stat, self.directory)
except FileNotFoundError:
raise RuntimeError(
f"StaticFiles directory '{self.directory}' does not exist."
diff --git a/starlette/testclient.py b/starlette/testclient.py
index 77c038b17..c1c0fe165 100644
--- a/starlette/testclient.py
+++ b/starlette/testclient.py
@@ -1,15 +1,19 @@
import asyncio
+import contextlib
import http
import inspect
import io
import json
+import math
import queue
-import threading
import types
import typing
+from concurrent.futures import Future
from urllib.parse import unquote, urljoin, urlsplit
+import anyio
import requests
+from anyio.streams.stapled import StapledObjectStream
from starlette.types import Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect
@@ -89,11 +93,16 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
class _ASGIAdapter(requests.adapters.HTTPAdapter):
def __init__(
- self, app: ASGI3App, raise_server_exceptions: bool = True, root_path: str = ""
+ self,
+ app: ASGI3App,
+ async_backend: typing.Dict[str, typing.Any],
+ raise_server_exceptions: bool = True,
+ root_path: str = "",
) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
self.root_path = root_path
+ self.async_backend = async_backend
def send(
self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any
@@ -142,7 +151,7 @@ def send(
"server": [host, port],
"subprotocols": subprotocols,
}
- session = WebSocketTestSession(self.app, scope)
+ session = WebSocketTestSession(self.app, scope, self.async_backend)
raise _Upgrade(session)
scope = {
@@ -161,17 +170,17 @@ def send(
request_complete = False
response_started = False
- response_complete = False
+ response_complete: anyio.Event
raw_kwargs: typing.Dict[str, typing.Any] = {"body": io.BytesIO()}
template = None
context = None
async def receive() -> Message:
- nonlocal request_complete, response_complete
+ nonlocal request_complete
if request_complete:
- while not response_complete:
- await asyncio.sleep(0.0001)
+ if not response_complete.is_set():
+ await response_complete.wait()
return {"type": "http.disconnect"}
body = request.body
@@ -195,7 +204,7 @@ async def receive() -> Message:
return {"type": "http.request", "body": body_bytes}
async def send(message: Message) -> None:
- nonlocal raw_kwargs, response_started, response_complete, template, context
+ nonlocal raw_kwargs, response_started, template, context
if message["type"] == "http.response.start":
assert (
@@ -217,7 +226,7 @@ async def send(message: Message) -> None:
response_started
), 'Received "http.response.body" without "http.response.start".'
assert (
- not response_complete
+ not response_complete.is_set()
), 'Received "http.response.body" after response completed.'
body = message.get("body", b"")
more_body = message.get("more_body", False)
@@ -225,19 +234,15 @@ async def send(message: Message) -> None:
raw_kwargs["body"].write(body)
if not more_body:
raw_kwargs["body"].seek(0)
- response_complete = True
+ response_complete.set()
elif message["type"] == "http.response.template":
template = message["template"]
context = message["context"]
try:
- loop = asyncio.get_event_loop()
- except RuntimeError:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(self.app(scope, receive, send))
+ with anyio.start_blocking_portal(**self.async_backend) as portal:
+ response_complete = portal.call(anyio.Event)
+ portal.call(self.app, scope, receive, send)
except BaseException as exc:
if self.raise_server_exceptions:
raise exc
@@ -264,48 +269,59 @@ async def send(message: Message) -> None:
class WebSocketTestSession:
- def __init__(self, app: ASGI3App, scope: Scope) -> None:
+ def __init__(
+ self, app: ASGI3App, scope: Scope, async_backend: typing.Dict[str, typing.Any]
+ ) -> None:
self.app = app
self.scope = scope
self.accepted_subprotocol = None
+ self.async_backend = async_backend
self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
- self._thread = threading.Thread(target=self._run)
- self.send({"type": "websocket.connect"})
- self._thread.start()
- message = self.receive()
- self._raise_on_close(message)
- self.accepted_subprotocol = message.get("subprotocol", None)
def __enter__(self) -> "WebSocketTestSession":
+ self.exit_stack = contextlib.ExitStack()
+ self.portal = self.exit_stack.enter_context(
+ anyio.start_blocking_portal(**self.async_backend)
+ )
+
+ try:
+ _: "Future[None]" = self.portal.start_task_soon(self._run)
+ self.send({"type": "websocket.connect"})
+ message = self.receive()
+ self._raise_on_close(message)
+ except Exception:
+ self.exit_stack.close()
+ raise
+ self.accepted_subprotocol = message.get("subprotocol", None)
return self
def __exit__(self, *args: typing.Any) -> None:
- self.close(1000)
- self._thread.join()
+ try:
+ self.close(1000)
+ finally:
+ self.exit_stack.close()
while not self._send_queue.empty():
message = self._send_queue.get()
if isinstance(message, BaseException):
raise message
- def _run(self) -> None:
+ async def _run(self) -> None:
"""
The sub-thread in which the websocket session runs.
"""
- loop = asyncio.new_event_loop()
scope = self.scope
receive = self._asgi_receive
send = self._asgi_send
try:
- loop.run_until_complete(self.app(scope, receive, send))
+ await self.app(scope, receive, send)
except BaseException as exc:
self._send_queue.put(exc)
- finally:
- loop.close()
+ raise
async def _asgi_receive(self) -> Message:
while self._receive_queue.empty():
- await asyncio.sleep(0)
+ await anyio.sleep(0)
return self._receive_queue.get()
async def _asgi_send(self, message: Message) -> None:
@@ -365,6 +381,14 @@ def receive_json(self, mode: str = "text") -> typing.Any:
class TestClient(requests.Session):
__test__ = False # For pytest to not discover this up.
+ #: These options are passed to `anyio.start_blocking_portal()`
+ async_backend: typing.Dict[str, typing.Any] = {
+ "backend": "asyncio",
+ "backend_options": {},
+ }
+
+ task: "Future[None]"
+
def __init__(
self,
app: typing.Union[ASGI2App, ASGI3App],
@@ -381,6 +405,7 @@ def __init__(
asgi_app = _WrapASGI2(app) # type: ignore
adapter = _ASGIAdapter(
asgi_app,
+ self.async_backend,
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
)
@@ -452,27 +477,40 @@ def websocket_connect(
return session
def __enter__(self) -> "TestClient":
- loop = asyncio.get_event_loop()
- self.send_queue: "asyncio.Queue[typing.Any]" = asyncio.Queue()
- self.receive_queue: "asyncio.Queue[typing.Any]" = asyncio.Queue()
- self.task = loop.create_task(self.lifespan())
- loop.run_until_complete(self.wait_startup())
+ self.exit_stack = contextlib.ExitStack()
+ self.portal = self.exit_stack.enter_context(
+ anyio.start_blocking_portal(**self.async_backend)
+ )
+ self.stream_send = StapledObjectStream(
+ *anyio.create_memory_object_stream(math.inf)
+ )
+ self.stream_receive = StapledObjectStream(
+ *anyio.create_memory_object_stream(math.inf)
+ )
+ try:
+ self.task = self.portal.start_task_soon(self.lifespan)
+ self.portal.call(self.wait_startup)
+ except Exception:
+ self.exit_stack.close()
+ raise
return self
def __exit__(self, *args: typing.Any) -> None:
- loop = asyncio.get_event_loop()
- loop.run_until_complete(self.wait_shutdown())
+ try:
+ self.portal.call(self.wait_shutdown)
+ finally:
+ self.exit_stack.close()
async def lifespan(self) -> None:
scope = {"type": "lifespan"}
try:
- await self.app(scope, self.receive_queue.get, self.send_queue.put)
+ await self.app(scope, self.stream_receive.receive, self.stream_send.send)
finally:
- await self.send_queue.put(None)
+ await self.stream_send.send(None)
async def wait_startup(self) -> None:
- await self.receive_queue.put({"type": "lifespan.startup"})
- message = await self.send_queue.get()
+ await self.stream_receive.send({"type": "lifespan.startup"})
+ message = await self.stream_send.receive()
if message is None:
self.task.result()
assert message["type"] in (
@@ -480,14 +518,14 @@ async def wait_startup(self) -> None:
"lifespan.startup.failed",
)
if message["type"] == "lifespan.startup.failed":
- message = await self.send_queue.get()
+ message = await self.stream_send.receive()
if message is None:
self.task.result()
async def wait_shutdown(self) -> None:
- await self.receive_queue.put({"type": "lifespan.shutdown"})
- message = await self.send_queue.get()
- if message is None:
- self.task.result()
- assert message["type"] == "lifespan.shutdown.complete"
- await self.task
+ async with self.stream_send:
+ await self.stream_receive.send({"type": "lifespan.shutdown"})
+ message = await self.stream_send.receive()
+ if message is None:
+ self.task.result()
+ assert message["type"] == "lifespan.shutdown.complete"
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 000000000..d1f3ba8e4
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,24 @@
+import pytest
+
+from starlette.testclient import TestClient
+
+
+@pytest.fixture(
+ params=[
+ pytest.param(
+ {"backend": "asyncio", "backend_options": {"use_uvloop": False}},
+ id="asyncio",
+ ),
+ pytest.param({"backend": "trio", "backend_options": {}}, id="trio"),
+ ],
+ autouse=True,
+)
+def anyio_backend(request, monkeypatch):
+ monkeypatch.setattr(TestClient, "async_backend", request.param)
+ return request.param["backend"]
+
+
+@pytest.fixture
+def no_trio_support(request):
+ if request.keywords.get("trio"):
+ pytest.skip("Trio not supported (yet!)")
diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py
index 048dd9ffb..df8901934 100644
--- a/tests/middleware/test_base.py
+++ b/tests/middleware/test_base.py
@@ -143,3 +143,18 @@ def homepage(request):
def test_middleware_repr():
middleware = Middleware(CustomMiddleware)
assert repr(middleware) == "Middleware(CustomMiddleware)"
+
+
+def test_fully_evaluated_response():
+ # Test for https://github.com/encode/starlette/issues/1022
+ class CustomMiddleware(BaseHTTPMiddleware):
+ async def dispatch(self, request, call_next):
+ await call_next(request)
+ return PlainTextResponse("Custom")
+
+ app = Starlette()
+ app.add_middleware(CustomMiddleware)
+
+ client = TestClient(app)
+ response = client.get("/does_not_exist")
+ assert response.text == "Custom"
diff --git a/tests/middleware/test_errors.py b/tests/middleware/test_errors.py
index c178ef9da..28b2a7ba3 100644
--- a/tests/middleware/test_errors.py
+++ b/tests/middleware/test_errors.py
@@ -67,4 +67,5 @@ async def app(scope, receive, send):
with pytest.raises(RuntimeError):
client = TestClient(app)
- client.websocket_connect("/")
+ with client.websocket_connect("/"):
+ pass # pragma: nocover
diff --git a/tests/test_authentication.py b/tests/test_authentication.py
index 3373f67c5..8ee87932a 100644
--- a/tests/test_authentication.py
+++ b/tests/test_authentication.py
@@ -261,10 +261,14 @@ def test_authentication_required():
def test_websocket_authentication_required():
with TestClient(app) as client:
with pytest.raises(WebSocketDisconnect):
- client.websocket_connect("/ws")
+ with client.websocket_connect("/ws"):
+ pass # pragma: nocover
with pytest.raises(WebSocketDisconnect):
- client.websocket_connect("/ws", headers={"Authorization": "basic foobar"})
+ with client.websocket_connect(
+ "/ws", headers={"Authorization": "basic foobar"}
+ ):
+ pass # pragma: nocover
with client.websocket_connect(
"/ws", auth=("tomchristie", "example")
@@ -273,12 +277,14 @@ def test_websocket_authentication_required():
assert data == {"authenticated": True, "user": "tomchristie"}
with pytest.raises(WebSocketDisconnect):
- client.websocket_connect("/ws/decorated")
+ with client.websocket_connect("/ws/decorated"):
+ pass # pragma: nocover
with pytest.raises(WebSocketDisconnect):
- client.websocket_connect(
+ with client.websocket_connect(
"/ws/decorated", headers={"Authorization": "basic foobar"}
- )
+ ):
+ pass # pragma: nocover
with client.websocket_connect(
"/ws/decorated", auth=("tomchristie", "example")
diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py
new file mode 100644
index 000000000..cc5eba974
--- /dev/null
+++ b/tests/test_concurrency.py
@@ -0,0 +1,22 @@
+import anyio
+import pytest
+
+from starlette.concurrency import run_until_first_complete
+
+
+@pytest.mark.anyio
+async def test_run_until_first_complete():
+ task1_finished = anyio.Event()
+ task2_finished = anyio.Event()
+
+ async def task1():
+ task1_finished.set()
+
+ async def task2():
+ await task1_finished.wait()
+ await anyio.sleep(0) # pragma: nocover
+ task2_finished.set() # pragma: nocover
+
+ await run_until_first_complete((task1, {}), (task2, {}))
+ assert task1_finished.is_set()
+ assert not task2_finished.is_set()
diff --git a/tests/test_database.py b/tests/test_database.py
index 258a71ec5..f7280c2c7 100644
--- a/tests/test_database.py
+++ b/tests/test_database.py
@@ -19,6 +19,9 @@
)
+pytestmark = pytest.mark.usefixtures("no_trio_support")
+
+
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py
index b0e6baf98..bb71ba870 100644
--- a/tests/test_datastructures.py
+++ b/tests/test_datastructures.py
@@ -217,7 +217,7 @@ class BigUploadFile(UploadFile):
spool_max_size = 1024
-@pytest.mark.asyncio
+@pytest.mark.anyio
async def test_upload_file():
big_file = BigUploadFile("big-file")
await big_file.write(b"big-data" * 512)
diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py
index 841c9a5cf..bab6961b5 100644
--- a/tests/test_exceptions.py
+++ b/tests/test_exceptions.py
@@ -54,7 +54,8 @@ def test_not_modified():
def test_websockets_should_raise():
with pytest.raises(RuntimeError):
- client.websocket_connect("/runtime_error")
+ with client.websocket_connect("/runtime_error"):
+ pass # pragma: nocover
def test_handled_exc_after_response():
diff --git a/tests/test_graphql.py b/tests/test_graphql.py
index 67f307231..b945a5cfe 100644
--- a/tests/test_graphql.py
+++ b/tests/test_graphql.py
@@ -1,5 +1,4 @@
import graphene
-import pytest
from graphql.execution.executors.asyncio import AsyncioExecutor
from starlette.applications import Starlette
@@ -142,27 +141,8 @@ async def resolve_hello(self, info, name):
async_app = GraphQLApp(schema=async_schema, executor_class=AsyncioExecutor)
-def test_graphql_async():
+def test_graphql_async(no_trio_support):
client = TestClient(async_app)
response = client.get("/?query={ hello }")
assert response.status_code == 200
assert response.json() == {"data": {"hello": "Hello stranger"}}
-
-
-async_schema = graphene.Schema(query=ASyncQuery)
-
-
-@pytest.fixture
-def old_style_async_app(event_loop) -> GraphQLApp:
- old_style_async_app = GraphQLApp(
- schema=async_schema, executor=AsyncioExecutor(loop=event_loop)
- )
- return old_style_async_app
-
-
-def test_graphql_async_old_style_executor(old_style_async_app: GraphQLApp):
- # See https://github.com/encode/starlette/issues/242
- client = TestClient(old_style_async_app)
- response = client.get("/?query={ hello }")
- assert response.status_code == 200
- assert response.json() == {"data": {"hello": "Hello stranger"}}
diff --git a/tests/test_requests.py b/tests/test_requests.py
index a83a2c480..fee059ab2 100644
--- a/tests/test_requests.py
+++ b/tests/test_requests.py
@@ -1,5 +1,4 @@
-import asyncio
-
+import anyio
import pytest
from starlette.requests import ClientDisconnect, Request, State
@@ -212,9 +211,8 @@ async def receiver():
return {"type": "http.disconnect"}
scope = {"type": "http", "method": "POST", "path": "/"}
- loop = asyncio.get_event_loop()
with pytest.raises(ClientDisconnect):
- loop.run_until_complete(app(scope, receiver, None))
+ anyio.run(app, scope, receiver, None)
def test_request_is_disconnected():
diff --git a/tests/test_responses.py b/tests/test_responses.py
index fd2ba0e42..496e64c86 100644
--- a/tests/test_responses.py
+++ b/tests/test_responses.py
@@ -1,6 +1,6 @@
-import asyncio
import os
+import anyio
import pytest
from starlette import status
@@ -83,7 +83,7 @@ async def numbers(minimum, maximum):
yield str(i)
if i != maximum:
yield ", "
- await asyncio.sleep(0)
+ await anyio.sleep(0)
async def numbers_for_cleanup(start=1, stop=5):
nonlocal filled_by_bg_task
@@ -197,7 +197,7 @@ async def numbers(minimum, maximum):
yield str(i)
if i != maximum:
yield ", "
- await asyncio.sleep(0)
+ await anyio.sleep(0)
async def numbers_for_cleanup(start=1, stop=5):
nonlocal filled_by_bg_task
diff --git a/tests/test_routing.py b/tests/test_routing.py
index fff3332db..1d8eb8d95 100644
--- a/tests/test_routing.py
+++ b/tests/test_routing.py
@@ -286,7 +286,8 @@ def test_protocol_switch():
assert session.receive_json() == {"URL": "ws://testserver/"}
with pytest.raises(WebSocketDisconnect):
- client.websocket_connect("/404")
+ with client.websocket_connect("/404"):
+ pass # pragma: nocover
ok = PlainTextResponse("OK")
@@ -492,7 +493,8 @@ def test_standalone_ws_route_does_not_match():
app = WebSocketRoute("/", ws_helloworld)
client = TestClient(app)
with pytest.raises(WebSocketDisconnect):
- client.websocket_connect("/invalid")
+ with client.websocket_connect("/invalid"):
+ pass # pragma: nocover
def test_lifespan_async():
diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py
index 6b325071f..3c8ff240e 100644
--- a/tests/test_staticfiles.py
+++ b/tests/test_staticfiles.py
@@ -1,8 +1,8 @@
-import asyncio
import os
import pathlib
import time
+import anyio
import pytest
from starlette.applications import Starlette
@@ -153,8 +153,7 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir):
# We can't test this with 'requests', so we test the app directly here.
path = app.get_path({"path": "/../example.txt"})
scope = {"method": "GET"}
- loop = asyncio.get_event_loop()
- response = loop.run_until_complete(app.get_response(path, scope))
+ response = anyio.run(app.get_response, path, scope)
assert response.status_code == 404
assert response.body == b"Not Found"
diff --git a/tests/test_testclient.py b/tests/test_testclient.py
index 00f4e0125..86f36e172 100644
--- a/tests/test_testclient.py
+++ b/tests/test_testclient.py
@@ -1,5 +1,4 @@
-import asyncio
-
+import anyio
import pytest
from starlette.applications import Starlette
@@ -118,13 +117,14 @@ async def respond(websocket):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
- asyncio.ensure_future(respond(websocket))
- try:
- # this will block as the client does not send us data
- # it should not prevent `respond` from executing though
- await websocket.receive_json()
- except WebSocketDisconnect:
- pass
+ async with anyio.create_task_group() as task_group:
+ task_group.start_soon(respond, websocket)
+ try:
+ # this will block as the client does not send us data
+ # it should not prevent `respond` from executing though
+ await websocket.receive_json()
+ except WebSocketDisconnect:
+ pass
return asgi
diff --git a/tests/test_websockets.py b/tests/test_websockets.py
index ffb1a44a8..63ecd050a 100644
--- a/tests/test_websockets.py
+++ b/tests/test_websockets.py
@@ -1,9 +1,7 @@
-import asyncio
-
+import anyio
import pytest
from starlette import status
-from starlette.concurrency import run_until_first_complete
from starlette.testclient import TestClient
from starlette.websockets import WebSocket, WebSocketDisconnect
@@ -208,23 +206,24 @@ async def asgi(receive, send):
def test_websocket_concurrency_pattern():
def app(scope):
- async def reader(websocket, queue):
- async for data in websocket.iter_json():
- await queue.put(data)
+ stream_send, stream_receive = anyio.create_memory_object_stream()
- async def writer(websocket, queue):
- while True:
- message = await queue.get()
- await websocket.send_json(message)
+ async def reader(websocket):
+ async with stream_send:
+ async for data in websocket.iter_json():
+ await stream_send.send(data)
+
+ async def writer(websocket):
+ async with stream_receive:
+ async for message in stream_receive:
+ await websocket.send_json(message)
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
- queue = asyncio.Queue()
await websocket.accept()
- await run_until_first_complete(
- (reader, {"websocket": websocket, "queue": queue}),
- (writer, {"websocket": websocket, "queue": queue}),
- )
+ async with anyio.create_task_group() as task_group:
+ task_group.start_soon(reader, websocket)
+ await writer(websocket)
await websocket.close()
return asgi
@@ -283,7 +282,8 @@ async def asgi(receive, send):
client = TestClient(app)
with pytest.raises(WebSocketDisconnect) as exc:
- client.websocket_connect("/")
+ with client.websocket_connect("/"):
+ pass # pragma: nocover
assert exc.value.code == status.WS_1001_GOING_AWAY
@@ -311,7 +311,8 @@ async def asgi(receive, send):
client = TestClient(app)
with pytest.raises(AssertionError):
- client.websocket_connect("/123?a=abc")
+ with client.websocket_connect("/123?a=abc"):
+ pass # pragma: nocover
def test_duplicate_close():
@@ -327,7 +328,7 @@ async def asgi(receive, send):
client = TestClient(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/"):
- pass
+ pass # pragma: nocover
def test_duplicate_disconnect():