From 7f6bb71099524460a98baed80ed5432b2d8b8284 Mon Sep 17 00:00:00 2001 From: Kashif Khan <361477+kashifkhan@users.noreply.github.com> Date: Mon, 28 Aug 2023 14:13:18 -0500 Subject: [PATCH] [Azure AMQP] Remove Deprecated SSL.Wrap_Socket (#31524) * remove deprecated code * some re-ordering * bring changes over to SB * fix order * update async side * sb async transport * comment for ordering * changes to sni wrap code * sync code * enable auto ssl handshake * fix type of ssl opts * bring change to SB * minor fix * refactor out server_side * remove comment * sync changes * remove whitespace * get rid of whitespace * remove whitespace * fix pylint --- .../azure/eventhub/_pyamqp/_transport.py | 61 ++++++++----------- .../eventhub/_pyamqp/aio/_transport_async.py | 53 +++++++++------- .../azure/servicebus/_pyamqp/_transport.py | 61 ++++++++----------- .../_pyamqp/aio/_transport_async.py | 53 +++++++++------- 4 files changed, 116 insertions(+), 112 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index 7c67990aecaf..c75253ca1fef 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -501,6 +501,7 @@ def __init__( self, host, *, port=AMQPS_PORT, socket_timeout=None, ssl_opts=None, **kwargs ): self.sslopts = ssl_opts if isinstance(ssl_opts, dict) else {} + self.sslopts['server_hostname'] = host self._read_buffer = BytesIO() super(SSLTransport, self).__init__( host, port=port, socket_timeout=socket_timeout, **kwargs @@ -509,7 +510,6 @@ def __init__( def _setup_transport(self): """Wrap the socket in an SSL object.""" self.sock = self._wrap_socket(self.sock, **self.sslopts) - self.sock.do_handshake() self._quick_recv = self.sock.recv def _wrap_socket(self, sock, context=None, **sslopts): @@ -531,10 +531,9 @@ def _wrap_socket_sni( sock, keyfile=None, certfile=None, - server_side=False, cert_reqs=ssl.CERT_REQUIRED, ca_certs=None, - do_handshake_on_connect=False, + do_handshake_on_connect=True, suppress_ragged_eofs=True, server_hostname=None, ciphers=None, @@ -548,7 +547,6 @@ def _wrap_socket_sni( :param socket.socket sock: socket to wrap :param str or None keyfile: key file path :param str or None certfile: cert file path - :param bool or None server_side: server side socket :param int cert_reqs: cert requirements :param str or None ca_certs: ca certs file path :param bool do_handshake_on_connect: do handshake on connect @@ -562,44 +560,39 @@ def _wrap_socket_sni( # Setup the right SSL version; default to optimal versions across # ssl implementations if ssl_version is None: - ssl_version = ssl.PROTOCOL_TLS + ssl_version = ssl.PROTOCOL_TLS_CLIENT + purpose = ssl.Purpose.SERVER_AUTH opts = { "sock": sock, - "keyfile": keyfile, - "certfile": certfile, - "server_side": server_side, - "cert_reqs": cert_reqs, - "ca_certs": ca_certs, "do_handshake_on_connect": do_handshake_on_connect, "suppress_ragged_eofs": suppress_ragged_eofs, - "ciphers": ciphers, - #'ssl_version': ssl_version + "server_hostname": server_hostname, } - # TODO: We need to refactor this. - try: - sock = ssl.wrap_socket(**opts) # pylint: disable=deprecated-method - except FileNotFoundError as exc: - # FileNotFoundError does not have missing filename info, so adding it below. - # Assuming that this must be ca_certs, since this is the only file path that - # users can pass in (`connection_verify` in the EH/SB clients) through opts above. - # For uamqp exception parity. Remove later when resolving issue #27128. - exc.filename = {"ca_certs": ca_certs} - raise exc - # Set SNI headers if supported - if ( - (server_hostname is not None) - and (hasattr(ssl, "HAS_SNI") and ssl.HAS_SNI) - and (hasattr(ssl, "SSLContext")) - ): - context = ssl.SSLContext(opts["ssl_version"]) + context = ssl.SSLContext(ssl_version) + + if ca_certs is not None: + try: + context.load_verify_locations(ca_certs) + except FileNotFoundError as exc: + exc.filename = {"ca_certs": ca_certs} + raise exc from None + elif context.verify_mode != ssl.CERT_NONE: + # load the default system root CA certs. + context.load_default_certs(purpose=purpose) + + if certfile is not None: + context.load_cert_chain(certfile, keyfile) + + if ciphers is not None: + context.set_ciphers(ciphers) + + if cert_reqs == ssl.CERT_NONE and server_hostname is None: + context.check_hostname = False context.verify_mode = cert_reqs - if cert_reqs != ssl.CERT_NONE: - context.check_hostname = True - if (certfile is not None) and (keyfile is not None): - context.load_cert_chain(certfile, keyfile) - sock = context.wrap_socket(sock, server_hostname=server_hostname) + + sock = context.wrap_socket(**opts) return sock def _shutdown_transport(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py index effac268676e..5f1ba00548a1 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py @@ -185,31 +185,40 @@ def _build_ssl_opts(self, sslopts): return self._build_ssl_context(**sslopts.pop("context")) ssl_version = sslopts.get("ssl_version") if ssl_version is None: - ssl_version = ssl.PROTOCOL_TLS + ssl_version = ssl.PROTOCOL_TLS_CLIENT + + context = ssl.SSLContext(ssl_version) + + purpose = ssl.Purpose.SERVER_AUTH + + ca_certs = sslopts.get("ca_certs") + + if ca_certs is not None: + try: + context.load_verify_locations(ca_certs) + except FileNotFoundError as exc: + # FileNotFoundError does not have missing filename info, so adding it below. + # since this is the only file path that users can pass in + # (`connection_verify` in the EH/SB clients) through opts above. + exc.filename = {"ca_certs": ca_certs} + raise exc from None + elif context.verify_mode != ssl.CERT_NONE: + # load the default system root CA certs. + context.load_default_certs(purpose=purpose) + + certfile = sslopts.get("certfile") + keyfile = sslopts.get("keyfile") + if certfile is not None: + context.load_cert_chain(certfile, keyfile) + - # Set SNI headers if supported server_hostname = sslopts.get("server_hostname") - if ( - (server_hostname is not None) - and (hasattr(ssl, "HAS_SNI") and ssl.HAS_SNI) - and (hasattr(ssl, "SSLContext")) - ): - context = ssl.SSLContext(ssl_version) - cert_reqs = sslopts.get("cert_reqs", ssl.CERT_REQUIRED) - certfile = sslopts.get("certfile") - keyfile = sslopts.get("keyfile") + cert_reqs = sslopts.get("cert_reqs", ssl.CERT_REQUIRED) + if cert_reqs == ssl.CERT_NONE and server_hostname is None: + context.check_hostname = False context.verify_mode = cert_reqs - if cert_reqs != ssl.CERT_NONE: - context.check_hostname = True - if (certfile is not None) and (keyfile is not None): - context.load_cert_chain(certfile, keyfile) - return context - ca_certs = sslopts.get("ca_certs") - if ca_certs: - context = ssl.SSLContext(ssl_version) - context.load_verify_locations(ca_certs) - return context - return True + + return context except TypeError: raise TypeError( "SSL configuration must be a dictionary, or the value True." diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py index 7c67990aecaf..c75253ca1fef 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py @@ -501,6 +501,7 @@ def __init__( self, host, *, port=AMQPS_PORT, socket_timeout=None, ssl_opts=None, **kwargs ): self.sslopts = ssl_opts if isinstance(ssl_opts, dict) else {} + self.sslopts['server_hostname'] = host self._read_buffer = BytesIO() super(SSLTransport, self).__init__( host, port=port, socket_timeout=socket_timeout, **kwargs @@ -509,7 +510,6 @@ def __init__( def _setup_transport(self): """Wrap the socket in an SSL object.""" self.sock = self._wrap_socket(self.sock, **self.sslopts) - self.sock.do_handshake() self._quick_recv = self.sock.recv def _wrap_socket(self, sock, context=None, **sslopts): @@ -531,10 +531,9 @@ def _wrap_socket_sni( sock, keyfile=None, certfile=None, - server_side=False, cert_reqs=ssl.CERT_REQUIRED, ca_certs=None, - do_handshake_on_connect=False, + do_handshake_on_connect=True, suppress_ragged_eofs=True, server_hostname=None, ciphers=None, @@ -548,7 +547,6 @@ def _wrap_socket_sni( :param socket.socket sock: socket to wrap :param str or None keyfile: key file path :param str or None certfile: cert file path - :param bool or None server_side: server side socket :param int cert_reqs: cert requirements :param str or None ca_certs: ca certs file path :param bool do_handshake_on_connect: do handshake on connect @@ -562,44 +560,39 @@ def _wrap_socket_sni( # Setup the right SSL version; default to optimal versions across # ssl implementations if ssl_version is None: - ssl_version = ssl.PROTOCOL_TLS + ssl_version = ssl.PROTOCOL_TLS_CLIENT + purpose = ssl.Purpose.SERVER_AUTH opts = { "sock": sock, - "keyfile": keyfile, - "certfile": certfile, - "server_side": server_side, - "cert_reqs": cert_reqs, - "ca_certs": ca_certs, "do_handshake_on_connect": do_handshake_on_connect, "suppress_ragged_eofs": suppress_ragged_eofs, - "ciphers": ciphers, - #'ssl_version': ssl_version + "server_hostname": server_hostname, } - # TODO: We need to refactor this. - try: - sock = ssl.wrap_socket(**opts) # pylint: disable=deprecated-method - except FileNotFoundError as exc: - # FileNotFoundError does not have missing filename info, so adding it below. - # Assuming that this must be ca_certs, since this is the only file path that - # users can pass in (`connection_verify` in the EH/SB clients) through opts above. - # For uamqp exception parity. Remove later when resolving issue #27128. - exc.filename = {"ca_certs": ca_certs} - raise exc - # Set SNI headers if supported - if ( - (server_hostname is not None) - and (hasattr(ssl, "HAS_SNI") and ssl.HAS_SNI) - and (hasattr(ssl, "SSLContext")) - ): - context = ssl.SSLContext(opts["ssl_version"]) + context = ssl.SSLContext(ssl_version) + + if ca_certs is not None: + try: + context.load_verify_locations(ca_certs) + except FileNotFoundError as exc: + exc.filename = {"ca_certs": ca_certs} + raise exc from None + elif context.verify_mode != ssl.CERT_NONE: + # load the default system root CA certs. + context.load_default_certs(purpose=purpose) + + if certfile is not None: + context.load_cert_chain(certfile, keyfile) + + if ciphers is not None: + context.set_ciphers(ciphers) + + if cert_reqs == ssl.CERT_NONE and server_hostname is None: + context.check_hostname = False context.verify_mode = cert_reqs - if cert_reqs != ssl.CERT_NONE: - context.check_hostname = True - if (certfile is not None) and (keyfile is not None): - context.load_cert_chain(certfile, keyfile) - sock = context.wrap_socket(sock, server_hostname=server_hostname) + + sock = context.wrap_socket(**opts) return sock def _shutdown_transport(self): diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_transport_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_transport_async.py index effac268676e..5f1ba00548a1 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_transport_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_transport_async.py @@ -185,31 +185,40 @@ def _build_ssl_opts(self, sslopts): return self._build_ssl_context(**sslopts.pop("context")) ssl_version = sslopts.get("ssl_version") if ssl_version is None: - ssl_version = ssl.PROTOCOL_TLS + ssl_version = ssl.PROTOCOL_TLS_CLIENT + + context = ssl.SSLContext(ssl_version) + + purpose = ssl.Purpose.SERVER_AUTH + + ca_certs = sslopts.get("ca_certs") + + if ca_certs is not None: + try: + context.load_verify_locations(ca_certs) + except FileNotFoundError as exc: + # FileNotFoundError does not have missing filename info, so adding it below. + # since this is the only file path that users can pass in + # (`connection_verify` in the EH/SB clients) through opts above. + exc.filename = {"ca_certs": ca_certs} + raise exc from None + elif context.verify_mode != ssl.CERT_NONE: + # load the default system root CA certs. + context.load_default_certs(purpose=purpose) + + certfile = sslopts.get("certfile") + keyfile = sslopts.get("keyfile") + if certfile is not None: + context.load_cert_chain(certfile, keyfile) + - # Set SNI headers if supported server_hostname = sslopts.get("server_hostname") - if ( - (server_hostname is not None) - and (hasattr(ssl, "HAS_SNI") and ssl.HAS_SNI) - and (hasattr(ssl, "SSLContext")) - ): - context = ssl.SSLContext(ssl_version) - cert_reqs = sslopts.get("cert_reqs", ssl.CERT_REQUIRED) - certfile = sslopts.get("certfile") - keyfile = sslopts.get("keyfile") + cert_reqs = sslopts.get("cert_reqs", ssl.CERT_REQUIRED) + if cert_reqs == ssl.CERT_NONE and server_hostname is None: + context.check_hostname = False context.verify_mode = cert_reqs - if cert_reqs != ssl.CERT_NONE: - context.check_hostname = True - if (certfile is not None) and (keyfile is not None): - context.load_cert_chain(certfile, keyfile) - return context - ca_certs = sslopts.get("ca_certs") - if ca_certs: - context = ssl.SSLContext(ssl_version) - context.load_verify_locations(ca_certs) - return context - return True + + return context except TypeError: raise TypeError( "SSL configuration must be a dictionary, or the value True."