Skip to content

Commit

Permalink
Checkin progress on testclient
Browse files Browse the repository at this point in the history
  • Loading branch information
uSpike committed Mar 27, 2021
1 parent 7e2cd46 commit 03e312e
Showing 1 changed file with 47 additions and 39 deletions.
86 changes: 47 additions & 39 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import io
import json
import queue
import threading
import types
import typing
from urllib.parse import unquote, urljoin, urlsplit

import anyio
import requests

from starlette.types import Message, Receive, Scope, Send
Expand Down Expand Up @@ -171,7 +171,7 @@ async def receive() -> Message:

if request_complete:
while not response_complete:
await asyncio.sleep(0.0001)
await anyio.sleep(0.0001)
return {"type": "http.disconnect"}

body = request.body
Expand Down Expand Up @@ -231,13 +231,8 @@ async def send(message: Message) -> None:
context = message["context"]

try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

try:
loop.run_until_complete(self.app(scope, receive, send))
with anyio.start_blocking_portal(backend_options={"debug": True}) as portal:
portal.call(self.app, scope, receive, send)
except BaseException as exc:
if self.raise_server_exceptions:
raise exc from None
Expand Down Expand Up @@ -268,11 +263,11 @@ def __init__(self, app: ASGI3App, scope: Scope) -> None:
self.app = app
self.scope = scope
self.accepted_subprotocol = None
self.portal = anyio.start_blocking_portal(backend_options={"debug": True})
self._receive_queue = queue.Queue() # type: queue.Queue
self._send_queue = queue.Queue() # type: queue.Queue
self._thread = threading.Thread(target=self._run)
self.portal.spawn_task(self._run)
self.send({"type": "websocket.connect"})
self._thread.start()
message = self.receive()
self._raise_on_close(message)
self.accepted_subprotocol = message.get("subprotocol", None)
Expand All @@ -281,31 +276,30 @@ def __enter__(self) -> "WebSocketTestSession":
return self

def __exit__(self, *args: typing.Any) -> None:
self.close(1000)
self._thread.join()
try:
self.close(1000)
finally:
self.portal.stop_from_external_thread()
while not self._send_queue.empty():
message = self._send_queue.get()
if isinstance(message, BaseException):
raise message

def _run(self) -> None:
async def _run(self) -> None:
"""
The sub-thread in which the websocket session runs.
"""
loop = asyncio.new_event_loop()
scope = self.scope
receive = self._asgi_receive
send = self._asgi_send
try:
loop.run_until_complete(self.app(scope, receive, send))
await self.app(scope, receive, send)
except BaseException as exc:
self._send_queue.put(exc)
finally:
loop.close()

async def _asgi_receive(self) -> Message:
while self._receive_queue.empty():
await asyncio.sleep(0)
await anyio.sleep(0)
return self._receive_queue.get()

async def _asgi_send(self, message: Message) -> None:
Expand Down Expand Up @@ -452,42 +446,56 @@ def websocket_connect(
return session

def __enter__(self) -> "TestClient":
loop = asyncio.get_event_loop()
self.send_queue = asyncio.Queue() # type: asyncio.Queue
self.receive_queue = asyncio.Queue() # type: asyncio.Queue
self.task = loop.create_task(self.lifespan())
loop.run_until_complete(self.wait_startup())
self.stream_send, self.stream_receive = anyio.create_memory_object_stream()
self.portal = anyio.start_blocking_portal(
backend_options={"debug": True}
) # XXX backend
self.task = self.portal.spawn_task(self.lifespan)
self.portal.call(self.wait_startup)
return self

def __exit__(self, *args: typing.Any) -> None:
loop = asyncio.get_event_loop()
loop.run_until_complete(self.wait_shutdown())
try:
self.portal.call(self.wait_shutdown)
finally:
self.portal.stop_from_external_thread()

async def lifespan(self) -> None:
scope = {"type": "lifespan"}
try:
await self.app(scope, self.receive_queue.get, self.send_queue.put)
finally:
await self.send_queue.put(None)
async with self.stream_send:
await self.app(scope, self.stream_receive.receive, self.stream_send.send)

async def wait_startup(self) -> None:
await self.receive_queue.put({"type": "lifespan.startup"})
message = await self.send_queue.get()
if message is None:
try:
await self.stream_send.send({"type": "lifespan.startup"})
except anyio.ClosedResourceError:
self.task.result()
return
try:
message = await self.stream_receive.receive()
except anyio.EndOfStream:
self.task.result()
return
assert message["type"] in (
"lifespan.startup.complete",
"lifespan.startup.failed",
)
if message["type"] == "lifespan.startup.failed":
message = await self.send_queue.get()
if message is None:
try:
message = await self.stream_receive.receive()
except anyio.EndOfStream:
self.task.result()

async def wait_shutdown(self) -> None:
await self.receive_queue.put({"type": "lifespan.shutdown"})
message = await self.send_queue.get()
if message is None:
try:
await self.stream_send.send({"type": "lifespan.shutdown"})
except anyio.ClosedResourceError:
self.task.result()
return
try:
message = await self.stream_receive.receive()
except anyio.EndOfStream:
self.task.result()
return
assert message["type"] == "lifespan.shutdown.complete"
await self.task
self.task.result()

0 comments on commit 03e312e

Please sign in to comment.