Skip to content

Commit 2c12155

Browse files
authored
Added a replacement for the default cluster node in the event of failure. (#2463)
1 parent f4d07dd commit 2c12155

File tree

5 files changed

+128
-3
lines changed

5 files changed

+128
-3
lines changed

CHANGES

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
* Added CredentialsProvider class to support password rotation
3030
* Enable Lock for asyncio cluster mode
3131
* Fix Sentinel.execute_command doesn't execute across the entire sentinel cluster bug (#2458)
32+
* Added a replacement for the default cluster node in the event of failure (#2463)
3233

3334
* 4.1.3 (Feb 8, 2022)
3435
* Fix flushdb and flushall (#1926)

redis/asyncio/cluster.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,8 @@ def set_response_callback(self, command: str, callback: ResponseCallbackT) -> No
517517
async def _determine_nodes(
518518
self, command: str, *args: Any, node_flag: Optional[str] = None
519519
) -> List["ClusterNode"]:
520+
# Determine which nodes should be executed the command on.
521+
# Returns a list of target nodes.
520522
if not node_flag:
521523
# get the nodes group for this command if it was predefined
522524
node_flag = self.command_flags.get(command)
@@ -654,6 +656,12 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
654656
for _ in range(execute_attempts):
655657
if self._initialize:
656658
await self.initialize()
659+
if (
660+
len(target_nodes) == 1
661+
and target_nodes[0] == self.get_default_node()
662+
):
663+
# Replace the default cluster node
664+
self.replace_default_node()
657665
try:
658666
if not target_nodes_specified:
659667
# Determine the nodes to execute the command on
@@ -1450,7 +1458,6 @@ async def _execute(
14501458
)
14511459
if len(target_nodes) > 1:
14521460
raise RedisClusterException(f"Too many targets for command {cmd.args}")
1453-
14541461
node = target_nodes[0]
14551462
if node.name not in nodes:
14561463
nodes[node.name] = (node, [])
@@ -1487,6 +1494,19 @@ async def _execute(
14871494
result.args = (msg,) + result.args[1:]
14881495
raise result
14891496

1497+
default_node = nodes.get(client.get_default_node().name)
1498+
if default_node is not None:
1499+
# This pipeline execution used the default node, check if we need
1500+
# to replace it.
1501+
# Note: when the error is raised we'll reset the default node in the
1502+
# caller function.
1503+
for cmd in default_node[1]:
1504+
# Check if it has a command that failed with a relevant
1505+
# exception
1506+
if type(cmd.result) in self.__class__.ERRORS_ALLOW_RETRY:
1507+
client.replace_default_node()
1508+
break
1509+
14901510
return [cmd.result for cmd in stack]
14911511

14921512
def _split_command_across_slots(

redis/cluster.py

+43-2
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,30 @@ class AbstractRedisCluster:
379379

380380
ERRORS_ALLOW_RETRY = (ConnectionError, TimeoutError, ClusterDownError)
381381

382+
def replace_default_node(self, target_node: "ClusterNode" = None) -> None:
383+
"""Replace the default cluster node.
384+
A random cluster node will be chosen if target_node isn't passed, and primaries
385+
will be prioritized. The default node will not be changed if there are no other
386+
nodes in the cluster.
387+
388+
Args:
389+
target_node (ClusterNode, optional): Target node to replace the default
390+
node. Defaults to None.
391+
"""
392+
if target_node:
393+
self.nodes_manager.default_node = target_node
394+
else:
395+
curr_node = self.get_default_node()
396+
primaries = [node for node in self.get_primaries() if node != curr_node]
397+
if primaries:
398+
# Choose a primary if the cluster contains different primaries
399+
self.nodes_manager.default_node = random.choice(primaries)
400+
else:
401+
# Otherwise, hoose a primary if the cluster contains different primaries
402+
replicas = [node for node in self.get_replicas() if node != curr_node]
403+
if replicas:
404+
self.nodes_manager.default_node = random.choice(replicas)
405+
382406

383407
class RedisCluster(AbstractRedisCluster, RedisClusterCommands):
384408
@classmethod
@@ -811,7 +835,9 @@ def set_response_callback(self, command, callback):
811835
"""Set a custom Response Callback"""
812836
self.cluster_response_callbacks[command] = callback
813837

814-
def _determine_nodes(self, *args, **kwargs):
838+
def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]:
839+
# Determine which nodes should be executed the command on.
840+
# Returns a list of target nodes.
815841
command = args[0].upper()
816842
if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags:
817843
command = f"{args[0]} {args[1]}".upper()
@@ -990,6 +1016,7 @@ def execute_command(self, *args, **kwargs):
9901016
dict<Any, ClusterNode>
9911017
"""
9921018
target_nodes_specified = False
1019+
is_default_node = False
9931020
target_nodes = None
9941021
passed_targets = kwargs.pop("target_nodes", None)
9951022
if passed_targets is not None and not self._is_nodes_flag(passed_targets):
@@ -1020,12 +1047,20 @@ def execute_command(self, *args, **kwargs):
10201047
raise RedisClusterException(
10211048
f"No targets were found to execute {args} command on"
10221049
)
1050+
if (
1051+
len(target_nodes) == 1
1052+
and target_nodes[0] == self.get_default_node()
1053+
):
1054+
is_default_node = True
10231055
for node in target_nodes:
10241056
res[node.name] = self._execute_command(node, *args, **kwargs)
10251057
# Return the processed result
10261058
return self._process_result(args[0], res, **kwargs)
10271059
except Exception as e:
10281060
if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY:
1061+
if is_default_node:
1062+
# Replace the default cluster node
1063+
self.replace_default_node()
10291064
# The nodes and slots cache were reinitialized.
10301065
# Try again with the new cluster setup.
10311066
retry_attempts -= 1
@@ -1883,7 +1918,7 @@ def _send_cluster_commands(
18831918
# if we have to run through it again, we only retry
18841919
# the commands that failed.
18851920
attempt = sorted(stack, key=lambda x: x.position)
1886-
1921+
is_default_node = False
18871922
# build a list of node objects based on node names we need to
18881923
nodes = {}
18891924

@@ -1913,6 +1948,8 @@ def _send_cluster_commands(
19131948
)
19141949

19151950
node = target_nodes[0]
1951+
if node == self.get_default_node():
1952+
is_default_node = True
19161953

19171954
# now that we know the name of the node
19181955
# ( it's just a string in the form of host:port )
@@ -1926,6 +1963,8 @@ def _send_cluster_commands(
19261963
# Connection retries are being handled in the node's
19271964
# Retry object. Reinitialize the node -> slot table.
19281965
self.nodes_manager.initialize()
1966+
if is_default_node:
1967+
self.replace_default_node()
19291968
raise
19301969
nodes[node_name] = NodeCommands(
19311970
redis_node.parse_response,
@@ -2007,6 +2046,8 @@ def _send_cluster_commands(
20072046
self.reinitialize_counter += 1
20082047
if self._should_reinitialized():
20092048
self.nodes_manager.initialize()
2049+
if is_default_node:
2050+
self.replace_default_node()
20102051
for c in attempt:
20112052
try:
20122053
# send each command individually like we

tests/test_asyncio/test_cluster.py

+40
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,27 @@ async def test_can_run_concurrent_commands(self, request: FixtureRequest) -> Non
788788
)
789789
await rc.close()
790790

791+
def test_replace_cluster_node(self, r: RedisCluster) -> None:
792+
prev_default_node = r.get_default_node()
793+
r.replace_default_node()
794+
assert r.get_default_node() != prev_default_node
795+
r.replace_default_node(prev_default_node)
796+
assert r.get_default_node() == prev_default_node
797+
798+
async def test_default_node_is_replaced_after_exception(self, r):
799+
curr_default_node = r.get_default_node()
800+
# CLUSTER NODES command is being executed on the default node
801+
nodes = await r.cluster_nodes()
802+
assert "myself" in nodes.get(curr_default_node.name).get("flags")
803+
# Mock connection error for the default node
804+
mock_node_resp_exc(curr_default_node, ConnectionError("error"))
805+
# Test that the command succeed from a different node
806+
nodes = await r.cluster_nodes()
807+
assert "myself" not in nodes.get(curr_default_node.name).get("flags")
808+
assert r.get_default_node() != curr_default_node
809+
# Rollback to the old default node
810+
r.replace_default_node(curr_default_node)
811+
791812

792813
class TestClusterRedisCommands:
793814
"""
@@ -2591,6 +2612,25 @@ async def test_can_run_concurrent_pipelines(self, r: RedisCluster) -> None:
25912612
*(self.test_multi_key_operation_with_multi_slots(r) for i in range(100)),
25922613
)
25932614

2615+
@pytest.mark.onlycluster
2616+
async def test_pipeline_with_default_node_error_command(self, create_redis):
2617+
"""
2618+
Test that the default node is being replaced when it raises a relevant exception
2619+
"""
2620+
r = await create_redis(cls=RedisCluster, flushdb=False)
2621+
curr_default_node = r.get_default_node()
2622+
err = ConnectionError("error")
2623+
cmd_count = await r.command_count()
2624+
mock_node_resp_exc(curr_default_node, err)
2625+
async with r.pipeline(transaction=False) as pipe:
2626+
pipe.command_count()
2627+
result = await pipe.execute(raise_on_error=False)
2628+
assert result[0] == err
2629+
assert r.get_default_node() != curr_default_node
2630+
pipe.command_count()
2631+
result = await pipe.execute(raise_on_error=False)
2632+
assert result[0] == cmd_count
2633+
25942634

25952635
@pytest.mark.ssl
25962636
class TestSSL:

tests/test_cluster.py

+23
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,29 @@ def test_cluster_retry_object(self, r) -> None:
791791
== retry._retries
792792
)
793793

794+
def test_replace_cluster_node(self, r) -> None:
795+
prev_default_node = r.get_default_node()
796+
r.replace_default_node()
797+
assert r.get_default_node() != prev_default_node
798+
r.replace_default_node(prev_default_node)
799+
assert r.get_default_node() == prev_default_node
800+
801+
def test_default_node_is_replaced_after_exception(self, r):
802+
curr_default_node = r.get_default_node()
803+
# CLUSTER NODES command is being executed on the default node
804+
nodes = r.cluster_nodes()
805+
assert "myself" in nodes.get(curr_default_node.name).get("flags")
806+
807+
def raise_connection_error():
808+
raise ConnectionError("error")
809+
810+
# Mock connection error for the default node
811+
mock_node_resp_func(curr_default_node, raise_connection_error)
812+
# Test that the command succeed from a different node
813+
nodes = r.cluster_nodes()
814+
assert "myself" not in nodes.get(curr_default_node.name).get("flags")
815+
assert r.get_default_node() != curr_default_node
816+
794817

795818
@pytest.mark.onlycluster
796819
class TestClusterRedisCommands:

0 commit comments

Comments
 (0)