Skip to content

Commit

Permalink
Add unit tests for the connect method of all Redis connection class…
Browse files Browse the repository at this point in the history
…es (#2631)

* tests: move certificate discovery to a separate module

* tests: add 'connect' tests for all Redis connection classes

---------

Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>
  • Loading branch information
woutdenolf and dvora-h authored Jun 23, 2023
1 parent 2bb7f10 commit 53bed27
Show file tree
Hide file tree
Showing 5 changed files with 350 additions and 24 deletions.
14 changes: 14 additions & 0 deletions tests/ssl_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import os


def get_ssl_filename(name):
root = os.path.join(os.path.dirname(__file__), "..")
cert_dir = os.path.abspath(os.path.join(root, "docker", "stunnel", "keys"))
if not os.path.isdir(cert_dir): # github actions package validation case
cert_dir = os.path.abspath(
os.path.join(root, "..", "docker", "stunnel", "keys")
)
if not os.path.isdir(cert_dir):
raise IOError(f"No SSL certificates found. They should be in {cert_dir}")

return os.path.join(cert_dir, name)
15 changes: 3 additions & 12 deletions tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import binascii
import datetime
import os
import warnings
from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union
from urllib.parse import urlparse
Expand Down Expand Up @@ -36,6 +35,7 @@
skip_unless_arch_bits,
)

from ..ssl_utils import get_ssl_filename
from .compat import mock

pytestmark = pytest.mark.onlycluster
Expand Down Expand Up @@ -2744,17 +2744,8 @@ class TestSSL:
appropriate port.
"""

ROOT = os.path.join(os.path.dirname(__file__), "../..")
CERT_DIR = os.path.abspath(os.path.join(ROOT, "docker", "stunnel", "keys"))
if not os.path.isdir(CERT_DIR): # github actions package validation case
CERT_DIR = os.path.abspath(
os.path.join(ROOT, "..", "docker", "stunnel", "keys")
)
if not os.path.isdir(CERT_DIR):
raise IOError(f"No SSL certificates found. They should be in {CERT_DIR}")

SERVER_CERT = os.path.join(CERT_DIR, "server-cert.pem")
SERVER_KEY = os.path.join(CERT_DIR, "server-key.pem")
SERVER_CERT = get_ssl_filename("server-cert.pem")
SERVER_KEY = get_ssl_filename("server-key.pem")

@pytest_asyncio.fixture()
def create_client(self, request: FixtureRequest) -> Callable[..., RedisCluster]:
Expand Down
145 changes: 145 additions & 0 deletions tests/test_asyncio/test_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import asyncio
import logging
import re
import socket
import ssl

import pytest

from redis.asyncio.connection import (
Connection,
SSLConnection,
UnixDomainSocketConnection,
)

from ..ssl_utils import get_ssl_filename

_logger = logging.getLogger(__name__)


_CLIENT_NAME = "test-suite-client"
_CMD_SEP = b"\r\n"
_SUCCESS_RESP = b"+OK" + _CMD_SEP
_ERROR_RESP = b"-ERR" + _CMD_SEP
_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP}


@pytest.fixture
def tcp_address():
with socket.socket() as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()


@pytest.fixture
def uds_address(tmpdir):
return tmpdir / "uds.sock"


async def test_tcp_connect(tcp_address):
host, port = tcp_address
conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10)
await _assert_connect(conn, tcp_address)


async def test_uds_connect(uds_address):
path = str(uds_address)
conn = UnixDomainSocketConnection(
path=path, client_name=_CLIENT_NAME, socket_timeout=10
)
await _assert_connect(conn, path)


@pytest.mark.ssl
async def test_tcp_ssl_connect(tcp_address):
host, port = tcp_address
certfile = get_ssl_filename("server-cert.pem")
keyfile = get_ssl_filename("server-key.pem")
conn = SSLConnection(
host=host,
port=port,
client_name=_CLIENT_NAME,
ssl_ca_certs=certfile,
socket_timeout=10,
)
await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)


async def _assert_connect(conn, server_address, certfile=None, keyfile=None):
stop_event = asyncio.Event()
finished = asyncio.Event()

async def _handler(reader, writer):
try:
return await _redis_request_handler(reader, writer, stop_event)
finally:
finished.set()

if isinstance(server_address, str):
server = await asyncio.start_unix_server(_handler, path=server_address)
elif certfile:
host, port = server_address
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.minimum_version = ssl.TLSVersion.TLSv1_2
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
server = await asyncio.start_server(_handler, host=host, port=port, ssl=context)
else:
host, port = server_address
server = await asyncio.start_server(_handler, host=host, port=port)

async with server as aserver:
await aserver.start_serving()
try:
await conn.connect()
await conn.disconnect()
finally:
stop_event.set()
aserver.close()
await aserver.wait_closed()
await finished.wait()


async def _redis_request_handler(reader, writer, stop_event):
buffer = b""
command = None
command_ptr = None
fragment_length = None
while not stop_event.is_set() or buffer:
_logger.info(str(stop_event.is_set()))
try:
buffer += await asyncio.wait_for(reader.read(1024), timeout=0.5)
except TimeoutError:
continue
if not buffer:
continue
parts = re.split(_CMD_SEP, buffer)
buffer = parts[-1]
for fragment in parts[:-1]:
fragment = fragment.decode()
_logger.info("Command fragment: %s", fragment)

if fragment.startswith("*") and command is None:
command = [None for _ in range(int(fragment[1:]))]
command_ptr = 0
fragment_length = None
continue

if fragment.startswith("$") and command[command_ptr] is None:
fragment_length = int(fragment[1:])
continue

assert len(fragment) == fragment_length
command[command_ptr] = fragment
command_ptr += 1

if command_ptr < len(command):
continue

command = " ".join(command)
_logger.info("Command %s", command)
resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP)
_logger.info("Response from %s", resp)
writer.write(resp)
await writer.drain()
command = None
_logger.info("Exit handler")
185 changes: 185 additions & 0 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import logging
import re
import socket
import socketserver
import ssl
import threading

import pytest

from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection

from .ssl_utils import get_ssl_filename

_logger = logging.getLogger(__name__)


_CLIENT_NAME = "test-suite-client"
_CMD_SEP = b"\r\n"
_SUCCESS_RESP = b"+OK" + _CMD_SEP
_ERROR_RESP = b"-ERR" + _CMD_SEP
_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP}


@pytest.fixture
def tcp_address():
with socket.socket() as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()


@pytest.fixture
def uds_address(tmpdir):
return tmpdir / "uds.sock"


def test_tcp_connect(tcp_address):
host, port = tcp_address
conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10)
_assert_connect(conn, tcp_address)


def test_uds_connect(uds_address):
path = str(uds_address)
conn = UnixDomainSocketConnection(path, client_name=_CLIENT_NAME, socket_timeout=10)
_assert_connect(conn, path)


@pytest.mark.ssl
def test_tcp_ssl_connect(tcp_address):
host, port = tcp_address
certfile = get_ssl_filename("server-cert.pem")
keyfile = get_ssl_filename("server-key.pem")
conn = SSLConnection(
host=host,
port=port,
client_name=_CLIENT_NAME,
ssl_ca_certs=certfile,
socket_timeout=10,
)
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)


def _assert_connect(conn, server_address, certfile=None, keyfile=None):
if isinstance(server_address, str):
server = _RedisUDSServer(server_address, _RedisRequestHandler)
else:
server = _RedisTCPServer(
server_address, _RedisRequestHandler, certfile=certfile, keyfile=keyfile
)
with server as aserver:
t = threading.Thread(target=aserver.serve_forever)
t.start()
try:
aserver.wait_online()
conn.connect()
conn.disconnect()
finally:
aserver.stop()
t.join(timeout=5)


class _RedisTCPServer(socketserver.TCPServer):
def __init__(self, *args, certfile=None, keyfile=None, **kw) -> None:
self._ready_event = threading.Event()
self._stop_requested = False
self._certfile = certfile
self._keyfile = keyfile
super().__init__(*args, **kw)

def service_actions(self):
self._ready_event.set()

def wait_online(self):
self._ready_event.wait()

def stop(self):
self._stop_requested = True
self.shutdown()

def is_serving(self):
return not self._stop_requested

def get_request(self):
if self._certfile is None:
return super().get_request()
newsocket, fromaddr = self.socket.accept()
connstream = ssl.wrap_socket(
newsocket,
server_side=True,
certfile=self._certfile,
keyfile=self._keyfile,
ssl_version=ssl.PROTOCOL_TLSv1_2,
)
return connstream, fromaddr


class _RedisUDSServer(socketserver.UnixStreamServer):
def __init__(self, *args, **kw) -> None:
self._ready_event = threading.Event()
self._stop_requested = False
super().__init__(*args, **kw)

def service_actions(self):
self._ready_event.set()

def wait_online(self):
self._ready_event.wait()

def stop(self):
self._stop_requested = True
self.shutdown()

def is_serving(self):
return not self._stop_requested


class _RedisRequestHandler(socketserver.StreamRequestHandler):
def setup(self):
_logger.info("%s connected", self.client_address)

def finish(self):
_logger.info("%s disconnected", self.client_address)

def handle(self):
buffer = b""
command = None
command_ptr = None
fragment_length = None
while self.server.is_serving() or buffer:
try:
buffer += self.request.recv(1024)
except socket.timeout:
continue
if not buffer:
continue
parts = re.split(_CMD_SEP, buffer)
buffer = parts[-1]
for fragment in parts[:-1]:
fragment = fragment.decode()
_logger.info("Command fragment: %s", fragment)

if fragment.startswith("*") and command is None:
command = [None for _ in range(int(fragment[1:]))]
command_ptr = 0
fragment_length = None
continue

if fragment.startswith("$") and command[command_ptr] is None:
fragment_length = int(fragment[1:])
continue

assert len(fragment) == fragment_length
command[command_ptr] = fragment
command_ptr += 1

if command_ptr < len(command):
continue

command = " ".join(command)
_logger.info("Command %s", command)
resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP)
_logger.info("Response %s", resp)
self.request.sendall(resp)
command = None
_logger.info("Exit handler")
Loading

0 comments on commit 53bed27

Please sign in to comment.