Skip to content

Commit 3ea7246

Browse files
committed
tests: add 'connect' tests for all Redis connection classes
1 parent 6ebf8eb commit 3ea7246

File tree

1 file changed

+166
-0
lines changed

1 file changed

+166
-0
lines changed

tests/test_connect.py

+166
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import logging
2+
import re
3+
import socket
4+
import ssl
5+
import threading
6+
7+
import pytest
8+
9+
from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection
10+
11+
from .ssl_utils import get_ssl_filename
12+
13+
_logger = logging.getLogger(__name__)
14+
15+
16+
_CLIENT_NAME = "test-suite-client"
17+
_CMD_SEP = b"\r\n"
18+
_SUCCESS_RESP = b"+OK" + _CMD_SEP
19+
_ERROR_RESP = b"-ERR" + _CMD_SEP
20+
_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP}
21+
22+
23+
@pytest.fixture
24+
def tcp_address():
25+
with socket.socket() as sock:
26+
sock.bind(("127.0.0.1", 0))
27+
return sock.getsockname()
28+
29+
30+
@pytest.fixture
31+
def uds_address(tmpdir):
32+
return tmpdir / "uds.sock"
33+
34+
35+
def test_tcp_connect(tcp_address):
36+
host, port = tcp_address
37+
conn = Connection(host=host, port=port, client_name=_CLIENT_NAME)
38+
_assert_connect(conn, tcp_address)
39+
40+
41+
def test_uds_connect(uds_address):
42+
path = str(uds_address)
43+
conn = UnixDomainSocketConnection(path, client_name=_CLIENT_NAME)
44+
_assert_connect(conn, path)
45+
46+
47+
@pytest.mark.ssl
48+
def test_tcp_ssl_connect(tcp_address):
49+
host, port = tcp_address
50+
certfile = get_ssl_filename("server-cert.pem")
51+
keyfile = get_ssl_filename("server-key.pem")
52+
conn = SSLConnection(
53+
host=host, port=port, client_name=_CLIENT_NAME, ssl_ca_certs=certfile
54+
)
55+
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)
56+
57+
58+
def _assert_connect(conn, server_address, certfile=None, keyfile=None):
59+
ready = threading.Event()
60+
stop = threading.Event()
61+
t = threading.Thread(
62+
target=_redis_mock_server,
63+
args=(server_address, ready, stop),
64+
kwargs={"certfile": certfile, "keyfile": keyfile},
65+
)
66+
t.start()
67+
try:
68+
ready.wait()
69+
conn.connect()
70+
conn.disconnect()
71+
finally:
72+
stop.set()
73+
t.join(timeout=5)
74+
75+
76+
def _redis_mock_server(server_address, ready, stop, certfile=None, keyfile=None):
77+
try:
78+
if isinstance(server_address, str):
79+
family = socket.AF_UNIX
80+
mockname = "Redis mock server (UDS)"
81+
elif certfile:
82+
family = socket.AF_INET
83+
mockname = "Redis mock server (TCP-SSL)"
84+
else:
85+
family = socket.AF_INET
86+
mockname = "Redis mock server (TCP)"
87+
88+
with socket.socket(family, socket.SOCK_STREAM) as s:
89+
s.bind(server_address)
90+
s.listen(1)
91+
s.settimeout(0.1)
92+
93+
if certfile:
94+
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
95+
context.minimum_version = ssl.TLSVersion.TLSv1_2
96+
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
97+
98+
_logger.info("Start %s: %s", mockname, server_address)
99+
ready.set()
100+
101+
# Wait for a client connection
102+
while not stop.is_set():
103+
try:
104+
sconn, _ = s.accept()
105+
sconn.settimeout(0.1)
106+
break
107+
except socket.timeout:
108+
pass
109+
if stop.is_set():
110+
_logger.info("Exit %s: %s", mockname, server_address)
111+
return
112+
113+
# Handle commands from the client
114+
with sconn:
115+
if certfile:
116+
with context.wrap_socket(sconn, server_side=True) as wconn:
117+
_redis_mock_server_handle(wconn, stop, mockname)
118+
else:
119+
_redis_mock_server_handle(sconn, stop, mockname)
120+
_logger.info("Exit %s: %s", mockname, server_address)
121+
except BaseException as e:
122+
_logger.exception("Error in %s: %s", mockname, e)
123+
raise
124+
125+
126+
def _redis_mock_server_handle(conn, stop, mockname):
127+
buffer = b""
128+
command = None
129+
command_ptr = None
130+
fragment_length = None
131+
while not stop.is_set() or buffer:
132+
try:
133+
buffer += conn.recv(1024)
134+
except socket.timeout:
135+
continue
136+
if not buffer:
137+
continue
138+
parts = re.split(_CMD_SEP, buffer)
139+
buffer = parts[-1]
140+
for fragment in parts[:-1]:
141+
fragment = fragment.decode()
142+
_logger.info("Command fragment in %s: %s", mockname, fragment)
143+
144+
if fragment.startswith("*") and command is None:
145+
command = [None for _ in range(int(fragment[1:]))]
146+
command_ptr = 0
147+
fragment_length = None
148+
continue
149+
150+
if fragment.startswith("$") and command[command_ptr] is None:
151+
fragment_length = int(fragment[1:])
152+
continue
153+
154+
assert len(fragment) == fragment_length
155+
command[command_ptr] = fragment
156+
command_ptr += 1
157+
158+
if command_ptr < len(command):
159+
continue
160+
161+
command = " ".join(command)
162+
_logger.info("Command in %s: %s", mockname, command)
163+
resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP)
164+
_logger.info("Response from %s: %s", mockname, resp)
165+
conn.sendall(resp)
166+
command = None

0 commit comments

Comments
 (0)