Skip to content

Commit

Permalink
Make sure we always close transports if connection waiter has failed
Browse files Browse the repository at this point in the history
  • Loading branch information
1st1 committed Jun 1, 2018
1 parent 874555c commit ac90d8b
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 28 deletions.
52 changes: 51 additions & 1 deletion tests/test_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,9 +1255,59 @@ async def client(addr):
max_clients=1,
backlog=1) as srv:

with self.assertRaises(ssl.SSLCertVerificationError):
exc_type = ssl.SSLError
if self.PY37:
exc_type = ssl.SSLCertVerificationError
with self.assertRaises(exc_type):
self.loop.run_until_complete(client(srv.addr))

def test_ssl_handshake_timeout(self):
if self.implementation == 'asyncio':
raise unittest.SkipTest()

# bpo-29970: Check that a connection is aborted if handshake is not
# completed in timeout period, instead of remaining open indefinitely
client_sslctx = self._create_client_ssl_context()

# silence error logger
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))

server_side_aborted = False

def server(sock):
nonlocal server_side_aborted
try:
sock.recv_all(1024 * 1024)
except ConnectionAbortedError:
server_side_aborted = True
finally:
sock.close()

async def client(addr):
await asyncio.wait_for(
self.loop.create_connection(
asyncio.Protocol,
*addr,
ssl=client_sslctx,
server_hostname='',
ssl_handshake_timeout=10.0),
0.5,
loop=self.loop)

with self.tcp_server(server,
max_clients=1,
backlog=1) as srv:

with self.assertRaises(asyncio.TimeoutError):
self.loop.run_until_complete(client(srv.addr))

self.assertTrue(server_side_aborted)

# Python issue #23197: cancelling a handshake must not raise an
# exception or log an error, even if the handshake failed
self.assertEqual(messages, [])

def test_ssl_connect_accepted_socket(self):
server_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
server_context.load_cert_chain(self.ONLYCERT, self.ONLYKEY)
Expand Down
51 changes: 34 additions & 17 deletions uvloop/loop.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1087,7 +1087,7 @@ cdef class Loop:
raise OSError(err.errno, 'error while attempting '
'to bind on address %r: %s'
% (pyaddr, err.strerror.lower()))
except:
except Exception:
tcp._close()
raise

Expand Down Expand Up @@ -1625,7 +1625,7 @@ cdef class Loop:
try:
tcp._open(sock.fileno())
tcp.listen(backlog)
except:
except Exception:
tcp._close()
raise

Expand Down Expand Up @@ -1794,7 +1794,7 @@ cdef class Loop:
tr._close()
tr = None
exceptions.append(exc)
except:
except Exception:
if tr is not None:
tr._close()
tr = None
Expand Down Expand Up @@ -1832,7 +1832,7 @@ cdef class Loop:
tr._open(sock.fileno())
tr._init_protocol()
await waiter
except:
except Exception:
# It's OK to call `_close()` here, as opposed to
# `_force_close()` or `close()` as we want to terminate the
# transport immediately. The `waiter` can only be waken
Expand All @@ -1844,7 +1844,11 @@ cdef class Loop:
tr._attach_fileobj(sock)

if ssl:
await ssl_waiter
try:
await ssl_waiter
except Exception:
tr._close()
raise
return protocol._app_transport, app_protocol
else:
return tr, protocol
Expand Down Expand Up @@ -1934,7 +1938,7 @@ cdef class Loop:
raise OSError(errno.EADDRINUSE, msg) from None
else:
raise
except:
except Exception:
sock.close()
raise

Expand All @@ -1957,14 +1961,14 @@ cdef class Loop:

try:
pipe._open(sock.fileno())
except:
except Exception:
pipe._close()
sock.close()
raise

try:
pipe.listen(backlog)
except:
except Exception:
pipe._close()
raise

Expand Down Expand Up @@ -2026,7 +2030,7 @@ cdef class Loop:
tr.connect(path)
try:
await waiter
except:
except Exception:
tr._close()
raise

Expand All @@ -2049,14 +2053,18 @@ cdef class Loop:
tr._open(sock.fileno())
tr._init_protocol()
await waiter
except:
except Exception:
tr._close()
raise

tr._attach_fileobj(sock)

if ssl:
await ssl_waiter
try:
await ssl_waiter
except Exception:
tr._close()
raise
return protocol._app_transport, app_protocol
else:
return tr, protocol
Expand Down Expand Up @@ -2408,7 +2416,11 @@ cdef class Loop:
transport._init_protocol()
transport._attach_fileobj(sock)

await waiter
try:
await waiter
except Exception:
transport.close()
raise

if ssl:
return protocol._app_transport, protocol
Expand Down Expand Up @@ -2488,7 +2500,7 @@ cdef class Loop:

try:
await waiter
except:
except Exception:
proc.close()
raise

Expand Down Expand Up @@ -2540,7 +2552,7 @@ cdef class Loop:
transp._open(pipe.fileno())
transp._init_protocol()
await waiter
except:
except Exception:
transp.close()
raise
transp._attach_fileobj(pipe)
Expand All @@ -2565,7 +2577,7 @@ cdef class Loop:
transp._open(pipe.fileno())
transp._init_protocol()
await waiter
except:
except Exception:
transp.close()
raise
transp._attach_fileobj(pipe)
Expand Down Expand Up @@ -2803,7 +2815,7 @@ cdef class Loop:
if sock is not None:
sock.close()
exceptions.append(exc)
except:
except Exception:
if sock is not None:
sock.close()
raise
Expand All @@ -2821,7 +2833,12 @@ cdef class Loop:
udp._set_waiter(waiter)
udp._init_protocol()

await waiter
try:
await waiter
except Exception:
udp.close()
raise

return udp, protocol

def _asyncgen_finalizer_hook(self, agen):
Expand Down
16 changes: 6 additions & 10 deletions uvloop/sslproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ class SSLProtocol(object):
self._waiter = waiter
self._loop = loop
self._app_protocol = app_protocol
self._app_transport = _SSLProtocolTransport(self._loop, self)
self._app_transport = None
# _SSLPipe instance (None until the connection is made)
self._sslpipe = None
self._session_established = False
Expand Down Expand Up @@ -466,12 +466,6 @@ class SSLProtocol(object):
if self._session_established:
self._session_established = False
self._loop.call_soon(self._app_protocol.connection_lost, exc)
else:
# Most likely an exception occurred while in SSL handshake.
# Just mark the app transport as closed so that its __del__
# doesn't complain.
if self._app_transport is not None:
self._app_transport._closed = True

self._transport = None
self._app_transport = None
Expand Down Expand Up @@ -617,8 +611,10 @@ class SSLProtocol(object):
self._extra.update(peercert=peercert,
cipher=sslobj.cipher(),
compression=sslobj.compression(),
ssl_object=sslobj,
)
ssl_object=sslobj)

self._app_transport = _SSLProtocolTransport(self._loop, self)

if self._call_connection_made:
self._app_protocol.connection_made(self._app_transport)
self._wakeup_waiter()
Expand Down Expand Up @@ -684,7 +680,7 @@ class SSLProtocol(object):
ConnectionAbortedError)):
if self._loop.get_debug():
aio_logger.debug("%r: %s", self, message, exc_info=True)
else:
elif not isinstance(exc, aio_CancelledError):
self._loop.call_exception_handler({
'message': message,
'exception': exc,
Expand Down

0 comments on commit ac90d8b

Please sign in to comment.