Skip to content

Commit f11ddae

Browse files
committed
add 'connect' tests for all Redis connection classes
1 parent b167df0 commit f11ddae

File tree

1 file changed

+176
-0
lines changed

1 file changed

+176
-0
lines changed

tests/test_connect.py

+176
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import logging
2+
import re
3+
import socket
4+
import ssl
5+
import threading
6+
7+
import pytest
8+
9+
from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection
10+
11+
from .ssl_certificates import get_ssl_certificate
12+
13+
_logger = logging.getLogger(__name__)
14+
15+
16+
_CLIENT_NAME = "test-suite-client"
17+
_CMD_SEP = b"\r\n"
18+
_SUCCESS_RESP = b"+OK" + _CMD_SEP
19+
_ERROR_RESP = b"-ERR" + _CMD_SEP
20+
_COMMANDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP}
21+
22+
23+
@pytest.fixture
24+
def tcp_address():
25+
with socket.socket() as sock:
26+
sock.bind(("127.0.0.1", 0))
27+
return sock.getsockname()
28+
29+
30+
@pytest.fixture
31+
def uds_address(tmpdir):
32+
return tmpdir / "uds.sock"
33+
34+
35+
def test_tcp_connect(tcp_address):
36+
host, port = tcp_address
37+
conn = Connection(host=host, port=port, client_name=_CLIENT_NAME)
38+
_assert_connect(conn, tcp_address)
39+
40+
41+
def test_uds_connect(uds_address):
42+
path = str(uds_address)
43+
conn = UnixDomainSocketConnection(path, client_name=_CLIENT_NAME)
44+
_assert_connect(conn, path)
45+
46+
47+
@pytest.mark.ssl
48+
def test_tcp_ssl_connect(tcp_address):
49+
host, port = tcp_address
50+
certfile = get_ssl_certificate("server-cert.pem")
51+
keyfile = get_ssl_certificate("server-key.pem")
52+
conn = SSLConnection(
53+
host=host,
54+
port=port,
55+
client_name=_CLIENT_NAME,
56+
ssl_certfile=certfile,
57+
ssl_keyfile=keyfile,
58+
ssl_ca_certs=certfile,
59+
)
60+
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)
61+
62+
63+
def _assert_connect(conn, server_address, certfile=None, keyfile=None):
64+
ready = threading.Event()
65+
stop = threading.Event()
66+
t = threading.Thread(
67+
target=_redis_mock_server,
68+
args=(server_address, ready, stop),
69+
kwargs={"certfile": certfile, "keyfile": keyfile},
70+
)
71+
t.start()
72+
try:
73+
ready.wait()
74+
conn.connect()
75+
conn.disconnect()
76+
finally:
77+
stop.set()
78+
t.join(timeout=5)
79+
80+
81+
def _redis_mock_server(server_address, ready, stop, certfile=None, keyfile=None):
82+
try:
83+
if isinstance(server_address, str):
84+
family = socket.AF_UNIX
85+
mockname = "Redis mock server (UDS)"
86+
elif certfile:
87+
family = socket.AF_INET
88+
mockname = "Redis mock server (TCP-SSL)"
89+
else:
90+
family = socket.AF_INET
91+
mockname = "Redis mock server (TCP)"
92+
93+
with socket.socket(family, socket.SOCK_STREAM) as s:
94+
s.bind(server_address)
95+
s.listen(1)
96+
s.settimeout(0.1)
97+
98+
if certfile:
99+
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
100+
context.minimum_version = ssl.TLSVersion.TLSv1_2
101+
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
102+
103+
_logger.info("Start %s: %s", mockname, server_address)
104+
ready.set()
105+
106+
# Wait a client connection
107+
while not stop.is_set():
108+
try:
109+
sconn, _ = s.accept()
110+
sconn.settimeout(0.1)
111+
break
112+
except socket.timeout:
113+
pass
114+
if stop.is_set():
115+
_logger.info("Exit %s: %s", mockname, server_address)
116+
return
117+
118+
# Receive commands from the client
119+
with sconn:
120+
if certfile:
121+
conn = context.wrap_socket(sconn, server_side=True)
122+
else:
123+
conn = sconn
124+
try:
125+
buffer = b""
126+
command = None
127+
command_ptr = None
128+
fragment_length = None
129+
while not stop.is_set() or buffer:
130+
try:
131+
buffer += conn.recv(1024)
132+
except socket.timeout:
133+
continue
134+
if not buffer:
135+
continue
136+
parts = re.split(_CMD_SEP, buffer)
137+
buffer = parts[-1]
138+
for fragment in parts[:-1]:
139+
fragment = fragment.decode()
140+
_logger.info(
141+
"Command fragment in %s: %s", mockname, fragment
142+
)
143+
144+
if fragment.startswith("*") and command is None:
145+
command = [None for _ in range(int(fragment[1:]))]
146+
command_ptr = 0
147+
fragment_length = None
148+
continue
149+
150+
if (
151+
fragment.startswith("$")
152+
and command[command_ptr] is None
153+
):
154+
fragment_length = int(fragment[1:])
155+
continue
156+
157+
assert len(fragment) == fragment_length
158+
command[command_ptr] = fragment
159+
command_ptr += 1
160+
161+
if command_ptr < len(command):
162+
continue
163+
164+
command = " ".join(command)
165+
_logger.info("Command in %s: %s", mockname, command)
166+
resp = _COMMANDS.get(command, _ERROR_RESP)
167+
_logger.info("Response from %s: %s", mockname, resp)
168+
conn.sendall(resp)
169+
command = None
170+
finally:
171+
if certfile:
172+
conn.close()
173+
_logger.info("Exit %s: %s", mockname, server_address)
174+
except BaseException as e:
175+
_logger.exception("Error in %s: %s", mockname, e)
176+
raise

0 commit comments

Comments
 (0)