Skip to content

Commit

Permalink
Implement loop.start_tls()
Browse files Browse the repository at this point in the history
Side change: no longer defer "start_reading()" call after
"connection_made()".  The reading should start synchronously to copy
asyncio behaviour.  The race condition in sslproto.py that prompted that
change has been fixed.
  • Loading branch information
1st1 committed Jun 4, 2018
1 parent eb2afa6 commit 622ed9c
Show file tree
Hide file tree
Showing 4 changed files with 456 additions and 24 deletions.
373 changes: 373 additions & 0 deletions tests/test_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,9 @@ class _TestSSL(tb.SSLTestCase):
ONLYCERT = tb._cert_fullname(__file__, 'ssl_cert.pem')
ONLYKEY = tb._cert_fullname(__file__, 'ssl_key.pem')

PAYLOAD_SIZE = 1024 * 100
TIMEOUT = 60

def test_create_server_ssl_1(self):
CNT = 0 # number of clients that were successful
TOTAL_CNT = 25 # total number of clients that test will create
Expand Down Expand Up @@ -1418,6 +1421,21 @@ async def client(addr):
with self.assertRaises(exc_type):
self.loop.run_until_complete(client(srv.addr))

def test_start_tls_wrong_args(self):
if self.implementation == 'asyncio':
raise unittest.SkipTest()

async def main():
with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
await self.loop.start_tls(None, None, None)

sslctx = self._create_server_ssl_context(
self.ONLYCERT, self.ONLYKEY)
with self.assertRaisesRegex(TypeError, 'is not supported'):
await self.loop.start_tls(None, None, sslctx)

self.loop.run_until_complete(main())

def test_ssl_handshake_timeout(self):
if self.implementation == 'asyncio':
raise unittest.SkipTest()
Expand Down Expand Up @@ -1480,6 +1498,361 @@ def test_ssl_connect_accepted_socket(self):
Test_UV_TCP.test_connect_accepted_socket(
self, server_context, client_context)

def test_start_tls_client_corrupted_ssl(self):
if self.implementation == 'asyncio':
raise unittest.SkipTest()

self.loop.set_exception_handler(lambda loop, ctx: None)

sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
client_sslctx = self._create_client_ssl_context()

def server(sock):
orig_sock = sock.dup()
try:
sock.starttls(
sslctx,
server_side=True)
sock.sendall(b'A\n')
sock.recv_all(1)
orig_sock.send(b'please corrupt the SSL connection')
except ssl.SSLError:
pass
finally:
sock.close()
orig_sock.close()

async def client(addr):
reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='',
loop=self.loop)

self.assertEqual(await reader.readline(), b'A\n')
writer.write(b'B')
with self.assertRaises(ssl.SSLError):
await reader.readline()
writer.close()
return 'OK'

with self.tcp_server(server,
max_clients=1,
backlog=1) as srv:

res = self.loop.run_until_complete(client(srv.addr))

self.assertEqual(res, 'OK')

def test_start_tls_client_reg_proto_1(self):
if self.implementation == 'asyncio':
raise unittest.SkipTest()

HELLO_MSG = b'1' * self.PAYLOAD_SIZE

server_context = self._create_server_ssl_context(
self.ONLYCERT, self.ONLYKEY)
client_context = self._create_client_ssl_context()

def serve(sock):
sock.settimeout(self.TIMEOUT)

data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))

sock.starttls(server_context, server_side=True)

sock.sendall(b'O')
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))

sock.shutdown(socket.SHUT_RDWR)
sock.close()

class ClientProto(asyncio.Protocol):
def __init__(self, on_data, on_eof):
self.on_data = on_data
self.on_eof = on_eof
self.con_made_cnt = 0

def connection_made(proto, tr):
proto.con_made_cnt += 1
# Ensure connection_made gets called only once.
self.assertEqual(proto.con_made_cnt, 1)

def data_received(self, data):
self.on_data.set_result(data)

def eof_received(self):
self.on_eof.set_result(True)

async def client(addr):
await asyncio.sleep(0.5, loop=self.loop)

on_data = self.loop.create_future()
on_eof = self.loop.create_future()

tr, proto = await self.loop.create_connection(
lambda: ClientProto(on_data, on_eof), *addr)

tr.write(HELLO_MSG)
new_tr = await self.loop.start_tls(tr, proto, client_context)

self.assertEqual(await on_data, b'O')
new_tr.write(HELLO_MSG)
await on_eof

new_tr.close()

with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
self.loop.run_until_complete(
asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))

def test_start_tls_client_buf_proto_1(self):
if self.implementation == 'asyncio':
raise unittest.SkipTest()

HELLO_MSG = b'1' * self.PAYLOAD_SIZE

server_context = self._create_server_ssl_context(
self.ONLYCERT, self.ONLYKEY)
client_context = self._create_client_ssl_context()

client_con_made_calls = 0

def serve(sock):
sock.settimeout(self.TIMEOUT)

data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))

sock.starttls(server_context, server_side=True)

sock.sendall(b'O')
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))

sock.sendall(b'2')
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))

sock.shutdown(socket.SHUT_RDWR)
sock.close()

class ClientProtoFirst(asyncio.BaseProtocol):
def __init__(self, on_data):
self.on_data = on_data
self.buf = bytearray(1)

def connection_made(self, tr):
nonlocal client_con_made_calls
client_con_made_calls += 1

def get_buffer(self, sizehint):
return self.buf

def buffer_updated(self, nsize):
assert nsize == 1
self.on_data.set_result(bytes(self.buf[:nsize]))

def eof_received(self):
pass

class ClientProtoSecond(asyncio.Protocol):
def __init__(self, on_data, on_eof):
self.on_data = on_data
self.on_eof = on_eof
self.con_made_cnt = 0

def connection_made(self, tr):
nonlocal client_con_made_calls
client_con_made_calls += 1

def data_received(self, data):
self.on_data.set_result(data)

def eof_received(self):
self.on_eof.set_result(True)

async def client(addr):
await asyncio.sleep(0.5, loop=self.loop)

on_data1 = self.loop.create_future()
on_data2 = self.loop.create_future()
on_eof = self.loop.create_future()

tr, proto = await self.loop.create_connection(
lambda: ClientProtoFirst(on_data1), *addr)

tr.write(HELLO_MSG)
new_tr = await self.loop.start_tls(tr, proto, client_context)

self.assertEqual(await on_data1, b'O')
new_tr.write(HELLO_MSG)

new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
self.assertEqual(await on_data2, b'2')
new_tr.write(HELLO_MSG)
await on_eof

new_tr.close()

# connection_made() should be called only once -- when
# we establish connection for the first time. Start TLS
# doesn't call connection_made() on application protocols.
self.assertEqual(client_con_made_calls, 1)

with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
self.loop.run_until_complete(
asyncio.wait_for(client(srv.addr),
loop=self.loop, timeout=self.TIMEOUT))

def test_start_tls_slow_client_cancel(self):
if self.implementation == 'asyncio':
raise unittest.SkipTest()

HELLO_MSG = b'1' * self.PAYLOAD_SIZE

client_context = self._create_client_ssl_context()
server_waits_on_handshake = self.loop.create_future()

def serve(sock):
sock.settimeout(self.TIMEOUT)

data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))

try:
self.loop.call_soon_threadsafe(
server_waits_on_handshake.set_result, None)
data = sock.recv_all(1024 * 1024)
except ConnectionAbortedError:
pass
finally:
sock.close()

class ClientProto(asyncio.Protocol):
def __init__(self, on_data, on_eof):
self.on_data = on_data
self.on_eof = on_eof
self.con_made_cnt = 0

def connection_made(proto, tr):
proto.con_made_cnt += 1
# Ensure connection_made gets called only once.
self.assertEqual(proto.con_made_cnt, 1)

def data_received(self, data):
self.on_data.set_result(data)

def eof_received(self):
self.on_eof.set_result(True)

async def client(addr):
await asyncio.sleep(0.5, loop=self.loop)

on_data = self.loop.create_future()
on_eof = self.loop.create_future()

tr, proto = await self.loop.create_connection(
lambda: ClientProto(on_data, on_eof), *addr)

tr.write(HELLO_MSG)

await server_waits_on_handshake

with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(
self.loop.start_tls(tr, proto, client_context),
0.5,
loop=self.loop)

with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
self.loop.run_until_complete(
asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))

def test_start_tls_server_1(self):
if self.implementation == 'asyncio':
raise unittest.SkipTest()

HELLO_MSG = b'1' * self.PAYLOAD_SIZE

server_context = self._create_server_ssl_context(
self.ONLYCERT, self.ONLYKEY)
client_context = self._create_client_ssl_context()

def client(sock, addr):
sock.settimeout(self.TIMEOUT)

sock.connect(addr)
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))

sock.starttls(client_context)
sock.sendall(HELLO_MSG)

sock.shutdown(socket.SHUT_RDWR)
sock.close()

class ServerProto(asyncio.Protocol):
def __init__(self, on_con, on_eof, on_con_lost):
self.on_con = on_con
self.on_eof = on_eof
self.on_con_lost = on_con_lost
self.data = b''

def connection_made(self, tr):
self.on_con.set_result(tr)

def data_received(self, data):
self.data += data

def eof_received(self):
self.on_eof.set_result(1)

def connection_lost(self, exc):
if exc is None:
self.on_con_lost.set_result(None)
else:
self.on_con_lost.set_exception(exc)

async def main(proto, on_con, on_eof, on_con_lost):
tr = await on_con
tr.write(HELLO_MSG)

self.assertEqual(proto.data, b'')

new_tr = await self.loop.start_tls(
tr, proto, server_context,
server_side=True,
ssl_handshake_timeout=self.TIMEOUT)

await on_eof
await on_con_lost
self.assertEqual(proto.data, HELLO_MSG)
new_tr.close()

async def run_main():
on_con = self.loop.create_future()
on_eof = self.loop.create_future()
on_con_lost = self.loop.create_future()
proto = ServerProto(on_con, on_eof, on_con_lost)

server = await self.loop.create_server(
lambda: proto, '127.0.0.1', 0)
addr = server.sockets[0].getsockname()

with self.tcp_client(lambda sock: client(sock, addr),
timeout=self.TIMEOUT):
await asyncio.wait_for(
main(proto, on_con, on_eof, on_con_lost),
loop=self.loop, timeout=self.TIMEOUT)

server.close()
await server.wait_closed()

self.loop.run_until_complete(run_main())


class Test_UV_TCPSSL(_TestSSL, tb.UVTestCase):
pass
Expand Down
Loading

0 comments on commit 622ed9c

Please sign in to comment.