@@ -1256,6 +1256,144 @@ async def start_server():
12561256 for client in clients :
12571257 client .stop ()
12581258
1259+ def test_create_server_ssl_over_ssl (self ):
1260+ if self .implementation == 'asyncio' :
1261+ raise unittest .SkipTest ('asyncio does not support SSL over SSL' )
1262+
1263+ CNT = 0 # number of clients that were successful
1264+ TOTAL_CNT = 25 # total number of clients that test will create
1265+ TIMEOUT = 10.0 # timeout for this test
1266+
1267+ A_DATA = b'A' * 1024 * 1024
1268+ B_DATA = b'B' * 1024 * 1024
1269+
1270+ sslctx_1 = self ._create_server_ssl_context (self .ONLYCERT , self .ONLYKEY )
1271+ client_sslctx_1 = self ._create_client_ssl_context ()
1272+ sslctx_2 = self ._create_server_ssl_context (self .ONLYCERT , self .ONLYKEY )
1273+ client_sslctx_2 = self ._create_client_ssl_context ()
1274+
1275+ clients = []
1276+
1277+ async def handle_client (reader , writer ):
1278+ nonlocal CNT
1279+
1280+ # hack reader and writer to call start_tls()
1281+ transport = writer ._transport
1282+ writer ._transport = None
1283+ reader ._transport = None
1284+
1285+ transport = await self .loop .start_tls (
1286+ transport , writer ._protocol , sslctx_2 , server_side = True )
1287+
1288+ # restore with new transport
1289+ writer ._transport = transport
1290+ reader ._transport = transport
1291+
1292+ data = await reader .readexactly (len (A_DATA ))
1293+ self .assertEqual (data , A_DATA )
1294+ writer .write (b'OK' )
1295+
1296+ data = await reader .readexactly (len (B_DATA ))
1297+ self .assertEqual (data , B_DATA )
1298+ writer .writelines ([b'SP' , bytearray (b'A' ), memoryview (b'M' )])
1299+
1300+ await writer .drain ()
1301+ writer .close ()
1302+
1303+ CNT += 1
1304+
1305+ async def test_client (addr ):
1306+ fut = asyncio .Future (loop = self .loop )
1307+
1308+ def prog (sock ):
1309+ try :
1310+ sock .connect (addr )
1311+ sock .starttls (client_sslctx_1 )
1312+
1313+ # because wrap_socket() doesn't work correctly on
1314+ # SSLSocket, we have to do the 2nd level SSL manually
1315+ incoming = ssl .MemoryBIO ()
1316+ outgoing = ssl .MemoryBIO ()
1317+ sslobj = client_sslctx_2 .wrap_bio (incoming , outgoing )
1318+
1319+ def do (func ):
1320+ while True :
1321+ try :
1322+ rv = func ()
1323+ break
1324+ except ssl .SSLWantReadError :
1325+ if outgoing .pending :
1326+ sock .send (outgoing .read ())
1327+ incoming .write (sock .recv (65536 ))
1328+ if outgoing .pending :
1329+ sock .send (outgoing .read ())
1330+ return rv
1331+
1332+ do (sslobj .do_handshake )
1333+
1334+ do (lambda : sslobj .write (A_DATA ))
1335+ data = do (lambda : sslobj .read (2 ))
1336+ self .assertEqual (data , b'OK' )
1337+
1338+ do (lambda : sslobj .write (B_DATA ))
1339+ data = b''
1340+ while data != b'SPAM' :
1341+ data += do (lambda : sslobj .read (4 ))
1342+ self .assertEqual (data , b'SPAM' )
1343+
1344+ do (sslobj .unwrap )
1345+ sock .close ()
1346+
1347+ except Exception as ex :
1348+ self .loop .call_soon_threadsafe (fut .set_exception , ex )
1349+ else :
1350+ self .loop .call_soon_threadsafe (fut .set_result , None )
1351+
1352+ client = self .tcp_client (prog )
1353+ client .start ()
1354+ clients .append (client )
1355+
1356+ await fut
1357+
1358+ async def start_server ():
1359+ extras = {}
1360+ if self .implementation != 'asyncio' or self .PY37 :
1361+ extras = dict (ssl_handshake_timeout = 10.0 )
1362+
1363+ srv = await asyncio .start_server (
1364+ handle_client ,
1365+ '127.0.0.1' , 0 ,
1366+ family = socket .AF_INET ,
1367+ ssl = sslctx_1 ,
1368+ loop = self .loop ,
1369+ ** extras )
1370+
1371+ try :
1372+ srv_socks = srv .sockets
1373+ self .assertTrue (srv_socks )
1374+
1375+ addr = srv_socks [0 ].getsockname ()
1376+
1377+ tasks = []
1378+ for _ in range (TOTAL_CNT ):
1379+ tasks .append (test_client (addr ))
1380+
1381+ await asyncio .wait_for (
1382+ asyncio .gather (* tasks , loop = self .loop ),
1383+ TIMEOUT , loop = self .loop )
1384+
1385+ finally :
1386+ self .loop .call_soon (srv .close )
1387+ await srv .wait_closed ()
1388+
1389+ with self ._silence_eof_received_warning ():
1390+ self .loop .run_until_complete (start_server ())
1391+
1392+ self .assertEqual (CNT , TOTAL_CNT )
1393+
1394+ for client in clients :
1395+ client .stop ()
1396+
12591397 def test_create_connection_ssl_1 (self ):
12601398 if self .implementation == 'asyncio' :
12611399 # Don't crash on asyncio errors
0 commit comments