@@ -532,6 +532,14 @@ def __init__(self, certificate, ssl_version=None,
532532 threading .Thread .__init__ (self )
533533 self .daemon = True
534534
535+ def __enter__ (self ):
536+ self .start (threading .Event ())
537+ self .flag .wait ()
538+
539+ def __exit__ (self , * args ):
540+ self .stop ()
541+ self .join ()
542+
535543 def start (self , flag = None ):
536544 self .flag = flag
537545 threading .Thread .start (self )
@@ -638,6 +646,20 @@ def __init__(self, certfile):
638646 def __str__ (self ):
639647 return "<%s %s>" % (self .__class__ .__name__ , self .server )
640648
649+ def __enter__ (self ):
650+ self .start (threading .Event ())
651+ self .flag .wait ()
652+
653+ def __exit__ (self , * args ):
654+ if test_support .verbose :
655+ sys .stdout .write (" cleanup: stopping server.\n " )
656+ self .stop ()
657+ if test_support .verbose :
658+ sys .stdout .write (" cleanup: joining server thread.\n " )
659+ self .join ()
660+ if test_support .verbose :
661+ sys .stdout .write (" cleanup: successfully joined.\n " )
662+
641663 def start (self , flag = None ):
642664 self .flag = flag
643665 threading .Thread .start (self )
@@ -752,12 +774,7 @@ def bad_cert_test(certfile):
752774 server = ThreadedEchoServer (CERTFILE ,
753775 certreqs = ssl .CERT_REQUIRED ,
754776 cacerts = CERTFILE , chatty = False )
755- flag = threading .Event ()
756- server .start (flag )
757- # wait for it to start
758- flag .wait ()
759- # try to connect
760- try :
777+ with server :
761778 try :
762779 s = ssl .wrap_socket (socket .socket (),
763780 certfile = certfile ,
@@ -771,9 +788,6 @@ def bad_cert_test(certfile):
771788 sys .stdout .write ("\n socket.error is %s\n " % x [1 ])
772789 else :
773790 raise AssertionError ("Use of invalid cert should have failed!" )
774- finally :
775- server .stop ()
776- server .join ()
777791
778792 def server_params_test (certfile , protocol , certreqs , cacertsfile ,
779793 client_certfile , client_protocol = None , indata = "FOO\n " ,
@@ -791,14 +805,10 @@ def server_params_test(certfile, protocol, certreqs, cacertsfile,
791805 chatty = chatty ,
792806 connectionchatty = connectionchatty ,
793807 wrap_accepting_socket = wrap_accepting_socket )
794- flag = threading .Event ()
795- server .start (flag )
796- # wait for it to start
797- flag .wait ()
798- # try to connect
799- if client_protocol is None :
800- client_protocol = protocol
801- try :
808+ with server :
809+ # try to connect
810+ if client_protocol is None :
811+ client_protocol = protocol
802812 s = ssl .wrap_socket (socket .socket (),
803813 certfile = client_certfile ,
804814 ca_certs = cacertsfile ,
@@ -826,9 +836,6 @@ def server_params_test(certfile, protocol, certreqs, cacertsfile,
826836 if test_support .verbose :
827837 sys .stdout .write (" client: closing connection.\n " )
828838 s .close ()
829- finally :
830- server .stop ()
831- server .join ()
832839
833840 def try_protocol_combo (server_protocol ,
834841 client_protocol ,
@@ -930,12 +937,7 @@ def test_getpeercert(self):
930937 ssl_version = ssl .PROTOCOL_SSLv23 ,
931938 cacerts = CERTFILE ,
932939 chatty = False )
933- flag = threading .Event ()
934- server .start (flag )
935- # wait for it to start
936- flag .wait ()
937- # try to connect
938- try :
940+ with server :
939941 s = ssl .wrap_socket (socket .socket (),
940942 certfile = CERTFILE ,
941943 ca_certs = CERTFILE ,
@@ -957,9 +959,6 @@ def test_getpeercert(self):
957959 "Missing or invalid 'organizationName' field in certificate subject; "
958960 "should be 'Python Software Foundation'." )
959961 s .close ()
960- finally :
961- server .stop ()
962- server .join ()
963962
964963 def test_empty_cert (self ):
965964 """Connecting with an empty cert file"""
@@ -1042,13 +1041,8 @@ def test_starttls(self):
10421041 starttls_server = True ,
10431042 chatty = True ,
10441043 connectionchatty = True )
1045- flag = threading .Event ()
1046- server .start (flag )
1047- # wait for it to start
1048- flag .wait ()
1049- # try to connect
10501044 wrapped = False
1051- try :
1045+ with server :
10521046 s = socket .socket ()
10531047 s .setblocking (1 )
10541048 s .connect ((HOST , server .port ))
@@ -1093,9 +1087,6 @@ def test_starttls(self):
10931087 else :
10941088 s .send ("over\n " )
10951089 s .close ()
1096- finally :
1097- server .stop ()
1098- server .join ()
10991090
11001091 def test_socketserver (self ):
11011092 """Using a SocketServer to create and manage SSL connections."""
@@ -1145,12 +1136,7 @@ def test_asyncore_server(self):
11451136 if test_support .verbose :
11461137 sys .stdout .write ("\n " )
11471138 server = AsyncoreEchoServer (CERTFILE )
1148- flag = threading .Event ()
1149- server .start (flag )
1150- # wait for it to start
1151- flag .wait ()
1152- # try to connect
1153- try :
1139+ with server :
11541140 s = ssl .wrap_socket (socket .socket ())
11551141 s .connect (('127.0.0.1' , server .port ))
11561142 if test_support .verbose :
@@ -1169,10 +1155,6 @@ def test_asyncore_server(self):
11691155 if test_support .verbose :
11701156 sys .stdout .write (" client: closing connection.\n " )
11711157 s .close ()
1172- finally :
1173- server .stop ()
1174- # wait for server thread to end
1175- server .join ()
11761158
11771159 def test_recv_send (self ):
11781160 """Test recv(), send() and friends."""
@@ -1185,19 +1167,14 @@ def test_recv_send(self):
11851167 cacerts = CERTFILE ,
11861168 chatty = True ,
11871169 connectionchatty = False )
1188- flag = threading .Event ()
1189- server .start (flag )
1190- # wait for it to start
1191- flag .wait ()
1192- # try to connect
1193- s = ssl .wrap_socket (socket .socket (),
1194- server_side = False ,
1195- certfile = CERTFILE ,
1196- ca_certs = CERTFILE ,
1197- cert_reqs = ssl .CERT_NONE ,
1198- ssl_version = ssl .PROTOCOL_TLSv1 )
1199- s .connect ((HOST , server .port ))
1200- try :
1170+ with server :
1171+ s = ssl .wrap_socket (socket .socket (),
1172+ server_side = False ,
1173+ certfile = CERTFILE ,
1174+ ca_certs = CERTFILE ,
1175+ cert_reqs = ssl .CERT_NONE ,
1176+ ssl_version = ssl .PROTOCOL_TLSv1 )
1177+ s .connect ((HOST , server .port ))
12011178 # helper methods for standardising recv* method signatures
12021179 def _recv_into ():
12031180 b = bytearray ("\0 " * 100 )
@@ -1285,9 +1262,6 @@ def _recvfrom_into():
12851262
12861263 s .write ("over\n " .encode ("ASCII" , "strict" ))
12871264 s .close ()
1288- finally :
1289- server .stop ()
1290- server .join ()
12911265
12921266 def test_handshake_timeout (self ):
12931267 # Issue #5103: SSL handshake must respect the socket timeout
0 commit comments