Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delay connection closed (#69) #79

Merged
merged 7 commits into from
Nov 10, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 66 additions & 2 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,70 @@ async def handler(request):
with trio.fail_after(2):
async with open_websocket(
HOST, server.port, '/', use_ssl=False) as connection:
with pytest.raises(ConnectionClosed) as e:
with pytest.raises(ConnectionClosed) as exc_info:
await connection.get_message()
assert e.reason.name == 'NORMAL_CLOSURE'
exc = exc_info.value
belm0 marked this conversation as resolved.
Show resolved Hide resolved
assert exc.reason.name == 'NORMAL_CLOSURE'


@pytest.mark.skip(reason='Hangs because channel size is hard coded to 0')
async def test_read_messages_after_remote_close(nursery):
belm0 marked this conversation as resolved.
Show resolved Hide resolved
'''
When the remote endpoint closes, the local endpoint can still read all
of the messages sent prior to closing. Any attempt to read beyond that will
raise ConnectionClosed.
'''
server_closed = trio.Event()

async def handler(request):
server = await request.accept()
async with server:
await server.send_message('1')
mehaase marked this conversation as resolved.
Show resolved Hide resolved
await server.send_message('2')
server_closed.set()

server = await nursery.start(
partial(serve_websocket, handler, HOST, 0, ssl_context=None))

async with open_websocket(HOST, server.port, '/', use_ssl=False) as client:
await server_closed.wait()
assert await client.get_message() == '1'
assert await client.get_message() == '2'
with pytest.raises(ConnectionClosed):
await client.get_message()


async def test_no_messages_after_local_close(nursery):
'''
If the local endpoint initiates closing, then pending messages are discarded
and any attempt to read a message will raise ConnectionClosed.
'''
client_closed = trio.Event()

async def handler(request):
# The server sends some messages and then closes.
server = await request.accept()
async with server:
await server.send_message('1')
await server.send_message('2')
await client_closed.wait()

server = await nursery.start(
partial(serve_websocket, handler, HOST, 0, ssl_context=None))

async with open_websocket(HOST, server.port, '/', use_ssl=False) as client:
pass
with pytest.raises(ConnectionClosed):
await client.get_message()
client_closed.set()


async def test_client_cm_exit_with_pending_messages(echo_server, autojump_clock):
with trio.fail_after(1):
async with open_websocket(HOST, echo_server.port, RESOURCE,
use_ssl=False) as ws:
await ws.send_message('hello')
# allow time for the server to respond
await trio.sleep(.1)
# bug: context manager exit is blocked on unconsumed message
#await ws.get_message()
54 changes: 28 additions & 26 deletions trio_websocket/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import trio.ssl
import wsproto.connection as wsconnection
import wsproto.frame_protocol as wsframeproto
from wsproto.events import BytesReceived
from yarl import URL

from .version import __version__
Expand Down Expand Up @@ -440,8 +441,7 @@ def __init__(self, stream, wsproto, *, path=None):
self._stream = stream
self._stream_lock = trio.StrictFIFOLock()
self._wsproto = wsproto
self._bytes_message = b''
self._str_message = ''
self._message_parts = [] # type: List[bytes|str]
self._reader_running = True
self._path = path
self._subprotocol = None
Expand Down Expand Up @@ -520,6 +520,7 @@ async def aclose(self, code=1000, reason=None):
return
self._wsproto.close(code=code, reason=reason)
try:
await self._recv_channel.aclose()
await self._write_pending()
await self._close_handshake.wait()
finally:
Expand All @@ -532,17 +533,21 @@ async def get_message(self):
Receive the next WebSocket message.

If no message is available immediately, then this function blocks until
a message is ready. When the connection is closed, this message
a message is ready.

If the remote endpoint closes the connection, then the caller can still
get messages sent prior to closing. Once all pending messages have been
retrieved, additional calls to this method will raise
``ConnectionClosed``. If the local endpoint closes the connection, then
pending messages are discarded and calls to this method will immediately
raise ``ConnectionClosed``.

:rtype: str or bytes
:raises ConnectionClosed: if connection is closed before a message
arrives.
:raises ConnectionClosed: if the connection is closed.
'''
if self._close_reason:
raise ConnectionClosed(self._close_reason)
try:
message = await self._recv_channel.receive()
except trio.EndOfChannel:
except (trio.ClosedResourceError, trio.EndOfChannel):
raise ConnectionClosed(self._close_reason) from None
return message

Expand Down Expand Up @@ -720,27 +725,24 @@ async def _handle_connection_failed_event(self, event):
self._open_handshake.set()
self._close_handshake.set()

async def _handle_bytes_received_event(self, event):
'''
Handle a BytesReceived event.

:param event:
async def _handle_data_received_event(self, event):
'''
self._bytes_message += event.data
if event.message_finished:
await self._send_channel.send(self._bytes_message)
self._bytes_message = b''

async def _handle_text_received_event(self, event):
'''
Handle a TextReceived event.
Handle a BytesReceived or TextReceived event.

:param event:
'''
self._str_message += event.data
self._message_parts.append(event.data)
if event.message_finished:
await self._send_channel.send(self._str_message)
self._str_message = ''
msg = (b'' if isinstance(event, BytesReceived) else '') \
.join(self._message_parts)
self._message_parts = []
try:
await self._send_channel.send(msg)
except trio.BrokenResourceError:
belm0 marked this conversation as resolved.
Show resolved Hide resolved
# The receive channel is closed, probably because somebody
# called ``aclose()``. We don't want to abort the reader task,
# and there's no useful cleanup that we can do here.
pass

async def _handle_ping_received_event(self, event):
'''
Expand Down Expand Up @@ -790,8 +792,8 @@ async def _reader_task(self):
'ConnectionFailed': self._handle_connection_failed_event,
'ConnectionEstablished': self._handle_connection_established_event,
'ConnectionClosed': self._handle_connection_closed_event,
'BytesReceived': self._handle_bytes_received_event,
'TextReceived': self._handle_text_received_event,
'BytesReceived': self._handle_data_received_event,
'TextReceived': self._handle_data_received_event,
'PingReceived': self._handle_ping_received_event,
'PongReceived': self._handle_pong_received_event,
}
Expand Down