Skip to content

Commit

Permalink
Added test for SSL over SSL.
Browse files Browse the repository at this point in the history
  • Loading branch information
fantix committed Jul 5, 2018
1 parent c178c6b commit 01be886
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 1 deletion.
138 changes: 138 additions & 0 deletions tests/test_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,144 @@ async def start_server():
for client in clients:
client.stop()

def test_create_server_ssl_over_ssl(self):
if self.implementation == 'asyncio':
raise unittest.SkipTest('asyncio does not support SSL over SSL')

CNT = 0 # number of clients that were successful
TOTAL_CNT = 25 # total number of clients that test will create
TIMEOUT = 10.0 # timeout for this test

A_DATA = b'A' * 1024 * 1024
B_DATA = b'B' * 1024 * 1024

sslctx_1 = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
client_sslctx_1 = self._create_client_ssl_context()
sslctx_2 = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
client_sslctx_2 = self._create_client_ssl_context()

clients = []

async def handle_client(reader, writer):
nonlocal CNT

# hack reader and writer to call start_tls()
transport = writer._transport
writer._transport = None
reader._transport = None

transport = await self.loop.start_tls(
transport, writer._protocol, sslctx_2, server_side=True)

# restore with new transport
writer._transport = transport
reader._transport = transport

data = await reader.readexactly(len(A_DATA))
self.assertEqual(data, A_DATA)
writer.write(b'OK')

data = await reader.readexactly(len(B_DATA))
self.assertEqual(data, B_DATA)
writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])

await writer.drain()
writer.close()

CNT += 1

async def test_client(addr):
fut = asyncio.Future(loop=self.loop)

def prog(sock):
try:
sock.connect(addr)
sock.starttls(client_sslctx_1)

# because wrap_socket() doesn't work correctly on
# SSLSocket, we have to do the 2nd level SSL manually
incoming = ssl.MemoryBIO()
outgoing = ssl.MemoryBIO()
sslobj = client_sslctx_2.wrap_bio(incoming, outgoing)

def do(func):
while True:
try:
rv = func()
break
except ssl.SSLWantReadError:
if outgoing.pending:
sock.send(outgoing.read())
incoming.write(sock.recv(65536))
if outgoing.pending:
sock.send(outgoing.read())
return rv

do(sslobj.do_handshake)

do(lambda: sslobj.write(A_DATA))
data = do(lambda: sslobj.read(2))
self.assertEqual(data, b'OK')

do(lambda: sslobj.write(B_DATA))
data = b''
while data != b'SPAM':
data += do(lambda: sslobj.read(4))
self.assertEqual(data, b'SPAM')

do(sslobj.unwrap)
sock.close()

except Exception as ex:
self.loop.call_soon_threadsafe(fut.set_exception, ex)
else:
self.loop.call_soon_threadsafe(fut.set_result, None)

client = self.tcp_client(prog)
client.start()
clients.append(client)

await fut

async def start_server():
extras = {}
if self.implementation != 'asyncio' or self.PY37:
extras = dict(ssl_handshake_timeout=10.0)

srv = await asyncio.start_server(
handle_client,
'127.0.0.1', 0,
family=socket.AF_INET,
ssl=sslctx_1,
loop=self.loop,
**extras)

try:
srv_socks = srv.sockets
self.assertTrue(srv_socks)

addr = srv_socks[0].getsockname()

tasks = []
for _ in range(TOTAL_CNT):
tasks.append(test_client(addr))

await asyncio.wait_for(
asyncio.gather(*tasks, loop=self.loop),
TIMEOUT, loop=self.loop)

finally:
self.loop.call_soon(srv.close)
await srv.wait_closed()

with self._silence_eof_received_warning():
self.loop.run_until_complete(start_server())

self.assertEqual(CNT, TOTAL_CNT)

for client in clients:
client.stop()

def test_create_connection_ssl_1(self):
if self.implementation == 'asyncio':
# Don't crash on asyncio errors
Expand Down
3 changes: 2 additions & 1 deletion uvloop/loop.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1534,7 +1534,8 @@ cdef class Loop:
f'sslcontext is expected to be an instance of ssl.SSLContext, '
f'got {sslcontext!r}')

if not isinstance(transport, (TCPTransport, UnixTransport)):
if not isinstance(transport, (TCPTransport, UnixTransport,
_SSLProtocolTransport)):
raise TypeError(
f'transport {transport!r} is not supported by start_tls()')

Expand Down

0 comments on commit 01be886

Please sign in to comment.