Skip to content

Commit 484861e

Browse files
committed
Changed determine_nodes to return only the target nodes, added a comparison to determine whether a node is the default node instead
1 parent b3ab42a commit 484861e

File tree

2 files changed

+37
-37
lines changed

2 files changed

+37
-37
lines changed

redis/asyncio/cluster.py

+18-20
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
List,
1212
Mapping,
1313
Optional,
14-
Tuple,
1514
Type,
1615
TypeVar,
1716
Union,
@@ -517,44 +516,37 @@ def set_response_callback(self, command: str, callback: ResponseCallbackT) -> No
517516

518517
async def _determine_nodes(
519518
self, command: str, *args: Any, node_flag: Optional[str] = None
520-
) -> Tuple[List["ClusterNode"], bool]:
521-
"""Determine which nodes should be executed the command on
522-
523-
Returns:
524-
tuple[list[Type[ClusterNode]], bool]:
525-
A tuple containing a list of target nodes and a bool indicating
526-
if the return node was chosen because it is the default node
527-
"""
519+
) -> List["ClusterNode"]:
520+
# Determine which nodes should be executed the command on.
521+
# Returns a list of target nodes.
528522
if not node_flag:
529523
# get the nodes group for this command if it was predefined
530524
node_flag = self.command_flags.get(command)
531525

532526
if node_flag in self.node_flags:
533527
if node_flag == self.__class__.DEFAULT_NODE:
534528
# return the cluster's default node
535-
return [self.nodes_manager.default_node], True
529+
return [self.nodes_manager.default_node]
536530
if node_flag == self.__class__.PRIMARIES:
537531
# return all primaries
538-
return self.nodes_manager.get_nodes_by_server_type(PRIMARY), False
532+
return self.nodes_manager.get_nodes_by_server_type(PRIMARY)
539533
if node_flag == self.__class__.REPLICAS:
540534
# return all replicas
541-
return self.nodes_manager.get_nodes_by_server_type(REPLICA), False
535+
return self.nodes_manager.get_nodes_by_server_type(REPLICA)
542536
if node_flag == self.__class__.ALL_NODES:
543537
# return all nodes
544-
return list(self.nodes_manager.nodes_cache.values()), False
538+
return list(self.nodes_manager.nodes_cache.values())
545539
if node_flag == self.__class__.RANDOM:
546540
# return a random node
547-
return [
548-
random.choice(list(self.nodes_manager.nodes_cache.values()))
549-
], False
541+
return [random.choice(list(self.nodes_manager.nodes_cache.values()))]
550542

551543
# get the node that holds the key's slot
552544
return [
553545
self.nodes_manager.get_node_from_slot(
554546
await self._determine_slot(command, *args),
555547
self.read_from_replicas and command in READ_COMMANDS,
556548
)
557-
], False
549+
]
558550

559551
async def _determine_slot(self, command: str, *args: Any) -> int:
560552
if self.command_flags.get(command) == SLOT_ID:
@@ -671,13 +663,18 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
671663
try:
672664
if not target_nodes_specified:
673665
# Determine the nodes to execute the command on
674-
target_nodes, is_default_node = await self._determine_nodes(
666+
target_nodes = await self._determine_nodes(
675667
*args, node_flag=passed_targets
676668
)
677669
if not target_nodes:
678670
raise RedisClusterException(
679671
f"No targets were found to execute {args} command on"
680672
)
673+
if (
674+
len(target_nodes) == 1
675+
and target_nodes[0] == self.get_default_node()
676+
):
677+
is_default_node = True
681678

682679
if len(target_nodes) == 1:
683680
# Return the processed result
@@ -1456,7 +1453,7 @@ async def _execute(
14561453
if passed_targets and not client._is_node_flag(passed_targets):
14571454
target_nodes = client._parse_target_nodes(passed_targets)
14581455
else:
1459-
target_nodes, is_default_node = await client._determine_nodes(
1456+
target_nodes = await client._determine_nodes(
14601457
*cmd.args, node_flag=passed_targets
14611458
)
14621459
if not target_nodes:
@@ -1465,8 +1462,9 @@ async def _execute(
14651462
)
14661463
if len(target_nodes) > 1:
14671464
raise RedisClusterException(f"Too many targets for command {cmd.args}")
1468-
14691465
node = target_nodes[0]
1466+
if node == client.get_default_node():
1467+
is_default_node = True
14701468
if node.name not in nodes:
14711469
nodes[node.name] = (node, [])
14721470
nodes[node.name][1].append(cmd)

redis/cluster.py

+19-17
Original file line numberDiff line numberDiff line change
@@ -835,14 +835,9 @@ def set_response_callback(self, command, callback):
835835
"""Set a custom Response Callback"""
836836
self.cluster_response_callbacks[command] = callback
837837

838-
def _determine_nodes(self, *args, **kwargs) -> Tuple[List["ClusterNode"], bool]:
839-
"""Determine which nodes should be executed the command on
840-
841-
Returns:
842-
tuple[list[Type[ClusterNode]], bool]:
843-
A tuple containing a list of target nodes and a bool indicating
844-
if the return node was chosen because it is the default node
845-
"""
838+
def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]:
839+
# Determine which nodes should be executed the command on.
840+
# Returns a list of target nodes.
846841
command = args[0].upper()
847842
if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags:
848843
command = f"{args[0]} {args[1]}".upper()
@@ -856,28 +851,28 @@ def _determine_nodes(self, *args, **kwargs) -> Tuple[List["ClusterNode"], bool]:
856851
command_flag = self.command_flags.get(command)
857852
if command_flag == self.__class__.RANDOM:
858853
# return a random node
859-
return [self.get_random_node()], False
854+
return [self.get_random_node()]
860855
elif command_flag == self.__class__.PRIMARIES:
861856
# return all primaries
862-
return self.get_primaries(), False
857+
return self.get_primaries()
863858
elif command_flag == self.__class__.REPLICAS:
864859
# return all replicas
865-
return self.get_replicas(), False
860+
return self.get_replicas()
866861
elif command_flag == self.__class__.ALL_NODES:
867862
# return all nodes
868-
return self.get_nodes(), False
863+
return self.get_nodes()
869864
elif command_flag == self.__class__.DEFAULT_NODE:
870865
# return the cluster's default node
871-
return [self.nodes_manager.default_node], True
866+
return [self.nodes_manager.default_node]
872867
elif command in self.__class__.SEARCH_COMMANDS[0]:
873-
return [self.nodes_manager.default_node], True
868+
return [self.nodes_manager.default_node]
874869
else:
875870
# get the node that holds the key's slot
876871
slot = self.determine_slot(*args)
877872
node = self.nodes_manager.get_node_from_slot(
878873
slot, self.read_from_replicas and command in READ_COMMANDS
879874
)
880-
return [node], False
875+
return [node]
881876

882877
def _should_reinitialized(self):
883878
# To reinitialize the cluster on every MOVED error,
@@ -1045,13 +1040,18 @@ def execute_command(self, *args, **kwargs):
10451040
res = {}
10461041
if not target_nodes_specified:
10471042
# Determine the nodes to execute the command on
1048-
target_nodes, is_default_node = self._determine_nodes(
1043+
target_nodes = self._determine_nodes(
10491044
*args, **kwargs, nodes_flag=passed_targets
10501045
)
10511046
if not target_nodes:
10521047
raise RedisClusterException(
10531048
f"No targets were found to execute {args} command on"
10541049
)
1050+
if (
1051+
len(target_nodes) == 1
1052+
and target_nodes[0] == self.get_default_node()
1053+
):
1054+
is_default_node = True
10551055
for node in target_nodes:
10561056
res[node.name] = self._execute_command(node, *args, **kwargs)
10571057
# Return the processed result
@@ -1935,7 +1935,7 @@ def _send_cluster_commands(
19351935
if passed_targets and not self._is_nodes_flag(passed_targets):
19361936
target_nodes = self._parse_target_nodes(passed_targets)
19371937
else:
1938-
target_nodes, is_default_node = self._determine_nodes(
1938+
target_nodes = self._determine_nodes(
19391939
*c.args, node_flag=passed_targets
19401940
)
19411941
if not target_nodes:
@@ -1948,6 +1948,8 @@ def _send_cluster_commands(
19481948
)
19491949

19501950
node = target_nodes[0]
1951+
if node == self.get_default_node():
1952+
is_default_node = True
19511953

19521954
# now that we know the name of the node
19531955
# ( it's just a string in the form of host:port )

0 commit comments

Comments
 (0)