Skip to content

Commit fb9471a

Browse files
poiujvladvildanov
authored andcommitted
Allow to control the minimum SSL version (#3127)
* Allow to control the minimum SSL version It's useful for applications that has strict security requirements. * Add tests for minimum SSL version The commit updates test_tcp_ssl_connect for both sync and async connections. Now it sets the minimum SSL version. The test is ran with both TLSv1.2 and TLSv1.3 (if supported). A new test case is test_tcp_ssl_version_mismatch. The test added for both sync and async connections. It uses TLS 1.3 on the client side, and TLS 1.2 on the server side. It expects a connection error. The test is skipped if TLS 1.3 is not supported. * Add example of using a minimum TLS version
1 parent c1c6671 commit fb9471a

File tree

9 files changed

+161
-10
lines changed

9 files changed

+161
-10
lines changed

CHANGES

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
* Allow to control the minimum SSL version
12
* Add an optional lock_name attribute to LockError.
23
* Fix return types for `get`, `set_path` and `strappend` in JSONCommands
34
* Connection.register_connect_callback() is made public.

docs/examples/ssl_connection_examples.ipynb

+36
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,42 @@
7676
"ssl_connection.ping()"
7777
]
7878
},
79+
{
80+
"cell_type": "markdown",
81+
"metadata": {},
82+
"source": [
83+
"## Connecting to a Redis instance via SSL, while specifying a minimum TLS version"
84+
]
85+
},
86+
{
87+
"cell_type": "code",
88+
"execution_count": null,
89+
"metadata": {},
90+
"outputs": [
91+
{
92+
"data": {
93+
"text/plain": [
94+
"True"
95+
]
96+
},
97+
"execution_count": 6,
98+
"metadata": {},
99+
"output_type": "execute_result"
100+
}
101+
],
102+
"source": [
103+
"import redis\n",
104+
"import ssl\n",
105+
"\n",
106+
"ssl_conn = redis.Redis(\n",
107+
" host=\"localhost\",\n",
108+
" port=6666,\n",
109+
" ssl=True,\n",
110+
" ssl_min_version=ssl.TLSVersion.TLSv1_3,\n",
111+
")\n",
112+
"ssl_conn.ping()"
113+
]
114+
},
79115
{
80116
"cell_type": "markdown",
81117
"metadata": {},

redis/asyncio/client.py

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import copy
33
import inspect
44
import re
5+
import ssl
56
import warnings
67
from typing import (
78
TYPE_CHECKING,
@@ -225,6 +226,7 @@ def __init__(
225226
ssl_ca_certs: Optional[str] = None,
226227
ssl_ca_data: Optional[str] = None,
227228
ssl_check_hostname: bool = False,
229+
ssl_min_version: Optional[ssl.TLSVersion] = None,
228230
max_connections: Optional[int] = None,
229231
single_connection_client: bool = False,
230232
health_check_interval: int = 0,
@@ -331,6 +333,7 @@ def __init__(
331333
"ssl_ca_certs": ssl_ca_certs,
332334
"ssl_ca_data": ssl_ca_data,
333335
"ssl_check_hostname": ssl_check_hostname,
336+
"ssl_min_version": ssl_min_version,
334337
}
335338
)
336339
# This arg only used if no pool is passed in

redis/asyncio/cluster.py

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import collections
33
import random
44
import socket
5+
import ssl
56
import warnings
67
from typing import (
78
Any,
@@ -271,6 +272,7 @@ def __init__(
271272
ssl_certfile: Optional[str] = None,
272273
ssl_check_hostname: bool = False,
273274
ssl_keyfile: Optional[str] = None,
275+
ssl_min_version: Optional[ssl.TLSVersion] = None,
274276
protocol: Optional[int] = 2,
275277
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
276278
cache_enabled: bool = False,
@@ -344,6 +346,7 @@ def __init__(
344346
"ssl_certfile": ssl_certfile,
345347
"ssl_check_hostname": ssl_check_hostname,
346348
"ssl_keyfile": ssl_keyfile,
349+
"ssl_min_version": ssl_min_version,
347350
}
348351
)
349352

redis/asyncio/connection.py

+11
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,7 @@ def __init__(
822822
ssl_ca_certs: Optional[str] = None,
823823
ssl_ca_data: Optional[str] = None,
824824
ssl_check_hostname: bool = False,
825+
ssl_min_version: Optional[ssl.TLSVersion] = None,
825826
**kwargs,
826827
):
827828
self.ssl_context: RedisSSLContext = RedisSSLContext(
@@ -831,6 +832,7 @@ def __init__(
831832
ca_certs=ssl_ca_certs,
832833
ca_data=ssl_ca_data,
833834
check_hostname=ssl_check_hostname,
835+
min_version=ssl_min_version,
834836
)
835837
super().__init__(**kwargs)
836838

@@ -863,6 +865,10 @@ def ca_data(self):
863865
def check_hostname(self):
864866
return self.ssl_context.check_hostname
865867

868+
@property
869+
def min_version(self):
870+
return self.ssl_context.min_version
871+
866872

867873
class RedisSSLContext:
868874
__slots__ = (
@@ -873,6 +879,7 @@ class RedisSSLContext:
873879
"ca_data",
874880
"context",
875881
"check_hostname",
882+
"min_version",
876883
)
877884

878885
def __init__(
@@ -883,6 +890,7 @@ def __init__(
883890
ca_certs: Optional[str] = None,
884891
ca_data: Optional[str] = None,
885892
check_hostname: bool = False,
893+
min_version: Optional[ssl.TLSVersion] = None,
886894
):
887895
self.keyfile = keyfile
888896
self.certfile = certfile
@@ -902,6 +910,7 @@ def __init__(
902910
self.ca_certs = ca_certs
903911
self.ca_data = ca_data
904912
self.check_hostname = check_hostname
913+
self.min_version = min_version
905914
self.context: Optional[ssl.SSLContext] = None
906915

907916
def get(self) -> ssl.SSLContext:
@@ -913,6 +922,8 @@ def get(self) -> ssl.SSLContext:
913922
context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
914923
if self.ca_certs or self.ca_data:
915924
context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data)
925+
if self.min_version is not None:
926+
context.minimum_version = self.min_version
916927
self.context = context
917928
return self.context
918929

redis/client.py

+2
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def __init__(
198198
ssl_validate_ocsp_stapled=False,
199199
ssl_ocsp_context=None,
200200
ssl_ocsp_expected_cert=None,
201+
ssl_min_version=None,
201202
max_connections=None,
202203
single_connection_client=False,
203204
health_check_interval=0,
@@ -311,6 +312,7 @@ def __init__(
311312
"ssl_validate_ocsp": ssl_validate_ocsp,
312313
"ssl_ocsp_context": ssl_ocsp_context,
313314
"ssl_ocsp_expected_cert": ssl_ocsp_expected_cert,
315+
"ssl_min_version": ssl_min_version,
314316
}
315317
)
316318
connection_pool = ConnectionPool(**kwargs)

redis/connection.py

+5
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,7 @@ def __init__(
769769
ssl_validate_ocsp_stapled=False,
770770
ssl_ocsp_context=None,
771771
ssl_ocsp_expected_cert=None,
772+
ssl_min_version=None,
772773
**kwargs,
773774
):
774775
"""Constructor
@@ -787,6 +788,7 @@ def __init__(
787788
ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response
788789
ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
789790
ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
791+
ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module.
790792
791793
Raises:
792794
RedisError
@@ -819,6 +821,7 @@ def __init__(
819821
self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
820822
self.ssl_ocsp_context = ssl_ocsp_context
821823
self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
824+
self.ssl_min_version = ssl_min_version
822825
super().__init__(**kwargs)
823826

824827
def _connect(self):
@@ -841,6 +844,8 @@ def _connect(self):
841844
context.load_verify_locations(
842845
cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
843846
)
847+
if self.ssl_min_version is not None:
848+
context.minimum_version = self.ssl_min_version
844849
sslsock = context.wrap_socket(sock, server_hostname=self.host)
845850
if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
846851
raise RedisError("cryptography is not installed.")

tests/test_asyncio/test_connect.py

+51-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
SSLConnection,
1111
UnixDomainSocketConnection,
1212
)
13+
from redis.exceptions import ConnectionError
1314

1415
from ..ssl_utils import get_ssl_filename
1516

@@ -50,7 +51,17 @@ async def test_uds_connect(uds_address):
5051

5152

5253
@pytest.mark.ssl
53-
async def test_tcp_ssl_connect(tcp_address):
54+
@pytest.mark.parametrize(
55+
"ssl_min_version",
56+
[
57+
ssl.TLSVersion.TLSv1_2,
58+
pytest.param(
59+
ssl.TLSVersion.TLSv1_3,
60+
marks=pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3"),
61+
),
62+
],
63+
)
64+
async def test_tcp_ssl_connect(tcp_address, ssl_min_version):
5465
host, port = tcp_address
5566
certfile = get_ssl_filename("server-cert.pem")
5667
keyfile = get_ssl_filename("server-key.pem")
@@ -60,12 +71,44 @@ async def test_tcp_ssl_connect(tcp_address):
6071
client_name=_CLIENT_NAME,
6172
ssl_ca_certs=certfile,
6273
socket_timeout=10,
74+
ssl_min_version=ssl_min_version,
6375
)
6476
await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)
6577
await conn.disconnect()
6678

6779

68-
async def _assert_connect(conn, server_address, certfile=None, keyfile=None):
80+
@pytest.mark.ssl
81+
@pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3")
82+
async def test_tcp_ssl_version_mismatch(tcp_address):
83+
host, port = tcp_address
84+
certfile = get_ssl_filename("server-cert.pem")
85+
keyfile = get_ssl_filename("server-key.pem")
86+
conn = SSLConnection(
87+
host=host,
88+
port=port,
89+
client_name=_CLIENT_NAME,
90+
ssl_ca_certs=certfile,
91+
socket_timeout=1,
92+
ssl_min_version=ssl.TLSVersion.TLSv1_3,
93+
)
94+
with pytest.raises(ConnectionError):
95+
await _assert_connect(
96+
conn,
97+
tcp_address,
98+
certfile=certfile,
99+
keyfile=keyfile,
100+
ssl_version=ssl.TLSVersion.TLSv1_2,
101+
)
102+
await conn.disconnect()
103+
104+
105+
async def _assert_connect(
106+
conn,
107+
server_address,
108+
certfile=None,
109+
keyfile=None,
110+
ssl_version=None,
111+
):
69112
stop_event = asyncio.Event()
70113
finished = asyncio.Event()
71114

@@ -82,7 +125,9 @@ async def _handler(reader, writer):
82125
elif certfile:
83126
host, port = server_address
84127
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
85-
context.minimum_version = ssl.TLSVersion.TLSv1_2
128+
if ssl_version is not None:
129+
context.minimum_version = ssl_version
130+
context.maximum_version = ssl_version
86131
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
87132
server = await asyncio.start_server(_handler, host=host, port=port, ssl=context)
88133
else:
@@ -94,6 +139,9 @@ async def _handler(reader, writer):
94139
try:
95140
await conn.connect()
96141
await conn.disconnect()
142+
except ConnectionError:
143+
finished.set()
144+
raise
97145
finally:
98146
stop_event.set()
99147
aserver.close()

0 commit comments

Comments
 (0)