Skip to content

Commit

Permalink
Fixed CommandsParser relevant bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
barshaul committed Oct 26, 2021
1 parent 37aa5c0 commit 5516db7
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 11 deletions.
4 changes: 3 additions & 1 deletion redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ class RedisCluster(ClusterCommands, DataAccessCommands,
"CLUSTER SLOTS",
"RANDOMKEY",
"COMMAND",
"DEBUG",
],
RANDOM,
),
Expand Down Expand Up @@ -458,7 +459,8 @@ def __init__(
Redis.RESPONSE_CALLBACKS,
self.__class__.CLUSTER_COMMANDS_RESPONSE_CALLBACKS))
self.result_callbacks = self.__class__.RESULT_CALLBACKS
self.commands_parser = CommandsParser(self)
self.commands_parser = CommandsParser(self.get_random_node().
redis_connection)

def __enter__(self):
return self
Expand Down
32 changes: 28 additions & 4 deletions redis/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,17 @@


class CommandsParser:
DEFAULT_KEY_POS = 1

def __init__(self, redis_connection):
self.commands = redis_connection.execute_command("COMMAND")
self.initialized = False
self.commands = {}
self.initialize(redis_connection)

def initialize(self, r):
if r is not None:
self.commands = r.execute_command("COMMAND")
self.initialized = True

# As soon as this PR is merged into Redis, we should reimplement
# our logic to use COMMAND INFO changes to determine the key positions
Expand All @@ -27,6 +36,10 @@ def get_keys(self, *args):
# The command has no keys in it
return None

if not self.initialized:
# return the argument in the default position for keys
return [args[__class__.DEFAULT_KEY_POS]]

cmd_name = args[0].lower()
if len(cmd_name.split()) > 1:
# we need to take only the main command, e.g. 'memory' for
Expand All @@ -38,7 +51,13 @@ def get_keys(self, *args):
return None

command = self.commands.get(cmd_name)
if 'movablekeys' not in command['flags']:
if 'movablekeys' not in command['flags'] and 'pubsub' not in \
command['flags']:
if command['step_count'] == 0 and command['first_key_pos'] == 0 \
and command['last_key_pos'] == 0:
# We don't have further info, return the argument in the
# default position for keys
return [args[__class__.DEFAULT_KEY_POS]]
last_key_pos = command['last_key_pos']
if last_key_pos == -1:
last_key_pos = len(args) - 1
Expand Down Expand Up @@ -108,10 +127,12 @@ def get_keys_complex(self, *args):
if 'STOREDIST' in args:
storedist_idx = args.index('STOREDIST')
keys.append(args[storedist_idx + 1])
elif command == 'MEMORY USAGE':
elif command in ['MEMORY USAGE', 'PUBLISH', 'PUBSUB CHANNELS']:
# format example:
# PUBLISH channel message
keys = [args[1]]
elif command == 'MIGRATE':
# format exapmle:
# format example:
# MIGRATE 192.168.1.34 6379 "" 0 5000 KEYS key1 key2 key3
if args[3] == "":
keys_idx = args.index('KEYS')
Expand All @@ -125,6 +146,9 @@ def get_keys_complex(self, *args):
keys = None
else:
keys = list(args[3:5])
elif command in ['SUBSCRIBE', 'PSUBSCRIBE', 'UNSUBSCRIBE',
'PUNSUBSCRIBE', 'PUBSUB NUMSUB']:
keys = list(args[1:])
else:
keys = None
return keys
Expand Down
17 changes: 11 additions & 6 deletions tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def execute_command(*_args, **_kwargs):
if _args[0] == 'CLUSTER SLOTS':
mock_cluster_slots = cluster_slots
return mock_cluster_slots
elif _args[0] == 'COMMAND':
return {'get': {'name': 'get', 'arity': 2, 'flags':
['readonly', 'fast'], 'first_key_pos': 1,
'last_key_pos': 1, 'step_count': 1}}
elif _args[1] == 'cluster-require-full-coverage':
return {'cluster-require-full-coverage': 'yes'}
elif func is not None:
Expand Down Expand Up @@ -132,7 +136,7 @@ def test_startup_nodes(self):
ClusterNode(default_host, port_2)]
cluster = get_mocked_redis_client(startup_nodes=startup_nodes)
assert cluster.get_node(host=default_host, port=port_1) is not None \
and cluster.get_node(host=default_host, port=port_2) is not None
and cluster.get_node(host=default_host, port=port_2) is not None

def test_empty_startup_nodes(self):
"""
Expand Down Expand Up @@ -355,7 +359,7 @@ def test_keyslot(self, r):

def test_get_node_name(self):
assert get_node_name(default_host, default_port) == \
"{0}:{1}".format(default_host, default_port)
"{0}:{1}".format(default_host, default_port)

def test_all_nodes(self, r):
"""
Expand Down Expand Up @@ -398,7 +402,7 @@ def raise_cluster_down_error(target_node, *args, **kwargs):
with pytest.raises(ClusterDownError):
rc.get("bar")
assert execute_command.failed_calls == \
rc.cluster_error_retry_attempts
rc.cluster_error_retry_attempts

@pytest.mark.filterwarnings("ignore:ConnectionError")
def test_connection_error_overreaches_retry_attempts(self):
Expand All @@ -419,7 +423,7 @@ def raise_conn_error(target_node, *args, **kwargs):
with pytest.raises(ConnectionError):
rc.get("bar")
assert execute_command.failed_calls == \
rc.cluster_error_retry_attempts
rc.cluster_error_retry_attempts


@skip_if_not_cluster_mode()
Expand Down Expand Up @@ -469,7 +473,7 @@ def test_pubsub_numsub_merge_results(self, r):
p.subscribe(channel)
# Assert that each node returns that only one client is subscribed
assert node.redis_connection.pubsub_numsub(channel) == \
[(b_channel, 1)]
[(b_channel, 1)]
# Assert that the cluster's pubsub_numsub function returns ALL clients
# subscribed to this channel in the entire cluster
assert r.pubsub_numsub(channel) == [(b_channel, len(nodes))]
Expand Down Expand Up @@ -718,7 +722,8 @@ def execute_command(*args, **kwargs):
['127.0.0.1', 7002, 'node_2'],
]
]

elif args[0] == 'COMMAND':
return {}
elif args[1] == 'cluster-require-full-coverage':
return {'cluster-require-full-coverage': 'yes'}

Expand Down

0 comments on commit 5516db7

Please sign in to comment.