diff --git a/tests/test_tcp.py b/tests/test_tcp.py index 76403a65..12e2d4d2 100644 --- a/tests/test_tcp.py +++ b/tests/test_tcp.py @@ -1601,6 +1601,55 @@ async def client(addr): # exception or log an error, even if the handshake failed self.assertEqual(messages, []) + def test_ssl_handshake_connection_lost(self): + # #246: make sure that no connection_lost() is called before + # connection_made() is called first + + client_sslctx = self._create_client_ssl_context() + + # silence error logger + self.loop.set_exception_handler(lambda loop, ctx: None) + + connection_made_called = False + connection_lost_called = False + + def server(sock): + sock.recv(1024) + # break the connection during handshake + sock.close() + + class ClientProto(asyncio.Protocol): + def connection_made(self, transport): + nonlocal connection_made_called + connection_made_called = True + + def connection_lost(self, exc): + nonlocal connection_lost_called + connection_lost_called = True + + async def client(addr): + await self.loop.create_connection( + ClientProto, + *addr, + ssl=client_sslctx, + server_hostname=''), + + with self.tcp_server(server, + max_clients=1, + backlog=1) as srv: + + with self.assertRaises(ConnectionResetError): + self.loop.run_until_complete(client(srv.addr)) + + if connection_lost_called: + if connection_made_called: + self.fail("unexpected call to connection_lost()") + else: + self.fail("unexpected call to connection_lost() without" + "calling connection_made()") + elif connection_made_called: + self.fail("unexpected call to connection_made()") + def test_ssl_connect_accepted_socket(self): server_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) server_context.load_cert_chain(self.ONLYCERT, self.ONLYKEY) diff --git a/uvloop/sslproto.pxd b/uvloop/sslproto.pxd index e3f79a2d..c29af7ba 100644 --- a/uvloop/sslproto.pxd +++ b/uvloop/sslproto.pxd @@ -6,6 +6,22 @@ cdef enum SSLProtocolState: SHUTDOWN = 4 +cdef enum AppProtocolState: + # This tracks the state of app protocol (https://git.io/fj59P): + # + # INIT -cm-> CON_MADE [-dr*->] [-er-> EOF?] -cl-> CON_LOST + # + # * cm: connection_made() + # * dr: data_received() + # * er: eof_received() + # * cl: connection_lost() + + STATE_INIT = 0 + STATE_CON_MADE = 1 + STATE_EOF = 2 + STATE_CON_LOST = 3 + + cdef class _SSLProtocolTransport: cdef: object _loop @@ -30,7 +46,6 @@ cdef class SSLProtocol: bint _app_transport_created object _transport - bint _call_connection_made object _ssl_handshake_timeout object _ssl_shutdown_timeout @@ -46,7 +61,7 @@ cdef class SSLProtocol: object _ssl_buffer_view SSLProtocolState _state size_t _conn_lost - bint _eof_received + AppProtocolState _app_state bint _ssl_writing_paused bint _app_reading_paused diff --git a/uvloop/sslproto.pyx b/uvloop/sslproto.pyx index 42bbb74b..17f801c8 100644 --- a/uvloop/sslproto.pyx +++ b/uvloop/sslproto.pyx @@ -253,7 +253,6 @@ cdef class SSLProtocol: self._app_transport_created = False # transport, ex: SelectorSocketTransport self._transport = None - self._call_connection_made = call_connection_made self._ssl_handshake_timeout = ssl_handshake_timeout self._ssl_shutdown_timeout = ssl_shutdown_timeout # SSL and state machine @@ -264,7 +263,10 @@ cdef class SSLProtocol: self._outgoing_read = self._outgoing.read self._state = UNWRAPPED self._conn_lost = 0 # Set when connection_lost called - self._eof_received = False + if call_connection_made: + self._app_state = STATE_INIT + else: + self._app_state = STATE_CON_MADE # Flow Control @@ -335,7 +337,10 @@ cdef class SSLProtocol: self._app_transport._closed = True if self._state != DO_HANDSHAKE: - self._loop.call_soon(self._app_protocol.connection_lost, exc) + if self._app_state == STATE_CON_MADE or \ + self._app_state == STATE_EOF: + self._app_state = STATE_CON_LOST + self._loop.call_soon(self._app_protocol.connection_lost, exc) self._set_state(UNWRAPPED) self._transport = None self._app_transport = None @@ -518,7 +523,8 @@ cdef class SSLProtocol: cipher=sslobj.cipher(), compression=sslobj.compression(), ssl_object=sslobj) - if self._call_connection_made: + if self._app_state == STATE_INIT: + self._app_state = STATE_CON_MADE self._app_protocol.connection_made(self._get_app_transport()) self._wakeup_waiter() self._do_read() @@ -735,8 +741,8 @@ cdef class SSLProtocol: cdef _call_eof_received(self): try: - if not self._eof_received: - self._eof_received = True + if self._app_state == STATE_CON_MADE: + self._app_state = STATE_EOF keep_open = self._app_protocol.eof_received() if keep_open: aio_logger.warning('returning true from eof_received() '