Skip to content

Commit

Permalink
[Azure AMQP] Remove Deprecated SSL.Wrap_Socket (Azure#31524)
Browse files Browse the repository at this point in the history
* remove deprecated code

* some re-ordering

* bring changes over to SB

* fix order

* update async side

* sb async transport

* comment for ordering

* changes to sni wrap code

* sync code

* enable auto ssl handshake

* fix type of ssl opts

* bring change to SB

* minor fix

* refactor out server_side

* remove comment

* sync changes

* remove whitespace

* get rid of whitespace

* remove whitespace

* fix pylint
  • Loading branch information
kashifkhan authored Aug 28, 2023
1 parent f8f8742 commit 7f6bb71
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 112 deletions.
61 changes: 27 additions & 34 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def __init__(
self, host, *, port=AMQPS_PORT, socket_timeout=None, ssl_opts=None, **kwargs
):
self.sslopts = ssl_opts if isinstance(ssl_opts, dict) else {}
self.sslopts['server_hostname'] = host
self._read_buffer = BytesIO()
super(SSLTransport, self).__init__(
host, port=port, socket_timeout=socket_timeout, **kwargs
Expand All @@ -509,7 +510,6 @@ def __init__(
def _setup_transport(self):
"""Wrap the socket in an SSL object."""
self.sock = self._wrap_socket(self.sock, **self.sslopts)
self.sock.do_handshake()
self._quick_recv = self.sock.recv

def _wrap_socket(self, sock, context=None, **sslopts):
Expand All @@ -531,10 +531,9 @@ def _wrap_socket_sni(
sock,
keyfile=None,
certfile=None,
server_side=False,
cert_reqs=ssl.CERT_REQUIRED,
ca_certs=None,
do_handshake_on_connect=False,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
server_hostname=None,
ciphers=None,
Expand All @@ -548,7 +547,6 @@ def _wrap_socket_sni(
:param socket.socket sock: socket to wrap
:param str or None keyfile: key file path
:param str or None certfile: cert file path
:param bool or None server_side: server side socket
:param int cert_reqs: cert requirements
:param str or None ca_certs: ca certs file path
:param bool do_handshake_on_connect: do handshake on connect
Expand All @@ -562,44 +560,39 @@ def _wrap_socket_sni(
# Setup the right SSL version; default to optimal versions across
# ssl implementations
if ssl_version is None:
ssl_version = ssl.PROTOCOL_TLS
ssl_version = ssl.PROTOCOL_TLS_CLIENT
purpose = ssl.Purpose.SERVER_AUTH

opts = {
"sock": sock,
"keyfile": keyfile,
"certfile": certfile,
"server_side": server_side,
"cert_reqs": cert_reqs,
"ca_certs": ca_certs,
"do_handshake_on_connect": do_handshake_on_connect,
"suppress_ragged_eofs": suppress_ragged_eofs,
"ciphers": ciphers,
#'ssl_version': ssl_version
"server_hostname": server_hostname,
}

# TODO: We need to refactor this.
try:
sock = ssl.wrap_socket(**opts) # pylint: disable=deprecated-method
except FileNotFoundError as exc:
# FileNotFoundError does not have missing filename info, so adding it below.
# Assuming that this must be ca_certs, since this is the only file path that
# users can pass in (`connection_verify` in the EH/SB clients) through opts above.
# For uamqp exception parity. Remove later when resolving issue #27128.
exc.filename = {"ca_certs": ca_certs}
raise exc
# Set SNI headers if supported
if (
(server_hostname is not None)
and (hasattr(ssl, "HAS_SNI") and ssl.HAS_SNI)
and (hasattr(ssl, "SSLContext"))
):
context = ssl.SSLContext(opts["ssl_version"])
context = ssl.SSLContext(ssl_version)

if ca_certs is not None:
try:
context.load_verify_locations(ca_certs)
except FileNotFoundError as exc:
exc.filename = {"ca_certs": ca_certs}
raise exc from None
elif context.verify_mode != ssl.CERT_NONE:
# load the default system root CA certs.
context.load_default_certs(purpose=purpose)

if certfile is not None:
context.load_cert_chain(certfile, keyfile)

if ciphers is not None:
context.set_ciphers(ciphers)

if cert_reqs == ssl.CERT_NONE and server_hostname is None:
context.check_hostname = False
context.verify_mode = cert_reqs
if cert_reqs != ssl.CERT_NONE:
context.check_hostname = True
if (certfile is not None) and (keyfile is not None):
context.load_cert_chain(certfile, keyfile)
sock = context.wrap_socket(sock, server_hostname=server_hostname)

sock = context.wrap_socket(**opts)
return sock

def _shutdown_transport(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,31 +185,40 @@ def _build_ssl_opts(self, sslopts):
return self._build_ssl_context(**sslopts.pop("context"))
ssl_version = sslopts.get("ssl_version")
if ssl_version is None:
ssl_version = ssl.PROTOCOL_TLS
ssl_version = ssl.PROTOCOL_TLS_CLIENT

context = ssl.SSLContext(ssl_version)

purpose = ssl.Purpose.SERVER_AUTH

ca_certs = sslopts.get("ca_certs")

if ca_certs is not None:
try:
context.load_verify_locations(ca_certs)
except FileNotFoundError as exc:
# FileNotFoundError does not have missing filename info, so adding it below.
# since this is the only file path that users can pass in
# (`connection_verify` in the EH/SB clients) through opts above.
exc.filename = {"ca_certs": ca_certs}
raise exc from None
elif context.verify_mode != ssl.CERT_NONE:
# load the default system root CA certs.
context.load_default_certs(purpose=purpose)

certfile = sslopts.get("certfile")
keyfile = sslopts.get("keyfile")
if certfile is not None:
context.load_cert_chain(certfile, keyfile)


# Set SNI headers if supported
server_hostname = sslopts.get("server_hostname")
if (
(server_hostname is not None)
and (hasattr(ssl, "HAS_SNI") and ssl.HAS_SNI)
and (hasattr(ssl, "SSLContext"))
):
context = ssl.SSLContext(ssl_version)
cert_reqs = sslopts.get("cert_reqs", ssl.CERT_REQUIRED)
certfile = sslopts.get("certfile")
keyfile = sslopts.get("keyfile")
cert_reqs = sslopts.get("cert_reqs", ssl.CERT_REQUIRED)
if cert_reqs == ssl.CERT_NONE and server_hostname is None:
context.check_hostname = False
context.verify_mode = cert_reqs
if cert_reqs != ssl.CERT_NONE:
context.check_hostname = True
if (certfile is not None) and (keyfile is not None):
context.load_cert_chain(certfile, keyfile)
return context
ca_certs = sslopts.get("ca_certs")
if ca_certs:
context = ssl.SSLContext(ssl_version)
context.load_verify_locations(ca_certs)
return context
return True

return context
except TypeError:
raise TypeError(
"SSL configuration must be a dictionary, or the value True."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def __init__(
self, host, *, port=AMQPS_PORT, socket_timeout=None, ssl_opts=None, **kwargs
):
self.sslopts = ssl_opts if isinstance(ssl_opts, dict) else {}
self.sslopts['server_hostname'] = host
self._read_buffer = BytesIO()
super(SSLTransport, self).__init__(
host, port=port, socket_timeout=socket_timeout, **kwargs
Expand All @@ -509,7 +510,6 @@ def __init__(
def _setup_transport(self):
"""Wrap the socket in an SSL object."""
self.sock = self._wrap_socket(self.sock, **self.sslopts)
self.sock.do_handshake()
self._quick_recv = self.sock.recv

def _wrap_socket(self, sock, context=None, **sslopts):
Expand All @@ -531,10 +531,9 @@ def _wrap_socket_sni(
sock,
keyfile=None,
certfile=None,
server_side=False,
cert_reqs=ssl.CERT_REQUIRED,
ca_certs=None,
do_handshake_on_connect=False,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
server_hostname=None,
ciphers=None,
Expand All @@ -548,7 +547,6 @@ def _wrap_socket_sni(
:param socket.socket sock: socket to wrap
:param str or None keyfile: key file path
:param str or None certfile: cert file path
:param bool or None server_side: server side socket
:param int cert_reqs: cert requirements
:param str or None ca_certs: ca certs file path
:param bool do_handshake_on_connect: do handshake on connect
Expand All @@ -562,44 +560,39 @@ def _wrap_socket_sni(
# Setup the right SSL version; default to optimal versions across
# ssl implementations
if ssl_version is None:
ssl_version = ssl.PROTOCOL_TLS
ssl_version = ssl.PROTOCOL_TLS_CLIENT
purpose = ssl.Purpose.SERVER_AUTH

opts = {
"sock": sock,
"keyfile": keyfile,
"certfile": certfile,
"server_side": server_side,
"cert_reqs": cert_reqs,
"ca_certs": ca_certs,
"do_handshake_on_connect": do_handshake_on_connect,
"suppress_ragged_eofs": suppress_ragged_eofs,
"ciphers": ciphers,
#'ssl_version': ssl_version
"server_hostname": server_hostname,
}

# TODO: We need to refactor this.
try:
sock = ssl.wrap_socket(**opts) # pylint: disable=deprecated-method
except FileNotFoundError as exc:
# FileNotFoundError does not have missing filename info, so adding it below.
# Assuming that this must be ca_certs, since this is the only file path that
# users can pass in (`connection_verify` in the EH/SB clients) through opts above.
# For uamqp exception parity. Remove later when resolving issue #27128.
exc.filename = {"ca_certs": ca_certs}
raise exc
# Set SNI headers if supported
if (
(server_hostname is not None)
and (hasattr(ssl, "HAS_SNI") and ssl.HAS_SNI)
and (hasattr(ssl, "SSLContext"))
):
context = ssl.SSLContext(opts["ssl_version"])
context = ssl.SSLContext(ssl_version)

if ca_certs is not None:
try:
context.load_verify_locations(ca_certs)
except FileNotFoundError as exc:
exc.filename = {"ca_certs": ca_certs}
raise exc from None
elif context.verify_mode != ssl.CERT_NONE:
# load the default system root CA certs.
context.load_default_certs(purpose=purpose)

if certfile is not None:
context.load_cert_chain(certfile, keyfile)

if ciphers is not None:
context.set_ciphers(ciphers)

if cert_reqs == ssl.CERT_NONE and server_hostname is None:
context.check_hostname = False
context.verify_mode = cert_reqs
if cert_reqs != ssl.CERT_NONE:
context.check_hostname = True
if (certfile is not None) and (keyfile is not None):
context.load_cert_chain(certfile, keyfile)
sock = context.wrap_socket(sock, server_hostname=server_hostname)

sock = context.wrap_socket(**opts)
return sock

def _shutdown_transport(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,31 +185,40 @@ def _build_ssl_opts(self, sslopts):
return self._build_ssl_context(**sslopts.pop("context"))
ssl_version = sslopts.get("ssl_version")
if ssl_version is None:
ssl_version = ssl.PROTOCOL_TLS
ssl_version = ssl.PROTOCOL_TLS_CLIENT

context = ssl.SSLContext(ssl_version)

purpose = ssl.Purpose.SERVER_AUTH

ca_certs = sslopts.get("ca_certs")

if ca_certs is not None:
try:
context.load_verify_locations(ca_certs)
except FileNotFoundError as exc:
# FileNotFoundError does not have missing filename info, so adding it below.
# since this is the only file path that users can pass in
# (`connection_verify` in the EH/SB clients) through opts above.
exc.filename = {"ca_certs": ca_certs}
raise exc from None
elif context.verify_mode != ssl.CERT_NONE:
# load the default system root CA certs.
context.load_default_certs(purpose=purpose)

certfile = sslopts.get("certfile")
keyfile = sslopts.get("keyfile")
if certfile is not None:
context.load_cert_chain(certfile, keyfile)


# Set SNI headers if supported
server_hostname = sslopts.get("server_hostname")
if (
(server_hostname is not None)
and (hasattr(ssl, "HAS_SNI") and ssl.HAS_SNI)
and (hasattr(ssl, "SSLContext"))
):
context = ssl.SSLContext(ssl_version)
cert_reqs = sslopts.get("cert_reqs", ssl.CERT_REQUIRED)
certfile = sslopts.get("certfile")
keyfile = sslopts.get("keyfile")
cert_reqs = sslopts.get("cert_reqs", ssl.CERT_REQUIRED)
if cert_reqs == ssl.CERT_NONE and server_hostname is None:
context.check_hostname = False
context.verify_mode = cert_reqs
if cert_reqs != ssl.CERT_NONE:
context.check_hostname = True
if (certfile is not None) and (keyfile is not None):
context.load_cert_chain(certfile, keyfile)
return context
ca_certs = sslopts.get("ca_certs")
if ca_certs:
context = ssl.SSLContext(ssl_version)
context.load_verify_locations(ca_certs)
return context
return True

return context
except TypeError:
raise TypeError(
"SSL configuration must be a dictionary, or the value True."
Expand Down

0 comments on commit 7f6bb71

Please sign in to comment.