diff --git a/.ci/requirements.txt b/.ci/requirements.txt index 655ac694..96067c59 100644 --- a/.ci/requirements.txt +++ b/.ci/requirements.txt @@ -3,3 +3,4 @@ aiohttp tinys3 twine psutil +pyOpenSSL==18.0.0 diff --git a/examples/bench/echoserver.py b/examples/bench/echoserver.py index 91bf72fd..66c020da 100644 --- a/examples/bench/echoserver.py +++ b/examples/bench/echoserver.py @@ -74,6 +74,23 @@ def data_received(self, data): self.transport.write(data) +class EchoBufferedProtocol(asyncio.BufferedProtocol): + def connection_made(self, transport): + self.transport = transport + # Here the buffer is intended to be copied, so that the outgoing buffer + # won't be wrongly updated by next read + self.buffer = bytearray(256 * 1024) + + def connection_lost(self, exc): + self.transport = None + + def get_buffer(self, sizehint): + return self.buffer + + def buffer_updated(self, nbytes): + self.transport.write(self.buffer[:nbytes]) + + async def print_debug(loop): while True: print(chr(27) + "[2J") # clear screen @@ -89,6 +106,7 @@ async def print_debug(loop): parser.add_argument('--addr', default='127.0.0.1:25000', type=str) parser.add_argument('--print', default=False, action='store_true') parser.add_argument('--ssl', default=False, action='store_true') + parser.add_argument('--buffered', default=False, action='store_true') args = parser.parse_args() if args.uvloop: @@ -140,6 +158,10 @@ async def print_debug(loop): print('cannot use --stream and --proto simultaneously') exit(1) + if args.buffered: + print('cannot use --stream and --buffered simultaneously') + exit(1) + print('using asyncio/streams') if unix: coro = asyncio.start_unix_server(echo_client_streams, @@ -155,12 +177,18 @@ async def print_debug(loop): print('cannot use --stream and --proto simultaneously') exit(1) - print('using simple protocol') + if args.buffered: + print('using buffered protocol') + protocol = EchoBufferedProtocol + else: + print('using simple protocol') + protocol = EchoProtocol + if unix: - coro = loop.create_unix_server(EchoProtocol, addr, + coro = loop.create_unix_server(protocol, addr, ssl=server_context) else: - coro = loop.create_server(EchoProtocol, *addr, + coro = loop.create_server(protocol, *addr, ssl=server_context) srv = loop.run_until_complete(coro) else: diff --git a/requirements.dev.txt b/requirements.dev.txt index 49d07e0b..0e7df81e 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -1,3 +1,4 @@ Cython==0.28.4 Sphinx>=1.4.1 psutil +pyOpenSSL==18.0.0 diff --git a/tests/test_tcp.py b/tests/test_tcp.py index b1c9ddea..e19d6377 100644 --- a/tests/test_tcp.py +++ b/tests/test_tcp.py @@ -1,14 +1,18 @@ import asyncio import asyncio.sslproto import gc +import os +import select import socket import unittest.mock import uvloop import ssl import sys import threading +import time import weakref +from OpenSSL import SSL as openssl_ssl from uvloop import _testbase as tb @@ -1145,6 +1149,46 @@ def test_create_server_stream_bittype(self): srv.close() self.loop.run_until_complete(srv.wait_closed()) + def test_flowcontrol_mixin_set_write_limits(self): + async def client(addr): + paused = False + + class Protocol(asyncio.Protocol): + def pause_writing(self): + nonlocal paused + paused = True + + def resume_writing(self): + nonlocal paused + paused = False + + t, p = await self.loop.create_connection(Protocol, *addr) + + t.write(b'q' * 512) + self.assertEqual(t.get_write_buffer_size(), 512) + + t.set_write_buffer_limits(low=16385) + self.assertFalse(paused) + self.assertEqual(t.get_write_buffer_limits(), (16385, 65540)) + + with self.assertRaisesRegex(ValueError, 'high.*must be >= low'): + t.set_write_buffer_limits(high=0, low=1) + + t.set_write_buffer_limits(high=1024, low=128) + self.assertFalse(paused) + self.assertEqual(t.get_write_buffer_limits(), (128, 1024)) + + t.set_write_buffer_limits(high=256, low=128) + self.assertTrue(paused) + self.assertEqual(t.get_write_buffer_limits(), (128, 256)) + + t.close() + + with self.tcp_server(lambda sock: sock.recv_all(1), + max_clients=1, + backlog=1) as srv: + self.loop.run_until_complete(client(srv.addr)) + class Test_AIO_TCP(_TestTCP, tb.AIOTestCase): pass @@ -1569,7 +1613,7 @@ def serve(sock): data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) - sock.shutdown(socket.SHUT_RDWR) + sock.unwrap() sock.close() class ClientProto(asyncio.Protocol): @@ -1639,7 +1683,7 @@ def serve(sock): data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) - sock.shutdown(socket.SHUT_RDWR) + sock.unwrap() sock.close() class ClientProtoFirst(asyncio.BaseProtocol): @@ -1794,7 +1838,7 @@ def client(sock, addr): sock.starttls(client_context) sock.sendall(HELLO_MSG) - sock.shutdown(socket.SHUT_RDWR) + sock.unwrap() sock.close() class ServerProto(asyncio.Protocol): @@ -1856,6 +1900,716 @@ async def run_main(): self.loop.run_until_complete(run_main()) + def test_create_server_ssl_over_ssl(self): + if self.implementation == 'asyncio': + raise unittest.SkipTest('asyncio does not support SSL over SSL') + + CNT = 0 # number of clients that were successful + TOTAL_CNT = 25 # total number of clients that test will create + TIMEOUT = 10.0 # timeout for this test + + A_DATA = b'A' * 1024 * 1024 + B_DATA = b'B' * 1024 * 1024 + + sslctx_1 = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) + client_sslctx_1 = self._create_client_ssl_context() + sslctx_2 = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) + client_sslctx_2 = self._create_client_ssl_context() + + clients = [] + + async def handle_client(reader, writer): + nonlocal CNT + + data = await reader.readexactly(len(A_DATA)) + self.assertEqual(data, A_DATA) + writer.write(b'OK') + + data = await reader.readexactly(len(B_DATA)) + self.assertEqual(data, B_DATA) + writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) + + await writer.drain() + writer.close() + + CNT += 1 + + class ServerProtocol(asyncio.StreamReaderProtocol): + def connection_made(self, transport): + super_ = super() + transport.pause_reading() + fut = self._loop.create_task(self._loop.start_tls( + transport, self, sslctx_2, server_side=True)) + + def cb(_): + try: + tr = fut.result() + except Exception as ex: + super_.connection_lost(ex) + else: + super_.connection_made(tr) + fut.add_done_callback(cb) + + def server_protocol_factory(): + reader = asyncio.StreamReader(loop=self.loop) + protocol = ServerProtocol(reader, handle_client, loop=self.loop) + return protocol + + async def test_client(addr): + fut = asyncio.Future(loop=self.loop) + + def prog(sock): + try: + sock.connect(addr) + sock.starttls(client_sslctx_1) + + # because wrap_socket() doesn't work correctly on + # SSLSocket, we have to do the 2nd level SSL manually + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + sslobj = client_sslctx_2.wrap_bio(incoming, outgoing) + + def do(func, *args): + while True: + try: + rv = func(*args) + break + except ssl.SSLWantReadError: + if outgoing.pending: + sock.send(outgoing.read()) + incoming.write(sock.recv(65536)) + if outgoing.pending: + sock.send(outgoing.read()) + return rv + + do(sslobj.do_handshake) + + do(sslobj.write, A_DATA) + data = do(sslobj.read, 2) + self.assertEqual(data, b'OK') + + do(sslobj.write, B_DATA) + data = b'' + while True: + chunk = do(sslobj.read, 4) + if not chunk: + break + data += chunk + self.assertEqual(data, b'SPAM') + + do(sslobj.unwrap) + sock.close() + + except Exception as ex: + self.loop.call_soon_threadsafe(fut.set_exception, ex) + sock.close() + else: + self.loop.call_soon_threadsafe(fut.set_result, None) + + client = self.tcp_client(prog) + client.start() + clients.append(client) + + await fut + + async def start_server(): + extras = {} + if self.implementation != 'asyncio' or self.PY37: + extras = dict(ssl_handshake_timeout=10.0) + + srv = await self.loop.create_server( + server_protocol_factory, + '127.0.0.1', 0, + family=socket.AF_INET, + ssl=sslctx_1, + **extras) + + try: + srv_socks = srv.sockets + self.assertTrue(srv_socks) + + addr = srv_socks[0].getsockname() + + tasks = [] + for _ in range(TOTAL_CNT): + tasks.append(test_client(addr)) + + await asyncio.wait_for( + asyncio.gather(*tasks, loop=self.loop), + TIMEOUT, loop=self.loop) + + finally: + self.loop.call_soon(srv.close) + await srv.wait_closed() + + with self._silence_eof_received_warning(): + self.loop.run_until_complete(start_server()) + + self.assertEqual(CNT, TOTAL_CNT) + + for client in clients: + client.stop() + + def test_renegotiation(self): + if self.implementation == 'asyncio': + raise unittest.SkipTest('asyncio does not support renegotiation') + + CNT = 0 + TOTAL_CNT = 25 + + A_DATA = b'A' * 1024 * 1024 + B_DATA = b'B' * 1024 * 1024 + + sslctx = openssl_ssl.Context(openssl_ssl.SSLv23_METHOD) + if hasattr(openssl_ssl, 'OP_NO_SSLV2'): + sslctx.set_options(openssl_ssl.OP_NO_SSLV2) + sslctx.use_privatekey_file(self.ONLYKEY) + sslctx.use_certificate_chain_file(self.ONLYCERT) + client_sslctx = self._create_client_ssl_context() + + def server(sock): + conn = openssl_ssl.Connection(sslctx, sock) + conn.set_accept_state() + + data = b'' + while len(data) < len(A_DATA): + try: + chunk = conn.recv(len(A_DATA) - len(data)) + if not chunk: + break + data += chunk + except openssl_ssl.WantReadError: + pass + self.assertEqual(data, A_DATA) + conn.renegotiate() + if conn.renegotiate_pending(): + conn.send(b'OK') + else: + conn.send(b'ER') + + data = b'' + while len(data) < len(B_DATA): + try: + chunk = conn.recv(len(B_DATA) - len(data)) + if not chunk: + break + data += chunk + except openssl_ssl.WantReadError: + pass + self.assertEqual(data, B_DATA) + if conn.renegotiate_pending(): + conn.send(b'ERRO') + else: + conn.send(b'SPAM') + + conn.shutdown() + + async def client(addr): + extras = {} + if self.implementation != 'asyncio' or self.PY37: + extras = dict(ssl_handshake_timeout=10.0) + + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + loop=self.loop, + **extras) + + writer.write(A_DATA) + self.assertEqual(await reader.readexactly(2), b'OK') + + writer.write(B_DATA) + self.assertEqual(await reader.readexactly(4), b'SPAM') + + nonlocal CNT + CNT += 1 + + writer.close() + + async def client_sock(addr): + sock = socket.socket() + sock.connect(addr) + reader, writer = await asyncio.open_connection( + sock=sock, + ssl=client_sslctx, + server_hostname='', + loop=self.loop) + + writer.write(A_DATA) + self.assertEqual(await reader.readexactly(2), b'OK') + + writer.write(B_DATA) + self.assertEqual(await reader.readexactly(4), b'SPAM') + + nonlocal CNT + CNT += 1 + + writer.close() + sock.close() + + def run(coro): + nonlocal CNT + CNT = 0 + + with self.tcp_server(server, + max_clients=TOTAL_CNT, + backlog=TOTAL_CNT) as srv: + tasks = [] + for _ in range(TOTAL_CNT): + tasks.append(coro(srv.addr)) + + self.loop.run_until_complete( + asyncio.gather(*tasks, loop=self.loop)) + + self.assertEqual(CNT, TOTAL_CNT) + + with self._silence_eof_received_warning(): + run(client) + + with self._silence_eof_received_warning(): + run(client_sock) + + def test_shutdown_timeout(self): + if self.implementation == 'asyncio': + raise unittest.SkipTest() + + CNT = 0 # number of clients that were successful + TOTAL_CNT = 25 # total number of clients that test will create + TIMEOUT = 10.0 # timeout for this test + + A_DATA = b'A' * 1024 * 1024 + + sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) + client_sslctx = self._create_client_ssl_context() + + clients = [] + + async def handle_client(reader, writer): + nonlocal CNT + + data = await reader.readexactly(len(A_DATA)) + self.assertEqual(data, A_DATA) + writer.write(b'OK') + await writer.drain() + writer.close() + with self.assertRaisesRegex(asyncio.TimeoutError, + 'SSL shutdown timed out'): + await reader.read() + CNT += 1 + + async def test_client(addr): + fut = asyncio.Future(loop=self.loop) + + def prog(sock): + try: + sock.starttls(client_sslctx) + sock.connect(addr) + sock.send(A_DATA) + + data = sock.recv_all(2) + self.assertEqual(data, b'OK') + + data = sock.recv(1024) + self.assertEqual(data, b'') + + fd = sock.detach() + try: + select.select([fd], [], [], 3) + finally: + os.close(fd) + + except Exception as ex: + self.loop.call_soon_threadsafe(fut.set_exception, ex) + else: + self.loop.call_soon_threadsafe(fut.set_result, None) + + client = self.tcp_client(prog) + client.start() + clients.append(client) + + await fut + + async def start_server(): + extras = {} + if self.implementation != 'asyncio' or self.PY37: + extras['ssl_handshake_timeout'] = 10.0 + if self.implementation != 'asyncio': # or self.PY38 + extras['ssl_shutdown_timeout'] = 0.5 + + srv = await asyncio.start_server( + handle_client, + '127.0.0.1', 0, + family=socket.AF_INET, + ssl=sslctx, + loop=self.loop, + **extras) + + try: + srv_socks = srv.sockets + self.assertTrue(srv_socks) + + addr = srv_socks[0].getsockname() + + tasks = [] + for _ in range(TOTAL_CNT): + tasks.append(test_client(addr)) + + await asyncio.wait_for( + asyncio.gather(*tasks, loop=self.loop), + TIMEOUT, loop=self.loop) + + finally: + self.loop.call_soon(srv.close) + await srv.wait_closed() + + with self._silence_eof_received_warning(): + self.loop.run_until_complete(start_server()) + + self.assertEqual(CNT, TOTAL_CNT) + + for client in clients: + client.stop() + + def test_shutdown_cleanly(self): + if self.implementation == 'asyncio': + raise unittest.SkipTest() + + CNT = 0 + TOTAL_CNT = 25 + + A_DATA = b'A' * 1024 * 1024 + + sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) + client_sslctx = self._create_client_ssl_context() + + def server(sock): + sock.starttls( + sslctx, + server_side=True) + + data = sock.recv_all(len(A_DATA)) + self.assertEqual(data, A_DATA) + sock.send(b'OK') + + sock.unwrap() + + sock.close() + + async def client(addr): + extras = {} + if self.implementation != 'asyncio' or self.PY37: + extras = dict(ssl_handshake_timeout=10.0) + + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + loop=self.loop, + **extras) + + writer.write(A_DATA) + self.assertEqual(await reader.readexactly(2), b'OK') + + self.assertEqual(await reader.read(), b'') + + nonlocal CNT + CNT += 1 + + writer.close() + + def run(coro): + nonlocal CNT + CNT = 0 + + with self.tcp_server(server, + max_clients=TOTAL_CNT, + backlog=TOTAL_CNT) as srv: + tasks = [] + for _ in range(TOTAL_CNT): + tasks.append(coro(srv.addr)) + + self.loop.run_until_complete( + asyncio.gather(*tasks, loop=self.loop)) + + self.assertEqual(CNT, TOTAL_CNT) + + with self._silence_eof_received_warning(): + run(client) + + def test_write_to_closed_transport(self): + if self.implementation == 'asyncio': + raise unittest.SkipTest() + + sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) + client_sslctx = self._create_client_ssl_context() + future = None + + def server(sock): + sock.starttls(sslctx, server_side=True) + sock.shutdown(socket.SHUT_RDWR) + sock.close() + + def unwrap_server(sock): + sock.starttls(sslctx, server_side=True) + while True: + try: + sock.unwrap() + break + except OSError as ex: + if ex.errno == 0: + pass + sock.close() + + async def client(addr): + nonlocal future + future = self.loop.create_future() + + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + loop=self.loop) + writer.write(b'I AM WRITING NOWHERE1' * 100) + + try: + data = await reader.read() + self.assertEqual(data, b'') + except (ConnectionResetError, BrokenPipeError): + pass + + for i in range(25): + writer.write(b'I AM WRITING NOWHERE2' * 100) + + self.assertEqual( + writer.transport.get_write_buffer_size(), 0) + + await future + + def run(meth): + def wrapper(sock): + try: + meth(sock) + except Exception as ex: + self.loop.call_soon_threadsafe(future.set_exception, ex) + else: + self.loop.call_soon_threadsafe(future.set_result, None) + return wrapper + + with self._silence_eof_received_warning(): + with self.tcp_server(run(server)) as srv: + self.loop.run_until_complete(client(srv.addr)) + + with self.tcp_server(run(unwrap_server)) as srv: + self.loop.run_until_complete(client(srv.addr)) + + def test_flush_before_shutdown(self): + if self.implementation == 'asyncio': + raise unittest.SkipTest() + + CHUNK = 1024 * 128 + SIZE = 32 + + sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) + sslctx_openssl = openssl_ssl.Context(openssl_ssl.SSLv23_METHOD) + if hasattr(openssl_ssl, 'OP_NO_SSLV2'): + sslctx_openssl.set_options(openssl_ssl.OP_NO_SSLV2) + sslctx_openssl.use_privatekey_file(self.ONLYKEY) + sslctx_openssl.use_certificate_chain_file(self.ONLYCERT) + client_sslctx = self._create_client_ssl_context() + + future = None + + def server(sock): + sock.starttls(sslctx, server_side=True) + self.assertEqual(sock.recv_all(4), b'ping') + sock.send(b'pong') + time.sleep(0.5) # hopefully stuck the TCP buffer + data = sock.recv_all(CHUNK * SIZE) + self.assertEqual(len(data), CHUNK * SIZE) + sock.close() + + def openssl_server(sock): + conn = openssl_ssl.Connection(sslctx_openssl, sock) + conn.set_accept_state() + + while True: + try: + data = conn.recv(16384) + self.assertEqual(data, b'ping') + break + except openssl_ssl.WantReadError: + pass + + # use renegotiation to queue data in peer _write_backlog + conn.renegotiate() + conn.send(b'pong') + + data_size = 0 + while True: + try: + chunk = conn.recv(16384) + if not chunk: + break + data_size += len(chunk) + except openssl_ssl.WantReadError: + pass + except openssl_ssl.ZeroReturnError: + break + self.assertEqual(data_size, CHUNK * SIZE) + + def run(meth): + def wrapper(sock): + try: + meth(sock) + except Exception as ex: + self.loop.call_soon_threadsafe(future.set_exception, ex) + else: + self.loop.call_soon_threadsafe(future.set_result, None) + return wrapper + + async def client(addr): + nonlocal future + future = self.loop.create_future() + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + loop=self.loop) + writer.write(b'ping') + data = await reader.readexactly(4) + self.assertEqual(data, b'pong') + for _ in range(SIZE): + writer.write(b'x' * CHUNK) + writer.close() + try: + data = await reader.read() + self.assertEqual(data, b'') + except ConnectionResetError: + pass + await future + + with self.tcp_server(run(server)) as srv: + self.loop.run_until_complete(client(srv.addr)) + + with self.tcp_server(run(openssl_server)) as srv: + self.loop.run_until_complete(client(srv.addr)) + + def test_remote_shutdown_receives_trailing_data(self): + if self.implementation == 'asyncio': + raise unittest.SkipTest() + + CHUNK = 1024 * 128 + SIZE = 32 + + sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) + client_sslctx = self._create_client_ssl_context() + future = None + + def server(sock): + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True) + + while True: + try: + sslobj.do_handshake() + except ssl.SSLWantReadError: + if outgoing.pending: + sock.send(outgoing.read()) + incoming.write(sock.recv(16384)) + else: + if outgoing.pending: + sock.send(outgoing.read()) + break + + incoming.write(sock.recv(16384)) + self.assertEqual(sslobj.read(4), b'ping') + sslobj.write(b'pong') + sock.send(outgoing.read()) + + time.sleep(0.2) # wait for the peer to fill its backlog + + # send close_notify but don't wait for response + with self.assertRaises(ssl.SSLWantReadError): + sslobj.unwrap() + sock.send(outgoing.read()) + + # should receive all data + data_len = 0 + while True: + try: + chunk = len(sslobj.read(16384)) + data_len += chunk + except ssl.SSLWantReadError: + incoming.write(sock.recv(16384)) + except ssl.SSLZeroReturnError: + break + + self.assertEqual(data_len, CHUNK * SIZE) + + # verify that close_notify is received + sslobj.unwrap() + + sock.close() + + def eof_server(sock): + sock.starttls(sslctx, server_side=True) + self.assertEqual(sock.recv_all(4), b'ping') + sock.send(b'pong') + + time.sleep(0.2) # wait for the peer to fill its backlog + + # send EOF + sock.shutdown(socket.SHUT_WR) + + # should receive all data + data = sock.recv_all(CHUNK * SIZE) + self.assertEqual(len(data), CHUNK * SIZE) + + sock.close() + + async def client(addr): + nonlocal future + future = self.loop.create_future() + + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + loop=self.loop) + writer.write(b'ping') + data = await reader.readexactly(4) + self.assertEqual(data, b'pong') + + # fill write backlog in a hacky way - renegotiation won't help + for _ in range(SIZE): + writer.transport._test__append_write_backlog(b'x' * CHUNK) + + try: + data = await reader.read() + self.assertEqual(data, b'') + except (BrokenPipeError, ConnectionResetError): + pass + + await future + + def run(meth): + def wrapper(sock): + try: + meth(sock) + except Exception as ex: + self.loop.call_soon_threadsafe(future.set_exception, ex) + else: + self.loop.call_soon_threadsafe(future.set_result, None) + return wrapper + + with self.tcp_server(run(server)) as srv: + self.loop.run_until_complete(client(srv.addr)) + + with self.tcp_server(run(eof_server)) as srv: + self.loop.run_until_complete(client(srv.addr)) + class Test_UV_TCPSSL(_TestSSL, tb.UVTestCase): pass diff --git a/uvloop/handles/basetransport.pxd b/uvloop/handles/basetransport.pxd index fa6be651..ba356a78 100644 --- a/uvloop/handles/basetransport.pxd +++ b/uvloop/handles/basetransport.pxd @@ -21,7 +21,6 @@ cdef class UVBaseTransport(UVSocketHandle): # All "inline" methods are final - cdef inline _set_write_buffer_limits(self, int high=*, int low=*) cdef inline _maybe_pause_protocol(self) cdef inline _maybe_resume_protocol(self) diff --git a/uvloop/handles/basetransport.pyx b/uvloop/handles/basetransport.pyx index 356b5afb..20a62728 100644 --- a/uvloop/handles/basetransport.pyx +++ b/uvloop/handles/basetransport.pyx @@ -2,8 +2,8 @@ cdef class UVBaseTransport(UVSocketHandle): def __cinit__(self): # Flow control - self._high_water = FLOW_CONTROL_HIGH_WATER - self._low_water = FLOW_CONTROL_LOW_WATER + self._high_water = FLOW_CONTROL_HIGH_WATER * 1024 + self._low_water = FLOW_CONTROL_HIGH_WATER // 4 self._protocol = None self._protocol_connected = 0 @@ -59,25 +59,6 @@ cdef class UVBaseTransport(UVSocketHandle): 'protocol': self._protocol, }) - cdef inline _set_write_buffer_limits(self, int high=-1, int low=-1): - if high == -1: - if low == -1: - high = FLOW_CONTROL_HIGH_WATER - else: - high = FLOW_CONTROL_LOW_WATER - - if low == -1: - low = high // 4 - - if not high >= low >= 0: - raise ValueError('high (%r) must be >= low (%r) must be >= 0' % - (high, low)) - - self._high_water = high - self._low_water = low - - self._maybe_pause_protocol() - cdef inline _maybe_pause_protocol(self): cdef: size_t size = self._get_write_buffer_size() @@ -283,12 +264,10 @@ cdef class UVBaseTransport(UVSocketHandle): def set_write_buffer_limits(self, high=None, low=None): self._ensure_alive() - if high is None: - high = -1 - if low is None: - low = -1 + self._high_water, self._low_water = add_flowcontrol_defaults( + high, low, FLOW_CONTROL_HIGH_WATER) - self._set_write_buffer_limits(high, low) + self._maybe_pause_protocol() def get_write_buffer_limits(self): return (self._low_water, self._high_water) diff --git a/uvloop/handles/pipe.pxd b/uvloop/handles/pipe.pxd index 6c283c5f..56bc4f17 100644 --- a/uvloop/handles/pipe.pxd +++ b/uvloop/handles/pipe.pxd @@ -4,7 +4,9 @@ cdef class UnixServer(UVStreamServer): @staticmethod cdef UnixServer new(Loop loop, object protocol_factory, Server server, - object ssl, object ssl_handshake_timeout) + object ssl, + object ssl_handshake_timeout, + object ssl_shutdown_timeout) cdef class UnixTransport(UVStream): diff --git a/uvloop/handles/pipe.pyx b/uvloop/handles/pipe.pyx index 77617f5a..182d0bf7 100644 --- a/uvloop/handles/pipe.pyx +++ b/uvloop/handles/pipe.pyx @@ -39,12 +39,14 @@ cdef class UnixServer(UVStreamServer): @staticmethod cdef UnixServer new(Loop loop, object protocol_factory, Server server, - object ssl, object ssl_handshake_timeout): + object ssl, + object ssl_handshake_timeout, + object ssl_shutdown_timeout): cdef UnixServer handle handle = UnixServer.__new__(UnixServer) handle._init(loop, protocol_factory, server, - ssl, ssl_handshake_timeout) + ssl, ssl_handshake_timeout, ssl_shutdown_timeout) __pipe_init_uv_handle(handle, loop) return handle diff --git a/uvloop/handles/streamserver.pxd b/uvloop/handles/streamserver.pxd index 8a2b1f4e..019d022e 100644 --- a/uvloop/handles/streamserver.pxd +++ b/uvloop/handles/streamserver.pxd @@ -2,6 +2,7 @@ cdef class UVStreamServer(UVSocketHandle): cdef: object ssl object ssl_handshake_timeout + object ssl_shutdown_timeout object protocol_factory bint opened Server _server @@ -9,7 +10,9 @@ cdef class UVStreamServer(UVSocketHandle): # All "inline" methods are final cdef inline _init(self, Loop loop, object protocol_factory, - Server server, object ssl, object ssl_handshake_timeout) + Server server, object ssl, + object ssl_handshake_timeout, + object ssl_shutdown_timeout) cdef inline _mark_as_open(self) diff --git a/uvloop/handles/streamserver.pyx b/uvloop/handles/streamserver.pyx index 26734e58..8027597b 100644 --- a/uvloop/handles/streamserver.pyx +++ b/uvloop/handles/streamserver.pyx @@ -6,10 +6,13 @@ cdef class UVStreamServer(UVSocketHandle): self._server = None self.ssl = None self.ssl_handshake_timeout = None + self.ssl_shutdown_timeout = None self.protocol_factory = None cdef inline _init(self, Loop loop, object protocol_factory, - Server server, object ssl, object ssl_handshake_timeout): + Server server, object ssl, + object ssl_handshake_timeout, + object ssl_shutdown_timeout): if ssl is not None: if not isinstance(ssl, ssl_SSLContext): @@ -20,9 +23,13 @@ cdef class UVStreamServer(UVSocketHandle): if ssl_handshake_timeout is not None: raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') self.ssl = ssl self.ssl_handshake_timeout = ssl_handshake_timeout + self.ssl_shutdown_timeout = ssl_shutdown_timeout self._start_init(loop) self.protocol_factory = protocol_factory @@ -67,7 +74,8 @@ cdef class UVStreamServer(UVSocketHandle): waiter, server_side=True, server_hostname=None, - ssl_handshake_timeout=self.ssl_handshake_timeout) + ssl_handshake_timeout=self.ssl_handshake_timeout, + ssl_shutdown_timeout=self.ssl_shutdown_timeout) client = self._make_new_transport(ssl_protocol, None) diff --git a/uvloop/handles/tcp.pxd b/uvloop/handles/tcp.pxd index 4a8067a4..ed886d8b 100644 --- a/uvloop/handles/tcp.pxd +++ b/uvloop/handles/tcp.pxd @@ -4,7 +4,8 @@ cdef class TCPServer(UVStreamServer): @staticmethod cdef TCPServer new(Loop loop, object protocol_factory, Server server, object ssl, unsigned int flags, - object ssl_handshake_timeout) + object ssl_handshake_timeout, + object ssl_shutdown_timeout) cdef class TCPTransport(UVStream): diff --git a/uvloop/handles/tcp.pyx b/uvloop/handles/tcp.pyx index f3f37cf6..db1e6607 100644 --- a/uvloop/handles/tcp.pyx +++ b/uvloop/handles/tcp.pyx @@ -59,12 +59,13 @@ cdef class TCPServer(UVStreamServer): @staticmethod cdef TCPServer new(Loop loop, object protocol_factory, Server server, object ssl, unsigned int flags, - object ssl_handshake_timeout): + object ssl_handshake_timeout, + object ssl_shutdown_timeout): cdef TCPServer handle handle = TCPServer.__new__(TCPServer) handle._init(loop, protocol_factory, server, - ssl, ssl_handshake_timeout) + ssl, ssl_handshake_timeout, ssl_shutdown_timeout) __tcp_init_uv_handle(handle, loop, flags) return handle diff --git a/uvloop/includes/consts.pxi b/uvloop/includes/consts.pxi index e6b8ea18..ad33323b 100644 --- a/uvloop/includes/consts.pxi +++ b/uvloop/includes/consts.pxi @@ -1,7 +1,8 @@ DEF UV_STREAM_RECV_BUF_SIZE = 256000 # 250kb -DEF FLOW_CONTROL_HIGH_WATER = 65536 -DEF FLOW_CONTROL_LOW_WATER = 16384 +DEF FLOW_CONTROL_HIGH_WATER = 64 # KiB +DEF FLOW_CONTROL_HIGH_WATER_SSL_READ = 256 # KiB +DEF FLOW_CONTROL_HIGH_WATER_SSL_WRITE = 512 # KiB DEF DEFAULT_FREELIST_SIZE = 250 @@ -17,3 +18,7 @@ DEF LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5 # Number of seconds to wait for SSL handshake to complete # The default timeout matches that of Nginx. DEF SSL_HANDSHAKE_TIMEOUT = 60.0 +# Number of seconds to wait for SSL shutdown to complete +# The default timeout mimics lingering_time +DEF SSL_SHUTDOWN_TIMEOUT = 30.0 +DEF SSL_READ_MAX_SIZE = 256 * 1024 diff --git a/uvloop/includes/flowcontrol.pxd b/uvloop/includes/flowcontrol.pxd new file mode 100644 index 00000000..9a99caef --- /dev/null +++ b/uvloop/includes/flowcontrol.pxd @@ -0,0 +1,20 @@ +cdef inline add_flowcontrol_defaults(high, low, int kb): + cdef int h, l + if high is None: + if low is None: + h = kb * 1024 + else: + l = low + h = 4 * l + else: + h = high + if low is None: + l = h // 4 + else: + l = low + + if not h >= l >= 0: + raise ValueError('high (%r) must be >= low (%r) must be >= 0' % + (h, l)) + + return h, l diff --git a/uvloop/includes/python.pxd b/uvloop/includes/python.pxd index 8c2b01c1..d77b4a09 100644 --- a/uvloop/includes/python.pxd +++ b/uvloop/includes/python.pxd @@ -14,6 +14,13 @@ cdef extern from "Python.h": int _PyImport_ReleaseLock() void _Py_RestoreSignals() + object PyMemoryView_FromMemory(char *mem, ssize_t size, int flags) + object PyMemoryView_FromObject(object obj) + int PyMemoryView_Check(object obj) + + cdef enum: + PyBUF_WRITE + cdef extern from "includes/compat.h": ctypedef struct PyContext diff --git a/uvloop/includes/stdlib.pxi b/uvloop/includes/stdlib.pxi index ff2de986..e337f67b 100644 --- a/uvloop/includes/stdlib.pxi +++ b/uvloop/includes/stdlib.pxi @@ -120,6 +120,7 @@ cdef ssl_SSLContext = ssl.SSLContext cdef ssl_MemoryBIO = ssl.MemoryBIO cdef ssl_create_default_context = ssl.create_default_context cdef ssl_SSLError = ssl.SSLError +cdef ssl_SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError) cdef ssl_CertificateError = ssl.CertificateError cdef int ssl_SSL_ERROR_WANT_READ = ssl.SSL_ERROR_WANT_READ cdef int ssl_SSL_ERROR_WANT_WRITE = ssl.SSL_ERROR_WANT_WRITE diff --git a/uvloop/loop.pxd b/uvloop/loop.pxd index 3086b235..48e5d610 100644 --- a/uvloop/loop.pxd +++ b/uvloop/loop.pxd @@ -168,7 +168,8 @@ cdef class Loop: object ssl, bint reuse_port, object backlog, - object ssl_handshake_timeout) + object ssl_handshake_timeout, + object ssl_shutdown_timeout) cdef _track_transport(self, UVBaseTransport transport) cdef _fileobj_to_fd(self, fileobj) @@ -226,6 +227,7 @@ include "handles/pipe.pxd" include "handles/process.pxd" include "request.pxd" +include "sslproto.pxd" include "handles/udp.pxd" diff --git a/uvloop/loop.pyx b/uvloop/loop.pyx index 7745239e..dfb3cd91 100644 --- a/uvloop/loop.pyx +++ b/uvloop/loop.pyx @@ -18,7 +18,10 @@ from .includes.python cimport PY_VERSION_HEX, \ PyContext, \ PyContext_CopyCurrent, \ PyContext_Enter, \ - PyContext_Exit + PyContext_Exit, \ + PyMemoryView_FromMemory, PyBUF_WRITE, \ + PyMemoryView_FromObject, PyMemoryView_Check +from .includes.flowcontrol cimport add_flowcontrol_defaults from libc.stdint cimport uint64_t from libc.string cimport memset, strerror, memcpy @@ -1085,13 +1088,15 @@ cdef class Loop: object ssl, bint reuse_port, object backlog, - object ssl_handshake_timeout): + object ssl_handshake_timeout, + object ssl_shutdown_timeout): cdef: TCPServer tcp int bind_flags tcp = TCPServer.new(self, protocol_factory, server, ssl, - addr.sa_family, ssl_handshake_timeout) + addr.sa_family, + ssl_handshake_timeout, ssl_shutdown_timeout) if reuse_port: self._sock_set_reuseport(tcp._fileno()) @@ -1523,7 +1528,8 @@ cdef class Loop: async def start_tls(self, transport, protocol, sslcontext, *, server_side=False, server_hostname=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): """Upgrade transport to TLS. Return a new transport that *protocol* should start using @@ -1534,7 +1540,8 @@ cdef class Loop: f'sslcontext is expected to be an instance of ssl.SSLContext, ' f'got {sslcontext!r}') - if not isinstance(transport, (TCPTransport, UnixTransport)): + if not isinstance(transport, (TCPTransport, UnixTransport, + _SSLProtocolTransport)): raise TypeError( f'transport {transport!r} is not supported by start_tls()') @@ -1543,6 +1550,7 @@ cdef class Loop: self, protocol, sslcontext, waiter, server_side, server_hostname, ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout, call_connection_made=False) # Pause early so that "ssl_protocol.data_received()" doesn't @@ -1561,7 +1569,7 @@ cdef class Loop: resume_cb.cancel() raise - return ssl_protocol._app_transport + return (ssl_protocol)._app_transport @cython.iterable_coroutine async def create_server(self, protocol_factory, host=None, port=None, @@ -1573,7 +1581,8 @@ cdef class Loop: ssl=None, reuse_address=None, # ignored, libuv sets it reuse_port=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): """A coroutine which creates a TCP server bound to host and port. The return value is a Server object which can be used to stop @@ -1612,6 +1621,10 @@ cdef class Loop: ssl_handshake_timeout is the time in seconds that an SSL server will wait for completion of the SSL handshake before aborting the connection. Default is 60s. + + ssl_shutdown_timeout is the time in seconds that an SSL server + will wait for completion of the SSL shutdown before aborting the + connection. Default is 30s. """ cdef: TCPServer tcp @@ -1634,6 +1647,9 @@ cdef class Loop: if ssl_handshake_timeout is not None: raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') if host is not None or port is not None: if sock is not None: @@ -1668,7 +1684,8 @@ cdef class Loop: tcp = self._create_server( addrinfo.ai_addr, protocol_factory, server, - ssl, reuse_port, backlog, ssl_handshake_timeout) + ssl, reuse_port, backlog, + ssl_handshake_timeout, ssl_shutdown_timeout) server._add_server(tcp) @@ -1690,7 +1707,8 @@ cdef class Loop: sock.setblocking(False) tcp = TCPServer.new(self, protocol_factory, server, ssl, - uv.AF_UNSPEC, ssl_handshake_timeout) + uv.AF_UNSPEC, + ssl_handshake_timeout, ssl_shutdown_timeout) try: tcp._open(sock.fileno()) @@ -1709,7 +1727,8 @@ cdef class Loop: async def create_connection(self, protocol_factory, host=None, port=None, *, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): """Connect to a TCP server. Create a streaming transport connection to a given Internet host and @@ -1763,13 +1782,17 @@ cdef class Loop: protocol = SSLProtocol( self, app_protocol, sslcontext, ssl_waiter, False, server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) else: if server_hostname is not None: raise ValueError('server_hostname is only meaningful with ssl') if ssl_handshake_timeout is not None: raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') if host is not None or port is not None: if sock is not None: @@ -1919,14 +1942,15 @@ cdef class Loop: except Exception: tr._close() raise - return protocol._app_transport, app_protocol + return (protocol)._app_transport, app_protocol else: return tr, protocol @cython.iterable_coroutine async def create_unix_server(self, protocol_factory, path=None, *, backlog=100, sock=None, ssl=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): """A coroutine which creates a UNIX Domain Socket server. The return value is a Server object, which can be used to stop @@ -1947,6 +1971,10 @@ cdef class Loop: ssl_handshake_timeout is the time in seconds that an SSL server will wait for completion of the SSL handshake before aborting the connection. Default is 60s. + + ssl_shutdown_timeout is the time in seconds that an SSL server + will wait for completion of the SSL shutdown before aborting the + connection. Default is 30s. """ cdef: UnixServer pipe @@ -1959,6 +1987,9 @@ cdef class Loop: if ssl_handshake_timeout is not None: raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') if path is not None: if sock is not None: @@ -2027,7 +2058,8 @@ cdef class Loop: sock.setblocking(False) pipe = UnixServer.new( - self, protocol_factory, server, ssl, ssl_handshake_timeout) + self, protocol_factory, server, ssl, + ssl_handshake_timeout, ssl_shutdown_timeout) try: pipe._open(sock.fileno()) @@ -2050,7 +2082,8 @@ cdef class Loop: async def create_unix_connection(self, protocol_factory, path=None, *, ssl=None, sock=None, server_hostname=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): cdef: UnixTransport tr @@ -2070,13 +2103,17 @@ cdef class Loop: protocol = SSLProtocol( self, app_protocol, sslcontext, ssl_waiter, False, server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) else: if server_hostname is not None: raise ValueError('server_hostname is only meaningful with ssl') if ssl_handshake_timeout is not None: raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') if path is not None: if sock is not None: @@ -2135,7 +2172,7 @@ cdef class Loop: except Exception: tr._close() raise - return protocol._app_transport, app_protocol + return (protocol)._app_transport, app_protocol else: return tr, protocol @@ -2431,7 +2468,9 @@ cdef class Loop: @cython.iterable_coroutine async def connect_accepted_socket(self, protocol_factory, sock, *, - ssl=None, ssl_handshake_timeout=None): + ssl=None, + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): """Handle an accepted connection. This is used by servers that accept connections outside of @@ -2451,6 +2490,9 @@ cdef class Loop: if ssl_handshake_timeout is not None: raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') if not _is_sock_stream(sock.type): raise ValueError( @@ -2468,7 +2510,8 @@ cdef class Loop: self, app_protocol, ssl, waiter, server_side=True, server_hostname=None, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) transport_waiter = None if sock.family == uv.AF_UNIX: @@ -2493,7 +2536,7 @@ cdef class Loop: raise if ssl: - return protocol._app_transport, protocol + return (protocol)._app_transport, protocol else: return transport, protocol diff --git a/uvloop/sslproto.pxd b/uvloop/sslproto.pxd new file mode 100644 index 00000000..30efa5cd --- /dev/null +++ b/uvloop/sslproto.pxd @@ -0,0 +1,120 @@ +cdef enum SSLProtocolState: + UNWRAPPED = 0 + DO_HANDSHAKE = 1 + WRAPPED = 2 + FLUSHING = 3 + SHUTDOWN = 4 + + +cdef class _SSLProtocolTransport: + cdef: + object _loop + SSLProtocol _ssl_protocol + bint _closed + + +cdef class SSLProtocol: + cdef: + bint _server_side + str _server_hostname + object _sslcontext + + object _extra + + object _write_backlog + size_t _write_buffer_size + + object _waiter + object _loop + _SSLProtocolTransport _app_transport + + object _transport + bint _call_connection_made + object _ssl_handshake_timeout + object _ssl_shutdown_timeout + + object _sslobj + object _sslobj_read + object _sslobj_write + object _incoming + object _incoming_write + object _outgoing + object _outgoing_read + char* _ssl_buffer + size_t _ssl_buffer_len + object _ssl_buffer_view + SSLProtocolState _state + size_t _conn_lost + bint _eof_received + + bint _ssl_writing_paused + bint _app_reading_paused + + size_t _incoming_high_water + size_t _incoming_low_water + bint _ssl_reading_paused + + bint _app_writing_paused + size_t _outgoing_high_water + size_t _outgoing_low_water + + object _app_protocol + bint _app_protocol_is_buffer + object _app_protocol_get_buffer + object _app_protocol_buffer_updated + + object _handshake_start_time + object _handshake_timeout_handle + object _shutdown_timeout_handle + + cdef _set_app_protocol(self, app_protocol) + cdef _wakeup_waiter(self, exc=*) + cdef _get_extra_info(self, name, default=*) + cdef _set_state(self, SSLProtocolState new_state) + + # Handshake flow + + cdef _start_handshake(self) + cdef _check_handshake_timeout(self) + cdef _do_handshake(self) + cdef _on_handshake_complete(self, handshake_exc) + + # Shutdown flow + + cdef _start_shutdown(self) + cdef _check_shutdown_timeout(self) + cdef _do_flush(self) + cdef _do_shutdown(self) + cdef _on_shutdown_complete(self, shutdown_exc) + cdef _abort(self, exc) + + # Outgoing flow + + cdef _write_appdata(self, list_of_data) + cdef _do_write(self) + cdef _process_outgoing(self) + + # Incoming flow + + cdef _do_read(self) + cdef _do_read__buffered(self) + cdef _do_read__copied(self) + cdef _call_eof_received(self) + + # Flow control for writes from APP socket + + cdef _control_app_writing(self) + cdef size_t _get_write_buffer_size(self) + cdef _set_write_buffer_limits(self, high=*, low=*) + + # Flow control for reads to APP socket + + cdef _pause_reading(self) + cdef _resume_reading(self) + + # Flow control for reads from SSL socket + + cdef _control_ssl_reading(self) + cdef _set_read_buffer_limits(self, high=*, low=*) + cdef size_t _get_read_buffer_size(self) + cdef _fatal_error(self, exc, message=*) diff --git a/uvloop/sslproto.pyx b/uvloop/sslproto.pyx index 40e2a6b8..892e790f 100644 --- a/uvloop/sslproto.pyx +++ b/uvloop/sslproto.pyx @@ -1,6 +1,3 @@ -# Adapted from CPython/Lib/asyncio/sslproto.py. -# License: PSFL. - cdef _create_transport_context(server_side, server_hostname): if server_side: raise ValueError('Server side SSL needs a valid SSLContext') @@ -15,261 +12,12 @@ cdef _create_transport_context(server_side, server_hostname): return sslcontext -# States of an _SSLPipe. -cdef: - str _UNWRAPPED = "UNWRAPPED" - str _DO_HANDSHAKE = "DO_HANDSHAKE" - str _WRAPPED = "WRAPPED" - str _SHUTDOWN = "SHUTDOWN" - - -cdef ssize_t READ_MAX_SIZE = 256 * 1024 - - -@cython.no_gc_clear -cdef class _SSLPipe: - """An SSL "Pipe". - - An SSL pipe allows you to communicate with an SSL/TLS protocol instance - through memory buffers. It can be used to implement a security layer for an - existing connection where you don't have access to the connection's file - descriptor, or for some reason you don't want to use it. - - An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode, - data is passed through untransformed. In wrapped mode, application level - data is encrypted to SSL record level data and vice versa. The SSL record - level is the lowest level in the SSL protocol suite and is what travels - as-is over the wire. - - An SslPipe initially is in "unwrapped" mode. To start SSL, call - do_handshake(). To shutdown SSL again, call unwrap(). - """ - - cdef: - object _context - object _server_side - object _server_hostname - object _state - object _incoming - object _outgoing - object _sslobj - bint _need_ssldata - object _handshake_cb - object _shutdown_cb - - - def __init__(self, context, server_side, server_hostname=None): - """ - The *context* argument specifies the ssl.SSLContext to use. - - The *server_side* argument indicates whether this is a server side or - client side transport. - - The optional *server_hostname* argument can be used to specify the - hostname you are connecting to. You may only specify this parameter if - the _ssl module supports Server Name Indication (SNI). - """ - self._context = context - self._server_side = server_side - self._server_hostname = server_hostname - self._state = _UNWRAPPED - self._incoming = ssl_MemoryBIO() - self._outgoing = ssl_MemoryBIO() - self._sslobj = None - self._need_ssldata = False - self._handshake_cb = None - self._shutdown_cb = None - - cdef do_handshake(self, callback=None): - """Start the SSL handshake. - - Return a list of ssldata. A ssldata element is a list of buffers - - The optional *callback* argument can be used to install a callback that - will be called when the handshake is complete. The callback will be - called with None if successful, else an exception instance. - """ - if self._state is not _UNWRAPPED: - raise RuntimeError('handshake in progress or completed') - self._sslobj = self._context.wrap_bio( - self._incoming, self._outgoing, - server_side=self._server_side, - server_hostname=self._server_hostname) - self._state = _DO_HANDSHAKE - self._handshake_cb = callback - ssldata, appdata = self.feed_ssldata(b'', only_handshake=True) - assert len(appdata) == 0 - return ssldata - - cdef shutdown(self, callback=None): - """Start the SSL shutdown sequence. - - Return a list of ssldata. A ssldata element is a list of buffers - - The optional *callback* argument can be used to install a callback that - will be called when the shutdown is complete. The callback will be - called without arguments. - """ - if self._state is _UNWRAPPED: - raise RuntimeError('no security layer present') - if self._state is _SHUTDOWN: - raise RuntimeError('shutdown in progress') - assert self._state in (_WRAPPED, _DO_HANDSHAKE) - self._state = _SHUTDOWN - self._shutdown_cb = callback - ssldata, appdata = self.feed_ssldata(b'') - assert appdata == [] or appdata == [b''] - return ssldata - - cdef feed_eof(self): - """Send a potentially "ragged" EOF. - - This method will raise an SSL_ERROR_EOF exception if the EOF is - unexpected. - """ - self._incoming.write_eof() - ssldata, appdata = self.feed_ssldata(b'') - assert appdata == [] or appdata == [b''] - - cdef feed_ssldata(self, data, bint only_handshake=False): - """Feed SSL record level data into the pipe. - - The data must be a bytes instance. It is OK to send an empty bytes - instance. This can be used to get ssldata for a handshake initiated by - this endpoint. - - Return a (ssldata, appdata) tuple. The ssldata element is a list of - buffers containing SSL data that needs to be sent to the remote SSL. - - The appdata element is a list of buffers containing plaintext data that - needs to be forwarded to the application. The appdata list may contain - an empty buffer indicating an SSL "close_notify" alert. This alert must - be acknowledged by calling shutdown(). - """ - cdef: - list appdata - list ssldata - int errno - - if self._state is _UNWRAPPED: - # If unwrapped, pass plaintext data straight through. - if data: - appdata = [data] - else: - appdata = [] - return ([], appdata) - - self._need_ssldata = False - if data: - self._incoming.write(data) - - ssldata = [] - appdata = [] - try: - if self._state is _DO_HANDSHAKE: - # Call do_handshake() until it doesn't raise anymore. - self._sslobj.do_handshake() - self._state = _WRAPPED - if self._handshake_cb: - self._handshake_cb(None) - if only_handshake: - return (ssldata, appdata) - # Handshake done: execute the wrapped block - - if self._state is _WRAPPED: - # Main state: read data from SSL until close_notify - while True: - chunk = self._sslobj.read(READ_MAX_SIZE) - appdata.append(chunk) - if not chunk: # close_notify - break - - elif self._state is _SHUTDOWN: - # Call shutdown() until it doesn't raise anymore. - self._sslobj.unwrap() - self._sslobj = None - self._state = _UNWRAPPED - if self._shutdown_cb: - self._shutdown_cb() - - elif self._state is _UNWRAPPED: - # Drain possible plaintext data after close_notify. - appdata.append(self._incoming.read()) - except (ssl_SSLError, ssl_CertificateError) as exc: - errno = getattr(exc, 'errno', 0) # SSL_ERROR_NONE = 0 - if errno not in (ssl_SSL_ERROR_WANT_READ, ssl_SSL_ERROR_WANT_WRITE, - ssl_SSL_ERROR_SYSCALL): - if self._state is _DO_HANDSHAKE and self._handshake_cb: - self._handshake_cb(exc) - raise - self._need_ssldata = (errno == ssl_SSL_ERROR_WANT_READ) - - # Check for record level data that needs to be sent back. - # Happens for the initial handshake and renegotiations. - if self._outgoing.pending: - ssldata.append(self._outgoing.read()) - return (ssldata, appdata) - - cdef feed_appdata(self, data, int offset=0): - """Feed plaintext data into the pipe. - - Return an (ssldata, offset) tuple. The ssldata element is a list of - buffers containing record level data that needs to be sent to the - remote SSL instance. The offset is the number of plaintext bytes that - were processed, which may be less than the length of data. - - NOTE: In case of short writes, this call MUST be retried with the SAME - buffer passed into the *data* argument (i.e. the id() must be the - same). This is an OpenSSL requirement. A further particularity is that - a short write will always have offset == 0, because the _ssl module - does not enable partial writes. And even though the offset is zero, - there will still be encrypted data in ssldata. - """ - cdef: - int errno - assert 0 <= offset <= len(data) - if self._state is _UNWRAPPED: - # pass through data in unwrapped mode - if offset < len(data): - ssldata = [data[offset:]] - else: - ssldata = [] - return (ssldata, len(data)) - - ssldata = [] - view = memoryview(data) - while True: - self._need_ssldata = False - try: - if offset < len(view): - offset += self._sslobj.write(view[offset:]) - except ssl_SSLError as exc: - errno = getattr(exc, 'errno', 0) # SSL_ERROR_NONE = 0 - # It is not allowed to call write() after unwrap() until the - # close_notify is acknowledged. We return the condition to the - # caller as a short write. - if exc.reason == 'PROTOCOL_IS_SHUTDOWN': - exc.errno = errno = ssl_SSL_ERROR_WANT_READ - if errno not in (ssl_SSL_ERROR_WANT_READ, - ssl_SSL_ERROR_WANT_WRITE, - ssl_SSL_ERROR_SYSCALL): - raise - self._need_ssldata = (errno == ssl_SSL_ERROR_WANT_READ) - - # See if there's any record level data back for us. - if self._outgoing.pending: - ssldata.append(self._outgoing.read()) - if offset == len(view) or self._need_ssldata: - break - return (ssldata, offset) - - -class _SSLProtocolTransport(aio_FlowControlMixin, aio_Transport): +cdef class _SSLProtocolTransport: # TODO: # _sendfile_compatible = constants._SendfileMode.FALLBACK - def __init__(self, loop, ssl_protocol): + def __cinit__(self, loop, ssl_protocol): self._loop = loop # SSLProtocol instance self._ssl_protocol = ssl_protocol @@ -299,18 +47,15 @@ class _SSLProtocolTransport(aio_FlowControlMixin, aio_Transport): self._closed = True self._ssl_protocol._start_shutdown() - def __del__(self): + def __dealloc__(self): if not self._closed: _warn_with_source( "unclosed transport {!r}".format(self), ResourceWarning, self) - self.close() + self._closed = True def is_reading(self): - tr = self._ssl_protocol._transport - if tr is None: - raise RuntimeError('SSL transport has not been initialized yet') - return tr.is_reading() + return not self._ssl_protocol._app_reading_paused def pause_reading(self): """Pause the receiving end. @@ -318,7 +63,7 @@ class _SSLProtocolTransport(aio_FlowControlMixin, aio_Transport): No data will be passed to the protocol's data_received() method until resume_reading() is called. """ - self._ssl_protocol._transport.pause_reading() + self._ssl_protocol._pause_reading() def resume_reading(self): """Resume the receiving end. @@ -326,7 +71,7 @@ class _SSLProtocolTransport(aio_FlowControlMixin, aio_Transport): Data received will once again be passed to the protocol's data_received() method. """ - self._ssl_protocol._transport.resume_reading() + self._ssl_protocol._resume_reading() def set_write_buffer_limits(self, high=None, low=None): """Set the high- and low-water limits for write flow control. @@ -347,16 +92,51 @@ class _SSLProtocolTransport(aio_FlowControlMixin, aio_Transport): reduces opportunities for doing I/O and computation concurrently. """ - self._ssl_protocol._transport.set_write_buffer_limits(high, low) + self._ssl_protocol._set_write_buffer_limits(high, low) + self._ssl_protocol._control_app_writing() + + def get_write_buffer_limits(self): + return (self._ssl_protocol._outgoing_low_water, + self._ssl_protocol._outgoing_high_water) def get_write_buffer_size(self): - """Return the current size of the write buffer.""" - return self._ssl_protocol._transport.get_write_buffer_size() + """Return the current size of the write buffers.""" + return self._ssl_protocol._get_write_buffer_size() + + def set_read_buffer_limits(self, high=None, low=None): + """Set the high- and low-water limits for read flow control. + + These two values control when to call the upstream transport's + pause_reading() and resume_reading() methods. If specified, + the low-water limit must be less than or equal to the + high-water limit. Neither value can be negative. + + The defaults are implementation-specific. If only the + high-water limit is given, the low-water limit defaults to an + implementation-specific value less than or equal to the + high-water limit. Setting high to zero forces low to zero as + well, and causes pause_reading() to be called whenever the + buffer becomes non-empty. Setting low to zero causes + resume_reading() to be called only once the buffer is empty. + Use of zero for either limit is generally sub-optimal as it + reduces opportunities for doing I/O and computation + concurrently. + """ + self._ssl_protocol._set_read_buffer_limits(high, low) + self._ssl_protocol._control_ssl_reading() + + def get_read_buffer_limits(self): + return (self._ssl_protocol._incoming_low_water, + self._ssl_protocol._incoming_high_water) + + def get_read_buffer_size(self): + """Return the current size of the read buffer.""" + return self._ssl_protocol._get_read_buffer_size() @property def _protocol_paused(self): # Required for sendfile fallback pause_writing/resume_writing logic - return self._ssl_protocol._transport._protocol_paused + return self._ssl_protocol._app_writing_paused def write(self, data): """Write some data bytes to the transport. @@ -369,7 +149,22 @@ class _SSLProtocolTransport(aio_FlowControlMixin, aio_Transport): f"got {type(data).__name__}") if not data: return - self._ssl_protocol._write_appdata(data) + self._ssl_protocol._write_appdata((data,)) + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation concatenates the arguments and + calls write() on the result. + """ + self._ssl_protocol._write_appdata(list_of_data) + + def write_eof(self): + """Close the write end after flushing buffered data. + + This raises :exc:`NotImplementedError` right now. + """ + raise NotImplementedError def can_write_eof(self): """Return True if this transport supports write_eof(), False if not.""" @@ -382,27 +177,56 @@ class _SSLProtocolTransport(aio_FlowControlMixin, aio_Transport): The protocol's connection_lost() method will (eventually) be called with None as its argument. """ - self._ssl_protocol._abort() + self._force_close(None) + + def _force_close(self, exc): self._closed = True + self._ssl_protocol._abort(exc) + + def _test__append_write_backlog(self, data): + # for test only + self._ssl_protocol._write_backlog.append(data) + self._ssl_protocol._write_buffer_size += len(data) -class SSLProtocol(object): +cdef class SSLProtocol: """SSL protocol. Implementation of SSL on top of a socket using incoming and outgoing buffers which are ssl.MemoryBIO objects. """ + def __cinit__(self, *args, **kwargs): + self._ssl_buffer_len = SSL_READ_MAX_SIZE + self._ssl_buffer = PyMem_RawMalloc(self._ssl_buffer_len) + if not self._ssl_buffer: + raise MemoryError() + self._ssl_buffer_view = PyMemoryView_FromMemory( + self._ssl_buffer, self._ssl_buffer_len, PyBUF_WRITE) + + def __dealloc__(self): + self._ssl_buffer_view = None + PyMem_RawFree(self._ssl_buffer) + self._ssl_buffer = NULL + self._ssl_buffer_len = 0 + def __init__(self, loop, app_protocol, sslcontext, waiter, server_side=False, server_hostname=None, call_connection_made=True, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): if ssl_handshake_timeout is None: ssl_handshake_timeout = SSL_HANDSHAKE_TIMEOUT elif ssl_handshake_timeout <= 0: raise ValueError( f"ssl_handshake_timeout should be a positive number, " f"got {ssl_handshake_timeout}") + if ssl_shutdown_timeout is None: + ssl_shutdown_timeout = SSL_SHUTDOWN_TIMEOUT + elif ssl_shutdown_timeout <= 0: + raise ValueError( + f"ssl_shutdown_timeout should be a positive number, " + f"got {ssl_shutdown_timeout}") if not sslcontext: sslcontext = _create_transport_context( @@ -425,25 +249,49 @@ class SSLProtocol(object): self._waiter = waiter self._loop = loop self._set_app_protocol(app_protocol) - self._app_transport = None - # _SSLPipe instance (None until the connection is made) - self._sslpipe = None - self._session_established = False - self._in_handshake = False - self._in_shutdown = False + self._app_transport = _SSLProtocolTransport(self._loop, self) # transport, ex: SelectorSocketTransport self._transport = None self._call_connection_made = call_connection_made self._ssl_handshake_timeout = ssl_handshake_timeout + self._ssl_shutdown_timeout = ssl_shutdown_timeout + # SSL and state machine + self._sslobj = None + self._incoming = ssl_MemoryBIO() + self._incoming_write = self._incoming.write + self._outgoing = ssl_MemoryBIO() + self._outgoing_read = self._outgoing.read + self._state = UNWRAPPED + self._conn_lost = 0 # Set when connection_lost called + self._eof_received = False + + # Flow Control + + self._ssl_writing_paused = False + + self._app_reading_paused = False - def _set_app_protocol(self, app_protocol): + self._ssl_reading_paused = False + self._incoming_high_water = 0 + self._incoming_low_water = 0 + self._set_read_buffer_limits() + + self._app_writing_paused = False + self._outgoing_high_water = 0 + self._outgoing_low_water = 0 + self._set_write_buffer_limits() + + cdef _set_app_protocol(self, app_protocol): self._app_protocol = app_protocol - self._app_protocol_is_buffer = ( - not hasattr(app_protocol, 'data_received') and - hasattr(app_protocol, 'get_buffer') - ) + if (hasattr(app_protocol, 'get_buffer') and + not isinstance(app_protocol, aio_Protocol)): + self._app_protocol_get_buffer = app_protocol.get_buffer + self._app_protocol_buffer_updated = app_protocol.buffer_updated + self._app_protocol_is_buffer = True + else: + self._app_protocol_is_buffer = False - def _wakeup_waiter(self, exc=None): + cdef _wakeup_waiter(self, exc=None): if self._waiter is None: return if not self._waiter.cancelled(): @@ -459,9 +307,6 @@ class SSLProtocol(object): Start the SSL handshake. """ self._transport = transport - self._sslpipe = _SSLPipe(self._sslcontext, - self._server_side, - self._server_hostname) self._start_handshake() def connection_lost(self, exc): @@ -471,62 +316,55 @@ class SSLProtocol(object): meaning a regular EOF is received or the connection was aborted or closed). """ - if self._session_established: - self._session_established = False - self._loop.call_soon(self._app_protocol.connection_lost, exc) + self._write_backlog.clear() + self._outgoing_read() + self._conn_lost += 1 + + # Just mark the app transport as closed so that its __dealloc__ + # doesn't complain. + if self._app_transport is not None: + self._app_transport._closed = True + if self._state != DO_HANDSHAKE: + self._loop.call_soon(self._app_protocol.connection_lost, exc) + self._set_state(UNWRAPPED) self._transport = None self._app_transport = None self._wakeup_waiter(exc) - def pause_writing(self): - """Called when the low-level transport's buffer goes over - the high-water mark. - """ - self._app_protocol.pause_writing() + if getattr(self, '_shutdown_timeout_handle', None): + self._shutdown_timeout_handle.cancel() + if getattr(self, '_handshake_timeout_handle', None): + self._handshake_timeout_handle.cancel() - def resume_writing(self): - """Called when the low-level transport's buffer drains below - the low-water mark. - """ - self._app_protocol.resume_writing() + def get_buffer(self, n): + cdef size_t want = n + if want > SSL_READ_MAX_SIZE: + want = SSL_READ_MAX_SIZE + if self._ssl_buffer_len < want: + self._ssl_buffer = PyMem_RawRealloc(self._ssl_buffer, want) + if not self._ssl_buffer: + raise MemoryError() + self._ssl_buffer_len = want + self._ssl_buffer_view = PyMemoryView_FromMemory( + self._ssl_buffer, want, PyBUF_WRITE) + return self._ssl_buffer_view - def data_received(self, data): - """Called when some SSL data is received. + def buffer_updated(self, nbytes): + self._incoming_write(PyMemoryView_FromMemory( + self._ssl_buffer, nbytes, PyBUF_WRITE)) - The argument is a bytes object. - """ - if self._sslpipe is None: - # transport closing, sslpipe is destroyed - return + if self._state == DO_HANDSHAKE: + self._do_handshake() - try: - ssldata, appdata = (<_SSLPipe>self._sslpipe).feed_ssldata(data) - except ssl_SSLError as e: - msg = ( - f'SSL error errno:{getattr(e, "errno", "missing")} ' - f'reason: {getattr(e, "reason", "missing")}' - ) - self._fatal_error(e, msg) - return + elif self._state == WRAPPED: + self._do_read() - self._transport.writelines(ssldata) + elif self._state == FLUSHING: + self._do_flush() - for chunk in appdata: - if chunk: - try: - if self._app_protocol_is_buffer: - _feed_data_to_bufferred_proto( - self._app_protocol, chunk) - else: - self._app_protocol.data_received(chunk) - except Exception as ex: - self._fatal_error( - ex, 'application protocol failed to receive SSL data') - return - else: - self._start_shutdown() - break + elif self._state == SHUTDOWN: + self._do_shutdown() def eof_received(self): """Called when the other end of the low-level stream @@ -540,17 +378,27 @@ class SSLProtocol(object): if self._loop.get_debug(): aio_logger.debug("%r received EOF", self) - self._wakeup_waiter(ConnectionResetError) + if self._state == DO_HANDSHAKE: + self._on_handshake_complete(ConnectionResetError) + + elif self._state == WRAPPED: + self._set_state(FLUSHING) + self._do_write() + self._set_state(SHUTDOWN) + self._do_shutdown() + + elif self._state == FLUSHING: + self._do_write() + self._set_state(SHUTDOWN) + self._do_shutdown() + + elif self._state == SHUTDOWN: + self._do_shutdown() - if not self._in_handshake: - keep_open = self._app_protocol.eof_received() - if keep_open: - aio_logger.warning('returning true from eof_received() ' - 'has no effect when using ssl') finally: self._transport.close() - def _get_extra_info(self, name, default=None): + cdef _get_extra_info(self, name, default=None): if name in self._extra: return self._extra[name] elif self._transport is not None: @@ -558,40 +406,62 @@ class SSLProtocol(object): else: return default - def _mark_closed(self): - self._closed = True + cdef _set_state(self, SSLProtocolState new_state): + cdef bint allowed = False + + if new_state == UNWRAPPED: + allowed = True + + elif self._state == UNWRAPPED and new_state == DO_HANDSHAKE: + allowed = True + + elif self._state == DO_HANDSHAKE and new_state == WRAPPED: + allowed = True + + elif self._state == WRAPPED and new_state == FLUSHING: + allowed = True + + elif self._state == FLUSHING and new_state == SHUTDOWN: + allowed = True + + if allowed: + self._state = new_state - def _start_shutdown(self): - if self._in_shutdown: - return - if self._in_handshake: - self._abort() else: - self._in_shutdown = True - self._write_appdata(b'') + raise RuntimeError( + 'cannot switch state from {} to {}'.format( + self._state, new_state)) - def _write_appdata(self, data): - self._write_backlog.append((data, 0)) - self._write_buffer_size += len(data) - self._process_write_backlog() + # Handshake flow - def _start_handshake(self): + cdef _start_handshake(self): if self._loop.get_debug(): aio_logger.debug("%r starts SSL handshake", self) self._handshake_start_time = self._loop.time() else: self._handshake_start_time = None - self._in_handshake = True - # (b'', 1) is a special value in _process_write_backlog() to do - # the SSL handshake - self._write_backlog.append((b'', 1)) + + self._set_state(DO_HANDSHAKE) + + # start handshake timeout count down self._handshake_timeout_handle = \ self._loop.call_later(self._ssl_handshake_timeout, - self._check_handshake_timeout) - self._process_write_backlog() + lambda: self._check_handshake_timeout()) - def _check_handshake_timeout(self): - if self._in_handshake is True: + try: + self._sslobj = self._sslcontext.wrap_bio( + self._incoming, self._outgoing, + server_side=self._server_side, + server_hostname=self._server_hostname) + self._sslobj_read = self._sslobj.read + self._sslobj_write = self._sslobj.write + except Exception as ex: + self._on_handshake_complete(ex) + else: + self._do_handshake() + + cdef _check_handshake_timeout(self): + if self._state == DO_HANDSHAKE: msg = ( f"SSL handshake is taking longer than " f"{self._ssl_handshake_timeout} seconds: " @@ -599,17 +469,29 @@ class SSLProtocol(object): ) self._fatal_error(ConnectionAbortedError(msg)) - def _on_handshake_complete(self, handshake_exc): - self._in_handshake = False + cdef _do_handshake(self): + try: + self._sslobj.do_handshake() + except ssl_SSLAgainErrors as exc: + self._process_outgoing() + except ssl_SSLError as exc: + self._on_handshake_complete(exc) + else: + self._on_handshake_complete(None) + + cdef _on_handshake_complete(self, handshake_exc): self._handshake_timeout_handle.cancel() - sslobj = (<_SSLPipe>self._sslpipe)._sslobj + sslobj = self._sslobj try: - if handshake_exc is not None: + if handshake_exc is None: + self._set_state(WRAPPED) + else: raise handshake_exc peercert = sslobj.getpeercert() except Exception as exc: + self._set_state(UNWRAPPED) if isinstance(exc, ssl_CertificateError): msg = 'SSL handshake failed on verifying the certificate' else: @@ -626,63 +508,324 @@ class SSLProtocol(object): cipher=sslobj.cipher(), compression=sslobj.compression(), ssl_object=sslobj) - - self._app_transport = _SSLProtocolTransport(self._loop, self) - if self._call_connection_made: self._app_protocol.connection_made(self._app_transport) self._wakeup_waiter() - self._session_established = True - # In case transport.write() was already called. Don't call - # immediately _process_write_backlog(), but schedule it: - # _on_handshake_complete() can be called indirectly from - # _process_write_backlog(), and _process_write_backlog() is not - # reentrant. - self._loop.call_soon(self._process_write_backlog) - - def _process_write_backlog(self): - # Try to make progress on the write backlog. - if self._transport is None or self._sslpipe is None: + self._do_read() + + # Shutdown flow + + cdef _start_shutdown(self): + if self._state in (FLUSHING, SHUTDOWN, UNWRAPPED): return + if self._app_transport is not None: + self._app_transport._closed = True + if self._state == DO_HANDSHAKE: + self._abort(None) + else: + self._set_state(FLUSHING) + self._shutdown_timeout_handle = \ + self._loop.call_later(self._ssl_shutdown_timeout, + lambda: self._check_shutdown_timeout()) + self._do_flush() + + cdef _check_shutdown_timeout(self): + if self._state in (FLUSHING, SHUTDOWN): + self._transport._force_close( + aio_TimeoutError('SSL shutdown timed out')) + + cdef _do_flush(self): + if self._write_backlog: + try: + while True: + # data is discarded when FLUSHING + chunk_size = len(self._sslobj_read(SSL_READ_MAX_SIZE)) + if not chunk_size: + # close_notify + break + except ssl_SSLAgainErrors as exc: + pass + except ssl_SSLError as exc: + self._on_shutdown_complete(exc) + return + + try: + self._do_write() + except Exception as exc: + self._on_shutdown_complete(exc) + return + + if not self._write_backlog: + self._set_state(SHUTDOWN) + self._do_shutdown() + + cdef _do_shutdown(self): + try: + self._sslobj.unwrap() + except ssl_SSLAgainErrors as exc: + self._process_outgoing() + except ssl_SSLError as exc: + self._on_shutdown_complete(exc) + else: + self._process_outgoing() + self._call_eof_received() + self._on_shutdown_complete(None) + + cdef _on_shutdown_complete(self, shutdown_exc): + self._shutdown_timeout_handle.cancel() + + if shutdown_exc: + self._fatal_error(shutdown_exc) + else: + self._loop.call_soon(self._transport.close) + + cdef _abort(self, exc): + self._set_state(UNWRAPPED) + if self._transport is not None: + self._transport._force_close(exc) + + # Outgoing flow + + cdef _write_appdata(self, list_of_data): + if self._state in (FLUSHING, SHUTDOWN, UNWRAPPED): + if self._conn_lost >= LOG_THRESHOLD_FOR_CONNLOST_WRITES: + aio_logger.warning('SSL connection is closed') + self._conn_lost += 1 + return + + for data in list_of_data: + self._write_backlog.append(data) + self._write_buffer_size += len(data) + try: + if self._state == WRAPPED: + self._do_write() + + except Exception as ex: + self._fatal_error(ex, 'Fatal error on SSL protocol') + + cdef _do_write(self): + cdef size_t data_len, count + try: + while self._write_backlog: + data = self._write_backlog[0] + count = self._sslobj_write(data) + data_len = len(data) + if count < data_len: + if not PyMemoryView_Check(data): + data = PyMemoryView_FromObject(data) + self._write_backlog[0] = data[count:] + self._write_buffer_size -= count + else: + del self._write_backlog[0] + self._write_buffer_size -= data_len + except ssl_SSLAgainErrors as exc: + pass + self._process_outgoing() + + cdef _process_outgoing(self): + if not self._ssl_writing_paused: + data = self._outgoing_read() + if len(data): + self._transport.write(data) + self._control_app_writing() + + # Incoming flow + + cdef _do_read(self): + if self._state != WRAPPED: + return + try: + if not self._app_reading_paused: + if self._app_protocol_is_buffer: + self._do_read__buffered() + else: + self._do_read__copied() + if self._write_backlog: + self._do_write() + else: + self._process_outgoing() + self._control_ssl_reading() + except Exception as ex: + self._fatal_error(ex, 'Fatal error on SSL protocol') + + cdef _do_read__buffered(self): cdef: - _SSLPipe sslpipe = <_SSLPipe>self._sslpipe + Py_buffer pybuf + bint pybuf_inited = False + size_t wants, offset = 0 + int count = 1 + object buf + + buf = self._app_protocol_get_buffer(self._get_read_buffer_size()) + wants = len(buf) try: - for i in range(len(self._write_backlog)): - data, offset = self._write_backlog[0] - if data: - ssldata, offset = sslpipe.feed_appdata(data, offset) - elif offset: - ssldata = sslpipe.do_handshake(self._on_handshake_complete) - offset = 1 + count = self._sslobj_read(wants, buf) + + if count > 0: + offset = count + if offset < wants: + PyObject_GetBuffer(buf, &pybuf, PyBUF_WRITABLE) + pybuf_inited = True + while offset < wants: + buf = PyMemoryView_FromMemory( + (pybuf.buf) + offset, + wants - offset, + PyBUF_WRITE) + count = self._sslobj_read(wants - offset, buf) + if count > 0: + offset += count + else: + break else: - ssldata = sslpipe.shutdown(self._finalize) - offset = 1 - - self._transport.writelines(ssldata) - - if offset < len(data): - self._write_backlog[0] = (data, offset) - # A short write means that a write is blocked on a read - # We need to enable reading if it is paused! - assert sslpipe._need_ssldata - if self._transport._paused: - self._transport.resume_reading() + self._loop.call_soon(lambda: self._do_read()) + except ssl_SSLAgainErrors as exc: + pass + finally: + if pybuf_inited: + PyBuffer_Release(&pybuf) + if offset > 0: + self._app_protocol_buffer_updated(offset) + if not count: + # close_notify + self._call_eof_received() + self._start_shutdown() + + cdef _do_read__copied(self): + cdef: + list data + bytes first, chunk = b'1' + bint zero = True, one = False + + try: + while True: + chunk = self._sslobj_read(SSL_READ_MAX_SIZE) + if not chunk: break + if zero: + zero = False + one = True + first = chunk + elif one: + one = False + data = [first, chunk] + else: + data.append(chunk) + except ssl_SSLAgainErrors as exc: + pass + if one: + self._app_protocol.data_received(first) + elif not zero: + self._app_protocol.data_received(b''.join(data)) + if not chunk: + # close_notify + self._call_eof_received() + self._start_shutdown() + + cdef _call_eof_received(self): + try: + if not self._eof_received: + self._eof_received = True + keep_open = self._app_protocol.eof_received() + if keep_open: + aio_logger.warning('returning true from eof_received() ' + 'has no effect when using ssl') + except Exception as ex: + self._fatal_error(ex, 'Error calling eof_received()') - # An entire chunk from the backlog was processed. We can - # delete it and reduce the outstanding buffer size. - del self._write_backlog[0] - self._write_buffer_size -= len(data) - except Exception as exc: - if self._in_handshake: - # Exceptions will be re-raised in _on_handshake_complete. - self._on_handshake_complete(exc) - else: - self._fatal_error(exc, 'Fatal error on SSL transport') + # Flow control for writes from APP socket - def _fatal_error(self, exc, message='Fatal error on transport'): + cdef _control_app_writing(self): + cdef size_t size = self._get_write_buffer_size() + if size >= self._outgoing_high_water and not self._app_writing_paused: + self._app_writing_paused = True + try: + self._app_protocol.pause_writing() + except Exception as exc: + self._loop.call_exception_handler({ + 'message': 'protocol.pause_writing() failed', + 'exception': exc, + 'transport': self._app_transport, + 'protocol': self, + }) + elif size <= self._outgoing_low_water and self._app_writing_paused: + self._app_writing_paused = False + try: + self._app_protocol.resume_writing() + except Exception as exc: + self._loop.call_exception_handler({ + 'message': 'protocol.resume_writing() failed', + 'exception': exc, + 'transport': self._app_transport, + 'protocol': self, + }) + + cdef size_t _get_write_buffer_size(self): + return self._outgoing.pending + self._write_buffer_size + + cdef _set_write_buffer_limits(self, high=None, low=None): + high, low = add_flowcontrol_defaults( + high, low, FLOW_CONTROL_HIGH_WATER_SSL_WRITE) + self._outgoing_high_water = high + self._outgoing_low_water = low + + # Flow control for reads to APP socket + + cdef _pause_reading(self): + self._app_reading_paused = True + + cdef _resume_reading(self): + if self._app_reading_paused: + self._app_reading_paused = False + + def resume(): + if self._state == WRAPPED: + self._do_read() + elif self._state == FLUSHING: + self._do_flush() + elif self._state == SHUTDOWN: + self._do_shutdown() + self._loop.call_soon(resume) + + # Flow control for reads from SSL socket + + cdef _control_ssl_reading(self): + cdef size_t size = self._get_read_buffer_size() + if size >= self._incoming_high_water and not self._ssl_reading_paused: + self._ssl_reading_paused = True + self._transport.pause_reading() + elif size <= self._incoming_low_water and self._ssl_reading_paused: + self._ssl_reading_paused = False + self._transport.resume_reading() + + cdef _set_read_buffer_limits(self, high=None, low=None): + high, low = add_flowcontrol_defaults( + high, low, FLOW_CONTROL_HIGH_WATER_SSL_READ) + self._incoming_high_water = high + self._incoming_low_water = low + + cdef size_t _get_read_buffer_size(self): + return self._incoming.pending + + # Flow control for writes to SSL socket + + def pause_writing(self): + """Called when the low-level transport's buffer goes over + the high-water mark. + """ + assert not self._ssl_writing_paused + self._ssl_writing_paused = True + + def resume_writing(self): + """Called when the low-level transport's buffer drains below + the low-water mark. + """ + assert self._ssl_writing_paused + self._ssl_writing_paused = False + self._process_outgoing() + + cdef _fatal_error(self, exc, message='Fatal error on transport'): if self._transport: self._transport._force_close(exc) @@ -698,35 +841,3 @@ class SSLProtocol(object): 'transport': self._transport, 'protocol': self, }) - - def _finalize(self): - self._sslpipe = None - - if self._transport is not None: - self._transport.close() - - def _abort(self): - try: - if self._transport is not None: - self._transport.abort() - finally: - self._finalize() - - -cdef _feed_data_to_bufferred_proto(proto, data): - data_len = len(data) - while data_len: - buf = proto.get_buffer(data_len) - buf_len = len(buf) - if not buf_len: - raise RuntimeError('get_buffer() returned an empty buffer') - - if buf_len >= data_len: - buf[:data_len] = data - proto.buffer_updated(data_len) - return - else: - buf[:buf_len] = data[:buf_len] - proto.buffer_updated(buf_len) - data = data[buf_len:] - data_len = len(data)