Skip to content

Commit

Permalink
Merge existing changes to the forked version (redis#1)
Browse files Browse the repository at this point in the history
* [GROW-2938] do not reset redis_connection on an error

* [GROW-2938] add backoff to more errors

* [GROW-2938] recover from SlotNotCoveredError

* [GROW-2938] prevent get_node_from_slot from failing due to concurrent cluster slots refresh

* [GROW-2938] add retry to ClusterPipeline
  • Loading branch information
zach-iee authored Jun 23, 2023
1 parent 2bb7f10 commit 63e06dd
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 84 deletions.
117 changes: 73 additions & 44 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,8 @@ def __init__(
self.retry = retry
kwargs.update({"retry": self.retry})
else:
kwargs.update({"retry": Retry(default_backoff(), 0)})
self.retry = Retry(default_backoff(), 0)
kwargs["retry"] = self.retry

self.encoder = Encoder(
kwargs.get("encoding", "utf-8"),
Expand Down Expand Up @@ -775,6 +776,7 @@ def pipeline(self, transaction=None, shard_hint=None):
read_from_replicas=self.read_from_replicas,
reinitialize_steps=self.reinitialize_steps,
lock=self._lock,
retry=self.retry,
)

def lock(
Expand Down Expand Up @@ -858,41 +860,49 @@ def set_response_callback(self, command, callback):
def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]:
# Determine which nodes should be executed the command on.
# Returns a list of target nodes.
command = args[0].upper()
if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags:
command = f"{args[0]} {args[1]}".upper()

nodes_flag = kwargs.pop("nodes_flag", None)
if nodes_flag is not None:
# nodes flag passed by the user
command_flag = nodes_flag
else:
# get the nodes group for this command if it was predefined
command_flag = self.command_flags.get(command)
if command_flag == self.__class__.RANDOM:
# return a random node
return [self.get_random_node()]
elif command_flag == self.__class__.PRIMARIES:
# return all primaries
return self.get_primaries()
elif command_flag == self.__class__.REPLICAS:
# return all replicas
return self.get_replicas()
elif command_flag == self.__class__.ALL_NODES:
# return all nodes
return self.get_nodes()
elif command_flag == self.__class__.DEFAULT_NODE:
# return the cluster's default node
return [self.nodes_manager.default_node]
elif command in self.__class__.SEARCH_COMMANDS[0]:
return [self.nodes_manager.default_node]
else:
# get the node that holds the key's slot
slot = self.determine_slot(*args)
node = self.nodes_manager.get_node_from_slot(
slot, self.read_from_replicas and command in READ_COMMANDS
)
return [node]
try:
command = args[0].upper()
if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags:
command = f"{args[0]} {args[1]}".upper()

nodes_flag = kwargs.pop("nodes_flag", None)
if nodes_flag is not None:
# nodes flag passed by the user
command_flag = nodes_flag
else:
# get the nodes group for this command if it was predefined
command_flag = self.command_flags.get(command)
if command_flag == self.__class__.RANDOM:
# return a random node
return [self.get_random_node()]
elif command_flag == self.__class__.PRIMARIES:
# return all primaries
return self.get_primaries()
elif command_flag == self.__class__.REPLICAS:
# return all replicas
return self.get_replicas()
elif command_flag == self.__class__.ALL_NODES:
# return all nodes
return self.get_nodes()
elif command_flag == self.__class__.DEFAULT_NODE:
# return the cluster's default node
return [self.nodes_manager.default_node]
elif command in self.__class__.SEARCH_COMMANDS[0]:
return [self.nodes_manager.default_node]
else:
# get the node that holds the key's slot
slot = self.determine_slot(*args)
node = self.nodes_manager.get_node_from_slot(
slot, self.read_from_replicas and command in READ_COMMANDS
)
return [node]
except SlotNotCoveredError as e:
self.reinitialize_counter += 1
if self._should_reinitialized():
self.nodes_manager.initialize()
# Reset the counter
self.reinitialize_counter = 0
raise e

def _should_reinitialized(self):
# To reinitialize the cluster on every MOVED error,
Expand Down Expand Up @@ -1084,6 +1094,12 @@ def execute_command(self, *args, **kwargs):
# The nodes and slots cache were reinitialized.
# Try again with the new cluster setup.
retry_attempts -= 1
if self.retry and isinstance(e, self.retry._supported_errors):
backoff = self.retry._backoff.compute(
self.cluster_error_retry_attempts - retry_attempts
)
if backoff > 0:
time.sleep(backoff)
continue
else:
# raise the exception
Expand Down Expand Up @@ -1143,8 +1159,6 @@ def _execute_command(self, target_node, *args, **kwargs):
# Remove the failed node from the startup nodes before we try
# to reinitialize the cluster
self.nodes_manager.startup_nodes.pop(target_node.name, None)
# Reset the cluster node's connection
target_node.redis_connection = None
self.nodes_manager.initialize()
raise e
except MovedError as e:
Expand All @@ -1164,6 +1178,13 @@ def _execute_command(self, target_node, *args, **kwargs):
else:
self.nodes_manager.update_moved_exception(e)
moved = True
except SlotNotCoveredError as e:
self.reinitialize_counter += 1
if self._should_reinitialized():
self.nodes_manager.initialize()
# Reset the counter
self.reinitialize_counter = 0
raise e
except TryAgainError:
if ttl < self.RedisClusterRequestTTL / 2:
time.sleep(0.05)
Expand Down Expand Up @@ -1397,7 +1418,10 @@ def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None):
# randomly choose one of the replicas
node_idx = random.randint(1, len(self.slots_cache[slot]) - 1)

return self.slots_cache[slot][node_idx]
try:
return self.slots_cache[slot][node_idx]
except IndexError:
return self.slots_cache[slot][0]

def get_nodes_by_server_type(self, server_type):
"""
Expand Down Expand Up @@ -1774,6 +1798,7 @@ def __init__(
cluster_error_retry_attempts: int = 3,
reinitialize_steps: int = 5,
lock=None,
retry: Optional["Retry"] = None,
**kwargs,
):
""" """
Expand All @@ -1799,6 +1824,7 @@ def __init__(
if lock is None:
lock = threading.Lock()
self._lock = lock
self.retry = retry

def __repr__(self):
""" """
Expand Down Expand Up @@ -1931,8 +1957,9 @@ def send_cluster_commands(
stack,
raise_on_error=raise_on_error,
allow_redirections=allow_redirections,
attempts_count=self.cluster_error_retry_attempts - retry_attempts,
)
except (ClusterDownError, ConnectionError) as e:
except (ClusterDownError, ConnectionError, TimeoutError) as e:
if retry_attempts > 0:
# Try again with the new cluster setup. All other errors
# should be raised.
Expand All @@ -1942,7 +1969,7 @@ def send_cluster_commands(
raise e

def _send_cluster_commands(
self, stack, raise_on_error=True, allow_redirections=True
self, stack, raise_on_error=True, allow_redirections=True, attempts_count=0
):
"""
Send a bunch of cluster commands to the redis cluster.
Expand Down Expand Up @@ -1997,9 +2024,11 @@ def _send_cluster_commands(
redis_node = self.get_redis_connection(node)
try:
connection = get_connection(redis_node, c.args)
except ConnectionError:
# Connection retries are being handled in the node's
# Retry object. Reinitialize the node -> slot table.
except (ConnectionError, TimeoutError) as e:
if self.retry and isinstance(e, self.retry._supported_errors):
backoff = self.retry._backoff.compute(attempts_count)
if backoff > 0:
time.sleep(backoff)
self.nodes_manager.initialize()
if is_default_node:
self.replace_default_node()
Expand Down
87 changes: 47 additions & 40 deletions tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import socket
import socketserver
import threading
import uuid
import warnings
from queue import LifoQueue, Queue
from time import sleep
Expand All @@ -12,7 +13,12 @@
import pytest

from redis import Redis
from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff
from redis.backoff import (
ConstantBackoff,
ExponentialBackoff,
NoBackoff,
default_backoff,
)
from redis.cluster import (
PRIMARY,
REDIS_CLUSTER_HASH_SLOTS,
Expand All @@ -35,6 +41,7 @@
RedisClusterException,
RedisError,
ResponseError,
SlotNotCoveredError,
TimeoutError,
)
from redis.retry import Retry
Expand Down Expand Up @@ -788,45 +795,6 @@ def test_not_require_full_coverage_cluster_down_error(self, r):
else:
raise e

def test_timeout_error_topology_refresh_reuse_connections(self, r):
"""
By mucking TIMEOUT errors, we'll force the cluster topology to be reinitialized,
and then ensure that only the impacted connection is replaced
"""
node = r.get_node_from_key("key")
r.set("key", "value")
node_conn_origin = {}
for n in r.get_nodes():
node_conn_origin[n.name] = n.redis_connection
real_func = r.get_redis_connection(node).parse_response

class counter:
def __init__(self, val=0):
self.val = int(val)

count = counter(0)
with patch.object(Redis, "parse_response") as parse_response:

def moved_redirect_effect(connection, *args, **options):
# raise a timeout for 5 times so we'll need to reinitialize the topology
if count.val == 4:
parse_response.side_effect = real_func
count.val += 1
raise TimeoutError()

parse_response.side_effect = moved_redirect_effect
assert r.get("key") == b"value"
for node_name, conn in node_conn_origin.items():
if node_name == node.name:
# The old redis connection of the timed out node should have been
# deleted and replaced
assert conn != r.get_redis_connection(node)
else:
# other nodes' redis connection should have been reused during the
# topology refresh
cur_node = r.get_node(node_name=node_name)
assert conn == r.get_redis_connection(cur_node)

def test_cluster_get_set_retry_object(self, request):
retry = Retry(NoBackoff(), 2)
r = _get_client(RedisCluster, request, retry=retry)
Expand Down Expand Up @@ -939,6 +907,45 @@ def address_remap(address):
n_used = sum((1 if p.n_connections else 0) for p in proxies)
assert n_used > 1

@pytest.mark.parametrize("error", [ConnectionError, TimeoutError])
def test_additional_backoff_redis_cluster(self, error):
with patch.object(ConstantBackoff, "compute") as compute:

def _compute(target_node, *args, **kwargs):
return 1

compute.side_effect = _compute
with patch.object(RedisCluster, "_execute_command") as execute_command:

def raise_error(target_node, *args, **kwargs):
execute_command.failed_calls += 1
raise error("mocked error")

execute_command.side_effect = raise_error

rc = get_mocked_redis_client(
host=default_host,
port=default_port,
retry=Retry(ConstantBackoff(1), 3),
)

with pytest.raises(error):
rc.get("bar")
assert compute.call_count == rc.cluster_error_retry_attempts

@pytest.mark.parametrize("reinitialize_steps", [2, 10, 99])
def test_recover_slot_not_covered_error(self, request, reinitialize_steps):
rc = _get_client(RedisCluster, request, reinitialize_steps=reinitialize_steps)
key = uuid.uuid4().hex

rc.nodes_manager.slots_cache[rc.keyslot(key)] = []

for _ in range(0, reinitialize_steps):
with pytest.raises(SlotNotCoveredError):
rc.get(key)

rc.get(key)


@pytest.mark.onlycluster
class TestClusterRedisCommands:
Expand Down

0 comments on commit 63e06dd

Please sign in to comment.