Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add "address_remap" feature to RedisCluster #2726

Merged
merged 7 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Add `address_remap` parameter to `RedisCluster`
* Fix incorrect usage of once flag in async Sentinel
* asyncio: Fix memory leak caused by hiredis (#2693)
* Allow data to drain from async PythonParser when reading during a disconnect()
Expand Down
31 changes: 30 additions & 1 deletion redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import warnings
from typing import (
Any,
Callable,
Deque,
Dict,
Generator,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -147,6 +149,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
maximum number of connections are already created, a
:class:`~.MaxConnectionsError` is raised. This error may be retried as defined
by :attr:`connection_error_retry_attempts`
:param address_remap:
| An optional callable which, when provided with an internal network
address of a node, e.g. a `(host, port)` tuple, will return the address
where the node is reachable. This can be used to map the addresses at
which the nodes _think_ they are, to addresses at which a client may
reach them, such as when they sit behind a proxy.

| Rest of the arguments will be passed to the
:class:`~redis.asyncio.connection.Connection` instances when created
Expand Down Expand Up @@ -250,6 +258,7 @@ def __init__(
ssl_certfile: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_keyfile: Optional[str] = None,
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
) -> None:
if db:
raise RedisClusterException(
Expand Down Expand Up @@ -337,7 +346,12 @@ def __init__(
if host and port:
startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))

self.nodes_manager = NodesManager(startup_nodes, require_full_coverage, kwargs)
self.nodes_manager = NodesManager(
startup_nodes,
require_full_coverage,
kwargs,
address_remap=address_remap,
)
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self.read_from_replicas = read_from_replicas
self.reinitialize_steps = reinitialize_steps
Expand Down Expand Up @@ -1059,17 +1073,20 @@ class NodesManager:
"require_full_coverage",
"slots_cache",
"startup_nodes",
"address_remap",
)

def __init__(
self,
startup_nodes: List["ClusterNode"],
require_full_coverage: bool,
connection_kwargs: Dict[str, Any],
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
) -> None:
self.startup_nodes = {node.name: node for node in startup_nodes}
self.require_full_coverage = require_full_coverage
self.connection_kwargs = connection_kwargs
self.address_remap = address_remap

self.default_node: "ClusterNode" = None
self.nodes_cache: Dict[str, "ClusterNode"] = {}
Expand Down Expand Up @@ -1228,6 +1245,7 @@ async def initialize(self) -> None:
if host == "":
host = startup_node.host
port = int(primary_node[1])
host, port = self.remap_host_port(host, port)

target_node = tmp_nodes_cache.get(get_node_name(host, port))
if not target_node:
Expand All @@ -1246,6 +1264,7 @@ async def initialize(self) -> None:
for replica_node in replica_nodes:
host = replica_node[0]
port = replica_node[1]
host, port = self.remap_host_port(host, port)

target_replica_node = tmp_nodes_cache.get(
get_node_name(host, port)
Expand Down Expand Up @@ -1319,6 +1338,16 @@ async def close(self, attr: str = "nodes_cache") -> None:
)
)

def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
"""
Remap the host and port returned from the cluster to a different
internal value. Useful if the client is not connecting directly
to the cluster.
"""
if self.address_remap:
return self.address_remap((host, port))
return host, port


class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
"""
Expand Down
22 changes: 22 additions & 0 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ def __init__(
read_from_replicas: bool = False,
dynamic_startup_nodes: bool = True,
url: Optional[str] = None,
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -514,6 +515,12 @@ def __init__(
reinitialize_steps to 1.
To avoid reinitializing the cluster on moved errors, set
reinitialize_steps to 0.
:param address_remap:
An optional callable which, when provided with an internal network
address of a node, e.g. a `(host, port)` tuple, will return the address
where the node is reachable. This can be used to map the addresses at
which the nodes _think_ they are, to addresses at which a client may
reach them, such as when they sit behind a proxy.

:**kwargs:
Extra arguments that will be sent into Redis instance when created
Expand Down Expand Up @@ -594,6 +601,7 @@ def __init__(
from_url=from_url,
require_full_coverage=require_full_coverage,
dynamic_startup_nodes=dynamic_startup_nodes,
address_remap=address_remap,
**kwargs,
)

Expand Down Expand Up @@ -1269,6 +1277,7 @@ def __init__(
lock=None,
dynamic_startup_nodes=True,
connection_pool_class=ConnectionPool,
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
**kwargs,
):
self.nodes_cache = {}
Expand All @@ -1280,6 +1289,7 @@ def __init__(
self._require_full_coverage = require_full_coverage
self._dynamic_startup_nodes = dynamic_startup_nodes
self.connection_pool_class = connection_pool_class
self.address_remap = address_remap
self._moved_exception = None
self.connection_kwargs = kwargs
self.read_load_balancer = LoadBalancer()
Expand Down Expand Up @@ -1502,6 +1512,7 @@ def initialize(self):
if host == "":
host = startup_node.host
port = int(primary_node[1])
host, port = self.remap_host_port(host, port)

target_node = self._get_or_create_cluster_node(
host, port, PRIMARY, tmp_nodes_cache
Expand All @@ -1518,6 +1529,7 @@ def initialize(self):
for replica_node in replica_nodes:
host = str_if_bytes(replica_node[0])
port = replica_node[1]
host, port = self.remap_host_port(host, port)

target_replica_node = self._get_or_create_cluster_node(
host, port, REPLICA, tmp_nodes_cache
Expand Down Expand Up @@ -1591,6 +1603,16 @@ def reset(self):
# The read_load_balancer is None, do nothing
pass

def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
"""
Remap the host and port returned from the cluster to a different
internal value. Useful if the client is not connecting directly
to the cluster.
"""
if self.address_remap:
return self.address_remap((host, port))
return host, port


class ClusterPubSub(PubSub):
"""
Expand Down
110 changes: 109 additions & 1 deletion tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from _pytest.fixtures import FixtureRequest

from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster
from redis.asyncio.connection import Connection, SSLConnection
from redis.asyncio.connection import Connection, SSLConnection, async_timeout
from redis.asyncio.parser import CommandsParser
from redis.asyncio.retry import Retry
from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff
Expand Down Expand Up @@ -49,6 +49,71 @@
]


class NodeProxy:
"""A class to proxy a node connection to a different port"""

def __init__(self, addr, redis_addr):
self.addr = addr
self.redis_addr = redis_addr
self.send_event = asyncio.Event()
self.server = None
self.task = None
self.n_connections = 0

async def start(self):
# test that we can connect to redis
async with async_timeout(2):
_, redis_writer = await asyncio.open_connection(*self.redis_addr)
redis_writer.close()
self.server = await asyncio.start_server(
self.handle, *self.addr, reuse_address=True
)
self.task = asyncio.create_task(self.server.serve_forever())

async def handle(self, reader, writer):
# establish connection to redis
redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr)
try:
self.n_connections += 1
pipe1 = asyncio.create_task(self.pipe(reader, redis_writer))
pipe2 = asyncio.create_task(self.pipe(redis_reader, writer))
await asyncio.gather(pipe1, pipe2)
finally:
redis_writer.close()

async def aclose(self):
self.task.cancel()
try:
await self.task
except asyncio.CancelledError:
pass
await self.server.wait_closed()

async def pipe(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
):
while True:
data = await reader.read(1000)
if not data:
break
writer.write(data)
await writer.drain()


@pytest.fixture
def redis_addr(request):
redis_url = request.config.getoption("--redis-url")
scheme, netloc = urlparse(redis_url)[:2]
assert scheme == "redis"
if ":" in netloc:
host, port = netloc.split(":")
return host, int(port)
else:
return netloc, 6379


@pytest_asyncio.fixture()
async def slowlog(r: RedisCluster) -> None:
"""
Expand Down Expand Up @@ -809,6 +874,49 @@ async def test_default_node_is_replaced_after_exception(self, r):
# Rollback to the old default node
r.replace_default_node(curr_default_node)

async def test_address_remap(self, create_redis, redis_addr):
"""Test that we can create a rediscluster object with
a host-port remapper and map connections through proxy objects
"""

# we remap the first n nodes
offset = 1000
n = 6
ports = [redis_addr[1] + i for i in range(n)]

def address_remap(address):
# remap first three nodes to our local proxy
# old = host, port
host, port = address
if int(port) in ports:
host, port = "127.0.0.1", int(port) + offset
# print(f"{old} {host, port}")
return host, port

# create the proxies
proxies = [
NodeProxy(("127.0.0.1", port + offset), (redis_addr[0], port))
for port in ports
]
await asyncio.gather(*[p.start() for p in proxies])
try:
# create cluster:
r = await create_redis(
cls=RedisCluster, flushdb=False, address_remap=address_remap
)
try:
assert await r.ping() is True
assert await r.set("byte_string", b"giraffe")
assert await r.get("byte_string") == b"giraffe"
finally:
await r.close()
finally:
await asyncio.gather(*[p.aclose() for p in proxies])

# verify that the proxies were indeed used
n_used = sum((1 if p.n_connections else 0) for p in proxies)
assert n_used > 1


class TestClusterRedisCommands:
"""
Expand Down
Loading