From d7d433641f3d1ea0dca2794aef064a7d3e28514e Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Mon, 27 Jun 2022 16:10:36 +0530 Subject: [PATCH] commands/cluster: use pipeline to execute split commands (#2230) - allow passing target_nodes to pipeline commands - move READ_COMMANDS to commands/cluster to avoid import cycle - add types to list_or_args --- redis/asyncio/cluster.py | 17 ++-- redis/cluster.py | 65 +++--------- redis/commands/__init__.py | 13 +-- redis/commands/cluster.py | 204 +++++++++++++++++++++---------------- redis/commands/helpers.py | 5 +- 5 files changed, 151 insertions(+), 153 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index a7bea3029a..2894004403 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -23,7 +23,6 @@ from redis.cluster import ( PIPELINE_BLOCKED_COMMANDS, PRIMARY, - READ_COMMANDS, REPLICA, SLOT_ID, AbstractRedisCluster, @@ -32,7 +31,7 @@ get_node_name, parse_cluster_slots, ) -from redis.commands import AsyncRedisClusterCommands +from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.exceptions import ( AskError, @@ -1350,11 +1349,17 @@ async def _execute( nodes = {} for cmd in todo: - target_nodes = await client._determine_nodes(*cmd.args) - if not target_nodes: - raise RedisClusterException( - f"No targets were found to execute {cmd.args} command on" + passed_targets = cmd.kwargs.pop("target_nodes", None) + if passed_targets and not client._is_node_flag(passed_targets): + target_nodes = client._parse_target_nodes(passed_targets) + else: + target_nodes = await client._determine_nodes( + *cmd.args, node_flag=passed_targets ) + if not target_nodes: + raise RedisClusterException( + f"No targets were found to execute {cmd.args} command on" + ) if len(target_nodes) > 1: raise RedisClusterException(f"Too many targets for command {cmd.args}") diff --git a/redis/cluster.py b/redis/cluster.py index b301cdc4e0..6034e9606f 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, Tuple from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan -from redis.commands import CommandsParser, RedisClusterCommands +from redis.commands import READ_COMMANDS, CommandsParser, RedisClusterCommands from redis.connection import ConnectionPool, DefaultParser, Encoder, parse_url from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.exceptions import ( @@ -154,52 +154,6 @@ def parse_cluster_shards(resp, **options): ) KWARGS_DISABLED_KEYS = ("host", "port") -# Not complete, but covers the major ones -# https://redis.io/commands -READ_COMMANDS = frozenset( - [ - "BITCOUNT", - "BITPOS", - "EXISTS", - "GEODIST", - "GEOHASH", - "GEOPOS", - "GEORADIUS", - "GEORADIUSBYMEMBER", - "GET", - "GETBIT", - "GETRANGE", - "HEXISTS", - "HGET", - "HGETALL", - "HKEYS", - "HLEN", - "HMGET", - "HSTRLEN", - "HVALS", - "KEYS", - "LINDEX", - "LLEN", - "LRANGE", - "MGET", - "PTTL", - "RANDOMKEY", - "SCARD", - "SDIFF", - "SINTER", - "SISMEMBER", - "SMEMBERS", - "SRANDMEMBER", - "STRLEN", - "SUNION", - "TTL", - "ZCARD", - "ZCOUNT", - "ZRANGE", - "ZSCORE", - ] -) - def cleanup_kwargs(**kwargs): """ @@ -1993,14 +1947,25 @@ def _send_cluster_commands( # refer to our internal node -> slot table that # tells us where a given # command should route to. - node = self._determine_nodes(*c.args) + passed_targets = c.options.pop("target_nodes", None) + if passed_targets and not self._is_nodes_flag(passed_targets): + target_nodes = self._parse_target_nodes(passed_targets) + else: + target_nodes = self._determine_nodes(*c.args, node_flag=passed_targets) + if not target_nodes: + raise RedisClusterException( + f"No targets were found to execute {c.args} command on" + ) + if len(target_nodes) > 1: + raise RedisClusterException(f"Too many targets for command {c.args}") + node = target_nodes[0] # now that we know the name of the node # ( it's just a string in the form of host:port ) # we can build a list of commands for each node. - node_name = node[0].name + node_name = node.name if node_name not in nodes: - redis_node = self.get_redis_connection(node[0]) + redis_node = self.get_redis_connection(node) connection = get_connection(redis_node, c.args) nodes[node_name] = NodeCommands( redis_node.parse_response, redis_node.connection_pool, connection diff --git a/redis/commands/__init__.py b/redis/commands/__init__.py index e3383ff722..f3f08286c8 100644 --- a/redis/commands/__init__.py +++ b/redis/commands/__init__.py @@ -1,4 +1,4 @@ -from .cluster import AsyncRedisClusterCommands, RedisClusterCommands +from .cluster import READ_COMMANDS, AsyncRedisClusterCommands, RedisClusterCommands from .core import AsyncCoreCommands, CoreCommands from .helpers import list_or_args from .parser import CommandsParser @@ -6,14 +6,15 @@ from .sentinel import AsyncSentinelCommands, SentinelCommands __all__ = [ + "AsyncCoreCommands", "AsyncRedisClusterCommands", - "RedisClusterCommands", + "AsyncRedisModuleCommands", + "AsyncSentinelCommands", "CommandsParser", - "AsyncCoreCommands", "CoreCommands", - "list_or_args", - "AsyncRedisModuleCommands", + "READ_COMMANDS", + "RedisClusterCommands", "RedisModuleCommands", - "AsyncSentinelCommands", "SentinelCommands", + "list_or_args", ] diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index b91b65f083..a1060d2cbb 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -46,25 +46,111 @@ from redis.asyncio.cluster import TargetNodesT +# Not complete, but covers the major ones +# https://redis.io/commands +READ_COMMANDS = frozenset( + [ + "BITCOUNT", + "BITPOS", + "EXISTS", + "GEODIST", + "GEOHASH", + "GEOPOS", + "GEORADIUS", + "GEORADIUSBYMEMBER", + "GET", + "GETBIT", + "GETRANGE", + "HEXISTS", + "HGET", + "HGETALL", + "HKEYS", + "HLEN", + "HMGET", + "HSTRLEN", + "HVALS", + "KEYS", + "LINDEX", + "LLEN", + "LRANGE", + "MGET", + "PTTL", + "RANDOMKEY", + "SCARD", + "SDIFF", + "SINTER", + "SISMEMBER", + "SMEMBERS", + "SRANDMEMBER", + "STRLEN", + "SUNION", + "TTL", + "ZCARD", + "ZCOUNT", + "ZRANGE", + "ZSCORE", + ] +) + + class ClusterMultiKeyCommands(ClusterCommandsProtocol): """ A class containing commands that handle more than one key """ def _partition_keys_by_slot(self, keys: Iterable[KeyT]) -> Dict[int, List[KeyT]]: - """ - Split keys into a dictionary that maps a slot to - a list of keys. - """ + """Split keys into a dictionary that maps a slot to a list of keys.""" + slots_to_keys = {} for key in keys: - k = self.encoder.encode(key) - slot = key_slot(k) + slot = key_slot(self.encoder.encode(key)) slots_to_keys.setdefault(slot, []).append(key) return slots_to_keys - def mget_nonatomic(self, keys: KeysT, *args) -> List[Optional[Any]]: + def _partition_pairs_by_slot( + self, mapping: Mapping[AnyKeyT, EncodableT] + ) -> Dict[int, List[EncodableT]]: + """Split pairs into a dictionary that maps a slot to a list of pairs.""" + + slots_to_pairs = {} + for pair in mapping.items(): + slot = key_slot(self.encoder.encode(pair[0])) + slots_to_pairs.setdefault(slot, []).extend(pair) + + return slots_to_pairs + + def _execute_pipeline_by_slot( + self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]] + ) -> List[Any]: + read_from_replicas = self.read_from_replicas and command in READ_COMMANDS + pipe = self.pipeline() + [ + pipe.execute_command( + command, + *slot_args, + target_nodes=[ + self.nodes_manager.get_node_from_slot(slot, read_from_replicas) + ], + ) + for slot, slot_args in slots_to_args.items() + ] + return pipe.execute() + + def _reorder_keys_by_command( + self, + keys: Iterable[KeyT], + slots_to_args: Mapping[int, Iterable[EncodableT]], + responses: Iterable[Any], + ) -> List[Any]: + results = { + k: v + for slot_values, response in zip(slots_to_args.values(), responses) + for k, v in zip(slot_values, response) + } + return [results[key] for key in keys] + + def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]: """ Splits the keys into different slots and then calls MGET for the keys of every slot. This operation will not be atomic @@ -75,30 +161,17 @@ def mget_nonatomic(self, keys: KeysT, *args) -> List[Optional[Any]]: For more information see https://redis.io/commands/mget """ - from redis.client import EMPTY_RESPONSE - - options = {} - if not args: - options[EMPTY_RESPONSE] = [] - # Concatenate all keys into a list keys = list_or_args(keys, args) + # Split keys into slots slots_to_keys = self._partition_keys_by_slot(keys) - # Call MGET for every slot and concatenate - # the results - # We must make sure that the keys are returned in order - all_results = {} - for slot_keys in slots_to_keys.values(): - slot_values = self.execute_command("MGET", *slot_keys, **options) + # Execute commands using a pipeline + res = self._execute_pipeline_by_slot("MGET", slots_to_keys) - slot_results = dict(zip(slot_keys, slot_values)) - all_results.update(slot_results) - - # Sort the results - vals_in_order = [all_results[key] for key in keys] - return vals_in_order + # Reorder keys in the order the user provided & return + return self._reorder_keys_by_command(keys, slots_to_keys, res) def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> List[bool]: """ @@ -114,35 +187,22 @@ def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> List[bool]: """ # Partition the keys by slot - slots_to_pairs = {} - for pair in mapping.items(): - # encode the key - k = self.encoder.encode(pair[0]) - slot = key_slot(k) - slots_to_pairs.setdefault(slot, []).extend(pair) - - # Call MSET for every slot and concatenate - # the results (one result per slot) - res = [] - for pairs in slots_to_pairs.values(): - res.append(self.execute_command("MSET", *pairs)) + slots_to_pairs = self._partition_pairs_by_slot(mapping) - return res + # Execute commands using a pipeline & return list of replies + return self._execute_pipeline_by_slot("MSET", slots_to_pairs) def _split_command_across_slots(self, command: str, *keys: KeyT) -> int: """ Runs the given command once for the keys of each slot. Returns the sum of the return values. """ + # Partition the keys by slot slots_to_keys = self._partition_keys_by_slot(keys) # Sum up the reply from each command - total = 0 - for slot_keys in slots_to_keys.values(): - total += self.execute_command(command, *slot_keys) - - return total + return sum(self._execute_pipeline_by_slot(command, slots_to_keys)) def exists(self, *keys: KeyT) -> ResponseT: """ @@ -202,7 +262,7 @@ class AsyncClusterMultiKeyCommands(ClusterMultiKeyCommands): A class containing commands that handle more than one key """ - async def mget_nonatomic(self, keys: KeysT, *args) -> List[Optional[Any]]: + async def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]: """ Splits the keys into different slots and then calls MGET for the keys of every slot. This operation will not be atomic @@ -213,36 +273,17 @@ async def mget_nonatomic(self, keys: KeysT, *args) -> List[Optional[Any]]: For more information see https://redis.io/commands/mget """ - from redis.client import EMPTY_RESPONSE - - options = {} - if not args: - options[EMPTY_RESPONSE] = [] - # Concatenate all keys into a list keys = list_or_args(keys, args) + # Split keys into slots slots_to_keys = self._partition_keys_by_slot(keys) - # Call MGET for every slot and concatenate - # the results - # We must make sure that the keys are returned in order - all_values = await asyncio.gather( - *( - asyncio.ensure_future( - self.execute_command("MGET", *slot_keys, **options) - ) - for slot_keys in slots_to_keys.values() - ) - ) + # Execute commands using a pipeline + res = await self._execute_pipeline_by_slot("MGET", slots_to_keys) - all_results = {} - for slot_keys, slot_values in zip(slots_to_keys.values(), all_values): - all_results.update(dict(zip(slot_keys, slot_values))) - - # Sort the results - vals_in_order = [all_results[key] for key in keys] - return vals_in_order + # Reorder keys in the order the user provided & return + return self._reorder_keys_by_command(keys, slots_to_keys, res) async def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> List[bool]: """ @@ -258,39 +299,22 @@ async def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> List[bo """ # Partition the keys by slot - slots_to_pairs = {} - for pair in mapping.items(): - # encode the key - k = self.encoder.encode(pair[0]) - slot = key_slot(k) - slots_to_pairs.setdefault(slot, []).extend(pair) + slots_to_pairs = self._partition_pairs_by_slot(mapping) - # Call MSET for every slot and concatenate - # the results (one result per slot) - return await asyncio.gather( - *( - asyncio.ensure_future(self.execute_command("MSET", *pairs)) - for pairs in slots_to_pairs.values() - ) - ) + # Execute commands using a pipeline & return list of replies + return await self._execute_pipeline_by_slot("MSET", slots_to_pairs) async def _split_command_across_slots(self, command: str, *keys: KeyT) -> int: """ Runs the given command once for the keys of each slot. Returns the sum of the return values. """ + # Partition the keys by slot slots_to_keys = self._partition_keys_by_slot(keys) # Sum up the reply from each command - return sum( - await asyncio.gather( - *( - asyncio.ensure_future(self.execute_command(command, *slot_keys)) - for slot_keys in slots_to_keys.values() - ) - ) - ) + return sum(await self._execute_pipeline_by_slot(command, slots_to_keys)) class ClusterManagementCommands(ManagementCommands): diff --git a/redis/commands/helpers.py b/redis/commands/helpers.py index 2b873e3385..6989ab59fa 100644 --- a/redis/commands/helpers.py +++ b/redis/commands/helpers.py @@ -1,9 +1,12 @@ import copy import random import string +from typing import List, Tuple +from redis.typing import KeysT, KeyT -def list_or_args(keys, args): + +def list_or_args(keys: KeysT, args: Tuple[KeyT, ...]) -> List[KeyT]: # returns a single new list combining keys and args try: iter(keys)