diff --git a/ClusterRedisExample.py b/ClusterRedisExample.py new file mode 100644 index 0000000000..3ce71a184c --- /dev/null +++ b/ClusterRedisExample.py @@ -0,0 +1,45 @@ +from redis import RedisCluster as Redis +from redis.cluster import ClusterNode + +host = 'localhost' +startup_nodes = [ClusterNode(host, 16379), ClusterNode(host, 16380)] + +# from_url examples +rc_url = Redis.from_url("redis://localhost:16379/0") +print(rc_url.cluster_slots()) +print(rc_url.ping(Redis.PRIMARIES)) +print(rc_url.ping(Redis.REPLICAS)) +print(rc_url.ping(Redis.RANDOM)) +print(rc_url.ping(Redis.ALL_NODES)) +print(rc_url.execute_command("STRALGO", "LCS", "STRINGS", "string1", + "string2", + target_nodes=rc_url.get_random_node())) +print(rc_url.client_list()) +print(rc_url.set('foo', 'bar1')) +print(rc_url.mget('{bar}1', '{bar}2')) +print(rc_url.set('zzzsdfsdf', 'bar2')) +print(rc_url.keyslot('bar')) +print(rc_url.set('{000}', 'bar3')) +print(f"get_nodes: {rc_url.get_nodes()}") +print(rc_url.get('foo')) +print(rc_url.keys()) +# rc = Redis(host=host, port=6379) +rc = Redis(startup_nodes=startup_nodes, decode_responses=True) +print(rc.get('{000}')) +print(rc.keys()) +print(rc.cluster_save_config(rc.get_primaries())) +print(rc.cluster_save_config(rc.get_node(host=host, port=16379))) + +# READONLY examples +rc_readonly = Redis(startup_nodes=startup_nodes, read_from_replicas=True, + debug=True) +rc_readonly.set('bar', 'foo') +for i in range(0, 4): + # Assigning the read command to the slot's servers in a Round-Robin manner + print(rc_readonly.get('bar')) +# set command would be directed only to the slot's primary node +# reset READONLY flag +print(rc_readonly.readwrite()) +for i in range(0, 4): + # now the get command would be directed only to the slot's primary node + print(rc_readonly.get('bar')) diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..e69de29bb2 diff --git a/README.md b/README.md index b6d3115a0b..80046bc928 100644 --- a/README.md +++ b/README.md @@ -941,8 +941,228 @@ C 3 ### Cluster Mode -redis-py does not currently support [Cluster -Mode](https://redis.io/topics/cluster-tutorial). +redis-py is now supports cluster mode and provides a client for +[Redis Cluster](). + +The cluster client is based on [redis-py-cluster](https://github.com/Grokzen/redis-py-cluster) +by Grokzen, with a lot of added and +changed functionality. + +**Create RedisCluster:** + +Connecting redis-py to the Redis Cluster instance(s) is easy. +RedisCluster requires at least one node to discover the whole cluster nodes, +and there is multiple ways of creating a RedisCluster instance: + +- Use the 'host' and 'port' arguments: + +``` pycon + >>> from redis.cluster import RedisCluster as Redis + >>> rc = Redis(host='localhost', port=6379) + >>> print(rc.get_nodes()) + [[host=127.0.0.1,port=6379,name=127.0.0.1:6379,server_type=primary,redis_connection=Redis>>], [host=127.0.0.1,port=6378,name=127.0.0.1:6378,server_type=primary,redis_connection=Redis>>], [host=127.0.0.1,port=6377,name=127.0.0.1:6377,server_type=replica,redis_connection=Redis>>]] +``` +- Use Redis URL: + +``` pycon + >>> from redis.cluster import RedisCluster as Redis + >>> rc = Redis.from_url("redis://localhost:6379/0") +``` + +- Use ClusterNode(s): + +``` pycon + >>> from redis.cluster import RedisCluster as Redis + >>> from redis.cluster import ClusterNode + >>> nodes = [ClusterNode('localhost', 6379), ClusterNode('localhost', 6378)] + >>> rc = Redis(startup_nodes=nodes) +``` + +When a RedisCluster instance is being created it first attempts to establish a +connection to one of the provided startup nodes. If none of the startup nodes +are reachable, a 'RedisClusterException' will be thrown. +After a connection to the one of the cluster's nodes is established, the +RedisCluster instance will be initialized with 3 caches: +a slots cache which maps each of the 16384 slots to the node/s handling them, +a nodes cache that contains ClusterNode objects (name, host, port, redis connection) +for all of the cluster's nodes, and a commands cache contains all the server +supported commands that were retrieved using the Redis 'COMMAND' output. + +RedisCluster instance can be directly used to execute Redis commands. When a +command is being executed through the cluster instance, the target node(s) will +be internally determined. When using a key-based command, the target node will +be the node that holds the key's slot. +Cluster management commands or other cluster commands have predefined node +group targets (all-primaries, all-nodes, random-node, all-replicas), which are +outlined in the command’s function documentation. +For example, ‘KEYS’ command will be sent to all primaries and return all keys +in the cluster, and ‘CLUSTER NODES’ command will be sent to a random node. +Other management commands will require you to pass the target node/s to execute +the command on. + +``` pycon + >>> # target-nodes: the node that holds 'foo1's key slot + >>> rc.set('foo1', 'bar1') + >>> # target-nodes: the node that holds 'foo2's key slot + >>> rc.set('foo2', 'bar2') + >>> # target-nodes: the node that holds 'foo1's key slot + >>> print(rc.get('foo1')) + b'bar' + >>> # target-nodes: all-primaries + >>> print(rc.keys()) + [b'foo1', b'foo2'] + >>> # target-nodes: all-nodes + >>> rc.flushall() +``` + +**Specifying Target Nodes:** + +As mentioned above, some RedisCluster commands will require you to provide the +target node/s that you want to execute the command on, and in other cases, the +target node will be determined by the client itself. That being said, ALL +RedisCluster commands can be executed against a specific node or a group of +nodes by passing the command kwarg `target_nodes`. +The best practice is to specify target nodes using RedisCluster class's node +flags: PRIMARIES, REPLICAS, ALL_NODES, RANDOM. When a nodes flag is passed +along with a command, it will be internally resolved to the relevant node/s. +If the nodes topology of the cluster changes during the execution of a command, +the client will be able to resolve the nodes flag again with the new topology +and attempt to retry executing the command. + +``` pycon + >>> from redis.cluster import RedisCluster as Redis + >>> # run cluster-meet command on all of the cluster's nodes + >>> rc.cluster_meet(Redis.ALL_NODES, '127.0.0.1', 6379) + >>> # ping all replicas + >>> rc.ping(Redis.REPLICAS) + >>> # ping a specific node + >>> rc.ping(Redis.RANDOM) + >>> # ping all nodes in the cluster, default command behavior + >>> rc.ping() + >>> # execute bgsave in all primaries + >>> rc.bgsave(Redis.PRIMARIES) +``` + +You could also pass ClusterNodes directly if you want to execute a command on a +specific node / node group that isn't addressed by the nodes flag. However, if +the command execution fails due to cluster topology changes, a retry attempt +will not be made, since the passed target node/s may no longer be valid, and +the relevant cluster or connection error will be returned. + +``` pycon + >>> node = rc.get_node('localhost', 6379) + >>> # Get the keys only for that specific node + >>> rc.keys(node) + >>> # get Redis info from a subset of primaries + >>> subset_primaries = [node for node in rc.get_primaries() if node.port > 6378] + >>> rc.info(subset_primaries) +``` + +In addition, you can use the RedisCluster instance to obtain the Redis instance +of a specific node and execute commands on that node directly. The Redis client, +however, cannot handle cluster failures and retries. + +``` pycon + >>> cluster_node = rc.get_node(host='localhost', port=6379) + >>> print(cluster_node) + [host=127.0.0.1,port=6379,name=127.0.0.1:6379,server_type=primary,redis_connection=Redis>>] + >>> r = cluster_node.redis_connection + >>> r.client_list() + [{'id': '276', 'addr': '127.0.0.1:64108', 'fd': '16', 'name': '', 'age': '0', 'idle': '0', 'flags': 'N', 'db': '0', 'sub': '0', 'psub': '0', 'multi': '-1', 'qbuf': '26', 'qbuf-free': '32742', 'argv-mem': '10', 'obl': '0', 'oll': '0', 'omem': '0', 'tot-mem': '54298', 'events': 'r', 'cmd': 'client', 'user': 'default'}] + >>> # Get the keys only for that specific node + >>> r.keys() + [b'foo1'] +``` + +**Multi-key commands:** + +Redis supports multi-key commands in Cluster Mode, such as Set type unions or +intersections, mset and mget, as long as the keys all hash to the same slot. +By using RedisCluster client, you can use the known functions (e.g. mget, mset) +to perform an atomic multi-key operation. However, you must ensure all keys are +mapped to the same slot, otherwise a RedisClusterException will be thrown. +Redis Cluster implements a concept called hash tags that can be used in order +to force certain keys to be stored in the same hash slot, see +[Keys hash tag](https://redis.io/topics/cluster-spec#keys-hash-tags). +You can also use nonatomic for some of the multikey operations, and pass keys +that aren't mapped to the same slot. The client will then map the keys to the +relevant slots, sending the commands to the slots' node owners. Non-atomic +operations batch the keys according to their hash value, and then each batch is +sent separately to the slot's owner. + +``` pycon + # Atomic operations can be used when all keys are mapped to the same slot + >>> rc.mset({'{foo}1': 'bar1', '{foo}2': 'bar2'}) + >>> rc.mget('{foo}1', '{foo}2') + [b'bar1', b'bar2'] + # Non-atomic multi-key operations splits the keys into different slots + >>> rc.mset_nonatomic({'foo': 'value1', 'bar': 'value2', 'zzz': 'value3') + >>> rc.mget_nonatomic('foo', 'bar', 'zzz') + [b'value1', b'value2', b'value3'] +``` + +**Cluster PubSub:** + +When a ClusterPubSub instance is created without specifying a node, a single +node will be transparently chosen for the pubsub connection on the +first command execution. The node will be determined by: + 1. Hashing the channel name in the request to find its keyslot + 2. Selecting a node that handles the keyslot: If read_from_replicas is + set to true, a replica can be selected. + +*Known limitations with pubsub:* + +Pattern subscribe and publish do not work properly because if we hash a pattern +like fo* we will get a keyslot for that string but there is a endless +possibilities of channel names based on that pattern that we can’t know in +advance. This feature is not limited but the commands is not recommended to use +right now. +See [redis-py-cluster documentaion](https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html) + for more. + +``` pycon + >>> p1 = rc.pubsub() + # p1 connection will be set to the node that holds 'foo' keyslot + >>> p1.subscribe('foo') + # p2 connection will be set to node 'localhost:6379' + >>> p2 = rc.pubsub(rc.get_node('localhost', 6379)) +``` + +**Read Only Mode** + +By default, Redis Cluster always returns MOVE redirection response on accessing +a replica node. You can overcome this limitation and scale read commands with +READONLY mode. + +To enable READONLY mode pass read_from_replicas=True to RedisCluster +constructor. When set to true, read commands will be assigned between the +primary and its replications in a Round-Robin manner. + +You could also enable READONLY mode in runtime by running readonly() method, +or disable it with readwrite(). + +``` pycon + >>> from cluster import RedisCluster as Redis + # Use 'debug' mode to print the node that the command is executed on + >>> rc_readonly = Redis(startup_nodes=startup_nodes, + read_from_replicas=True, debug=True) + >>> rc_readonly.set('{foo}1', 'bar1') + >>> for i in range(0, 4): + # Assigns read command to the slot's hosts in a Round-Robin manner + >>> rc_readonly.get('{foo}1') + # set command would be directed only to the slot's primary node + >>> rc_readonly.set('{foo}2', 'bar2') + # reset READONLY flag + >>> rc_readonly.readwrite() + # now the get command would be directed only to the slot's primary node + >>> rc_readonly.get('{foo}1') +``` + + + +See [Redis Cluster tutorial](https://redis.io/topics/cluster-tutorial) and +[Redis Cluster specifications](https://redis.io/topics/cluster-spec) +to learn more about Redis Cluster. ### Author diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docker-entry.sh b/docker-entry.sh new file mode 100755 index 0000000000..e69de29bb2 diff --git a/docker/base/Dockerfile.cluster b/docker/base/Dockerfile.cluster new file mode 100644 index 0000000000..7c9b4a9b9c --- /dev/null +++ b/docker/base/Dockerfile.cluster @@ -0,0 +1,8 @@ +FROM redis:6.2.6-buster + +COPY ../cluster/create_cluster.sh /create_cluster.sh +RUN chmod +x /create_cluster.sh + +EXPOSE 16379 16380 16381 16382 16383 16384 + +CMD [ "/create_cluster.sh"] \ No newline at end of file diff --git a/docker/cluster/create_cluster.sh b/docker/cluster/create_cluster.sh new file mode 100644 index 0000000000..d339f1836a --- /dev/null +++ b/docker/cluster/create_cluster.sh @@ -0,0 +1,21 @@ +#! /bin/bash +mkdir -p /nodes +echo -n > /nodes/nodemap +for PORT in $(seq 16379 16384); do + mkdir -p /nodes/$PORT + if [[ -e /redis.conf ]]; then + cp /redis.conf /nodes/$PORT/redis.conf + else + touch /nodes/$PORT/redis.conf + fi + cat << EOF >> /nodes/$PORT/redis.conf +port $PORT +daemonize yes +logfile /redis.log +dir /nodes/$PORT +EOF + redis-server /nodes/$PORT/redis.conf + echo 127.0.0.1:$PORT >> /nodes/nodemap +done +echo yes | redis-cli --cluster create $(seq -f 127.0.0.1:%g 16379 16384) --cluster-replicas 1 +tail -f /redis.log diff --git a/docker/cluster/redis.conf b/docker/cluster/redis.conf new file mode 100644 index 0000000000..cc22e16ffe --- /dev/null +++ b/docker/cluster/redis.conf @@ -0,0 +1,3 @@ +# Redis Cluster config file will be shared across all nodes. +# Dont pass node-unique arguments (e.g. port, dir). +cluster-enabled yes diff --git a/redis/__init__.py b/redis/__init__.py index 2458b5bc49..47adaa8c83 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -1,4 +1,5 @@ from redis.client import Redis, StrictRedis +from redis.cluster import RedisCluster from redis.connection import ( BlockingConnectionPool, ConnectionPool, @@ -49,6 +50,7 @@ def int_or_str(value): 'PubSubError', 'ReadOnlyError', 'Redis', + 'RedisCluster', 'RedisError', 'ResponseError', 'SSLConnection', diff --git a/redis/client.py b/redis/client.py index 986af7cfba..93979569b1 100755 --- a/redis/client.py +++ b/redis/client.py @@ -460,6 +460,7 @@ def _parse_node_line(line): line_items = line.split(' ') node_id, addr, flags, master_id, ping, pong, epoch, \ connected = line.split(' ')[:8] + addr = addr.split('@')[0] slots = [sl.split('-') for sl in line_items[8:]] node_dict = { 'node_id': node_id, @@ -475,8 +476,13 @@ def _parse_node_line(line): def parse_cluster_nodes(response, **options): - raw_lines = str_if_bytes(response).splitlines() - return dict(_parse_node_line(line) for line in raw_lines) + """ + @see: http://redis.io/commands/cluster-nodes # string + @see: http://redis.io/commands/cluster-replicas # list of string + """ + if isinstance(response, str): + response = response.splitlines() + return dict(_parse_node_line(str_if_bytes(node)) for node in response) def parse_geosearch_generic(response, **options): @@ -515,6 +521,21 @@ def parse_geosearch_generic(response, **options): ] +def parse_command(response, **options): + commands = {} + for command in response: + cmd_dict = {} + cmd_name = str_if_bytes(command[0]) + cmd_dict['name'] = cmd_name + cmd_dict['arity'] = str_if_bytes(command[1]) + cmd_dict['flags'] = [str_if_bytes(flag) for flag in command[2]] + cmd_dict['first_key_pos'] = command[3] + cmd_dict['last_key_pos'] = command[4] + cmd_dict['step_count'] = command[5] + commands[cmd_name] = cmd_dict + return commands + + def parse_pubsub_numsub(response, **options): return list(zip(response[0::2], response[1::2])) @@ -700,8 +721,10 @@ class Redis(RedisModuleCommands, CoreCommands, object): 'CLUSTER SET-CONFIG-EPOCH': bool_ok, 'CLUSTER SETSLOT': bool_ok, 'CLUSTER SLAVES': parse_cluster_nodes, - 'COMMAND': int, + 'CLUSTER REPLICAS': parse_cluster_nodes, + 'COMMAND': parse_command, 'COMMAND COUNT': int, + 'COMMAND GETKEYS': lambda r: list(map(str_if_bytes, r)), 'CONFIG GET': parse_config_get, 'CONFIG RESETSTAT': bool_ok, 'CONFIG SET': bool_ok, @@ -824,7 +847,7 @@ def __init__(self, host='localhost', port=6379, ssl_check_hostname=False, max_connections=None, single_connection_client=False, health_check_interval=0, client_name=None, username=None, - retry=None): + retry=None, redis_connect_func=None): """ Initialize a new Redis client. To specify a retry policy, first set `retry_on_timeout` to `True` @@ -852,7 +875,8 @@ def __init__(self, host='localhost', port=6379, 'retry': copy.deepcopy(retry), 'max_connections': max_connections, 'health_check_interval': health_check_interval, - 'client_name': client_name + 'client_name': client_name, + 'redis_connect_func': redis_connect_func } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -1188,14 +1212,16 @@ class PubSub: HEALTH_CHECK_MESSAGE = 'redis-py-health-check' def __init__(self, connection_pool, shard_hint=None, - ignore_subscribe_messages=False): + ignore_subscribe_messages=False, encoder=None): self.connection_pool = connection_pool self.shard_hint = shard_hint self.ignore_subscribe_messages = ignore_subscribe_messages self.connection = None # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. - self.encoder = self.connection_pool.get_encoder() + self.encoder = encoder + if self.encoder is None: + self.encoder = self.connection_pool.get_encoder() if self.encoder.decode_responses: self.health_check_response = ['pong', self.HEALTH_CHECK_MESSAGE] else: diff --git a/redis/cluster.py b/redis/cluster.py new file mode 100644 index 0000000000..e3976dcb03 --- /dev/null +++ b/redis/cluster.py @@ -0,0 +1,1883 @@ +import copy +import random +import socket +import time +import threading +import warnings +import sys + +from collections import OrderedDict +from redis.client import CaseInsensitiveDict, Redis, PubSub +from redis.commands import ( + ClusterCommands, + CommandsParser +) +from redis.connection import DefaultParser, ConnectionPool, Encoder, parse_url +from redis.crc import key_slot, REDIS_CLUSTER_HASH_SLOTS +from redis.exceptions import ( + AskError, + BusyLoadingError, + ClusterCrossSlotError, + ClusterDownError, + ClusterError, + DataError, + MasterDownError, + MovedError, + RedisClusterException, + RedisError, + ResponseError, + SlotNotCoveredError, + TimeoutError, + TryAgainError, +) +from redis.utils import ( + dict_merge, + list_keys_to_dict, + merge_result, + str_if_bytes, + safe_str +) + + +def get_node_name(host, port): + return '{0}:{1}'.format(host, port) + + +def get_connection(redis_node, *args, **options): + return redis_node.connection or redis_node.connection_pool.get_connection( + args[0], **options + ) + + +def parse_pubsub_numsub(command, res, **options): + numsub_d = OrderedDict() + for numsub_tups in res.values(): + for channel, numsubbed in numsub_tups: + try: + numsub_d[channel] += numsubbed + except KeyError: + numsub_d[channel] = numsubbed + + ret_numsub = [ + (channel, numsub) + for channel, numsub in numsub_d.items() + ] + return ret_numsub + + +def parse_cluster_slots(resp, **options): + current_host = options.get('current_host', '') + + def fix_server(*args): + return str_if_bytes(args[0]) or current_host, args[1] + + slots = {} + for slot in resp: + start, end, primary = slot[:3] + replicas = slot[3:] + slots[start, end] = { + 'primary': fix_server(*primary), + 'replicas': [fix_server(*replica) for replica in replicas], + } + + return slots + + +PRIMARY = "primary" +REPLICA = "replica" +SLOT_ID = 'slot-id' + +REDIS_ALLOWED_KEYS = ( + "charset", + "connection_class", + "connection_pool", + "db", + "decode_responses", + "encoding", + "encoding_errors", + "errors", + "host", + "max_connections", + "nodes_flag", + "redis_connect_func", + "password", + "port", + "retry_on_timeout", + "socket_connect_timeout", + "socket_keepalive", + "socket_keepalive_options", + "socket_timeout", + "ssl", + "ssl_ca_certs", + "ssl_certfile", + "ssl_cert_reqs", + "ssl_keyfile", + "unix_socket_path", + "username", +) +KWARGS_DISABLED_KEYS = ( + "host", + "port", +) + +# Not complete, but covers the major ones +# https://redis.io/commands +READ_COMMANDS = frozenset([ + "BITCOUNT", + "BITPOS", + "EXISTS", + "GEODIST", + "GEOHASH", + "GEOPOS", + "GEORADIUS", + "GEORADIUSBYMEMBER", + "GET", + "GETBIT", + "GETRANGE", + "HEXISTS", + "HGET", + "HGETALL", + "HKEYS", + "HLEN", + "HMGET", + "HSTRLEN", + "HVALS", + "KEYS", + "LINDEX", + "LLEN", + "LRANGE", + "MGET", + "PTTL", + "RANDOMKEY", + "SCARD", + "SDIFF", + "SINTER", + "SISMEMBER", + "SMEMBERS", + "SRANDMEMBER", + "STRLEN", + "SUNION", + "TTL", + "ZCARD", + "ZCOUNT", + "ZRANGE", + "ZSCORE", +]) + + +def cleanup_kwargs(**kwargs): + """ + Remove unsupported or disabled keys from kwargs + """ + connection_kwargs = { + k: v + for k, v in kwargs.items() + if k in REDIS_ALLOWED_KEYS and k not in KWARGS_DISABLED_KEYS + } + + return connection_kwargs + + +class ClusterParser(DefaultParser): + EXCEPTION_CLASSES = dict_merge( + DefaultParser.EXCEPTION_CLASSES, { + 'ASK': AskError, + 'TRYAGAIN': TryAgainError, + 'MOVED': MovedError, + 'CLUSTERDOWN': ClusterDownError, + 'CROSSSLOT': ClusterCrossSlotError, + 'MASTERDOWN': MasterDownError, + }) + + +class RedisCluster(ClusterCommands, object): + RedisClusterRequestTTL = 16 + + PRIMARIES = "all-primaries" + REPLICAS = "all-replicas" + ALL_NODES = "all-nodes" + RANDOM = "random" + + NODE_FLAGS = { + PRIMARIES, + REPLICAS, + ALL_NODES, + RANDOM + } + + COMMAND_FLAGS = dict_merge( + list_keys_to_dict( + [ + "CLIENT LIST", + "CLIENT SETNAME", + "CLIENT GETNAME", + "CONFIG GET", + "CONFIG SET", + "CONFIG REWRITE", + "CONFIG RESETSTAT", + "TIME", + "PUBSUB CHANNELS", + "PUBSUB NUMPAT", + "PUBSUB NUMSUB", + "PING", + "INFO", + "SHUTDOWN" + ], + ALL_NODES, + ), + list_keys_to_dict( + [ + "KEYS", + "SCAN", + "FLUSHALL", + "FLUSHDB", + "DBSIZE", + "BGSAVE", + "SLOWLOG GET", + "SLOWLOG LEN", + "SLOWLOG RESET", + "WAIT", + "TIME", + "SAVE", + "MEMORY PURGE", + "MEMORY MALLOC-STATS", + "MEMORY STATS", + "LASTSAVE", + "CLIENT TRACKINGINFO", + "CLIENT PAUSE", + "CLIENT UNPAUSE", + "CLIENT UNBLOCK", + "CLIENT ID", + "CLIENT REPLY", + "CLIENT GETREDIR", + "CLIENT INFO", + "CLIENT KILL" + ], + PRIMARIES, + ), + list_keys_to_dict( + [ + "READONLY", + "READWRITE", + ], + REPLICAS, + ), + list_keys_to_dict( + [ + "CLUSTER INFO", + "CLUSTER NODES", + "CLUSTER REPLICAS", + "CLUSTER SLOTS", + "CLUSTER COUNT-FAILURE-REPORTS", + "CLUSTER KEYSLOT", + "RANDOMKEY", + "COMMAND", + "COMMAND GETKEYS", + "DEBUG", + ], + RANDOM, + ), + list_keys_to_dict( + [ + "CLUSTER COUNTKEYSINSLOT", + "CLUSTER DELSLOTS", + "CLUSTER GETKEYSINSLOT", + "CLUSTER SETSLOT", + ], + SLOT_ID, + ), + ) + + CLUSTER_COMMANDS_RESPONSE_CALLBACKS = { + 'CLUSTER ADDSLOTS': bool, + 'CLUSTER COUNT-FAILURE-REPORTS': int, + 'CLUSTER COUNTKEYSINSLOT': int, + 'CLUSTER DELSLOTS': bool, + 'CLUSTER FAILOVER': bool, + 'CLUSTER FORGET': bool, + 'CLUSTER GETKEYSINSLOT': list, + 'CLUSTER KEYSLOT': int, + 'CLUSTER MEET': bool, + 'CLUSTER REPLICATE': bool, + 'CLUSTER RESET': bool, + 'CLUSTER SAVECONFIG': bool, + 'CLUSTER SET-CONFIG-EPOCH': bool, + 'CLUSTER SETSLOT': bool, + 'CLUSTER SLOTS': parse_cluster_slots, + 'ASKING': bool, + 'READONLY': bool, + 'READWRITE': bool, + } + + RESULT_CALLBACKS = dict_merge( + list_keys_to_dict([ + "PUBSUB NUMSUB", + ], parse_pubsub_numsub), + list_keys_to_dict([ + "PUBSUB NUMPAT", + ], lambda command, res: sum(list(res.values()))), + list_keys_to_dict([ + "KEYS", + "PUBSUB CHANNELS", + ], merge_result), + list_keys_to_dict([ + "PING", + "CONFIG SET", + "CONFIG REWRITE", + "CONFIG RESETSTAT", + "CLIENT SETNAME", + "BGSAVE", + "SLOWLOG RESET", + "SAVE", + "MEMORY PURGE", + "CLIENT PAUSE", + "CLIENT UNPAUSE", + ], lambda command, res: all(res.values()) if isinstance(res, dict) + else res), + list_keys_to_dict([ + "DBSIZE", + "WAIT", + ], lambda command, res: sum(res.values()) if isinstance(res, dict) + else res), + list_keys_to_dict([ + "CLIENT UNBLOCK", + ], lambda command, res: 1 if sum(res.values()) > 0 else 0) + ) + + def __init__( + self, + host=None, + port=6379, + startup_nodes=None, + cluster_error_retry_attempts=3, + require_full_coverage=True, + skip_full_coverage_check=False, + reinitialize_steps=10, + read_from_replicas=False, + url=None, + debug=False, + **kwargs + ): + """ + :startup_nodes: 'list[ClusterNode]' + List of nodes from which initial bootstrapping can be done + :host: 'str' + Can be used to point to a startup node + :port: 'int' + Can be used to point to a startup node + :require_full_coverage: 'bool' + If set to True, as it is by default, all slots must be covered. + If set to False and not all slots are covered, the instance + creation will succeed only if 'cluster-require-full-coverage' + configuration is set to 'no' in all of the cluster's nodes. + Otherwise, RedisClusterException will be thrown. + :skip_full_coverage_check: 'bool' + If require_full_coverage is set to False, a check of + cluster-require-full-coverage config will be executed against all + nodes. Set skip_full_coverage_check to True to skip this check. + Useful for clusters without the CONFIG command (like ElastiCache) + :read_from_replicas: 'bool' + Enable read from replicas in READONLY mode. You can read possibly + stale data. + When set to true, read commands will be assigned between the + primary and its replications in a Round-Robin manner. + :cluster_error_retry_attempts: 'int' + Retry command execution attempts when encountering ClusterDownError + or ConnectionError + :debug: + Add prints to debug the RedisCluster client + + :**kwargs: + Extra arguments that will be sent into Redis instance when created + (See Official redis-py doc for supported kwargs + [https://github.com/andymccurdy/redis-py/blob/master/redis/client.py]) + Some kwargs are not supported and will raise a + RedisClusterException: + - db (Redis do not support database SELECT in cluster mode) + """ + + if startup_nodes is None: + startup_nodes = [] + + if "db" in kwargs: + # Argument 'db' is not possible to use in cluster mode + raise RedisClusterException( + "Argument 'db' is not possible to use in cluster mode" + ) + + # Get the startup node/s + from_url = False + if url is not None: + from_url = True + url_options = parse_url(url) + if "path" in url_options: + raise RedisClusterException( + "RedisCluster does not currently support Unix Domain " + "Socket connections") + if "db" in url_options and url_options["db"] != 0: + # Argument 'db' is not possible to use in cluster mode + raise RedisClusterException( + "A ``db`` querystring option can only be 0 in cluster mode" + ) + kwargs.update(url_options) + startup_nodes.append(ClusterNode(kwargs['host'], kwargs['port'])) + elif host is not None and port is not None: + startup_nodes.append(ClusterNode(host, port)) + elif len(startup_nodes) == 0: + # No startup node was provided + raise RedisClusterException( + "RedisCluster requires at least one node to discover the " + "cluster. Please provide one of the followings:\n" + "1. host and port, for example:\n" + " RedisCluster(host='localhost', port=6379)\n" + "2. list of startup nodes, for example:\n" + " RedisCluster(startup_nodes=[ClusterNode('localhost', 6379)," + " ClusterNode('localhost', 6378)])") + + # Update the connection arguments + # Whenever a new connection is established, RedisCluster's on_connect + # method should be run + # If the user passed on_connect function we'll save it and run it + # inside the RedisCluster.on_connect() function + self.user_on_connect_func = kwargs.pop("redis_connect_func", None) + kwargs.update({"redis_connect_func": self.on_connect}) + kwargs = cleanup_kwargs(**kwargs) + + self.encoder = Encoder( + kwargs.get("encoding", "utf-8"), + kwargs.get("encoding_errors", "strict"), + kwargs.get("decode_responses", False), + ) + self.cluster_error_retry_attempts = cluster_error_retry_attempts + self.command_flags = self.__class__.COMMAND_FLAGS.copy() + self.node_flags = self.__class__.NODE_FLAGS.copy() + self.debug_mode = debug + self.read_from_replicas = read_from_replicas + self.reinitialize_counter = 0 + self.reinitialize_steps = reinitialize_steps + self.nodes_manager = None + self.nodes_manager = NodesManager( + startup_nodes=startup_nodes, + from_url=from_url, + require_full_coverage=require_full_coverage, + skip_full_coverage_check=skip_full_coverage_check, + **kwargs, + ) + + self.cluster_response_callbacks = CaseInsensitiveDict( + self.__class__.CLUSTER_COMMANDS_RESPONSE_CALLBACKS) + self.result_callbacks = CaseInsensitiveDict( + self.__class__.RESULT_CALLBACKS) + self.commands_parser = CommandsParser(self) + self._lock = threading.Lock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def __del__(self): + self.close() + + def disconnect_connection_pools(self): + for node in self.get_nodes(): + if node.redis_connection: + node.redis_connection.connection_pool.disconnect() + + @classmethod + def from_url(cls, url, **kwargs): + """ + Return a Redis client object configured from the given URL + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[[username]:[password]]@/path/to/socket.sock?db=0 + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments + and keyword arguments are passed to the ``ConnectionPool``'s + class initializer. In the case of conflicting arguments, querystring + arguments always win. + + """ + return cls(url=url, **kwargs) + + def on_connect(self, connection): + """ + Initialize the connection, authenticate and select a database and send + READONLY if it is set during object initialization. + """ + connection.set_parser(ClusterParser) + connection.on_connect() + + if self.read_from_replicas: + # Sending READONLY command to server to configure connection as + # readonly. Since each cluster node may change its server type due + # to a failover, we should establish a READONLY connection + # regardless of the server type. If this is a primary connection, + # READONLY would not affect executing write commands. + connection.send_command('READONLY') + if str_if_bytes(connection.read_response()) != 'OK': + raise ConnectionError('READONLY command failed') + + if self.user_on_connect_func is not None: + self.user_on_connect_func(connection) + + def get_redis_connection(self, node): + if not node.redis_connection: + with self._lock: + if not node.redis_connection: + self.nodes_manager.create_redis_connections([node]) + return node.redis_connection + + def get_node(self, host=None, port=None, node_name=None): + return self.nodes_manager.get_node(host, port, node_name) + + def get_primaries(self): + return self.nodes_manager.get_nodes_by_server_type(PRIMARY) + + def get_replicas(self): + return self.nodes_manager.get_nodes_by_server_type(REPLICA) + + def get_random_node(self): + return random.choice(list(self.nodes_manager.nodes_cache.values())) + + def get_nodes(self): + return list(self.nodes_manager.nodes_cache.values()) + + def pubsub(self, node=None, host=None, port=None, **kwargs): + """ + Allows passing a ClusterNode, or host&port, to get a pubsub instance + connected to the specified node + """ + return ClusterPubSub(self, node=node, host=host, port=port, **kwargs) + + def pipeline(self, transaction=None, + shard_hint=None, read_from_replicas=False): + """ + Cluster impl: + Pipelines do not work in cluster mode the same way they + do in normal mode. Create a clone of this object so + that simulating pipelines will work correctly. Each + command will be called directly when used and + when calling execute() will only return the result stack. + """ + if shard_hint: + raise RedisClusterException( + "shard_hint is deprecated in cluster mode") + + if transaction: + raise RedisClusterException( + "transaction is deprecated in cluster mode") + + return ClusterPipeline( + nodes_manager=self.nodes_manager, + startup_nodes=self.nodes_manager.startup_nodes, + result_callbacks=self.result_callbacks, + cluster_response_callbacks=self.cluster_response_callbacks, + cluster_error_retry_attempts=self.cluster_error_retry_attempts, + read_from_replicas=read_from_replicas, + ) + + def _determine_nodes(self, *args, **kwargs): + command = args[0] + nodes_flag = kwargs.pop("nodes_flag", None) + if nodes_flag is not None: + # nodes flag passed by the user + command_flag = nodes_flag + else: + # get the predefined nodes group for this command + command_flag = self.command_flags.get(command) + + if command_flag == self.__class__.RANDOM: + return [self.get_random_node()] + elif command_flag == self.__class__.PRIMARIES: + return self.get_primaries() + elif command_flag == self.__class__.REPLICAS: + return self.get_replicas() + elif command_flag == self.__class__.ALL_NODES: + return self.get_nodes() + else: + # get the node that holds the key's slot + slot = self.determine_slot(*args) + return [self.nodes_manager. + get_node_from_slot(slot, self.read_from_replicas + and command in READ_COMMANDS)] + + def _should_reinitialized(self): + # In order not to reinitialize the cluster, the user can set + # reinitialize_steps to 0. + if self.reinitialize_steps == 0: + return False + else: + return self.reinitialize_counter % self.reinitialize_steps == 0 + + def keyslot(self, key): + """ + Calculate keyslot for a given key. + """ + k = self.encoder.encode(key) + return key_slot(k) + + def determine_slot(self, *args): + """ + figure out what slot based on command and args + """ + if self.command_flags.get(args[0]) == SLOT_ID: + # The command contains the slot ID + return args[1] + + redis_conn = self.get_random_node().redis_connection + keys = self.commands_parser.get_keys(redis_conn, *args) + if keys is None or len(keys) == 0: + raise RedisClusterException( + "No way to dispatch this command to Redis Cluster. " + "Missing key.\nYou can execute the command by specifying " + "target nodes.\nCommand: {0}".format(args) + ) + + if len(keys) > 1: + # multi-key command, we need to make sure all keys are mapped to + # the same slot + slots = {self.keyslot(key) for key in keys} + if len(slots) != 1: + raise RedisClusterException("{0} - all keys must map to the " + "same key slot".format(args[0])) + return slots.pop() + else: + # single key command + return self.keyslot(keys[0]) + + def reinitialize_caches(self): + self.nodes_manager.initialize() + + def _is_nodes_flag(self, target_nodes): + return isinstance(target_nodes, str) \ + and target_nodes in self.node_flags + + def _parse_target_nodes(self, target_nodes): + if isinstance(target_nodes, list): + nodes = target_nodes + elif isinstance(target_nodes, ClusterNode): + # Supports passing a single ClusterNode as a variable + nodes = [target_nodes] + elif isinstance(target_nodes, dict): + # Supports dictionaries of the format {node_name: node}. + # It enables to execute commands with multi nodes as follows: + # rc.cluster_save_config(rc.get_primaries()) + nodes = target_nodes.values() + else: + raise TypeError("target_nodes type can be one of the " + "followings: node_flag (PRIMARIES, " + "REPLICAS, RANDOM, ALL_NODES)," + "ClusterNode, list, or " + "dict. The passed type is {0}". + format(type(target_nodes))) + return nodes + + def execute_command(self, *args, **kwargs): + """ + Wrapper for ClusterDownError and ConnectionError error handling. + + It will try the number of times specified by the config option + "self.cluster_error_retry_attempts" which defaults to 3 unless manually + configured. + + If it reaches the number of times, the command will raise the exception + + Key argument :target_nodes: can be passed with the following types: + nodes_flag: PRIMARIES, REPLICAS, ALL_NODES, RANDOM + ClusterNode + list + dict + """ + target_nodes_specified = False + target_nodes = kwargs.pop("target_nodes", None) + if target_nodes is not None and not self._is_nodes_flag(target_nodes): + target_nodes = self._parse_target_nodes(target_nodes) + target_nodes_specified = True + # If ClusterDownError/ConnectionError were thrown, the nodes + # and slots cache were reinitialized. We will retry executing the + # command with the updated cluster setup only when the target nodes + # can be determined again with the new cache tables. Therefore, + # when target nodes were passed to this function, we cannot retry + # the command execution since the nodes may not be valid anymore + # after the tables were reinitialized. So in case of passed target + # nodes, retry_attempts will be set to 1. + retry_attempts = 1 if target_nodes_specified else \ + self.cluster_error_retry_attempts + exception = None + for _ in range(0, retry_attempts): + try: + res = {} + if not target_nodes_specified: + # Determine the nodes to execute the command on + target_nodes = self._determine_nodes( + *args, **kwargs, nodes_flag=target_nodes) + if not target_nodes: + raise RedisClusterException( + "No targets were found to execute" + " {} command on".format(args)) + for node in target_nodes: + res[node.name] = self._execute_command( + node, *args, **kwargs) + # Return the processed result + return self._process_result(args[0], res, **kwargs) + except (ClusterDownError, ConnectionError) as e: + # The nodes and slots cache were reinitialized. + # Try again with the new cluster setup. All other errors + # should be raised. + exception = e + + # If it fails the configured number of times then raise exception back + # to caller of this method + raise exception + + def _execute_command(self, target_node, *args, **kwargs): + """ + Send a command to a node in the cluster + """ + command = args[0] + redis_node = None + connection = None + redirect_addr = None + asking = False + moved = False + ttl = int(self.RedisClusterRequestTTL) + connection_error_retry_counter = 0 + + while ttl > 0: + ttl -= 1 + try: + if asking: + target_node = self.get_node(node_name=redirect_addr) + elif moved: + # MOVED occurred and the slots cache was updated, + # refresh the target node + slot = self.determine_slot(*args) + target_node = self.nodes_manager. \ + get_node_from_slot(slot, self.read_from_replicas and + command in READ_COMMANDS) + moved = False + + if self.debug_mode: + print("Executing command {0} on target node: {1} {2}". + format(command, target_node.server_type, + target_node.name)) + redis_node = self.get_redis_connection(target_node) + connection = get_connection(redis_node, *args, **kwargs) + if asking: + connection.send_command("ASKING") + redis_node.parse_response(connection, "ASKING", **kwargs) + asking = False + + connection.send_command(*args) + response = redis_node.parse_response(connection, command, + **kwargs) + if command in self.cluster_response_callbacks: + response = self.cluster_response_callbacks[command]( + response, **kwargs) + return response + + except (RedisClusterException, BusyLoadingError): + warnings.warn("RedisClusterException || BusyLoadingError") + raise + except ConnectionError: + warnings.warn("ConnectionError") + # ConnectionError can also be raised if we couldn't get a + # connection from the pool before timing out, so check that + # this is an actual connection before attempting to disconnect. + if connection is not None: + connection.disconnect() + connection_error_retry_counter += 1 + + # Give the node 0.25 seconds to get back up and retry again + # with same node and configuration. After 5 attempts then try + # to reinitialize the cluster and see if the nodes + # configuration has changed or not + if connection_error_retry_counter < 5: + time.sleep(0.25) + else: + # Hard force of reinitialize of the node/slots setup + # and try again with the new setup + self.nodes_manager.initialize() + raise + except TimeoutError: + warnings.warn("TimeoutError") + if connection is not None: + connection.disconnect() + + if ttl < self.RedisClusterRequestTTL / 2: + time.sleep(0.05) + except MovedError as e: + # First, we will try to patch the slots/nodes cache with the + # redirected node output and try again. If MovedError exceeds + # 'reinitialize_steps' number of times, we will force + # reinitializing the tables, and then try again. + # 'reinitialize_steps' counter will increase faster when the + # same client object is shared between multiple threads. To + # reduce the frequency you can set this variable in the + # RedisCluster constructor. + warnings.warn("MovedError") + self.reinitialize_counter += 1 + if self._should_reinitialized(): + self.nodes_manager.initialize() + else: + self.nodes_manager.update_moved_exception(e) + moved = True + except TryAgainError: + warnings.warn("TryAgainError") + + if ttl < self.RedisClusterRequestTTL / 2: + time.sleep(0.05) + except AskError as e: + warnings.warn("AskError") + + redirect_addr = get_node_name(host=e.host, port=e.port) + asking = True + except ClusterDownError as e: + warnings.warn("ClusterDownError") + # ClusterDownError can occur during a failover and to get + # self-healed, we will try to reinitialize the cluster layout + # and retry executing the command + time.sleep(0.05) + self.nodes_manager.initialize() + raise e + except ResponseError as e: + message = e.__str__() + warnings.warn("ResponseError: {0}".format(message)) + raise e + except BaseException as e: + warnings.warn("BaseException") + if connection: + connection.disconnect() + raise e + finally: + if connection is not None: + redis_node.connection_pool.release(connection) + + raise ClusterError("TTL exhausted.") + + def close(self): + try: + with self._lock: + if self.nodes_manager: + self.nodes_manager.close() + except AttributeError: + # RedisCluster's __init__ can fail before nodes_manager is set + pass + + def _process_result(self, command, res, **kwargs): + """ + Process the result of the executed command. + The function would return a dict or a single value. + + :type command: str + :type res: dict + + `res` should be in the following format: + Dict + """ + if command in self.result_callbacks: + return self.result_callbacks[command](command, res, **kwargs) + elif len(res) == 1: + # When we execute the command on a single node, we can + # remove the dictionary and return a single response + return list(res.values())[0] + else: + return res + + +class ClusterNode(object): + def __init__(self, host, port, server_type=None, redis_connection=None): + if host == 'localhost': + host = socket.gethostbyname(host) + + self.host = host + self.port = port + self.name = get_node_name(host, port) + self.server_type = server_type + self.redis_connection = redis_connection + + def __repr__(self): + return '[host={0},port={1},' \ + 'name={2},server_type={3},redis_connection={4}]' \ + .format(self.host, + self.port, + self.name, + self.server_type, + self.redis_connection) + + def __eq__(self, obj): + return isinstance(obj, ClusterNode) and obj.name == self.name + + +class LoadBalancer: + """ + Round-Robin Load Balancing + """ + + def __init__(self, start_index=0): + self.primary_to_idx = {} + self.start_index = start_index + + def get_server_index(self, primary, list_size): + server_index = self.primary_to_idx.setdefault(primary, + self.start_index) + # Update the index + self.primary_to_idx[primary] = (server_index + 1) % list_size + return server_index + + def reset(self): + self.primary_to_idx.clear() + + +class NodesManager: + def __init__(self, startup_nodes, from_url=False, + require_full_coverage=True, skip_full_coverage_check=False, + lock=None, **kwargs): + self.nodes_cache = {} + self.slots_cache = {} + self.startup_nodes = {} + self.populate_startup_nodes(startup_nodes) + self.from_url = from_url + self._require_full_coverage = require_full_coverage + self._skip_full_coverage_check = skip_full_coverage_check + self._moved_exception = None + self.connection_kwargs = kwargs + self.read_load_balancer = LoadBalancer() + if lock is None: + lock = threading.Lock() + self._lock = lock + self.initialize() + + def get_node(self, host=None, port=None, node_name=None): + if node_name is None and (host is None or port is None): + warnings.warn( + "get_node requires one of the followings: " + "1. node name " + "2. host and port" + ) + return None + if host is not None and port is not None: + if host == "localhost": + host = socket.gethostbyname(host) + node_name = get_node_name(host=host, port=port) + return self.nodes_cache.get(node_name) + + def update_moved_exception(self, exception): + self._moved_exception = exception + + def _update_moved_slots(self): + e = self._moved_exception + redirected_node = self.get_node(host=e.host, port=e.port) + if redirected_node: + if redirected_node.server_type is not PRIMARY: + # Update the node's server type + redirected_node.server_type = PRIMARY + else: + # This is a new node, we will add it to the nodes cache + redirected_node = ClusterNode(e.host, e.port, PRIMARY) + self.nodes_cache[redirected_node.name] = redirected_node + if redirected_node in self.slots_cache[e.slot_id]: + # The MOVED error resulted from a failover, and the new slot owner + # had previously been a replica. + old_primary = self.slots_cache[e.slot_id][0] + # Update the old primary to be a replica and add it to the end of + # the slot's node list + old_primary.server_type = REPLICA + self.slots_cache[e.slot_id].append(old_primary) + # Remove the old replica, which is now a primary, from the slot's + # node list + self.slots_cache[e.slot_id].remove(redirected_node) + # Override the old primary with the new one + self.slots_cache[e.slot_id][0] = redirected_node + else: + # The new slot owner is a new server, or a server from a different + # shard. We need to remove all current nodes from the slot's list + # (including replications) and add just the new node. + self.slots_cache[e.slot_id] = [redirected_node] + # Reset moved_exception + self._moved_exception = None + + def get_node_from_slot(self, slot, read_from_replicas=False, + server_type=None): + """ + Gets a node that servers this hash slot + """ + if self._moved_exception: + with self._lock: + if self._moved_exception: + self._update_moved_slots() + + if self.slots_cache.get(slot) is None or \ + len(self.slots_cache[slot]) == 0: + raise SlotNotCoveredError( + 'Slot "{0}" not covered by the cluster. ' + '"require_full_coverage={1}"'.format( + slot, self._require_full_coverage) + ) + + if read_from_replicas: + # get the server index in a Round-Robin manner + primary_name = self.slots_cache[slot][0].name + node_idx = self.read_load_balancer.get_server_index( + primary_name, len(self.slots_cache[slot])) + elif ( + server_type is None + or server_type == PRIMARY + or len(self.slots_cache[slot]) == 1 + ): + # return a primary + node_idx = 0 + else: + # return a replica + # randomly choose one of the replicas + node_idx = random.randint( + 1, len(self.slots_cache[slot]) - 1) + + return self.slots_cache[slot][node_idx] + + def get_nodes_by_server_type(self, server_type): + return [ + node + for node in self.nodes_cache.values() + if node.server_type == server_type + ] + + def populate_startup_nodes(self, nodes): + """ + Populate all startup nodes and filters out any duplicates + """ + for n in nodes: + self.startup_nodes[n.name] = n + + def cluster_require_full_coverage(self, cluster_nodes): + """ + if exists 'cluster-require-full-coverage no' config on redis servers, + then even all slots are not covered, cluster still will be able to + respond + """ + + def node_require_full_coverage(node): + try: + return ("yes" in node.redis_connection.config_get( + "cluster-require-full-coverage").values() + ) + except ConnectionError: + return False + except Exception as e: + raise RedisClusterException( + 'ERROR sending "config get cluster-require-full-coverage"' + ' command to redis server: {0}, {1}'.format(node.name, e) + ) + + # at least one node should have cluster-require-full-coverage yes + return any(node_require_full_coverage(node) + for node in cluster_nodes.values()) + + def check_slots_coverage(self, slots_cache): + # Validate if all slots are covered or if we should try next + # startup node + for i in range(0, REDIS_CLUSTER_HASH_SLOTS): + if i not in slots_cache: + return False + return True + + def create_redis_connections(self, nodes): + """ + This function will create a redis connection to all nodes in :nodes: + """ + for node in nodes: + if node.redis_connection is None: + node.redis_connection = self.create_redis_node( + host=node.host, + port=node.port, + **self.connection_kwargs, + ) + + def create_redis_node(self, host, port, **kwargs): + if self.from_url: + # Create a redis node with a costumed connection pool + kwargs.update({"host": host}) + kwargs.update({"port": port}) + connection_pool = ConnectionPool(**kwargs) + r = Redis( + connection_pool=connection_pool + ) + else: + r = Redis( + host=host, + port=port, + **kwargs + ) + return r + + def initialize(self): + """ + Initializes the nodes cache, slots cache and redis connections. + :startup_nodes: + Responsible for discovering other nodes in the cluster + """ + self.reset() + tmp_nodes_cache = {} + tmp_slots = {} + disagreements = [] + startup_nodes_reachable = False + kwargs = self.connection_kwargs + for startup_node in self.startup_nodes.values(): + try: + if startup_node.redis_connection: + r = startup_node.redis_connection + else: + # Create a new Redis connection and let Redis decode the + # responses so we won't need to handle that + copy_kwargs = copy.deepcopy(kwargs) + copy_kwargs.update({"decode_responses": True}) + copy_kwargs.update({"encoding": "utf-8"}) + r = self.create_redis_node( + startup_node.host, startup_node.port, **copy_kwargs) + self.startup_nodes[startup_node.name].redis_connection = r + cluster_slots = r.execute_command("CLUSTER SLOTS") + startup_nodes_reachable = True + except (ConnectionError, TimeoutError): + continue + except ResponseError as e: + warnings.warn( + 'ReseponseError sending "cluster slots" to redis server') + + # Isn't a cluster connection, so it won't parse these + # exceptions automatically + message = e.__str__() + if "CLUSTERDOWN" in message or "MASTERDOWN" in message: + continue + else: + raise RedisClusterException( + 'ERROR sending "cluster slots" command to redis ' + 'server: {0}. error: {1}'.format( + startup_node, message) + ) + except Exception as e: + message = e.__str__() + raise RedisClusterException( + 'ERROR sending "cluster slots" command to redis ' + 'server: {0}. error: {1}'.format( + startup_node, message) + ) + + # If there's only one server in the cluster, its ``host`` is '' + # Fix it to the host in startup_nodes + if (len(cluster_slots) == 1 + and len(cluster_slots[0][2][0]) == 0 + and len(self.startup_nodes) == 1): + cluster_slots[0][2][0] = startup_node.host + + for slot in cluster_slots: + primary_node = slot[2] + host = primary_node[0] + if host == "": + host = startup_node.host + port = int(primary_node[1]) + + target_node = tmp_nodes_cache.get(get_node_name(host, port)) + if target_node is None: + target_node = ClusterNode(host, port, PRIMARY) + # add this node to the nodes cache + tmp_nodes_cache[target_node.name] = target_node + + for i in range(int(slot[0]), int(slot[1]) + 1): + if i not in tmp_slots: + tmp_slots[i] = [] + tmp_slots[i].append(target_node) + replica_nodes = [slot[j] for j in range(3, len(slot))] + + for replica_node in replica_nodes: + host = replica_node[0] + port = replica_node[1] + + target_replica_node = tmp_nodes_cache.get( + get_node_name(host, port)) + if target_replica_node is None: + target_replica_node = ClusterNode( + host, port, REPLICA) + tmp_slots[i].append(target_replica_node) + # add this node to the nodes cache + tmp_nodes_cache[ + target_replica_node.name + ] = target_replica_node + else: + # Validate that 2 nodes want to use the same slot cache + # setup + if tmp_slots[i][0].name != target_node.name: + disagreements.append( + '{0} vs {1} on slot: {2}'.format( + tmp_slots[i][0].name, target_node.name, i) + ) + + if len(disagreements) > 5: + raise RedisClusterException( + 'startup_nodes could not agree on a valid' + ' slots cache: {0}'.format( + ", ".join(disagreements)) + ) + + if not startup_nodes_reachable: + raise RedisClusterException( + "Redis Cluster cannot be connected. Please provide at least " + "one reachable node. " + ) + + # Create Redis connections to all nodes + self.create_redis_connections(list(tmp_nodes_cache.values())) + + fully_covered = self.check_slots_coverage(tmp_slots) + if not fully_covered: + if self._require_full_coverage: + # Despite the requirement that the slots be covered, there + # isn't a full coverage + raise RedisClusterException( + 'All slots are not covered after query all startup_nodes.' + ' {0} of {1} covered...'.format( + len(self.slots_cache), REDIS_CLUSTER_HASH_SLOTS) + ) + else: + # The user set require_full_coverage to False. + # In case of full coverage requirement in the cluster's Redis + # configurations, we will raise an exception. Otherwise, we may + # continue with partial coverage. + # see Redis Cluster configuration parameters in + # https://redis.io/topics/cluster-tutorial + if not self._skip_full_coverage_check and \ + self.cluster_require_full_coverage(tmp_nodes_cache): + raise RedisClusterException( + 'Not all slots are covered but the cluster\'s ' + 'configuration requires full coverage. Set ' + 'cluster-require-full-coverage configuration to no on ' + 'all of the cluster nodes if you wish the cluster to ' + 'be able to serve without being fully covered.' + ' {0} of {1} covered...'.format( + len(self.slots_cache), REDIS_CLUSTER_HASH_SLOTS) + ) + + # Set the tmp variables to the real variables + self.nodes_cache = tmp_nodes_cache + self.slots_cache = tmp_slots + # Populate the startup nodes with all discovered nodes + self.populate_startup_nodes(self.nodes_cache.values()) + + def close(self): + for node in self.nodes_cache.values(): + if node.redis_connection: + node.redis_connection.close() + + def reset(self): + if self.read_load_balancer is not None: + self.read_load_balancer.reset() + + +class ClusterPubSub(PubSub): + """ + Wrapper for PubSub class. + + IMPORTANT: before using ClusterPubSub, read about the known limitations + with pubsub in Cluster mode and learn how to workaround them: + https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html + """ + + def __init__(self, redis_cluster, node=None, host=None, port=None, + **kwargs): + """ + When a pubsub instance is created without specifying a node, a single + node will be transparently chosen for the pubsub connection on the + first command execution. The node will be determined by: + 1. Hashing the channel name in the request to find its keyslot + 2. Selecting a node that handles the keyslot: If read_from_replicas is + set to true, a replica can be selected. + """ + self.node = None + connection_pool = None + if host is not None and port is not None: + node = redis_cluster.get_node(host=host, port=port) + self.node = node + if node is not None: + if not isinstance(node, ClusterNode): + raise DataError("'node' must be a ClusterNode") + connection_pool = redis_cluster.get_redis_connection(node). \ + connection_pool + self.cluster = redis_cluster + super().__init__(**kwargs, connection_pool=connection_pool, + encoder=redis_cluster.encoder) + + def execute_command(self, *args, **kwargs): + """ + Execute a publish/subscribe command. + + Taken code from redis-py and tweak to make it work within a cluster. + """ + # NOTE: don't parse the response in this function -- it could pull a + # legitimate message off the stack if the connection is already + # subscribed to one or more channels + + if self.connection is None: + if self.connection_pool is None: + if len(args) > 1: + # Hash the first channel and get one of the nodes holding + # this slot + channel = args[1] + slot = self.cluster.keyslot(channel) + node = self.cluster.nodes_manager. \ + get_node_from_slot(slot, self.cluster. + read_from_replicas) + else: + # Get a random node + node = self.cluster.get_random_node() + self.node = node + redis_connection = self.cluster.get_redis_connection(node) + self.connection_pool = redis_connection.connection_pool + self.connection = self.connection_pool.get_connection( + 'pubsub', + self.shard_hint + ) + # register a callback that re-subscribes to any channels we + # were listening to when we were disconnected + self.connection.register_connect_callback(self.on_connect) + connection = self.connection + self._execute(connection, connection.send_command, *args) + + def get_redis_connection(self): + """ + Get the Redis connection of the pubsub connected node. + """ + if self.node is not None: + return self.node.redis_connection + + +ERRORS_ALLOW_RETRY = (ConnectionError, TimeoutError, + MovedError, AskError, TryAgainError) + + +class ClusterPipeline(RedisCluster): + """ + Support for Redis pipeline + in cluster mode + """ + + def __init__(self, nodes_manager, result_callbacks=None, + cluster_response_callbacks=None, startup_nodes=None, + read_from_replicas=False, cluster_error_retry_attempts=3, + debug=False, **kwargs): + """ + """ + self.command_stack = [] + self.debug_mode = debug + self.nodes_manager = nodes_manager + self.refresh_table_asap = False + self.result_callbacks = (result_callbacks or + self.__class__.RESULT_CALLBACKS.copy()) + self.startup_nodes = startup_nodes if startup_nodes else [] + self.read_from_replicas = read_from_replicas + self.command_flags = self.__class__.COMMAND_FLAGS.copy() + self.cluster_response_callbacks = cluster_response_callbacks + self.cluster_error_retry_attempts = cluster_error_retry_attempts + + self.encoder = Encoder( + kwargs.get("encoding", "utf-8"), + kwargs.get("encoding_errors", "strict"), + kwargs.get("decode_responses", False), + ) + + # The commands parser refers to the parent + # so that we don't push the COMMAND command + # onto the stack + self.commands_parser = CommandsParser(super()) + + def __repr__(self): + """ + """ + return "{0}".format(type(self).__name__) + + def __enter__(self): + """ + """ + return self + + def __exit__(self, exc_type, exc_value, traceback): + """ + """ + self.reset() + + def __del__(self): + try: + self.reset() + except Exception: + pass + + def __len__(self): + """ + """ + return len(self.command_stack) + + def __nonzero__(self): + "Pipeline instances should always evaluate to True on Python 2.7" + return True + + def __bool__(self): + "Pipeline instances should always evaluate to True on Python 3+" + return True + + def execute_command(self, *args, **kwargs): + """ + """ + return self.pipeline_execute_command(*args, **kwargs) + + def pipeline_execute_command(self, *args, **options): + """ + """ + self.command_stack.append( + PipelineCommand(args, options, len(self.command_stack))) + return self + + def raise_first_error(self, stack): + """ + """ + for c in stack: + r = c.result + if isinstance(r, Exception): + self.annotate_exception(r, c.position + 1, c.args) + raise r + + def annotate_exception(self, exception, number, command): + """ + """ + cmd = ' '.join(map(safe_str, command)) + msg = 'Command # %d (%s) of pipeline caused error: %s' % ( + number, cmd, exception.args[0]) + exception.args = (msg,) + exception.args[1:] + + def execute(self, raise_on_error=True): + """ + """ + stack = self.command_stack + + if not stack: + return [] + + try: + return self.send_cluster_commands(stack, raise_on_error) + finally: + self.reset() + + def reset(self): + """ + Reset back to empty pipeline. + """ + self.command_stack = [] + + self.scripts = set() + + # TODO: Implement + # make sure to reset the connection state in the event that we were + # watching something + # if self.watching and self.connection: + # try: + # # call this manually since our unwatch or + # # immediate_execute_command methods can call reset() + # self.connection.send_command('UNWATCH') + # self.connection.read_response() + # except ConnectionError: + # # disconnect will also remove any previous WATCHes + # self.connection.disconnect() + + # clean up the other instance attributes + self.watching = False + self.explicit_transaction = False + + # TODO: Implement + # we can safely return the connection to the pool here since we're + # sure we're no longer WATCHing anything + # if self.connection: + # self.connection_pool.release(self.connection) + # self.connection = None + + def send_cluster_commands(self, stack, + raise_on_error=True, allow_redirections=True): + """ + Wrapper for CLUSTERDOWN error handling. + + If the cluster reports it is down it is assumed that: + - connection_pool was disconnected + - connection_pool was reseted + - refereh_table_asap set to True + + It will try the number of times specified by + the config option "self.cluster_error_retry_attempts" + which defaults to 3 unless manually configured. + + If it reaches the number of times, the command will + raises ClusterDownException. + """ + for _ in range(0, self.cluster_error_retry_attempts): + try: + return self._send_cluster_commands( + stack, + raise_on_error=raise_on_error, + allow_redirections=allow_redirections, + ) + except ClusterDownError: + # Try again with the new cluster setup. All other errors + # should be raised. + pass + + # If it fails the configured number of times then raise + # exception back to caller of this method + raise ClusterDownError( + "CLUSTERDOWN error. Unable to rebuild the cluster") + + def _send_cluster_commands(self, stack, + raise_on_error=True, + allow_redirections=True): + """ + Send a bunch of cluster commands to the redis cluster. + + `allow_redirections` If the pipeline should follow + `ASK` & `MOVED` responses automatically. If set + to false it will raise RedisClusterException. + """ + # the first time sending the commands we send all of + # the commands that were queued up. + # if we have to run through it again, we only retry + # the commands that failed. + attempt = sorted(stack, key=lambda x: x.position) + + # build a list of node objects based on node names we need to + nodes = {} + + # as we move through each command that still needs to be processed, + # we figure out the slot number that command maps to, then from + # the slot determine the node. + for c in attempt: + # refer to our internal node -> slot table that + # tells us where a given + # command should route to. + slot = self.determine_slot(*c.args) + node = self.nodes_manager.get_node_from_slot( + slot, self.read_from_replicas and c.args[0] in READ_COMMANDS) + + # now that we know the name of the node + # ( it's just a string in the form of host:port ) + # we can build a list of commands for each node. + node_name = node.name + if node_name not in nodes: + redis_node = self.get_redis_connection(node) + connection = get_connection(redis_node, c.args) + nodes[node_name] = NodeCommands(redis_node.parse_response, + redis_node.connection_pool, + connection) + + nodes[node_name].append(c) + + # send the commands in sequence. + # we write to all the open sockets for each node first, + # before reading anything + # this allows us to flush all the requests out across the + # network essentially in parallel + # so that we can read them all in parallel as they come back. + # we dont' multiplex on the sockets as they come available, + # but that shouldn't make too much difference. + node_commands = nodes.values() + for n in node_commands: + n.write() + + for n in node_commands: + n.read() + + # release all of the redis connections we allocated earlier + # back into the connection pool. + # we used to do this step as part of a try/finally block, + # but it is really dangerous to + # release connections back into the pool if for some + # reason the socket has data still left in it + # from a previous operation. The write and + # read operations already have try/catch around them for + # all known types of errors including connection + # and socket level errors. + # So if we hit an exception, something really bad + # happened and putting any oF + # these connections back into the pool is a very bad idea. + # the socket might have unread buffer still sitting in it, + # and then the next time we read from it we pass the + # buffered result back from a previous command and + # every single request after to that connection will always get + # a mismatched result. + for n in nodes.values(): + n.connection_pool.release(n.connection) + + # if the response isn't an exception it is a + # valid response from the node + # we're all done with that command, YAY! + # if we have more commands to attempt, we've run into problems. + # collect all the commands we are allowed to retry. + # (MOVED, ASK, or connection errors or timeout errors) + attempt = sorted([c for c in attempt + if isinstance(c.result, ERRORS_ALLOW_RETRY)], + key=lambda x: x.position) + if attempt and allow_redirections: + # RETRY MAGIC HAPPENS HERE! + # send these remaing comamnds one at a time using `execute_command` + # in the main client. This keeps our retry logic + # in one place mostly, + # and allows us to be more confident in correctness of behavior. + # at this point any speed gains from pipelining have been lost + # anyway, so we might as well make the best + # attempt to get the correct behavior. + # + # The client command will handle retries for each + # individual command sequentially as we pass each + # one into `execute_command`. Any exceptions + # that bubble out should only appear once all + # retries have been exhausted. + # + # If a lot of commands have failed, we'll be setting the + # flag to rebuild the slots table from scratch. + # So MOVED errors should correct themselves fairly quickly. + self.connection_pool.nodes. \ + increment_reinitialize_counter(len(attempt)) + for c in attempt: + try: + # send each command individually like we + # do in the main client. + c.result = super(ClusterPipeline, self). \ + execute_command(*c.args, **c.options) + except RedisError as e: + c.result = e + + # turn the response back into a simple flat array that corresponds + # to the sequence of commands issued in the stack in pipeline.execute() + response = [c.result for c in sorted(stack, key=lambda x: x.position)] + + if raise_on_error: + self.raise_first_error(stack) + + return response + + def _fail_on_redirect(self, allow_redirections): + """ + """ + if not allow_redirections: + raise RedisClusterException( + "ASK & MOVED redirection not allowed in this pipeline") + + def multi(self): + """ + """ + raise RedisClusterException("method multi() is not implemented") + + def immediate_execute_command(self, *args, **options): + """ + """ + raise RedisClusterException( + "method immediate_execute_command() is not implemented") + + def _execute_transaction(self, *args, **kwargs): + """ + """ + raise RedisClusterException( + "method _execute_transaction() is not implemented") + + def load_scripts(self): + """ + """ + raise RedisClusterException( + "method load_scripts() is not implemented") + + def watch(self, *names): + """ + """ + raise RedisClusterException("method watch() is not implemented") + + def unwatch(self): + """ + """ + raise RedisClusterException("method unwatch() is not implemented") + + def script_load_for_pipeline(self, *args, **kwargs): + """ + """ + raise RedisClusterException( + "method script_load_for_pipeline() is not implemented") + + def delete(self, *names): + """ + "Delete a key specified by ``names``" + """ + if len(names) != 1: + raise RedisClusterException( + "deleting multiple keys is not " + "implemented in pipeline command") + + return self.execute_command('DEL', names[0]) + + +def block_pipeline_command(func): + """ + Prints error because some pipelined commands should + be blocked when running in cluster-mode + """ + + def inner(*args, **kwargs): + raise RedisClusterException( + "ERROR: Calling pipelined function {0} is blocked when " + "running redis in cluster mode...".format(func.__name__)) + + return inner + + +# Blocked pipeline commands +ClusterPipeline.bitop = block_pipeline_command(RedisCluster.bitop) +ClusterPipeline.brpoplpush = block_pipeline_command(RedisCluster.brpoplpush) +ClusterPipeline.client_getname = \ + block_pipeline_command(RedisCluster.client_getname) +ClusterPipeline.client_list = block_pipeline_command(RedisCluster.client_list) +ClusterPipeline.client_setname = \ + block_pipeline_command(RedisCluster.client_setname) +ClusterPipeline.config_set = block_pipeline_command(RedisCluster.config_set) +ClusterPipeline.dbsize = block_pipeline_command(RedisCluster.dbsize) +ClusterPipeline.flushall = block_pipeline_command(RedisCluster.flushall) +ClusterPipeline.flushdb = block_pipeline_command(RedisCluster.flushdb) +ClusterPipeline.keys = block_pipeline_command(RedisCluster.keys) +ClusterPipeline.mget = block_pipeline_command(RedisCluster.mget) +ClusterPipeline.move = block_pipeline_command(RedisCluster.move) +ClusterPipeline.mset = block_pipeline_command(RedisCluster.mset) +ClusterPipeline.msetnx = block_pipeline_command(RedisCluster.msetnx) +ClusterPipeline.pfmerge = block_pipeline_command(RedisCluster.pfmerge) +ClusterPipeline.pfcount = block_pipeline_command(RedisCluster.pfcount) +ClusterPipeline.ping = block_pipeline_command(RedisCluster.ping) +ClusterPipeline.publish = block_pipeline_command(RedisCluster.publish) +ClusterPipeline.randomkey = block_pipeline_command(RedisCluster.randomkey) +ClusterPipeline.rename = block_pipeline_command(RedisCluster.rename) +ClusterPipeline.renamenx = block_pipeline_command(RedisCluster.renamenx) +ClusterPipeline.rpoplpush = block_pipeline_command(RedisCluster.rpoplpush) +ClusterPipeline.scan = block_pipeline_command(RedisCluster.scan) +ClusterPipeline.sdiff = block_pipeline_command(RedisCluster.sdiff) +ClusterPipeline.sdiffstore = block_pipeline_command(RedisCluster.sdiffstore) +ClusterPipeline.sinter = block_pipeline_command(RedisCluster.sinter) +ClusterPipeline.sinterstore = block_pipeline_command(RedisCluster.sinterstore) +ClusterPipeline.smove = block_pipeline_command(RedisCluster.smove) +ClusterPipeline.sort = block_pipeline_command(RedisCluster.sort) +ClusterPipeline.sunion = block_pipeline_command(RedisCluster.sunion) +ClusterPipeline.sunionstore = block_pipeline_command(RedisCluster.sunionstore) + + +class PipelineCommand(object): + """ + """ + + def __init__(self, args, options=None, position=None): + self.args = args + if options is None: + options = {} + self.options = options + self.position = position + self.result = None + self.node = None + self.asking = False + + +class NodeCommands(object): + """ + """ + + def __init__(self, parse_response, connection_pool, connection): + """ + """ + self.parse_response = parse_response + self.connection_pool = connection_pool + self.connection = connection + self.commands = [] + + def append(self, c): + """ + """ + self.commands.append(c) + + def write(self): + """ + Code borrowed from Redis so it can be fixed + """ + connection = self.connection + commands = self.commands + + # We are going to clobber the commands with the write, so go ahead + # and ensure that nothing is sitting there from a previous run. + for c in commands: + c.result = None + + # build up all commands into a single request to increase network perf + # send all the commands and catch connection and timeout errors. + try: + connection.send_packed_command( + connection.pack_commands([c.args for c in commands])) + except (ConnectionError, TimeoutError) as e: + for c in commands: + c.result = e + + def read(self): + """ + """ + connection = self.connection + for c in self.commands: + + # if there is a result on this command, + # it means we ran into an exception + # like a connection error. Trying to parse + # a response on a connection that + # is no longer open will result in a + # connection error raised by redis-py. + # but redis-py doesn't check in parse_response + # that the sock object is + # still set and if you try to + # read from a closed connection, it will + # result in an AttributeError because + # it will do a readline() call on None. + # This can have all kinds of nasty side-effects. + # Treating this case as a connection error + # is fine because it will dump + # the connection object back into the + # pool and on the next write, it will + # explicitly open the connection and all will be well. + if c.result is None: + try: + c.result = self.parse_response( + connection, c.args[0], **c.options) + except (ConnectionError, TimeoutError) as e: + for c in self.commands: + c.result = e + return + except RedisError: + c.result = sys.exc_info()[1] diff --git a/redis/commands.py b/redis/commands.py new file mode 100644 index 0000000000..af9bc68ee5 --- /dev/null +++ b/redis/commands.py @@ -0,0 +1,3935 @@ +import datetime +import time +import warnings +import hashlib + +from .helpers import list_or_args +from redis.exceptions import ( + ConnectionError, + DataError, + NoScriptError, + RedisError, + ResponseError +) +from redis.utils import str_if_bytes +from redis.crc import key_slot + + +class CommandsParser: + def __init__(self, redis_connection): + self.initialized = False + self.commands = {} + self.initialize(redis_connection) + + def initialize(self, r): + self.commands = r.execute_command("COMMAND") + + # As soon as this PR is merged into Redis, we should reimplement + # our logic to use COMMAND INFO changes to determine the key positions + # https://github.com/redis/redis/pull/8324 + def get_keys(self, redis_conn, *args): + """ + Get the keys from the passed command + """ + if len(args) < 2: + # The command has no keys in it + return None + + cmd_name = args[0].lower() + cmd_name_split = cmd_name.split() + if len(cmd_name_split) > 1: + # we need to take only the main command, e.g. 'memory' for + # 'memory usage' + cmd_name = cmd_name_split[0] + if cmd_name not in self.commands: + # We'll try to reinitialize the commands cache, if the engine + # version has changed, the commands may not be current + self.initialize(redis_conn) + if cmd_name not in self.commands: + raise RedisError("{0} command doesn't exist in Redis commands". + format(cmd_name.upper())) + + command = self.commands.get(cmd_name) + if 'movablekeys' in command['flags']: + keys = self.get_moveable_keys(redis_conn, *args) + elif 'pubsub' in command['flags']: + keys = self.get_pubsub_keys(*args) + else: + if command['step_count'] == 0 and command['first_key_pos'] == 0 \ + and command['last_key_pos'] == 0: + # The command doesn't have keys in it + return None + last_key_pos = command['last_key_pos'] + if last_key_pos == -1: + last_key_pos = len(args) - 1 + keys_pos = list(range(command['first_key_pos'], last_key_pos + 1, + command['step_count'])) + keys = [args[pos] for pos in keys_pos] + + return keys + + def get_moveable_keys(self, redis_conn, *args): + try: + pieces = [] + cmd_name = args[0] + for arg in cmd_name.split(): + # The command name should be splitted into separate arguments, + # e.g. 'MEMORY USAGE' will be splitted into ['MEMORY', 'USAGE'] + pieces.append(arg) + pieces += args[1:] + keys = redis_conn.execute_command('COMMAND GETKEYS', *pieces) + except ResponseError as e: + message = e.__str__() + if 'Invalid arguments' in message or \ + 'The command has no key arguments' in message: + return None + else: + raise e + return keys + + def get_pubsub_keys(self, *args): + """ + Get the keys from pubsub command. + Although PubSub commands have predetermined key locations, they are not + supported in the 'COMMAND's output, so the key positions are hardcoded + in this method + """ + if len(args) < 2: + # The command has no keys in it + return None + args = [str_if_bytes(arg) for arg in args] + command = args[0].upper() + if command in ['PUBLISH', 'PUBSUB CHANNELS']: + # format example: + # PUBLISH channel message + keys = [args[1]] + elif command in ['SUBSCRIBE', 'PSUBSCRIBE', 'UNSUBSCRIBE', + 'PUNSUBSCRIBE', 'PUBSUB NUMSUB']: + keys = list(args[1:]) + else: + keys = None + return keys + + +class CoreCommands: + """ + A class containing all of the implemented redis commands. This class is + to be used as a mixin. + """ + + # SERVER INFORMATION + + +class AclCommands: + # ACL methods + def acl_cat(self, category=None): + """ + Returns a list of categories or commands within a category. + + If ``category`` is not supplied, returns a list of all categories. + If ``category`` is supplied, returns a list of all commands within + that category. + """ + pieces = [category] if category else [] + return self.execute_command('ACL CAT', *pieces) + + def acl_deluser(self, username): + "Delete the ACL for the specified ``username``" + return self.execute_command('ACL DELUSER', username) + + def acl_genpass(self): + "Generate a random password value" + return self.execute_command('ACL GENPASS') + + def acl_getuser(self, username): + """ + Get the ACL details for the specified ``username``. + + If ``username`` does not exist, return None + """ + return self.execute_command('ACL GETUSER', username) + + def acl_list(self): + "Return a list of all ACLs on the server" + return self.execute_command('ACL LIST') + + def acl_log(self, count=None): + """ + Get ACL logs as a list. + :param int count: Get logs[0:count]. + :rtype: List. + """ + args = [] + if count is not None: + if not isinstance(count, int): + raise DataError('ACL LOG count must be an ' + 'integer') + args.append(count) + + return self.execute_command('ACL LOG', *args) + + def acl_log_reset(self): + """ + Reset ACL logs. + :rtype: Boolean. + """ + args = [b'RESET'] + return self.execute_command('ACL LOG', *args) + + def acl_load(self): + """ + Load ACL rules from the configured ``aclfile``. + + Note that the server must be configured with the ``aclfile`` + directive to be able to load ACL rules from an aclfile. + """ + return self.execute_command('ACL LOAD') + + def acl_save(self): + """ + Save ACL rules to the configured ``aclfile``. + + Note that the server must be configured with the ``aclfile`` + directive to be able to save ACL rules to an aclfile. + """ + return self.execute_command('ACL SAVE') + + def acl_setuser(self, username, enabled=False, nopass=False, + passwords=None, hashed_passwords=None, categories=None, + commands=None, keys=None, reset=False, reset_keys=False, + reset_passwords=False): + """ + Create or update an ACL user. + + Create or update the ACL for ``username``. If the user already exists, + the existing ACL is completely overwritten and replaced with the + specified values. + + ``enabled`` is a boolean indicating whether the user should be allowed + to authenticate or not. Defaults to ``False``. + + ``nopass`` is a boolean indicating whether the can authenticate without + a password. This cannot be True if ``passwords`` are also specified. + + ``passwords`` if specified is a list of plain text passwords + to add to or remove from the user. Each password must be prefixed with + a '+' to add or a '-' to remove. For convenience, the value of + ``passwords`` can be a simple prefixed string when adding or + removing a single password. + + ``hashed_passwords`` if specified is a list of SHA-256 hashed passwords + to add to or remove from the user. Each hashed password must be + prefixed with a '+' to add or a '-' to remove. For convenience, + the value of ``hashed_passwords`` can be a simple prefixed string when + adding or removing a single password. + + ``categories`` if specified is a list of strings representing category + permissions. Each string must be prefixed with either a '+' to add the + category permission or a '-' to remove the category permission. + + ``commands`` if specified is a list of strings representing command + permissions. Each string must be prefixed with either a '+' to add the + command permission or a '-' to remove the command permission. + + ``keys`` if specified is a list of key patterns to grant the user + access to. Keys patterns allow '*' to support wildcard matching. For + example, '*' grants access to all keys while 'cache:*' grants access + to all keys that are prefixed with 'cache:'. ``keys`` should not be + prefixed with a '~'. + + ``reset`` is a boolean indicating whether the user should be fully + reset prior to applying the new ACL. Setting this to True will + remove all existing passwords, flags and privileges from the user and + then apply the specified rules. If this is False, the user's existing + passwords, flags and privileges will be kept and any new specified + rules will be applied on top. + + ``reset_keys`` is a boolean indicating whether the user's key + permissions should be reset prior to applying any new key permissions + specified in ``keys``. If this is False, the user's existing + key permissions will be kept and any new specified key permissions + will be applied on top. + + ``reset_passwords`` is a boolean indicating whether to remove all + existing passwords and the 'nopass' flag from the user prior to + applying any new passwords specified in 'passwords' or + 'hashed_passwords'. If this is False, the user's existing passwords + and 'nopass' status will be kept and any new specified passwords + or hashed_passwords will be applied on top. + """ + encoder = self.connection_pool.get_encoder() + pieces = [username] + + if reset: + pieces.append(b'reset') + + if reset_keys: + pieces.append(b'resetkeys') + + if reset_passwords: + pieces.append(b'resetpass') + + if enabled: + pieces.append(b'on') + else: + pieces.append(b'off') + + if (passwords or hashed_passwords) and nopass: + raise DataError('Cannot set \'nopass\' and supply ' + '\'passwords\' or \'hashed_passwords\'') + + if passwords: + # as most users will have only one password, allow remove_passwords + # to be specified as a simple string or a list + passwords = list_or_args(passwords, []) + for i, password in enumerate(passwords): + password = encoder.encode(password) + if password.startswith(b'+'): + pieces.append(b'>%s' % password[1:]) + elif password.startswith(b'-'): + pieces.append(b'<%s' % password[1:]) + else: + raise DataError('Password %d must be prefixeed with a ' + '"+" to add or a "-" to remove' % i) + + if hashed_passwords: + # as most users will have only one password, allow remove_passwords + # to be specified as a simple string or a list + hashed_passwords = list_or_args(hashed_passwords, []) + for i, hashed_password in enumerate(hashed_passwords): + hashed_password = encoder.encode(hashed_password) + if hashed_password.startswith(b'+'): + pieces.append(b'#%s' % hashed_password[1:]) + elif hashed_password.startswith(b'-'): + pieces.append(b'!%s' % hashed_password[1:]) + else: + raise DataError('Hashed %d password must be prefixeed ' + 'with a "+" to add or a "-" to remove' % i) + + if nopass: + pieces.append(b'nopass') + + if categories: + for category in categories: + category = encoder.encode(category) + # categories can be prefixed with one of (+@, +, -@, -) + if category.startswith(b'+@'): + pieces.append(category) + elif category.startswith(b'+'): + pieces.append(b'+@%s' % category[1:]) + elif category.startswith(b'-@'): + pieces.append(category) + elif category.startswith(b'-'): + pieces.append(b'-@%s' % category[1:]) + else: + raise DataError('Category "%s" must be prefixed with ' + '"+" or "-"' + % encoder.decode(category, force=True)) + if commands: + for cmd in commands: + cmd = encoder.encode(cmd) + if not cmd.startswith(b'+') and not cmd.startswith(b'-'): + raise DataError('Command "%s" must be prefixed with ' + '"+" or "-"' + % encoder.decode(cmd, force=True)) + pieces.append(cmd) + + if keys: + for key in keys: + key = encoder.encode(key) + pieces.append(b'~%s' % key) + + return self.execute_command('ACL SETUSER', *pieces) + + def acl_users(self): + "Returns a list of all registered users on the server." + return self.execute_command('ACL USERS') + + def acl_whoami(self): + "Get the username for the current connection" + return self.execute_command('ACL WHOAMI') + + +class ManagementCommands: + def bgrewriteaof(self): + "Tell the Redis server to rewrite the AOF file from data in memory." + return self.execute_command('BGREWRITEAOF') + + def bgsave(self): + """ + Tell the Redis server to save its data to disk. Unlike save(), + this method is asynchronous and returns immediately. + """ + return self.execute_command('BGSAVE') + + def client_kill(self, address): + "Disconnects the client at ``address`` (ip:port)" + return self.execute_command('CLIENT KILL', address) + + def client_kill_filter(self, _id=None, _type=None, addr=None, + skipme=None, laddr=None): + """ + Disconnects client(s) using a variety of filter options + :param id: Kills a client by its unique ID field + :param type: Kills a client by type where type is one of 'normal', + 'master', 'slave' or 'pubsub' + :param addr: Kills a client by its 'address:port' + :param skipme: If True, then the client calling the command + :param laddr: Kills a cient by its 'local (bind) address:port' + will not get killed even if it is identified by one of the filter + options. If skipme is not provided, the server defaults to skipme=True + """ + args = [] + if _type is not None: + client_types = ('normal', 'master', 'slave', 'pubsub') + if str(_type).lower() not in client_types: + raise DataError("CLIENT KILL type must be one of %r" % ( + client_types,)) + args.extend((b'TYPE', _type)) + if skipme is not None: + if not isinstance(skipme, bool): + raise DataError("CLIENT KILL skipme must be a bool") + if skipme: + args.extend((b'SKIPME', b'YES')) + else: + args.extend((b'SKIPME', b'NO')) + if _id is not None: + args.extend((b'ID', _id)) + if addr is not None: + args.extend((b'ADDR', addr)) + if laddr is not None: + args.extend((b'LADDR', laddr)) + if not args: + raise DataError("CLIENT KILL ... ... " + " must specify at least one filter") + return self.execute_command('CLIENT KILL', *args) + + def client_info(self): + """ + Returns information and statistics about the current + client connection. + """ + return self.execute_command('CLIENT INFO') + + def client_list(self, _type=None, client_id=None): + """ + Returns a list of currently connected clients. + If type of client specified, only that type will be returned. + :param _type: optional. one of the client types (normal, master, + replica, pubsub) + """ + "Returns a list of currently connected clients" + args = [] + if _type is not None: + client_types = ('normal', 'master', 'replica', 'pubsub') + if str(_type).lower() not in client_types: + raise DataError("CLIENT LIST _type must be one of %r" % ( + client_types,)) + args.append(b'TYPE') + args.append(_type) + if client_id is not None: + args.append(b"ID") + args.append(client_id) + return self.execute_command('CLIENT LIST', *args) + + def client_getname(self): + "Returns the current connection name" + return self.execute_command('CLIENT GETNAME') + + def client_id(self): + "Returns the current connection id" + return self.execute_command('CLIENT ID') + + def client_setname(self, name): + "Sets the current connection name" + return self.execute_command('CLIENT SETNAME', name) + + def client_unblock(self, client_id, error=False): + """ + Unblocks a connection by its client id. + If ``error`` is True, unblocks the client with a special error message. + If ``error`` is False (default), the client is unblocked using the + regular timeout mechanism. + """ + args = ['CLIENT UNBLOCK', int(client_id)] + if error: + args.append(b'ERROR') + return self.execute_command(*args) + + def client_pause(self, timeout): + """ + Suspend all the Redis clients for the specified amount of time + :param timeout: milliseconds to pause clients + """ + if not isinstance(timeout, int): + raise DataError("CLIENT PAUSE timeout must be an integer") + return self.execute_command('CLIENT PAUSE', str(timeout)) + + def client_unpause(self): + """ + Unpause all redis clients + """ + return self.execute_command('CLIENT UNPAUSE') + + def readwrite(self): + "Disables read queries for a connection to a Redis Cluster slave node" + return self.execute_command('READWRITE') + + def readonly(self): + "Enables read queries for a connection to a Redis Cluster replica node" + return self.execute_command('READONLY') + + def config_get(self, pattern="*"): + "Return a dictionary of configuration based on the ``pattern``" + return self.execute_command('CONFIG GET', pattern) + + def config_set(self, name, value): + "Set config item ``name`` with ``value``" + return self.execute_command('CONFIG SET', name, value) + + def config_resetstat(self): + "Reset runtime statistics" + return self.execute_command('CONFIG RESETSTAT') + + def config_rewrite(self): + "Rewrite config file with the minimal change to reflect running config" + return self.execute_command('CONFIG REWRITE') + + def cluster(self, cluster_arg, *args): + return self.execute_command('CLUSTER %s' % cluster_arg.upper(), *args) + + def dbsize(self): + "Returns the number of keys in the current database" + return self.execute_command('DBSIZE') + + def debug_object(self, key): + "Returns version specific meta information about a given key" + return self.execute_command('DEBUG OBJECT', key) + + def echo(self, value): + "Echo the string back from the server" + return self.execute_command('ECHO', value) + + def flushall(self, asynchronous=False): + """ + Delete all keys in all databases on the current host. + + ``asynchronous`` indicates whether the operation is + executed asynchronously by the server. + """ + args = [] + if asynchronous: + args.append(b'ASYNC') + return self.execute_command('FLUSHALL', *args) + + def flushdb(self, asynchronous=False): + """ + Delete all keys in the current database. + + ``asynchronous`` indicates whether the operation is + executed asynchronously by the server. + """ + args = [] + if asynchronous: + args.append(b'ASYNC') + return self.execute_command('FLUSHDB', *args) + + def swapdb(self, first, second): + "Swap two databases" + return self.execute_command('SWAPDB', first, second) + + def info(self, section=None): + """ + Returns a dictionary containing information about the Redis server + + The ``section`` option can be used to select a specific section + of information + + The section option is not supported by older versions of Redis Server, + and will generate ResponseError + """ + if section is None: + return self.execute_command('INFO') + else: + return self.execute_command('INFO', section) + + def lastsave(self): + """ + Return a Python datetime object representing the last time the + Redis database was saved to disk + """ + return self.execute_command('LASTSAVE') + + def migrate(self, host, port, keys, destination_db, timeout, + copy=False, replace=False, auth=None): + """ + Migrate 1 or more keys from the current Redis server to a different + server specified by the ``host``, ``port`` and ``destination_db``. + + The ``timeout``, specified in milliseconds, indicates the maximum + time the connection between the two servers can be idle before the + command is interrupted. + + If ``copy`` is True, the specified ``keys`` are NOT deleted from + the source server. + + If ``replace`` is True, this operation will overwrite the keys + on the destination server if they exist. + + If ``auth`` is specified, authenticate to the destination server with + the password provided. + """ + keys = list_or_args(keys, []) + if not keys: + raise DataError('MIGRATE requires at least one key') + pieces = [] + if copy: + pieces.append(b'COPY') + if replace: + pieces.append(b'REPLACE') + if auth: + pieces.append(b'AUTH') + pieces.append(auth) + pieces.append(b'KEYS') + pieces.extend(keys) + return self.execute_command('MIGRATE', host, port, '', destination_db, + timeout, *pieces) + + def object(self, infotype, key): + "Return the encoding, idletime, or refcount about the key" + return self.execute_command('OBJECT', infotype, key, infotype=infotype) + + def memory_stats(self): + "Return a dictionary of memory stats" + return self.execute_command('MEMORY STATS') + + def memory_usage(self, key, samples=None): + """ + Return the total memory usage for key, its value and associated + administrative overheads. + + For nested data structures, ``samples`` is the number of elements to + sample. If left unspecified, the server's default is 5. Use 0 to sample + all elements. + """ + args = [] + if isinstance(samples, int): + args.extend([b'SAMPLES', samples]) + return self.execute_command('MEMORY USAGE', key, *args) + + def memory_purge(self): + "Attempts to purge dirty pages for reclamation by allocator" + return self.execute_command('MEMORY PURGE') + + def ping(self): + "Ping the Redis server" + return self.execute_command('PING') + + def quit(self): + """ + Ask the server to close the connection. + https://redis.io/commands/quit + """ + return self.execute_command('QUIT') + + def replicaof(self, *args): + """ + Update the replication settings of a redis replica, on the fly. + Examples of valid arguments include: + NO ONE (set no replication) + host port (set to the host and port of a redis server) + see: https://redis.io/commands/replicaof + """ + return self.execute_command('REPLICAOF', *args) + + def save(self): + """ + Tell the Redis server to save its data to disk, + blocking until the save is complete + """ + return self.execute_command('SAVE') + + def shutdown(self, save=False, nosave=False): + """Shutdown the Redis server. If Redis has persistence configured, + data will be flushed before shutdown. If the "save" option is set, + a data flush will be attempted even if there is no persistence + configured. If the "nosave" option is set, no data flush will be + attempted. The "save" and "nosave" options cannot both be set. + """ + if save and nosave: + raise DataError('SHUTDOWN save and nosave cannot both be set') + args = ['SHUTDOWN'] + if save: + args.append('SAVE') + if nosave: + args.append('NOSAVE') + try: + self.execute_command(*args) + except ConnectionError: + # a ConnectionError here is expected + return + raise RedisError("SHUTDOWN seems to have failed.") + + def slaveof(self, host=None, port=None): + """ + Set the server to be a replicated slave of the instance identified + by the ``host`` and ``port``. If called without arguments, the + instance is promoted to a master instead. + """ + if host is None and port is None: + return self.execute_command('SLAVEOF', b'NO', b'ONE') + return self.execute_command('SLAVEOF', host, port) + + def slowlog_get(self, num=None): + """ + Get the entries from the slowlog. If ``num`` is specified, get the + most recent ``num`` items. + """ + args = ['SLOWLOG GET'] + if num is not None: + args.append(num) + decode_responses = self.connection_pool.connection_kwargs.get( + 'decode_responses', False) + return self.execute_command(*args, decode_responses=decode_responses) + + def slowlog_len(self): + "Get the number of items in the slowlog" + return self.execute_command('SLOWLOG LEN') + + def slowlog_reset(self): + "Remove all items in the slowlog" + return self.execute_command('SLOWLOG RESET') + + def time(self): + """ + Returns the server time as a 2-item tuple of ints: + (seconds since epoch, microseconds into this second). + """ + return self.execute_command('TIME') + + def wait(self, num_replicas, timeout): + """ + Redis synchronous replication + That returns the number of replicas that processed the query when + we finally have at least ``num_replicas``, or when the ``timeout`` was + reached. + """ + return self.execute_command('WAIT', num_replicas, timeout) + + +class BasicKeyCommands: + # BASIC KEY COMMANDS + def append(self, key, value): + """ + Appends the string ``value`` to the value at ``key``. If ``key`` + doesn't already exist, create it with a value of ``value``. + Returns the new length of the value at ``key``. + """ + return self.execute_command('APPEND', key, value) + + def bitcount(self, key, start=None, end=None): + """ + Returns the count of set bits in the value of ``key``. Optional + ``start`` and ``end`` parameters indicate which bytes to consider + """ + params = [key] + if start is not None and end is not None: + params.append(start) + params.append(end) + elif (start is not None and end is None) or \ + (end is not None and start is None): + raise DataError("Both start and end must be specified") + return self.execute_command('BITCOUNT', *params) + + def bitfield(self, key, default_overflow=None): + """ + Return a BitFieldOperation instance to conveniently construct one or + more bitfield operations on ``key``. + """ + return BitFieldOperation(self, key, default_overflow=default_overflow) + + def bitop(self, operation, dest, *keys): + """ + Perform a bitwise operation using ``operation`` between ``keys`` and + store the result in ``dest``. + """ + return self.execute_command('BITOP', operation, dest, *keys) + + def bitpos(self, key, bit, start=None, end=None): + """ + Return the position of the first bit set to 1 or 0 in a string. + ``start`` and ``end`` defines search range. The range is interpreted + as a range of bytes and not a range of bits, so start=0 and end=2 + means to look at the first three bytes. + """ + if bit not in (0, 1): + raise DataError('bit must be 0 or 1') + params = [key, bit] + + start is not None and params.append(start) + + if start is not None and end is not None: + params.append(end) + elif start is None and end is not None: + raise DataError("start argument is not set, " + "when end is specified") + return self.execute_command('BITPOS', *params) + + def copy(self, source, destination, destination_db=None, replace=False): + """ + Copy the value stored in the ``source`` key to the ``destination`` key. + + ``destination_db`` an alternative destination database. By default, + the ``destination`` key is created in the source Redis database. + + ``replace`` whether the ``destination`` key should be removed before + copying the value to it. By default, the value is not copied if + the ``destination`` key already exists. + """ + params = [source, destination] + if destination_db is not None: + params.extend(["DB", destination_db]) + if replace: + params.append("REPLACE") + return self.execute_command('COPY', *params) + + def decr(self, name, amount=1): + """ + Decrements the value of ``key`` by ``amount``. If no key exists, + the value will be initialized as 0 - ``amount`` + """ + # An alias for ``decr()``, because it is already implemented + # as DECRBY redis command. + return self.decrby(name, amount) + + def decrby(self, name, amount=1): + """ + Decrements the value of ``key`` by ``amount``. If no key exists, + the value will be initialized as 0 - ``amount`` + """ + return self.execute_command('DECRBY', name, amount) + + def delete(self, *names): + "Delete one or more keys specified by ``names``" + return self.execute_command('DEL', *names) + + def __delitem__(self, name): + self.delete(name) + + def dump(self, name): + """ + Return a serialized version of the value stored at the specified key. + If key does not exist a nil bulk reply is returned. + """ + return self.execute_command('DUMP', name) + + def exists(self, *names): + "Returns the number of ``names`` that exist" + return self.execute_command('EXISTS', *names) + + __contains__ = exists + + def expire(self, name, time): + """ + Set an expire flag on key ``name`` for ``time`` seconds. ``time`` + can be represented by an integer or a Python timedelta object. + """ + if isinstance(time, datetime.timedelta): + time = int(time.total_seconds()) + return self.execute_command('EXPIRE', name, time) + + def expireat(self, name, when): + """ + Set an expire flag on key ``name``. ``when`` can be represented + as an integer indicating unix time or a Python datetime object. + """ + if isinstance(when, datetime.datetime): + when = int(time.mktime(when.timetuple())) + return self.execute_command('EXPIREAT', name, when) + + def get(self, name): + """ + Return the value at key ``name``, or None if the key doesn't exist + """ + return self.execute_command('GET', name) + + def getdel(self, name): + """ + Get the value at key ``name`` and delete the key. This command + is similar to GET, except for the fact that it also deletes + the key on success (if and only if the key's value type + is a string). + """ + return self.execute_command('GETDEL', name) + + def getex(self, name, + ex=None, px=None, exat=None, pxat=None, persist=False): + """ + Get the value of key and optionally set its expiration. + GETEX is similar to GET, but is a write command with + additional options. All time parameters can be given as + datetime.timedelta or integers. + + ``ex`` sets an expire flag on key ``name`` for ``ex`` seconds. + + ``px`` sets an expire flag on key ``name`` for ``px`` milliseconds. + + ``exat`` sets an expire flag on key ``name`` for ``ex`` seconds, + specified in unix time. + + ``pxat`` sets an expire flag on key ``name`` for ``ex`` milliseconds, + specified in unix time. + + ``persist`` remove the time to live associated with ``name``. + """ + + opset = set([ex, px, exat, pxat]) + if len(opset) > 2 or len(opset) > 1 and persist: + raise DataError("``ex``, ``px``, ``exat``, ``pxat``", + "and ``persist`` are mutually exclusive.") + + pieces = [] + # similar to set command + if ex is not None: + pieces.append('EX') + if isinstance(ex, datetime.timedelta): + ex = int(ex.total_seconds()) + pieces.append(ex) + if px is not None: + pieces.append('PX') + if isinstance(px, datetime.timedelta): + px = int(px.total_seconds() * 1000) + pieces.append(px) + # similar to pexpireat command + if exat is not None: + pieces.append('EXAT') + if isinstance(exat, datetime.datetime): + s = int(exat.microsecond / 1000000) + exat = int(time.mktime(exat.timetuple())) + s + pieces.append(exat) + if pxat is not None: + pieces.append('PXAT') + if isinstance(pxat, datetime.datetime): + ms = int(pxat.microsecond / 1000) + pxat = int(time.mktime(pxat.timetuple())) * 1000 + ms + pieces.append(pxat) + if persist: + pieces.append('PERSIST') + + return self.execute_command('GETEX', name, *pieces) + + def __getitem__(self, name): + """ + Return the value at key ``name``, raises a KeyError if the key + doesn't exist. + """ + value = self.get(name) + if value is not None: + return value + raise KeyError(name) + + def getbit(self, name, offset): + "Returns a boolean indicating the value of ``offset`` in ``name``" + return self.execute_command('GETBIT', name, offset) + + def getrange(self, key, start, end): + """ + Returns the substring of the string value stored at ``key``, + determined by the offsets ``start`` and ``end`` (both are inclusive) + """ + return self.execute_command('GETRANGE', key, start, end) + + def getset(self, name, value): + """ + Sets the value at key ``name`` to ``value`` + and returns the old value at key ``name`` atomically. + + As per Redis 6.2, GETSET is considered deprecated. + Please use SET with GET parameter in new code. + """ + return self.execute_command('GETSET', name, value) + + def incr(self, name, amount=1): + """ + Increments the value of ``key`` by ``amount``. If no key exists, + the value will be initialized as ``amount`` + """ + return self.incrby(name, amount) + + def incrby(self, name, amount=1): + """ + Increments the value of ``key`` by ``amount``. If no key exists, + the value will be initialized as ``amount`` + """ + # An alias for ``incr()``, because it is already implemented + # as INCRBY redis command. + return self.execute_command('INCRBY', name, amount) + + def incrbyfloat(self, name, amount=1.0): + """ + Increments the value at key ``name`` by floating ``amount``. + If no key exists, the value will be initialized as ``amount`` + """ + return self.execute_command('INCRBYFLOAT', name, amount) + + def keys(self, pattern='*'): + "Returns a list of keys matching ``pattern``" + return self.execute_command('KEYS', pattern) + + def lmove(self, first_list, second_list, src="LEFT", dest="RIGHT"): + """ + Atomically returns and removes the first/last element of a list, + pushing it as the first/last element on the destination list. + Returns the element being popped and pushed. + """ + params = [first_list, second_list, src, dest] + return self.execute_command("LMOVE", *params) + + def blmove(self, first_list, second_list, timeout, + src="LEFT", dest="RIGHT"): + """ + Blocking version of lmove. + """ + params = [first_list, second_list, src, dest, timeout] + return self.execute_command("BLMOVE", *params) + + def mget(self, keys, *args): + """ + Returns a list of values ordered identically to ``keys`` + """ + from redis.client import EMPTY_RESPONSE + args = list_or_args(keys, args) + options = {} + if not args: + options[EMPTY_RESPONSE] = [] + return self.execute_command('MGET', *args, **options) + + def mset(self, mapping): + """ + Sets key/values based on a mapping. Mapping is a dictionary of + key/value pairs. Both keys and values should be strings or types that + can be cast to a string via str(). + """ + items = [] + for pair in mapping.items(): + items.extend(pair) + return self.execute_command('MSET', *items) + + def msetnx(self, mapping): + """ + Sets key/values based on a mapping if none of the keys are already set. + Mapping is a dictionary of key/value pairs. Both keys and values + should be strings or types that can be cast to a string via str(). + Returns a boolean indicating if the operation was successful. + """ + items = [] + for pair in mapping.items(): + items.extend(pair) + return self.execute_command('MSETNX', *items) + + def move(self, name, db): + "Moves the key ``name`` to a different Redis database ``db``" + return self.execute_command('MOVE', name, db) + + def persist(self, name): + "Removes an expiration on ``name``" + return self.execute_command('PERSIST', name) + + def pexpire(self, name, time): + """ + Set an expire flag on key ``name`` for ``time`` milliseconds. + ``time`` can be represented by an integer or a Python timedelta + object. + """ + if isinstance(time, datetime.timedelta): + time = int(time.total_seconds() * 1000) + return self.execute_command('PEXPIRE', name, time) + + def pexpireat(self, name, when): + """ + Set an expire flag on key ``name``. ``when`` can be represented + as an integer representing unix time in milliseconds (unix time * 1000) + or a Python datetime object. + """ + if isinstance(when, datetime.datetime): + ms = int(when.microsecond / 1000) + when = int(time.mktime(when.timetuple())) * 1000 + ms + return self.execute_command('PEXPIREAT', name, when) + + def psetex(self, name, time_ms, value): + """ + Set the value of key ``name`` to ``value`` that expires in ``time_ms`` + milliseconds. ``time_ms`` can be represented by an integer or a Python + timedelta object + """ + if isinstance(time_ms, datetime.timedelta): + time_ms = int(time_ms.total_seconds() * 1000) + return self.execute_command('PSETEX', name, time_ms, value) + + def pttl(self, name): + "Returns the number of milliseconds until the key ``name`` will expire" + return self.execute_command('PTTL', name) + + def hrandfield(self, key, count=None, withvalues=False): + """ + Return a random field from the hash value stored at key. + + count: if the argument is positive, return an array of distinct fields. + If called with a negative count, the behavior changes and the command + is allowed to return the same field multiple times. In this case, + the number of returned fields is the absolute value of the + specified count. + withvalues: The optional WITHVALUES modifier changes the reply so it + includes the respective values of the randomly selected hash fields. + """ + params = [] + if count is not None: + params.append(count) + if withvalues: + params.append("WITHVALUES") + + return self.execute_command("HRANDFIELD", key, *params) + + def randomkey(self): + "Returns the name of a random key" + return self.execute_command('RANDOMKEY') + + def rename(self, src, dst): + """ + Rename key ``src`` to ``dst`` + """ + return self.execute_command('RENAME', src, dst) + + def renamenx(self, src, dst): + "Rename key ``src`` to ``dst`` if ``dst`` doesn't already exist" + return self.execute_command('RENAMENX', src, dst) + + def restore(self, name, ttl, value, replace=False, absttl=False): + """ + Create a key using the provided serialized value, previously obtained + using DUMP. + + ``replace`` allows an existing key on ``name`` to be overridden. If + it's not specified an error is raised on collision. + + ``absttl`` if True, specified ``ttl`` should represent an absolute Unix + timestamp in milliseconds in which the key will expire. (Redis 5.0 or + greater). + """ + params = [name, ttl, value] + if replace: + params.append('REPLACE') + if absttl: + params.append('ABSTTL') + return self.execute_command('RESTORE', *params) + + def set(self, name, value, + ex=None, px=None, nx=False, xx=False, keepttl=False, get=False): + """ + Set the value at key ``name`` to ``value`` + + ``ex`` sets an expire flag on key ``name`` for ``ex`` seconds. + + ``px`` sets an expire flag on key ``name`` for ``px`` milliseconds. + + ``nx`` if set to True, set the value at key ``name`` to ``value`` only + if it does not exist. + + ``xx`` if set to True, set the value at key ``name`` to ``value`` only + if it already exists. + + ``keepttl`` if True, retain the time to live associated with the key. + (Available since Redis 6.0) + + ``get`` if True, set the value at key ``name`` to ``value`` and return + the old value stored at key, or None when key did not exist. + (Available since Redis 6.2) + """ + pieces = [name, value] + options = {} + if ex is not None: + pieces.append('EX') + if isinstance(ex, datetime.timedelta): + ex = int(ex.total_seconds()) + pieces.append(ex) + if px is not None: + pieces.append('PX') + if isinstance(px, datetime.timedelta): + px = int(px.total_seconds() * 1000) + pieces.append(px) + + if nx: + pieces.append('NX') + if xx: + pieces.append('XX') + + if keepttl: + pieces.append('KEEPTTL') + + if get: + pieces.append('GET') + options["get"] = True + + return self.execute_command('SET', *pieces, **options) + + def __setitem__(self, name, value): + self.set(name, value) + + def setbit(self, name, offset, value): + """ + Flag the ``offset`` in ``name`` as ``value``. Returns a boolean + indicating the previous value of ``offset``. + """ + value = value and 1 or 0 + return self.execute_command('SETBIT', name, offset, value) + + def setex(self, name, time, value): + """ + Set the value of key ``name`` to ``value`` that expires in ``time`` + seconds. ``time`` can be represented by an integer or a Python + timedelta object. + """ + if isinstance(time, datetime.timedelta): + time = int(time.total_seconds()) + return self.execute_command('SETEX', name, time, value) + + def setnx(self, name, value): + "Set the value of key ``name`` to ``value`` if key doesn't exist" + return self.execute_command('SETNX', name, value) + + def setrange(self, name, offset, value): + """ + Overwrite bytes in the value of ``name`` starting at ``offset`` with + ``value``. If ``offset`` plus the length of ``value`` exceeds the + length of the original value, the new value will be larger than before. + If ``offset`` exceeds the length of the original value, null bytes + will be used to pad between the end of the previous value and the start + of what's being injected. + + Returns the length of the new string. + """ + return self.execute_command('SETRANGE', name, offset, value) + + def strlen(self, name): + "Return the number of bytes stored in the value of ``name``" + return self.execute_command('STRLEN', name) + + def substr(self, name, start, end=-1): + """ + Return a substring of the string at key ``name``. ``start`` and ``end`` + are 0-based integers specifying the portion of the string to return. + """ + return self.execute_command('SUBSTR', name, start, end) + + def touch(self, *args): + """ + Alters the last access time of a key(s) ``*args``. A key is ignored + if it does not exist. + """ + return self.execute_command('TOUCH', *args) + + def ttl(self, name): + "Returns the number of seconds until the key ``name`` will expire" + return self.execute_command('TTL', name) + + def type(self, name): + "Returns the type of key ``name``" + return self.execute_command('TYPE', name) + + def watch(self, *names): + """ + Watches the values at keys ``names``, or None if the key doesn't exist + """ + warnings.warn(DeprecationWarning('Call WATCH from a Pipeline object')) + + def unwatch(self): + """ + Unwatches the value at key ``name``, or None of the key doesn't exist + """ + warnings.warn( + DeprecationWarning('Call UNWATCH from a Pipeline object')) + + def unlink(self, *names): + "Unlink one or more keys specified by ``names``" + return self.execute_command('UNLINK', *names) + + +class ListCommands: + # LIST COMMANDS + def blpop(self, keys, timeout=0): + """ + LPOP a value off of the first non-empty list + named in the ``keys`` list. + + If none of the lists in ``keys`` has a value to LPOP, then block + for ``timeout`` seconds, or until a value gets pushed on to one + of the lists. + + If timeout is 0, then block indefinitely. + """ + if timeout is None: + timeout = 0 + keys = list_or_args(keys, None) + keys.append(timeout) + return self.execute_command('BLPOP', *keys) + + def brpop(self, keys, timeout=0): + """ + RPOP a value off of the first non-empty list + named in the ``keys`` list. + + If none of the lists in ``keys`` has a value to RPOP, then block + for ``timeout`` seconds, or until a value gets pushed on to one + of the lists. + + If timeout is 0, then block indefinitely. + """ + if timeout is None: + timeout = 0 + keys = list_or_args(keys, None) + keys.append(timeout) + return self.execute_command('BRPOP', *keys) + + def brpoplpush(self, src, dst, timeout=0): + """ + Pop a value off the tail of ``src``, push it on the head of ``dst`` + and then return it. + + This command blocks until a value is in ``src`` or until ``timeout`` + seconds elapse, whichever is first. A ``timeout`` value of 0 blocks + forever. + """ + if timeout is None: + timeout = 0 + return self.execute_command('BRPOPLPUSH', src, dst, timeout) + + def lindex(self, name, index): + """ + Return the item from list ``name`` at position ``index`` + + Negative indexes are supported and will return an item at the + end of the list + """ + return self.execute_command('LINDEX', name, index) + + def linsert(self, name, where, refvalue, value): + """ + Insert ``value`` in list ``name`` either immediately before or after + [``where``] ``refvalue`` + + Returns the new length of the list on success or -1 if ``refvalue`` + is not in the list. + """ + return self.execute_command('LINSERT', name, where, refvalue, value) + + def llen(self, name): + "Return the length of the list ``name``" + return self.execute_command('LLEN', name) + + def lpop(self, name, count=None): + """ + Removes and returns the first elements of the list ``name``. + + By default, the command pops a single element from the beginning of + the list. When provided with the optional ``count`` argument, the reply + will consist of up to count elements, depending on the list's length. + """ + if count is not None: + return self.execute_command('LPOP', name, count) + else: + return self.execute_command('LPOP', name) + + def lpush(self, name, *values): + "Push ``values`` onto the head of the list ``name``" + return self.execute_command('LPUSH', name, *values) + + def lpushx(self, name, value): + "Push ``value`` onto the head of the list ``name`` if ``name`` exists" + return self.execute_command('LPUSHX', name, value) + + def lrange(self, name, start, end): + """ + Return a slice of the list ``name`` between + position ``start`` and ``end`` + + ``start`` and ``end`` can be negative numbers just like + Python slicing notation + """ + return self.execute_command('LRANGE', name, start, end) + + def lrem(self, name, count, value): + """ + Remove the first ``count`` occurrences of elements equal to ``value`` + from the list stored at ``name``. + + The count argument influences the operation in the following ways: + count > 0: Remove elements equal to value moving from head to tail. + count < 0: Remove elements equal to value moving from tail to head. + count = 0: Remove all elements equal to value. + """ + return self.execute_command('LREM', name, count, value) + + def lset(self, name, index, value): + "Set ``position`` of list ``name`` to ``value``" + return self.execute_command('LSET', name, index, value) + + def ltrim(self, name, start, end): + """ + Trim the list ``name``, removing all values not within the slice + between ``start`` and ``end`` + + ``start`` and ``end`` can be negative numbers just like + Python slicing notation + """ + return self.execute_command('LTRIM', name, start, end) + + def rpop(self, name, count=None): + """ + Removes and returns the last elements of the list ``name``. + + By default, the command pops a single element from the end of the list. + When provided with the optional ``count`` argument, the reply will + consist of up to count elements, depending on the list's length. + """ + if count is not None: + return self.execute_command('RPOP', name, count) + else: + return self.execute_command('RPOP', name) + + def rpoplpush(self, src, dst): + """ + RPOP a value off of the ``src`` list and atomically LPUSH it + on to the ``dst`` list. Returns the value. + """ + return self.execute_command('RPOPLPUSH', src, dst) + + def rpush(self, name, *values): + "Push ``values`` onto the tail of the list ``name``" + return self.execute_command('RPUSH', name, *values) + + def rpushx(self, name, value): + "Push ``value`` onto the tail of the list ``name`` if ``name`` exists" + return self.execute_command('RPUSHX', name, value) + + def lpos(self, name, value, rank=None, count=None, maxlen=None): + """ + Get position of ``value`` within the list ``name`` + + If specified, ``rank`` indicates the "rank" of the first element to + return in case there are multiple copies of ``value`` in the list. + By default, LPOS returns the position of the first occurrence of + ``value`` in the list. When ``rank`` 2, LPOS returns the position of + the second ``value`` in the list. If ``rank`` is negative, LPOS + searches the list in reverse. For example, -1 would return the + position of the last occurrence of ``value`` and -2 would return the + position of the next to last occurrence of ``value``. + + If specified, ``count`` indicates that LPOS should return a list of + up to ``count`` positions. A ``count`` of 2 would return a list of + up to 2 positions. A ``count`` of 0 returns a list of all positions + matching ``value``. When ``count`` is specified and but ``value`` + does not exist in the list, an empty list is returned. + + If specified, ``maxlen`` indicates the maximum number of list + elements to scan. A ``maxlen`` of 1000 will only return the + position(s) of items within the first 1000 entries in the list. + A ``maxlen`` of 0 (the default) will scan the entire list. + """ + pieces = [name, value] + if rank is not None: + pieces.extend(['RANK', rank]) + + if count is not None: + pieces.extend(['COUNT', count]) + + if maxlen is not None: + pieces.extend(['MAXLEN', maxlen]) + + return self.execute_command('LPOS', *pieces) + + def sort(self, name, start=None, num=None, by=None, get=None, + desc=False, alpha=False, store=None, groups=False): + """ + Sort and return the list, set or sorted set at ``name``. + + ``start`` and ``num`` allow for paging through the sorted data + + ``by`` allows using an external key to weight and sort the items. + Use an "*" to indicate where in the key the item value is located + + ``get`` allows for returning items from external keys rather than the + sorted data itself. Use an "*" to indicate where in the key + the item value is located + + ``desc`` allows for reversing the sort + + ``alpha`` allows for sorting lexicographically rather than numerically + + ``store`` allows for storing the result of the sort into + the key ``store`` + + ``groups`` if set to True and if ``get`` contains at least two + elements, sort will return a list of tuples, each containing the + values fetched from the arguments to ``get``. + + """ + if (start is not None and num is None) or \ + (num is not None and start is None): + raise DataError("``start`` and ``num`` must both be specified") + + pieces = [name] + if by is not None: + pieces.append(b'BY') + pieces.append(by) + if start is not None and num is not None: + pieces.append(b'LIMIT') + pieces.append(start) + pieces.append(num) + if get is not None: + # If get is a string assume we want to get a single value. + # Otherwise assume it's an interable and we want to get multiple + # values. We can't just iterate blindly because strings are + # iterable. + if isinstance(get, (bytes, str)): + pieces.append(b'GET') + pieces.append(get) + else: + for g in get: + pieces.append(b'GET') + pieces.append(g) + if desc: + pieces.append(b'DESC') + if alpha: + pieces.append(b'ALPHA') + if store is not None: + pieces.append(b'STORE') + pieces.append(store) + + if groups: + if not get or isinstance(get, (bytes, str)) or len(get) < 2: + raise DataError('when using "groups" the "get" argument ' + 'must be specified and contain at least ' + 'two keys') + + options = {'groups': len(get) if groups else None} + return self.execute_command('SORT', *pieces, **options) + + +class ScanCommands: + # SCAN COMMANDS + def scan(self, cursor=0, match=None, count=None, _type=None): + """ + Incrementally return lists of key names. Also return a cursor + indicating the scan position. + + ``match`` allows for filtering the keys by pattern + + ``count`` provides a hint to Redis about the number of keys to + return per batch. + + ``_type`` filters the returned values by a particular Redis type. + Stock Redis instances allow for the following types: + HASH, LIST, SET, STREAM, STRING, ZSET + Additionally, Redis modules can expose other types as well. + """ + pieces = [cursor] + if match is not None: + pieces.extend([b'MATCH', match]) + if count is not None: + pieces.extend([b'COUNT', count]) + if _type is not None: + pieces.extend([b'TYPE', _type]) + return self.execute_command('SCAN', *pieces) + + def scan_iter(self, match=None, count=None, _type=None): + """ + Make an iterator using the SCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` provides a hint to Redis about the number of keys to + return per batch. + + ``_type`` filters the returned values by a particular Redis type. + Stock Redis instances allow for the following types: + HASH, LIST, SET, STREAM, STRING, ZSET + Additionally, Redis modules can expose other types as well. + """ + cursor = '0' + while cursor != 0: + cursor, data = self.scan(cursor=cursor, match=match, + count=count, _type=_type) + yield from data + + def sscan(self, name, cursor=0, match=None, count=None): + """ + Incrementally return lists of elements in a set. Also return a cursor + indicating the scan position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + """ + pieces = [name, cursor] + if match is not None: + pieces.extend([b'MATCH', match]) + if count is not None: + pieces.extend([b'COUNT', count]) + return self.execute_command('SSCAN', *pieces) + + def sscan_iter(self, name, match=None, count=None): + """ + Make an iterator using the SSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + """ + cursor = '0' + while cursor != 0: + cursor, data = self.sscan(name, cursor=cursor, + match=match, count=count) + yield from data + + def hscan(self, name, cursor=0, match=None, count=None): + """ + Incrementally return key/value slices in a hash. Also return a cursor + indicating the scan position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + """ + pieces = [name, cursor] + if match is not None: + pieces.extend([b'MATCH', match]) + if count is not None: + pieces.extend([b'COUNT', count]) + return self.execute_command('HSCAN', *pieces) + + def hscan_iter(self, name, match=None, count=None): + """ + Make an iterator using the HSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + """ + cursor = '0' + while cursor != 0: + cursor, data = self.hscan(name, cursor=cursor, + match=match, count=count) + yield from data.items() + + def zscan(self, name, cursor=0, match=None, count=None, + score_cast_func=float): + """ + Incrementally return lists of elements in a sorted set. Also return a + cursor indicating the scan position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + + ``score_cast_func`` a callable used to cast the score return value + """ + pieces = [name, cursor] + if match is not None: + pieces.extend([b'MATCH', match]) + if count is not None: + pieces.extend([b'COUNT', count]) + options = {'score_cast_func': score_cast_func} + return self.execute_command('ZSCAN', *pieces, **options) + + def zscan_iter(self, name, match=None, count=None, + score_cast_func=float): + """ + Make an iterator using the ZSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + + ``score_cast_func`` a callable used to cast the score return value + """ + cursor = '0' + while cursor != 0: + cursor, data = self.zscan(name, cursor=cursor, match=match, + count=count, + score_cast_func=score_cast_func) + yield from data + + +class SetCommands: + # SET COMMANDS + def sadd(self, name, *values): + "Add ``value(s)`` to set ``name``" + return self.execute_command('SADD', name, *values) + + def scard(self, name): + "Return the number of elements in set ``name``" + return self.execute_command('SCARD', name) + + def sdiff(self, keys, *args): + "Return the difference of sets specified by ``keys``" + args = list_or_args(keys, args) + return self.execute_command('SDIFF', *args) + + def sdiffstore(self, dest, keys, *args): + """ + Store the difference of sets specified by ``keys`` into a new + set named ``dest``. Returns the number of keys in the new set. + """ + args = list_or_args(keys, args) + return self.execute_command('SDIFFSTORE', dest, *args) + + def sinter(self, keys, *args): + "Return the intersection of sets specified by ``keys``" + args = list_or_args(keys, args) + return self.execute_command('SINTER', *args) + + def sinterstore(self, dest, keys, *args): + """ + Store the intersection of sets specified by ``keys`` into a new + set named ``dest``. Returns the number of keys in the new set. + """ + args = list_or_args(keys, args) + return self.execute_command('SINTERSTORE', dest, *args) + + def sismember(self, name, value): + "Return a boolean indicating if ``value`` is a member of set ``name``" + return self.execute_command('SISMEMBER', name, value) + + def smembers(self, name): + "Return all members of the set ``name``" + return self.execute_command('SMEMBERS', name) + + def smove(self, src, dst, value): + "Move ``value`` from set ``src`` to set ``dst`` atomically" + return self.execute_command('SMOVE', src, dst, value) + + def spop(self, name, count=None): + "Remove and return a random member of set ``name``" + args = (count is not None) and [count] or [] + return self.execute_command('SPOP', name, *args) + + def srandmember(self, name, number=None): + """ + If ``number`` is None, returns a random member of set ``name``. + + If ``number`` is supplied, returns a list of ``number`` random + members of set ``name``. Note this is only available when running + Redis 2.6+. + """ + args = (number is not None) and [number] or [] + return self.execute_command('SRANDMEMBER', name, *args) + + def srem(self, name, *values): + "Remove ``values`` from set ``name``" + return self.execute_command('SREM', name, *values) + + def sunion(self, keys, *args): + "Return the union of sets specified by ``keys``" + args = list_or_args(keys, args) + return self.execute_command('SUNION', *args) + + def sunionstore(self, dest, keys, *args): + """ + Store the union of sets specified by ``keys`` into a new + set named ``dest``. Returns the number of keys in the new set. + """ + args = list_or_args(keys, args) + return self.execute_command('SUNIONSTORE', dest, *args) + + +class StreamsCommands: + # STREAMS COMMANDS + def xack(self, name, groupname, *ids): + """ + Acknowledges the successful processing of one or more messages. + name: name of the stream. + groupname: name of the consumer group. + *ids: message ids to acknowledge. + """ + return self.execute_command('XACK', name, groupname, *ids) + + def xadd(self, name, fields, id='*', maxlen=None, approximate=True, + nomkstream=False): + """ + Add to a stream. + name: name of the stream + fields: dict of field/value pairs to insert into the stream + id: Location to insert this record. By default it is appended. + maxlen: truncate old stream members beyond this size + approximate: actual stream length may be slightly more than maxlen + nomkstream: When set to true, do not make a stream + """ + pieces = [] + if maxlen is not None: + if not isinstance(maxlen, int) or maxlen < 1: + raise DataError('XADD maxlen must be a positive integer') + pieces.append(b'MAXLEN') + if approximate: + pieces.append(b'~') + pieces.append(str(maxlen)) + if nomkstream: + pieces.append(b'NOMKSTREAM') + pieces.append(id) + if not isinstance(fields, dict) or len(fields) == 0: + raise DataError('XADD fields must be a non-empty dict') + for pair in fields.items(): + pieces.extend(pair) + return self.execute_command('XADD', name, *pieces) + + def xautoclaim(self, name, groupname, consumername, min_idle_time, + start_id=0, count=None, justid=False): + """ + Transfers ownership of pending stream entries that match the specified + criteria. Conceptually, equivalent to calling XPENDING and then XCLAIM, + but provides a more straightforward way to deal with message delivery + failures via SCAN-like semantics. + name: name of the stream. + groupname: name of the consumer group. + consumername: name of a consumer that claims the message. + min_idle_time: filter messages that were idle less than this amount of + milliseconds. + start_id: filter messages with equal or greater ID. + count: optional integer, upper limit of the number of entries that the + command attempts to claim. Set to 100 by default. + justid: optional boolean, false by default. Return just an array of IDs + of messages successfully claimed, without returning the actual message + """ + try: + if int(min_idle_time) < 0: + raise DataError("XAUTOCLAIM min_idle_time must be a non" + "negative integer") + except TypeError: + pass + + kwargs = {} + pieces = [name, groupname, consumername, min_idle_time, start_id] + + try: + if int(count) < 0: + raise DataError("XPENDING count must be a integer >= 0") + pieces.extend([b'COUNT', count]) + except TypeError: + pass + if justid: + pieces.append(b'JUSTID') + kwargs['parse_justid'] = True + + return self.execute_command('XAUTOCLAIM', *pieces, **kwargs) + + def xclaim(self, name, groupname, consumername, min_idle_time, message_ids, + idle=None, time=None, retrycount=None, force=False, + justid=False): + """ + Changes the ownership of a pending message. + name: name of the stream. + groupname: name of the consumer group. + consumername: name of a consumer that claims the message. + min_idle_time: filter messages that were idle less than this amount of + milliseconds + message_ids: non-empty list or tuple of message IDs to claim + idle: optional. Set the idle time (last time it was delivered) of the + message in ms + time: optional integer. This is the same as idle but instead of a + relative amount of milliseconds, it sets the idle time to a specific + Unix time (in milliseconds). + retrycount: optional integer. set the retry counter to the specified + value. This counter is incremented every time a message is delivered + again. + force: optional boolean, false by default. Creates the pending message + entry in the PEL even if certain specified IDs are not already in the + PEL assigned to a different client. + justid: optional boolean, false by default. Return just an array of IDs + of messages successfully claimed, without returning the actual message + """ + if not isinstance(min_idle_time, int) or min_idle_time < 0: + raise DataError("XCLAIM min_idle_time must be a non negative " + "integer") + if not isinstance(message_ids, (list, tuple)) or not message_ids: + raise DataError("XCLAIM message_ids must be a non empty list or " + "tuple of message IDs to claim") + + kwargs = {} + pieces = [name, groupname, consumername, str(min_idle_time)] + pieces.extend(list(message_ids)) + + if idle is not None: + if not isinstance(idle, int): + raise DataError("XCLAIM idle must be an integer") + pieces.extend((b'IDLE', str(idle))) + if time is not None: + if not isinstance(time, int): + raise DataError("XCLAIM time must be an integer") + pieces.extend((b'TIME', str(time))) + if retrycount is not None: + if not isinstance(retrycount, int): + raise DataError("XCLAIM retrycount must be an integer") + pieces.extend((b'RETRYCOUNT', str(retrycount))) + + if force: + if not isinstance(force, bool): + raise DataError("XCLAIM force must be a boolean") + pieces.append(b'FORCE') + if justid: + if not isinstance(justid, bool): + raise DataError("XCLAIM justid must be a boolean") + pieces.append(b'JUSTID') + kwargs['parse_justid'] = True + return self.execute_command('XCLAIM', *pieces, **kwargs) + + def xdel(self, name, *ids): + """ + Deletes one or more messages from a stream. + name: name of the stream. + *ids: message ids to delete. + """ + return self.execute_command('XDEL', name, *ids) + + def xgroup_create(self, name, groupname, id='$', mkstream=False): + """ + Create a new consumer group associated with a stream. + name: name of the stream. + groupname: name of the consumer group. + id: ID of the last item in the stream to consider already delivered. + """ + pieces = ['XGROUP CREATE', name, groupname, id] + if mkstream: + pieces.append(b'MKSTREAM') + return self.execute_command(*pieces) + + def xgroup_delconsumer(self, name, groupname, consumername): + """ + Remove a specific consumer from a consumer group. + Returns the number of pending messages that the consumer had before it + was deleted. + name: name of the stream. + groupname: name of the consumer group. + consumername: name of consumer to delete + """ + return self.execute_command('XGROUP DELCONSUMER', name, groupname, + consumername) + + def xgroup_destroy(self, name, groupname): + """ + Destroy a consumer group. + name: name of the stream. + groupname: name of the consumer group. + """ + return self.execute_command('XGROUP DESTROY', name, groupname) + + def xgroup_setid(self, name, groupname, id): + """ + Set the consumer group last delivered ID to something else. + name: name of the stream. + groupname: name of the consumer group. + id: ID of the last item in the stream to consider already delivered. + """ + return self.execute_command('XGROUP SETID', name, groupname, id) + + def xinfo_consumers(self, name, groupname): + """ + Returns general information about the consumers in the group. + name: name of the stream. + groupname: name of the consumer group. + """ + return self.execute_command('XINFO CONSUMERS', name, groupname) + + def xinfo_groups(self, name): + """ + Returns general information about the consumer groups of the stream. + name: name of the stream. + """ + return self.execute_command('XINFO GROUPS', name) + + def xinfo_stream(self, name): + """ + Returns general information about the stream. + name: name of the stream. + """ + return self.execute_command('XINFO STREAM', name) + + def xlen(self, name): + """ + Returns the number of elements in a given stream. + """ + return self.execute_command('XLEN', name) + + def xpending(self, name, groupname): + """ + Returns information about pending messages of a group. + name: name of the stream. + groupname: name of the consumer group. + """ + return self.execute_command('XPENDING', name, groupname) + + def xpending_range(self, name, groupname, min, max, count, + consumername=None, idle=None): + """ + Returns information about pending messages, in a range. + name: name of the stream. + groupname: name of the consumer group. + min: minimum stream ID. + max: maximum stream ID. + count: number of messages to return + consumername: name of a consumer to filter by (optional). + idle: available from version 6.2. filter entries by their + idle-time, given in milliseconds (optional). + """ + if {min, max, count} == {None}: + if idle is not None or consumername is not None: + raise DataError("if XPENDING is provided with idle time" + " or consumername, it must be provided" + " with min, max and count parameters") + return self.xpending(name, groupname) + + pieces = [name, groupname] + if min is None or max is None or count is None: + raise DataError("XPENDING must be provided with min, max " + "and count parameters, or none of them.") + # idle + try: + if int(idle) < 0: + raise DataError("XPENDING idle must be a integer >= 0") + pieces.extend(['IDLE', idle]) + except TypeError: + pass + # count + try: + if int(count) < 0: + raise DataError("XPENDING count must be a integer >= 0") + pieces.extend([min, max, count]) + except TypeError: + pass + + return self.execute_command('XPENDING', *pieces, parse_detail=True) + + def xrange(self, name, min='-', max='+', count=None): + """ + Read stream values within an interval. + name: name of the stream. + start: first stream ID. defaults to '-', + meaning the earliest available. + finish: last stream ID. defaults to '+', + meaning the latest available. + count: if set, only return this many items, beginning with the + earliest available. + """ + pieces = [min, max] + if count is not None: + if not isinstance(count, int) or count < 1: + raise DataError('XRANGE count must be a positive integer') + pieces.append(b'COUNT') + pieces.append(str(count)) + + return self.execute_command('XRANGE', name, *pieces) + + def xread(self, streams, count=None, block=None): + """ + Block and monitor multiple streams for new data. + streams: a dict of stream names to stream IDs, where + IDs indicate the last ID already seen. + count: if set, only return this many items, beginning with the + earliest available. + block: number of milliseconds to wait, if nothing already present. + """ + pieces = [] + if block is not None: + if not isinstance(block, int) or block < 0: + raise DataError('XREAD block must be a non-negative integer') + pieces.append(b'BLOCK') + pieces.append(str(block)) + if count is not None: + if not isinstance(count, int) or count < 1: + raise DataError('XREAD count must be a positive integer') + pieces.append(b'COUNT') + pieces.append(str(count)) + if not isinstance(streams, dict) or len(streams) == 0: + raise DataError('XREAD streams must be a non empty dict') + pieces.append(b'STREAMS') + keys, values = zip(*streams.items()) + pieces.extend(keys) + pieces.extend(values) + return self.execute_command('XREAD', *pieces) + + def xreadgroup(self, groupname, consumername, streams, count=None, + block=None, noack=False): + """ + Read from a stream via a consumer group. + groupname: name of the consumer group. + consumername: name of the requesting consumer. + streams: a dict of stream names to stream IDs, where + IDs indicate the last ID already seen. + count: if set, only return this many items, beginning with the + earliest available. + block: number of milliseconds to wait, if nothing already present. + noack: do not add messages to the PEL + """ + pieces = [b'GROUP', groupname, consumername] + if count is not None: + if not isinstance(count, int) or count < 1: + raise DataError("XREADGROUP count must be a positive integer") + pieces.append(b'COUNT') + pieces.append(str(count)) + if block is not None: + if not isinstance(block, int) or block < 0: + raise DataError("XREADGROUP block must be a non-negative " + "integer") + pieces.append(b'BLOCK') + pieces.append(str(block)) + if noack: + pieces.append(b'NOACK') + if not isinstance(streams, dict) or len(streams) == 0: + raise DataError('XREADGROUP streams must be a non empty dict') + pieces.append(b'STREAMS') + pieces.extend(streams.keys()) + pieces.extend(streams.values()) + return self.execute_command('XREADGROUP', *pieces) + + def xrevrange(self, name, max='+', min='-', count=None): + """ + Read stream values within an interval, in reverse order. + name: name of the stream + start: first stream ID. defaults to '+', + meaning the latest available. + finish: last stream ID. defaults to '-', + meaning the earliest available. + count: if set, only return this many items, beginning with the + latest available. + """ + pieces = [max, min] + if count is not None: + if not isinstance(count, int) or count < 1: + raise DataError('XREVRANGE count must be a positive integer') + pieces.append(b'COUNT') + pieces.append(str(count)) + + return self.execute_command('XREVRANGE', name, *pieces) + + def xtrim(self, name, maxlen=None, approximate=True, minid=None, + limit=None): + """ + Trims old messages from a stream. + name: name of the stream. + maxlen: truncate old stream messages beyond this size + approximate: actual stream length may be slightly more than maxlen + """ + pieces = [] + if maxlen is not None and minid is not None: + raise DataError("Only one of ```maxlen``` or ```minid```", + "may be specified") + + if maxlen is not None: + pieces.append(b'MAXLEN') + if minid is not None: + pieces.append(b'MINID') + if approximate: + pieces.append(b'~') + if maxlen is not None: + pieces.append(maxlen) + if minid is not None: + pieces.append(minid) + if limit is not None: + pieces.append(b"LIMIT") + pieces.append(limit) + + return self.execute_command('XTRIM', name, *pieces) + + +class SortedSetCommands: + # SORTED SET COMMANDS + def zadd(self, name, mapping, nx=False, xx=False, ch=False, incr=False, + gt=None, lt=None): + """ + Set any number of element-name, score pairs to the key ``name``. Pairs + are specified as a dict of element-names keys to score values. + + ``nx`` forces ZADD to only create new elements and not to update + scores for elements that already exist. + + ``xx`` forces ZADD to only update scores of elements that already + exist. New elements will not be added. + + ``ch`` modifies the return value to be the numbers of elements changed. + Changed elements include new elements that were added and elements + whose scores changed. + + ``incr`` modifies ZADD to behave like ZINCRBY. In this mode only a + single element/score pair can be specified and the score is the amount + the existing score will be incremented by. When using this mode the + return value of ZADD will be the new score of the element. + + ``LT`` Only update existing elements if the new score is less than + the current score. This flag doesn't prevent adding new elements. + + ``GT`` Only update existing elements if the new score is greater than + the current score. This flag doesn't prevent adding new elements. + + The return value of ZADD varies based on the mode specified. With no + options, ZADD returns the number of new elements added to the sorted + set. + + ``NX``, ``LT``, and ``GT`` are mutually exclusive options. + See: https://redis.io/commands/ZADD + """ + if not mapping: + raise DataError("ZADD requires at least one element/score pair") + if nx and xx: + raise DataError("ZADD allows either 'nx' or 'xx', not both") + if incr and len(mapping) != 1: + raise DataError("ZADD option 'incr' only works when passing a " + "single element/score pair") + if nx is True and (gt is not None or lt is not None): + raise DataError("Only one of 'nx', 'lt', or 'gr' may be defined.") + + pieces = [] + options = {} + if nx: + pieces.append(b'NX') + if xx: + pieces.append(b'XX') + if ch: + pieces.append(b'CH') + if incr: + pieces.append(b'INCR') + options['as_score'] = True + if gt: + pieces.append(b'GT') + if lt: + pieces.append(b'LT') + for pair in mapping.items(): + pieces.append(pair[1]) + pieces.append(pair[0]) + return self.execute_command('ZADD', name, *pieces, **options) + + def zcard(self, name): + "Return the number of elements in the sorted set ``name``" + return self.execute_command('ZCARD', name) + + def zcount(self, name, min, max): + """ + Returns the number of elements in the sorted set at key ``name`` with + a score between ``min`` and ``max``. + """ + return self.execute_command('ZCOUNT', name, min, max) + + def zdiff(self, keys, withscores=False): + """ + Returns the difference between the first and all successive input + sorted sets provided in ``keys``. + """ + pieces = [len(keys), *keys] + if withscores: + pieces.append("WITHSCORES") + return self.execute_command("ZDIFF", *pieces) + + def zdiffstore(self, dest, keys): + """ + Computes the difference between the first and all successive input + sorted sets provided in ``keys`` and stores the result in ``dest``. + """ + pieces = [len(keys), *keys] + return self.execute_command("ZDIFFSTORE", dest, *pieces) + + def zincrby(self, name, amount, value): + "Increment the score of ``value`` in sorted set ``name`` by ``amount``" + return self.execute_command('ZINCRBY', name, amount, value) + + def zinter(self, keys, aggregate=None, withscores=False): + """ + Return the intersect of multiple sorted sets specified by ``keys``. + With the ``aggregate`` option, it is possible to specify how the + results of the union are aggregated. This option defaults to SUM, + where the score of an element is summed across the inputs where it + exists. When this option is set to either MIN or MAX, the resulting + set will contain the minimum or maximum score of an element across + the inputs where it exists. + """ + return self._zaggregate('ZINTER', None, keys, aggregate, + withscores=withscores) + + def zinterstore(self, dest, keys, aggregate=None): + """ + Intersect multiple sorted sets specified by ``keys`` into a new + sorted set, ``dest``. Scores in the destination will be aggregated + based on the ``aggregate``. This option defaults to SUM, where the + score of an element is summed across the inputs where it exists. + When this option is set to either MIN or MAX, the resulting set will + contain the minimum or maximum score of an element across the inputs + where it exists. + """ + return self._zaggregate('ZINTERSTORE', dest, keys, aggregate) + + def zlexcount(self, name, min, max): + """ + Return the number of items in the sorted set ``name`` between the + lexicographical range ``min`` and ``max``. + """ + return self.execute_command('ZLEXCOUNT', name, min, max) + + def zpopmax(self, name, count=None): + """ + Remove and return up to ``count`` members with the highest scores + from the sorted set ``name``. + """ + args = (count is not None) and [count] or [] + options = { + 'withscores': True + } + return self.execute_command('ZPOPMAX', name, *args, **options) + + def zpopmin(self, name, count=None): + """ + Remove and return up to ``count`` members with the lowest scores + from the sorted set ``name``. + """ + args = (count is not None) and [count] or [] + options = { + 'withscores': True + } + return self.execute_command('ZPOPMIN', name, *args, **options) + + def zrandmember(self, key, count=None, withscores=False): + """ + Return a random element from the sorted set value stored at key. + + ``count`` if the argument is positive, return an array of distinct + fields. If called with a negative count, the behavior changes and + the command is allowed to return the same field multiple times. + In this case, the number of returned fields is the absolute value + of the specified count. + + ``withscores`` The optional WITHSCORES modifier changes the reply so it + includes the respective scores of the randomly selected elements from + the sorted set. + """ + params = [] + if count is not None: + params.append(count) + if withscores: + params.append("WITHSCORES") + + return self.execute_command("ZRANDMEMBER", key, *params) + + def bzpopmax(self, keys, timeout=0): + """ + ZPOPMAX a value off of the first non-empty sorted set + named in the ``keys`` list. + + If none of the sorted sets in ``keys`` has a value to ZPOPMAX, + then block for ``timeout`` seconds, or until a member gets added + to one of the sorted sets. + + If timeout is 0, then block indefinitely. + """ + if timeout is None: + timeout = 0 + keys = list_or_args(keys, None) + keys.append(timeout) + return self.execute_command('BZPOPMAX', *keys) + + def bzpopmin(self, keys, timeout=0): + """ + ZPOPMIN a value off of the first non-empty sorted set + named in the ``keys`` list. + + If none of the sorted sets in ``keys`` has a value to ZPOPMIN, + then block for ``timeout`` seconds, or until a member gets added + to one of the sorted sets. + + If timeout is 0, then block indefinitely. + """ + if timeout is None: + timeout = 0 + keys = list_or_args(keys, None) + keys.append(timeout) + return self.execute_command('BZPOPMIN', *keys) + + def zrange(self, name, start, end, desc=False, withscores=False, + score_cast_func=float): + """ + Return a range of values from sorted set ``name`` between + ``start`` and ``end`` sorted in ascending order. + + ``start`` and ``end`` can be negative, indicating the end of the range. + + ``desc`` a boolean indicating whether to sort the results descendingly + + ``withscores`` indicates to return the scores along with the values. + The return type is a list of (value, score) pairs + + ``score_cast_func`` a callable used to cast the score return value + """ + if desc: + return self.zrevrange(name, start, end, withscores, + score_cast_func) + pieces = ['ZRANGE', name, start, end] + if withscores: + pieces.append(b'WITHSCORES') + options = { + 'withscores': withscores, + 'score_cast_func': score_cast_func + } + return self.execute_command(*pieces, **options) + + def zrangestore(self, dest, name, start, end): + """ + Stores in ``dest`` the result of a range of values from sorted set + ``name`` between ``start`` and ``end`` sorted in ascending order. + + ``start`` and ``end`` can be negative, indicating the end of the range. + """ + return self.execute_command('ZRANGESTORE', dest, name, start, end) + + def zrangebylex(self, name, min, max, start=None, num=None): + """ + Return the lexicographical range of values from sorted set ``name`` + between ``min`` and ``max``. + + If ``start`` and ``num`` are specified, then return a slice of the + range. + """ + if (start is not None and num is None) or \ + (num is not None and start is None): + raise DataError("``start`` and ``num`` must both be specified") + pieces = ['ZRANGEBYLEX', name, min, max] + if start is not None and num is not None: + pieces.extend([b'LIMIT', start, num]) + return self.execute_command(*pieces) + + def zrevrangebylex(self, name, max, min, start=None, num=None): + """ + Return the reversed lexicographical range of values from sorted set + ``name`` between ``max`` and ``min``. + + If ``start`` and ``num`` are specified, then return a slice of the + range. + """ + if (start is not None and num is None) or \ + (num is not None and start is None): + raise DataError("``start`` and ``num`` must both be specified") + pieces = ['ZREVRANGEBYLEX', name, max, min] + if start is not None and num is not None: + pieces.extend([b'LIMIT', start, num]) + return self.execute_command(*pieces) + + def zrangebyscore(self, name, min, max, start=None, num=None, + withscores=False, score_cast_func=float): + """ + Return a range of values from the sorted set ``name`` with scores + between ``min`` and ``max``. + + If ``start`` and ``num`` are specified, then return a slice + of the range. + + ``withscores`` indicates to return the scores along with the values. + The return type is a list of (value, score) pairs + + `score_cast_func`` a callable used to cast the score return value + """ + if (start is not None and num is None) or \ + (num is not None and start is None): + raise DataError("``start`` and ``num`` must both be specified") + pieces = ['ZRANGEBYSCORE', name, min, max] + if start is not None and num is not None: + pieces.extend([b'LIMIT', start, num]) + if withscores: + pieces.append(b'WITHSCORES') + options = { + 'withscores': withscores, + 'score_cast_func': score_cast_func + } + return self.execute_command(*pieces, **options) + + def zrank(self, name, value): + """ + Returns a 0-based value indicating the rank of ``value`` in sorted set + ``name`` + """ + return self.execute_command('ZRANK', name, value) + + def zrem(self, name, *values): + "Remove member ``values`` from sorted set ``name``" + return self.execute_command('ZREM', name, *values) + + def zremrangebylex(self, name, min, max): + """ + Remove all elements in the sorted set ``name`` between the + lexicographical range specified by ``min`` and ``max``. + + Returns the number of elements removed. + """ + return self.execute_command('ZREMRANGEBYLEX', name, min, max) + + def zremrangebyrank(self, name, min, max): + """ + Remove all elements in the sorted set ``name`` with ranks between + ``min`` and ``max``. Values are 0-based, ordered from smallest score + to largest. Values can be negative indicating the highest scores. + Returns the number of elements removed + """ + return self.execute_command('ZREMRANGEBYRANK', name, min, max) + + def zremrangebyscore(self, name, min, max): + """ + Remove all elements in the sorted set ``name`` with scores + between ``min`` and ``max``. Returns the number of elements removed. + """ + return self.execute_command('ZREMRANGEBYSCORE', name, min, max) + + def zrevrange(self, name, start, end, withscores=False, + score_cast_func=float): + """ + Return a range of values from sorted set ``name`` between + ``start`` and ``end`` sorted in descending order. + + ``start`` and ``end`` can be negative, indicating the end of the range. + + ``withscores`` indicates to return the scores along with the values + The return type is a list of (value, score) pairs + + ``score_cast_func`` a callable used to cast the score return value + """ + pieces = ['ZREVRANGE', name, start, end] + if withscores: + pieces.append(b'WITHSCORES') + options = { + 'withscores': withscores, + 'score_cast_func': score_cast_func + } + return self.execute_command(*pieces, **options) + + def zrevrangebyscore(self, name, max, min, start=None, num=None, + withscores=False, score_cast_func=float): + """ + Return a range of values from the sorted set ``name`` with scores + between ``min`` and ``max`` in descending order. + + If ``start`` and ``num`` are specified, then return a slice + of the range. + + ``withscores`` indicates to return the scores along with the values. + The return type is a list of (value, score) pairs + + ``score_cast_func`` a callable used to cast the score return value + """ + if (start is not None and num is None) or \ + (num is not None and start is None): + raise DataError("``start`` and ``num`` must both be specified") + pieces = ['ZREVRANGEBYSCORE', name, max, min] + if start is not None and num is not None: + pieces.extend([b'LIMIT', start, num]) + if withscores: + pieces.append(b'WITHSCORES') + options = { + 'withscores': withscores, + 'score_cast_func': score_cast_func + } + return self.execute_command(*pieces, **options) + + def zrevrank(self, name, value): + """ + Returns a 0-based value indicating the descending rank of + ``value`` in sorted set ``name`` + """ + return self.execute_command('ZREVRANK', name, value) + + def zscore(self, name, value): + "Return the score of element ``value`` in sorted set ``name``" + return self.execute_command('ZSCORE', name, value) + + def zunion(self, keys, aggregate=None, withscores=False): + """ + Return the union of multiple sorted sets specified by ``keys``. + ``keys`` can be provided as dictionary of keys and their weights. + Scores will be aggregated based on the ``aggregate``, or SUM if + none is provided. + """ + return self._zaggregate('ZUNION', None, keys, aggregate, + withscores=withscores) + + def zunionstore(self, dest, keys, aggregate=None): + """ + Union multiple sorted sets specified by ``keys`` into + a new sorted set, ``dest``. Scores in the destination will be + aggregated based on the ``aggregate``, or SUM if none is provided. + """ + return self._zaggregate('ZUNIONSTORE', dest, keys, aggregate) + + def _zaggregate(self, command, dest, keys, aggregate=None, + **options): + pieces = [command] + if dest is not None: + pieces.append(dest) + pieces.append(len(keys)) + if isinstance(keys, dict): + keys, weights = keys.keys(), keys.values() + else: + weights = None + pieces.extend(keys) + if weights: + pieces.append(b'WEIGHTS') + pieces.extend(weights) + if aggregate: + if aggregate.upper() in ['SUM', 'MIN', 'MAX']: + pieces.append(b'AGGREGATE') + pieces.append(aggregate) + else: + raise DataError("aggregate can be sum, min or max.") + if options.get('withscores', False): + pieces.append(b'WITHSCORES') + return self.execute_command(*pieces, **options) + + +class HyperLogLogCommands: + # HYPERLOGLOG COMMANDS + def pfadd(self, name, *values): + "Adds the specified elements to the specified HyperLogLog." + return self.execute_command('PFADD', name, *values) + + def pfcount(self, *sources): + """ + Return the approximated cardinality of + the set observed by the HyperLogLog at key(s). + """ + return self.execute_command('PFCOUNT', *sources) + + def pfmerge(self, dest, *sources): + "Merge N different HyperLogLogs into a single one." + return self.execute_command('PFMERGE', dest, *sources) + + +class HashCommands: + # HASH COMMANDS + def hdel(self, name, *keys): + "Delete ``keys`` from hash ``name``" + return self.execute_command('HDEL', name, *keys) + + def hexists(self, name, key): + "Returns a boolean indicating if ``key`` exists within hash ``name``" + return self.execute_command('HEXISTS', name, key) + + def hget(self, name, key): + "Return the value of ``key`` within the hash ``name``" + return self.execute_command('HGET', name, key) + + def hgetall(self, name): + "Return a Python dict of the hash's name/value pairs" + return self.execute_command('HGETALL', name) + + def hincrby(self, name, key, amount=1): + "Increment the value of ``key`` in hash ``name`` by ``amount``" + return self.execute_command('HINCRBY', name, key, amount) + + def hincrbyfloat(self, name, key, amount=1.0): + """ + Increment the value of ``key`` in hash ``name`` by floating ``amount`` + """ + return self.execute_command('HINCRBYFLOAT', name, key, amount) + + def hkeys(self, name): + "Return the list of keys within hash ``name``" + return self.execute_command('HKEYS', name) + + def hlen(self, name): + "Return the number of elements in hash ``name``" + return self.execute_command('HLEN', name) + + def hset(self, name, key=None, value=None, mapping=None): + """ + Set ``key`` to ``value`` within hash ``name``, + ``mapping`` accepts a dict of key/value pairs that will be + added to hash ``name``. + Returns the number of fields that were added. + """ + if key is None and not mapping: + raise DataError("'hset' with no key value pairs") + items = [] + if key is not None: + items.extend((key, value)) + if mapping: + for pair in mapping.items(): + items.extend(pair) + + return self.execute_command('HSET', name, *items) + + def hsetnx(self, name, key, value): + """ + Set ``key`` to ``value`` within hash ``name`` if ``key`` does not + exist. Returns 1 if HSETNX created a field, otherwise 0. + """ + return self.execute_command('HSETNX', name, key, value) + + def hmset(self, name, mapping): + """ + Set key to value within hash ``name`` for each corresponding + key and value from the ``mapping`` dict. + """ + warnings.warn( + '%s.hmset() is deprecated. Use %s.hset() instead.' + % (self.__class__.__name__, self.__class__.__name__), + DeprecationWarning, + stacklevel=2, + ) + if not mapping: + raise DataError("'hmset' with 'mapping' of length 0") + items = [] + for pair in mapping.items(): + items.extend(pair) + return self.execute_command('HMSET', name, *items) + + def hmget(self, name, keys, *args): + "Returns a list of values ordered identically to ``keys``" + args = list_or_args(keys, args) + return self.execute_command('HMGET', name, *args) + + def hvals(self, name): + "Return the list of values within hash ``name``" + return self.execute_command('HVALS', name) + + def hstrlen(self, name, key): + """ + Return the number of bytes stored in the value of ``key`` + within hash ``name`` + """ + return self.execute_command('HSTRLEN', name, key) + + +class PubSubCommands: + def publish(self, channel, message): + """ + Publish ``message`` on ``channel``. + Returns the number of subscribers the message was delivered to. + """ + return self.execute_command('PUBLISH', channel, message) + + def pubsub_channels(self, pattern='*'): + """ + Return a list of channels that have at least one subscriber + """ + return self.execute_command('PUBSUB CHANNELS', pattern) + + def pubsub_numpat(self): + """ + Returns the number of subscriptions to patterns + """ + return self.execute_command('PUBSUB NUMPAT') + + def pubsub_numsub(self, *args): + """ + Return a list of (channel, number of subscribers) tuples + for each channel given in ``*args`` + """ + return self.execute_command('PUBSUB NUMSUB', *args) + + +class ScriptCommands: + def eval(self, script, numkeys, *keys_and_args): + """ + Execute the Lua ``script``, specifying the ``numkeys`` the script + will touch and the key names and argument values in ``keys_and_args``. + Returns the result of the script. + + In practice, use the object returned by ``register_script``. This + function exists purely for Redis API completion. + """ + return self.execute_command('EVAL', script, numkeys, *keys_and_args) + + def evalsha(self, sha, numkeys, *keys_and_args): + """ + Use the ``sha`` to execute a Lua script already registered via EVAL + or SCRIPT LOAD. Specify the ``numkeys`` the script will touch and the + key names and argument values in ``keys_and_args``. Returns the result + of the script. + + In practice, use the object returned by ``register_script``. This + function exists purely for Redis API completion. + """ + return self.execute_command('EVALSHA', sha, numkeys, *keys_and_args) + + def script_exists(self, *args): + """ + Check if a script exists in the script cache by specifying the SHAs of + each script as ``args``. Returns a list of boolean values indicating if + if each already script exists in the cache. + """ + return self.execute_command('SCRIPT EXISTS', *args) + + def script_flush(self): + "Flush all scripts from the script cache" + return self.execute_command('SCRIPT FLUSH') + + def script_kill(self): + "Kill the currently executing Lua script" + return self.execute_command('SCRIPT KILL') + + def script_load(self, script): + "Load a Lua ``script`` into the script cache. Returns the SHA." + return self.execute_command('SCRIPT LOAD', script) + + def register_script(self, script): + """ + Register a Lua ``script`` specifying the ``keys`` it will touch. + Returns a Script object that is callable and hides the complexity of + deal with scripts, keys, and shas. This is the preferred way to work + with Lua scripts. + """ + return Script(self, script) + + +class GeoCommands: + # GEO COMMANDS + def geoadd(self, name, *values): + """ + Add the specified geospatial items to the specified key identified + by the ``name`` argument. The Geospatial items are given as ordered + members of the ``values`` argument, each item or place is formed by + the triad longitude, latitude and name. + """ + if len(values) % 3 != 0: + raise DataError("GEOADD requires places with lon, lat and name" + " values") + return self.execute_command('GEOADD', name, *values) + + def geodist(self, name, place1, place2, unit=None): + """ + Return the distance between ``place1`` and ``place2`` members of the + ``name`` key. + The units must be one of the following : m, km mi, ft. By default + meters are used. + """ + pieces = [name, place1, place2] + if unit and unit not in ('m', 'km', 'mi', 'ft'): + raise DataError("GEODIST invalid unit") + elif unit: + pieces.append(unit) + return self.execute_command('GEODIST', *pieces) + + def geohash(self, name, *values): + """ + Return the geo hash string for each item of ``values`` members of + the specified key identified by the ``name`` argument. + """ + return self.execute_command('GEOHASH', name, *values) + + def geopos(self, name, *values): + """ + Return the positions of each item of ``values`` as members of + the specified key identified by the ``name`` argument. Each position + is represented by the pairs lon and lat. + """ + return self.execute_command('GEOPOS', name, *values) + + def georadius(self, name, longitude, latitude, radius, unit=None, + withdist=False, withcoord=False, withhash=False, count=None, + sort=None, store=None, store_dist=None): + """ + Return the members of the specified key identified by the + ``name`` argument which are within the borders of the area specified + with the ``latitude`` and ``longitude`` location and the maximum + distance from the center specified by the ``radius`` value. + + The units must be one of the following : m, km mi, ft. By default + + ``withdist`` indicates to return the distances of each place. + + ``withcoord`` indicates to return the latitude and longitude of + each place. + + ``withhash`` indicates to return the geohash string of each place. + + ``count`` indicates to return the number of elements up to N. + + ``sort`` indicates to return the places in a sorted way, ASC for + nearest to fairest and DESC for fairest to nearest. + + ``store`` indicates to save the places names in a sorted set named + with a specific key, each element of the destination sorted set is + populated with the score got from the original geo sorted set. + + ``store_dist`` indicates to save the places names in a sorted set + named with a specific key, instead of ``store`` the sorted set + destination score is set with the distance. + """ + return self._georadiusgeneric('GEORADIUS', + name, longitude, latitude, radius, + unit=unit, withdist=withdist, + withcoord=withcoord, withhash=withhash, + count=count, sort=sort, store=store, + store_dist=store_dist) + + def georadiusbymember(self, name, member, radius, unit=None, + withdist=False, withcoord=False, withhash=False, + count=None, sort=None, store=None, store_dist=None): + """ + This command is exactly like ``georadius`` with the sole difference + that instead of taking, as the center of the area to query, a longitude + and latitude value, it takes the name of a member already existing + inside the geospatial index represented by the sorted set. + """ + return self._georadiusgeneric('GEORADIUSBYMEMBER', + name, member, radius, unit=unit, + withdist=withdist, withcoord=withcoord, + withhash=withhash, count=count, + sort=sort, store=store, + store_dist=store_dist) + + def _georadiusgeneric(self, command, *args, **kwargs): + pieces = list(args) + if kwargs['unit'] and kwargs['unit'] not in ('m', 'km', 'mi', 'ft'): + raise DataError("GEORADIUS invalid unit") + elif kwargs['unit']: + pieces.append(kwargs['unit']) + else: + pieces.append('m', ) + + for arg_name, byte_repr in ( + ('withdist', b'WITHDIST'), + ('withcoord', b'WITHCOORD'), + ('withhash', b'WITHHASH')): + if kwargs[arg_name]: + pieces.append(byte_repr) + + if kwargs['count']: + pieces.extend([b'COUNT', kwargs['count']]) + + if kwargs['sort']: + if kwargs['sort'] == 'ASC': + pieces.append(b'ASC') + elif kwargs['sort'] == 'DESC': + pieces.append(b'DESC') + else: + raise DataError("GEORADIUS invalid sort") + + if kwargs['store'] and kwargs['store_dist']: + raise DataError("GEORADIUS store and store_dist cant be set" + " together") + + if kwargs['store']: + pieces.extend([b'STORE', kwargs['store']]) + + if kwargs['store_dist']: + pieces.extend([b'STOREDIST', kwargs['store_dist']]) + + return self.execute_command(command, *pieces, **kwargs) + + +class ModuleCommands: + # MODULE COMMANDS + def module_load(self, path): + """ + Loads the module from ``path``. + Raises ``ModuleError`` if a module is not found at ``path``. + """ + return self.execute_command('MODULE LOAD', path) + + def module_unload(self, name): + """ + Unloads the module ``name``. + Raises ``ModuleError`` if ``name`` is not in loaded modules. + """ + return self.execute_command('MODULE UNLOAD', name) + + def module_list(self): + """ + Returns a list of dictionaries containing the name and version of + all loaded modules. + """ + return self.execute_command('MODULE LIST') + + +class Script: + "An executable Lua script object returned by ``register_script``" + + def __init__(self, registered_client, script): + self.registered_client = registered_client + self.script = script + # Precalculate and store the SHA1 hex digest of the script. + + if isinstance(script, str): + # We need the encoding from the client in order to generate an + # accurate byte representation of the script + encoder = registered_client.connection_pool.get_encoder() + script = encoder.encode(script) + self.sha = hashlib.sha1(script).hexdigest() + + def __call__(self, keys=[], args=[], client=None): + "Execute the script, passing any required ``args``" + if client is None: + client = self.registered_client + args = tuple(keys) + tuple(args) + # make sure the Redis server knows about the script + from redis.client import Pipeline + if isinstance(client, Pipeline): + # Make sure the pipeline can register the script before executing. + client.scripts.add(self) + try: + return client.evalsha(self.sha, len(keys), *args) + except NoScriptError: + # Maybe the client is pointed to a different server than the client + # that created this instance? + # Overwrite the sha just in case there was a discrepancy. + self.sha = client.script_load(self.script) + return client.evalsha(self.sha, len(keys), *args) + + +class BitFieldOperation: + """ + Command builder for BITFIELD commands. + """ + + def __init__(self, client, key, default_overflow=None): + self.client = client + self.key = key + self._default_overflow = default_overflow + self.reset() + + def reset(self): + """ + Reset the state of the instance to when it was constructed + """ + self.operations = [] + self._last_overflow = 'WRAP' + self.overflow(self._default_overflow or self._last_overflow) + + def overflow(self, overflow): + """ + Update the overflow algorithm of successive INCRBY operations + :param overflow: Overflow algorithm, one of WRAP, SAT, FAIL. See the + Redis docs for descriptions of these algorithmsself. + :returns: a :py:class:`BitFieldOperation` instance. + """ + overflow = overflow.upper() + if overflow != self._last_overflow: + self._last_overflow = overflow + self.operations.append(('OVERFLOW', overflow)) + return self + + def incrby(self, fmt, offset, increment, overflow=None): + """ + Increment a bitfield by a given amount. + :param fmt: format-string for the bitfield being updated, e.g. 'u8' + for an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :param int increment: value to increment the bitfield by. + :param str overflow: overflow algorithm. Defaults to WRAP, but other + acceptable values are SAT and FAIL. See the Redis docs for + descriptions of these algorithms. + :returns: a :py:class:`BitFieldOperation` instance. + """ + if overflow is not None: + self.overflow(overflow) + + self.operations.append(('INCRBY', fmt, offset, increment)) + return self + + def get(self, fmt, offset): + """ + Get the value of a given bitfield. + :param fmt: format-string for the bitfield being read, e.g. 'u8' for + an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :returns: a :py:class:`BitFieldOperation` instance. + """ + self.operations.append(('GET', fmt, offset)) + return self + + def set(self, fmt, offset, value): + """ + Set the value of a given bitfield. + :param fmt: format-string for the bitfield being read, e.g. 'u8' for + an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :param int value: value to set at the given position. + :returns: a :py:class:`BitFieldOperation` instance. + """ + self.operations.append(('SET', fmt, offset, value)) + return self + + @property + def command(self): + cmd = ['BITFIELD', self.key] + for ops in self.operations: + cmd.extend(ops) + return cmd + + def execute(self): + """ + Execute the operation(s) in a single BITFIELD command. The return value + is a list of values corresponding to each operation. If the client + used to create this instance was a pipeline, the list of values + will be present within the pipeline's execute. + """ + command = self.command + self.reset() + return self.client.execute_command(*command) + + +class SentinalCommands: + """ + A class containing the commands specific to redis sentinal. This class is + to be used as a mixin. + """ + + def sentinel(self, *args): + "Redis Sentinel's SENTINEL command." + warnings.warn( + DeprecationWarning('Use the individual sentinel_* methods')) + + def sentinel_get_master_addr_by_name(self, service_name): + "Returns a (host, port) pair for the given ``service_name``" + return self.execute_command('SENTINEL GET-MASTER-ADDR-BY-NAME', + service_name) + + def sentinel_master(self, service_name): + "Returns a dictionary containing the specified masters state." + return self.execute_command('SENTINEL MASTER', service_name) + + def sentinel_masters(self): + "Returns a list of dictionaries containing each master's state." + return self.execute_command('SENTINEL MASTERS') + + def sentinel_monitor(self, name, ip, port, quorum): + "Add a new master to Sentinel to be monitored" + return self.execute_command('SENTINEL MONITOR', name, ip, port, quorum) + + def sentinel_remove(self, name): + "Remove a master from Sentinel's monitoring" + return self.execute_command('SENTINEL REMOVE', name) + + def sentinel_sentinels(self, service_name): + "Returns a list of sentinels for ``service_name``" + return self.execute_command('SENTINEL SENTINELS', service_name) + + def sentinel_set(self, name, option, value): + "Set Sentinel monitoring parameters for a given master" + return self.execute_command('SENTINEL SET', name, option, value) + + def sentinel_slaves(self, service_name): + "Returns a list of slaves for ``service_name``" + return self.execute_command('SENTINEL SLAVES', service_name) + + +class ClusterMultiKeyCommands: + """ + A class containing commands that handle more than one key + """ + + def _partition_keys_by_slot(self, keys): + """ + Split keys into a dictionary that maps a slot to + a list of keys. + """ + slots_to_keys = {} + for key in keys: + k = self.encoder.encode(key) + slot = key_slot(k) + slots_to_keys.setdefault(slot, []).append(key) + + return slots_to_keys + + def mget_nonatomic(self, keys, *args): + """ + Splits the keys into different slots and then calls MGET + for the keys of every slot. This operation will not be atomic + if keys belong to more than one slot. + + Returns a list of values ordered identically to ``keys`` + """ + + from redis.client import EMPTY_RESPONSE + options = {} + if not args: + options[EMPTY_RESPONSE] = [] + + # Concatenate all keys into a list + keys = list_or_args(keys, args) + # Split keys into slots + slots_to_keys = self._partition_keys_by_slot(keys) + + # Call MGET for every slot and concatenate + # the results + # We must make sure that the keys are returned in order + all_results = {} + for slot_keys in slots_to_keys.values(): + slot_values = self.execute_command( + 'MGET', *slot_keys, **options) + + slot_results = dict(zip(slot_keys, slot_values)) + all_results.update(slot_results) + + # Sort the results + vals_in_order = [all_results[key] for key in keys] + return vals_in_order + + def mset_nonatomic(self, mapping): + """ + Sets key/values based on a mapping. Mapping is a dictionary of + key/value pairs. Both keys and values should be strings or types that + can be cast to a string via str(). + + Splits the keys into different slots and then calls MSET + for the keys of every slot. This operation will not be atomic + if keys belong to more than one slot. + """ + + # Partition the keys by slot + slots_to_pairs = {} + for pair in mapping.items(): + # encode the key + k = self.encoder.encode(pair[0]) + slot = key_slot(k) + slots_to_pairs.setdefault(slot, []).extend(pair) + + # Call MSET for every slot and concatenate + # the results (one result per slot) + res = [] + for pairs in slots_to_pairs.values(): + res.append(self.execute_command('MSET', *pairs)) + + return res + + def _split_command_across_slots(self, command, *keys): + """ + Runs the given command once for the keys + of each slot. Returns the sum of the return values. + """ + # Partition the keys by slot + slots_to_keys = self._partition_keys_by_slot(keys) + + # Sum up the reply from each command + total = 0 + for slot_keys in slots_to_keys.values(): + total += self.execute_command(command, *slot_keys) + + return total + + def exists(self, *keys): + """ + Returns the number of ``names`` that exist in the + whole cluster. The keys are first split up into slots + and then an EXISTS command is sent for every slot + """ + return self._split_command_across_slots('EXISTS', *keys) + + def delete(self, *keys): + """ + Deletes the given keys in the cluster. + The keys are first split up into slots + and then an DEL command is sent for every slot + + Non-existant keys are ignored. + Returns the number of keys that were deleted. + """ + return self._split_command_across_slots('DEL', *keys) + + def touch(self, *keys): + """ + Updates the last access time of given keys across the + cluster. + + The keys are first split up into slots + and then an TOUCH command is sent for every slot + + Non-existant keys are ignored. + Returns the number of keys that were touched. + """ + return self._split_command_across_slots('TOUCH', *keys) + + def unlink(self, *keys): + """ + Remove the specified keys in a different thread. + + The keys are first split up into slots + and then an TOUCH command is sent for every slot + + Non-existant keys are ignored. + Returns the number of keys that were unlinked. + """ + return self._split_command_across_slots('UNLINK', *keys) + + +class DataAccessCommands(BasicKeyCommands, ListCommands, + ScanCommands, SetCommands, StreamsCommands, + SortedSetCommands, + HyperLogLogCommands, HashCommands, GeoCommands, + ): + """ + A class containing all of the implemented data access redis commands. + This class is to be used as a mixin. + """ + + +class Commands(AclCommands, DataAccessCommands, ManagementCommands, + ModuleCommands, PubSubCommands, ScriptCommands): + """ + A class containing all of the implemented redis commands. This class is + to be used as a mixin. + """ + + +class ClusterManagementCommands: + def bgsave(self, schedule=True, target_nodes=None): + """ + Tell the Redis server to save its data to disk. Unlike save(), + this method is asynchronous and returns immediately. + """ + pieces = [] + if schedule: + pieces.append("SCHEDULE") + return self.execute_command('BGSAVE', + *pieces, + target_nodes=target_nodes) + + def client_getname(self, target_nodes=None): + """ + Returns the current connection name from all nodes. + The result will be a dictionary with the IP and + connection name. + """ + return self.execute_command('CLIENT GETNAME', + target_nodes=target_nodes) + + def client_getredir(self, target_nodes=None): + """Returns the ID (an integer) of the client to whom we are + redirecting tracking notifications. + + see: https://redis.io/commands/client-getredir + """ + return self.execute_command('CLIENT GETREDIR', + target_nodes=target_nodes) + + def client_id(self, target_nodes=None): + """Returns the current connection id""" + return self.execute_command('CLIENT ID', + target_nodes=target_nodes) + + def client_info(self, target_nodes=None): + """ + Returns information and statistics about the current + client connection. + """ + return self.execute_command('CLIENT INFO', + target_nodes=target_nodes) + + def client_kill_filter(self, _id=None, _type=None, addr=None, + skipme=None, laddr=None, user=None, + target_nodes=None): + """ + Disconnects client(s) using a variety of filter options + :param id: Kills a client by its unique ID field + :param type: Kills a client by type where type is one of 'normal', + 'master', 'slave' or 'pubsub' + :param addr: Kills a client by its 'address:port' + :param skipme: If True, then the client calling the command + will not get killed even if it is identified by one of the filter + options. If skipme is not provided, the server defaults to skipme=True + :param laddr: Kills a client by its 'local (bind) address:port' + :param user: Kills a client for a specific user name + """ + args = [] + if _type is not None: + client_types = ('normal', 'master', 'slave', 'pubsub') + if str(_type).lower() not in client_types: + raise DataError("CLIENT KILL type must be one of %r" % ( + client_types,)) + args.extend((b'TYPE', _type)) + if skipme is not None: + if not isinstance(skipme, bool): + raise DataError("CLIENT KILL skipme must be a bool") + if skipme: + args.extend((b'SKIPME', b'YES')) + else: + args.extend((b'SKIPME', b'NO')) + if _id is not None: + args.extend((b'ID', _id)) + if addr is not None: + args.extend((b'ADDR', addr)) + if laddr is not None: + args.extend((b'LADDR', laddr)) + if user is not None: + args.extend((b'USER', user)) + if not args: + raise DataError("CLIENT KILL ... ... " + " must specify at least one filter") + return self.execute_command('CLIENT KILL', *args, + target_nodes=target_nodes) + + def client_kill(self, address, target_nodes=None): + "Disconnects the client at ``address`` (ip:port)" + return self.execute_command('CLIENT KILL', address, + target_nodes=target_nodes) + + def client_list(self, _type=None, target_nodes=None): + """ + Returns a list of currently connected clients to the entire cluster. + If type of client specified, only that type will be returned. + :param _type: optional. one of the client types (normal, master, + replica, pubsub) + """ + if _type is not None: + client_types = ('normal', 'master', 'replica', 'pubsub') + if str(_type).lower() not in client_types: + raise DataError("CLIENT LIST _type must be one of %r" % ( + client_types,)) + return self.execute_command('CLIENT LIST', + b'TYPE', + _type, + target_noes=target_nodes) + return self.execute_command('CLIENT LIST', + target_nodes=target_nodes) + + def client_pause(self, timeout, target_nodes=None): + """ + Suspend all the Redis clients for the specified amount of time + :param timeout: milliseconds to pause clients + """ + if not isinstance(timeout, int): + raise DataError("CLIENT PAUSE timeout must be an integer") + return self.execute_command('CLIENT PAUSE', str(timeout), + target_nodes=target_nodes) + + def client_reply(self, reply, target_nodes=None): + """Enable and disable redis server replies. + ``reply`` Must be ON OFF or SKIP, + ON - The default most with server replies to commands + OFF - Disable server responses to commands + SKIP - Skip the response of the immediately following command. + + Note: When setting OFF or SKIP replies, you will need a client object + with a timeout specified in seconds, and will need to catch the + TimeoutError. + The test_client_reply unit test illustrates this, and + conftest.py has a client with a timeout. + See https://redis.io/commands/client-reply + """ + replies = ['ON', 'OFF', 'SKIP'] + if reply not in replies: + raise DataError('CLIENT REPLY must be one of %r' % replies) + return self.execute_command("CLIENT REPLY", reply, + target_nodes=target_nodes) + + def client_setname(self, name, target_nodes=None): + "Sets the current connection name" + return self.execute_command('CLIENT SETNAME', name, + target_nodes=target_nodes) + + def client_trackinginfo(self, target_nodes=None): + """ + Returns the information about the current client connection's + use of the server assisted client side cache. + See https://redis.io/commands/client-trackinginfo + """ + return self.execute_command('CLIENT TRACKINGINFO', + target_nodes=target_nodes) + + def client_unblock(self, client_id, error=False, target_nodes=None): + """ + Unblocks a connection by its client id. + If ``error`` is True, unblocks the client with a special error message. + If ``error`` is False (default), the client is unblocked using the + regular timeout mechanism. + """ + args = ['CLIENT UNBLOCK', int(client_id)] + if error: + args.append(b'ERROR') + return self.execute_command(*args, target_nodes=target_nodes) + + def client_unpause(self, target_nodes=None): + """ + Unpause all redis clients + """ + return self.execute_command('CLIENT UNPAUSE', + target_nodes=target_nodes) + + def config_get(self, pattern="*", target_nodes=None): + """Return a dictionary of configuration based on the ``pattern``""" + return self.execute_command('CONFIG GET', + pattern, + target_nodes=target_nodes) + + def config_resetstat(self, target_nodes=None): + """Reset runtime statistics""" + return self.execute_command('CONFIG RESETSTAT', + target_nodes=target_nodes) + + def config_rewrite(self, target_nodes=None): + """ + Rewrite config file with the minimal change to reflect running config. + """ + return self.execute_command('CONFIG REWRITE', + target_nodes=target_nodes) + + def config_set(self, name, value, target_nodes=None): + "Set config item ``name`` with ``value``" + return self.execute_command('CONFIG SET', + name, + value, + target_nodes=target_nodes) + + def dbsize(self, target_nodes=None): + """ + Sums the number of keys in the target nodes' DB. + If no target nodes are specified, send to the entire cluster and sum + the results. + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('DBSIZE', + target_nodes=target_nodes) + + def debug_object(self, key): + raise NotImplementedError( + "DEBUG OBJECT is intentionally not implemented in the client." + ) + + def debug_segfault(self): + raise NotImplementedError( + "DEBUG SEGFAULT is intentionally not implemented in the client." + ) + + def echo(self, value, target_nodes): + """Echo the string back from the server""" + return self.execute_command('ECHO', value, + target_nodes=target_nodes) + + def flushall(self, asynchronous=False, target_nodes=None): + """ + Delete all keys in the database on all hosts. + In cluster mode this method is the same as flushdb + + ``asynchronous`` indicates whether the operation is + executed asynchronously by the server. + """ + args = [] + if asynchronous: + args.append(b'ASYNC') + return self.execute_command('FLUSHALL', + *args, + target_nodes=target_nodes) + + def flushdb(self, asynchronous=False, target_nodes=None): + """ + Delete all keys in the database. + + ``asynchronous`` indicates whether the operation is + executed asynchronously by the server. + """ + args = [] + if asynchronous: + args.append(b'ASYNC') + return self.execute_command('FLUSHDB', + *args, + target_nodes=target_nodes) + + def info(self, section=None, target_nodes=None): + """ + Returns a dictionary containing information about the Redis server + + The ``section`` option can be used to select a specific section + of information + + The section option is not supported by older versions of Redis Server, + and will generate ResponseError + """ + if section is None: + return self.execute_command('INFO', + target_nodes=target_nodes) + else: + return self.execute_command('INFO', + section, + target_nodes=target_nodes) + + def lastsave(self, target_nodes=None): + """ + Return a Python datetime object representing the last time the + Redis database was saved to disk + """ + return self.execute_command('LASTSAVE', + target_nodes=target_nodes) + + def memory_doctor(self): + raise NotImplementedError( + "MEMORY DOCTOR is intentionally not implemented in the client." + ) + + def memory_help(self): + raise NotImplementedError( + "MEMORY HELP is intentionally not implemented in the client." + ) + + def memory_malloc_stats(self, target_nodes=None): + """Return an internal statistics report from the memory allocator.""" + return self.execute_command('MEMORY MALLOC-STATS', + target_nodes=target_nodes) + + def memory_purge(self, target_nodes=None): + """Attempts to purge dirty pages for reclamation by allocator""" + return self.execute_command('MEMORY PURGE', + target_nodes=target_nodes) + + def memory_stats(self, target_nodes=None): + """Return a dictionary of memory stats""" + return self.execute_command('MEMORY STATS', + target_nodes=target_nodes) + + def memory_usage(self, key, samples=None): + """ + Return the total memory usage for key, its value and associated + administrative overheads. + + For nested data structures, ``samples`` is the number of elements to + sample. If left unspecified, the server's default is 5. Use 0 to sample + all elements. + """ + args = [] + if isinstance(samples, int): + args.extend([b'SAMPLES', samples]) + return self.execute_command('MEMORY USAGE', key, *args) + + def migrate(self, host, source_node, port, keys, destination_db, timeout, + copy=False, replace=False, auth=None): + """ + Migrate 1 or more keys from the source_node Redis server to a different + server specified by the ``host``, ``port`` and ``destination_db``. + + The ``timeout``, specified in milliseconds, indicates the maximum + time the connection between the two servers can be idle before the + command is interrupted. + + If ``copy`` is True, the specified ``keys`` are NOT deleted from + the source server. + + If ``replace`` is True, this operation will overwrite the keys + on the destination server if they exist. + + If ``auth`` is specified, authenticate to the destination server with + the password provided. + """ + keys = list_or_args(keys, []) + if not keys: + raise DataError('MIGRATE requires at least one key') + pieces = [] + if copy: + pieces.append(b'COPY') + if replace: + pieces.append(b'REPLACE') + if auth: + pieces.append(b'AUTH') + pieces.append(auth) + pieces.append(b'KEYS') + pieces.extend(keys) + return self.execute_command('MIGRATE', host, port, '', destination_db, + timeout, *pieces, + target_nodes=source_node) + + def object(self, infotype, key): + """Return the encoding, idletime, or refcount about the key""" + return self.execute_command('OBJECT', infotype, key, infotype=infotype) + + def ping(self, target_nodes=None): + """ + Ping the cluster's servers. + If no target nodes are specified, sent to all nodes and returns True if + the ping was successful across all nodes. + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('PING', + target_nodes=target_nodes) + + def save(self): + """ + Tell the Redis server to save its data to disk, + blocking until the save is complete + """ + return self.execute_command('SAVE') + + def shutdown(self, save=False, nosave=False): + """Shutdown the Redis server. If Redis has persistence configured, + data will be flushed before shutdown. If the "save" option is set, + a data flush will be attempted even if there is no persistence + configured. If the "nosave" option is set, no data flush will be + attempted. The "save" and "nosave" options cannot both be set. + """ + if save and nosave: + raise DataError('SHUTDOWN save and nosave cannot both be set') + args = ['SHUTDOWN'] + if save: + args.append('SAVE') + if nosave: + args.append('NOSAVE') + try: + self.execute_command(*args) + except ConnectionError: + # a ConnectionError here is expected + return + raise RedisError("SHUTDOWN seems to have failed.") + + def slowlog_get(self, num=None, target_nodes=None): + """ + Get the entries from the slowlog. If ``num`` is specified, get the + most recent ``num`` items. + """ + args = ['SLOWLOG GET'] + if num is not None: + args.append(num) + + return self.execute_command(*args, + target_nodes=target_nodes) + + def slowlog_len(self, target_nodes=None): + "Get the number of items in the slowlog" + return self.execute_command('SLOWLOG LEN', + target_nodes=target_nodes) + + def slowlog_reset(self, target_nodes=None): + "Remove all items in the slowlog" + return self.execute_command('SLOWLOG RESET', + target_nodes=target_nodes) + + def time(self, target_nodes=None): + """ + Returns the server time as a 2-item tuple of ints: + (seconds since epoch, microseconds into this second). + """ + return self.execute_command('TIME', target_nodes=target_nodes) + + def wait(self, num_replicas, timeout, target_nodes=None): + """ + Redis synchronous replication + That returns the number of replicas that processed the query when + we finally have at least ``num_replicas``, or when the ``timeout`` was + reached. + + In cluster mode the WAIT command will be sent to all primaries + and the result will be summed up + """ + return self.execute_command('WAIT', num_replicas, + timeout, + target_nodes=target_nodes) + + +class ClusterCommands(ClusterManagementCommands, ClusterMultiKeyCommands, + DataAccessCommands, PubSubCommands): + def cluster_addslots(self, target_node, *slots): + """ + Assign new hash slots to receiving node. Sends to specified node. + + :target_node: 'ClusterNode' + The node to execute the command on + """ + return self.execute_command('CLUSTER ADDSLOTS', *slots, + target_nodes=target_node) + + def cluster_countkeysinslot(self, slot_id): + """ + Return the number of local keys in the specified hash slot + Send to node based on specified slot_id + """ + return self.execute_command('CLUSTER COUNTKEYSINSLOT', slot_id) + + def cluster_count_failure_report(self, node_id): + """ + Return the number of failure reports active for a given node + Sends to a random node + """ + return self.execute_command('CLUSTER COUNT-FAILURE-REPORTS', node_id) + + def cluster_delslots(self, *slots): + """ + Set hash slots as unbound in the cluster. + It determines by it self what node the slot is in and sends it there + + Returns a list of the results for each processed slot. + """ + return [ + self.execute_command('CLUSTER DELSLOTS', slot) + for slot in slots + ] + + def cluster_failover(self, target_node, option=None): + """ + Forces a slave to perform a manual failover of its master + Sends to specified node + + :target_node: 'ClusterNode' + The node to execute the command on + """ + if option: + if option.upper() not in ['FORCE', 'TAKEOVER']: + raise RedisError( + 'Invalid option for CLUSTER FAILOVER command: {0}'.format( + option)) + else: + return self.execute_command('CLUSTER FAILOVER', option, + target_nodes=target_node) + else: + return self.execute_command('CLUSTER FAILOVER', + target_nodes=target_node) + + def cluster_info(self, target_node=None): + """ + Provides info about Redis Cluster node state. + The command will be sent to a random node in the cluster if no target + node is specified. + + :target_node: 'ClusterNode' + The node to execute the command on + """ + return self.execute_command('CLUSTER INFO', target_nodes=target_node) + + def cluster_keyslot(self, key): + """ + Returns the hash slot of the specified key + Sends to random node in the cluster + """ + return self.execute_command('CLUSTER KEYSLOT', key) + + def cluster_meet(self, target_nodes, host, port): + """ + Force a node cluster to handshake with another node. + Sends to specified node. + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('CLUSTER MEET', host, port, + target_nodes=target_nodes) + + def cluster_nodes(self): + """ + Force a node cluster to handshake with another node + + Sends to random node in the cluster + """ + return self.execute_command('CLUSTER NODES') + + def cluster_replicate(self, target_nodes, node_id): + """ + Reconfigure a node as a slave of the specified master node + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('CLUSTER REPLICATE', node_id, + target_nodes=target_nodes) + + def cluster_reset(self, target_nodes, soft=True): + """ + Reset a Redis Cluster node + + If 'soft' is True then it will send 'SOFT' argument + If 'soft' is False then it will send 'HARD' argument + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('CLUSTER RESET', + b'SOFT' if soft else b'HARD', + target_nodes=target_nodes) + + def cluster_save_config(self, target_nodes): + """ + Forces the node to save cluster state on disk + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('CLUSTER SAVECONFIG', + target_nodes=target_nodes) + + def cluster_get_keys_in_slot(self, slot, num_keys): + """ + Returns the number of keys in the specified cluster slot + """ + return self.execute_command('CLUSTER GETKEYSINSLOT', slot, num_keys) + + def cluster_set_config_epoch(self, target_nodes, epoch): + """ + Set the configuration epoch in a new node + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('CLUSTER SET-CONFIG-EPOCH', epoch, + target_nodes=target_nodes) + + def cluster_setslot(self, target_node, node_id, slot_id, state): + """ + Bind an hash slot to a specific node + + :target_node: 'ClusterNode' + The node to execute the command on + """ + if state.upper() in ('IMPORTING', 'NODE', 'MIGRATING'): + return self.execute_command('CLUSTER SETSLOT', slot_id, state, + node_id, target_nodes=target_node) + elif state.upper() == 'STABLE': + raise RedisError('For "stable" state please use ' + 'cluster_setslot_stable') + else: + raise RedisError('Invalid slot state: {0}'.format(state)) + + def cluster_setslot_stable(self, slot_id): + """ + Clears migrating / importing state from the slot. + It determines by it self what node the slot is in and sends it there. + """ + return self.execute_command('CLUSTER SETSLOT', slot_id, 'STABLE') + + def cluster_replicas(self, node_id): + """ + Provides a list of replica nodes replicating from the specified primary + target node. + Sends to random node in the cluster. + """ + return self.execute_command('CLUSTER REPLICAS', node_id) + + def cluster_slots(self): + """ + Get array of Cluster slot to node mappings + + Sends to random node in the cluster + """ + return self.execute_command('CLUSTER SLOTS') + + def readonly(self, target_nodes=None): + """ + Enables read queries. + The command will be sent to all replica nodes if target_nodes is not + specified. + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + self.read_from_replicas = True + return self.execute_command('READONLY', target_nodes=target_nodes) + + def readwrite(self, target_nodes=None): + """ + Disables read queries. + The command will be sent to all replica nodes if target_nodes is not + specified. + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + # Reset read from replicas flag + self.read_from_replicas = False + return self.execute_command('READWRITE', target_nodes=target_nodes) diff --git a/redis/commands/__init__.py b/redis/commands/__init__.py index f1ddaaabc1..60f13d8d35 100644 --- a/redis/commands/__init__.py +++ b/redis/commands/__init__.py @@ -2,9 +2,13 @@ from .redismodules import RedisModuleCommands from .helpers import list_or_args from .sentinel import SentinelCommands +from .cluster import ClusterCommands +from .parser import CommandsParser __all__ = [ 'CoreCommands', + 'ClusterCommands', + 'CommandsParser', 'RedisModuleCommands', 'SentinelCommands', 'list_or_args' diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py new file mode 100644 index 0000000000..d3c2e8572b --- /dev/null +++ b/redis/commands/cluster.py @@ -0,0 +1,801 @@ +from redis.exceptions import ( + ConnectionError, + DataError, + RedisError, +) +from redis.crc import key_slot +from .core import DataAccessCommands, PubSubCommands +from .helpers import list_or_args + + +class ClusterMultiKeyCommands: + """ + A class containing commands that handle more than one key + """ + + def _partition_keys_by_slot(self, keys): + """ + Split keys into a dictionary that maps a slot to + a list of keys. + """ + slots_to_keys = {} + for key in keys: + k = self.encoder.encode(key) + slot = key_slot(k) + slots_to_keys.setdefault(slot, []).append(key) + + return slots_to_keys + + def mget_nonatomic(self, keys, *args): + """ + Splits the keys into different slots and then calls MGET + for the keys of every slot. This operation will not be atomic + if keys belong to more than one slot. + + Returns a list of values ordered identically to ``keys`` + """ + + from redis.client import EMPTY_RESPONSE + options = {} + if not args: + options[EMPTY_RESPONSE] = [] + + # Concatenate all keys into a list + keys = list_or_args(keys, args) + # Split keys into slots + slots_to_keys = self._partition_keys_by_slot(keys) + + # Call MGET for every slot and concatenate + # the results + # We must make sure that the keys are returned in order + all_results = {} + for slot_keys in slots_to_keys.values(): + slot_values = self.execute_command( + 'MGET', *slot_keys, **options) + + slot_results = dict(zip(slot_keys, slot_values)) + all_results.update(slot_results) + + # Sort the results + vals_in_order = [all_results[key] for key in keys] + return vals_in_order + + def mset_nonatomic(self, mapping): + """ + Sets key/values based on a mapping. Mapping is a dictionary of + key/value pairs. Both keys and values should be strings or types that + can be cast to a string via str(). + + Splits the keys into different slots and then calls MSET + for the keys of every slot. This operation will not be atomic + if keys belong to more than one slot. + """ + + # Partition the keys by slot + slots_to_pairs = {} + for pair in mapping.items(): + # encode the key + k = self.encoder.encode(pair[0]) + slot = key_slot(k) + slots_to_pairs.setdefault(slot, []).extend(pair) + + # Call MSET for every slot and concatenate + # the results (one result per slot) + res = [] + for pairs in slots_to_pairs.values(): + res.append(self.execute_command('MSET', *pairs)) + + return res + + def _split_command_across_slots(self, command, *keys): + """ + Runs the given command once for the keys + of each slot. Returns the sum of the return values. + """ + # Partition the keys by slot + slots_to_keys = self._partition_keys_by_slot(keys) + + # Sum up the reply from each command + total = 0 + for slot_keys in slots_to_keys.values(): + total += self.execute_command(command, *slot_keys) + + return total + + def exists(self, *keys): + """ + Returns the number of ``names`` that exist in the + whole cluster. The keys are first split up into slots + and then an EXISTS command is sent for every slot + """ + return self._split_command_across_slots('EXISTS', *keys) + + def delete(self, *keys): + """ + Deletes the given keys in the cluster. + The keys are first split up into slots + and then an DEL command is sent for every slot + + Non-existant keys are ignored. + Returns the number of keys that were deleted. + """ + return self._split_command_across_slots('DEL', *keys) + + def touch(self, *keys): + """ + Updates the last access time of given keys across the + cluster. + + The keys are first split up into slots + and then an TOUCH command is sent for every slot + + Non-existant keys are ignored. + Returns the number of keys that were touched. + """ + return self._split_command_across_slots('TOUCH', *keys) + + def unlink(self, *keys): + """ + Remove the specified keys in a different thread. + + The keys are first split up into slots + and then an TOUCH command is sent for every slot + + Non-existant keys are ignored. + Returns the number of keys that were unlinked. + """ + return self._split_command_across_slots('UNLINK', *keys) + + +class ClusterManagementCommands: + def bgsave(self, schedule=True, target_nodes=None): + """ + Tell the Redis server to save its data to disk. Unlike save(), + this method is asynchronous and returns immediately. + """ + pieces = [] + if schedule: + pieces.append("SCHEDULE") + return self.execute_command('BGSAVE', + *pieces, + target_nodes=target_nodes) + + def client_getname(self, target_nodes=None): + """ + Returns the current connection name from all nodes. + The result will be a dictionary with the IP and + connection name. + """ + return self.execute_command('CLIENT GETNAME', + target_nodes=target_nodes) + + def client_getredir(self, target_nodes=None): + """Returns the ID (an integer) of the client to whom we are + redirecting tracking notifications. + + see: https://redis.io/commands/client-getredir + """ + return self.execute_command('CLIENT GETREDIR', + target_nodes=target_nodes) + + def client_id(self, target_nodes=None): + """Returns the current connection id""" + return self.execute_command('CLIENT ID', + target_nodes=target_nodes) + + def client_info(self, target_nodes=None): + """ + Returns information and statistics about the current + client connection. + """ + return self.execute_command('CLIENT INFO', + target_nodes=target_nodes) + + def client_kill_filter(self, _id=None, _type=None, addr=None, + skipme=None, laddr=None, user=None, + target_nodes=None): + """ + Disconnects client(s) using a variety of filter options + :param id: Kills a client by its unique ID field + :param type: Kills a client by type where type is one of 'normal', + 'master', 'slave' or 'pubsub' + :param addr: Kills a client by its 'address:port' + :param skipme: If True, then the client calling the command + will not get killed even if it is identified by one of the filter + options. If skipme is not provided, the server defaults to skipme=True + :param laddr: Kills a client by its 'local (bind) address:port' + :param user: Kills a client for a specific user name + """ + args = [] + if _type is not None: + client_types = ('normal', 'master', 'slave', 'pubsub') + if str(_type).lower() not in client_types: + raise DataError("CLIENT KILL type must be one of %r" % ( + client_types,)) + args.extend((b'TYPE', _type)) + if skipme is not None: + if not isinstance(skipme, bool): + raise DataError("CLIENT KILL skipme must be a bool") + if skipme: + args.extend((b'SKIPME', b'YES')) + else: + args.extend((b'SKIPME', b'NO')) + if _id is not None: + args.extend((b'ID', _id)) + if addr is not None: + args.extend((b'ADDR', addr)) + if laddr is not None: + args.extend((b'LADDR', laddr)) + if user is not None: + args.extend((b'USER', user)) + if not args: + raise DataError("CLIENT KILL ... ... " + " must specify at least one filter") + return self.execute_command('CLIENT KILL', *args, + target_nodes=target_nodes) + + def client_kill(self, address, target_nodes=None): + "Disconnects the client at ``address`` (ip:port)" + return self.execute_command('CLIENT KILL', address, + target_nodes=target_nodes) + + def client_list(self, _type=None, target_nodes=None): + """ + Returns a list of currently connected clients to the entire cluster. + If type of client specified, only that type will be returned. + :param _type: optional. one of the client types (normal, master, + replica, pubsub) + """ + if _type is not None: + client_types = ('normal', 'master', 'replica', 'pubsub') + if str(_type).lower() not in client_types: + raise DataError("CLIENT LIST _type must be one of %r" % ( + client_types,)) + return self.execute_command('CLIENT LIST', + b'TYPE', + _type, + target_noes=target_nodes) + return self.execute_command('CLIENT LIST', + target_nodes=target_nodes) + + def client_pause(self, timeout, target_nodes=None): + """ + Suspend all the Redis clients for the specified amount of time + :param timeout: milliseconds to pause clients + """ + if not isinstance(timeout, int): + raise DataError("CLIENT PAUSE timeout must be an integer") + return self.execute_command('CLIENT PAUSE', str(timeout), + target_nodes=target_nodes) + + def client_reply(self, reply, target_nodes=None): + """Enable and disable redis server replies. + ``reply`` Must be ON OFF or SKIP, + ON - The default most with server replies to commands + OFF - Disable server responses to commands + SKIP - Skip the response of the immediately following command. + + Note: When setting OFF or SKIP replies, you will need a client object + with a timeout specified in seconds, and will need to catch the + TimeoutError. + The test_client_reply unit test illustrates this, and + conftest.py has a client with a timeout. + See https://redis.io/commands/client-reply + """ + replies = ['ON', 'OFF', 'SKIP'] + if reply not in replies: + raise DataError('CLIENT REPLY must be one of %r' % replies) + return self.execute_command("CLIENT REPLY", reply, + target_nodes=target_nodes) + + def client_setname(self, name, target_nodes=None): + "Sets the current connection name" + return self.execute_command('CLIENT SETNAME', name, + target_nodes=target_nodes) + + def client_trackinginfo(self, target_nodes=None): + """ + Returns the information about the current client connection's + use of the server assisted client side cache. + See https://redis.io/commands/client-trackinginfo + """ + return self.execute_command('CLIENT TRACKINGINFO', + target_nodes=target_nodes) + + def client_unblock(self, client_id, error=False, target_nodes=None): + """ + Unblocks a connection by its client id. + If ``error`` is True, unblocks the client with a special error message. + If ``error`` is False (default), the client is unblocked using the + regular timeout mechanism. + """ + args = ['CLIENT UNBLOCK', int(client_id)] + if error: + args.append(b'ERROR') + return self.execute_command(*args, target_nodes=target_nodes) + + def client_unpause(self, target_nodes=None): + """ + Unpause all redis clients + """ + return self.execute_command('CLIENT UNPAUSE', + target_nodes=target_nodes) + + def config_get(self, pattern="*", target_nodes=None): + """Return a dictionary of configuration based on the ``pattern``""" + return self.execute_command('CONFIG GET', + pattern, + target_nodes=target_nodes) + + def config_resetstat(self, target_nodes=None): + """Reset runtime statistics""" + return self.execute_command('CONFIG RESETSTAT', + target_nodes=target_nodes) + + def config_rewrite(self, target_nodes=None): + """ + Rewrite config file with the minimal change to reflect running config. + """ + return self.execute_command('CONFIG REWRITE', + target_nodes=target_nodes) + + def config_set(self, name, value, target_nodes=None): + "Set config item ``name`` with ``value``" + return self.execute_command('CONFIG SET', + name, + value, + target_nodes=target_nodes) + + def dbsize(self, target_nodes=None): + """ + Sums the number of keys in the target nodes' DB. + If no target nodes are specified, send to the entire cluster and sum + the results. + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('DBSIZE', + target_nodes=target_nodes) + + def debug_object(self, key): + raise NotImplementedError( + "DEBUG OBJECT is intentionally not implemented in the client." + ) + + def debug_segfault(self): + raise NotImplementedError( + "DEBUG SEGFAULT is intentionally not implemented in the client." + ) + + def echo(self, value, target_nodes): + """Echo the string back from the server""" + return self.execute_command('ECHO', value, + target_nodes=target_nodes) + + def flushall(self, asynchronous=False, target_nodes=None): + """ + Delete all keys in the database on all hosts. + In cluster mode this method is the same as flushdb + + ``asynchronous`` indicates whether the operation is + executed asynchronously by the server. + """ + args = [] + if asynchronous: + args.append(b'ASYNC') + return self.execute_command('FLUSHALL', + *args, + target_nodes=target_nodes) + + def flushdb(self, asynchronous=False, target_nodes=None): + """ + Delete all keys in the database. + + ``asynchronous`` indicates whether the operation is + executed asynchronously by the server. + """ + args = [] + if asynchronous: + args.append(b'ASYNC') + return self.execute_command('FLUSHDB', + *args, + target_nodes=target_nodes) + + def info(self, section=None, target_nodes=None): + """ + Returns a dictionary containing information about the Redis server + + The ``section`` option can be used to select a specific section + of information + + The section option is not supported by older versions of Redis Server, + and will generate ResponseError + """ + if section is None: + return self.execute_command('INFO', + target_nodes=target_nodes) + else: + return self.execute_command('INFO', + section, + target_nodes=target_nodes) + + def lastsave(self, target_nodes=None): + """ + Return a Python datetime object representing the last time the + Redis database was saved to disk + """ + return self.execute_command('LASTSAVE', + target_nodes=target_nodes) + + def memory_doctor(self): + raise NotImplementedError( + "MEMORY DOCTOR is intentionally not implemented in the client." + ) + + def memory_help(self): + raise NotImplementedError( + "MEMORY HELP is intentionally not implemented in the client." + ) + + def memory_malloc_stats(self, target_nodes=None): + """Return an internal statistics report from the memory allocator.""" + return self.execute_command('MEMORY MALLOC-STATS', + target_nodes=target_nodes) + + def memory_purge(self, target_nodes=None): + """Attempts to purge dirty pages for reclamation by allocator""" + return self.execute_command('MEMORY PURGE', + target_nodes=target_nodes) + + def memory_stats(self, target_nodes=None): + """Return a dictionary of memory stats""" + return self.execute_command('MEMORY STATS', + target_nodes=target_nodes) + + def memory_usage(self, key, samples=None): + """ + Return the total memory usage for key, its value and associated + administrative overheads. + + For nested data structures, ``samples`` is the number of elements to + sample. If left unspecified, the server's default is 5. Use 0 to sample + all elements. + """ + args = [] + if isinstance(samples, int): + args.extend([b'SAMPLES', samples]) + return self.execute_command('MEMORY USAGE', key, *args) + + def migrate(self, host, source_node, port, keys, destination_db, timeout, + copy=False, replace=False, auth=None): + """ + Migrate 1 or more keys from the source_node Redis server to a different + server specified by the ``host``, ``port`` and ``destination_db``. + + The ``timeout``, specified in milliseconds, indicates the maximum + time the connection between the two servers can be idle before the + command is interrupted. + + If ``copy`` is True, the specified ``keys`` are NOT deleted from + the source server. + + If ``replace`` is True, this operation will overwrite the keys + on the destination server if they exist. + + If ``auth`` is specified, authenticate to the destination server with + the password provided. + """ + keys = list_or_args(keys, []) + if not keys: + raise DataError('MIGRATE requires at least one key') + pieces = [] + if copy: + pieces.append(b'COPY') + if replace: + pieces.append(b'REPLACE') + if auth: + pieces.append(b'AUTH') + pieces.append(auth) + pieces.append(b'KEYS') + pieces.extend(keys) + return self.execute_command('MIGRATE', host, port, '', destination_db, + timeout, *pieces, + target_nodes=source_node) + + def object(self, infotype, key): + """Return the encoding, idletime, or refcount about the key""" + return self.execute_command('OBJECT', infotype, key, infotype=infotype) + + def ping(self, target_nodes=None): + """ + Ping the cluster's servers. + If no target nodes are specified, sent to all nodes and returns True if + the ping was successful across all nodes. + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('PING', + target_nodes=target_nodes) + + def save(self): + """ + Tell the Redis server to save its data to disk, + blocking until the save is complete + """ + return self.execute_command('SAVE') + + def shutdown(self, save=False, nosave=False): + """Shutdown the Redis server. If Redis has persistence configured, + data will be flushed before shutdown. If the "save" option is set, + a data flush will be attempted even if there is no persistence + configured. If the "nosave" option is set, no data flush will be + attempted. The "save" and "nosave" options cannot both be set. + """ + if save and nosave: + raise DataError('SHUTDOWN save and nosave cannot both be set') + args = ['SHUTDOWN'] + if save: + args.append('SAVE') + if nosave: + args.append('NOSAVE') + try: + self.execute_command(*args) + except ConnectionError: + # a ConnectionError here is expected + return + raise RedisError("SHUTDOWN seems to have failed.") + + def slowlog_get(self, num=None, target_nodes=None): + """ + Get the entries from the slowlog. If ``num`` is specified, get the + most recent ``num`` items. + """ + args = ['SLOWLOG GET'] + if num is not None: + args.append(num) + + return self.execute_command(*args, + target_nodes=target_nodes) + + def slowlog_len(self, target_nodes=None): + "Get the number of items in the slowlog" + return self.execute_command('SLOWLOG LEN', + target_nodes=target_nodes) + + def slowlog_reset(self, target_nodes=None): + "Remove all items in the slowlog" + return self.execute_command('SLOWLOG RESET', + target_nodes=target_nodes) + + def time(self, target_nodes=None): + """ + Returns the server time as a 2-item tuple of ints: + (seconds since epoch, microseconds into this second). + """ + return self.execute_command('TIME', target_nodes=target_nodes) + + def wait(self, num_replicas, timeout, target_nodes=None): + """ + Redis synchronous replication + That returns the number of replicas that processed the query when + we finally have at least ``num_replicas``, or when the ``timeout`` was + reached. + + In cluster mode the WAIT command will be sent to all primaries + and the result will be summed up + """ + return self.execute_command('WAIT', num_replicas, + timeout, + target_nodes=target_nodes) + + +class ClusterCommands(ClusterManagementCommands, ClusterMultiKeyCommands, + DataAccessCommands, PubSubCommands): + def cluster_addslots(self, target_node, *slots): + """ + Assign new hash slots to receiving node. Sends to specified node. + + :target_node: 'ClusterNode' + The node to execute the command on + """ + return self.execute_command('CLUSTER ADDSLOTS', *slots, + target_nodes=target_node) + + def cluster_countkeysinslot(self, slot_id): + """ + Return the number of local keys in the specified hash slot + Send to node based on specified slot_id + """ + return self.execute_command('CLUSTER COUNTKEYSINSLOT', slot_id) + + def cluster_count_failure_report(self, node_id): + """ + Return the number of failure reports active for a given node + Sends to a random node + """ + return self.execute_command('CLUSTER COUNT-FAILURE-REPORTS', node_id) + + def cluster_delslots(self, *slots): + """ + Set hash slots as unbound in the cluster. + It determines by it self what node the slot is in and sends it there + + Returns a list of the results for each processed slot. + """ + return [ + self.execute_command('CLUSTER DELSLOTS', slot) + for slot in slots + ] + + def cluster_failover(self, target_node, option=None): + """ + Forces a slave to perform a manual failover of its master + Sends to specified node + + :target_node: 'ClusterNode' + The node to execute the command on + """ + if option: + if option.upper() not in ['FORCE', 'TAKEOVER']: + raise RedisError( + 'Invalid option for CLUSTER FAILOVER command: {0}'.format( + option)) + else: + return self.execute_command('CLUSTER FAILOVER', option, + target_nodes=target_node) + else: + return self.execute_command('CLUSTER FAILOVER', + target_nodes=target_node) + + def cluster_info(self, target_node=None): + """ + Provides info about Redis Cluster node state. + The command will be sent to a random node in the cluster if no target + node is specified. + + :target_node: 'ClusterNode' + The node to execute the command on + """ + return self.execute_command('CLUSTER INFO', target_nodes=target_node) + + def cluster_keyslot(self, key): + """ + Returns the hash slot of the specified key + Sends to random node in the cluster + """ + return self.execute_command('CLUSTER KEYSLOT', key) + + def cluster_meet(self, target_nodes, host, port): + """ + Force a node cluster to handshake with another node. + Sends to specified node. + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('CLUSTER MEET', host, port, + target_nodes=target_nodes) + + def cluster_nodes(self): + """ + Force a node cluster to handshake with another node + + Sends to random node in the cluster + """ + return self.execute_command('CLUSTER NODES') + + def cluster_replicate(self, target_nodes, node_id): + """ + Reconfigure a node as a slave of the specified master node + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('CLUSTER REPLICATE', node_id, + target_nodes=target_nodes) + + def cluster_reset(self, target_nodes, soft=True): + """ + Reset a Redis Cluster node + + If 'soft' is True then it will send 'SOFT' argument + If 'soft' is False then it will send 'HARD' argument + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('CLUSTER RESET', + b'SOFT' if soft else b'HARD', + target_nodes=target_nodes) + + def cluster_save_config(self, target_nodes): + """ + Forces the node to save cluster state on disk + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('CLUSTER SAVECONFIG', + target_nodes=target_nodes) + + def cluster_get_keys_in_slot(self, slot, num_keys): + """ + Returns the number of keys in the specified cluster slot + """ + return self.execute_command('CLUSTER GETKEYSINSLOT', slot, num_keys) + + def cluster_set_config_epoch(self, target_nodes, epoch): + """ + Set the configuration epoch in a new node + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('CLUSTER SET-CONFIG-EPOCH', epoch, + target_nodes=target_nodes) + + def cluster_setslot(self, target_node, node_id, slot_id, state): + """ + Bind an hash slot to a specific node + + :target_node: 'ClusterNode' + The node to execute the command on + """ + if state.upper() in ('IMPORTING', 'NODE', 'MIGRATING'): + return self.execute_command('CLUSTER SETSLOT', slot_id, state, + node_id, target_nodes=target_node) + elif state.upper() == 'STABLE': + raise RedisError('For "stable" state please use ' + 'cluster_setslot_stable') + else: + raise RedisError('Invalid slot state: {0}'.format(state)) + + def cluster_setslot_stable(self, slot_id): + """ + Clears migrating / importing state from the slot. + It determines by it self what node the slot is in and sends it there. + """ + return self.execute_command('CLUSTER SETSLOT', slot_id, 'STABLE') + + def cluster_replicas(self, node_id): + """ + Provides a list of replica nodes replicating from the specified primary + target node. + Sends to random node in the cluster. + """ + return self.execute_command('CLUSTER REPLICAS', node_id) + + def cluster_slots(self): + """ + Get array of Cluster slot to node mappings + + Sends to random node in the cluster + """ + return self.execute_command('CLUSTER SLOTS') + + def readonly(self, target_nodes=None): + """ + Enables read queries. + The command will be sent to all replica nodes if target_nodes is not + specified. + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + self.read_from_replicas = True + return self.execute_command('READONLY', target_nodes=target_nodes) + + def readwrite(self, target_nodes=None): + """ + Disables read queries. + The command will be sent to all replica nodes if target_nodes is not + specified. + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + # Reset read from replicas flag + self.read_from_replicas = False + return self.execute_command('READWRITE', target_nodes=target_nodes) diff --git a/redis/commands/core.py b/redis/commands/core.py index 6512b45a42..c6f589d21e 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -12,14 +12,7 @@ ) -class CoreCommands: - """ - A class containing all of the implemented redis commands. This class is - to be used as a mixin. - """ - - # SERVER INFORMATION - +class AclCommands: # ACL methods def acl_cat(self, category=None): """ @@ -267,6 +260,8 @@ def acl_whoami(self): "Get the username for the current connection" return self.execute_command('ACL WHOAMI') + +class ManagementCommands: def bgrewriteaof(self): "Tell the Redis server to rewrite the AOF file from data in memory." return self.execute_command('BGREWRITEAOF') @@ -431,6 +426,14 @@ def client_unpause(self): """ return self.execute_command('CLIENT UNPAUSE') + def command_info(self): + raise NotImplementedError( + "COMMAND INFO is intentionally not implemented in the client." + ) + + def command_count(self): + return self.execute_command('COMMAND COUNT') + def readwrite(self): """ Disables read queries for a connection to a Redis Cluster slave node. @@ -461,6 +464,9 @@ def config_rewrite(self): """ return self.execute_command('CONFIG REWRITE') + def cluster(self, cluster_arg, *args): + return self.execute_command('CLUSTER %s' % cluster_arg.upper(), *args) + def dbsize(self): """Returns the number of keys in the current database""" return self.execute_command('DBSIZE') @@ -624,6 +630,16 @@ def quit(self): """ return self.execute_command('QUIT') + def replicaof(self, *args): + """ + Update the replication settings of a redis replica, on the fly. + Examples of valid arguments include: + NO ONE (set no replication) + host port (set to the host and port of a redis server) + see: https://redis.io/commands/replicaof + """ + return self.execute_command('REPLICAOF', *args) + def save(self): """ Tell the Redis server to save its data to disk, @@ -698,6 +714,8 @@ def wait(self, num_replicas, timeout): """ return self.execute_command('WAIT', num_replicas, timeout) + +class BasicKeyCommands: # BASIC KEY COMMANDS def append(self, key, value): """ @@ -1324,6 +1342,8 @@ def unlink(self, *names): "Unlink one or more keys specified by ``names``" return self.execute_command('UNLINK', *names) + +class ListCommands: # LIST COMMANDS def blpop(self, keys, timeout=0): """ @@ -1576,6 +1596,8 @@ def sort(self, name, start=None, num=None, by=None, get=None, options = {'groups': len(get) if groups else None} return self.execute_command('SORT', *pieces, **options) + +class ScanCommands: # SCAN COMMANDS def scan(self, cursor=0, match=None, count=None, _type=None): """ @@ -1723,6 +1745,8 @@ def zscan_iter(self, name, match=None, count=None, score_cast_func=score_cast_func) yield from data + +class SetCommands: # SET COMMANDS def sadd(self, name, *values): """Add ``value(s)`` to set ``name``""" @@ -1805,6 +1829,8 @@ def sunionstore(self, dest, keys, *args): args = list_or_args(keys, args) return self.execute_command('SUNIONSTORE', dest, *args) + +class StreamsCommands: # STREAMS COMMANDS def xack(self, name, groupname, *ids): """ @@ -2243,6 +2269,8 @@ def xtrim(self, name, maxlen=None, approximate=True, minid=None, return self.execute_command('XTRIM', name, *pieces) + +class SortedSetCommands: # SORTED SET COMMANDS def zadd(self, name, mapping, nx=False, xx=False, ch=False, incr=False, gt=None, lt=None): @@ -2721,6 +2749,8 @@ def _zaggregate(self, command, dest, keys, aggregate=None, pieces.append(b'WITHSCORES') return self.execute_command(*pieces, **options) + +class HyperLogLogCommands: # HYPERLOGLOG COMMANDS def pfadd(self, name, *values): "Adds the specified elements to the specified HyperLogLog." @@ -2737,6 +2767,8 @@ def pfmerge(self, dest, *sources): "Merge N different HyperLogLogs into a single one." return self.execute_command('PFMERGE', dest, *sources) + +class HashCommands: # HASH COMMANDS def hdel(self, name, *keys): "Delete ``keys`` from hash ``name``" @@ -2831,6 +2863,9 @@ def hstrlen(self, name, key): """ return self.execute_command('HSTRLEN', name, key) + +class PubSubCommands: + # PUBSUB COMMANDS def publish(self, channel, message): """ Publish ``message`` on ``channel``. @@ -2857,19 +2892,9 @@ def pubsub_numsub(self, *args): """ return self.execute_command('PUBSUB NUMSUB', *args) - def cluster(self, cluster_arg, *args): - return self.execute_command('CLUSTER %s' % cluster_arg.upper(), *args) - - def replicaof(self, *args): - """ - Update the replication settings of a redis replica, on the fly. - Examples of valid arguments include: - NO ONE (set no replication) - host port (set to the host and port of a redis server) - see: https://redis.io/commands/replicaof - """ - return self.execute_command('REPLICAOF', *args) +class ScriptCommands: + # SCRIPT COMMANDS def eval(self, script, numkeys, *keys_and_args): """ Execute the Lua ``script``, specifying the ``numkeys`` the script @@ -2941,6 +2966,8 @@ def register_script(self, script): """ return Script(self, script) + +class GeoCommands: # GEO COMMANDS def geoadd(self, name, values, nx=False, xx=False, ch=False): """ @@ -3235,6 +3262,8 @@ def _geosearchgeneric(self, command, *args, **kwargs): return self.execute_command(command, *pieces, **kwargs) + +class ModuleCommands: # MODULE COMMANDS def module_load(self, path, *args): """ @@ -3258,14 +3287,6 @@ def module_list(self): """ return self.execute_command('MODULE LIST') - def command_info(self): - raise NotImplementedError( - "COMMAND INFO is intentionally not implemented in the client." - ) - - def command_count(self): - return self.execute_command('COMMAND COUNT') - class Script: "An executable Lua script object returned by ``register_script``" @@ -3397,3 +3418,22 @@ def execute(self): command = self.command self.reset() return self.client.execute_command(*command) + + +class DataAccessCommands(BasicKeyCommands, ListCommands, + ScanCommands, SetCommands, StreamsCommands, + SortedSetCommands, + HyperLogLogCommands, HashCommands, GeoCommands, + ): + """ + A class containing all of the implemented data access redis commands. + This class is to be used as a mixin. + """ + + +class CoreCommands(AclCommands, DataAccessCommands, ManagementCommands, + ModuleCommands, PubSubCommands, ScriptCommands): + """ + A class containing all of the implemented redis commands. This class is + to be used as a mixin. + """ diff --git a/redis/commands/parser.py b/redis/commands/parser.py new file mode 100644 index 0000000000..22478ed2ed --- /dev/null +++ b/redis/commands/parser.py @@ -0,0 +1,108 @@ +from redis.exceptions import ( + RedisError, + ResponseError +) +from redis.utils import str_if_bytes + + +class CommandsParser: + """ + Parses Redis commands to get command keys. + COMMAND output is used to determine key locations. + Commands that do not have a predefined key location are flagged with + 'movablekeys', and these commands' keys are determined by the command + 'COMMAND GETKEYS'. + """ + def __init__(self, redis_connection): + self.initialized = False + self.commands = {} + self.initialize(redis_connection) + + def initialize(self, r): + self.commands = r.execute_command("COMMAND") + + # As soon as this PR is merged into Redis, we should reimplement + # our logic to use COMMAND INFO changes to determine the key positions + # https://github.com/redis/redis/pull/8324 + def get_keys(self, redis_conn, *args): + """ + Get the keys from the passed command + """ + if len(args) < 2: + # The command has no keys in it + return None + + cmd_name = args[0].lower() + cmd_name_split = cmd_name.split() + if len(cmd_name_split) > 1: + # we need to take only the main command, e.g. 'memory' for + # 'memory usage' + cmd_name = cmd_name_split[0] + if cmd_name not in self.commands: + # We'll try to reinitialize the commands cache, if the engine + # version has changed, the commands may not be current + self.initialize(redis_conn) + if cmd_name not in self.commands: + raise RedisError("{0} command doesn't exist in Redis commands". + format(cmd_name.upper())) + + command = self.commands.get(cmd_name) + if 'movablekeys' in command['flags']: + keys = self._get_moveable_keys(redis_conn, *args) + elif 'pubsub' in command['flags']: + keys = self._get_pubsub_keys(*args) + else: + if command['step_count'] == 0 and command['first_key_pos'] == 0 \ + and command['last_key_pos'] == 0: + # The command doesn't have keys in it + return None + last_key_pos = command['last_key_pos'] + if last_key_pos == -1: + last_key_pos = len(args) - 1 + keys_pos = list(range(command['first_key_pos'], last_key_pos + 1, + command['step_count'])) + keys = [args[pos] for pos in keys_pos] + + return keys + + def _get_moveable_keys(self, redis_conn, *args): + try: + pieces = [] + cmd_name = args[0] + for arg in cmd_name.split(): + # The command name should be splitted into separate arguments, + # e.g. 'MEMORY USAGE' will be splitted into ['MEMORY', 'USAGE'] + pieces.append(arg) + pieces += args[1:] + keys = redis_conn.execute_command('COMMAND GETKEYS', *pieces) + except ResponseError as e: + message = e.__str__() + if 'Invalid arguments' in message or \ + 'The command has no key arguments' in message: + return None + else: + raise e + return keys + + def _get_pubsub_keys(self, *args): + """ + Get the keys from pubsub command. + Although PubSub commands have predetermined key locations, they are not + supported in the 'COMMAND's output, so the key positions are hardcoded + in this method + """ + if len(args) < 2: + # The command has no keys in it + return None + args = [str_if_bytes(arg) for arg in args] + command = args[0].upper() + if command in ['PUBLISH', 'PUBSUB CHANNELS']: + # format example: + # PUBLISH channel message + keys = [args[1]] + elif command in ['SUBSCRIBE', 'PSUBSCRIBE', 'UNSUBSCRIBE', + 'PUNSUBSCRIBE', 'PUBSUB NUMSUB']: + keys = list(args[1:]) + else: + keys = None + return keys diff --git a/redis/connection.py b/redis/connection.py index c99c550ecd..e1ad6ea7f2 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -11,6 +11,7 @@ import threading import warnings +from redis.backoff import NoBackoff from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, @@ -28,9 +29,9 @@ TimeoutError, ModuleError, ) -from redis.utils import HIREDIS_AVAILABLE, str_if_bytes -from redis.backoff import NoBackoff + from redis.retry import Retry +from redis.utils import HIREDIS_AVAILABLE, str_if_bytes try: import ssl @@ -506,7 +507,7 @@ def __init__(self, host='localhost', port=6379, db=0, password=None, encoding_errors='strict', decode_responses=False, parser_class=DefaultParser, socket_read_size=65536, health_check_interval=0, client_name=None, username=None, - retry=None): + retry=None, redis_connect_func=None): """ Initialize a new Connection. To specify a retry policy, first set `retry_on_timeout` to `True` @@ -536,8 +537,10 @@ def __init__(self, host='localhost', port=6379, db=0, password=None, self.health_check_interval = health_check_interval self.next_health_check = 0 self.encoder = Encoder(encoding, encoding_errors, decode_responses) + self.redis_connect_func = redis_connect_func self._sock = None - self._parser = parser_class(socket_read_size=socket_read_size) + self._socket_read_size = socket_read_size + self.set_parser(parser_class) self._connect_callbacks = [] self._buffer_cutoff = 6000 @@ -567,6 +570,9 @@ def register_connect_callback(self, callback): def clear_connect_callbacks(self): self._connect_callbacks = [] + def set_parser(self, parser_class): + self._parser = parser_class(socket_read_size=self._socket_read_size) + def connect(self): "Connects to the Redis server if not already connected" if self._sock: @@ -580,7 +586,12 @@ def connect(self): self._sock = sock try: - self.on_connect() + if self.redis_connect_func is None: + # Use the default on_connect function + self.on_connect() + else: + # Use the passed function redis_connect_func + self.redis_connect_func(self) except RedisError: # clean up after any error in on_connect self.disconnect() @@ -910,7 +921,8 @@ def __init__(self, path='', db=0, username=None, password=None, self.next_health_check = 0 self.encoder = Encoder(encoding, encoding_errors, decode_responses) self._sock = None - self._parser = parser_class(socket_read_size=socket_read_size) + self._socket_read_size = socket_read_size + self.set_parser(parser_class) self._connect_callbacks = [] self._buffer_cutoff = 6000 diff --git a/redis/crc.py b/redis/crc.py new file mode 100644 index 0000000000..a4dfdf69f5 --- /dev/null +++ b/redis/crc.py @@ -0,0 +1,28 @@ +from binascii import crc_hqx + +# Redis Cluster's key space is divided into 16384 slots. +# For more information see: https://github.com/redis/redis/issues/2576 +REDIS_CLUSTER_HASH_SLOTS = 16384 + +__all__ = [ + "crc16", + "key_slot", + "REDIS_CLUSTER_HASH_SLOTS" +] + + +def crc16(data): + return crc_hqx(data, 0) + + +def key_slot(key, bucket=REDIS_CLUSTER_HASH_SLOTS): + """Calculate key slot for a given key. + :param key - bytes + :param bucket - int + """ + start = key.find(b"{") + if start > -1: + end = key.find(b"}", start + 1) + if end > -1 and end != start + 1: + key = key[start + 1: end] + return crc16(key) % bucket diff --git a/redis/exceptions.py b/redis/exceptions.py index 91eb3c7257..5ea7fe9c30 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -84,3 +84,67 @@ class AuthenticationWrongNumberOfArgsError(ResponseError): were sent to the AUTH command """ pass + + +class RedisClusterException(Exception): + pass + + +class ClusterError(RedisError): + pass + + +class ClusterDownError(ClusterError, ResponseError): + + def __init__(self, resp): + self.args = (resp,) + self.message = resp + + +class AskError(ResponseError): + """ + src node: MIGRATING to dst node + get > ASK error + ask dst node > ASKING command + dst node: IMPORTING from src node + asking command only affects next command + any op will be allowed after asking command + """ + + def __init__(self, resp): + """should only redirect to master node""" + self.args = (resp,) + self.message = resp + slot_id, new_node = resp.split(' ') + host, port = new_node.rsplit(':', 1) + self.slot_id = int(slot_id) + self.node_addr = self.host, self.port = host, int(port) + + +class TryAgainError(ResponseError): + + def __init__(self, *args, **kwargs): + pass + + +class ClusterCrossSlotError(ResponseError): + message = "Keys in request don't hash to the same slot" + + +class MovedError(AskError): + pass + + +class MasterDownError(ClusterDownError): + pass + + +class SlotNotCoveredError(RedisClusterException): + """ + This error only happens in the case where the connection pool will try to + fetch what node that is covered by a given slot. + + If this error is raised the client should drop the current node layout and + attempt to reconnect and refresh the node layout again + """ + pass diff --git a/redis/utils.py b/redis/utils.py index 26fb002b89..0e78cc5f3b 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -36,3 +36,39 @@ def str_if_bytes(value): def safe_str(value): return str(str_if_bytes(value)) + + +def dict_merge(*dicts): + """ + Merge all provided dicts into 1 dict. + *dicts : `dict` + dictionaries to merge + """ + merged = {} + + for d in dicts: + merged.update(d) + + return merged + + +def list_keys_to_dict(key_list, callback): + return dict.fromkeys(key_list, callback) + + +def merge_result(command, res): + """ + Merge all items in `res` into a list. + + This command is used when sending a command to multiple nodes + and they result from each node should be merged into a single list. + + res : 'dict' + """ + result = set() + + for v in res.values(): + for value in v: + result.add(value) + + return list(result) diff --git a/tasks.py b/tasks.py index aa965c6902..44b652908d 100644 --- a/tasks.py +++ b/tasks.py @@ -40,7 +40,10 @@ def tests(c): """Run the redis-py test suite against the current python, with and without hiredis. """ + print("Starting Redis tests") run("tox -e plain -e hiredis") + print("Starting RedisCluster tests") + run("tox -e plain -e hiredis -- --redis-url=redis://localhost:16379/0") @task diff --git a/tests/conftest.py b/tests/conftest.py index 47188df07f..df809bf81d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,17 +3,19 @@ import pytest import random import redis +import time from distutils.version import LooseVersion from redis.connection import parse_url +from redis.exceptions import RedisClusterException from unittest.mock import Mock from urllib.parse import urlparse - REDIS_INFO = {} default_redis_url = "redis://localhost:6379/9" default_redismod_url = "redis://localhost:36379/9" default_redismod_url = "redis://localhost:36379" +default_cluster_nodes = 6 def pytest_addoption(parser): @@ -28,6 +30,12 @@ def pytest_addoption(parser): " with loaded modules," " defaults to `%(default)s`") + parser.addoption('--cluster-nodes', default=default_cluster_nodes, + action="store", + help="The number of cluster nodes that need to be " + "available before the test can start," + " defaults to `%(default)s`") + def _get_info(redis_url): client = redis.Redis.from_url(redis_url) @@ -41,14 +49,37 @@ def pytest_sessionstart(session): info = _get_info(redis_url) version = info["redis_version"] arch_bits = info["arch_bits"] + cluster_enabled = info["cluster_enabled"] REDIS_INFO["version"] = version REDIS_INFO["arch_bits"] = arch_bits + REDIS_INFO["cluster_enabled"] = cluster_enabled # module info redismod_url = session.config.getoption("--redismod-url") info = _get_info(redismod_url) REDIS_INFO["modules"] = info["modules"] + if cluster_enabled: + cluster_nodes = session.config.getoption("--cluster-nodes") + wait_for_cluster_creation(redis_url, cluster_nodes) + + +def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=20): + now = time.time() + timeout = now + timeout + print("Waiting for {0} cluster nodes to become available". + format(cluster_nodes)) + while now < timeout: + try: + client = redis.RedisCluster.from_url(redis_url) + if len(client.get_nodes()) == cluster_nodes: + print("All nodes are available!") + break + except RedisClusterException: + pass + time.sleep(1) + now = time.time() + def skip_if_server_version_lt(min_version): redis_version = REDIS_INFO["version"] @@ -86,6 +117,17 @@ def skip_ifmodversion_lt(min_version: str, module_name: str): raise AttributeError("No redis module named {}".format(module_name)) +def skip_if_cluster_mode(): + return pytest.mark.skipif(REDIS_INFO["cluster_enabled"], + reason="This test isn't supported with cluster " + "mode") + + +def skip_if_not_cluster_mode(): + return pytest.mark.skipif(not REDIS_INFO["cluster_enabled"], + reason="Cluster-mode is required for this test") + + def _get_client(cls, request, single_connection_client=True, flushdb=True, from_url=None, **kwargs): @@ -100,10 +142,14 @@ def _get_client(cls, request, single_connection_client=True, flushdb=True, redis_url = request.config.getoption("--redis-url") else: redis_url = from_url - url_options = parse_url(redis_url) - url_options.update(kwargs) - pool = redis.ConnectionPool(**url_options) - client = cls(connection_pool=pool) + if REDIS_INFO["cluster_enabled"]: + client = redis.RedisCluster.from_url(redis_url, **kwargs) + single_connection_client = False + else: + url_options = parse_url(redis_url) + url_options.update(kwargs) + pool = redis.ConnectionPool(**url_options) + client = cls(connection_pool=pool) if single_connection_client: client = client.client() if request: @@ -116,7 +162,10 @@ def teardown(): # just manually retry the flushdb client.flushdb() client.close() - client.connection_pool.disconnect() + if REDIS_INFO["cluster_enabled"]: + client.disconnect_connection_pools() + else: + client.connection_pool.disconnect() request.addfinalizer(teardown) return client @@ -220,6 +269,13 @@ def master_host(request): yield parts.hostname +@pytest.fixture(scope="session") +def master_port(request): + url = request.config.getoption("--redis-url") + parts = urlparse(url) + yield parts.port + + def wait_for_command(client, monitor, command): # issue a command with a key name that's local to this process. # if we find a command with our key before the command we're waiting diff --git a/tests/test_cluster.py b/tests/test_cluster.py new file mode 100644 index 0000000000..fa78c04710 --- /dev/null +++ b/tests/test_cluster.py @@ -0,0 +1,1403 @@ +import pytest +import datetime +import warnings + +from time import sleep +from unittest.mock import call, patch, DEFAULT, Mock +from redis import Redis +from redis.cluster import get_node_name, ClusterNode, \ + RedisCluster, NodesManager, PRIMARY, REDIS_CLUSTER_HASH_SLOTS, REPLICA +from redis.commands import CommandsParser +from redis.connection import Connection +from redis.utils import str_if_bytes +from redis.exceptions import ( + AskError, + ClusterDownError, + MovedError, + RedisClusterException, + RedisError +) + +from redis.crc import key_slot +from .conftest import ( + skip_if_not_cluster_mode, + _get_client, + skip_if_server_version_lt +) + +default_host = "127.0.0.1" +default_port = 7000 +default_cluster_slots = [ + [ + 0, 8191, + ['127.0.0.1', 7000, 'node_0'], + ['127.0.0.1', 7003, 'node_3'], + ], + [ + 8192, 16383, + ['127.0.0.1', 7001, 'node_1'], + ['127.0.0.1', 7002, 'node_2'] + ] +] + + +@pytest.fixture() +def slowlog(request, r): + """ + Set the slowlog threshold to 0, and the + max length to 128. This will force every + command into the slowlog and allow us + to test it + """ + # Save old values + current_config = r.config_get( + target_nodes=r.get_primaries()[0]) + old_slower_than_value = current_config['slowlog-log-slower-than'] + old_max_legnth_value = current_config['slowlog-max-len'] + + # Function to restore the old values + def cleanup(): + r.config_set('slowlog-log-slower-than', old_slower_than_value) + r.config_set('slowlog-max-len', old_max_legnth_value) + request.addfinalizer(cleanup) + + # Set the new values + r.config_set('slowlog-log-slower-than', 0) + r.config_set('slowlog-max-len', 128) + + +def get_mocked_redis_client(func=None, *args, **kwargs): + """ + Return a stable RedisCluster object that have deterministic + nodes and slots setup to remove the problem of different IP addresses + on different installations and machines. + """ + cluster_slots = kwargs.pop('cluster_slots', default_cluster_slots) + coverage_res = kwargs.pop('coverage_result', 'yes') + with patch.object(Redis, 'execute_command') as execute_command_mock: + def execute_command(*_args, **_kwargs): + if _args[0] == 'CLUSTER SLOTS': + mock_cluster_slots = cluster_slots + return mock_cluster_slots + elif _args[1] == 'cluster-require-full-coverage': + return {'cluster-require-full-coverage': coverage_res} + elif func is not None: + return func(*args, **kwargs) + else: + return execute_command_mock(*_args, **_kwargs) + + execute_command_mock.side_effect = execute_command + + with patch.object(CommandsParser, 'initialize', + autospec=True) as cmd_parser_initialize: + + def cmd_init_mock(self, r): + self.commands = {'get': {'name': 'get', 'arity': 2, + 'flags': ['readonly', + 'fast'], + 'first_key_pos': 1, + 'last_key_pos': 1, + 'step_count': 1}} + + cmd_parser_initialize.side_effect = cmd_init_mock + + return RedisCluster(*args, **kwargs) + + +def mock_node_resp(node, response): + connection = Mock() + connection.read_response.return_value = response + node.redis_connection.connection = connection + return node + + +def mock_all_nodes_resp(rc, response): + for node in rc.get_nodes(): + mock_node_resp(node, response) + return rc + + +def find_node_ip_based_on_port(cluster_client, port): + for node in cluster_client.get_nodes(): + if node.port == port: + return node.host + + +def moved_redirection_helper(request, failover=False): + """ + Test that the client handles MOVED response after a failover. + Redirection after a failover means that the redirection address is of a + replica that was promoted to a primary. + + At first call it should return a MOVED ResponseError that will point + the client to the next server it should talk to. + + Verify that: + 1. it tries to talk to the redirected node + 2. it updates the slot's primary to the redirected node + + For a failover, also verify: + 3. the redirected node's server type updated to 'primary' + 4. the server type of the previous slot owner updated to 'replica' + """ + rc = _get_client(RedisCluster, request, flushdb=False) + slot = 12182 + redirect_node = None + # Get the current primary that holds this slot + prev_primary = rc.nodes_manager.get_node_from_slot(slot) + if failover: + if len(rc.nodes_manager.slots_cache[slot]) < 2: + warnings.warn("Skipping this test since it requires to have a " + "replica") + return + redirect_node = rc.nodes_manager.slots_cache[slot][1] + else: + # Use one of the primaries to be the redirected node + redirect_node = rc.get_primaries()[0] + r_host = redirect_node.host + r_port = redirect_node.port + with patch.object(Redis, 'parse_response') as parse_response: + def moved_redirect_effect(connection, *args, **options): + def ok_response(connection, *args, **options): + assert connection.host == r_host + assert connection.port == r_port + + return "MOCK_OK" + + parse_response.side_effect = ok_response + raise MovedError("{0} {1}:{2}".format(slot, r_host, r_port)) + + parse_response.side_effect = moved_redirect_effect + assert rc.execute_command("SET", "foo", "bar") == "MOCK_OK" + slot_primary = rc.nodes_manager.slots_cache[slot][0] + assert slot_primary == redirect_node + if failover: + assert rc.get_node(host=r_host, port=r_port).server_type == PRIMARY + assert prev_primary.server_type == REPLICA + + +@skip_if_not_cluster_mode() +class TestRedisClusterObj: + def test_host_port_startup_node(self): + """ + Test that it is possible to use host & port arguments as startup node + args + """ + cluster = get_mocked_redis_client(host=default_host, port=default_port) + assert cluster.get_node(host=default_host, + port=default_port) is not None + + def test_startup_nodes(self): + """ + Test that it is possible to use startup_nodes + argument to init the cluster + """ + port_1 = 7000 + port_2 = 7001 + startup_nodes = [ClusterNode(default_host, port_1), + ClusterNode(default_host, port_2)] + cluster = get_mocked_redis_client(startup_nodes=startup_nodes) + assert cluster.get_node(host=default_host, port=port_1) is not None \ + and cluster.get_node(host=default_host, port=port_2) is not None + + def test_empty_startup_nodes(self): + """ + Test that exception is raised when empty providing empty startup_nodes + """ + with pytest.raises(RedisClusterException) as ex: + RedisCluster(startup_nodes=[]) + + assert str(ex.value).startswith( + "RedisCluster requires at least one node to discover the " + "cluster"), str_if_bytes(ex.value) + + def test_from_url(self, r): + redis_url = "redis://{0}:{1}/0".format(default_host, default_port) + with patch.object(RedisCluster, 'from_url') as from_url: + def from_url_mocked(_url, **_kwargs): + return get_mocked_redis_client(url=_url, **_kwargs) + + from_url.side_effect = from_url_mocked + cluster = RedisCluster.from_url(redis_url) + assert cluster.get_node(host=default_host, + port=default_port) is not None + + def test_execute_command_errors(self, r): + """ + Test that if no key is provided then exception should be raised. + """ + with pytest.raises(RedisClusterException) as ex: + r.execute_command("GET") + assert str(ex.value).startswith("No way to dispatch this command to " + "Redis Cluster. Missing key.") + + def test_execute_command_node_flag_primaries(self, r): + """ + Test command execution with nodes flag PRIMARIES + """ + primaries = r.get_primaries() + replicas = r.get_replicas() + mock_all_nodes_resp(r, 'PONG') + assert r.ping(RedisCluster.PRIMARIES) is True + for primary in primaries: + conn = primary.redis_connection.connection + assert conn.read_response.called is True + for replica in replicas: + conn = replica.redis_connection.connection + assert conn.read_response.called is not True + + def test_execute_command_node_flag_replicas(self, r): + """ + Test command execution with nodes flag REPLICAS + """ + replicas = r.get_replicas() + if not replicas: + r = get_mocked_redis_client(default_host, default_port) + primaries = r.get_primaries() + mock_all_nodes_resp(r, 'PONG') + assert r.ping(RedisCluster.REPLICAS) is True + for replica in replicas: + conn = replica.redis_connection.connection + assert conn.read_response.called is True + for primary in primaries: + conn = primary.redis_connection.connection + assert conn.read_response.called is not True + + def test_execute_command_node_flag_all_nodes(self, r): + """ + Test command execution with nodes flag ALL_NODES + """ + mock_all_nodes_resp(r, 'PONG') + assert r.ping(RedisCluster.ALL_NODES) is True + for node in r.get_nodes(): + conn = node.redis_connection.connection + assert conn.read_response.called is True + + def test_execute_command_node_flag_random(self, r): + """ + Test command execution with nodes flag RANDOM + """ + mock_all_nodes_resp(r, 'PONG') + assert r.ping(RedisCluster.RANDOM) is True + called_count = 0 + for node in r.get_nodes(): + conn = node.redis_connection.connection + if conn.read_response.called is True: + called_count += 1 + assert called_count == 1 + + @pytest.mark.filterwarnings("ignore:AskError") + def test_ask_redirection(self, r): + """ + Test that the server handles ASK response. + + At first call it should return a ASK ResponseError that will point + the client to the next server it should talk to. + + Important thing to verify is that it tries to talk to the second node. + """ + redirect_node = r.get_nodes()[0] + with patch.object(Redis, 'parse_response') as parse_response: + def ask_redirect_effect(connection, *args, **options): + def ok_response(connection, *args, **options): + assert connection.host == redirect_node.host + assert connection.port == redirect_node.port + + return "MOCK_OK" + + parse_response.side_effect = ok_response + raise AskError("12182 {0}:{1}".format(redirect_node.host, + redirect_node.port)) + + parse_response.side_effect = ask_redirect_effect + + assert r.execute_command("SET", "foo", "bar") == "MOCK_OK" + + @pytest.mark.filterwarnings("ignore:MovedError") + def test_moved_redirection(self, request): + """ + Test that the client handles MOVED response. + """ + moved_redirection_helper(request, failover=False) + + @pytest.mark.filterwarnings("ignore:MovedError") + def test_moved_redirection_after_failover(self, request): + """ + Test that the client handles MOVED response after a failover. + """ + moved_redirection_helper(request, failover=True) + + @pytest.mark.filterwarnings("ignore:ClusterDownError") + def test_refresh_using_specific_nodes(self, request): + """ + Test making calls on specific nodes when the cluster has failed over to + another node + """ + node_7006 = ClusterNode(host=default_host, port=7006, + server_type=PRIMARY) + node_7007 = ClusterNode(host=default_host, port=7007, + server_type=PRIMARY) + with patch.object(Redis, 'parse_response') as parse_response: + with patch.object(NodesManager, 'initialize', autospec=True) as \ + initialize: + with patch.multiple(Connection, + send_command=DEFAULT, + connect=DEFAULT, + can_read=DEFAULT) as mocks: + # simulate 7006 as a failed node + def parse_response_mock(connection, command_name, + **options): + if connection.port == 7006: + parse_response.failed_calls += 1 + raise ClusterDownError( + 'CLUSTERDOWN The cluster is ' + 'down. Use CLUSTER INFO for ' + 'more information') + elif connection.port == 7007: + parse_response.successful_calls += 1 + + def initialize_mock(self): + # start with all slots mapped to 7006 + self.nodes_cache = {node_7006.name: node_7006} + self.slots_cache = {} + + for i in range(0, 16383): + self.slots_cache[i] = [node_7006] + + # After the first connection fails, a reinitialize + # should follow the cluster to 7007 + def map_7007(self): + self.nodes_cache = { + node_7007.name: node_7007} + self.slots_cache = {} + + for i in range(0, 16383): + self.slots_cache[i] = [node_7007] + + # Change initialize side effect for the second call + initialize.side_effect = map_7007 + + parse_response.side_effect = parse_response_mock + parse_response.successful_calls = 0 + parse_response.failed_calls = 0 + initialize.side_effect = initialize_mock + mocks['can_read'].return_value = False + mocks['send_command'].return_value = "MOCK_OK" + mocks['connect'].return_value = None + with patch.object(CommandsParser, 'initialize', + autospec=True) as cmd_parser_initialize: + + def cmd_init_mock(self, r): + self.commands = {'get': {'name': 'get', 'arity': 2, + 'flags': ['readonly', + 'fast'], + 'first_key_pos': 1, + 'last_key_pos': 1, + 'step_count': 1}} + + cmd_parser_initialize.side_effect = cmd_init_mock + + rc = _get_client( + RedisCluster, request, flushdb=False) + assert len(rc.get_nodes()) == 1 + assert rc.get_node(node_name=node_7006.name) is not \ + None + + rc.get('foo') + + # Cluster should now point to 7007, and there should be + # one failed and one successful call + assert len(rc.get_nodes()) == 1 + assert rc.get_node(node_name=node_7007.name) is not \ + None + assert rc.get_node(node_name=node_7006.name) is None + assert parse_response.failed_calls == 1 + assert parse_response.successful_calls == 1 + + def test_reading_from_replicas_in_round_robin(self): + with patch.multiple(Connection, send_command=DEFAULT, + read_response=DEFAULT, _connect=DEFAULT, + can_read=DEFAULT, on_connect=DEFAULT) as mocks: + with patch.object(Redis, 'parse_response') as parse_response: + def parse_response_mock_first(connection, *args, **options): + # Primary + assert connection.port == 7001 + parse_response.side_effect = parse_response_mock_second + return "MOCK_OK" + + def parse_response_mock_second(connection, *args, **options): + # Replica + assert connection.port == 7002 + parse_response.side_effect = parse_response_mock_third + return "MOCK_OK" + + def parse_response_mock_third(connection, *args, **options): + # Primary + assert connection.port == 7001 + return "MOCK_OK" + + # We don't need to create a real cluster connection but we + # do want RedisCluster.on_connect function to get called, + # so we'll mock some of the Connection's functions to allow it + parse_response.side_effect = parse_response_mock_first + mocks['send_command'].return_value = True + mocks['read_response'].return_value = "OK" + mocks['_connect'].return_value = True + mocks['can_read'].return_value = False + mocks['on_connect'].return_value = True + + # Create a cluster with reading from replications + read_cluster = get_mocked_redis_client(host=default_host, + port=default_port, + read_from_replicas=True) + assert read_cluster.read_from_replicas is True + # Check that we read from the slot's nodes in a round robin + # matter. + # 'foo' belongs to slot 12182 and the slot's nodes are: + # [(127.0.0.1,7001,primary), (127.0.0.1,7002,replica)] + read_cluster.get("foo") + read_cluster.get("foo") + read_cluster.get("foo") + mocks['send_command'].assert_has_calls([call('READONLY')]) + + def test_keyslot(self, r): + """ + Test that method will compute correct key in all supported cases + """ + assert r.keyslot("foo") == 12182 + assert r.keyslot("{foo}bar") == 12182 + assert r.keyslot("{foo}") == 12182 + assert r.keyslot(1337) == 4314 + + assert r.keyslot(125) == r.keyslot(b"125") + assert r.keyslot(125) == r.keyslot("\x31\x32\x35") + assert r.keyslot("大奖") == r.keyslot(b"\xe5\xa4\xa7\xe5\xa5\x96") + assert r.keyslot(u"大奖") == r.keyslot(b"\xe5\xa4\xa7\xe5\xa5\x96") + assert r.keyslot(1337.1234) == r.keyslot("1337.1234") + assert r.keyslot(1337) == r.keyslot("1337") + assert r.keyslot(b"abc") == r.keyslot("abc") + + def test_get_node_name(self): + assert get_node_name(default_host, default_port) == \ + "{0}:{1}".format(default_host, default_port) + + def test_all_nodes(self, r): + """ + Set a list of nodes and it should be possible to iterate over all + """ + nodes = [node for node in r.nodes_manager.nodes_cache.values()] + + for i, node in enumerate(r.get_nodes()): + assert node in nodes + + def test_all_nodes_masters(self, r): + """ + Set a list of nodes with random primaries/replicas config and it shold + be possible to iterate over all of them. + """ + nodes = [node for node in r.nodes_manager.nodes_cache.values() + if node.server_type == PRIMARY] + + for node in r.get_primaries(): + assert node in nodes + + @pytest.mark.filterwarnings("ignore:ClusterDownError") + def test_cluster_down_overreaches_retry_attempts(self): + """ + When ClusterDownError is thrown, test that we retry executing the + command as many times as configured in cluster_error_retry_attempts + and then raise the exception + """ + with patch.object(RedisCluster, '_execute_command') as execute_command: + def raise_cluster_down_error(target_node, *args, **kwargs): + execute_command.failed_calls += 1 + raise ClusterDownError( + 'CLUSTERDOWN The cluster is down. Use CLUSTER INFO for ' + 'more information') + + execute_command.side_effect = raise_cluster_down_error + + rc = get_mocked_redis_client(host=default_host, port=default_port) + + with pytest.raises(ClusterDownError): + rc.get("bar") + assert execute_command.failed_calls == \ + rc.cluster_error_retry_attempts + + @pytest.mark.filterwarnings("ignore:ConnectionError") + def test_connection_error_overreaches_retry_attempts(self): + """ + When ConnectionError is thrown, test that we retry executing the + command as many times as configured in cluster_error_retry_attempts + and then raise the exception + """ + with patch.object(RedisCluster, '_execute_command') as execute_command: + def raise_conn_error(target_node, *args, **kwargs): + execute_command.failed_calls += 1 + raise ConnectionError() + + execute_command.side_effect = raise_conn_error + + rc = get_mocked_redis_client(host=default_host, port=default_port) + + with pytest.raises(ConnectionError): + rc.get("bar") + assert execute_command.failed_calls == \ + rc.cluster_error_retry_attempts + + def test_user_on_connect_function(self, request): + """ + Test support in passing on_connect function by the user + """ + + def on_connect(connection): + assert connection is not None + + mock = Mock(side_effect=on_connect) + + _get_client(RedisCluster, request, redis_connect_func=mock) + assert mock.called is True + + +@skip_if_not_cluster_mode() +class TestClusterRedisCommands: + def test_case_insensitive_command_names(self, r): + assert r.cluster_response_callbacks['cluster addslots'] == \ + r.cluster_response_callbacks['CLUSTER ADDSLOTS'] + + def test_get_and_set(self, r): + # get and set can't be tested independently of each other + assert r.get('a') is None + byte_string = b'value' + integer = 5 + unicode_string = chr(3456) + 'abcd' + chr(3421) + assert r.set('byte_string', byte_string) + assert r.set('integer', 5) + assert r.set('unicode_string', unicode_string) + assert r.get('byte_string') == byte_string + assert r.get('integer') == str(integer).encode() + assert r.get('unicode_string').decode('utf-8') == unicode_string + + def test_mget_nonatomic(self, r): + assert r.mget_nonatomic([]) == [] + assert r.mget_nonatomic(['a', 'b']) == [None, None] + r['a'] = '1' + r['b'] = '2' + r['c'] = '3' + + assert (r.mget_nonatomic('a', 'other', 'b', 'c') == + [b'1', None, b'2', b'3']) + + def test_mset_nonatomic(self, r): + d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} + assert r.mset_nonatomic(d) + for k, v in d.items(): + assert r[k] == v + + def test_dbsize(self, r): + d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} + assert r.mset_nonatomic(d) + assert r.dbsize() == len(d) + + def test_config_set(self, r): + assert r.config_set('slowlog-log-slower-than', 0) + + def test_client_setname(self, r): + r.client_setname('redis_py_test') + res = r.client_getname() + for client_name in res.values(): + assert client_name == 'redis_py_test' + + def test_exists(self, r): + d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} + r.mset_nonatomic(d) + assert r.exists(*d.keys()) == len(d) + + def test_delete(self, r): + d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} + r.mset_nonatomic(d) + assert r.delete(*d.keys()) == len(d) + assert r.delete(*d.keys()) == 0 + + def test_touch(self, r): + d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} + r.mset_nonatomic(d) + assert r.touch(*d.keys()) == len(d) + + def test_unlink(self, r): + d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} + r.mset_nonatomic(d) + assert r.unlink(*d.keys()) == len(d) + # Unlink is non-blocking so we sleep before + # verifying the deletion + sleep(0.1) + assert r.unlink(*d.keys()) == 0 + + def test_pubsub_channels_merge_results(self, r): + nodes = r.get_nodes() + channels = [] + i = 0 + for node in nodes: + channel = "foo{0}".format(i) + # We will create different pubsub clients where each one is + # connected to a different node + p = r.pubsub(node) + p.subscribe(channel) + b_channel = channel.encode('utf-8') + channels.append(b_channel) + # Assert that each node returns only the channel it subscribed to + sub_channels = node.redis_connection.pubsub_channels() + if not sub_channels: + # Try again after a short sleep + sleep(0.3) + sub_channels = node.redis_connection.pubsub_channels() + assert sub_channels == [b_channel] + i += 1 + # Assert that the cluster's pubsub_channels function returns ALL of + # the cluster's channels + result = r.pubsub_channels() + result.sort() + assert result == channels + + def test_pubsub_numsub_merge_results(self, r): + nodes = r.get_nodes() + channel = "foo" + b_channel = channel.encode('utf-8') + for node in nodes: + # We will create different pubsub clients where each one is + # connected to a different node + p = r.pubsub(node) + p.subscribe(channel) + # Assert that each node returns that only one client is subscribed + sub_chann_num = node.redis_connection.pubsub_numsub(channel) + if sub_chann_num == [(b_channel, 0)]: + sleep(0.3) + sub_chann_num = node.redis_connection.pubsub_numsub(channel) + assert sub_chann_num == [(b_channel, 1)] + # Assert that the cluster's pubsub_numsub function returns ALL clients + # subscribed to this channel in the entire cluster + assert r.pubsub_numsub(channel) == [(b_channel, len(nodes))] + + def test_pubsub_numpat_merge_results(self, r): + nodes = r.get_nodes() + pattern = "foo*" + for node in nodes: + # We will create different pubsub clients where each one is + # connected to a different node + p = r.pubsub(node) + p.psubscribe(pattern) + # Assert that each node returns that only one client is subscribed + sub_num_pat = node.redis_connection.pubsub_numpat() + if sub_num_pat == 0: + sleep(0.3) + sub_num_pat = node.redis_connection.pubsub_numpat() + assert sub_num_pat == 1 + # Assert that the cluster's pubsub_numsub function returns ALL clients + # subscribed to this channel in the entire cluster + assert r.pubsub_numpat() == len(nodes) + + def test_cluster_slots(self, r): + mock_all_nodes_resp(r, default_cluster_slots) + cluster_slots = r.cluster_slots() + assert isinstance(cluster_slots, dict) + assert len(default_cluster_slots) == len(cluster_slots) + assert cluster_slots.get((0, 8191)) is not None + assert cluster_slots.get((0, 8191)).get('primary') == \ + ('127.0.0.1', 7000) + + def test_cluster_addslots(self, r): + node = r.get_random_node() + mock_node_resp(node, 'OK') + assert r.cluster_addslots(node, 1, 2, 3) is True + + def test_cluster_countkeysinslot(self, r): + node = r.nodes_manager.get_node_from_slot(1) + mock_node_resp(node, 2) + assert r.cluster_countkeysinslot(1) == 2 + + def test_cluster_count_failure_report(self, r): + mock_all_nodes_resp(r, 0) + assert r.cluster_count_failure_report('node_0') == 0 + + def test_cluster_delslots(self): + cluster_slots = [ + [ + 0, 8191, + ['127.0.0.1', 7000, 'node_0'], + ], + [ + 8192, 16383, + ['127.0.0.1', 7001, 'node_1'], + ] + ] + r = get_mocked_redis_client(host=default_host, port=default_port, + cluster_slots=cluster_slots) + mock_all_nodes_resp(r, 'OK') + node0 = r.get_node(default_host, 7000) + node1 = r.get_node(default_host, 7001) + assert r.cluster_delslots(0, 8192) == [True, True] + assert node0.redis_connection.connection.read_response.called + assert node1.redis_connection.connection.read_response.called + + def test_cluster_failover(self, r): + node = r.get_random_node() + mock_node_resp(node, 'OK') + assert r.cluster_failover(node) is True + assert r.cluster_failover(node, 'FORCE') is True + assert r.cluster_failover(node, 'TAKEOVER') is True + with pytest.raises(RedisError): + r.cluster_failover(node, 'FORCT') + + def test_cluster_info(self, r): + info = r.cluster_info() + assert isinstance(info, dict) + assert info['cluster_state'] == 'ok' + + def test_cluster_keyslot(self, r): + mock_all_nodes_resp(r, 12182) + assert r.cluster_keyslot('foo') == 12182 + + def test_cluster_meet(self, r): + node = r.get_random_node() + mock_node_resp(node, 'OK') + assert r.cluster_meet(node, '127.0.0.1', 6379) is True + + def test_cluster_nodes(self, r): + response = ( + 'c8253bae761cb1ecb2b61857d85dfe455a0fec8b 172.17.0.7:7006 ' + 'slave aa90da731f673a99617dfe930306549a09f83a6b 0 ' + '1447836263059 5 connected\n' + '9bd595fe4821a0e8d6b99d70faa660638a7612b3 172.17.0.7:7008 ' + 'master - 0 1447836264065 0 connected\n' + 'aa90da731f673a99617dfe930306549a09f83a6b 172.17.0.7:7003 ' + 'myself,master - 0 0 2 connected 5461-10922\n' + '1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 ' + 'slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 ' + '1447836262556 3 connected\n' + '4ad9a12e63e8f0207025eeba2354bcf4c85e5b22 172.17.0.7:7005 ' + 'master - 0 1447836262555 7 connected 0-5460\n' + '19efe5a631f3296fdf21a5441680f893e8cc96ec 172.17.0.7:7004 ' + 'master - 0 1447836263562 3 connected 10923-16383\n' + 'fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 ' + 'master,fail - 1447829446956 1447829444948 1 disconnected\n' + ) + mock_all_nodes_resp(r, response) + nodes = r.cluster_nodes() + assert len(nodes) == 7 + assert nodes.get('172.17.0.7:7006') is not None + assert nodes.get('172.17.0.7:7006').get('node_id') == \ + "c8253bae761cb1ecb2b61857d85dfe455a0fec8b" + + def test_cluster_replicate(self, r): + node = r.get_random_node() + all_replicas = r.get_replicas() + mock_all_nodes_resp(r, 'OK') + assert r.cluster_replicate(node, 'c8253bae761cb61857d') is True + results = r.cluster_replicate(all_replicas, 'c8253bae761cb61857d') + for res in results.values(): + assert res is True + + def test_cluster_reset(self, r): + node = r.get_random_node() + all_nodes = r.get_nodes() + mock_all_nodes_resp(r, 'OK') + assert r.cluster_reset(node) is True + assert r.cluster_reset(node, False) is True + all_results = r.cluster_reset(all_nodes, False) + for res in all_results.values(): + assert res is True + + def test_cluster_save_config(self, r): + node = r.get_random_node() + all_nodes = r.get_nodes() + mock_all_nodes_resp(r, 'OK') + assert r.cluster_save_config(node) is True + all_results = r.cluster_save_config(all_nodes) + for res in all_results.values(): + assert res is True + + def test_cluster_get_keys_in_slot(self, r): + response = [b'{foo}1', b'{foo}2'] + node = r.nodes_manager.get_node_from_slot(12182) + mock_node_resp(node, response) + keys = r.cluster_get_keys_in_slot(12182, 4) + assert keys == response + + def test_cluster_set_config_epoch(self, r): + node = r.get_random_node() + all_nodes = r.get_nodes() + mock_all_nodes_resp(r, 'OK') + assert r.cluster_set_config_epoch(node, 3) is True + all_results = r.cluster_set_config_epoch(all_nodes, 3) + for res in all_results.values(): + assert res is True + + def test_cluster_setslot(self, r): + node = r.get_random_node() + mock_node_resp(node, 'OK') + assert r.cluster_setslot(node, 'node_0', 1218, 'IMPORTING') is True + assert r.cluster_setslot(node, 'node_0', 1218, 'NODE') is True + assert r.cluster_setslot(node, 'node_0', 1218, 'MIGRATING') is True + with pytest.raises(RedisError): + r.cluster_failover(node, 'STABLE') + with pytest.raises(RedisError): + r.cluster_failover(node, 'STATE') + + def test_cluster_setslot_stable(self, r): + node = r.nodes_manager.get_node_from_slot(12182) + mock_node_resp(node, 'OK') + assert r.cluster_setslot_stable(12182) is True + assert node.redis_connection.connection.read_response.called + + def test_cluster_replicas(self, r): + response = [b'01eca22229cf3c652b6fca0d09ff6941e0d2e3 ' + b'127.0.0.1:6377@16377 slave ' + b'52611e796814b78e90ad94be9d769a4f668f9a 0 ' + b'1634550063436 4 connected', + b'r4xfga22229cf3c652b6fca0d09ff69f3e0d4d ' + b'127.0.0.1:6378@16378 slave ' + b'52611e796814b78e90ad94be9d769a4f668f9a 0 ' + b'1634550063436 4 connected'] + mock_all_nodes_resp(r, response) + replicas = r.cluster_replicas('52611e796814b78e90ad94be9d769a4f668f9a') + assert replicas.get('127.0.0.1:6377') is not None + assert replicas.get('127.0.0.1:6378') is not None + assert replicas.get('127.0.0.1:6378').get('node_id') == \ + 'r4xfga22229cf3c652b6fca0d09ff69f3e0d4d' + + def test_readonly(self): + r = get_mocked_redis_client(host=default_host, port=default_port) + node = r.get_random_node() + all_replicas = r.get_replicas() + mock_all_nodes_resp(r, 'OK') + assert r.readonly(node) is True + all_replicas_results = r.readonly() + for res in all_replicas_results.values(): + assert res is True + for replica in all_replicas: + assert replica.redis_connection.connection.read_response.called + + def test_readwrite(self): + r = get_mocked_redis_client(host=default_host, port=default_port) + node = r.get_random_node() + mock_all_nodes_resp(r, 'OK') + all_replicas = r.get_replicas() + assert r.readwrite(node) is True + all_replicas_results = r.readwrite() + for res in all_replicas_results.values(): + assert res is True + for replica in all_replicas: + assert replica.redis_connection.connection.read_response.called + + def test_bgsave(self, r): + assert r.bgsave() + sleep(0.3) + assert r.bgsave(True) + + def test_info(self, r): + # Map keys to same slot + r.set('x{1}', 1) + r.set('y{1}', 2) + r.set('z{1}', 3) + # Get node that handles the slot + slot = r.keyslot('x{1}') + node = r.nodes_manager.get_node_from_slot(slot) + # Run info on that node + info = r.info(target_nodes=node) + assert isinstance(info, dict) + assert info['db0']['keys'] == 3 + + def test_slowlog_get(self, r, slowlog): + assert r.slowlog_reset() + unicode_string = chr(3456) + 'abcd' + chr(3421) + r.get(unicode_string) + + slot = r.keyslot(unicode_string) + node = r.nodes_manager.get_node_from_slot(slot) + slowlog = r.slowlog_get(target_nodes=node) + assert isinstance(slowlog, list) + commands = [log['command'] for log in slowlog] + + get_command = b' '.join((b'GET', unicode_string.encode('utf-8'))) + assert get_command in commands + assert b'SLOWLOG RESET' in commands + + # the order should be ['GET ', 'SLOWLOG RESET'], + # but if other clients are executing commands at the same time, there + # could be commands, before, between, or after, so just check that + # the two we care about are in the appropriate order. + assert commands.index(get_command) < commands.index(b'SLOWLOG RESET') + + # make sure other attributes are typed correctly + assert isinstance(slowlog[0]['start_time'], int) + assert isinstance(slowlog[0]['duration'], int) + + def test_slowlog_get_limit(self, r, slowlog): + assert r.slowlog_reset() + r.get('foo') + node = r.nodes_manager.get_node_from_slot(key_slot(b'foo')) + slowlog = r.slowlog_get(1, target_nodes=node) + assert isinstance(slowlog, list) + # only one command, based on the number we passed to slowlog_get() + assert len(slowlog) == 1 + + def test_slowlog_length(self, r, slowlog): + r.get('foo') + node = r.nodes_manager.get_node_from_slot(key_slot(b'foo')) + slowlog_len = r.slowlog_len(target_nodes=node) + assert isinstance(slowlog_len, int) + + def test_time(self, r): + t = r.time(target_nodes=r.get_primaries()[0]) + assert len(t) == 2 + assert isinstance(t[0], int) + assert isinstance(t[1], int) + + @skip_if_server_version_lt('4.0.0') + def test_memory_usage(self, r): + r.set('foo', 'bar') + assert isinstance(r.memory_usage('foo'), int) + + @skip_if_server_version_lt('4.0.0') + def test_memory_malloc_stats(self, r): + assert r.memory_malloc_stats() + + @skip_if_server_version_lt('4.0.0') + def test_memory_stats(self, r): + # put a key into the current db to make sure that "db." + # has data + r.set('foo', 'bar') + node = r.nodes_manager.get_node_from_slot(key_slot(b'foo')) + stats = r.memory_stats(target_nodes=node) + assert isinstance(stats, dict) + for key, value in stats.items(): + if key.startswith('db.'): + assert isinstance(value, dict) + + @skip_if_server_version_lt('4.0.0') + def test_memory_help(self, r): + with pytest.raises(NotImplementedError): + r.memory_help() + + @skip_if_server_version_lt('4.0.0') + def test_memory_doctor(self, r): + with pytest.raises(NotImplementedError): + r.memory_doctor() + + def test_object(self, r): + r['a'] = 'foo' + assert isinstance(r.object('refcount', 'a'), int) + assert isinstance(r.object('idletime', 'a'), int) + assert r.object('encoding', 'a') in (b'raw', b'embstr') + assert r.object('idletime', 'invalid-key') is None + + def test_lastsave(self, r): + node = r.get_primaries()[0] + assert isinstance(r.lastsave(target_nodes=node), + datetime.datetime) + + def test_echo(self, r): + node = r.get_primaries()[0] + assert r.echo('foo bar', node) == b'foo bar' + + @skip_if_server_version_lt('1.0.0') + def test_debug_segfault(self, r): + with pytest.raises(NotImplementedError): + r.debug_segfault() + + def test_config_resetstat(self, r): + node = r.get_primaries()[0] + r.ping(target_nodes=node) + prior_commands_processed = \ + int(r.info(target_nodes=node)['total_commands_processed']) + assert prior_commands_processed >= 1 + r.config_resetstat(target_nodes=node) + reset_commands_processed = \ + int(r.info(target_nodes=node)['total_commands_processed']) + assert reset_commands_processed < prior_commands_processed + + @skip_if_server_version_lt('6.2.0') + def test_client_trackinginfo(self, r): + node = r.get_primaries()[0] + res = r.client_trackinginfo(target_nodes=node) + assert len(res) > 2 + assert 'prefixes' in res + + @skip_if_server_version_lt('2.9.50') + def test_client_pause(self, r): + node = r.get_primaries()[0] + assert r.client_pause(1, target_nodes=node) + assert r.client_pause(timeout=1, target_nodes=node) + with pytest.raises(RedisError): + r.client_pause(timeout='not an integer', target_nodes=node) + + @skip_if_server_version_lt('6.2.0') + def test_client_unpause(self, r): + assert r.client_unpause() + + @skip_if_server_version_lt('5.0.0') + def test_client_id(self, r): + node = r.get_primaries()[0] + assert r.client_id(target_nodes=node) > 0 + + @skip_if_server_version_lt('5.0.0') + def test_client_unblock(self, r): + node = r.get_primaries()[0] + myid = r.client_id(target_nodes=node) + assert not r.client_unblock(myid, target_nodes=node) + assert not r.client_unblock(myid, error=True, target_nodes=node) + assert not r.client_unblock(myid, error=False, target_nodes=node) + + @skip_if_server_version_lt('6.0.0') + def test_client_getredir(self, r): + node = r.get_primaries()[0] + assert isinstance(r.client_getredir(target_nodes=node), int) + assert r.client_getredir(target_nodes=node) == -1 + + @skip_if_server_version_lt('6.2.0') + def test_client_info(self, r): + node = r.get_primaries()[0] + info = r.client_info(target_nodes=node) + assert isinstance(info, dict) + assert 'addr' in info + + @skip_if_server_version_lt('2.6.9') + def test_client_kill(self, r, r2): + node = r.get_primaries()[0] + r.client_setname('redis-py-c1') + r2.client_setname('redis-py-c2') + clients = [client for client in r.client_list()[node.name] + if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + assert len(clients) == 2 + clients_by_name = dict([(client.get('name'), client) + for client in clients]) + + client_addr = clients_by_name['redis-py-c2'].get('addr') + assert r.client_kill(client_addr, target_nodes=node) is True + + clients = [client for client in r.client_list()[node.name] + if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + assert len(clients) == 1 + assert clients[0].get('name') == 'redis-py-c1' + + +@skip_if_not_cluster_mode() +class TestNodesManager: + def test_load_balancer(self, r): + n_manager = r.nodes_manager + lb = n_manager.read_load_balancer + slot_1 = 1257 + slot_2 = 8975 + node_1 = ClusterNode(default_host, 6379, PRIMARY) + node_2 = ClusterNode(default_host, 6378, REPLICA) + node_3 = ClusterNode(default_host, 6377, REPLICA) + node_4 = ClusterNode(default_host, 6376, PRIMARY) + node_5 = ClusterNode(default_host, 6375, REPLICA) + n_manager.slots_cache = { + slot_1: [node_1, node_2, node_3], + slot_2: [node_4, node_5] + } + primary1_name = n_manager.slots_cache[slot_1][0].name + primary2_name = n_manager.slots_cache[slot_2][0].name + list1_size = len(n_manager.slots_cache[slot_1]) + list2_size = len(n_manager.slots_cache[slot_2]) + # slot 1 + assert lb.get_server_index(primary1_name, list1_size) == 0 + assert lb.get_server_index(primary1_name, list1_size) == 1 + assert lb.get_server_index(primary1_name, list1_size) == 2 + assert lb.get_server_index(primary1_name, list1_size) == 0 + # slot 2 + assert lb.get_server_index(primary2_name, list2_size) == 0 + assert lb.get_server_index(primary2_name, list2_size) == 1 + assert lb.get_server_index(primary2_name, list2_size) == 0 + + lb.reset() + assert lb.get_server_index(primary1_name, list1_size) == 0 + assert lb.get_server_index(primary2_name, list2_size) == 0 + + def test_init_slots_cache_not_all_slots_covered(self): + """ + Test that if not all slots are covered it should raise an exception + """ + # Missing slot 5460 + cluster_slots = [ + [0, 5459, ['127.0.0.1', 7000], ['127.0.0.1', 7003]], + [5461, 10922, ['127.0.0.1', 7001], + ['127.0.0.1', 7004]], + [10923, 16383, ['127.0.0.1', 7002], + ['127.0.0.1', 7005]], + ] + with pytest.raises(RedisClusterException) as ex: + get_mocked_redis_client(host=default_host, port=default_port, + cluster_slots=cluster_slots) + assert str(ex.value).startswith( + "All slots are not covered after query all startup_nodes.") + + def test_init_slots_cache_not_require_full_coverage_error(self): + """ + When require_full_coverage is set to False and not all slots are + covered, if one of the nodes has 'cluster-require_full_coverage' + config set to 'yes' the cluster initialization should fail + """ + # Missing slot 5460 + cluster_slots = [ + [0, 5459, ['127.0.0.1', 7000], ['127.0.0.1', 7003]], + [5461, 10922, ['127.0.0.1', 7001], + ['127.0.0.1', 7004]], + [10923, 16383, ['127.0.0.1', 7002], + ['127.0.0.1', 7005]], + ] + + with pytest.raises(RedisClusterException): + get_mocked_redis_client(host=default_host, port=default_port, + cluster_slots=cluster_slots, + require_full_coverage=False, + coverage_result='yes') + + def test_init_slots_cache_not_require_full_coverage_success(self): + """ + When require_full_coverage is set to False and not all slots are + covered, if all of the nodes has 'cluster-require_full_coverage' + config set to 'no' the cluster initialization should succeed + """ + # Missing slot 5460 + cluster_slots = [ + [0, 5459, ['127.0.0.1', 7000], ['127.0.0.1', 7003]], + [5461, 10922, ['127.0.0.1', 7001], + ['127.0.0.1', 7004]], + [10923, 16383, ['127.0.0.1', 7002], + ['127.0.0.1', 7005]], + ] + + rc = get_mocked_redis_client(host=default_host, port=default_port, + cluster_slots=cluster_slots, + require_full_coverage=False, + coverage_result='no') + + assert 5460 not in rc.nodes_manager.slots_cache + + def test_init_slots_cache_not_require_full_coverage_skips_check(self): + """ + Test that when require_full_coverage is set to False and + skip_full_coverage_check is set to true, the cluster initialization + succeed without checking the nodes' Redis configurations + """ + # Missing slot 5460 + cluster_slots = [ + [0, 5459, ['127.0.0.1', 7000], ['127.0.0.1', 7003]], + [5461, 10922, ['127.0.0.1', 7001], + ['127.0.0.1', 7004]], + [10923, 16383, ['127.0.0.1', 7002], + ['127.0.0.1', 7005]], + ] + + with patch.object(NodesManager, + 'cluster_require_full_coverage') as conf_check_mock: + rc = get_mocked_redis_client(host=default_host, port=default_port, + cluster_slots=cluster_slots, + require_full_coverage=False, + skip_full_coverage_check=True, + coverage_result='no') + + assert conf_check_mock.called is False + assert 5460 not in rc.nodes_manager.slots_cache + + def test_init_slots_cache(self): + """ + Test that slots cache can in initialized and all slots are covered + """ + good_slots_resp = [ + [0, 5460, ['127.0.0.1', 7000], ['127.0.0.2', 7003]], + [5461, 10922, ['127.0.0.1', 7001], ['127.0.0.2', 7004]], + [10923, 16383, ['127.0.0.1', 7002], ['127.0.0.2', 7005]], + ] + + rc = get_mocked_redis_client(host=default_host, port=default_port, + cluster_slots=good_slots_resp) + n_manager = rc.nodes_manager + assert len(n_manager.slots_cache) == REDIS_CLUSTER_HASH_SLOTS + for slot_info in good_slots_resp: + all_hosts = ['127.0.0.1', '127.0.0.2'] + all_ports = [7000, 7001, 7002, 7003, 7004, 7005] + slot_start = slot_info[0] + slot_end = slot_info[1] + for i in range(slot_start, slot_end + 1): + assert len(n_manager.slots_cache[i]) == len(slot_info[2:]) + assert n_manager.slots_cache[i][0].host in all_hosts + assert n_manager.slots_cache[i][1].host in all_hosts + assert n_manager.slots_cache[i][0].port in all_ports + assert n_manager.slots_cache[i][1].port in all_ports + + assert len(n_manager.nodes_cache) == 6 + + def test_empty_startup_nodes(self): + """ + It should not be possible to create a node manager with no nodes + specified + """ + with pytest.raises(RedisClusterException): + NodesManager([]) + + def test_wrong_startup_nodes_type(self): + """ + If something other then a list type itteratable is provided it should + fail + """ + with pytest.raises(RedisClusterException): + NodesManager({}) + + def test_init_slots_cache_slots_collision(self, request): + """ + Test that if 2 nodes do not agree on the same slots setup it should + raise an error. In this test both nodes will say that the first + slots block should be bound to different servers. + """ + with patch.object(NodesManager, + 'create_redis_node') as create_redis_node: + def create_mocked_redis_node(host, port, **kwargs): + """ + Helper function to return custom slots cache data from + different redis nodes + """ + if port == 7000: + result = [ + [ + 0, + 5460, + ['127.0.0.1', 7000], + ['127.0.0.1', 7003], + ], + [ + 5461, + 10922, + ['127.0.0.1', 7001], + ['127.0.0.1', 7004], + ], + ] + + elif port == 7001: + result = [ + [ + 0, + 5460, + ['127.0.0.1', 7001], + ['127.0.0.1', 7003], + ], + [ + 5461, + 10922, + ['127.0.0.1', 7000], + ['127.0.0.1', 7004], + ], + ] + else: + result = [] + + r_node = Redis( + host=host, + port=port + ) + + orig_execute_command = r_node.execute_command + + def execute_command(*args, **kwargs): + if args[0] == 'CLUSTER SLOTS': + return result + elif args[1] == 'cluster-require-full-coverage': + return {'cluster-require-full-coverage': 'yes'} + else: + return orig_execute_command(*args, **kwargs) + + r_node.execute_command = execute_command + return r_node + + create_redis_node.side_effect = create_mocked_redis_node + + with pytest.raises(RedisClusterException) as ex: + node_1 = ClusterNode('127.0.0.1', 7000) + node_2 = ClusterNode('127.0.0.1', 7001) + RedisCluster(startup_nodes=[node_1, node_2]) + assert str(ex.value).startswith( + "startup_nodes could not agree on a valid slots cache"), str( + ex.value) + + def test_cluster_one_instance(self): + """ + If the cluster exists of only 1 node then there is some hacks that must + be validated they work. + """ + node = ClusterNode(default_host, default_port) + cluster_slots = [[0, 16383, ['', default_port]]] + rc = get_mocked_redis_client(startup_nodes=[node], + cluster_slots=cluster_slots) + + n = rc.nodes_manager + assert len(n.nodes_cache) == 1 + n_node = rc.get_node(node_name=node.name) + assert n_node is not None + assert n_node == node + assert n_node.server_type == PRIMARY + assert len(n.slots_cache) == REDIS_CLUSTER_HASH_SLOTS + for i in range(0, REDIS_CLUSTER_HASH_SLOTS): + assert n.slots_cache[i] == [n_node] + + def test_init_with_down_node(self): + """ + If I can't connect to one of the nodes, everything should still work. + But if I can't connect to any of the nodes, exception should be thrown. + """ + with patch.object(NodesManager, + 'create_redis_node') as create_redis_node: + def create_mocked_redis_node(host, port, **kwargs): + if port == 7000: + raise ConnectionError('mock connection error for 7000') + + r_node = Redis(host=host, port=port, decode_responses=True) + + def execute_command(*args, **kwargs): + if args[0] == 'CLUSTER SLOTS': + return [ + [ + 0, 8191, + ['127.0.0.1', 7001, 'node_1'], + ], + [ + 8192, 16383, + ['127.0.0.1', 7002, 'node_2'], + ] + ] + elif args[1] == 'cluster-require-full-coverage': + return {'cluster-require-full-coverage': 'yes'} + + r_node.execute_command = execute_command + + return r_node + + create_redis_node.side_effect = create_mocked_redis_node + + node_1 = ClusterNode('127.0.0.1', 7000) + node_2 = ClusterNode('127.0.0.1', 7001) + + # If all startup nodes fail to connect, connection error should be + # thrown + with pytest.raises(RedisClusterException) as e: + RedisCluster(startup_nodes=[node_1]) + assert 'Redis Cluster cannot be connected' in str(e.value) + + with patch.object(CommandsParser, 'initialize', + autospec=True) as cmd_parser_initialize: + + def cmd_init_mock(self, r): + self.commands = {'get': {'name': 'get', 'arity': 2, + 'flags': ['readonly', + 'fast'], + 'first_key_pos': 1, + 'last_key_pos': 1, + 'step_count': 1}} + + cmd_parser_initialize.side_effect = cmd_init_mock + # When at least one startup node is reachable, the cluster + # initialization should succeeds + rc = RedisCluster(startup_nodes=[node_1, node_2]) + assert rc.get_node(host=default_host, port=7001) is not None + assert rc.get_node(host=default_host, port=7002) is not None diff --git a/tests/test_commands.py b/tests/test_commands.py index 6d65931539..998bc0f0e6 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -8,9 +8,10 @@ from redis.client import parse_info from redis import exceptions - +from redis.commands import CommandsParser from .conftest import ( _get_client, + skip_if_cluster_mode, skip_if_server_version_gte, skip_if_server_version_lt, skip_unless_arch_bits, @@ -46,6 +47,7 @@ def get_stream_message(client, stream, message_id): # RESPONSE CALLBACKS +@skip_if_cluster_mode() class TestResponseCallbacks: "Tests for the response callback system" @@ -60,6 +62,7 @@ def test_case_insensitive_command_names(self, r): assert r.response_callbacks['del'] == r.response_callbacks['DEL'] +@skip_if_cluster_mode() class TestRedisCommands: def test_command_on_invalid_key_type(self, r): r.lpush('a', '1') @@ -122,6 +125,7 @@ def test_acl_getuser_setuser(self, r, request): def teardown(): r.acl_deluser(username) + request.addfinalizer(teardown) # test enabled=False @@ -215,6 +219,7 @@ def test_acl_list(self, r, request): def teardown(): r.acl_deluser(username) + request.addfinalizer(teardown) assert r.acl_setuser(username, enabled=False, reset=True) @@ -262,6 +267,7 @@ def test_acl_setuser_categories_without_prefix_fails(self, r, request): def teardown(): r.acl_deluser(username) + request.addfinalizer(teardown) with pytest.raises(exceptions.DataError): @@ -273,6 +279,7 @@ def test_acl_setuser_commands_without_prefix_fails(self, r, request): def teardown(): r.acl_deluser(username) + request.addfinalizer(teardown) with pytest.raises(exceptions.DataError): @@ -284,6 +291,7 @@ def test_acl_setuser_add_passwords_and_nopass_fails(self, r, request): def teardown(): r.acl_deluser(username) + request.addfinalizer(teardown) with pytest.raises(exceptions.DataError): @@ -593,6 +601,7 @@ def parse_response(connection, command_name, **options): # Complexity info stored as fourth item in list response.insert(3, COMPLEXITY_STATEMENT) return r.response_callbacks[command_name](responses, **options) + r.parse_response = parse_response # test @@ -1195,22 +1204,22 @@ def test_stralgo_lcs(self, r): # test other labels assert r.stralgo('LCS', value1, value2, len=True) == len(res) assert r.stralgo('LCS', value1, value2, idx=True) == \ - { - 'len': len(res), - 'matches': [[(4, 7), (5, 8)], [(2, 3), (0, 1)]] - } + { + 'len': len(res), + 'matches': [[(4, 7), (5, 8)], [(2, 3), (0, 1)]] + } assert r.stralgo('LCS', value1, value2, idx=True, withmatchlen=True) == \ - { - 'len': len(res), - 'matches': [[4, (4, 7), (5, 8)], [2, (2, 3), (0, 1)]] - } + { + 'len': len(res), + 'matches': [[4, (4, 7), (5, 8)], [2, (2, 3), (0, 1)]] + } assert r.stralgo('LCS', value1, value2, idx=True, minmatchlen=4, withmatchlen=True) == \ - { - 'len': len(res), - 'matches': [[4, (4, 7), (5, 8)]] - } + { + 'len': len(res), + 'matches': [[4, (4, 7), (5, 8)]] + } @skip_if_server_version_lt('6.0.0') def test_stralgo_negative(self, r): @@ -1758,16 +1767,16 @@ def test_zinter(self, r): r.zinter(['a', 'b', 'c'], aggregate='foo', withscores=True) # aggregate with SUM assert r.zinter(['a', 'b', 'c'], withscores=True) \ - == [(b'a3', 8), (b'a1', 9)] + == [(b'a3', 8), (b'a1', 9)] # aggregate with MAX assert r.zinter(['a', 'b', 'c'], aggregate='MAX', withscores=True) \ - == [(b'a3', 5), (b'a1', 6)] + == [(b'a3', 5), (b'a1', 6)] # aggregate with MIN assert r.zinter(['a', 'b', 'c'], aggregate='MIN', withscores=True) \ - == [(b'a1', 1), (b'a3', 1)] + == [(b'a1', 1), (b'a3', 1)] # with weights assert r.zinter({'a': 1, 'b': 2, 'c': 3}, withscores=True) \ - == [(b'a3', 20), (b'a1', 23)] + == [(b'a3', 20), (b'a1', 23)] def test_zinterstore_sum(self, r): r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) @@ -2059,14 +2068,14 @@ def test_zunion(self, r): assert r.zunion(['a', 'b', 'c'], withscores=True) == \ [(b'a2', 3), (b'a4', 4), (b'a3', 8), (b'a1', 9)] # max - assert r.zunion(['a', 'b', 'c'], aggregate='MAX', withscores=True)\ - == [(b'a2', 2), (b'a4', 4), (b'a3', 5), (b'a1', 6)] + assert r.zunion(['a', 'b', 'c'], aggregate='MAX', withscores=True) \ + == [(b'a2', 2), (b'a4', 4), (b'a3', 5), (b'a1', 6)] # min - assert r.zunion(['a', 'b', 'c'], aggregate='MIN', withscores=True)\ - == [(b'a1', 1), (b'a2', 1), (b'a3', 1), (b'a4', 4)] + assert r.zunion(['a', 'b', 'c'], aggregate='MIN', withscores=True) \ + == [(b'a1', 1), (b'a2', 1), (b'a3', 1), (b'a4', 4)] # with weight - assert r.zunion({'a': 1, 'b': 2, 'c': 3}, withscores=True)\ - == [(b'a2', 5), (b'a4', 12), (b'a3', 20), (b'a1', 23)] + assert r.zunion({'a': 1, 'b': 2, 'c': 3}, withscores=True) \ + == [(b'a2', 5), (b'a4', 12), (b'a3', 20), (b'a1', 23)] def test_zunionstore_sum(self, r): r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) @@ -2927,10 +2936,10 @@ def test_xautoclaim(self, r): # which only returns message ids assert r.xautoclaim(stream, group, consumer1, min_idle_time=0, start_id=0, justid=True) == \ - [message_id1, message_id2] + [message_id1, message_id2] assert r.xautoclaim(stream, group, consumer1, min_idle_time=0, start_id=message_id2, justid=True) == \ - [message_id2] + [message_id2] @skip_if_server_version_lt('6.2.0') def test_xautoclaim_negative(self, r): @@ -3511,51 +3520,51 @@ def test_bitfield_operations(self, r): # comments show affected bits bf = r.bitfield('a') resp = (bf - .set('u8', 8, 255) # 00000000 11111111 - .get('u8', 0) # 00000000 - .get('u4', 8) # 1111 - .get('u4', 12) # 1111 - .get('u4', 13) # 111 0 + .set('u8', 8, 255) # 00000000 11111111 + .get('u8', 0) # 00000000 + .get('u4', 8) # 1111 + .get('u4', 12) # 1111 + .get('u4', 13) # 111 0 .execute()) assert resp == [0, 0, 15, 15, 14] # .set() returns the previous value... resp = (bf - .set('u8', 4, 1) # 0000 0001 - .get('u16', 0) # 00000000 00011111 - .set('u16', 0, 0) # 00000000 00000000 + .set('u8', 4, 1) # 0000 0001 + .get('u16', 0) # 00000000 00011111 + .set('u16', 0, 0) # 00000000 00000000 .execute()) assert resp == [15, 31, 31] # incrby adds to the value resp = (bf .incrby('u8', 8, 254) # 00000000 11111110 - .incrby('u8', 8, 1) # 00000000 11111111 - .get('u16', 0) # 00000000 11111111 + .incrby('u8', 8, 1) # 00000000 11111111 + .get('u16', 0) # 00000000 11111111 .execute()) assert resp == [254, 255, 255] # Verify overflow protection works as a method: r.delete('a') resp = (bf - .set('u8', 8, 254) # 00000000 11111110 + .set('u8', 8, 254) # 00000000 11111110 .overflow('fail') - .incrby('u8', 8, 2) # incrby 2 would overflow, None returned - .incrby('u8', 8, 1) # 00000000 11111111 - .incrby('u8', 8, 1) # incrby 1 would overflow, None returned - .get('u16', 0) # 00000000 11111111 + .incrby('u8', 8, 2) # incrby 2 would overflow, None returned + .incrby('u8', 8, 1) # 00000000 11111111 + .incrby('u8', 8, 1) # incrby 1 would overflow, None returned + .get('u16', 0) # 00000000 11111111 .execute()) assert resp == [0, None, 255, None, 255] # Verify overflow protection works as arg to incrby: r.delete('a') resp = (bf - .set('u8', 8, 255) # 00000000 11111111 - .incrby('u8', 8, 1) # 00000000 00000000 wrap default - .set('u8', 8, 255) # 00000000 11111111 + .set('u8', 8, 255) # 00000000 11111111 + .incrby('u8', 8, 1) # 00000000 00000000 wrap default + .set('u8', 8, 255) # 00000000 11111111 .incrby('u8', 8, 1, 'FAIL') # 00000000 11111111 fail - .incrby('u8', 8, 1) # 00000000 11111111 still fail - .get('u16', 0) # 00000000 11111111 + .incrby('u8', 8, 1) # 00000000 11111111 still fail + .get('u16', 0) # 00000000 11111111 .execute()) assert resp == [0, 0, 0, None, None, 255] @@ -3563,9 +3572,9 @@ def test_bitfield_operations(self, r): r.delete('a') bf = r.bitfield('a', default_overflow='FAIL') resp = (bf - .set('u8', 8, 255) # 00000000 11111111 - .incrby('u8', 8, 1) # 00000000 11111111 fail default - .get('u16', 0) # 00000000 11111111 + .set('u8', 8, 255) # 00000000 11111111 + .incrby('u8', 8, 1) # 00000000 11111111 fail default + .get('u16', 0) # 00000000 11111111 .execute()) assert resp == [0, None, 255] @@ -3672,6 +3681,7 @@ def test_replicaof(self, r): assert r.replicaof("NO", "ONE") +@skip_if_cluster_mode() class TestBinarySave: def test_binary_get_set(self, r): @@ -3757,3 +3767,60 @@ def test_floating_point_encoding(self, r): timestamp = 1349673917.939762 r.zadd('a', {'a1': timestamp}) assert r.zscore('a', 'a1') == timestamp + + +class TestCommandsParser: + def test_init_commands(self, r): + commands_parser = CommandsParser(r) + assert commands_parser.commands is not None + assert 'get' in commands_parser.commands + + def test_get_keys_predetermined_key_location(self, r): + commands_parser = CommandsParser(r) + args1 = ['GET', 'foo'] + args2 = ['OBJECT', 'encoding', 'foo'] + args3 = ['MGET', 'foo', 'bar', 'foobar'] + assert commands_parser.get_keys(r, *args1) == ['foo'] + assert commands_parser.get_keys(r, *args2) == ['foo'] + assert commands_parser.get_keys(r, *args3) == ['foo', 'bar', 'foobar'] + + @pytest.mark.filterwarnings("ignore:ResponseError") + def test_get_moveable_keys(self, r): + commands_parser = CommandsParser(r) + args1 = ['EVAL', 'return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}', 2, 'key1', + 'key2', 'first', 'second'] + args2 = ['XREAD', 'COUNT', 2, b'STREAMS', 'mystream', 'writers', 0, 0] + args3 = ['ZUNIONSTORE', 'out', 2, 'zset1', 'zset2', 'WEIGHTS', 2, 3] + args4 = ['GEORADIUS', 'Sicily', 15, 37, 200, 'km', 'WITHCOORD', + b'STORE', 'out'] + args5 = ['MEMORY USAGE', 'foo'] + args6 = ['MIGRATE', '192.168.1.34', 6379, "", 0, 5000, b'KEYS', + 'key1', 'key2', 'key3'] + args7 = ['MIGRATE', '192.168.1.34', 6379, "key1", 0, 5000] + args8 = ['STRALGO', 'LCS', 'STRINGS', 'string_a', 'string_b'] + args9 = ['STRALGO', 'LCS', 'KEYS', 'key1', 'key2'] + + assert commands_parser.get_keys( + r, *args1).sort() == ['key1', 'key2'].sort() + assert commands_parser.get_keys( + r, *args2).sort() == ['mystream', 'writers'].sort() + assert commands_parser.get_keys( + r, *args3).sort() == ['out', 'zset1', 'zset2'].sort() + assert commands_parser.get_keys( + r, *args4).sort() == ['Sicily', 'out'].sort() + assert commands_parser.get_keys(r, *args5).sort() == ['foo'].sort() + assert commands_parser.get_keys( + r, *args6).sort() == ['key1', 'key2', 'key3'].sort() + assert commands_parser.get_keys(r, *args7).sort() == ['key1'].sort() + assert commands_parser.get_keys(r, *args8) is None + assert commands_parser.get_keys( + r, *args9).sort() == ['key1', 'key2'].sort() + + def test_get_pubsub_keys(self, r): + commands_parser = CommandsParser(r) + args1 = ['PUBLISH', 'foo', 'bar'] + args2 = ['PUBSUB NUMSUB', 'foo1', 'foo2', 'foo3'] + args3 = ['SUBSCRIBE', 'foo1', 'foo2', 'foo3'] + assert commands_parser.get_keys(r, *args1) == ['foo'] + assert commands_parser.get_keys(r, *args2) == ['foo1', 'foo2', 'foo3'] + assert commands_parser.get_keys(r, *args3) == ['foo1', 'foo2', 'foo3'] diff --git a/tests/test_connection.py b/tests/test_connection.py index fa9a2b0c90..2ca858d263 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -4,10 +4,11 @@ from redis.exceptions import InvalidResponse, ModuleError from redis.utils import HIREDIS_AVAILABLE -from .conftest import skip_if_server_version_lt +from .conftest import skip_if_server_version_lt, skip_if_cluster_mode @pytest.mark.skipif(HIREDIS_AVAILABLE, reason='PythonParser only') +@skip_if_cluster_mode() def test_invalid_response(r): raw = b'x' parser = r.connection._parser @@ -17,12 +18,14 @@ def test_invalid_response(r): assert str(cm.value) == 'Protocol Error: %r' % raw +@skip_if_cluster_mode() @skip_if_server_version_lt('4.0.0') def test_loaded_modules(r, modclient): assert r.loaded_modules == [] assert 'rejson' in modclient.loaded_modules.keys() +@skip_if_cluster_mode() @skip_if_server_version_lt('4.0.0') def test_loading_external_modules(r, modclient): def inner(): diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 8d2ad041a0..4708057e98 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -7,7 +7,8 @@ from threading import Thread from redis.connection import ssl_available, to_bool -from .conftest import skip_if_server_version_lt, _get_client +from .conftest import skip_if_server_version_lt, skip_if_cluster_mode,\ + _get_client from .test_pubsub import wait_for_message @@ -43,15 +44,15 @@ def test_connection_creation(self): assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs - def test_multiple_connections(self, master_host): - connection_kwargs = {'host': master_host} + def test_multiple_connections(self, master_host, master_port): + connection_kwargs = {'host': master_host, 'port': master_port} pool = self.get_pool(connection_kwargs=connection_kwargs) c1 = pool.get_connection('_') c2 = pool.get_connection('_') assert c1 != c2 - def test_max_connections(self, master_host): - connection_kwargs = {'host': master_host} + def test_max_connections(self, master_host, master_port): + connection_kwargs = {'host': master_host, 'port': master_port} pool = self.get_pool(max_connections=2, connection_kwargs=connection_kwargs) pool.get_connection('_') @@ -59,8 +60,9 @@ def test_max_connections(self, master_host): with pytest.raises(redis.ConnectionError): pool.get_connection('_') - def test_reuse_previously_released_connection(self, master_host): - connection_kwargs = {'host': master_host} + def test_reuse_previously_released_connection(self, master_host, + master_port): + connection_kwargs = {'host': master_host, 'port': master_port} pool = self.get_pool(connection_kwargs=connection_kwargs) c1 = pool.get_connection('_') pool.release(c1) @@ -463,6 +465,7 @@ def get_connection(self, *args, **kwargs): assert pool.get_connection('_').check_hostname is True +@pytest.mark.filterwarnings("ignore:BaseException") class TestConnection: def test_on_connect_error(self): """ @@ -479,6 +482,7 @@ def test_on_connect_error(self): assert len(pool._available_connections) == 1 assert not pool._available_connections[0]._sock + @skip_if_cluster_mode() @skip_if_server_version_lt('2.8.8') def test_busy_loading_disconnects_socket(self, r): """ @@ -489,6 +493,7 @@ def test_busy_loading_disconnects_socket(self, r): r.execute_command('DEBUG', 'ERROR', 'LOADING fake message') assert not r.connection._sock + @skip_if_cluster_mode() @skip_if_server_version_lt('2.8.8') def test_busy_loading_from_pipeline_immediate_command(self, r): """ @@ -504,6 +509,7 @@ def test_busy_loading_from_pipeline_immediate_command(self, r): assert len(pool._available_connections) == 1 assert not pool._available_connections[0]._sock + @skip_if_cluster_mode() @skip_if_server_version_lt('2.8.8') def test_busy_loading_from_pipeline(self, r): """ @@ -519,6 +525,7 @@ def test_busy_loading_from_pipeline(self, r): assert len(pool._available_connections) == 1 assert not pool._available_connections[0]._sock + @pytest.mark.filterwarnings("ignore:ResponseError") @skip_if_server_version_lt('2.8.8') def test_read_only_error(self, r): "READONLY errors get turned in ReadOnlyError exceptions" @@ -560,6 +567,7 @@ def test_connect_invalid_password_supplied(self, r): r.execute_command('DEBUG', 'ERROR', 'ERR invalid password') +@skip_if_cluster_mode() class TestMultiConnectionClient: @pytest.fixture() def r(self, request): @@ -573,6 +581,7 @@ def test_multi_connection_command(self, r): assert r.get('a') == b'123' +@skip_if_cluster_mode() class TestHealthCheck: interval = 60 diff --git a/tests/test_encoding.py b/tests/test_encoding.py index 706654f89f..955735338b 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -91,6 +91,7 @@ def test_basic_command(self, r): r.set('hello', 'world') +@pytest.mark.filterwarnings("ignore:BaseException") class TestInvalidUserInput: def test_boolean_fails(self, r): with pytest.raises(redis.DataError): diff --git a/tests/test_json.py b/tests/test_json.py index 83fbf28669..c0b4d9ee4c 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1,7 +1,7 @@ import pytest import redis from redis.commands.json.path import Path -from .conftest import skip_ifmodversion_lt +from .conftest import skip_ifmodversion_lt, skip_if_cluster_mode @pytest.fixture @@ -10,226 +10,207 @@ def client(modclient): return modclient -@pytest.mark.redismod -def test_json_setbinarykey(client): - d = {"hello": "world", b"some": "value"} - with pytest.raises(TypeError): - client.json().set("somekey", Path.rootPath(), d) - assert client.json().set("somekey", Path.rootPath(), d, decode_keys=True) - - -@pytest.mark.redismod -def test_json_setgetdeleteforget(client): - assert client.json().set("foo", Path.rootPath(), "bar") - assert client.json().get("foo") == "bar" - assert client.json().get("baz") is None - assert client.json().delete("foo") == 1 - assert client.json().forget("foo") == 0 # second delete - assert client.exists("foo") == 0 - - -@pytest.mark.redismod -def test_justaget(client): - client.json().set("foo", Path.rootPath(), "bar") - assert client.json().get("foo") == "bar" - - -@pytest.mark.redismod -def test_json_get_jset(client): - assert client.json().set("foo", Path.rootPath(), "bar") - assert "bar" == client.json().get("foo") - assert client.json().get("baz") is None - assert 1 == client.json().delete("foo") - assert client.exists("foo") == 0 - - -@pytest.mark.redismod -def test_nonascii_setgetdelete(client): - assert client.json().set("notascii", Path.rootPath(), - "hyvää-élève") is True - assert "hyvää-élève" == client.json().get("notascii", no_escape=True) - assert 1 == client.json().delete("notascii") - assert client.exists("notascii") == 0 - - -@pytest.mark.redismod -def test_jsonsetexistentialmodifiersshouldsucceed(client): - obj = {"foo": "bar"} - assert client.json().set("obj", Path.rootPath(), obj) - - # Test that flags prevent updates when conditions are unmet - assert client.json().set("obj", Path("foo"), "baz", nx=True) is None - assert client.json().set("obj", Path("qaz"), "baz", xx=True) is None - - # Test that flags allow updates when conditions are met - assert client.json().set("obj", Path("foo"), "baz", xx=True) - assert client.json().set("obj", Path("qaz"), "baz", nx=True) - - # Test that flags are mutually exlusive - with pytest.raises(Exception): - client.json().set("obj", Path("foo"), "baz", nx=True, xx=True) - - -@pytest.mark.redismod -def test_mgetshouldsucceed(client): - client.json().set("1", Path.rootPath(), 1) - client.json().set("2", Path.rootPath(), 2) - r = client.json().mget(Path.rootPath(), "1", "2") - e = [1, 2] - assert e == r - - -@pytest.mark.redismod -@skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the release -def test_clearShouldSucceed(client): - client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) - assert 1 == client.json().clear("arr", Path.rootPath()) - assert [] == client.json().get("arr") - - -@pytest.mark.redismod -def test_typeshouldsucceed(client): - client.json().set("1", Path.rootPath(), 1) - assert b"integer" == client.json().type("1") - - -@pytest.mark.redismod -def test_numincrbyshouldsucceed(client): - client.json().set("num", Path.rootPath(), 1) - assert 2 == client.json().numincrby("num", Path.rootPath(), 1) - assert 2.5 == client.json().numincrby("num", Path.rootPath(), 0.5) - assert 1.25 == client.json().numincrby("num", Path.rootPath(), -1.25) - - -@pytest.mark.redismod -def test_nummultbyshouldsucceed(client): - client.json().set("num", Path.rootPath(), 1) - assert 2 == client.json().nummultby("num", Path.rootPath(), 2) - assert 5 == client.json().nummultby("num", Path.rootPath(), 2.5) - assert 2.5 == client.json().nummultby("num", Path.rootPath(), 0.5) - - -@pytest.mark.redismod -@skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the release -def test_toggleShouldSucceed(client): - client.json().set("bool", Path.rootPath(), False) - assert client.json().toggle("bool", Path.rootPath()) - assert not client.json().toggle("bool", Path.rootPath()) - # check non-boolean value - client.json().set("num", Path.rootPath(), 1) - with pytest.raises(redis.exceptions.ResponseError): - client.json().toggle("num", Path.rootPath()) - - -@pytest.mark.redismod -def test_strappendshouldsucceed(client): - client.json().set("str", Path.rootPath(), "foo") - assert 6 == client.json().strappend("str", "bar", Path.rootPath()) - assert "foobar" == client.json().get("str", Path.rootPath()) - - -@pytest.mark.redismod -def test_debug(client): - client.json().set("str", Path.rootPath(), "foo") - assert 24 == client.json().debug("str", Path.rootPath()) - - -@pytest.mark.redismod -def test_strlenshouldsucceed(client): - client.json().set("str", Path.rootPath(), "foo") - assert 3 == client.json().strlen("str", Path.rootPath()) - client.json().strappend("str", "bar", Path.rootPath()) - assert 6 == client.json().strlen("str", Path.rootPath()) - - -@pytest.mark.redismod -def test_arrappendshouldsucceed(client): - client.json().set("arr", Path.rootPath(), [1]) - assert 2 == client.json().arrappend("arr", Path.rootPath(), 2) - assert 4 == client.json().arrappend("arr", Path.rootPath(), 3, 4) - assert 7 == client.json().arrappend("arr", Path.rootPath(), *[5, 6, 7]) - - -@pytest.mark.redismod -def testArrIndexShouldSucceed(client): - client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) - assert 1 == client.json().arrindex("arr", Path.rootPath(), 1) - assert -1 == client.json().arrindex("arr", Path.rootPath(), 1, 2) - - -@pytest.mark.redismod -def test_arrinsertshouldsucceed(client): - client.json().set("arr", Path.rootPath(), [0, 4]) - assert 5 - -client.json().arrinsert( - "arr", - Path.rootPath(), - 1, - *[ +@skip_if_cluster_mode() +class TestJson: + @pytest.mark.redismod + def test_json_setbinarykey(self, client): + d = {"hello": "world", b"some": "value"} + with pytest.raises(TypeError): + client.json().set("somekey", Path.rootPath(), d) + assert client.json().set("somekey", Path.rootPath(), d, + decode_keys=True) + + @pytest.mark.redismod + def test_json_setgetdeleteforget(self, client): + assert client.json().set("foo", Path.rootPath(), "bar") + assert client.json().get("foo") == "bar" + assert client.json().get("baz") is None + assert client.json().delete("foo") == 1 + assert client.json().forget("foo") == 0 # second delete + assert client.exists("foo") == 0 + + @pytest.mark.redismod + def test_justaget(self, client): + client.json().set("foo", Path.rootPath(), "bar") + assert client.json().get("foo") == "bar" + + @pytest.mark.redismod + def test_json_get_jset(self, client): + assert client.json().set("foo", Path.rootPath(), "bar") + assert "bar" == client.json().get("foo") + assert client.json().get("baz") is None + assert 1 == client.json().delete("foo") + assert client.exists("foo") == 0 + + @pytest.mark.redismod + def test_nonascii_setgetdelete(self, client): + assert client.json().set("notascii", Path.rootPath(), + "hyvää-élève") is True + assert "hyvää-élève" == client.json().get("notascii", no_escape=True) + assert 1 == client.json().delete("notascii") + assert client.exists("notascii") == 0 + + @pytest.mark.redismod + def test_jsonsetexistentialmodifiersshouldsucceed(self, client): + obj = {"foo": "bar"} + assert client.json().set("obj", Path.rootPath(), obj) + + # Test that flags prevent updates when conditions are unmet + assert client.json().set("obj", Path("foo"), "baz", nx=True) is None + assert client.json().set("obj", Path("qaz"), "baz", xx=True) is None + + # Test that flags allow updates when conditions are met + assert client.json().set("obj", Path("foo"), "baz", xx=True) + assert client.json().set("obj", Path("qaz"), "baz", nx=True) + + # Test that flags are mutually exlusive + with pytest.raises(Exception): + client.json().set("obj", Path("foo"), "baz", nx=True, xx=True) + + @pytest.mark.redismod + def test_mgetshouldsucceed(self, client): + client.json().set("1", Path.rootPath(), 1) + client.json().set("2", Path.rootPath(), 2) + r = client.json().mget(Path.rootPath(), "1", "2") + e = [1, 2] + assert e == r + + @pytest.mark.redismod + @skip_ifmodversion_lt("99.99.99", + "ReJSON") # todo: update after the release + def test_clearShouldSucceed(self, client): + client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) + assert 1 == client.json().clear("arr", Path.rootPath()) + assert [] == client.json().get("arr") + + @pytest.mark.redismod + def test_typeshouldsucceed(self, client): + client.json().set("1", Path.rootPath(), 1) + assert b"integer" == client.json().type("1") + + @pytest.mark.redismod + def test_numincrbyshouldsucceed(self, client): + client.json().set("num", Path.rootPath(), 1) + assert 2 == client.json().numincrby("num", Path.rootPath(), 1) + assert 2.5 == client.json().numincrby("num", Path.rootPath(), 0.5) + assert 1.25 == client.json().numincrby("num", Path.rootPath(), -1.25) + + @pytest.mark.redismod + def test_nummultbyshouldsucceed(self, client): + client.json().set("num", Path.rootPath(), 1) + assert 2 == client.json().nummultby("num", Path.rootPath(), 2) + assert 5 == client.json().nummultby("num", Path.rootPath(), 2.5) + assert 2.5 == client.json().nummultby("num", Path.rootPath(), 0.5) + + @pytest.mark.redismod + @skip_ifmodversion_lt("99.99.99", + "ReJSON") # todo: update after the release + def test_toggleShouldSucceed(self, client): + client.json().set("bool", Path.rootPath(), False) + assert client.json().toggle("bool", Path.rootPath()) + assert not client.json().toggle("bool", Path.rootPath()) + # check non-boolean value + client.json().set("num", Path.rootPath(), 1) + with pytest.raises(redis.exceptions.ResponseError): + client.json().toggle("num", Path.rootPath()) + + @pytest.mark.redismod + def test_strappendshouldsucceed(self, client): + client.json().set("str", Path.rootPath(), "foo") + assert 6 == client.json().strappend("str", "bar", Path.rootPath()) + assert "foobar" == client.json().get("str", Path.rootPath()) + + @pytest.mark.redismod + def test_debug(self, client): + client.json().set("str", Path.rootPath(), "foo") + assert 24 == client.json().debug("str", Path.rootPath()) + + @pytest.mark.redismod + def test_strlenshouldsucceed(self, client): + client.json().set("str", Path.rootPath(), "foo") + assert 3 == client.json().strlen("str", Path.rootPath()) + client.json().strappend("str", "bar", Path.rootPath()) + assert 6 == client.json().strlen("str", Path.rootPath()) + + @pytest.mark.redismod + def test_arrappendshouldsucceed(self, client): + client.json().set("arr", Path.rootPath(), [1]) + assert 2 == client.json().arrappend("arr", Path.rootPath(), 2) + assert 4 == client.json().arrappend("arr", Path.rootPath(), 3, 4) + assert 7 == client.json().arrappend("arr", Path.rootPath(), *[5, 6, 7]) + + @pytest.mark.redismod + def testArrIndexShouldSucceed(self, client): + client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) + assert 1 == client.json().arrindex("arr", Path.rootPath(), 1) + assert -1 == client.json().arrindex("arr", Path.rootPath(), 1, 2) + + @pytest.mark.redismod + def test_arrinsertshouldsucceed(self, client): + client.json().set("arr", Path.rootPath(), [0, 4]) + assert 5 - -client.json().arrinsert( + "arr", + Path.rootPath(), 1, - 2, - 3, - ] - ) - assert [0, 1, 2, 3, 4] == client.json().get("arr") - - -@pytest.mark.redismod -def test_arrlenshouldsucceed(client): - client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) - assert 5 == client.json().arrlen("arr", Path.rootPath()) - - -@pytest.mark.redismod -def test_arrpopshouldsucceed(client): - client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) - assert 4 == client.json().arrpop("arr", Path.rootPath(), 4) - assert 3 == client.json().arrpop("arr", Path.rootPath(), -1) - assert 2 == client.json().arrpop("arr", Path.rootPath()) - assert 0 == client.json().arrpop("arr", Path.rootPath(), 0) - assert [1] == client.json().get("arr") - - -@pytest.mark.redismod -def test_arrtrimshouldsucceed(client): - client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) - assert 3 == client.json().arrtrim("arr", Path.rootPath(), 1, 3) - assert [1, 2, 3] == client.json().get("arr") - - -@pytest.mark.redismod -def test_respshouldsucceed(client): - obj = {"foo": "bar", "baz": 1, "qaz": True} - client.json().set("obj", Path.rootPath(), obj) - assert b"bar" == client.json().resp("obj", Path("foo")) - assert 1 == client.json().resp("obj", Path("baz")) - assert client.json().resp("obj", Path("qaz")) - - -@pytest.mark.redismod -def test_objkeysshouldsucceed(client): - obj = {"foo": "bar", "baz": "qaz"} - client.json().set("obj", Path.rootPath(), obj) - keys = client.json().objkeys("obj", Path.rootPath()) - keys.sort() - exp = list(obj.keys()) - exp.sort() - assert exp == keys - - -@pytest.mark.redismod -def test_objlenshouldsucceed(client): - obj = {"foo": "bar", "baz": "qaz"} - client.json().set("obj", Path.rootPath(), obj) - assert len(obj) == client.json().objlen("obj", Path.rootPath()) - - -# @pytest.mark.pipeline -# @pytest.mark.redismod -# def test_pipelineshouldsucceed(client): -# p = client.json().pipeline() -# p.set("foo", Path.rootPath(), "bar") -# p.get("foo") -# p.delete("foo") -# assert [True, "bar", 1] == p.execute() -# assert client.keys() == [] -# assert client.get("foo") is None + *[ + 1, + 2, + 3, + ] + ) + assert [0, 1, 2, 3, 4] == client.json().get("arr") + + @pytest.mark.redismod + def test_arrlenshouldsucceed(self, client): + client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) + assert 5 == client.json().arrlen("arr", Path.rootPath()) + + @pytest.mark.redismod + def test_arrpopshouldsucceed(self, client): + client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) + assert 4 == client.json().arrpop("arr", Path.rootPath(), 4) + assert 3 == client.json().arrpop("arr", Path.rootPath(), -1) + assert 2 == client.json().arrpop("arr", Path.rootPath()) + assert 0 == client.json().arrpop("arr", Path.rootPath(), 0) + assert [1] == client.json().get("arr") + + @pytest.mark.redismod + def test_arrtrimshouldsucceed(self, client): + client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) + assert 3 == client.json().arrtrim("arr", Path.rootPath(), 1, 3) + assert [1, 2, 3] == client.json().get("arr") + + @pytest.mark.redismod + def test_respshouldsucceed(self, client): + obj = {"foo": "bar", "baz": 1, "qaz": True} + client.json().set("obj", Path.rootPath(), obj) + assert b"bar" == client.json().resp("obj", Path("foo")) + assert 1 == client.json().resp("obj", Path("baz")) + assert client.json().resp("obj", Path("qaz")) + + @pytest.mark.redismod + def test_objkeysshouldsucceed(self, client): + obj = {"foo": "bar", "baz": "qaz"} + client.json().set("obj", Path.rootPath(), obj) + keys = client.json().objkeys("obj", Path.rootPath()) + keys.sort() + exp = list(obj.keys()) + exp.sort() + assert exp == keys + + @pytest.mark.redismod + def test_objlenshouldsucceed(self, client): + obj = {"foo": "bar", "baz": "qaz"} + client.json().set("obj", Path.rootPath(), obj) + assert len(obj) == client.json().objlen("obj", Path.rootPath()) + + # @pytest.mark.pipeline + # @pytest.mark.redismod + # def test_pipelineshouldsucceed(client): + # p = client.json().pipeline() + # p.set("foo", Path.rootPath(), "bar") + # p.get("foo") + # p.delete("foo") + # assert [True, "bar", 1] == p.execute() + # assert client.keys() == [] + # assert client.get("foo") is None diff --git a/tests/test_lock.py b/tests/test_lock.py index fa76385221..ab62dfc820 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -4,9 +4,10 @@ from redis.exceptions import LockError, LockNotOwnedError from redis.client import Redis from redis.lock import Lock -from .conftest import _get_client +from .conftest import _get_client, skip_if_cluster_mode +@skip_if_cluster_mode() class TestLock: @pytest.fixture() def r_decoded(self, request): @@ -220,6 +221,7 @@ def test_reacquiring_lock_no_longer_owned_raises_error(self, r): lock.reacquire() +@skip_if_cluster_mode() class TestLockClassSelection: def test_lock_class_argument(self, r): class MyLock: diff --git a/tests/test_monitor.py b/tests/test_monitor.py index 1013202f22..5d065c9206 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -1,6 +1,7 @@ -from .conftest import wait_for_command +from .conftest import wait_for_command, skip_if_cluster_mode +@skip_if_cluster_mode() class TestMonitor: def test_wait_command_not_found(self, r): "Make sure the wait_for_command func works when command is not found" diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 2d27c4e8bb..a298af39c7 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -30,12 +30,12 @@ def r(self, request): request=request, single_connection_client=False) - def test_close_connection_in_child(self, master_host): + def test_close_connection_in_child(self, master_host, master_port): """ A connection owned by a parent and closed by a child doesn't destroy the file descriptors so a parent can still use it. """ - conn = Connection(host=master_host) + conn = Connection(host=master_host, port=master_port) conn.send_command('ping') assert conn.read_response() == b'PONG' @@ -56,12 +56,12 @@ def target(conn): conn.send_command('ping') assert conn.read_response() == b'PONG' - def test_close_connection_in_parent(self, master_host): + def test_close_connection_in_parent(self, master_host, master_port): """ A connection owned by a parent is unusable by a child if the parent (the owning process) closes the connection. """ - conn = Connection(host=master_host) + conn = Connection(host=master_host, port=master_port) conn.send_command('ping') assert conn.read_response() == b'PONG' @@ -84,12 +84,13 @@ def target(conn, ev): assert proc.exitcode == 0 @pytest.mark.parametrize('max_connections', [1, 2, None]) - def test_pool(self, max_connections, master_host): + def test_pool(self, max_connections, master_host, master_port): """ A child will create its own connections when using a pool created by a parent. """ - pool = ConnectionPool.from_url('redis://{}'.format(master_host), + pool = ConnectionPool.from_url('redis://{}:{}'.format(master_host, + master_port), max_connections=max_connections) conn = pool.get_connection('ping') @@ -119,12 +120,14 @@ def target(pool): assert conn.read_response() == b'PONG' @pytest.mark.parametrize('max_connections', [1, 2, None]) - def test_close_pool_in_main(self, max_connections, master_host): + def test_close_pool_in_main(self, max_connections, master_host, + master_port): """ A child process that uses the same pool as its parent isn't affected when the parent disconnects all connections within the pool. """ - pool = ConnectionPool.from_url('redis://{}'.format(master_host), + pool = ConnectionPool.from_url('redis://{}:{}'.format(master_host, + master_port), max_connections=max_connections) conn = pool.get_connection('ping') diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 08bd40bacd..8fadf46bf1 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,7 +1,8 @@ import pytest import redis -from .conftest import wait_for_command, skip_if_server_version_lt +from .conftest import wait_for_command, skip_if_server_version_lt, \ + skip_if_cluster_mode class TestPipeline: @@ -59,6 +60,7 @@ def test_pipeline_no_transaction(self, r): assert r['b'] == b'b1' assert r['c'] == b'c1' + @skip_if_cluster_mode() def test_pipeline_no_transaction_watch(self, r): r['a'] = 0 @@ -70,6 +72,7 @@ def test_pipeline_no_transaction_watch(self, r): pipe.set('a', int(a) + 1) assert pipe.execute() == [True] + @skip_if_cluster_mode() def test_pipeline_no_transaction_watch_failure(self, r): r['a'] = 0 @@ -129,6 +132,7 @@ def test_exec_error_raised(self, r): assert pipe.set('z', 'zzz').execute() == [True] assert r['z'] == b'zzz' + @skip_if_cluster_mode() def test_transaction_with_empty_error_command(self, r): """ Commands with custom EMPTY_ERROR functionality return their default @@ -143,6 +147,7 @@ def test_transaction_with_empty_error_command(self, r): assert result[1] == [] assert result[2] + @skip_if_cluster_mode() def test_pipeline_with_empty_error_command(self, r): """ Commands with custom EMPTY_ERROR functionality return their default @@ -171,6 +176,7 @@ def test_parse_error_raised(self, r): assert pipe.set('z', 'zzz').execute() == [True] assert r['z'] == b'zzz' + @skip_if_cluster_mode() def test_parse_error_raised_transaction(self, r): with r.pipeline() as pipe: pipe.multi() @@ -186,6 +192,7 @@ def test_parse_error_raised_transaction(self, r): assert pipe.set('z', 'zzz').execute() == [True] assert r['z'] == b'zzz' + @skip_if_cluster_mode() def test_watch_succeed(self, r): r['a'] = 1 r['b'] = 2 @@ -203,6 +210,7 @@ def test_watch_succeed(self, r): assert pipe.execute() == [True] assert not pipe.watching + @skip_if_cluster_mode() def test_watch_failure(self, r): r['a'] = 1 r['b'] = 2 @@ -217,6 +225,7 @@ def test_watch_failure(self, r): assert not pipe.watching + @skip_if_cluster_mode() def test_watch_failure_in_empty_transaction(self, r): r['a'] = 1 r['b'] = 2 @@ -230,6 +239,7 @@ def test_watch_failure_in_empty_transaction(self, r): assert not pipe.watching + @skip_if_cluster_mode() def test_unwatch(self, r): r['a'] = 1 r['b'] = 2 @@ -242,6 +252,7 @@ def test_unwatch(self, r): pipe.get('a') assert pipe.execute() == [b'1'] + @skip_if_cluster_mode() def test_watch_exec_no_unwatch(self, r): r['a'] = 1 r['b'] = 2 @@ -262,6 +273,7 @@ def test_watch_exec_no_unwatch(self, r): unwatch_command = wait_for_command(r, m, 'UNWATCH') assert unwatch_command is None, "should not send UNWATCH" + @skip_if_cluster_mode() def test_watch_reset_unwatch(self, r): r['a'] = 1 @@ -276,6 +288,7 @@ def test_watch_reset_unwatch(self, r): assert unwatch_command is not None assert unwatch_command['command'] == 'UNWATCH' + @skip_if_cluster_mode() def test_transaction_callable(self, r): r['a'] = 1 r['b'] = 2 @@ -300,6 +313,7 @@ def my_transaction(pipe): assert result == [True] assert r['c'] == b'4' + @skip_if_cluster_mode() def test_transaction_callable_returns_value_from_callable(self, r): def callback(pipe): # No need to do anything here since we only want the return value @@ -354,6 +368,7 @@ def test_pipeline_with_bitfield(self, r): assert pipe == pipe2 assert response == [True, [0, 0, 15, 15, 14], b'1'] + @skip_if_cluster_mode() @skip_if_server_version_lt('2.0.0') def test_pipeline_discard(self, r): diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 6a4f0aafa4..ebb96de58b 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -7,7 +7,8 @@ import redis from redis.exceptions import ConnectionError -from .conftest import _get_client, skip_if_server_version_lt +from .conftest import _get_client, skip_if_cluster_mode, \ + skip_if_server_version_lt def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): @@ -119,6 +120,7 @@ def test_resubscribe_to_channels_on_reconnection(self, r): kwargs = make_subscribe_test_data(r.pubsub(), 'channel') self._test_resubscribe_on_reconnection(**kwargs) + @skip_if_cluster_mode() def test_resubscribe_to_patterns_on_reconnection(self, r): kwargs = make_subscribe_test_data(r.pubsub(), 'pattern') self._test_resubscribe_on_reconnection(**kwargs) @@ -173,6 +175,7 @@ def test_subscribe_property_with_channels(self, r): kwargs = make_subscribe_test_data(r.pubsub(), 'channel') self._test_subscribed_property(**kwargs) + @skip_if_cluster_mode() def test_subscribe_property_with_patterns(self, r): kwargs = make_subscribe_test_data(r.pubsub(), 'pattern') self._test_subscribed_property(**kwargs) @@ -216,6 +219,7 @@ def test_sub_unsub_resub_channels(self, r): kwargs = make_subscribe_test_data(r.pubsub(), 'channel') self._test_sub_unsub_resub(**kwargs) + @skip_if_cluster_mode() def test_sub_unsub_resub_patterns(self, r): kwargs = make_subscribe_test_data(r.pubsub(), 'pattern') self._test_sub_unsub_resub(**kwargs) @@ -303,6 +307,7 @@ def test_channel_message_handler(self, r): assert wait_for_message(p) is None assert self.message == make_message('message', 'foo', 'test message') + @skip_if_cluster_mode() def test_pattern_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) p.psubscribe(**{'f*': self.message_handler}) @@ -322,6 +327,9 @@ def test_unicode_channel_message_handler(self, r): assert wait_for_message(p) is None assert self.message == make_message('message', channel, 'test message') + @skip_if_cluster_mode() + # see: https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html + # #known-limitations-with-pubsub def test_unicode_pattern_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) pattern = 'uni' + chr(4456) + '*' @@ -397,6 +405,7 @@ def test_channel_publish(self, r): self.channel, self.data) + @skip_if_cluster_mode() def test_pattern_publish(self, r): p = r.pubsub() p.psubscribe(self.pattern) @@ -493,7 +502,7 @@ def test_pubsub_numsub(self, r): assert wait_for_message(p3)['type'] == 'subscribe' channels = [(b'foo', 1), (b'bar', 2), (b'baz', 3)] - assert channels == r.pubsub_numsub('foo', 'bar', 'baz') + assert r.pubsub_numsub('foo', 'bar', 'baz') == channels @skip_if_server_version_lt('2.8.0') def test_pubsub_numpat(self, r): @@ -525,6 +534,7 @@ def test_send_pubsub_ping_message(self, r): pattern=None) +@skip_if_cluster_mode() class TestPubSubConnectionKilled: @skip_if_server_version_lt('3.0.0') diff --git a/tests/test_scripting.py b/tests/test_scripting.py index c3c2094d4a..46a684e36d 100644 --- a/tests/test_scripting.py +++ b/tests/test_scripting.py @@ -1,7 +1,9 @@ import pytest from redis import exceptions - +from .conftest import ( + skip_if_cluster_mode, +) multiply_script = """ local value = redis.call('GET', KEYS[1]) @@ -20,6 +22,7 @@ """ +@skip_if_cluster_mode() class TestScripting: @pytest.fixture(autouse=True) def reset_scripts(self, r): diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py index 54cf262c43..7f66603085 100644 --- a/tests/test_sentinel.py +++ b/tests/test_sentinel.py @@ -5,6 +5,7 @@ from redis import exceptions from redis.sentinel import (Sentinel, SentinelConnectionPool, MasterNotFoundError, SlaveNotFoundError) +from .conftest import skip_if_cluster_mode import redis.sentinel @@ -13,6 +14,7 @@ def master_ip(master_host): yield socket.gethostbyname(master_host) +@skip_if_cluster_mode() class SentinelTestClient: def __init__(self, cluster, id): self.cluster = cluster @@ -36,6 +38,24 @@ def execute_command(self, *args, **kwargs): return bool_ok +@pytest.fixture() +def cluster(request, master_ip): + def teardown(): + redis.sentinel.Redis = saved_Redis + + cluster = SentinelTestCluster(ip=master_ip) + saved_Redis = redis.sentinel.Redis + redis.sentinel.Redis = cluster.client + request.addfinalizer(teardown) + return cluster + + +@pytest.fixture() +def sentinel(request, cluster): + return Sentinel([('foo', 26379), ('bar', 26379)]) + + +@skip_if_cluster_mode() class SentinelTestCluster: def __init__(self, servisentinel_ce_name='mymaster', ip='127.0.0.1', port=6379): @@ -64,156 +84,129 @@ def timeout_if_down(self, node): def client(self, host, port, **kwargs): return SentinelTestClient(self, (host, port)) - -@pytest.fixture() -def cluster(request, master_ip): - def teardown(): - redis.sentinel.Redis = saved_Redis - cluster = SentinelTestCluster(ip=master_ip) - saved_Redis = redis.sentinel.Redis - redis.sentinel.Redis = cluster.client - request.addfinalizer(teardown) - return cluster - - -@pytest.fixture() -def sentinel(request, cluster): - return Sentinel([('foo', 26379), ('bar', 26379)]) - - -def test_discover_master(sentinel, master_ip): - address = sentinel.discover_master('mymaster') - assert address == (master_ip, 6379) - - -def test_discover_master_error(sentinel): - with pytest.raises(MasterNotFoundError): - sentinel.discover_master('xxx') - - -def test_discover_master_sentinel_down(cluster, sentinel, master_ip): - # Put first sentinel 'foo' down - cluster.nodes_down.add(('foo', 26379)) - address = sentinel.discover_master('mymaster') - assert address == (master_ip, 6379) - # 'bar' is now first sentinel - assert sentinel.sentinels[0].id == ('bar', 26379) - - -def test_discover_master_sentinel_timeout(cluster, sentinel, master_ip): - # Put first sentinel 'foo' down - cluster.nodes_timeout.add(('foo', 26379)) - address = sentinel.discover_master('mymaster') - assert address == (master_ip, 6379) - # 'bar' is now first sentinel - assert sentinel.sentinels[0].id == ('bar', 26379) - - -def test_master_min_other_sentinels(cluster, master_ip): - sentinel = Sentinel([('foo', 26379)], min_other_sentinels=1) - # min_other_sentinels - with pytest.raises(MasterNotFoundError): - sentinel.discover_master('mymaster') - cluster.master['num-other-sentinels'] = 2 - address = sentinel.discover_master('mymaster') - assert address == (master_ip, 6379) - - -def test_master_odown(cluster, sentinel): - cluster.master['is_odown'] = True - with pytest.raises(MasterNotFoundError): - sentinel.discover_master('mymaster') - - -def test_master_sdown(cluster, sentinel): - cluster.master['is_sdown'] = True - with pytest.raises(MasterNotFoundError): - sentinel.discover_master('mymaster') - - -def test_discover_slaves(cluster, sentinel): - assert sentinel.discover_slaves('mymaster') == [] - - cluster.slaves = [ - {'ip': 'slave0', 'port': 1234, 'is_odown': False, 'is_sdown': False}, - {'ip': 'slave1', 'port': 1234, 'is_odown': False, 'is_sdown': False}, - ] - assert sentinel.discover_slaves('mymaster') == [ - ('slave0', 1234), ('slave1', 1234)] - - # slave0 -> ODOWN - cluster.slaves[0]['is_odown'] = True - assert sentinel.discover_slaves('mymaster') == [ - ('slave1', 1234)] - - # slave1 -> SDOWN - cluster.slaves[1]['is_sdown'] = True - assert sentinel.discover_slaves('mymaster') == [] - - cluster.slaves[0]['is_odown'] = False - cluster.slaves[1]['is_sdown'] = False - - # node0 -> DOWN - cluster.nodes_down.add(('foo', 26379)) - assert sentinel.discover_slaves('mymaster') == [ - ('slave0', 1234), ('slave1', 1234)] - cluster.nodes_down.clear() - - # node0 -> TIMEOUT - cluster.nodes_timeout.add(('foo', 26379)) - assert sentinel.discover_slaves('mymaster') == [ - ('slave0', 1234), ('slave1', 1234)] - - -def test_master_for(cluster, sentinel, master_ip): - master = sentinel.master_for('mymaster', db=9) - assert master.ping() - assert master.connection_pool.master_address == (master_ip, 6379) - - # Use internal connection check - master = sentinel.master_for('mymaster', db=9, check_connection=True) - assert master.ping() - - -def test_slave_for(cluster, sentinel): - cluster.slaves = [ - {'ip': '127.0.0.1', 'port': 6379, - 'is_odown': False, 'is_sdown': False}, - ] - slave = sentinel.slave_for('mymaster', db=9) - assert slave.ping() - - -def test_slave_for_slave_not_found_error(cluster, sentinel): - cluster.master['is_odown'] = True - slave = sentinel.slave_for('mymaster', db=9) - with pytest.raises(SlaveNotFoundError): - slave.ping() - - -def test_slave_round_robin(cluster, sentinel, master_ip): - cluster.slaves = [ - {'ip': 'slave0', 'port': 6379, 'is_odown': False, 'is_sdown': False}, - {'ip': 'slave1', 'port': 6379, 'is_odown': False, 'is_sdown': False}, - ] - pool = SentinelConnectionPool('mymaster', sentinel) - rotator = pool.rotate_slaves() - assert next(rotator) in (('slave0', 6379), ('slave1', 6379)) - assert next(rotator) in (('slave0', 6379), ('slave1', 6379)) - # Fallback to master - assert next(rotator) == (master_ip, 6379) - with pytest.raises(SlaveNotFoundError): - next(rotator) - - -def test_ckquorum(cluster, sentinel): - assert sentinel.sentinel_ckquorum("mymaster") - - -def test_flushconfig(cluster, sentinel): - assert sentinel.sentinel_flushconfig() - - -def test_reset(cluster, sentinel): - cluster.master['is_odown'] = True - assert sentinel.sentinel_reset('mymaster') + def test_discover_master(sentinel, master_ip): + address = sentinel.discover_master('mymaster') + assert address == (master_ip, 6379) + + def test_discover_master_error(sentinel): + with pytest.raises(MasterNotFoundError): + sentinel.discover_master('xxx') + + def test_discover_master_sentinel_down(cluster, sentinel, master_ip): + # Put first sentinel 'foo' down + cluster.nodes_down.add(('foo', 26379)) + address = sentinel.discover_master('mymaster') + assert address == (master_ip, 6379) + # 'bar' is now first sentinel + assert sentinel.sentinels[0].id == ('bar', 26379) + + def test_discover_master_sentinel_timeout(cluster, sentinel, master_ip): + # Put first sentinel 'foo' down + cluster.nodes_timeout.add(('foo', 26379)) + address = sentinel.discover_master('mymaster') + assert address == (master_ip, 6379) + # 'bar' is now first sentinel + assert sentinel.sentinels[0].id == ('bar', 26379) + + def test_master_min_other_sentinels(cluster, master_ip): + sentinel = Sentinel([('foo', 26379)], min_other_sentinels=1) + # min_other_sentinels + with pytest.raises(MasterNotFoundError): + sentinel.discover_master('mymaster') + cluster.master['num-other-sentinels'] = 2 + address = sentinel.discover_master('mymaster') + assert address == (master_ip, 6379) + + def test_master_odown(cluster, sentinel): + cluster.master['is_odown'] = True + with pytest.raises(MasterNotFoundError): + sentinel.discover_master('mymaster') + + def test_master_sdown(cluster, sentinel): + cluster.master['is_sdown'] = True + with pytest.raises(MasterNotFoundError): + sentinel.discover_master('mymaster') + + def test_discover_slaves(cluster, sentinel): + assert sentinel.discover_slaves('mymaster') == [] + + cluster.slaves = [ + {'ip': 'slave0', 'port': 1234, 'is_odown': False, + 'is_sdown': False}, + {'ip': 'slave1', 'port': 1234, 'is_odown': False, + 'is_sdown': False}, + ] + assert sentinel.discover_slaves('mymaster') == [ + ('slave0', 1234), ('slave1', 1234)] + + # slave0 -> ODOWN + cluster.slaves[0]['is_odown'] = True + assert sentinel.discover_slaves('mymaster') == [ + ('slave1', 1234)] + + # slave1 -> SDOWN + cluster.slaves[1]['is_sdown'] = True + assert sentinel.discover_slaves('mymaster') == [] + + cluster.slaves[0]['is_odown'] = False + cluster.slaves[1]['is_sdown'] = False + + # node0 -> DOWN + cluster.nodes_down.add(('foo', 26379)) + assert sentinel.discover_slaves('mymaster') == [ + ('slave0', 1234), ('slave1', 1234)] + cluster.nodes_down.clear() + + # node0 -> TIMEOUT + cluster.nodes_timeout.add(('foo', 26379)) + assert sentinel.discover_slaves('mymaster') == [ + ('slave0', 1234), ('slave1', 1234)] + + def test_master_for(cluster, sentinel, master_ip): + master = sentinel.master_for('mymaster', db=9) + assert master.ping() + assert master.connection_pool.master_address == (master_ip, 6379) + + # Use internal connection check + master = sentinel.master_for('mymaster', db=9, check_connection=True) + assert master.ping() + + def test_slave_for(cluster, sentinel): + cluster.slaves = [ + {'ip': '127.0.0.1', 'port': 6379, + 'is_odown': False, 'is_sdown': False}, + ] + slave = sentinel.slave_for('mymaster', db=9) + assert slave.ping() + + def test_slave_for_slave_not_found_error(cluster, sentinel): + cluster.master['is_odown'] = True + slave = sentinel.slave_for('mymaster', db=9) + with pytest.raises(SlaveNotFoundError): + slave.ping() + + def test_slave_round_robin(cluster, sentinel, master_ip): + cluster.slaves = [ + {'ip': 'slave0', 'port': 6379, 'is_odown': False, + 'is_sdown': False}, + {'ip': 'slave1', 'port': 6379, 'is_odown': False, + 'is_sdown': False}, + ] + pool = SentinelConnectionPool('mymaster', sentinel) + rotator = pool.rotate_slaves() + assert next(rotator) in (('slave0', 6379), ('slave1', 6379)) + assert next(rotator) in (('slave0', 6379), ('slave1', 6379)) + # Fallback to master + assert next(rotator) == (master_ip, 6379) + with pytest.raises(SlaveNotFoundError): + next(rotator) + + def test_ckquorum(cluster, sentinel): + assert sentinel.sentinel_ckquorum("mymaster") + + def test_flushconfig(cluster, sentinel): + assert sentinel.sentinel_flushconfig() + + def test_reset(cluster, sentinel): + cluster.master['is_odown'] = True + assert sentinel.sentinel_reset('mymaster') diff --git a/tox.ini b/tox.ini index 67b7e7575d..e46a0c47f5 100644 --- a/tox.ini +++ b/tox.ini @@ -74,6 +74,21 @@ image = redisfab/lots-of-pythons volumes = bind:rw:{toxinidir}:/data +[docker:redis_cluster] +name = redis_cluster +image = barshaul/redis-py:6.2.6-cluster +healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',16379)) else False" +ports = + 16379:16379/tcp + 16380:16380/tcp + 16381:16381/tcp + 16382:16382/tcp + 16383:16383/tcp + 16384:16384/tcp +volumes = + bind:rw:{toxinidir}/docker/cluster/redis.conf:/redis.conf + + [testenv] deps = -r {toxinidir}/dev_requirements.txt docker = @@ -82,6 +97,7 @@ docker = sentinel_1 sentinel_2 sentinel_3 + redis_cluster redismod extras = hiredis: hiredis @@ -98,6 +114,7 @@ docker = sentinel_1 sentinel_2 sentinel_3 + redis_cluster redismod lots-of-pythons commands = /usr/bin/echo