Skip to content

Commit

Permalink
Added a replacement for the default cluster node in the event of fail…
Browse files Browse the repository at this point in the history
…ure. Handles failovers better.
  • Loading branch information
barshaul committed Nov 20, 2022
1 parent fa45fb1 commit c466e62
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 22 deletions.
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
* Fixed "cannot pickle '_thread.lock' object" bug (#2354, #2297)
* Added CredentialsProvider class to support password rotation
* Enable Lock for asyncio cluster mode
* Added a replacement for the default cluster node in the event of failure (#2463)

* 4.1.3 (Feb 8, 2022)
* Fix flushdb and flushall (#1926)
Expand Down
39 changes: 28 additions & 11 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def __init__(
reinitialize_steps: int = 5,
cluster_error_retry_attempts: int = 3,
connection_error_retry_attempts: int = 3,
max_connections: int = 2**31,
max_connections: int = 2 ** 31,
# Client related kwargs
db: Union[str, int] = 0,
path: Optional[str] = None,
Expand Down Expand Up @@ -516,35 +516,44 @@ def set_response_callback(self, command: str, callback: ResponseCallbackT) -> No

async def _determine_nodes(
self, command: str, *args: Any, node_flag: Optional[str] = None
) -> List["ClusterNode"]:
) -> tuple[list["ClusterNode"], bool]:
"""Determine which nodes should be executed the command on
Returns:
tuple[list[Type[ClusterNode]], bool]:
A tuple containing a list of target nodes and a bool indicating
if the return node was chosen because it is the default node
"""
if not node_flag:
# get the nodes group for this command if it was predefined
node_flag = self.command_flags.get(command)

if node_flag in self.node_flags:
if node_flag == self.__class__.DEFAULT_NODE:
# return the cluster's default node
return [self.nodes_manager.default_node]
return [self.nodes_manager.default_node], True
if node_flag == self.__class__.PRIMARIES:
# return all primaries
return self.nodes_manager.get_nodes_by_server_type(PRIMARY)
return self.nodes_manager.get_nodes_by_server_type(PRIMARY), False
if node_flag == self.__class__.REPLICAS:
# return all replicas
return self.nodes_manager.get_nodes_by_server_type(REPLICA)
return self.nodes_manager.get_nodes_by_server_type(REPLICA), False
if node_flag == self.__class__.ALL_NODES:
# return all nodes
return list(self.nodes_manager.nodes_cache.values())
return list(self.nodes_manager.nodes_cache.values()), False
if node_flag == self.__class__.RANDOM:
# return a random node
return [random.choice(list(self.nodes_manager.nodes_cache.values()))]
return [
random.choice(list(self.nodes_manager.nodes_cache.values()))
], False

# get the node that holds the key's slot
return [
self.nodes_manager.get_node_from_slot(
await self._determine_slot(command, *args),
self.read_from_replicas and command in READ_COMMANDS,
)
]
], False

async def _determine_slot(self, command: str, *args: Any) -> int:
if self.command_flags.get(command) == SLOT_ID:
Expand Down Expand Up @@ -641,6 +650,7 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
command = args[0]
target_nodes = []
target_nodes_specified = False
is_default_node = False
retry_attempts = self.cluster_error_retry_attempts

passed_targets = kwargs.pop("target_nodes", None)
Expand All @@ -654,10 +664,13 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
for _ in range(execute_attempts):
if self._initialize:
await self.initialize()
if is_default_node:
# Replace the default cluster node
self.replace_default_node()
try:
if not target_nodes_specified:
# Determine the nodes to execute the command on
target_nodes = await self._determine_nodes(
target_nodes, is_default_node = await self._determine_nodes(
*args, node_flag=passed_targets
)
if not target_nodes:
Expand Down Expand Up @@ -882,7 +895,7 @@ def __init__(
port: Union[str, int],
server_type: Optional[str] = None,
*,
max_connections: int = 2**31,
max_connections: int = 2 ** 31,
connection_class: Type[Connection] = Connection,
**connection_kwargs: Any,
) -> None:
Expand Down Expand Up @@ -1436,12 +1449,13 @@ async def _execute(
]

nodes = {}
is_default_node = False
for cmd in todo:
passed_targets = cmd.kwargs.pop("target_nodes", None)
if passed_targets and not client._is_node_flag(passed_targets):
target_nodes = client._parse_target_nodes(passed_targets)
else:
target_nodes = await client._determine_nodes(
target_nodes, is_default_node = await client._determine_nodes(
*cmd.args, node_flag=passed_targets
)
if not target_nodes:
Expand Down Expand Up @@ -1487,6 +1501,9 @@ async def _execute(
result.args = (msg,) + result.args[1:]
raise result

if is_default_node:
self.replace_default_node()

return [cmd.result for cmd in stack]

def _split_command_across_slots(
Expand Down
61 changes: 50 additions & 11 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,30 @@ class AbstractRedisCluster:

ERRORS_ALLOW_RETRY = (ConnectionError, TimeoutError, ClusterDownError)

def replace_default_node(self, target_node: "ClusterNode" = None) -> None:
"""Replace the default cluster node.
A random cluster node will be chosen if target_node isn't passed, and primaries
will be prioritized. The default node will not be changed if there are no other
nodes in the cluster.
Args:
target_node (ClusterNode, optional): Target node to replace the default
node. Defaults to None.
"""
if target_node:
self.nodes_manager.default_node = target_node
else:
curr_node = self.get_default_node()
primaries = [node for node in self.get_primaries() if node != curr_node]
if primaries:
# Choose a primary if the cluster contains different primaries
self.nodes_manager.default_node = random.choice(primaries)
else:
# Otherwise, hoose a primary if the cluster contains different primaries
replicas = [node for node in self.get_replicas() if node != curr_node]
if replicas:
self.nodes_manager.default_node = random.choice(replicas)


class RedisCluster(AbstractRedisCluster, RedisClusterCommands):
@classmethod
Expand Down Expand Up @@ -811,7 +835,14 @@ def set_response_callback(self, command, callback):
"""Set a custom Response Callback"""
self.cluster_response_callbacks[command] = callback

def _determine_nodes(self, *args, **kwargs):
def _determine_nodes(self, *args, **kwargs) -> tuple[list["ClusterNode"], bool]:
"""Determine which nodes should be executed the command on
Returns:
tuple[list[Type[ClusterNode]], bool]:
A tuple containing a list of target nodes and a bool indicating
if the return node was chosen because it is the default node
"""
command = args[0].upper()
if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags:
command = f"{args[0]} {args[1]}".upper()
Expand All @@ -825,28 +856,28 @@ def _determine_nodes(self, *args, **kwargs):
command_flag = self.command_flags.get(command)
if command_flag == self.__class__.RANDOM:
# return a random node
return [self.get_random_node()]
return [self.get_random_node()], False
elif command_flag == self.__class__.PRIMARIES:
# return all primaries
return self.get_primaries()
return self.get_primaries(), False
elif command_flag == self.__class__.REPLICAS:
# return all replicas
return self.get_replicas()
return self.get_replicas(), False
elif command_flag == self.__class__.ALL_NODES:
# return all nodes
return self.get_nodes()
return self.get_nodes(), False
elif command_flag == self.__class__.DEFAULT_NODE:
# return the cluster's default node
return [self.nodes_manager.default_node]
return [self.nodes_manager.default_node], True
elif command in self.__class__.SEARCH_COMMANDS[0]:
return [self.nodes_manager.default_node]
return [self.nodes_manager.default_node], True
else:
# get the node that holds the key's slot
slot = self.determine_slot(*args)
node = self.nodes_manager.get_node_from_slot(
slot, self.read_from_replicas and command in READ_COMMANDS
)
return [node]
return [node], False

def _should_reinitialized(self):
# To reinitialize the cluster on every MOVED error,
Expand Down Expand Up @@ -990,6 +1021,7 @@ def execute_command(self, *args, **kwargs):
dict<Any, ClusterNode>
"""
target_nodes_specified = False
is_default_node = False
target_nodes = None
passed_targets = kwargs.pop("target_nodes", None)
if passed_targets is not None and not self._is_nodes_flag(passed_targets):
Expand All @@ -1013,7 +1045,7 @@ def execute_command(self, *args, **kwargs):
res = {}
if not target_nodes_specified:
# Determine the nodes to execute the command on
target_nodes = self._determine_nodes(
target_nodes, is_default_node = self._determine_nodes(
*args, **kwargs, nodes_flag=passed_targets
)
if not target_nodes:
Expand All @@ -1025,6 +1057,9 @@ def execute_command(self, *args, **kwargs):
# Return the processed result
return self._process_result(args[0], res, **kwargs)
except Exception as e:
if is_default_node:
# Replace the default cluster node
self.replace_default_node()
if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY:
# The nodes and slots cache were reinitialized.
# Try again with the new cluster setup.
Expand Down Expand Up @@ -1883,7 +1918,7 @@ def _send_cluster_commands(
# if we have to run through it again, we only retry
# the commands that failed.
attempt = sorted(stack, key=lambda x: x.position)

is_default_node = False
# build a list of node objects based on node names we need to
nodes = {}

Expand All @@ -1900,7 +1935,7 @@ def _send_cluster_commands(
if passed_targets and not self._is_nodes_flag(passed_targets):
target_nodes = self._parse_target_nodes(passed_targets)
else:
target_nodes = self._determine_nodes(
target_nodes, is_default_node = self._determine_nodes(
*c.args, node_flag=passed_targets
)
if not target_nodes:
Expand All @@ -1926,6 +1961,8 @@ def _send_cluster_commands(
# Connection retries are being handled in the node's
# Retry object. Reinitialize the node -> slot table.
self.nodes_manager.initialize()
if is_default_node:
self.replace_default_node()
raise
nodes[node_name] = NodeCommands(
redis_node.parse_response,
Expand Down Expand Up @@ -2007,6 +2044,8 @@ def _send_cluster_commands(
self.reinitialize_counter += 1
if self._should_reinitialized():
self.nodes_manager.initialize()
if is_default_node:
self.replace_default_node()
for c in attempt:
try:
# send each command individually like we
Expand Down
20 changes: 20 additions & 0 deletions tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,26 @@ async def test_can_run_concurrent_commands(self, request: FixtureRequest) -> Non
)
await rc.close()

def test_replace_cluster_node(self, r: RedisCluster) -> None:
prev_default_node = r.get_default_node()
r.replace_default_node()
assert r.get_default_node() != prev_default_node
r.replace_default_node(prev_default_node)
assert r.get_default_node() == prev_default_node

async def test_default_node_is_replaced_after_exception(self, r):
curr_default_node = r.get_default_node()
# CLUSTER NODES command is being executed on the default node
nodes = await r.cluster_nodes()
assert "myself" in nodes.get(curr_default_node.name).get("flags")

# Mock connection error for the default node
mock_node_resp_exc(curr_default_node, ConnectionError("error"))
# Test that the command succeed from a different node
nodes = await r.cluster_nodes()
assert "myself" not in nodes.get(curr_default_node.name).get("flags")
assert r.get_default_node() != curr_default_node


class TestClusterRedisCommands:
"""
Expand Down
23 changes: 23 additions & 0 deletions tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,29 @@ def test_cluster_retry_object(self, r) -> None:
== retry._retries
)

def test_replace_cluster_node(self, r) -> None:
prev_default_node = r.get_default_node()
r.replace_default_node()
assert r.get_default_node() != prev_default_node
r.replace_default_node(prev_default_node)
assert r.get_default_node() == prev_default_node

def test_default_node_is_replaced_after_exception(self, r):
curr_default_node = r.get_default_node()
# CLUSTER NODES command is being executed on the default node
nodes = r.cluster_nodes()
assert "myself" in nodes.get(curr_default_node.name).get("flags")

def raise_connection_error():
raise ConnectionError("error")

# Mock connection error for the default node
mock_node_resp_func(curr_default_node, raise_connection_error)
# Test that the command succeed from a different node
nodes = r.cluster_nodes()
assert "myself" not in nodes.get(curr_default_node.name).get("flags")
assert r.get_default_node() != curr_default_node


@pytest.mark.onlycluster
class TestClusterRedisCommands:
Expand Down

0 comments on commit c466e62

Please sign in to comment.