Skip to content

Commit

Permalink
Merge pull request #2 from MagicStack/master
Browse files Browse the repository at this point in the history
add app state check (MagicStack#263)
  • Loading branch information
Justin Kula authored Aug 24, 2019
2 parents ba8c9ab + 82104fb commit 11da8f7
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 8 deletions.
49 changes: 49 additions & 0 deletions tests/test_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,6 +1601,55 @@ async def client(addr):
# exception or log an error, even if the handshake failed
self.assertEqual(messages, [])

def test_ssl_handshake_connection_lost(self):
# #246: make sure that no connection_lost() is called before
# connection_made() is called first

client_sslctx = self._create_client_ssl_context()

# silence error logger
self.loop.set_exception_handler(lambda loop, ctx: None)

connection_made_called = False
connection_lost_called = False

def server(sock):
sock.recv(1024)
# break the connection during handshake
sock.close()

class ClientProto(asyncio.Protocol):
def connection_made(self, transport):
nonlocal connection_made_called
connection_made_called = True

def connection_lost(self, exc):
nonlocal connection_lost_called
connection_lost_called = True

async def client(addr):
await self.loop.create_connection(
ClientProto,
*addr,
ssl=client_sslctx,
server_hostname=''),

with self.tcp_server(server,
max_clients=1,
backlog=1) as srv:

with self.assertRaises(ConnectionResetError):
self.loop.run_until_complete(client(srv.addr))

if connection_lost_called:
if connection_made_called:
self.fail("unexpected call to connection_lost()")
else:
self.fail("unexpected call to connection_lost() without"
"calling connection_made()")
elif connection_made_called:
self.fail("unexpected call to connection_made()")

def test_ssl_connect_accepted_socket(self):
server_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
server_context.load_cert_chain(self.ONLYCERT, self.ONLYKEY)
Expand Down
19 changes: 17 additions & 2 deletions uvloop/sslproto.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,22 @@ cdef enum SSLProtocolState:
SHUTDOWN = 4


cdef enum AppProtocolState:
# This tracks the state of app protocol (https://git.io/fj59P):
#
# INIT -cm-> CON_MADE [-dr*->] [-er-> EOF?] -cl-> CON_LOST
#
# * cm: connection_made()
# * dr: data_received()
# * er: eof_received()
# * cl: connection_lost()

STATE_INIT = 0
STATE_CON_MADE = 1
STATE_EOF = 2
STATE_CON_LOST = 3


cdef class _SSLProtocolTransport:
cdef:
object _loop
Expand All @@ -30,7 +46,6 @@ cdef class SSLProtocol:
bint _app_transport_created

object _transport
bint _call_connection_made
object _ssl_handshake_timeout
object _ssl_shutdown_timeout

Expand All @@ -46,7 +61,7 @@ cdef class SSLProtocol:
object _ssl_buffer_view
SSLProtocolState _state
size_t _conn_lost
bint _eof_received
AppProtocolState _app_state

bint _ssl_writing_paused
bint _app_reading_paused
Expand Down
18 changes: 12 additions & 6 deletions uvloop/sslproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ cdef class SSLProtocol:
self._app_transport_created = False
# transport, ex: SelectorSocketTransport
self._transport = None
self._call_connection_made = call_connection_made
self._ssl_handshake_timeout = ssl_handshake_timeout
self._ssl_shutdown_timeout = ssl_shutdown_timeout
# SSL and state machine
Expand All @@ -264,7 +263,10 @@ cdef class SSLProtocol:
self._outgoing_read = self._outgoing.read
self._state = UNWRAPPED
self._conn_lost = 0 # Set when connection_lost called
self._eof_received = False
if call_connection_made:
self._app_state = STATE_INIT
else:
self._app_state = STATE_CON_MADE

# Flow Control

Expand Down Expand Up @@ -335,7 +337,10 @@ cdef class SSLProtocol:
self._app_transport._closed = True

if self._state != DO_HANDSHAKE:
self._loop.call_soon(self._app_protocol.connection_lost, exc)
if self._app_state == STATE_CON_MADE or \
self._app_state == STATE_EOF:
self._app_state = STATE_CON_LOST
self._loop.call_soon(self._app_protocol.connection_lost, exc)
self._set_state(UNWRAPPED)
self._transport = None
self._app_transport = None
Expand Down Expand Up @@ -518,7 +523,8 @@ cdef class SSLProtocol:
cipher=sslobj.cipher(),
compression=sslobj.compression(),
ssl_object=sslobj)
if self._call_connection_made:
if self._app_state == STATE_INIT:
self._app_state = STATE_CON_MADE
self._app_protocol.connection_made(self._get_app_transport())
self._wakeup_waiter()
self._do_read()
Expand Down Expand Up @@ -735,8 +741,8 @@ cdef class SSLProtocol:

cdef _call_eof_received(self):
try:
if not self._eof_received:
self._eof_received = True
if self._app_state == STATE_CON_MADE:
self._app_state = STATE_EOF
keep_open = self._app_protocol.eof_received()
if keep_open:
aio_logger.warning('returning true from eof_received() '
Expand Down

0 comments on commit 11da8f7

Please sign in to comment.