From f789eb84ba39800b129abee2353efaed09d86479 Mon Sep 17 00:00:00 2001 From: Anas Date: Sun, 21 Nov 2021 19:32:32 +0200 Subject: [PATCH 1/8] Added black and isort to tox and dev_requirements.txt --- benchmarks/socket_read_size.py | 27 +++--- dev_requirements.txt | 2 + redis/backoff.py | 2 +- redis/commands/json/__init__.py | 10 +-- redis/commands/json/decoders.py | 7 +- redis/commands/redismodules.py | 16 ++-- redis/commands/search/__init__.py | 4 +- redis/commands/search/commands.py | 18 ++-- redis/commands/search/field.py | 6 +- redis/commands/search/query.py | 9 +- redis/commands/search/result.py | 2 +- redis/commands/search/suggestion.py | 6 +- redis/commands/sentinel.py | 30 +++---- redis/commands/timeseries/__init__.py | 11 +-- redis/commands/timeseries/commands.py | 29 ++----- redis/commands/timeseries/info.py | 2 +- redis/commands/timeseries/utils.py | 11 +-- redis/exceptions.py | 1 + redis/lock.py | 55 ++++++------ redis/retry.py | 5 +- redis/utils.py | 7 +- setup.cfg | 7 ++ tests/test_encoding.py | 81 +++++++++-------- tests/test_lock.py | 117 +++++++++++++------------ tests/test_monitor.py | 46 +++++----- tests/test_retry.py | 7 +- tests/test_scripting.py | 58 ++++++------- tests/test_sentinel.py | 120 +++++++++++++------------- tests/test_timeseries.py | 115 +++++------------------- tox.ini | 5 +- 30 files changed, 357 insertions(+), 459 deletions(-) create mode 100644 setup.cfg diff --git a/benchmarks/socket_read_size.py b/benchmarks/socket_read_size.py index 72a1b0a7e3..3427956ced 100644 --- a/benchmarks/socket_read_size.py +++ b/benchmarks/socket_read_size.py @@ -1,34 +1,27 @@ -from redis.connection import PythonParser, HiredisParser from base import Benchmark +from redis.connection import HiredisParser, PythonParser + class SocketReadBenchmark(Benchmark): ARGUMENTS = ( + {"name": "parser", "values": [PythonParser, HiredisParser]}, { - 'name': 'parser', - 'values': [PythonParser, HiredisParser] - }, - { - 'name': 'value_size', - 'values': [10, 100, 1000, 10000, 100000, 1000000, 10000000, - 100000000] + "name": "value_size", + "values": [10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000], }, - { - 'name': 'read_size', - 'values': [4096, 8192, 16384, 32768, 65536, 131072] - } + {"name": "read_size", "values": [4096, 8192, 16384, 32768, 65536, 131072]}, ) def setup(self, value_size, read_size, parser): - r = self.get_client(parser_class=parser, - socket_read_size=read_size) - r.set('benchmark', 'a' * value_size) + r = self.get_client(parser_class=parser, socket_read_size=read_size) + r.set("benchmark", "a" * value_size) def run(self, value_size, read_size, parser): r = self.get_client() - r.get('benchmark') + r.get("benchmark") -if __name__ == '__main__': +if __name__ == "__main__": SocketReadBenchmark().run_benchmark() diff --git a/dev_requirements.txt b/dev_requirements.txt index 56ac08efe2..e1c131e2c9 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,5 +1,7 @@ +black==21.11b1 flake8>=3.9.2 flynt~=0.69.0 +isort==5.10.1 pytest==6.2.5 pytest-timeout==2.0.1 tox==3.24.4 diff --git a/redis/backoff.py b/redis/backoff.py index 9162778cc0..cbb4e73779 100644 --- a/redis/backoff.py +++ b/redis/backoff.py @@ -1,5 +1,5 @@ -from abc import ABC, abstractmethod import random +from abc import ABC, abstractmethod class AbstractBackoff(ABC): diff --git a/redis/commands/json/__init__.py b/redis/commands/json/__init__.py index d634dbd3f4..12c0648722 100644 --- a/redis/commands/json/__init__.py +++ b/redis/commands/json/__init__.py @@ -1,12 +1,10 @@ -from json import JSONDecoder, JSONEncoder, JSONDecodeError +from json import JSONDecodeError, JSONDecoder, JSONEncoder + +import redis -from .decoders import ( - decode_list, - bulk_of_jsons, -) from ..helpers import nativestr from .commands import JSONCommands -import redis +from .decoders import bulk_of_jsons, decode_list class JSON(JSONCommands): diff --git a/redis/commands/json/decoders.py b/redis/commands/json/decoders.py index b19395c73b..b93847112b 100644 --- a/redis/commands/json/decoders.py +++ b/redis/commands/json/decoders.py @@ -1,6 +1,7 @@ -from ..helpers import nativestr -import re import copy +import re + +from ..helpers import nativestr def bulk_of_jsons(d): @@ -33,7 +34,7 @@ def unstring(obj): One can't simply call int/float in a try/catch because there is a semantic difference between (for example) 15.0 and 15. """ - floatreg = '^\\d+.\\d+$' + floatreg = "^\\d+.\\d+$" match = re.findall(floatreg, obj) if match != []: return float(match[0]) diff --git a/redis/commands/redismodules.py b/redis/commands/redismodules.py index 5f629fb5ea..2420d7b6fb 100644 --- a/redis/commands/redismodules.py +++ b/redis/commands/redismodules.py @@ -1,4 +1,4 @@ -from json import JSONEncoder, JSONDecoder +from json import JSONDecoder, JSONEncoder class RedisModuleCommands: @@ -7,21 +7,18 @@ class RedisModuleCommands: """ def json(self, encoder=JSONEncoder(), decoder=JSONDecoder()): - """Access the json namespace, providing support for redis json. - """ + """Access the json namespace, providing support for redis json.""" from .json import JSON - jj = JSON( - client=self, - encoder=encoder, - decoder=decoder) + + jj = JSON(client=self, encoder=encoder, decoder=decoder) return jj def ft(self, index_name="idx"): - """Access the search namespace, providing support for redis search. - """ + """Access the search namespace, providing support for redis search.""" from .search import Search + s = Search(client=self, index_name=index_name) return s @@ -31,5 +28,6 @@ def ts(self): """ from .timeseries import TimeSeries + s = TimeSeries(client=self) return s diff --git a/redis/commands/search/__init__.py b/redis/commands/search/__init__.py index a30cebe1b7..94bc037c3d 100644 --- a/redis/commands/search/__init__.py +++ b/redis/commands/search/__init__.py @@ -35,7 +35,7 @@ def add_document( replace=False, partial=False, no_create=False, - **fields + **fields, ): """ Add a document to the batch query @@ -49,7 +49,7 @@ def add_document( replace=replace, partial=partial, no_create=no_create, - **fields + **fields, ) self.current_chunk += 1 self.total += 1 diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 553bc39839..cfc82cd127 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -1,11 +1,11 @@ import itertools import time -from .document import Document -from .result import Result -from .query import Query from ._util import to_string from .aggregation import AggregateRequest, AggregateResult, Cursor +from .document import Document +from .query import Query +from .result import Result from .suggestion import SuggestionParser from ..helpers import parse_to_dict @@ -148,7 +148,7 @@ def _add_document( partial=False, language=None, no_create=False, - **fields + **fields, ): """ Internal add_document used for both batch and single doc indexing @@ -211,7 +211,7 @@ def add_document( partial=False, language=None, no_create=False, - **fields + **fields, ): """ Add a single document to the index. @@ -253,7 +253,7 @@ def add_document( partial=partial, language=language, no_create=no_create, - **fields + **fields, ) def add_document_hash( @@ -535,8 +535,7 @@ def spellcheck(self, query, distance=None, include=None, exclude=None): # ] # } corrections[_correction[1]] = [ - {"score": _item[0], "suggestion": _item[1]} - for _item in _correction[2] + {"score": _item[0], "suggestion": _item[1]} for _item in _correction[2] ] return corrections @@ -704,8 +703,7 @@ def sugdel(self, key, string): return self.execute_command(SUGDEL_COMMAND, key, string) def sugget( - self, key, prefix, fuzzy=False, num=10, with_scores=False, - with_payloads=False + self, key, prefix, fuzzy=False, num=10, with_scores=False, with_payloads=False ): """ Get a list of suggestions from the AutoCompleter, for a given prefix. diff --git a/redis/commands/search/field.py b/redis/commands/search/field.py index 076c872b62..69e39083b0 100644 --- a/redis/commands/search/field.py +++ b/redis/commands/search/field.py @@ -9,8 +9,7 @@ class Field: NOINDEX = "NOINDEX" AS = "AS" - def __init__(self, name, args=[], sortable=False, - no_index=False, as_name=None): + def __init__(self, name, args=[], sortable=False, no_index=False, as_name=None): self.name = name self.args = args self.args_suffix = list() @@ -47,8 +46,7 @@ class TextField(Field): def __init__( self, name, weight=1.0, no_stem=False, phonetic_matcher=None, **kwargs ): - Field.__init__(self, name, - args=[Field.TEXT, Field.WEIGHT, weight], **kwargs) + Field.__init__(self, name, args=[Field.TEXT, Field.WEIGHT, weight], **kwargs) if no_stem: Field.append_arg(self, self.NOSTEM) diff --git a/redis/commands/search/query.py b/redis/commands/search/query.py index 5534f7b88e..2bb8347dbc 100644 --- a/redis/commands/search/query.py +++ b/redis/commands/search/query.py @@ -62,11 +62,9 @@ def return_field(self, field, as_field=None): def _mk_field_list(self, fields): if not fields: return [] - return \ - [fields] if isinstance(fields, str) else list(fields) + return [fields] if isinstance(fields, str) else list(fields) - def summarize(self, fields=None, context_len=None, - num_frags=None, sep=None): + def summarize(self, fields=None, context_len=None, num_frags=None, sep=None): """ Return an abridged format of the field, containing only the segments of the field which contain the matching term(s). @@ -300,8 +298,7 @@ class NumericFilter(Filter): INF = "+inf" NEG_INF = "-inf" - def __init__(self, field, minval, maxval, minExclusive=False, - maxExclusive=False): + def __init__(self, field, minval, maxval, minExclusive=False, maxExclusive=False): args = [ minval if not minExclusive else f"({minval}", maxval if not maxExclusive else f"({maxval}", diff --git a/redis/commands/search/result.py b/redis/commands/search/result.py index 57ba53d5ca..5f4aca6411 100644 --- a/redis/commands/search/result.py +++ b/redis/commands/search/result.py @@ -1,5 +1,5 @@ -from .document import Document from ._util import to_string +from .document import Document class Result: diff --git a/redis/commands/search/suggestion.py b/redis/commands/search/suggestion.py index 6d295a652f..5d1eba64b8 100644 --- a/redis/commands/search/suggestion.py +++ b/redis/commands/search/suggestion.py @@ -46,8 +46,6 @@ def __init__(self, with_scores, with_payloads, ret): def __iter__(self): for i in range(0, len(self._sugs), self.sugsize): ss = self._sugs[i] - score = float(self._sugs[i + self._scoreidx]) \ - if self.with_scores else 1.0 - payload = self._sugs[i + self._payloadidx] \ - if self.with_payloads else None + score = float(self._sugs[i + self._scoreidx]) if self.with_scores else 1.0 + payload = self._sugs[i + self._payloadidx] if self.with_payloads else None yield Suggestion(ss, score, payload) diff --git a/redis/commands/sentinel.py b/redis/commands/sentinel.py index 1f02984bed..a9b06c2f6e 100644 --- a/redis/commands/sentinel.py +++ b/redis/commands/sentinel.py @@ -9,41 +9,39 @@ class SentinelCommands: def sentinel(self, *args): "Redis Sentinel's SENTINEL command." - warnings.warn( - DeprecationWarning('Use the individual sentinel_* methods')) + warnings.warn(DeprecationWarning("Use the individual sentinel_* methods")) def sentinel_get_master_addr_by_name(self, service_name): "Returns a (host, port) pair for the given ``service_name``" - return self.execute_command('SENTINEL GET-MASTER-ADDR-BY-NAME', - service_name) + return self.execute_command("SENTINEL GET-MASTER-ADDR-BY-NAME", service_name) def sentinel_master(self, service_name): "Returns a dictionary containing the specified masters state." - return self.execute_command('SENTINEL MASTER', service_name) + return self.execute_command("SENTINEL MASTER", service_name) def sentinel_masters(self): "Returns a list of dictionaries containing each master's state." - return self.execute_command('SENTINEL MASTERS') + return self.execute_command("SENTINEL MASTERS") def sentinel_monitor(self, name, ip, port, quorum): "Add a new master to Sentinel to be monitored" - return self.execute_command('SENTINEL MONITOR', name, ip, port, quorum) + return self.execute_command("SENTINEL MONITOR", name, ip, port, quorum) def sentinel_remove(self, name): "Remove a master from Sentinel's monitoring" - return self.execute_command('SENTINEL REMOVE', name) + return self.execute_command("SENTINEL REMOVE", name) def sentinel_sentinels(self, service_name): "Returns a list of sentinels for ``service_name``" - return self.execute_command('SENTINEL SENTINELS', service_name) + return self.execute_command("SENTINEL SENTINELS", service_name) def sentinel_set(self, name, option, value): "Set Sentinel monitoring parameters for a given master" - return self.execute_command('SENTINEL SET', name, option, value) + return self.execute_command("SENTINEL SET", name, option, value) def sentinel_slaves(self, service_name): "Returns a list of slaves for ``service_name``" - return self.execute_command('SENTINEL SLAVES', service_name) + return self.execute_command("SENTINEL SLAVES", service_name) def sentinel_reset(self, pattern): """ @@ -54,7 +52,7 @@ def sentinel_reset(self, pattern): failover in progress), and removes every slave and sentinel already discovered and associated with the master. """ - return self.execute_command('SENTINEL RESET', pattern, once=True) + return self.execute_command("SENTINEL RESET", pattern, once=True) def sentinel_failover(self, new_master_name): """ @@ -63,7 +61,7 @@ def sentinel_failover(self, new_master_name): configuration will be published so that the other Sentinels will update their configurations). """ - return self.execute_command('SENTINEL FAILOVER', new_master_name) + return self.execute_command("SENTINEL FAILOVER", new_master_name) def sentinel_ckquorum(self, new_master_name): """ @@ -74,9 +72,7 @@ def sentinel_ckquorum(self, new_master_name): This command should be used in monitoring systems to check if a Sentinel deployment is ok. """ - return self.execute_command('SENTINEL CKQUORUM', - new_master_name, - once=True) + return self.execute_command("SENTINEL CKQUORUM", new_master_name, once=True) def sentinel_flushconfig(self): """ @@ -94,4 +90,4 @@ def sentinel_flushconfig(self): This command works even if the previous configuration file is completely missing. """ - return self.execute_command('SENTINEL FLUSHCONFIG') + return self.execute_command("SENTINEL FLUSHCONFIG") diff --git a/redis/commands/timeseries/__init__.py b/redis/commands/timeseries/__init__.py index 5ce538f675..5b1f15114d 100644 --- a/redis/commands/timeseries/__init__.py +++ b/redis/commands/timeseries/__init__.py @@ -1,19 +1,12 @@ import redis.client -from .utils import ( - parse_range, - parse_get, - parse_m_range, - parse_m_get, -) -from .info import TSInfo from ..helpers import parse_to_list from .commands import ( ALTER_CMD, CREATE_CMD, CREATERULE_CMD, - DELETERULE_CMD, DEL_CMD, + DELETERULE_CMD, GET_CMD, INFO_CMD, MGET_CMD, @@ -24,6 +17,8 @@ REVRANGE_CMD, TimeSeriesCommands, ) +from .info import TSInfo +from .utils import parse_get, parse_m_get, parse_m_range, parse_range class TimeSeries(TimeSeriesCommands): diff --git a/redis/commands/timeseries/commands.py b/redis/commands/timeseries/commands.py index 460ba766a9..b7e33bc799 100644 --- a/redis/commands/timeseries/commands.py +++ b/redis/commands/timeseries/commands.py @@ -1,6 +1,5 @@ from redis.exceptions import DataError - ADD_CMD = "TS.ADD" ALTER_CMD = "TS.ALTER" CREATERULE_CMD = "TS.CREATERULE" @@ -276,13 +275,7 @@ def delete(self, key, from_time, to_time): """ # noqa return self.execute_command(DEL_CMD, key, from_time, to_time) - def createrule( - self, - source_key, - dest_key, - aggregation_type, - bucket_size_msec - ): + def createrule(self, source_key, dest_key, aggregation_type, bucket_size_msec): """ Create a compaction rule from values added to `source_key` into `dest_key`. Aggregating for `bucket_size_msec` where an `aggregation_type` can be @@ -321,11 +314,7 @@ def __range_params( """Create TS.RANGE and TS.REVRANGE arguments.""" params = [key, from_time, to_time] self._appendFilerByTs(params, filter_by_ts) - self._appendFilerByValue( - params, - filter_by_min_value, - filter_by_max_value - ) + self._appendFilerByValue(params, filter_by_min_value, filter_by_max_value) self._appendCount(params, count) self._appendAlign(params, align) self._appendAggregation(params, aggregation_type, bucket_size_msec) @@ -471,11 +460,7 @@ def __mrange_params( """Create TS.MRANGE and TS.MREVRANGE arguments.""" params = [from_time, to_time] self._appendFilerByTs(params, filter_by_ts) - self._appendFilerByValue( - params, - filter_by_min_value, - filter_by_max_value - ) + self._appendFilerByValue(params, filter_by_min_value, filter_by_max_value) self._appendCount(params, count) self._appendAlign(params, align) self._appendAggregation(params, aggregation_type, bucket_size_msec) @@ -654,7 +639,7 @@ def mrevrange( return self.execute_command(MREVRANGE_CMD, *params) def get(self, key): - """ # noqa + """# noqa Get the last sample of `key`. For more information: https://oss.redis.com/redistimeseries/master/commands/#tsget @@ -662,7 +647,7 @@ def get(self, key): return self.execute_command(GET_CMD, key) def mget(self, filters, with_labels=False): - """ # noqa + """# noqa Get the last samples matching the specific `filter`. For more information: https://oss.redis.com/redistimeseries/master/commands/#tsmget @@ -674,7 +659,7 @@ def mget(self, filters, with_labels=False): return self.execute_command(MGET_CMD, *params) def info(self, key): - """ # noqa + """# noqa Get information of `key`. For more information: https://oss.redis.com/redistimeseries/master/commands/#tsinfo @@ -682,7 +667,7 @@ def info(self, key): return self.execute_command(INFO_CMD, key) def queryindex(self, filters): - """ # noqa + """# noqa Get all the keys matching the `filter` list. For more information: https://oss.redis.com/redistimeseries/master/commands/#tsqueryindex diff --git a/redis/commands/timeseries/info.py b/redis/commands/timeseries/info.py index 2b8acd1b66..fba7f093b1 100644 --- a/redis/commands/timeseries/info.py +++ b/redis/commands/timeseries/info.py @@ -1,5 +1,5 @@ -from .utils import list_to_dict from ..helpers import nativestr +from .utils import list_to_dict class TSInfo: diff --git a/redis/commands/timeseries/utils.py b/redis/commands/timeseries/utils.py index c33b7c591e..c49b040271 100644 --- a/redis/commands/timeseries/utils.py +++ b/redis/commands/timeseries/utils.py @@ -2,9 +2,7 @@ def list_to_dict(aList): - return { - nativestr(aList[i][0]): nativestr(aList[i][1]) - for i in range(len(aList))} + return {nativestr(aList[i][0]): nativestr(aList[i][1]) for i in range(len(aList))} def parse_range(response): @@ -16,9 +14,7 @@ def parse_m_range(response): """Parse multi range response. Used by TS.MRANGE and TS.MREVRANGE.""" res = [] for item in response: - res.append( - {nativestr(item[0]): - [list_to_dict(item[1]), parse_range(item[2])]}) + res.append({nativestr(item[0]): [list_to_dict(item[1]), parse_range(item[2])]}) return sorted(res, key=lambda d: list(d.keys())) @@ -34,8 +30,7 @@ def parse_m_get(response): res = [] for item in response: if not item[2]: - res.append( - {nativestr(item[0]): [list_to_dict(item[1]), None, None]}) + res.append({nativestr(item[0]): [list_to_dict(item[1]), None, None]}) else: res.append( { diff --git a/redis/exceptions.py b/redis/exceptions.py index eb6ecc2dc5..4d5d530925 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -83,6 +83,7 @@ class AuthenticationWrongNumberOfArgsError(ResponseError): An error to indicate that the wrong number of args were sent to the AUTH command """ + pass diff --git a/redis/lock.py b/redis/lock.py index d2297526a0..95bb413d7e 100644 --- a/redis/lock.py +++ b/redis/lock.py @@ -2,6 +2,7 @@ import time as mod_time import uuid from types import SimpleNamespace + from redis.exceptions import LockError, LockNotOwnedError @@ -70,8 +71,16 @@ class Lock: return 1 """ - def __init__(self, redis, name, timeout=None, sleep=0.1, - blocking=True, blocking_timeout=None, thread_local=True): + def __init__( + self, + redis, + name, + timeout=None, + sleep=0.1, + blocking=True, + blocking_timeout=None, + thread_local=True, + ): """ Create a new Lock instance named ``name`` using the Redis client supplied by ``redis``. @@ -129,11 +138,7 @@ def __init__(self, redis, name, timeout=None, sleep=0.1, self.blocking = blocking self.blocking_timeout = blocking_timeout self.thread_local = bool(thread_local) - self.local = ( - threading.local() - if self.thread_local - else SimpleNamespace() - ) + self.local = threading.local() if self.thread_local else SimpleNamespace() self.local.token = None self.register_scripts() @@ -145,8 +150,7 @@ def register_scripts(self): if cls.lua_extend is None: cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT) if cls.lua_reacquire is None: - cls.lua_reacquire = \ - client.register_script(cls.LUA_REACQUIRE_SCRIPT) + cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT) def __enter__(self): if self.acquire(): @@ -222,8 +226,7 @@ def owned(self): if stored_token and not isinstance(stored_token, bytes): encoder = self.redis.connection_pool.get_encoder() stored_token = encoder.encode(stored_token) - return self.local.token is not None and \ - stored_token == self.local.token + return self.local.token is not None and stored_token == self.local.token def release(self): "Releases the already acquired lock" @@ -234,11 +237,10 @@ def release(self): self.do_release(expected_token) def do_release(self, expected_token): - if not bool(self.lua_release(keys=[self.name], - args=[expected_token], - client=self.redis)): - raise LockNotOwnedError("Cannot release a lock" - " that's no longer owned") + if not bool( + self.lua_release(keys=[self.name], args=[expected_token], client=self.redis) + ): + raise LockNotOwnedError("Cannot release a lock" " that's no longer owned") def extend(self, additional_time, replace_ttl=False): """ @@ -262,17 +264,11 @@ def do_extend(self, additional_time, replace_ttl): if not bool( self.lua_extend( keys=[self.name], - args=[ - self.local.token, - additional_time, - replace_ttl and "1" or "0" - ], + args=[self.local.token, additional_time, replace_ttl and "1" or "0"], client=self.redis, ) ): - raise LockNotOwnedError( - "Cannot extend a lock that's" " no longer owned" - ) + raise LockNotOwnedError("Cannot extend a lock that's" " no longer owned") return True def reacquire(self): @@ -287,9 +283,10 @@ def reacquire(self): def do_reacquire(self): timeout = int(self.timeout * 1000) - if not bool(self.lua_reacquire(keys=[self.name], - args=[self.local.token, timeout], - client=self.redis)): - raise LockNotOwnedError("Cannot reacquire a lock that's" - " no longer owned") + if not bool( + self.lua_reacquire( + keys=[self.name], args=[self.local.token, timeout], client=self.redis + ) + ): + raise LockNotOwnedError("Cannot reacquire a lock that's" " no longer owned") return True diff --git a/redis/retry.py b/redis/retry.py index cd06a23e3d..75504c77e7 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -6,8 +6,9 @@ class Retry: """Retry a specific number of times after a failure""" - def __init__(self, backoff, retries, - supported_errors=(ConnectionError, TimeoutError)): + def __init__( + self, backoff, retries, supported_errors=(ConnectionError, TimeoutError) + ): """ Initialize a `Retry` object with a `Backoff` object that retries a maximum of `retries` times. diff --git a/redis/utils.py b/redis/utils.py index 0e78cc5f3b..50961cb767 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -1,8 +1,8 @@ from contextlib import contextmanager - try: import hiredis # noqa + HIREDIS_AVAILABLE = True except ImportError: HIREDIS_AVAILABLE = False @@ -16,6 +16,7 @@ def from_url(url, **kwargs): none is provided. """ from redis.client import Redis + return Redis.from_url(url, **kwargs) @@ -28,9 +29,7 @@ def pipeline(redis_obj): def str_if_bytes(value): return ( - value.decode('utf-8', errors='replace') - if isinstance(value, bytes) - else value + value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value ) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000..5abd1b3c7b --- /dev/null +++ b/setup.cfg @@ -0,0 +1,7 @@ +[flake8] +max-line-length = 88 +extend-ignore = E203, W503 + +[isort] +profile = black +multi_line_output = 3 diff --git a/tests/test_encoding.py b/tests/test_encoding.py index 706654f89f..bd0f09fcc0 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -1,7 +1,8 @@ import pytest -import redis +import redis from redis.connection import Connection + from .conftest import _get_client @@ -19,62 +20,70 @@ def r_no_decode(self, request): ) def test_simple_encoding(self, r_no_decode): - unicode_string = chr(3456) + 'abcd' + chr(3421) - r_no_decode['unicode-string'] = unicode_string.encode('utf-8') - cached_val = r_no_decode['unicode-string'] + unicode_string = chr(3456) + "abcd" + chr(3421) + r_no_decode["unicode-string"] = unicode_string.encode("utf-8") + cached_val = r_no_decode["unicode-string"] assert isinstance(cached_val, bytes) - assert unicode_string == cached_val.decode('utf-8') + assert unicode_string == cached_val.decode("utf-8") def test_simple_encoding_and_decoding(self, r): - unicode_string = chr(3456) + 'abcd' + chr(3421) - r['unicode-string'] = unicode_string - cached_val = r['unicode-string'] + unicode_string = chr(3456) + "abcd" + chr(3421) + r["unicode-string"] = unicode_string + cached_val = r["unicode-string"] assert isinstance(cached_val, str) assert unicode_string == cached_val def test_memoryview_encoding(self, r_no_decode): - unicode_string = chr(3456) + 'abcd' + chr(3421) - unicode_string_view = memoryview(unicode_string.encode('utf-8')) - r_no_decode['unicode-string-memoryview'] = unicode_string_view - cached_val = r_no_decode['unicode-string-memoryview'] + unicode_string = chr(3456) + "abcd" + chr(3421) + unicode_string_view = memoryview(unicode_string.encode("utf-8")) + r_no_decode["unicode-string-memoryview"] = unicode_string_view + cached_val = r_no_decode["unicode-string-memoryview"] # The cached value won't be a memoryview because it's a copy from Redis assert isinstance(cached_val, bytes) - assert unicode_string == cached_val.decode('utf-8') + assert unicode_string == cached_val.decode("utf-8") def test_memoryview_encoding_and_decoding(self, r): - unicode_string = chr(3456) + 'abcd' + chr(3421) - unicode_string_view = memoryview(unicode_string.encode('utf-8')) - r['unicode-string-memoryview'] = unicode_string_view - cached_val = r['unicode-string-memoryview'] + unicode_string = chr(3456) + "abcd" + chr(3421) + unicode_string_view = memoryview(unicode_string.encode("utf-8")) + r["unicode-string-memoryview"] = unicode_string_view + cached_val = r["unicode-string-memoryview"] assert isinstance(cached_val, str) assert unicode_string == cached_val def test_list_encoding(self, r): - unicode_string = chr(3456) + 'abcd' + chr(3421) + unicode_string = chr(3456) + "abcd" + chr(3421) result = [unicode_string, unicode_string, unicode_string] - r.rpush('a', *result) - assert r.lrange('a', 0, -1) == result + r.rpush("a", *result) + assert r.lrange("a", 0, -1) == result class TestEncodingErrors: def test_ignore(self, request): - r = _get_client(redis.Redis, request=request, decode_responses=True, - encoding_errors='ignore') - r.set('a', b'foo\xff') - assert r.get('a') == 'foo' + r = _get_client( + redis.Redis, + request=request, + decode_responses=True, + encoding_errors="ignore", + ) + r.set("a", b"foo\xff") + assert r.get("a") == "foo" def test_replace(self, request): - r = _get_client(redis.Redis, request=request, decode_responses=True, - encoding_errors='replace') - r.set('a', b'foo\xff') - assert r.get('a') == 'foo\ufffd' + r = _get_client( + redis.Redis, + request=request, + decode_responses=True, + encoding_errors="replace", + ) + r.set("a", b"foo\xff") + assert r.get("a") == "foo\ufffd" class TestMemoryviewsAreNotPacked: def test_memoryviews_are_not_packed(self): c = Connection() - arg = memoryview(b'some_arg') - arg_list = ['SOME_COMMAND', arg] + arg = memoryview(b"some_arg") + arg_list = ["SOME_COMMAND", arg] cmd = c.pack_command(*arg_list) assert cmd[1] is arg cmds = c.pack_commands([arg_list, arg_list]) @@ -85,25 +94,25 @@ def test_memoryviews_are_not_packed(self): class TestCommandsAreNotEncoded: @pytest.fixture() def r(self, request): - return _get_client(redis.Redis, request=request, encoding='utf-16') + return _get_client(redis.Redis, request=request, encoding="utf-16") def test_basic_command(self, r): - r.set('hello', 'world') + r.set("hello", "world") class TestInvalidUserInput: def test_boolean_fails(self, r): with pytest.raises(redis.DataError): - r.set('a', True) + r.set("a", True) def test_none_fails(self, r): with pytest.raises(redis.DataError): - r.set('a', None) + r.set("a", None) def test_user_type_fails(self, r): class Foo: def __str__(self): - return 'Foo' + return "Foo" with pytest.raises(redis.DataError): - r.set('a', Foo()) + r.set("a", Foo()) diff --git a/tests/test_lock.py b/tests/test_lock.py index 66148edcfc..02cca1b522 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -1,9 +1,11 @@ -import pytest import time -from redis.exceptions import LockError, LockNotOwnedError +import pytest + from redis.client import Redis +from redis.exceptions import LockError, LockNotOwnedError from redis.lock import Lock + from .conftest import _get_client @@ -14,36 +16,36 @@ def r_decoded(self, request): return _get_client(Redis, request=request, decode_responses=True) def get_lock(self, redis, *args, **kwargs): - kwargs['lock_class'] = Lock + kwargs["lock_class"] = Lock return redis.lock(*args, **kwargs) def test_lock(self, r): - lock = self.get_lock(r, 'foo') + lock = self.get_lock(r, "foo") assert lock.acquire(blocking=False) - assert r.get('foo') == lock.local.token - assert r.ttl('foo') == -1 + assert r.get("foo") == lock.local.token + assert r.ttl("foo") == -1 lock.release() - assert r.get('foo') is None + assert r.get("foo") is None def test_lock_token(self, r): - lock = self.get_lock(r, 'foo') + lock = self.get_lock(r, "foo") self._test_lock_token(r, lock) def test_lock_token_thread_local_false(self, r): - lock = self.get_lock(r, 'foo', thread_local=False) + lock = self.get_lock(r, "foo", thread_local=False) self._test_lock_token(r, lock) def _test_lock_token(self, r, lock): - assert lock.acquire(blocking=False, token='test') - assert r.get('foo') == b'test' - assert lock.local.token == b'test' - assert r.ttl('foo') == -1 + assert lock.acquire(blocking=False, token="test") + assert r.get("foo") == b"test" + assert lock.local.token == b"test" + assert r.ttl("foo") == -1 lock.release() - assert r.get('foo') is None + assert r.get("foo") is None assert lock.local.token is None def test_locked(self, r): - lock = self.get_lock(r, 'foo') + lock = self.get_lock(r, "foo") assert lock.locked() is False lock.acquire(blocking=False) assert lock.locked() is True @@ -51,14 +53,14 @@ def test_locked(self, r): assert lock.locked() is False def _test_owned(self, client): - lock = self.get_lock(client, 'foo') + lock = self.get_lock(client, "foo") assert lock.owned() is False lock.acquire(blocking=False) assert lock.owned() is True lock.release() assert lock.owned() is False - lock2 = self.get_lock(client, 'foo') + lock2 = self.get_lock(client, "foo") assert lock.owned() is False assert lock2.owned() is False lock2.acquire(blocking=False) @@ -75,8 +77,8 @@ def test_owned_with_decoded_responses(self, r_decoded): self._test_owned(r_decoded) def test_competing_locks(self, r): - lock1 = self.get_lock(r, 'foo') - lock2 = self.get_lock(r, 'foo') + lock1 = self.get_lock(r, "foo") + lock2 = self.get_lock(r, "foo") assert lock1.acquire(blocking=False) assert not lock2.acquire(blocking=False) lock1.release() @@ -85,23 +87,23 @@ def test_competing_locks(self, r): lock2.release() def test_timeout(self, r): - lock = self.get_lock(r, 'foo', timeout=10) + lock = self.get_lock(r, "foo", timeout=10) assert lock.acquire(blocking=False) - assert 8 < r.ttl('foo') <= 10 + assert 8 < r.ttl("foo") <= 10 lock.release() def test_float_timeout(self, r): - lock = self.get_lock(r, 'foo', timeout=9.5) + lock = self.get_lock(r, "foo", timeout=9.5) assert lock.acquire(blocking=False) - assert 8 < r.pttl('foo') <= 9500 + assert 8 < r.pttl("foo") <= 9500 lock.release() def test_blocking_timeout(self, r): - lock1 = self.get_lock(r, 'foo') + lock1 = self.get_lock(r, "foo") assert lock1.acquire(blocking=False) bt = 0.2 sleep = 0.05 - lock2 = self.get_lock(r, 'foo', sleep=sleep, blocking_timeout=bt) + lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt) start = time.monotonic() assert not lock2.acquire() # The elapsed duration should be less than the total blocking_timeout @@ -111,22 +113,22 @@ def test_blocking_timeout(self, r): def test_context_manager(self, r): # blocking_timeout prevents a deadlock if the lock can't be acquired # for some reason - with self.get_lock(r, 'foo', blocking_timeout=0.2) as lock: - assert r.get('foo') == lock.local.token - assert r.get('foo') is None + with self.get_lock(r, "foo", blocking_timeout=0.2) as lock: + assert r.get("foo") == lock.local.token + assert r.get("foo") is None def test_context_manager_raises_when_locked_not_acquired(self, r): - r.set('foo', 'bar') + r.set("foo", "bar") with pytest.raises(LockError): - with self.get_lock(r, 'foo', blocking_timeout=0.1): + with self.get_lock(r, "foo", blocking_timeout=0.1): pass def test_high_sleep_small_blocking_timeout(self, r): - lock1 = self.get_lock(r, 'foo') + lock1 = self.get_lock(r, "foo") assert lock1.acquire(blocking=False) sleep = 60 bt = 1 - lock2 = self.get_lock(r, 'foo', sleep=sleep, blocking_timeout=bt) + lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt) start = time.monotonic() assert not lock2.acquire() # the elapsed timed is less than the blocking_timeout as the lock is @@ -135,88 +137,88 @@ def test_high_sleep_small_blocking_timeout(self, r): lock1.release() def test_releasing_unlocked_lock_raises_error(self, r): - lock = self.get_lock(r, 'foo') + lock = self.get_lock(r, "foo") with pytest.raises(LockError): lock.release() def test_releasing_lock_no_longer_owned_raises_error(self, r): - lock = self.get_lock(r, 'foo') + lock = self.get_lock(r, "foo") lock.acquire(blocking=False) # manually change the token - r.set('foo', 'a') + r.set("foo", "a") with pytest.raises(LockNotOwnedError): lock.release() # even though we errored, the token is still cleared assert lock.local.token is None def test_extend_lock(self, r): - lock = self.get_lock(r, 'foo', timeout=10) + lock = self.get_lock(r, "foo", timeout=10) assert lock.acquire(blocking=False) - assert 8000 < r.pttl('foo') <= 10000 + assert 8000 < r.pttl("foo") <= 10000 assert lock.extend(10) - assert 16000 < r.pttl('foo') <= 20000 + assert 16000 < r.pttl("foo") <= 20000 lock.release() def test_extend_lock_replace_ttl(self, r): - lock = self.get_lock(r, 'foo', timeout=10) + lock = self.get_lock(r, "foo", timeout=10) assert lock.acquire(blocking=False) - assert 8000 < r.pttl('foo') <= 10000 + assert 8000 < r.pttl("foo") <= 10000 assert lock.extend(10, replace_ttl=True) - assert 8000 < r.pttl('foo') <= 10000 + assert 8000 < r.pttl("foo") <= 10000 lock.release() def test_extend_lock_float(self, r): - lock = self.get_lock(r, 'foo', timeout=10.0) + lock = self.get_lock(r, "foo", timeout=10.0) assert lock.acquire(blocking=False) - assert 8000 < r.pttl('foo') <= 10000 + assert 8000 < r.pttl("foo") <= 10000 assert lock.extend(10.0) - assert 16000 < r.pttl('foo') <= 20000 + assert 16000 < r.pttl("foo") <= 20000 lock.release() def test_extending_unlocked_lock_raises_error(self, r): - lock = self.get_lock(r, 'foo', timeout=10) + lock = self.get_lock(r, "foo", timeout=10) with pytest.raises(LockError): lock.extend(10) def test_extending_lock_with_no_timeout_raises_error(self, r): - lock = self.get_lock(r, 'foo') + lock = self.get_lock(r, "foo") assert lock.acquire(blocking=False) with pytest.raises(LockError): lock.extend(10) lock.release() def test_extending_lock_no_longer_owned_raises_error(self, r): - lock = self.get_lock(r, 'foo', timeout=10) + lock = self.get_lock(r, "foo", timeout=10) assert lock.acquire(blocking=False) - r.set('foo', 'a') + r.set("foo", "a") with pytest.raises(LockNotOwnedError): lock.extend(10) def test_reacquire_lock(self, r): - lock = self.get_lock(r, 'foo', timeout=10) + lock = self.get_lock(r, "foo", timeout=10) assert lock.acquire(blocking=False) - assert r.pexpire('foo', 5000) - assert r.pttl('foo') <= 5000 + assert r.pexpire("foo", 5000) + assert r.pttl("foo") <= 5000 assert lock.reacquire() - assert 8000 < r.pttl('foo') <= 10000 + assert 8000 < r.pttl("foo") <= 10000 lock.release() def test_reacquiring_unlocked_lock_raises_error(self, r): - lock = self.get_lock(r, 'foo', timeout=10) + lock = self.get_lock(r, "foo", timeout=10) with pytest.raises(LockError): lock.reacquire() def test_reacquiring_lock_with_no_timeout_raises_error(self, r): - lock = self.get_lock(r, 'foo') + lock = self.get_lock(r, "foo") assert lock.acquire(blocking=False) with pytest.raises(LockError): lock.reacquire() lock.release() def test_reacquiring_lock_no_longer_owned_raises_error(self, r): - lock = self.get_lock(r, 'foo', timeout=10) + lock = self.get_lock(r, "foo", timeout=10) assert lock.acquire(blocking=False) - r.set('foo', 'a') + r.set("foo", "a") with pytest.raises(LockNotOwnedError): lock.reacquire() @@ -228,5 +230,6 @@ class MyLock: def __init__(self, *args, **kwargs): pass - lock = r.lock('foo', lock_class=MyLock) + + lock = r.lock("foo", lock_class=MyLock) assert type(lock) == MyLock diff --git a/tests/test_monitor.py b/tests/test_monitor.py index 6c3ea33bce..09e70d828f 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -2,7 +2,7 @@ from .conftest import ( skip_if_redis_enterprise, skip_ifnot_redis_enterprise, - wait_for_command + wait_for_command, ) @@ -11,56 +11,56 @@ class TestMonitor: def test_wait_command_not_found(self, r): "Make sure the wait_for_command func works when command is not found" with r.monitor() as m: - response = wait_for_command(r, m, 'nothing') + response = wait_for_command(r, m, "nothing") assert response is None def test_response_values(self, r): - db = r.connection_pool.connection_kwargs.get('db', 0) + db = r.connection_pool.connection_kwargs.get("db", 0) with r.monitor() as m: r.ping() - response = wait_for_command(r, m, 'PING') - assert isinstance(response['time'], float) - assert response['db'] == db - assert response['client_type'] in ('tcp', 'unix') - assert isinstance(response['client_address'], str) - assert isinstance(response['client_port'], str) - assert response['command'] == 'PING' + response = wait_for_command(r, m, "PING") + assert isinstance(response["time"], float) + assert response["db"] == db + assert response["client_type"] in ("tcp", "unix") + assert isinstance(response["client_address"], str) + assert isinstance(response["client_port"], str) + assert response["command"] == "PING" def test_command_with_quoted_key(self, r): with r.monitor() as m: r.get('foo"bar') response = wait_for_command(r, m, 'GET foo"bar') - assert response['command'] == 'GET foo"bar' + assert response["command"] == 'GET foo"bar' def test_command_with_binary_data(self, r): with r.monitor() as m: - byte_string = b'foo\x92' + byte_string = b"foo\x92" r.get(byte_string) - response = wait_for_command(r, m, 'GET foo\\x92') - assert response['command'] == 'GET foo\\x92' + response = wait_for_command(r, m, "GET foo\\x92") + assert response["command"] == "GET foo\\x92" def test_command_with_escaped_data(self, r): with r.monitor() as m: - byte_string = b'foo\\x92' + byte_string = b"foo\\x92" r.get(byte_string) - response = wait_for_command(r, m, 'GET foo\\\\x92') - assert response['command'] == 'GET foo\\\\x92' + response = wait_for_command(r, m, "GET foo\\\\x92") + assert response["command"] == "GET foo\\\\x92" @skip_if_redis_enterprise def test_lua_script(self, r): with r.monitor() as m: script = 'return redis.call("GET", "foo")' assert r.eval(script, 0) is None - response = wait_for_command(r, m, 'GET foo') - assert response['command'] == 'GET foo' - assert response['client_type'] == 'lua' - assert response['client_address'] == 'lua' - assert response['client_port'] == '' + response = wait_for_command(r, m, "GET foo") + assert response["command"] == "GET foo" + assert response["client_type"] == "lua" + assert response["client_address"] == "lua" + assert response["client_port"] == "" @skip_ifnot_redis_enterprise def test_lua_script_in_enterprise(self, r): with r.monitor() as m: script = 'return redis.call("GET", "foo")' assert r.eval(script, 0) is None - response = wait_for_command(r, m, 'GET foo') + response = wait_for_command(r, m, "GET foo") assert response is None diff --git a/tests/test_retry.py b/tests/test_retry.py index 535485acae..c4650bc650 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -1,8 +1,8 @@ -from redis.backoff import NoBackoff import pytest -from redis.exceptions import ConnectionError +from redis.backoff import NoBackoff from redis.connection import Connection, UnixDomainSocketConnection +from redis.exceptions import ConnectionError from redis.retry import Retry @@ -34,8 +34,7 @@ def test_retry_on_timeout_boolean(self, Class, retry_on_timeout): @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection]) def test_retry_on_timeout_retry(self, Class, retries): retry_on_timeout = retries > 0 - c = Class(retry_on_timeout=retry_on_timeout, - retry=Retry(NoBackoff(), retries)) + c = Class(retry_on_timeout=retry_on_timeout, retry=Retry(NoBackoff(), retries)) assert c.retry_on_timeout == retry_on_timeout assert isinstance(c.retry, Retry) assert c.retry._retries == retries diff --git a/tests/test_scripting.py b/tests/test_scripting.py index 7614b1233f..9f4f82023f 100644 --- a/tests/test_scripting.py +++ b/tests/test_scripting.py @@ -1,10 +1,8 @@ import pytest from redis import exceptions - from tests.conftest import skip_if_server_version_lt - multiply_script = """ local value = redis.call('GET', KEYS[1]) value = tonumber(value) @@ -29,52 +27,52 @@ def reset_scripts(self, r): r.script_flush() def test_eval(self, r): - r.set('a', 2) + r.set("a", 2) # 2 * 3 == 6 - assert r.eval(multiply_script, 1, 'a', 3) == 6 + assert r.eval(multiply_script, 1, "a", 3) == 6 - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_script_flush_620(self, r): - r.set('a', 2) + r.set("a", 2) r.script_load(multiply_script) - r.script_flush('ASYNC') + r.script_flush("ASYNC") - r.set('a', 2) + r.set("a", 2) r.script_load(multiply_script) - r.script_flush('SYNC') + r.script_flush("SYNC") - r.set('a', 2) + r.set("a", 2) r.script_load(multiply_script) r.script_flush() with pytest.raises(exceptions.DataError): - r.set('a', 2) + r.set("a", 2) r.script_load(multiply_script) r.script_flush("NOTREAL") def test_script_flush(self, r): - r.set('a', 2) + r.set("a", 2) r.script_load(multiply_script) r.script_flush(None) with pytest.raises(exceptions.DataError): - r.set('a', 2) + r.set("a", 2) r.script_load(multiply_script) r.script_flush("NOTREAL") def test_evalsha(self, r): - r.set('a', 2) + r.set("a", 2) sha = r.script_load(multiply_script) # 2 * 3 == 6 - assert r.evalsha(sha, 1, 'a', 3) == 6 + assert r.evalsha(sha, 1, "a", 3) == 6 def test_evalsha_script_not_loaded(self, r): - r.set('a', 2) + r.set("a", 2) sha = r.script_load(multiply_script) # remove the script from Redis's cache r.script_flush() with pytest.raises(exceptions.NoScriptError): - r.evalsha(sha, 1, 'a', 3) + r.evalsha(sha, 1, "a", 3) def test_script_loading(self, r): # get the sha, then clear the cache @@ -85,31 +83,31 @@ def test_script_loading(self, r): assert r.script_exists(sha) == [True] def test_script_object(self, r): - r.set('a', 2) + r.set("a", 2) multiply = r.register_script(multiply_script) precalculated_sha = multiply.sha assert precalculated_sha assert r.script_exists(multiply.sha) == [False] # Test second evalsha block (after NoScriptError) - assert multiply(keys=['a'], args=[3]) == 6 + assert multiply(keys=["a"], args=[3]) == 6 # At this point, the script should be loaded assert r.script_exists(multiply.sha) == [True] # Test that the precalculated sha matches the one from redis assert multiply.sha == precalculated_sha # Test first evalsha block - assert multiply(keys=['a'], args=[3]) == 6 + assert multiply(keys=["a"], args=[3]) == 6 def test_script_object_in_pipeline(self, r): multiply = r.register_script(multiply_script) precalculated_sha = multiply.sha assert precalculated_sha pipe = r.pipeline() - pipe.set('a', 2) - pipe.get('a') - multiply(keys=['a'], args=[3], client=pipe) + pipe.set("a", 2) + pipe.get("a") + multiply(keys=["a"], args=[3], client=pipe) assert r.script_exists(multiply.sha) == [False] # [SET worked, GET 'a', result of multiple script] - assert pipe.execute() == [True, b'2', 6] + assert pipe.execute() == [True, b"2", 6] # The script should have been loaded by pipe.execute() assert r.script_exists(multiply.sha) == [True] # The precalculated sha should have been the correct one @@ -119,12 +117,12 @@ def test_script_object_in_pipeline(self, r): # the multiply script should be reloaded by pipe.execute() r.script_flush() pipe = r.pipeline() - pipe.set('a', 2) - pipe.get('a') - multiply(keys=['a'], args=[3], client=pipe) + pipe.set("a", 2) + pipe.get("a") + multiply(keys=["a"], args=[3], client=pipe) assert r.script_exists(multiply.sha) == [False] # [SET worked, GET 'a', result of multiple script] - assert pipe.execute() == [True, b'2', 6] + assert pipe.execute() == [True, b"2", 6] assert r.script_exists(multiply.sha) == [True] def test_eval_msgpack_pipeline_error_in_lua(self, r): @@ -135,12 +133,12 @@ def test_eval_msgpack_pipeline_error_in_lua(self, r): # avoiding a dependency to msgpack, this is the output of # msgpack.dumps({"name": "joe"}) - msgpack_message_1 = b'\x81\xa4name\xa3Joe' + msgpack_message_1 = b"\x81\xa4name\xa3Joe" msgpack_hello(args=[msgpack_message_1], client=pipe) assert r.script_exists(msgpack_hello.sha) == [False] - assert pipe.execute()[0] == b'hello Joe' + assert pipe.execute()[0] == b"hello Joe" assert r.script_exists(msgpack_hello.sha) == [True] msgpack_hello_broken = r.register_script(msgpack_hello_script_broken) diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py index 9377d5ba65..0357443a14 100644 --- a/tests/test_sentinel.py +++ b/tests/test_sentinel.py @@ -2,10 +2,14 @@ import pytest -from redis import exceptions -from redis.sentinel import (Sentinel, SentinelConnectionPool, - MasterNotFoundError, SlaveNotFoundError) import redis.sentinel +from redis import exceptions +from redis.sentinel import ( + MasterNotFoundError, + Sentinel, + SentinelConnectionPool, + SlaveNotFoundError, +) @pytest.fixture(scope="module") @@ -33,20 +37,20 @@ def sentinel_slaves(self, master_name): def execute_command(self, *args, **kwargs): # wrapper purely to validate the calls don't explode from redis.client import bool_ok + return bool_ok class SentinelTestCluster: - def __init__(self, servisentinel_ce_name='mymaster', ip='127.0.0.1', - port=6379): + def __init__(self, servisentinel_ce_name="mymaster", ip="127.0.0.1", port=6379): self.clients = {} self.master = { - 'ip': ip, - 'port': port, - 'is_master': True, - 'is_sdown': False, - 'is_odown': False, - 'num-other-sentinels': 0, + "ip": ip, + "port": port, + "is_master": True, + "is_sdown": False, + "is_odown": False, + "num-other-sentinels": 0, } self.service_name = servisentinel_ce_name self.slaves = [] @@ -69,6 +73,7 @@ def client(self, host, port, **kwargs): def cluster(request, master_ip): def teardown(): redis.sentinel.Redis = saved_Redis + cluster = SentinelTestCluster(ip=master_ip) saved_Redis = redis.sentinel.Redis redis.sentinel.Redis = cluster.client @@ -78,126 +83,121 @@ def teardown(): @pytest.fixture() def sentinel(request, cluster): - return Sentinel([('foo', 26379), ('bar', 26379)]) + return Sentinel([("foo", 26379), ("bar", 26379)]) @pytest.mark.onlynoncluster def test_discover_master(sentinel, master_ip): - address = sentinel.discover_master('mymaster') + address = sentinel.discover_master("mymaster") assert address == (master_ip, 6379) @pytest.mark.onlynoncluster def test_discover_master_error(sentinel): with pytest.raises(MasterNotFoundError): - sentinel.discover_master('xxx') + sentinel.discover_master("xxx") @pytest.mark.onlynoncluster def test_discover_master_sentinel_down(cluster, sentinel, master_ip): # Put first sentinel 'foo' down - cluster.nodes_down.add(('foo', 26379)) - address = sentinel.discover_master('mymaster') + cluster.nodes_down.add(("foo", 26379)) + address = sentinel.discover_master("mymaster") assert address == (master_ip, 6379) # 'bar' is now first sentinel - assert sentinel.sentinels[0].id == ('bar', 26379) + assert sentinel.sentinels[0].id == ("bar", 26379) @pytest.mark.onlynoncluster def test_discover_master_sentinel_timeout(cluster, sentinel, master_ip): # Put first sentinel 'foo' down - cluster.nodes_timeout.add(('foo', 26379)) - address = sentinel.discover_master('mymaster') + cluster.nodes_timeout.add(("foo", 26379)) + address = sentinel.discover_master("mymaster") assert address == (master_ip, 6379) # 'bar' is now first sentinel - assert sentinel.sentinels[0].id == ('bar', 26379) + assert sentinel.sentinels[0].id == ("bar", 26379) @pytest.mark.onlynoncluster def test_master_min_other_sentinels(cluster, master_ip): - sentinel = Sentinel([('foo', 26379)], min_other_sentinels=1) + sentinel = Sentinel([("foo", 26379)], min_other_sentinels=1) # min_other_sentinels with pytest.raises(MasterNotFoundError): - sentinel.discover_master('mymaster') - cluster.master['num-other-sentinels'] = 2 - address = sentinel.discover_master('mymaster') + sentinel.discover_master("mymaster") + cluster.master["num-other-sentinels"] = 2 + address = sentinel.discover_master("mymaster") assert address == (master_ip, 6379) @pytest.mark.onlynoncluster def test_master_odown(cluster, sentinel): - cluster.master['is_odown'] = True + cluster.master["is_odown"] = True with pytest.raises(MasterNotFoundError): - sentinel.discover_master('mymaster') + sentinel.discover_master("mymaster") @pytest.mark.onlynoncluster def test_master_sdown(cluster, sentinel): - cluster.master['is_sdown'] = True + cluster.master["is_sdown"] = True with pytest.raises(MasterNotFoundError): - sentinel.discover_master('mymaster') + sentinel.discover_master("mymaster") @pytest.mark.onlynoncluster def test_discover_slaves(cluster, sentinel): - assert sentinel.discover_slaves('mymaster') == [] + assert sentinel.discover_slaves("mymaster") == [] cluster.slaves = [ - {'ip': 'slave0', 'port': 1234, 'is_odown': False, 'is_sdown': False}, - {'ip': 'slave1', 'port': 1234, 'is_odown': False, 'is_sdown': False}, + {"ip": "slave0", "port": 1234, "is_odown": False, "is_sdown": False}, + {"ip": "slave1", "port": 1234, "is_odown": False, "is_sdown": False}, ] - assert sentinel.discover_slaves('mymaster') == [ - ('slave0', 1234), ('slave1', 1234)] + assert sentinel.discover_slaves("mymaster") == [("slave0", 1234), ("slave1", 1234)] # slave0 -> ODOWN - cluster.slaves[0]['is_odown'] = True - assert sentinel.discover_slaves('mymaster') == [ - ('slave1', 1234)] + cluster.slaves[0]["is_odown"] = True + assert sentinel.discover_slaves("mymaster") == [("slave1", 1234)] # slave1 -> SDOWN - cluster.slaves[1]['is_sdown'] = True - assert sentinel.discover_slaves('mymaster') == [] + cluster.slaves[1]["is_sdown"] = True + assert sentinel.discover_slaves("mymaster") == [] - cluster.slaves[0]['is_odown'] = False - cluster.slaves[1]['is_sdown'] = False + cluster.slaves[0]["is_odown"] = False + cluster.slaves[1]["is_sdown"] = False # node0 -> DOWN - cluster.nodes_down.add(('foo', 26379)) - assert sentinel.discover_slaves('mymaster') == [ - ('slave0', 1234), ('slave1', 1234)] + cluster.nodes_down.add(("foo", 26379)) + assert sentinel.discover_slaves("mymaster") == [("slave0", 1234), ("slave1", 1234)] cluster.nodes_down.clear() # node0 -> TIMEOUT - cluster.nodes_timeout.add(('foo', 26379)) - assert sentinel.discover_slaves('mymaster') == [ - ('slave0', 1234), ('slave1', 1234)] + cluster.nodes_timeout.add(("foo", 26379)) + assert sentinel.discover_slaves("mymaster") == [("slave0", 1234), ("slave1", 1234)] @pytest.mark.onlynoncluster def test_master_for(cluster, sentinel, master_ip): - master = sentinel.master_for('mymaster', db=9) + master = sentinel.master_for("mymaster", db=9) assert master.ping() assert master.connection_pool.master_address == (master_ip, 6379) # Use internal connection check - master = sentinel.master_for('mymaster', db=9, check_connection=True) + master = sentinel.master_for("mymaster", db=9, check_connection=True) assert master.ping() @pytest.mark.onlynoncluster def test_slave_for(cluster, sentinel): cluster.slaves = [ - {'ip': '127.0.0.1', 'port': 6379, - 'is_odown': False, 'is_sdown': False}, + {"ip": "127.0.0.1", "port": 6379, "is_odown": False, "is_sdown": False}, ] - slave = sentinel.slave_for('mymaster', db=9) + slave = sentinel.slave_for("mymaster", db=9) assert slave.ping() @pytest.mark.onlynoncluster def test_slave_for_slave_not_found_error(cluster, sentinel): - cluster.master['is_odown'] = True - slave = sentinel.slave_for('mymaster', db=9) + cluster.master["is_odown"] = True + slave = sentinel.slave_for("mymaster", db=9) with pytest.raises(SlaveNotFoundError): slave.ping() @@ -205,13 +205,13 @@ def test_slave_for_slave_not_found_error(cluster, sentinel): @pytest.mark.onlynoncluster def test_slave_round_robin(cluster, sentinel, master_ip): cluster.slaves = [ - {'ip': 'slave0', 'port': 6379, 'is_odown': False, 'is_sdown': False}, - {'ip': 'slave1', 'port': 6379, 'is_odown': False, 'is_sdown': False}, + {"ip": "slave0", "port": 6379, "is_odown": False, "is_sdown": False}, + {"ip": "slave1", "port": 6379, "is_odown": False, "is_sdown": False}, ] - pool = SentinelConnectionPool('mymaster', sentinel) + pool = SentinelConnectionPool("mymaster", sentinel) rotator = pool.rotate_slaves() - assert next(rotator) in (('slave0', 6379), ('slave1', 6379)) - assert next(rotator) in (('slave0', 6379), ('slave1', 6379)) + assert next(rotator) in (("slave0", 6379), ("slave1", 6379)) + assert next(rotator) in (("slave0", 6379), ("slave1", 6379)) # Fallback to master assert next(rotator) == (master_ip, 6379) with pytest.raises(SlaveNotFoundError): @@ -230,5 +230,5 @@ def test_flushconfig(cluster, sentinel): @pytest.mark.onlynoncluster def test_reset(cluster, sentinel): - cluster.master['is_odown'] = True - assert sentinel.sentinel_reset('mymaster') + cluster.master["is_odown"] = True + assert sentinel.sentinel_reset("mymaster") diff --git a/tests/test_timeseries.py b/tests/test_timeseries.py index 07433574f1..8c97ab804d 100644 --- a/tests/test_timeseries.py +++ b/tests/test_timeseries.py @@ -1,6 +1,8 @@ -import pytest import time from time import sleep + +import pytest + from .conftest import skip_ifmodversion_lt @@ -68,8 +70,7 @@ def test_add(client): assert 4 == client.ts().add( 4, 4, 2, retention_msecs=10, labels={"Redis": "Labs", "Time": "Series"} ) - assert round(time.time()) == \ - round(float(client.ts().add(5, "*", 1)) / 1000) + assert round(time.time()) == round(float(client.ts().add(5, "*", 1)) / 1000) info = client.ts().info(4) assert 10 == info.retention_msecs @@ -88,12 +89,7 @@ def test_add_duplicate_policy(client): # Test for duplicate policy BLOCK assert 1 == client.ts().add("time-serie-add-ooo-block", 1, 5.0) with pytest.raises(Exception): - client.ts().add( - "time-serie-add-ooo-block", - 1, - 5.0, - duplicate_policy="block" - ) + client.ts().add("time-serie-add-ooo-block", 1, 5.0, duplicate_policy="block") # Test for duplicate policy LAST assert 1 == client.ts().add("time-serie-add-ooo-last", 1, 5.0) @@ -127,8 +123,7 @@ def test_add_duplicate_policy(client): @pytest.mark.redismod def test_madd(client): client.ts().create("a") - assert [1, 2, 3] == \ - client.ts().madd([("a", 1, 5), ("a", 2, 10), ("a", 3, 15)]) + assert [1, 2, 3] == client.ts().madd([("a", 1, 5), ("a", 2, 10), ("a", 3, 15)]) @pytest.mark.redismod @@ -206,13 +201,7 @@ def test_range(client): assert 200 == len(client.ts().range(1, 0, 500)) # last sample isn't returned assert 20 == len( - client.ts().range( - 1, - 0, - 500, - aggregation_type="avg", - bucket_size_msec=10 - ) + client.ts().range(1, 0, 500, aggregation_type="avg", bucket_size_msec=10) ) assert 10 == len(client.ts().range(1, 0, 500, count=10)) @@ -253,13 +242,7 @@ def test_rev_range(client): assert 200 == len(client.ts().range(1, 0, 500)) # first sample isn't returned assert 20 == len( - client.ts().revrange( - 1, - 0, - 500, - aggregation_type="avg", - bucket_size_msec=10 - ) + client.ts().revrange(1, 0, 500, aggregation_type="avg", bucket_size_msec=10) ) assert 10 == len(client.ts().revrange(1, 0, 500, count=10)) assert 2 == len( @@ -283,10 +266,7 @@ def test_rev_range(client): @pytest.mark.redismod def testMultiRange(client): client.ts().create(1, labels={"Test": "This", "team": "ny"}) - client.ts().create( - 2, - labels={"Test": "This", "Taste": "That", "team": "sf"} - ) + client.ts().create(2, labels={"Test": "This", "Taste": "That", "team": "sf"}) for i in range(100): client.ts().add(1, i, i % 7) client.ts().add(2, i, i % 11) @@ -301,11 +281,7 @@ def testMultiRange(client): for i in range(100): client.ts().add(1, i + 200, i % 7) res = client.ts().mrange( - 0, - 500, - filters=["Test=This"], - aggregation_type="avg", - bucket_size_msec=10 + 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 ) assert 2 == len(res) assert 20 == len(res[0]["1"][1]) @@ -320,21 +296,13 @@ def testMultiRange(client): @skip_ifmodversion_lt("99.99.99", "timeseries") def test_multi_range_advanced(client): client.ts().create(1, labels={"Test": "This", "team": "ny"}) - client.ts().create( - 2, - labels={"Test": "This", "Taste": "That", "team": "sf"} - ) + client.ts().create(2, labels={"Test": "This", "Taste": "That", "team": "sf"}) for i in range(100): client.ts().add(1, i, i % 7) client.ts().add(2, i, i % 11) # test with selected labels - res = client.ts().mrange( - 0, - 200, - filters=["Test=This"], - select_labels=["team"] - ) + res = client.ts().mrange(0, 200, filters=["Test=This"], select_labels=["team"]) assert {"team": "ny"} == res[0]["1"][0] assert {"team": "sf"} == res[1]["2"][0] @@ -350,28 +318,11 @@ def test_multi_range_advanced(client): assert [(15, 1.0), (16, 2.0)] == res[0]["1"][1] # test groupby - res = client.ts().mrange( - 0, - 3, - filters=["Test=This"], - groupby="Test", - reduce="sum" - ) + res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="Test", reduce="sum") assert [(0, 0.0), (1, 2.0), (2, 4.0), (3, 6.0)] == res[0]["Test=This"][1] - res = client.ts().mrange( - 0, - 3, - filters=["Test=This"], - groupby="Test", - reduce="max" - ) + res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="Test", reduce="max") assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["Test=This"][1] - res = client.ts().mrange( - 0, - 3, - filters=["Test=This"], - groupby="team", - reduce="min") + res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="team", reduce="min") assert 2 == len(res) assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["team=ny"][1] assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[1]["team=sf"][1] @@ -401,10 +352,7 @@ def test_multi_range_advanced(client): @skip_ifmodversion_lt("99.99.99", "timeseries") def test_multi_reverse_range(client): client.ts().create(1, labels={"Test": "This", "team": "ny"}) - client.ts().create( - 2, - labels={"Test": "This", "Taste": "That", "team": "sf"} - ) + client.ts().create(2, labels={"Test": "This", "Taste": "That", "team": "sf"}) for i in range(100): client.ts().add(1, i, i % 7) client.ts().add(2, i, i % 11) @@ -419,31 +367,18 @@ def test_multi_reverse_range(client): for i in range(100): client.ts().add(1, i + 200, i % 7) res = client.ts().mrevrange( - 0, - 500, - filters=["Test=This"], - aggregation_type="avg", - bucket_size_msec=10 + 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 ) assert 2 == len(res) assert 20 == len(res[0]["1"][1]) assert {} == res[0]["1"][0] # test withlabels - res = client.ts().mrevrange( - 0, - 200, - filters=["Test=This"], - with_labels=True - ) + res = client.ts().mrevrange(0, 200, filters=["Test=This"], with_labels=True) assert {"Test": "This", "team": "ny"} == res[0]["1"][0] # test with selected labels - res = client.ts().mrevrange( - 0, - 200, - filters=["Test=This"], select_labels=["team"] - ) + res = client.ts().mrevrange(0, 200, filters=["Test=This"], select_labels=["team"]) assert {"team": "ny"} == res[0]["1"][0] assert {"team": "sf"} == res[1]["2"][0] @@ -529,11 +464,7 @@ def test_mget(client): @pytest.mark.redismod def test_info(client): - client.ts().create( - 1, - retention_msecs=5, - labels={"currentLabel": "currentData"} - ) + client.ts().create(1, retention_msecs=5, labels={"currentLabel": "currentData"}) info = client.ts().info(1) assert 5 == info.retention_msecs assert info.labels["currentLabel"] == "currentData" @@ -542,11 +473,7 @@ def test_info(client): @pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") def testInfoDuplicatePolicy(client): - client.ts().create( - 1, - retention_msecs=5, - labels={"currentLabel": "currentData"} - ) + client.ts().create(1, retention_msecs=5, labels={"currentLabel": "currentData"}) info = client.ts().info(1) assert info.duplicate_policy is None diff --git a/tox.ini b/tox.ini index f710bbaca8..0bab8b3d09 100644 --- a/tox.ini +++ b/tox.ini @@ -130,7 +130,9 @@ commands = /usr/bin/echo deps_files = dev_requirements.txt docker = commands = - flake8 --max-line-length=88 + flake8 + black --target-version py38 --check --diff --exclude tests/test_commands.py . + isort --check-only --diff . vulture redis whitelist.py --min-confidence 80 flynt --fail-on-change --dry-run . skipsdist = true @@ -150,6 +152,7 @@ allowlist_externals = make commands = make html [flake8] +max-line-length = 88 exclude = *.egg-info, *.pyc, From 1be25917da91b9cef5fad2b818ef4635a6a65835 Mon Sep 17 00:00:00 2001 From: Anas Date: Sun, 21 Nov 2021 20:28:15 +0200 Subject: [PATCH 2/8] Fixed exclude param of black and removed setup.cfg --- setup.cfg | 7 ------- tox.ini | 10 +++++++++- 2 files changed, 9 insertions(+), 8 deletions(-) delete mode 100644 setup.cfg diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 5abd1b3c7b..0000000000 --- a/setup.cfg +++ /dev/null @@ -1,7 +0,0 @@ -[flake8] -max-line-length = 88 -extend-ignore = E203, W503 - -[isort] -profile = black -multi_line_output = 3 diff --git a/tox.ini b/tox.ini index 0bab8b3d09..50e165f719 100644 --- a/tox.ini +++ b/tox.ini @@ -90,6 +90,9 @@ healtcheck_cmd = python -c "import socket;print(True) if all([0 == socket.socket volumes = bind:rw:{toxinidir}/docker/cluster/redis.conf:/redis.conf +[isort] +profile = black +multi_line_output = 3 [testenv] deps = @@ -131,7 +134,7 @@ deps_files = dev_requirements.txt docker = commands = flake8 - black --target-version py38 --check --diff --exclude tests/test_commands.py . + black --target-version py38 --check --diff --exclude="(tests\/test_commands\.py|\.tox|venv)" . isort --check-only --diff . vulture redis whitelist.py --min-confidence 80 flynt --fail-on-change --dry-run . @@ -165,3 +168,8 @@ exclude = docker, venv*, whitelist.py +ignore = + W503 + E203 + E126 +max-line-length = 88 From c48d244bdd952433674d372b82e6dc3a0a0a88ec Mon Sep 17 00:00:00 2001 From: Anas Date: Sun, 21 Nov 2021 21:57:23 +0200 Subject: [PATCH 3/8] Added pre-commit and solved test_commands.py black error --- .pre-commit-config.yaml | 30 ++++++++++++++++++++++++++++++ dev_requirements.txt | 3 ++- 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..d4bc54d983 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,30 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: check-ast + - id: check-case-conflict + - id: check-docstring-first + - id: check-merge-conflict + - id: check-symlinks + - id: debug-statements + +- repo: https://github.com/asottile/pyupgrade + rev: v2.29.0 + hooks: + - id: pyupgrade + +- repo: https://github.com/psf/black + rev: 21.11b1 + hooks: + - id: black + +- repo: https://github.com/PyCQA/flake8 + rev: 4.0.1 + hooks: + - id: flake8 + +- repo: https://github.com/PyCQA/isort + rev: 5.10.1 + hooks: + - id: isort diff --git a/dev_requirements.txt b/dev_requirements.txt index e1c131e2c9..248f567166 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,5 +1,5 @@ black==21.11b1 -flake8>=3.9.2 +flake8==4.0.1 flynt~=0.69.0 isort==5.10.1 pytest==6.2.5 @@ -7,6 +7,7 @@ pytest-timeout==2.0.1 tox==3.24.4 tox-docker==3.1.0 invoke==1.6.0 +pre-commit==2.15.0 pytest-cov>=3.0.0 vulture>=2.3.0 ujson>=4.2.0 From 8c722c95cfdac2b4149fc24fd441c2b0a8adef2e Mon Sep 17 00:00:00 2001 From: Anas Date: Mon, 22 Nov 2021 12:03:56 +0200 Subject: [PATCH 4/8] Removed pre-commit --- .pre-commit-config.yaml | 30 ------------------------------ dev_requirements.txt | 1 - 2 files changed, 31 deletions(-) delete mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index d4bc54d983..0000000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,30 +0,0 @@ -repos: -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 - hooks: - - id: check-ast - - id: check-case-conflict - - id: check-docstring-first - - id: check-merge-conflict - - id: check-symlinks - - id: debug-statements - -- repo: https://github.com/asottile/pyupgrade - rev: v2.29.0 - hooks: - - id: pyupgrade - -- repo: https://github.com/psf/black - rev: 21.11b1 - hooks: - - id: black - -- repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 - hooks: - - id: flake8 - -- repo: https://github.com/PyCQA/isort - rev: 5.10.1 - hooks: - - id: isort diff --git a/dev_requirements.txt b/dev_requirements.txt index 248f567166..2a4f37762f 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -7,7 +7,6 @@ pytest-timeout==2.0.1 tox==3.24.4 tox-docker==3.1.0 invoke==1.6.0 -pre-commit==2.15.0 pytest-cov>=3.0.0 vulture>=2.3.0 ujson>=4.2.0 From 9be7eb5343d4fe76e7a460fbe9ec49c21c7b7928 Mon Sep 17 00:00:00 2001 From: Anas Date: Mon, 22 Nov 2021 12:09:07 +0200 Subject: [PATCH 5/8] Removed tests/test_commands.py from exclusion of black --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 50e165f719..9981f4f77e 100644 --- a/tox.ini +++ b/tox.ini @@ -134,7 +134,7 @@ deps_files = dev_requirements.txt docker = commands = flake8 - black --target-version py38 --check --diff --exclude="(tests\/test_commands\.py|\.tox|venv)" . + black --target-version py38 --check --diff . isort --check-only --diff . vulture redis whitelist.py --min-confidence 80 flynt --fail-on-change --dry-run . From a66e857d0cb735cc2976537e00c2488cbc07f597 Mon Sep 17 00:00:00 2001 From: Anas Date: Mon, 29 Nov 2021 14:23:08 +0200 Subject: [PATCH 6/8] Changed target version of black --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 9981f4f77e..90ff8e71ea 100644 --- a/tox.ini +++ b/tox.ini @@ -134,7 +134,7 @@ deps_files = dev_requirements.txt docker = commands = flake8 - black --target-version py38 --check --diff . + black --target-version py36 --check --diff . isort --check-only --diff . vulture redis whitelist.py --min-confidence 80 flynt --fail-on-change --dry-run . From 08f980efb92015bc961c3d59e0f7671a065c947e Mon Sep 17 00:00:00 2001 From: Anas Date: Mon, 29 Nov 2021 14:53:10 +0200 Subject: [PATCH 7/8] Made linters happy --- docs/conf.py | 27 +- redis/__init__.py | 74 ++-- redis/commands/__init__.py | 12 +- redis/commands/helpers.py | 12 +- redis/commands/search/commands.py | 19 +- redis/crc.py | 7 +- redis/exceptions.py | 12 +- setup.py | 7 +- tests/test_command_parser.py | 100 +++--- tests/test_connection_pool.py | 569 +++++++++++++++--------------- tests/test_helpers.py | 59 ++-- tests/test_json.py | 106 ++---- tests/test_monitor.py | 1 + tests/test_search.py | 410 ++++++++------------- 14 files changed, 657 insertions(+), 758 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 8520969d24..7e83e42156 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -55,7 +55,8 @@ # # The short X.Y version. import redis -version = '.'.join(redis.__version__.split(".")[0:2]) + +version = ".".join(redis.__version__.split(".")[0:2]) # The full version, including alpha/beta/rc tags. release = redis.__version__ @@ -108,13 +109,13 @@ # further. For a list of options available for each theme, see the # documentation. html_theme_options = { - 'display_version': True, - 'prev_next_buttons_location': 'bottom', - 'style_external_links': False, + "display_version": True, + "prev_next_buttons_location": "bottom", + "style_external_links": False, # Toc options - 'collapse_navigation': True, - 'sticky_navigation': True, - 'navigation_depth': 4, + "collapse_navigation": True, + "sticky_navigation": True, + "navigation_depth": 4, } # Add any paths that contain custom themes here, relative to this directory. @@ -201,11 +202,7 @@ # (source start file, target name, title, author, documentclass # [howto/manual]). latex_documents = [ - ("index", - "redis-py.tex", - "redis-py Documentation", - "Redis Inc", - "manual"), + ("index", "redis-py.tex", "redis-py Documentation", "Redis Inc", "manual"), ] # The name of an image file (relative to this directory) to place at the top of @@ -233,11 +230,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [( - "index", - "redis-py", - "redis-py Documentation", - ["Andy McCurdy"], 1)] +man_pages = [("index", "redis-py", "redis-py Documentation", ["Andy McCurdy"], 1)] # If true, show URL addresses after external links. # man_show_urls = False diff --git a/redis/__init__.py b/redis/__init__.py index bc7f3c9d9c..33b4369382 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -2,18 +2,11 @@ from redis.cluster import RedisCluster from redis.connection import ( BlockingConnectionPool, - ConnectionPool, Connection, + ConnectionPool, SSLConnection, - UnixDomainSocketConnection -) -from redis.sentinel import ( - Sentinel, - SentinelConnectionPool, - SentinelManagedConnection, - SentinelManagedSSLConnection, + UnixDomainSocketConnection, ) -from redis.utils import from_url from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, @@ -27,8 +20,15 @@ RedisError, ResponseError, TimeoutError, - WatchError + WatchError, +) +from redis.sentinel import ( + Sentinel, + SentinelConnectionPool, + SentinelManagedConnection, + SentinelManagedSSLConnection, ) +from redis.utils import from_url def int_or_str(value): @@ -41,33 +41,33 @@ def int_or_str(value): __version__ = "4.1.0rc1" -VERSION = tuple(map(int_or_str, __version__.split('.'))) +VERSION = tuple(map(int_or_str, __version__.split("."))) __all__ = [ - 'AuthenticationError', - 'AuthenticationWrongNumberOfArgsError', - 'BlockingConnectionPool', - 'BusyLoadingError', - 'ChildDeadlockedError', - 'Connection', - 'ConnectionError', - 'ConnectionPool', - 'DataError', - 'from_url', - 'InvalidResponse', - 'PubSubError', - 'ReadOnlyError', - 'Redis', - 'RedisCluster', - 'RedisError', - 'ResponseError', - 'Sentinel', - 'SentinelConnectionPool', - 'SentinelManagedConnection', - 'SentinelManagedSSLConnection', - 'SSLConnection', - 'StrictRedis', - 'TimeoutError', - 'UnixDomainSocketConnection', - 'WatchError', + "AuthenticationError", + "AuthenticationWrongNumberOfArgsError", + "BlockingConnectionPool", + "BusyLoadingError", + "ChildDeadlockedError", + "Connection", + "ConnectionError", + "ConnectionPool", + "DataError", + "from_url", + "InvalidResponse", + "PubSubError", + "ReadOnlyError", + "Redis", + "RedisCluster", + "RedisError", + "ResponseError", + "Sentinel", + "SentinelConnectionPool", + "SentinelManagedConnection", + "SentinelManagedSSLConnection", + "SSLConnection", + "StrictRedis", + "TimeoutError", + "UnixDomainSocketConnection", + "WatchError", ] diff --git a/redis/commands/__init__.py b/redis/commands/__init__.py index a4728d0ac4..bc1e78c60c 100644 --- a/redis/commands/__init__.py +++ b/redis/commands/__init__.py @@ -6,10 +6,10 @@ from .sentinel import SentinelCommands __all__ = [ - 'ClusterCommands', - 'CommandsParser', - 'CoreCommands', - 'list_or_args', - 'RedisModuleCommands', - 'SentinelCommands' + "ClusterCommands", + "CommandsParser", + "CoreCommands", + "list_or_args", + "RedisModuleCommands", + "SentinelCommands", ] diff --git a/redis/commands/helpers.py b/redis/commands/helpers.py index dc5705b80b..80dfd76a15 100644 --- a/redis/commands/helpers.py +++ b/redis/commands/helpers.py @@ -22,7 +22,7 @@ def list_or_args(keys, args): def nativestr(x): """Return the decoded binary string, or a string, depending on type.""" r = x.decode("utf-8", "replace") if isinstance(x, bytes) else x - if r == 'null': + if r == "null": return return r @@ -58,14 +58,14 @@ def parse_list_to_dict(response): res = {} for i in range(0, len(response), 2): if isinstance(response[i], list): - res['Child iterators'].append(parse_list_to_dict(response[i])) - elif isinstance(response[i+1], list): - res['Child iterators'] = [parse_list_to_dict(response[i+1])] + res["Child iterators"].append(parse_list_to_dict(response[i])) + elif isinstance(response[i + 1], list): + res["Child iterators"] = [parse_list_to_dict(response[i + 1])] else: try: - res[response[i]] = float(response[i+1]) + res[response[i]] = float(response[i + 1]) except (TypeError, ValueError): - res[response[i]] = response[i+1] + res[response[i]] = response[i + 1] return res diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index cfc82cd127..09e50855ba 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -1,13 +1,13 @@ import itertools import time +from ..helpers import parse_to_dict from ._util import to_string from .aggregation import AggregateRequest, AggregateResult, Cursor from .document import Document from .query import Query from .result import Result from .suggestion import SuggestionParser -from ..helpers import parse_to_dict NUMERIC = "NUMERIC" @@ -453,7 +453,7 @@ def profile(self, query, limited=False): cmd = [PROFILE_CMD, self.index_name, ""] if limited: cmd.append("LIMITED") - cmd.append('QUERY') + cmd.append("QUERY") if isinstance(query, AggregateRequest): cmd[2] = "AGGREGATE" @@ -462,19 +462,20 @@ def profile(self, query, limited=False): cmd[2] = "SEARCH" cmd += query.get_args() else: - raise ValueError("Must provide AggregateRequest object or " - "Query object.") + raise ValueError("Must provide AggregateRequest object or " "Query object.") res = self.execute_command(*cmd) if isinstance(query, AggregateRequest): result = self._get_AggregateResult(res[0], query, query._cursor) else: - result = Result(res[0], - not query._no_content, - duration=(time.time() - st) * 1000.0, - has_payload=query._with_payloads, - with_scores=query._with_scores,) + result = Result( + res[0], + not query._no_content, + duration=(time.time() - st) * 1000.0, + has_payload=query._with_payloads, + with_scores=query._with_scores, + ) return result, parse_to_dict(res[1]) diff --git a/redis/crc.py b/redis/crc.py index 7d2ee507be..c47e2acede 100644 --- a/redis/crc.py +++ b/redis/crc.py @@ -4,10 +4,7 @@ # For more information see: https://github.com/redis/redis/issues/2576 REDIS_CLUSTER_HASH_SLOTS = 16384 -__all__ = [ - "key_slot", - "REDIS_CLUSTER_HASH_SLOTS" -] +__all__ = ["key_slot", "REDIS_CLUSTER_HASH_SLOTS"] def key_slot(key, bucket=REDIS_CLUSTER_HASH_SLOTS): @@ -20,5 +17,5 @@ def key_slot(key, bucket=REDIS_CLUSTER_HASH_SLOTS): if start > -1: end = key.find(b"}", start + 1) if end > -1 and end != start + 1: - key = key[start + 1: end] + key = key[start + 1 : end] return crc_hqx(key, 0) % bucket diff --git a/redis/exceptions.py b/redis/exceptions.py index 4d5d530925..e37cad358e 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -91,6 +91,7 @@ class RedisClusterException(Exception): """ Base exception for the RedisCluster client """ + pass @@ -99,6 +100,7 @@ class ClusterError(RedisError): Cluster errors occurred multiple times, resulting in an exhaustion of the command execution TTL """ + pass @@ -112,6 +114,7 @@ class ClusterDownError(ClusterError, ResponseError): unavailable. It automatically returns available as soon as all the slots are covered again. """ + def __init__(self, resp): self.args = (resp,) self.message = resp @@ -136,8 +139,8 @@ def __init__(self, resp): """should only redirect to master node""" self.args = (resp,) self.message = resp - slot_id, new_node = resp.split(' ') - host, port = new_node.rsplit(':', 1) + slot_id, new_node = resp.split(" ") + host, port = new_node.rsplit(":", 1) self.slot_id = int(slot_id) self.node_addr = self.host, self.port = host, int(port) @@ -148,6 +151,7 @@ class TryAgainError(ResponseError): Operations on keys that don't exist or are - during resharding - split between the source and destination nodes, will generate a -TRYAGAIN error. """ + def __init__(self, *args, **kwargs): pass @@ -158,6 +162,7 @@ class ClusterCrossSlotError(ResponseError): A CROSSSLOT error is generated when keys in a request don't hash to the same slot. """ + message = "Keys in request don't hash to the same slot" @@ -167,6 +172,7 @@ class MovedError(AskError): A request sent to a node that doesn't serve this key will be replayed with a MOVED error that points to the correct node. """ + pass @@ -175,6 +181,7 @@ class MasterDownError(ClusterDownError): Error indicated MASTERDOWN error received from cluster. Link with MASTER is down and replica-serve-stale-data is set to 'no'. """ + pass @@ -186,4 +193,5 @@ class SlotNotCoveredError(RedisClusterException): If this error is raised the client should drop the current node layout and attempt to reconnect and refresh the node layout again """ + pass diff --git a/setup.py b/setup.py index 9acb501633..ee91298289 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,6 @@ #!/usr/bin/env python -from setuptools import setup, find_packages +from setuptools import find_packages, setup + import redis setup( @@ -24,8 +25,8 @@ author_email="oss@redis.com", python_requires=">=3.6", install_requires=[ - 'deprecated==1.2.3', - 'packaging==21.3', + "deprecated==1.2.3", + "packaging==21.3", ], classifiers=[ "Development Status :: 5 - Production/Stable", diff --git a/tests/test_command_parser.py b/tests/test_command_parser.py index ba129ba673..ad29e69f37 100644 --- a/tests/test_command_parser.py +++ b/tests/test_command_parser.py @@ -7,56 +7,74 @@ class TestCommandsParser: def test_init_commands(self, r): commands_parser = CommandsParser(r) assert commands_parser.commands is not None - assert 'get' in commands_parser.commands + assert "get" in commands_parser.commands def test_get_keys_predetermined_key_location(self, r): commands_parser = CommandsParser(r) - args1 = ['GET', 'foo'] - args2 = ['OBJECT', 'encoding', 'foo'] - args3 = ['MGET', 'foo', 'bar', 'foobar'] - assert commands_parser.get_keys(r, *args1) == ['foo'] - assert commands_parser.get_keys(r, *args2) == ['foo'] - assert commands_parser.get_keys(r, *args3) == ['foo', 'bar', 'foobar'] + args1 = ["GET", "foo"] + args2 = ["OBJECT", "encoding", "foo"] + args3 = ["MGET", "foo", "bar", "foobar"] + assert commands_parser.get_keys(r, *args1) == ["foo"] + assert commands_parser.get_keys(r, *args2) == ["foo"] + assert commands_parser.get_keys(r, *args3) == ["foo", "bar", "foobar"] @pytest.mark.filterwarnings("ignore:ResponseError") def test_get_moveable_keys(self, r): commands_parser = CommandsParser(r) - args1 = ['EVAL', 'return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}', 2, 'key1', - 'key2', 'first', 'second'] - args2 = ['XREAD', 'COUNT', 2, b'STREAMS', 'mystream', 'writers', 0, 0] - args3 = ['ZUNIONSTORE', 'out', 2, 'zset1', 'zset2', 'WEIGHTS', 2, 3] - args4 = ['GEORADIUS', 'Sicily', 15, 37, 200, 'km', 'WITHCOORD', - b'STORE', 'out'] - args5 = ['MEMORY USAGE', 'foo'] - args6 = ['MIGRATE', '192.168.1.34', 6379, "", 0, 5000, b'KEYS', - 'key1', 'key2', 'key3'] - args7 = ['MIGRATE', '192.168.1.34', 6379, "key1", 0, 5000] - args8 = ['STRALGO', 'LCS', 'STRINGS', 'string_a', 'string_b'] - args9 = ['STRALGO', 'LCS', 'KEYS', 'key1', 'key2'] + args1 = [ + "EVAL", + "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", + 2, + "key1", + "key2", + "first", + "second", + ] + args2 = ["XREAD", "COUNT", 2, b"STREAMS", "mystream", "writers", 0, 0] + args3 = ["ZUNIONSTORE", "out", 2, "zset1", "zset2", "WEIGHTS", 2, 3] + args4 = ["GEORADIUS", "Sicily", 15, 37, 200, "km", "WITHCOORD", b"STORE", "out"] + args5 = ["MEMORY USAGE", "foo"] + args6 = [ + "MIGRATE", + "192.168.1.34", + 6379, + "", + 0, + 5000, + b"KEYS", + "key1", + "key2", + "key3", + ] + args7 = ["MIGRATE", "192.168.1.34", 6379, "key1", 0, 5000] + args8 = ["STRALGO", "LCS", "STRINGS", "string_a", "string_b"] + args9 = ["STRALGO", "LCS", "KEYS", "key1", "key2"] - assert commands_parser.get_keys( - r, *args1).sort() == ['key1', 'key2'].sort() - assert commands_parser.get_keys( - r, *args2).sort() == ['mystream', 'writers'].sort() - assert commands_parser.get_keys( - r, *args3).sort() == ['out', 'zset1', 'zset2'].sort() - assert commands_parser.get_keys( - r, *args4).sort() == ['Sicily', 'out'].sort() - assert commands_parser.get_keys(r, *args5).sort() == ['foo'].sort() - assert commands_parser.get_keys( - r, *args6).sort() == ['key1', 'key2', 'key3'].sort() - assert commands_parser.get_keys(r, *args7).sort() == ['key1'].sort() + assert commands_parser.get_keys(r, *args1).sort() == ["key1", "key2"].sort() + assert ( + commands_parser.get_keys(r, *args2).sort() == ["mystream", "writers"].sort() + ) + assert ( + commands_parser.get_keys(r, *args3).sort() + == ["out", "zset1", "zset2"].sort() + ) + assert commands_parser.get_keys(r, *args4).sort() == ["Sicily", "out"].sort() + assert commands_parser.get_keys(r, *args5).sort() == ["foo"].sort() + assert ( + commands_parser.get_keys(r, *args6).sort() + == ["key1", "key2", "key3"].sort() + ) + assert commands_parser.get_keys(r, *args7).sort() == ["key1"].sort() assert commands_parser.get_keys(r, *args8) is None - assert commands_parser.get_keys( - r, *args9).sort() == ['key1', 'key2'].sort() + assert commands_parser.get_keys(r, *args9).sort() == ["key1", "key2"].sort() def test_get_pubsub_keys(self, r): commands_parser = CommandsParser(r) - args1 = ['PUBLISH', 'foo', 'bar'] - args2 = ['PUBSUB NUMSUB', 'foo1', 'foo2', 'foo3'] - args3 = ['PUBSUB channels', '*'] - args4 = ['SUBSCRIBE', 'foo1', 'foo2', 'foo3'] - assert commands_parser.get_keys(r, *args1) == ['foo'] - assert commands_parser.get_keys(r, *args2) == ['foo1', 'foo2', 'foo3'] - assert commands_parser.get_keys(r, *args3) == ['*'] - assert commands_parser.get_keys(r, *args4) == ['foo1', 'foo2', 'foo3'] + args1 = ["PUBLISH", "foo", "bar"] + args2 = ["PUBSUB NUMSUB", "foo1", "foo2", "foo3"] + args3 = ["PUBSUB channels", "*"] + args4 = ["SUBSCRIBE", "foo1", "foo2", "foo3"] + assert commands_parser.get_keys(r, *args1) == ["foo"] + assert commands_parser.get_keys(r, *args2) == ["foo1", "foo2", "foo3"] + assert commands_parser.get_keys(r, *args3) == ["*"] + assert commands_parser.get_keys(r, *args4) == ["foo1", "foo2", "foo3"] diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 288d43dfd7..2602af82e1 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -1,17 +1,15 @@ import os -import pytest import re -import redis import time +from threading import Thread from unittest import mock -from threading import Thread +import pytest + +import redis from redis.connection import ssl_available, to_bool -from .conftest import ( - skip_if_server_version_lt, - skip_if_redis_enterprise, - _get_client -) + +from .conftest import _get_client, skip_if_redis_enterprise, skip_if_server_version_lt from .test_pubsub import wait_for_message @@ -30,107 +28,122 @@ def can_read(self): class TestConnectionPool: - def get_pool(self, connection_kwargs=None, max_connections=None, - connection_class=redis.Connection): + def get_pool( + self, + connection_kwargs=None, + max_connections=None, + connection_class=redis.Connection, + ): connection_kwargs = connection_kwargs or {} pool = redis.ConnectionPool( connection_class=connection_class, max_connections=max_connections, - **connection_kwargs) + **connection_kwargs, + ) return pool def test_connection_creation(self): - connection_kwargs = {'foo': 'bar', 'biz': 'baz'} - pool = self.get_pool(connection_kwargs=connection_kwargs, - connection_class=DummyConnection) - connection = pool.get_connection('_') + connection_kwargs = {"foo": "bar", "biz": "baz"} + pool = self.get_pool( + connection_kwargs=connection_kwargs, connection_class=DummyConnection + ) + connection = pool.get_connection("_") assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs def test_multiple_connections(self, master_host): - connection_kwargs = {'host': master_host[0], 'port': master_host[1]} + connection_kwargs = {"host": master_host[0], "port": master_host[1]} pool = self.get_pool(connection_kwargs=connection_kwargs) - c1 = pool.get_connection('_') - c2 = pool.get_connection('_') + c1 = pool.get_connection("_") + c2 = pool.get_connection("_") assert c1 != c2 def test_max_connections(self, master_host): - connection_kwargs = {'host': master_host[0], 'port': master_host[1]} - pool = self.get_pool(max_connections=2, - connection_kwargs=connection_kwargs) - pool.get_connection('_') - pool.get_connection('_') + connection_kwargs = {"host": master_host[0], "port": master_host[1]} + pool = self.get_pool(max_connections=2, connection_kwargs=connection_kwargs) + pool.get_connection("_") + pool.get_connection("_") with pytest.raises(redis.ConnectionError): - pool.get_connection('_') + pool.get_connection("_") def test_reuse_previously_released_connection(self, master_host): - connection_kwargs = {'host': master_host[0], 'port': master_host[1]} + connection_kwargs = {"host": master_host[0], "port": master_host[1]} pool = self.get_pool(connection_kwargs=connection_kwargs) - c1 = pool.get_connection('_') + c1 = pool.get_connection("_") pool.release(c1) - c2 = pool.get_connection('_') + c2 = pool.get_connection("_") assert c1 == c2 def test_repr_contains_db_info_tcp(self): connection_kwargs = { - 'host': 'localhost', - 'port': 6379, - 'db': 1, - 'client_name': 'test-client' + "host": "localhost", + "port": 6379, + "db": 1, + "client_name": "test-client", } - pool = self.get_pool(connection_kwargs=connection_kwargs, - connection_class=redis.Connection) - expected = ('ConnectionPool>') + pool = self.get_pool( + connection_kwargs=connection_kwargs, connection_class=redis.Connection + ) + expected = ( + "ConnectionPool>" + ) assert repr(pool) == expected def test_repr_contains_db_info_unix(self): - connection_kwargs = { - 'path': '/abc', - 'db': 1, - 'client_name': 'test-client' - } - pool = self.get_pool(connection_kwargs=connection_kwargs, - connection_class=redis.UnixDomainSocketConnection) - expected = ('ConnectionPool>') + connection_kwargs = {"path": "/abc", "db": 1, "client_name": "test-client"} + pool = self.get_pool( + connection_kwargs=connection_kwargs, + connection_class=redis.UnixDomainSocketConnection, + ) + expected = ( + "ConnectionPool>" + ) assert repr(pool) == expected class TestBlockingConnectionPool: def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): connection_kwargs = connection_kwargs or {} - pool = redis.BlockingConnectionPool(connection_class=DummyConnection, - max_connections=max_connections, - timeout=timeout, - **connection_kwargs) + pool = redis.BlockingConnectionPool( + connection_class=DummyConnection, + max_connections=max_connections, + timeout=timeout, + **connection_kwargs, + ) return pool def test_connection_creation(self, master_host): - connection_kwargs = {'foo': 'bar', 'biz': 'baz', - 'host': master_host[0], 'port': master_host[1]} + connection_kwargs = { + "foo": "bar", + "biz": "baz", + "host": master_host[0], + "port": master_host[1], + } pool = self.get_pool(connection_kwargs=connection_kwargs) - connection = pool.get_connection('_') + connection = pool.get_connection("_") assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs def test_multiple_connections(self, master_host): - connection_kwargs = {'host': master_host[0], 'port': master_host[1]} + connection_kwargs = {"host": master_host[0], "port": master_host[1]} pool = self.get_pool(connection_kwargs=connection_kwargs) - c1 = pool.get_connection('_') - c2 = pool.get_connection('_') + c1 = pool.get_connection("_") + c2 = pool.get_connection("_") assert c1 != c2 def test_connection_pool_blocks_until_timeout(self, master_host): "When out of connections, block for timeout seconds, then raise" - connection_kwargs = {'host': master_host[0], 'port': master_host[1]} - pool = self.get_pool(max_connections=1, timeout=0.1, - connection_kwargs=connection_kwargs) - pool.get_connection('_') + connection_kwargs = {"host": master_host[0], "port": master_host[1]} + pool = self.get_pool( + max_connections=1, timeout=0.1, connection_kwargs=connection_kwargs + ) + pool.get_connection("_") start = time.time() with pytest.raises(redis.ConnectionError): - pool.get_connection('_') + pool.get_connection("_") # we should have waited at least 0.1 seconds assert time.time() - start >= 0.1 @@ -139,10 +152,11 @@ def test_connection_pool_blocks_until_conn_available(self, master_host): When out of connections, block until another connection is released to the pool """ - connection_kwargs = {'host': master_host[0], 'port': master_host[1]} - pool = self.get_pool(max_connections=1, timeout=2, - connection_kwargs=connection_kwargs) - c1 = pool.get_connection('_') + connection_kwargs = {"host": master_host[0], "port": master_host[1]} + pool = self.get_pool( + max_connections=1, timeout=2, connection_kwargs=connection_kwargs + ) + c1 = pool.get_connection("_") def target(): time.sleep(0.1) @@ -150,294 +164,295 @@ def target(): start = time.time() Thread(target=target).start() - pool.get_connection('_') + pool.get_connection("_") assert time.time() - start >= 0.1 def test_reuse_previously_released_connection(self, master_host): - connection_kwargs = {'host': master_host[0], 'port': master_host[1]} + connection_kwargs = {"host": master_host[0], "port": master_host[1]} pool = self.get_pool(connection_kwargs=connection_kwargs) - c1 = pool.get_connection('_') + c1 = pool.get_connection("_") pool.release(c1) - c2 = pool.get_connection('_') + c2 = pool.get_connection("_") assert c1 == c2 def test_repr_contains_db_info_tcp(self): pool = redis.ConnectionPool( - host='localhost', - port=6379, - client_name='test-client' + host="localhost", port=6379, client_name="test-client" + ) + expected = ( + "ConnectionPool>" ) - expected = ('ConnectionPool>') assert repr(pool) == expected def test_repr_contains_db_info_unix(self): pool = redis.ConnectionPool( connection_class=redis.UnixDomainSocketConnection, - path='abc', - client_name='test-client' + path="abc", + client_name="test-client", + ) + expected = ( + "ConnectionPool>" ) - expected = ('ConnectionPool>') assert repr(pool) == expected class TestConnectionPoolURLParsing: def test_hostname(self): - pool = redis.ConnectionPool.from_url('redis://my.host') + pool = redis.ConnectionPool.from_url("redis://my.host") assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { - 'host': 'my.host', + "host": "my.host", } def test_quoted_hostname(self): - pool = redis.ConnectionPool.from_url('redis://my %2F host %2B%3D+') + pool = redis.ConnectionPool.from_url("redis://my %2F host %2B%3D+") assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { - 'host': 'my / host +=+', + "host": "my / host +=+", } def test_port(self): - pool = redis.ConnectionPool.from_url('redis://localhost:6380') + pool = redis.ConnectionPool.from_url("redis://localhost:6380") assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { - 'host': 'localhost', - 'port': 6380, + "host": "localhost", + "port": 6380, } @skip_if_server_version_lt("6.0.0") def test_username(self): - pool = redis.ConnectionPool.from_url('redis://myuser:@localhost') + pool = redis.ConnectionPool.from_url("redis://myuser:@localhost") assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { - 'host': 'localhost', - 'username': 'myuser', + "host": "localhost", + "username": "myuser", } @skip_if_server_version_lt("6.0.0") def test_quoted_username(self): pool = redis.ConnectionPool.from_url( - 'redis://%2Fmyuser%2F%2B name%3D%24+:@localhost') + "redis://%2Fmyuser%2F%2B name%3D%24+:@localhost" + ) assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { - 'host': 'localhost', - 'username': '/myuser/+ name=$+', + "host": "localhost", + "username": "/myuser/+ name=$+", } def test_password(self): - pool = redis.ConnectionPool.from_url('redis://:mypassword@localhost') + pool = redis.ConnectionPool.from_url("redis://:mypassword@localhost") assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { - 'host': 'localhost', - 'password': 'mypassword', + "host": "localhost", + "password": "mypassword", } def test_quoted_password(self): pool = redis.ConnectionPool.from_url( - 'redis://:%2Fmypass%2F%2B word%3D%24+@localhost') + "redis://:%2Fmypass%2F%2B word%3D%24+@localhost" + ) assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { - 'host': 'localhost', - 'password': '/mypass/+ word=$+', + "host": "localhost", + "password": "/mypass/+ word=$+", } @skip_if_server_version_lt("6.0.0") def test_username_and_password(self): - pool = redis.ConnectionPool.from_url('redis://myuser:mypass@localhost') + pool = redis.ConnectionPool.from_url("redis://myuser:mypass@localhost") assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { - 'host': 'localhost', - 'username': 'myuser', - 'password': 'mypass', + "host": "localhost", + "username": "myuser", + "password": "mypass", } def test_db_as_argument(self): - pool = redis.ConnectionPool.from_url('redis://localhost', db=1) + pool = redis.ConnectionPool.from_url("redis://localhost", db=1) assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { - 'host': 'localhost', - 'db': 1, + "host": "localhost", + "db": 1, } def test_db_in_path(self): - pool = redis.ConnectionPool.from_url('redis://localhost/2', db=1) + pool = redis.ConnectionPool.from_url("redis://localhost/2", db=1) assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { - 'host': 'localhost', - 'db': 2, + "host": "localhost", + "db": 2, } def test_db_in_querystring(self): - pool = redis.ConnectionPool.from_url('redis://localhost/2?db=3', - db=1) + pool = redis.ConnectionPool.from_url("redis://localhost/2?db=3", db=1) assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { - 'host': 'localhost', - 'db': 3, + "host": "localhost", + "db": 3, } def test_extra_typed_querystring_options(self): pool = redis.ConnectionPool.from_url( - 'redis://localhost/2?socket_timeout=20&socket_connect_timeout=10' - '&socket_keepalive=&retry_on_timeout=Yes&max_connections=10' + "redis://localhost/2?socket_timeout=20&socket_connect_timeout=10" + "&socket_keepalive=&retry_on_timeout=Yes&max_connections=10" ) assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { - 'host': 'localhost', - 'db': 2, - 'socket_timeout': 20.0, - 'socket_connect_timeout': 10.0, - 'retry_on_timeout': True, + "host": "localhost", + "db": 2, + "socket_timeout": 20.0, + "socket_connect_timeout": 10.0, + "retry_on_timeout": True, } assert pool.max_connections == 10 def test_boolean_parsing(self): for expected, value in ( - (None, None), - (None, ''), - (False, 0), (False, '0'), - (False, 'f'), (False, 'F'), (False, 'False'), - (False, 'n'), (False, 'N'), (False, 'No'), - (True, 1), (True, '1'), - (True, 'y'), (True, 'Y'), (True, 'Yes'), + (None, None), + (None, ""), + (False, 0), + (False, "0"), + (False, "f"), + (False, "F"), + (False, "False"), + (False, "n"), + (False, "N"), + (False, "No"), + (True, 1), + (True, "1"), + (True, "y"), + (True, "Y"), + (True, "Yes"), ): assert expected is to_bool(value) def test_client_name_in_querystring(self): - pool = redis.ConnectionPool.from_url( - 'redis://location?client_name=test-client' - ) - assert pool.connection_kwargs['client_name'] == 'test-client' + pool = redis.ConnectionPool.from_url("redis://location?client_name=test-client") + assert pool.connection_kwargs["client_name"] == "test-client" def test_invalid_extra_typed_querystring_options(self): with pytest.raises(ValueError): redis.ConnectionPool.from_url( - 'redis://localhost/2?socket_timeout=_&' - 'socket_connect_timeout=abc' + "redis://localhost/2?socket_timeout=_&" "socket_connect_timeout=abc" ) def test_extra_querystring_options(self): - pool = redis.ConnectionPool.from_url('redis://localhost?a=1&b=2') + pool = redis.ConnectionPool.from_url("redis://localhost?a=1&b=2") assert pool.connection_class == redis.Connection - assert pool.connection_kwargs == { - 'host': 'localhost', - 'a': '1', - 'b': '2' - } + assert pool.connection_kwargs == {"host": "localhost", "a": "1", "b": "2"} def test_calling_from_subclass_returns_correct_instance(self): - pool = redis.BlockingConnectionPool.from_url('redis://localhost') + pool = redis.BlockingConnectionPool.from_url("redis://localhost") assert isinstance(pool, redis.BlockingConnectionPool) def test_client_creates_connection_pool(self): - r = redis.Redis.from_url('redis://myhost') + r = redis.Redis.from_url("redis://myhost") assert r.connection_pool.connection_class == redis.Connection assert r.connection_pool.connection_kwargs == { - 'host': 'myhost', + "host": "myhost", } def test_invalid_scheme_raises_error(self): with pytest.raises(ValueError) as cm: - redis.ConnectionPool.from_url('localhost') + redis.ConnectionPool.from_url("localhost") assert str(cm.value) == ( - 'Redis URL must specify one of the following schemes ' - '(redis://, rediss://, unix://)' + "Redis URL must specify one of the following schemes " + "(redis://, rediss://, unix://)" ) class TestConnectionPoolUnixSocketURLParsing: def test_defaults(self): - pool = redis.ConnectionPool.from_url('unix:///socket') + pool = redis.ConnectionPool.from_url("unix:///socket") assert pool.connection_class == redis.UnixDomainSocketConnection assert pool.connection_kwargs == { - 'path': '/socket', + "path": "/socket", } @skip_if_server_version_lt("6.0.0") def test_username(self): - pool = redis.ConnectionPool.from_url('unix://myuser:@/socket') + pool = redis.ConnectionPool.from_url("unix://myuser:@/socket") assert pool.connection_class == redis.UnixDomainSocketConnection assert pool.connection_kwargs == { - 'path': '/socket', - 'username': 'myuser', + "path": "/socket", + "username": "myuser", } @skip_if_server_version_lt("6.0.0") def test_quoted_username(self): pool = redis.ConnectionPool.from_url( - 'unix://%2Fmyuser%2F%2B name%3D%24+:@/socket') + "unix://%2Fmyuser%2F%2B name%3D%24+:@/socket" + ) assert pool.connection_class == redis.UnixDomainSocketConnection assert pool.connection_kwargs == { - 'path': '/socket', - 'username': '/myuser/+ name=$+', + "path": "/socket", + "username": "/myuser/+ name=$+", } def test_password(self): - pool = redis.ConnectionPool.from_url('unix://:mypassword@/socket') + pool = redis.ConnectionPool.from_url("unix://:mypassword@/socket") assert pool.connection_class == redis.UnixDomainSocketConnection assert pool.connection_kwargs == { - 'path': '/socket', - 'password': 'mypassword', + "path": "/socket", + "password": "mypassword", } def test_quoted_password(self): pool = redis.ConnectionPool.from_url( - 'unix://:%2Fmypass%2F%2B word%3D%24+@/socket') + "unix://:%2Fmypass%2F%2B word%3D%24+@/socket" + ) assert pool.connection_class == redis.UnixDomainSocketConnection assert pool.connection_kwargs == { - 'path': '/socket', - 'password': '/mypass/+ word=$+', + "path": "/socket", + "password": "/mypass/+ word=$+", } def test_quoted_path(self): pool = redis.ConnectionPool.from_url( - 'unix://:mypassword@/my%2Fpath%2Fto%2F..%2F+_%2B%3D%24ocket') + "unix://:mypassword@/my%2Fpath%2Fto%2F..%2F+_%2B%3D%24ocket" + ) assert pool.connection_class == redis.UnixDomainSocketConnection assert pool.connection_kwargs == { - 'path': '/my/path/to/../+_+=$ocket', - 'password': 'mypassword', + "path": "/my/path/to/../+_+=$ocket", + "password": "mypassword", } def test_db_as_argument(self): - pool = redis.ConnectionPool.from_url('unix:///socket', db=1) + pool = redis.ConnectionPool.from_url("unix:///socket", db=1) assert pool.connection_class == redis.UnixDomainSocketConnection assert pool.connection_kwargs == { - 'path': '/socket', - 'db': 1, + "path": "/socket", + "db": 1, } def test_db_in_querystring(self): - pool = redis.ConnectionPool.from_url('unix:///socket?db=2', db=1) + pool = redis.ConnectionPool.from_url("unix:///socket?db=2", db=1) assert pool.connection_class == redis.UnixDomainSocketConnection assert pool.connection_kwargs == { - 'path': '/socket', - 'db': 2, + "path": "/socket", + "db": 2, } def test_client_name_in_querystring(self): - pool = redis.ConnectionPool.from_url( - 'redis://location?client_name=test-client' - ) - assert pool.connection_kwargs['client_name'] == 'test-client' + pool = redis.ConnectionPool.from_url("redis://location?client_name=test-client") + assert pool.connection_kwargs["client_name"] == "test-client" def test_extra_querystring_options(self): - pool = redis.ConnectionPool.from_url('unix:///socket?a=1&b=2') + pool = redis.ConnectionPool.from_url("unix:///socket?a=1&b=2") assert pool.connection_class == redis.UnixDomainSocketConnection - assert pool.connection_kwargs == { - 'path': '/socket', - 'a': '1', - 'b': '2' - } + assert pool.connection_kwargs == {"path": "/socket", "a": "1", "b": "2"} @pytest.mark.skipif(not ssl_available, reason="SSL not installed") class TestSSLConnectionURLParsing: def test_host(self): - pool = redis.ConnectionPool.from_url('rediss://my.host') + pool = redis.ConnectionPool.from_url("rediss://my.host") assert pool.connection_class == redis.SSLConnection assert pool.connection_kwargs == { - 'host': 'my.host', + "host": "my.host", } def test_cert_reqs_options(self): @@ -447,25 +462,20 @@ class DummyConnectionPool(redis.ConnectionPool): def get_connection(self, *args, **kwargs): return self.make_connection() - pool = DummyConnectionPool.from_url( - 'rediss://?ssl_cert_reqs=none') - assert pool.get_connection('_').cert_reqs == ssl.CERT_NONE + pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=none") + assert pool.get_connection("_").cert_reqs == ssl.CERT_NONE - pool = DummyConnectionPool.from_url( - 'rediss://?ssl_cert_reqs=optional') - assert pool.get_connection('_').cert_reqs == ssl.CERT_OPTIONAL + pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=optional") + assert pool.get_connection("_").cert_reqs == ssl.CERT_OPTIONAL - pool = DummyConnectionPool.from_url( - 'rediss://?ssl_cert_reqs=required') - assert pool.get_connection('_').cert_reqs == ssl.CERT_REQUIRED + pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=required") + assert pool.get_connection("_").cert_reqs == ssl.CERT_REQUIRED - pool = DummyConnectionPool.from_url( - 'rediss://?ssl_check_hostname=False') - assert pool.get_connection('_').check_hostname is False + pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=False") + assert pool.get_connection("_").check_hostname is False - pool = DummyConnectionPool.from_url( - 'rediss://?ssl_check_hostname=True') - assert pool.get_connection('_').check_hostname is True + pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=True") + assert pool.get_connection("_").check_hostname is True class TestConnection: @@ -485,7 +495,7 @@ def test_on_connect_error(self): assert not pool._available_connections[0]._sock @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.8.8') + @skip_if_server_version_lt("2.8.8") @skip_if_redis_enterprise def test_busy_loading_disconnects_socket(self, r): """ @@ -493,11 +503,11 @@ def test_busy_loading_disconnects_socket(self, r): disconnected and a BusyLoadingError raised """ with pytest.raises(redis.BusyLoadingError): - r.execute_command('DEBUG', 'ERROR', 'LOADING fake message') + r.execute_command("DEBUG", "ERROR", "LOADING fake message") assert not r.connection._sock @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.8.8') + @skip_if_server_version_lt("2.8.8") @skip_if_redis_enterprise def test_busy_loading_from_pipeline_immediate_command(self, r): """ @@ -506,15 +516,14 @@ def test_busy_loading_from_pipeline_immediate_command(self, r): """ pipe = r.pipeline() with pytest.raises(redis.BusyLoadingError): - pipe.immediate_execute_command('DEBUG', 'ERROR', - 'LOADING fake message') + pipe.immediate_execute_command("DEBUG", "ERROR", "LOADING fake message") pool = r.connection_pool assert not pipe.connection assert len(pool._available_connections) == 1 assert not pool._available_connections[0]._sock @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.8.8') + @skip_if_server_version_lt("2.8.8") @skip_if_redis_enterprise def test_busy_loading_from_pipeline(self, r): """ @@ -522,7 +531,7 @@ def test_busy_loading_from_pipeline(self, r): regardless of the raise_on_error flag. """ pipe = r.pipeline() - pipe.execute_command('DEBUG', 'ERROR', 'LOADING fake message') + pipe.execute_command("DEBUG", "ERROR", "LOADING fake message") with pytest.raises(redis.BusyLoadingError): pipe.execute() pool = r.connection_pool @@ -530,31 +539,31 @@ def test_busy_loading_from_pipeline(self, r): assert len(pool._available_connections) == 1 assert not pool._available_connections[0]._sock - @skip_if_server_version_lt('2.8.8') + @skip_if_server_version_lt("2.8.8") @skip_if_redis_enterprise def test_read_only_error(self, r): "READONLY errors get turned in ReadOnlyError exceptions" with pytest.raises(redis.ReadOnlyError): - r.execute_command('DEBUG', 'ERROR', 'READONLY blah blah') + r.execute_command("DEBUG", "ERROR", "READONLY blah blah") def test_connect_from_url_tcp(self): - connection = redis.Redis.from_url('redis://localhost') + connection = redis.Redis.from_url("redis://localhost") pool = connection.connection_pool - assert re.match('(.*)<(.*)<(.*)>>', repr(pool)).groups() == ( - 'ConnectionPool', - 'Connection', - 'host=localhost,port=6379,db=0', + assert re.match("(.*)<(.*)<(.*)>>", repr(pool)).groups() == ( + "ConnectionPool", + "Connection", + "host=localhost,port=6379,db=0", ) def test_connect_from_url_unix(self): - connection = redis.Redis.from_url('unix:///path/to/socket') + connection = redis.Redis.from_url("unix:///path/to/socket") pool = connection.connection_pool - assert re.match('(.*)<(.*)<(.*)>>', repr(pool)).groups() == ( - 'ConnectionPool', - 'UnixDomainSocketConnection', - 'path=/path/to/socket,db=0', + assert re.match("(.*)<(.*)<(.*)>>", repr(pool)).groups() == ( + "ConnectionPool", + "UnixDomainSocketConnection", + "path=/path/to/socket,db=0", ) @skip_if_redis_enterprise @@ -564,28 +573,27 @@ def test_connect_no_auth_supplied_when_required(self, r): password but one isn't supplied. """ with pytest.raises(redis.AuthenticationError): - r.execute_command('DEBUG', 'ERROR', - 'ERR Client sent AUTH, but no password is set') + r.execute_command( + "DEBUG", "ERROR", "ERR Client sent AUTH, but no password is set" + ) @skip_if_redis_enterprise def test_connect_invalid_password_supplied(self, r): "AuthenticationError should be raised when sending the wrong password" with pytest.raises(redis.AuthenticationError): - r.execute_command('DEBUG', 'ERROR', 'ERR invalid password') + r.execute_command("DEBUG", "ERROR", "ERR invalid password") @pytest.mark.onlynoncluster class TestMultiConnectionClient: @pytest.fixture() def r(self, request): - return _get_client(redis.Redis, - request, - single_connection_client=False) + return _get_client(redis.Redis, request, single_connection_client=False) def test_multi_connection_command(self, r): assert not r.connection - assert r.set('a', '123') - assert r.get('a') == b'123' + assert r.set("a", "123") + assert r.get("a") == b"123" @pytest.mark.onlynoncluster @@ -594,8 +602,7 @@ class TestHealthCheck: @pytest.fixture() def r(self, request): - return _get_client(redis.Redis, request, - health_check_interval=self.interval) + return _get_client(redis.Redis, request, health_check_interval=self.interval) def assert_interval_advanced(self, connection): diff = connection.next_health_check - time.time() @@ -608,61 +615,66 @@ def test_health_check_runs(self, r): def test_arbitrary_command_invokes_health_check(self, r): # invoke a command to make sure the connection is entirely setup - r.get('foo') + r.get("foo") r.connection.next_health_check = time.time() - with mock.patch.object(r.connection, 'send_command', - wraps=r.connection.send_command) as m: - r.get('foo') - m.assert_called_with('PING', check_health=False) + with mock.patch.object( + r.connection, "send_command", wraps=r.connection.send_command + ) as m: + r.get("foo") + m.assert_called_with("PING", check_health=False) self.assert_interval_advanced(r.connection) def test_arbitrary_command_advances_next_health_check(self, r): - r.get('foo') + r.get("foo") next_health_check = r.connection.next_health_check - r.get('foo') + r.get("foo") assert next_health_check < r.connection.next_health_check def test_health_check_not_invoked_within_interval(self, r): - r.get('foo') - with mock.patch.object(r.connection, 'send_command', - wraps=r.connection.send_command) as m: - r.get('foo') - ping_call_spec = (('PING',), {'check_health': False}) + r.get("foo") + with mock.patch.object( + r.connection, "send_command", wraps=r.connection.send_command + ) as m: + r.get("foo") + ping_call_spec = (("PING",), {"check_health": False}) assert ping_call_spec not in m.call_args_list def test_health_check_in_pipeline(self, r): with r.pipeline(transaction=False) as pipe: - pipe.connection = pipe.connection_pool.get_connection('_') + pipe.connection = pipe.connection_pool.get_connection("_") pipe.connection.next_health_check = 0 - with mock.patch.object(pipe.connection, 'send_command', - wraps=pipe.connection.send_command) as m: - responses = pipe.set('foo', 'bar').get('foo').execute() - m.assert_any_call('PING', check_health=False) - assert responses == [True, b'bar'] + with mock.patch.object( + pipe.connection, "send_command", wraps=pipe.connection.send_command + ) as m: + responses = pipe.set("foo", "bar").get("foo").execute() + m.assert_any_call("PING", check_health=False) + assert responses == [True, b"bar"] def test_health_check_in_transaction(self, r): with r.pipeline(transaction=True) as pipe: - pipe.connection = pipe.connection_pool.get_connection('_') + pipe.connection = pipe.connection_pool.get_connection("_") pipe.connection.next_health_check = 0 - with mock.patch.object(pipe.connection, 'send_command', - wraps=pipe.connection.send_command) as m: - responses = pipe.set('foo', 'bar').get('foo').execute() - m.assert_any_call('PING', check_health=False) - assert responses == [True, b'bar'] + with mock.patch.object( + pipe.connection, "send_command", wraps=pipe.connection.send_command + ) as m: + responses = pipe.set("foo", "bar").get("foo").execute() + m.assert_any_call("PING", check_health=False) + assert responses == [True, b"bar"] def test_health_check_in_watched_pipeline(self, r): - r.set('foo', 'bar') + r.set("foo", "bar") with r.pipeline(transaction=False) as pipe: - pipe.connection = pipe.connection_pool.get_connection('_') + pipe.connection = pipe.connection_pool.get_connection("_") pipe.connection.next_health_check = 0 - with mock.patch.object(pipe.connection, 'send_command', - wraps=pipe.connection.send_command) as m: - pipe.watch('foo') + with mock.patch.object( + pipe.connection, "send_command", wraps=pipe.connection.send_command + ) as m: + pipe.watch("foo") # the health check should be called when watching - m.assert_called_with('PING', check_health=False) + m.assert_called_with("PING", check_health=False) self.assert_interval_advanced(pipe.connection) - assert pipe.get('foo') == b'bar' + assert pipe.get("foo") == b"bar" # reset the mock to clear the call list and schedule another # health check @@ -670,27 +682,28 @@ def test_health_check_in_watched_pipeline(self, r): pipe.connection.next_health_check = 0 pipe.multi() - responses = pipe.set('foo', 'not-bar').get('foo').execute() - assert responses == [True, b'not-bar'] - m.assert_any_call('PING', check_health=False) + responses = pipe.set("foo", "not-bar").get("foo").execute() + assert responses == [True, b"not-bar"] + m.assert_any_call("PING", check_health=False) def test_health_check_in_pubsub_before_subscribe(self, r): "A health check happens before the first [p]subscribe" p = r.pubsub() - p.connection = p.connection_pool.get_connection('_') + p.connection = p.connection_pool.get_connection("_") p.connection.next_health_check = 0 - with mock.patch.object(p.connection, 'send_command', - wraps=p.connection.send_command) as m: + with mock.patch.object( + p.connection, "send_command", wraps=p.connection.send_command + ) as m: assert not p.subscribed - p.subscribe('foo') + p.subscribe("foo") # the connection is not yet in pubsub mode, so the normal # ping/pong within connection.send_command should check # the health of the connection - m.assert_any_call('PING', check_health=False) + m.assert_any_call("PING", check_health=False) self.assert_interval_advanced(p.connection) subscribe_message = wait_for_message(p) - assert subscribe_message['type'] == 'subscribe' + assert subscribe_message["type"] == "subscribe" def test_health_check_in_pubsub_after_subscribed(self, r): """ @@ -698,38 +711,38 @@ def test_health_check_in_pubsub_after_subscribed(self, r): connection health """ p = r.pubsub() - p.connection = p.connection_pool.get_connection('_') + p.connection = p.connection_pool.get_connection("_") p.connection.next_health_check = 0 - with mock.patch.object(p.connection, 'send_command', - wraps=p.connection.send_command) as m: - p.subscribe('foo') + with mock.patch.object( + p.connection, "send_command", wraps=p.connection.send_command + ) as m: + p.subscribe("foo") subscribe_message = wait_for_message(p) - assert subscribe_message['type'] == 'subscribe' + assert subscribe_message["type"] == "subscribe" self.assert_interval_advanced(p.connection) # because we weren't subscribed when sending the subscribe # message to 'foo', the connection's standard check_health ran # prior to subscribing. - m.assert_any_call('PING', check_health=False) + m.assert_any_call("PING", check_health=False) p.connection.next_health_check = 0 m.reset_mock() - p.subscribe('bar') + p.subscribe("bar") # the second subscribe issues exactly only command (the subscribe) # and the health check is not invoked - m.assert_called_once_with('SUBSCRIBE', 'bar', check_health=False) + m.assert_called_once_with("SUBSCRIBE", "bar", check_health=False) # since no message has been read since the health check was # reset, it should still be 0 assert p.connection.next_health_check == 0 subscribe_message = wait_for_message(p) - assert subscribe_message['type'] == 'subscribe' + assert subscribe_message["type"] == "subscribe" assert wait_for_message(p) is None # now that the connection is subscribed, the pubsub health # check should have taken over and include the HEALTH_CHECK_MESSAGE - m.assert_any_call('PING', p.HEALTH_CHECK_MESSAGE, - check_health=False) + m.assert_any_call("PING", p.HEALTH_CHECK_MESSAGE, check_health=False) self.assert_interval_advanced(p.connection) def test_health_check_in_pubsub_poll(self, r): @@ -738,12 +751,13 @@ def test_health_check_in_pubsub_poll(self, r): check the connection's health. """ p = r.pubsub() - p.connection = p.connection_pool.get_connection('_') - with mock.patch.object(p.connection, 'send_command', - wraps=p.connection.send_command) as m: - p.subscribe('foo') + p.connection = p.connection_pool.get_connection("_") + with mock.patch.object( + p.connection, "send_command", wraps=p.connection.send_command + ) as m: + p.subscribe("foo") subscribe_message = wait_for_message(p) - assert subscribe_message['type'] == 'subscribe' + assert subscribe_message["type"] == "subscribe" self.assert_interval_advanced(p.connection) # polling the connection before the health check interval @@ -759,6 +773,5 @@ def test_health_check_in_pubsub_poll(self, r): # should be advanced p.connection.next_health_check = 0 assert wait_for_message(p) is None - m.assert_called_with('PING', p.HEALTH_CHECK_MESSAGE, - check_health=False) + m.assert_called_with("PING", p.HEALTH_CHECK_MESSAGE, check_health=False) self.assert_interval_advanced(p.connection) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 402eccf0a2..359582909f 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,19 +1,20 @@ import string + from redis.commands.helpers import ( delist, list_or_args, nativestr, + parse_to_dict, parse_to_list, quote_string, random_string, - parse_to_dict ) def test_list_or_args(): k = ["hello, world"] a = ["some", "argument", "list"] - assert list_or_args(k, a) == k+a + assert list_or_args(k, a) == k + a for i in ["banana", b"banana"]: assert list_or_args(i, a) == [i] + a @@ -22,42 +23,50 @@ def test_list_or_args(): def test_parse_to_list(): assert parse_to_list(None) == [] r = ["hello", b"my name", "45", "555.55", "is simon!", None] - assert parse_to_list(r) == \ - ["hello", "my name", 45, 555.55, "is simon!", None] + assert parse_to_list(r) == ["hello", "my name", 45, 555.55, "is simon!", None] def test_parse_to_dict(): assert parse_to_dict(None) == {} - r = [['Some number', '1.0345'], - ['Some string', 'hello'], - ['Child iterators', - ['Time', '0.2089', 'Counter', 3, 'Child iterators', - ['Type', 'bar', 'Time', '0.0729', 'Counter', 3], - ['Type', 'barbar', 'Time', '0.058', 'Counter', 3]]]] + r = [ + ["Some number", "1.0345"], + ["Some string", "hello"], + [ + "Child iterators", + [ + "Time", + "0.2089", + "Counter", + 3, + "Child iterators", + ["Type", "bar", "Time", "0.0729", "Counter", 3], + ["Type", "barbar", "Time", "0.058", "Counter", 3], + ], + ], + ] assert parse_to_dict(r) == { - 'Child iterators': { - 'Child iterators': [ - {'Counter': 3.0, 'Time': 0.0729, 'Type': 'bar'}, - {'Counter': 3.0, 'Time': 0.058, 'Type': 'barbar'} + "Child iterators": { + "Child iterators": [ + {"Counter": 3.0, "Time": 0.0729, "Type": "bar"}, + {"Counter": 3.0, "Time": 0.058, "Type": "barbar"}, ], - 'Counter': 3.0, - 'Time': 0.2089 + "Counter": 3.0, + "Time": 0.2089, }, - 'Some number': 1.0345, - 'Some string': 'hello' + "Some number": 1.0345, + "Some string": "hello", } def test_nativestr(): - assert nativestr('teststr') == 'teststr' - assert nativestr(b'teststr') == 'teststr' - assert nativestr('null') is None + assert nativestr("teststr") == "teststr" + assert nativestr(b"teststr") == "teststr" + assert nativestr("null") is None def test_delist(): assert delist(None) is None - assert delist([b'hello', 'world', b'banana']) == \ - ['hello', 'world', 'banana'] + assert delist([b"hello", "world", b"banana"]) == ["hello", "world", "banana"] def test_random_string(): @@ -69,5 +78,5 @@ def test_random_string(): def test_quote_string(): assert quote_string("hello world!") == '"hello world!"' - assert quote_string('') == '""' - assert quote_string('hello world!') == '"hello world!"' + assert quote_string("") == '""' + assert quote_string("hello world!") == '"hello world!"' diff --git a/tests/test_json.py b/tests/test_json.py index 187bfe2289..1686f9d05e 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1,8 +1,10 @@ import pytest + import redis -from redis.commands.json.path import Path from redis import exceptions -from redis.commands.json.decoders import unstring, decode_list +from redis.commands.json.decoders import decode_list, unstring +from redis.commands.json.path import Path + from .conftest import skip_ifmodversion_lt @@ -48,9 +50,7 @@ def test_json_get_jset(client): @pytest.mark.redismod def test_nonascii_setgetdelete(client): assert client.json().set("notascii", Path.rootPath(), "hyvää-élève") - assert "hyvää-élève" == client.json().get( - "notascii", - no_escape=True) + assert "hyvää-élève" == client.json().get("notascii", no_escape=True) assert 1 == client.json().delete("notascii") assert client.exists("notascii") == 0 @@ -179,7 +179,7 @@ def test_arrinsert(client): 1, 2, 3, - ] + ], ) assert [0, 1, 2, 3, 4] == client.json().get("arr") @@ -307,8 +307,7 @@ def test_json_delete_with_dollar(client): r = client.json().get("doc1", "$") assert r == [{"nested": {"b": 3}}] - doc2 = {"a": {"a": 2, "b": 3}, "b": [ - "a", "b"], "nested": {"b": [True, "a", "b"]}} + doc2 = {"a": {"a": 2, "b": 3}, "b": ["a", "b"], "nested": {"b": [True, "a", "b"]}} assert client.json().set("doc2", "$", doc2) assert client.json().delete("doc2", "$..a") == 1 res = client.json().get("doc2", "$") @@ -361,8 +360,7 @@ def test_json_forget_with_dollar(client): r = client.json().get("doc1", "$") assert r == [{"nested": {"b": 3}}] - doc2 = {"a": {"a": 2, "b": 3}, "b": [ - "a", "b"], "nested": {"b": [True, "a", "b"]}} + doc2 = {"a": {"a": 2, "b": 3}, "b": ["a", "b"], "nested": {"b": [True, "a", "b"]}} assert client.json().set("doc2", "$", doc2) assert client.json().forget("doc2", "$..a") == 1 res = client.json().get("doc2", "$") @@ -413,16 +411,12 @@ def test_json_mget_dollar(client): client.json().set( "doc1", "$", - {"a": 1, - "b": 2, - "nested": {"a": 3}, - "c": None, "nested2": {"a": None}}, + {"a": 1, "b": 2, "nested": {"a": 3}, "c": None, "nested2": {"a": None}}, ) client.json().set( "doc2", "$", - {"a": 4, "b": 5, "nested": {"a": 6}, - "c": None, "nested2": {"a": [None]}}, + {"a": 4, "b": 5, "nested": {"a": 6}, "c": None, "nested2": {"a": [None]}}, ) # Compare also to single JSON.GET assert client.json().get("doc1", "$..a") == [1, 3, None] @@ -431,8 +425,7 @@ def test_json_mget_dollar(client): # Test mget with single path client.json().mget("doc1", "$..a") == [1, 3, None] # Test mget with multi path - client.json().mget(["doc1", "doc2"], "$..a") == [ - [1, 3, None], [4, 6, [None]]] + client.json().mget(["doc1", "doc2"], "$..a") == [[1, 3, None], [4, 6, [None]]] # Test missing key client.json().mget(["doc1", "missing_doc"], "$..a") == [[1, 3, None], None] @@ -444,15 +437,11 @@ def test_json_mget_dollar(client): def test_numby_commands_dollar(client): # Test NUMINCRBY - client.json().set( - "doc1", - "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]}) + client.json().set("doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]}) # Test multi - assert client.json().numincrby("doc1", "$..a", 2) == \ - [None, 4, 7.0, None] + assert client.json().numincrby("doc1", "$..a", 2) == [None, 4, 7.0, None] - assert client.json().numincrby("doc1", "$..a", 2.5) == \ - [None, 6.5, 9.5, None] + assert client.json().numincrby("doc1", "$..a", 2.5) == [None, 6.5, 9.5, None] # Test single assert client.json().numincrby("doc1", "$.b[1].a", 2) == [11.5] @@ -460,15 +449,12 @@ def test_numby_commands_dollar(client): assert client.json().numincrby("doc1", "$.b[1].a", 3.5) == [15.0] # Test NUMMULTBY - client.json().set("doc1", "$", {"a": "b", "b": [ - {"a": 2}, {"a": 5.0}, {"a": "c"}]}) + client.json().set("doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]}) # test list with pytest.deprecated_call(): - assert client.json().nummultby("doc1", "$..a", 2) == \ - [None, 4, 10, None] - assert client.json().nummultby("doc1", "$..a", 2.5) == \ - [None, 10.0, 25.0, None] + assert client.json().nummultby("doc1", "$..a", 2) == [None, 4, 10, None] + assert client.json().nummultby("doc1", "$..a", 2.5) == [None, 10.0, 25.0, None] # Test single with pytest.deprecated_call(): @@ -482,13 +468,11 @@ def test_numby_commands_dollar(client): client.json().nummultby("non_existing_doc", "$..a", 2) # Test legacy NUMINCRBY - client.json().set("doc1", "$", {"a": "b", "b": [ - {"a": 2}, {"a": 5.0}, {"a": "c"}]}) + client.json().set("doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]}) client.json().numincrby("doc1", ".b[0].a", 3) == 5 # Test legacy NUMMULTBY - client.json().set("doc1", "$", {"a": "b", "b": [ - {"a": 2}, {"a": 5.0}, {"a": "c"}]}) + client.json().set("doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]}) with pytest.deprecated_call(): client.json().nummultby("doc1", ".b[0].a", 3) == 6 @@ -498,8 +482,7 @@ def test_numby_commands_dollar(client): def test_strappend_dollar(client): client.json().set( - "doc1", "$", {"a": "foo", "nested1": { - "a": "hello"}, "nested2": {"a": 31}} + "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} ) # Test multi client.json().strappend("doc1", "bar", "$..a") == [6, 8, None] @@ -534,8 +517,7 @@ def test_strlen_dollar(client): # Test multi client.json().set( - "doc1", "$", {"a": "foo", "nested1": { - "a": "hello"}, "nested2": {"a": 31}} + "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} ) assert client.json().strlen("doc1", "$..a") == [3, 5, None] @@ -634,8 +616,7 @@ def test_arrinsert_dollar(client): }, ) # Test multi - assert client.json().arrinsert("doc1", "$..a", "1", - "bar", "racuda") == [3, 5, None] + assert client.json().arrinsert("doc1", "$..a", "1", "bar", "racuda") == [3, 5, None] assert client.json().get("doc1", "$") == [ { @@ -674,8 +655,11 @@ def test_arrlen_dollar(client): # Test multi assert client.json().arrlen("doc1", "$..a") == [1, 3, None] - assert client.json().arrappend("doc1", "$..a", "non", "abba", "stanza") \ - == [4, 6, None] + assert client.json().arrappend("doc1", "$..a", "non", "abba", "stanza") == [ + 4, + 6, + None, + ] client.json().clear("doc1", "$.a") assert client.json().arrlen("doc1", "$..a") == [0, 6, None] @@ -924,8 +908,7 @@ def test_clear_dollar(client): assert client.json().clear("doc1", "$..a") == 3 assert client.json().get("doc1", "$") == [ - {"nested1": {"a": {}}, "a": [], "nested2": { - "a": "claro"}, "nested3": {"a": {}}} + {"nested1": {"a": {}}, "a": [], "nested2": {"a": "claro"}, "nested3": {"a": {}}} ] # Test single @@ -994,8 +977,7 @@ def test_debug_dollar(client): client.json().set("doc1", "$", jdata) # Test multi - assert client.json().debug("MEMORY", "doc1", "$..a") == [ - 72, 24, 24, 16, 16, 1, 0] + assert client.json().debug("MEMORY", "doc1", "$..a") == [72, 24, 24, 16, 16, 1, 0] # Test single assert client.json().debug("MEMORY", "doc1", "$.nested2.a") == [24] @@ -1234,12 +1216,10 @@ def test_arrindex_dollar(client): [], ] - assert client.json().arrindex("test_num", "$..arr", 3) == [ - 3, 2, -1, None, -1] + assert client.json().arrindex("test_num", "$..arr", 3) == [3, 2, -1, None, -1] # Test index of double scalar in multi values - assert client.json().arrindex("test_num", "$..arr", 3.0) == [ - 2, 8, -1, None, -1] + assert client.json().arrindex("test_num", "$..arr", 3.0) == [2, 8, -1, None, -1] # Test index of string scalar in multi values client.json().set( @@ -1249,10 +1229,7 @@ def test_arrindex_dollar(client): {"arr": ["bazzz", "bar", 2, "baz", 2, "ba", "baz", 3]}, { "nested1_found": { - "arr": [ - None, - "baz2", - "buzz", 2, 1, 0, 1, "2", "baz", 2, 4, 5] + "arr": [None, "baz2", "buzz", 2, 1, 0, 1, "2", "baz", 2, 4, 5] } }, {"nested2_not_found": {"arr": ["baz2", 4, 6]}}, @@ -1344,11 +1321,7 @@ def test_arrindex_dollar(client): {"arr": ["bazzz", "None", 2, None, 2, "ba", "baz", 3]}, { "nested1_found": { - "arr": [ - "zaz", - "baz2", - "buzz", - 2, 1, 0, 1, "2", None, 2, 4, 5] + "arr": ["zaz", "baz2", "buzz", 2, 1, 0, 1, "2", None, 2, 4, 5] } }, {"nested2_not_found": {"arr": ["None", 4, 6]}}, @@ -1369,8 +1342,7 @@ def test_arrindex_dollar(client): # Fail with none-scalar value with pytest.raises(exceptions.ResponseError): - client.json().arrindex( - "test_None", "$..nested42_empty_arr.arr", {"arr": []}) + client.json().arrindex("test_None", "$..nested42_empty_arr.arr", {"arr": []}) # Do not fail with none-scalar value in legacy mode assert ( @@ -1392,10 +1364,7 @@ def test_arrindex_dollar(client): assert client.json().arrindex("test_string", ".[0].arr", "faz") == -1 # Test index of None scalar in single value assert client.json().arrindex("test_None", ".[0].arr", "None") == 1 - assert client.json().arrindex( - "test_None", - "..nested2_not_found.arr", - "None") == 0 + assert client.json().arrindex("test_None", "..nested2_not_found.arr", "None") == 0 @pytest.mark.redismod @@ -1406,14 +1375,15 @@ def test_decoders_and_unstring(): assert decode_list(b"45.55") == 45.55 assert decode_list("45.55") == 45.55 - assert decode_list(['hello', b'world']) == ['hello', 'world'] + assert decode_list(["hello", b"world"]) == ["hello", "world"] @pytest.mark.redismod def test_custom_decoder(client): - import ujson import json + import ujson + cj = client.json(encoder=ujson, decoder=ujson) assert cj.set("foo", Path.rootPath(), "bar") assert "bar" == cj.get("foo") diff --git a/tests/test_monitor.py b/tests/test_monitor.py index 09e70d828f..40d9e43094 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -1,4 +1,5 @@ import pytest + from .conftest import ( skip_if_redis_enterprise, skip_ifnot_redis_enterprise, diff --git a/tests/test_search.py b/tests/test_search.py index c7b570cdd1..5b6a66009a 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1,52 +1,32 @@ -import pytest -import redis import bz2 import csv -import time import os - +import time from io import TextIOWrapper -from .conftest import skip_ifmodversion_lt, default_redismod_url -from redis import Redis +import pytest + +import redis import redis.commands.search +import redis.commands.search.aggregation as aggregations +import redis.commands.search.reducers as reducers +from redis import Redis from redis.commands.json.path import Path from redis.commands.search import Search -from redis.commands.search.field import ( - GeoField, - NumericField, - TagField, - TextField -) -from redis.commands.search.query import ( - GeoFilter, - NumericFilter, - Query -) -from redis.commands.search.result import Result +from redis.commands.search.field import GeoField, NumericField, TagField, TextField from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.query import GeoFilter, NumericFilter, Query +from redis.commands.search.result import Result from redis.commands.search.suggestion import Suggestion -import redis.commands.search.aggregation as aggregations -import redis.commands.search.reducers as reducers -WILL_PLAY_TEXT = ( - os.path.abspath( - os.path.join( - os.path.dirname(__file__), - "testdata", - "will_play_text.csv.bz2" - ) - ) +from .conftest import default_redismod_url, skip_ifmodversion_lt + +WILL_PLAY_TEXT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "testdata", "will_play_text.csv.bz2") ) -TITLES_CSV = ( - os.path.abspath( - os.path.join( - os.path.dirname(__file__), - "testdata", - "titles.csv" - ) - ) +TITLES_CSV = os.path.abspath( + os.path.join(os.path.dirname(__file__), "testdata", "titles.csv") ) @@ -81,9 +61,7 @@ def getClient(): def createIndex(client, num_docs=100, definition=None): try: client.create_index( - (TextField("play", weight=5.0), - TextField("txt"), - NumericField("chapter")), + (TextField("play", weight=5.0), TextField("txt"), NumericField("chapter")), definition=definition, ) except redis.ResponseError: @@ -96,8 +74,7 @@ def createIndex(client, num_docs=100, definition=None): r = csv.reader(bzfp, delimiter=";") for n, line in enumerate(r): - play, chapter, _, text = \ - line[1], line[2], line[4], line[5] + play, chapter, _, text = line[1], line[2], line[4], line[5] key = f"{play}:{chapter}".lower() d = chapters.setdefault(key, {}) @@ -183,12 +160,10 @@ def test_client(client): # test in fields txt_total = ( - client.ft().search( - Query("henry").no_content().limit_fields("txt")).total + client.ft().search(Query("henry").no_content().limit_fields("txt")).total ) play_total = ( - client.ft().search( - Query("henry").no_content().limit_fields("play")).total + client.ft().search(Query("henry").no_content().limit_fields("play")).total ) both_total = ( client.ft() @@ -217,10 +192,8 @@ def test_client(client): # test slop and in order assert 193 == client.ft().search(Query("henry king")).total - assert 3 == client.ft().search( - Query("henry king").slop(0).in_order()).total - assert 52 == client.ft().search( - Query("king henry").slop(0).in_order()).total + assert 3 == client.ft().search(Query("henry king").slop(0).in_order()).total + assert 52 == client.ft().search(Query("king henry").slop(0).in_order()).total assert 53 == client.ft().search(Query("henry king").slop(0)).total assert 167 == client.ft().search(Query("henry king").slop(100)).total @@ -284,11 +257,7 @@ def test_replace(client): res = client.ft().search("foo bar") assert 2 == res.total - client.ft().add_document( - "doc1", - replace=True, - txt="this is a replaced doc" - ) + client.ft().add_document("doc1", replace=True, txt="this is a replaced doc") res = client.ft().search("foo bar") assert 1 == res.total @@ -301,10 +270,7 @@ def test_replace(client): @pytest.mark.redismod def test_stopwords(client): - client.ft().create_index( - (TextField("txt"),), - stopwords=["foo", "bar", "baz"] - ) + client.ft().create_index((TextField("txt"),), stopwords=["foo", "bar", "baz"]) client.ft().add_document("doc1", txt="foo bar") client.ft().add_document("doc2", txt="hello world") waitForIndex(client, "idx") @@ -318,17 +284,8 @@ def test_stopwords(client): @pytest.mark.redismod def test_filters(client): - client.ft().create_index( - (TextField("txt"), - NumericField("num"), - GeoField("loc")) - ) - client.ft().add_document( - "doc1", - txt="foo bar", - num=3.141, - loc="-0.441,51.458" - ) + client.ft().create_index((TextField("txt"), NumericField("num"), GeoField("loc"))) + client.ft().add_document("doc1", txt="foo bar", num=3.141, loc="-0.441,51.458") client.ft().add_document("doc2", txt="foo baz", num=2, loc="-0.1,51.2") waitForIndex(client, "idx") @@ -336,8 +293,7 @@ def test_filters(client): q1 = Query("foo").add_filter(NumericFilter("num", 0, 2)).no_content() q2 = ( Query("foo") - .add_filter( - NumericFilter("num", 2, NumericFilter.INF, minExclusive=True)) + .add_filter(NumericFilter("num", 2, NumericFilter.INF, minExclusive=True)) .no_content() ) res1, res2 = client.ft().search(q1), client.ft().search(q2) @@ -348,10 +304,8 @@ def test_filters(client): assert "doc1" == res2.docs[0].id # Test geo filter - q1 = Query("foo").add_filter( - GeoFilter("loc", -0.44, 51.45, 10)).no_content() - q2 = Query("foo").add_filter( - GeoFilter("loc", -0.44, 51.45, 100)).no_content() + q1 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 10)).no_content() + q2 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 100)).no_content() res1, res2 = client.ft().search(q1), client.ft().search(q2) assert 1 == res1.total @@ -377,10 +331,7 @@ def test_payloads_with_no_content(client): @pytest.mark.redismod def test_sort_by(client): - client.ft().create_index( - (TextField("txt"), - NumericField("num", sortable=True)) - ) + client.ft().create_index((TextField("txt"), NumericField("num", sortable=True))) client.ft().add_document("doc1", txt="foo bar", num=1) client.ft().add_document("doc2", txt="foo baz", num=2) client.ft().add_document("doc3", txt="foo qux", num=3) @@ -422,10 +373,7 @@ def test_drop_index(): @pytest.mark.redismod def test_example(client): # Creating the index definition and schema - client.ft().create_index( - (TextField("title", weight=5.0), - TextField("body")) - ) + client.ft().create_index((TextField("title", weight=5.0), TextField("body"))) # Indexing a document client.ft().add_document( @@ -483,12 +431,7 @@ def test_auto_complete(client): client.ft().sugadd("ac", Suggestion("pay2", payload="pl2")) client.ft().sugadd("ac", Suggestion("pay3", payload="pl3")) - sugs = client.ft().sugget( - "ac", - "pay", - with_payloads=True, - with_scores=True - ) + sugs = client.ft().sugget("ac", "pay", with_payloads=True, with_scores=True) assert 3 == len(sugs) for sug in sugs: assert sug.payload @@ -550,11 +493,7 @@ def test_no_index(client): @pytest.mark.redismod def test_partial(client): - client.ft().create_index( - (TextField("f1"), - TextField("f2"), - TextField("f3")) - ) + client.ft().create_index((TextField("f1"), TextField("f2"), TextField("f3"))) client.ft().add_document("doc1", f1="f1_val", f2="f2_val") client.ft().add_document("doc2", f1="f1_val", f2="f2_val") client.ft().add_document("doc1", f3="f3_val", partial=True) @@ -572,11 +511,7 @@ def test_partial(client): @pytest.mark.redismod def test_no_create(client): - client.ft().create_index( - (TextField("f1"), - TextField("f2"), - TextField("f3")) - ) + client.ft().create_index((TextField("f1"), TextField("f2"), TextField("f3"))) client.ft().add_document("doc1", f1="f1_val", f2="f2_val") client.ft().add_document("doc2", f1="f1_val", f2="f2_val") client.ft().add_document("doc1", f3="f3_val", no_create=True) @@ -592,21 +527,12 @@ def test_no_create(client): assert 1 == res.total with pytest.raises(redis.ResponseError): - client.ft().add_document( - "doc3", - f2="f2_val", - f3="f3_val", - no_create=True - ) + client.ft().add_document("doc3", f2="f2_val", f3="f3_val", no_create=True) @pytest.mark.redismod def test_explain(client): - client.ft().create_index( - (TextField("f1"), - TextField("f2"), - TextField("f3")) - ) + client.ft().create_index((TextField("f1"), TextField("f2"), TextField("f3"))) res = client.ft().explain("@f3:f3_val @f2:f2_val @f1:f1_val") assert res @@ -629,8 +555,8 @@ def test_summarize(client): doc = sorted(client.ft().search(q).docs)[0] assert "Henry IV" == doc.play assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt ) q = Query("king henry").paging(0, 1).summarize().highlight() @@ -638,8 +564,8 @@ def test_summarize(client): doc = sorted(client.ft().search(q).docs)[0] assert "Henry ... " == doc.play assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt ) @@ -786,11 +712,7 @@ def test_alter_schema_add(client): def test_spell_check(client): client.ft().create_index((TextField("f1"), TextField("f2"))) - client.ft().add_document( - "doc1", - f1="some valid content", - f2="this is sample text" - ) + client.ft().add_document("doc1", f1="some valid content", f2="this is sample text") client.ft().add_document("doc2", f1="very important", f2="lorem ipsum") waitForIndex(client, "idx") @@ -812,10 +734,10 @@ def test_spell_check(client): res = client.ft().spellcheck("lorm", include="dict") assert len(res["lorm"]) == 3 assert ( - res["lorm"][0]["suggestion"], - res["lorm"][1]["suggestion"], - res["lorm"][2]["suggestion"], - ) == ("lorem", "lore", "lorm") + res["lorm"][0]["suggestion"], + res["lorm"][1]["suggestion"], + res["lorm"][2]["suggestion"], + ) == ("lorem", "lore", "lorm") assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") # test spellcheck exclude @@ -873,7 +795,7 @@ def test_scorer(client): ) client.ft().add_document( "doc2", - description="Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do.", # noqa + description="Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do.", # noqa ) # default scorer is TFIDF @@ -881,8 +803,7 @@ def test_scorer(client): assert 1.0 == res.docs[0].score res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) assert 1.0 == res.docs[0].score - res = client.ft().search( - Query("quick").scorer("TFIDF.DOCNORM").with_scores()) + res = client.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) assert 0.1111111111111111 == res.docs[0].score res = client.ft().search(Query("quick").scorer("BM25").with_scores()) assert 0.17699114465425977 == res.docs[0].score @@ -1060,7 +981,7 @@ def test_aggregations_groupby(client): ) res = client.ft().aggregate(req).rows[0] - assert res == ['parent', 'redis', 'first', 'RediSearch'] + assert res == ["parent", "redis", "first", "RediSearch"] req = aggregations.AggregateRequest("redis").group_by( "@parent", @@ -1083,35 +1004,33 @@ def test_aggregations_sort_by_and_limit(client): ) ) - client.ft().client.hset("doc1", mapping={'t1': 'a', 't2': 'b'}) - client.ft().client.hset("doc2", mapping={'t1': 'b', 't2': 'a'}) + client.ft().client.hset("doc1", mapping={"t1": "a", "t2": "b"}) + client.ft().client.hset("doc2", mapping={"t1": "b", "t2": "a"}) # test sort_by using SortDirection - req = aggregations.AggregateRequest("*") \ - .sort_by(aggregations.Asc("@t2"), aggregations.Desc("@t1")) + req = aggregations.AggregateRequest("*").sort_by( + aggregations.Asc("@t2"), aggregations.Desc("@t1") + ) res = client.ft().aggregate(req) - assert res.rows[0] == ['t2', 'a', 't1', 'b'] - assert res.rows[1] == ['t2', 'b', 't1', 'a'] + assert res.rows[0] == ["t2", "a", "t1", "b"] + assert res.rows[1] == ["t2", "b", "t1", "a"] # test sort_by without SortDirection - req = aggregations.AggregateRequest("*") \ - .sort_by("@t1") + req = aggregations.AggregateRequest("*").sort_by("@t1") res = client.ft().aggregate(req) - assert res.rows[0] == ['t1', 'a'] - assert res.rows[1] == ['t1', 'b'] + assert res.rows[0] == ["t1", "a"] + assert res.rows[1] == ["t1", "b"] # test sort_by with max - req = aggregations.AggregateRequest("*") \ - .sort_by("@t1", max=1) + req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) res = client.ft().aggregate(req) assert len(res.rows) == 1 # test limit - req = aggregations.AggregateRequest("*") \ - .sort_by("@t1").limit(1, 1) + req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) res = client.ft().aggregate(req) assert len(res.rows) == 1 - assert res.rows[0] == ['t1', 'b'] + assert res.rows[0] == ["t1", "b"] @pytest.mark.redismod @@ -1123,17 +1042,17 @@ def test_aggregations_load(client): ) ) - client.ft().client.hset("doc1", mapping={'t1': 'hello', 't2': 'world'}) + client.ft().client.hset("doc1", mapping={"t1": "hello", "t2": "world"}) # load t1 req = aggregations.AggregateRequest("*").load("t1") res = client.ft().aggregate(req) - assert res.rows[0] == ['t1', 'hello'] + assert res.rows[0] == ["t1", "hello"] # load t2 req = aggregations.AggregateRequest("*").load("t2") res = client.ft().aggregate(req) - assert res.rows[0] == ['t2', 'world'] + assert res.rows[0] == ["t2", "world"] @pytest.mark.redismod @@ -1147,24 +1066,19 @@ def test_aggregations_apply(client): client.ft().client.hset( "doc1", - mapping={ - 'PrimaryKey': '9::362330', - 'CreatedDateTimeUTC': '637387878524969984' - } + mapping={"PrimaryKey": "9::362330", "CreatedDateTimeUTC": "637387878524969984"}, ) client.ft().client.hset( "doc2", - mapping={ - 'PrimaryKey': '9::362329', - 'CreatedDateTimeUTC': '637387875859270016' - } + mapping={"PrimaryKey": "9::362329", "CreatedDateTimeUTC": "637387875859270016"}, ) - req = aggregations.AggregateRequest("*") \ - .apply(CreatedDateTimeUTC='@CreatedDateTimeUTC * 10') + req = aggregations.AggregateRequest("*").apply( + CreatedDateTimeUTC="@CreatedDateTimeUTC * 10" + ) res = client.ft().aggregate(req) - assert res.rows[0] == ['CreatedDateTimeUTC', '6373878785249699840'] - assert res.rows[1] == ['CreatedDateTimeUTC', '6373878758592700416'] + assert res.rows[0] == ["CreatedDateTimeUTC", "6373878785249699840"] + assert res.rows[1] == ["CreatedDateTimeUTC", "6373878758592700416"] @pytest.mark.redismod @@ -1176,33 +1090,19 @@ def test_aggregations_filter(client): ) ) - client.ft().client.hset( - "doc1", - mapping={ - 'name': 'bar', - 'age': '25' - } - ) - client.ft().client.hset( - "doc2", - mapping={ - 'name': 'foo', - 'age': '19' - } - ) + client.ft().client.hset("doc1", mapping={"name": "bar", "age": "25"}) + client.ft().client.hset("doc2", mapping={"name": "foo", "age": "19"}) - req = aggregations.AggregateRequest("*") \ - .filter("@name=='foo' && @age < 20") + req = aggregations.AggregateRequest("*").filter("@name=='foo' && @age < 20") res = client.ft().aggregate(req) assert len(res.rows) == 1 - assert res.rows[0] == ['name', 'foo', 'age', '19'] + assert res.rows[0] == ["name", "foo", "age", "19"] - req = aggregations.AggregateRequest("*") \ - .filter("@age > 15").sort_by("@age") + req = aggregations.AggregateRequest("*").filter("@age > 15").sort_by("@age") res = client.ft().aggregate(req) assert len(res.rows) == 2 - assert res.rows[0] == ['age', '19'] - assert res.rows[1] == ['age', '25'] + assert res.rows[0] == ["age", "19"] + assert res.rows[1] == ["age", "25"] @pytest.mark.redismod @@ -1226,25 +1126,25 @@ def test_index_definition(client): ) assert [ - "ON", - "JSON", - "PREFIX", - 2, - "hset:", - "henry", - "FILTER", - "@f1==32", - "LANGUAGE_FIELD", - "play", - "LANGUAGE", - "English", - "SCORE_FIELD", - "chapter", - "SCORE", - 0.5, - "PAYLOAD_FIELD", - "txt", - ] == definition.args + "ON", + "JSON", + "PREFIX", + 2, + "hset:", + "henry", + "FILTER", + "@f1==32", + "LANGUAGE_FIELD", + "play", + "LANGUAGE", + "English", + "SCORE_FIELD", + "chapter", + "SCORE", + 0.5, + "PAYLOAD_FIELD", + "txt", + ] == definition.args createIndex(client.ft(), num_docs=500, definition=definition) @@ -1274,10 +1174,7 @@ def test_create_client_definition_hash(client): Create definition with IndexType.HASH as index type (ON HASH), and use hset to test the client definition. """ - definition = IndexDefinition( - prefix=["hset:", "henry"], - index_type=IndexType.HASH - ) + definition = IndexDefinition(prefix=["hset:", "henry"], index_type=IndexType.HASH) createIndex(client.ft(), num_docs=500, definition=definition) info = client.ft().info() @@ -1320,15 +1217,10 @@ def test_fields_as_name(client): client.ft().create_index(SCHEMA, definition=definition) # insert json data - res = client.json().set( - "doc:1", - Path.rootPath(), - {"name": "Jon", "age": 25} - ) + res = client.json().set("doc:1", Path.rootPath(), {"name": "Jon", "age": 25}) assert res - total = client.ft().search( - Query("Jon").return_fields("name", "just_a_number")).docs + total = client.ft().search(Query("Jon").return_fields("name", "just_a_number")).docs assert 1 == len(total) assert "doc:1" == total[0].id assert "Jon" == total[0].name @@ -1354,14 +1246,12 @@ def test_search_return_fields(client): client.ft().create_index(SCHEMA, definition=definition) waitForIndex(client, "idx") - total = client.ft().search( - Query("*").return_field("$.t", as_field="txt")).docs + total = client.ft().search(Query("*").return_field("$.t", as_field="txt")).docs assert 1 == len(total) assert "doc:1" == total[0].id assert "riceratops" == total[0].txt - total = client.ft().search( - Query("*").return_field("$.t2", as_field="txt")).docs + total = client.ft().search(Query("*").return_field("$.t2", as_field="txt")).docs assert 1 == len(total) assert "doc:1" == total[0].id assert "telmatosaurus" == total[0].txt @@ -1379,17 +1269,10 @@ def test_synupdate(client): ) client.ft().synupdate("id1", True, "boy", "child", "offspring") - client.ft().add_document( - "doc1", - title="he is a baby", - body="this is a test") + client.ft().add_document("doc1", title="he is a baby", body="this is a test") client.ft().synupdate("id1", True, "baby") - client.ft().add_document( - "doc2", - title="he is another baby", - body="another test" - ) + client.ft().add_document("doc2", title="he is another baby", body="another test") res = client.ft().search(Query("child").expander("SYNONYM")) assert res.docs[0].id == "doc2" @@ -1431,15 +1314,12 @@ def test_create_json_with_alias(client): """ definition = IndexDefinition(prefix=["king:"], index_type=IndexType.JSON) client.ft().create_index( - (TextField("$.name", as_name="name"), - NumericField("$.num", as_name="num")), - definition=definition + (TextField("$.name", as_name="name"), NumericField("$.num", as_name="num")), + definition=definition, ) - client.json().set("king:1", Path.rootPath(), {"name": "henry", - "num": 42}) - client.json().set("king:2", Path.rootPath(), {"name": "james", - "num": 3.14}) + client.json().set("king:1", Path.rootPath(), {"name": "henry", "num": 42}) + client.json().set("king:2", Path.rootPath(), {"name": "james", "num": 3.14}) res = client.ft().search("@name:henry") assert res.docs[0].id == "king:1" @@ -1466,12 +1346,12 @@ def test_json_with_multipath(client): """ definition = IndexDefinition(prefix=["king:"], index_type=IndexType.JSON) client.ft().create_index( - (TagField("$..name", as_name="name")), - definition=definition + (TagField("$..name", as_name="name")), definition=definition ) - client.json().set("king:1", Path.rootPath(), - {"name": "henry", "country": {"name": "england"}}) + client.json().set( + "king:1", Path.rootPath(), {"name": "henry", "country": {"name": "england"}} + ) res = client.ft().search("@name:{henry}") assert res.docs[0].id == "king:1" @@ -1489,9 +1369,11 @@ def test_json_with_multipath(client): def test_json_with_jsonpath(client): definition = IndexDefinition(index_type=IndexType.JSON) client.ft().create_index( - (TextField('$["prod:name"]', as_name="name"), - TextField('$.prod:name', as_name="name_unsupported")), - definition=definition + ( + TextField('$["prod:name"]', as_name="name"), + TextField("$.prod:name", as_name="name_unsupported"), + ), + definition=definition, ) client.json().set("doc:1", Path.rootPath(), {"prod:name": "RediSearch"}) @@ -1510,11 +1392,10 @@ def test_json_with_jsonpath(client): res = client.ft().search(Query("@name:RediSearch").return_field("name")) assert res.total == 1 assert res.docs[0].id == "doc:1" - assert res.docs[0].name == 'RediSearch' + assert res.docs[0].name == "RediSearch" # return of an unsupported field fails - res = client.ft().search(Query("@name:RediSearch") - .return_field("name_unsupported")) + res = client.ft().search(Query("@name:RediSearch").return_field("name_unsupported")) assert res.total == 1 assert res.docs[0].id == "doc:1" with pytest.raises(Exception): @@ -1523,42 +1404,49 @@ def test_json_with_jsonpath(client): @pytest.mark.redismod def test_profile(client): - client.ft().create_index((TextField('t'),)) - client.ft().client.hset('1', 't', 'hello') - client.ft().client.hset('2', 't', 'world') + client.ft().create_index((TextField("t"),)) + client.ft().client.hset("1", "t", "hello") + client.ft().client.hset("2", "t", "world") # check using Query - q = Query('hello|world').no_content() + q = Query("hello|world").no_content() res, det = client.ft().profile(q) - assert det['Iterators profile']['Counter'] == 2.0 - assert len(det['Iterators profile']['Child iterators']) == 2 - assert det['Iterators profile']['Type'] == 'UNION' - assert det['Parsing time'] < 0.3 + assert det["Iterators profile"]["Counter"] == 2.0 + assert len(det["Iterators profile"]["Child iterators"]) == 2 + assert det["Iterators profile"]["Type"] == "UNION" + assert det["Parsing time"] < 0.3 assert len(res.docs) == 2 # check also the search result # check using AggregateRequest - req = aggregations.AggregateRequest("*").load("t")\ + req = ( + aggregations.AggregateRequest("*") + .load("t") .apply(prefix="startswith(@t, 'hel')") + ) res, det = client.ft().profile(req) - assert det['Iterators profile']['Counter'] == 2.0 - assert det['Iterators profile']['Type'] == 'WILDCARD' - assert det['Parsing time'] < 0.3 + assert det["Iterators profile"]["Counter"] == 2.0 + assert det["Iterators profile"]["Type"] == "WILDCARD" + assert det["Parsing time"] < 0.3 assert len(res.rows) == 2 # check also the search result @pytest.mark.redismod def test_profile_limited(client): - client.ft().create_index((TextField('t'),)) - client.ft().client.hset('1', 't', 'hello') - client.ft().client.hset('2', 't', 'hell') - client.ft().client.hset('3', 't', 'help') - client.ft().client.hset('4', 't', 'helowa') + client.ft().create_index((TextField("t"),)) + client.ft().client.hset("1", "t", "hello") + client.ft().client.hset("2", "t", "hell") + client.ft().client.hset("3", "t", "help") + client.ft().client.hset("4", "t", "helowa") - q = Query('%hell% hel*') + q = Query("%hell% hel*") res, det = client.ft().profile(q, limited=True) - assert det['Iterators profile']['Child iterators'][0]['Child iterators'] \ - == 'The number of iterators in the union is 3' - assert det['Iterators profile']['Child iterators'][1]['Child iterators'] \ - == 'The number of iterators in the union is 4' - assert det['Iterators profile']['Type'] == 'INTERSECT' + assert ( + det["Iterators profile"]["Child iterators"][0]["Child iterators"] + == "The number of iterators in the union is 3" + ) + assert ( + det["Iterators profile"]["Child iterators"][1]["Child iterators"] + == "The number of iterators in the union is 4" + ) + assert det["Iterators profile"]["Type"] == "INTERSECT" assert len(res.docs) == 3 # check also the search result From 3a9daf40b994299b5762ce811a7f3e4cb011ea2d Mon Sep 17 00:00:00 2001 From: Anas Date: Tue, 30 Nov 2021 17:47:19 +0200 Subject: [PATCH 8/8] Rebased on upstream/master and made linters happy --- benchmarks/base.py | 17 +- benchmarks/basic_operations.py | 70 +- benchmarks/command_packer_benchmark.py | 49 +- redis/client.py | 1032 +++--- redis/cluster.py | 769 ++-- redis/commands/cluster.py | 406 ++- redis/commands/core.py | 1713 +++++---- redis/commands/json/commands.py | 38 +- redis/commands/parser.py | 41 +- redis/commands/search/commands.py | 6 +- redis/commands/search/querystring.py | 7 +- redis/commands/timeseries/commands.py | 6 +- redis/connection.py | 429 ++- redis/sentinel.py | 136 +- tasks.py | 10 +- tests/conftest.py | 168 +- tests/test_cluster.py | 1749 +++++----- tests/test_commands.py | 4464 +++++++++++++----------- tests/test_connection.py | 23 +- tests/test_multiprocessing.py | 78 +- tests/test_pipeline.py | 289 +- tests/test_pubsub.py | 340 +- tox.ini | 1 - 23 files changed, 6326 insertions(+), 5515 deletions(-) diff --git a/benchmarks/base.py b/benchmarks/base.py index 519c9ccab5..f52657f072 100644 --- a/benchmarks/base.py +++ b/benchmarks/base.py @@ -1,9 +1,10 @@ import functools import itertools -import redis import sys import timeit +import redis + class Benchmark: ARGUMENTS = () @@ -15,9 +16,7 @@ def get_client(self, **kwargs): # eventually make this more robust and take optional args from # argparse if self._client is None or kwargs: - defaults = { - 'db': 9 - } + defaults = {"db": 9} defaults.update(kwargs) pool = redis.ConnectionPool(**kwargs) self._client = redis.Redis(connection_pool=pool) @@ -30,16 +29,16 @@ def run(self, **kwargs): pass def run_benchmark(self): - group_names = [group['name'] for group in self.ARGUMENTS] - group_values = [group['values'] for group in self.ARGUMENTS] + group_names = [group["name"] for group in self.ARGUMENTS] + group_values = [group["values"] for group in self.ARGUMENTS] for value_set in itertools.product(*group_values): pairs = list(zip(group_names, value_set)) - arg_string = ', '.join(f'{p[0]}={p[1]}' for p in pairs) - sys.stdout.write(f'Benchmark: {arg_string}... ') + arg_string = ", ".join(f"{p[0]}={p[1]}" for p in pairs) + sys.stdout.write(f"Benchmark: {arg_string}... ") sys.stdout.flush() kwargs = dict(pairs) setup = functools.partial(self.setup, **kwargs) run = functools.partial(self.run, **kwargs) t = timeit.timeit(stmt=run, setup=setup, number=1000) - sys.stdout.write(f'{t:f}\n') + sys.stdout.write(f"{t:f}\n") sys.stdout.flush() diff --git a/benchmarks/basic_operations.py b/benchmarks/basic_operations.py index cb009debbd..1dc4a87cc8 100644 --- a/benchmarks/basic_operations.py +++ b/benchmarks/basic_operations.py @@ -1,24 +1,27 @@ -import redis import time -from functools import wraps from argparse import ArgumentParser +from functools import wraps + +import redis def parse_args(): parser = ArgumentParser() - parser.add_argument('-n', - type=int, - help='Total number of requests (default 100000)', - default=100000) - parser.add_argument('-P', - type=int, - help=('Pipeline requests.' - ' Default 1 (no pipeline).'), - default=1) - parser.add_argument('-s', - type=int, - help='Data size of SET/GET value in bytes (default 2)', - default=2) + parser.add_argument( + "-n", type=int, help="Total number of requests (default 100000)", default=100000 + ) + parser.add_argument( + "-P", + type=int, + help=("Pipeline requests." " Default 1 (no pipeline)."), + default=1, + ) + parser.add_argument( + "-s", + type=int, + help="Data size of SET/GET value in bytes (default 2)", + default=2, + ) args = parser.parse_args() return args @@ -45,15 +48,16 @@ def wrapper(*args, **kwargs): start = time.monotonic() ret = func(*args, **kwargs) duration = time.monotonic() - start - if 'num' in kwargs: - count = kwargs['num'] + if "num" in kwargs: + count = kwargs["num"] else: count = args[1] - print(f'{func.__name__} - {count} Requests') - print(f'Duration = {duration}') - print(f'Rate = {count/duration}') + print(f"{func.__name__} - {count} Requests") + print(f"Duration = {duration}") + print(f"Rate = {count/duration}") print() return ret + return wrapper @@ -62,9 +66,9 @@ def set_str(conn, num, pipeline_size, data_size): if pipeline_size > 1: conn = conn.pipeline() - set_data = 'a'.ljust(data_size, '0') + set_data = "a".ljust(data_size, "0") for i in range(num): - conn.set(f'set_str:{i}', set_data) + conn.set(f"set_str:{i}", set_data) if pipeline_size > 1 and i % pipeline_size == 0: conn.execute() @@ -79,7 +83,7 @@ def set_int(conn, num, pipeline_size, data_size): set_data = 10 ** (data_size - 1) for i in range(num): - conn.set(f'set_int:{i}', set_data) + conn.set(f"set_int:{i}", set_data) if pipeline_size > 1 and i % pipeline_size == 0: conn.execute() @@ -93,7 +97,7 @@ def get_str(conn, num, pipeline_size, data_size): conn = conn.pipeline() for i in range(num): - conn.get(f'set_str:{i}') + conn.get(f"set_str:{i}") if pipeline_size > 1 and i % pipeline_size == 0: conn.execute() @@ -107,7 +111,7 @@ def get_int(conn, num, pipeline_size, data_size): conn = conn.pipeline() for i in range(num): - conn.get(f'set_int:{i}') + conn.get(f"set_int:{i}") if pipeline_size > 1 and i % pipeline_size == 0: conn.execute() @@ -121,7 +125,7 @@ def incr(conn, num, pipeline_size, *args, **kwargs): conn = conn.pipeline() for i in range(num): - conn.incr('incr_key') + conn.incr("incr_key") if pipeline_size > 1 and i % pipeline_size == 0: conn.execute() @@ -136,7 +140,7 @@ def lpush(conn, num, pipeline_size, data_size): set_data = 10 ** (data_size - 1) for i in range(num): - conn.lpush('lpush_key', set_data) + conn.lpush("lpush_key", set_data) if pipeline_size > 1 and i % pipeline_size == 0: conn.execute() @@ -150,7 +154,7 @@ def lrange_300(conn, num, pipeline_size, data_size): conn = conn.pipeline() for i in range(num): - conn.lrange('lpush_key', i, i+300) + conn.lrange("lpush_key", i, i + 300) if pipeline_size > 1 and i % pipeline_size == 0: conn.execute() @@ -163,7 +167,7 @@ def lpop(conn, num, pipeline_size, data_size): if pipeline_size > 1: conn = conn.pipeline() for i in range(num): - conn.lpop('lpush_key') + conn.lpop("lpush_key") if pipeline_size > 1 and i % pipeline_size == 0: conn.execute() if pipeline_size > 1: @@ -175,11 +179,9 @@ def hmset(conn, num, pipeline_size, data_size): if pipeline_size > 1: conn = conn.pipeline() - set_data = {'str_value': 'string', - 'int_value': 123456, - 'float_value': 123456.0} + set_data = {"str_value": "string", "int_value": 123456, "float_value": 123456.0} for i in range(num): - conn.hmset('hmset_key', set_data) + conn.hmset("hmset_key", set_data) if pipeline_size > 1 and i % pipeline_size == 0: conn.execute() @@ -187,5 +189,5 @@ def hmset(conn, num, pipeline_size, data_size): conn.execute() -if __name__ == '__main__': +if __name__ == "__main__": run() diff --git a/benchmarks/command_packer_benchmark.py b/benchmarks/command_packer_benchmark.py index 3176c06800..e66dbbcbf9 100644 --- a/benchmarks/command_packer_benchmark.py +++ b/benchmarks/command_packer_benchmark.py @@ -1,7 +1,7 @@ -from redis.connection import (Connection, SYM_STAR, SYM_DOLLAR, SYM_EMPTY, - SYM_CRLF) from base import Benchmark +from redis.connection import SYM_CRLF, SYM_DOLLAR, SYM_EMPTY, SYM_STAR, Connection + class StringJoiningConnection(Connection): def send_packed_command(self, command, check_health=True): @@ -13,7 +13,7 @@ def send_packed_command(self, command, check_health=True): except OSError as e: self.disconnect() if len(e.args) == 1: - _errno, errmsg = 'UNKNOWN', e.args[0] + _errno, errmsg = "UNKNOWN", e.args[0] else: _errno, errmsg = e.args raise ConnectionError(f"Error {_errno} while writing to socket. {errmsg}.") @@ -23,12 +23,17 @@ def send_packed_command(self, command, check_health=True): def pack_command(self, *args): "Pack a series of arguments into a value Redis command" - args_output = SYM_EMPTY.join([ - SYM_EMPTY.join( - (SYM_DOLLAR, str(len(k)).encode(), SYM_CRLF, k, SYM_CRLF)) - for k in map(self.encoder.encode, args)]) + args_output = SYM_EMPTY.join( + [ + SYM_EMPTY.join( + (SYM_DOLLAR, str(len(k)).encode(), SYM_CRLF, k, SYM_CRLF) + ) + for k in map(self.encoder.encode, args) + ] + ) output = SYM_EMPTY.join( - (SYM_STAR, str(len(args)).encode(), SYM_CRLF, args_output)) + (SYM_STAR, str(len(args)).encode(), SYM_CRLF, args_output) + ) return output @@ -44,7 +49,7 @@ def send_packed_command(self, command, check_health=True): except OSError as e: self.disconnect() if len(e.args) == 1: - _errno, errmsg = 'UNKNOWN', e.args[0] + _errno, errmsg = "UNKNOWN", e.args[0] else: _errno, errmsg = e.args raise ConnectionError(f"Error {_errno} while writing to socket. {errmsg}.") @@ -54,19 +59,20 @@ def send_packed_command(self, command, check_health=True): def pack_command(self, *args): output = [] - buff = SYM_EMPTY.join( - (SYM_STAR, str(len(args)).encode(), SYM_CRLF)) + buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) for k in map(self.encoder.encode, args): if len(buff) > 6000 or len(k) > 6000: buff = SYM_EMPTY.join( - (buff, SYM_DOLLAR, str(len(k)).encode(), SYM_CRLF)) + (buff, SYM_DOLLAR, str(len(k)).encode(), SYM_CRLF) + ) output.append(buff) output.append(k) buff = SYM_CRLF else: - buff = SYM_EMPTY.join((buff, SYM_DOLLAR, str(len(k)).encode(), - SYM_CRLF, k, SYM_CRLF)) + buff = SYM_EMPTY.join( + (buff, SYM_DOLLAR, str(len(k)).encode(), SYM_CRLF, k, SYM_CRLF) + ) output.append(buff) return output @@ -75,13 +81,12 @@ class CommandPackerBenchmark(Benchmark): ARGUMENTS = ( { - 'name': 'connection_class', - 'values': [StringJoiningConnection, ListJoiningConnection] + "name": "connection_class", + "values": [StringJoiningConnection, ListJoiningConnection], }, { - 'name': 'value_size', - 'values': [10, 100, 1000, 10000, 100000, 1000000, 10000000, - 100000000] + "name": "value_size", + "values": [10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000], }, ) @@ -90,9 +95,9 @@ def setup(self, connection_class, value_size): def run(self, connection_class, value_size): r = self.get_client() - x = 'a' * value_size - r.set('benchmark', x) + x = "a" * value_size + r.set("benchmark", x) -if __name__ == '__main__': +if __name__ == "__main__": CommandPackerBenchmark().run_benchmark() diff --git a/redis/client.py b/redis/client.py index 9f2907ee52..14e588a1d7 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1,15 +1,18 @@ -from itertools import chain import copy import datetime import re import threading import time import warnings -from redis.commands import (CoreCommands, RedisModuleCommands, - SentinelCommands, list_or_args) -from redis.connection import (ConnectionPool, UnixDomainSocketConnection, - SSLConnection) -from redis.lock import Lock +from itertools import chain + +from redis.commands import ( + CoreCommands, + RedisModuleCommands, + SentinelCommands, + list_or_args, +) +from redis.connection import ConnectionPool, SSLConnection, UnixDomainSocketConnection from redis.exceptions import ( ConnectionError, ExecAbortError, @@ -20,13 +23,14 @@ TimeoutError, WatchError, ) +from redis.lock import Lock from redis.utils import safe_str, str_if_bytes -SYM_EMPTY = b'' -EMPTY_RESPONSE = 'EMPTY_RESPONSE' +SYM_EMPTY = b"" +EMPTY_RESPONSE = "EMPTY_RESPONSE" # some responses (ie. dump) are binary, and just meant to never be decoded -NEVER_DECODE = 'NEVER_DECODE' +NEVER_DECODE = "NEVER_DECODE" def timestamp_to_datetime(response): @@ -76,12 +80,12 @@ def parse_debug_object(response): # The 'type' of the object is the first item in the response, but isn't # prefixed with a name response = str_if_bytes(response) - response = 'type:' + response - response = dict(kv.split(':') for kv in response.split()) + response = "type:" + response + response = dict(kv.split(":") for kv in response.split()) # parse some expected int values from the string response # note: this cmd isn't spec'd so these may not appear in all redis versions - int_fields = ('refcount', 'serializedlength', 'lru', 'lru_seconds_idle') + int_fields = ("refcount", "serializedlength", "lru", "lru_seconds_idle") for field in int_fields: if field in response: response[field] = int(response[field]) @@ -91,7 +95,7 @@ def parse_debug_object(response): def parse_object(response, infotype): "Parse the results of an OBJECT command" - if infotype in ('idletime', 'refcount'): + if infotype in ("idletime", "refcount"): return int_or_none(response) return response @@ -102,9 +106,9 @@ def parse_info(response): response = str_if_bytes(response) def get_value(value): - if ',' not in value or '=' not in value: + if "," not in value or "=" not in value: try: - if '.' in value: + if "." in value: return float(value) else: return int(value) @@ -112,82 +116,84 @@ def get_value(value): return value else: sub_dict = {} - for item in value.split(','): - k, v = item.rsplit('=', 1) + for item in value.split(","): + k, v = item.rsplit("=", 1) sub_dict[k] = get_value(v) return sub_dict for line in response.splitlines(): - if line and not line.startswith('#'): - if line.find(':') != -1: + if line and not line.startswith("#"): + if line.find(":") != -1: # Split, the info fields keys and values. # Note that the value may contain ':'. but the 'host:' # pseudo-command is the only case where the key contains ':' - key, value = line.split(':', 1) - if key == 'cmdstat_host': - key, value = line.rsplit(':', 1) + key, value = line.split(":", 1) + if key == "cmdstat_host": + key, value = line.rsplit(":", 1) - if key == 'module': + if key == "module": # Hardcode a list for key 'modules' since there could be # multiple lines that started with 'module' - info.setdefault('modules', []).append(get_value(value)) + info.setdefault("modules", []).append(get_value(value)) else: info[key] = get_value(value) else: # if the line isn't splittable, append it to the "__raw__" key - info.setdefault('__raw__', []).append(line) + info.setdefault("__raw__", []).append(line) return info def parse_memory_stats(response, **kwargs): "Parse the results of MEMORY STATS" - stats = pairs_to_dict(response, - decode_keys=True, - decode_string_values=True) + stats = pairs_to_dict(response, decode_keys=True, decode_string_values=True) for key, value in stats.items(): - if key.startswith('db.'): - stats[key] = pairs_to_dict(value, - decode_keys=True, - decode_string_values=True) + if key.startswith("db."): + stats[key] = pairs_to_dict( + value, decode_keys=True, decode_string_values=True + ) return stats SENTINEL_STATE_TYPES = { - 'can-failover-its-master': int, - 'config-epoch': int, - 'down-after-milliseconds': int, - 'failover-timeout': int, - 'info-refresh': int, - 'last-hello-message': int, - 'last-ok-ping-reply': int, - 'last-ping-reply': int, - 'last-ping-sent': int, - 'master-link-down-time': int, - 'master-port': int, - 'num-other-sentinels': int, - 'num-slaves': int, - 'o-down-time': int, - 'pending-commands': int, - 'parallel-syncs': int, - 'port': int, - 'quorum': int, - 'role-reported-time': int, - 's-down-time': int, - 'slave-priority': int, - 'slave-repl-offset': int, - 'voted-leader-epoch': int + "can-failover-its-master": int, + "config-epoch": int, + "down-after-milliseconds": int, + "failover-timeout": int, + "info-refresh": int, + "last-hello-message": int, + "last-ok-ping-reply": int, + "last-ping-reply": int, + "last-ping-sent": int, + "master-link-down-time": int, + "master-port": int, + "num-other-sentinels": int, + "num-slaves": int, + "o-down-time": int, + "pending-commands": int, + "parallel-syncs": int, + "port": int, + "quorum": int, + "role-reported-time": int, + "s-down-time": int, + "slave-priority": int, + "slave-repl-offset": int, + "voted-leader-epoch": int, } def parse_sentinel_state(item): result = pairs_to_dict_typed(item, SENTINEL_STATE_TYPES) - flags = set(result['flags'].split(',')) - for name, flag in (('is_master', 'master'), ('is_slave', 'slave'), - ('is_sdown', 's_down'), ('is_odown', 'o_down'), - ('is_sentinel', 'sentinel'), - ('is_disconnected', 'disconnected'), - ('is_master_down', 'master_down')): + flags = set(result["flags"].split(",")) + for name, flag in ( + ("is_master", "master"), + ("is_slave", "slave"), + ("is_sdown", "s_down"), + ("is_odown", "o_down"), + ("is_sentinel", "sentinel"), + ("is_disconnected", "disconnected"), + ("is_master_down", "master_down"), + ): result[name] = flag in flags return result @@ -200,7 +206,7 @@ def parse_sentinel_masters(response): result = {} for item in response: state = parse_sentinel_state(map(str_if_bytes, item)) - result[state['name']] = state + result[state["name"]] = state return result @@ -251,9 +257,9 @@ def zset_score_pairs(response, **options): If ``withscores`` is specified in the options, return the response as a list of (value, score) pairs """ - if not response or not options.get('withscores'): + if not response or not options.get("withscores"): return response - score_cast_func = options.get('score_cast_func', float) + score_cast_func = options.get("score_cast_func", float) it = iter(response) return list(zip(it, map(score_cast_func, it))) @@ -263,9 +269,9 @@ def sort_return_tuples(response, **options): If ``groups`` is specified, return the response as a list of n-element tuples with n being the value found in options['groups'] """ - if not response or not options.get('groups'): + if not response or not options.get("groups"): return response - n = options['groups'] + n = options["groups"] return list(zip(*[response[i::n] for i in range(n)])) @@ -296,34 +302,30 @@ def parse_list_of_dicts(response): def parse_xclaim(response, **options): - if options.get('parse_justid', False): + if options.get("parse_justid", False): return response return parse_stream_list(response) def parse_xautoclaim(response, **options): - if options.get('parse_justid', False): + if options.get("parse_justid", False): return response[1] return parse_stream_list(response[1]) def parse_xinfo_stream(response, **options): data = pairs_to_dict(response, decode_keys=True) - if not options.get('full', False): - first = data['first-entry'] + if not options.get("full", False): + first = data["first-entry"] if first is not None: - data['first-entry'] = (first[0], pairs_to_dict(first[1])) - last = data['last-entry'] + data["first-entry"] = (first[0], pairs_to_dict(first[1])) + last = data["last-entry"] if last is not None: - data['last-entry'] = (last[0], pairs_to_dict(last[1])) + data["last-entry"] = (last[0], pairs_to_dict(last[1])) else: - data['entries'] = { - _id: pairs_to_dict(entry) - for _id, entry in data['entries'] - } - data['groups'] = [ - pairs_to_dict(group, decode_keys=True) - for group in data['groups'] + data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]} + data["groups"] = [ + pairs_to_dict(group, decode_keys=True) for group in data["groups"] ] return data @@ -335,19 +337,19 @@ def parse_xread(response): def parse_xpending(response, **options): - if options.get('parse_detail', False): + if options.get("parse_detail", False): return parse_xpending_range(response) - consumers = [{'name': n, 'pending': int(p)} for n, p in response[3] or []] + consumers = [{"name": n, "pending": int(p)} for n, p in response[3] or []] return { - 'pending': response[0], - 'min': response[1], - 'max': response[2], - 'consumers': consumers + "pending": response[0], + "min": response[1], + "max": response[2], + "consumers": consumers, } def parse_xpending_range(response): - k = ('message_id', 'consumer', 'time_since_delivered', 'times_delivered') + k = ("message_id", "consumer", "time_since_delivered", "times_delivered") return [dict(zip(k, r)) for r in response] @@ -358,13 +360,13 @@ def float_or_none(response): def bool_ok(response): - return str_if_bytes(response) == 'OK' + return str_if_bytes(response) == "OK" def parse_zadd(response, **options): if response is None: return None - if options.get('as_score'): + if options.get("as_score"): return float(response) return int(response) @@ -373,7 +375,7 @@ def parse_client_list(response, **options): clients = [] for c in str_if_bytes(response).splitlines(): # Values might contain '=' - clients.append(dict(pair.split('=', 1) for pair in c.split(' '))) + clients.append(dict(pair.split("=", 1) for pair in c.split(" "))) return clients @@ -393,7 +395,7 @@ def parse_hscan(response, **options): def parse_zscan(response, **options): - score_cast_func = options.get('score_cast_func', float) + score_cast_func = options.get("score_cast_func", float) cursor, r = response it = iter(r) return int(cursor), list(zip(it, map(score_cast_func, it))) @@ -405,23 +407,24 @@ def parse_zmscore(response, **options): def parse_slowlog_get(response, **options): - space = ' ' if options.get('decode_responses', False) else b' ' + space = " " if options.get("decode_responses", False) else b" " def parse_item(item): result = { - 'id': item[0], - 'start_time': int(item[1]), - 'duration': int(item[2]), + "id": item[0], + "start_time": int(item[1]), + "duration": int(item[2]), } # Redis Enterprise injects another entry at index [3], which has # the complexity info (i.e. the value N in case the command has # an O(N) complexity) instead of the command. if isinstance(item[3], list): - result['command'] = space.join(item[3]) + result["command"] = space.join(item[3]) else: - result['complexity'] = item[3] - result['command'] = space.join(item[4]) + result["complexity"] = item[3] + result["command"] = space.join(item[4]) return result + return [parse_item(item) for item in response] @@ -437,42 +440,42 @@ def parse_stralgo(response, **options): When WITHMATCHLEN is given, each array representing a match will also have the length of the match at the beginning of the array. """ - if options.get('len', False): + if options.get("len", False): return int(response) - if options.get('idx', False): - if options.get('withmatchlen', False): - matches = [[(int(match[-1]))] + list(map(tuple, match[:-1])) - for match in response[1]] + if options.get("idx", False): + if options.get("withmatchlen", False): + matches = [ + [(int(match[-1]))] + list(map(tuple, match[:-1])) + for match in response[1] + ] else: - matches = [list(map(tuple, match)) - for match in response[1]] + matches = [list(map(tuple, match)) for match in response[1]] return { str_if_bytes(response[0]): matches, - str_if_bytes(response[2]): int(response[3]) + str_if_bytes(response[2]): int(response[3]), } return str_if_bytes(response) def parse_cluster_info(response, **options): response = str_if_bytes(response) - return dict(line.split(':') for line in response.splitlines() if line) + return dict(line.split(":") for line in response.splitlines() if line) def _parse_node_line(line): - line_items = line.split(' ') - node_id, addr, flags, master_id, ping, pong, epoch, \ - connected = line.split(' ')[:8] - addr = addr.split('@')[0] - slots = [sl.split('-') for sl in line_items[8:]] + line_items = line.split(" ") + node_id, addr, flags, master_id, ping, pong, epoch, connected = line.split(" ")[:8] + addr = addr.split("@")[0] + slots = [sl.split("-") for sl in line_items[8:]] node_dict = { - 'node_id': node_id, - 'flags': flags, - 'master_id': master_id, - 'last_ping_sent': ping, - 'last_pong_rcvd': pong, - 'epoch': epoch, - 'slots': slots, - 'connected': True if connected == 'connected' else False + "node_id": node_id, + "flags": flags, + "master_id": master_id, + "last_ping_sent": ping, + "last_pong_rcvd": pong, + "epoch": epoch, + "slots": slots, + "connected": True if connected == "connected" else False, } return addr, node_dict @@ -492,7 +495,7 @@ def parse_geosearch_generic(response, **options): Parse the response of 'GEOSEARCH', GEORADIUS' and 'GEORADIUSBYMEMBER' commands according to 'withdist', 'withhash' and 'withcoord' labels. """ - if options['store'] or options['store_dist']: + if options["store"] or options["store_dist"]: # `store` and `store_dist` cant be combined # with other command arguments. # relevant to 'GEORADIUS' and 'GEORADIUSBYMEMBER' @@ -503,24 +506,21 @@ def parse_geosearch_generic(response, **options): else: response_list = response - if not options['withdist'] and not options['withcoord'] \ - and not options['withhash']: + if not options["withdist"] and not options["withcoord"] and not options["withhash"]: # just a bunch of places return response_list cast = { - 'withdist': float, - 'withcoord': lambda ll: (float(ll[0]), float(ll[1])), - 'withhash': int + "withdist": float, + "withcoord": lambda ll: (float(ll[0]), float(ll[1])), + "withhash": int, } # zip all output results with each casting function to get # the properly native Python value. f = [lambda x: x] - f += [cast[o] for o in ['withdist', 'withhash', 'withcoord'] if options[o]] - return [ - list(map(lambda fv: fv[0](fv[1]), zip(f, r))) for r in response_list - ] + f += [cast[o] for o in ["withdist", "withhash", "withcoord"] if options[o]] + return [list(map(lambda fv: fv[0](fv[1]), zip(f, r))) for r in response_list] def parse_command(response, **options): @@ -528,12 +528,12 @@ def parse_command(response, **options): for command in response: cmd_dict = {} cmd_name = str_if_bytes(command[0]) - cmd_dict['name'] = cmd_name - cmd_dict['arity'] = int(command[1]) - cmd_dict['flags'] = [str_if_bytes(flag) for flag in command[2]] - cmd_dict['first_key_pos'] = command[3] - cmd_dict['last_key_pos'] = command[4] - cmd_dict['step_count'] = command[5] + cmd_dict["name"] = cmd_name + cmd_dict["arity"] = int(command[1]) + cmd_dict["flags"] = [str_if_bytes(flag) for flag in command[2]] + cmd_dict["first_key_pos"] = command[3] + cmd_dict["last_key_pos"] = command[4] + cmd_dict["step_count"] = command[5] commands[cmd_name] = cmd_dict return commands @@ -545,7 +545,7 @@ def parse_pubsub_numsub(response, **options): def parse_client_kill(response, **options): if isinstance(response, int): return response - return str_if_bytes(response) == 'OK' + return str_if_bytes(response) == "OK" def parse_acl_getuser(response, **options): @@ -554,21 +554,21 @@ def parse_acl_getuser(response, **options): data = pairs_to_dict(response, decode_keys=True) # convert everything but user-defined data in 'keys' to native strings - data['flags'] = list(map(str_if_bytes, data['flags'])) - data['passwords'] = list(map(str_if_bytes, data['passwords'])) - data['commands'] = str_if_bytes(data['commands']) + data["flags"] = list(map(str_if_bytes, data["flags"])) + data["passwords"] = list(map(str_if_bytes, data["passwords"])) + data["commands"] = str_if_bytes(data["commands"]) # split 'commands' into separate 'categories' and 'commands' lists commands, categories = [], [] - for command in data['commands'].split(' '): - if '@' in command: + for command in data["commands"].split(" "): + if "@" in command: categories.append(command) else: commands.append(command) - data['commands'] = commands - data['categories'] = categories - data['enabled'] = 'on' in data['flags'] + data["commands"] = commands + data["categories"] = categories + data["enabled"] = "on" in data["flags"] return data @@ -579,7 +579,7 @@ def parse_acl_log(response, **options): data = [] for log in response: log_data = pairs_to_dict(log, True, True) - client_info = log_data.get('client-info', '') + client_info = log_data.get("client-info", "") log_data["client-info"] = parse_client_info(client_info) # float() is lossy comparing to the "double" in C @@ -602,9 +602,22 @@ def parse_client_info(value): client_info[key] = value # Those fields are defined as int in networking.c - for int_key in {"id", "age", "idle", "db", "sub", "psub", - "multi", "qbuf", "qbuf-free", "obl", - "argv-mem", "oll", "omem", "tot-mem"}: + for int_key in { + "id", + "age", + "idle", + "db", + "sub", + "psub", + "multi", + "qbuf", + "qbuf-free", + "obl", + "argv-mem", + "oll", + "omem", + "tot-mem", + }: client_info[int_key] = int(client_info[int_key]) return client_info @@ -622,11 +635,11 @@ def parse_set_result(response, **options): - BOOL - String when GET argument is used """ - if options.get('get'): + if options.get("get"): # Redis will return a getCommand result. # See `setGenericCommand` in t_string.c return response - return response and str_if_bytes(response) == 'OK' + return response and str_if_bytes(response) == "OK" class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): @@ -641,158 +654,156 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): configuration, an instance will either use a ConnectionPool, or Connection object to talk to redis. """ + RESPONSE_CALLBACKS = { **string_keys_to_dict( - 'AUTH COPY EXPIRE EXPIREAT PEXPIRE PEXPIREAT ' - 'HEXISTS HMSET LMOVE BLMOVE MOVE ' - 'MSETNX PERSIST PSETEX RENAMENX SISMEMBER SMOVE SETEX SETNX', - bool + "AUTH COPY EXPIRE EXPIREAT PEXPIRE PEXPIREAT " + "HEXISTS HMSET LMOVE BLMOVE MOVE " + "MSETNX PERSIST PSETEX RENAMENX SISMEMBER SMOVE SETEX SETNX", + bool, ), **string_keys_to_dict( - 'BITCOUNT BITPOS DECRBY DEL EXISTS GEOADD GETBIT HDEL HLEN ' - 'HSTRLEN INCRBY LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD ' - 'SCARD SDIFFSTORE SETBIT SETRANGE SINTERSTORE SREM STRLEN ' - 'SUNIONSTORE UNLINK XACK XDEL XLEN XTRIM ZCARD ZLEXCOUNT ZREM ' - 'ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE', - int + "BITCOUNT BITPOS DECRBY DEL EXISTS GEOADD GETBIT HDEL HLEN " + "HSTRLEN INCRBY LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD " + "SCARD SDIFFSTORE SETBIT SETRANGE SINTERSTORE SREM STRLEN " + "SUNIONSTORE UNLINK XACK XDEL XLEN XTRIM ZCARD ZLEXCOUNT ZREM " + "ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE", + int, ), + **string_keys_to_dict("INCRBYFLOAT HINCRBYFLOAT", float), **string_keys_to_dict( - 'INCRBYFLOAT HINCRBYFLOAT', - float + # these return OK, or int if redis-server is >=1.3.4 + "LPUSH RPUSH", + lambda r: isinstance(r, int) and r or str_if_bytes(r) == "OK", ), + **string_keys_to_dict("SORT", sort_return_tuples), + **string_keys_to_dict("ZSCORE ZINCRBY GEODIST", float_or_none), **string_keys_to_dict( - # these return OK, or int if redis-server is >=1.3.4 - 'LPUSH RPUSH', - lambda r: isinstance(r, int) and r or str_if_bytes(r) == 'OK' + "FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE READONLY READWRITE " + "RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH ", + bool_ok, ), - **string_keys_to_dict('SORT', sort_return_tuples), - **string_keys_to_dict('ZSCORE ZINCRBY GEODIST', float_or_none), + **string_keys_to_dict("BLPOP BRPOP", lambda r: r and tuple(r) or None), **string_keys_to_dict( - 'FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE READONLY READWRITE ' - 'RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH ', - bool_ok + "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set() ), - **string_keys_to_dict('BLPOP BRPOP', lambda r: r and tuple(r) or None), **string_keys_to_dict( - 'SDIFF SINTER SMEMBERS SUNION', - lambda r: r and set(r) or set() + "ZPOPMAX ZPOPMIN ZINTER ZDIFF ZUNION ZRANGE ZRANGEBYSCORE " + "ZREVRANGE ZREVRANGEBYSCORE", + zset_score_pairs, ), **string_keys_to_dict( - 'ZPOPMAX ZPOPMIN ZINTER ZDIFF ZUNION ZRANGE ZRANGEBYSCORE ' - 'ZREVRANGE ZREVRANGEBYSCORE', zset_score_pairs + "BZPOPMIN BZPOPMAX", lambda r: r and (r[0], r[1], float(r[2])) or None + ), + **string_keys_to_dict("ZRANK ZREVRANK", int_or_none), + **string_keys_to_dict("XREVRANGE XRANGE", parse_stream_list), + **string_keys_to_dict("XREAD XREADGROUP", parse_xread), + **string_keys_to_dict("BGREWRITEAOF BGSAVE", lambda r: True), + "ACL CAT": lambda r: list(map(str_if_bytes, r)), + "ACL DELUSER": int, + "ACL GENPASS": str_if_bytes, + "ACL GETUSER": parse_acl_getuser, + "ACL HELP": lambda r: list(map(str_if_bytes, r)), + "ACL LIST": lambda r: list(map(str_if_bytes, r)), + "ACL LOAD": bool_ok, + "ACL LOG": parse_acl_log, + "ACL SAVE": bool_ok, + "ACL SETUSER": bool_ok, + "ACL USERS": lambda r: list(map(str_if_bytes, r)), + "ACL WHOAMI": str_if_bytes, + "CLIENT GETNAME": str_if_bytes, + "CLIENT ID": int, + "CLIENT KILL": parse_client_kill, + "CLIENT LIST": parse_client_list, + "CLIENT INFO": parse_client_info, + "CLIENT SETNAME": bool_ok, + "CLIENT UNBLOCK": lambda r: r and int(r) == 1 or False, + "CLIENT PAUSE": bool_ok, + "CLIENT GETREDIR": int, + "CLIENT TRACKINGINFO": lambda r: list(map(str_if_bytes, r)), + "CLUSTER ADDSLOTS": bool_ok, + "CLUSTER COUNT-FAILURE-REPORTS": lambda x: int(x), + "CLUSTER COUNTKEYSINSLOT": lambda x: int(x), + "CLUSTER DELSLOTS": bool_ok, + "CLUSTER FAILOVER": bool_ok, + "CLUSTER FORGET": bool_ok, + "CLUSTER INFO": parse_cluster_info, + "CLUSTER KEYSLOT": lambda x: int(x), + "CLUSTER MEET": bool_ok, + "CLUSTER NODES": parse_cluster_nodes, + "CLUSTER REPLICATE": bool_ok, + "CLUSTER RESET": bool_ok, + "CLUSTER SAVECONFIG": bool_ok, + "CLUSTER SET-CONFIG-EPOCH": bool_ok, + "CLUSTER SETSLOT": bool_ok, + "CLUSTER SLAVES": parse_cluster_nodes, + "CLUSTER REPLICAS": parse_cluster_nodes, + "COMMAND": parse_command, + "COMMAND COUNT": int, + "COMMAND GETKEYS": lambda r: list(map(str_if_bytes, r)), + "CONFIG GET": parse_config_get, + "CONFIG RESETSTAT": bool_ok, + "CONFIG SET": bool_ok, + "DEBUG OBJECT": parse_debug_object, + "GEOHASH": lambda r: list(map(str_if_bytes, r)), + "GEOPOS": lambda r: list( + map(lambda ll: (float(ll[0]), float(ll[1])) if ll is not None else None, r) ), - **string_keys_to_dict('BZPOPMIN BZPOPMAX', \ - lambda r: - r and (r[0], r[1], float(r[2])) or None), - **string_keys_to_dict('ZRANK ZREVRANK', int_or_none), - **string_keys_to_dict('XREVRANGE XRANGE', parse_stream_list), - **string_keys_to_dict('XREAD XREADGROUP', parse_xread), - **string_keys_to_dict('BGREWRITEAOF BGSAVE', lambda r: True), - 'ACL CAT': lambda r: list(map(str_if_bytes, r)), - 'ACL DELUSER': int, - 'ACL GENPASS': str_if_bytes, - 'ACL GETUSER': parse_acl_getuser, - 'ACL HELP': lambda r: list(map(str_if_bytes, r)), - 'ACL LIST': lambda r: list(map(str_if_bytes, r)), - 'ACL LOAD': bool_ok, - 'ACL LOG': parse_acl_log, - 'ACL SAVE': bool_ok, - 'ACL SETUSER': bool_ok, - 'ACL USERS': lambda r: list(map(str_if_bytes, r)), - 'ACL WHOAMI': str_if_bytes, - 'CLIENT GETNAME': str_if_bytes, - 'CLIENT ID': int, - 'CLIENT KILL': parse_client_kill, - 'CLIENT LIST': parse_client_list, - 'CLIENT INFO': parse_client_info, - 'CLIENT SETNAME': bool_ok, - 'CLIENT UNBLOCK': lambda r: r and int(r) == 1 or False, - 'CLIENT PAUSE': bool_ok, - 'CLIENT GETREDIR': int, - 'CLIENT TRACKINGINFO': lambda r: list(map(str_if_bytes, r)), - 'CLUSTER ADDSLOTS': bool_ok, - 'CLUSTER COUNT-FAILURE-REPORTS': lambda x: int(x), - 'CLUSTER COUNTKEYSINSLOT': lambda x: int(x), - 'CLUSTER DELSLOTS': bool_ok, - 'CLUSTER FAILOVER': bool_ok, - 'CLUSTER FORGET': bool_ok, - 'CLUSTER INFO': parse_cluster_info, - 'CLUSTER KEYSLOT': lambda x: int(x), - 'CLUSTER MEET': bool_ok, - 'CLUSTER NODES': parse_cluster_nodes, - 'CLUSTER REPLICATE': bool_ok, - 'CLUSTER RESET': bool_ok, - 'CLUSTER SAVECONFIG': bool_ok, - 'CLUSTER SET-CONFIG-EPOCH': bool_ok, - 'CLUSTER SETSLOT': bool_ok, - 'CLUSTER SLAVES': parse_cluster_nodes, - 'CLUSTER REPLICAS': parse_cluster_nodes, - 'COMMAND': parse_command, - 'COMMAND COUNT': int, - 'COMMAND GETKEYS': lambda r: list(map(str_if_bytes, r)), - 'CONFIG GET': parse_config_get, - 'CONFIG RESETSTAT': bool_ok, - 'CONFIG SET': bool_ok, - 'DEBUG OBJECT': parse_debug_object, - 'GEOHASH': lambda r: list(map(str_if_bytes, r)), - 'GEOPOS': lambda r: list(map(lambda ll: (float(ll[0]), - float(ll[1])) - if ll is not None else None, r)), - 'GEOSEARCH': parse_geosearch_generic, - 'GEORADIUS': parse_geosearch_generic, - 'GEORADIUSBYMEMBER': parse_geosearch_generic, - 'HGETALL': lambda r: r and pairs_to_dict(r) or {}, - 'HSCAN': parse_hscan, - 'INFO': parse_info, - 'LASTSAVE': timestamp_to_datetime, - 'MEMORY PURGE': bool_ok, - 'MEMORY STATS': parse_memory_stats, - 'MEMORY USAGE': int_or_none, - 'MODULE LOAD': parse_module_result, - 'MODULE UNLOAD': parse_module_result, - 'MODULE LIST': lambda r: [pairs_to_dict(m) for m in r], - 'OBJECT': parse_object, - 'PING': lambda r: str_if_bytes(r) == 'PONG', - 'QUIT': bool_ok, - 'STRALGO': parse_stralgo, - 'PUBSUB NUMSUB': parse_pubsub_numsub, - 'RANDOMKEY': lambda r: r and r or None, - 'SCAN': parse_scan, - 'SCRIPT EXISTS': lambda r: list(map(bool, r)), - 'SCRIPT FLUSH': bool_ok, - 'SCRIPT KILL': bool_ok, - 'SCRIPT LOAD': str_if_bytes, - 'SENTINEL CKQUORUM': bool_ok, - 'SENTINEL FAILOVER': bool_ok, - 'SENTINEL FLUSHCONFIG': bool_ok, - 'SENTINEL GET-MASTER-ADDR-BY-NAME': parse_sentinel_get_master, - 'SENTINEL MASTER': parse_sentinel_master, - 'SENTINEL MASTERS': parse_sentinel_masters, - 'SENTINEL MONITOR': bool_ok, - 'SENTINEL RESET': bool_ok, - 'SENTINEL REMOVE': bool_ok, - 'SENTINEL SENTINELS': parse_sentinel_slaves_and_sentinels, - 'SENTINEL SET': bool_ok, - 'SENTINEL SLAVES': parse_sentinel_slaves_and_sentinels, - 'SET': parse_set_result, - 'SLOWLOG GET': parse_slowlog_get, - 'SLOWLOG LEN': int, - 'SLOWLOG RESET': bool_ok, - 'SSCAN': parse_scan, - 'TIME': lambda x: (int(x[0]), int(x[1])), - 'XCLAIM': parse_xclaim, - 'XAUTOCLAIM': parse_xautoclaim, - 'XGROUP CREATE': bool_ok, - 'XGROUP DELCONSUMER': int, - 'XGROUP DESTROY': bool, - 'XGROUP SETID': bool_ok, - 'XINFO CONSUMERS': parse_list_of_dicts, - 'XINFO GROUPS': parse_list_of_dicts, - 'XINFO STREAM': parse_xinfo_stream, - 'XPENDING': parse_xpending, - 'ZADD': parse_zadd, - 'ZSCAN': parse_zscan, - 'ZMSCORE': parse_zmscore, + "GEOSEARCH": parse_geosearch_generic, + "GEORADIUS": parse_geosearch_generic, + "GEORADIUSBYMEMBER": parse_geosearch_generic, + "HGETALL": lambda r: r and pairs_to_dict(r) or {}, + "HSCAN": parse_hscan, + "INFO": parse_info, + "LASTSAVE": timestamp_to_datetime, + "MEMORY PURGE": bool_ok, + "MEMORY STATS": parse_memory_stats, + "MEMORY USAGE": int_or_none, + "MODULE LOAD": parse_module_result, + "MODULE UNLOAD": parse_module_result, + "MODULE LIST": lambda r: [pairs_to_dict(m) for m in r], + "OBJECT": parse_object, + "PING": lambda r: str_if_bytes(r) == "PONG", + "QUIT": bool_ok, + "STRALGO": parse_stralgo, + "PUBSUB NUMSUB": parse_pubsub_numsub, + "RANDOMKEY": lambda r: r and r or None, + "SCAN": parse_scan, + "SCRIPT EXISTS": lambda r: list(map(bool, r)), + "SCRIPT FLUSH": bool_ok, + "SCRIPT KILL": bool_ok, + "SCRIPT LOAD": str_if_bytes, + "SENTINEL CKQUORUM": bool_ok, + "SENTINEL FAILOVER": bool_ok, + "SENTINEL FLUSHCONFIG": bool_ok, + "SENTINEL GET-MASTER-ADDR-BY-NAME": parse_sentinel_get_master, + "SENTINEL MASTER": parse_sentinel_master, + "SENTINEL MASTERS": parse_sentinel_masters, + "SENTINEL MONITOR": bool_ok, + "SENTINEL RESET": bool_ok, + "SENTINEL REMOVE": bool_ok, + "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels, + "SENTINEL SET": bool_ok, + "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels, + "SET": parse_set_result, + "SLOWLOG GET": parse_slowlog_get, + "SLOWLOG LEN": int, + "SLOWLOG RESET": bool_ok, + "SSCAN": parse_scan, + "TIME": lambda x: (int(x[0]), int(x[1])), + "XCLAIM": parse_xclaim, + "XAUTOCLAIM": parse_xautoclaim, + "XGROUP CREATE": bool_ok, + "XGROUP DELCONSUMER": int, + "XGROUP DESTROY": bool, + "XGROUP SETID": bool_ok, + "XINFO CONSUMERS": parse_list_of_dicts, + "XINFO GROUPS": parse_list_of_dicts, + "XINFO STREAM": parse_xinfo_stream, + "XPENDING": parse_xpending, + "ZADD": parse_zadd, + "ZSCAN": parse_zscan, + "ZMSCORE": parse_zmscore, } @classmethod @@ -839,20 +850,38 @@ class initializer. In the case of conflicting arguments, querystring connection_pool = ConnectionPool.from_url(url, **kwargs) return cls(connection_pool=connection_pool) - def __init__(self, host='localhost', port=6379, - db=0, password=None, socket_timeout=None, - socket_connect_timeout=None, - socket_keepalive=None, socket_keepalive_options=None, - connection_pool=None, unix_socket_path=None, - encoding='utf-8', encoding_errors='strict', - charset=None, errors=None, - decode_responses=False, retry_on_timeout=False, - ssl=False, ssl_keyfile=None, ssl_certfile=None, - ssl_cert_reqs='required', ssl_ca_certs=None, - ssl_check_hostname=False, - max_connections=None, single_connection_client=False, - health_check_interval=0, client_name=None, username=None, - retry=None, redis_connect_func=None): + def __init__( + self, + host="localhost", + port=6379, + db=0, + password=None, + socket_timeout=None, + socket_connect_timeout=None, + socket_keepalive=None, + socket_keepalive_options=None, + connection_pool=None, + unix_socket_path=None, + encoding="utf-8", + encoding_errors="strict", + charset=None, + errors=None, + decode_responses=False, + retry_on_timeout=False, + ssl=False, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs="required", + ssl_ca_certs=None, + ssl_check_hostname=False, + max_connections=None, + single_connection_client=False, + health_check_interval=0, + client_name=None, + username=None, + retry=None, + redis_connect_func=None, + ): """ Initialize a new Redis client. To specify a retry policy, first set `retry_on_timeout` to `True` @@ -860,62 +889,73 @@ def __init__(self, host='localhost', port=6379, """ if not connection_pool: if charset is not None: - warnings.warn(DeprecationWarning( - '"charset" is deprecated. Use "encoding" instead')) + warnings.warn( + DeprecationWarning( + '"charset" is deprecated. Use "encoding" instead' + ) + ) encoding = charset if errors is not None: - warnings.warn(DeprecationWarning( - '"errors" is deprecated. Use "encoding_errors" instead')) + warnings.warn( + DeprecationWarning( + '"errors" is deprecated. Use "encoding_errors" instead' + ) + ) encoding_errors = errors kwargs = { - 'db': db, - 'username': username, - 'password': password, - 'socket_timeout': socket_timeout, - 'encoding': encoding, - 'encoding_errors': encoding_errors, - 'decode_responses': decode_responses, - 'retry_on_timeout': retry_on_timeout, - 'retry': copy.deepcopy(retry), - 'max_connections': max_connections, - 'health_check_interval': health_check_interval, - 'client_name': client_name, - 'redis_connect_func': redis_connect_func + "db": db, + "username": username, + "password": password, + "socket_timeout": socket_timeout, + "encoding": encoding, + "encoding_errors": encoding_errors, + "decode_responses": decode_responses, + "retry_on_timeout": retry_on_timeout, + "retry": copy.deepcopy(retry), + "max_connections": max_connections, + "health_check_interval": health_check_interval, + "client_name": client_name, + "redis_connect_func": redis_connect_func, } # based on input, setup appropriate connection args if unix_socket_path is not None: - kwargs.update({ - 'path': unix_socket_path, - 'connection_class': UnixDomainSocketConnection - }) + kwargs.update( + { + "path": unix_socket_path, + "connection_class": UnixDomainSocketConnection, + } + ) else: # TCP specific options - kwargs.update({ - 'host': host, - 'port': port, - 'socket_connect_timeout': socket_connect_timeout, - 'socket_keepalive': socket_keepalive, - 'socket_keepalive_options': socket_keepalive_options, - }) + kwargs.update( + { + "host": host, + "port": port, + "socket_connect_timeout": socket_connect_timeout, + "socket_keepalive": socket_keepalive, + "socket_keepalive_options": socket_keepalive_options, + } + ) if ssl: - kwargs.update({ - 'connection_class': SSLConnection, - 'ssl_keyfile': ssl_keyfile, - 'ssl_certfile': ssl_certfile, - 'ssl_cert_reqs': ssl_cert_reqs, - 'ssl_ca_certs': ssl_ca_certs, - 'ssl_check_hostname': ssl_check_hostname, - }) + kwargs.update( + { + "connection_class": SSLConnection, + "ssl_keyfile": ssl_keyfile, + "ssl_certfile": ssl_certfile, + "ssl_cert_reqs": ssl_cert_reqs, + "ssl_ca_certs": ssl_ca_certs, + "ssl_check_hostname": ssl_check_hostname, + } + ) connection_pool = ConnectionPool(**kwargs) self.connection_pool = connection_pool self.connection = None if single_connection_client: - self.connection = self.connection_pool.get_connection('_') + self.connection = self.connection_pool.get_connection("_") - self.response_callbacks = CaseInsensitiveDict( - self.__class__.RESPONSE_CALLBACKS) + self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) def __repr__(self): return f"{type(self).__name__}<{repr(self.connection_pool)}>" @@ -924,8 +964,11 @@ def set_response_callback(self, command, callback): "Set a custom Response Callback" self.response_callbacks[command] = callback - def load_external_module(self, funcname, func, - ): + def load_external_module( + self, + funcname, + func, + ): """ This function can be used to add externally defined redis modules, and their namespaces to the redis client. @@ -957,10 +1000,8 @@ def pipeline(self, transaction=True, shard_hint=None): between the client and server. """ return Pipeline( - self.connection_pool, - self.response_callbacks, - transaction, - shard_hint) + self.connection_pool, self.response_callbacks, transaction, shard_hint + ) def transaction(self, func, *watches, **kwargs): """ @@ -968,9 +1009,9 @@ def transaction(self, func, *watches, **kwargs): while watching all keys specified in `watches`. The 'func' callable should expect a single argument which is a Pipeline object. """ - shard_hint = kwargs.pop('shard_hint', None) - value_from_callable = kwargs.pop('value_from_callable', False) - watch_delay = kwargs.pop('watch_delay', None) + shard_hint = kwargs.pop("shard_hint", None) + value_from_callable = kwargs.pop("value_from_callable", False) + watch_delay = kwargs.pop("watch_delay", None) with self.pipeline(True, shard_hint) as pipe: while True: try: @@ -984,8 +1025,15 @@ def transaction(self, func, *watches, **kwargs): time.sleep(watch_delay) continue - def lock(self, name, timeout=None, sleep=0.1, blocking_timeout=None, - lock_class=None, thread_local=True): + def lock( + self, + name, + timeout=None, + sleep=0.1, + blocking_timeout=None, + lock_class=None, + thread_local=True, + ): """ Return a new Lock object using key ``name`` that mimics the behavior of threading.Lock. @@ -1028,12 +1076,17 @@ def lock(self, name, timeout=None, sleep=0.1, blocking_timeout=None, local storage isn't disabled in this case, the worker thread won't see the token set by the thread that acquired the lock. Our assumption is that these cases aren't common and as such default to using - thread local storage. """ + thread local storage.""" if lock_class is None: lock_class = Lock - return lock_class(self, name, timeout=timeout, sleep=sleep, - blocking_timeout=blocking_timeout, - thread_local=thread_local) + return lock_class( + self, + name, + timeout=timeout, + sleep=sleep, + blocking_timeout=blocking_timeout, + thread_local=thread_local, + ) def pubsub(self, **kwargs): """ @@ -1047,8 +1100,9 @@ def monitor(self): return Monitor(self.connection_pool) def client(self): - return self.__class__(connection_pool=self.connection_pool, - single_connection_client=True) + return self.__class__( + connection_pool=self.connection_pool, single_connection_client=True + ) def __enter__(self): return self @@ -1065,11 +1119,7 @@ def close(self): self.connection = None self.connection_pool.release(conn) - def _send_command_parse_response(self, - conn, - command_name, - *args, - **options): + def _send_command_parse_response(self, conn, command_name, *args, **options): """ Send a command and parse the response """ @@ -1095,11 +1145,11 @@ def execute_command(self, *args, **options): try: return conn.retry.call_with_retry( - lambda: self._send_command_parse_response(conn, - command_name, - *args, - **options), - lambda error: self._disconnect_raise(conn, error)) + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) finally: if not self.connection: pool.release(conn) @@ -1129,19 +1179,20 @@ class Monitor: next_command() method returns one command from monitor listen() method yields commands from monitor. """ - monitor_re = re.compile(r'\[(\d+) (.*)\] (.*)') + + monitor_re = re.compile(r"\[(\d+) (.*)\] (.*)") command_re = re.compile(r'"(.*?)(? conn.next_health_check: - conn.send_command('PING', self.HEALTH_CHECK_MESSAGE, - check_health=False) + conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False) def _normalize_keys(self, data): """ @@ -1371,7 +1425,7 @@ def psubscribe(self, *args, **kwargs): args = list_or_args(args[0], args[1:]) new_patterns = dict.fromkeys(args) new_patterns.update(kwargs) - ret_val = self.execute_command('PSUBSCRIBE', *new_patterns.keys()) + ret_val = self.execute_command("PSUBSCRIBE", *new_patterns.keys()) # update the patterns dict AFTER we send the command. we don't want to # subscribe twice to these patterns, once for the command and again # for the reconnection. @@ -1391,7 +1445,7 @@ def punsubscribe(self, *args): else: patterns = self.patterns self.pending_unsubscribe_patterns.update(patterns) - return self.execute_command('PUNSUBSCRIBE', *args) + return self.execute_command("PUNSUBSCRIBE", *args) def subscribe(self, *args, **kwargs): """ @@ -1405,7 +1459,7 @@ def subscribe(self, *args, **kwargs): args = list_or_args(args[0], args[1:]) new_channels = dict.fromkeys(args) new_channels.update(kwargs) - ret_val = self.execute_command('SUBSCRIBE', *new_channels.keys()) + ret_val = self.execute_command("SUBSCRIBE", *new_channels.keys()) # update the channels dict AFTER we send the command. we don't want to # subscribe twice to these channels, once for the command and again # for the reconnection. @@ -1425,7 +1479,7 @@ def unsubscribe(self, *args): else: channels = self.channels self.pending_unsubscribe_channels.update(channels) - return self.execute_command('UNSUBSCRIBE', *args) + return self.execute_command("UNSUBSCRIBE", *args) def listen(self): "Listen for messages on channels this client has been subscribed to" @@ -1451,8 +1505,8 @@ def ping(self, message=None): """ Ping the Redis server """ - message = '' if message is None else message - return self.execute_command('PING', message) + message = "" if message is None else message + return self.execute_command("PING", message) def handle_message(self, response, ignore_subscribe_messages=False): """ @@ -1461,31 +1515,31 @@ def handle_message(self, response, ignore_subscribe_messages=False): message being returned. """ message_type = str_if_bytes(response[0]) - if message_type == 'pmessage': + if message_type == "pmessage": message = { - 'type': message_type, - 'pattern': response[1], - 'channel': response[2], - 'data': response[3] + "type": message_type, + "pattern": response[1], + "channel": response[2], + "data": response[3], } - elif message_type == 'pong': + elif message_type == "pong": message = { - 'type': message_type, - 'pattern': None, - 'channel': None, - 'data': response[1] + "type": message_type, + "pattern": None, + "channel": None, + "data": response[1], } else: message = { - 'type': message_type, - 'pattern': None, - 'channel': response[1], - 'data': response[2] + "type": message_type, + "pattern": None, + "channel": response[1], + "data": response[2], } # if this is an unsubscribe message, remove it from memory if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES: - if message_type == 'punsubscribe': + if message_type == "punsubscribe": pattern = response[1] if pattern in self.pending_unsubscribe_patterns: self.pending_unsubscribe_patterns.remove(pattern) @@ -1498,14 +1552,14 @@ def handle_message(self, response, ignore_subscribe_messages=False): if message_type in self.PUBLISH_MESSAGE_TYPES: # if there's a message handler, invoke it - if message_type == 'pmessage': - handler = self.patterns.get(message['pattern'], None) + if message_type == "pmessage": + handler = self.patterns.get(message["pattern"], None) else: - handler = self.channels.get(message['channel'], None) + handler = self.channels.get(message["channel"], None) if handler: handler(message) return None - elif message_type != 'pong': + elif message_type != "pong": # this is a subscribe/unsubscribe message. ignore if we don't # want them if ignore_subscribe_messages or self.ignore_subscribe_messages: @@ -1513,8 +1567,7 @@ def handle_message(self, response, ignore_subscribe_messages=False): return message - def run_in_thread(self, sleep_time=0, daemon=False, - exception_handler=None): + def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None): for channel, handler in self.channels.items(): if handler is None: raise PubSubError(f"Channel: '{channel}' has no handler registered") @@ -1523,18 +1576,14 @@ def run_in_thread(self, sleep_time=0, daemon=False, raise PubSubError(f"Pattern: '{pattern}' has no handler registered") thread = PubSubWorkerThread( - self, - sleep_time, - daemon=daemon, - exception_handler=exception_handler + self, sleep_time, daemon=daemon, exception_handler=exception_handler ) thread.start() return thread class PubSubWorkerThread(threading.Thread): - def __init__(self, pubsub, sleep_time, daemon=False, - exception_handler=None): + def __init__(self, pubsub, sleep_time, daemon=False, exception_handler=None): super().__init__() self.daemon = daemon self.pubsub = pubsub @@ -1550,8 +1599,7 @@ def run(self): sleep_time = self.sleep_time while self._running.is_set(): try: - pubsub.get_message(ignore_subscribe_messages=True, - timeout=sleep_time) + pubsub.get_message(ignore_subscribe_messages=True, timeout=sleep_time) except BaseException as e: if self.exception_handler is None: raise @@ -1584,10 +1632,9 @@ class Pipeline(Redis): on a key of a different datatype. """ - UNWATCH_COMMANDS = {'DISCARD', 'EXEC', 'UNWATCH'} + UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} - def __init__(self, connection_pool, response_callbacks, transaction, - shard_hint): + def __init__(self, connection_pool, response_callbacks, transaction, shard_hint): self.connection_pool = connection_pool self.connection = None self.response_callbacks = response_callbacks @@ -1625,7 +1672,7 @@ def reset(self): try: # call this manually since our unwatch or # immediate_execute_command methods can call reset() - self.connection.send_command('UNWATCH') + self.connection.send_command("UNWATCH") self.connection.read_response() except ConnectionError: # disconnect will also remove any previous WATCHes @@ -1645,15 +1692,15 @@ def multi(self): are issued. End the transactional block with `execute`. """ if self.explicit_transaction: - raise RedisError('Cannot issue nested calls to MULTI') + raise RedisError("Cannot issue nested calls to MULTI") if self.command_stack: - raise RedisError('Commands without an initial WATCH have already ' - 'been issued') + raise RedisError( + "Commands without an initial WATCH have already " "been issued" + ) self.explicit_transaction = True def execute_command(self, *args, **kwargs): - if (self.watching or args[0] == 'WATCH') and \ - not self.explicit_transaction: + if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) @@ -1670,8 +1717,9 @@ def _disconnect_reset_raise(self, conn, error): # indicates the user should retry this transaction. if self.watching: self.reset() - raise WatchError("A ConnectionError occurred on while " - "watching one or more keys") + raise WatchError( + "A ConnectionError occurred on while " "watching one or more keys" + ) # if retry_on_timeout is not set, or the error is not # a TimeoutError, raise it if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): @@ -1689,16 +1737,15 @@ def immediate_execute_command(self, *args, **options): conn = self.connection # if this is the first call, we need a connection if not conn: - conn = self.connection_pool.get_connection(command_name, - self.shard_hint) + conn = self.connection_pool.get_connection(command_name, self.shard_hint) self.connection = conn return conn.retry.call_with_retry( - lambda: self._send_command_parse_response(conn, - command_name, - *args, - **options), - lambda error: self._disconnect_reset_raise(conn, error)) + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_reset_raise(conn, error), + ) def pipeline_execute_command(self, *args, **options): """ @@ -1716,9 +1763,10 @@ def pipeline_execute_command(self, *args, **options): return self def _execute_transaction(self, connection, commands, raise_on_error): - cmds = chain([(('MULTI', ), {})], commands, [(('EXEC', ), {})]) - all_cmds = connection.pack_commands([args for args, options in cmds - if EMPTY_RESPONSE not in options]) + cmds = chain([(("MULTI",), {})], commands, [(("EXEC",), {})]) + all_cmds = connection.pack_commands( + [args for args, options in cmds if EMPTY_RESPONSE not in options] + ) connection.send_packed_command(all_cmds) errors = [] @@ -1727,7 +1775,7 @@ def _execute_transaction(self, connection, commands, raise_on_error): # so that we read all the additional command messages from # the socket try: - self.parse_response(connection, '_') + self.parse_response(connection, "_") except ResponseError as e: errors.append((0, e)) @@ -1737,14 +1785,14 @@ def _execute_transaction(self, connection, commands, raise_on_error): errors.append((i, command[1][EMPTY_RESPONSE])) else: try: - self.parse_response(connection, '_') + self.parse_response(connection, "_") except ResponseError as e: self.annotate_exception(e, i + 1, command[0]) errors.append((i, e)) # parse the EXEC. try: - response = self.parse_response(connection, '_') + response = self.parse_response(connection, "_") except ExecAbortError: if errors: raise errors[0][1] @@ -1762,8 +1810,9 @@ def _execute_transaction(self, connection, commands, raise_on_error): if len(response) != len(commands): self.connection.disconnect() - raise ResponseError("Wrong number of response items from " - "pipeline execution") + raise ResponseError( + "Wrong number of response items from " "pipeline execution" + ) # find any errors in the response and raise if necessary if raise_on_error: @@ -1788,8 +1837,7 @@ def _execute_pipeline(self, connection, commands, raise_on_error): response = [] for args, options in commands: try: - response.append( - self.parse_response(connection, args[0], **options)) + response.append(self.parse_response(connection, args[0], **options)) except ResponseError as e: response.append(e) @@ -1804,19 +1852,18 @@ def raise_first_error(self, commands, response): raise r def annotate_exception(self, exception, number, command): - cmd = ' '.join(map(safe_str, command)) + cmd = " ".join(map(safe_str, command)) msg = ( - f'Command # {number} ({cmd}) of pipeline ' - f'caused error: {exception.args[0]}' + f"Command # {number} ({cmd}) of pipeline " + f"caused error: {exception.args[0]}" ) exception.args = (msg,) + exception.args[1:] def parse_response(self, connection, command_name, **options): - result = Redis.parse_response( - self, connection, command_name, **options) + result = Redis.parse_response(self, connection, command_name, **options) if command_name in self.UNWATCH_COMMANDS: self.watching = False - elif command_name == 'WATCH': + elif command_name == "WATCH": self.watching = True return result @@ -1827,11 +1874,11 @@ def load_scripts(self): shas = [s.sha for s in scripts] # we can't use the normal script_* methods because they would just # get buffered in the pipeline. - exists = immediate('SCRIPT EXISTS', *shas) + exists = immediate("SCRIPT EXISTS", *shas) if not all(exists): for s, exist in zip(scripts, exists): if not exist: - s.sha = immediate('SCRIPT LOAD', s.script) + s.sha = immediate("SCRIPT LOAD", s.script) def _disconnect_raise_reset(self, conn, error): """ @@ -1844,8 +1891,9 @@ def _disconnect_raise_reset(self, conn, error): # since this connection has died. raise a WatchError, which # indicates the user should retry this transaction. if self.watching: - raise WatchError("A ConnectionError occurred on while " - "watching one or more keys") + raise WatchError( + "A ConnectionError occurred on while " "watching one or more keys" + ) # if retry_on_timeout is not set, or the error is not # a TimeoutError, raise it if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): @@ -1866,8 +1914,7 @@ def execute(self, raise_on_error=True): conn = self.connection if not conn: - conn = self.connection_pool.get_connection('MULTI', - self.shard_hint) + conn = self.connection_pool.get_connection("MULTI", self.shard_hint) # assign to self.connection so reset() releases the connection # back to the pool after we're done self.connection = conn @@ -1875,7 +1922,8 @@ def execute(self, raise_on_error=True): try: return conn.retry.call_with_retry( lambda: execute(conn, stack, raise_on_error), - lambda error: self._disconnect_raise_reset(conn, error)) + lambda error: self._disconnect_raise_reset(conn, error), + ) finally: self.reset() @@ -1888,9 +1936,9 @@ def discard(self): def watch(self, *names): "Watches the values at keys ``names``" if self.explicit_transaction: - raise RedisError('Cannot issue a WATCH after a MULTI') - return self.execute_command('WATCH', *names) + raise RedisError("Cannot issue a WATCH after a MULTI") + return self.execute_command("WATCH", *names) def unwatch(self): "Unwatches all previously specified keys" - return self.watching and self.execute_command('UNWATCH') or True + return self.watching and self.execute_command("UNWATCH") or True diff --git a/redis/cluster.py b/redis/cluster.py index c1853aa876..57e8316ba2 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2,18 +2,15 @@ import logging import random import socket -import time -import threading import sys - +import threading +import time from collections import OrderedDict -from redis.client import CaseInsensitiveDict, Redis, PubSub -from redis.commands import ( - ClusterCommands, - CommandsParser -) -from redis.connection import DefaultParser, ConnectionPool, Encoder, parse_url -from redis.crc import key_slot, REDIS_CLUSTER_HASH_SLOTS + +from redis.client import CaseInsensitiveDict, PubSub, Redis +from redis.commands import ClusterCommands, CommandsParser +from redis.connection import ConnectionPool, DefaultParser, Encoder, parse_url +from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.exceptions import ( AskError, BusyLoadingError, @@ -34,15 +31,15 @@ dict_merge, list_keys_to_dict, merge_result, + safe_str, str_if_bytes, - safe_str ) log = logging.getLogger(__name__) def get_node_name(host, port): - return f'{host}:{port}' + return f"{host}:{port}" def get_connection(redis_node, *args, **options): @@ -67,15 +64,12 @@ def parse_pubsub_numsub(command, res, **options): except KeyError: numsub_d[channel] = numsubbed - ret_numsub = [ - (channel, numsub) - for channel, numsub in numsub_d.items() - ] + ret_numsub = [(channel, numsub) for channel, numsub in numsub_d.items()] return ret_numsub def parse_cluster_slots(resp, **options): - current_host = options.get('current_host', '') + current_host = options.get("current_host", "") def fix_server(*args): return str_if_bytes(args[0]) or current_host, args[1] @@ -85,8 +79,8 @@ def fix_server(*args): start, end, primary = slot[:3] replicas = slot[3:] slots[start, end] = { - 'primary': fix_server(*primary), - 'replicas': [fix_server(*replica) for replica in replicas], + "primary": fix_server(*primary), + "replicas": [fix_server(*replica) for replica in replicas], } return slots @@ -132,47 +126,49 @@ def fix_server(*args): # Not complete, but covers the major ones # https://redis.io/commands -READ_COMMANDS = frozenset([ - "BITCOUNT", - "BITPOS", - "EXISTS", - "GEODIST", - "GEOHASH", - "GEOPOS", - "GEORADIUS", - "GEORADIUSBYMEMBER", - "GET", - "GETBIT", - "GETRANGE", - "HEXISTS", - "HGET", - "HGETALL", - "HKEYS", - "HLEN", - "HMGET", - "HSTRLEN", - "HVALS", - "KEYS", - "LINDEX", - "LLEN", - "LRANGE", - "MGET", - "PTTL", - "RANDOMKEY", - "SCARD", - "SDIFF", - "SINTER", - "SISMEMBER", - "SMEMBERS", - "SRANDMEMBER", - "STRLEN", - "SUNION", - "TTL", - "ZCARD", - "ZCOUNT", - "ZRANGE", - "ZSCORE", -]) +READ_COMMANDS = frozenset( + [ + "BITCOUNT", + "BITPOS", + "EXISTS", + "GEODIST", + "GEOHASH", + "GEOPOS", + "GEORADIUS", + "GEORADIUSBYMEMBER", + "GET", + "GETBIT", + "GETRANGE", + "HEXISTS", + "HGET", + "HGETALL", + "HKEYS", + "HLEN", + "HMGET", + "HSTRLEN", + "HVALS", + "KEYS", + "LINDEX", + "LLEN", + "LRANGE", + "MGET", + "PTTL", + "RANDOMKEY", + "SCARD", + "SDIFF", + "SINTER", + "SISMEMBER", + "SMEMBERS", + "SRANDMEMBER", + "STRLEN", + "SUNION", + "TTL", + "ZCARD", + "ZCOUNT", + "ZRANGE", + "ZSCORE", + ] +) def cleanup_kwargs(**kwargs): @@ -190,14 +186,16 @@ def cleanup_kwargs(**kwargs): class ClusterParser(DefaultParser): EXCEPTION_CLASSES = dict_merge( - DefaultParser.EXCEPTION_CLASSES, { - 'ASK': AskError, - 'TRYAGAIN': TryAgainError, - 'MOVED': MovedError, - 'CLUSTERDOWN': ClusterDownError, - 'CROSSSLOT': ClusterCrossSlotError, - 'MASTERDOWN': MasterDownError, - }) + DefaultParser.EXCEPTION_CLASSES, + { + "ASK": AskError, + "TRYAGAIN": TryAgainError, + "MOVED": MovedError, + "CLUSTERDOWN": ClusterDownError, + "CROSSSLOT": ClusterCrossSlotError, + "MASTERDOWN": MasterDownError, + }, + ) class RedisCluster(ClusterCommands): @@ -209,13 +207,7 @@ class RedisCluster(ClusterCommands): RANDOM = "random" DEFAULT_NODE = "default-node" - NODE_FLAGS = { - PRIMARIES, - REPLICAS, - ALL_NODES, - RANDOM, - DEFAULT_NODE - } + NODE_FLAGS = {PRIMARIES, REPLICAS, ALL_NODES, RANDOM, DEFAULT_NODE} COMMAND_FLAGS = dict_merge( list_keys_to_dict( @@ -292,119 +284,138 @@ class RedisCluster(ClusterCommands): ) CLUSTER_COMMANDS_RESPONSE_CALLBACKS = { - 'CLUSTER ADDSLOTS': bool, - 'CLUSTER COUNT-FAILURE-REPORTS': int, - 'CLUSTER COUNTKEYSINSLOT': int, - 'CLUSTER DELSLOTS': bool, - 'CLUSTER FAILOVER': bool, - 'CLUSTER FORGET': bool, - 'CLUSTER GETKEYSINSLOT': list, - 'CLUSTER KEYSLOT': int, - 'CLUSTER MEET': bool, - 'CLUSTER REPLICATE': bool, - 'CLUSTER RESET': bool, - 'CLUSTER SAVECONFIG': bool, - 'CLUSTER SET-CONFIG-EPOCH': bool, - 'CLUSTER SETSLOT': bool, - 'CLUSTER SLOTS': parse_cluster_slots, - 'ASKING': bool, - 'READONLY': bool, - 'READWRITE': bool, + "CLUSTER ADDSLOTS": bool, + "CLUSTER COUNT-FAILURE-REPORTS": int, + "CLUSTER COUNTKEYSINSLOT": int, + "CLUSTER DELSLOTS": bool, + "CLUSTER FAILOVER": bool, + "CLUSTER FORGET": bool, + "CLUSTER GETKEYSINSLOT": list, + "CLUSTER KEYSLOT": int, + "CLUSTER MEET": bool, + "CLUSTER REPLICATE": bool, + "CLUSTER RESET": bool, + "CLUSTER SAVECONFIG": bool, + "CLUSTER SET-CONFIG-EPOCH": bool, + "CLUSTER SETSLOT": bool, + "CLUSTER SLOTS": parse_cluster_slots, + "ASKING": bool, + "READONLY": bool, + "READWRITE": bool, } RESULT_CALLBACKS = dict_merge( - list_keys_to_dict([ - "PUBSUB NUMSUB", - ], parse_pubsub_numsub), - list_keys_to_dict([ - "PUBSUB NUMPAT", - ], lambda command, res: sum(list(res.values()))), - list_keys_to_dict([ - "KEYS", - "PUBSUB CHANNELS", - ], merge_result), - list_keys_to_dict([ - "PING", - "CONFIG SET", - "CONFIG REWRITE", - "CONFIG RESETSTAT", - "CLIENT SETNAME", - "BGSAVE", - "SLOWLOG RESET", - "SAVE", - "MEMORY PURGE", - "CLIENT PAUSE", - "CLIENT UNPAUSE", - ], lambda command, res: all(res.values()) if isinstance(res, dict) - else res), - list_keys_to_dict([ - "DBSIZE", - "WAIT", - ], lambda command, res: sum(res.values()) if isinstance(res, dict) - else res), - list_keys_to_dict([ - "CLIENT UNBLOCK", - ], lambda command, res: 1 if sum(res.values()) > 0 else 0), - list_keys_to_dict([ - "SCAN", - ], parse_scan_result) + list_keys_to_dict( + [ + "PUBSUB NUMSUB", + ], + parse_pubsub_numsub, + ), + list_keys_to_dict( + [ + "PUBSUB NUMPAT", + ], + lambda command, res: sum(list(res.values())), + ), + list_keys_to_dict( + [ + "KEYS", + "PUBSUB CHANNELS", + ], + merge_result, + ), + list_keys_to_dict( + [ + "PING", + "CONFIG SET", + "CONFIG REWRITE", + "CONFIG RESETSTAT", + "CLIENT SETNAME", + "BGSAVE", + "SLOWLOG RESET", + "SAVE", + "MEMORY PURGE", + "CLIENT PAUSE", + "CLIENT UNPAUSE", + ], + lambda command, res: all(res.values()) if isinstance(res, dict) else res, + ), + list_keys_to_dict( + [ + "DBSIZE", + "WAIT", + ], + lambda command, res: sum(res.values()) if isinstance(res, dict) else res, + ), + list_keys_to_dict( + [ + "CLIENT UNBLOCK", + ], + lambda command, res: 1 if sum(res.values()) > 0 else 0, + ), + list_keys_to_dict( + [ + "SCAN", + ], + parse_scan_result, + ), ) def __init__( - self, - host=None, - port=6379, - startup_nodes=None, - cluster_error_retry_attempts=3, - require_full_coverage=True, - skip_full_coverage_check=False, - reinitialize_steps=10, - read_from_replicas=False, - url=None, - retry_on_timeout=False, - retry=None, - **kwargs + self, + host=None, + port=6379, + startup_nodes=None, + cluster_error_retry_attempts=3, + require_full_coverage=True, + skip_full_coverage_check=False, + reinitialize_steps=10, + read_from_replicas=False, + url=None, + retry_on_timeout=False, + retry=None, + **kwargs, ): """ - Initialize a new RedisCluster client. - - :startup_nodes: 'list[ClusterNode]' - List of nodes from which initial bootstrapping can be done - :host: 'str' - Can be used to point to a startup node - :port: 'int' - Can be used to point to a startup node - :require_full_coverage: 'bool' - If set to True, as it is by default, all slots must be covered. - If set to False and not all slots are covered, the instance - creation will succeed only if 'cluster-require-full-coverage' - configuration is set to 'no' in all of the cluster's nodes. - Otherwise, RedisClusterException will be thrown. - :skip_full_coverage_check: 'bool' - If require_full_coverage is set to False, a check of - cluster-require-full-coverage config will be executed against all - nodes. Set skip_full_coverage_check to True to skip this check. - Useful for clusters without the CONFIG command (like ElastiCache) - :read_from_replicas: 'bool' - Enable read from replicas in READONLY mode. You can read possibly - stale data. - When set to true, read commands will be assigned between the - primary and its replications in a Round-Robin manner. - :cluster_error_retry_attempts: 'int' - Retry command execution attempts when encountering ClusterDownError - or ConnectionError - :retry_on_timeout: 'bool' - To specify a retry policy, first set `retry_on_timeout` to `True` - then set `retry` to a valid `Retry` object - :retry: 'Retry' - a `Retry` object - :**kwargs: - Extra arguments that will be sent into Redis instance when created - (See Official redis-py doc for supported kwargs - [https://github.com/andymccurdy/redis-py/blob/master/redis/client.py]) - Some kwargs are not supported and will raise a - RedisClusterException: - - db (Redis do not support database SELECT in cluster mode) + Initialize a new RedisCluster client. + + :startup_nodes: 'list[ClusterNode]' + List of nodes from which initial bootstrapping can be done + :host: 'str' + Can be used to point to a startup node + :port: 'int' + Can be used to point to a startup node + :require_full_coverage: 'bool' + If set to True, as it is by default, all slots must be covered. + If set to False and not all slots are covered, the instance + creation will succeed only if 'cluster-require-full-coverage' + configuration is set to 'no' in all of the cluster's nodes. + Otherwise, RedisClusterException will be thrown. + :skip_full_coverage_check: 'bool' + If require_full_coverage is set to False, a check of + cluster-require-full-coverage config will be executed against all + nodes. Set skip_full_coverage_check to True to skip this check. + Useful for clusters without the CONFIG command (like ElastiCache) + :read_from_replicas: 'bool' + Enable read from replicas in READONLY mode. You can read possibly + stale data. + When set to true, read commands will be assigned between the + primary and its replications in a Round-Robin manner. + :cluster_error_retry_attempts: 'int' + Retry command execution attempts when encountering ClusterDownError + or ConnectionError + :retry_on_timeout: 'bool' + To specify a retry policy, first set `retry_on_timeout` to `True` + then set `retry` to a valid `Retry` object + :retry: 'Retry' + a `Retry` object + :**kwargs: + Extra arguments that will be sent into Redis instance when created + (See Official redis-py doc for supported kwargs + [https://github.com/andymccurdy/redis-py/blob/master/redis/client.py]) + Some kwargs are not supported and will raise a + RedisClusterException: + - db (Redis do not support database SELECT in cluster mode) """ log.info("Creating a new instance of RedisCluster client") @@ -418,8 +429,7 @@ def __init__( ) if retry_on_timeout: - kwargs.update({'retry_on_timeout': retry_on_timeout, - 'retry': retry}) + kwargs.update({"retry_on_timeout": retry_on_timeout, "retry": retry}) # Get the startup node/s from_url = False @@ -429,15 +439,16 @@ def __init__( if "path" in url_options: raise RedisClusterException( "RedisCluster does not currently support Unix Domain " - "Socket connections") + "Socket connections" + ) if "db" in url_options and url_options["db"] != 0: # Argument 'db' is not possible to use in cluster mode raise RedisClusterException( "A ``db`` querystring option can only be 0 in cluster mode" ) kwargs.update(url_options) - host = kwargs.get('host') - port = kwargs.get('port', port) + host = kwargs.get("host") + port = kwargs.get("port", port) startup_nodes.append(ClusterNode(host, port)) elif host is not None and port is not None: startup_nodes.append(ClusterNode(host, port)) @@ -450,7 +461,8 @@ def __init__( " RedisCluster(host='localhost', port=6379)\n" "2. list of startup nodes, for example:\n" " RedisCluster(startup_nodes=[ClusterNode('localhost', 6379)," - " ClusterNode('localhost', 6378)])") + " ClusterNode('localhost', 6378)])" + ) log.debug(f"startup_nodes : {startup_nodes}") # Update the connection arguments # Whenever a new connection is established, RedisCluster's on_connect @@ -482,9 +494,9 @@ def __init__( ) self.cluster_response_callbacks = CaseInsensitiveDict( - self.__class__.CLUSTER_COMMANDS_RESPONSE_CALLBACKS) - self.result_callbacks = CaseInsensitiveDict( - self.__class__.RESULT_CALLBACKS) + self.__class__.CLUSTER_COMMANDS_RESPONSE_CALLBACKS + ) + self.result_callbacks = CaseInsensitiveDict(self.__class__.RESULT_CALLBACKS) self.commands_parser = CommandsParser(self) self._lock = threading.Lock() @@ -563,9 +575,9 @@ def on_connect(self, connection): # to a failover, we should establish a READONLY connection # regardless of the server type. If this is a primary connection, # READONLY would not affect executing write commands. - connection.send_command('READONLY') - if str_if_bytes(connection.read_response()) != 'OK': - raise ConnectionError('READONLY command failed') + connection.send_command("READONLY") + if str_if_bytes(connection.read_response()) != "OK": + raise ConnectionError("READONLY command failed") if self.user_on_connect_func is not None: self.user_on_connect_func(connection) @@ -601,9 +613,7 @@ def get_node_from_key(self, key, replica=False): slot = self.keyslot(key) slot_cache = self.nodes_manager.slots_cache.get(slot) if slot_cache is None or len(slot_cache) == 0: - raise SlotNotCoveredError( - f'Slot "{slot}" is not covered by the cluster.' - ) + raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.') if replica and len(self.nodes_manager.slots_cache[slot]) < 2: return None elif replica: @@ -627,8 +637,10 @@ def set_default_node(self, node): :return True if the default node was set, else False """ if node is None or self.get_node(node_name=node.name) is None: - log.info("The requested node does not exist in the cluster, so " - "the default node was not changed.") + log.info( + "The requested node does not exist in the cluster, so " + "the default node was not changed." + ) return False self.nodes_manager.default_node = node log.info(f"Changed the default cluster node to {node}") @@ -651,12 +663,10 @@ def pipeline(self, transaction=None, shard_hint=None): when calling execute() will only return the result stack. """ if shard_hint: - raise RedisClusterException( - "shard_hint is deprecated in cluster mode") + raise RedisClusterException("shard_hint is deprecated in cluster mode") if transaction: - raise RedisClusterException( - "transaction is deprecated in cluster mode") + raise RedisClusterException("transaction is deprecated in cluster mode") return ClusterPipeline( nodes_manager=self.nodes_manager, @@ -665,7 +675,7 @@ def pipeline(self, transaction=None, shard_hint=None): cluster_response_callbacks=self.cluster_response_callbacks, cluster_error_retry_attempts=self.cluster_error_retry_attempts, read_from_replicas=self.read_from_replicas, - reinitialize_steps=self.reinitialize_steps + reinitialize_steps=self.reinitialize_steps, ) def _determine_nodes(self, *args, **kwargs): @@ -698,7 +708,8 @@ def _determine_nodes(self, *args, **kwargs): # get the node that holds the key's slot slot = self.determine_slot(*args) node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and command in READ_COMMANDS) + slot, self.read_from_replicas and command in READ_COMMANDS + ) log.debug(f"Target for {args}: slot {slot}") return [node] @@ -760,8 +771,7 @@ def reinitialize_caches(self): self.nodes_manager.initialize() def _is_nodes_flag(self, target_nodes): - return isinstance(target_nodes, str) \ - and target_nodes in self.node_flags + return isinstance(target_nodes, str) and target_nodes in self.node_flags def _parse_target_nodes(self, target_nodes): if isinstance(target_nodes, list): @@ -812,8 +822,9 @@ def execute_command(self, *args, **kwargs): # the command execution since the nodes may not be valid anymore # after the tables were reinitialized. So in case of passed target # nodes, retry_attempts will be set to 1. - retry_attempts = 1 if target_nodes_specified else \ - self.cluster_error_retry_attempts + retry_attempts = ( + 1 if target_nodes_specified else self.cluster_error_retry_attempts + ) exception = None for _ in range(0, retry_attempts): try: @@ -821,13 +832,14 @@ def execute_command(self, *args, **kwargs): if not target_nodes_specified: # Determine the nodes to execute the command on target_nodes = self._determine_nodes( - *args, **kwargs, nodes_flag=target_nodes) + *args, **kwargs, nodes_flag=target_nodes + ) if not target_nodes: raise RedisClusterException( - f"No targets were found to execute {args} command on") + f"No targets were found to execute {args} command on" + ) for node in target_nodes: - res[node.name] = self._execute_command( - node, *args, **kwargs) + res[node.name] = self._execute_command(node, *args, **kwargs) # Return the processed result return self._process_result(args[0], res, **kwargs) except (ClusterDownError, ConnectionError) as e: @@ -862,9 +874,9 @@ def _execute_command(self, target_node, *args, **kwargs): # MOVED occurred and the slots cache was updated, # refresh the target node slot = self.determine_slot(*args) - target_node = self.nodes_manager. \ - get_node_from_slot(slot, self.read_from_replicas and - command in READ_COMMANDS) + target_node = self.nodes_manager.get_node_from_slot( + slot, self.read_from_replicas and command in READ_COMMANDS + ) moved = False log.debug( @@ -879,11 +891,11 @@ def _execute_command(self, target_node, *args, **kwargs): asking = False connection.send_command(*args) - response = redis_node.parse_response(connection, command, - **kwargs) + response = redis_node.parse_response(connection, command, **kwargs) if command in self.cluster_response_callbacks: response = self.cluster_response_callbacks[command]( - response, **kwargs) + response, **kwargs + ) return response except (RedisClusterException, BusyLoadingError): @@ -997,7 +1009,7 @@ def _process_result(self, command, res, **kwargs): class ClusterNode: def __init__(self, host, port, server_type=None, redis_connection=None): - if host == 'localhost': + if host == "localhost": host = socket.gethostbyname(host) self.host = host @@ -1008,11 +1020,11 @@ def __init__(self, host, port, server_type=None, redis_connection=None): def __repr__(self): return ( - f'[host={self.host},' - f'port={self.port},' - f'name={self.name},' - f'server_type={self.server_type},' - f'redis_connection={self.redis_connection}]' + f"[host={self.host}," + f"port={self.port}," + f"name={self.name}," + f"server_type={self.server_type}," + f"redis_connection={self.redis_connection}]" ) def __eq__(self, obj): @@ -1029,8 +1041,7 @@ def __init__(self, start_index=0): self.start_index = start_index def get_server_index(self, primary, list_size): - server_index = self.primary_to_idx.setdefault(primary, - self.start_index) + server_index = self.primary_to_idx.setdefault(primary, self.start_index) # Update the index self.primary_to_idx[primary] = (server_index + 1) % list_size return server_index @@ -1040,9 +1051,15 @@ def reset(self): class NodesManager: - def __init__(self, startup_nodes, from_url=False, - require_full_coverage=True, skip_full_coverage_check=False, - lock=None, **kwargs): + def __init__( + self, + startup_nodes, + from_url=False, + require_full_coverage=True, + skip_full_coverage_check=False, + lock=None, + **kwargs, + ): self.nodes_cache = {} self.slots_cache = {} self.startup_nodes = {} @@ -1122,8 +1139,7 @@ def _update_moved_slots(self): # Reset moved_exception self._moved_exception = None - def get_node_from_slot(self, slot, read_from_replicas=False, - server_type=None): + def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None): """ Gets a node that servers this hash slot """ @@ -1132,8 +1148,7 @@ def get_node_from_slot(self, slot, read_from_replicas=False, if self._moved_exception: self._update_moved_slots() - if self.slots_cache.get(slot) is None or \ - len(self.slots_cache[slot]) == 0: + if self.slots_cache.get(slot) is None or len(self.slots_cache[slot]) == 0: raise SlotNotCoveredError( f'Slot "{slot}" not covered by the cluster. ' f'"require_full_coverage={self._require_full_coverage}"' @@ -1143,19 +1158,19 @@ def get_node_from_slot(self, slot, read_from_replicas=False, # get the server index in a Round-Robin manner primary_name = self.slots_cache[slot][0].name node_idx = self.read_load_balancer.get_server_index( - primary_name, len(self.slots_cache[slot])) + primary_name, len(self.slots_cache[slot]) + ) elif ( - server_type is None - or server_type == PRIMARY - or len(self.slots_cache[slot]) == 1 + server_type is None + or server_type == PRIMARY + or len(self.slots_cache[slot]) == 1 ): # return a primary node_idx = 0 else: # return a replica # randomly choose one of the replicas - node_idx = random.randint( - 1, len(self.slots_cache[slot]) - 1) + node_idx = random.randint(1, len(self.slots_cache[slot]) - 1) return self.slots_cache[slot][node_idx] @@ -1187,20 +1202,22 @@ def cluster_require_full_coverage(self, cluster_nodes): def node_require_full_coverage(node): try: - return ("yes" in node.redis_connection.config_get( - "cluster-require-full-coverage").values() + return ( + "yes" + in node.redis_connection.config_get( + "cluster-require-full-coverage" + ).values() ) except ConnectionError: return False except Exception as e: raise RedisClusterException( 'ERROR sending "config get cluster-require-full-coverage"' - f' command to redis server: {node.name}, {e}' + f" command to redis server: {node.name}, {e}" ) # at least one node should have cluster-require-full-coverage yes - return any(node_require_full_coverage(node) - for node in cluster_nodes.values()) + return any(node_require_full_coverage(node) for node in cluster_nodes.values()) def check_slots_coverage(self, slots_cache): # Validate if all slots are covered or if we should try next @@ -1229,11 +1246,7 @@ def create_redis_node(self, host, port, **kwargs): kwargs.update({"port": port}) r = Redis(connection_pool=ConnectionPool(**kwargs)) else: - r = Redis( - host=host, - port=port, - **kwargs - ) + r = Redis(host=host, port=port, **kwargs) return r def initialize(self): @@ -1257,22 +1270,23 @@ def initialize(self): # Create a new Redis connection and let Redis decode the # responses so we won't need to handle that copy_kwargs = copy.deepcopy(kwargs) - copy_kwargs.update({"decode_responses": True, - "encoding": "utf-8"}) + copy_kwargs.update({"decode_responses": True, "encoding": "utf-8"}) r = self.create_redis_node( - startup_node.host, startup_node.port, **copy_kwargs) + startup_node.host, startup_node.port, **copy_kwargs + ) self.startup_nodes[startup_node.name].redis_connection = r cluster_slots = r.execute_command("CLUSTER SLOTS") startup_nodes_reachable = True except (ConnectionError, TimeoutError) as e: msg = e.__str__ - log.exception('An exception occurred while trying to' - ' initialize the cluster using the seed node' - f' {startup_node.name}:\n{msg}') + log.exception( + "An exception occurred while trying to" + " initialize the cluster using the seed node" + f" {startup_node.name}:\n{msg}" + ) continue except ResponseError as e: - log.exception( - 'ReseponseError sending "cluster slots" to redis server') + log.exception('ReseponseError sending "cluster slots" to redis server') # Isn't a cluster connection, so it won't parse these # exceptions automatically @@ -1282,13 +1296,13 @@ def initialize(self): else: raise RedisClusterException( 'ERROR sending "cluster slots" command to redis ' - f'server: {startup_node}. error: {message}' + f"server: {startup_node}. error: {message}" ) except Exception as e: message = e.__str__() raise RedisClusterException( 'ERROR sending "cluster slots" command to redis ' - f'server: {startup_node}. error: {message}' + f"server: {startup_node}. error: {message}" ) # CLUSTER SLOTS command results in the following output: @@ -1298,9 +1312,11 @@ def initialize(self): # primary node of the first slot section. # If there's only one server in the cluster, its ``host`` is '' # Fix it to the host in startup_nodes - if (len(cluster_slots) == 1 - and len(cluster_slots[0][2][0]) == 0 - and len(self.startup_nodes) == 1): + if ( + len(cluster_slots) == 1 + and len(cluster_slots[0][2][0]) == 0 + and len(self.startup_nodes) == 1 + ): cluster_slots[0][2][0] = startup_node.host for slot in cluster_slots: @@ -1327,10 +1343,10 @@ def initialize(self): port = replica_node[1] target_replica_node = tmp_nodes_cache.get( - get_node_name(host, port)) + get_node_name(host, port) + ) if target_replica_node is None: - target_replica_node = ClusterNode( - host, port, REPLICA) + target_replica_node = ClusterNode(host, port, REPLICA) tmp_slots[i].append(target_replica_node) # add this node to the nodes cache tmp_nodes_cache[ @@ -1342,12 +1358,12 @@ def initialize(self): tmp_slot = tmp_slots[i][0] if tmp_slot.name != target_node.name: disagreements.append( - f'{tmp_slot.name} vs {target_node.name} on slot: {i}' + f"{tmp_slot.name} vs {target_node.name} on slot: {i}" ) if len(disagreements) > 5: raise RedisClusterException( - f'startup_nodes could not agree on a valid ' + f"startup_nodes could not agree on a valid " f'slots cache: {", ".join(disagreements)}' ) @@ -1366,8 +1382,8 @@ def initialize(self): # Despite the requirement that the slots be covered, there # isn't a full coverage raise RedisClusterException( - f'All slots are not covered after query all startup_nodes. ' - f'{len(self.slots_cache)} of {REDIS_CLUSTER_HASH_SLOTS} covered...' + f"All slots are not covered after query all startup_nodes. " + f"{len(self.slots_cache)} of {REDIS_CLUSTER_HASH_SLOTS} covered..." ) elif not fully_covered and not self._require_full_coverage: # The user set require_full_coverage to False. @@ -1376,15 +1392,17 @@ def initialize(self): # continue with partial coverage. # see Redis Cluster configuration parameters in # https://redis.io/topics/cluster-tutorial - if not self._skip_full_coverage_check and \ - self.cluster_require_full_coverage(tmp_nodes_cache): + if ( + not self._skip_full_coverage_check + and self.cluster_require_full_coverage(tmp_nodes_cache) + ): raise RedisClusterException( - 'Not all slots are covered but the cluster\'s ' - 'configuration requires full coverage. Set ' - 'cluster-require-full-coverage configuration to no on ' - 'all of the cluster nodes if you wish the cluster to ' - 'be able to serve without being fully covered.' - f'{len(self.slots_cache)} of {REDIS_CLUSTER_HASH_SLOTS} covered...' + "Not all slots are covered but the cluster's " + "configuration requires full coverage. Set " + "cluster-require-full-coverage configuration to no on " + "all of the cluster nodes if you wish the cluster to " + "be able to serve without being fully covered." + f"{len(self.slots_cache)} of {REDIS_CLUSTER_HASH_SLOTS} covered..." ) # Set the tmp variables to the real variables @@ -1418,8 +1436,7 @@ class ClusterPubSub(PubSub): https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html """ - def __init__(self, redis_cluster, node=None, host=None, port=None, - **kwargs): + def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs): """ When a pubsub instance is created without specifying a node, a single node will be transparently chosen for the pubsub connection on the @@ -1436,11 +1453,15 @@ def __init__(self, redis_cluster, node=None, host=None, port=None, log.info("Creating new instance of ClusterPubSub") self.node = None self.set_pubsub_node(redis_cluster, node, host, port) - connection_pool = None if self.node is None else \ - redis_cluster.get_redis_connection(self.node).connection_pool + connection_pool = ( + None + if self.node is None + else redis_cluster.get_redis_connection(self.node).connection_pool + ) self.cluster = redis_cluster - super().__init__(**kwargs, connection_pool=connection_pool, - encoder=redis_cluster.encoder) + super().__init__( + **kwargs, connection_pool=connection_pool, encoder=redis_cluster.encoder + ) def set_pubsub_node(self, cluster, node=None, host=None, port=None): """ @@ -1468,8 +1489,7 @@ def set_pubsub_node(self, cluster, node=None, host=None, port=None): pubsub_node = node elif any([host, port]) is True: # only 'host' or 'port' passed - raise DataError('Passing a host requires passing a port, ' - 'and vice versa') + raise DataError("Passing a host requires passing a port, " "and vice versa") else: # nothing passed by the user. set node to None pubsub_node = None @@ -1489,7 +1509,8 @@ def _raise_on_invalid_node(self, redis_cluster, node, host, port): """ if node is None or redis_cluster.get_node(node_name=node.name) is None: raise RedisClusterException( - f"Node {host}:{port} doesn't exist in the cluster") + f"Node {host}:{port} doesn't exist in the cluster" + ) def execute_command(self, *args, **kwargs): """ @@ -1508,9 +1529,9 @@ def execute_command(self, *args, **kwargs): # this slot channel = args[1] slot = self.cluster.keyslot(channel) - node = self.cluster.nodes_manager. \ - get_node_from_slot(slot, self.cluster. - read_from_replicas) + node = self.cluster.nodes_manager.get_node_from_slot( + slot, self.cluster.read_from_replicas + ) else: # Get a random node node = self.cluster.get_random_node() @@ -1518,8 +1539,7 @@ def execute_command(self, *args, **kwargs): redis_connection = self.cluster.get_redis_connection(node) self.connection_pool = redis_connection.connection_pool self.connection = self.connection_pool.get_connection( - 'pubsub', - self.shard_hint + "pubsub", self.shard_hint ) # register a callback that re-subscribes to any channels we # were listening to when we were disconnected @@ -1535,8 +1555,13 @@ def get_redis_connection(self): return self.node.redis_connection -ERRORS_ALLOW_RETRY = (ConnectionError, TimeoutError, - MovedError, AskError, TryAgainError) +ERRORS_ALLOW_RETRY = ( + ConnectionError, + TimeoutError, + MovedError, + AskError, + TryAgainError, +) class ClusterPipeline(RedisCluster): @@ -1545,18 +1570,25 @@ class ClusterPipeline(RedisCluster): in cluster mode """ - def __init__(self, nodes_manager, result_callbacks=None, - cluster_response_callbacks=None, startup_nodes=None, - read_from_replicas=False, cluster_error_retry_attempts=3, - reinitialize_steps=10, **kwargs): - """ - """ + def __init__( + self, + nodes_manager, + result_callbacks=None, + cluster_response_callbacks=None, + startup_nodes=None, + read_from_replicas=False, + cluster_error_retry_attempts=3, + reinitialize_steps=10, + **kwargs, + ): + """ """ log.info("Creating new instance of ClusterPipeline") self.command_stack = [] self.nodes_manager = nodes_manager self.refresh_table_asap = False - self.result_callbacks = (result_callbacks or - self.__class__.RESULT_CALLBACKS.copy()) + self.result_callbacks = ( + result_callbacks or self.__class__.RESULT_CALLBACKS.copy() + ) self.startup_nodes = startup_nodes if startup_nodes else [] self.read_from_replicas = read_from_replicas self.command_flags = self.__class__.COMMAND_FLAGS.copy() @@ -1576,18 +1608,15 @@ def __init__(self, nodes_manager, result_callbacks=None, self.commands_parser = CommandsParser(super()) def __repr__(self): - """ - """ + """ """ return f"{type(self).__name__}" def __enter__(self): - """ - """ + """ """ return self def __exit__(self, exc_type, exc_value, traceback): - """ - """ + """ """ self.reset() def __del__(self): @@ -1597,8 +1626,7 @@ def __del__(self): pass def __len__(self): - """ - """ + """ """ return len(self.command_stack) def __nonzero__(self): @@ -1620,7 +1648,8 @@ def pipeline_execute_command(self, *args, **options): Appends the executed command to the pipeline's command stack """ self.command_stack.append( - PipelineCommand(args, options, len(self.command_stack))) + PipelineCommand(args, options, len(self.command_stack)) + ) return self def raise_first_error(self, stack): @@ -1637,10 +1666,10 @@ def annotate_exception(self, exception, number, command): """ Provides extra context to the exception prior to it being handled """ - cmd = ' '.join(map(safe_str, command)) + cmd = " ".join(map(safe_str, command)) msg = ( - f'Command # {number} ({cmd}) of pipeline ' - f'caused error: {exception.args[0]}' + f"Command # {number} ({cmd}) of pipeline " + f"caused error: {exception.args[0]}" ) exception.args = (msg,) + exception.args[1:] @@ -1686,8 +1715,9 @@ def reset(self): # self.connection_pool.release(self.connection) # self.connection = None - def send_cluster_commands(self, stack, - raise_on_error=True, allow_redirections=True): + def send_cluster_commands( + self, stack, raise_on_error=True, allow_redirections=True + ): """ Wrapper for CLUSTERDOWN error handling. @@ -1720,12 +1750,11 @@ def send_cluster_commands(self, stack, # If it fails the configured number of times then raise # exception back to caller of this method - raise ClusterDownError( - "CLUSTERDOWN error. Unable to rebuild the cluster") + raise ClusterDownError("CLUSTERDOWN error. Unable to rebuild the cluster") - def _send_cluster_commands(self, stack, - raise_on_error=True, - allow_redirections=True): + def _send_cluster_commands( + self, stack, raise_on_error=True, allow_redirections=True + ): """ Send a bunch of cluster commands to the redis cluster. @@ -1751,7 +1780,8 @@ def _send_cluster_commands(self, stack, # command should route to. slot = self.determine_slot(*c.args) node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and c.args[0] in READ_COMMANDS) + slot, self.read_from_replicas and c.args[0] in READ_COMMANDS + ) # now that we know the name of the node # ( it's just a string in the form of host:port ) @@ -1760,9 +1790,9 @@ def _send_cluster_commands(self, stack, if node_name not in nodes: redis_node = self.get_redis_connection(node) connection = get_connection(redis_node, c.args) - nodes[node_name] = NodeCommands(redis_node.parse_response, - redis_node.connection_pool, - connection) + nodes[node_name] = NodeCommands( + redis_node.parse_response, redis_node.connection_pool, connection + ) nodes[node_name].append(c) @@ -1808,9 +1838,10 @@ def _send_cluster_commands(self, stack, # if we have more commands to attempt, we've run into problems. # collect all the commands we are allowed to retry. # (MOVED, ASK, or connection errors or timeout errors) - attempt = sorted((c for c in attempt - if isinstance(c.result, ERRORS_ALLOW_RETRY)), - key=lambda x: x.position) + attempt = sorted( + (c for c in attempt if isinstance(c.result, ERRORS_ALLOW_RETRY)), + key=lambda x: x.position, + ) if attempt and allow_redirections: # RETRY MAGIC HAPPENS HERE! # send these remaing comamnds one at a time using `execute_command` @@ -1831,10 +1862,10 @@ def _send_cluster_commands(self, stack, # flag to rebuild the slots table from scratch. # So MOVED errors should correct themselves fairly quickly. log.exception( - f'An exception occurred during pipeline execution. ' - f'args: {attempt[-1].args}, ' - f'error: {type(attempt[-1].result).__name__} ' - f'{str(attempt[-1].result)}' + f"An exception occurred during pipeline execution. " + f"args: {attempt[-1].args}, " + f"error: {type(attempt[-1].result).__name__} " + f"{str(attempt[-1].result)}" ) self.reinitialize_counter += 1 if self._should_reinitialized(): @@ -1857,55 +1888,47 @@ def _send_cluster_commands(self, stack, return response def _fail_on_redirect(self, allow_redirections): - """ - """ + """ """ if not allow_redirections: raise RedisClusterException( - "ASK & MOVED redirection not allowed in this pipeline") + "ASK & MOVED redirection not allowed in this pipeline" + ) def eval(self): - """ - """ + """ """ raise RedisClusterException("method eval() is not implemented") def multi(self): - """ - """ + """ """ raise RedisClusterException("method multi() is not implemented") def immediate_execute_command(self, *args, **options): - """ - """ + """ """ raise RedisClusterException( - "method immediate_execute_command() is not implemented") + "method immediate_execute_command() is not implemented" + ) def _execute_transaction(self, *args, **kwargs): - """ - """ - raise RedisClusterException( - "method _execute_transaction() is not implemented") + """ """ + raise RedisClusterException("method _execute_transaction() is not implemented") def load_scripts(self): - """ - """ - raise RedisClusterException( - "method load_scripts() is not implemented") + """ """ + raise RedisClusterException("method load_scripts() is not implemented") def watch(self, *names): - """ - """ + """ """ raise RedisClusterException("method watch() is not implemented") def unwatch(self): - """ - """ + """ """ raise RedisClusterException("method unwatch() is not implemented") def script_load_for_pipeline(self, *args, **kwargs): - """ - """ + """ """ raise RedisClusterException( - "method script_load_for_pipeline() is not implemented") + "method script_load_for_pipeline() is not implemented" + ) def delete(self, *names): """ @@ -1913,10 +1936,10 @@ def delete(self, *names): """ if len(names) != 1: raise RedisClusterException( - "deleting multiple keys is not " - "implemented in pipeline command") + "deleting multiple keys is not " "implemented in pipeline command" + ) - return self.execute_command('DEL', names[0]) + return self.execute_command("DEL", names[0]) def block_pipeline_command(func): @@ -1928,7 +1951,8 @@ def block_pipeline_command(func): def inner(*args, **kwargs): raise RedisClusterException( f"ERROR: Calling pipelined function {func.__name__} is blocked when " - f"running redis in cluster mode...") + f"running redis in cluster mode..." + ) return inner @@ -1936,11 +1960,9 @@ def inner(*args, **kwargs): # Blocked pipeline commands ClusterPipeline.bitop = block_pipeline_command(RedisCluster.bitop) ClusterPipeline.brpoplpush = block_pipeline_command(RedisCluster.brpoplpush) -ClusterPipeline.client_getname = \ - block_pipeline_command(RedisCluster.client_getname) +ClusterPipeline.client_getname = block_pipeline_command(RedisCluster.client_getname) ClusterPipeline.client_list = block_pipeline_command(RedisCluster.client_list) -ClusterPipeline.client_setname = \ - block_pipeline_command(RedisCluster.client_setname) +ClusterPipeline.client_setname = block_pipeline_command(RedisCluster.client_setname) ClusterPipeline.config_set = block_pipeline_command(RedisCluster.config_set) ClusterPipeline.dbsize = block_pipeline_command(RedisCluster.dbsize) ClusterPipeline.flushall = block_pipeline_command(RedisCluster.flushall) @@ -1972,8 +1994,7 @@ def inner(*args, **kwargs): class PipelineCommand: - """ - """ + """ """ def __init__(self, args, options=None, position=None): self.args = args @@ -1987,20 +2008,17 @@ def __init__(self, args, options=None, position=None): class NodeCommands: - """ - """ + """ """ def __init__(self, parse_response, connection_pool, connection): - """ - """ + """ """ self.parse_response = parse_response self.connection_pool = connection_pool self.connection = connection self.commands = [] def append(self, c): - """ - """ + """ """ self.commands.append(c) def write(self): @@ -2019,14 +2037,14 @@ def write(self): # send all the commands and catch connection and timeout errors. try: connection.send_packed_command( - connection.pack_commands([c.args for c in commands])) + connection.pack_commands([c.args for c in commands]) + ) except (ConnectionError, TimeoutError) as e: for c in commands: c.result = e def read(self): - """ - """ + """ """ connection = self.connection for c in self.commands: @@ -2050,8 +2068,7 @@ def read(self): # explicitly open the connection and all will be well. if c.result is None: try: - c.result = self.parse_response( - connection, c.args[0], **c.options) + c.result = self.parse_response(connection, c.args[0], **c.options) except (ConnectionError, TimeoutError) as e: for c in self.commands: c.result = e diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index e6b0a08924..0df073ab16 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -1,9 +1,6 @@ -from redis.exceptions import ( - ConnectionError, - DataError, - RedisError, -) from redis.crc import key_slot +from redis.exceptions import ConnectionError, DataError, RedisError + from .core import DataAccessCommands from .helpers import list_or_args @@ -36,6 +33,7 @@ def mget_nonatomic(self, keys, *args): """ from redis.client import EMPTY_RESPONSE + options = {} if not args: options[EMPTY_RESPONSE] = [] @@ -50,8 +48,7 @@ def mget_nonatomic(self, keys, *args): # We must make sure that the keys are returned in order all_results = {} for slot_keys in slots_to_keys.values(): - slot_values = self.execute_command( - 'MGET', *slot_keys, **options) + slot_values = self.execute_command("MGET", *slot_keys, **options) slot_results = dict(zip(slot_keys, slot_values)) all_results.update(slot_results) @@ -83,7 +80,7 @@ def mset_nonatomic(self, mapping): # the results (one result per slot) res = [] for pairs in slots_to_pairs.values(): - res.append(self.execute_command('MSET', *pairs)) + res.append(self.execute_command("MSET", *pairs)) return res @@ -108,7 +105,7 @@ def exists(self, *keys): whole cluster. The keys are first split up into slots and then an EXISTS command is sent for every slot """ - return self._split_command_across_slots('EXISTS', *keys) + return self._split_command_across_slots("EXISTS", *keys) def delete(self, *keys): """ @@ -119,7 +116,7 @@ def delete(self, *keys): Non-existant keys are ignored. Returns the number of keys that were deleted. """ - return self._split_command_across_slots('DEL', *keys) + return self._split_command_across_slots("DEL", *keys) def touch(self, *keys): """ @@ -132,7 +129,7 @@ def touch(self, *keys): Non-existant keys are ignored. Returns the number of keys that were touched. """ - return self._split_command_across_slots('TOUCH', *keys) + return self._split_command_across_slots("TOUCH", *keys) def unlink(self, *keys): """ @@ -144,7 +141,7 @@ def unlink(self, *keys): Non-existant keys are ignored. Returns the number of keys that were unlinked. """ - return self._split_command_across_slots('UNLINK', *keys) + return self._split_command_across_slots("UNLINK", *keys) class ClusterManagementCommands: @@ -166,6 +163,7 @@ class ClusterManagementCommands: r.bgsave(target_nodes=primary) r.bgsave(target_nodes='primaries') """ + def bgsave(self, schedule=True, target_nodes=None): """ Tell the Redis server to save its data to disk. Unlike save(), @@ -174,9 +172,7 @@ def bgsave(self, schedule=True, target_nodes=None): pieces = [] if schedule: pieces.append("SCHEDULE") - return self.execute_command('BGSAVE', - *pieces, - target_nodes=target_nodes) + return self.execute_command("BGSAVE", *pieces, target_nodes=target_nodes) def client_getname(self, target_nodes=None): """ @@ -184,8 +180,7 @@ def client_getname(self, target_nodes=None): The result will be a dictionary with the IP and connection name. """ - return self.execute_command('CLIENT GETNAME', - target_nodes=target_nodes) + return self.execute_command("CLIENT GETNAME", target_nodes=target_nodes) def client_getredir(self, target_nodes=None): """Returns the ID (an integer) of the client to whom we are @@ -193,25 +188,29 @@ def client_getredir(self, target_nodes=None): see: https://redis.io/commands/client-getredir """ - return self.execute_command('CLIENT GETREDIR', - target_nodes=target_nodes) + return self.execute_command("CLIENT GETREDIR", target_nodes=target_nodes) def client_id(self, target_nodes=None): """Returns the current connection id""" - return self.execute_command('CLIENT ID', - target_nodes=target_nodes) + return self.execute_command("CLIENT ID", target_nodes=target_nodes) def client_info(self, target_nodes=None): """ Returns information and statistics about the current client connection. """ - return self.execute_command('CLIENT INFO', - target_nodes=target_nodes) + return self.execute_command("CLIENT INFO", target_nodes=target_nodes) - def client_kill_filter(self, _id=None, _type=None, addr=None, - skipme=None, laddr=None, user=None, - target_nodes=None): + def client_kill_filter( + self, + _id=None, + _type=None, + addr=None, + skipme=None, + laddr=None, + user=None, + target_nodes=None, + ): """ Disconnects client(s) using a variety of filter options :param id: Kills a client by its unique ID field @@ -226,35 +225,35 @@ def client_kill_filter(self, _id=None, _type=None, addr=None, """ args = [] if _type is not None: - client_types = ('normal', 'master', 'slave', 'pubsub') + client_types = ("normal", "master", "slave", "pubsub") if str(_type).lower() not in client_types: raise DataError(f"CLIENT KILL type must be one of {client_types!r}") - args.extend((b'TYPE', _type)) + args.extend((b"TYPE", _type)) if skipme is not None: if not isinstance(skipme, bool): raise DataError("CLIENT KILL skipme must be a bool") if skipme: - args.extend((b'SKIPME', b'YES')) + args.extend((b"SKIPME", b"YES")) else: - args.extend((b'SKIPME', b'NO')) + args.extend((b"SKIPME", b"NO")) if _id is not None: - args.extend((b'ID', _id)) + args.extend((b"ID", _id)) if addr is not None: - args.extend((b'ADDR', addr)) + args.extend((b"ADDR", addr)) if laddr is not None: - args.extend((b'LADDR', laddr)) + args.extend((b"LADDR", laddr)) if user is not None: - args.extend((b'USER', user)) + args.extend((b"USER", user)) if not args: - raise DataError("CLIENT KILL ... ... " - " must specify at least one filter") - return self.execute_command('CLIENT KILL', *args, - target_nodes=target_nodes) + raise DataError( + "CLIENT KILL ... ... " + " must specify at least one filter" + ) + return self.execute_command("CLIENT KILL", *args, target_nodes=target_nodes) def client_kill(self, address, target_nodes=None): "Disconnects the client at ``address`` (ip:port)" - return self.execute_command('CLIENT KILL', address, - target_nodes=target_nodes) + return self.execute_command("CLIENT KILL", address, target_nodes=target_nodes) def client_list(self, _type=None, target_nodes=None): """ @@ -264,15 +263,13 @@ def client_list(self, _type=None, target_nodes=None): replica, pubsub) """ if _type is not None: - client_types = ('normal', 'master', 'replica', 'pubsub') + client_types = ("normal", "master", "replica", "pubsub") if str(_type).lower() not in client_types: raise DataError(f"CLIENT LIST _type must be one of {client_types!r}") - return self.execute_command('CLIENT LIST', - b'TYPE', - _type, - target_noes=target_nodes) - return self.execute_command('CLIENT LIST', - target_nodes=target_nodes) + return self.execute_command( + "CLIENT LIST", b"TYPE", _type, target_noes=target_nodes + ) + return self.execute_command("CLIENT LIST", target_nodes=target_nodes) def client_pause(self, timeout, target_nodes=None): """ @@ -281,8 +278,9 @@ def client_pause(self, timeout, target_nodes=None): """ if not isinstance(timeout, int): raise DataError("CLIENT PAUSE timeout must be an integer") - return self.execute_command('CLIENT PAUSE', str(timeout), - target_nodes=target_nodes) + return self.execute_command( + "CLIENT PAUSE", str(timeout), target_nodes=target_nodes + ) def client_reply(self, reply, target_nodes=None): """Enable and disable redis server replies. @@ -298,16 +296,14 @@ def client_reply(self, reply, target_nodes=None): conftest.py has a client with a timeout. See https://redis.io/commands/client-reply """ - replies = ['ON', 'OFF', 'SKIP'] + replies = ["ON", "OFF", "SKIP"] if reply not in replies: - raise DataError(f'CLIENT REPLY must be one of {replies!r}') - return self.execute_command("CLIENT REPLY", reply, - target_nodes=target_nodes) + raise DataError(f"CLIENT REPLY must be one of {replies!r}") + return self.execute_command("CLIENT REPLY", reply, target_nodes=target_nodes) def client_setname(self, name, target_nodes=None): "Sets the current connection name" - return self.execute_command('CLIENT SETNAME', name, - target_nodes=target_nodes) + return self.execute_command("CLIENT SETNAME", name, target_nodes=target_nodes) def client_trackinginfo(self, target_nodes=None): """ @@ -315,8 +311,7 @@ def client_trackinginfo(self, target_nodes=None): use of the server assisted client side cache. See https://redis.io/commands/client-trackinginfo """ - return self.execute_command('CLIENT TRACKINGINFO', - target_nodes=target_nodes) + return self.execute_command("CLIENT TRACKINGINFO", target_nodes=target_nodes) def client_unblock(self, client_id, error=False, target_nodes=None): """ @@ -325,56 +320,50 @@ def client_unblock(self, client_id, error=False, target_nodes=None): If ``error`` is False (default), the client is unblocked using the regular timeout mechanism. """ - args = ['CLIENT UNBLOCK', int(client_id)] + args = ["CLIENT UNBLOCK", int(client_id)] if error: - args.append(b'ERROR') + args.append(b"ERROR") return self.execute_command(*args, target_nodes=target_nodes) def client_unpause(self, target_nodes=None): """ Unpause all redis clients """ - return self.execute_command('CLIENT UNPAUSE', - target_nodes=target_nodes) + return self.execute_command("CLIENT UNPAUSE", target_nodes=target_nodes) def command(self, target_nodes=None): """ Returns dict reply of details about all Redis commands. """ - return self.execute_command('COMMAND', target_nodes=target_nodes) + return self.execute_command("COMMAND", target_nodes=target_nodes) def command_count(self, target_nodes=None): """ Returns Integer reply of number of total commands in this Redis server. """ - return self.execute_command('COMMAND COUNT', target_nodes=target_nodes) + return self.execute_command("COMMAND COUNT", target_nodes=target_nodes) def config_get(self, pattern="*", target_nodes=None): """ Return a dictionary of configuration based on the ``pattern`` """ - return self.execute_command('CONFIG GET', - pattern, - target_nodes=target_nodes) + return self.execute_command("CONFIG GET", pattern, target_nodes=target_nodes) def config_resetstat(self, target_nodes=None): """Reset runtime statistics""" - return self.execute_command('CONFIG RESETSTAT', - target_nodes=target_nodes) + return self.execute_command("CONFIG RESETSTAT", target_nodes=target_nodes) def config_rewrite(self, target_nodes=None): """ Rewrite config file with the minimal change to reflect running config. """ - return self.execute_command('CONFIG REWRITE', - target_nodes=target_nodes) + return self.execute_command("CONFIG REWRITE", target_nodes=target_nodes) def config_set(self, name, value, target_nodes=None): "Set config item ``name`` with ``value``" - return self.execute_command('CONFIG SET', - name, - value, - target_nodes=target_nodes) + return self.execute_command( + "CONFIG SET", name, value, target_nodes=target_nodes + ) def dbsize(self, target_nodes=None): """ @@ -383,8 +372,7 @@ def dbsize(self, target_nodes=None): :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' The node/s to execute the command on """ - return self.execute_command('DBSIZE', - target_nodes=target_nodes) + return self.execute_command("DBSIZE", target_nodes=target_nodes) def debug_object(self, key): raise NotImplementedError( @@ -398,8 +386,7 @@ def debug_segfault(self): def echo(self, value, target_nodes): """Echo the string back from the server""" - return self.execute_command('ECHO', value, - target_nodes=target_nodes) + return self.execute_command("ECHO", value, target_nodes=target_nodes) def flushall(self, asynchronous=False, target_nodes=None): """ @@ -411,10 +398,8 @@ def flushall(self, asynchronous=False, target_nodes=None): """ args = [] if asynchronous: - args.append(b'ASYNC') - return self.execute_command('FLUSHALL', - *args, - target_nodes=target_nodes) + args.append(b"ASYNC") + return self.execute_command("FLUSHALL", *args, target_nodes=target_nodes) def flushdb(self, asynchronous=False, target_nodes=None): """ @@ -425,10 +410,8 @@ def flushdb(self, asynchronous=False, target_nodes=None): """ args = [] if asynchronous: - args.append(b'ASYNC') - return self.execute_command('FLUSHDB', - *args, - target_nodes=target_nodes) + args.append(b"ASYNC") + return self.execute_command("FLUSHDB", *args, target_nodes=target_nodes) def info(self, section=None, target_nodes=None): """ @@ -441,24 +424,20 @@ def info(self, section=None, target_nodes=None): and will generate ResponseError """ if section is None: - return self.execute_command('INFO', - target_nodes=target_nodes) + return self.execute_command("INFO", target_nodes=target_nodes) else: - return self.execute_command('INFO', - section, - target_nodes=target_nodes) + return self.execute_command("INFO", section, target_nodes=target_nodes) - def keys(self, pattern='*', target_nodes=None): + def keys(self, pattern="*", target_nodes=None): "Returns a list of keys matching ``pattern``" - return self.execute_command('KEYS', pattern, target_nodes=target_nodes) + return self.execute_command("KEYS", pattern, target_nodes=target_nodes) def lastsave(self, target_nodes=None): """ Return a Python datetime object representing the last time the Redis database was saved to disk """ - return self.execute_command('LASTSAVE', - target_nodes=target_nodes) + return self.execute_command("LASTSAVE", target_nodes=target_nodes) def memory_doctor(self): raise NotImplementedError( @@ -472,18 +451,15 @@ def memory_help(self): def memory_malloc_stats(self, target_nodes=None): """Return an internal statistics report from the memory allocator.""" - return self.execute_command('MEMORY MALLOC-STATS', - target_nodes=target_nodes) + return self.execute_command("MEMORY MALLOC-STATS", target_nodes=target_nodes) def memory_purge(self, target_nodes=None): """Attempts to purge dirty pages for reclamation by allocator""" - return self.execute_command('MEMORY PURGE', - target_nodes=target_nodes) + return self.execute_command("MEMORY PURGE", target_nodes=target_nodes) def memory_stats(self, target_nodes=None): """Return a dictionary of memory stats""" - return self.execute_command('MEMORY STATS', - target_nodes=target_nodes) + return self.execute_command("MEMORY STATS", target_nodes=target_nodes) def memory_usage(self, key, samples=None): """ @@ -496,12 +472,12 @@ def memory_usage(self, key, samples=None): """ args = [] if isinstance(samples, int): - args.extend([b'SAMPLES', samples]) - return self.execute_command('MEMORY USAGE', key, *args) + args.extend([b"SAMPLES", samples]) + return self.execute_command("MEMORY USAGE", key, *args) def object(self, infotype, key): """Return the encoding, idletime, or refcount about the key""" - return self.execute_command('OBJECT', infotype, key, infotype=infotype) + return self.execute_command("OBJECT", infotype, key, infotype=infotype) def ping(self, target_nodes=None): """ @@ -509,24 +485,22 @@ def ping(self, target_nodes=None): If no target nodes are specified, sent to all nodes and returns True if the ping was successful across all nodes. """ - return self.execute_command('PING', - target_nodes=target_nodes) + return self.execute_command("PING", target_nodes=target_nodes) def randomkey(self, target_nodes=None): """ Returns the name of a random key" """ - return self.execute_command('RANDOMKEY', target_nodes=target_nodes) + return self.execute_command("RANDOMKEY", target_nodes=target_nodes) def save(self, target_nodes=None): """ Tell the Redis server to save its data to disk, blocking until the save is complete """ - return self.execute_command('SAVE', target_nodes=target_nodes) + return self.execute_command("SAVE", target_nodes=target_nodes) - def scan(self, cursor=0, match=None, count=None, _type=None, - target_nodes=None): + def scan(self, cursor=0, match=None, count=None, _type=None, target_nodes=None): """ Incrementally return lists of key names. Also return a cursor indicating the scan position. @@ -543,12 +517,12 @@ def scan(self, cursor=0, match=None, count=None, _type=None, """ pieces = [cursor] if match is not None: - pieces.extend([b'MATCH', match]) + pieces.extend([b"MATCH", match]) if count is not None: - pieces.extend([b'COUNT', count]) + pieces.extend([b"COUNT", count]) if _type is not None: - pieces.extend([b'TYPE', _type]) - return self.execute_command('SCAN', *pieces, target_nodes=target_nodes) + pieces.extend([b"TYPE", _type]) + return self.execute_command("SCAN", *pieces, target_nodes=target_nodes) def scan_iter(self, match=None, count=None, _type=None, target_nodes=None): """ @@ -565,11 +539,15 @@ def scan_iter(self, match=None, count=None, _type=None, target_nodes=None): HASH, LIST, SET, STREAM, STRING, ZSET Additionally, Redis modules can expose other types as well. """ - cursor = '0' + cursor = "0" while cursor != 0: - cursor, data = self.scan(cursor=cursor, match=match, - count=count, _type=_type, - target_nodes=target_nodes) + cursor, data = self.scan( + cursor=cursor, + match=match, + count=count, + _type=_type, + target_nodes=target_nodes, + ) yield from data def shutdown(self, save=False, nosave=False, target_nodes=None): @@ -580,12 +558,12 @@ def shutdown(self, save=False, nosave=False, target_nodes=None): attempted. The "save" and "nosave" options cannot both be set. """ if save and nosave: - raise DataError('SHUTDOWN save and nosave cannot both be set') - args = ['SHUTDOWN'] + raise DataError("SHUTDOWN save and nosave cannot both be set") + args = ["SHUTDOWN"] if save: - args.append('SAVE') + args.append("SAVE") if nosave: - args.append('NOSAVE') + args.append("NOSAVE") try: self.execute_command(*args, target_nodes=target_nodes) except ConnectionError: @@ -598,26 +576,32 @@ def slowlog_get(self, num=None, target_nodes=None): Get the entries from the slowlog. If ``num`` is specified, get the most recent ``num`` items. """ - args = ['SLOWLOG GET'] + args = ["SLOWLOG GET"] if num is not None: args.append(num) - return self.execute_command(*args, - target_nodes=target_nodes) + return self.execute_command(*args, target_nodes=target_nodes) def slowlog_len(self, target_nodes=None): "Get the number of items in the slowlog" - return self.execute_command('SLOWLOG LEN', - target_nodes=target_nodes) + return self.execute_command("SLOWLOG LEN", target_nodes=target_nodes) def slowlog_reset(self, target_nodes=None): "Remove all items in the slowlog" - return self.execute_command('SLOWLOG RESET', - target_nodes=target_nodes) - - def stralgo(self, algo, value1, value2, specific_argument='strings', - len=False, idx=False, minmatchlen=None, withmatchlen=False, - target_nodes=None): + return self.execute_command("SLOWLOG RESET", target_nodes=target_nodes) + + def stralgo( + self, + algo, + value1, + value2, + specific_argument="strings", + len=False, + idx=False, + minmatchlen=None, + withmatchlen=False, + target_nodes=None, + ): """ Implements complex algorithms that operate on strings. Right now the only algorithm implemented is the LCS algorithm @@ -636,40 +620,45 @@ def stralgo(self, algo, value1, value2, specific_argument='strings', Can be provided only when ``idx`` set to True. """ # check validity - supported_algo = ['LCS'] + supported_algo = ["LCS"] if algo not in supported_algo: - supported_algos_str = ', '.join(supported_algo) + supported_algos_str = ", ".join(supported_algo) raise DataError(f"The supported algorithms are: {supported_algos_str}") - if specific_argument not in ['keys', 'strings']: + if specific_argument not in ["keys", "strings"]: raise DataError("specific_argument can be only keys or strings") if len and idx: raise DataError("len and idx cannot be provided together.") pieces = [algo, specific_argument.upper(), value1, value2] if len: - pieces.append(b'LEN') + pieces.append(b"LEN") if idx: - pieces.append(b'IDX') + pieces.append(b"IDX") try: int(minmatchlen) - pieces.extend([b'MINMATCHLEN', minmatchlen]) + pieces.extend([b"MINMATCHLEN", minmatchlen]) except TypeError: pass if withmatchlen: - pieces.append(b'WITHMATCHLEN') - if specific_argument == 'strings' and target_nodes is None: - target_nodes = 'default-node' - return self.execute_command('STRALGO', *pieces, len=len, idx=idx, - minmatchlen=minmatchlen, - withmatchlen=withmatchlen, - target_nodes=target_nodes) + pieces.append(b"WITHMATCHLEN") + if specific_argument == "strings" and target_nodes is None: + target_nodes = "default-node" + return self.execute_command( + "STRALGO", + *pieces, + len=len, + idx=idx, + minmatchlen=minmatchlen, + withmatchlen=withmatchlen, + target_nodes=target_nodes, + ) def time(self, target_nodes=None): """ Returns the server time as a 2-item tuple of ints: (seconds since epoch, microseconds into this second). """ - return self.execute_command('TIME', target_nodes=target_nodes) + return self.execute_command("TIME", target_nodes=target_nodes) def wait(self, num_replicas, timeout, target_nodes=None): """ @@ -680,9 +669,9 @@ def wait(self, num_replicas, timeout, target_nodes=None): If more than one target node are passed the result will be summed up """ - return self.execute_command('WAIT', num_replicas, - timeout, - target_nodes=target_nodes) + return self.execute_command( + "WAIT", num_replicas, timeout, target_nodes=target_nodes + ) class ClusterPubSubCommands: @@ -690,38 +679,44 @@ class ClusterPubSubCommands: Redis PubSub commands for RedisCluster use. see https://redis.io/topics/pubsub """ + def publish(self, channel, message, target_nodes=None): """ Publish ``message`` on ``channel``. Returns the number of subscribers the message was delivered to. """ - return self.execute_command('PUBLISH', channel, message, - target_nodes=target_nodes) + return self.execute_command( + "PUBLISH", channel, message, target_nodes=target_nodes + ) - def pubsub_channels(self, pattern='*', target_nodes=None): + def pubsub_channels(self, pattern="*", target_nodes=None): """ Return a list of channels that have at least one subscriber """ - return self.execute_command('PUBSUB CHANNELS', pattern, - target_nodes=target_nodes) + return self.execute_command( + "PUBSUB CHANNELS", pattern, target_nodes=target_nodes + ) def pubsub_numpat(self, target_nodes=None): """ Returns the number of subscriptions to patterns """ - return self.execute_command('PUBSUB NUMPAT', target_nodes=target_nodes) + return self.execute_command("PUBSUB NUMPAT", target_nodes=target_nodes) def pubsub_numsub(self, *args, target_nodes=None): """ Return a list of (channel, number of subscribers) tuples for each channel given in ``*args`` """ - return self.execute_command('PUBSUB NUMSUB', *args, - target_nodes=target_nodes) + return self.execute_command("PUBSUB NUMSUB", *args, target_nodes=target_nodes) -class ClusterCommands(ClusterManagementCommands, ClusterMultiKeyCommands, - ClusterPubSubCommands, DataAccessCommands): +class ClusterCommands( + ClusterManagementCommands, + ClusterMultiKeyCommands, + ClusterPubSubCommands, + DataAccessCommands, +): """ Redis Cluster commands @@ -738,6 +733,7 @@ class ClusterCommands(ClusterManagementCommands, ClusterMultiKeyCommands, for example: r.cluster_info(target_nodes='all') """ + def cluster_addslots(self, target_node, *slots): """ Assign new hash slots to receiving node. Sends to specified node. @@ -745,22 +741,23 @@ def cluster_addslots(self, target_node, *slots): :target_node: 'ClusterNode' The node to execute the command on """ - return self.execute_command('CLUSTER ADDSLOTS', *slots, - target_nodes=target_node) + return self.execute_command( + "CLUSTER ADDSLOTS", *slots, target_nodes=target_node + ) def cluster_countkeysinslot(self, slot_id): """ Return the number of local keys in the specified hash slot Send to node based on specified slot_id """ - return self.execute_command('CLUSTER COUNTKEYSINSLOT', slot_id) + return self.execute_command("CLUSTER COUNTKEYSINSLOT", slot_id) def cluster_count_failure_report(self, node_id): """ Return the number of failure reports active for a given node Sends to a random node """ - return self.execute_command('CLUSTER COUNT-FAILURE-REPORTS', node_id) + return self.execute_command("CLUSTER COUNT-FAILURE-REPORTS", node_id) def cluster_delslots(self, *slots): """ @@ -769,10 +766,7 @@ def cluster_delslots(self, *slots): Returns a list of the results for each processed slot. """ - return [ - self.execute_command('CLUSTER DELSLOTS', slot) - for slot in slots - ] + return [self.execute_command("CLUSTER DELSLOTS", slot) for slot in slots] def cluster_failover(self, target_node, option=None): """ @@ -783,15 +777,16 @@ def cluster_failover(self, target_node, option=None): The node to execute the command on """ if option: - if option.upper() not in ['FORCE', 'TAKEOVER']: + if option.upper() not in ["FORCE", "TAKEOVER"]: raise RedisError( - f'Invalid option for CLUSTER FAILOVER command: {option}') + f"Invalid option for CLUSTER FAILOVER command: {option}" + ) else: - return self.execute_command('CLUSTER FAILOVER', option, - target_nodes=target_node) + return self.execute_command( + "CLUSTER FAILOVER", option, target_nodes=target_node + ) else: - return self.execute_command('CLUSTER FAILOVER', - target_nodes=target_node) + return self.execute_command("CLUSTER FAILOVER", target_nodes=target_node) def cluster_info(self, target_nodes=None): """ @@ -799,22 +794,23 @@ def cluster_info(self, target_nodes=None): The command will be sent to a random node in the cluster if no target node is specified. """ - return self.execute_command('CLUSTER INFO', target_nodes=target_nodes) + return self.execute_command("CLUSTER INFO", target_nodes=target_nodes) def cluster_keyslot(self, key): """ Returns the hash slot of the specified key Sends to random node in the cluster """ - return self.execute_command('CLUSTER KEYSLOT', key) + return self.execute_command("CLUSTER KEYSLOT", key) def cluster_meet(self, host, port, target_nodes=None): """ Force a node cluster to handshake with another node. Sends to specified node. """ - return self.execute_command('CLUSTER MEET', host, port, - target_nodes=target_nodes) + return self.execute_command( + "CLUSTER MEET", host, port, target_nodes=target_nodes + ) def cluster_nodes(self): """ @@ -822,14 +818,15 @@ def cluster_nodes(self): Sends to random node in the cluster """ - return self.execute_command('CLUSTER NODES') + return self.execute_command("CLUSTER NODES") def cluster_replicate(self, target_nodes, node_id): """ Reconfigure a node as a slave of the specified master node """ - return self.execute_command('CLUSTER REPLICATE', node_id, - target_nodes=target_nodes) + return self.execute_command( + "CLUSTER REPLICATE", node_id, target_nodes=target_nodes + ) def cluster_reset(self, soft=True, target_nodes=None): """ @@ -838,29 +835,29 @@ def cluster_reset(self, soft=True, target_nodes=None): If 'soft' is True then it will send 'SOFT' argument If 'soft' is False then it will send 'HARD' argument """ - return self.execute_command('CLUSTER RESET', - b'SOFT' if soft else b'HARD', - target_nodes=target_nodes) + return self.execute_command( + "CLUSTER RESET", b"SOFT" if soft else b"HARD", target_nodes=target_nodes + ) def cluster_save_config(self, target_nodes=None): """ Forces the node to save cluster state on disk """ - return self.execute_command('CLUSTER SAVECONFIG', - target_nodes=target_nodes) + return self.execute_command("CLUSTER SAVECONFIG", target_nodes=target_nodes) def cluster_get_keys_in_slot(self, slot, num_keys): """ Returns the number of keys in the specified cluster slot """ - return self.execute_command('CLUSTER GETKEYSINSLOT', slot, num_keys) + return self.execute_command("CLUSTER GETKEYSINSLOT", slot, num_keys) def cluster_set_config_epoch(self, epoch, target_nodes=None): """ Set the configuration epoch in a new node """ - return self.execute_command('CLUSTER SET-CONFIG-EPOCH', epoch, - target_nodes=target_nodes) + return self.execute_command( + "CLUSTER SET-CONFIG-EPOCH", epoch, target_nodes=target_nodes + ) def cluster_setslot(self, target_node, node_id, slot_id, state): """ @@ -869,47 +866,48 @@ def cluster_setslot(self, target_node, node_id, slot_id, state): :target_node: 'ClusterNode' The node to execute the command on """ - if state.upper() in ('IMPORTING', 'NODE', 'MIGRATING'): - return self.execute_command('CLUSTER SETSLOT', slot_id, state, - node_id, target_nodes=target_node) - elif state.upper() == 'STABLE': - raise RedisError('For "stable" state please use ' - 'cluster_setslot_stable') + if state.upper() in ("IMPORTING", "NODE", "MIGRATING"): + return self.execute_command( + "CLUSTER SETSLOT", slot_id, state, node_id, target_nodes=target_node + ) + elif state.upper() == "STABLE": + raise RedisError('For "stable" state please use ' "cluster_setslot_stable") else: - raise RedisError(f'Invalid slot state: {state}') + raise RedisError(f"Invalid slot state: {state}") def cluster_setslot_stable(self, slot_id): """ Clears migrating / importing state from the slot. It determines by it self what node the slot is in and sends it there. """ - return self.execute_command('CLUSTER SETSLOT', slot_id, 'STABLE') + return self.execute_command("CLUSTER SETSLOT", slot_id, "STABLE") def cluster_replicas(self, node_id, target_nodes=None): """ Provides a list of replica nodes replicating from the specified primary target node. """ - return self.execute_command('CLUSTER REPLICAS', node_id, - target_nodes=target_nodes) + return self.execute_command( + "CLUSTER REPLICAS", node_id, target_nodes=target_nodes + ) def cluster_slots(self, target_nodes=None): """ Get array of Cluster slot to node mappings """ - return self.execute_command('CLUSTER SLOTS', target_nodes=target_nodes) + return self.execute_command("CLUSTER SLOTS", target_nodes=target_nodes) def readonly(self, target_nodes=None): """ Enables read queries. The command will be sent to the default cluster node if target_nodes is not specified. - """ - if target_nodes == 'replicas' or target_nodes == 'all': + """ + if target_nodes == "replicas" or target_nodes == "all": # read_from_replicas will only be enabled if the READONLY command # is sent to all replicas self.read_from_replicas = True - return self.execute_command('READONLY', target_nodes=target_nodes) + return self.execute_command("READONLY", target_nodes=target_nodes) def readwrite(self, target_nodes=None): """ @@ -919,4 +917,4 @@ def readwrite(self, target_nodes=None): """ # Reset read from replicas flag self.read_from_replicas = False - return self.execute_command('READWRITE', target_nodes=target_nodes) + return self.execute_command("READWRITE", target_nodes=target_nodes) diff --git a/redis/commands/core.py b/redis/commands/core.py index 0285f80e0f..688e1dda1b 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -1,15 +1,11 @@ import datetime +import hashlib import time import warnings -import hashlib + +from redis.exceptions import ConnectionError, DataError, NoScriptError, RedisError from .helpers import list_or_args -from redis.exceptions import ( - ConnectionError, - DataError, - NoScriptError, - RedisError, -) class ACLCommands: @@ -17,6 +13,7 @@ class ACLCommands: Redis Access Control List (ACL) commands. see: https://redis.io/topics/acl """ + def acl_cat(self, category=None): """ Returns a list of categories or commands within a category. @@ -28,7 +25,7 @@ def acl_cat(self, category=None): For more information check https://redis.io/commands/acl-cat """ pieces = [category] if category else [] - return self.execute_command('ACL CAT', *pieces) + return self.execute_command("ACL CAT", *pieces) def acl_deluser(self, *username): """ @@ -36,7 +33,7 @@ def acl_deluser(self, *username): For more information check https://redis.io/commands/acl-deluser """ - return self.execute_command('ACL DELUSER', *username) + return self.execute_command("ACL DELUSER", *username) def acl_genpass(self, bits=None): """Generate a random password value. @@ -51,9 +48,10 @@ def acl_genpass(self, bits=None): if b < 0 or b > 4096: raise ValueError except ValueError: - raise DataError('genpass optionally accepts a bits argument, ' - 'between 0 and 4096.') - return self.execute_command('ACL GENPASS', *pieces) + raise DataError( + "genpass optionally accepts a bits argument, " "between 0 and 4096." + ) + return self.execute_command("ACL GENPASS", *pieces) def acl_getuser(self, username): """ @@ -63,7 +61,7 @@ def acl_getuser(self, username): For more information check https://redis.io/commands/acl-getuser """ - return self.execute_command('ACL GETUSER', username) + return self.execute_command("ACL GETUSER", username) def acl_help(self): """The ACL HELP command returns helpful text describing @@ -71,7 +69,7 @@ def acl_help(self): For more information check https://redis.io/commands/acl-help """ - return self.execute_command('ACL HELP') + return self.execute_command("ACL HELP") def acl_list(self): """ @@ -79,7 +77,7 @@ def acl_list(self): For more information check https://redis.io/commands/acl-list """ - return self.execute_command('ACL LIST') + return self.execute_command("ACL LIST") def acl_log(self, count=None): """ @@ -92,11 +90,10 @@ def acl_log(self, count=None): args = [] if count is not None: if not isinstance(count, int): - raise DataError('ACL LOG count must be an ' - 'integer') + raise DataError("ACL LOG count must be an " "integer") args.append(count) - return self.execute_command('ACL LOG', *args) + return self.execute_command("ACL LOG", *args) def acl_log_reset(self): """ @@ -105,8 +102,8 @@ def acl_log_reset(self): For more information check https://redis.io/commands/acl-log """ - args = [b'RESET'] - return self.execute_command('ACL LOG', *args) + args = [b"RESET"] + return self.execute_command("ACL LOG", *args) def acl_load(self): """ @@ -117,7 +114,7 @@ def acl_load(self): For more information check https://redis.io/commands/acl-load """ - return self.execute_command('ACL LOAD') + return self.execute_command("ACL LOAD") def acl_save(self): """ @@ -128,12 +125,22 @@ def acl_save(self): For more information check https://redis.io/commands/acl-save """ - return self.execute_command('ACL SAVE') - - def acl_setuser(self, username, enabled=False, nopass=False, - passwords=None, hashed_passwords=None, categories=None, - commands=None, keys=None, reset=False, reset_keys=False, - reset_passwords=False): + return self.execute_command("ACL SAVE") + + def acl_setuser( + self, + username, + enabled=False, + nopass=False, + passwords=None, + hashed_passwords=None, + categories=None, + commands=None, + keys=None, + reset=False, + reset_keys=False, + reset_passwords=False, + ): """ Create or update an ACL user. @@ -199,22 +206,23 @@ def acl_setuser(self, username, enabled=False, nopass=False, pieces = [username] if reset: - pieces.append(b'reset') + pieces.append(b"reset") if reset_keys: - pieces.append(b'resetkeys') + pieces.append(b"resetkeys") if reset_passwords: - pieces.append(b'resetpass') + pieces.append(b"resetpass") if enabled: - pieces.append(b'on') + pieces.append(b"on") else: - pieces.append(b'off') + pieces.append(b"off") if (passwords or hashed_passwords) and nopass: - raise DataError('Cannot set \'nopass\' and supply ' - '\'passwords\' or \'hashed_passwords\'') + raise DataError( + "Cannot set 'nopass' and supply " "'passwords' or 'hashed_passwords'" + ) if passwords: # as most users will have only one password, allow remove_passwords @@ -222,13 +230,15 @@ def acl_setuser(self, username, enabled=False, nopass=False, passwords = list_or_args(passwords, []) for i, password in enumerate(passwords): password = encoder.encode(password) - if password.startswith(b'+'): - pieces.append(b'>%s' % password[1:]) - elif password.startswith(b'-'): - pieces.append(b'<%s' % password[1:]) + if password.startswith(b"+"): + pieces.append(b">%s" % password[1:]) + elif password.startswith(b"-"): + pieces.append(b"<%s" % password[1:]) else: - raise DataError(f'Password {i} must be prefixed with a ' - f'"+" to add or a "-" to remove') + raise DataError( + f"Password {i} must be prefixed with a " + f'"+" to add or a "-" to remove' + ) if hashed_passwords: # as most users will have only one password, allow remove_passwords @@ -236,29 +246,31 @@ def acl_setuser(self, username, enabled=False, nopass=False, hashed_passwords = list_or_args(hashed_passwords, []) for i, hashed_password in enumerate(hashed_passwords): hashed_password = encoder.encode(hashed_password) - if hashed_password.startswith(b'+'): - pieces.append(b'#%s' % hashed_password[1:]) - elif hashed_password.startswith(b'-'): - pieces.append(b'!%s' % hashed_password[1:]) + if hashed_password.startswith(b"+"): + pieces.append(b"#%s" % hashed_password[1:]) + elif hashed_password.startswith(b"-"): + pieces.append(b"!%s" % hashed_password[1:]) else: - raise DataError(f'Hashed password {i} must be prefixed with a ' - f'"+" to add or a "-" to remove') + raise DataError( + f"Hashed password {i} must be prefixed with a " + f'"+" to add or a "-" to remove' + ) if nopass: - pieces.append(b'nopass') + pieces.append(b"nopass") if categories: for category in categories: category = encoder.encode(category) # categories can be prefixed with one of (+@, +, -@, -) - if category.startswith(b'+@'): + if category.startswith(b"+@"): pieces.append(category) - elif category.startswith(b'+'): - pieces.append(b'+@%s' % category[1:]) - elif category.startswith(b'-@'): + elif category.startswith(b"+"): + pieces.append(b"+@%s" % category[1:]) + elif category.startswith(b"-@"): pieces.append(category) - elif category.startswith(b'-'): - pieces.append(b'-@%s' % category[1:]) + elif category.startswith(b"-"): + pieces.append(b"-@%s" % category[1:]) else: raise DataError( f'Category "{encoder.decode(category, force=True)}" ' @@ -267,7 +279,7 @@ def acl_setuser(self, username, enabled=False, nopass=False, if commands: for cmd in commands: cmd = encoder.encode(cmd) - if not cmd.startswith(b'+') and not cmd.startswith(b'-'): + if not cmd.startswith(b"+") and not cmd.startswith(b"-"): raise DataError( f'Command "{encoder.decode(cmd, force=True)}" ' 'must be prefixed with "+" or "-"' @@ -277,35 +289,36 @@ def acl_setuser(self, username, enabled=False, nopass=False, if keys: for key in keys: key = encoder.encode(key) - pieces.append(b'~%s' % key) + pieces.append(b"~%s" % key) - return self.execute_command('ACL SETUSER', *pieces) + return self.execute_command("ACL SETUSER", *pieces) def acl_users(self): """Returns a list of all registered users on the server. For more information check https://redis.io/commands/acl-users """ - return self.execute_command('ACL USERS') + return self.execute_command("ACL USERS") def acl_whoami(self): """Get the username for the current connection For more information check https://redis.io/commands/acl-whoami """ - return self.execute_command('ACL WHOAMI') + return self.execute_command("ACL WHOAMI") class ManagementCommands: """ Redis management commands """ + def bgrewriteaof(self): """Tell the Redis server to rewrite the AOF file from data in memory. For more information check https://redis.io/commands/bgrewriteaof """ - return self.execute_command('BGREWRITEAOF') + return self.execute_command("BGREWRITEAOF") def bgsave(self, schedule=True): """ @@ -317,17 +330,18 @@ def bgsave(self, schedule=True): pieces = [] if schedule: pieces.append("SCHEDULE") - return self.execute_command('BGSAVE', *pieces) + return self.execute_command("BGSAVE", *pieces) def client_kill(self, address): """Disconnects the client at ``address`` (ip:port) For more information check https://redis.io/commands/client-kill """ - return self.execute_command('CLIENT KILL', address) + return self.execute_command("CLIENT KILL", address) - def client_kill_filter(self, _id=None, _type=None, addr=None, - skipme=None, laddr=None, user=None): + def client_kill_filter( + self, _id=None, _type=None, addr=None, skipme=None, laddr=None, user=None + ): """ Disconnects client(s) using a variety of filter options :param id: Kills a client by its unique ID field @@ -342,29 +356,31 @@ def client_kill_filter(self, _id=None, _type=None, addr=None, """ args = [] if _type is not None: - client_types = ('normal', 'master', 'slave', 'pubsub') + client_types = ("normal", "master", "slave", "pubsub") if str(_type).lower() not in client_types: raise DataError(f"CLIENT KILL type must be one of {client_types!r}") - args.extend((b'TYPE', _type)) + args.extend((b"TYPE", _type)) if skipme is not None: if not isinstance(skipme, bool): raise DataError("CLIENT KILL skipme must be a bool") if skipme: - args.extend((b'SKIPME', b'YES')) + args.extend((b"SKIPME", b"YES")) else: - args.extend((b'SKIPME', b'NO')) + args.extend((b"SKIPME", b"NO")) if _id is not None: - args.extend((b'ID', _id)) + args.extend((b"ID", _id)) if addr is not None: - args.extend((b'ADDR', addr)) + args.extend((b"ADDR", addr)) if laddr is not None: - args.extend((b'LADDR', laddr)) + args.extend((b"LADDR", laddr)) if user is not None: - args.extend((b'USER', user)) + args.extend((b"USER", user)) if not args: - raise DataError("CLIENT KILL ... ... " - " must specify at least one filter") - return self.execute_command('CLIENT KILL', *args) + raise DataError( + "CLIENT KILL ... ... " + " must specify at least one filter" + ) + return self.execute_command("CLIENT KILL", *args) def client_info(self): """ @@ -373,7 +389,7 @@ def client_info(self): For more information check https://redis.io/commands/client-info """ - return self.execute_command('CLIENT INFO') + return self.execute_command("CLIENT INFO") def client_list(self, _type=None, client_id=[]): """ @@ -387,17 +403,17 @@ def client_list(self, _type=None, client_id=[]): """ args = [] if _type is not None: - client_types = ('normal', 'master', 'replica', 'pubsub') + client_types = ("normal", "master", "replica", "pubsub") if str(_type).lower() not in client_types: raise DataError(f"CLIENT LIST _type must be one of {client_types!r}") - args.append(b'TYPE') + args.append(b"TYPE") args.append(_type) if not isinstance(client_id, list): raise DataError("client_id must be a list") if client_id != []: args.append(b"ID") - args.append(' '.join(client_id)) - return self.execute_command('CLIENT LIST', *args) + args.append(" ".join(client_id)) + return self.execute_command("CLIENT LIST", *args) def client_getname(self): """ @@ -405,7 +421,7 @@ def client_getname(self): For more information check https://redis.io/commands/client-getname """ - return self.execute_command('CLIENT GETNAME') + return self.execute_command("CLIENT GETNAME") def client_getredir(self): """ @@ -414,7 +430,7 @@ def client_getredir(self): see: https://redis.io/commands/client-getredir """ - return self.execute_command('CLIENT GETREDIR') + return self.execute_command("CLIENT GETREDIR") def client_reply(self, reply): """ @@ -432,9 +448,9 @@ def client_reply(self, reply): See https://redis.io/commands/client-reply """ - replies = ['ON', 'OFF', 'SKIP'] + replies = ["ON", "OFF", "SKIP"] if reply not in replies: - raise DataError(f'CLIENT REPLY must be one of {replies!r}') + raise DataError(f"CLIENT REPLY must be one of {replies!r}") return self.execute_command("CLIENT REPLY", reply) def client_id(self): @@ -443,7 +459,7 @@ def client_id(self): For more information check https://redis.io/commands/client-id """ - return self.execute_command('CLIENT ID') + return self.execute_command("CLIENT ID") def client_trackinginfo(self): """ @@ -452,7 +468,7 @@ def client_trackinginfo(self): See https://redis.io/commands/client-trackinginfo """ - return self.execute_command('CLIENT TRACKINGINFO') + return self.execute_command("CLIENT TRACKINGINFO") def client_setname(self, name): """ @@ -460,7 +476,7 @@ def client_setname(self, name): For more information check https://redis.io/commands/client-setname """ - return self.execute_command('CLIENT SETNAME', name) + return self.execute_command("CLIENT SETNAME", name) def client_unblock(self, client_id, error=False): """ @@ -471,9 +487,9 @@ def client_unblock(self, client_id, error=False): For more information check https://redis.io/commands/client-unblock """ - args = ['CLIENT UNBLOCK', int(client_id)] + args = ["CLIENT UNBLOCK", int(client_id)] if error: - args.append(b'ERROR') + args.append(b"ERROR") return self.execute_command(*args) def client_pause(self, timeout): @@ -485,7 +501,7 @@ def client_pause(self, timeout): """ if not isinstance(timeout, int): raise DataError("CLIENT PAUSE timeout must be an integer") - return self.execute_command('CLIENT PAUSE', str(timeout)) + return self.execute_command("CLIENT PAUSE", str(timeout)) def client_unpause(self): """ @@ -493,7 +509,7 @@ def client_unpause(self): For more information check https://redis.io/commands/client-unpause """ - return self.execute_command('CLIENT UNPAUSE') + return self.execute_command("CLIENT UNPAUSE") def command_info(self): raise NotImplementedError( @@ -501,7 +517,7 @@ def command_info(self): ) def command_count(self): - return self.execute_command('COMMAND COUNT') + return self.execute_command("COMMAND COUNT") def readwrite(self): """ @@ -509,7 +525,7 @@ def readwrite(self): For more information check https://redis.io/commands/readwrite """ - return self.execute_command('READWRITE') + return self.execute_command("READWRITE") def readonly(self): """ @@ -517,7 +533,7 @@ def readonly(self): For more information check https://redis.io/commands/readonly """ - return self.execute_command('READONLY') + return self.execute_command("READONLY") def config_get(self, pattern="*"): """ @@ -525,14 +541,14 @@ def config_get(self, pattern="*"): For more information check https://redis.io/commands/config-get """ - return self.execute_command('CONFIG GET', pattern) + return self.execute_command("CONFIG GET", pattern) def config_set(self, name, value): """Set config item ``name`` with ``value`` For more information check https://redis.io/commands/config-set """ - return self.execute_command('CONFIG SET', name, value) + return self.execute_command("CONFIG SET", name, value) def config_resetstat(self): """ @@ -540,7 +556,7 @@ def config_resetstat(self): For more information check https://redis.io/commands/config-resetstat """ - return self.execute_command('CONFIG RESETSTAT') + return self.execute_command("CONFIG RESETSTAT") def config_rewrite(self): """ @@ -548,10 +564,10 @@ def config_rewrite(self): For more information check https://redis.io/commands/config-rewrite """ - return self.execute_command('CONFIG REWRITE') + return self.execute_command("CONFIG REWRITE") def cluster(self, cluster_arg, *args): - return self.execute_command(f'CLUSTER {cluster_arg.upper()}', *args) + return self.execute_command(f"CLUSTER {cluster_arg.upper()}", *args) def dbsize(self): """ @@ -559,7 +575,7 @@ def dbsize(self): For more information check https://redis.io/commands/dbsize """ - return self.execute_command('DBSIZE') + return self.execute_command("DBSIZE") def debug_object(self, key): """ @@ -567,7 +583,7 @@ def debug_object(self, key): For more information check https://redis.io/commands/debug-object """ - return self.execute_command('DEBUG OBJECT', key) + return self.execute_command("DEBUG OBJECT", key) def debug_segfault(self): raise NotImplementedError( @@ -584,7 +600,7 @@ def echo(self, value): For more information check https://redis.io/commands/echo """ - return self.execute_command('ECHO', value) + return self.execute_command("ECHO", value) def flushall(self, asynchronous=False): """ @@ -597,8 +613,8 @@ def flushall(self, asynchronous=False): """ args = [] if asynchronous: - args.append(b'ASYNC') - return self.execute_command('FLUSHALL', *args) + args.append(b"ASYNC") + return self.execute_command("FLUSHALL", *args) def flushdb(self, asynchronous=False): """ @@ -611,8 +627,8 @@ def flushdb(self, asynchronous=False): """ args = [] if asynchronous: - args.append(b'ASYNC') - return self.execute_command('FLUSHDB', *args) + args.append(b"ASYNC") + return self.execute_command("FLUSHDB", *args) def swapdb(self, first, second): """ @@ -620,7 +636,7 @@ def swapdb(self, first, second): For more information check https://redis.io/commands/swapdb """ - return self.execute_command('SWAPDB', first, second) + return self.execute_command("SWAPDB", first, second) def info(self, section=None): """ @@ -635,9 +651,9 @@ def info(self, section=None): For more information check https://redis.io/commands/info """ if section is None: - return self.execute_command('INFO') + return self.execute_command("INFO") else: - return self.execute_command('INFO', section) + return self.execute_command("INFO", section) def lastsave(self): """ @@ -646,7 +662,7 @@ def lastsave(self): For more information check https://redis.io/commands/lastsave """ - return self.execute_command('LASTSAVE') + return self.execute_command("LASTSAVE") def lolwut(self, *version_numbers): """ @@ -655,12 +671,21 @@ def lolwut(self, *version_numbers): See: https://redis.io/commands/lolwut """ if version_numbers: - return self.execute_command('LOLWUT VERSION', *version_numbers) + return self.execute_command("LOLWUT VERSION", *version_numbers) else: - return self.execute_command('LOLWUT') - - def migrate(self, host, port, keys, destination_db, timeout, - copy=False, replace=False, auth=None): + return self.execute_command("LOLWUT") + + def migrate( + self, + host, + port, + keys, + destination_db, + timeout, + copy=False, + replace=False, + auth=None, + ): """ Migrate 1 or more keys from the current Redis server to a different server specified by the ``host``, ``port`` and ``destination_db``. @@ -682,25 +707,26 @@ def migrate(self, host, port, keys, destination_db, timeout, """ keys = list_or_args(keys, []) if not keys: - raise DataError('MIGRATE requires at least one key') + raise DataError("MIGRATE requires at least one key") pieces = [] if copy: - pieces.append(b'COPY') + pieces.append(b"COPY") if replace: - pieces.append(b'REPLACE') + pieces.append(b"REPLACE") if auth: - pieces.append(b'AUTH') + pieces.append(b"AUTH") pieces.append(auth) - pieces.append(b'KEYS') + pieces.append(b"KEYS") pieces.extend(keys) - return self.execute_command('MIGRATE', host, port, '', destination_db, - timeout, *pieces) + return self.execute_command( + "MIGRATE", host, port, "", destination_db, timeout, *pieces + ) def object(self, infotype, key): """ Return the encoding, idletime, or refcount about the key """ - return self.execute_command('OBJECT', infotype, key, infotype=infotype) + return self.execute_command("OBJECT", infotype, key, infotype=infotype) def memory_doctor(self): raise NotImplementedError( @@ -726,7 +752,7 @@ def memory_stats(self): For more information check https://redis.io/commands/memory-stats """ - return self.execute_command('MEMORY STATS') + return self.execute_command("MEMORY STATS") def memory_malloc_stats(self): """ @@ -734,7 +760,7 @@ def memory_malloc_stats(self): See: https://redis.io/commands/memory-malloc-stats """ - return self.execute_command('MEMORY MALLOC-STATS') + return self.execute_command("MEMORY MALLOC-STATS") def memory_usage(self, key, samples=None): """ @@ -749,8 +775,8 @@ def memory_usage(self, key, samples=None): """ args = [] if isinstance(samples, int): - args.extend([b'SAMPLES', samples]) - return self.execute_command('MEMORY USAGE', key, *args) + args.extend([b"SAMPLES", samples]) + return self.execute_command("MEMORY USAGE", key, *args) def memory_purge(self): """ @@ -758,7 +784,7 @@ def memory_purge(self): For more information check https://redis.io/commands/memory-purge """ - return self.execute_command('MEMORY PURGE') + return self.execute_command("MEMORY PURGE") def ping(self): """ @@ -766,7 +792,7 @@ def ping(self): For more information check https://redis.io/commands/ping """ - return self.execute_command('PING') + return self.execute_command("PING") def quit(self): """ @@ -774,7 +800,7 @@ def quit(self): For more information check https://redis.io/commands/quit """ - return self.execute_command('QUIT') + return self.execute_command("QUIT") def replicaof(self, *args): """ @@ -785,7 +811,7 @@ def replicaof(self, *args): For more information check https://redis.io/commands/replicaof """ - return self.execute_command('REPLICAOF', *args) + return self.execute_command("REPLICAOF", *args) def save(self): """ @@ -794,7 +820,7 @@ def save(self): For more information check https://redis.io/commands/save """ - return self.execute_command('SAVE') + return self.execute_command("SAVE") def shutdown(self, save=False, nosave=False): """Shutdown the Redis server. If Redis has persistence configured, @@ -806,12 +832,12 @@ def shutdown(self, save=False, nosave=False): For more information check https://redis.io/commands/shutdown """ if save and nosave: - raise DataError('SHUTDOWN save and nosave cannot both be set') - args = ['SHUTDOWN'] + raise DataError("SHUTDOWN save and nosave cannot both be set") + args = ["SHUTDOWN"] if save: - args.append('SAVE') + args.append("SAVE") if nosave: - args.append('NOSAVE') + args.append("NOSAVE") try: self.execute_command(*args) except ConnectionError: @@ -828,8 +854,8 @@ def slaveof(self, host=None, port=None): For more information check https://redis.io/commands/slaveof """ if host is None and port is None: - return self.execute_command('SLAVEOF', b'NO', b'ONE') - return self.execute_command('SLAVEOF', host, port) + return self.execute_command("SLAVEOF", b"NO", b"ONE") + return self.execute_command("SLAVEOF", host, port) def slowlog_get(self, num=None): """ @@ -838,11 +864,12 @@ def slowlog_get(self, num=None): For more information check https://redis.io/commands/slowlog-get """ - args = ['SLOWLOG GET'] + args = ["SLOWLOG GET"] if num is not None: args.append(num) decode_responses = self.connection_pool.connection_kwargs.get( - 'decode_responses', False) + "decode_responses", False + ) return self.execute_command(*args, decode_responses=decode_responses) def slowlog_len(self): @@ -851,7 +878,7 @@ def slowlog_len(self): For more information check https://redis.io/commands/slowlog-len """ - return self.execute_command('SLOWLOG LEN') + return self.execute_command("SLOWLOG LEN") def slowlog_reset(self): """ @@ -859,7 +886,7 @@ def slowlog_reset(self): For more information check https://redis.io/commands/slowlog-reset """ - return self.execute_command('SLOWLOG RESET') + return self.execute_command("SLOWLOG RESET") def time(self): """ @@ -868,7 +895,7 @@ def time(self): For more information check https://redis.io/commands/time """ - return self.execute_command('TIME') + return self.execute_command("TIME") def wait(self, num_replicas, timeout): """ @@ -879,13 +906,14 @@ def wait(self, num_replicas, timeout): For more information check https://redis.io/commands/wait """ - return self.execute_command('WAIT', num_replicas, timeout) + return self.execute_command("WAIT", num_replicas, timeout) class BasicKeyCommands: """ Redis basic key-based commands """ + def append(self, key, value): """ Appends the string ``value`` to the value at ``key``. If ``key`` @@ -894,7 +922,7 @@ def append(self, key, value): For more information check https://redis.io/commands/append """ - return self.execute_command('APPEND', key, value) + return self.execute_command("APPEND", key, value) def bitcount(self, key, start=None, end=None): """ @@ -907,10 +935,9 @@ def bitcount(self, key, start=None, end=None): if start is not None and end is not None: params.append(start) params.append(end) - elif (start is not None and end is None) or \ - (end is not None and start is None): + elif (start is not None and end is None) or (end is not None and start is None): raise DataError("Both start and end must be specified") - return self.execute_command('BITCOUNT', *params) + return self.execute_command("BITCOUNT", *params) def bitfield(self, key, default_overflow=None): """ @@ -928,7 +955,7 @@ def bitop(self, operation, dest, *keys): For more information check https://redis.io/commands/bitop """ - return self.execute_command('BITOP', operation, dest, *keys) + return self.execute_command("BITOP", operation, dest, *keys) def bitpos(self, key, bit, start=None, end=None): """ @@ -940,7 +967,7 @@ def bitpos(self, key, bit, start=None, end=None): For more information check https://redis.io/commands/bitpos """ if bit not in (0, 1): - raise DataError('bit must be 0 or 1') + raise DataError("bit must be 0 or 1") params = [key, bit] start is not None and params.append(start) @@ -948,9 +975,8 @@ def bitpos(self, key, bit, start=None, end=None): if start is not None and end is not None: params.append(end) elif start is None and end is not None: - raise DataError("start argument is not set, " - "when end is specified") - return self.execute_command('BITPOS', *params) + raise DataError("start argument is not set, " "when end is specified") + return self.execute_command("BITPOS", *params) def copy(self, source, destination, destination_db=None, replace=False): """ @@ -970,7 +996,7 @@ def copy(self, source, destination, destination_db=None, replace=False): params.extend(["DB", destination_db]) if replace: params.append("REPLACE") - return self.execute_command('COPY', *params) + return self.execute_command("COPY", *params) def decr(self, name, amount=1): """ @@ -990,13 +1016,13 @@ def decrby(self, name, amount=1): For more information check https://redis.io/commands/decrby """ - return self.execute_command('DECRBY', name, amount) + return self.execute_command("DECRBY", name, amount) def delete(self, *names): """ Delete one or more keys specified by ``names`` """ - return self.execute_command('DEL', *names) + return self.execute_command("DEL", *names) def __delitem__(self, name): self.delete(name) @@ -1009,9 +1035,10 @@ def dump(self, name): For more information check https://redis.io/commands/dump """ from redis.client import NEVER_DECODE + options = {} options[NEVER_DECODE] = [] - return self.execute_command('DUMP', name, **options) + return self.execute_command("DUMP", name, **options) def exists(self, *names): """ @@ -1019,7 +1046,8 @@ def exists(self, *names): For more information check https://redis.io/commands/exists """ - return self.execute_command('EXISTS', *names) + return self.execute_command("EXISTS", *names) + __contains__ = exists def expire(self, name, time): @@ -1031,7 +1059,7 @@ def expire(self, name, time): """ if isinstance(time, datetime.timedelta): time = int(time.total_seconds()) - return self.execute_command('EXPIRE', name, time) + return self.execute_command("EXPIRE", name, time) def expireat(self, name, when): """ @@ -1042,7 +1070,7 @@ def expireat(self, name, when): """ if isinstance(when, datetime.datetime): when = int(time.mktime(when.timetuple())) - return self.execute_command('EXPIREAT', name, when) + return self.execute_command("EXPIREAT", name, when) def get(self, name): """ @@ -1050,7 +1078,7 @@ def get(self, name): For more information check https://redis.io/commands/get """ - return self.execute_command('GET', name) + return self.execute_command("GET", name) def getdel(self, name): """ @@ -1061,10 +1089,9 @@ def getdel(self, name): For more information check https://redis.io/commands/getdel """ - return self.execute_command('GETDEL', name) + return self.execute_command("GETDEL", name) - def getex(self, name, - ex=None, px=None, exat=None, pxat=None, persist=False): + def getex(self, name, ex=None, px=None, exat=None, pxat=None, persist=False): """ Get the value of key and optionally set its expiration. GETEX is similar to GET, but is a write command with @@ -1088,38 +1115,40 @@ def getex(self, name, opset = {ex, px, exat, pxat} if len(opset) > 2 or len(opset) > 1 and persist: - raise DataError("``ex``, ``px``, ``exat``, ``pxat``, " - "and ``persist`` are mutually exclusive.") + raise DataError( + "``ex``, ``px``, ``exat``, ``pxat``, " + "and ``persist`` are mutually exclusive." + ) pieces = [] # similar to set command if ex is not None: - pieces.append('EX') + pieces.append("EX") if isinstance(ex, datetime.timedelta): ex = int(ex.total_seconds()) pieces.append(ex) if px is not None: - pieces.append('PX') + pieces.append("PX") if isinstance(px, datetime.timedelta): px = int(px.total_seconds() * 1000) pieces.append(px) # similar to pexpireat command if exat is not None: - pieces.append('EXAT') + pieces.append("EXAT") if isinstance(exat, datetime.datetime): s = int(exat.microsecond / 1000000) exat = int(time.mktime(exat.timetuple())) + s pieces.append(exat) if pxat is not None: - pieces.append('PXAT') + pieces.append("PXAT") if isinstance(pxat, datetime.datetime): ms = int(pxat.microsecond / 1000) pxat = int(time.mktime(pxat.timetuple())) * 1000 + ms pieces.append(pxat) if persist: - pieces.append('PERSIST') + pieces.append("PERSIST") - return self.execute_command('GETEX', name, *pieces) + return self.execute_command("GETEX", name, *pieces) def __getitem__(self, name): """ @@ -1137,7 +1166,7 @@ def getbit(self, name, offset): For more information check https://redis.io/commands/getbit """ - return self.execute_command('GETBIT', name, offset) + return self.execute_command("GETBIT", name, offset) def getrange(self, key, start, end): """ @@ -1146,7 +1175,7 @@ def getrange(self, key, start, end): For more information check https://redis.io/commands/getrange """ - return self.execute_command('GETRANGE', key, start, end) + return self.execute_command("GETRANGE", key, start, end) def getset(self, name, value): """ @@ -1158,7 +1187,7 @@ def getset(self, name, value): For more information check https://redis.io/commands/getset """ - return self.execute_command('GETSET', name, value) + return self.execute_command("GETSET", name, value) def incr(self, name, amount=1): """ @@ -1178,7 +1207,7 @@ def incrby(self, name, amount=1): """ # An alias for ``incr()``, because it is already implemented # as INCRBY redis command. - return self.execute_command('INCRBY', name, amount) + return self.execute_command("INCRBY", name, amount) def incrbyfloat(self, name, amount=1.0): """ @@ -1187,15 +1216,15 @@ def incrbyfloat(self, name, amount=1.0): For more information check https://redis.io/commands/incrbyfloat """ - return self.execute_command('INCRBYFLOAT', name, amount) + return self.execute_command("INCRBYFLOAT", name, amount) - def keys(self, pattern='*'): + def keys(self, pattern="*"): """ Returns a list of keys matching ``pattern`` For more information check https://redis.io/commands/keys """ - return self.execute_command('KEYS', pattern) + return self.execute_command("KEYS", pattern) def lmove(self, first_list, second_list, src="LEFT", dest="RIGHT"): """ @@ -1208,8 +1237,7 @@ def lmove(self, first_list, second_list, src="LEFT", dest="RIGHT"): params = [first_list, second_list, src, dest] return self.execute_command("LMOVE", *params) - def blmove(self, first_list, second_list, timeout, - src="LEFT", dest="RIGHT"): + def blmove(self, first_list, second_list, timeout, src="LEFT", dest="RIGHT"): """ Blocking version of lmove. @@ -1225,11 +1253,12 @@ def mget(self, keys, *args): For more information check https://redis.io/commands/mget """ from redis.client import EMPTY_RESPONSE + args = list_or_args(keys, args) options = {} if not args: options[EMPTY_RESPONSE] = [] - return self.execute_command('MGET', *args, **options) + return self.execute_command("MGET", *args, **options) def mset(self, mapping): """ @@ -1242,7 +1271,7 @@ def mset(self, mapping): items = [] for pair in mapping.items(): items.extend(pair) - return self.execute_command('MSET', *items) + return self.execute_command("MSET", *items) def msetnx(self, mapping): """ @@ -1256,7 +1285,7 @@ def msetnx(self, mapping): items = [] for pair in mapping.items(): items.extend(pair) - return self.execute_command('MSETNX', *items) + return self.execute_command("MSETNX", *items) def move(self, name, db): """ @@ -1264,7 +1293,7 @@ def move(self, name, db): For more information check https://redis.io/commands/move """ - return self.execute_command('MOVE', name, db) + return self.execute_command("MOVE", name, db) def persist(self, name): """ @@ -1272,7 +1301,7 @@ def persist(self, name): For more information check https://redis.io/commands/persist """ - return self.execute_command('PERSIST', name) + return self.execute_command("PERSIST", name) def pexpire(self, name, time): """ @@ -1284,7 +1313,7 @@ def pexpire(self, name, time): """ if isinstance(time, datetime.timedelta): time = int(time.total_seconds() * 1000) - return self.execute_command('PEXPIRE', name, time) + return self.execute_command("PEXPIRE", name, time) def pexpireat(self, name, when): """ @@ -1297,7 +1326,7 @@ def pexpireat(self, name, when): if isinstance(when, datetime.datetime): ms = int(when.microsecond / 1000) when = int(time.mktime(when.timetuple())) * 1000 + ms - return self.execute_command('PEXPIREAT', name, when) + return self.execute_command("PEXPIREAT", name, when) def psetex(self, name, time_ms, value): """ @@ -1309,7 +1338,7 @@ def psetex(self, name, time_ms, value): """ if isinstance(time_ms, datetime.timedelta): time_ms = int(time_ms.total_seconds() * 1000) - return self.execute_command('PSETEX', name, time_ms, value) + return self.execute_command("PSETEX", name, time_ms, value) def pttl(self, name): """ @@ -1317,7 +1346,7 @@ def pttl(self, name): For more information check https://redis.io/commands/pttl """ - return self.execute_command('PTTL', name) + return self.execute_command("PTTL", name) def hrandfield(self, key, count=None, withvalues=False): """ @@ -1347,7 +1376,7 @@ def randomkey(self): For more information check https://redis.io/commands/randomkey """ - return self.execute_command('RANDOMKEY') + return self.execute_command("RANDOMKEY") def rename(self, src, dst): """ @@ -1355,7 +1384,7 @@ def rename(self, src, dst): For more information check https://redis.io/commands/rename """ - return self.execute_command('RENAME', src, dst) + return self.execute_command("RENAME", src, dst) def renamenx(self, src, dst): """ @@ -1363,10 +1392,18 @@ def renamenx(self, src, dst): For more information check https://redis.io/commands/renamenx """ - return self.execute_command('RENAMENX', src, dst) + return self.execute_command("RENAMENX", src, dst) - def restore(self, name, ttl, value, replace=False, absttl=False, - idletime=None, frequency=None): + def restore( + self, + name, + ttl, + value, + replace=False, + absttl=False, + idletime=None, + frequency=None, + ): """ Create a key using the provided serialized value, previously obtained using DUMP. @@ -1388,28 +1425,38 @@ def restore(self, name, ttl, value, replace=False, absttl=False, """ params = [name, ttl, value] if replace: - params.append('REPLACE') + params.append("REPLACE") if absttl: - params.append('ABSTTL') + params.append("ABSTTL") if idletime is not None: - params.append('IDLETIME') + params.append("IDLETIME") try: params.append(int(idletime)) except ValueError: raise DataError("idletimemust be an integer") if frequency is not None: - params.append('FREQ') + params.append("FREQ") try: params.append(int(frequency)) except ValueError: raise DataError("frequency must be an integer") - return self.execute_command('RESTORE', *params) - - def set(self, name, value, - ex=None, px=None, nx=False, xx=False, keepttl=False, get=False, - exat=None, pxat=None): + return self.execute_command("RESTORE", *params) + + def set( + self, + name, + value, + ex=None, + px=None, + nx=False, + xx=False, + keepttl=False, + get=False, + exat=None, + pxat=None, + ): """ Set the value at key ``name`` to ``value`` @@ -1441,7 +1488,7 @@ def set(self, name, value, pieces = [name, value] options = {} if ex is not None: - pieces.append('EX') + pieces.append("EX") if isinstance(ex, datetime.timedelta): pieces.append(int(ex.total_seconds())) elif isinstance(ex, int): @@ -1449,7 +1496,7 @@ def set(self, name, value, else: raise DataError("ex must be datetime.timedelta or int") if px is not None: - pieces.append('PX') + pieces.append("PX") if isinstance(px, datetime.timedelta): pieces.append(int(px.total_seconds() * 1000)) elif isinstance(px, int): @@ -1457,30 +1504,30 @@ def set(self, name, value, else: raise DataError("px must be datetime.timedelta or int") if exat is not None: - pieces.append('EXAT') + pieces.append("EXAT") if isinstance(exat, datetime.datetime): s = int(exat.microsecond / 1000000) exat = int(time.mktime(exat.timetuple())) + s pieces.append(exat) if pxat is not None: - pieces.append('PXAT') + pieces.append("PXAT") if isinstance(pxat, datetime.datetime): ms = int(pxat.microsecond / 1000) pxat = int(time.mktime(pxat.timetuple())) * 1000 + ms pieces.append(pxat) if keepttl: - pieces.append('KEEPTTL') + pieces.append("KEEPTTL") if nx: - pieces.append('NX') + pieces.append("NX") if xx: - pieces.append('XX') + pieces.append("XX") if get: - pieces.append('GET') + pieces.append("GET") options["get"] = True - return self.execute_command('SET', *pieces, **options) + return self.execute_command("SET", *pieces, **options) def __setitem__(self, name, value): self.set(name, value) @@ -1493,7 +1540,7 @@ def setbit(self, name, offset, value): For more information check https://redis.io/commands/setbit """ value = value and 1 or 0 - return self.execute_command('SETBIT', name, offset, value) + return self.execute_command("SETBIT", name, offset, value) def setex(self, name, time, value): """ @@ -1505,7 +1552,7 @@ def setex(self, name, time, value): """ if isinstance(time, datetime.timedelta): time = int(time.total_seconds()) - return self.execute_command('SETEX', name, time, value) + return self.execute_command("SETEX", name, time, value) def setnx(self, name, value): """ @@ -1513,7 +1560,7 @@ def setnx(self, name, value): For more information check https://redis.io/commands/setnx """ - return self.execute_command('SETNX', name, value) + return self.execute_command("SETNX", name, value) def setrange(self, name, offset, value): """ @@ -1528,10 +1575,19 @@ def setrange(self, name, offset, value): For more information check https://redis.io/commands/setrange """ - return self.execute_command('SETRANGE', name, offset, value) + return self.execute_command("SETRANGE", name, offset, value) - def stralgo(self, algo, value1, value2, specific_argument='strings', - len=False, idx=False, minmatchlen=None, withmatchlen=False): + def stralgo( + self, + algo, + value1, + value2, + specific_argument="strings", + len=False, + idx=False, + minmatchlen=None, + withmatchlen=False, + ): """ Implements complex algorithms that operate on strings. Right now the only algorithm implemented is the LCS algorithm @@ -1552,31 +1608,36 @@ def stralgo(self, algo, value1, value2, specific_argument='strings', For more information check https://redis.io/commands/stralgo """ # check validity - supported_algo = ['LCS'] + supported_algo = ["LCS"] if algo not in supported_algo: - supported_algos_str = ', '.join(supported_algo) + supported_algos_str = ", ".join(supported_algo) raise DataError(f"The supported algorithms are: {supported_algos_str}") - if specific_argument not in ['keys', 'strings']: + if specific_argument not in ["keys", "strings"]: raise DataError("specific_argument can be only keys or strings") if len and idx: raise DataError("len and idx cannot be provided together.") pieces = [algo, specific_argument.upper(), value1, value2] if len: - pieces.append(b'LEN') + pieces.append(b"LEN") if idx: - pieces.append(b'IDX') + pieces.append(b"IDX") try: int(minmatchlen) - pieces.extend([b'MINMATCHLEN', minmatchlen]) + pieces.extend([b"MINMATCHLEN", minmatchlen]) except TypeError: pass if withmatchlen: - pieces.append(b'WITHMATCHLEN') - - return self.execute_command('STRALGO', *pieces, len=len, idx=idx, - minmatchlen=minmatchlen, - withmatchlen=withmatchlen) + pieces.append(b"WITHMATCHLEN") + + return self.execute_command( + "STRALGO", + *pieces, + len=len, + idx=idx, + minmatchlen=minmatchlen, + withmatchlen=withmatchlen, + ) def strlen(self, name): """ @@ -1584,14 +1645,14 @@ def strlen(self, name): For more information check https://redis.io/commands/strlen """ - return self.execute_command('STRLEN', name) + return self.execute_command("STRLEN", name) def substr(self, name, start, end=-1): """ Return a substring of the string at key ``name``. ``start`` and ``end`` are 0-based integers specifying the portion of the string to return. """ - return self.execute_command('SUBSTR', name, start, end) + return self.execute_command("SUBSTR", name, start, end) def touch(self, *args): """ @@ -1600,7 +1661,7 @@ def touch(self, *args): For more information check https://redis.io/commands/touch """ - return self.execute_command('TOUCH', *args) + return self.execute_command("TOUCH", *args) def ttl(self, name): """ @@ -1608,7 +1669,7 @@ def ttl(self, name): For more information check https://redis.io/commands/ttl """ - return self.execute_command('TTL', name) + return self.execute_command("TTL", name) def type(self, name): """ @@ -1616,7 +1677,7 @@ def type(self, name): For more information check https://redis.io/commands/type """ - return self.execute_command('TYPE', name) + return self.execute_command("TYPE", name) def watch(self, *names): """ @@ -1624,7 +1685,7 @@ def watch(self, *names): For more information check https://redis.io/commands/type """ - warnings.warn(DeprecationWarning('Call WATCH from a Pipeline object')) + warnings.warn(DeprecationWarning("Call WATCH from a Pipeline object")) def unwatch(self): """ @@ -1632,8 +1693,7 @@ def unwatch(self): For more information check https://redis.io/commands/unwatch """ - warnings.warn( - DeprecationWarning('Call UNWATCH from a Pipeline object')) + warnings.warn(DeprecationWarning("Call UNWATCH from a Pipeline object")) def unlink(self, *names): """ @@ -1641,7 +1701,7 @@ def unlink(self, *names): For more information check https://redis.io/commands/unlink """ - return self.execute_command('UNLINK', *names) + return self.execute_command("UNLINK", *names) class ListCommands: @@ -1649,6 +1709,7 @@ class ListCommands: Redis commands for List data type. see: https://redis.io/topics/data-types#lists """ + def blpop(self, keys, timeout=0): """ LPOP a value off of the first non-empty list @@ -1666,7 +1727,7 @@ def blpop(self, keys, timeout=0): timeout = 0 keys = list_or_args(keys, None) keys.append(timeout) - return self.execute_command('BLPOP', *keys) + return self.execute_command("BLPOP", *keys) def brpop(self, keys, timeout=0): """ @@ -1685,7 +1746,7 @@ def brpop(self, keys, timeout=0): timeout = 0 keys = list_or_args(keys, None) keys.append(timeout) - return self.execute_command('BRPOP', *keys) + return self.execute_command("BRPOP", *keys) def brpoplpush(self, src, dst, timeout=0): """ @@ -1700,7 +1761,7 @@ def brpoplpush(self, src, dst, timeout=0): """ if timeout is None: timeout = 0 - return self.execute_command('BRPOPLPUSH', src, dst, timeout) + return self.execute_command("BRPOPLPUSH", src, dst, timeout) def lindex(self, name, index): """ @@ -1711,7 +1772,7 @@ def lindex(self, name, index): For more information check https://redis.io/commands/lindex """ - return self.execute_command('LINDEX', name, index) + return self.execute_command("LINDEX", name, index) def linsert(self, name, where, refvalue, value): """ @@ -1723,7 +1784,7 @@ def linsert(self, name, where, refvalue, value): For more information check https://redis.io/commands/linsert """ - return self.execute_command('LINSERT', name, where, refvalue, value) + return self.execute_command("LINSERT", name, where, refvalue, value) def llen(self, name): """ @@ -1731,7 +1792,7 @@ def llen(self, name): For more information check https://redis.io/commands/llen """ - return self.execute_command('LLEN', name) + return self.execute_command("LLEN", name) def lpop(self, name, count=None): """ @@ -1744,9 +1805,9 @@ def lpop(self, name, count=None): For more information check https://redis.io/commands/lpop """ if count is not None: - return self.execute_command('LPOP', name, count) + return self.execute_command("LPOP", name, count) else: - return self.execute_command('LPOP', name) + return self.execute_command("LPOP", name) def lpush(self, name, *values): """ @@ -1754,7 +1815,7 @@ def lpush(self, name, *values): For more information check https://redis.io/commands/lpush """ - return self.execute_command('LPUSH', name, *values) + return self.execute_command("LPUSH", name, *values) def lpushx(self, name, *values): """ @@ -1762,7 +1823,7 @@ def lpushx(self, name, *values): For more information check https://redis.io/commands/lpushx """ - return self.execute_command('LPUSHX', name, *values) + return self.execute_command("LPUSHX", name, *values) def lrange(self, name, start, end): """ @@ -1774,7 +1835,7 @@ def lrange(self, name, start, end): For more information check https://redis.io/commands/lrange """ - return self.execute_command('LRANGE', name, start, end) + return self.execute_command("LRANGE", name, start, end) def lrem(self, name, count, value): """ @@ -1788,7 +1849,7 @@ def lrem(self, name, count, value): For more information check https://redis.io/commands/lrem """ - return self.execute_command('LREM', name, count, value) + return self.execute_command("LREM", name, count, value) def lset(self, name, index, value): """ @@ -1796,7 +1857,7 @@ def lset(self, name, index, value): For more information check https://redis.io/commands/lset """ - return self.execute_command('LSET', name, index, value) + return self.execute_command("LSET", name, index, value) def ltrim(self, name, start, end): """ @@ -1808,7 +1869,7 @@ def ltrim(self, name, start, end): For more information check https://redis.io/commands/ltrim """ - return self.execute_command('LTRIM', name, start, end) + return self.execute_command("LTRIM", name, start, end) def rpop(self, name, count=None): """ @@ -1821,9 +1882,9 @@ def rpop(self, name, count=None): For more information check https://redis.io/commands/rpop """ if count is not None: - return self.execute_command('RPOP', name, count) + return self.execute_command("RPOP", name, count) else: - return self.execute_command('RPOP', name) + return self.execute_command("RPOP", name) def rpoplpush(self, src, dst): """ @@ -1832,7 +1893,7 @@ def rpoplpush(self, src, dst): For more information check https://redis.io/commands/rpoplpush """ - return self.execute_command('RPOPLPUSH', src, dst) + return self.execute_command("RPOPLPUSH", src, dst) def rpush(self, name, *values): """ @@ -1840,7 +1901,7 @@ def rpush(self, name, *values): For more information check https://redis.io/commands/rpush """ - return self.execute_command('RPUSH', name, *values) + return self.execute_command("RPUSH", name, *values) def rpushx(self, name, value): """ @@ -1848,7 +1909,7 @@ def rpushx(self, name, value): For more information check https://redis.io/commands/rpushx """ - return self.execute_command('RPUSHX', name, value) + return self.execute_command("RPUSHX", name, value) def lpos(self, name, value, rank=None, count=None, maxlen=None): """ @@ -1878,18 +1939,28 @@ def lpos(self, name, value, rank=None, count=None, maxlen=None): """ pieces = [name, value] if rank is not None: - pieces.extend(['RANK', rank]) + pieces.extend(["RANK", rank]) if count is not None: - pieces.extend(['COUNT', count]) + pieces.extend(["COUNT", count]) if maxlen is not None: - pieces.extend(['MAXLEN', maxlen]) - - return self.execute_command('LPOS', *pieces) - - def sort(self, name, start=None, num=None, by=None, get=None, - desc=False, alpha=False, store=None, groups=False): + pieces.extend(["MAXLEN", maxlen]) + + return self.execute_command("LPOS", *pieces) + + def sort( + self, + name, + start=None, + num=None, + by=None, + get=None, + desc=False, + alpha=False, + store=None, + groups=False, + ): """ Sort and return the list, set or sorted set at ``name``. @@ -1915,39 +1986,40 @@ def sort(self, name, start=None, num=None, by=None, get=None, For more information check https://redis.io/commands/sort """ - if (start is not None and num is None) or \ - (num is not None and start is None): + if (start is not None and num is None) or (num is not None and start is None): raise DataError("``start`` and ``num`` must both be specified") pieces = [name] if by is not None: - pieces.extend([b'BY', by]) + pieces.extend([b"BY", by]) if start is not None and num is not None: - pieces.extend([b'LIMIT', start, num]) + pieces.extend([b"LIMIT", start, num]) if get is not None: # If get is a string assume we want to get a single value. # Otherwise assume it's an interable and we want to get multiple # values. We can't just iterate blindly because strings are # iterable. if isinstance(get, (bytes, str)): - pieces.extend([b'GET', get]) + pieces.extend([b"GET", get]) else: for g in get: - pieces.extend([b'GET', g]) + pieces.extend([b"GET", g]) if desc: - pieces.append(b'DESC') + pieces.append(b"DESC") if alpha: - pieces.append(b'ALPHA') + pieces.append(b"ALPHA") if store is not None: - pieces.extend([b'STORE', store]) + pieces.extend([b"STORE", store]) if groups: if not get or isinstance(get, (bytes, str)) or len(get) < 2: - raise DataError('when using "groups" the "get" argument ' - 'must be specified and contain at least ' - 'two keys') + raise DataError( + 'when using "groups" the "get" argument ' + "must be specified and contain at least " + "two keys" + ) - options = {'groups': len(get) if groups else None} - return self.execute_command('SORT', *pieces, **options) + options = {"groups": len(get) if groups else None} + return self.execute_command("SORT", *pieces, **options) class ScanCommands: @@ -1955,6 +2027,7 @@ class ScanCommands: Redis SCAN commands. see: https://redis.io/commands/scan """ + def scan(self, cursor=0, match=None, count=None, _type=None): """ Incrementally return lists of key names. Also return a cursor @@ -1974,12 +2047,12 @@ def scan(self, cursor=0, match=None, count=None, _type=None): """ pieces = [cursor] if match is not None: - pieces.extend([b'MATCH', match]) + pieces.extend([b"MATCH", match]) if count is not None: - pieces.extend([b'COUNT', count]) + pieces.extend([b"COUNT", count]) if _type is not None: - pieces.extend([b'TYPE', _type]) - return self.execute_command('SCAN', *pieces) + pieces.extend([b"TYPE", _type]) + return self.execute_command("SCAN", *pieces) def scan_iter(self, match=None, count=None, _type=None): """ @@ -1996,10 +2069,11 @@ def scan_iter(self, match=None, count=None, _type=None): HASH, LIST, SET, STREAM, STRING, ZSET Additionally, Redis modules can expose other types as well. """ - cursor = '0' + cursor = "0" while cursor != 0: - cursor, data = self.scan(cursor=cursor, match=match, - count=count, _type=_type) + cursor, data = self.scan( + cursor=cursor, match=match, count=count, _type=_type + ) yield from data def sscan(self, name, cursor=0, match=None, count=None): @@ -2015,10 +2089,10 @@ def sscan(self, name, cursor=0, match=None, count=None): """ pieces = [name, cursor] if match is not None: - pieces.extend([b'MATCH', match]) + pieces.extend([b"MATCH", match]) if count is not None: - pieces.extend([b'COUNT', count]) - return self.execute_command('SSCAN', *pieces) + pieces.extend([b"COUNT", count]) + return self.execute_command("SSCAN", *pieces) def sscan_iter(self, name, match=None, count=None): """ @@ -2029,10 +2103,9 @@ def sscan_iter(self, name, match=None, count=None): ``count`` allows for hint the minimum number of returns """ - cursor = '0' + cursor = "0" while cursor != 0: - cursor, data = self.sscan(name, cursor=cursor, - match=match, count=count) + cursor, data = self.sscan(name, cursor=cursor, match=match, count=count) yield from data def hscan(self, name, cursor=0, match=None, count=None): @@ -2048,10 +2121,10 @@ def hscan(self, name, cursor=0, match=None, count=None): """ pieces = [name, cursor] if match is not None: - pieces.extend([b'MATCH', match]) + pieces.extend([b"MATCH", match]) if count is not None: - pieces.extend([b'COUNT', count]) - return self.execute_command('HSCAN', *pieces) + pieces.extend([b"COUNT", count]) + return self.execute_command("HSCAN", *pieces) def hscan_iter(self, name, match=None, count=None): """ @@ -2062,14 +2135,12 @@ def hscan_iter(self, name, match=None, count=None): ``count`` allows for hint the minimum number of returns """ - cursor = '0' + cursor = "0" while cursor != 0: - cursor, data = self.hscan(name, cursor=cursor, - match=match, count=count) + cursor, data = self.hscan(name, cursor=cursor, match=match, count=count) yield from data.items() - def zscan(self, name, cursor=0, match=None, count=None, - score_cast_func=float): + def zscan(self, name, cursor=0, match=None, count=None, score_cast_func=float): """ Incrementally return lists of elements in a sorted set. Also return a cursor indicating the scan position. @@ -2084,14 +2155,13 @@ def zscan(self, name, cursor=0, match=None, count=None, """ pieces = [name, cursor] if match is not None: - pieces.extend([b'MATCH', match]) + pieces.extend([b"MATCH", match]) if count is not None: - pieces.extend([b'COUNT', count]) - options = {'score_cast_func': score_cast_func} - return self.execute_command('ZSCAN', *pieces, **options) + pieces.extend([b"COUNT", count]) + options = {"score_cast_func": score_cast_func} + return self.execute_command("ZSCAN", *pieces, **options) - def zscan_iter(self, name, match=None, count=None, - score_cast_func=float): + def zscan_iter(self, name, match=None, count=None, score_cast_func=float): """ Make an iterator using the ZSCAN command so that the client doesn't need to remember the cursor position. @@ -2102,11 +2172,15 @@ def zscan_iter(self, name, match=None, count=None, ``score_cast_func`` a callable used to cast the score return value """ - cursor = '0' + cursor = "0" while cursor != 0: - cursor, data = self.zscan(name, cursor=cursor, match=match, - count=count, - score_cast_func=score_cast_func) + cursor, data = self.zscan( + name, + cursor=cursor, + match=match, + count=count, + score_cast_func=score_cast_func, + ) yield from data @@ -2115,13 +2189,14 @@ class SetCommands: Redis commands for Set data type. see: https://redis.io/topics/data-types#sets """ + def sadd(self, name, *values): """ Add ``value(s)`` to set ``name`` For more information check https://redis.io/commands/sadd """ - return self.execute_command('SADD', name, *values) + return self.execute_command("SADD", name, *values) def scard(self, name): """ @@ -2129,7 +2204,7 @@ def scard(self, name): For more information check https://redis.io/commands/scard """ - return self.execute_command('SCARD', name) + return self.execute_command("SCARD", name) def sdiff(self, keys, *args): """ @@ -2138,7 +2213,7 @@ def sdiff(self, keys, *args): For more information check https://redis.io/commands/sdiff """ args = list_or_args(keys, args) - return self.execute_command('SDIFF', *args) + return self.execute_command("SDIFF", *args) def sdiffstore(self, dest, keys, *args): """ @@ -2148,7 +2223,7 @@ def sdiffstore(self, dest, keys, *args): For more information check https://redis.io/commands/sdiffstore """ args = list_or_args(keys, args) - return self.execute_command('SDIFFSTORE', dest, *args) + return self.execute_command("SDIFFSTORE", dest, *args) def sinter(self, keys, *args): """ @@ -2157,7 +2232,7 @@ def sinter(self, keys, *args): For more information check https://redis.io/commands/sinter """ args = list_or_args(keys, args) - return self.execute_command('SINTER', *args) + return self.execute_command("SINTER", *args) def sinterstore(self, dest, keys, *args): """ @@ -2167,7 +2242,7 @@ def sinterstore(self, dest, keys, *args): For more information check https://redis.io/commands/sinterstore """ args = list_or_args(keys, args) - return self.execute_command('SINTERSTORE', dest, *args) + return self.execute_command("SINTERSTORE", dest, *args) def sismember(self, name, value): """ @@ -2175,7 +2250,7 @@ def sismember(self, name, value): For more information check https://redis.io/commands/sismember """ - return self.execute_command('SISMEMBER', name, value) + return self.execute_command("SISMEMBER", name, value) def smembers(self, name): """ @@ -2183,7 +2258,7 @@ def smembers(self, name): For more information check https://redis.io/commands/smembers """ - return self.execute_command('SMEMBERS', name) + return self.execute_command("SMEMBERS", name) def smismember(self, name, values, *args): """ @@ -2193,7 +2268,7 @@ def smismember(self, name, values, *args): For more information check https://redis.io/commands/smismember """ args = list_or_args(values, args) - return self.execute_command('SMISMEMBER', name, *args) + return self.execute_command("SMISMEMBER", name, *args) def smove(self, src, dst, value): """ @@ -2201,7 +2276,7 @@ def smove(self, src, dst, value): For more information check https://redis.io/commands/smove """ - return self.execute_command('SMOVE', src, dst, value) + return self.execute_command("SMOVE", src, dst, value) def spop(self, name, count=None): """ @@ -2210,7 +2285,7 @@ def spop(self, name, count=None): For more information check https://redis.io/commands/spop """ args = (count is not None) and [count] or [] - return self.execute_command('SPOP', name, *args) + return self.execute_command("SPOP", name, *args) def srandmember(self, name, number=None): """ @@ -2223,7 +2298,7 @@ def srandmember(self, name, number=None): For more information check https://redis.io/commands/srandmember """ args = (number is not None) and [number] or [] - return self.execute_command('SRANDMEMBER', name, *args) + return self.execute_command("SRANDMEMBER", name, *args) def srem(self, name, *values): """ @@ -2231,7 +2306,7 @@ def srem(self, name, *values): For more information check https://redis.io/commands/srem """ - return self.execute_command('SREM', name, *values) + return self.execute_command("SREM", name, *values) def sunion(self, keys, *args): """ @@ -2240,7 +2315,7 @@ def sunion(self, keys, *args): For more information check https://redis.io/commands/sunion """ args = list_or_args(keys, args) - return self.execute_command('SUNION', *args) + return self.execute_command("SUNION", *args) def sunionstore(self, dest, keys, *args): """ @@ -2250,7 +2325,7 @@ def sunionstore(self, dest, keys, *args): For more information check https://redis.io/commands/sunionstore """ args = list_or_args(keys, args) - return self.execute_command('SUNIONSTORE', dest, *args) + return self.execute_command("SUNIONSTORE", dest, *args) class StreamCommands: @@ -2258,6 +2333,7 @@ class StreamCommands: Redis commands for Stream data type. see: https://redis.io/topics/streams-intro """ + def xack(self, name, groupname, *ids): """ Acknowledges the successful processing of one or more messages. @@ -2267,10 +2343,19 @@ def xack(self, name, groupname, *ids): For more information check https://redis.io/commands/xack """ - return self.execute_command('XACK', name, groupname, *ids) + return self.execute_command("XACK", name, groupname, *ids) - def xadd(self, name, fields, id='*', maxlen=None, approximate=True, - nomkstream=False, minid=None, limit=None): + def xadd( + self, + name, + fields, + id="*", + maxlen=None, + approximate=True, + nomkstream=False, + minid=None, + limit=None, + ): """ Add to a stream. name: name of the stream @@ -2288,34 +2373,43 @@ def xadd(self, name, fields, id='*', maxlen=None, approximate=True, """ pieces = [] if maxlen is not None and minid is not None: - raise DataError("Only one of ```maxlen``` or ```minid``` " - "may be specified") + raise DataError( + "Only one of ```maxlen``` or ```minid``` " "may be specified" + ) if maxlen is not None: if not isinstance(maxlen, int) or maxlen < 1: - raise DataError('XADD maxlen must be a positive integer') - pieces.append(b'MAXLEN') + raise DataError("XADD maxlen must be a positive integer") + pieces.append(b"MAXLEN") if approximate: - pieces.append(b'~') + pieces.append(b"~") pieces.append(str(maxlen)) if minid is not None: - pieces.append(b'MINID') + pieces.append(b"MINID") if approximate: - pieces.append(b'~') + pieces.append(b"~") pieces.append(minid) if limit is not None: - pieces.extend([b'LIMIT', limit]) + pieces.extend([b"LIMIT", limit]) if nomkstream: - pieces.append(b'NOMKSTREAM') + pieces.append(b"NOMKSTREAM") pieces.append(id) if not isinstance(fields, dict) or len(fields) == 0: - raise DataError('XADD fields must be a non-empty dict') + raise DataError("XADD fields must be a non-empty dict") for pair in fields.items(): pieces.extend(pair) - return self.execute_command('XADD', name, *pieces) - - def xautoclaim(self, name, groupname, consumername, min_idle_time, - start_id=0, count=None, justid=False): + return self.execute_command("XADD", name, *pieces) + + def xautoclaim( + self, + name, + groupname, + consumername, + min_idle_time, + start_id=0, + count=None, + justid=False, + ): """ Transfers ownership of pending stream entries that match the specified criteria. Conceptually, equivalent to calling XPENDING and then XCLAIM, @@ -2336,8 +2430,9 @@ def xautoclaim(self, name, groupname, consumername, min_idle_time, """ try: if int(min_idle_time) < 0: - raise DataError("XAUTOCLAIM min_idle_time must be a non" - "negative integer") + raise DataError( + "XAUTOCLAIM min_idle_time must be a non" "negative integer" + ) except TypeError: pass @@ -2347,18 +2442,28 @@ def xautoclaim(self, name, groupname, consumername, min_idle_time, try: if int(count) < 0: raise DataError("XPENDING count must be a integer >= 0") - pieces.extend([b'COUNT', count]) + pieces.extend([b"COUNT", count]) except TypeError: pass if justid: - pieces.append(b'JUSTID') - kwargs['parse_justid'] = True - - return self.execute_command('XAUTOCLAIM', *pieces, **kwargs) - - def xclaim(self, name, groupname, consumername, min_idle_time, message_ids, - idle=None, time=None, retrycount=None, force=False, - justid=False): + pieces.append(b"JUSTID") + kwargs["parse_justid"] = True + + return self.execute_command("XAUTOCLAIM", *pieces, **kwargs) + + def xclaim( + self, + name, + groupname, + consumername, + min_idle_time, + message_ids, + idle=None, + time=None, + retrycount=None, + force=False, + justid=False, + ): """ Changes the ownership of a pending message. name: name of the stream. @@ -2384,11 +2489,12 @@ def xclaim(self, name, groupname, consumername, min_idle_time, message_ids, For more information check https://redis.io/commands/xclaim """ if not isinstance(min_idle_time, int) or min_idle_time < 0: - raise DataError("XCLAIM min_idle_time must be a non negative " - "integer") + raise DataError("XCLAIM min_idle_time must be a non negative " "integer") if not isinstance(message_ids, (list, tuple)) or not message_ids: - raise DataError("XCLAIM message_ids must be a non empty list or " - "tuple of message IDs to claim") + raise DataError( + "XCLAIM message_ids must be a non empty list or " + "tuple of message IDs to claim" + ) kwargs = {} pieces = [name, groupname, consumername, str(min_idle_time)] @@ -2397,26 +2503,26 @@ def xclaim(self, name, groupname, consumername, min_idle_time, message_ids, if idle is not None: if not isinstance(idle, int): raise DataError("XCLAIM idle must be an integer") - pieces.extend((b'IDLE', str(idle))) + pieces.extend((b"IDLE", str(idle))) if time is not None: if not isinstance(time, int): raise DataError("XCLAIM time must be an integer") - pieces.extend((b'TIME', str(time))) + pieces.extend((b"TIME", str(time))) if retrycount is not None: if not isinstance(retrycount, int): raise DataError("XCLAIM retrycount must be an integer") - pieces.extend((b'RETRYCOUNT', str(retrycount))) + pieces.extend((b"RETRYCOUNT", str(retrycount))) if force: if not isinstance(force, bool): raise DataError("XCLAIM force must be a boolean") - pieces.append(b'FORCE') + pieces.append(b"FORCE") if justid: if not isinstance(justid, bool): raise DataError("XCLAIM justid must be a boolean") - pieces.append(b'JUSTID') - kwargs['parse_justid'] = True - return self.execute_command('XCLAIM', *pieces, **kwargs) + pieces.append(b"JUSTID") + kwargs["parse_justid"] = True + return self.execute_command("XCLAIM", *pieces, **kwargs) def xdel(self, name, *ids): """ @@ -2426,9 +2532,9 @@ def xdel(self, name, *ids): For more information check https://redis.io/commands/xdel """ - return self.execute_command('XDEL', name, *ids) + return self.execute_command("XDEL", name, *ids) - def xgroup_create(self, name, groupname, id='$', mkstream=False): + def xgroup_create(self, name, groupname, id="$", mkstream=False): """ Create a new consumer group associated with a stream. name: name of the stream. @@ -2437,9 +2543,9 @@ def xgroup_create(self, name, groupname, id='$', mkstream=False): For more information check https://redis.io/commands/xgroup-create """ - pieces = ['XGROUP CREATE', name, groupname, id] + pieces = ["XGROUP CREATE", name, groupname, id] if mkstream: - pieces.append(b'MKSTREAM') + pieces.append(b"MKSTREAM") return self.execute_command(*pieces) def xgroup_delconsumer(self, name, groupname, consumername): @@ -2453,8 +2559,7 @@ def xgroup_delconsumer(self, name, groupname, consumername): For more information check https://redis.io/commands/xgroup-delconsumer """ - return self.execute_command('XGROUP DELCONSUMER', name, groupname, - consumername) + return self.execute_command("XGROUP DELCONSUMER", name, groupname, consumername) def xgroup_destroy(self, name, groupname): """ @@ -2464,7 +2569,7 @@ def xgroup_destroy(self, name, groupname): For more information check https://redis.io/commands/xgroup-destroy """ - return self.execute_command('XGROUP DESTROY', name, groupname) + return self.execute_command("XGROUP DESTROY", name, groupname) def xgroup_createconsumer(self, name, groupname, consumername): """ @@ -2477,8 +2582,9 @@ def xgroup_createconsumer(self, name, groupname, consumername): See: https://redis.io/commands/xgroup-createconsumer """ - return self.execute_command('XGROUP CREATECONSUMER', name, groupname, - consumername) + return self.execute_command( + "XGROUP CREATECONSUMER", name, groupname, consumername + ) def xgroup_setid(self, name, groupname, id): """ @@ -2489,7 +2595,7 @@ def xgroup_setid(self, name, groupname, id): For more information check https://redis.io/commands/xgroup-setid """ - return self.execute_command('XGROUP SETID', name, groupname, id) + return self.execute_command("XGROUP SETID", name, groupname, id) def xinfo_consumers(self, name, groupname): """ @@ -2499,7 +2605,7 @@ def xinfo_consumers(self, name, groupname): For more information check https://redis.io/commands/xinfo-consumers """ - return self.execute_command('XINFO CONSUMERS', name, groupname) + return self.execute_command("XINFO CONSUMERS", name, groupname) def xinfo_groups(self, name): """ @@ -2508,7 +2614,7 @@ def xinfo_groups(self, name): For more information check https://redis.io/commands/xinfo-groups """ - return self.execute_command('XINFO GROUPS', name) + return self.execute_command("XINFO GROUPS", name) def xinfo_stream(self, name, full=False): """ @@ -2521,9 +2627,9 @@ def xinfo_stream(self, name, full=False): pieces = [name] options = {} if full: - pieces.append(b'FULL') - options = {'full': full} - return self.execute_command('XINFO STREAM', *pieces, **options) + pieces.append(b"FULL") + options = {"full": full} + return self.execute_command("XINFO STREAM", *pieces, **options) def xlen(self, name): """ @@ -2531,7 +2637,7 @@ def xlen(self, name): For more information check https://redis.io/commands/xlen """ - return self.execute_command('XLEN', name) + return self.execute_command("XLEN", name) def xpending(self, name, groupname): """ @@ -2541,11 +2647,18 @@ def xpending(self, name, groupname): For more information check https://redis.io/commands/xpending """ - return self.execute_command('XPENDING', name, groupname) + return self.execute_command("XPENDING", name, groupname) - def xpending_range(self, name, groupname, idle=None, - min=None, max=None, count=None, - consumername=None): + def xpending_range( + self, + name, + groupname, + idle=None, + min=None, + max=None, + count=None, + consumername=None, + ): """ Returns information about pending messages, in a range. @@ -2560,20 +2673,24 @@ def xpending_range(self, name, groupname, idle=None, """ if {min, max, count} == {None}: if idle is not None or consumername is not None: - raise DataError("if XPENDING is provided with idle time" - " or consumername, it must be provided" - " with min, max and count parameters") + raise DataError( + "if XPENDING is provided with idle time" + " or consumername, it must be provided" + " with min, max and count parameters" + ) return self.xpending(name, groupname) pieces = [name, groupname] if min is None or max is None or count is None: - raise DataError("XPENDING must be provided with min, max " - "and count parameters, or none of them.") + raise DataError( + "XPENDING must be provided with min, max " + "and count parameters, or none of them." + ) # idle try: if int(idle) < 0: raise DataError("XPENDING idle must be a integer >= 0") - pieces.extend(['IDLE', idle]) + pieces.extend(["IDLE", idle]) except TypeError: pass # count @@ -2587,9 +2704,9 @@ def xpending_range(self, name, groupname, idle=None, if consumername: pieces.append(consumername) - return self.execute_command('XPENDING', *pieces, parse_detail=True) + return self.execute_command("XPENDING", *pieces, parse_detail=True) - def xrange(self, name, min='-', max='+', count=None): + def xrange(self, name, min="-", max="+", count=None): """ Read stream values within an interval. name: name of the stream. @@ -2605,11 +2722,11 @@ def xrange(self, name, min='-', max='+', count=None): pieces = [min, max] if count is not None: if not isinstance(count, int) or count < 1: - raise DataError('XRANGE count must be a positive integer') - pieces.append(b'COUNT') + raise DataError("XRANGE count must be a positive integer") + pieces.append(b"COUNT") pieces.append(str(count)) - return self.execute_command('XRANGE', name, *pieces) + return self.execute_command("XRANGE", name, *pieces) def xread(self, streams, count=None, block=None): """ @@ -2625,24 +2742,25 @@ def xread(self, streams, count=None, block=None): pieces = [] if block is not None: if not isinstance(block, int) or block < 0: - raise DataError('XREAD block must be a non-negative integer') - pieces.append(b'BLOCK') + raise DataError("XREAD block must be a non-negative integer") + pieces.append(b"BLOCK") pieces.append(str(block)) if count is not None: if not isinstance(count, int) or count < 1: - raise DataError('XREAD count must be a positive integer') - pieces.append(b'COUNT') + raise DataError("XREAD count must be a positive integer") + pieces.append(b"COUNT") pieces.append(str(count)) if not isinstance(streams, dict) or len(streams) == 0: - raise DataError('XREAD streams must be a non empty dict') - pieces.append(b'STREAMS') + raise DataError("XREAD streams must be a non empty dict") + pieces.append(b"STREAMS") keys, values = zip(*streams.items()) pieces.extend(keys) pieces.extend(values) - return self.execute_command('XREAD', *pieces) + return self.execute_command("XREAD", *pieces) - def xreadgroup(self, groupname, consumername, streams, count=None, - block=None, noack=False): + def xreadgroup( + self, groupname, consumername, streams, count=None, block=None, noack=False + ): """ Read from a stream via a consumer group. groupname: name of the consumer group. @@ -2656,28 +2774,27 @@ def xreadgroup(self, groupname, consumername, streams, count=None, For more information check https://redis.io/commands/xreadgroup """ - pieces = [b'GROUP', groupname, consumername] + pieces = [b"GROUP", groupname, consumername] if count is not None: if not isinstance(count, int) or count < 1: raise DataError("XREADGROUP count must be a positive integer") - pieces.append(b'COUNT') + pieces.append(b"COUNT") pieces.append(str(count)) if block is not None: if not isinstance(block, int) or block < 0: - raise DataError("XREADGROUP block must be a non-negative " - "integer") - pieces.append(b'BLOCK') + raise DataError("XREADGROUP block must be a non-negative " "integer") + pieces.append(b"BLOCK") pieces.append(str(block)) if noack: - pieces.append(b'NOACK') + pieces.append(b"NOACK") if not isinstance(streams, dict) or len(streams) == 0: - raise DataError('XREADGROUP streams must be a non empty dict') - pieces.append(b'STREAMS') + raise DataError("XREADGROUP streams must be a non empty dict") + pieces.append(b"STREAMS") pieces.extend(streams.keys()) pieces.extend(streams.values()) - return self.execute_command('XREADGROUP', *pieces) + return self.execute_command("XREADGROUP", *pieces) - def xrevrange(self, name, max='+', min='-', count=None): + def xrevrange(self, name, max="+", min="-", count=None): """ Read stream values within an interval, in reverse order. name: name of the stream @@ -2693,14 +2810,13 @@ def xrevrange(self, name, max='+', min='-', count=None): pieces = [max, min] if count is not None: if not isinstance(count, int) or count < 1: - raise DataError('XREVRANGE count must be a positive integer') - pieces.append(b'COUNT') + raise DataError("XREVRANGE count must be a positive integer") + pieces.append(b"COUNT") pieces.append(str(count)) - return self.execute_command('XREVRANGE', name, *pieces) + return self.execute_command("XREVRANGE", name, *pieces) - def xtrim(self, name, maxlen=None, approximate=True, minid=None, - limit=None): + def xtrim(self, name, maxlen=None, approximate=True, minid=None, limit=None): """ Trims old messages from a stream. name: name of the stream. @@ -2715,15 +2831,14 @@ def xtrim(self, name, maxlen=None, approximate=True, minid=None, """ pieces = [] if maxlen is not None and minid is not None: - raise DataError("Only one of ``maxlen`` or ``minid`` " - "may be specified") + raise DataError("Only one of ``maxlen`` or ``minid`` " "may be specified") if maxlen is not None: - pieces.append(b'MAXLEN') + pieces.append(b"MAXLEN") if minid is not None: - pieces.append(b'MINID') + pieces.append(b"MINID") if approximate: - pieces.append(b'~') + pieces.append(b"~") if maxlen is not None: pieces.append(maxlen) if minid is not None: @@ -2732,7 +2847,7 @@ def xtrim(self, name, maxlen=None, approximate=True, minid=None, pieces.append(b"LIMIT") pieces.append(limit) - return self.execute_command('XTRIM', name, *pieces) + return self.execute_command("XTRIM", name, *pieces) class SortedSetCommands: @@ -2740,8 +2855,10 @@ class SortedSetCommands: Redis commands for Sorted Sets data type. see: https://redis.io/topics/data-types-intro#redis-sorted-sets """ - def zadd(self, name, mapping, nx=False, xx=False, ch=False, incr=False, - gt=None, lt=None): + + def zadd( + self, name, mapping, nx=False, xx=False, ch=False, incr=False, gt=None, lt=None + ): """ Set any number of element-name, score pairs to the key ``name``. Pairs are specified as a dict of element-names keys to score values. @@ -2780,30 +2897,32 @@ def zadd(self, name, mapping, nx=False, xx=False, ch=False, incr=False, if nx and xx: raise DataError("ZADD allows either 'nx' or 'xx', not both") if incr and len(mapping) != 1: - raise DataError("ZADD option 'incr' only works when passing a " - "single element/score pair") + raise DataError( + "ZADD option 'incr' only works when passing a " + "single element/score pair" + ) if nx is True and (gt is not None or lt is not None): raise DataError("Only one of 'nx', 'lt', or 'gr' may be defined.") pieces = [] options = {} if nx: - pieces.append(b'NX') + pieces.append(b"NX") if xx: - pieces.append(b'XX') + pieces.append(b"XX") if ch: - pieces.append(b'CH') + pieces.append(b"CH") if incr: - pieces.append(b'INCR') - options['as_score'] = True + pieces.append(b"INCR") + options["as_score"] = True if gt: - pieces.append(b'GT') + pieces.append(b"GT") if lt: - pieces.append(b'LT') + pieces.append(b"LT") for pair in mapping.items(): pieces.append(pair[1]) pieces.append(pair[0]) - return self.execute_command('ZADD', name, *pieces, **options) + return self.execute_command("ZADD", name, *pieces, **options) def zcard(self, name): """ @@ -2811,7 +2930,7 @@ def zcard(self, name): For more information check https://redis.io/commands/zcard """ - return self.execute_command('ZCARD', name) + return self.execute_command("ZCARD", name) def zcount(self, name, min, max): """ @@ -2820,7 +2939,7 @@ def zcount(self, name, min, max): For more information check https://redis.io/commands/zcount """ - return self.execute_command('ZCOUNT', name, min, max) + return self.execute_command("ZCOUNT", name, min, max) def zdiff(self, keys, withscores=False): """ @@ -2850,7 +2969,7 @@ def zincrby(self, name, amount, value): For more information check https://redis.io/commands/zincrby """ - return self.execute_command('ZINCRBY', name, amount, value) + return self.execute_command("ZINCRBY", name, amount, value) def zinter(self, keys, aggregate=None, withscores=False): """ @@ -2864,8 +2983,7 @@ def zinter(self, keys, aggregate=None, withscores=False): For more information check https://redis.io/commands/zinter """ - return self._zaggregate('ZINTER', None, keys, aggregate, - withscores=withscores) + return self._zaggregate("ZINTER", None, keys, aggregate, withscores=withscores) def zinterstore(self, dest, keys, aggregate=None): """ @@ -2879,7 +2997,7 @@ def zinterstore(self, dest, keys, aggregate=None): For more information check https://redis.io/commands/zinterstore """ - return self._zaggregate('ZINTERSTORE', dest, keys, aggregate) + return self._zaggregate("ZINTERSTORE", dest, keys, aggregate) def zlexcount(self, name, min, max): """ @@ -2888,7 +3006,7 @@ def zlexcount(self, name, min, max): For more information check https://redis.io/commands/zlexcount """ - return self.execute_command('ZLEXCOUNT', name, min, max) + return self.execute_command("ZLEXCOUNT", name, min, max) def zpopmax(self, name, count=None): """ @@ -2898,10 +3016,8 @@ def zpopmax(self, name, count=None): For more information check https://redis.io/commands/zpopmax """ args = (count is not None) and [count] or [] - options = { - 'withscores': True - } - return self.execute_command('ZPOPMAX', name, *args, **options) + options = {"withscores": True} + return self.execute_command("ZPOPMAX", name, *args, **options) def zpopmin(self, name, count=None): """ @@ -2911,10 +3027,8 @@ def zpopmin(self, name, count=None): For more information check https://redis.io/commands/zpopmin """ args = (count is not None) and [count] or [] - options = { - 'withscores': True - } - return self.execute_command('ZPOPMIN', name, *args, **options) + options = {"withscores": True} + return self.execute_command("ZPOPMIN", name, *args, **options) def zrandmember(self, key, count=None, withscores=False): """ @@ -2957,7 +3071,7 @@ def bzpopmax(self, keys, timeout=0): timeout = 0 keys = list_or_args(keys, None) keys.append(timeout) - return self.execute_command('BZPOPMAX', *keys) + return self.execute_command("BZPOPMAX", *keys) def bzpopmin(self, keys, timeout=0): """ @@ -2976,43 +3090,63 @@ def bzpopmin(self, keys, timeout=0): timeout = 0 keys = list_or_args(keys, None) keys.append(timeout) - return self.execute_command('BZPOPMIN', *keys) - - def _zrange(self, command, dest, name, start, end, desc=False, - byscore=False, bylex=False, withscores=False, - score_cast_func=float, offset=None, num=None): + return self.execute_command("BZPOPMIN", *keys) + + def _zrange( + self, + command, + dest, + name, + start, + end, + desc=False, + byscore=False, + bylex=False, + withscores=False, + score_cast_func=float, + offset=None, + num=None, + ): if byscore and bylex: - raise DataError("``byscore`` and ``bylex`` can not be " - "specified together.") - if (offset is not None and num is None) or \ - (num is not None and offset is None): + raise DataError( + "``byscore`` and ``bylex`` can not be " "specified together." + ) + if (offset is not None and num is None) or (num is not None and offset is None): raise DataError("``offset`` and ``num`` must both be specified.") if bylex and withscores: - raise DataError("``withscores`` not supported in combination " - "with ``bylex``.") + raise DataError( + "``withscores`` not supported in combination " "with ``bylex``." + ) pieces = [command] if dest: pieces.append(dest) pieces.extend([name, start, end]) if byscore: - pieces.append('BYSCORE') + pieces.append("BYSCORE") if bylex: - pieces.append('BYLEX') + pieces.append("BYLEX") if desc: - pieces.append('REV') + pieces.append("REV") if offset is not None and num is not None: - pieces.extend(['LIMIT', offset, num]) + pieces.extend(["LIMIT", offset, num]) if withscores: - pieces.append('WITHSCORES') - options = { - 'withscores': withscores, - 'score_cast_func': score_cast_func - } + pieces.append("WITHSCORES") + options = {"withscores": withscores, "score_cast_func": score_cast_func} return self.execute_command(*pieces, **options) - def zrange(self, name, start, end, desc=False, withscores=False, - score_cast_func=float, byscore=False, bylex=False, - offset=None, num=None): + def zrange( + self, + name, + start, + end, + desc=False, + withscores=False, + score_cast_func=float, + byscore=False, + bylex=False, + offset=None, + num=None, + ): """ Return a range of values from sorted set ``name`` between ``start`` and ``end`` sorted in ascending order. @@ -3043,16 +3177,25 @@ def zrange(self, name, start, end, desc=False, withscores=False, """ # Need to support ``desc`` also when using old redis version # because it was supported in 3.5.3 (of redis-py) - if not byscore and not bylex and (offset is None and num is None) \ - and desc: - return self.zrevrange(name, start, end, withscores, - score_cast_func) - - return self._zrange('ZRANGE', None, name, start, end, desc, byscore, - bylex, withscores, score_cast_func, offset, num) + if not byscore and not bylex and (offset is None and num is None) and desc: + return self.zrevrange(name, start, end, withscores, score_cast_func) + + return self._zrange( + "ZRANGE", + None, + name, + start, + end, + desc, + byscore, + bylex, + withscores, + score_cast_func, + offset, + num, + ) - def zrevrange(self, name, start, end, withscores=False, - score_cast_func=float): + def zrevrange(self, name, start, end, withscores=False, score_cast_func=float): """ Return a range of values from sorted set ``name`` between ``start`` and ``end`` sorted in descending order. @@ -3066,18 +3209,24 @@ def zrevrange(self, name, start, end, withscores=False, For more information check https://redis.io/commands/zrevrange """ - pieces = ['ZREVRANGE', name, start, end] + pieces = ["ZREVRANGE", name, start, end] if withscores: - pieces.append(b'WITHSCORES') - options = { - 'withscores': withscores, - 'score_cast_func': score_cast_func - } + pieces.append(b"WITHSCORES") + options = {"withscores": withscores, "score_cast_func": score_cast_func} return self.execute_command(*pieces, **options) - def zrangestore(self, dest, name, start, end, - byscore=False, bylex=False, desc=False, - offset=None, num=None): + def zrangestore( + self, + dest, + name, + start, + end, + byscore=False, + bylex=False, + desc=False, + offset=None, + num=None, + ): """ Stores in ``dest`` the result of a range of values from sorted set ``name`` between ``start`` and ``end`` sorted in ascending order. @@ -3101,8 +3250,20 @@ def zrangestore(self, dest, name, start, end, For more information check https://redis.io/commands/zrangestore """ - return self._zrange('ZRANGESTORE', dest, name, start, end, desc, - byscore, bylex, False, None, offset, num) + return self._zrange( + "ZRANGESTORE", + dest, + name, + start, + end, + desc, + byscore, + bylex, + False, + None, + offset, + num, + ) def zrangebylex(self, name, min, max, start=None, num=None): """ @@ -3114,12 +3275,11 @@ def zrangebylex(self, name, min, max, start=None, num=None): For more information check https://redis.io/commands/zrangebylex """ - if (start is not None and num is None) or \ - (num is not None and start is None): + if (start is not None and num is None) or (num is not None and start is None): raise DataError("``start`` and ``num`` must both be specified") - pieces = ['ZRANGEBYLEX', name, min, max] + pieces = ["ZRANGEBYLEX", name, min, max] if start is not None and num is not None: - pieces.extend([b'LIMIT', start, num]) + pieces.extend([b"LIMIT", start, num]) return self.execute_command(*pieces) def zrevrangebylex(self, name, max, min, start=None, num=None): @@ -3132,16 +3292,23 @@ def zrevrangebylex(self, name, max, min, start=None, num=None): For more information check https://redis.io/commands/zrevrangebylex """ - if (start is not None and num is None) or \ - (num is not None and start is None): + if (start is not None and num is None) or (num is not None and start is None): raise DataError("``start`` and ``num`` must both be specified") - pieces = ['ZREVRANGEBYLEX', name, max, min] + pieces = ["ZREVRANGEBYLEX", name, max, min] if start is not None and num is not None: - pieces.extend(['LIMIT', start, num]) + pieces.extend(["LIMIT", start, num]) return self.execute_command(*pieces) - def zrangebyscore(self, name, min, max, start=None, num=None, - withscores=False, score_cast_func=float): + def zrangebyscore( + self, + name, + min, + max, + start=None, + num=None, + withscores=False, + score_cast_func=float, + ): """ Return a range of values from the sorted set ``name`` with scores between ``min`` and ``max``. @@ -3156,22 +3323,26 @@ def zrangebyscore(self, name, min, max, start=None, num=None, For more information check https://redis.io/commands/zrangebyscore """ - if (start is not None and num is None) or \ - (num is not None and start is None): + if (start is not None and num is None) or (num is not None and start is None): raise DataError("``start`` and ``num`` must both be specified") - pieces = ['ZRANGEBYSCORE', name, min, max] + pieces = ["ZRANGEBYSCORE", name, min, max] if start is not None and num is not None: - pieces.extend(['LIMIT', start, num]) + pieces.extend(["LIMIT", start, num]) if withscores: - pieces.append('WITHSCORES') - options = { - 'withscores': withscores, - 'score_cast_func': score_cast_func - } + pieces.append("WITHSCORES") + options = {"withscores": withscores, "score_cast_func": score_cast_func} return self.execute_command(*pieces, **options) - def zrevrangebyscore(self, name, max, min, start=None, num=None, - withscores=False, score_cast_func=float): + def zrevrangebyscore( + self, + name, + max, + min, + start=None, + num=None, + withscores=False, + score_cast_func=float, + ): """ Return a range of values from the sorted set ``name`` with scores between ``min`` and ``max`` in descending order. @@ -3186,18 +3357,14 @@ def zrevrangebyscore(self, name, max, min, start=None, num=None, For more information check https://redis.io/commands/zrevrangebyscore """ - if (start is not None and num is None) or \ - (num is not None and start is None): + if (start is not None and num is None) or (num is not None and start is None): raise DataError("``start`` and ``num`` must both be specified") - pieces = ['ZREVRANGEBYSCORE', name, max, min] + pieces = ["ZREVRANGEBYSCORE", name, max, min] if start is not None and num is not None: - pieces.extend(['LIMIT', start, num]) + pieces.extend(["LIMIT", start, num]) if withscores: - pieces.append('WITHSCORES') - options = { - 'withscores': withscores, - 'score_cast_func': score_cast_func - } + pieces.append("WITHSCORES") + options = {"withscores": withscores, "score_cast_func": score_cast_func} return self.execute_command(*pieces, **options) def zrank(self, name, value): @@ -3207,7 +3374,7 @@ def zrank(self, name, value): For more information check https://redis.io/commands/zrank """ - return self.execute_command('ZRANK', name, value) + return self.execute_command("ZRANK", name, value) def zrem(self, name, *values): """ @@ -3215,7 +3382,7 @@ def zrem(self, name, *values): For more information check https://redis.io/commands/zrem """ - return self.execute_command('ZREM', name, *values) + return self.execute_command("ZREM", name, *values) def zremrangebylex(self, name, min, max): """ @@ -3226,7 +3393,7 @@ def zremrangebylex(self, name, min, max): For more information check https://redis.io/commands/zremrangebylex """ - return self.execute_command('ZREMRANGEBYLEX', name, min, max) + return self.execute_command("ZREMRANGEBYLEX", name, min, max) def zremrangebyrank(self, name, min, max): """ @@ -3237,7 +3404,7 @@ def zremrangebyrank(self, name, min, max): For more information check https://redis.io/commands/zremrangebyrank """ - return self.execute_command('ZREMRANGEBYRANK', name, min, max) + return self.execute_command("ZREMRANGEBYRANK", name, min, max) def zremrangebyscore(self, name, min, max): """ @@ -3246,7 +3413,7 @@ def zremrangebyscore(self, name, min, max): For more information check https://redis.io/commands/zremrangebyscore """ - return self.execute_command('ZREMRANGEBYSCORE', name, min, max) + return self.execute_command("ZREMRANGEBYSCORE", name, min, max) def zrevrank(self, name, value): """ @@ -3255,7 +3422,7 @@ def zrevrank(self, name, value): For more information check https://redis.io/commands/zrevrank """ - return self.execute_command('ZREVRANK', name, value) + return self.execute_command("ZREVRANK", name, value) def zscore(self, name, value): """ @@ -3263,7 +3430,7 @@ def zscore(self, name, value): For more information check https://redis.io/commands/zscore """ - return self.execute_command('ZSCORE', name, value) + return self.execute_command("ZSCORE", name, value) def zunion(self, keys, aggregate=None, withscores=False): """ @@ -3274,8 +3441,7 @@ def zunion(self, keys, aggregate=None, withscores=False): For more information check https://redis.io/commands/zunion """ - return self._zaggregate('ZUNION', None, keys, aggregate, - withscores=withscores) + return self._zaggregate("ZUNION", None, keys, aggregate, withscores=withscores) def zunionstore(self, dest, keys, aggregate=None): """ @@ -3285,7 +3451,7 @@ def zunionstore(self, dest, keys, aggregate=None): For more information check https://redis.io/commands/zunionstore """ - return self._zaggregate('ZUNIONSTORE', dest, keys, aggregate) + return self._zaggregate("ZUNIONSTORE", dest, keys, aggregate) def zmscore(self, key, members): """ @@ -3299,12 +3465,11 @@ def zmscore(self, key, members): For more information check https://redis.io/commands/zmscore """ if not members: - raise DataError('ZMSCORE members must be a non-empty list') + raise DataError("ZMSCORE members must be a non-empty list") pieces = [key] + members - return self.execute_command('ZMSCORE', *pieces) + return self.execute_command("ZMSCORE", *pieces) - def _zaggregate(self, command, dest, keys, aggregate=None, - **options): + def _zaggregate(self, command, dest, keys, aggregate=None, **options): pieces = [command] if dest is not None: pieces.append(dest) @@ -3315,16 +3480,16 @@ def _zaggregate(self, command, dest, keys, aggregate=None, weights = None pieces.extend(keys) if weights: - pieces.append(b'WEIGHTS') + pieces.append(b"WEIGHTS") pieces.extend(weights) if aggregate: - if aggregate.upper() in ['SUM', 'MIN', 'MAX']: - pieces.append(b'AGGREGATE') + if aggregate.upper() in ["SUM", "MIN", "MAX"]: + pieces.append(b"AGGREGATE") pieces.append(aggregate) else: raise DataError("aggregate can be sum, min or max.") - if options.get('withscores', False): - pieces.append(b'WITHSCORES') + if options.get("withscores", False): + pieces.append(b"WITHSCORES") return self.execute_command(*pieces, **options) @@ -3333,13 +3498,14 @@ class HyperlogCommands: Redis commands of HyperLogLogs data type. see: https://redis.io/topics/data-types-intro#hyperloglogs """ + def pfadd(self, name, *values): """ Adds the specified elements to the specified HyperLogLog. For more information check https://redis.io/commands/pfadd """ - return self.execute_command('PFADD', name, *values) + return self.execute_command("PFADD", name, *values) def pfcount(self, *sources): """ @@ -3348,7 +3514,7 @@ def pfcount(self, *sources): For more information check https://redis.io/commands/pfcount """ - return self.execute_command('PFCOUNT', *sources) + return self.execute_command("PFCOUNT", *sources) def pfmerge(self, dest, *sources): """ @@ -3356,7 +3522,7 @@ def pfmerge(self, dest, *sources): For more information check https://redis.io/commands/pfmerge """ - return self.execute_command('PFMERGE', dest, *sources) + return self.execute_command("PFMERGE", dest, *sources) class HashCommands: @@ -3364,13 +3530,14 @@ class HashCommands: Redis commands for Hash data type. see: https://redis.io/topics/data-types-intro#redis-hashes """ + def hdel(self, name, *keys): """ Delete ``keys`` from hash ``name`` For more information check https://redis.io/commands/hdel """ - return self.execute_command('HDEL', name, *keys) + return self.execute_command("HDEL", name, *keys) def hexists(self, name, key): """ @@ -3378,7 +3545,7 @@ def hexists(self, name, key): For more information check https://redis.io/commands/hexists """ - return self.execute_command('HEXISTS', name, key) + return self.execute_command("HEXISTS", name, key) def hget(self, name, key): """ @@ -3386,7 +3553,7 @@ def hget(self, name, key): For more information check https://redis.io/commands/hget """ - return self.execute_command('HGET', name, key) + return self.execute_command("HGET", name, key) def hgetall(self, name): """ @@ -3394,7 +3561,7 @@ def hgetall(self, name): For more information check https://redis.io/commands/hgetall """ - return self.execute_command('HGETALL', name) + return self.execute_command("HGETALL", name) def hincrby(self, name, key, amount=1): """ @@ -3402,7 +3569,7 @@ def hincrby(self, name, key, amount=1): For more information check https://redis.io/commands/hincrby """ - return self.execute_command('HINCRBY', name, key, amount) + return self.execute_command("HINCRBY", name, key, amount) def hincrbyfloat(self, name, key, amount=1.0): """ @@ -3410,7 +3577,7 @@ def hincrbyfloat(self, name, key, amount=1.0): For more information check https://redis.io/commands/hincrbyfloat """ - return self.execute_command('HINCRBYFLOAT', name, key, amount) + return self.execute_command("HINCRBYFLOAT", name, key, amount) def hkeys(self, name): """ @@ -3418,7 +3585,7 @@ def hkeys(self, name): For more information check https://redis.io/commands/hkeys """ - return self.execute_command('HKEYS', name) + return self.execute_command("HKEYS", name) def hlen(self, name): """ @@ -3426,7 +3593,7 @@ def hlen(self, name): For more information check https://redis.io/commands/hlen """ - return self.execute_command('HLEN', name) + return self.execute_command("HLEN", name) def hset(self, name, key=None, value=None, mapping=None): """ @@ -3446,7 +3613,7 @@ def hset(self, name, key=None, value=None, mapping=None): for pair in mapping.items(): items.extend(pair) - return self.execute_command('HSET', name, *items) + return self.execute_command("HSET", name, *items) def hsetnx(self, name, key, value): """ @@ -3455,7 +3622,7 @@ def hsetnx(self, name, key, value): For more information check https://redis.io/commands/hsetnx """ - return self.execute_command('HSETNX', name, key, value) + return self.execute_command("HSETNX", name, key, value) def hmset(self, name, mapping): """ @@ -3465,8 +3632,8 @@ def hmset(self, name, mapping): For more information check https://redis.io/commands/hmset """ warnings.warn( - f'{self.__class__.__name__}.hmset() is deprecated. ' - f'Use {self.__class__.__name__}.hset() instead.', + f"{self.__class__.__name__}.hmset() is deprecated. " + f"Use {self.__class__.__name__}.hset() instead.", DeprecationWarning, stacklevel=2, ) @@ -3475,7 +3642,7 @@ def hmset(self, name, mapping): items = [] for pair in mapping.items(): items.extend(pair) - return self.execute_command('HMSET', name, *items) + return self.execute_command("HMSET", name, *items) def hmget(self, name, keys, *args): """ @@ -3484,7 +3651,7 @@ def hmget(self, name, keys, *args): For more information check https://redis.io/commands/hmget """ args = list_or_args(keys, args) - return self.execute_command('HMGET', name, *args) + return self.execute_command("HMGET", name, *args) def hvals(self, name): """ @@ -3492,7 +3659,7 @@ def hvals(self, name): For more information check https://redis.io/commands/hvals """ - return self.execute_command('HVALS', name) + return self.execute_command("HVALS", name) def hstrlen(self, name, key): """ @@ -3501,7 +3668,7 @@ def hstrlen(self, name, key): For more information check https://redis.io/commands/hstrlen """ - return self.execute_command('HSTRLEN', name, key) + return self.execute_command("HSTRLEN", name, key) class PubSubCommands: @@ -3509,6 +3676,7 @@ class PubSubCommands: Redis PubSub commands. see https://redis.io/topics/pubsub """ + def publish(self, channel, message): """ Publish ``message`` on ``channel``. @@ -3516,15 +3684,15 @@ def publish(self, channel, message): For more information check https://redis.io/commands/publish """ - return self.execute_command('PUBLISH', channel, message) + return self.execute_command("PUBLISH", channel, message) - def pubsub_channels(self, pattern='*'): + def pubsub_channels(self, pattern="*"): """ Return a list of channels that have at least one subscriber For more information check https://redis.io/commands/pubsub-channels """ - return self.execute_command('PUBSUB CHANNELS', pattern) + return self.execute_command("PUBSUB CHANNELS", pattern) def pubsub_numpat(self): """ @@ -3532,7 +3700,7 @@ def pubsub_numpat(self): For more information check https://redis.io/commands/pubsub-numpat """ - return self.execute_command('PUBSUB NUMPAT') + return self.execute_command("PUBSUB NUMPAT") def pubsub_numsub(self, *args): """ @@ -3541,7 +3709,7 @@ def pubsub_numsub(self, *args): For more information check https://redis.io/commands/pubsub-numsub """ - return self.execute_command('PUBSUB NUMSUB', *args) + return self.execute_command("PUBSUB NUMSUB", *args) class ScriptCommands: @@ -3549,6 +3717,7 @@ class ScriptCommands: Redis Lua script commands. see: https://redis.com/ebook/part-3-next-steps/chapter-11-scripting-redis-with-lua/ """ + def eval(self, script, numkeys, *keys_and_args): """ Execute the Lua ``script``, specifying the ``numkeys`` the script @@ -3560,7 +3729,7 @@ def eval(self, script, numkeys, *keys_and_args): For more information check https://redis.io/commands/eval """ - return self.execute_command('EVAL', script, numkeys, *keys_and_args) + return self.execute_command("EVAL", script, numkeys, *keys_and_args) def evalsha(self, sha, numkeys, *keys_and_args): """ @@ -3574,7 +3743,7 @@ def evalsha(self, sha, numkeys, *keys_and_args): For more information check https://redis.io/commands/evalsha """ - return self.execute_command('EVALSHA', sha, numkeys, *keys_and_args) + return self.execute_command("EVALSHA", sha, numkeys, *keys_and_args) def script_exists(self, *args): """ @@ -3584,7 +3753,7 @@ def script_exists(self, *args): For more information check https://redis.io/commands/script-exists """ - return self.execute_command('SCRIPT EXISTS', *args) + return self.execute_command("SCRIPT EXISTS", *args) def script_debug(self, *args): raise NotImplementedError( @@ -3600,14 +3769,16 @@ def script_flush(self, sync_type=None): # Redis pre 6 had no sync_type. if sync_type not in ["SYNC", "ASYNC", None]: - raise DataError("SCRIPT FLUSH defaults to SYNC in redis > 6.2, or " - "accepts SYNC/ASYNC. For older versions, " - "of redis leave as None.") + raise DataError( + "SCRIPT FLUSH defaults to SYNC in redis > 6.2, or " + "accepts SYNC/ASYNC. For older versions, " + "of redis leave as None." + ) if sync_type is None: pieces = [] else: pieces = [sync_type] - return self.execute_command('SCRIPT FLUSH', *pieces) + return self.execute_command("SCRIPT FLUSH", *pieces) def script_kill(self): """ @@ -3615,7 +3786,7 @@ def script_kill(self): For more information check https://redis.io/commands/script-kill """ - return self.execute_command('SCRIPT KILL') + return self.execute_command("SCRIPT KILL") def script_load(self, script): """ @@ -3623,7 +3794,7 @@ def script_load(self, script): For more information check https://redis.io/commands/script-load """ - return self.execute_command('SCRIPT LOAD', script) + return self.execute_command("SCRIPT LOAD", script) def register_script(self, script): """ @@ -3640,6 +3811,7 @@ class GeoCommands: Redis Geospatial commands. see: https://redis.com/redis-best-practices/indexing-patterns/geospatial/ """ + def geoadd(self, name, values, nx=False, xx=False, ch=False): """ Add the specified geospatial items to the specified key identified @@ -3664,17 +3836,16 @@ def geoadd(self, name, values, nx=False, xx=False, ch=False): if nx and xx: raise DataError("GEOADD allows either 'nx' or 'xx', not both") if len(values) % 3 != 0: - raise DataError("GEOADD requires places with lon, lat and name" - " values") + raise DataError("GEOADD requires places with lon, lat and name" " values") pieces = [name] if nx: - pieces.append('NX') + pieces.append("NX") if xx: - pieces.append('XX') + pieces.append("XX") if ch: - pieces.append('CH') + pieces.append("CH") pieces.extend(values) - return self.execute_command('GEOADD', *pieces) + return self.execute_command("GEOADD", *pieces) def geodist(self, name, place1, place2, unit=None): """ @@ -3686,11 +3857,11 @@ def geodist(self, name, place1, place2, unit=None): For more information check https://redis.io/commands/geodist """ pieces = [name, place1, place2] - if unit and unit not in ('m', 'km', 'mi', 'ft'): + if unit and unit not in ("m", "km", "mi", "ft"): raise DataError("GEODIST invalid unit") elif unit: pieces.append(unit) - return self.execute_command('GEODIST', *pieces) + return self.execute_command("GEODIST", *pieces) def geohash(self, name, *values): """ @@ -3699,7 +3870,7 @@ def geohash(self, name, *values): For more information check https://redis.io/commands/geohash """ - return self.execute_command('GEOHASH', name, *values) + return self.execute_command("GEOHASH", name, *values) def geopos(self, name, *values): """ @@ -3709,11 +3880,24 @@ def geopos(self, name, *values): For more information check https://redis.io/commands/geopos """ - return self.execute_command('GEOPOS', name, *values) - - def georadius(self, name, longitude, latitude, radius, unit=None, - withdist=False, withcoord=False, withhash=False, count=None, - sort=None, store=None, store_dist=None, any=False): + return self.execute_command("GEOPOS", name, *values) + + def georadius( + self, + name, + longitude, + latitude, + radius, + unit=None, + withdist=False, + withcoord=False, + withhash=False, + count=None, + sort=None, + store=None, + store_dist=None, + any=False, + ): """ Return the members of the specified key identified by the ``name`` argument which are within the borders of the area specified @@ -3744,17 +3928,38 @@ def georadius(self, name, longitude, latitude, radius, unit=None, For more information check https://redis.io/commands/georadius """ - return self._georadiusgeneric('GEORADIUS', - name, longitude, latitude, radius, - unit=unit, withdist=withdist, - withcoord=withcoord, withhash=withhash, - count=count, sort=sort, store=store, - store_dist=store_dist, any=any) + return self._georadiusgeneric( + "GEORADIUS", + name, + longitude, + latitude, + radius, + unit=unit, + withdist=withdist, + withcoord=withcoord, + withhash=withhash, + count=count, + sort=sort, + store=store, + store_dist=store_dist, + any=any, + ) - def georadiusbymember(self, name, member, radius, unit=None, - withdist=False, withcoord=False, withhash=False, - count=None, sort=None, store=None, store_dist=None, - any=False): + def georadiusbymember( + self, + name, + member, + radius, + unit=None, + withdist=False, + withcoord=False, + withhash=False, + count=None, + sort=None, + store=None, + store_dist=None, + any=False, + ): """ This command is exactly like ``georadius`` with the sole difference that instead of taking, as the center of the area to query, a longitude @@ -3763,61 +3968,85 @@ def georadiusbymember(self, name, member, radius, unit=None, For more information check https://redis.io/commands/georadiusbymember """ - return self._georadiusgeneric('GEORADIUSBYMEMBER', - name, member, radius, unit=unit, - withdist=withdist, withcoord=withcoord, - withhash=withhash, count=count, - sort=sort, store=store, - store_dist=store_dist, any=any) + return self._georadiusgeneric( + "GEORADIUSBYMEMBER", + name, + member, + radius, + unit=unit, + withdist=withdist, + withcoord=withcoord, + withhash=withhash, + count=count, + sort=sort, + store=store, + store_dist=store_dist, + any=any, + ) def _georadiusgeneric(self, command, *args, **kwargs): pieces = list(args) - if kwargs['unit'] and kwargs['unit'] not in ('m', 'km', 'mi', 'ft'): + if kwargs["unit"] and kwargs["unit"] not in ("m", "km", "mi", "ft"): raise DataError("GEORADIUS invalid unit") - elif kwargs['unit']: - pieces.append(kwargs['unit']) + elif kwargs["unit"]: + pieces.append(kwargs["unit"]) else: - pieces.append('m',) + pieces.append( + "m", + ) - if kwargs['any'] and kwargs['count'] is None: + if kwargs["any"] and kwargs["count"] is None: raise DataError("``any`` can't be provided without ``count``") for arg_name, byte_repr in ( - ('withdist', 'WITHDIST'), - ('withcoord', 'WITHCOORD'), - ('withhash', 'WITHHASH')): + ("withdist", "WITHDIST"), + ("withcoord", "WITHCOORD"), + ("withhash", "WITHHASH"), + ): if kwargs[arg_name]: pieces.append(byte_repr) - if kwargs['count'] is not None: - pieces.extend(['COUNT', kwargs['count']]) - if kwargs['any']: - pieces.append('ANY') + if kwargs["count"] is not None: + pieces.extend(["COUNT", kwargs["count"]]) + if kwargs["any"]: + pieces.append("ANY") - if kwargs['sort']: - if kwargs['sort'] == 'ASC': - pieces.append('ASC') - elif kwargs['sort'] == 'DESC': - pieces.append('DESC') + if kwargs["sort"]: + if kwargs["sort"] == "ASC": + pieces.append("ASC") + elif kwargs["sort"] == "DESC": + pieces.append("DESC") else: raise DataError("GEORADIUS invalid sort") - if kwargs['store'] and kwargs['store_dist']: - raise DataError("GEORADIUS store and store_dist cant be set" - " together") + if kwargs["store"] and kwargs["store_dist"]: + raise DataError("GEORADIUS store and store_dist cant be set" " together") - if kwargs['store']: - pieces.extend([b'STORE', kwargs['store']]) + if kwargs["store"]: + pieces.extend([b"STORE", kwargs["store"]]) - if kwargs['store_dist']: - pieces.extend([b'STOREDIST', kwargs['store_dist']]) + if kwargs["store_dist"]: + pieces.extend([b"STOREDIST", kwargs["store_dist"]]) return self.execute_command(command, *pieces, **kwargs) - def geosearch(self, name, member=None, longitude=None, latitude=None, - unit='m', radius=None, width=None, height=None, sort=None, - count=None, any=False, withcoord=False, - withdist=False, withhash=False): + def geosearch( + self, + name, + member=None, + longitude=None, + latitude=None, + unit="m", + radius=None, + width=None, + height=None, + sort=None, + count=None, + any=False, + withcoord=False, + withdist=False, + withhash=False, + ): """ Return the members of specified key identified by the ``name`` argument, which are within the borders of the @@ -3853,19 +4082,42 @@ def geosearch(self, name, member=None, longitude=None, latitude=None, For more information check https://redis.io/commands/geosearch """ - return self._geosearchgeneric('GEOSEARCH', - name, member=member, longitude=longitude, - latitude=latitude, unit=unit, - radius=radius, width=width, - height=height, sort=sort, count=count, - any=any, withcoord=withcoord, - withdist=withdist, withhash=withhash, - store=None, store_dist=None) + return self._geosearchgeneric( + "GEOSEARCH", + name, + member=member, + longitude=longitude, + latitude=latitude, + unit=unit, + radius=radius, + width=width, + height=height, + sort=sort, + count=count, + any=any, + withcoord=withcoord, + withdist=withdist, + withhash=withhash, + store=None, + store_dist=None, + ) - def geosearchstore(self, dest, name, member=None, longitude=None, - latitude=None, unit='m', radius=None, width=None, - height=None, sort=None, count=None, any=False, - storedist=False): + def geosearchstore( + self, + dest, + name, + member=None, + longitude=None, + latitude=None, + unit="m", + radius=None, + width=None, + height=None, + sort=None, + count=None, + any=False, + storedist=False, + ): """ This command is like GEOSEARCH, but stores the result in ``dest``. By default, it stores the results in the destination @@ -3876,74 +4128,86 @@ def geosearchstore(self, dest, name, member=None, longitude=None, For more information check https://redis.io/commands/geosearchstore """ - return self._geosearchgeneric('GEOSEARCHSTORE', - dest, name, member=member, - longitude=longitude, latitude=latitude, - unit=unit, radius=radius, width=width, - height=height, sort=sort, count=count, - any=any, withcoord=None, - withdist=None, withhash=None, - store=None, store_dist=storedist) + return self._geosearchgeneric( + "GEOSEARCHSTORE", + dest, + name, + member=member, + longitude=longitude, + latitude=latitude, + unit=unit, + radius=radius, + width=width, + height=height, + sort=sort, + count=count, + any=any, + withcoord=None, + withdist=None, + withhash=None, + store=None, + store_dist=storedist, + ) def _geosearchgeneric(self, command, *args, **kwargs): pieces = list(args) # FROMMEMBER or FROMLONLAT - if kwargs['member'] is None: - if kwargs['longitude'] is None or kwargs['latitude'] is None: - raise DataError("GEOSEARCH must have member or" - " longitude and latitude") - if kwargs['member']: - if kwargs['longitude'] or kwargs['latitude']: - raise DataError("GEOSEARCH member and longitude or latitude" - " cant be set together") - pieces.extend([b'FROMMEMBER', kwargs['member']]) - if kwargs['longitude'] and kwargs['latitude']: - pieces.extend([b'FROMLONLAT', - kwargs['longitude'], kwargs['latitude']]) + if kwargs["member"] is None: + if kwargs["longitude"] is None or kwargs["latitude"] is None: + raise DataError( + "GEOSEARCH must have member or" " longitude and latitude" + ) + if kwargs["member"]: + if kwargs["longitude"] or kwargs["latitude"]: + raise DataError( + "GEOSEARCH member and longitude or latitude" " cant be set together" + ) + pieces.extend([b"FROMMEMBER", kwargs["member"]]) + if kwargs["longitude"] and kwargs["latitude"]: + pieces.extend([b"FROMLONLAT", kwargs["longitude"], kwargs["latitude"]]) # BYRADIUS or BYBOX - if kwargs['radius'] is None: - if kwargs['width'] is None or kwargs['height'] is None: - raise DataError("GEOSEARCH must have radius or" - " width and height") - if kwargs['unit'] is None: + if kwargs["radius"] is None: + if kwargs["width"] is None or kwargs["height"] is None: + raise DataError("GEOSEARCH must have radius or" " width and height") + if kwargs["unit"] is None: raise DataError("GEOSEARCH must have unit") - if kwargs['unit'].lower() not in ('m', 'km', 'mi', 'ft'): + if kwargs["unit"].lower() not in ("m", "km", "mi", "ft"): raise DataError("GEOSEARCH invalid unit") - if kwargs['radius']: - if kwargs['width'] or kwargs['height']: - raise DataError("GEOSEARCH radius and width or height" - " cant be set together") - pieces.extend([b'BYRADIUS', kwargs['radius'], kwargs['unit']]) - if kwargs['width'] and kwargs['height']: - pieces.extend([b'BYBOX', - kwargs['width'], kwargs['height'], kwargs['unit']]) + if kwargs["radius"]: + if kwargs["width"] or kwargs["height"]: + raise DataError( + "GEOSEARCH radius and width or height" " cant be set together" + ) + pieces.extend([b"BYRADIUS", kwargs["radius"], kwargs["unit"]]) + if kwargs["width"] and kwargs["height"]: + pieces.extend([b"BYBOX", kwargs["width"], kwargs["height"], kwargs["unit"]]) # sort - if kwargs['sort']: - if kwargs['sort'].upper() == 'ASC': - pieces.append(b'ASC') - elif kwargs['sort'].upper() == 'DESC': - pieces.append(b'DESC') + if kwargs["sort"]: + if kwargs["sort"].upper() == "ASC": + pieces.append(b"ASC") + elif kwargs["sort"].upper() == "DESC": + pieces.append(b"DESC") else: raise DataError("GEOSEARCH invalid sort") # count any - if kwargs['count']: - pieces.extend([b'COUNT', kwargs['count']]) - if kwargs['any']: - pieces.append(b'ANY') - elif kwargs['any']: - raise DataError("GEOSEARCH ``any`` can't be provided " - "without count") + if kwargs["count"]: + pieces.extend([b"COUNT", kwargs["count"]]) + if kwargs["any"]: + pieces.append(b"ANY") + elif kwargs["any"]: + raise DataError("GEOSEARCH ``any`` can't be provided " "without count") # other properties for arg_name, byte_repr in ( - ('withdist', b'WITHDIST'), - ('withcoord', b'WITHCOORD'), - ('withhash', b'WITHHASH'), - ('store_dist', b'STOREDIST')): + ("withdist", b"WITHDIST"), + ("withcoord", b"WITHCOORD"), + ("withhash", b"WITHHASH"), + ("store_dist", b"STOREDIST"), + ): if kwargs[arg_name]: pieces.append(byte_repr) @@ -3955,6 +4219,7 @@ class ModuleCommands: Redis Module commands. see: https://redis.io/topics/modules-intro """ + def module_load(self, path, *args): """ Loads the module from ``path``. @@ -3963,7 +4228,7 @@ def module_load(self, path, *args): For more information check https://redis.io/commands/module-load """ - return self.execute_command('MODULE LOAD', path, *args) + return self.execute_command("MODULE LOAD", path, *args) def module_unload(self, name): """ @@ -3972,7 +4237,7 @@ def module_unload(self, name): For more information check https://redis.io/commands/module-unload """ - return self.execute_command('MODULE UNLOAD', name) + return self.execute_command("MODULE UNLOAD", name) def module_list(self): """ @@ -3981,7 +4246,7 @@ def module_list(self): For more information check https://redis.io/commands/module-list """ - return self.execute_command('MODULE LIST') + return self.execute_command("MODULE LIST") def command_info(self): raise NotImplementedError( @@ -3989,13 +4254,13 @@ def command_info(self): ) def command_count(self): - return self.execute_command('COMMAND COUNT') + return self.execute_command("COMMAND COUNT") def command_getkeys(self, *args): - return self.execute_command('COMMAND GETKEYS', *args) + return self.execute_command("COMMAND GETKEYS", *args) def command(self): - return self.execute_command('COMMAND') + return self.execute_command("COMMAND") class Script: @@ -4022,6 +4287,7 @@ def __call__(self, keys=[], args=[], client=None): args = tuple(keys) + tuple(args) # make sure the Redis server knows about the script from redis.client import Pipeline + if isinstance(client, Pipeline): # Make sure the pipeline can register the script before executing. client.scripts.add(self) @@ -4039,6 +4305,7 @@ class BitFieldOperation: """ Command builder for BITFIELD commands. """ + def __init__(self, client, key, default_overflow=None): self.client = client self.key = key @@ -4050,7 +4317,7 @@ def reset(self): Reset the state of the instance to when it was constructed """ self.operations = [] - self._last_overflow = 'WRAP' + self._last_overflow = "WRAP" self.overflow(self._default_overflow or self._last_overflow) def overflow(self, overflow): @@ -4063,7 +4330,7 @@ def overflow(self, overflow): overflow = overflow.upper() if overflow != self._last_overflow: self._last_overflow = overflow - self.operations.append(('OVERFLOW', overflow)) + self.operations.append(("OVERFLOW", overflow)) return self def incrby(self, fmt, offset, increment, overflow=None): @@ -4083,7 +4350,7 @@ def incrby(self, fmt, offset, increment, overflow=None): if overflow is not None: self.overflow(overflow) - self.operations.append(('INCRBY', fmt, offset, increment)) + self.operations.append(("INCRBY", fmt, offset, increment)) return self def get(self, fmt, offset): @@ -4096,7 +4363,7 @@ def get(self, fmt, offset): fmt='u8', offset='#2', the offset will be 16. :returns: a :py:class:`BitFieldOperation` instance. """ - self.operations.append(('GET', fmt, offset)) + self.operations.append(("GET", fmt, offset)) return self def set(self, fmt, offset, value): @@ -4110,12 +4377,12 @@ def set(self, fmt, offset, value): :param int value: value to set at the given position. :returns: a :py:class:`BitFieldOperation` instance. """ - self.operations.append(('SET', fmt, offset, value)) + self.operations.append(("SET", fmt, offset, value)) return self @property def command(self): - cmd = ['BITFIELD', self.key] + cmd = ["BITFIELD", self.key] for ops in self.operations: cmd.extend(ops) return cmd @@ -4132,19 +4399,31 @@ def execute(self): return self.client.execute_command(*command) -class DataAccessCommands(BasicKeyCommands, ListCommands, - ScanCommands, SetCommands, StreamCommands, - SortedSetCommands, - HyperlogCommands, HashCommands, GeoCommands, - ): +class DataAccessCommands( + BasicKeyCommands, + ListCommands, + ScanCommands, + SetCommands, + StreamCommands, + SortedSetCommands, + HyperlogCommands, + HashCommands, + GeoCommands, +): """ A class containing all of the implemented data access redis commands. This class is to be used as a mixin. """ -class CoreCommands(ACLCommands, DataAccessCommands, ManagementCommands, - ModuleCommands, PubSubCommands, ScriptCommands): +class CoreCommands( + ACLCommands, + DataAccessCommands, + ManagementCommands, + ModuleCommands, + PubSubCommands, + ScriptCommands, +): """ A class containing all of the implemented redis commands. This class is to be used as a mixin. diff --git a/redis/commands/json/commands.py b/redis/commands/json/commands.py index 1affaafaf6..e7f07b612f 100644 --- a/redis/commands/json/commands.py +++ b/redis/commands/json/commands.py @@ -1,8 +1,10 @@ -from .path import Path -from .decoders import decode_dict_keys from deprecated import deprecated + from redis.exceptions import DataError +from .decoders import decode_dict_keys +from .path import Path + class JSONCommands: """json commands.""" @@ -29,8 +31,7 @@ def arrindex(self, name, path, scalar, start=0, stop=-1): For more information: https://oss.redis.com/redisjson/commands/#jsonarrindex """ # noqa return self.execute_command( - "JSON.ARRINDEX", name, str(path), self._encode(scalar), - start, stop + "JSON.ARRINDEX", name, str(path), self._encode(scalar), start, stop ) def arrinsert(self, name, path, index, *args): @@ -66,8 +67,7 @@ def arrtrim(self, name, path, start, stop): For more information: https://oss.redis.com/redisjson/commands/#jsonarrtrim """ # noqa - return self.execute_command("JSON.ARRTRIM", name, str(path), - start, stop) + return self.execute_command("JSON.ARRTRIM", name, str(path), start, stop) def type(self, name, path=Path.rootPath()): """Get the type of the JSON value under ``path`` from key ``name``. @@ -109,7 +109,7 @@ def numincrby(self, name, path, number): "JSON.NUMINCRBY", name, str(path), self._encode(number) ) - @deprecated(version='4.0.0', reason='deprecated since redisjson 1.0.0') + @deprecated(version="4.0.0", reason="deprecated since redisjson 1.0.0") def nummultby(self, name, path, number): """Multiply the numeric (integer or floating point) JSON value under ``path`` at key ``name`` with the provided ``number``. @@ -218,7 +218,7 @@ def strlen(self, name, path=None): ``name``. For more information: https://oss.redis.com/redisjson/commands/#jsonstrlen - """ # noqa + """ # noqa pieces = [name] if path is not None: pieces.append(str(path)) @@ -240,9 +240,7 @@ def strappend(self, name, value, path=Path.rootPath()): For more information: https://oss.redis.com/redisjson/commands/#jsonstrappend """ # noqa pieces = [name, str(path), self._encode(value)] - return self.execute_command( - "JSON.STRAPPEND", *pieces - ) + return self.execute_command("JSON.STRAPPEND", *pieces) def debug(self, subcommand, key=None, path=Path.rootPath()): """Return the memory usage in bytes of a value under ``path`` from @@ -252,8 +250,7 @@ def debug(self, subcommand, key=None, path=Path.rootPath()): """ # noqa valid_subcommands = ["MEMORY", "HELP"] if subcommand not in valid_subcommands: - raise DataError("The only valid subcommands are ", - str(valid_subcommands)) + raise DataError("The only valid subcommands are ", str(valid_subcommands)) pieces = [subcommand] if subcommand == "MEMORY": if key is None: @@ -262,17 +259,20 @@ def debug(self, subcommand, key=None, path=Path.rootPath()): pieces.append(str(path)) return self.execute_command("JSON.DEBUG", *pieces) - @deprecated(version='4.0.0', - reason='redisjson-py supported this, call get directly.') + @deprecated( + version="4.0.0", reason="redisjson-py supported this, call get directly." + ) def jsonget(self, *args, **kwargs): return self.get(*args, **kwargs) - @deprecated(version='4.0.0', - reason='redisjson-py supported this, call get directly.') + @deprecated( + version="4.0.0", reason="redisjson-py supported this, call get directly." + ) def jsonmget(self, *args, **kwargs): return self.mget(*args, **kwargs) - @deprecated(version='4.0.0', - reason='redisjson-py supported this, call get directly.') + @deprecated( + version="4.0.0", reason="redisjson-py supported this, call get directly." + ) def jsonset(self, *args, **kwargs): return self.set(*args, **kwargs) diff --git a/redis/commands/parser.py b/redis/commands/parser.py index 26b190c674..dadf3c6bf8 100644 --- a/redis/commands/parser.py +++ b/redis/commands/parser.py @@ -1,7 +1,4 @@ -from redis.exceptions import ( - RedisError, - ResponseError -) +from redis.exceptions import RedisError, ResponseError from redis.utils import str_if_bytes @@ -13,6 +10,7 @@ class CommandsParser: 'movablekeys', and these commands' keys are determined by the command 'COMMAND GETKEYS'. """ + def __init__(self, redis_connection): self.initialized = False self.commands = {} @@ -51,20 +49,24 @@ def get_keys(self, redis_conn, *args): ) command = self.commands.get(cmd_name) - if 'movablekeys' in command['flags']: + if "movablekeys" in command["flags"]: keys = self._get_moveable_keys(redis_conn, *args) - elif 'pubsub' in command['flags']: + elif "pubsub" in command["flags"]: keys = self._get_pubsub_keys(*args) else: - if command['step_count'] == 0 and command['first_key_pos'] == 0 \ - and command['last_key_pos'] == 0: + if ( + command["step_count"] == 0 + and command["first_key_pos"] == 0 + and command["last_key_pos"] == 0 + ): # The command doesn't have keys in it return None - last_key_pos = command['last_key_pos'] + last_key_pos = command["last_key_pos"] if last_key_pos < 0: last_key_pos = len(args) - abs(last_key_pos) - keys_pos = list(range(command['first_key_pos'], last_key_pos + 1, - command['step_count'])) + keys_pos = list( + range(command["first_key_pos"], last_key_pos + 1, command["step_count"]) + ) keys = [args[pos] for pos in keys_pos] return keys @@ -77,11 +79,13 @@ def _get_moveable_keys(self, redis_conn, *args): pieces = pieces + cmd_name.split() pieces = pieces + list(args[1:]) try: - keys = redis_conn.execute_command('COMMAND GETKEYS', *pieces) + keys = redis_conn.execute_command("COMMAND GETKEYS", *pieces) except ResponseError as e: message = e.__str__() - if 'Invalid arguments' in message or \ - 'The command has no key arguments' in message: + if ( + "Invalid arguments" in message + or "The command has no key arguments" in message + ): return None else: raise e @@ -99,18 +103,17 @@ def _get_pubsub_keys(self, *args): return None args = [str_if_bytes(arg) for arg in args] command = args[0].upper() - if command == 'PUBSUB': + if command == "PUBSUB": # the second argument is a part of the command name, e.g. # ['PUBSUB', 'NUMSUB', 'foo']. pubsub_type = args[1].upper() - if pubsub_type in ['CHANNELS', 'NUMSUB']: + if pubsub_type in ["CHANNELS", "NUMSUB"]: keys = args[2:] - elif command in ['SUBSCRIBE', 'PSUBSCRIBE', 'UNSUBSCRIBE', - 'PUNSUBSCRIBE']: + elif command in ["SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE"]: # format example: # SUBSCRIBE channel [channel ...] keys = list(args[1:]) - elif command == 'PUBLISH': + elif command == "PUBLISH": # format example: # PUBLISH channel message keys = [args[1]] diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 09e50855ba..4ec6fc9dfa 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -274,7 +274,7 @@ def add_document_hash( - **replace**: if True, and the document already is in the index, we perform an update and reindex the document - **language**: Specify the language used for document tokenization. - + For more information: https://oss.redis.com/redisearch/Commands/#ftaddhash """ # noqa return self._add_document_hash( @@ -294,7 +294,7 @@ def delete_document(self, doc_id, conn=None, delete_actual_document=False): - **delete_actual_document**: if set to True, RediSearch also delete the actual document if it is in the index - + For more information: https://oss.redis.com/redisearch/Commands/#ftdel """ # noqa args = [DEL_CMD, self.index_name, doc_id] @@ -768,7 +768,7 @@ def synupdate(self, groupid, skipinitial=False, *terms): If set to true, we do not scan and index. terms : The terms. - + For more information: https://oss.redis.com/redisearch/Commands/#ftsynupdate """ # noqa cmd = [SYNUPDATE_CMD, self.index_name, groupid] diff --git a/redis/commands/search/querystring.py b/redis/commands/search/querystring.py index ffba542e31..1da0387eb8 100644 --- a/redis/commands/search/querystring.py +++ b/redis/commands/search/querystring.py @@ -15,8 +15,7 @@ def between(a, b, inclusive_min=True, inclusive_max=True): """ Indicate that value is a numeric range """ - return RangeValue(a, b, inclusive_min=inclusive_min, - inclusive_max=inclusive_max) + return RangeValue(a, b, inclusive_min=inclusive_min, inclusive_max=inclusive_max) def equal(n): @@ -200,9 +199,7 @@ def join_fields(self, key, vals): return [BaseNode(f"@{key}:{vals[0].to_string()}")] if not vals[0].combinable: return [BaseNode(f"@{key}:{v.to_string()}") for v in vals] - s = BaseNode( - f"@{key}:({self.JOINSTR.join(v.to_string() for v in vals)})" - ) + s = BaseNode(f"@{key}:({self.JOINSTR.join(v.to_string() for v in vals)})") return [s] @classmethod diff --git a/redis/commands/timeseries/commands.py b/redis/commands/timeseries/commands.py index b7e33bc799..c86e0b98b7 100644 --- a/redis/commands/timeseries/commands.py +++ b/redis/commands/timeseries/commands.py @@ -57,7 +57,7 @@ def create(self, key, **kwargs): - 'min': only override if the value is lower than the existing value. - 'max': only override if the value is higher than the existing value. When this is not set, the server-wide default will be used. - + For more information: https://oss.redis.com/redistimeseries/commands/#tscreate """ # noqa retention_msecs = kwargs.get("retention_msecs", None) @@ -80,7 +80,7 @@ def alter(self, key, **kwargs): For more information see The parameters are the same as TS.CREATE. - + For more information: https://oss.redis.com/redistimeseries/commands/#tsalter """ # noqa retention_msecs = kwargs.get("retention_msecs", None) @@ -128,7 +128,7 @@ def add(self, key, timestamp, value, **kwargs): - 'min': only override if the value is lower than the existing value. - 'max': only override if the value is higher than the existing value. When this is not set, the server-wide default will be used. - + For more information: https://oss.redis.com/redistimeseries/master/commands/#tsadd """ # noqa retention_msecs = kwargs.get("retention_msecs", None) diff --git a/redis/connection.py b/redis/connection.py index ef3a667c16..d13fe65ef8 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,8 +1,3 @@ -from packaging.version import Version -from itertools import chain -from time import time -from queue import LifoQueue, Empty, Full -from urllib.parse import parse_qs, unquote, urlparse import copy import errno import io @@ -10,6 +5,12 @@ import socket import threading import weakref +from itertools import chain +from queue import Empty, Full, LifoQueue +from time import time +from urllib.parse import parse_qs, unquote, urlparse + +from packaging.version import Version from redis.backoff import NoBackoff from redis.exceptions import ( @@ -21,20 +22,20 @@ DataError, ExecAbortError, InvalidResponse, + ModuleError, NoPermissionError, NoScriptError, ReadOnlyError, RedisError, ResponseError, TimeoutError, - ModuleError, ) - from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE, str_if_bytes try: import ssl + ssl_available = True except ImportError: ssl_available = False @@ -44,7 +45,7 @@ } if ssl_available: - if hasattr(ssl, 'SSLWantReadError'): + if hasattr(ssl, "SSLWantReadError"): NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantReadError] = 2 NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantWriteError] = 2 else: @@ -56,34 +57,31 @@ import hiredis hiredis_version = Version(hiredis.__version__) - HIREDIS_SUPPORTS_CALLABLE_ERRORS = \ - hiredis_version >= Version('0.1.3') - HIREDIS_SUPPORTS_BYTE_BUFFER = \ - hiredis_version >= Version('0.1.4') - HIREDIS_SUPPORTS_ENCODING_ERRORS = \ - hiredis_version >= Version('1.0.0') + HIREDIS_SUPPORTS_CALLABLE_ERRORS = hiredis_version >= Version("0.1.3") + HIREDIS_SUPPORTS_BYTE_BUFFER = hiredis_version >= Version("0.1.4") + HIREDIS_SUPPORTS_ENCODING_ERRORS = hiredis_version >= Version("1.0.0") HIREDIS_USE_BYTE_BUFFER = True # only use byte buffer if hiredis supports it if not HIREDIS_SUPPORTS_BYTE_BUFFER: HIREDIS_USE_BYTE_BUFFER = False -SYM_STAR = b'*' -SYM_DOLLAR = b'$' -SYM_CRLF = b'\r\n' -SYM_EMPTY = b'' +SYM_STAR = b"*" +SYM_DOLLAR = b"$" +SYM_CRLF = b"\r\n" +SYM_EMPTY = b"" SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." SENTINEL = object() -MODULE_LOAD_ERROR = 'Error loading the extension. ' \ - 'Please check the server logs.' -NO_SUCH_MODULE_ERROR = 'Error unloading module: no such module with that name' -MODULE_UNLOAD_NOT_POSSIBLE_ERROR = 'Error unloading module: operation not ' \ - 'possible.' -MODULE_EXPORTS_DATA_TYPES_ERROR = "Error unloading module: the module " \ - "exports one or more module-side data " \ - "types, can't unload" +MODULE_LOAD_ERROR = "Error loading the extension. " "Please check the server logs." +NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" +MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not " "possible." +MODULE_EXPORTS_DATA_TYPES_ERROR = ( + "Error unloading module: the module " + "exports one or more module-side data " + "types, can't unload" +) class Encoder: @@ -100,15 +98,19 @@ def encode(self, value): return value elif isinstance(value, bool): # special case bool since it is a subclass of int - raise DataError("Invalid input of type: 'bool'. Convert to a " - "bytes, string, int or float first.") + raise DataError( + "Invalid input of type: 'bool'. Convert to a " + "bytes, string, int or float first." + ) elif isinstance(value, (int, float)): value = repr(value).encode() elif not isinstance(value, str): # a value we don't know how to deal with. throw an error typename = type(value).__name__ - raise DataError(f"Invalid input of type: '{typename}'. " - f"Convert to a bytes, string, int or float first.") + raise DataError( + f"Invalid input of type: '{typename}'. " + f"Convert to a bytes, string, int or float first." + ) if isinstance(value, str): value = value.encode(self.encoding, self.encoding_errors) return value @@ -125,36 +127,36 @@ def decode(self, value, force=False): class BaseParser: EXCEPTION_CLASSES = { - 'ERR': { - 'max number of clients reached': ConnectionError, - 'Client sent AUTH, but no password is set': AuthenticationError, - 'invalid password': AuthenticationError, + "ERR": { + "max number of clients reached": ConnectionError, + "Client sent AUTH, but no password is set": AuthenticationError, + "invalid password": AuthenticationError, # some Redis server versions report invalid command syntax # in lowercase - 'wrong number of arguments for \'auth\' command': - AuthenticationWrongNumberOfArgsError, + "wrong number of arguments " + "for 'auth' command": AuthenticationWrongNumberOfArgsError, # some Redis server versions report invalid command syntax # in uppercase - 'wrong number of arguments for \'AUTH\' command': - AuthenticationWrongNumberOfArgsError, + "wrong number of arguments " + "for 'AUTH' command": AuthenticationWrongNumberOfArgsError, MODULE_LOAD_ERROR: ModuleError, MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, NO_SUCH_MODULE_ERROR: ModuleError, MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, }, - 'EXECABORT': ExecAbortError, - 'LOADING': BusyLoadingError, - 'NOSCRIPT': NoScriptError, - 'READONLY': ReadOnlyError, - 'NOAUTH': AuthenticationError, - 'NOPERM': NoPermissionError, + "EXECABORT": ExecAbortError, + "LOADING": BusyLoadingError, + "NOSCRIPT": NoScriptError, + "READONLY": ReadOnlyError, + "NOAUTH": AuthenticationError, + "NOPERM": NoPermissionError, } def parse_error(self, response): "Parse an error response" - error_code = response.split(' ')[0] + error_code = response.split(" ")[0] if error_code in self.EXCEPTION_CLASSES: - response = response[len(error_code) + 1:] + response = response[len(error_code) + 1 :] exception_class = self.EXCEPTION_CLASSES[error_code] if isinstance(exception_class, dict): exception_class = exception_class.get(response, ResponseError) @@ -177,8 +179,7 @@ def __init__(self, socket, socket_read_size, socket_timeout): def length(self): return self.bytes_written - self.bytes_read - def _read_from_socket(self, length=None, timeout=SENTINEL, - raise_on_timeout=True): + def _read_from_socket(self, length=None, timeout=SENTINEL, raise_on_timeout=True): sock = self._sock socket_read_size = self.socket_read_size buf = self._buffer @@ -220,9 +221,9 @@ def _read_from_socket(self, length=None, timeout=SENTINEL, sock.settimeout(self.socket_timeout) def can_read(self, timeout): - return bool(self.length) or \ - self._read_from_socket(timeout=timeout, - raise_on_timeout=False) + return bool(self.length) or self._read_from_socket( + timeout=timeout, raise_on_timeout=False + ) def read(self, length): length = length + 2 # make sure to read the \r\n terminator @@ -283,6 +284,7 @@ def close(self): class PythonParser(BaseParser): "Plain Python parsing class" + def __init__(self, socket_read_size): self.socket_read_size = socket_read_size self.encoder = None @@ -298,9 +300,9 @@ def __del__(self): def on_connect(self, connection): "Called when the socket connects" self._sock = connection._sock - self._buffer = SocketBuffer(self._sock, - self.socket_read_size, - connection.socket_timeout) + self._buffer = SocketBuffer( + self._sock, self.socket_read_size, connection.socket_timeout + ) self.encoder = connection.encoder def on_disconnect(self): @@ -321,12 +323,12 @@ def read_response(self, disable_decoding=False): byte, response = raw[:1], raw[1:] - if byte not in (b'-', b'+', b':', b'$', b'*'): + if byte not in (b"-", b"+", b":", b"$", b"*"): raise InvalidResponse(f"Protocol Error: {raw!r}") # server returned an error - if byte == b'-': - response = response.decode('utf-8', errors='replace') + if byte == b"-": + response = response.decode("utf-8", errors="replace") error = self.parse_error(response) # if the error is a ConnectionError, raise immediately so the user # is notified @@ -338,24 +340,26 @@ def read_response(self, disable_decoding=False): # necessary, so just return the exception instance here. return error # single value - elif byte == b'+': + elif byte == b"+": pass # int value - elif byte == b':': + elif byte == b":": response = int(response) # bulk response - elif byte == b'$': + elif byte == b"$": length = int(response) if length == -1: return None response = self._buffer.read(length) # multi-bulk response - elif byte == b'*': + elif byte == b"*": length = int(response) if length == -1: return None - response = [self.read_response(disable_decoding=disable_decoding) - for i in range(length)] + response = [ + self.read_response(disable_decoding=disable_decoding) + for i in range(length) + ] if isinstance(response, bytes) and disable_decoding is False: response = self.encoder.decode(response) return response @@ -363,6 +367,7 @@ def read_response(self, disable_decoding=False): class HiredisParser(BaseParser): "Parser class for connections using Hiredis" + def __init__(self, socket_read_size): if not HIREDIS_AVAILABLE: raise RedisError("Hiredis is not installed") @@ -381,18 +386,18 @@ def on_connect(self, connection): self._sock = connection._sock self._socket_timeout = connection.socket_timeout kwargs = { - 'protocolError': InvalidResponse, - 'replyError': self.parse_error, + "protocolError": InvalidResponse, + "replyError": self.parse_error, } # hiredis < 0.1.3 doesn't support functions that create exceptions if not HIREDIS_SUPPORTS_CALLABLE_ERRORS: - kwargs['replyError'] = ResponseError + kwargs["replyError"] = ResponseError if connection.encoder.decode_responses: - kwargs['encoding'] = connection.encoder.encoding + kwargs["encoding"] = connection.encoder.encoding if HIREDIS_SUPPORTS_ENCODING_ERRORS: - kwargs['errors'] = connection.encoder.encoding_errors + kwargs["errors"] = connection.encoder.encoding_errors self._reader = hiredis.Reader(**kwargs) self._next_response = False @@ -408,8 +413,7 @@ def can_read(self, timeout): if self._next_response is False: self._next_response = self._reader.gets() if self._next_response is False: - return self.read_from_socket(timeout=timeout, - raise_on_timeout=False) + return self.read_from_socket(timeout=timeout, raise_on_timeout=False) return True def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True): @@ -468,16 +472,22 @@ def read_response(self, disable_decoding=False): if not HIREDIS_SUPPORTS_CALLABLE_ERRORS: if isinstance(response, ResponseError): response = self.parse_error(response.args[0]) - elif isinstance(response, list) and response and \ - isinstance(response[0], ResponseError): + elif ( + isinstance(response, list) + and response + and isinstance(response[0], ResponseError) + ): response[0] = self.parse_error(response[0].args[0]) # if the response is a ConnectionError or the response is a list and # the first item is a ConnectionError, raise it as something bad # happened if isinstance(response, ConnectionError): raise response - elif isinstance(response, list) and response and \ - isinstance(response[0], ConnectionError): + elif ( + isinstance(response, list) + and response + and isinstance(response[0], ConnectionError) + ): raise response[0] return response @@ -491,14 +501,29 @@ def read_response(self, disable_decoding=False): class Connection: "Manages TCP communication to and from a Redis server" - def __init__(self, host='localhost', port=6379, db=0, password=None, - socket_timeout=None, socket_connect_timeout=None, - socket_keepalive=False, socket_keepalive_options=None, - socket_type=0, retry_on_timeout=False, encoding='utf-8', - encoding_errors='strict', decode_responses=False, - parser_class=DefaultParser, socket_read_size=65536, - health_check_interval=0, client_name=None, username=None, - retry=None, redis_connect_func=None): + def __init__( + self, + host="localhost", + port=6379, + db=0, + password=None, + socket_timeout=None, + socket_connect_timeout=None, + socket_keepalive=False, + socket_keepalive_options=None, + socket_type=0, + retry_on_timeout=False, + encoding="utf-8", + encoding_errors="strict", + decode_responses=False, + parser_class=DefaultParser, + socket_read_size=65536, + health_check_interval=0, + client_name=None, + username=None, + retry=None, + redis_connect_func=None, + ): """ Initialize a new Connection. To specify a retry policy, first set `retry_on_timeout` to `True` @@ -536,17 +561,13 @@ def __init__(self, host='localhost', port=6379, db=0, password=None, self._buffer_cutoff = 6000 def __repr__(self): - repr_args = ','.join([f'{k}={v}' for k, v in self.repr_pieces()]) - return f'{self.__class__.__name__}<{repr_args}>' + repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) + return f"{self.__class__.__name__}<{repr_args}>" def repr_pieces(self): - pieces = [ - ('host', self.host), - ('port', self.port), - ('db', self.db) - ] + pieces = [("host", self.host), ("port", self.port), ("db", self.db)] if self.client_name: - pieces.append(('client_name', self.client_name)) + pieces.append(("client_name", self.client_name)) return pieces def __del__(self): @@ -606,8 +627,9 @@ def _connect(self): # ipv4/ipv6, but we want to set options prior to calling # socket.connect() err = None - for res in socket.getaddrinfo(self.host, self.port, self.socket_type, - socket.SOCK_STREAM): + for res in socket.getaddrinfo( + self.host, self.port, self.socket_type, socket.SOCK_STREAM + ): family, socktype, proto, canonname, socket_address = res sock = None try: @@ -658,12 +680,12 @@ def on_connect(self): # if username and/or password are set, authenticate if self.username or self.password: if self.username: - auth_args = (self.username, self.password or '') + auth_args = (self.username, self.password or "") else: auth_args = (self.password,) # avoid checking health here -- PING will fail if we try # to check the health prior to the AUTH - self.send_command('AUTH', *auth_args, check_health=False) + self.send_command("AUTH", *auth_args, check_health=False) try: auth_response = self.read_response() @@ -672,23 +694,23 @@ def on_connect(self): # server seems to be < 6.0.0 which expects a single password # arg. retry auth with just the password. # https://github.com/andymccurdy/redis-py/issues/1274 - self.send_command('AUTH', self.password, check_health=False) + self.send_command("AUTH", self.password, check_health=False) auth_response = self.read_response() - if str_if_bytes(auth_response) != 'OK': - raise AuthenticationError('Invalid Username or Password') + if str_if_bytes(auth_response) != "OK": + raise AuthenticationError("Invalid Username or Password") # if a client_name is given, set it if self.client_name: - self.send_command('CLIENT', 'SETNAME', self.client_name) - if str_if_bytes(self.read_response()) != 'OK': - raise ConnectionError('Error setting client name') + self.send_command("CLIENT", "SETNAME", self.client_name) + if str_if_bytes(self.read_response()) != "OK": + raise ConnectionError("Error setting client name") # if a database is specified, switch to it if self.db: - self.send_command('SELECT', self.db) - if str_if_bytes(self.read_response()) != 'OK': - raise ConnectionError('Invalid Database') + self.send_command("SELECT", self.db) + if str_if_bytes(self.read_response()) != "OK": + raise ConnectionError("Invalid Database") def disconnect(self): "Disconnects from the Redis server" @@ -705,9 +727,9 @@ def disconnect(self): def _send_ping(self): """Send PING, expect PONG in return""" - self.send_command('PING', check_health=False) - if str_if_bytes(self.read_response()) != 'PONG': - raise ConnectionError('Bad response from PING health check') + self.send_command("PING", check_health=False) + if str_if_bytes(self.read_response()) != "PONG": + raise ConnectionError("Bad response from PING health check") def _ping_failed(self, error): """Function to call when PING fails""" @@ -736,7 +758,7 @@ def send_packed_command(self, command, check_health=True): except OSError as e: self.disconnect() if len(e.args) == 1: - errno, errmsg = 'UNKNOWN', e.args[0] + errno, errmsg = "UNKNOWN", e.args[0] else: errno = e.args[0] errmsg = e.args[1] @@ -747,8 +769,9 @@ def send_packed_command(self, command, check_health=True): def send_command(self, *args, **kwargs): """Pack and send a command to the Redis server""" - self.send_packed_command(self.pack_command(*args), - check_health=kwargs.get('check_health', True)) + self.send_packed_command( + self.pack_command(*args), check_health=kwargs.get("check_health", True) + ) def can_read(self, timeout=0): """Poll the socket to see if there's data that can be read.""" @@ -760,17 +783,15 @@ def can_read(self, timeout=0): def read_response(self, disable_decoding=False): """Read the response from a previously sent command""" try: - response = self._parser.read_response( - disable_decoding=disable_decoding - ) + response = self._parser.read_response(disable_decoding=disable_decoding) except socket.timeout: self.disconnect() raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") except OSError as e: self.disconnect() raise ConnectionError( - f"Error while reading from {self.host}:{self.port}" - f" : {e.args}") + f"Error while reading from {self.host}:{self.port}" f" : {e.args}" + ) except BaseException: self.disconnect() raise @@ -792,7 +813,7 @@ def pack_command(self, *args): # not encoded. if isinstance(args[0], str): args = tuple(args[0].encode().split()) + args[1:] - elif b' ' in args[0]: + elif b" " in args[0]: args = tuple(args[0].split()) + args[1:] buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) @@ -802,17 +823,28 @@ def pack_command(self, *args): # to avoid large string mallocs, chunk the command into the # output list if we're sending large values or memoryviews arg_length = len(arg) - if (len(buff) > buffer_cutoff or arg_length > buffer_cutoff - or isinstance(arg, memoryview)): + if ( + len(buff) > buffer_cutoff + or arg_length > buffer_cutoff + or isinstance(arg, memoryview) + ): buff = SYM_EMPTY.join( - (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)) + (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) + ) output.append(buff) output.append(arg) buff = SYM_CRLF else: buff = SYM_EMPTY.join( - (buff, SYM_DOLLAR, str(arg_length).encode(), - SYM_CRLF, arg, SYM_CRLF)) + ( + buff, + SYM_DOLLAR, + str(arg_length).encode(), + SYM_CRLF, + arg, + SYM_CRLF, + ) + ) output.append(buff) return output @@ -826,8 +858,11 @@ def pack_commands(self, commands): for cmd in commands: for chunk in self.pack_command(*cmd): chunklen = len(chunk) - if (buffer_length > buffer_cutoff or chunklen > buffer_cutoff - or isinstance(chunk, memoryview)): + if ( + buffer_length > buffer_cutoff + or chunklen > buffer_cutoff + or isinstance(chunk, memoryview) + ): output.append(SYM_EMPTY.join(pieces)) buffer_length = 0 pieces = [] @@ -844,10 +879,15 @@ def pack_commands(self, commands): class SSLConnection(Connection): - - def __init__(self, ssl_keyfile=None, ssl_certfile=None, - ssl_cert_reqs='required', ssl_ca_certs=None, - ssl_check_hostname=False, **kwargs): + def __init__( + self, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs="required", + ssl_ca_certs=None, + ssl_check_hostname=False, + **kwargs, + ): if not ssl_available: raise RedisError("Python wasn't built with SSL support") @@ -859,13 +899,14 @@ def __init__(self, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs = ssl.CERT_NONE elif isinstance(ssl_cert_reqs, str): CERT_REQS = { - 'none': ssl.CERT_NONE, - 'optional': ssl.CERT_OPTIONAL, - 'required': ssl.CERT_REQUIRED + "none": ssl.CERT_NONE, + "optional": ssl.CERT_OPTIONAL, + "required": ssl.CERT_REQUIRED, } if ssl_cert_reqs not in CERT_REQS: raise RedisError( - f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}") + f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}" + ) ssl_cert_reqs = CERT_REQS[ssl_cert_reqs] self.cert_reqs = ssl_cert_reqs self.ca_certs = ssl_ca_certs @@ -878,22 +919,30 @@ def _connect(self): context.check_hostname = self.check_hostname context.verify_mode = self.cert_reqs if self.certfile and self.keyfile: - context.load_cert_chain(certfile=self.certfile, - keyfile=self.keyfile) + context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile) if self.ca_certs: context.load_verify_locations(self.ca_certs) return context.wrap_socket(sock, server_hostname=self.host) class UnixDomainSocketConnection(Connection): - - def __init__(self, path='', db=0, username=None, password=None, - socket_timeout=None, encoding='utf-8', - encoding_errors='strict', decode_responses=False, - retry_on_timeout=False, - parser_class=DefaultParser, socket_read_size=65536, - health_check_interval=0, client_name=None, - retry=None): + def __init__( + self, + path="", + db=0, + username=None, + password=None, + socket_timeout=None, + encoding="utf-8", + encoding_errors="strict", + decode_responses=False, + retry_on_timeout=False, + parser_class=DefaultParser, + socket_read_size=65536, + health_check_interval=0, + client_name=None, + retry=None, + ): """ Initialize a new UnixDomainSocketConnection. To specify a retry policy, first set `retry_on_timeout` to `True` @@ -926,11 +975,11 @@ def __init__(self, path='', db=0, username=None, password=None, def repr_pieces(self): pieces = [ - ('path', self.path), - ('db', self.db), + ("path", self.path), + ("db", self.db), ] if self.client_name: - pieces.append(('client_name', self.client_name)) + pieces.append(("client_name", self.client_name)) return pieces def _connect(self): @@ -952,11 +1001,11 @@ def _error_message(self, exception): ) -FALSE_STRINGS = ('0', 'F', 'FALSE', 'N', 'NO') +FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") def to_bool(value): - if value is None or value == '': + if value is None or value == "": return None if isinstance(value, str) and value.upper() in FALSE_STRINGS: return False @@ -964,14 +1013,14 @@ def to_bool(value): URL_QUERY_ARGUMENT_PARSERS = { - 'db': int, - 'socket_timeout': float, - 'socket_connect_timeout': float, - 'socket_keepalive': to_bool, - 'retry_on_timeout': to_bool, - 'max_connections': int, - 'health_check_interval': int, - 'ssl_check_hostname': to_bool, + "db": int, + "socket_timeout": float, + "socket_connect_timeout": float, + "socket_keepalive": to_bool, + "retry_on_timeout": to_bool, + "max_connections": int, + "health_check_interval": int, + "ssl_check_hostname": to_bool, } @@ -987,42 +1036,42 @@ def parse_url(url): try: kwargs[name] = parser(value) except (TypeError, ValueError): - raise ValueError( - f"Invalid value for `{name}` in connection URL." - ) + raise ValueError(f"Invalid value for `{name}` in connection URL.") else: kwargs[name] = value if url.username: - kwargs['username'] = unquote(url.username) + kwargs["username"] = unquote(url.username) if url.password: - kwargs['password'] = unquote(url.password) + kwargs["password"] = unquote(url.password) # We only support redis://, rediss:// and unix:// schemes. - if url.scheme == 'unix': + if url.scheme == "unix": if url.path: - kwargs['path'] = unquote(url.path) - kwargs['connection_class'] = UnixDomainSocketConnection + kwargs["path"] = unquote(url.path) + kwargs["connection_class"] = UnixDomainSocketConnection - elif url.scheme in ('redis', 'rediss'): + elif url.scheme in ("redis", "rediss"): if url.hostname: - kwargs['host'] = unquote(url.hostname) + kwargs["host"] = unquote(url.hostname) if url.port: - kwargs['port'] = int(url.port) + kwargs["port"] = int(url.port) # If there's a path argument, use it as the db argument if a # querystring value wasn't specified - if url.path and 'db' not in kwargs: + if url.path and "db" not in kwargs: try: - kwargs['db'] = int(unquote(url.path).replace('/', '')) + kwargs["db"] = int(unquote(url.path).replace("/", "")) except (AttributeError, ValueError): pass - if url.scheme == 'rediss': - kwargs['connection_class'] = SSLConnection + if url.scheme == "rediss": + kwargs["connection_class"] = SSLConnection else: - raise ValueError('Redis URL must specify one of the following ' - 'schemes (redis://, rediss://, unix://)') + raise ValueError( + "Redis URL must specify one of the following " + "schemes (redis://, rediss://, unix://)" + ) return kwargs @@ -1040,6 +1089,7 @@ class ConnectionPool: Any additional keyword arguments are passed to the constructor of ``connection_class``. """ + @classmethod def from_url(cls, url, **kwargs): """ @@ -1084,8 +1134,9 @@ class initializer. In the case of conflicting arguments, querystring kwargs.update(url_options) return cls(**kwargs) - def __init__(self, connection_class=Connection, max_connections=None, - **connection_kwargs): + def __init__( + self, connection_class=Connection, max_connections=None, **connection_kwargs + ): max_connections = max_connections or 2 ** 31 if not isinstance(max_connections, int) or max_connections < 0: raise ValueError('"max_connections" must be a positive integer') @@ -1194,12 +1245,12 @@ def get_connection(self, command_name, *keys, **options): # closed. either way, reconnect and verify everything is good. try: if connection.can_read(): - raise ConnectionError('Connection has data') + raise ConnectionError("Connection has data") except ConnectionError: connection.disconnect() connection.connect() if connection.can_read(): - raise ConnectionError('Connection not ready') + raise ConnectionError("Connection not ready") except BaseException: # release the connection back to the pool so that we don't # leak it @@ -1212,9 +1263,9 @@ def get_encoder(self): "Return an encoder based on encoding settings" kwargs = self.connection_kwargs return Encoder( - encoding=kwargs.get('encoding', 'utf-8'), - encoding_errors=kwargs.get('encoding_errors', 'strict'), - decode_responses=kwargs.get('decode_responses', False) + encoding=kwargs.get("encoding", "utf-8"), + encoding_errors=kwargs.get("encoding_errors", "strict"), + decode_responses=kwargs.get("decode_responses", False), ) def make_connection(self): @@ -1259,8 +1310,9 @@ def disconnect(self, inuse_connections=True): self._checkpid() with self._lock: if inuse_connections: - connections = chain(self._available_connections, - self._in_use_connections) + connections = chain( + self._available_connections, self._in_use_connections + ) else: connections = self._available_connections @@ -1301,16 +1353,23 @@ class BlockingConnectionPool(ConnectionPool): >>> # not available. >>> pool = BlockingConnectionPool(timeout=5) """ - def __init__(self, max_connections=50, timeout=20, - connection_class=Connection, queue_class=LifoQueue, - **connection_kwargs): + + def __init__( + self, + max_connections=50, + timeout=20, + connection_class=Connection, + queue_class=LifoQueue, + **connection_kwargs, + ): self.queue_class = queue_class self.timeout = timeout super().__init__( connection_class=connection_class, max_connections=max_connections, - **connection_kwargs) + **connection_kwargs, + ) def reset(self): # Create and fill up a thread safe queue with ``None`` values. @@ -1381,12 +1440,12 @@ def get_connection(self, command_name, *keys, **options): # closed. either way, reconnect and verify everything is good. try: if connection.can_read(): - raise ConnectionError('Connection has data') + raise ConnectionError("Connection has data") except ConnectionError: connection.disconnect() connection.connect() if connection.can_read(): - raise ConnectionError('Connection not ready') + raise ConnectionError("Connection not ready") except BaseException: # release the connection back to the pool so that we don't leak it self.release(connection) diff --git a/redis/sentinel.py b/redis/sentinel.py index 06877bd167..c9383d30a9 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -3,9 +3,8 @@ from redis.client import Redis from redis.commands import SentinelCommands -from redis.connection import ConnectionPool, Connection, SSLConnection -from redis.exceptions import (ConnectionError, ResponseError, ReadOnlyError, - TimeoutError) +from redis.connection import Connection, ConnectionPool, SSLConnection +from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError from redis.utils import str_if_bytes @@ -19,14 +18,14 @@ class SlaveNotFoundError(ConnectionError): class SentinelManagedConnection(Connection): def __init__(self, **kwargs): - self.connection_pool = kwargs.pop('connection_pool') + self.connection_pool = kwargs.pop("connection_pool") super().__init__(**kwargs) def __repr__(self): pool = self.connection_pool - s = f'{type(self).__name__}' + s = f"{type(self).__name__}" if self.host: - host_info = f',host={self.host},port={self.port}' + host_info = f",host={self.host},port={self.port}" s = s % host_info return s @@ -34,9 +33,9 @@ def connect_to(self, address): self.host, self.port = address super().connect() if self.connection_pool.check_connection: - self.send_command('PING') - if str_if_bytes(self.read_response()) != 'PONG': - raise ConnectionError('PING failed') + self.send_command("PING") + if str_if_bytes(self.read_response()) != "PONG": + raise ConnectionError("PING failed") def connect(self): if self._sock: @@ -62,7 +61,7 @@ def read_response(self, disable_decoding=False): # calling disconnect will force the connection to re-query # sentinel during the next connect() attempt. self.disconnect() - raise ConnectionError('The previous master is now a slave') + raise ConnectionError("The previous master is now a slave") raise @@ -79,19 +78,21 @@ class SentinelConnectionPool(ConnectionPool): """ def __init__(self, service_name, sentinel_manager, **kwargs): - kwargs['connection_class'] = kwargs.get( - 'connection_class', - SentinelManagedSSLConnection if kwargs.pop('ssl', False) - else SentinelManagedConnection) - self.is_master = kwargs.pop('is_master', True) - self.check_connection = kwargs.pop('check_connection', False) + kwargs["connection_class"] = kwargs.get( + "connection_class", + SentinelManagedSSLConnection + if kwargs.pop("ssl", False) + else SentinelManagedConnection, + ) + self.is_master = kwargs.pop("is_master", True) + self.check_connection = kwargs.pop("check_connection", False) super().__init__(**kwargs) - self.connection_kwargs['connection_pool'] = weakref.proxy(self) + self.connection_kwargs["connection_pool"] = weakref.proxy(self) self.service_name = service_name self.sentinel_manager = sentinel_manager def __repr__(self): - role = 'master' if self.is_master else 'slave' + role = "master" if self.is_master else "slave" return f"{type(self).__name__}' def check_master_state(self, state, service_name): - if not state['is_master'] or state['is_sdown'] or state['is_odown']: + if not state["is_master"] or state["is_sdown"] or state["is_odown"]: return False # Check if our sentinel doesn't see other nodes - if state['num-other-sentinels'] < self.min_other_sentinels: + if state["num-other-sentinels"] < self.min_other_sentinels: return False return True @@ -232,17 +238,19 @@ def discover_master(self, service_name): if state and self.check_master_state(state, service_name): # Put this sentinel at the top of the list self.sentinels[0], self.sentinels[sentinel_no] = ( - sentinel, self.sentinels[0]) - return state['ip'], state['port'] + sentinel, + self.sentinels[0], + ) + return state["ip"], state["port"] raise MasterNotFoundError(f"No master found for {service_name!r}") def filter_slaves(self, slaves): "Remove slaves that are in an ODOWN or SDOWN state" slaves_alive = [] for slave in slaves: - if slave['is_odown'] or slave['is_sdown']: + if slave["is_odown"] or slave["is_sdown"]: continue - slaves_alive.append((slave['ip'], slave['port'])) + slaves_alive.append((slave["ip"], slave["port"])) return slaves_alive def discover_slaves(self, service_name): @@ -257,8 +265,13 @@ def discover_slaves(self, service_name): return slaves return [] - def master_for(self, service_name, redis_class=Redis, - connection_pool_class=SentinelConnectionPool, **kwargs): + def master_for( + self, + service_name, + redis_class=Redis, + connection_pool_class=SentinelConnectionPool, + **kwargs, + ): """ Returns a redis client instance for the ``service_name`` master. @@ -281,14 +294,22 @@ def master_for(self, service_name, redis_class=Redis, passed to this class and passed to the connection pool as keyword arguments to be used to initialize Redis connections. """ - kwargs['is_master'] = True + kwargs["is_master"] = True connection_kwargs = dict(self.connection_kwargs) connection_kwargs.update(kwargs) - return redis_class(connection_pool=connection_pool_class( - service_name, self, **connection_kwargs)) - - def slave_for(self, service_name, redis_class=Redis, - connection_pool_class=SentinelConnectionPool, **kwargs): + return redis_class( + connection_pool=connection_pool_class( + service_name, self, **connection_kwargs + ) + ) + + def slave_for( + self, + service_name, + redis_class=Redis, + connection_pool_class=SentinelConnectionPool, + **kwargs, + ): """ Returns redis client instance for the ``service_name`` slave(s). @@ -306,8 +327,11 @@ def slave_for(self, service_name, redis_class=Redis, passed to this class and passed to the connection pool as keyword arguments to be used to initialize Redis connections. """ - kwargs['is_master'] = False + kwargs["is_master"] = False connection_kwargs = dict(self.connection_kwargs) connection_kwargs.update(kwargs) - return redis_class(connection_pool=connection_pool_class( - service_name, self, **connection_kwargs)) + return redis_class( + connection_pool=connection_pool_class( + service_name, self, **connection_kwargs + ) + ) diff --git a/tasks.py b/tasks.py index 8d9c4c64be..9291e7effb 100644 --- a/tasks.py +++ b/tasks.py @@ -1,11 +1,11 @@ import os import shutil -from invoke import task, run -with open('tox.ini') as fp: +from invoke import run, task + +with open("tox.ini") as fp: lines = fp.read().split("\n") - dockers = [line.split("=")[1].strip() for line in lines - if line.find("name") != -1] + dockers = [line.split("=")[1].strip() for line in lines if line.find("name") != -1] @task @@ -14,7 +14,7 @@ def devenv(c): specified in the tox.ini file. """ clean(c) - cmd = 'tox -e devenv' + cmd = "tox -e devenv" for d in dockers: cmd += f" --docker-dont-stop={d}" run(cmd) diff --git a/tests/conftest.py b/tests/conftest.py index 8ed39abddc..24783c0466 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,16 @@ -from redis.backoff import NoBackoff -from redis.retry import Retry -import pytest import random -import redis import time from distutils.version import LooseVersion -from redis.connection import parse_url -from redis.exceptions import RedisClusterException from unittest.mock import Mock from urllib.parse import urlparse +import pytest + +import redis +from redis.backoff import NoBackoff +from redis.connection import parse_url +from redis.exceptions import RedisClusterException +from redis.retry import Retry REDIS_INFO = {} default_redis_url = "redis://localhost:6379/9" @@ -19,29 +20,37 @@ def pytest_addoption(parser): - parser.addoption('--redis-url', default=default_redis_url, - action="store", - help="Redis connection string," - " defaults to `%(default)s`") - - parser.addoption('--redismod-url', default=default_redismod_url, - action="store", - help="Connection string to redis server" - " with loaded modules," - " defaults to `%(default)s`") - - parser.addoption('--redis-cluster-nodes', default=default_cluster_nodes, - action="store", - help="The number of cluster nodes that need to be " - "available before the test can start," - " defaults to `%(default)s`") + parser.addoption( + "--redis-url", + default=default_redis_url, + action="store", + help="Redis connection string," " defaults to `%(default)s`", + ) + + parser.addoption( + "--redismod-url", + default=default_redismod_url, + action="store", + help="Connection string to redis server" + " with loaded modules," + " defaults to `%(default)s`", + ) + + parser.addoption( + "--redis-cluster-nodes", + default=default_cluster_nodes, + action="store", + help="The number of cluster nodes that need to be " + "available before the test can start," + " defaults to `%(default)s`", + ) def _get_info(redis_url): client = redis.Redis.from_url(redis_url) info = client.info() cmds = [command.upper() for command in client.command().keys()] - if 'dping' in cmds: + if "dping" in cmds: info["enterprise"] = True else: info["enterprise"] = False @@ -102,42 +111,39 @@ def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=20): available_nodes = 0 if client is None else len(client.get_nodes()) raise RedisClusterException( f"The cluster did not become available after {timeout} seconds. " - f"Only {available_nodes} nodes out of {cluster_nodes} are available") + f"Only {available_nodes} nodes out of {cluster_nodes} are available" + ) def skip_if_server_version_lt(min_version): redis_version = REDIS_INFO["version"] check = LooseVersion(redis_version) < LooseVersion(min_version) - return pytest.mark.skipif( - check, - reason=f"Redis version required >= {min_version}") + return pytest.mark.skipif(check, reason=f"Redis version required >= {min_version}") def skip_if_server_version_gte(min_version): redis_version = REDIS_INFO["version"] check = LooseVersion(redis_version) >= LooseVersion(min_version) - return pytest.mark.skipif( - check, - reason=f"Redis version required < {min_version}") + return pytest.mark.skipif(check, reason=f"Redis version required < {min_version}") def skip_unless_arch_bits(arch_bits): - return pytest.mark.skipif(REDIS_INFO["arch_bits"] != arch_bits, - reason=f"server is not {arch_bits}-bit") + return pytest.mark.skipif( + REDIS_INFO["arch_bits"] != arch_bits, reason=f"server is not {arch_bits}-bit" + ) def skip_ifmodversion_lt(min_version: str, module_name: str): try: modules = REDIS_INFO["modules"] except KeyError: - return pytest.mark.skipif(True, - reason="Redis server does not have modules") + return pytest.mark.skipif(True, reason="Redis server does not have modules") if modules == []: return pytest.mark.skipif(True, reason="No redis modules found") for j in modules: - if module_name == j.get('name'): - version = j.get('ver') + if module_name == j.get("name"): + version = j.get("ver") mv = int(min_version.replace(".", "")) check = version < mv return pytest.mark.skipif(check, reason="Redis module version") @@ -155,9 +161,9 @@ def skip_ifnot_redis_enterprise(func): return pytest.mark.skipif(check, reason="Not running in redis enterprise") -def _get_client(cls, request, single_connection_client=True, flushdb=True, - from_url=None, - **kwargs): +def _get_client( + cls, request, single_connection_client=True, flushdb=True, from_url=None, **kwargs +): """ Helper for fixtures or tests that need a Redis client @@ -181,6 +187,7 @@ def _get_client(cls, request, single_connection_client=True, flushdb=True, if single_connection_client: client = client.client() if request: + def teardown(): if not cluster_mode: if flushdb: @@ -194,6 +201,7 @@ def teardown(): client.connection_pool.disconnect() else: cluster_teardown(client, flushdb) + request.addfinalizer(teardown) return client @@ -201,11 +209,11 @@ def teardown(): def cluster_teardown(client, flushdb): if flushdb: try: - client.flushdb(target_nodes='primaries') + client.flushdb(target_nodes="primaries") except redis.ConnectionError: # handle cases where a test disconnected a client # just manually retry the flushdb - client.flushdb(target_nodes='primaries') + client.flushdb(target_nodes="primaries") client.close() client.disconnect_connection_pools() @@ -214,9 +222,10 @@ def cluster_teardown(client, flushdb): # an index on db != 0 raises a ResponseError in redis @pytest.fixture() def modclient(request, **kwargs): - rmurl = request.config.getoption('--redismod-url') - with _get_client(redis.Redis, request, from_url=rmurl, - decode_responses=True, **kwargs) as client: + rmurl = request.config.getoption("--redismod-url") + with _get_client( + redis.Redis, request, from_url=rmurl, decode_responses=True, **kwargs + ) as client: yield client @@ -250,56 +259,61 @@ def _gen_cluster_mock_resp(r, response): @pytest.fixture() def mock_cluster_resp_ok(request, **kwargs): r = _get_client(redis.Redis, request, **kwargs) - return _gen_cluster_mock_resp(r, 'OK') + return _gen_cluster_mock_resp(r, "OK") @pytest.fixture() def mock_cluster_resp_int(request, **kwargs): r = _get_client(redis.Redis, request, **kwargs) - return _gen_cluster_mock_resp(r, '2') + return _gen_cluster_mock_resp(r, "2") @pytest.fixture() def mock_cluster_resp_info(request, **kwargs): r = _get_client(redis.Redis, request, **kwargs) - response = ('cluster_state:ok\r\ncluster_slots_assigned:16384\r\n' - 'cluster_slots_ok:16384\r\ncluster_slots_pfail:0\r\n' - 'cluster_slots_fail:0\r\ncluster_known_nodes:7\r\n' - 'cluster_size:3\r\ncluster_current_epoch:7\r\n' - 'cluster_my_epoch:2\r\ncluster_stats_messages_sent:170262\r\n' - 'cluster_stats_messages_received:105653\r\n') + response = ( + "cluster_state:ok\r\ncluster_slots_assigned:16384\r\n" + "cluster_slots_ok:16384\r\ncluster_slots_pfail:0\r\n" + "cluster_slots_fail:0\r\ncluster_known_nodes:7\r\n" + "cluster_size:3\r\ncluster_current_epoch:7\r\n" + "cluster_my_epoch:2\r\ncluster_stats_messages_sent:170262\r\n" + "cluster_stats_messages_received:105653\r\n" + ) return _gen_cluster_mock_resp(r, response) @pytest.fixture() def mock_cluster_resp_nodes(request, **kwargs): r = _get_client(redis.Redis, request, **kwargs) - response = ('c8253bae761cb1ecb2b61857d85dfe455a0fec8b 172.17.0.7:7006 ' - 'slave aa90da731f673a99617dfe930306549a09f83a6b 0 ' - '1447836263059 5 connected\n' - '9bd595fe4821a0e8d6b99d70faa660638a7612b3 172.17.0.7:7008 ' - 'master - 0 1447836264065 0 connected\n' - 'aa90da731f673a99617dfe930306549a09f83a6b 172.17.0.7:7003 ' - 'myself,master - 0 0 2 connected 5461-10922\n' - '1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 ' - 'slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 ' - '1447836262556 3 connected\n' - '4ad9a12e63e8f0207025eeba2354bcf4c85e5b22 172.17.0.7:7005 ' - 'master - 0 1447836262555 7 connected 0-5460\n' - '19efe5a631f3296fdf21a5441680f893e8cc96ec 172.17.0.7:7004 ' - 'master - 0 1447836263562 3 connected 10923-16383\n' - 'fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 ' - 'master,fail - 1447829446956 1447829444948 1 disconnected\n' - ) + response = ( + "c8253bae761cb1ecb2b61857d85dfe455a0fec8b 172.17.0.7:7006 " + "slave aa90da731f673a99617dfe930306549a09f83a6b 0 " + "1447836263059 5 connected\n" + "9bd595fe4821a0e8d6b99d70faa660638a7612b3 172.17.0.7:7008 " + "master - 0 1447836264065 0 connected\n" + "aa90da731f673a99617dfe930306549a09f83a6b 172.17.0.7:7003 " + "myself,master - 0 0 2 connected 5461-10922\n" + "1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 " + "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " + "1447836262556 3 connected\n" + "4ad9a12e63e8f0207025eeba2354bcf4c85e5b22 172.17.0.7:7005 " + "master - 0 1447836262555 7 connected 0-5460\n" + "19efe5a631f3296fdf21a5441680f893e8cc96ec 172.17.0.7:7004 " + "master - 0 1447836263562 3 connected 10923-16383\n" + "fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 " + "master,fail - 1447829446956 1447829444948 1 disconnected\n" + ) return _gen_cluster_mock_resp(r, response) @pytest.fixture() def mock_cluster_resp_slaves(request, **kwargs): r = _get_client(redis.Redis, request, **kwargs) - response = ("['1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 " - "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " - "1447836789290 3 connected']") + response = ( + "['1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 " + "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " + "1447836789290 3 connected']" + ) return _gen_cluster_mock_resp(r, response) @@ -315,15 +329,15 @@ def wait_for_command(client, monitor, command): # if we find a command with our key before the command we're waiting # for, something went wrong redis_version = REDIS_INFO["version"] - if LooseVersion(redis_version) >= LooseVersion('5.0.0'): + if LooseVersion(redis_version) >= LooseVersion("5.0.0"): id_str = str(client.client_id()) else: - id_str = f'{random.randrange(2 ** 32):08x}' - key = f'__REDIS-PY-{id_str}__' + id_str = f"{random.randrange(2 ** 32):08x}" + key = f"__REDIS-PY-{id_str}__" client.get(key) while True: monitor_response = monitor.next_command() - if command in monitor_response['command']: + if command in monitor_response["command"]: return monitor_response - if key in monitor_response['command']: + if key in monitor_response["command"]: return None diff --git a/tests/test_cluster.py b/tests/test_cluster.py index d12e47ed02..84d74bd43b 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -1,46 +1,47 @@ import binascii import datetime -import pytest import warnings - from time import sleep -from tests.test_pubsub import wait_for_message -from unittest.mock import call, patch, DEFAULT, Mock +from unittest.mock import DEFAULT, Mock, call, patch + +import pytest + from redis import Redis -from redis.cluster import get_node_name, ClusterNode, \ - RedisCluster, NodesManager, PRIMARY, REDIS_CLUSTER_HASH_SLOTS, REPLICA +from redis.cluster import ( + PRIMARY, + REDIS_CLUSTER_HASH_SLOTS, + REPLICA, + ClusterNode, + NodesManager, + RedisCluster, + get_node_name, +) from redis.commands import CommandsParser from redis.connection import Connection -from redis.utils import str_if_bytes +from redis.crc import key_slot from redis.exceptions import ( AskError, ClusterDownError, DataError, MovedError, RedisClusterException, - RedisError + RedisError, ) +from redis.utils import str_if_bytes +from tests.test_pubsub import wait_for_message -from redis.crc import key_slot -from .conftest import ( - _get_client, - skip_if_server_version_lt, - skip_unless_arch_bits -) +from .conftest import _get_client, skip_if_server_version_lt, skip_unless_arch_bits default_host = "127.0.0.1" default_port = 7000 default_cluster_slots = [ [ - 0, 8191, - ['127.0.0.1', 7000, 'node_0'], - ['127.0.0.1', 7003, 'node_3'], + 0, + 8191, + ["127.0.0.1", 7000, "node_0"], + ["127.0.0.1", 7003, "node_3"], ], - [ - 8192, 16383, - ['127.0.0.1', 7001, 'node_1'], - ['127.0.0.1', 7002, 'node_2'] - ] + [8192, 16383, ["127.0.0.1", 7001, "node_1"], ["127.0.0.1", 7002, "node_2"]], ] @@ -53,21 +54,20 @@ def slowlog(request, r): to test it """ # Save old values - current_config = r.config_get( - target_nodes=r.get_primaries()[0]) - old_slower_than_value = current_config['slowlog-log-slower-than'] - old_max_legnth_value = current_config['slowlog-max-len'] + current_config = r.config_get(target_nodes=r.get_primaries()[0]) + old_slower_than_value = current_config["slowlog-log-slower-than"] + old_max_legnth_value = current_config["slowlog-max-len"] # Function to restore the old values def cleanup(): - r.config_set('slowlog-log-slower-than', old_slower_than_value) - r.config_set('slowlog-max-len', old_max_legnth_value) + r.config_set("slowlog-log-slower-than", old_slower_than_value) + r.config_set("slowlog-max-len", old_max_legnth_value) request.addfinalizer(cleanup) # Set the new values - r.config_set('slowlog-log-slower-than', 0) - r.config_set('slowlog-max-len', 128) + r.config_set("slowlog-log-slower-than", 0) + r.config_set("slowlog-max-len", 128) def get_mocked_redis_client(func=None, *args, **kwargs): @@ -76,17 +76,18 @@ def get_mocked_redis_client(func=None, *args, **kwargs): nodes and slots setup to remove the problem of different IP addresses on different installations and machines. """ - cluster_slots = kwargs.pop('cluster_slots', default_cluster_slots) - coverage_res = kwargs.pop('coverage_result', 'yes') - with patch.object(Redis, 'execute_command') as execute_command_mock: + cluster_slots = kwargs.pop("cluster_slots", default_cluster_slots) + coverage_res = kwargs.pop("coverage_result", "yes") + with patch.object(Redis, "execute_command") as execute_command_mock: + def execute_command(*_args, **_kwargs): - if _args[0] == 'CLUSTER SLOTS': + if _args[0] == "CLUSTER SLOTS": mock_cluster_slots = cluster_slots return mock_cluster_slots - elif _args[0] == 'COMMAND': - return {'get': [], 'set': []} - elif _args[1] == 'cluster-require-full-coverage': - return {'cluster-require-full-coverage': coverage_res} + elif _args[0] == "COMMAND": + return {"get": [], "set": []} + elif _args[1] == "cluster-require-full-coverage": + return {"cluster-require-full-coverage": coverage_res} elif func is not None: return func(*args, **kwargs) else: @@ -94,16 +95,21 @@ def execute_command(*_args, **_kwargs): execute_command_mock.side_effect = execute_command - with patch.object(CommandsParser, 'initialize', - autospec=True) as cmd_parser_initialize: + with patch.object( + CommandsParser, "initialize", autospec=True + ) as cmd_parser_initialize: def cmd_init_mock(self, r): - self.commands = {'get': {'name': 'get', 'arity': 2, - 'flags': ['readonly', - 'fast'], - 'first_key_pos': 1, - 'last_key_pos': 1, - 'step_count': 1}} + self.commands = { + "get": { + "name": "get", + "arity": 2, + "flags": ["readonly", "fast"], + "first_key_pos": 1, + "last_key_pos": 1, + "step_count": 1, + } + } cmd_parser_initialize.side_effect = cmd_init_mock @@ -138,21 +144,21 @@ def find_node_ip_based_on_port(cluster_client, port): def moved_redirection_helper(request, failover=False): """ - Test that the client handles MOVED response after a failover. - Redirection after a failover means that the redirection address is of a - replica that was promoted to a primary. + Test that the client handles MOVED response after a failover. + Redirection after a failover means that the redirection address is of a + replica that was promoted to a primary. - At first call it should return a MOVED ResponseError that will point - the client to the next server it should talk to. + At first call it should return a MOVED ResponseError that will point + the client to the next server it should talk to. - Verify that: - 1. it tries to talk to the redirected node - 2. it updates the slot's primary to the redirected node + Verify that: + 1. it tries to talk to the redirected node + 2. it updates the slot's primary to the redirected node - For a failover, also verify: - 3. the redirected node's server type updated to 'primary' - 4. the server type of the previous slot owner updated to 'replica' - """ + For a failover, also verify: + 3. the redirected node's server type updated to 'primary' + 4. the server type of the previous slot owner updated to 'replica' + """ rc = _get_client(RedisCluster, request, flushdb=False) slot = 12182 redirect_node = None @@ -160,8 +166,7 @@ def moved_redirection_helper(request, failover=False): prev_primary = rc.nodes_manager.get_node_from_slot(slot) if failover: if len(rc.nodes_manager.slots_cache[slot]) < 2: - warnings.warn("Skipping this test since it requires to have a " - "replica") + warnings.warn("Skipping this test since it requires to have a " "replica") return redirect_node = rc.nodes_manager.slots_cache[slot][1] else: @@ -169,7 +174,8 @@ def moved_redirection_helper(request, failover=False): redirect_node = rc.get_primaries()[0] r_host = redirect_node.host r_port = redirect_node.port - with patch.object(Redis, 'parse_response') as parse_response: + with patch.object(Redis, "parse_response") as parse_response: + def moved_redirect_effect(connection, *args, **options): def ok_response(connection, *args, **options): assert connection.host == r_host @@ -201,8 +207,7 @@ def test_host_port_startup_node(self): args """ cluster = get_mocked_redis_client(host=default_host, port=default_port) - assert cluster.get_node(host=default_host, - port=default_port) is not None + assert cluster.get_node(host=default_host, port=default_port) is not None def test_startup_nodes(self): """ @@ -211,11 +216,15 @@ def test_startup_nodes(self): """ port_1 = 7000 port_2 = 7001 - startup_nodes = [ClusterNode(default_host, port_1), - ClusterNode(default_host, port_2)] + startup_nodes = [ + ClusterNode(default_host, port_1), + ClusterNode(default_host, port_2), + ] cluster = get_mocked_redis_client(startup_nodes=startup_nodes) - assert cluster.get_node(host=default_host, port=port_1) is not None \ - and cluster.get_node(host=default_host, port=port_2) is not None + assert ( + cluster.get_node(host=default_host, port=port_1) is not None + and cluster.get_node(host=default_host, port=port_2) is not None + ) def test_empty_startup_nodes(self): """ @@ -225,19 +234,19 @@ def test_empty_startup_nodes(self): RedisCluster(startup_nodes=[]) assert str(ex.value).startswith( - "RedisCluster requires at least one node to discover the " - "cluster"), str_if_bytes(ex.value) + "RedisCluster requires at least one node to discover the " "cluster" + ), str_if_bytes(ex.value) def test_from_url(self, r): redis_url = f"redis://{default_host}:{default_port}/0" - with patch.object(RedisCluster, 'from_url') as from_url: + with patch.object(RedisCluster, "from_url") as from_url: + def from_url_mocked(_url, **_kwargs): return get_mocked_redis_client(url=_url, **_kwargs) from_url.side_effect = from_url_mocked cluster = RedisCluster.from_url(redis_url) - assert cluster.get_node(host=default_host, - port=default_port) is not None + assert cluster.get_node(host=default_host, port=default_port) is not None def test_execute_command_errors(self, r): """ @@ -245,8 +254,9 @@ def test_execute_command_errors(self, r): """ with pytest.raises(RedisClusterException) as ex: r.execute_command("GET") - assert str(ex.value).startswith("No way to dispatch this command to " - "Redis Cluster. Missing key.") + assert str(ex.value).startswith( + "No way to dispatch this command to " "Redis Cluster. Missing key." + ) def test_execute_command_node_flag_primaries(self, r): """ @@ -254,7 +264,7 @@ def test_execute_command_node_flag_primaries(self, r): """ primaries = r.get_primaries() replicas = r.get_replicas() - mock_all_nodes_resp(r, 'PONG') + mock_all_nodes_resp(r, "PONG") assert r.ping(RedisCluster.PRIMARIES) is True for primary in primaries: conn = primary.redis_connection.connection @@ -271,7 +281,7 @@ def test_execute_command_node_flag_replicas(self, r): if not replicas: r = get_mocked_redis_client(default_host, default_port) primaries = r.get_primaries() - mock_all_nodes_resp(r, 'PONG') + mock_all_nodes_resp(r, "PONG") assert r.ping(RedisCluster.REPLICAS) is True for replica in replicas: conn = replica.redis_connection.connection @@ -284,7 +294,7 @@ def test_execute_command_node_flag_all_nodes(self, r): """ Test command execution with nodes flag ALL_NODES """ - mock_all_nodes_resp(r, 'PONG') + mock_all_nodes_resp(r, "PONG") assert r.ping(RedisCluster.ALL_NODES) is True for node in r.get_nodes(): conn = node.redis_connection.connection @@ -294,7 +304,7 @@ def test_execute_command_node_flag_random(self, r): """ Test command execution with nodes flag RANDOM """ - mock_all_nodes_resp(r, 'PONG') + mock_all_nodes_resp(r, "PONG") assert r.ping(RedisCluster.RANDOM) is True called_count = 0 for node in r.get_nodes(): @@ -309,7 +319,7 @@ def test_execute_command_default_node(self, r): default node """ def_node = r.get_default_node() - mock_node_resp(def_node, 'PONG') + mock_node_resp(def_node, "PONG") assert r.ping() is True conn = def_node.redis_connection.connection assert conn.read_response.called @@ -324,7 +334,8 @@ def test_ask_redirection(self, r): Important thing to verify is that it tries to talk to the second node. """ redirect_node = r.get_nodes()[0] - with patch.object(Redis, 'parse_response') as parse_response: + with patch.object(Redis, "parse_response") as parse_response: + def ask_redirect_effect(connection, *args, **options): def ok_response(connection, *args, **options): assert connection.host == redirect_node.host @@ -356,26 +367,22 @@ def test_refresh_using_specific_nodes(self, request): Test making calls on specific nodes when the cluster has failed over to another node """ - node_7006 = ClusterNode(host=default_host, port=7006, - server_type=PRIMARY) - node_7007 = ClusterNode(host=default_host, port=7007, - server_type=PRIMARY) - with patch.object(Redis, 'parse_response') as parse_response: - with patch.object(NodesManager, 'initialize', autospec=True) as \ - initialize: - with patch.multiple(Connection, - send_command=DEFAULT, - connect=DEFAULT, - can_read=DEFAULT) as mocks: + node_7006 = ClusterNode(host=default_host, port=7006, server_type=PRIMARY) + node_7007 = ClusterNode(host=default_host, port=7007, server_type=PRIMARY) + with patch.object(Redis, "parse_response") as parse_response: + with patch.object(NodesManager, "initialize", autospec=True) as initialize: + with patch.multiple( + Connection, send_command=DEFAULT, connect=DEFAULT, can_read=DEFAULT + ) as mocks: # simulate 7006 as a failed node - def parse_response_mock(connection, command_name, - **options): + def parse_response_mock(connection, command_name, **options): if connection.port == 7006: parse_response.failed_calls += 1 raise ClusterDownError( - 'CLUSTERDOWN The cluster is ' - 'down. Use CLUSTER INFO for ' - 'more information') + "CLUSTERDOWN The cluster is " + "down. Use CLUSTER INFO for " + "more information" + ) elif connection.port == 7007: parse_response.successful_calls += 1 @@ -391,8 +398,7 @@ def initialize_mock(self): # After the first connection fails, a reinitialize # should follow the cluster to 7007 def map_7007(self): - self.nodes_cache = { - node_7007.name: node_7007} + self.nodes_cache = {node_7007.name: node_7007} self.default_node = node_7007 self.slots_cache = {} @@ -406,44 +412,52 @@ def map_7007(self): parse_response.successful_calls = 0 parse_response.failed_calls = 0 initialize.side_effect = initialize_mock - mocks['can_read'].return_value = False - mocks['send_command'].return_value = "MOCK_OK" - mocks['connect'].return_value = None - with patch.object(CommandsParser, 'initialize', - autospec=True) as cmd_parser_initialize: + mocks["can_read"].return_value = False + mocks["send_command"].return_value = "MOCK_OK" + mocks["connect"].return_value = None + with patch.object( + CommandsParser, "initialize", autospec=True + ) as cmd_parser_initialize: def cmd_init_mock(self, r): - self.commands = {'get': {'name': 'get', 'arity': 2, - 'flags': ['readonly', - 'fast'], - 'first_key_pos': 1, - 'last_key_pos': 1, - 'step_count': 1}} + self.commands = { + "get": { + "name": "get", + "arity": 2, + "flags": ["readonly", "fast"], + "first_key_pos": 1, + "last_key_pos": 1, + "step_count": 1, + } + } cmd_parser_initialize.side_effect = cmd_init_mock - rc = _get_client( - RedisCluster, request, flushdb=False) + rc = _get_client(RedisCluster, request, flushdb=False) assert len(rc.get_nodes()) == 1 - assert rc.get_node(node_name=node_7006.name) is not \ - None + assert rc.get_node(node_name=node_7006.name) is not None - rc.get('foo') + rc.get("foo") # Cluster should now point to 7007, and there should be # one failed and one successful call assert len(rc.get_nodes()) == 1 - assert rc.get_node(node_name=node_7007.name) is not \ - None + assert rc.get_node(node_name=node_7007.name) is not None assert rc.get_node(node_name=node_7006.name) is None assert parse_response.failed_calls == 1 assert parse_response.successful_calls == 1 def test_reading_from_replicas_in_round_robin(self): - with patch.multiple(Connection, send_command=DEFAULT, - read_response=DEFAULT, _connect=DEFAULT, - can_read=DEFAULT, on_connect=DEFAULT) as mocks: - with patch.object(Redis, 'parse_response') as parse_response: + with patch.multiple( + Connection, + send_command=DEFAULT, + read_response=DEFAULT, + _connect=DEFAULT, + can_read=DEFAULT, + on_connect=DEFAULT, + ) as mocks: + with patch.object(Redis, "parse_response") as parse_response: + def parse_response_mock_first(connection, *args, **options): # Primary assert connection.port == 7001 @@ -465,16 +479,16 @@ def parse_response_mock_third(connection, *args, **options): # do want RedisCluster.on_connect function to get called, # so we'll mock some of the Connection's functions to allow it parse_response.side_effect = parse_response_mock_first - mocks['send_command'].return_value = True - mocks['read_response'].return_value = "OK" - mocks['_connect'].return_value = True - mocks['can_read'].return_value = False - mocks['on_connect'].return_value = True + mocks["send_command"].return_value = True + mocks["read_response"].return_value = "OK" + mocks["_connect"].return_value = True + mocks["can_read"].return_value = False + mocks["on_connect"].return_value = True # Create a cluster with reading from replications - read_cluster = get_mocked_redis_client(host=default_host, - port=default_port, - read_from_replicas=True) + read_cluster = get_mocked_redis_client( + host=default_host, port=default_port, read_from_replicas=True + ) assert read_cluster.read_from_replicas is True # Check that we read from the slot's nodes in a round robin # matter. @@ -483,7 +497,7 @@ def parse_response_mock_third(connection, *args, **options): read_cluster.get("foo") read_cluster.get("foo") read_cluster.get("foo") - mocks['send_command'].assert_has_calls([call('READONLY')]) + mocks["send_command"].assert_has_calls([call("READONLY")]) def test_keyslot(self, r): """ @@ -503,8 +517,10 @@ def test_keyslot(self, r): assert r.keyslot(b"abc") == r.keyslot("abc") def test_get_node_name(self): - assert get_node_name(default_host, default_port) == \ - f"{default_host}:{default_port}" + assert ( + get_node_name(default_host, default_port) + == f"{default_host}:{default_port}" + ) def test_all_nodes(self, r): """ @@ -520,8 +536,11 @@ def test_all_nodes_masters(self, r): Set a list of nodes with random primaries/replicas config and it shold be possible to iterate over all of them. """ - nodes = [node for node in r.nodes_manager.nodes_cache.values() - if node.server_type == PRIMARY] + nodes = [ + node + for node in r.nodes_manager.nodes_cache.values() + if node.server_type == PRIMARY + ] for node in r.get_primaries(): assert node in nodes @@ -532,12 +551,14 @@ def test_cluster_down_overreaches_retry_attempts(self): command as many times as configured in cluster_error_retry_attempts and then raise the exception """ - with patch.object(RedisCluster, '_execute_command') as execute_command: + with patch.object(RedisCluster, "_execute_command") as execute_command: + def raise_cluster_down_error(target_node, *args, **kwargs): execute_command.failed_calls += 1 raise ClusterDownError( - 'CLUSTERDOWN The cluster is down. Use CLUSTER INFO for ' - 'more information') + "CLUSTERDOWN The cluster is down. Use CLUSTER INFO for " + "more information" + ) execute_command.side_effect = raise_cluster_down_error @@ -545,8 +566,7 @@ def raise_cluster_down_error(target_node, *args, **kwargs): with pytest.raises(ClusterDownError): rc.get("bar") - assert execute_command.failed_calls == \ - rc.cluster_error_retry_attempts + assert execute_command.failed_calls == rc.cluster_error_retry_attempts def test_connection_error_overreaches_retry_attempts(self): """ @@ -554,7 +574,8 @@ def test_connection_error_overreaches_retry_attempts(self): command as many times as configured in cluster_error_retry_attempts and then raise the exception """ - with patch.object(RedisCluster, '_execute_command') as execute_command: + with patch.object(RedisCluster, "_execute_command") as execute_command: + def raise_conn_error(target_node, *args, **kwargs): execute_command.failed_calls += 1 raise ConnectionError() @@ -565,8 +586,7 @@ def raise_conn_error(target_node, *args, **kwargs): with pytest.raises(ConnectionError): rc.get("bar") - assert execute_command.failed_calls == \ - rc.cluster_error_retry_attempts + assert execute_command.failed_calls == rc.cluster_error_retry_attempts def test_user_on_connect_function(self, request): """ @@ -600,7 +620,7 @@ def test_set_default_node_failure(self, r): test failed replacement of the default cluster node """ default_node = r.get_default_node() - new_def_node = ClusterNode('1.1.1.1', 1111) + new_def_node = ClusterNode("1.1.1.1", 1111) assert r.set_default_node(None) is False assert r.set_default_node(new_def_node) is False assert r.get_default_node() == default_node @@ -609,7 +629,7 @@ def test_get_node_from_key(self, r): """ Test that get_node_from_key function returns the correct node """ - key = 'bar' + key = "bar" slot = r.keyslot(key) slot_nodes = r.nodes_manager.slots_cache.get(slot) primary = slot_nodes[0] @@ -627,78 +647,79 @@ class TestClusterRedisCommands: """ def test_case_insensitive_command_names(self, r): - assert r.cluster_response_callbacks['cluster addslots'] == \ - r.cluster_response_callbacks['CLUSTER ADDSLOTS'] + assert ( + r.cluster_response_callbacks["cluster addslots"] + == r.cluster_response_callbacks["CLUSTER ADDSLOTS"] + ) def test_get_and_set(self, r): # get and set can't be tested independently of each other - assert r.get('a') is None - byte_string = b'value' + assert r.get("a") is None + byte_string = b"value" integer = 5 - unicode_string = chr(3456) + 'abcd' + chr(3421) - assert r.set('byte_string', byte_string) - assert r.set('integer', 5) - assert r.set('unicode_string', unicode_string) - assert r.get('byte_string') == byte_string - assert r.get('integer') == str(integer).encode() - assert r.get('unicode_string').decode('utf-8') == unicode_string + unicode_string = chr(3456) + "abcd" + chr(3421) + assert r.set("byte_string", byte_string) + assert r.set("integer", 5) + assert r.set("unicode_string", unicode_string) + assert r.get("byte_string") == byte_string + assert r.get("integer") == str(integer).encode() + assert r.get("unicode_string").decode("utf-8") == unicode_string def test_mget_nonatomic(self, r): assert r.mget_nonatomic([]) == [] - assert r.mget_nonatomic(['a', 'b']) == [None, None] - r['a'] = '1' - r['b'] = '2' - r['c'] = '3' + assert r.mget_nonatomic(["a", "b"]) == [None, None] + r["a"] = "1" + r["b"] = "2" + r["c"] = "3" - assert (r.mget_nonatomic('a', 'other', 'b', 'c') == - [b'1', None, b'2', b'3']) + assert r.mget_nonatomic("a", "other", "b", "c") == [b"1", None, b"2", b"3"] def test_mset_nonatomic(self, r): - d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} + d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} assert r.mset_nonatomic(d) for k, v in d.items(): assert r[k] == v def test_config_set(self, r): - assert r.config_set('slowlog-log-slower-than', 0) + assert r.config_set("slowlog-log-slower-than", 0) def test_cluster_config_resetstat(self, r): - r.ping(target_nodes='all') - all_info = r.info(target_nodes='all') + r.ping(target_nodes="all") + all_info = r.info(target_nodes="all") prior_commands_processed = -1 for node_info in all_info.values(): - prior_commands_processed = node_info['total_commands_processed'] + prior_commands_processed = node_info["total_commands_processed"] assert prior_commands_processed >= 1 - r.config_resetstat(target_nodes='all') - all_info = r.info(target_nodes='all') + r.config_resetstat(target_nodes="all") + all_info = r.info(target_nodes="all") for node_info in all_info.values(): - reset_commands_processed = node_info['total_commands_processed'] + reset_commands_processed = node_info["total_commands_processed"] assert reset_commands_processed < prior_commands_processed def test_client_setname(self, r): node = r.get_random_node() - r.client_setname('redis_py_test', target_nodes=node) + r.client_setname("redis_py_test", target_nodes=node) client_name = r.client_getname(target_nodes=node) - assert client_name == 'redis_py_test' + assert client_name == "redis_py_test" def test_exists(self, r): - d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} + d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} r.mset_nonatomic(d) assert r.exists(*d.keys()) == len(d) def test_delete(self, r): - d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} + d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} r.mset_nonatomic(d) assert r.delete(*d.keys()) == len(d) assert r.delete(*d.keys()) == 0 def test_touch(self, r): - d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} + d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} r.mset_nonatomic(d) assert r.touch(*d.keys()) == len(d) def test_unlink(self, r): - d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} + d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} r.mset_nonatomic(d) assert r.unlink(*d.keys()) == len(d) # Unlink is non-blocking so we sleep before @@ -718,7 +739,7 @@ def test_pubsub_channels_merge_results(self, r): p = r.pubsub(node) pubsub_nodes.append(p) p.subscribe(channel) - b_channel = channel.encode('utf-8') + b_channel = channel.encode("utf-8") channels.append(b_channel) # Assert that each node returns only the channel it subscribed to sub_channels = node.redis_connection.pubsub_channels() @@ -730,7 +751,7 @@ def test_pubsub_channels_merge_results(self, r): i += 1 # Assert that the cluster's pubsub_channels function returns ALL of # the cluster's channels - result = r.pubsub_channels(target_nodes='all') + result = r.pubsub_channels(target_nodes="all") result.sort() assert result == channels @@ -738,7 +759,7 @@ def test_pubsub_numsub_merge_results(self, r): nodes = r.get_nodes() pubsub_nodes = [] channel = "foo" - b_channel = channel.encode('utf-8') + b_channel = channel.encode("utf-8") for node in nodes: # We will create different pubsub clients where each one is # connected to a different node @@ -753,8 +774,7 @@ def test_pubsub_numsub_merge_results(self, r): assert sub_chann_num == [(b_channel, 1)] # Assert that the cluster's pubsub_numsub function returns ALL clients # subscribed to this channel in the entire cluster - assert r.pubsub_numsub(channel, target_nodes='all') == \ - [(b_channel, len(nodes))] + assert r.pubsub_numsub(channel, target_nodes="all") == [(b_channel, len(nodes))] def test_pubsub_numpat_merge_results(self, r): nodes = r.get_nodes() @@ -774,35 +794,35 @@ def test_pubsub_numpat_merge_results(self, r): assert sub_num_pat == 1 # Assert that the cluster's pubsub_numsub function returns ALL clients # subscribed to this channel in the entire cluster - assert r.pubsub_numpat(target_nodes='all') == len(nodes) + assert r.pubsub_numpat(target_nodes="all") == len(nodes) - @skip_if_server_version_lt('2.8.0') + @skip_if_server_version_lt("2.8.0") def test_cluster_pubsub_channels(self, r): p = r.pubsub() - p.subscribe('foo', 'bar', 'baz', 'quux') + p.subscribe("foo", "bar", "baz", "quux") for i in range(4): - assert wait_for_message(p, timeout=0.5)['type'] == 'subscribe' - expected = [b'bar', b'baz', b'foo', b'quux'] - assert all([channel in r.pubsub_channels(target_nodes='all') - for channel in expected]) + assert wait_for_message(p, timeout=0.5)["type"] == "subscribe" + expected = [b"bar", b"baz", b"foo", b"quux"] + assert all( + [channel in r.pubsub_channels(target_nodes="all") for channel in expected] + ) - @skip_if_server_version_lt('2.8.0') + @skip_if_server_version_lt("2.8.0") def test_cluster_pubsub_numsub(self, r): p1 = r.pubsub() - p1.subscribe('foo', 'bar', 'baz') + p1.subscribe("foo", "bar", "baz") for i in range(3): - assert wait_for_message(p1, timeout=0.5)['type'] == 'subscribe' + assert wait_for_message(p1, timeout=0.5)["type"] == "subscribe" p2 = r.pubsub() - p2.subscribe('bar', 'baz') + p2.subscribe("bar", "baz") for i in range(2): - assert wait_for_message(p2, timeout=0.5)['type'] == 'subscribe' + assert wait_for_message(p2, timeout=0.5)["type"] == "subscribe" p3 = r.pubsub() - p3.subscribe('baz') - assert wait_for_message(p3, timeout=0.5)['type'] == 'subscribe' + p3.subscribe("baz") + assert wait_for_message(p3, timeout=0.5)["type"] == "subscribe" - channels = [(b'foo', 1), (b'bar', 2), (b'baz', 3)] - assert r.pubsub_numsub('foo', 'bar', 'baz', target_nodes='all') \ - == channels + channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)] + assert r.pubsub_numsub("foo", "bar", "baz", target_nodes="all") == channels def test_cluster_slots(self, r): mock_all_nodes_resp(r, default_cluster_slots) @@ -810,12 +830,11 @@ def test_cluster_slots(self, r): assert isinstance(cluster_slots, dict) assert len(default_cluster_slots) == len(cluster_slots) assert cluster_slots.get((0, 8191)) is not None - assert cluster_slots.get((0, 8191)).get('primary') == \ - ('127.0.0.1', 7000) + assert cluster_slots.get((0, 8191)).get("primary") == ("127.0.0.1", 7000) def test_cluster_addslots(self, r): node = r.get_random_node() - mock_node_resp(node, 'OK') + mock_node_resp(node, "OK") assert r.cluster_addslots(node, 1, 2, 3) is True def test_cluster_countkeysinslot(self, r): @@ -825,22 +844,25 @@ def test_cluster_countkeysinslot(self, r): def test_cluster_count_failure_report(self, r): mock_all_nodes_resp(r, 0) - assert r.cluster_count_failure_report('node_0') == 0 + assert r.cluster_count_failure_report("node_0") == 0 def test_cluster_delslots(self): cluster_slots = [ [ - 0, 8191, - ['127.0.0.1', 7000, 'node_0'], + 0, + 8191, + ["127.0.0.1", 7000, "node_0"], ], [ - 8192, 16383, - ['127.0.0.1', 7001, 'node_1'], - ] + 8192, + 16383, + ["127.0.0.1", 7001, "node_1"], + ], ] - r = get_mocked_redis_client(host=default_host, port=default_port, - cluster_slots=cluster_slots) - mock_all_nodes_resp(r, 'OK') + r = get_mocked_redis_client( + host=default_host, port=default_port, cluster_slots=cluster_slots + ) + mock_all_nodes_resp(r, "OK") node0 = r.get_node(default_host, 7000) node1 = r.get_node(default_host, 7001) assert r.cluster_delslots(0, 8192) == [True, True] @@ -849,59 +871,61 @@ def test_cluster_delslots(self): def test_cluster_failover(self, r): node = r.get_random_node() - mock_node_resp(node, 'OK') + mock_node_resp(node, "OK") assert r.cluster_failover(node) is True - assert r.cluster_failover(node, 'FORCE') is True - assert r.cluster_failover(node, 'TAKEOVER') is True + assert r.cluster_failover(node, "FORCE") is True + assert r.cluster_failover(node, "TAKEOVER") is True with pytest.raises(RedisError): - r.cluster_failover(node, 'FORCT') + r.cluster_failover(node, "FORCT") def test_cluster_info(self, r): info = r.cluster_info() assert isinstance(info, dict) - assert info['cluster_state'] == 'ok' + assert info["cluster_state"] == "ok" def test_cluster_keyslot(self, r): mock_all_nodes_resp(r, 12182) - assert r.cluster_keyslot('foo') == 12182 + assert r.cluster_keyslot("foo") == 12182 def test_cluster_meet(self, r): node = r.get_default_node() - mock_node_resp(node, 'OK') - assert r.cluster_meet('127.0.0.1', 6379) is True + mock_node_resp(node, "OK") + assert r.cluster_meet("127.0.0.1", 6379) is True def test_cluster_nodes(self, r): response = ( - 'c8253bae761cb1ecb2b61857d85dfe455a0fec8b 172.17.0.7:7006 ' - 'slave aa90da731f673a99617dfe930306549a09f83a6b 0 ' - '1447836263059 5 connected\n' - '9bd595fe4821a0e8d6b99d70faa660638a7612b3 172.17.0.7:7008 ' - 'master - 0 1447836264065 0 connected\n' - 'aa90da731f673a99617dfe930306549a09f83a6b 172.17.0.7:7003 ' - 'myself,master - 0 0 2 connected 5461-10922\n' - '1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 ' - 'slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 ' - '1447836262556 3 connected\n' - '4ad9a12e63e8f0207025eeba2354bcf4c85e5b22 172.17.0.7:7005 ' - 'master - 0 1447836262555 7 connected 0-5460\n' - '19efe5a631f3296fdf21a5441680f893e8cc96ec 172.17.0.7:7004 ' - 'master - 0 1447836263562 3 connected 10923-16383\n' - 'fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 ' - 'master,fail - 1447829446956 1447829444948 1 disconnected\n' + "c8253bae761cb1ecb2b61857d85dfe455a0fec8b 172.17.0.7:7006 " + "slave aa90da731f673a99617dfe930306549a09f83a6b 0 " + "1447836263059 5 connected\n" + "9bd595fe4821a0e8d6b99d70faa660638a7612b3 172.17.0.7:7008 " + "master - 0 1447836264065 0 connected\n" + "aa90da731f673a99617dfe930306549a09f83a6b 172.17.0.7:7003 " + "myself,master - 0 0 2 connected 5461-10922\n" + "1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 " + "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " + "1447836262556 3 connected\n" + "4ad9a12e63e8f0207025eeba2354bcf4c85e5b22 172.17.0.7:7005 " + "master - 0 1447836262555 7 connected 0-5460\n" + "19efe5a631f3296fdf21a5441680f893e8cc96ec 172.17.0.7:7004 " + "master - 0 1447836263562 3 connected 10923-16383\n" + "fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 " + "master,fail - 1447829446956 1447829444948 1 disconnected\n" ) mock_all_nodes_resp(r, response) nodes = r.cluster_nodes() assert len(nodes) == 7 - assert nodes.get('172.17.0.7:7006') is not None - assert nodes.get('172.17.0.7:7006').get('node_id') == \ - "c8253bae761cb1ecb2b61857d85dfe455a0fec8b" + assert nodes.get("172.17.0.7:7006") is not None + assert ( + nodes.get("172.17.0.7:7006").get("node_id") + == "c8253bae761cb1ecb2b61857d85dfe455a0fec8b" + ) def test_cluster_replicate(self, r): node = r.get_random_node() all_replicas = r.get_replicas() - mock_all_nodes_resp(r, 'OK') - assert r.cluster_replicate(node, 'c8253bae761cb61857d') is True - results = r.cluster_replicate(all_replicas, 'c8253bae761cb61857d') + mock_all_nodes_resp(r, "OK") + assert r.cluster_replicate(node, "c8253bae761cb61857d") is True + results = r.cluster_replicate(all_replicas, "c8253bae761cb61857d") if isinstance(results, dict): for res in results.values(): assert res is True @@ -909,74 +933,78 @@ def test_cluster_replicate(self, r): assert results is True def test_cluster_reset(self, r): - mock_all_nodes_resp(r, 'OK') + mock_all_nodes_resp(r, "OK") assert r.cluster_reset() is True assert r.cluster_reset(False) is True - all_results = r.cluster_reset(False, target_nodes='all') + all_results = r.cluster_reset(False, target_nodes="all") for res in all_results.values(): assert res is True def test_cluster_save_config(self, r): node = r.get_random_node() all_nodes = r.get_nodes() - mock_all_nodes_resp(r, 'OK') + mock_all_nodes_resp(r, "OK") assert r.cluster_save_config(node) is True all_results = r.cluster_save_config(all_nodes) for res in all_results.values(): assert res is True def test_cluster_get_keys_in_slot(self, r): - response = [b'{foo}1', b'{foo}2'] + response = [b"{foo}1", b"{foo}2"] node = r.nodes_manager.get_node_from_slot(12182) mock_node_resp(node, response) keys = r.cluster_get_keys_in_slot(12182, 4) assert keys == response def test_cluster_set_config_epoch(self, r): - mock_all_nodes_resp(r, 'OK') + mock_all_nodes_resp(r, "OK") assert r.cluster_set_config_epoch(3) is True - all_results = r.cluster_set_config_epoch(3, target_nodes='all') + all_results = r.cluster_set_config_epoch(3, target_nodes="all") for res in all_results.values(): assert res is True def test_cluster_setslot(self, r): node = r.get_random_node() - mock_node_resp(node, 'OK') - assert r.cluster_setslot(node, 'node_0', 1218, 'IMPORTING') is True - assert r.cluster_setslot(node, 'node_0', 1218, 'NODE') is True - assert r.cluster_setslot(node, 'node_0', 1218, 'MIGRATING') is True + mock_node_resp(node, "OK") + assert r.cluster_setslot(node, "node_0", 1218, "IMPORTING") is True + assert r.cluster_setslot(node, "node_0", 1218, "NODE") is True + assert r.cluster_setslot(node, "node_0", 1218, "MIGRATING") is True with pytest.raises(RedisError): - r.cluster_failover(node, 'STABLE') + r.cluster_failover(node, "STABLE") with pytest.raises(RedisError): - r.cluster_failover(node, 'STATE') + r.cluster_failover(node, "STATE") def test_cluster_setslot_stable(self, r): node = r.nodes_manager.get_node_from_slot(12182) - mock_node_resp(node, 'OK') + mock_node_resp(node, "OK") assert r.cluster_setslot_stable(12182) is True assert node.redis_connection.connection.read_response.called def test_cluster_replicas(self, r): - response = [b'01eca22229cf3c652b6fca0d09ff6941e0d2e3 ' - b'127.0.0.1:6377@16377 slave ' - b'52611e796814b78e90ad94be9d769a4f668f9a 0 ' - b'1634550063436 4 connected', - b'r4xfga22229cf3c652b6fca0d09ff69f3e0d4d ' - b'127.0.0.1:6378@16378 slave ' - b'52611e796814b78e90ad94be9d769a4f668f9a 0 ' - b'1634550063436 4 connected'] + response = [ + b"01eca22229cf3c652b6fca0d09ff6941e0d2e3 " + b"127.0.0.1:6377@16377 slave " + b"52611e796814b78e90ad94be9d769a4f668f9a 0 " + b"1634550063436 4 connected", + b"r4xfga22229cf3c652b6fca0d09ff69f3e0d4d " + b"127.0.0.1:6378@16378 slave " + b"52611e796814b78e90ad94be9d769a4f668f9a 0 " + b"1634550063436 4 connected", + ] mock_all_nodes_resp(r, response) - replicas = r.cluster_replicas('52611e796814b78e90ad94be9d769a4f668f9a') - assert replicas.get('127.0.0.1:6377') is not None - assert replicas.get('127.0.0.1:6378') is not None - assert replicas.get('127.0.0.1:6378').get('node_id') == \ - 'r4xfga22229cf3c652b6fca0d09ff69f3e0d4d' + replicas = r.cluster_replicas("52611e796814b78e90ad94be9d769a4f668f9a") + assert replicas.get("127.0.0.1:6377") is not None + assert replicas.get("127.0.0.1:6378") is not None + assert ( + replicas.get("127.0.0.1:6378").get("node_id") + == "r4xfga22229cf3c652b6fca0d09ff69f3e0d4d" + ) def test_readonly(self): r = get_mocked_redis_client(host=default_host, port=default_port) - mock_all_nodes_resp(r, 'OK') + mock_all_nodes_resp(r, "OK") assert r.readonly() is True - all_replicas_results = r.readonly(target_nodes='replicas') + all_replicas_results = r.readonly(target_nodes="replicas") for res in all_replicas_results.values(): assert res is True for replica in r.get_replicas(): @@ -984,9 +1012,9 @@ def test_readonly(self): def test_readwrite(self): r = get_mocked_redis_client(host=default_host, port=default_port) - mock_all_nodes_resp(r, 'OK') + mock_all_nodes_resp(r, "OK") assert r.readwrite() is True - all_replicas_results = r.readwrite(target_nodes='replicas') + all_replicas_results = r.readwrite(target_nodes="replicas") for res in all_replicas_results.values(): assert res is True for replica in r.get_replicas(): @@ -999,59 +1027,59 @@ def test_bgsave(self, r): def test_info(self, r): # Map keys to same slot - r.set('x{1}', 1) - r.set('y{1}', 2) - r.set('z{1}', 3) + r.set("x{1}", 1) + r.set("y{1}", 2) + r.set("z{1}", 3) # Get node that handles the slot - slot = r.keyslot('x{1}') + slot = r.keyslot("x{1}") node = r.nodes_manager.get_node_from_slot(slot) # Run info on that node info = r.info(target_nodes=node) assert isinstance(info, dict) - assert info['db0']['keys'] == 3 + assert info["db0"]["keys"] == 3 def _init_slowlog_test(self, r, node): - slowlog_lim = r.config_get('slowlog-log-slower-than', - target_nodes=node) - assert r.config_set('slowlog-log-slower-than', 0, target_nodes=node) \ - is True - return slowlog_lim['slowlog-log-slower-than'] + slowlog_lim = r.config_get("slowlog-log-slower-than", target_nodes=node) + assert r.config_set("slowlog-log-slower-than", 0, target_nodes=node) is True + return slowlog_lim["slowlog-log-slower-than"] def _teardown_slowlog_test(self, r, node, prev_limit): - assert r.config_set('slowlog-log-slower-than', prev_limit, - target_nodes=node) is True + assert ( + r.config_set("slowlog-log-slower-than", prev_limit, target_nodes=node) + is True + ) def test_slowlog_get(self, r, slowlog): - unicode_string = chr(3456) + 'abcd' + chr(3421) + unicode_string = chr(3456) + "abcd" + chr(3421) node = r.get_node_from_key(unicode_string) slowlog_limit = self._init_slowlog_test(r, node) assert r.slowlog_reset(target_nodes=node) r.get(unicode_string) slowlog = r.slowlog_get(target_nodes=node) assert isinstance(slowlog, list) - commands = [log['command'] for log in slowlog] + commands = [log["command"] for log in slowlog] - get_command = b' '.join((b'GET', unicode_string.encode('utf-8'))) + get_command = b" ".join((b"GET", unicode_string.encode("utf-8"))) assert get_command in commands - assert b'SLOWLOG RESET' in commands + assert b"SLOWLOG RESET" in commands # the order should be ['GET ', 'SLOWLOG RESET'], # but if other clients are executing commands at the same time, there # could be commands, before, between, or after, so just check that # the two we care about are in the appropriate order. - assert commands.index(get_command) < commands.index(b'SLOWLOG RESET') + assert commands.index(get_command) < commands.index(b"SLOWLOG RESET") # make sure other attributes are typed correctly - assert isinstance(slowlog[0]['start_time'], int) - assert isinstance(slowlog[0]['duration'], int) + assert isinstance(slowlog[0]["start_time"], int) + assert isinstance(slowlog[0]["duration"], int) # rollback the slowlog limit to its original value self._teardown_slowlog_test(r, node, slowlog_limit) def test_slowlog_get_limit(self, r, slowlog): assert r.slowlog_reset() - node = r.get_node_from_key('foo') + node = r.get_node_from_key("foo") slowlog_limit = self._init_slowlog_test(r, node) - r.get('foo') + r.get("foo") slowlog = r.slowlog_get(1, target_nodes=node) assert isinstance(slowlog, list) # only one command, based on the number we passed to slowlog_get() @@ -1059,8 +1087,8 @@ def test_slowlog_get_limit(self, r, slowlog): self._teardown_slowlog_test(r, node, slowlog_limit) def test_slowlog_length(self, r, slowlog): - r.get('foo') - node = r.nodes_manager.get_node_from_slot(key_slot(b'foo')) + r.get("foo") + node = r.nodes_manager.get_node_from_slot(key_slot(b"foo")) slowlog_len = r.slowlog_len(target_nodes=node) assert isinstance(slowlog_len, int) @@ -1070,47 +1098,46 @@ def test_time(self, r): assert isinstance(t[0], int) assert isinstance(t[1], int) - @skip_if_server_version_lt('4.0.0') + @skip_if_server_version_lt("4.0.0") def test_memory_usage(self, r): - r.set('foo', 'bar') - assert isinstance(r.memory_usage('foo'), int) + r.set("foo", "bar") + assert isinstance(r.memory_usage("foo"), int) - @skip_if_server_version_lt('4.0.0') + @skip_if_server_version_lt("4.0.0") def test_memory_malloc_stats(self, r): assert r.memory_malloc_stats() - @skip_if_server_version_lt('4.0.0') + @skip_if_server_version_lt("4.0.0") def test_memory_stats(self, r): # put a key into the current db to make sure that "db." # has data - r.set('foo', 'bar') - node = r.nodes_manager.get_node_from_slot(key_slot(b'foo')) + r.set("foo", "bar") + node = r.nodes_manager.get_node_from_slot(key_slot(b"foo")) stats = r.memory_stats(target_nodes=node) assert isinstance(stats, dict) for key, value in stats.items(): - if key.startswith('db.'): + if key.startswith("db."): assert isinstance(value, dict) - @skip_if_server_version_lt('4.0.0') + @skip_if_server_version_lt("4.0.0") def test_memory_help(self, r): with pytest.raises(NotImplementedError): r.memory_help() - @skip_if_server_version_lt('4.0.0') + @skip_if_server_version_lt("4.0.0") def test_memory_doctor(self, r): with pytest.raises(NotImplementedError): r.memory_doctor() def test_lastsave(self, r): node = r.get_primaries()[0] - assert isinstance(r.lastsave(target_nodes=node), - datetime.datetime) + assert isinstance(r.lastsave(target_nodes=node), datetime.datetime) def test_cluster_echo(self, r): node = r.get_primaries()[0] - assert r.echo('foo bar', node) == b'foo bar' + assert r.echo("foo bar", node) == b"foo bar" - @skip_if_server_version_lt('1.0.0') + @skip_if_server_version_lt("1.0.0") def test_debug_segfault(self, r): with pytest.raises(NotImplementedError): r.debug_segfault() @@ -1118,39 +1145,41 @@ def test_debug_segfault(self, r): def test_config_resetstat(self, r): node = r.get_primaries()[0] r.ping(target_nodes=node) - prior_commands_processed = \ - int(r.info(target_nodes=node)['total_commands_processed']) + prior_commands_processed = int( + r.info(target_nodes=node)["total_commands_processed"] + ) assert prior_commands_processed >= 1 r.config_resetstat(target_nodes=node) - reset_commands_processed = \ - int(r.info(target_nodes=node)['total_commands_processed']) + reset_commands_processed = int( + r.info(target_nodes=node)["total_commands_processed"] + ) assert reset_commands_processed < prior_commands_processed - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_client_trackinginfo(self, r): node = r.get_primaries()[0] res = r.client_trackinginfo(target_nodes=node) assert len(res) > 2 - assert 'prefixes' in res + assert "prefixes" in res - @skip_if_server_version_lt('2.9.50') + @skip_if_server_version_lt("2.9.50") def test_client_pause(self, r): node = r.get_primaries()[0] assert r.client_pause(1, target_nodes=node) assert r.client_pause(timeout=1, target_nodes=node) with pytest.raises(RedisError): - r.client_pause(timeout='not an integer', target_nodes=node) + r.client_pause(timeout="not an integer", target_nodes=node) - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_client_unpause(self, r): assert r.client_unpause() - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_client_id(self, r): node = r.get_primaries()[0] assert r.client_id(target_nodes=node) > 0 - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_client_unblock(self, r): node = r.get_primaries()[0] myid = r.client_id(target_nodes=node) @@ -1158,82 +1187,88 @@ def test_client_unblock(self, r): assert not r.client_unblock(myid, error=True, target_nodes=node) assert not r.client_unblock(myid, error=False, target_nodes=node) - @skip_if_server_version_lt('6.0.0') + @skip_if_server_version_lt("6.0.0") def test_client_getredir(self, r): node = r.get_primaries()[0] assert isinstance(r.client_getredir(target_nodes=node), int) assert r.client_getredir(target_nodes=node) == -1 - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_client_info(self, r): node = r.get_primaries()[0] info = r.client_info(target_nodes=node) assert isinstance(info, dict) - assert 'addr' in info + assert "addr" in info - @skip_if_server_version_lt('2.6.9') + @skip_if_server_version_lt("2.6.9") def test_client_kill(self, r, r2): node = r.get_primaries()[0] - r.client_setname('redis-py-c1', target_nodes='all') - r2.client_setname('redis-py-c2', target_nodes='all') - clients = [client for client in r.client_list(target_nodes=node) - if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + r.client_setname("redis-py-c1", target_nodes="all") + r2.client_setname("redis-py-c2", target_nodes="all") + clients = [ + client + for client in r.client_list(target_nodes=node) + if client.get("name") in ["redis-py-c1", "redis-py-c2"] + ] assert len(clients) == 2 - clients_by_name = {client.get('name'): client for client in clients} + clients_by_name = {client.get("name"): client for client in clients} - client_addr = clients_by_name['redis-py-c2'].get('addr') + client_addr = clients_by_name["redis-py-c2"].get("addr") assert r.client_kill(client_addr, target_nodes=node) is True - clients = [client for client in r.client_list(target_nodes=node) - if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + clients = [ + client + for client in r.client_list(target_nodes=node) + if client.get("name") in ["redis-py-c1", "redis-py-c2"] + ] assert len(clients) == 1 - assert clients[0].get('name') == 'redis-py-c1' + assert clients[0].get("name") == "redis-py-c1" - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_cluster_bitop_not_empty_string(self, r): - r['{foo}a'] = '' - r.bitop('not', '{foo}r', '{foo}a') - assert r.get('{foo}r') is None + r["{foo}a"] = "" + r.bitop("not", "{foo}r", "{foo}a") + assert r.get("{foo}r") is None - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_cluster_bitop_not(self, r): - test_str = b'\xAA\x00\xFF\x55' + test_str = b"\xAA\x00\xFF\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF - r['{foo}a'] = test_str - r.bitop('not', '{foo}r', '{foo}a') - assert int(binascii.hexlify(r['{foo}r']), 16) == correct + r["{foo}a"] = test_str + r.bitop("not", "{foo}r", "{foo}a") + assert int(binascii.hexlify(r["{foo}r"]), 16) == correct - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_cluster_bitop_not_in_place(self, r): - test_str = b'\xAA\x00\xFF\x55' + test_str = b"\xAA\x00\xFF\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF - r['{foo}a'] = test_str - r.bitop('not', '{foo}a', '{foo}a') - assert int(binascii.hexlify(r['{foo}a']), 16) == correct + r["{foo}a"] = test_str + r.bitop("not", "{foo}a", "{foo}a") + assert int(binascii.hexlify(r["{foo}a"]), 16) == correct - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_cluster_bitop_single_string(self, r): - test_str = b'\x01\x02\xFF' - r['{foo}a'] = test_str - r.bitop('and', '{foo}res1', '{foo}a') - r.bitop('or', '{foo}res2', '{foo}a') - r.bitop('xor', '{foo}res3', '{foo}a') - assert r['{foo}res1'] == test_str - assert r['{foo}res2'] == test_str - assert r['{foo}res3'] == test_str - - @skip_if_server_version_lt('2.6.0') + test_str = b"\x01\x02\xFF" + r["{foo}a"] = test_str + r.bitop("and", "{foo}res1", "{foo}a") + r.bitop("or", "{foo}res2", "{foo}a") + r.bitop("xor", "{foo}res3", "{foo}a") + assert r["{foo}res1"] == test_str + assert r["{foo}res2"] == test_str + assert r["{foo}res3"] == test_str + + @skip_if_server_version_lt("2.6.0") def test_cluster_bitop_string_operands(self, r): - r['{foo}a'] = b'\x01\x02\xFF\xFF' - r['{foo}b'] = b'\x01\x02\xFF' - r.bitop('and', '{foo}res1', '{foo}a', '{foo}b') - r.bitop('or', '{foo}res2', '{foo}a', '{foo}b') - r.bitop('xor', '{foo}res3', '{foo}a', '{foo}b') - assert int(binascii.hexlify(r['{foo}res1']), 16) == 0x0102FF00 - assert int(binascii.hexlify(r['{foo}res2']), 16) == 0x0102FFFF - assert int(binascii.hexlify(r['{foo}res3']), 16) == 0x000000FF - - @skip_if_server_version_lt('6.2.0') + r["{foo}a"] = b"\x01\x02\xFF\xFF" + r["{foo}b"] = b"\x01\x02\xFF" + r.bitop("and", "{foo}res1", "{foo}a", "{foo}b") + r.bitop("or", "{foo}res2", "{foo}a", "{foo}b") + r.bitop("xor", "{foo}res3", "{foo}a", "{foo}b") + assert int(binascii.hexlify(r["{foo}res1"]), 16) == 0x0102FF00 + assert int(binascii.hexlify(r["{foo}res2"]), 16) == 0x0102FFFF + assert int(binascii.hexlify(r["{foo}res3"]), 16) == 0x000000FF + + @skip_if_server_version_lt("6.2.0") def test_cluster_copy(self, r): assert r.copy("{foo}a", "{foo}b") == 0 r.set("{foo}a", "bar") @@ -1241,449 +1276,493 @@ def test_cluster_copy(self, r): assert r.get("{foo}a") == b"bar" assert r.get("{foo}b") == b"bar" - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_cluster_copy_and_replace(self, r): r.set("{foo}a", "foo1") r.set("{foo}b", "foo2") assert r.copy("{foo}a", "{foo}b") == 0 assert r.copy("{foo}a", "{foo}b", replace=True) == 1 - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_cluster_lmove(self, r): - r.rpush('{foo}a', 'one', 'two', 'three', 'four') - assert r.lmove('{foo}a', '{foo}b') - assert r.lmove('{foo}a', '{foo}b', 'right', 'left') + r.rpush("{foo}a", "one", "two", "three", "four") + assert r.lmove("{foo}a", "{foo}b") + assert r.lmove("{foo}a", "{foo}b", "right", "left") - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_cluster_blmove(self, r): - r.rpush('{foo}a', 'one', 'two', 'three', 'four') - assert r.blmove('{foo}a', '{foo}b', 5) - assert r.blmove('{foo}a', '{foo}b', 1, 'RIGHT', 'LEFT') + r.rpush("{foo}a", "one", "two", "three", "four") + assert r.blmove("{foo}a", "{foo}b", 5) + assert r.blmove("{foo}a", "{foo}b", 1, "RIGHT", "LEFT") def test_cluster_msetnx(self, r): - d = {'{foo}a': b'1', '{foo}b': b'2', '{foo}c': b'3'} + d = {"{foo}a": b"1", "{foo}b": b"2", "{foo}c": b"3"} assert r.msetnx(d) - d2 = {'{foo}a': b'x', '{foo}d': b'4'} + d2 = {"{foo}a": b"x", "{foo}d": b"4"} assert not r.msetnx(d2) for k, v in d.items(): assert r[k] == v - assert r.get('{foo}d') is None + assert r.get("{foo}d") is None def test_cluster_rename(self, r): - r['{foo}a'] = '1' - assert r.rename('{foo}a', '{foo}b') - assert r.get('{foo}a') is None - assert r['{foo}b'] == b'1' + r["{foo}a"] = "1" + assert r.rename("{foo}a", "{foo}b") + assert r.get("{foo}a") is None + assert r["{foo}b"] == b"1" def test_cluster_renamenx(self, r): - r['{foo}a'] = '1' - r['{foo}b'] = '2' - assert not r.renamenx('{foo}a', '{foo}b') - assert r['{foo}a'] == b'1' - assert r['{foo}b'] == b'2' + r["{foo}a"] = "1" + r["{foo}b"] = "2" + assert not r.renamenx("{foo}a", "{foo}b") + assert r["{foo}a"] == b"1" + assert r["{foo}b"] == b"2" # LIST COMMANDS def test_cluster_blpop(self, r): - r.rpush('{foo}a', '1', '2') - r.rpush('{foo}b', '3', '4') - assert r.blpop(['{foo}b', '{foo}a'], timeout=1) == (b'{foo}b', b'3') - assert r.blpop(['{foo}b', '{foo}a'], timeout=1) == (b'{foo}b', b'4') - assert r.blpop(['{foo}b', '{foo}a'], timeout=1) == (b'{foo}a', b'1') - assert r.blpop(['{foo}b', '{foo}a'], timeout=1) == (b'{foo}a', b'2') - assert r.blpop(['{foo}b', '{foo}a'], timeout=1) is None - r.rpush('{foo}c', '1') - assert r.blpop('{foo}c', timeout=1) == (b'{foo}c', b'1') + r.rpush("{foo}a", "1", "2") + r.rpush("{foo}b", "3", "4") + assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") + assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") + assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") + assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") + assert r.blpop(["{foo}b", "{foo}a"], timeout=1) is None + r.rpush("{foo}c", "1") + assert r.blpop("{foo}c", timeout=1) == (b"{foo}c", b"1") def test_cluster_brpop(self, r): - r.rpush('{foo}a', '1', '2') - r.rpush('{foo}b', '3', '4') - assert r.brpop(['{foo}b', '{foo}a'], timeout=1) == (b'{foo}b', b'4') - assert r.brpop(['{foo}b', '{foo}a'], timeout=1) == (b'{foo}b', b'3') - assert r.brpop(['{foo}b', '{foo}a'], timeout=1) == (b'{foo}a', b'2') - assert r.brpop(['{foo}b', '{foo}a'], timeout=1) == (b'{foo}a', b'1') - assert r.brpop(['{foo}b', '{foo}a'], timeout=1) is None - r.rpush('{foo}c', '1') - assert r.brpop('{foo}c', timeout=1) == (b'{foo}c', b'1') + r.rpush("{foo}a", "1", "2") + r.rpush("{foo}b", "3", "4") + assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") + assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") + assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") + assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") + assert r.brpop(["{foo}b", "{foo}a"], timeout=1) is None + r.rpush("{foo}c", "1") + assert r.brpop("{foo}c", timeout=1) == (b"{foo}c", b"1") def test_cluster_brpoplpush(self, r): - r.rpush('{foo}a', '1', '2') - r.rpush('{foo}b', '3', '4') - assert r.brpoplpush('{foo}a', '{foo}b') == b'2' - assert r.brpoplpush('{foo}a', '{foo}b') == b'1' - assert r.brpoplpush('{foo}a', '{foo}b', timeout=1) is None - assert r.lrange('{foo}a', 0, -1) == [] - assert r.lrange('{foo}b', 0, -1) == [b'1', b'2', b'3', b'4'] + r.rpush("{foo}a", "1", "2") + r.rpush("{foo}b", "3", "4") + assert r.brpoplpush("{foo}a", "{foo}b") == b"2" + assert r.brpoplpush("{foo}a", "{foo}b") == b"1" + assert r.brpoplpush("{foo}a", "{foo}b", timeout=1) is None + assert r.lrange("{foo}a", 0, -1) == [] + assert r.lrange("{foo}b", 0, -1) == [b"1", b"2", b"3", b"4"] def test_cluster_brpoplpush_empty_string(self, r): - r.rpush('{foo}a', '') - assert r.brpoplpush('{foo}a', '{foo}b') == b'' + r.rpush("{foo}a", "") + assert r.brpoplpush("{foo}a", "{foo}b") == b"" def test_cluster_rpoplpush(self, r): - r.rpush('{foo}a', 'a1', 'a2', 'a3') - r.rpush('{foo}b', 'b1', 'b2', 'b3') - assert r.rpoplpush('{foo}a', '{foo}b') == b'a3' - assert r.lrange('{foo}a', 0, -1) == [b'a1', b'a2'] - assert r.lrange('{foo}b', 0, -1) == [b'a3', b'b1', b'b2', b'b3'] + r.rpush("{foo}a", "a1", "a2", "a3") + r.rpush("{foo}b", "b1", "b2", "b3") + assert r.rpoplpush("{foo}a", "{foo}b") == b"a3" + assert r.lrange("{foo}a", 0, -1) == [b"a1", b"a2"] + assert r.lrange("{foo}b", 0, -1) == [b"a3", b"b1", b"b2", b"b3"] def test_cluster_sdiff(self, r): - r.sadd('{foo}a', '1', '2', '3') - assert r.sdiff('{foo}a', '{foo}b') == {b'1', b'2', b'3'} - r.sadd('{foo}b', '2', '3') - assert r.sdiff('{foo}a', '{foo}b') == {b'1'} + r.sadd("{foo}a", "1", "2", "3") + assert r.sdiff("{foo}a", "{foo}b") == {b"1", b"2", b"3"} + r.sadd("{foo}b", "2", "3") + assert r.sdiff("{foo}a", "{foo}b") == {b"1"} def test_cluster_sdiffstore(self, r): - r.sadd('{foo}a', '1', '2', '3') - assert r.sdiffstore('{foo}c', '{foo}a', '{foo}b') == 3 - assert r.smembers('{foo}c') == {b'1', b'2', b'3'} - r.sadd('{foo}b', '2', '3') - assert r.sdiffstore('{foo}c', '{foo}a', '{foo}b') == 1 - assert r.smembers('{foo}c') == {b'1'} + r.sadd("{foo}a", "1", "2", "3") + assert r.sdiffstore("{foo}c", "{foo}a", "{foo}b") == 3 + assert r.smembers("{foo}c") == {b"1", b"2", b"3"} + r.sadd("{foo}b", "2", "3") + assert r.sdiffstore("{foo}c", "{foo}a", "{foo}b") == 1 + assert r.smembers("{foo}c") == {b"1"} def test_cluster_sinter(self, r): - r.sadd('{foo}a', '1', '2', '3') - assert r.sinter('{foo}a', '{foo}b') == set() - r.sadd('{foo}b', '2', '3') - assert r.sinter('{foo}a', '{foo}b') == {b'2', b'3'} + r.sadd("{foo}a", "1", "2", "3") + assert r.sinter("{foo}a", "{foo}b") == set() + r.sadd("{foo}b", "2", "3") + assert r.sinter("{foo}a", "{foo}b") == {b"2", b"3"} def test_cluster_sinterstore(self, r): - r.sadd('{foo}a', '1', '2', '3') - assert r.sinterstore('{foo}c', '{foo}a', '{foo}b') == 0 - assert r.smembers('{foo}c') == set() - r.sadd('{foo}b', '2', '3') - assert r.sinterstore('{foo}c', '{foo}a', '{foo}b') == 2 - assert r.smembers('{foo}c') == {b'2', b'3'} + r.sadd("{foo}a", "1", "2", "3") + assert r.sinterstore("{foo}c", "{foo}a", "{foo}b") == 0 + assert r.smembers("{foo}c") == set() + r.sadd("{foo}b", "2", "3") + assert r.sinterstore("{foo}c", "{foo}a", "{foo}b") == 2 + assert r.smembers("{foo}c") == {b"2", b"3"} def test_cluster_smove(self, r): - r.sadd('{foo}a', 'a1', 'a2') - r.sadd('{foo}b', 'b1', 'b2') - assert r.smove('{foo}a', '{foo}b', 'a1') - assert r.smembers('{foo}a') == {b'a2'} - assert r.smembers('{foo}b') == {b'b1', b'b2', b'a1'} + r.sadd("{foo}a", "a1", "a2") + r.sadd("{foo}b", "b1", "b2") + assert r.smove("{foo}a", "{foo}b", "a1") + assert r.smembers("{foo}a") == {b"a2"} + assert r.smembers("{foo}b") == {b"b1", b"b2", b"a1"} def test_cluster_sunion(self, r): - r.sadd('{foo}a', '1', '2') - r.sadd('{foo}b', '2', '3') - assert r.sunion('{foo}a', '{foo}b') == {b'1', b'2', b'3'} + r.sadd("{foo}a", "1", "2") + r.sadd("{foo}b", "2", "3") + assert r.sunion("{foo}a", "{foo}b") == {b"1", b"2", b"3"} def test_cluster_sunionstore(self, r): - r.sadd('{foo}a', '1', '2') - r.sadd('{foo}b', '2', '3') - assert r.sunionstore('{foo}c', '{foo}a', '{foo}b') == 3 - assert r.smembers('{foo}c') == {b'1', b'2', b'3'} + r.sadd("{foo}a", "1", "2") + r.sadd("{foo}b", "2", "3") + assert r.sunionstore("{foo}c", "{foo}a", "{foo}b") == 3 + assert r.smembers("{foo}c") == {b"1", b"2", b"3"} - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_cluster_zdiff(self, r): - r.zadd('{foo}a', {'a1': 1, 'a2': 2, 'a3': 3}) - r.zadd('{foo}b', {'a1': 1, 'a2': 2}) - assert r.zdiff(['{foo}a', '{foo}b']) == [b'a3'] - assert r.zdiff(['{foo}a', '{foo}b'], withscores=True) == [b'a3', b'3'] + r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) + r.zadd("{foo}b", {"a1": 1, "a2": 2}) + assert r.zdiff(["{foo}a", "{foo}b"]) == [b"a3"] + assert r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [b"a3", b"3"] - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_cluster_zdiffstore(self, r): - r.zadd('{foo}a', {'a1': 1, 'a2': 2, 'a3': 3}) - r.zadd('{foo}b', {'a1': 1, 'a2': 2}) - assert r.zdiffstore("{foo}out", ['{foo}a', '{foo}b']) - assert r.zrange("{foo}out", 0, -1) == [b'a3'] - assert r.zrange("{foo}out", 0, -1, withscores=True) == [(b'a3', 3.0)] + r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) + r.zadd("{foo}b", {"a1": 1, "a2": 2}) + assert r.zdiffstore("{foo}out", ["{foo}a", "{foo}b"]) + assert r.zrange("{foo}out", 0, -1) == [b"a3"] + assert r.zrange("{foo}out", 0, -1, withscores=True) == [(b"a3", 3.0)] - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_cluster_zinter(self, r): - r.zadd('{foo}a', {'a1': 1, 'a2': 2, 'a3': 1}) - r.zadd('{foo}b', {'a1': 2, 'a2': 2, 'a3': 2}) - r.zadd('{foo}c', {'a1': 6, 'a3': 5, 'a4': 4}) - assert r.zinter(['{foo}a', '{foo}b', '{foo}c']) == [b'a3', b'a1'] + r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 1}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zinter(["{foo}a", "{foo}b", "{foo}c"]) == [b"a3", b"a1"] # invalid aggregation with pytest.raises(DataError): - r.zinter(['{foo}a', '{foo}b', '{foo}c'], - aggregate='foo', withscores=True) + r.zinter(["{foo}a", "{foo}b", "{foo}c"], aggregate="foo", withscores=True) # aggregate with SUM - assert r.zinter(['{foo}a', '{foo}b', '{foo}c'], withscores=True) \ - == [(b'a3', 8), (b'a1', 9)] + assert r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ + (b"a3", 8), + (b"a1", 9), + ] # aggregate with MAX - assert r.zinter(['{foo}a', '{foo}b', '{foo}c'], aggregate='MAX', - withscores=True) \ - == [(b'a3', 5), (b'a1', 6)] + assert r.zinter( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True + ) == [(b"a3", 5), (b"a1", 6)] # aggregate with MIN - assert r.zinter(['{foo}a', '{foo}b', '{foo}c'], aggregate='MIN', - withscores=True) \ - == [(b'a1', 1), (b'a3', 1)] + assert r.zinter( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True + ) == [(b"a1", 1), (b"a3", 1)] # with weights - assert r.zinter({'{foo}a': 1, '{foo}b': 2, '{foo}c': 3}, - withscores=True) \ - == [(b'a3', 20), (b'a1', 23)] + assert r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) == [ + (b"a3", 20), + (b"a1", 23), + ] def test_cluster_zinterstore_sum(self, r): - r.zadd('{foo}a', {'a1': 1, 'a2': 1, 'a3': 1}) - r.zadd('{foo}b', {'a1': 2, 'a2': 2, 'a3': 2}) - r.zadd('{foo}c', {'a1': 6, 'a3': 5, 'a4': 4}) - assert r.zinterstore('{foo}d', ['{foo}a', '{foo}b', '{foo}c']) == 2 - assert r.zrange('{foo}d', 0, -1, withscores=True) == \ - [(b'a3', 8), (b'a1', 9)] + r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 2 + assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] def test_cluster_zinterstore_max(self, r): - r.zadd('{foo}a', {'a1': 1, 'a2': 1, 'a3': 1}) - r.zadd('{foo}b', {'a1': 2, 'a2': 2, 'a3': 2}) - r.zadd('{foo}c', {'a1': 6, 'a3': 5, 'a4': 4}) - assert r.zinterstore( - '{foo}d', ['{foo}a', '{foo}b', '{foo}c'], aggregate='MAX') == 2 - assert r.zrange('{foo}d', 0, -1, withscores=True) == \ - [(b'a3', 5), (b'a1', 6)] + r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX") + == 2 + ) + assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] def test_cluster_zinterstore_min(self, r): - r.zadd('{foo}a', {'a1': 1, 'a2': 2, 'a3': 3}) - r.zadd('{foo}b', {'a1': 2, 'a2': 3, 'a3': 5}) - r.zadd('{foo}c', {'a1': 6, 'a3': 5, 'a4': 4}) - assert r.zinterstore( - '{foo}d', ['{foo}a', '{foo}b', '{foo}c'], aggregate='MIN') == 2 - assert r.zrange('{foo}d', 0, -1, withscores=True) == \ - [(b'a1', 1), (b'a3', 3)] + r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) + r.zadd("{foo}b", {"a1": 2, "a2": 3, "a3": 5}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN") + == 2 + ) + assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] def test_cluster_zinterstore_with_weight(self, r): - r.zadd('{foo}a', {'a1': 1, 'a2': 1, 'a3': 1}) - r.zadd('{foo}b', {'a1': 2, 'a2': 2, 'a3': 2}) - r.zadd('{foo}c', {'a1': 6, 'a3': 5, 'a4': 4}) - assert r.zinterstore( - '{foo}d', {'{foo}a': 1, '{foo}b': 2, '{foo}c': 3}) == 2 - assert r.zrange('{foo}d', 0, -1, withscores=True) == \ - [(b'a3', 20), (b'a1', 23)] - - @skip_if_server_version_lt('4.9.0') + r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zinterstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 2 + assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] + + @skip_if_server_version_lt("4.9.0") def test_cluster_bzpopmax(self, r): - r.zadd('{foo}a', {'a1': 1, 'a2': 2}) - r.zadd('{foo}b', {'b1': 10, 'b2': 20}) - assert r.bzpopmax(['{foo}b', '{foo}a'], timeout=1) == ( - b'{foo}b', b'b2', 20) - assert r.bzpopmax(['{foo}b', '{foo}a'], timeout=1) == ( - b'{foo}b', b'b1', 10) - assert r.bzpopmax(['{foo}b', '{foo}a'], timeout=1) == ( - b'{foo}a', b'a2', 2) - assert r.bzpopmax(['{foo}b', '{foo}a'], timeout=1) == ( - b'{foo}a', b'a1', 1) - assert r.bzpopmax(['{foo}b', '{foo}a'], timeout=1) is None - r.zadd('{foo}c', {'c1': 100}) - assert r.bzpopmax('{foo}c', timeout=1) == (b'{foo}c', b'c1', 100) - - @skip_if_server_version_lt('4.9.0') + r.zadd("{foo}a", {"a1": 1, "a2": 2}) + r.zadd("{foo}b", {"b1": 10, "b2": 20}) + assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b2", 20) + assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b1", 10) + assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a2", 2) + assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a1", 1) + assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) is None + r.zadd("{foo}c", {"c1": 100}) + assert r.bzpopmax("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) + + @skip_if_server_version_lt("4.9.0") def test_cluster_bzpopmin(self, r): - r.zadd('{foo}a', {'a1': 1, 'a2': 2}) - r.zadd('{foo}b', {'b1': 10, 'b2': 20}) - assert r.bzpopmin(['{foo}b', '{foo}a'], timeout=1) == ( - b'{foo}b', b'b1', 10) - assert r.bzpopmin(['{foo}b', '{foo}a'], timeout=1) == ( - b'{foo}b', b'b2', 20) - assert r.bzpopmin(['{foo}b', '{foo}a'], timeout=1) == ( - b'{foo}a', b'a1', 1) - assert r.bzpopmin(['{foo}b', '{foo}a'], timeout=1) == ( - b'{foo}a', b'a2', 2) - assert r.bzpopmin(['{foo}b', '{foo}a'], timeout=1) is None - r.zadd('{foo}c', {'c1': 100}) - assert r.bzpopmin('{foo}c', timeout=1) == (b'{foo}c', b'c1', 100) - - @skip_if_server_version_lt('6.2.0') + r.zadd("{foo}a", {"a1": 1, "a2": 2}) + r.zadd("{foo}b", {"b1": 10, "b2": 20}) + assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b1", 10) + assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b2", 20) + assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a1", 1) + assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a2", 2) + assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) is None + r.zadd("{foo}c", {"c1": 100}) + assert r.bzpopmin("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) + + @skip_if_server_version_lt("6.2.0") def test_cluster_zrangestore(self, r): - r.zadd('{foo}a', {'a1': 1, 'a2': 2, 'a3': 3}) - assert r.zrangestore('{foo}b', '{foo}a', 0, 1) - assert r.zrange('{foo}b', 0, -1) == [b'a1', b'a2'] - assert r.zrangestore('{foo}b', '{foo}a', 1, 2) - assert r.zrange('{foo}b', 0, -1) == [b'a2', b'a3'] - assert r.zrange('{foo}b', 0, -1, withscores=True) == \ - [(b'a2', 2), (b'a3', 3)] + r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) + assert r.zrangestore("{foo}b", "{foo}a", 0, 1) + assert r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] + assert r.zrangestore("{foo}b", "{foo}a", 1, 2) + assert r.zrange("{foo}b", 0, -1) == [b"a2", b"a3"] + assert r.zrange("{foo}b", 0, -1, withscores=True) == [(b"a2", 2), (b"a3", 3)] # reversed order - assert r.zrangestore('{foo}b', '{foo}a', 1, 2, desc=True) - assert r.zrange('{foo}b', 0, -1) == [b'a1', b'a2'] + assert r.zrangestore("{foo}b", "{foo}a", 1, 2, desc=True) + assert r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] # by score - assert r.zrangestore('{foo}b', '{foo}a', 2, 1, byscore=True, - offset=0, num=1, desc=True) - assert r.zrange('{foo}b', 0, -1) == [b'a2'] + assert r.zrangestore( + "{foo}b", "{foo}a", 2, 1, byscore=True, offset=0, num=1, desc=True + ) + assert r.zrange("{foo}b", 0, -1) == [b"a2"] # by lex - assert r.zrangestore('{foo}b', '{foo}a', '[a2', '(a3', bylex=True, - offset=0, num=1) - assert r.zrange('{foo}b', 0, -1) == [b'a2'] + assert r.zrangestore( + "{foo}b", "{foo}a", "[a2", "(a3", bylex=True, offset=0, num=1 + ) + assert r.zrange("{foo}b", 0, -1) == [b"a2"] - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_cluster_zunion(self, r): - r.zadd('{foo}a', {'a1': 1, 'a2': 1, 'a3': 1}) - r.zadd('{foo}b', {'a1': 2, 'a2': 2, 'a3': 2}) - r.zadd('{foo}c', {'a1': 6, 'a3': 5, 'a4': 4}) + r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) # sum - assert r.zunion(['{foo}a', '{foo}b', '{foo}c']) == \ - [b'a2', b'a4', b'a3', b'a1'] - assert r.zunion(['{foo}a', '{foo}b', '{foo}c'], withscores=True) == \ - [(b'a2', 3), (b'a4', 4), (b'a3', 8), (b'a1', 9)] + assert r.zunion(["{foo}a", "{foo}b", "{foo}c"]) == [b"a2", b"a4", b"a3", b"a1"] + assert r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ + (b"a2", 3), + (b"a4", 4), + (b"a3", 8), + (b"a1", 9), + ] # max - assert r.zunion(['{foo}a', '{foo}b', '{foo}c'], aggregate='MAX', - withscores=True) \ - == [(b'a2', 2), (b'a4', 4), (b'a3', 5), (b'a1', 6)] + assert r.zunion( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True + ) == [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)] # min - assert r.zunion(['{foo}a', '{foo}b', '{foo}c'], aggregate='MIN', - withscores=True) \ - == [(b'a1', 1), (b'a2', 1), (b'a3', 1), (b'a4', 4)] + assert r.zunion( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True + ) == [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)] # with weight - assert r.zunion({'{foo}a': 1, '{foo}b': 2, '{foo}c': 3}, - withscores=True) \ - == [(b'a2', 5), (b'a4', 12), (b'a3', 20), (b'a1', 23)] + assert r.zunion({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) == [ + (b"a2", 5), + (b"a4", 12), + (b"a3", 20), + (b"a1", 23), + ] def test_cluster_zunionstore_sum(self, r): - r.zadd('{foo}a', {'a1': 1, 'a2': 1, 'a3': 1}) - r.zadd('{foo}b', {'a1': 2, 'a2': 2, 'a3': 2}) - r.zadd('{foo}c', {'a1': 6, 'a3': 5, 'a4': 4}) - assert r.zunionstore('{foo}d', ['{foo}a', '{foo}b', '{foo}c']) == 4 - assert r.zrange('{foo}d', 0, -1, withscores=True) == \ - [(b'a2', 3), (b'a4', 4), (b'a3', 8), (b'a1', 9)] + r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 4 + assert r.zrange("{foo}d", 0, -1, withscores=True) == [ + (b"a2", 3), + (b"a4", 4), + (b"a3", 8), + (b"a1", 9), + ] def test_cluster_zunionstore_max(self, r): - r.zadd('{foo}a', {'a1': 1, 'a2': 1, 'a3': 1}) - r.zadd('{foo}b', {'a1': 2, 'a2': 2, 'a3': 2}) - r.zadd('{foo}c', {'a1': 6, 'a3': 5, 'a4': 4}) - assert r.zunionstore( - '{foo}d', ['{foo}a', '{foo}b', '{foo}c'], aggregate='MAX') == 4 - assert r.zrange('{foo}d', 0, -1, withscores=True) == \ - [(b'a2', 2), (b'a4', 4), (b'a3', 5), (b'a1', 6)] + r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX") + == 4 + ) + assert r.zrange("{foo}d", 0, -1, withscores=True) == [ + (b"a2", 2), + (b"a4", 4), + (b"a3", 5), + (b"a1", 6), + ] def test_cluster_zunionstore_min(self, r): - r.zadd('{foo}a', {'a1': 1, 'a2': 2, 'a3': 3}) - r.zadd('{foo}b', {'a1': 2, 'a2': 2, 'a3': 4}) - r.zadd('{foo}c', {'a1': 6, 'a3': 5, 'a4': 4}) - assert r.zunionstore( - '{foo}d', ['{foo}a', '{foo}b', '{foo}c'], aggregate='MIN') == 4 - assert r.zrange('{foo}d', 0, -1, withscores=True) == \ - [(b'a1', 1), (b'a2', 2), (b'a3', 3), (b'a4', 4)] + r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 4}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN") + == 4 + ) + assert r.zrange("{foo}d", 0, -1, withscores=True) == [ + (b"a1", 1), + (b"a2", 2), + (b"a3", 3), + (b"a4", 4), + ] def test_cluster_zunionstore_with_weight(self, r): - r.zadd('{foo}a', {'a1': 1, 'a2': 1, 'a3': 1}) - r.zadd('{foo}b', {'a1': 2, 'a2': 2, 'a3': 2}) - r.zadd('{foo}c', {'a1': 6, 'a3': 5, 'a4': 4}) - assert r.zunionstore( - '{foo}d', {'{foo}a': 1, '{foo}b': 2, '{foo}c': 3}) == 4 - assert r.zrange('{foo}d', 0, -1, withscores=True) == \ - [(b'a2', 5), (b'a4', 12), (b'a3', 20), (b'a1', 23)] - - @skip_if_server_version_lt('2.8.9') + r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zunionstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 4 + assert r.zrange("{foo}d", 0, -1, withscores=True) == [ + (b"a2", 5), + (b"a4", 12), + (b"a3", 20), + (b"a1", 23), + ] + + @skip_if_server_version_lt("2.8.9") def test_cluster_pfcount(self, r): - members = {b'1', b'2', b'3'} - r.pfadd('{foo}a', *members) - assert r.pfcount('{foo}a') == len(members) - members_b = {b'2', b'3', b'4'} - r.pfadd('{foo}b', *members_b) - assert r.pfcount('{foo}b') == len(members_b) - assert r.pfcount('{foo}a', '{foo}b') == len(members_b.union(members)) - - @skip_if_server_version_lt('2.8.9') + members = {b"1", b"2", b"3"} + r.pfadd("{foo}a", *members) + assert r.pfcount("{foo}a") == len(members) + members_b = {b"2", b"3", b"4"} + r.pfadd("{foo}b", *members_b) + assert r.pfcount("{foo}b") == len(members_b) + assert r.pfcount("{foo}a", "{foo}b") == len(members_b.union(members)) + + @skip_if_server_version_lt("2.8.9") def test_cluster_pfmerge(self, r): - mema = {b'1', b'2', b'3'} - memb = {b'2', b'3', b'4'} - memc = {b'5', b'6', b'7'} - r.pfadd('{foo}a', *mema) - r.pfadd('{foo}b', *memb) - r.pfadd('{foo}c', *memc) - r.pfmerge('{foo}d', '{foo}c', '{foo}a') - assert r.pfcount('{foo}d') == 6 - r.pfmerge('{foo}d', '{foo}b') - assert r.pfcount('{foo}d') == 7 + mema = {b"1", b"2", b"3"} + memb = {b"2", b"3", b"4"} + memc = {b"5", b"6", b"7"} + r.pfadd("{foo}a", *mema) + r.pfadd("{foo}b", *memb) + r.pfadd("{foo}c", *memc) + r.pfmerge("{foo}d", "{foo}c", "{foo}a") + assert r.pfcount("{foo}d") == 6 + r.pfmerge("{foo}d", "{foo}b") + assert r.pfcount("{foo}d") == 7 def test_cluster_sort_store(self, r): - r.rpush('{foo}a', '2', '3', '1') - assert r.sort('{foo}a', store='{foo}sorted_values') == 3 - assert r.lrange('{foo}sorted_values', 0, -1) == [b'1', b'2', b'3'] + r.rpush("{foo}a", "2", "3", "1") + assert r.sort("{foo}a", store="{foo}sorted_values") == 3 + assert r.lrange("{foo}sorted_values", 0, -1) == [b"1", b"2", b"3"] # GEO COMMANDS - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_cluster_geosearchstore(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) - r.geoadd('{foo}barcelona', values) - r.geosearchstore('{foo}places_barcelona', '{foo}barcelona', - longitude=2.191, latitude=41.433, radius=1000) - assert r.zrange('{foo}places_barcelona', 0, -1) == [b'place1'] + r.geoadd("{foo}barcelona", values) + r.geosearchstore( + "{foo}places_barcelona", + "{foo}barcelona", + longitude=2.191, + latitude=41.433, + radius=1000, + ) + assert r.zrange("{foo}places_barcelona", 0, -1) == [b"place1"] @skip_unless_arch_bits(64) - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_geosearchstore_dist(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) - r.geoadd('{foo}barcelona', values) - r.geosearchstore('{foo}places_barcelona', '{foo}barcelona', - longitude=2.191, latitude=41.433, - radius=1000, storedist=True) + r.geoadd("{foo}barcelona", values) + r.geosearchstore( + "{foo}places_barcelona", + "{foo}barcelona", + longitude=2.191, + latitude=41.433, + radius=1000, + storedist=True, + ) # instead of save the geo score, the distance is saved. - assert r.zscore('{foo}places_barcelona', 'place1') == 88.05060698409301 + assert r.zscore("{foo}places_barcelona", "place1") == 88.05060698409301 - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_cluster_georadius_store(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) - r.geoadd('{foo}barcelona', values) - r.georadius('{foo}barcelona', 2.191, 41.433, - 1000, store='{foo}places_barcelona') - assert r.zrange('{foo}places_barcelona', 0, -1) == [b'place1'] + r.geoadd("{foo}barcelona", values) + r.georadius( + "{foo}barcelona", 2.191, 41.433, 1000, store="{foo}places_barcelona" + ) + assert r.zrange("{foo}places_barcelona", 0, -1) == [b"place1"] @skip_unless_arch_bits(64) - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_cluster_georadius_store_dist(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) - r.geoadd('{foo}barcelona', values) - r.georadius('{foo}barcelona', 2.191, 41.433, 1000, - store_dist='{foo}places_barcelona') + r.geoadd("{foo}barcelona", values) + r.georadius( + "{foo}barcelona", 2.191, 41.433, 1000, store_dist="{foo}places_barcelona" + ) # instead of save the geo score, the distance is saved. - assert r.zscore('{foo}places_barcelona', 'place1') == 88.05060698409301 + assert r.zscore("{foo}places_barcelona", "place1") == 88.05060698409301 def test_cluster_dbsize(self, r): - d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} + d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} assert r.mset_nonatomic(d) - assert r.dbsize(target_nodes='primaries') == len(d) + assert r.dbsize(target_nodes="primaries") == len(d) def test_cluster_keys(self, r): assert r.keys() == [] - keys_with_underscores = {b'test_a', b'test_b'} - keys = keys_with_underscores.union({b'testc'}) + keys_with_underscores = {b"test_a", b"test_b"} + keys = keys_with_underscores.union({b"testc"}) for key in keys: r[key] = 1 - assert set(r.keys(pattern='test_*', target_nodes='primaries')) == \ - keys_with_underscores - assert set(r.keys(pattern='test*', target_nodes='primaries')) == keys + assert ( + set(r.keys(pattern="test_*", target_nodes="primaries")) + == keys_with_underscores + ) + assert set(r.keys(pattern="test*", target_nodes="primaries")) == keys # SCAN COMMANDS - @skip_if_server_version_lt('2.8.0') + @skip_if_server_version_lt("2.8.0") def test_cluster_scan(self, r): - r.set('a', 1) - r.set('b', 2) - r.set('c', 3) - cursor, keys = r.scan(target_nodes='primaries') + r.set("a", 1) + r.set("b", 2) + r.set("c", 3) + cursor, keys = r.scan(target_nodes="primaries") assert cursor == 0 - assert set(keys) == {b'a', b'b', b'c'} - _, keys = r.scan(match='a', target_nodes='primaries') - assert set(keys) == {b'a'} + assert set(keys) == {b"a", b"b", b"c"} + _, keys = r.scan(match="a", target_nodes="primaries") + assert set(keys) == {b"a"} @skip_if_server_version_lt("6.0.0") def test_cluster_scan_type(self, r): - r.sadd('a-set', 1) - r.hset('a-hash', 'foo', 2) - r.lpush('a-list', 'aux', 3) - _, keys = r.scan(match='a*', _type='SET', target_nodes='primaries') - assert set(keys) == {b'a-set'} + r.sadd("a-set", 1) + r.hset("a-hash", "foo", 2) + r.lpush("a-list", "aux", 3) + _, keys = r.scan(match="a*", _type="SET", target_nodes="primaries") + assert set(keys) == {b"a-set"} - @skip_if_server_version_lt('2.8.0') + @skip_if_server_version_lt("2.8.0") def test_cluster_scan_iter(self, r): - r.set('a', 1) - r.set('b', 2) - r.set('c', 3) - keys = list(r.scan_iter(target_nodes='primaries')) - assert set(keys) == {b'a', b'b', b'c'} - keys = list(r.scan_iter(match='a', target_nodes='primaries')) - assert set(keys) == {b'a'} + r.set("a", 1) + r.set("b", 2) + r.set("c", 3) + keys = list(r.scan_iter(target_nodes="primaries")) + assert set(keys) == {b"a", b"b", b"c"} + keys = list(r.scan_iter(match="a", target_nodes="primaries")) + assert set(keys) == {b"a"} def test_cluster_randomkey(self, r): - node = r.get_node_from_key('{foo}') + node = r.get_node_from_key("{foo}") assert r.randomkey(target_nodes=node) is None - for key in ('{foo}a', '{foo}b', '{foo}c'): + for key in ("{foo}a", "{foo}b", "{foo}c"): r[key] = 1 - assert r.randomkey(target_nodes=node) in \ - (b'{foo}a', b'{foo}b', b'{foo}c') + assert r.randomkey(target_nodes=node) in (b"{foo}a", b"{foo}b", b"{foo}c") @pytest.mark.onlycluster @@ -1704,7 +1783,7 @@ def test_load_balancer(self, r): node_5 = ClusterNode(default_host, 6375, REPLICA) n_manager.slots_cache = { slot_1: [node_1, node_2, node_3], - slot_2: [node_4, node_5] + slot_2: [node_4, node_5], } primary1_name = n_manager.slots_cache[slot_1][0].name primary2_name = n_manager.slots_cache[slot_2][0].name @@ -1730,17 +1809,17 @@ def test_init_slots_cache_not_all_slots_covered(self): """ # Missing slot 5460 cluster_slots = [ - [0, 5459, ['127.0.0.1', 7000], ['127.0.0.1', 7003]], - [5461, 10922, ['127.0.0.1', 7001], - ['127.0.0.1', 7004]], - [10923, 16383, ['127.0.0.1', 7002], - ['127.0.0.1', 7005]], + [0, 5459, ["127.0.0.1", 7000], ["127.0.0.1", 7003]], + [5461, 10922, ["127.0.0.1", 7001], ["127.0.0.1", 7004]], + [10923, 16383, ["127.0.0.1", 7002], ["127.0.0.1", 7005]], ] with pytest.raises(RedisClusterException) as ex: - get_mocked_redis_client(host=default_host, port=default_port, - cluster_slots=cluster_slots) + get_mocked_redis_client( + host=default_host, port=default_port, cluster_slots=cluster_slots + ) assert str(ex.value).startswith( - "All slots are not covered after query all startup_nodes.") + "All slots are not covered after query all startup_nodes." + ) def test_init_slots_cache_not_require_full_coverage_error(self): """ @@ -1750,18 +1829,19 @@ def test_init_slots_cache_not_require_full_coverage_error(self): """ # Missing slot 5460 cluster_slots = [ - [0, 5459, ['127.0.0.1', 7000], ['127.0.0.1', 7003]], - [5461, 10922, ['127.0.0.1', 7001], - ['127.0.0.1', 7004]], - [10923, 16383, ['127.0.0.1', 7002], - ['127.0.0.1', 7005]], + [0, 5459, ["127.0.0.1", 7000], ["127.0.0.1", 7003]], + [5461, 10922, ["127.0.0.1", 7001], ["127.0.0.1", 7004]], + [10923, 16383, ["127.0.0.1", 7002], ["127.0.0.1", 7005]], ] with pytest.raises(RedisClusterException): - get_mocked_redis_client(host=default_host, port=default_port, - cluster_slots=cluster_slots, - require_full_coverage=False, - coverage_result='yes') + get_mocked_redis_client( + host=default_host, + port=default_port, + cluster_slots=cluster_slots, + require_full_coverage=False, + coverage_result="yes", + ) def test_init_slots_cache_not_require_full_coverage_success(self): """ @@ -1771,17 +1851,18 @@ def test_init_slots_cache_not_require_full_coverage_success(self): """ # Missing slot 5460 cluster_slots = [ - [0, 5459, ['127.0.0.1', 7000], ['127.0.0.1', 7003]], - [5461, 10922, ['127.0.0.1', 7001], - ['127.0.0.1', 7004]], - [10923, 16383, ['127.0.0.1', 7002], - ['127.0.0.1', 7005]], + [0, 5459, ["127.0.0.1", 7000], ["127.0.0.1", 7003]], + [5461, 10922, ["127.0.0.1", 7001], ["127.0.0.1", 7004]], + [10923, 16383, ["127.0.0.1", 7002], ["127.0.0.1", 7005]], ] - rc = get_mocked_redis_client(host=default_host, port=default_port, - cluster_slots=cluster_slots, - require_full_coverage=False, - coverage_result='no') + rc = get_mocked_redis_client( + host=default_host, + port=default_port, + cluster_slots=cluster_slots, + require_full_coverage=False, + coverage_result="no", + ) assert 5460 not in rc.nodes_manager.slots_cache @@ -1793,20 +1874,22 @@ def test_init_slots_cache_not_require_full_coverage_skips_check(self): """ # Missing slot 5460 cluster_slots = [ - [0, 5459, ['127.0.0.1', 7000], ['127.0.0.1', 7003]], - [5461, 10922, ['127.0.0.1', 7001], - ['127.0.0.1', 7004]], - [10923, 16383, ['127.0.0.1', 7002], - ['127.0.0.1', 7005]], + [0, 5459, ["127.0.0.1", 7000], ["127.0.0.1", 7003]], + [5461, 10922, ["127.0.0.1", 7001], ["127.0.0.1", 7004]], + [10923, 16383, ["127.0.0.1", 7002], ["127.0.0.1", 7005]], ] - with patch.object(NodesManager, - 'cluster_require_full_coverage') as conf_check_mock: - rc = get_mocked_redis_client(host=default_host, port=default_port, - cluster_slots=cluster_slots, - require_full_coverage=False, - skip_full_coverage_check=True, - coverage_result='no') + with patch.object( + NodesManager, "cluster_require_full_coverage" + ) as conf_check_mock: + rc = get_mocked_redis_client( + host=default_host, + port=default_port, + cluster_slots=cluster_slots, + require_full_coverage=False, + skip_full_coverage_check=True, + coverage_result="no", + ) assert conf_check_mock.called is False assert 5460 not in rc.nodes_manager.slots_cache @@ -1816,17 +1899,18 @@ def test_init_slots_cache(self): Test that slots cache can in initialized and all slots are covered """ good_slots_resp = [ - [0, 5460, ['127.0.0.1', 7000], ['127.0.0.2', 7003]], - [5461, 10922, ['127.0.0.1', 7001], ['127.0.0.2', 7004]], - [10923, 16383, ['127.0.0.1', 7002], ['127.0.0.2', 7005]], + [0, 5460, ["127.0.0.1", 7000], ["127.0.0.2", 7003]], + [5461, 10922, ["127.0.0.1", 7001], ["127.0.0.2", 7004]], + [10923, 16383, ["127.0.0.1", 7002], ["127.0.0.2", 7005]], ] - rc = get_mocked_redis_client(host=default_host, port=default_port, - cluster_slots=good_slots_resp) + rc = get_mocked_redis_client( + host=default_host, port=default_port, cluster_slots=good_slots_resp + ) n_manager = rc.nodes_manager assert len(n_manager.slots_cache) == REDIS_CLUSTER_HASH_SLOTS for slot_info in good_slots_resp: - all_hosts = ['127.0.0.1', '127.0.0.2'] + all_hosts = ["127.0.0.1", "127.0.0.2"] all_ports = [7000, 7001, 7002, 7003, 7004, 7005] slot_start = slot_info[0] slot_end = slot_info[1] @@ -1861,8 +1945,8 @@ def test_init_slots_cache_slots_collision(self, request): raise an error. In this test both nodes will say that the first slots block should be bound to different servers. """ - with patch.object(NodesManager, - 'create_redis_node') as create_redis_node: + with patch.object(NodesManager, "create_redis_node") as create_redis_node: + def create_mocked_redis_node(host, port, **kwargs): """ Helper function to return custom slots cache data from @@ -1873,14 +1957,14 @@ def create_mocked_redis_node(host, port, **kwargs): [ 0, 5460, - ['127.0.0.1', 7000], - ['127.0.0.1', 7003], + ["127.0.0.1", 7000], + ["127.0.0.1", 7003], ], [ 5461, 10922, - ['127.0.0.1', 7001], - ['127.0.0.1', 7004], + ["127.0.0.1", 7001], + ["127.0.0.1", 7004], ], ] @@ -1889,31 +1973,28 @@ def create_mocked_redis_node(host, port, **kwargs): [ 0, 5460, - ['127.0.0.1', 7001], - ['127.0.0.1', 7003], + ["127.0.0.1", 7001], + ["127.0.0.1", 7003], ], [ 5461, 10922, - ['127.0.0.1', 7000], - ['127.0.0.1', 7004], + ["127.0.0.1", 7000], + ["127.0.0.1", 7004], ], ] else: result = [] - r_node = Redis( - host=host, - port=port - ) + r_node = Redis(host=host, port=port) orig_execute_command = r_node.execute_command def execute_command(*args, **kwargs): - if args[0] == 'CLUSTER SLOTS': + if args[0] == "CLUSTER SLOTS": return result - elif args[1] == 'cluster-require-full-coverage': - return {'cluster-require-full-coverage': 'yes'} + elif args[1] == "cluster-require-full-coverage": + return {"cluster-require-full-coverage": "yes"} else: return orig_execute_command(*args, **kwargs) @@ -1923,12 +2004,12 @@ def execute_command(*args, **kwargs): create_redis_node.side_effect = create_mocked_redis_node with pytest.raises(RedisClusterException) as ex: - node_1 = ClusterNode('127.0.0.1', 7000) - node_2 = ClusterNode('127.0.0.1', 7001) + node_1 = ClusterNode("127.0.0.1", 7000) + node_2 = ClusterNode("127.0.0.1", 7001) RedisCluster(startup_nodes=[node_1, node_2]) assert str(ex.value).startswith( - "startup_nodes could not agree on a valid slots cache"), str( - ex.value) + "startup_nodes could not agree on a valid slots cache" + ), str(ex.value) def test_cluster_one_instance(self): """ @@ -1936,9 +2017,8 @@ def test_cluster_one_instance(self): be validated they work. """ node = ClusterNode(default_host, default_port) - cluster_slots = [[0, 16383, ['', default_port]]] - rc = get_mocked_redis_client(startup_nodes=[node], - cluster_slots=cluster_slots) + cluster_slots = [[0, 16383, ["", default_port]]] + rc = get_mocked_redis_client(startup_nodes=[node], cluster_slots=cluster_slots) n = rc.nodes_manager assert len(n.nodes_cache) == 1 @@ -1955,28 +2035,30 @@ def test_init_with_down_node(self): If I can't connect to one of the nodes, everything should still work. But if I can't connect to any of the nodes, exception should be thrown. """ - with patch.object(NodesManager, - 'create_redis_node') as create_redis_node: + with patch.object(NodesManager, "create_redis_node") as create_redis_node: + def create_mocked_redis_node(host, port, **kwargs): if port == 7000: - raise ConnectionError('mock connection error for 7000') + raise ConnectionError("mock connection error for 7000") r_node = Redis(host=host, port=port, decode_responses=True) def execute_command(*args, **kwargs): - if args[0] == 'CLUSTER SLOTS': + if args[0] == "CLUSTER SLOTS": return [ [ - 0, 8191, - ['127.0.0.1', 7001, 'node_1'], + 0, + 8191, + ["127.0.0.1", 7001, "node_1"], ], [ - 8192, 16383, - ['127.0.0.1', 7002, 'node_2'], - ] + 8192, + 16383, + ["127.0.0.1", 7002, "node_2"], + ], ] - elif args[1] == 'cluster-require-full-coverage': - return {'cluster-require-full-coverage': 'yes'} + elif args[1] == "cluster-require-full-coverage": + return {"cluster-require-full-coverage": "yes"} r_node.execute_command = execute_command @@ -1984,25 +2066,30 @@ def execute_command(*args, **kwargs): create_redis_node.side_effect = create_mocked_redis_node - node_1 = ClusterNode('127.0.0.1', 7000) - node_2 = ClusterNode('127.0.0.1', 7001) + node_1 = ClusterNode("127.0.0.1", 7000) + node_2 = ClusterNode("127.0.0.1", 7001) # If all startup nodes fail to connect, connection error should be # thrown with pytest.raises(RedisClusterException) as e: RedisCluster(startup_nodes=[node_1]) - assert 'Redis Cluster cannot be connected' in str(e.value) + assert "Redis Cluster cannot be connected" in str(e.value) - with patch.object(CommandsParser, 'initialize', - autospec=True) as cmd_parser_initialize: + with patch.object( + CommandsParser, "initialize", autospec=True + ) as cmd_parser_initialize: def cmd_init_mock(self, r): - self.commands = {'get': {'name': 'get', 'arity': 2, - 'flags': ['readonly', - 'fast'], - 'first_key_pos': 1, - 'last_key_pos': 1, - 'step_count': 1}} + self.commands = { + "get": { + "name": "get", + "arity": 2, + "flags": ["readonly", "fast"], + "first_key_pos": 1, + "last_key_pos": 1, + "step_count": 1, + } + } cmd_parser_initialize.side_effect = cmd_init_mock # When at least one startup node is reachable, the cluster @@ -2040,7 +2127,7 @@ def test_init_pubusub_without_specifying_node(self, r): should be determined based on the keyslot of the first command execution. """ - channel_name = 'foo' + channel_name = "foo" node = r.get_node_from_key(channel_name) p = r.pubsub() assert p.get_pubsub_node() is None @@ -2052,7 +2139,7 @@ def test_init_pubsub_with_a_non_existent_node(self, r): Test creation of pubsub instance with node that doesn't exists in the cluster. RedisClusterException should be raised. """ - node = ClusterNode('1.1.1.1', 1111) + node = ClusterNode("1.1.1.1", 1111) with pytest.raises(RedisClusterException): r.pubsub(node) @@ -2063,7 +2150,7 @@ def test_init_pubsub_with_a_non_existent_host_port(self, r): RedisClusterException should be raised. """ with pytest.raises(RedisClusterException): - r.pubsub(host='1.1.1.1', port=1111) + r.pubsub(host="1.1.1.1", port=1111) def test_init_pubsub_host_or_port(self, r): """ @@ -2071,7 +2158,7 @@ def test_init_pubsub_host_or_port(self, r): versa. DataError should be raised. """ with pytest.raises(DataError): - r.pubsub(host='localhost') + r.pubsub(host="localhost") with pytest.raises(DataError): r.pubsub(port=16379) @@ -2131,14 +2218,17 @@ def test_blocked_arguments(self, r): with pytest.raises(RedisClusterException) as ex: r.pipeline(transaction=True) - assert str(ex.value).startswith( - "transaction is deprecated in cluster mode") is True + assert ( + str(ex.value).startswith("transaction is deprecated in cluster mode") + is True + ) with pytest.raises(RedisClusterException) as ex: r.pipeline(shard_hint=True) - assert str(ex.value).startswith( - "shard_hint is deprecated in cluster mode") is True + assert ( + str(ex.value).startswith("shard_hint is deprecated in cluster mode") is True + ) def test_redis_cluster_pipeline(self, r): """ @@ -2147,7 +2237,7 @@ def test_redis_cluster_pipeline(self, r): with r.pipeline() as pipe: pipe.set("foo", "bar") pipe.get("foo") - assert pipe.execute() == [True, b'bar'] + assert pipe.execute() == [True, b"bar"] def test_mget_disabled(self, r): """ @@ -2155,7 +2245,7 @@ def test_mget_disabled(self, r): """ with r.pipeline() as pipe: with pytest.raises(RedisClusterException): - pipe.mget(['a']) + pipe.mget(["a"]) def test_mset_disabled(self, r): """ @@ -2163,7 +2253,7 @@ def test_mset_disabled(self, r): """ with r.pipeline() as pipe: with pytest.raises(RedisClusterException): - pipe.mset({'a': 1, 'b': 2}) + pipe.mset({"a": 1, "b": 2}) def test_rename_disabled(self, r): """ @@ -2171,7 +2261,7 @@ def test_rename_disabled(self, r): """ with r.pipeline(transaction=False) as pipe: with pytest.raises(RedisClusterException): - pipe.rename('a', 'b') + pipe.rename("a", "b") def test_renamenx_disabled(self, r): """ @@ -2179,15 +2269,15 @@ def test_renamenx_disabled(self, r): """ with r.pipeline(transaction=False) as pipe: with pytest.raises(RedisClusterException): - pipe.renamenx('a', 'b') + pipe.renamenx("a", "b") def test_delete_single(self, r): """ Test a single delete operation """ - r['a'] = 1 + r["a"] = 1 with r.pipeline(transaction=False) as pipe: - pipe.delete('a') + pipe.delete("a") assert pipe.execute() == [1] def test_multi_delete_unsupported(self, r): @@ -2195,10 +2285,10 @@ def test_multi_delete_unsupported(self, r): Test that multi delete operation is unsupported """ with r.pipeline(transaction=False) as pipe: - r['a'] = 1 - r['b'] = 2 + r["a"] = 1 + r["b"] = 2 with pytest.raises(RedisClusterException): - pipe.delete('a', 'b') + pipe.delete("a", "b") def test_brpoplpush_disabled(self, r): """ @@ -2293,41 +2383,40 @@ def test_multi_key_operation_with_a_single_slot(self, r): Test multi key operation with a single slot """ pipe = r.pipeline(transaction=False) - pipe.set('a{foo}', 1) - pipe.set('b{foo}', 2) - pipe.set('c{foo}', 3) - pipe.get('a{foo}') - pipe.get('b{foo}') - pipe.get('c{foo}') + pipe.set("a{foo}", 1) + pipe.set("b{foo}", 2) + pipe.set("c{foo}", 3) + pipe.get("a{foo}") + pipe.get("b{foo}") + pipe.get("c{foo}") res = pipe.execute() - assert res == [True, True, True, b'1', b'2', b'3'] + assert res == [True, True, True, b"1", b"2", b"3"] def test_multi_key_operation_with_multi_slots(self, r): """ Test multi key operation with more than one slot """ pipe = r.pipeline(transaction=False) - pipe.set('a{foo}', 1) - pipe.set('b{foo}', 2) - pipe.set('c{foo}', 3) - pipe.set('bar', 4) - pipe.set('bazz', 5) - pipe.get('a{foo}') - pipe.get('b{foo}') - pipe.get('c{foo}') - pipe.get('bar') - pipe.get('bazz') + pipe.set("a{foo}", 1) + pipe.set("b{foo}", 2) + pipe.set("c{foo}", 3) + pipe.set("bar", 4) + pipe.set("bazz", 5) + pipe.get("a{foo}") + pipe.get("b{foo}") + pipe.get("c{foo}") + pipe.get("bar") + pipe.get("bazz") res = pipe.execute() - assert res == [True, True, True, True, True, b'1', b'2', b'3', b'4', - b'5'] + assert res == [True, True, True, True, True, b"1", b"2", b"3", b"4", b"5"] def test_connection_error_not_raised(self, r): """ Test that the pipeline doesn't raise an error on connection error when raise_on_error=False """ - key = 'foo' + key = "foo" node = r.get_node_from_key(key, False) def raise_connection_error(): @@ -2345,7 +2434,7 @@ def test_connection_error_raised(self, r): Test that the pipeline raises an error on connection error when raise_on_error=True """ - key = 'foo' + key = "foo" node = r.get_node_from_key(key, False) def raise_connection_error(): @@ -2361,7 +2450,7 @@ def test_asking_error(self, r): """ Test redirection on ASK error """ - key = 'foo' + key = "foo" first_node = r.get_node_from_key(key, False) ask_node = None for node in r.get_nodes(): @@ -2369,8 +2458,7 @@ def test_asking_error(self, r): ask_node = node break if ask_node is None: - warnings.warn("skipping this test since the cluster has only one " - "node") + warnings.warn("skipping this test since the cluster has only one " "node") return ask_msg = f"{r.keyslot(key)} {ask_node.host}:{ask_node.port}" @@ -2379,11 +2467,11 @@ def raise_ask_error(): with r.pipeline() as pipe: mock_node_resp_func(first_node, raise_ask_error) - mock_node_resp(ask_node, 'MOCK_OK') + mock_node_resp(ask_node, "MOCK_OK") res = pipe.get(key).execute() assert first_node.redis_connection.connection.read_response.called assert ask_node.redis_connection.connection.read_response.called - assert res == ['MOCK_OK'] + assert res == ["MOCK_OK"] def test_empty_stack(self, r): """ @@ -2405,17 +2493,16 @@ def test_pipeline_readonly(self, r): """ On readonly mode, we supports get related stuff only. """ - r.readonly(target_nodes='all') - r.set('foo71', 'a1') # we assume this key is set on 127.0.0.1:7001 - r.zadd('foo88', - {'z1': 1}) # we assume this key is set on 127.0.0.1:7002 - r.zadd('foo88', {'z2': 4}) + r.readonly(target_nodes="all") + r.set("foo71", "a1") # we assume this key is set on 127.0.0.1:7001 + r.zadd("foo88", {"z1": 1}) # we assume this key is set on 127.0.0.1:7002 + r.zadd("foo88", {"z2": 4}) with r.pipeline() as readonly_pipe: - readonly_pipe.get('foo71').zrange('foo88', 0, 5, withscores=True) + readonly_pipe.get("foo71").zrange("foo88", 0, 5, withscores=True) assert readonly_pipe.execute() == [ - b'a1', - [(b'z1', 1.0), (b'z2', 4)], + b"a1", + [(b"z1", 1.0), (b"z2", 4)], ] def test_moved_redirection_on_slave_with_default(self, r): @@ -2423,8 +2510,8 @@ def test_moved_redirection_on_slave_with_default(self, r): On Pipeline, we redirected once and finally get from master with readonly client when data is completely moved. """ - key = 'bar' - r.set(key, 'foo') + key = "bar" + r.set(key, "foo") # set read_from_replicas to True r.read_from_replicas = True primary = r.get_node_from_key(key, False) @@ -2456,15 +2543,15 @@ def test_readonly_pipeline_from_readonly_client(self, request): """ # Create a cluster with reading from replications ro = _get_client(RedisCluster, request, read_from_replicas=True) - key = 'bar' - ro.set(key, 'foo') + key = "bar" + ro.set(key, "foo") import time + time.sleep(0.2) with ro.pipeline() as readonly_pipe: - mock_all_nodes_resp(ro, 'MOCK_OK') + mock_all_nodes_resp(ro, "MOCK_OK") assert readonly_pipe.read_from_replicas is True - assert readonly_pipe.get(key).get( - key).execute() == ['MOCK_OK', 'MOCK_OK'] + assert readonly_pipe.get(key).get(key).execute() == ["MOCK_OK", "MOCK_OK"] slot_nodes = ro.nodes_manager.slots_cache[ro.keyslot(key)] if len(slot_nodes) > 1: executed_on_replica = False diff --git a/tests/test_commands.py b/tests/test_commands.py index 444a163489..1eb35f8673 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1,19 +1,20 @@ import binascii import datetime -import pytest import re -import redis import time from string import ascii_letters -from redis.client import parse_info +import pytest + +import redis from redis import exceptions +from redis.client import parse_info from .conftest import ( _get_client, + skip_if_redis_enterprise, skip_if_server_version_gte, skip_if_server_version_lt, - skip_if_redis_enterprise, skip_unless_arch_bits, ) @@ -21,21 +22,22 @@ @pytest.fixture() def slowlog(request, r): current_config = r.config_get() - old_slower_than_value = current_config['slowlog-log-slower-than'] - old_max_legnth_value = current_config['slowlog-max-len'] + old_slower_than_value = current_config["slowlog-log-slower-than"] + old_max_legnth_value = current_config["slowlog-max-len"] def cleanup(): - r.config_set('slowlog-log-slower-than', old_slower_than_value) - r.config_set('slowlog-max-len', old_max_legnth_value) + r.config_set("slowlog-log-slower-than", old_slower_than_value) + r.config_set("slowlog-max-len", old_max_legnth_value) + request.addfinalizer(cleanup) - r.config_set('slowlog-log-slower-than', 0) - r.config_set('slowlog-max-len', 128) + r.config_set("slowlog-log-slower-than", 0) + r.config_set("slowlog-max-len", 128) def redis_server_time(client): seconds, milliseconds = client.time() - timestamp = float(f'{seconds}.{milliseconds}') + timestamp = float(f"{seconds}.{milliseconds}") return datetime.datetime.fromtimestamp(timestamp) @@ -54,19 +56,19 @@ class TestResponseCallbacks: def test_response_callbacks(self, r): assert r.response_callbacks == redis.Redis.RESPONSE_CALLBACKS assert id(r.response_callbacks) != id(redis.Redis.RESPONSE_CALLBACKS) - r.set_response_callback('GET', lambda x: 'static') - r['a'] = 'foo' - assert r['a'] == 'static' + r.set_response_callback("GET", lambda x: "static") + r["a"] = "foo" + assert r["a"] == "static" def test_case_insensitive_command_names(self, r): - assert r.response_callbacks['del'] == r.response_callbacks['DEL'] + assert r.response_callbacks["del"] == r.response_callbacks["DEL"] class TestRedisCommands: def test_command_on_invalid_key_type(self, r): - r.lpush('a', '1') + r.lpush("a", "1") with pytest.raises(redis.ResponseError): - r['a'] + r["a"] # SERVER INFORMATION @pytest.mark.onlynoncluster @@ -74,20 +76,20 @@ def test_command_on_invalid_key_type(self, r): def test_acl_cat_no_category(self, r): categories = r.acl_cat() assert isinstance(categories, list) - assert 'read' in categories + assert "read" in categories @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") def test_acl_cat_with_category(self, r): - commands = r.acl_cat('read') + commands = r.acl_cat("read") assert isinstance(commands, list) - assert 'get' in commands + assert "get" in commands @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise def test_acl_deluser(self, r, request): - username = 'redis-py-user' + username = "redis-py-user" def teardown(): r.acl_deluser(username) @@ -99,7 +101,7 @@ def teardown(): assert r.acl_deluser(username) == 1 # now, a group of users - users = [f'bogususer_{r}' for r in range(0, 5)] + users = [f"bogususer_{r}" for r in range(0, 5)] for u in users: r.acl_setuser(u, enabled=False, reset=True) assert r.acl_deluser(*users) > 1 @@ -117,7 +119,7 @@ def test_acl_genpass(self, r): assert isinstance(password, str) with pytest.raises(exceptions.DataError): - r.acl_genpass('value') + r.acl_genpass("value") r.acl_genpass(-5) r.acl_genpass(5555) @@ -128,90 +130,109 @@ def test_acl_genpass(self, r): @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise def test_acl_getuser_setuser(self, r, request): - username = 'redis-py-user' + username = "redis-py-user" def teardown(): r.acl_deluser(username) + request.addfinalizer(teardown) # test enabled=False assert r.acl_setuser(username, enabled=False, reset=True) acl = r.acl_getuser(username) - assert acl['categories'] == ['-@all'] - assert acl['commands'] == [] - assert acl['keys'] == [] - assert acl['passwords'] == [] - assert 'off' in acl['flags'] - assert acl['enabled'] is False + assert acl["categories"] == ["-@all"] + assert acl["commands"] == [] + assert acl["keys"] == [] + assert acl["passwords"] == [] + assert "off" in acl["flags"] + assert acl["enabled"] is False # test nopass=True assert r.acl_setuser(username, enabled=True, reset=True, nopass=True) acl = r.acl_getuser(username) - assert acl['categories'] == ['-@all'] - assert acl['commands'] == [] - assert acl['keys'] == [] - assert acl['passwords'] == [] - assert 'on' in acl['flags'] - assert 'nopass' in acl['flags'] - assert acl['enabled'] is True + assert acl["categories"] == ["-@all"] + assert acl["commands"] == [] + assert acl["keys"] == [] + assert acl["passwords"] == [] + assert "on" in acl["flags"] + assert "nopass" in acl["flags"] + assert acl["enabled"] is True # test all args - assert r.acl_setuser(username, enabled=True, reset=True, - passwords=['+pass1', '+pass2'], - categories=['+set', '+@hash', '-geo'], - commands=['+get', '+mget', '-hset'], - keys=['cache:*', 'objects:*']) + assert r.acl_setuser( + username, + enabled=True, + reset=True, + passwords=["+pass1", "+pass2"], + categories=["+set", "+@hash", "-geo"], + commands=["+get", "+mget", "-hset"], + keys=["cache:*", "objects:*"], + ) acl = r.acl_getuser(username) - assert set(acl['categories']) == {'-@all', '+@set', '+@hash'} - assert set(acl['commands']) == {'+get', '+mget', '-hset'} - assert acl['enabled'] is True - assert 'on' in acl['flags'] - assert set(acl['keys']) == {b'cache:*', b'objects:*'} - assert len(acl['passwords']) == 2 + assert set(acl["categories"]) == {"-@all", "+@set", "+@hash"} + assert set(acl["commands"]) == {"+get", "+mget", "-hset"} + assert acl["enabled"] is True + assert "on" in acl["flags"] + assert set(acl["keys"]) == {b"cache:*", b"objects:*"} + assert len(acl["passwords"]) == 2 # test reset=False keeps existing ACL and applies new ACL on top - assert r.acl_setuser(username, enabled=True, reset=True, - passwords=['+pass1'], - categories=['+@set'], - commands=['+get'], - keys=['cache:*']) - assert r.acl_setuser(username, enabled=True, - passwords=['+pass2'], - categories=['+@hash'], - commands=['+mget'], - keys=['objects:*']) + assert r.acl_setuser( + username, + enabled=True, + reset=True, + passwords=["+pass1"], + categories=["+@set"], + commands=["+get"], + keys=["cache:*"], + ) + assert r.acl_setuser( + username, + enabled=True, + passwords=["+pass2"], + categories=["+@hash"], + commands=["+mget"], + keys=["objects:*"], + ) acl = r.acl_getuser(username) - assert set(acl['categories']) == {'-@all', '+@set', '+@hash'} - assert set(acl['commands']) == {'+get', '+mget'} - assert acl['enabled'] is True - assert 'on' in acl['flags'] - assert set(acl['keys']) == {b'cache:*', b'objects:*'} - assert len(acl['passwords']) == 2 + assert set(acl["categories"]) == {"-@all", "+@set", "+@hash"} + assert set(acl["commands"]) == {"+get", "+mget"} + assert acl["enabled"] is True + assert "on" in acl["flags"] + assert set(acl["keys"]) == {b"cache:*", b"objects:*"} + assert len(acl["passwords"]) == 2 # test removal of passwords - assert r.acl_setuser(username, enabled=True, reset=True, - passwords=['+pass1', '+pass2']) - assert len(r.acl_getuser(username)['passwords']) == 2 - assert r.acl_setuser(username, enabled=True, - passwords=['-pass2']) - assert len(r.acl_getuser(username)['passwords']) == 1 + assert r.acl_setuser( + username, enabled=True, reset=True, passwords=["+pass1", "+pass2"] + ) + assert len(r.acl_getuser(username)["passwords"]) == 2 + assert r.acl_setuser(username, enabled=True, passwords=["-pass2"]) + assert len(r.acl_getuser(username)["passwords"]) == 1 # Resets and tests that hashed passwords are set properly. - hashed_password = ('5e884898da28047151d0e56f8dc629' - '2773603d0d6aabbdd62a11ef721d1542d8') - assert r.acl_setuser(username, enabled=True, reset=True, - hashed_passwords=['+' + hashed_password]) + hashed_password = ( + "5e884898da28047151d0e56f8dc629" "2773603d0d6aabbdd62a11ef721d1542d8" + ) + assert r.acl_setuser( + username, enabled=True, reset=True, hashed_passwords=["+" + hashed_password] + ) acl = r.acl_getuser(username) - assert acl['passwords'] == [hashed_password] + assert acl["passwords"] == [hashed_password] # test removal of hashed passwords - assert r.acl_setuser(username, enabled=True, reset=True, - hashed_passwords=['+' + hashed_password], - passwords=['+pass1']) - assert len(r.acl_getuser(username)['passwords']) == 2 - assert r.acl_setuser(username, enabled=True, - hashed_passwords=['-' + hashed_password]) - assert len(r.acl_getuser(username)['passwords']) == 1 + assert r.acl_setuser( + username, + enabled=True, + reset=True, + hashed_passwords=["+" + hashed_password], + passwords=["+pass1"], + ) + assert len(r.acl_getuser(username)["passwords"]) == 2 + assert r.acl_setuser( + username, enabled=True, hashed_passwords=["-" + hashed_password] + ) + assert len(r.acl_getuser(username)["passwords"]) == 1 @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") @@ -224,10 +245,11 @@ def test_acl_help(self, r): @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise def test_acl_list(self, r, request): - username = 'redis-py-user' + username = "redis-py-user" def teardown(): r.acl_deluser(username) + request.addfinalizer(teardown) assert r.acl_setuser(username, enabled=False, reset=True) @@ -238,77 +260,86 @@ def teardown(): @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise def test_acl_log(self, r, request): - username = 'redis-py-user' + username = "redis-py-user" def teardown(): r.acl_deluser(username) request.addfinalizer(teardown) - r.acl_setuser(username, enabled=True, reset=True, - commands=['+get', '+set', '+select'], - keys=['cache:*'], nopass=True) + r.acl_setuser( + username, + enabled=True, + reset=True, + commands=["+get", "+set", "+select"], + keys=["cache:*"], + nopass=True, + ) r.acl_log_reset() - user_client = _get_client(redis.Redis, request, flushdb=False, - username=username) + user_client = _get_client( + redis.Redis, request, flushdb=False, username=username + ) # Valid operation and key - assert user_client.set('cache:0', 1) - assert user_client.get('cache:0') == b'1' + assert user_client.set("cache:0", 1) + assert user_client.get("cache:0") == b"1" # Invalid key with pytest.raises(exceptions.NoPermissionError): - user_client.get('violated_cache:0') + user_client.get("violated_cache:0") # Invalid operation with pytest.raises(exceptions.NoPermissionError): - user_client.hset('cache:0', 'hkey', 'hval') + user_client.hset("cache:0", "hkey", "hval") assert isinstance(r.acl_log(), list) assert len(r.acl_log()) == 2 assert len(r.acl_log(count=1)) == 1 assert isinstance(r.acl_log()[0], dict) - assert 'client-info' in r.acl_log(count=1)[0] + assert "client-info" in r.acl_log(count=1)[0] assert r.acl_log_reset() @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise def test_acl_setuser_categories_without_prefix_fails(self, r, request): - username = 'redis-py-user' + username = "redis-py-user" def teardown(): r.acl_deluser(username) + request.addfinalizer(teardown) with pytest.raises(exceptions.DataError): - r.acl_setuser(username, categories=['list']) + r.acl_setuser(username, categories=["list"]) @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise def test_acl_setuser_commands_without_prefix_fails(self, r, request): - username = 'redis-py-user' + username = "redis-py-user" def teardown(): r.acl_deluser(username) + request.addfinalizer(teardown) with pytest.raises(exceptions.DataError): - r.acl_setuser(username, commands=['get']) + r.acl_setuser(username, commands=["get"]) @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise def test_acl_setuser_add_passwords_and_nopass_fails(self, r, request): - username = 'redis-py-user' + username = "redis-py-user" def teardown(): r.acl_deluser(username) + request.addfinalizer(teardown) with pytest.raises(exceptions.DataError): - r.acl_setuser(username, passwords='+mypass', nopass=True) + r.acl_setuser(username, passwords="+mypass", nopass=True) @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") @@ -327,36 +358,36 @@ def test_acl_whoami(self, r): def test_client_list(self, r): clients = r.client_list() assert isinstance(clients[0], dict) - assert 'addr' in clients[0] + assert "addr" in clients[0] @pytest.mark.onlynoncluster - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_client_info(self, r): info = r.client_info() assert isinstance(info, dict) - assert 'addr' in info + assert "addr" in info @pytest.mark.onlynoncluster - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_client_list_types_not_replica(self, r): with pytest.raises(exceptions.RedisError): - r.client_list(_type='not a client type') - for client_type in ['normal', 'master', 'pubsub']: + r.client_list(_type="not a client type") + for client_type in ["normal", "master", "pubsub"]: clients = r.client_list(_type=client_type) assert isinstance(clients, list) @skip_if_redis_enterprise def test_client_list_replica(self, r): - clients = r.client_list(_type='replica') + clients = r.client_list(_type="replica") assert isinstance(clients, list) @pytest.mark.onlynoncluster - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_client_list_client_id(self, r, request): clients = r.client_list() - clients = r.client_list(client_id=[clients[0]['id']]) + clients = r.client_list(client_id=[clients[0]["id"]]) assert len(clients) == 1 - assert 'addr' in clients[0] + assert "addr" in clients[0] # testing multiple client ids _get_client(redis.Redis, request, flushdb=False) @@ -366,19 +397,19 @@ def test_client_list_client_id(self, r, request): assert len(clients_listed) > 1 @pytest.mark.onlynoncluster - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_client_id(self, r): assert r.client_id() > 0 @pytest.mark.onlynoncluster - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_client_trackinginfo(self, r): res = r.client_trackinginfo() assert len(res) > 2 - assert 'prefixes' in res + assert "prefixes" in res @pytest.mark.onlynoncluster - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_client_unblock(self, r): myid = r.client_id() assert not r.client_unblock(myid) @@ -386,36 +417,42 @@ def test_client_unblock(self, r): assert not r.client_unblock(myid, error=False) @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.6.9') + @skip_if_server_version_lt("2.6.9") def test_client_getname(self, r): assert r.client_getname() is None @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.6.9') + @skip_if_server_version_lt("2.6.9") def test_client_setname(self, r): - assert r.client_setname('redis_py_test') - assert r.client_getname() == 'redis_py_test' + assert r.client_setname("redis_py_test") + assert r.client_getname() == "redis_py_test" @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.6.9') + @skip_if_server_version_lt("2.6.9") def test_client_kill(self, r, r2): - r.client_setname('redis-py-c1') - r2.client_setname('redis-py-c2') - clients = [client for client in r.client_list() - if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + r.client_setname("redis-py-c1") + r2.client_setname("redis-py-c2") + clients = [ + client + for client in r.client_list() + if client.get("name") in ["redis-py-c1", "redis-py-c2"] + ] assert len(clients) == 2 - clients_by_name = {client.get('name'): client for client in clients} + clients_by_name = {client.get("name"): client for client in clients} - client_addr = clients_by_name['redis-py-c2'].get('addr') + client_addr = clients_by_name["redis-py-c2"].get("addr") assert r.client_kill(client_addr) is True - clients = [client for client in r.client_list() - if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + clients = [ + client + for client in r.client_list() + if client.get("name") in ["redis-py-c1", "redis-py-c2"] + ] assert len(clients) == 1 - assert clients[0].get('name') == 'redis-py-c1' + assert clients[0].get("name") == "redis-py-c1" - @skip_if_server_version_lt('2.8.12') + @skip_if_server_version_lt("2.8.12") def test_client_kill_filter_invalid_params(self, r): # empty with pytest.raises(exceptions.DataError): @@ -430,110 +467,130 @@ def test_client_kill_filter_invalid_params(self, r): r.client_kill_filter(_type="caster") @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.8.12') + @skip_if_server_version_lt("2.8.12") def test_client_kill_filter_by_id(self, r, r2): - r.client_setname('redis-py-c1') - r2.client_setname('redis-py-c2') - clients = [client for client in r.client_list() - if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + r.client_setname("redis-py-c1") + r2.client_setname("redis-py-c2") + clients = [ + client + for client in r.client_list() + if client.get("name") in ["redis-py-c1", "redis-py-c2"] + ] assert len(clients) == 2 - clients_by_name = {client.get('name'): client for client in clients} + clients_by_name = {client.get("name"): client for client in clients} - client_2_id = clients_by_name['redis-py-c2'].get('id') + client_2_id = clients_by_name["redis-py-c2"].get("id") resp = r.client_kill_filter(_id=client_2_id) assert resp == 1 - clients = [client for client in r.client_list() - if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + clients = [ + client + for client in r.client_list() + if client.get("name") in ["redis-py-c1", "redis-py-c2"] + ] assert len(clients) == 1 - assert clients[0].get('name') == 'redis-py-c1' + assert clients[0].get("name") == "redis-py-c1" @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.8.12') + @skip_if_server_version_lt("2.8.12") def test_client_kill_filter_by_addr(self, r, r2): - r.client_setname('redis-py-c1') - r2.client_setname('redis-py-c2') - clients = [client for client in r.client_list() - if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + r.client_setname("redis-py-c1") + r2.client_setname("redis-py-c2") + clients = [ + client + for client in r.client_list() + if client.get("name") in ["redis-py-c1", "redis-py-c2"] + ] assert len(clients) == 2 - clients_by_name = {client.get('name'): client for client in clients} + clients_by_name = {client.get("name"): client for client in clients} - client_2_addr = clients_by_name['redis-py-c2'].get('addr') + client_2_addr = clients_by_name["redis-py-c2"].get("addr") resp = r.client_kill_filter(addr=client_2_addr) assert resp == 1 - clients = [client for client in r.client_list() - if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + clients = [ + client + for client in r.client_list() + if client.get("name") in ["redis-py-c1", "redis-py-c2"] + ] assert len(clients) == 1 - assert clients[0].get('name') == 'redis-py-c1' + assert clients[0].get("name") == "redis-py-c1" - @skip_if_server_version_lt('2.6.9') + @skip_if_server_version_lt("2.6.9") def test_client_list_after_client_setname(self, r): - r.client_setname('redis_py_test') + r.client_setname("redis_py_test") clients = r.client_list() # we don't know which client ours will be - assert 'redis_py_test' in [c['name'] for c in clients] + assert "redis_py_test" in [c["name"] for c in clients] - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_client_kill_filter_by_laddr(self, r, r2): - r.client_setname('redis-py-c1') - r2.client_setname('redis-py-c2') - clients = [client for client in r.client_list() - if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + r.client_setname("redis-py-c1") + r2.client_setname("redis-py-c2") + clients = [ + client + for client in r.client_list() + if client.get("name") in ["redis-py-c1", "redis-py-c2"] + ] assert len(clients) == 2 - clients_by_name = {client.get('name'): client for client in clients} + clients_by_name = {client.get("name"): client for client in clients} - client_2_addr = clients_by_name['redis-py-c2'].get('laddr') + client_2_addr = clients_by_name["redis-py-c2"].get("laddr") assert r.client_kill_filter(laddr=client_2_addr) - @skip_if_server_version_lt('6.0.0') + @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise def test_client_kill_filter_by_user(self, r, request): - killuser = 'user_to_kill' - r.acl_setuser(killuser, enabled=True, reset=True, - commands=['+get', '+set', '+select'], - keys=['cache:*'], nopass=True) + killuser = "user_to_kill" + r.acl_setuser( + killuser, + enabled=True, + reset=True, + commands=["+get", "+set", "+select"], + keys=["cache:*"], + nopass=True, + ) _get_client(redis.Redis, request, flushdb=False, username=killuser) r.client_kill_filter(user=killuser) clients = r.client_list() for c in clients: - assert c['user'] != killuser + assert c["user"] != killuser r.acl_deluser(killuser) @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.9.50') + @skip_if_server_version_lt("2.9.50") @skip_if_redis_enterprise def test_client_pause(self, r): assert r.client_pause(1) assert r.client_pause(timeout=1) with pytest.raises(exceptions.RedisError): - r.client_pause(timeout='not an integer') + r.client_pause(timeout="not an integer") @pytest.mark.onlynoncluster - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") @skip_if_redis_enterprise def test_client_unpause(self, r): - assert r.client_unpause() == b'OK' + assert r.client_unpause() == b"OK" @pytest.mark.onlynoncluster - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_client_reply(self, r, r_timeout): - assert r_timeout.client_reply('ON') == b'OK' + assert r_timeout.client_reply("ON") == b"OK" with pytest.raises(exceptions.TimeoutError): - r_timeout.client_reply('OFF') + r_timeout.client_reply("OFF") - r_timeout.client_reply('SKIP') + r_timeout.client_reply("SKIP") - assert r_timeout.set('foo', 'bar') + assert r_timeout.set("foo", "bar") # validate it was set - assert r.get('foo') == b'bar' + assert r.get("foo") == b"bar" @pytest.mark.onlynoncluster - @skip_if_server_version_lt('6.0.0') + @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise def test_client_getredir(self, r): assert isinstance(r.client_getredir(), int) @@ -549,37 +606,37 @@ def test_config_get(self, r): @skip_if_redis_enterprise def test_config_resetstat(self, r): r.ping() - prior_commands_processed = int(r.info()['total_commands_processed']) + prior_commands_processed = int(r.info()["total_commands_processed"]) assert prior_commands_processed >= 1 r.config_resetstat() - reset_commands_processed = int(r.info()['total_commands_processed']) + reset_commands_processed = int(r.info()["total_commands_processed"]) assert reset_commands_processed < prior_commands_processed @skip_if_redis_enterprise def test_config_set(self, r): - r.config_set('timeout', 70) - assert r.config_get()['timeout'] == '70' - assert r.config_set('timeout', 0) - assert r.config_get()['timeout'] == '0' + r.config_set("timeout", 70) + assert r.config_get()["timeout"] == "70" + assert r.config_set("timeout", 0) + assert r.config_get()["timeout"] == "0" @pytest.mark.onlynoncluster def test_dbsize(self, r): - r['a'] = 'foo' - r['b'] = 'bar' + r["a"] = "foo" + r["b"] = "bar" assert r.dbsize() == 2 @pytest.mark.onlynoncluster def test_echo(self, r): - assert r.echo('foo bar') == b'foo bar' + assert r.echo("foo bar") == b"foo bar" @pytest.mark.onlynoncluster def test_info(self, r): - r['a'] = 'foo' - r['b'] = 'bar' + r["a"] = "foo" + r["b"] = "bar" info = r.info() assert isinstance(info, dict) - assert 'arch_bits' in info.keys() - assert 'redis_version' in info.keys() + assert "arch_bits" in info.keys() + assert "redis_version" in info.keys() @pytest.mark.onlynoncluster @skip_if_redis_enterprise @@ -587,20 +644,20 @@ def test_lastsave(self, r): assert isinstance(r.lastsave(), datetime.datetime) @pytest.mark.onlynoncluster - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_lolwut(self, r): - lolwut = r.lolwut().decode('utf-8') - assert 'Redis ver.' in lolwut + lolwut = r.lolwut().decode("utf-8") + assert "Redis ver." in lolwut - lolwut = r.lolwut(5, 6, 7, 8).decode('utf-8') - assert 'Redis ver.' in lolwut + lolwut = r.lolwut(5, 6, 7, 8).decode("utf-8") + assert "Redis ver." in lolwut def test_object(self, r): - r['a'] = 'foo' - assert isinstance(r.object('refcount', 'a'), int) - assert isinstance(r.object('idletime', 'a'), int) - assert r.object('encoding', 'a') in (b'raw', b'embstr') - assert r.object('idletime', 'invalid-key') is None + r["a"] = "foo" + assert isinstance(r.object("refcount", "a"), int) + assert isinstance(r.object("idletime", "a"), int) + assert r.object("encoding", "a") in (b"raw", b"embstr") + assert r.object("idletime", "invalid-key") is None def test_ping(self, r): assert r.ping() @@ -612,36 +669,34 @@ def test_quit(self, r): @pytest.mark.onlynoncluster def test_slowlog_get(self, r, slowlog): assert r.slowlog_reset() - unicode_string = chr(3456) + 'abcd' + chr(3421) + unicode_string = chr(3456) + "abcd" + chr(3421) r.get(unicode_string) slowlog = r.slowlog_get() assert isinstance(slowlog, list) - commands = [log['command'] for log in slowlog] + commands = [log["command"] for log in slowlog] - get_command = b' '.join((b'GET', unicode_string.encode('utf-8'))) + get_command = b" ".join((b"GET", unicode_string.encode("utf-8"))) assert get_command in commands - assert b'SLOWLOG RESET' in commands + assert b"SLOWLOG RESET" in commands # the order should be ['GET ', 'SLOWLOG RESET'], # but if other clients are executing commands at the same time, there # could be commands, before, between, or after, so just check that # the two we care about are in the appropriate order. - assert commands.index(get_command) < commands.index(b'SLOWLOG RESET') + assert commands.index(get_command) < commands.index(b"SLOWLOG RESET") # make sure other attributes are typed correctly - assert isinstance(slowlog[0]['start_time'], int) - assert isinstance(slowlog[0]['duration'], int) + assert isinstance(slowlog[0]["start_time"], int) + assert isinstance(slowlog[0]["duration"], int) # Mock result if we didn't get slowlog complexity info. - if 'complexity' not in slowlog[0]: + if "complexity" not in slowlog[0]: # monkey patch parse_response() COMPLEXITY_STATEMENT = "Complexity info: N:4712,M:3788" old_parse_response = r.parse_response def parse_response(connection, command_name, **options): - if command_name != 'SLOWLOG GET': - return old_parse_response(connection, - command_name, - **options) + if command_name != "SLOWLOG GET": + return old_parse_response(connection, command_name, **options) responses = connection.read_response() for response in responses: # Complexity info stored as fourth item in list @@ -653,10 +708,10 @@ def parse_response(connection, command_name, **options): # test slowlog = r.slowlog_get() assert isinstance(slowlog, list) - commands = [log['command'] for log in slowlog] + commands = [log["command"] for log in slowlog] assert get_command in commands idx = commands.index(get_command) - assert slowlog[idx]['complexity'] == COMPLEXITY_STATEMENT + assert slowlog[idx]["complexity"] == COMPLEXITY_STATEMENT # tear down monkeypatch r.parse_response = old_parse_response @@ -664,7 +719,7 @@ def parse_response(connection, command_name, **options): @pytest.mark.onlynoncluster def test_slowlog_get_limit(self, r, slowlog): assert r.slowlog_reset() - r.get('foo') + r.get("foo") slowlog = r.slowlog_get(1) assert isinstance(slowlog, list) # only one command, based on the number we passed to slowlog_get() @@ -672,10 +727,10 @@ def test_slowlog_get_limit(self, r, slowlog): @pytest.mark.onlynoncluster def test_slowlog_length(self, r, slowlog): - r.get('foo') + r.get("foo") assert isinstance(r.slowlog_len(), int) - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_time(self, r): t = r.time() assert len(t) == 2 @@ -690,104 +745,104 @@ def test_bgsave(self, r): # BASIC KEY COMMANDS def test_append(self, r): - assert r.append('a', 'a1') == 2 - assert r['a'] == b'a1' - assert r.append('a', 'a2') == 4 - assert r['a'] == b'a1a2' + assert r.append("a", "a1") == 2 + assert r["a"] == b"a1" + assert r.append("a", "a2") == 4 + assert r["a"] == b"a1a2" - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_bitcount(self, r): - r.setbit('a', 5, True) - assert r.bitcount('a') == 1 - r.setbit('a', 6, True) - assert r.bitcount('a') == 2 - r.setbit('a', 5, False) - assert r.bitcount('a') == 1 - r.setbit('a', 9, True) - r.setbit('a', 17, True) - r.setbit('a', 25, True) - r.setbit('a', 33, True) - assert r.bitcount('a') == 5 - assert r.bitcount('a', 0, -1) == 5 - assert r.bitcount('a', 2, 3) == 2 - assert r.bitcount('a', 2, -1) == 3 - assert r.bitcount('a', -2, -1) == 2 - assert r.bitcount('a', 1, 1) == 1 - - @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.6.0') + r.setbit("a", 5, True) + assert r.bitcount("a") == 1 + r.setbit("a", 6, True) + assert r.bitcount("a") == 2 + r.setbit("a", 5, False) + assert r.bitcount("a") == 1 + r.setbit("a", 9, True) + r.setbit("a", 17, True) + r.setbit("a", 25, True) + r.setbit("a", 33, True) + assert r.bitcount("a") == 5 + assert r.bitcount("a", 0, -1) == 5 + assert r.bitcount("a", 2, 3) == 2 + assert r.bitcount("a", 2, -1) == 3 + assert r.bitcount("a", -2, -1) == 2 + assert r.bitcount("a", 1, 1) == 1 + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("2.6.0") def test_bitop_not_empty_string(self, r): - r['a'] = '' - r.bitop('not', 'r', 'a') - assert r.get('r') is None + r["a"] = "" + r.bitop("not", "r", "a") + assert r.get("r") is None @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_bitop_not(self, r): - test_str = b'\xAA\x00\xFF\x55' + test_str = b"\xAA\x00\xFF\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF - r['a'] = test_str - r.bitop('not', 'r', 'a') - assert int(binascii.hexlify(r['r']), 16) == correct + r["a"] = test_str + r.bitop("not", "r", "a") + assert int(binascii.hexlify(r["r"]), 16) == correct @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_bitop_not_in_place(self, r): - test_str = b'\xAA\x00\xFF\x55' + test_str = b"\xAA\x00\xFF\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF - r['a'] = test_str - r.bitop('not', 'a', 'a') - assert int(binascii.hexlify(r['a']), 16) == correct + r["a"] = test_str + r.bitop("not", "a", "a") + assert int(binascii.hexlify(r["a"]), 16) == correct @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_bitop_single_string(self, r): - test_str = b'\x01\x02\xFF' - r['a'] = test_str - r.bitop('and', 'res1', 'a') - r.bitop('or', 'res2', 'a') - r.bitop('xor', 'res3', 'a') - assert r['res1'] == test_str - assert r['res2'] == test_str - assert r['res3'] == test_str + test_str = b"\x01\x02\xFF" + r["a"] = test_str + r.bitop("and", "res1", "a") + r.bitop("or", "res2", "a") + r.bitop("xor", "res3", "a") + assert r["res1"] == test_str + assert r["res2"] == test_str + assert r["res3"] == test_str @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_bitop_string_operands(self, r): - r['a'] = b'\x01\x02\xFF\xFF' - r['b'] = b'\x01\x02\xFF' - r.bitop('and', 'res1', 'a', 'b') - r.bitop('or', 'res2', 'a', 'b') - r.bitop('xor', 'res3', 'a', 'b') - assert int(binascii.hexlify(r['res1']), 16) == 0x0102FF00 - assert int(binascii.hexlify(r['res2']), 16) == 0x0102FFFF - assert int(binascii.hexlify(r['res3']), 16) == 0x000000FF + r["a"] = b"\x01\x02\xFF\xFF" + r["b"] = b"\x01\x02\xFF" + r.bitop("and", "res1", "a", "b") + r.bitop("or", "res2", "a", "b") + r.bitop("xor", "res3", "a", "b") + assert int(binascii.hexlify(r["res1"]), 16) == 0x0102FF00 + assert int(binascii.hexlify(r["res2"]), 16) == 0x0102FFFF + assert int(binascii.hexlify(r["res3"]), 16) == 0x000000FF @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.8.7') + @skip_if_server_version_lt("2.8.7") def test_bitpos(self, r): - key = 'key:bitpos' - r.set(key, b'\xff\xf0\x00') + key = "key:bitpos" + r.set(key, b"\xff\xf0\x00") assert r.bitpos(key, 0) == 12 assert r.bitpos(key, 0, 2, -1) == 16 assert r.bitpos(key, 0, -2, -1) == 12 - r.set(key, b'\x00\xff\xf0') + r.set(key, b"\x00\xff\xf0") assert r.bitpos(key, 1, 0) == 8 assert r.bitpos(key, 1, 1) == 8 - r.set(key, b'\x00\x00\x00') + r.set(key, b"\x00\x00\x00") assert r.bitpos(key, 1) == -1 - @skip_if_server_version_lt('2.8.7') + @skip_if_server_version_lt("2.8.7") def test_bitpos_wrong_arguments(self, r): - key = 'key:bitpos:wrong:args' - r.set(key, b'\xff\xf0\x00') + key = "key:bitpos:wrong:args" + r.set(key, b"\xff\xf0\x00") with pytest.raises(exceptions.RedisError): r.bitpos(key, 0, end=1) == 12 with pytest.raises(exceptions.RedisError): r.bitpos(key, 7) == 12 @pytest.mark.onlynoncluster - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_copy(self, r): assert r.copy("a", "b") == 0 r.set("a", "foo") @@ -796,7 +851,7 @@ def test_copy(self, r): assert r.get("b") == b"foo" @pytest.mark.onlynoncluster - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_copy_and_replace(self, r): r.set("a", "foo1") r.set("b", "foo2") @@ -804,7 +859,7 @@ def test_copy_and_replace(self, r): assert r.copy("a", "b", replace=True) == 1 @pytest.mark.onlynoncluster - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_copy_to_another_database(self, request): r0 = _get_client(redis.Redis, request, db=0) r1 = _get_client(redis.Redis, request, db=1) @@ -813,2268 +868,2477 @@ def test_copy_to_another_database(self, request): assert r1.get("b") == b"foo" def test_decr(self, r): - assert r.decr('a') == -1 - assert r['a'] == b'-1' - assert r.decr('a') == -2 - assert r['a'] == b'-2' - assert r.decr('a', amount=5) == -7 - assert r['a'] == b'-7' + assert r.decr("a") == -1 + assert r["a"] == b"-1" + assert r.decr("a") == -2 + assert r["a"] == b"-2" + assert r.decr("a", amount=5) == -7 + assert r["a"] == b"-7" def test_decrby(self, r): - assert r.decrby('a', amount=2) == -2 - assert r.decrby('a', amount=3) == -5 - assert r['a'] == b'-5' + assert r.decrby("a", amount=2) == -2 + assert r.decrby("a", amount=3) == -5 + assert r["a"] == b"-5" def test_delete(self, r): - assert r.delete('a') == 0 - r['a'] = 'foo' - assert r.delete('a') == 1 + assert r.delete("a") == 0 + r["a"] = "foo" + assert r.delete("a") == 1 def test_delete_with_multiple_keys(self, r): - r['a'] = 'foo' - r['b'] = 'bar' - assert r.delete('a', 'b') == 2 - assert r.get('a') is None - assert r.get('b') is None + r["a"] = "foo" + r["b"] = "bar" + assert r.delete("a", "b") == 2 + assert r.get("a") is None + assert r.get("b") is None def test_delitem(self, r): - r['a'] = 'foo' - del r['a'] - assert r.get('a') is None + r["a"] = "foo" + del r["a"] + assert r.get("a") is None - @skip_if_server_version_lt('4.0.0') + @skip_if_server_version_lt("4.0.0") def test_unlink(self, r): - assert r.unlink('a') == 0 - r['a'] = 'foo' - assert r.unlink('a') == 1 - assert r.get('a') is None + assert r.unlink("a") == 0 + r["a"] = "foo" + assert r.unlink("a") == 1 + assert r.get("a") is None - @skip_if_server_version_lt('4.0.0') + @skip_if_server_version_lt("4.0.0") def test_unlink_with_multiple_keys(self, r): - r['a'] = 'foo' - r['b'] = 'bar' - assert r.unlink('a', 'b') == 2 - assert r.get('a') is None - assert r.get('b') is None + r["a"] = "foo" + r["b"] = "bar" + assert r.unlink("a", "b") == 2 + assert r.get("a") is None + assert r.get("b") is None - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_dump_and_restore(self, r): - r['a'] = 'foo' - dumped = r.dump('a') - del r['a'] - r.restore('a', 0, dumped) - assert r['a'] == b'foo' + r["a"] = "foo" + dumped = r.dump("a") + del r["a"] + r.restore("a", 0, dumped) + assert r["a"] == b"foo" - @skip_if_server_version_lt('3.0.0') + @skip_if_server_version_lt("3.0.0") def test_dump_and_restore_and_replace(self, r): - r['a'] = 'bar' - dumped = r.dump('a') + r["a"] = "bar" + dumped = r.dump("a") with pytest.raises(redis.ResponseError): - r.restore('a', 0, dumped) + r.restore("a", 0, dumped) - r.restore('a', 0, dumped, replace=True) - assert r['a'] == b'bar' + r.restore("a", 0, dumped, replace=True) + assert r["a"] == b"bar" - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_dump_and_restore_absttl(self, r): - r['a'] = 'foo' - dumped = r.dump('a') - del r['a'] + r["a"] = "foo" + dumped = r.dump("a") + del r["a"] ttl = int( - (redis_server_time(r) + datetime.timedelta(minutes=1)).timestamp() - * 1000 + (redis_server_time(r) + datetime.timedelta(minutes=1)).timestamp() * 1000 ) - r.restore('a', ttl, dumped, absttl=True) - assert r['a'] == b'foo' - assert 0 < r.ttl('a') <= 61 + r.restore("a", ttl, dumped, absttl=True) + assert r["a"] == b"foo" + assert 0 < r.ttl("a") <= 61 def test_exists(self, r): - assert r.exists('a') == 0 - r['a'] = 'foo' - r['b'] = 'bar' - assert r.exists('a') == 1 - assert r.exists('a', 'b') == 2 + assert r.exists("a") == 0 + r["a"] = "foo" + r["b"] = "bar" + assert r.exists("a") == 1 + assert r.exists("a", "b") == 2 def test_exists_contains(self, r): - assert 'a' not in r - r['a'] = 'foo' - assert 'a' in r + assert "a" not in r + r["a"] = "foo" + assert "a" in r def test_expire(self, r): - assert r.expire('a', 10) is False - r['a'] = 'foo' - assert r.expire('a', 10) is True - assert 0 < r.ttl('a') <= 10 - assert r.persist('a') - assert r.ttl('a') == -1 + assert r.expire("a", 10) is False + r["a"] = "foo" + assert r.expire("a", 10) is True + assert 0 < r.ttl("a") <= 10 + assert r.persist("a") + assert r.ttl("a") == -1 def test_expireat_datetime(self, r): expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) - r['a'] = 'foo' - assert r.expireat('a', expire_at) is True - assert 0 < r.ttl('a') <= 61 + r["a"] = "foo" + assert r.expireat("a", expire_at) is True + assert 0 < r.ttl("a") <= 61 def test_expireat_no_key(self, r): expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) - assert r.expireat('a', expire_at) is False + assert r.expireat("a", expire_at) is False def test_expireat_unixtime(self, r): expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) - r['a'] = 'foo' + r["a"] = "foo" expire_at_seconds = int(time.mktime(expire_at.timetuple())) - assert r.expireat('a', expire_at_seconds) is True - assert 0 < r.ttl('a') <= 61 + assert r.expireat("a", expire_at_seconds) is True + assert 0 < r.ttl("a") <= 61 def test_get_and_set(self, r): # get and set can't be tested independently of each other - assert r.get('a') is None - byte_string = b'value' + assert r.get("a") is None + byte_string = b"value" integer = 5 - unicode_string = chr(3456) + 'abcd' + chr(3421) - assert r.set('byte_string', byte_string) - assert r.set('integer', 5) - assert r.set('unicode_string', unicode_string) - assert r.get('byte_string') == byte_string - assert r.get('integer') == str(integer).encode() - assert r.get('unicode_string').decode('utf-8') == unicode_string - - @skip_if_server_version_lt('6.2.0') + unicode_string = chr(3456) + "abcd" + chr(3421) + assert r.set("byte_string", byte_string) + assert r.set("integer", 5) + assert r.set("unicode_string", unicode_string) + assert r.get("byte_string") == byte_string + assert r.get("integer") == str(integer).encode() + assert r.get("unicode_string").decode("utf-8") == unicode_string + + @skip_if_server_version_lt("6.2.0") def test_getdel(self, r): - assert r.getdel('a') is None - r.set('a', 1) - assert r.getdel('a') == b'1' - assert r.getdel('a') is None + assert r.getdel("a") is None + r.set("a", 1) + assert r.getdel("a") == b"1" + assert r.getdel("a") is None - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_getex(self, r): - r.set('a', 1) - assert r.getex('a') == b'1' - assert r.ttl('a') == -1 - assert r.getex('a', ex=60) == b'1' - assert r.ttl('a') == 60 - assert r.getex('a', px=6000) == b'1' - assert r.ttl('a') == 6 + r.set("a", 1) + assert r.getex("a") == b"1" + assert r.ttl("a") == -1 + assert r.getex("a", ex=60) == b"1" + assert r.ttl("a") == 60 + assert r.getex("a", px=6000) == b"1" + assert r.ttl("a") == 6 expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) - assert r.getex('a', pxat=expire_at) == b'1' - assert r.ttl('a') <= 61 - assert r.getex('a', persist=True) == b'1' - assert r.ttl('a') == -1 + assert r.getex("a", pxat=expire_at) == b"1" + assert r.ttl("a") <= 61 + assert r.getex("a", persist=True) == b"1" + assert r.ttl("a") == -1 def test_getitem_and_setitem(self, r): - r['a'] = 'bar' - assert r['a'] == b'bar' + r["a"] = "bar" + assert r["a"] == b"bar" def test_getitem_raises_keyerror_for_missing_key(self, r): with pytest.raises(KeyError): - r['a'] + r["a"] def test_getitem_does_not_raise_keyerror_for_empty_string(self, r): - r['a'] = b"" - assert r['a'] == b"" + r["a"] = b"" + assert r["a"] == b"" def test_get_set_bit(self, r): # no value - assert not r.getbit('a', 5) + assert not r.getbit("a", 5) # set bit 5 - assert not r.setbit('a', 5, True) - assert r.getbit('a', 5) + assert not r.setbit("a", 5, True) + assert r.getbit("a", 5) # unset bit 4 - assert not r.setbit('a', 4, False) - assert not r.getbit('a', 4) + assert not r.setbit("a", 4, False) + assert not r.getbit("a", 4) # set bit 4 - assert not r.setbit('a', 4, True) - assert r.getbit('a', 4) + assert not r.setbit("a", 4, True) + assert r.getbit("a", 4) # set bit 5 again - assert r.setbit('a', 5, True) - assert r.getbit('a', 5) + assert r.setbit("a", 5, True) + assert r.getbit("a", 5) def test_getrange(self, r): - r['a'] = 'foo' - assert r.getrange('a', 0, 0) == b'f' - assert r.getrange('a', 0, 2) == b'foo' - assert r.getrange('a', 3, 4) == b'' + r["a"] = "foo" + assert r.getrange("a", 0, 0) == b"f" + assert r.getrange("a", 0, 2) == b"foo" + assert r.getrange("a", 3, 4) == b"" def test_getset(self, r): - assert r.getset('a', 'foo') is None - assert r.getset('a', 'bar') == b'foo' - assert r.get('a') == b'bar' + assert r.getset("a", "foo") is None + assert r.getset("a", "bar") == b"foo" + assert r.get("a") == b"bar" def test_incr(self, r): - assert r.incr('a') == 1 - assert r['a'] == b'1' - assert r.incr('a') == 2 - assert r['a'] == b'2' - assert r.incr('a', amount=5) == 7 - assert r['a'] == b'7' + assert r.incr("a") == 1 + assert r["a"] == b"1" + assert r.incr("a") == 2 + assert r["a"] == b"2" + assert r.incr("a", amount=5) == 7 + assert r["a"] == b"7" def test_incrby(self, r): - assert r.incrby('a') == 1 - assert r.incrby('a', 4) == 5 - assert r['a'] == b'5' + assert r.incrby("a") == 1 + assert r.incrby("a", 4) == 5 + assert r["a"] == b"5" - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_incrbyfloat(self, r): - assert r.incrbyfloat('a') == 1.0 - assert r['a'] == b'1' - assert r.incrbyfloat('a', 1.1) == 2.1 - assert float(r['a']) == float(2.1) + assert r.incrbyfloat("a") == 1.0 + assert r["a"] == b"1" + assert r.incrbyfloat("a", 1.1) == 2.1 + assert float(r["a"]) == float(2.1) @pytest.mark.onlynoncluster def test_keys(self, r): assert r.keys() == [] - keys_with_underscores = {b'test_a', b'test_b'} - keys = keys_with_underscores.union({b'testc'}) + keys_with_underscores = {b"test_a", b"test_b"} + keys = keys_with_underscores.union({b"testc"}) for key in keys: r[key] = 1 - assert set(r.keys(pattern='test_*')) == keys_with_underscores - assert set(r.keys(pattern='test*')) == keys + assert set(r.keys(pattern="test_*")) == keys_with_underscores + assert set(r.keys(pattern="test*")) == keys @pytest.mark.onlynoncluster def test_mget(self, r): assert r.mget([]) == [] - assert r.mget(['a', 'b']) == [None, None] - r['a'] = '1' - r['b'] = '2' - r['c'] = '3' - assert r.mget('a', 'other', 'b', 'c') == [b'1', None, b'2', b'3'] + assert r.mget(["a", "b"]) == [None, None] + r["a"] = "1" + r["b"] = "2" + r["c"] = "3" + assert r.mget("a", "other", "b", "c") == [b"1", None, b"2", b"3"] @pytest.mark.onlynoncluster - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_lmove(self, r): - r.rpush('a', 'one', 'two', 'three', 'four') - assert r.lmove('a', 'b') - assert r.lmove('a', 'b', 'right', 'left') + r.rpush("a", "one", "two", "three", "four") + assert r.lmove("a", "b") + assert r.lmove("a", "b", "right", "left") @pytest.mark.onlynoncluster - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_blmove(self, r): - r.rpush('a', 'one', 'two', 'three', 'four') - assert r.blmove('a', 'b', 5) - assert r.blmove('a', 'b', 1, 'RIGHT', 'LEFT') + r.rpush("a", "one", "two", "three", "four") + assert r.blmove("a", "b", 5) + assert r.blmove("a", "b", 1, "RIGHT", "LEFT") @pytest.mark.onlynoncluster def test_mset(self, r): - d = {'a': b'1', 'b': b'2', 'c': b'3'} + d = {"a": b"1", "b": b"2", "c": b"3"} assert r.mset(d) for k, v in d.items(): assert r[k] == v @pytest.mark.onlynoncluster def test_msetnx(self, r): - d = {'a': b'1', 'b': b'2', 'c': b'3'} + d = {"a": b"1", "b": b"2", "c": b"3"} assert r.msetnx(d) - d2 = {'a': b'x', 'd': b'4'} + d2 = {"a": b"x", "d": b"4"} assert not r.msetnx(d2) for k, v in d.items(): assert r[k] == v - assert r.get('d') is None + assert r.get("d") is None - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_pexpire(self, r): - assert r.pexpire('a', 60000) is False - r['a'] = 'foo' - assert r.pexpire('a', 60000) is True - assert 0 < r.pttl('a') <= 60000 - assert r.persist('a') - assert r.pttl('a') == -1 - - @skip_if_server_version_lt('2.6.0') + assert r.pexpire("a", 60000) is False + r["a"] = "foo" + assert r.pexpire("a", 60000) is True + assert 0 < r.pttl("a") <= 60000 + assert r.persist("a") + assert r.pttl("a") == -1 + + @skip_if_server_version_lt("2.6.0") def test_pexpireat_datetime(self, r): expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) - r['a'] = 'foo' - assert r.pexpireat('a', expire_at) is True - assert 0 < r.pttl('a') <= 61000 + r["a"] = "foo" + assert r.pexpireat("a", expire_at) is True + assert 0 < r.pttl("a") <= 61000 - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_pexpireat_no_key(self, r): expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) - assert r.pexpireat('a', expire_at) is False + assert r.pexpireat("a", expire_at) is False - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_pexpireat_unixtime(self, r): expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) - r['a'] = 'foo' + r["a"] = "foo" expire_at_seconds = int(time.mktime(expire_at.timetuple())) * 1000 - assert r.pexpireat('a', expire_at_seconds) is True - assert 0 < r.pttl('a') <= 61000 + assert r.pexpireat("a", expire_at_seconds) is True + assert 0 < r.pttl("a") <= 61000 - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_psetex(self, r): - assert r.psetex('a', 1000, 'value') - assert r['a'] == b'value' - assert 0 < r.pttl('a') <= 1000 + assert r.psetex("a", 1000, "value") + assert r["a"] == b"value" + assert 0 < r.pttl("a") <= 1000 - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_psetex_timedelta(self, r): expire_at = datetime.timedelta(milliseconds=1000) - assert r.psetex('a', expire_at, 'value') - assert r['a'] == b'value' - assert 0 < r.pttl('a') <= 1000 + assert r.psetex("a", expire_at, "value") + assert r["a"] == b"value" + assert 0 < r.pttl("a") <= 1000 - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_pttl(self, r): - assert r.pexpire('a', 10000) is False - r['a'] = '1' - assert r.pexpire('a', 10000) is True - assert 0 < r.pttl('a') <= 10000 - assert r.persist('a') - assert r.pttl('a') == -1 - - @skip_if_server_version_lt('2.8.0') + assert r.pexpire("a", 10000) is False + r["a"] = "1" + assert r.pexpire("a", 10000) is True + assert 0 < r.pttl("a") <= 10000 + assert r.persist("a") + assert r.pttl("a") == -1 + + @skip_if_server_version_lt("2.8.0") def test_pttl_no_key(self, r): "PTTL on servers 2.8 and after return -2 when the key doesn't exist" - assert r.pttl('a') == -2 + assert r.pttl("a") == -2 - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_hrandfield(self, r): - assert r.hrandfield('key') is None - r.hset('key', mapping={'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5}) - assert r.hrandfield('key') is not None - assert len(r.hrandfield('key', 2)) == 2 + assert r.hrandfield("key") is None + r.hset("key", mapping={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}) + assert r.hrandfield("key") is not None + assert len(r.hrandfield("key", 2)) == 2 # with values - assert len(r.hrandfield('key', 2, True)) == 4 + assert len(r.hrandfield("key", 2, True)) == 4 # without duplications - assert len(r.hrandfield('key', 10)) == 5 + assert len(r.hrandfield("key", 10)) == 5 # with duplications - assert len(r.hrandfield('key', -10)) == 10 + assert len(r.hrandfield("key", -10)) == 10 @pytest.mark.onlynoncluster def test_randomkey(self, r): assert r.randomkey() is None - for key in ('a', 'b', 'c'): + for key in ("a", "b", "c"): r[key] = 1 - assert r.randomkey() in (b'a', b'b', b'c') + assert r.randomkey() in (b"a", b"b", b"c") @pytest.mark.onlynoncluster def test_rename(self, r): - r['a'] = '1' - assert r.rename('a', 'b') - assert r.get('a') is None - assert r['b'] == b'1' + r["a"] = "1" + assert r.rename("a", "b") + assert r.get("a") is None + assert r["b"] == b"1" @pytest.mark.onlynoncluster def test_renamenx(self, r): - r['a'] = '1' - r['b'] = '2' - assert not r.renamenx('a', 'b') - assert r['a'] == b'1' - assert r['b'] == b'2' + r["a"] = "1" + r["b"] = "2" + assert not r.renamenx("a", "b") + assert r["a"] == b"1" + assert r["b"] == b"2" - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_set_nx(self, r): - assert r.set('a', '1', nx=True) - assert not r.set('a', '2', nx=True) - assert r['a'] == b'1' + assert r.set("a", "1", nx=True) + assert not r.set("a", "2", nx=True) + assert r["a"] == b"1" - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_set_xx(self, r): - assert not r.set('a', '1', xx=True) - assert r.get('a') is None - r['a'] = 'bar' - assert r.set('a', '2', xx=True) - assert r.get('a') == b'2' + assert not r.set("a", "1", xx=True) + assert r.get("a") is None + r["a"] = "bar" + assert r.set("a", "2", xx=True) + assert r.get("a") == b"2" - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_set_px(self, r): - assert r.set('a', '1', px=10000) - assert r['a'] == b'1' - assert 0 < r.pttl('a') <= 10000 - assert 0 < r.ttl('a') <= 10 + assert r.set("a", "1", px=10000) + assert r["a"] == b"1" + assert 0 < r.pttl("a") <= 10000 + assert 0 < r.ttl("a") <= 10 with pytest.raises(exceptions.DataError): - assert r.set('a', '1', px=10.0) + assert r.set("a", "1", px=10.0) - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_set_px_timedelta(self, r): expire_at = datetime.timedelta(milliseconds=1000) - assert r.set('a', '1', px=expire_at) - assert 0 < r.pttl('a') <= 1000 - assert 0 < r.ttl('a') <= 1 + assert r.set("a", "1", px=expire_at) + assert 0 < r.pttl("a") <= 1000 + assert 0 < r.ttl("a") <= 1 - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_set_ex(self, r): - assert r.set('a', '1', ex=10) - assert 0 < r.ttl('a') <= 10 + assert r.set("a", "1", ex=10) + assert 0 < r.ttl("a") <= 10 with pytest.raises(exceptions.DataError): - assert r.set('a', '1', ex=10.0) + assert r.set("a", "1", ex=10.0) - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_set_ex_timedelta(self, r): expire_at = datetime.timedelta(seconds=60) - assert r.set('a', '1', ex=expire_at) - assert 0 < r.ttl('a') <= 60 + assert r.set("a", "1", ex=expire_at) + assert 0 < r.ttl("a") <= 60 - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_set_exat_timedelta(self, r): expire_at = redis_server_time(r) + datetime.timedelta(seconds=10) - assert r.set('a', '1', exat=expire_at) - assert 0 < r.ttl('a') <= 10 + assert r.set("a", "1", exat=expire_at) + assert 0 < r.ttl("a") <= 10 - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_set_pxat_timedelta(self, r): expire_at = redis_server_time(r) + datetime.timedelta(seconds=50) - assert r.set('a', '1', pxat=expire_at) - assert 0 < r.ttl('a') <= 100 + assert r.set("a", "1", pxat=expire_at) + assert 0 < r.ttl("a") <= 100 - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_set_multipleoptions(self, r): - r['a'] = 'val' - assert r.set('a', '1', xx=True, px=10000) - assert 0 < r.ttl('a') <= 10 + r["a"] = "val" + assert r.set("a", "1", xx=True, px=10000) + assert 0 < r.ttl("a") <= 10 @skip_if_server_version_lt("6.0.0") def test_set_keepttl(self, r): - r['a'] = 'val' - assert r.set('a', '1', xx=True, px=10000) - assert 0 < r.ttl('a') <= 10 - r.set('a', '2', keepttl=True) - assert r.get('a') == b'2' - assert 0 < r.ttl('a') <= 10 - - @skip_if_server_version_lt('6.2.0') + r["a"] = "val" + assert r.set("a", "1", xx=True, px=10000) + assert 0 < r.ttl("a") <= 10 + r.set("a", "2", keepttl=True) + assert r.get("a") == b"2" + assert 0 < r.ttl("a") <= 10 + + @skip_if_server_version_lt("6.2.0") def test_set_get(self, r): - assert r.set('a', 'True', get=True) is None - assert r.set('a', 'True', get=True) == b'True' - assert r.set('a', 'foo') is True - assert r.set('a', 'bar', get=True) == b'foo' - assert r.get('a') == b'bar' + assert r.set("a", "True", get=True) is None + assert r.set("a", "True", get=True) == b"True" + assert r.set("a", "foo") is True + assert r.set("a", "bar", get=True) == b"foo" + assert r.get("a") == b"bar" def test_setex(self, r): - assert r.setex('a', 60, '1') - assert r['a'] == b'1' - assert 0 < r.ttl('a') <= 60 + assert r.setex("a", 60, "1") + assert r["a"] == b"1" + assert 0 < r.ttl("a") <= 60 def test_setnx(self, r): - assert r.setnx('a', '1') - assert r['a'] == b'1' - assert not r.setnx('a', '2') - assert r['a'] == b'1' + assert r.setnx("a", "1") + assert r["a"] == b"1" + assert not r.setnx("a", "2") + assert r["a"] == b"1" def test_setrange(self, r): - assert r.setrange('a', 5, 'foo') == 8 - assert r['a'] == b'\0\0\0\0\0foo' - r['a'] = 'abcdefghijh' - assert r.setrange('a', 6, '12345') == 11 - assert r['a'] == b'abcdef12345' + assert r.setrange("a", 5, "foo") == 8 + assert r["a"] == b"\0\0\0\0\0foo" + r["a"] = "abcdefghijh" + assert r.setrange("a", 6, "12345") == 11 + assert r["a"] == b"abcdef12345" - @skip_if_server_version_lt('6.0.0') + @skip_if_server_version_lt("6.0.0") def test_stralgo_lcs(self, r): - key1 = '{foo}key1' - key2 = '{foo}key2' - value1 = 'ohmytext' - value2 = 'mynewtext' - res = 'mytext' + key1 = "{foo}key1" + key2 = "{foo}key2" + value1 = "ohmytext" + value2 = "mynewtext" + res = "mytext" if skip_if_redis_enterprise(None).args[0] is True: with pytest.raises(redis.exceptions.ResponseError): - assert r.stralgo('LCS', value1, value2) == res + assert r.stralgo("LCS", value1, value2) == res return # test LCS of strings - assert r.stralgo('LCS', value1, value2) == res + assert r.stralgo("LCS", value1, value2) == res # test using keys r.mset({key1: value1, key2: value2}) - assert r.stralgo('LCS', key1, key2, specific_argument="keys") == res + assert r.stralgo("LCS", key1, key2, specific_argument="keys") == res # test other labels - assert r.stralgo('LCS', value1, value2, len=True) == len(res) - assert r.stralgo('LCS', value1, value2, idx=True) == \ - { - 'len': len(res), - 'matches': [[(4, 7), (5, 8)], [(2, 3), (0, 1)]] - } - assert r.stralgo('LCS', value1, value2, - idx=True, withmatchlen=True) == \ - { - 'len': len(res), - 'matches': [[4, (4, 7), (5, 8)], [2, (2, 3), (0, 1)]] - } - assert r.stralgo('LCS', value1, value2, - idx=True, minmatchlen=4, withmatchlen=True) == \ - { - 'len': len(res), - 'matches': [[4, (4, 7), (5, 8)]] - } - - @skip_if_server_version_lt('6.0.0') + assert r.stralgo("LCS", value1, value2, len=True) == len(res) + assert r.stralgo("LCS", value1, value2, idx=True) == { + "len": len(res), + "matches": [[(4, 7), (5, 8)], [(2, 3), (0, 1)]], + } + assert r.stralgo("LCS", value1, value2, idx=True, withmatchlen=True) == { + "len": len(res), + "matches": [[4, (4, 7), (5, 8)], [2, (2, 3), (0, 1)]], + } + assert r.stralgo( + "LCS", value1, value2, idx=True, minmatchlen=4, withmatchlen=True + ) == {"len": len(res), "matches": [[4, (4, 7), (5, 8)]]} + + @skip_if_server_version_lt("6.0.0") def test_stralgo_negative(self, r): with pytest.raises(exceptions.DataError): - r.stralgo('ISSUB', 'value1', 'value2') + r.stralgo("ISSUB", "value1", "value2") with pytest.raises(exceptions.DataError): - r.stralgo('LCS', 'value1', 'value2', len=True, idx=True) + r.stralgo("LCS", "value1", "value2", len=True, idx=True) with pytest.raises(exceptions.DataError): - r.stralgo('LCS', 'value1', 'value2', specific_argument="INT") + r.stralgo("LCS", "value1", "value2", specific_argument="INT") with pytest.raises(ValueError): - r.stralgo('LCS', 'value1', 'value2', idx=True, minmatchlen="one") + r.stralgo("LCS", "value1", "value2", idx=True, minmatchlen="one") def test_strlen(self, r): - r['a'] = 'foo' - assert r.strlen('a') == 3 + r["a"] = "foo" + assert r.strlen("a") == 3 def test_substr(self, r): - r['a'] = '0123456789' + r["a"] = "0123456789" if skip_if_redis_enterprise(None).args[0] is True: with pytest.raises(redis.exceptions.ResponseError): - assert r.substr('a', 0) == b'0123456789' + assert r.substr("a", 0) == b"0123456789" return - assert r.substr('a', 0) == b'0123456789' - assert r.substr('a', 2) == b'23456789' - assert r.substr('a', 3, 5) == b'345' - assert r.substr('a', 3, -2) == b'345678' + assert r.substr("a", 0) == b"0123456789" + assert r.substr("a", 2) == b"23456789" + assert r.substr("a", 3, 5) == b"345" + assert r.substr("a", 3, -2) == b"345678" def test_ttl(self, r): - r['a'] = '1' - assert r.expire('a', 10) - assert 0 < r.ttl('a') <= 10 - assert r.persist('a') - assert r.ttl('a') == -1 + r["a"] = "1" + assert r.expire("a", 10) + assert 0 < r.ttl("a") <= 10 + assert r.persist("a") + assert r.ttl("a") == -1 - @skip_if_server_version_lt('2.8.0') + @skip_if_server_version_lt("2.8.0") def test_ttl_nokey(self, r): "TTL on servers 2.8 and after return -2 when the key doesn't exist" - assert r.ttl('a') == -2 + assert r.ttl("a") == -2 def test_type(self, r): - assert r.type('a') == b'none' - r['a'] = '1' - assert r.type('a') == b'string' - del r['a'] - r.lpush('a', '1') - assert r.type('a') == b'list' - del r['a'] - r.sadd('a', '1') - assert r.type('a') == b'set' - del r['a'] - r.zadd('a', {'1': 1}) - assert r.type('a') == b'zset' + assert r.type("a") == b"none" + r["a"] = "1" + assert r.type("a") == b"string" + del r["a"] + r.lpush("a", "1") + assert r.type("a") == b"list" + del r["a"] + r.sadd("a", "1") + assert r.type("a") == b"set" + del r["a"] + r.zadd("a", {"1": 1}) + assert r.type("a") == b"zset" # LIST COMMANDS @pytest.mark.onlynoncluster def test_blpop(self, r): - r.rpush('a', '1', '2') - r.rpush('b', '3', '4') - assert r.blpop(['b', 'a'], timeout=1) == (b'b', b'3') - assert r.blpop(['b', 'a'], timeout=1) == (b'b', b'4') - assert r.blpop(['b', 'a'], timeout=1) == (b'a', b'1') - assert r.blpop(['b', 'a'], timeout=1) == (b'a', b'2') - assert r.blpop(['b', 'a'], timeout=1) is None - r.rpush('c', '1') - assert r.blpop('c', timeout=1) == (b'c', b'1') + r.rpush("a", "1", "2") + r.rpush("b", "3", "4") + assert r.blpop(["b", "a"], timeout=1) == (b"b", b"3") + assert r.blpop(["b", "a"], timeout=1) == (b"b", b"4") + assert r.blpop(["b", "a"], timeout=1) == (b"a", b"1") + assert r.blpop(["b", "a"], timeout=1) == (b"a", b"2") + assert r.blpop(["b", "a"], timeout=1) is None + r.rpush("c", "1") + assert r.blpop("c", timeout=1) == (b"c", b"1") @pytest.mark.onlynoncluster def test_brpop(self, r): - r.rpush('a', '1', '2') - r.rpush('b', '3', '4') - assert r.brpop(['b', 'a'], timeout=1) == (b'b', b'4') - assert r.brpop(['b', 'a'], timeout=1) == (b'b', b'3') - assert r.brpop(['b', 'a'], timeout=1) == (b'a', b'2') - assert r.brpop(['b', 'a'], timeout=1) == (b'a', b'1') - assert r.brpop(['b', 'a'], timeout=1) is None - r.rpush('c', '1') - assert r.brpop('c', timeout=1) == (b'c', b'1') + r.rpush("a", "1", "2") + r.rpush("b", "3", "4") + assert r.brpop(["b", "a"], timeout=1) == (b"b", b"4") + assert r.brpop(["b", "a"], timeout=1) == (b"b", b"3") + assert r.brpop(["b", "a"], timeout=1) == (b"a", b"2") + assert r.brpop(["b", "a"], timeout=1) == (b"a", b"1") + assert r.brpop(["b", "a"], timeout=1) is None + r.rpush("c", "1") + assert r.brpop("c", timeout=1) == (b"c", b"1") @pytest.mark.onlynoncluster def test_brpoplpush(self, r): - r.rpush('a', '1', '2') - r.rpush('b', '3', '4') - assert r.brpoplpush('a', 'b') == b'2' - assert r.brpoplpush('a', 'b') == b'1' - assert r.brpoplpush('a', 'b', timeout=1) is None - assert r.lrange('a', 0, -1) == [] - assert r.lrange('b', 0, -1) == [b'1', b'2', b'3', b'4'] + r.rpush("a", "1", "2") + r.rpush("b", "3", "4") + assert r.brpoplpush("a", "b") == b"2" + assert r.brpoplpush("a", "b") == b"1" + assert r.brpoplpush("a", "b", timeout=1) is None + assert r.lrange("a", 0, -1) == [] + assert r.lrange("b", 0, -1) == [b"1", b"2", b"3", b"4"] @pytest.mark.onlynoncluster def test_brpoplpush_empty_string(self, r): - r.rpush('a', '') - assert r.brpoplpush('a', 'b') == b'' + r.rpush("a", "") + assert r.brpoplpush("a", "b") == b"" def test_lindex(self, r): - r.rpush('a', '1', '2', '3') - assert r.lindex('a', '0') == b'1' - assert r.lindex('a', '1') == b'2' - assert r.lindex('a', '2') == b'3' + r.rpush("a", "1", "2", "3") + assert r.lindex("a", "0") == b"1" + assert r.lindex("a", "1") == b"2" + assert r.lindex("a", "2") == b"3" def test_linsert(self, r): - r.rpush('a', '1', '2', '3') - assert r.linsert('a', 'after', '2', '2.5') == 4 - assert r.lrange('a', 0, -1) == [b'1', b'2', b'2.5', b'3'] - assert r.linsert('a', 'before', '2', '1.5') == 5 - assert r.lrange('a', 0, -1) == \ - [b'1', b'1.5', b'2', b'2.5', b'3'] + r.rpush("a", "1", "2", "3") + assert r.linsert("a", "after", "2", "2.5") == 4 + assert r.lrange("a", 0, -1) == [b"1", b"2", b"2.5", b"3"] + assert r.linsert("a", "before", "2", "1.5") == 5 + assert r.lrange("a", 0, -1) == [b"1", b"1.5", b"2", b"2.5", b"3"] def test_llen(self, r): - r.rpush('a', '1', '2', '3') - assert r.llen('a') == 3 + r.rpush("a", "1", "2", "3") + assert r.llen("a") == 3 def test_lpop(self, r): - r.rpush('a', '1', '2', '3') - assert r.lpop('a') == b'1' - assert r.lpop('a') == b'2' - assert r.lpop('a') == b'3' - assert r.lpop('a') is None + r.rpush("a", "1", "2", "3") + assert r.lpop("a") == b"1" + assert r.lpop("a") == b"2" + assert r.lpop("a") == b"3" + assert r.lpop("a") is None - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_lpop_count(self, r): - r.rpush('a', '1', '2', '3') - assert r.lpop('a', 2) == [b'1', b'2'] - assert r.lpop('a', 1) == [b'3'] - assert r.lpop('a') is None - assert r.lpop('a', 3) is None + r.rpush("a", "1", "2", "3") + assert r.lpop("a", 2) == [b"1", b"2"] + assert r.lpop("a", 1) == [b"3"] + assert r.lpop("a") is None + assert r.lpop("a", 3) is None def test_lpush(self, r): - assert r.lpush('a', '1') == 1 - assert r.lpush('a', '2') == 2 - assert r.lpush('a', '3', '4') == 4 - assert r.lrange('a', 0, -1) == [b'4', b'3', b'2', b'1'] + assert r.lpush("a", "1") == 1 + assert r.lpush("a", "2") == 2 + assert r.lpush("a", "3", "4") == 4 + assert r.lrange("a", 0, -1) == [b"4", b"3", b"2", b"1"] def test_lpushx(self, r): - assert r.lpushx('a', '1') == 0 - assert r.lrange('a', 0, -1) == [] - r.rpush('a', '1', '2', '3') - assert r.lpushx('a', '4') == 4 - assert r.lrange('a', 0, -1) == [b'4', b'1', b'2', b'3'] + assert r.lpushx("a", "1") == 0 + assert r.lrange("a", 0, -1) == [] + r.rpush("a", "1", "2", "3") + assert r.lpushx("a", "4") == 4 + assert r.lrange("a", 0, -1) == [b"4", b"1", b"2", b"3"] - @skip_if_server_version_lt('4.0.0') + @skip_if_server_version_lt("4.0.0") def test_lpushx_with_list(self, r): # now with a list - r.lpush('somekey', 'a') - r.lpush('somekey', 'b') - assert r.lpushx('somekey', 'foo', 'asdasd', 55, 'asdasdas') == 6 - res = r.lrange('somekey', 0, -1) - assert res == [b'asdasdas', b'55', b'asdasd', b'foo', b'b', b'a'] + r.lpush("somekey", "a") + r.lpush("somekey", "b") + assert r.lpushx("somekey", "foo", "asdasd", 55, "asdasdas") == 6 + res = r.lrange("somekey", 0, -1) + assert res == [b"asdasdas", b"55", b"asdasd", b"foo", b"b", b"a"] def test_lrange(self, r): - r.rpush('a', '1', '2', '3', '4', '5') - assert r.lrange('a', 0, 2) == [b'1', b'2', b'3'] - assert r.lrange('a', 2, 10) == [b'3', b'4', b'5'] - assert r.lrange('a', 0, -1) == [b'1', b'2', b'3', b'4', b'5'] + r.rpush("a", "1", "2", "3", "4", "5") + assert r.lrange("a", 0, 2) == [b"1", b"2", b"3"] + assert r.lrange("a", 2, 10) == [b"3", b"4", b"5"] + assert r.lrange("a", 0, -1) == [b"1", b"2", b"3", b"4", b"5"] def test_lrem(self, r): - r.rpush('a', 'Z', 'b', 'Z', 'Z', 'c', 'Z', 'Z') + r.rpush("a", "Z", "b", "Z", "Z", "c", "Z", "Z") # remove the first 'Z' item - assert r.lrem('a', 1, 'Z') == 1 - assert r.lrange('a', 0, -1) == [b'b', b'Z', b'Z', b'c', b'Z', b'Z'] + assert r.lrem("a", 1, "Z") == 1 + assert r.lrange("a", 0, -1) == [b"b", b"Z", b"Z", b"c", b"Z", b"Z"] # remove the last 2 'Z' items - assert r.lrem('a', -2, 'Z') == 2 - assert r.lrange('a', 0, -1) == [b'b', b'Z', b'Z', b'c'] + assert r.lrem("a", -2, "Z") == 2 + assert r.lrange("a", 0, -1) == [b"b", b"Z", b"Z", b"c"] # remove all 'Z' items - assert r.lrem('a', 0, 'Z') == 2 - assert r.lrange('a', 0, -1) == [b'b', b'c'] + assert r.lrem("a", 0, "Z") == 2 + assert r.lrange("a", 0, -1) == [b"b", b"c"] def test_lset(self, r): - r.rpush('a', '1', '2', '3') - assert r.lrange('a', 0, -1) == [b'1', b'2', b'3'] - assert r.lset('a', 1, '4') - assert r.lrange('a', 0, 2) == [b'1', b'4', b'3'] + r.rpush("a", "1", "2", "3") + assert r.lrange("a", 0, -1) == [b"1", b"2", b"3"] + assert r.lset("a", 1, "4") + assert r.lrange("a", 0, 2) == [b"1", b"4", b"3"] def test_ltrim(self, r): - r.rpush('a', '1', '2', '3') - assert r.ltrim('a', 0, 1) - assert r.lrange('a', 0, -1) == [b'1', b'2'] + r.rpush("a", "1", "2", "3") + assert r.ltrim("a", 0, 1) + assert r.lrange("a", 0, -1) == [b"1", b"2"] def test_rpop(self, r): - r.rpush('a', '1', '2', '3') - assert r.rpop('a') == b'3' - assert r.rpop('a') == b'2' - assert r.rpop('a') == b'1' - assert r.rpop('a') is None + r.rpush("a", "1", "2", "3") + assert r.rpop("a") == b"3" + assert r.rpop("a") == b"2" + assert r.rpop("a") == b"1" + assert r.rpop("a") is None - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_rpop_count(self, r): - r.rpush('a', '1', '2', '3') - assert r.rpop('a', 2) == [b'3', b'2'] - assert r.rpop('a', 1) == [b'1'] - assert r.rpop('a') is None - assert r.rpop('a', 3) is None + r.rpush("a", "1", "2", "3") + assert r.rpop("a", 2) == [b"3", b"2"] + assert r.rpop("a", 1) == [b"1"] + assert r.rpop("a") is None + assert r.rpop("a", 3) is None @pytest.mark.onlynoncluster def test_rpoplpush(self, r): - r.rpush('a', 'a1', 'a2', 'a3') - r.rpush('b', 'b1', 'b2', 'b3') - assert r.rpoplpush('a', 'b') == b'a3' - assert r.lrange('a', 0, -1) == [b'a1', b'a2'] - assert r.lrange('b', 0, -1) == [b'a3', b'b1', b'b2', b'b3'] + r.rpush("a", "a1", "a2", "a3") + r.rpush("b", "b1", "b2", "b3") + assert r.rpoplpush("a", "b") == b"a3" + assert r.lrange("a", 0, -1) == [b"a1", b"a2"] + assert r.lrange("b", 0, -1) == [b"a3", b"b1", b"b2", b"b3"] def test_rpush(self, r): - assert r.rpush('a', '1') == 1 - assert r.rpush('a', '2') == 2 - assert r.rpush('a', '3', '4') == 4 - assert r.lrange('a', 0, -1) == [b'1', b'2', b'3', b'4'] + assert r.rpush("a", "1") == 1 + assert r.rpush("a", "2") == 2 + assert r.rpush("a", "3", "4") == 4 + assert r.lrange("a", 0, -1) == [b"1", b"2", b"3", b"4"] - @skip_if_server_version_lt('6.0.6') + @skip_if_server_version_lt("6.0.6") def test_lpos(self, r): - assert r.rpush('a', 'a', 'b', 'c', '1', '2', '3', 'c', 'c') == 8 - assert r.lpos('a', 'a') == 0 - assert r.lpos('a', 'c') == 2 + assert r.rpush("a", "a", "b", "c", "1", "2", "3", "c", "c") == 8 + assert r.lpos("a", "a") == 0 + assert r.lpos("a", "c") == 2 - assert r.lpos('a', 'c', rank=1) == 2 - assert r.lpos('a', 'c', rank=2) == 6 - assert r.lpos('a', 'c', rank=4) is None - assert r.lpos('a', 'c', rank=-1) == 7 - assert r.lpos('a', 'c', rank=-2) == 6 + assert r.lpos("a", "c", rank=1) == 2 + assert r.lpos("a", "c", rank=2) == 6 + assert r.lpos("a", "c", rank=4) is None + assert r.lpos("a", "c", rank=-1) == 7 + assert r.lpos("a", "c", rank=-2) == 6 - assert r.lpos('a', 'c', count=0) == [2, 6, 7] - assert r.lpos('a', 'c', count=1) == [2] - assert r.lpos('a', 'c', count=2) == [2, 6] - assert r.lpos('a', 'c', count=100) == [2, 6, 7] + assert r.lpos("a", "c", count=0) == [2, 6, 7] + assert r.lpos("a", "c", count=1) == [2] + assert r.lpos("a", "c", count=2) == [2, 6] + assert r.lpos("a", "c", count=100) == [2, 6, 7] - assert r.lpos('a', 'c', count=0, rank=2) == [6, 7] - assert r.lpos('a', 'c', count=2, rank=-1) == [7, 6] + assert r.lpos("a", "c", count=0, rank=2) == [6, 7] + assert r.lpos("a", "c", count=2, rank=-1) == [7, 6] - assert r.lpos('axxx', 'c', count=0, rank=2) == [] - assert r.lpos('axxx', 'c') is None + assert r.lpos("axxx", "c", count=0, rank=2) == [] + assert r.lpos("axxx", "c") is None - assert r.lpos('a', 'x', count=2) == [] - assert r.lpos('a', 'x') is None + assert r.lpos("a", "x", count=2) == [] + assert r.lpos("a", "x") is None - assert r.lpos('a', 'a', count=0, maxlen=1) == [0] - assert r.lpos('a', 'c', count=0, maxlen=1) == [] - assert r.lpos('a', 'c', count=0, maxlen=3) == [2] - assert r.lpos('a', 'c', count=0, maxlen=3, rank=-1) == [7, 6] - assert r.lpos('a', 'c', count=0, maxlen=7, rank=2) == [6] + assert r.lpos("a", "a", count=0, maxlen=1) == [0] + assert r.lpos("a", "c", count=0, maxlen=1) == [] + assert r.lpos("a", "c", count=0, maxlen=3) == [2] + assert r.lpos("a", "c", count=0, maxlen=3, rank=-1) == [7, 6] + assert r.lpos("a", "c", count=0, maxlen=7, rank=2) == [6] def test_rpushx(self, r): - assert r.rpushx('a', 'b') == 0 - assert r.lrange('a', 0, -1) == [] - r.rpush('a', '1', '2', '3') - assert r.rpushx('a', '4') == 4 - assert r.lrange('a', 0, -1) == [b'1', b'2', b'3', b'4'] + assert r.rpushx("a", "b") == 0 + assert r.lrange("a", 0, -1) == [] + r.rpush("a", "1", "2", "3") + assert r.rpushx("a", "4") == 4 + assert r.lrange("a", 0, -1) == [b"1", b"2", b"3", b"4"] # SCAN COMMANDS @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.8.0') + @skip_if_server_version_lt("2.8.0") def test_scan(self, r): - r.set('a', 1) - r.set('b', 2) - r.set('c', 3) + r.set("a", 1) + r.set("b", 2) + r.set("c", 3) cursor, keys = r.scan() assert cursor == 0 - assert set(keys) == {b'a', b'b', b'c'} - _, keys = r.scan(match='a') - assert set(keys) == {b'a'} + assert set(keys) == {b"a", b"b", b"c"} + _, keys = r.scan(match="a") + assert set(keys) == {b"a"} @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") def test_scan_type(self, r): - r.sadd('a-set', 1) - r.hset('a-hash', 'foo', 2) - r.lpush('a-list', 'aux', 3) - _, keys = r.scan(match='a*', _type='SET') - assert set(keys) == {b'a-set'} + r.sadd("a-set", 1) + r.hset("a-hash", "foo", 2) + r.lpush("a-list", "aux", 3) + _, keys = r.scan(match="a*", _type="SET") + assert set(keys) == {b"a-set"} @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.8.0') + @skip_if_server_version_lt("2.8.0") def test_scan_iter(self, r): - r.set('a', 1) - r.set('b', 2) - r.set('c', 3) + r.set("a", 1) + r.set("b", 2) + r.set("c", 3) keys = list(r.scan_iter()) - assert set(keys) == {b'a', b'b', b'c'} - keys = list(r.scan_iter(match='a')) - assert set(keys) == {b'a'} + assert set(keys) == {b"a", b"b", b"c"} + keys = list(r.scan_iter(match="a")) + assert set(keys) == {b"a"} - @skip_if_server_version_lt('2.8.0') + @skip_if_server_version_lt("2.8.0") def test_sscan(self, r): - r.sadd('a', 1, 2, 3) - cursor, members = r.sscan('a') + r.sadd("a", 1, 2, 3) + cursor, members = r.sscan("a") assert cursor == 0 - assert set(members) == {b'1', b'2', b'3'} - _, members = r.sscan('a', match=b'1') - assert set(members) == {b'1'} + assert set(members) == {b"1", b"2", b"3"} + _, members = r.sscan("a", match=b"1") + assert set(members) == {b"1"} - @skip_if_server_version_lt('2.8.0') + @skip_if_server_version_lt("2.8.0") def test_sscan_iter(self, r): - r.sadd('a', 1, 2, 3) - members = list(r.sscan_iter('a')) - assert set(members) == {b'1', b'2', b'3'} - members = list(r.sscan_iter('a', match=b'1')) - assert set(members) == {b'1'} + r.sadd("a", 1, 2, 3) + members = list(r.sscan_iter("a")) + assert set(members) == {b"1", b"2", b"3"} + members = list(r.sscan_iter("a", match=b"1")) + assert set(members) == {b"1"} - @skip_if_server_version_lt('2.8.0') + @skip_if_server_version_lt("2.8.0") def test_hscan(self, r): - r.hset('a', mapping={'a': 1, 'b': 2, 'c': 3}) - cursor, dic = r.hscan('a') + r.hset("a", mapping={"a": 1, "b": 2, "c": 3}) + cursor, dic = r.hscan("a") assert cursor == 0 - assert dic == {b'a': b'1', b'b': b'2', b'c': b'3'} - _, dic = r.hscan('a', match='a') - assert dic == {b'a': b'1'} + assert dic == {b"a": b"1", b"b": b"2", b"c": b"3"} + _, dic = r.hscan("a", match="a") + assert dic == {b"a": b"1"} - @skip_if_server_version_lt('2.8.0') + @skip_if_server_version_lt("2.8.0") def test_hscan_iter(self, r): - r.hset('a', mapping={'a': 1, 'b': 2, 'c': 3}) - dic = dict(r.hscan_iter('a')) - assert dic == {b'a': b'1', b'b': b'2', b'c': b'3'} - dic = dict(r.hscan_iter('a', match='a')) - assert dic == {b'a': b'1'} + r.hset("a", mapping={"a": 1, "b": 2, "c": 3}) + dic = dict(r.hscan_iter("a")) + assert dic == {b"a": b"1", b"b": b"2", b"c": b"3"} + dic = dict(r.hscan_iter("a", match="a")) + assert dic == {b"a": b"1"} - @skip_if_server_version_lt('2.8.0') + @skip_if_server_version_lt("2.8.0") def test_zscan(self, r): - r.zadd('a', {'a': 1, 'b': 2, 'c': 3}) - cursor, pairs = r.zscan('a') + r.zadd("a", {"a": 1, "b": 2, "c": 3}) + cursor, pairs = r.zscan("a") assert cursor == 0 - assert set(pairs) == {(b'a', 1), (b'b', 2), (b'c', 3)} - _, pairs = r.zscan('a', match='a') - assert set(pairs) == {(b'a', 1)} + assert set(pairs) == {(b"a", 1), (b"b", 2), (b"c", 3)} + _, pairs = r.zscan("a", match="a") + assert set(pairs) == {(b"a", 1)} - @skip_if_server_version_lt('2.8.0') + @skip_if_server_version_lt("2.8.0") def test_zscan_iter(self, r): - r.zadd('a', {'a': 1, 'b': 2, 'c': 3}) - pairs = list(r.zscan_iter('a')) - assert set(pairs) == {(b'a', 1), (b'b', 2), (b'c', 3)} - pairs = list(r.zscan_iter('a', match='a')) - assert set(pairs) == {(b'a', 1)} + r.zadd("a", {"a": 1, "b": 2, "c": 3}) + pairs = list(r.zscan_iter("a")) + assert set(pairs) == {(b"a", 1), (b"b", 2), (b"c", 3)} + pairs = list(r.zscan_iter("a", match="a")) + assert set(pairs) == {(b"a", 1)} # SET COMMANDS def test_sadd(self, r): - members = {b'1', b'2', b'3'} - r.sadd('a', *members) - assert r.smembers('a') == members + members = {b"1", b"2", b"3"} + r.sadd("a", *members) + assert r.smembers("a") == members def test_scard(self, r): - r.sadd('a', '1', '2', '3') - assert r.scard('a') == 3 + r.sadd("a", "1", "2", "3") + assert r.scard("a") == 3 @pytest.mark.onlynoncluster def test_sdiff(self, r): - r.sadd('a', '1', '2', '3') - assert r.sdiff('a', 'b') == {b'1', b'2', b'3'} - r.sadd('b', '2', '3') - assert r.sdiff('a', 'b') == {b'1'} + r.sadd("a", "1", "2", "3") + assert r.sdiff("a", "b") == {b"1", b"2", b"3"} + r.sadd("b", "2", "3") + assert r.sdiff("a", "b") == {b"1"} @pytest.mark.onlynoncluster def test_sdiffstore(self, r): - r.sadd('a', '1', '2', '3') - assert r.sdiffstore('c', 'a', 'b') == 3 - assert r.smembers('c') == {b'1', b'2', b'3'} - r.sadd('b', '2', '3') - assert r.sdiffstore('c', 'a', 'b') == 1 - assert r.smembers('c') == {b'1'} + r.sadd("a", "1", "2", "3") + assert r.sdiffstore("c", "a", "b") == 3 + assert r.smembers("c") == {b"1", b"2", b"3"} + r.sadd("b", "2", "3") + assert r.sdiffstore("c", "a", "b") == 1 + assert r.smembers("c") == {b"1"} @pytest.mark.onlynoncluster def test_sinter(self, r): - r.sadd('a', '1', '2', '3') - assert r.sinter('a', 'b') == set() - r.sadd('b', '2', '3') - assert r.sinter('a', 'b') == {b'2', b'3'} + r.sadd("a", "1", "2", "3") + assert r.sinter("a", "b") == set() + r.sadd("b", "2", "3") + assert r.sinter("a", "b") == {b"2", b"3"} @pytest.mark.onlynoncluster def test_sinterstore(self, r): - r.sadd('a', '1', '2', '3') - assert r.sinterstore('c', 'a', 'b') == 0 - assert r.smembers('c') == set() - r.sadd('b', '2', '3') - assert r.sinterstore('c', 'a', 'b') == 2 - assert r.smembers('c') == {b'2', b'3'} + r.sadd("a", "1", "2", "3") + assert r.sinterstore("c", "a", "b") == 0 + assert r.smembers("c") == set() + r.sadd("b", "2", "3") + assert r.sinterstore("c", "a", "b") == 2 + assert r.smembers("c") == {b"2", b"3"} def test_sismember(self, r): - r.sadd('a', '1', '2', '3') - assert r.sismember('a', '1') - assert r.sismember('a', '2') - assert r.sismember('a', '3') - assert not r.sismember('a', '4') + r.sadd("a", "1", "2", "3") + assert r.sismember("a", "1") + assert r.sismember("a", "2") + assert r.sismember("a", "3") + assert not r.sismember("a", "4") def test_smembers(self, r): - r.sadd('a', '1', '2', '3') - assert r.smembers('a') == {b'1', b'2', b'3'} + r.sadd("a", "1", "2", "3") + assert r.smembers("a") == {b"1", b"2", b"3"} - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_smismember(self, r): - r.sadd('a', '1', '2', '3') + r.sadd("a", "1", "2", "3") result_list = [True, False, True, True] - assert r.smismember('a', '1', '4', '2', '3') == result_list - assert r.smismember('a', ['1', '4', '2', '3']) == result_list + assert r.smismember("a", "1", "4", "2", "3") == result_list + assert r.smismember("a", ["1", "4", "2", "3"]) == result_list @pytest.mark.onlynoncluster def test_smove(self, r): - r.sadd('a', 'a1', 'a2') - r.sadd('b', 'b1', 'b2') - assert r.smove('a', 'b', 'a1') - assert r.smembers('a') == {b'a2'} - assert r.smembers('b') == {b'b1', b'b2', b'a1'} + r.sadd("a", "a1", "a2") + r.sadd("b", "b1", "b2") + assert r.smove("a", "b", "a1") + assert r.smembers("a") == {b"a2"} + assert r.smembers("b") == {b"b1", b"b2", b"a1"} def test_spop(self, r): - s = [b'1', b'2', b'3'] - r.sadd('a', *s) - value = r.spop('a') + s = [b"1", b"2", b"3"] + r.sadd("a", *s) + value = r.spop("a") assert value in s - assert r.smembers('a') == set(s) - {value} + assert r.smembers("a") == set(s) - {value} - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_spop_multi_value(self, r): - s = [b'1', b'2', b'3'] - r.sadd('a', *s) - values = r.spop('a', 2) + s = [b"1", b"2", b"3"] + r.sadd("a", *s) + values = r.spop("a", 2) assert len(values) == 2 for value in values: assert value in s - assert r.spop('a', 1) == list(set(s) - set(values)) + assert r.spop("a", 1) == list(set(s) - set(values)) def test_srandmember(self, r): - s = [b'1', b'2', b'3'] - r.sadd('a', *s) - assert r.srandmember('a') in s + s = [b"1", b"2", b"3"] + r.sadd("a", *s) + assert r.srandmember("a") in s - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_srandmember_multi_value(self, r): - s = [b'1', b'2', b'3'] - r.sadd('a', *s) - randoms = r.srandmember('a', number=2) + s = [b"1", b"2", b"3"] + r.sadd("a", *s) + randoms = r.srandmember("a", number=2) assert len(randoms) == 2 assert set(randoms).intersection(s) == set(randoms) def test_srem(self, r): - r.sadd('a', '1', '2', '3', '4') - assert r.srem('a', '5') == 0 - assert r.srem('a', '2', '4') == 2 - assert r.smembers('a') == {b'1', b'3'} + r.sadd("a", "1", "2", "3", "4") + assert r.srem("a", "5") == 0 + assert r.srem("a", "2", "4") == 2 + assert r.smembers("a") == {b"1", b"3"} @pytest.mark.onlynoncluster def test_sunion(self, r): - r.sadd('a', '1', '2') - r.sadd('b', '2', '3') - assert r.sunion('a', 'b') == {b'1', b'2', b'3'} + r.sadd("a", "1", "2") + r.sadd("b", "2", "3") + assert r.sunion("a", "b") == {b"1", b"2", b"3"} @pytest.mark.onlynoncluster def test_sunionstore(self, r): - r.sadd('a', '1', '2') - r.sadd('b', '2', '3') - assert r.sunionstore('c', 'a', 'b') == 3 - assert r.smembers('c') == {b'1', b'2', b'3'} + r.sadd("a", "1", "2") + r.sadd("b", "2", "3") + assert r.sunionstore("c", "a", "b") == 3 + assert r.smembers("c") == {b"1", b"2", b"3"} - @skip_if_server_version_lt('1.0.0') + @skip_if_server_version_lt("1.0.0") def test_debug_segfault(self, r): with pytest.raises(NotImplementedError): r.debug_segfault() @pytest.mark.onlynoncluster - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_script_debug(self, r): with pytest.raises(NotImplementedError): r.script_debug() # SORTED SET COMMANDS def test_zadd(self, r): - mapping = {'a1': 1.0, 'a2': 2.0, 'a3': 3.0} - r.zadd('a', mapping) - assert r.zrange('a', 0, -1, withscores=True) == \ - [(b'a1', 1.0), (b'a2', 2.0), (b'a3', 3.0)] + mapping = {"a1": 1.0, "a2": 2.0, "a3": 3.0} + r.zadd("a", mapping) + assert r.zrange("a", 0, -1, withscores=True) == [ + (b"a1", 1.0), + (b"a2", 2.0), + (b"a3", 3.0), + ] # error cases with pytest.raises(exceptions.DataError): - r.zadd('a', {}) + r.zadd("a", {}) # cannot use both nx and xx options with pytest.raises(exceptions.DataError): - r.zadd('a', mapping, nx=True, xx=True) + r.zadd("a", mapping, nx=True, xx=True) # cannot use the incr options with more than one value with pytest.raises(exceptions.DataError): - r.zadd('a', mapping, incr=True) + r.zadd("a", mapping, incr=True) def test_zadd_nx(self, r): - assert r.zadd('a', {'a1': 1}) == 1 - assert r.zadd('a', {'a1': 99, 'a2': 2}, nx=True) == 1 - assert r.zrange('a', 0, -1, withscores=True) == \ - [(b'a1', 1.0), (b'a2', 2.0)] + assert r.zadd("a", {"a1": 1}) == 1 + assert r.zadd("a", {"a1": 99, "a2": 2}, nx=True) == 1 + assert r.zrange("a", 0, -1, withscores=True) == [(b"a1", 1.0), (b"a2", 2.0)] def test_zadd_xx(self, r): - assert r.zadd('a', {'a1': 1}) == 1 - assert r.zadd('a', {'a1': 99, 'a2': 2}, xx=True) == 0 - assert r.zrange('a', 0, -1, withscores=True) == \ - [(b'a1', 99.0)] + assert r.zadd("a", {"a1": 1}) == 1 + assert r.zadd("a", {"a1": 99, "a2": 2}, xx=True) == 0 + assert r.zrange("a", 0, -1, withscores=True) == [(b"a1", 99.0)] def test_zadd_ch(self, r): - assert r.zadd('a', {'a1': 1}) == 1 - assert r.zadd('a', {'a1': 99, 'a2': 2}, ch=True) == 2 - assert r.zrange('a', 0, -1, withscores=True) == \ - [(b'a2', 2.0), (b'a1', 99.0)] + assert r.zadd("a", {"a1": 1}) == 1 + assert r.zadd("a", {"a1": 99, "a2": 2}, ch=True) == 2 + assert r.zrange("a", 0, -1, withscores=True) == [(b"a2", 2.0), (b"a1", 99.0)] def test_zadd_incr(self, r): - assert r.zadd('a', {'a1': 1}) == 1 - assert r.zadd('a', {'a1': 4.5}, incr=True) == 5.5 + assert r.zadd("a", {"a1": 1}) == 1 + assert r.zadd("a", {"a1": 4.5}, incr=True) == 5.5 def test_zadd_incr_with_xx(self, r): # this asks zadd to incr 'a1' only if it exists, but it clearly # doesn't. Redis returns a null value in this case and so should # redis-py - assert r.zadd('a', {'a1': 1}, xx=True, incr=True) is None + assert r.zadd("a", {"a1": 1}, xx=True, incr=True) is None - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_zadd_gt_lt(self, r): for i in range(1, 20): - r.zadd('a', {f'a{i}': i}) - assert r.zadd('a', {'a20': 5}, gt=3) == 1 + r.zadd("a", {f"a{i}": i}) + assert r.zadd("a", {"a20": 5}, gt=3) == 1 for i in range(1, 20): - r.zadd('a', {f'a{i}': i}) - assert r.zadd('a', {'a2': 5}, lt=1) == 0 + r.zadd("a", {f"a{i}": i}) + assert r.zadd("a", {"a2": 5}, lt=1) == 0 # cannot use both nx and xx options with pytest.raises(exceptions.DataError): - r.zadd('a', {'a15': 155}, nx=True, lt=True) - r.zadd('a', {'a15': 155}, nx=True, gt=True) - r.zadd('a', {'a15': 155}, lt=True, gt=True) + r.zadd("a", {"a15": 155}, nx=True, lt=True) + r.zadd("a", {"a15": 155}, nx=True, gt=True) + r.zadd("a", {"a15": 155}, lt=True, gt=True) def test_zcard(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) - assert r.zcard('a') == 3 + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert r.zcard("a") == 3 def test_zcount(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) - assert r.zcount('a', '-inf', '+inf') == 3 - assert r.zcount('a', 1, 2) == 2 - assert r.zcount('a', '(' + str(1), 2) == 1 - assert r.zcount('a', 1, '(' + str(2)) == 1 - assert r.zcount('a', 10, 20) == 0 + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert r.zcount("a", "-inf", "+inf") == 3 + assert r.zcount("a", 1, 2) == 2 + assert r.zcount("a", "(" + str(1), 2) == 1 + assert r.zcount("a", 1, "(" + str(2)) == 1 + assert r.zcount("a", 10, 20) == 0 @pytest.mark.onlynoncluster - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_zdiff(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) - r.zadd('b', {'a1': 1, 'a2': 2}) - assert r.zdiff(['a', 'b']) == [b'a3'] - assert r.zdiff(['a', 'b'], withscores=True) == [b'a3', b'3'] + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + r.zadd("b", {"a1": 1, "a2": 2}) + assert r.zdiff(["a", "b"]) == [b"a3"] + assert r.zdiff(["a", "b"], withscores=True) == [b"a3", b"3"] @pytest.mark.onlynoncluster - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_zdiffstore(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) - r.zadd('b', {'a1': 1, 'a2': 2}) - assert r.zdiffstore("out", ['a', 'b']) - assert r.zrange("out", 0, -1) == [b'a3'] - assert r.zrange("out", 0, -1, withscores=True) == [(b'a3', 3.0)] + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + r.zadd("b", {"a1": 1, "a2": 2}) + assert r.zdiffstore("out", ["a", "b"]) + assert r.zrange("out", 0, -1) == [b"a3"] + assert r.zrange("out", 0, -1, withscores=True) == [(b"a3", 3.0)] def test_zincrby(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) - assert r.zincrby('a', 1, 'a2') == 3.0 - assert r.zincrby('a', 5, 'a3') == 8.0 - assert r.zscore('a', 'a2') == 3.0 - assert r.zscore('a', 'a3') == 8.0 + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert r.zincrby("a", 1, "a2") == 3.0 + assert r.zincrby("a", 5, "a3") == 8.0 + assert r.zscore("a", "a2") == 3.0 + assert r.zscore("a", "a3") == 8.0 - @skip_if_server_version_lt('2.8.9') + @skip_if_server_version_lt("2.8.9") def test_zlexcount(self, r): - r.zadd('a', {'a': 0, 'b': 0, 'c': 0, 'd': 0, 'e': 0, 'f': 0, 'g': 0}) - assert r.zlexcount('a', '-', '+') == 7 - assert r.zlexcount('a', '[b', '[f') == 5 + r.zadd("a", {"a": 0, "b": 0, "c": 0, "d": 0, "e": 0, "f": 0, "g": 0}) + assert r.zlexcount("a", "-", "+") == 7 + assert r.zlexcount("a", "[b", "[f") == 5 @pytest.mark.onlynoncluster - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_zinter(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 1}) - r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) - r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4}) - assert r.zinter(['a', 'b', 'c']) == [b'a3', b'a1'] + r.zadd("a", {"a1": 1, "a2": 2, "a3": 1}) + r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zinter(["a", "b", "c"]) == [b"a3", b"a1"] # invalid aggregation with pytest.raises(exceptions.DataError): - r.zinter(['a', 'b', 'c'], aggregate='foo', withscores=True) + r.zinter(["a", "b", "c"], aggregate="foo", withscores=True) # aggregate with SUM - assert r.zinter(['a', 'b', 'c'], withscores=True) \ - == [(b'a3', 8), (b'a1', 9)] + assert r.zinter(["a", "b", "c"], withscores=True) == [(b"a3", 8), (b"a1", 9)] # aggregate with MAX - assert r.zinter(['a', 'b', 'c'], aggregate='MAX', withscores=True) \ - == [(b'a3', 5), (b'a1', 6)] + assert r.zinter(["a", "b", "c"], aggregate="MAX", withscores=True) == [ + (b"a3", 5), + (b"a1", 6), + ] # aggregate with MIN - assert r.zinter(['a', 'b', 'c'], aggregate='MIN', withscores=True) \ - == [(b'a1', 1), (b'a3', 1)] + assert r.zinter(["a", "b", "c"], aggregate="MIN", withscores=True) == [ + (b"a1", 1), + (b"a3", 1), + ] # with weights - assert r.zinter({'a': 1, 'b': 2, 'c': 3}, withscores=True) \ - == [(b'a3', 20), (b'a1', 23)] + assert r.zinter({"a": 1, "b": 2, "c": 3}, withscores=True) == [ + (b"a3", 20), + (b"a1", 23), + ] @pytest.mark.onlynoncluster def test_zinterstore_sum(self, r): - r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) - r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) - r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4}) - assert r.zinterstore('d', ['a', 'b', 'c']) == 2 - assert r.zrange('d', 0, -1, withscores=True) == \ - [(b'a3', 8), (b'a1', 9)] + r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zinterstore("d", ["a", "b", "c"]) == 2 + assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] @pytest.mark.onlynoncluster def test_zinterstore_max(self, r): - r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) - r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) - r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4}) - assert r.zinterstore('d', ['a', 'b', 'c'], aggregate='MAX') == 2 - assert r.zrange('d', 0, -1, withscores=True) == \ - [(b'a3', 5), (b'a1', 6)] + r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zinterstore("d", ["a", "b", "c"], aggregate="MAX") == 2 + assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] @pytest.mark.onlynoncluster def test_zinterstore_min(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) - r.zadd('b', {'a1': 2, 'a2': 3, 'a3': 5}) - r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4}) - assert r.zinterstore('d', ['a', 'b', 'c'], aggregate='MIN') == 2 - assert r.zrange('d', 0, -1, withscores=True) == \ - [(b'a1', 1), (b'a3', 3)] + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + r.zadd("b", {"a1": 2, "a2": 3, "a3": 5}) + r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zinterstore("d", ["a", "b", "c"], aggregate="MIN") == 2 + assert r.zrange("d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] @pytest.mark.onlynoncluster def test_zinterstore_with_weight(self, r): - r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) - r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) - r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4}) - assert r.zinterstore('d', {'a': 1, 'b': 2, 'c': 3}) == 2 - assert r.zrange('d', 0, -1, withscores=True) == \ - [(b'a3', 20), (b'a1', 23)] - - @skip_if_server_version_lt('4.9.0') + r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zinterstore("d", {"a": 1, "b": 2, "c": 3}) == 2 + assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] + + @skip_if_server_version_lt("4.9.0") def test_zpopmax(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) - assert r.zpopmax('a') == [(b'a3', 3)] + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert r.zpopmax("a") == [(b"a3", 3)] # with count - assert r.zpopmax('a', count=2) == \ - [(b'a2', 2), (b'a1', 1)] + assert r.zpopmax("a", count=2) == [(b"a2", 2), (b"a1", 1)] - @skip_if_server_version_lt('4.9.0') + @skip_if_server_version_lt("4.9.0") def test_zpopmin(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) - assert r.zpopmin('a') == [(b'a1', 1)] + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert r.zpopmin("a") == [(b"a1", 1)] # with count - assert r.zpopmin('a', count=2) == \ - [(b'a2', 2), (b'a3', 3)] + assert r.zpopmin("a", count=2) == [(b"a2", 2), (b"a3", 3)] - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_zrandemember(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3, 'a4': 4, 'a5': 5}) - assert r.zrandmember('a') is not None - assert len(r.zrandmember('a', 2)) == 2 + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert r.zrandmember("a") is not None + assert len(r.zrandmember("a", 2)) == 2 # with scores - assert len(r.zrandmember('a', 2, True)) == 4 + assert len(r.zrandmember("a", 2, True)) == 4 # without duplications - assert len(r.zrandmember('a', 10)) == 5 + assert len(r.zrandmember("a", 10)) == 5 # with duplications - assert len(r.zrandmember('a', -10)) == 10 + assert len(r.zrandmember("a", -10)) == 10 @pytest.mark.onlynoncluster - @skip_if_server_version_lt('4.9.0') + @skip_if_server_version_lt("4.9.0") def test_bzpopmax(self, r): - r.zadd('a', {'a1': 1, 'a2': 2}) - r.zadd('b', {'b1': 10, 'b2': 20}) - assert r.bzpopmax(['b', 'a'], timeout=1) == (b'b', b'b2', 20) - assert r.bzpopmax(['b', 'a'], timeout=1) == (b'b', b'b1', 10) - assert r.bzpopmax(['b', 'a'], timeout=1) == (b'a', b'a2', 2) - assert r.bzpopmax(['b', 'a'], timeout=1) == (b'a', b'a1', 1) - assert r.bzpopmax(['b', 'a'], timeout=1) is None - r.zadd('c', {'c1': 100}) - assert r.bzpopmax('c', timeout=1) == (b'c', b'c1', 100) - - @pytest.mark.onlynoncluster - @skip_if_server_version_lt('4.9.0') + r.zadd("a", {"a1": 1, "a2": 2}) + r.zadd("b", {"b1": 10, "b2": 20}) + assert r.bzpopmax(["b", "a"], timeout=1) == (b"b", b"b2", 20) + assert r.bzpopmax(["b", "a"], timeout=1) == (b"b", b"b1", 10) + assert r.bzpopmax(["b", "a"], timeout=1) == (b"a", b"a2", 2) + assert r.bzpopmax(["b", "a"], timeout=1) == (b"a", b"a1", 1) + assert r.bzpopmax(["b", "a"], timeout=1) is None + r.zadd("c", {"c1": 100}) + assert r.bzpopmax("c", timeout=1) == (b"c", b"c1", 100) + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("4.9.0") def test_bzpopmin(self, r): - r.zadd('a', {'a1': 1, 'a2': 2}) - r.zadd('b', {'b1': 10, 'b2': 20}) - assert r.bzpopmin(['b', 'a'], timeout=1) == (b'b', b'b1', 10) - assert r.bzpopmin(['b', 'a'], timeout=1) == (b'b', b'b2', 20) - assert r.bzpopmin(['b', 'a'], timeout=1) == (b'a', b'a1', 1) - assert r.bzpopmin(['b', 'a'], timeout=1) == (b'a', b'a2', 2) - assert r.bzpopmin(['b', 'a'], timeout=1) is None - r.zadd('c', {'c1': 100}) - assert r.bzpopmin('c', timeout=1) == (b'c', b'c1', 100) + r.zadd("a", {"a1": 1, "a2": 2}) + r.zadd("b", {"b1": 10, "b2": 20}) + assert r.bzpopmin(["b", "a"], timeout=1) == (b"b", b"b1", 10) + assert r.bzpopmin(["b", "a"], timeout=1) == (b"b", b"b2", 20) + assert r.bzpopmin(["b", "a"], timeout=1) == (b"a", b"a1", 1) + assert r.bzpopmin(["b", "a"], timeout=1) == (b"a", b"a2", 2) + assert r.bzpopmin(["b", "a"], timeout=1) is None + r.zadd("c", {"c1": 100}) + assert r.bzpopmin("c", timeout=1) == (b"c", b"c1", 100) def test_zrange(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) - assert r.zrange('a', 0, 1) == [b'a1', b'a2'] - assert r.zrange('a', 1, 2) == [b'a2', b'a3'] - assert r.zrange('a', 0, 2) == [b'a1', b'a2', b'a3'] - assert r.zrange('a', 0, 2, desc=True) == [b'a3', b'a2', b'a1'] + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert r.zrange("a", 0, 1) == [b"a1", b"a2"] + assert r.zrange("a", 1, 2) == [b"a2", b"a3"] + assert r.zrange("a", 0, 2) == [b"a1", b"a2", b"a3"] + assert r.zrange("a", 0, 2, desc=True) == [b"a3", b"a2", b"a1"] # withscores - assert r.zrange('a', 0, 1, withscores=True) == \ - [(b'a1', 1.0), (b'a2', 2.0)] - assert r.zrange('a', 1, 2, withscores=True) == \ - [(b'a2', 2.0), (b'a3', 3.0)] + assert r.zrange("a", 0, 1, withscores=True) == [(b"a1", 1.0), (b"a2", 2.0)] + assert r.zrange("a", 1, 2, withscores=True) == [(b"a2", 2.0), (b"a3", 3.0)] # custom score function - assert r.zrange('a', 0, 1, withscores=True, score_cast_func=int) == \ - [(b'a1', 1), (b'a2', 2)] + assert r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ + (b"a1", 1), + (b"a2", 2), + ] def test_zrange_errors(self, r): with pytest.raises(exceptions.DataError): - r.zrange('a', 0, 1, byscore=True, bylex=True) + r.zrange("a", 0, 1, byscore=True, bylex=True) with pytest.raises(exceptions.DataError): - r.zrange('a', 0, 1, bylex=True, withscores=True) + r.zrange("a", 0, 1, bylex=True, withscores=True) with pytest.raises(exceptions.DataError): - r.zrange('a', 0, 1, byscore=True, withscores=True, offset=4) + r.zrange("a", 0, 1, byscore=True, withscores=True, offset=4) with pytest.raises(exceptions.DataError): - r.zrange('a', 0, 1, byscore=True, withscores=True, num=2) + r.zrange("a", 0, 1, byscore=True, withscores=True, num=2) - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_zrange_params(self, r): # bylex - r.zadd('a', {'a': 0, 'b': 0, 'c': 0, 'd': 0, 'e': 0, 'f': 0, 'g': 0}) - assert r.zrange('a', '[aaa', '(g', bylex=True) == \ - [b'b', b'c', b'd', b'e', b'f'] - assert r.zrange('a', '[f', '+', bylex=True) == [b'f', b'g'] - assert r.zrange('a', '+', '[f', desc=True, bylex=True) == [b'g', b'f'] - assert r.zrange('a', '-', '+', bylex=True, offset=3, num=2) == \ - [b'd', b'e'] - assert r.zrange('a', '+', '-', desc=True, bylex=True, - offset=3, num=2) == \ - [b'd', b'c'] + r.zadd("a", {"a": 0, "b": 0, "c": 0, "d": 0, "e": 0, "f": 0, "g": 0}) + assert r.zrange("a", "[aaa", "(g", bylex=True) == [b"b", b"c", b"d", b"e", b"f"] + assert r.zrange("a", "[f", "+", bylex=True) == [b"f", b"g"] + assert r.zrange("a", "+", "[f", desc=True, bylex=True) == [b"g", b"f"] + assert r.zrange("a", "-", "+", bylex=True, offset=3, num=2) == [b"d", b"e"] + assert r.zrange("a", "+", "-", desc=True, bylex=True, offset=3, num=2) == [ + b"d", + b"c", + ] # byscore - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3, 'a4': 4, 'a5': 5}) - assert r.zrange('a', 2, 4, byscore=True, offset=1, num=2) == \ - [b'a3', b'a4'] - assert r.zrange('a', 4, 2, desc=True, byscore=True, - offset=1, num=2) == \ - [b'a3', b'a2'] - assert r.zrange('a', 2, 4, byscore=True, withscores=True) == \ - [(b'a2', 2.0), (b'a3', 3.0), (b'a4', 4.0)] - assert r.zrange('a', 4, 2, desc=True, byscore=True, - withscores=True, score_cast_func=int) == \ - [(b'a4', 4), (b'a3', 3), (b'a2', 2)] + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert r.zrange("a", 2, 4, byscore=True, offset=1, num=2) == [b"a3", b"a4"] + assert r.zrange("a", 4, 2, desc=True, byscore=True, offset=1, num=2) == [ + b"a3", + b"a2", + ] + assert r.zrange("a", 2, 4, byscore=True, withscores=True) == [ + (b"a2", 2.0), + (b"a3", 3.0), + (b"a4", 4.0), + ] + assert r.zrange( + "a", 4, 2, desc=True, byscore=True, withscores=True, score_cast_func=int + ) == [(b"a4", 4), (b"a3", 3), (b"a2", 2)] # rev - assert r.zrange('a', 0, 1, desc=True) == [b'a5', b'a4'] + assert r.zrange("a", 0, 1, desc=True) == [b"a5", b"a4"] @pytest.mark.onlynoncluster - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_zrangestore(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) - assert r.zrangestore('b', 'a', 0, 1) - assert r.zrange('b', 0, -1) == [b'a1', b'a2'] - assert r.zrangestore('b', 'a', 1, 2) - assert r.zrange('b', 0, -1) == [b'a2', b'a3'] - assert r.zrange('b', 0, -1, withscores=True) == \ - [(b'a2', 2), (b'a3', 3)] + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert r.zrangestore("b", "a", 0, 1) + assert r.zrange("b", 0, -1) == [b"a1", b"a2"] + assert r.zrangestore("b", "a", 1, 2) + assert r.zrange("b", 0, -1) == [b"a2", b"a3"] + assert r.zrange("b", 0, -1, withscores=True) == [(b"a2", 2), (b"a3", 3)] # reversed order - assert r.zrangestore('b', 'a', 1, 2, desc=True) - assert r.zrange('b', 0, -1) == [b'a1', b'a2'] + assert r.zrangestore("b", "a", 1, 2, desc=True) + assert r.zrange("b", 0, -1) == [b"a1", b"a2"] # by score - assert r.zrangestore('b', 'a', 2, 1, byscore=True, - offset=0, num=1, desc=True) - assert r.zrange('b', 0, -1) == [b'a2'] + assert r.zrangestore("b", "a", 2, 1, byscore=True, offset=0, num=1, desc=True) + assert r.zrange("b", 0, -1) == [b"a2"] # by lex - assert r.zrangestore('b', 'a', '[a2', '(a3', bylex=True, - offset=0, num=1) - assert r.zrange('b', 0, -1) == [b'a2'] + assert r.zrangestore("b", "a", "[a2", "(a3", bylex=True, offset=0, num=1) + assert r.zrange("b", 0, -1) == [b"a2"] - @skip_if_server_version_lt('2.8.9') + @skip_if_server_version_lt("2.8.9") def test_zrangebylex(self, r): - r.zadd('a', {'a': 0, 'b': 0, 'c': 0, 'd': 0, 'e': 0, 'f': 0, 'g': 0}) - assert r.zrangebylex('a', '-', '[c') == [b'a', b'b', b'c'] - assert r.zrangebylex('a', '-', '(c') == [b'a', b'b'] - assert r.zrangebylex('a', '[aaa', '(g') == \ - [b'b', b'c', b'd', b'e', b'f'] - assert r.zrangebylex('a', '[f', '+') == [b'f', b'g'] - assert r.zrangebylex('a', '-', '+', start=3, num=2) == [b'd', b'e'] - - @skip_if_server_version_lt('2.9.9') + r.zadd("a", {"a": 0, "b": 0, "c": 0, "d": 0, "e": 0, "f": 0, "g": 0}) + assert r.zrangebylex("a", "-", "[c") == [b"a", b"b", b"c"] + assert r.zrangebylex("a", "-", "(c") == [b"a", b"b"] + assert r.zrangebylex("a", "[aaa", "(g") == [b"b", b"c", b"d", b"e", b"f"] + assert r.zrangebylex("a", "[f", "+") == [b"f", b"g"] + assert r.zrangebylex("a", "-", "+", start=3, num=2) == [b"d", b"e"] + + @skip_if_server_version_lt("2.9.9") def test_zrevrangebylex(self, r): - r.zadd('a', {'a': 0, 'b': 0, 'c': 0, 'd': 0, 'e': 0, 'f': 0, 'g': 0}) - assert r.zrevrangebylex('a', '[c', '-') == [b'c', b'b', b'a'] - assert r.zrevrangebylex('a', '(c', '-') == [b'b', b'a'] - assert r.zrevrangebylex('a', '(g', '[aaa') == \ - [b'f', b'e', b'd', b'c', b'b'] - assert r.zrevrangebylex('a', '+', '[f') == [b'g', b'f'] - assert r.zrevrangebylex('a', '+', '-', start=3, num=2) == \ - [b'd', b'c'] + r.zadd("a", {"a": 0, "b": 0, "c": 0, "d": 0, "e": 0, "f": 0, "g": 0}) + assert r.zrevrangebylex("a", "[c", "-") == [b"c", b"b", b"a"] + assert r.zrevrangebylex("a", "(c", "-") == [b"b", b"a"] + assert r.zrevrangebylex("a", "(g", "[aaa") == [b"f", b"e", b"d", b"c", b"b"] + assert r.zrevrangebylex("a", "+", "[f") == [b"g", b"f"] + assert r.zrevrangebylex("a", "+", "-", start=3, num=2) == [b"d", b"c"] def test_zrangebyscore(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3, 'a4': 4, 'a5': 5}) - assert r.zrangebyscore('a', 2, 4) == [b'a2', b'a3', b'a4'] + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert r.zrangebyscore("a", 2, 4) == [b"a2", b"a3", b"a4"] # slicing with start/num - assert r.zrangebyscore('a', 2, 4, start=1, num=2) == \ - [b'a3', b'a4'] + assert r.zrangebyscore("a", 2, 4, start=1, num=2) == [b"a3", b"a4"] # withscores - assert r.zrangebyscore('a', 2, 4, withscores=True) == \ - [(b'a2', 2.0), (b'a3', 3.0), (b'a4', 4.0)] - assert r.zrangebyscore('a', 2, 4, withscores=True, - score_cast_func=int) == \ - [(b'a2', 2), (b'a3', 3), (b'a4', 4)] + assert r.zrangebyscore("a", 2, 4, withscores=True) == [ + (b"a2", 2.0), + (b"a3", 3.0), + (b"a4", 4.0), + ] + assert r.zrangebyscore("a", 2, 4, withscores=True, score_cast_func=int) == [ + (b"a2", 2), + (b"a3", 3), + (b"a4", 4), + ] def test_zrank(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3, 'a4': 4, 'a5': 5}) - assert r.zrank('a', 'a1') == 0 - assert r.zrank('a', 'a2') == 1 - assert r.zrank('a', 'a6') is None + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert r.zrank("a", "a1") == 0 + assert r.zrank("a", "a2") == 1 + assert r.zrank("a", "a6") is None def test_zrem(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) - assert r.zrem('a', 'a2') == 1 - assert r.zrange('a', 0, -1) == [b'a1', b'a3'] - assert r.zrem('a', 'b') == 0 - assert r.zrange('a', 0, -1) == [b'a1', b'a3'] + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert r.zrem("a", "a2") == 1 + assert r.zrange("a", 0, -1) == [b"a1", b"a3"] + assert r.zrem("a", "b") == 0 + assert r.zrange("a", 0, -1) == [b"a1", b"a3"] def test_zrem_multiple_keys(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) - assert r.zrem('a', 'a1', 'a2') == 2 - assert r.zrange('a', 0, 5) == [b'a3'] + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert r.zrem("a", "a1", "a2") == 2 + assert r.zrange("a", 0, 5) == [b"a3"] - @skip_if_server_version_lt('2.8.9') + @skip_if_server_version_lt("2.8.9") def test_zremrangebylex(self, r): - r.zadd('a', {'a': 0, 'b': 0, 'c': 0, 'd': 0, 'e': 0, 'f': 0, 'g': 0}) - assert r.zremrangebylex('a', '-', '[c') == 3 - assert r.zrange('a', 0, -1) == [b'd', b'e', b'f', b'g'] - assert r.zremrangebylex('a', '[f', '+') == 2 - assert r.zrange('a', 0, -1) == [b'd', b'e'] - assert r.zremrangebylex('a', '[h', '+') == 0 - assert r.zrange('a', 0, -1) == [b'd', b'e'] + r.zadd("a", {"a": 0, "b": 0, "c": 0, "d": 0, "e": 0, "f": 0, "g": 0}) + assert r.zremrangebylex("a", "-", "[c") == 3 + assert r.zrange("a", 0, -1) == [b"d", b"e", b"f", b"g"] + assert r.zremrangebylex("a", "[f", "+") == 2 + assert r.zrange("a", 0, -1) == [b"d", b"e"] + assert r.zremrangebylex("a", "[h", "+") == 0 + assert r.zrange("a", 0, -1) == [b"d", b"e"] def test_zremrangebyrank(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3, 'a4': 4, 'a5': 5}) - assert r.zremrangebyrank('a', 1, 3) == 3 - assert r.zrange('a', 0, 5) == [b'a1', b'a5'] + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert r.zremrangebyrank("a", 1, 3) == 3 + assert r.zrange("a", 0, 5) == [b"a1", b"a5"] def test_zremrangebyscore(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3, 'a4': 4, 'a5': 5}) - assert r.zremrangebyscore('a', 2, 4) == 3 - assert r.zrange('a', 0, -1) == [b'a1', b'a5'] - assert r.zremrangebyscore('a', 2, 4) == 0 - assert r.zrange('a', 0, -1) == [b'a1', b'a5'] + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert r.zremrangebyscore("a", 2, 4) == 3 + assert r.zrange("a", 0, -1) == [b"a1", b"a5"] + assert r.zremrangebyscore("a", 2, 4) == 0 + assert r.zrange("a", 0, -1) == [b"a1", b"a5"] def test_zrevrange(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) - assert r.zrevrange('a', 0, 1) == [b'a3', b'a2'] - assert r.zrevrange('a', 1, 2) == [b'a2', b'a1'] + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert r.zrevrange("a", 0, 1) == [b"a3", b"a2"] + assert r.zrevrange("a", 1, 2) == [b"a2", b"a1"] # withscores - assert r.zrevrange('a', 0, 1, withscores=True) == \ - [(b'a3', 3.0), (b'a2', 2.0)] - assert r.zrevrange('a', 1, 2, withscores=True) == \ - [(b'a2', 2.0), (b'a1', 1.0)] + assert r.zrevrange("a", 0, 1, withscores=True) == [(b"a3", 3.0), (b"a2", 2.0)] + assert r.zrevrange("a", 1, 2, withscores=True) == [(b"a2", 2.0), (b"a1", 1.0)] # custom score function - assert r.zrevrange('a', 0, 1, withscores=True, - score_cast_func=int) == \ - [(b'a3', 3.0), (b'a2', 2.0)] + assert r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) == [ + (b"a3", 3.0), + (b"a2", 2.0), + ] def test_zrevrangebyscore(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3, 'a4': 4, 'a5': 5}) - assert r.zrevrangebyscore('a', 4, 2) == [b'a4', b'a3', b'a2'] + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert r.zrevrangebyscore("a", 4, 2) == [b"a4", b"a3", b"a2"] # slicing with start/num - assert r.zrevrangebyscore('a', 4, 2, start=1, num=2) == \ - [b'a3', b'a2'] + assert r.zrevrangebyscore("a", 4, 2, start=1, num=2) == [b"a3", b"a2"] # withscores - assert r.zrevrangebyscore('a', 4, 2, withscores=True) == \ - [(b'a4', 4.0), (b'a3', 3.0), (b'a2', 2.0)] + assert r.zrevrangebyscore("a", 4, 2, withscores=True) == [ + (b"a4", 4.0), + (b"a3", 3.0), + (b"a2", 2.0), + ] # custom score function - assert r.zrevrangebyscore('a', 4, 2, withscores=True, - score_cast_func=int) == \ - [(b'a4', 4), (b'a3', 3), (b'a2', 2)] + assert r.zrevrangebyscore("a", 4, 2, withscores=True, score_cast_func=int) == [ + (b"a4", 4), + (b"a3", 3), + (b"a2", 2), + ] def test_zrevrank(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3, 'a4': 4, 'a5': 5}) - assert r.zrevrank('a', 'a1') == 4 - assert r.zrevrank('a', 'a2') == 3 - assert r.zrevrank('a', 'a6') is None + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert r.zrevrank("a", "a1") == 4 + assert r.zrevrank("a", "a2") == 3 + assert r.zrevrank("a", "a6") is None def test_zscore(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) - assert r.zscore('a', 'a1') == 1.0 - assert r.zscore('a', 'a2') == 2.0 - assert r.zscore('a', 'a4') is None + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert r.zscore("a", "a1") == 1.0 + assert r.zscore("a", "a2") == 2.0 + assert r.zscore("a", "a4") is None @pytest.mark.onlynoncluster - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_zunion(self, r): - r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) - r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) - r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4}) + r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) # sum - assert r.zunion(['a', 'b', 'c']) == \ - [b'a2', b'a4', b'a3', b'a1'] - assert r.zunion(['a', 'b', 'c'], withscores=True) == \ - [(b'a2', 3), (b'a4', 4), (b'a3', 8), (b'a1', 9)] + assert r.zunion(["a", "b", "c"]) == [b"a2", b"a4", b"a3", b"a1"] + assert r.zunion(["a", "b", "c"], withscores=True) == [ + (b"a2", 3), + (b"a4", 4), + (b"a3", 8), + (b"a1", 9), + ] # max - assert r.zunion(['a', 'b', 'c'], aggregate='MAX', withscores=True)\ - == [(b'a2', 2), (b'a4', 4), (b'a3', 5), (b'a1', 6)] + assert r.zunion(["a", "b", "c"], aggregate="MAX", withscores=True) == [ + (b"a2", 2), + (b"a4", 4), + (b"a3", 5), + (b"a1", 6), + ] # min - assert r.zunion(['a', 'b', 'c'], aggregate='MIN', withscores=True)\ - == [(b'a1', 1), (b'a2', 1), (b'a3', 1), (b'a4', 4)] + assert r.zunion(["a", "b", "c"], aggregate="MIN", withscores=True) == [ + (b"a1", 1), + (b"a2", 1), + (b"a3", 1), + (b"a4", 4), + ] # with weight - assert r.zunion({'a': 1, 'b': 2, 'c': 3}, withscores=True)\ - == [(b'a2', 5), (b'a4', 12), (b'a3', 20), (b'a1', 23)] + assert r.zunion({"a": 1, "b": 2, "c": 3}, withscores=True) == [ + (b"a2", 5), + (b"a4", 12), + (b"a3", 20), + (b"a1", 23), + ] @pytest.mark.onlynoncluster def test_zunionstore_sum(self, r): - r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) - r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) - r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4}) - assert r.zunionstore('d', ['a', 'b', 'c']) == 4 - assert r.zrange('d', 0, -1, withscores=True) == \ - [(b'a2', 3), (b'a4', 4), (b'a3', 8), (b'a1', 9)] + r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zunionstore("d", ["a", "b", "c"]) == 4 + assert r.zrange("d", 0, -1, withscores=True) == [ + (b"a2", 3), + (b"a4", 4), + (b"a3", 8), + (b"a1", 9), + ] @pytest.mark.onlynoncluster def test_zunionstore_max(self, r): - r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) - r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) - r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4}) - assert r.zunionstore('d', ['a', 'b', 'c'], aggregate='MAX') == 4 - assert r.zrange('d', 0, -1, withscores=True) == \ - [(b'a2', 2), (b'a4', 4), (b'a3', 5), (b'a1', 6)] + r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zunionstore("d", ["a", "b", "c"], aggregate="MAX") == 4 + assert r.zrange("d", 0, -1, withscores=True) == [ + (b"a2", 2), + (b"a4", 4), + (b"a3", 5), + (b"a1", 6), + ] @pytest.mark.onlynoncluster def test_zunionstore_min(self, r): - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) - r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 4}) - r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4}) - assert r.zunionstore('d', ['a', 'b', 'c'], aggregate='MIN') == 4 - assert r.zrange('d', 0, -1, withscores=True) == \ - [(b'a1', 1), (b'a2', 2), (b'a3', 3), (b'a4', 4)] + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + r.zadd("b", {"a1": 2, "a2": 2, "a3": 4}) + r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zunionstore("d", ["a", "b", "c"], aggregate="MIN") == 4 + assert r.zrange("d", 0, -1, withscores=True) == [ + (b"a1", 1), + (b"a2", 2), + (b"a3", 3), + (b"a4", 4), + ] @pytest.mark.onlynoncluster def test_zunionstore_with_weight(self, r): - r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) - r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) - r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4}) - assert r.zunionstore('d', {'a': 1, 'b': 2, 'c': 3}) == 4 - assert r.zrange('d', 0, -1, withscores=True) == \ - [(b'a2', 5), (b'a4', 12), (b'a3', 20), (b'a1', 23)] - - @skip_if_server_version_lt('6.1.240') + r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zunionstore("d", {"a": 1, "b": 2, "c": 3}) == 4 + assert r.zrange("d", 0, -1, withscores=True) == [ + (b"a2", 5), + (b"a4", 12), + (b"a3", 20), + (b"a1", 23), + ] + + @skip_if_server_version_lt("6.1.240") def test_zmscore(self, r): with pytest.raises(exceptions.DataError): - r.zmscore('invalid_key', []) + r.zmscore("invalid_key", []) - assert r.zmscore('invalid_key', ['invalid_member']) == [None] + assert r.zmscore("invalid_key", ["invalid_member"]) == [None] - r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3.5}) - assert r.zmscore('a', ['a1', 'a2', 'a3', 'a4']) == \ - [1.0, 2.0, 3.5, None] + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3.5}) + assert r.zmscore("a", ["a1", "a2", "a3", "a4"]) == [1.0, 2.0, 3.5, None] # HYPERLOGLOG TESTS - @skip_if_server_version_lt('2.8.9') + @skip_if_server_version_lt("2.8.9") def test_pfadd(self, r): - members = {b'1', b'2', b'3'} - assert r.pfadd('a', *members) == 1 - assert r.pfadd('a', *members) == 0 - assert r.pfcount('a') == len(members) + members = {b"1", b"2", b"3"} + assert r.pfadd("a", *members) == 1 + assert r.pfadd("a", *members) == 0 + assert r.pfcount("a") == len(members) @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.8.9') + @skip_if_server_version_lt("2.8.9") def test_pfcount(self, r): - members = {b'1', b'2', b'3'} - r.pfadd('a', *members) - assert r.pfcount('a') == len(members) - members_b = {b'2', b'3', b'4'} - r.pfadd('b', *members_b) - assert r.pfcount('b') == len(members_b) - assert r.pfcount('a', 'b') == len(members_b.union(members)) + members = {b"1", b"2", b"3"} + r.pfadd("a", *members) + assert r.pfcount("a") == len(members) + members_b = {b"2", b"3", b"4"} + r.pfadd("b", *members_b) + assert r.pfcount("b") == len(members_b) + assert r.pfcount("a", "b") == len(members_b.union(members)) @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.8.9') + @skip_if_server_version_lt("2.8.9") def test_pfmerge(self, r): - mema = {b'1', b'2', b'3'} - memb = {b'2', b'3', b'4'} - memc = {b'5', b'6', b'7'} - r.pfadd('a', *mema) - r.pfadd('b', *memb) - r.pfadd('c', *memc) - r.pfmerge('d', 'c', 'a') - assert r.pfcount('d') == 6 - r.pfmerge('d', 'b') - assert r.pfcount('d') == 7 + mema = {b"1", b"2", b"3"} + memb = {b"2", b"3", b"4"} + memc = {b"5", b"6", b"7"} + r.pfadd("a", *mema) + r.pfadd("b", *memb) + r.pfadd("c", *memc) + r.pfmerge("d", "c", "a") + assert r.pfcount("d") == 6 + r.pfmerge("d", "b") + assert r.pfcount("d") == 7 # HASH COMMANDS def test_hget_and_hset(self, r): - r.hset('a', mapping={'1': 1, '2': 2, '3': 3}) - assert r.hget('a', '1') == b'1' - assert r.hget('a', '2') == b'2' - assert r.hget('a', '3') == b'3' + r.hset("a", mapping={"1": 1, "2": 2, "3": 3}) + assert r.hget("a", "1") == b"1" + assert r.hget("a", "2") == b"2" + assert r.hget("a", "3") == b"3" # field was updated, redis returns 0 - assert r.hset('a', '2', 5) == 0 - assert r.hget('a', '2') == b'5' + assert r.hset("a", "2", 5) == 0 + assert r.hget("a", "2") == b"5" # field is new, redis returns 1 - assert r.hset('a', '4', 4) == 1 - assert r.hget('a', '4') == b'4' + assert r.hset("a", "4", 4) == 1 + assert r.hget("a", "4") == b"4" # key inside of hash that doesn't exist returns null value - assert r.hget('a', 'b') is None + assert r.hget("a", "b") is None # keys with bool(key) == False - assert r.hset('a', 0, 10) == 1 - assert r.hset('a', '', 10) == 1 + assert r.hset("a", 0, 10) == 1 + assert r.hset("a", "", 10) == 1 def test_hset_with_multi_key_values(self, r): - r.hset('a', mapping={'1': 1, '2': 2, '3': 3}) - assert r.hget('a', '1') == b'1' - assert r.hget('a', '2') == b'2' - assert r.hget('a', '3') == b'3' + r.hset("a", mapping={"1": 1, "2": 2, "3": 3}) + assert r.hget("a", "1") == b"1" + assert r.hget("a", "2") == b"2" + assert r.hget("a", "3") == b"3" - r.hset('b', "foo", "bar", mapping={'1': 1, '2': 2}) - assert r.hget('b', '1') == b'1' - assert r.hget('b', '2') == b'2' - assert r.hget('b', 'foo') == b'bar' + r.hset("b", "foo", "bar", mapping={"1": 1, "2": 2}) + assert r.hget("b", "1") == b"1" + assert r.hget("b", "2") == b"2" + assert r.hget("b", "foo") == b"bar" def test_hset_without_data(self, r): with pytest.raises(exceptions.DataError): r.hset("x") def test_hdel(self, r): - r.hset('a', mapping={'1': 1, '2': 2, '3': 3}) - assert r.hdel('a', '2') == 1 - assert r.hget('a', '2') is None - assert r.hdel('a', '1', '3') == 2 - assert r.hlen('a') == 0 + r.hset("a", mapping={"1": 1, "2": 2, "3": 3}) + assert r.hdel("a", "2") == 1 + assert r.hget("a", "2") is None + assert r.hdel("a", "1", "3") == 2 + assert r.hlen("a") == 0 def test_hexists(self, r): - r.hset('a', mapping={'1': 1, '2': 2, '3': 3}) - assert r.hexists('a', '1') - assert not r.hexists('a', '4') + r.hset("a", mapping={"1": 1, "2": 2, "3": 3}) + assert r.hexists("a", "1") + assert not r.hexists("a", "4") def test_hgetall(self, r): - h = {b'a1': b'1', b'a2': b'2', b'a3': b'3'} - r.hset('a', mapping=h) - assert r.hgetall('a') == h + h = {b"a1": b"1", b"a2": b"2", b"a3": b"3"} + r.hset("a", mapping=h) + assert r.hgetall("a") == h def test_hincrby(self, r): - assert r.hincrby('a', '1') == 1 - assert r.hincrby('a', '1', amount=2) == 3 - assert r.hincrby('a', '1', amount=-2) == 1 + assert r.hincrby("a", "1") == 1 + assert r.hincrby("a", "1", amount=2) == 3 + assert r.hincrby("a", "1", amount=-2) == 1 - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_hincrbyfloat(self, r): - assert r.hincrbyfloat('a', '1') == 1.0 - assert r.hincrbyfloat('a', '1') == 2.0 - assert r.hincrbyfloat('a', '1', 1.2) == 3.2 + assert r.hincrbyfloat("a", "1") == 1.0 + assert r.hincrbyfloat("a", "1") == 2.0 + assert r.hincrbyfloat("a", "1", 1.2) == 3.2 def test_hkeys(self, r): - h = {b'a1': b'1', b'a2': b'2', b'a3': b'3'} - r.hset('a', mapping=h) + h = {b"a1": b"1", b"a2": b"2", b"a3": b"3"} + r.hset("a", mapping=h) local_keys = list(h.keys()) - remote_keys = r.hkeys('a') - assert (sorted(local_keys) == sorted(remote_keys)) + remote_keys = r.hkeys("a") + assert sorted(local_keys) == sorted(remote_keys) def test_hlen(self, r): - r.hset('a', mapping={'1': 1, '2': 2, '3': 3}) - assert r.hlen('a') == 3 + r.hset("a", mapping={"1": 1, "2": 2, "3": 3}) + assert r.hlen("a") == 3 def test_hmget(self, r): - assert r.hset('a', mapping={'a': 1, 'b': 2, 'c': 3}) - assert r.hmget('a', 'a', 'b', 'c') == [b'1', b'2', b'3'] + assert r.hset("a", mapping={"a": 1, "b": 2, "c": 3}) + assert r.hmget("a", "a", "b", "c") == [b"1", b"2", b"3"] def test_hmset(self, r): redis_class = type(r).__name__ - warning_message = (r'^{0}\.hmset\(\) is deprecated\. ' - r'Use {0}\.hset\(\) instead\.$'.format(redis_class)) - h = {b'a': b'1', b'b': b'2', b'c': b'3'} + warning_message = ( + r"^{0}\.hmset\(\) is deprecated\. " + r"Use {0}\.hset\(\) instead\.$".format(redis_class) + ) + h = {b"a": b"1", b"b": b"2", b"c": b"3"} with pytest.warns(DeprecationWarning, match=warning_message): - assert r.hmset('a', h) - assert r.hgetall('a') == h + assert r.hmset("a", h) + assert r.hgetall("a") == h def test_hsetnx(self, r): # Initially set the hash field - assert r.hsetnx('a', '1', 1) - assert r.hget('a', '1') == b'1' - assert not r.hsetnx('a', '1', 2) - assert r.hget('a', '1') == b'1' + assert r.hsetnx("a", "1", 1) + assert r.hget("a", "1") == b"1" + assert not r.hsetnx("a", "1", 2) + assert r.hget("a", "1") == b"1" def test_hvals(self, r): - h = {b'a1': b'1', b'a2': b'2', b'a3': b'3'} - r.hset('a', mapping=h) + h = {b"a1": b"1", b"a2": b"2", b"a3": b"3"} + r.hset("a", mapping=h) local_vals = list(h.values()) - remote_vals = r.hvals('a') + remote_vals = r.hvals("a") assert sorted(local_vals) == sorted(remote_vals) - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_hstrlen(self, r): - r.hset('a', mapping={'1': '22', '2': '333'}) - assert r.hstrlen('a', '1') == 2 - assert r.hstrlen('a', '2') == 3 + r.hset("a", mapping={"1": "22", "2": "333"}) + assert r.hstrlen("a", "1") == 2 + assert r.hstrlen("a", "2") == 3 # SORT def test_sort_basic(self, r): - r.rpush('a', '3', '2', '1', '4') - assert r.sort('a') == [b'1', b'2', b'3', b'4'] + r.rpush("a", "3", "2", "1", "4") + assert r.sort("a") == [b"1", b"2", b"3", b"4"] def test_sort_limited(self, r): - r.rpush('a', '3', '2', '1', '4') - assert r.sort('a', start=1, num=2) == [b'2', b'3'] + r.rpush("a", "3", "2", "1", "4") + assert r.sort("a", start=1, num=2) == [b"2", b"3"] @pytest.mark.onlynoncluster def test_sort_by(self, r): - r['score:1'] = 8 - r['score:2'] = 3 - r['score:3'] = 5 - r.rpush('a', '3', '2', '1') - assert r.sort('a', by='score:*') == [b'2', b'3', b'1'] + r["score:1"] = 8 + r["score:2"] = 3 + r["score:3"] = 5 + r.rpush("a", "3", "2", "1") + assert r.sort("a", by="score:*") == [b"2", b"3", b"1"] @pytest.mark.onlynoncluster def test_sort_get(self, r): - r['user:1'] = 'u1' - r['user:2'] = 'u2' - r['user:3'] = 'u3' - r.rpush('a', '2', '3', '1') - assert r.sort('a', get='user:*') == [b'u1', b'u2', b'u3'] + r["user:1"] = "u1" + r["user:2"] = "u2" + r["user:3"] = "u3" + r.rpush("a", "2", "3", "1") + assert r.sort("a", get="user:*") == [b"u1", b"u2", b"u3"] @pytest.mark.onlynoncluster def test_sort_get_multi(self, r): - r['user:1'] = 'u1' - r['user:2'] = 'u2' - r['user:3'] = 'u3' - r.rpush('a', '2', '3', '1') - assert r.sort('a', get=('user:*', '#')) == \ - [b'u1', b'1', b'u2', b'2', b'u3', b'3'] + r["user:1"] = "u1" + r["user:2"] = "u2" + r["user:3"] = "u3" + r.rpush("a", "2", "3", "1") + assert r.sort("a", get=("user:*", "#")) == [ + b"u1", + b"1", + b"u2", + b"2", + b"u3", + b"3", + ] @pytest.mark.onlynoncluster def test_sort_get_groups_two(self, r): - r['user:1'] = 'u1' - r['user:2'] = 'u2' - r['user:3'] = 'u3' - r.rpush('a', '2', '3', '1') - assert r.sort('a', get=('user:*', '#'), groups=True) == \ - [(b'u1', b'1'), (b'u2', b'2'), (b'u3', b'3')] + r["user:1"] = "u1" + r["user:2"] = "u2" + r["user:3"] = "u3" + r.rpush("a", "2", "3", "1") + assert r.sort("a", get=("user:*", "#"), groups=True) == [ + (b"u1", b"1"), + (b"u2", b"2"), + (b"u3", b"3"), + ] @pytest.mark.onlynoncluster def test_sort_groups_string_get(self, r): - r['user:1'] = 'u1' - r['user:2'] = 'u2' - r['user:3'] = 'u3' - r.rpush('a', '2', '3', '1') + r["user:1"] = "u1" + r["user:2"] = "u2" + r["user:3"] = "u3" + r.rpush("a", "2", "3", "1") with pytest.raises(exceptions.DataError): - r.sort('a', get='user:*', groups=True) + r.sort("a", get="user:*", groups=True) @pytest.mark.onlynoncluster def test_sort_groups_just_one_get(self, r): - r['user:1'] = 'u1' - r['user:2'] = 'u2' - r['user:3'] = 'u3' - r.rpush('a', '2', '3', '1') + r["user:1"] = "u1" + r["user:2"] = "u2" + r["user:3"] = "u3" + r.rpush("a", "2", "3", "1") with pytest.raises(exceptions.DataError): - r.sort('a', get=['user:*'], groups=True) + r.sort("a", get=["user:*"], groups=True) def test_sort_groups_no_get(self, r): - r['user:1'] = 'u1' - r['user:2'] = 'u2' - r['user:3'] = 'u3' - r.rpush('a', '2', '3', '1') + r["user:1"] = "u1" + r["user:2"] = "u2" + r["user:3"] = "u3" + r.rpush("a", "2", "3", "1") with pytest.raises(exceptions.DataError): - r.sort('a', groups=True) + r.sort("a", groups=True) @pytest.mark.onlynoncluster def test_sort_groups_three_gets(self, r): - r['user:1'] = 'u1' - r['user:2'] = 'u2' - r['user:3'] = 'u3' - r['door:1'] = 'd1' - r['door:2'] = 'd2' - r['door:3'] = 'd3' - r.rpush('a', '2', '3', '1') - assert r.sort('a', get=('user:*', 'door:*', '#'), groups=True) == \ - [ - (b'u1', b'd1', b'1'), - (b'u2', b'd2', b'2'), - (b'u3', b'd3', b'3') - ] + r["user:1"] = "u1" + r["user:2"] = "u2" + r["user:3"] = "u3" + r["door:1"] = "d1" + r["door:2"] = "d2" + r["door:3"] = "d3" + r.rpush("a", "2", "3", "1") + assert r.sort("a", get=("user:*", "door:*", "#"), groups=True) == [ + (b"u1", b"d1", b"1"), + (b"u2", b"d2", b"2"), + (b"u3", b"d3", b"3"), + ] def test_sort_desc(self, r): - r.rpush('a', '2', '3', '1') - assert r.sort('a', desc=True) == [b'3', b'2', b'1'] + r.rpush("a", "2", "3", "1") + assert r.sort("a", desc=True) == [b"3", b"2", b"1"] def test_sort_alpha(self, r): - r.rpush('a', 'e', 'c', 'b', 'd', 'a') - assert r.sort('a', alpha=True) == \ - [b'a', b'b', b'c', b'd', b'e'] + r.rpush("a", "e", "c", "b", "d", "a") + assert r.sort("a", alpha=True) == [b"a", b"b", b"c", b"d", b"e"] @pytest.mark.onlynoncluster def test_sort_store(self, r): - r.rpush('a', '2', '3', '1') - assert r.sort('a', store='sorted_values') == 3 - assert r.lrange('sorted_values', 0, -1) == [b'1', b'2', b'3'] + r.rpush("a", "2", "3", "1") + assert r.sort("a", store="sorted_values") == 3 + assert r.lrange("sorted_values", 0, -1) == [b"1", b"2", b"3"] @pytest.mark.onlynoncluster def test_sort_all_options(self, r): - r['user:1:username'] = 'zeus' - r['user:2:username'] = 'titan' - r['user:3:username'] = 'hermes' - r['user:4:username'] = 'hercules' - r['user:5:username'] = 'apollo' - r['user:6:username'] = 'athena' - r['user:7:username'] = 'hades' - r['user:8:username'] = 'dionysus' - - r['user:1:favorite_drink'] = 'yuengling' - r['user:2:favorite_drink'] = 'rum' - r['user:3:favorite_drink'] = 'vodka' - r['user:4:favorite_drink'] = 'milk' - r['user:5:favorite_drink'] = 'pinot noir' - r['user:6:favorite_drink'] = 'water' - r['user:7:favorite_drink'] = 'gin' - r['user:8:favorite_drink'] = 'apple juice' - - r.rpush('gods', '5', '8', '3', '1', '2', '7', '6', '4') - num = r.sort('gods', start=2, num=4, by='user:*:username', - get='user:*:favorite_drink', desc=True, alpha=True, - store='sorted') + r["user:1:username"] = "zeus" + r["user:2:username"] = "titan" + r["user:3:username"] = "hermes" + r["user:4:username"] = "hercules" + r["user:5:username"] = "apollo" + r["user:6:username"] = "athena" + r["user:7:username"] = "hades" + r["user:8:username"] = "dionysus" + + r["user:1:favorite_drink"] = "yuengling" + r["user:2:favorite_drink"] = "rum" + r["user:3:favorite_drink"] = "vodka" + r["user:4:favorite_drink"] = "milk" + r["user:5:favorite_drink"] = "pinot noir" + r["user:6:favorite_drink"] = "water" + r["user:7:favorite_drink"] = "gin" + r["user:8:favorite_drink"] = "apple juice" + + r.rpush("gods", "5", "8", "3", "1", "2", "7", "6", "4") + num = r.sort( + "gods", + start=2, + num=4, + by="user:*:username", + get="user:*:favorite_drink", + desc=True, + alpha=True, + store="sorted", + ) assert num == 4 - assert r.lrange('sorted', 0, 10) == \ - [b'vodka', b'milk', b'gin', b'apple juice'] + assert r.lrange("sorted", 0, 10) == [b"vodka", b"milk", b"gin", b"apple juice"] def test_sort_issue_924(self, r): # Tests for issue https://github.com/andymccurdy/redis-py/issues/924 - r.execute_command('SADD', 'issue#924', 1) - r.execute_command('SORT', 'issue#924') + r.execute_command("SADD", "issue#924", 1) + r.execute_command("SORT", "issue#924") @pytest.mark.onlynoncluster def test_cluster_addslots(self, mock_cluster_resp_ok): - assert mock_cluster_resp_ok.cluster('ADDSLOTS', 1) is True + assert mock_cluster_resp_ok.cluster("ADDSLOTS", 1) is True @pytest.mark.onlynoncluster def test_cluster_count_failure_reports(self, mock_cluster_resp_int): - assert isinstance(mock_cluster_resp_int.cluster( - 'COUNT-FAILURE-REPORTS', 'node'), int) + assert isinstance( + mock_cluster_resp_int.cluster("COUNT-FAILURE-REPORTS", "node"), int + ) @pytest.mark.onlynoncluster def test_cluster_countkeysinslot(self, mock_cluster_resp_int): - assert isinstance(mock_cluster_resp_int.cluster( - 'COUNTKEYSINSLOT', 2), int) + assert isinstance(mock_cluster_resp_int.cluster("COUNTKEYSINSLOT", 2), int) @pytest.mark.onlynoncluster def test_cluster_delslots(self, mock_cluster_resp_ok): - assert mock_cluster_resp_ok.cluster('DELSLOTS', 1) is True + assert mock_cluster_resp_ok.cluster("DELSLOTS", 1) is True @pytest.mark.onlynoncluster def test_cluster_failover(self, mock_cluster_resp_ok): - assert mock_cluster_resp_ok.cluster('FAILOVER', 1) is True + assert mock_cluster_resp_ok.cluster("FAILOVER", 1) is True @pytest.mark.onlynoncluster def test_cluster_forget(self, mock_cluster_resp_ok): - assert mock_cluster_resp_ok.cluster('FORGET', 1) is True + assert mock_cluster_resp_ok.cluster("FORGET", 1) is True @pytest.mark.onlynoncluster def test_cluster_info(self, mock_cluster_resp_info): - assert isinstance(mock_cluster_resp_info.cluster('info'), dict) + assert isinstance(mock_cluster_resp_info.cluster("info"), dict) @pytest.mark.onlynoncluster def test_cluster_keyslot(self, mock_cluster_resp_int): - assert isinstance(mock_cluster_resp_int.cluster( - 'keyslot', 'asdf'), int) + assert isinstance(mock_cluster_resp_int.cluster("keyslot", "asdf"), int) @pytest.mark.onlynoncluster def test_cluster_meet(self, mock_cluster_resp_ok): - assert mock_cluster_resp_ok.cluster('meet', 'ip', 'port', 1) is True + assert mock_cluster_resp_ok.cluster("meet", "ip", "port", 1) is True @pytest.mark.onlynoncluster def test_cluster_nodes(self, mock_cluster_resp_nodes): - assert isinstance(mock_cluster_resp_nodes.cluster('nodes'), dict) + assert isinstance(mock_cluster_resp_nodes.cluster("nodes"), dict) @pytest.mark.onlynoncluster def test_cluster_replicate(self, mock_cluster_resp_ok): - assert mock_cluster_resp_ok.cluster('replicate', 'nodeid') is True + assert mock_cluster_resp_ok.cluster("replicate", "nodeid") is True @pytest.mark.onlynoncluster def test_cluster_reset(self, mock_cluster_resp_ok): - assert mock_cluster_resp_ok.cluster('reset', 'hard') is True + assert mock_cluster_resp_ok.cluster("reset", "hard") is True @pytest.mark.onlynoncluster def test_cluster_saveconfig(self, mock_cluster_resp_ok): - assert mock_cluster_resp_ok.cluster('saveconfig') is True + assert mock_cluster_resp_ok.cluster("saveconfig") is True @pytest.mark.onlynoncluster def test_cluster_setslot(self, mock_cluster_resp_ok): - assert mock_cluster_resp_ok.cluster('setslot', 1, - 'IMPORTING', 'nodeid') is True + assert mock_cluster_resp_ok.cluster("setslot", 1, "IMPORTING", "nodeid") is True @pytest.mark.onlynoncluster def test_cluster_slaves(self, mock_cluster_resp_slaves): - assert isinstance(mock_cluster_resp_slaves.cluster( - 'slaves', 'nodeid'), dict) + assert isinstance(mock_cluster_resp_slaves.cluster("slaves", "nodeid"), dict) @pytest.mark.onlynoncluster - @skip_if_server_version_lt('3.0.0') + @skip_if_server_version_lt("3.0.0") @skip_if_redis_enterprise def test_readwrite(self, r): assert r.readwrite() @pytest.mark.onlynoncluster - @skip_if_server_version_lt('3.0.0') + @skip_if_server_version_lt("3.0.0") def test_readonly_invalid_cluster_state(self, r): with pytest.raises(exceptions.RedisError): r.readonly() @pytest.mark.onlynoncluster - @skip_if_server_version_lt('3.0.0') + @skip_if_server_version_lt("3.0.0") def test_readonly(self, mock_cluster_resp_ok): assert mock_cluster_resp_ok.readonly() is True # GEO COMMANDS - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_geoadd(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') - assert r.geoadd('barcelona', values) == 2 - assert r.zcard('barcelona') == 2 + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + assert r.geoadd("barcelona", values) == 2 + assert r.zcard("barcelona") == 2 - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_geoadd_nx(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') - assert r.geoadd('a', values) == 2 - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') + \ - (2.1804738294738, 41.405647879212, 'place3') - assert r.geoadd('a', values, nx=True) == 1 - assert r.zrange('a', 0, -1) == [b'place3', b'place2', b'place1'] - - @skip_if_server_version_lt('6.2.0') + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + assert r.geoadd("a", values) == 2 + values = ( + (2.1909389952632, 41.433791470673, "place1") + + (2.1873744593677, 41.406342043777, "place2") + + (2.1804738294738, 41.405647879212, "place3") + ) + assert r.geoadd("a", values, nx=True) == 1 + assert r.zrange("a", 0, -1) == [b"place3", b"place2", b"place1"] + + @skip_if_server_version_lt("6.2.0") def test_geoadd_xx(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') - assert r.geoadd('a', values) == 1 - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') - assert r.geoadd('a', values, xx=True) == 0 - assert r.zrange('a', 0, -1) == \ - [b'place1'] - - @skip_if_server_version_lt('6.2.0') + values = (2.1909389952632, 41.433791470673, "place1") + assert r.geoadd("a", values) == 1 + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + assert r.geoadd("a", values, xx=True) == 0 + assert r.zrange("a", 0, -1) == [b"place1"] + + @skip_if_server_version_lt("6.2.0") def test_geoadd_ch(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') - assert r.geoadd('a', values) == 1 - values = (2.1909389952632, 31.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') - assert r.geoadd('a', values, ch=True) == 2 - assert r.zrange('a', 0, -1) == \ - [b'place1', b'place2'] - - @skip_if_server_version_lt('3.2.0') + values = (2.1909389952632, 41.433791470673, "place1") + assert r.geoadd("a", values) == 1 + values = (2.1909389952632, 31.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + assert r.geoadd("a", values, ch=True) == 2 + assert r.zrange("a", 0, -1) == [b"place1", b"place2"] + + @skip_if_server_version_lt("3.2.0") def test_geoadd_invalid_params(self, r): with pytest.raises(exceptions.RedisError): - r.geoadd('barcelona', (1, 2)) + r.geoadd("barcelona", (1, 2)) - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_geodist(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') - assert r.geoadd('barcelona', values) == 2 - assert r.geodist('barcelona', 'place1', 'place2') == 3067.4157 + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + assert r.geoadd("barcelona", values) == 2 + assert r.geodist("barcelona", "place1", "place2") == 3067.4157 - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_geodist_units(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') - r.geoadd('barcelona', values) - assert r.geodist('barcelona', 'place1', 'place2', 'km') == 3.0674 + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + r.geoadd("barcelona", values) + assert r.geodist("barcelona", "place1", "place2", "km") == 3.0674 - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_geodist_missing_one_member(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') - r.geoadd('barcelona', values) - assert r.geodist('barcelona', 'place1', 'missing_member', 'km') is None + values = (2.1909389952632, 41.433791470673, "place1") + r.geoadd("barcelona", values) + assert r.geodist("barcelona", "place1", "missing_member", "km") is None - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_geodist_invalid_units(self, r): with pytest.raises(exceptions.RedisError): - assert r.geodist('x', 'y', 'z', 'inches') + assert r.geodist("x", "y", "z", "inches") - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_geohash(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') - r.geoadd('barcelona', values) - assert r.geohash('barcelona', 'place1', 'place2', 'place3') == \ - ['sp3e9yg3kd0', 'sp3e9cbc3t0', None] + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + r.geoadd("barcelona", values) + assert r.geohash("barcelona", "place1", "place2", "place3") == [ + "sp3e9yg3kd0", + "sp3e9cbc3t0", + None, + ] @skip_unless_arch_bits(64) - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_geopos(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') - r.geoadd('barcelona', values) + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + r.geoadd("barcelona", values) # redis uses 52 bits precision, hereby small errors may be introduced. - assert r.geopos('barcelona', 'place1', 'place2') == \ - [(2.19093829393386841, 41.43379028184083523), - (2.18737632036209106, 41.40634178640635099)] + assert r.geopos("barcelona", "place1", "place2") == [ + (2.19093829393386841, 41.43379028184083523), + (2.18737632036209106, 41.40634178640635099), + ] - @skip_if_server_version_lt('4.0.0') + @skip_if_server_version_lt("4.0.0") def test_geopos_no_value(self, r): - assert r.geopos('barcelona', 'place1', 'place2') == [None, None] + assert r.geopos("barcelona", "place1", "place2") == [None, None] - @skip_if_server_version_lt('3.2.0') - @skip_if_server_version_gte('4.0.0') + @skip_if_server_version_lt("3.2.0") + @skip_if_server_version_gte("4.0.0") def test_old_geopos_no_value(self, r): - assert r.geopos('barcelona', 'place1', 'place2') == [] + assert r.geopos("barcelona", "place1", "place2") == [] - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_geosearch(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, b'\x80place2') + \ - (2.583333, 41.316667, 'place3') - r.geoadd('barcelona', values) - assert r.geosearch('barcelona', longitude=2.191, - latitude=41.433, radius=1000) == [b'place1'] - assert r.geosearch('barcelona', longitude=2.187, - latitude=41.406, radius=1000) == [b'\x80place2'] - assert r.geosearch('barcelona', longitude=2.191, latitude=41.433, - height=1000, width=1000) == [b'place1'] - assert r.geosearch('barcelona', member='place3', radius=100, - unit='km') == [b'\x80place2', b'place1', b'place3'] + values = ( + (2.1909389952632, 41.433791470673, "place1") + + (2.1873744593677, 41.406342043777, b"\x80place2") + + (2.583333, 41.316667, "place3") + ) + r.geoadd("barcelona", values) + assert r.geosearch( + "barcelona", longitude=2.191, latitude=41.433, radius=1000 + ) == [b"place1"] + assert r.geosearch( + "barcelona", longitude=2.187, latitude=41.406, radius=1000 + ) == [b"\x80place2"] + assert r.geosearch( + "barcelona", longitude=2.191, latitude=41.433, height=1000, width=1000 + ) == [b"place1"] + assert r.geosearch("barcelona", member="place3", radius=100, unit="km") == [ + b"\x80place2", + b"place1", + b"place3", + ] # test count - assert r.geosearch('barcelona', member='place3', radius=100, - unit='km', count=2) == [b'place3', b'\x80place2'] - assert r.geosearch('barcelona', member='place3', radius=100, - unit='km', count=1, any=1)[0] \ - in [b'place1', b'place3', b'\x80place2'] + assert r.geosearch( + "barcelona", member="place3", radius=100, unit="km", count=2 + ) == [b"place3", b"\x80place2"] + assert r.geosearch( + "barcelona", member="place3", radius=100, unit="km", count=1, any=1 + )[0] in [b"place1", b"place3", b"\x80place2"] @skip_unless_arch_bits(64) - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_geosearch_member(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, b'\x80place2') - - r.geoadd('barcelona', values) - assert r.geosearch('barcelona', member='place1', radius=4000) == \ - [b'\x80place2', b'place1'] - assert r.geosearch('barcelona', member='place1', radius=10) == \ - [b'place1'] - - assert r.geosearch('barcelona', member='place1', radius=4000, - withdist=True, - withcoord=True, - withhash=True) == \ - [[b'\x80place2', 3067.4157, 3471609625421029, - (2.187376320362091, 41.40634178640635)], - [b'place1', 0.0, 3471609698139488, - (2.1909382939338684, 41.433790281840835)]] - - @skip_if_server_version_lt('6.2.0') + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + b"\x80place2", + ) + + r.geoadd("barcelona", values) + assert r.geosearch("barcelona", member="place1", radius=4000) == [ + b"\x80place2", + b"place1", + ] + assert r.geosearch("barcelona", member="place1", radius=10) == [b"place1"] + + assert r.geosearch( + "barcelona", + member="place1", + radius=4000, + withdist=True, + withcoord=True, + withhash=True, + ) == [ + [ + b"\x80place2", + 3067.4157, + 3471609625421029, + (2.187376320362091, 41.40634178640635), + ], + [ + b"place1", + 0.0, + 3471609698139488, + (2.1909382939338684, 41.433790281840835), + ], + ] + + @skip_if_server_version_lt("6.2.0") def test_geosearch_sort(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') - r.geoadd('barcelona', values) - assert r.geosearch('barcelona', longitude=2.191, - latitude=41.433, radius=3000, sort='ASC') == \ - [b'place1', b'place2'] - assert r.geosearch('barcelona', longitude=2.191, - latitude=41.433, radius=3000, sort='DESC') == \ - [b'place2', b'place1'] + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + r.geoadd("barcelona", values) + assert r.geosearch( + "barcelona", longitude=2.191, latitude=41.433, radius=3000, sort="ASC" + ) == [b"place1", b"place2"] + assert r.geosearch( + "barcelona", longitude=2.191, latitude=41.433, radius=3000, sort="DESC" + ) == [b"place2", b"place1"] @skip_unless_arch_bits(64) - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_geosearch_with(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') - r.geoadd('barcelona', values) + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + r.geoadd("barcelona", values) # test a bunch of combinations to test the parse response # function. - assert r.geosearch('barcelona', longitude=2.191, latitude=41.433, - radius=1, unit='km', withdist=True, - withcoord=True, withhash=True) == \ - [[b'place1', 0.0881, 3471609698139488, - (2.19093829393386841, 41.43379028184083523)]] - assert r.geosearch('barcelona', longitude=2.191, latitude=41.433, - radius=1, unit='km', - withdist=True, withcoord=True) == \ - [[b'place1', 0.0881, - (2.19093829393386841, 41.43379028184083523)]] - assert r.geosearch('barcelona', longitude=2.191, latitude=41.433, - radius=1, unit='km', - withhash=True, withcoord=True) == \ - [[b'place1', 3471609698139488, - (2.19093829393386841, 41.43379028184083523)]] + assert r.geosearch( + "barcelona", + longitude=2.191, + latitude=41.433, + radius=1, + unit="km", + withdist=True, + withcoord=True, + withhash=True, + ) == [ + [ + b"place1", + 0.0881, + 3471609698139488, + (2.19093829393386841, 41.43379028184083523), + ] + ] + assert ( + r.geosearch( + "barcelona", + longitude=2.191, + latitude=41.433, + radius=1, + unit="km", + withdist=True, + withcoord=True, + ) + == [[b"place1", 0.0881, (2.19093829393386841, 41.43379028184083523)]] + ) + assert r.geosearch( + "barcelona", + longitude=2.191, + latitude=41.433, + radius=1, + unit="km", + withhash=True, + withcoord=True, + ) == [ + [b"place1", 3471609698139488, (2.19093829393386841, 41.43379028184083523)] + ] # test no values. - assert r.geosearch('barcelona', longitude=2, latitude=1, - radius=1, unit='km', withdist=True, - withcoord=True, withhash=True) == [] + assert ( + r.geosearch( + "barcelona", + longitude=2, + latitude=1, + radius=1, + unit="km", + withdist=True, + withcoord=True, + withhash=True, + ) + == [] + ) - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_geosearch_negative(self, r): # not specifying member nor longitude and latitude with pytest.raises(exceptions.DataError): - assert r.geosearch('barcelona') + assert r.geosearch("barcelona") # specifying member and longitude and latitude with pytest.raises(exceptions.DataError): - assert r.geosearch('barcelona', - member="Paris", longitude=2, latitude=1) + assert r.geosearch("barcelona", member="Paris", longitude=2, latitude=1) # specifying one of longitude and latitude with pytest.raises(exceptions.DataError): - assert r.geosearch('barcelona', longitude=2) + assert r.geosearch("barcelona", longitude=2) with pytest.raises(exceptions.DataError): - assert r.geosearch('barcelona', latitude=2) + assert r.geosearch("barcelona", latitude=2) # not specifying radius nor width and height with pytest.raises(exceptions.DataError): - assert r.geosearch('barcelona', member="Paris") + assert r.geosearch("barcelona", member="Paris") # specifying radius and width and height with pytest.raises(exceptions.DataError): - assert r.geosearch('barcelona', member="Paris", - radius=3, width=2, height=1) + assert r.geosearch("barcelona", member="Paris", radius=3, width=2, height=1) # specifying one of width and height with pytest.raises(exceptions.DataError): - assert r.geosearch('barcelona', member="Paris", width=2) + assert r.geosearch("barcelona", member="Paris", width=2) with pytest.raises(exceptions.DataError): - assert r.geosearch('barcelona', member="Paris", height=2) + assert r.geosearch("barcelona", member="Paris", height=2) # invalid sort with pytest.raises(exceptions.DataError): - assert r.geosearch('barcelona', - member="Paris", width=2, height=2, sort="wrong") + assert r.geosearch( + "barcelona", member="Paris", width=2, height=2, sort="wrong" + ) # invalid unit with pytest.raises(exceptions.DataError): - assert r.geosearch('barcelona', - member="Paris", width=2, height=2, unit="miles") + assert r.geosearch( + "barcelona", member="Paris", width=2, height=2, unit="miles" + ) # use any without count with pytest.raises(exceptions.DataError): - assert r.geosearch('barcelona', member='place3', radius=100, any=1) + assert r.geosearch("barcelona", member="place3", radius=100, any=1) @pytest.mark.onlynoncluster - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_geosearchstore(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) - r.geoadd('barcelona', values) - r.geosearchstore('places_barcelona', 'barcelona', - longitude=2.191, latitude=41.433, radius=1000) - assert r.zrange('places_barcelona', 0, -1) == [b'place1'] + r.geoadd("barcelona", values) + r.geosearchstore( + "places_barcelona", + "barcelona", + longitude=2.191, + latitude=41.433, + radius=1000, + ) + assert r.zrange("places_barcelona", 0, -1) == [b"place1"] @pytest.mark.onlynoncluster @skip_unless_arch_bits(64) - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_geosearchstore_dist(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) - r.geoadd('barcelona', values) - r.geosearchstore('places_barcelona', 'barcelona', - longitude=2.191, latitude=41.433, - radius=1000, storedist=True) + r.geoadd("barcelona", values) + r.geosearchstore( + "places_barcelona", + "barcelona", + longitude=2.191, + latitude=41.433, + radius=1000, + storedist=True, + ) # instead of save the geo score, the distance is saved. - assert r.zscore('places_barcelona', 'place1') == 88.05060698409301 + assert r.zscore("places_barcelona", "place1") == 88.05060698409301 - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_georadius(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, b'\x80place2') + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + b"\x80place2", + ) - r.geoadd('barcelona', values) - assert r.georadius('barcelona', 2.191, 41.433, 1000) == [b'place1'] - assert r.georadius('barcelona', 2.187, 41.406, 1000) == [b'\x80place2'] + r.geoadd("barcelona", values) + assert r.georadius("barcelona", 2.191, 41.433, 1000) == [b"place1"] + assert r.georadius("barcelona", 2.187, 41.406, 1000) == [b"\x80place2"] - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_georadius_no_values(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) - r.geoadd('barcelona', values) - assert r.georadius('barcelona', 1, 2, 1000) == [] + r.geoadd("barcelona", values) + assert r.georadius("barcelona", 1, 2, 1000) == [] - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_georadius_units(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) - r.geoadd('barcelona', values) - assert r.georadius('barcelona', 2.191, 41.433, 1, unit='km') == \ - [b'place1'] + r.geoadd("barcelona", values) + assert r.georadius("barcelona", 2.191, 41.433, 1, unit="km") == [b"place1"] @skip_unless_arch_bits(64) - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_georadius_with(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) - r.geoadd('barcelona', values) + r.geoadd("barcelona", values) # test a bunch of combinations to test the parse response # function. - assert r.georadius('barcelona', 2.191, 41.433, 1, unit='km', - withdist=True, withcoord=True, withhash=True) == \ - [[b'place1', 0.0881, 3471609698139488, - (2.19093829393386841, 41.43379028184083523)]] + assert r.georadius( + "barcelona", + 2.191, + 41.433, + 1, + unit="km", + withdist=True, + withcoord=True, + withhash=True, + ) == [ + [ + b"place1", + 0.0881, + 3471609698139488, + (2.19093829393386841, 41.43379028184083523), + ] + ] - assert r.georadius('barcelona', 2.191, 41.433, 1, unit='km', - withdist=True, withcoord=True) == \ - [[b'place1', 0.0881, - (2.19093829393386841, 41.43379028184083523)]] + assert r.georadius( + "barcelona", 2.191, 41.433, 1, unit="km", withdist=True, withcoord=True + ) == [[b"place1", 0.0881, (2.19093829393386841, 41.43379028184083523)]] - assert r.georadius('barcelona', 2.191, 41.433, 1, unit='km', - withhash=True, withcoord=True) == \ - [[b'place1', 3471609698139488, - (2.19093829393386841, 41.43379028184083523)]] + assert r.georadius( + "barcelona", 2.191, 41.433, 1, unit="km", withhash=True, withcoord=True + ) == [ + [b"place1", 3471609698139488, (2.19093829393386841, 41.43379028184083523)] + ] # test no values. - assert r.georadius('barcelona', 2, 1, 1, unit='km', - withdist=True, withcoord=True, withhash=True) == [] + assert ( + r.georadius( + "barcelona", + 2, + 1, + 1, + unit="km", + withdist=True, + withcoord=True, + withhash=True, + ) + == [] + ) - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_georadius_count(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) - r.geoadd('barcelona', values) - assert r.georadius('barcelona', 2.191, 41.433, 3000, count=1) == \ - [b'place1'] - assert r.georadius('barcelona', 2.191, 41.433, 3000, - count=1, any=True) == \ - [b'place2'] + r.geoadd("barcelona", values) + assert r.georadius("barcelona", 2.191, 41.433, 3000, count=1) == [b"place1"] + assert r.georadius("barcelona", 2.191, 41.433, 3000, count=1, any=True) == [ + b"place2" + ] - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_georadius_sort(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) - r.geoadd('barcelona', values) - assert r.georadius('barcelona', 2.191, 41.433, 3000, sort='ASC') == \ - [b'place1', b'place2'] - assert r.georadius('barcelona', 2.191, 41.433, 3000, sort='DESC') == \ - [b'place2', b'place1'] + r.geoadd("barcelona", values) + assert r.georadius("barcelona", 2.191, 41.433, 3000, sort="ASC") == [ + b"place1", + b"place2", + ] + assert r.georadius("barcelona", 2.191, 41.433, 3000, sort="DESC") == [ + b"place2", + b"place1", + ] @pytest.mark.onlynoncluster - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_georadius_store(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) - r.geoadd('barcelona', values) - r.georadius('barcelona', 2.191, 41.433, 1000, store='places_barcelona') - assert r.zrange('places_barcelona', 0, -1) == [b'place1'] + r.geoadd("barcelona", values) + r.georadius("barcelona", 2.191, 41.433, 1000, store="places_barcelona") + assert r.zrange("places_barcelona", 0, -1) == [b"place1"] @pytest.mark.onlynoncluster @skip_unless_arch_bits(64) - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_georadius_store_dist(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, 'place2') + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) - r.geoadd('barcelona', values) - r.georadius('barcelona', 2.191, 41.433, 1000, - store_dist='places_barcelona') + r.geoadd("barcelona", values) + r.georadius("barcelona", 2.191, 41.433, 1000, store_dist="places_barcelona") # instead of save the geo score, the distance is saved. - assert r.zscore('places_barcelona', 'place1') == 88.05060698409301 + assert r.zscore("places_barcelona", "place1") == 88.05060698409301 @skip_unless_arch_bits(64) - @skip_if_server_version_lt('3.2.0') + @skip_if_server_version_lt("3.2.0") def test_georadiusmember(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, b'\x80place2') - - r.geoadd('barcelona', values) - assert r.georadiusbymember('barcelona', 'place1', 4000) == \ - [b'\x80place2', b'place1'] - assert r.georadiusbymember('barcelona', 'place1', 10) == [b'place1'] - - assert r.georadiusbymember('barcelona', 'place1', 4000, - withdist=True, withcoord=True, - withhash=True) == \ - [[b'\x80place2', 3067.4157, 3471609625421029, - (2.187376320362091, 41.40634178640635)], - [b'place1', 0.0, 3471609698139488, - (2.1909382939338684, 41.433790281840835)]] - - @skip_if_server_version_lt('6.2.0') + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + b"\x80place2", + ) + + r.geoadd("barcelona", values) + assert r.georadiusbymember("barcelona", "place1", 4000) == [ + b"\x80place2", + b"place1", + ] + assert r.georadiusbymember("barcelona", "place1", 10) == [b"place1"] + + assert r.georadiusbymember( + "barcelona", "place1", 4000, withdist=True, withcoord=True, withhash=True + ) == [ + [ + b"\x80place2", + 3067.4157, + 3471609625421029, + (2.187376320362091, 41.40634178640635), + ], + [ + b"place1", + 0.0, + 3471609698139488, + (2.1909382939338684, 41.433790281840835), + ], + ] + + @skip_if_server_version_lt("6.2.0") def test_georadiusmember_count(self, r): - values = (2.1909389952632, 41.433791470673, 'place1') + \ - (2.1873744593677, 41.406342043777, b'\x80place2') - r.geoadd('barcelona', values) - assert r.georadiusbymember('barcelona', 'place1', 4000, - count=1, any=True) == \ - [b'\x80place2'] - - @skip_if_server_version_lt('5.0.0') + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + b"\x80place2", + ) + r.geoadd("barcelona", values) + assert r.georadiusbymember("barcelona", "place1", 4000, count=1, any=True) == [ + b"\x80place2" + ] + + @skip_if_server_version_lt("5.0.0") def test_xack(self, r): - stream = 'stream' - group = 'group' - consumer = 'consumer' + stream = "stream" + group = "group" + consumer = "consumer" # xack on a stream that doesn't exist - assert r.xack(stream, group, '0-0') == 0 + assert r.xack(stream, group, "0-0") == 0 - m1 = r.xadd(stream, {'one': 'one'}) - m2 = r.xadd(stream, {'two': 'two'}) - m3 = r.xadd(stream, {'three': 'three'}) + m1 = r.xadd(stream, {"one": "one"}) + m2 = r.xadd(stream, {"two": "two"}) + m3 = r.xadd(stream, {"three": "three"}) # xack on a group that doesn't exist assert r.xack(stream, group, m1) == 0 r.xgroup_create(stream, group, 0) - r.xreadgroup(group, consumer, streams={stream: '>'}) + r.xreadgroup(group, consumer, streams={stream: ">"}) # xack returns the number of ack'd elements assert r.xack(stream, group, m1) == 1 assert r.xack(stream, group, m2, m3) == 2 - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_xadd(self, r): - stream = 'stream' - message_id = r.xadd(stream, {'foo': 'bar'}) - assert re.match(br'[0-9]+\-[0-9]+', message_id) + stream = "stream" + message_id = r.xadd(stream, {"foo": "bar"}) + assert re.match(br"[0-9]+\-[0-9]+", message_id) # explicit message id - message_id = b'9999999999999999999-0' - assert message_id == r.xadd(stream, {'foo': 'bar'}, id=message_id) + message_id = b"9999999999999999999-0" + assert message_id == r.xadd(stream, {"foo": "bar"}, id=message_id) # with maxlen, the list evicts the first message - r.xadd(stream, {'foo': 'bar'}, maxlen=2, approximate=False) + r.xadd(stream, {"foo": "bar"}, maxlen=2, approximate=False) assert r.xlen(stream) == 2 - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_xadd_nomkstream(self, r): # nomkstream option - stream = 'stream' - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'some': 'other'}, nomkstream=False) + stream = "stream" + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"some": "other"}, nomkstream=False) assert r.xlen(stream) == 2 - r.xadd(stream, {'some': 'other'}, nomkstream=True) + r.xadd(stream, {"some": "other"}, nomkstream=True) assert r.xlen(stream) == 3 - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_xadd_minlen_and_limit(self, r): - stream = 'stream' + stream = "stream" - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) # Future self: No limits without approximate, according to the api with pytest.raises(redis.ResponseError): - assert r.xadd(stream, {'foo': 'bar'}, maxlen=3, - approximate=False, limit=2) + assert r.xadd(stream, {"foo": "bar"}, maxlen=3, approximate=False, limit=2) # limit can not be provided without maxlen or minid with pytest.raises(redis.ResponseError): - assert r.xadd(stream, {'foo': 'bar'}, limit=2) + assert r.xadd(stream, {"foo": "bar"}, limit=2) # maxlen with a limit - assert r.xadd(stream, {'foo': 'bar'}, maxlen=3, - approximate=True, limit=2) + assert r.xadd(stream, {"foo": "bar"}, maxlen=3, approximate=True, limit=2) r.delete(stream) # maxlen and minid can not be provided together with pytest.raises(redis.DataError): - assert r.xadd(stream, {'foo': 'bar'}, maxlen=3, - minid="sometestvalue") + assert r.xadd(stream, {"foo": "bar"}, maxlen=3, minid="sometestvalue") # minid with a limit - m1 = r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) - assert r.xadd(stream, {'foo': 'bar'}, approximate=True, - minid=m1, limit=3) + m1 = r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + assert r.xadd(stream, {"foo": "bar"}, approximate=True, minid=m1, limit=3) # pure minid - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) - m4 = r.xadd(stream, {'foo': 'bar'}) - assert r.xadd(stream, {'foo': 'bar'}, approximate=False, minid=m4) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + m4 = r.xadd(stream, {"foo": "bar"}) + assert r.xadd(stream, {"foo": "bar"}, approximate=False, minid=m4) # minid approximate - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) - m3 = r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) - assert r.xadd(stream, {'foo': 'bar'}, approximate=True, minid=m3) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + m3 = r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + assert r.xadd(stream, {"foo": "bar"}, approximate=True, minid=m3) - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_xautoclaim(self, r): - stream = 'stream' - group = 'group' - consumer1 = 'consumer1' - consumer2 = 'consumer2' + stream = "stream" + group = "group" + consumer1 = "consumer1" + consumer2 = "consumer2" - message_id1 = r.xadd(stream, {'john': 'wick'}) - message_id2 = r.xadd(stream, {'johny': 'deff'}) + message_id1 = r.xadd(stream, {"john": "wick"}) + message_id2 = r.xadd(stream, {"johny": "deff"}) message = get_stream_message(r, stream, message_id1) r.xgroup_create(stream, group, 0) @@ -3084,70 +3348,78 @@ def test_xautoclaim(self, r): assert response == [] # read the group as consumer1 to initially claim the messages - r.xreadgroup(group, consumer1, streams={stream: '>'}) + r.xreadgroup(group, consumer1, streams={stream: ">"}) # claim one message as consumer2 - response = r.xautoclaim(stream, group, consumer2, - min_idle_time=0, count=1) + response = r.xautoclaim(stream, group, consumer2, min_idle_time=0, count=1) assert response == [message] # reclaim the messages as consumer1, but use the justid argument # which only returns message ids - assert r.xautoclaim(stream, group, consumer1, min_idle_time=0, - start_id=0, justid=True) == \ - [message_id1, message_id2] - assert r.xautoclaim(stream, group, consumer1, min_idle_time=0, - start_id=message_id2, justid=True) == \ - [message_id2] - - @skip_if_server_version_lt('6.2.0') + assert r.xautoclaim( + stream, group, consumer1, min_idle_time=0, start_id=0, justid=True + ) == [message_id1, message_id2] + assert r.xautoclaim( + stream, group, consumer1, min_idle_time=0, start_id=message_id2, justid=True + ) == [message_id2] + + @skip_if_server_version_lt("6.2.0") def test_xautoclaim_negative(self, r): - stream = 'stream' - group = 'group' - consumer = 'consumer' + stream = "stream" + group = "group" + consumer = "consumer" with pytest.raises(redis.DataError): r.xautoclaim(stream, group, consumer, min_idle_time=-1) with pytest.raises(ValueError): r.xautoclaim(stream, group, consumer, min_idle_time="wrong") with pytest.raises(redis.DataError): - r.xautoclaim(stream, group, consumer, min_idle_time=0, - count=-1) + r.xautoclaim(stream, group, consumer, min_idle_time=0, count=-1) - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_xclaim(self, r): - stream = 'stream' - group = 'group' - consumer1 = 'consumer1' - consumer2 = 'consumer2' - message_id = r.xadd(stream, {'john': 'wick'}) + stream = "stream" + group = "group" + consumer1 = "consumer1" + consumer2 = "consumer2" + message_id = r.xadd(stream, {"john": "wick"}) message = get_stream_message(r, stream, message_id) r.xgroup_create(stream, group, 0) # trying to claim a message that isn't already pending doesn't # do anything - response = r.xclaim(stream, group, consumer2, - min_idle_time=0, message_ids=(message_id,)) + response = r.xclaim( + stream, group, consumer2, min_idle_time=0, message_ids=(message_id,) + ) assert response == [] # read the group as consumer1 to initially claim the messages - r.xreadgroup(group, consumer1, streams={stream: '>'}) + r.xreadgroup(group, consumer1, streams={stream: ">"}) # claim the message as consumer2 - response = r.xclaim(stream, group, consumer2, - min_idle_time=0, message_ids=(message_id,)) + response = r.xclaim( + stream, group, consumer2, min_idle_time=0, message_ids=(message_id,) + ) assert response[0] == message # reclaim the message as consumer1, but use the justid argument # which only returns message ids - assert r.xclaim(stream, group, consumer1, - min_idle_time=0, message_ids=(message_id,), - justid=True) == [message_id] + assert ( + r.xclaim( + stream, + group, + consumer1, + min_idle_time=0, + message_ids=(message_id,), + justid=True, + ) + == [message_id] + ) - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_xclaim_trimmed(self, r): # xclaim should not raise an exception if the item is not there - stream = 'stream' - group = 'group' + stream = "stream" + group = "group" r.xgroup_create(stream, group, id="$", mkstream=True) @@ -3156,57 +3428,59 @@ def test_xclaim_trimmed(self, r): sid2 = r.xadd(stream, {"item": 0}) # read them from consumer1 - r.xreadgroup(group, 'consumer1', {stream: ">"}) + r.xreadgroup(group, "consumer1", {stream: ">"}) # add a 3rd and trim the stream down to 2 items r.xadd(stream, {"item": 3}, maxlen=2, approximate=False) # xclaim them from consumer2 # the item that is still in the stream should be returned - item = r.xclaim(stream, group, 'consumer2', 0, [sid1, sid2]) + item = r.xclaim(stream, group, "consumer2", 0, [sid1, sid2]) assert len(item) == 2 assert item[0] == (None, None) assert item[1][0] == sid2 - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_xdel(self, r): - stream = 'stream' + stream = "stream" # deleting from an empty stream doesn't do anything assert r.xdel(stream, 1) == 0 - m1 = r.xadd(stream, {'foo': 'bar'}) - m2 = r.xadd(stream, {'foo': 'bar'}) - m3 = r.xadd(stream, {'foo': 'bar'}) + m1 = r.xadd(stream, {"foo": "bar"}) + m2 = r.xadd(stream, {"foo": "bar"}) + m3 = r.xadd(stream, {"foo": "bar"}) # xdel returns the number of deleted elements assert r.xdel(stream, m1) == 1 assert r.xdel(stream, m2, m3) == 2 - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_xgroup_create(self, r): # tests xgroup_create and xinfo_groups - stream = 'stream' - group = 'group' - r.xadd(stream, {'foo': 'bar'}) + stream = "stream" + group = "group" + r.xadd(stream, {"foo": "bar"}) # no group is setup yet, no info to obtain assert r.xinfo_groups(stream) == [] assert r.xgroup_create(stream, group, 0) - expected = [{ - 'name': group.encode(), - 'consumers': 0, - 'pending': 0, - 'last-delivered-id': b'0-0' - }] + expected = [ + { + "name": group.encode(), + "consumers": 0, + "pending": 0, + "last-delivered-id": b"0-0", + } + ] assert r.xinfo_groups(stream) == expected - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_xgroup_create_mkstream(self, r): # tests xgroup_create and xinfo_groups - stream = 'stream' - group = 'group' + stream = "stream" + group = "group" # an error is raised if a group is created on a stream that # doesn't already exist @@ -3216,53 +3490,55 @@ def test_xgroup_create_mkstream(self, r): # however, with mkstream=True, the underlying stream is created # automatically assert r.xgroup_create(stream, group, 0, mkstream=True) - expected = [{ - 'name': group.encode(), - 'consumers': 0, - 'pending': 0, - 'last-delivered-id': b'0-0' - }] + expected = [ + { + "name": group.encode(), + "consumers": 0, + "pending": 0, + "last-delivered-id": b"0-0", + } + ] assert r.xinfo_groups(stream) == expected - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_xgroup_delconsumer(self, r): - stream = 'stream' - group = 'group' - consumer = 'consumer' - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) + stream = "stream" + group = "group" + consumer = "consumer" + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) r.xgroup_create(stream, group, 0) # a consumer that hasn't yet read any messages doesn't do anything assert r.xgroup_delconsumer(stream, group, consumer) == 0 # read all messages from the group - r.xreadgroup(group, consumer, streams={stream: '>'}) + r.xreadgroup(group, consumer, streams={stream: ">"}) # deleting the consumer should return 2 pending messages assert r.xgroup_delconsumer(stream, group, consumer) == 2 - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_xgroup_createconsumer(self, r): - stream = 'stream' - group = 'group' - consumer = 'consumer' - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) + stream = "stream" + group = "group" + consumer = "consumer" + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) r.xgroup_create(stream, group, 0) assert r.xgroup_createconsumer(stream, group, consumer) == 1 # read all messages from the group - r.xreadgroup(group, consumer, streams={stream: '>'}) + r.xreadgroup(group, consumer, streams={stream: ">"}) # deleting the consumer should return 2 pending messages assert r.xgroup_delconsumer(stream, group, consumer) == 2 - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_xgroup_destroy(self, r): - stream = 'stream' - group = 'group' - r.xadd(stream, {'foo': 'bar'}) + stream = "stream" + group = "group" + r.xadd(stream, {"foo": "bar"}) # destroying a nonexistent group returns False assert not r.xgroup_destroy(stream, group) @@ -3270,198 +3546,189 @@ def test_xgroup_destroy(self, r): r.xgroup_create(stream, group, 0) assert r.xgroup_destroy(stream, group) - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_xgroup_setid(self, r): - stream = 'stream' - group = 'group' - message_id = r.xadd(stream, {'foo': 'bar'}) + stream = "stream" + group = "group" + message_id = r.xadd(stream, {"foo": "bar"}) r.xgroup_create(stream, group, 0) # advance the last_delivered_id to the message_id r.xgroup_setid(stream, group, message_id) - expected = [{ - 'name': group.encode(), - 'consumers': 0, - 'pending': 0, - 'last-delivered-id': message_id - }] + expected = [ + { + "name": group.encode(), + "consumers": 0, + "pending": 0, + "last-delivered-id": message_id, + } + ] assert r.xinfo_groups(stream) == expected - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_xinfo_consumers(self, r): - stream = 'stream' - group = 'group' - consumer1 = 'consumer1' - consumer2 = 'consumer2' - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) + stream = "stream" + group = "group" + consumer1 = "consumer1" + consumer2 = "consumer2" + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) r.xgroup_create(stream, group, 0) - r.xreadgroup(group, consumer1, streams={stream: '>'}, count=1) - r.xreadgroup(group, consumer2, streams={stream: '>'}) + r.xreadgroup(group, consumer1, streams={stream: ">"}, count=1) + r.xreadgroup(group, consumer2, streams={stream: ">"}) info = r.xinfo_consumers(stream, group) assert len(info) == 2 expected = [ - {'name': consumer1.encode(), 'pending': 1}, - {'name': consumer2.encode(), 'pending': 2}, + {"name": consumer1.encode(), "pending": 1}, + {"name": consumer2.encode(), "pending": 2}, ] # we can't determine the idle time, so just make sure it's an int - assert isinstance(info[0].pop('idle'), int) - assert isinstance(info[1].pop('idle'), int) + assert isinstance(info[0].pop("idle"), int) + assert isinstance(info[1].pop("idle"), int) assert info == expected - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_xinfo_stream(self, r): - stream = 'stream' - m1 = r.xadd(stream, {'foo': 'bar'}) - m2 = r.xadd(stream, {'foo': 'bar'}) + stream = "stream" + m1 = r.xadd(stream, {"foo": "bar"}) + m2 = r.xadd(stream, {"foo": "bar"}) info = r.xinfo_stream(stream) - assert info['length'] == 2 - assert info['first-entry'] == get_stream_message(r, stream, m1) - assert info['last-entry'] == get_stream_message(r, stream, m2) + assert info["length"] == 2 + assert info["first-entry"] == get_stream_message(r, stream, m1) + assert info["last-entry"] == get_stream_message(r, stream, m2) - @skip_if_server_version_lt('6.0.0') + @skip_if_server_version_lt("6.0.0") def test_xinfo_stream_full(self, r): - stream = 'stream' - group = 'group' - m1 = r.xadd(stream, {'foo': 'bar'}) + stream = "stream" + group = "group" + m1 = r.xadd(stream, {"foo": "bar"}) r.xgroup_create(stream, group, 0) info = r.xinfo_stream(stream, full=True) - assert info['length'] == 1 - assert m1 in info['entries'] - assert len(info['groups']) == 1 + assert info["length"] == 1 + assert m1 in info["entries"] + assert len(info["groups"]) == 1 - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_xlen(self, r): - stream = 'stream' + stream = "stream" assert r.xlen(stream) == 0 - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) assert r.xlen(stream) == 2 - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_xpending(self, r): - stream = 'stream' - group = 'group' - consumer1 = 'consumer1' - consumer2 = 'consumer2' - m1 = r.xadd(stream, {'foo': 'bar'}) - m2 = r.xadd(stream, {'foo': 'bar'}) + stream = "stream" + group = "group" + consumer1 = "consumer1" + consumer2 = "consumer2" + m1 = r.xadd(stream, {"foo": "bar"}) + m2 = r.xadd(stream, {"foo": "bar"}) r.xgroup_create(stream, group, 0) # xpending on a group that has no consumers yet - expected = { - 'pending': 0, - 'min': None, - 'max': None, - 'consumers': [] - } + expected = {"pending": 0, "min": None, "max": None, "consumers": []} assert r.xpending(stream, group) == expected # read 1 message from the group with each consumer - r.xreadgroup(group, consumer1, streams={stream: '>'}, count=1) - r.xreadgroup(group, consumer2, streams={stream: '>'}, count=1) + r.xreadgroup(group, consumer1, streams={stream: ">"}, count=1) + r.xreadgroup(group, consumer2, streams={stream: ">"}, count=1) expected = { - 'pending': 2, - 'min': m1, - 'max': m2, - 'consumers': [ - {'name': consumer1.encode(), 'pending': 1}, - {'name': consumer2.encode(), 'pending': 1}, - ] + "pending": 2, + "min": m1, + "max": m2, + "consumers": [ + {"name": consumer1.encode(), "pending": 1}, + {"name": consumer2.encode(), "pending": 1}, + ], } assert r.xpending(stream, group) == expected - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_xpending_range(self, r): - stream = 'stream' - group = 'group' - consumer1 = 'consumer1' - consumer2 = 'consumer2' - m1 = r.xadd(stream, {'foo': 'bar'}) - m2 = r.xadd(stream, {'foo': 'bar'}) + stream = "stream" + group = "group" + consumer1 = "consumer1" + consumer2 = "consumer2" + m1 = r.xadd(stream, {"foo": "bar"}) + m2 = r.xadd(stream, {"foo": "bar"}) r.xgroup_create(stream, group, 0) # xpending range on a group that has no consumers yet - assert r.xpending_range(stream, group, min='-', max='+', count=5) == [] + assert r.xpending_range(stream, group, min="-", max="+", count=5) == [] # read 1 message from the group with each consumer - r.xreadgroup(group, consumer1, streams={stream: '>'}, count=1) - r.xreadgroup(group, consumer2, streams={stream: '>'}, count=1) + r.xreadgroup(group, consumer1, streams={stream: ">"}, count=1) + r.xreadgroup(group, consumer2, streams={stream: ">"}, count=1) - response = r.xpending_range(stream, group, - min='-', max='+', count=5) + response = r.xpending_range(stream, group, min="-", max="+", count=5) assert len(response) == 2 - assert response[0]['message_id'] == m1 - assert response[0]['consumer'] == consumer1.encode() - assert response[1]['message_id'] == m2 - assert response[1]['consumer'] == consumer2.encode() + assert response[0]["message_id"] == m1 + assert response[0]["consumer"] == consumer1.encode() + assert response[1]["message_id"] == m2 + assert response[1]["consumer"] == consumer2.encode() # test with consumer name - response = r.xpending_range(stream, group, - min='-', max='+', count=5, - consumername=consumer1) - assert response[0]['message_id'] == m1 - assert response[0]['consumer'] == consumer1.encode() + response = r.xpending_range( + stream, group, min="-", max="+", count=5, consumername=consumer1 + ) + assert response[0]["message_id"] == m1 + assert response[0]["consumer"] == consumer1.encode() - @skip_if_server_version_lt('6.2.0') + @skip_if_server_version_lt("6.2.0") def test_xpending_range_idle(self, r): - stream = 'stream' - group = 'group' - consumer1 = 'consumer1' - consumer2 = 'consumer2' - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) + stream = "stream" + group = "group" + consumer1 = "consumer1" + consumer2 = "consumer2" + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) r.xgroup_create(stream, group, 0) # read 1 message from the group with each consumer - r.xreadgroup(group, consumer1, streams={stream: '>'}, count=1) - r.xreadgroup(group, consumer2, streams={stream: '>'}, count=1) + r.xreadgroup(group, consumer1, streams={stream: ">"}, count=1) + r.xreadgroup(group, consumer2, streams={stream: ">"}, count=1) - response = r.xpending_range(stream, group, - min='-', max='+', count=5) + response = r.xpending_range(stream, group, min="-", max="+", count=5) assert len(response) == 2 - response = r.xpending_range(stream, group, - min='-', max='+', count=5, idle=1000) + response = r.xpending_range(stream, group, min="-", max="+", count=5, idle=1000) assert len(response) == 0 def test_xpending_range_negative(self, r): - stream = 'stream' - group = 'group' + stream = "stream" + group = "group" with pytest.raises(redis.DataError): - r.xpending_range(stream, group, min='-', max='+', count=None) + r.xpending_range(stream, group, min="-", max="+", count=None) with pytest.raises(ValueError): - r.xpending_range(stream, group, min='-', max='+', count="one") + r.xpending_range(stream, group, min="-", max="+", count="one") with pytest.raises(redis.DataError): - r.xpending_range(stream, group, min='-', max='+', count=-1) + r.xpending_range(stream, group, min="-", max="+", count=-1) with pytest.raises(ValueError): - r.xpending_range(stream, group, min='-', max='+', count=5, - idle="one") + r.xpending_range(stream, group, min="-", max="+", count=5, idle="one") with pytest.raises(redis.exceptions.ResponseError): - r.xpending_range(stream, group, min='-', max='+', count=5, - idle=1.5) + r.xpending_range(stream, group, min="-", max="+", count=5, idle=1.5) with pytest.raises(redis.DataError): - r.xpending_range(stream, group, min='-', max='+', count=5, - idle=-1) + r.xpending_range(stream, group, min="-", max="+", count=5, idle=-1) with pytest.raises(redis.DataError): - r.xpending_range(stream, group, min=None, max=None, count=None, - idle=0) + r.xpending_range(stream, group, min=None, max=None, count=None, idle=0) with pytest.raises(redis.DataError): - r.xpending_range(stream, group, min=None, max=None, count=None, - consumername=0) + r.xpending_range( + stream, group, min=None, max=None, count=None, consumername=0 + ) - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_xrange(self, r): - stream = 'stream' - m1 = r.xadd(stream, {'foo': 'bar'}) - m2 = r.xadd(stream, {'foo': 'bar'}) - m3 = r.xadd(stream, {'foo': 'bar'}) - m4 = r.xadd(stream, {'foo': 'bar'}) + stream = "stream" + m1 = r.xadd(stream, {"foo": "bar"}) + m2 = r.xadd(stream, {"foo": "bar"}) + m3 = r.xadd(stream, {"foo": "bar"}) + m4 = r.xadd(stream, {"foo": "bar"}) def get_ids(results): return [result[0] for result in results] @@ -3478,11 +3745,11 @@ def get_ids(results): results = r.xrange(stream, max=m2, count=1) assert get_ids(results) == [m1] - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_xread(self, r): - stream = 'stream' - m1 = r.xadd(stream, {'foo': 'bar'}) - m2 = r.xadd(stream, {'bing': 'baz'}) + stream = "stream" + m1 = r.xadd(stream, {"foo": "bar"}) + m2 = r.xadd(stream, {"bing": "baz"}) expected = [ [ @@ -3490,7 +3757,7 @@ def test_xread(self, r): [ get_stream_message(r, stream, m1), get_stream_message(r, stream, m2), - ] + ], ] ] # xread starting at 0 returns both messages @@ -3501,7 +3768,7 @@ def test_xread(self, r): stream.encode(), [ get_stream_message(r, stream, m1), - ] + ], ] ] # xread starting at 0 and count=1 returns only the first message @@ -3512,7 +3779,7 @@ def test_xread(self, r): stream.encode(), [ get_stream_message(r, stream, m2), - ] + ], ] ] # xread starting at m1 returns only the second message @@ -3521,13 +3788,13 @@ def test_xread(self, r): # xread starting at the last message returns an empty list assert r.xread(streams={stream: m2}) == [] - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_xreadgroup(self, r): - stream = 'stream' - group = 'group' - consumer = 'consumer' - m1 = r.xadd(stream, {'foo': 'bar'}) - m2 = r.xadd(stream, {'bing': 'baz'}) + stream = "stream" + group = "group" + consumer = "consumer" + m1 = r.xadd(stream, {"foo": "bar"}) + m2 = r.xadd(stream, {"bing": "baz"}) r.xgroup_create(stream, group, 0) expected = [ @@ -3536,11 +3803,11 @@ def test_xreadgroup(self, r): [ get_stream_message(r, stream, m1), get_stream_message(r, stream, m2), - ] + ], ] ] # xread starting at 0 returns both messages - assert r.xreadgroup(group, consumer, streams={stream: '>'}) == expected + assert r.xreadgroup(group, consumer, streams={stream: ">"}) == expected r.xgroup_destroy(stream, group) r.xgroup_create(stream, group, 0) @@ -3550,34 +3817,34 @@ def test_xreadgroup(self, r): stream.encode(), [ get_stream_message(r, stream, m1), - ] + ], ] ] # xread with count=1 returns only the first message - assert r.xreadgroup(group, consumer, - streams={stream: '>'}, count=1) == expected + assert r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) == expected r.xgroup_destroy(stream, group) # create the group using $ as the last id meaning subsequent reads # will only find messages added after this - r.xgroup_create(stream, group, '$') + r.xgroup_create(stream, group, "$") expected = [] # xread starting after the last message returns an empty message list - assert r.xreadgroup(group, consumer, streams={stream: '>'}) == expected + assert r.xreadgroup(group, consumer, streams={stream: ">"}) == expected # xreadgroup with noack does not have any items in the PEL r.xgroup_destroy(stream, group) - r.xgroup_create(stream, group, '0') - assert len(r.xreadgroup(group, consumer, streams={stream: '>'}, - noack=True)[0][1]) == 2 + r.xgroup_create(stream, group, "0") + assert ( + len(r.xreadgroup(group, consumer, streams={stream: ">"}, noack=True)[0][1]) + == 2 + ) # now there should be nothing pending - assert len(r.xreadgroup(group, consumer, - streams={stream: '0'})[0][1]) == 0 + assert len(r.xreadgroup(group, consumer, streams={stream: "0"})[0][1]) == 0 r.xgroup_destroy(stream, group) - r.xgroup_create(stream, group, '0') + r.xgroup_create(stream, group, "0") # delete all the messages in the stream expected = [ [ @@ -3585,20 +3852,20 @@ def test_xreadgroup(self, r): [ (m1, {}), (m2, {}), - ] + ], ] ] - r.xreadgroup(group, consumer, streams={stream: '>'}) + r.xreadgroup(group, consumer, streams={stream: ">"}) r.xtrim(stream, 0) - assert r.xreadgroup(group, consumer, streams={stream: '0'}) == expected + assert r.xreadgroup(group, consumer, streams={stream: "0"}) == expected - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_xrevrange(self, r): - stream = 'stream' - m1 = r.xadd(stream, {'foo': 'bar'}) - m2 = r.xadd(stream, {'foo': 'bar'}) - m3 = r.xadd(stream, {'foo': 'bar'}) - m4 = r.xadd(stream, {'foo': 'bar'}) + stream = "stream" + m1 = r.xadd(stream, {"foo": "bar"}) + m2 = r.xadd(stream, {"foo": "bar"}) + m3 = r.xadd(stream, {"foo": "bar"}) + m4 = r.xadd(stream, {"foo": "bar"}) def get_ids(results): return [result[0] for result in results] @@ -3615,17 +3882,17 @@ def get_ids(results): results = r.xrevrange(stream, min=m2, count=1) assert get_ids(results) == [m4] - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_xtrim(self, r): - stream = 'stream' + stream = "stream" # trimming an empty key doesn't do anything assert r.xtrim(stream, 1000) == 0 - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) # trimming an amount large than the number of messages # doesn't do anything @@ -3634,14 +3901,14 @@ def test_xtrim(self, r): # 1 message is trimmed assert r.xtrim(stream, 3, approximate=False) == 1 - @skip_if_server_version_lt('6.2.4') + @skip_if_server_version_lt("6.2.4") def test_xtrim_minlen_and_length_args(self, r): - stream = 'stream' + stream = "stream" - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) # Future self: No limits without approximate, according to the api with pytest.raises(redis.ResponseError): @@ -3655,99 +3922,105 @@ def test_xtrim_minlen_and_length_args(self, r): assert r.xtrim(stream, maxlen=3, minid="sometestvalue") # minid with a limit - m1 = r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) + m1 = r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) assert r.xtrim(stream, None, approximate=True, minid=m1, limit=3) == 0 # pure minid - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) - m4 = r.xadd(stream, {'foo': 'bar'}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + m4 = r.xadd(stream, {"foo": "bar"}) assert r.xtrim(stream, None, approximate=False, minid=m4) == 7 # minid approximate - r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) - m3 = r.xadd(stream, {'foo': 'bar'}) - r.xadd(stream, {'foo': 'bar'}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + m3 = r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) assert r.xtrim(stream, None, approximate=True, minid=m3) == 0 def test_bitfield_operations(self, r): # comments show affected bits - bf = r.bitfield('a') - resp = (bf - .set('u8', 8, 255) # 00000000 11111111 - .get('u8', 0) # 00000000 - .get('u4', 8) # 1111 - .get('u4', 12) # 1111 - .get('u4', 13) # 111 0 - .execute()) + bf = r.bitfield("a") + resp = ( + bf.set("u8", 8, 255) # 00000000 11111111 + .get("u8", 0) # 00000000 + .get("u4", 8) # 1111 + .get("u4", 12) # 1111 + .get("u4", 13) # 111 0 + .execute() + ) assert resp == [0, 0, 15, 15, 14] # .set() returns the previous value... - resp = (bf - .set('u8', 4, 1) # 0000 0001 - .get('u16', 0) # 00000000 00011111 - .set('u16', 0, 0) # 00000000 00000000 - .execute()) + resp = ( + bf.set("u8", 4, 1) # 0000 0001 + .get("u16", 0) # 00000000 00011111 + .set("u16", 0, 0) # 00000000 00000000 + .execute() + ) assert resp == [15, 31, 31] # incrby adds to the value - resp = (bf - .incrby('u8', 8, 254) # 00000000 11111110 - .incrby('u8', 8, 1) # 00000000 11111111 - .get('u16', 0) # 00000000 11111111 - .execute()) + resp = ( + bf.incrby("u8", 8, 254) # 00000000 11111110 + .incrby("u8", 8, 1) # 00000000 11111111 + .get("u16", 0) # 00000000 11111111 + .execute() + ) assert resp == [254, 255, 255] # Verify overflow protection works as a method: - r.delete('a') - resp = (bf - .set('u8', 8, 254) # 00000000 11111110 - .overflow('fail') - .incrby('u8', 8, 2) # incrby 2 would overflow, None returned - .incrby('u8', 8, 1) # 00000000 11111111 - .incrby('u8', 8, 1) # incrby 1 would overflow, None returned - .get('u16', 0) # 00000000 11111111 - .execute()) + r.delete("a") + resp = ( + bf.set("u8", 8, 254) # 00000000 11111110 + .overflow("fail") + .incrby("u8", 8, 2) # incrby 2 would overflow, None returned + .incrby("u8", 8, 1) # 00000000 11111111 + .incrby("u8", 8, 1) # incrby 1 would overflow, None returned + .get("u16", 0) # 00000000 11111111 + .execute() + ) assert resp == [0, None, 255, None, 255] # Verify overflow protection works as arg to incrby: - r.delete('a') - resp = (bf - .set('u8', 8, 255) # 00000000 11111111 - .incrby('u8', 8, 1) # 00000000 00000000 wrap default - .set('u8', 8, 255) # 00000000 11111111 - .incrby('u8', 8, 1, 'FAIL') # 00000000 11111111 fail - .incrby('u8', 8, 1) # 00000000 11111111 still fail - .get('u16', 0) # 00000000 11111111 - .execute()) + r.delete("a") + resp = ( + bf.set("u8", 8, 255) # 00000000 11111111 + .incrby("u8", 8, 1) # 00000000 00000000 wrap default + .set("u8", 8, 255) # 00000000 11111111 + .incrby("u8", 8, 1, "FAIL") # 00000000 11111111 fail + .incrby("u8", 8, 1) # 00000000 11111111 still fail + .get("u16", 0) # 00000000 11111111 + .execute() + ) assert resp == [0, 0, 0, None, None, 255] # test default default_overflow - r.delete('a') - bf = r.bitfield('a', default_overflow='FAIL') - resp = (bf - .set('u8', 8, 255) # 00000000 11111111 - .incrby('u8', 8, 1) # 00000000 11111111 fail default - .get('u16', 0) # 00000000 11111111 - .execute()) + r.delete("a") + bf = r.bitfield("a", default_overflow="FAIL") + resp = ( + bf.set("u8", 8, 255) # 00000000 11111111 + .incrby("u8", 8, 1) # 00000000 11111111 fail default + .get("u16", 0) # 00000000 11111111 + .execute() + ) assert resp == [0, None, 255] - @skip_if_server_version_lt('4.0.0') + @skip_if_server_version_lt("4.0.0") def test_memory_help(self, r): with pytest.raises(NotImplementedError): r.memory_help() - @skip_if_server_version_lt('4.0.0') + @skip_if_server_version_lt("4.0.0") def test_memory_doctor(self, r): with pytest.raises(NotImplementedError): r.memory_doctor() - @skip_if_server_version_lt('4.0.0') + @skip_if_server_version_lt("4.0.0") def test_memory_malloc_stats(self, r): if skip_if_redis_enterprise(None).args[0] is True: with pytest.raises(redis.exceptions.ResponseError): @@ -3756,11 +4029,11 @@ def test_memory_malloc_stats(self, r): assert r.memory_malloc_stats() - @skip_if_server_version_lt('4.0.0') + @skip_if_server_version_lt("4.0.0") def test_memory_stats(self, r): # put a key into the current db to make sure that "db." # has data - r.set('foo', 'bar') + r.set("foo", "bar") if skip_if_redis_enterprise(None).args[0] is True: with pytest.raises(redis.exceptions.ResponseError): @@ -3770,104 +4043,113 @@ def test_memory_stats(self, r): stats = r.memory_stats() assert isinstance(stats, dict) for key, value in stats.items(): - if key.startswith('db.'): + if key.startswith("db."): assert isinstance(value, dict) - @skip_if_server_version_lt('4.0.0') + @skip_if_server_version_lt("4.0.0") def test_memory_usage(self, r): - r.set('foo', 'bar') - assert isinstance(r.memory_usage('foo'), int) + r.set("foo", "bar") + assert isinstance(r.memory_usage("foo"), int) @pytest.mark.onlynoncluster - @skip_if_server_version_lt('4.0.0') + @skip_if_server_version_lt("4.0.0") @skip_if_redis_enterprise def test_module_list(self, r): assert isinstance(r.module_list(), list) for x in r.module_list(): assert isinstance(x, dict) - @skip_if_server_version_lt('2.8.13') + @skip_if_server_version_lt("2.8.13") def test_command_count(self, r): res = r.command_count() assert isinstance(res, int) assert res >= 100 @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.8.13') + @skip_if_server_version_lt("2.8.13") def test_command_getkeys(self, r): - res = r.command_getkeys('MSET', 'a', 'b', 'c', 'd', 'e', 'f') - assert res == ['a', 'c', 'e'] - res = r.command_getkeys('EVAL', '"not consulted"', - '3', 'key1', 'key2', 'key3', - 'arg1', 'arg2', 'arg3', 'argN') - assert res == ['key1', 'key2', 'key3'] - - @skip_if_server_version_lt('2.8.13') + res = r.command_getkeys("MSET", "a", "b", "c", "d", "e", "f") + assert res == ["a", "c", "e"] + res = r.command_getkeys( + "EVAL", + '"not consulted"', + "3", + "key1", + "key2", + "key3", + "arg1", + "arg2", + "arg3", + "argN", + ) + assert res == ["key1", "key2", "key3"] + + @skip_if_server_version_lt("2.8.13") def test_command(self, r): res = r.command() assert len(res) >= 100 cmds = list(res.keys()) - assert 'set' in cmds - assert 'get' in cmds + assert "set" in cmds + assert "get" in cmds @pytest.mark.onlynoncluster - @skip_if_server_version_lt('4.0.0') + @skip_if_server_version_lt("4.0.0") @skip_if_redis_enterprise def test_module(self, r): with pytest.raises(redis.exceptions.ModuleError) as excinfo: - r.module_load('/some/fake/path') + r.module_load("/some/fake/path") assert "Error loading the extension." in str(excinfo.value) with pytest.raises(redis.exceptions.ModuleError) as excinfo: - r.module_load('/some/fake/path', 'arg1', 'arg2', 'arg3', 'arg4') + r.module_load("/some/fake/path", "arg1", "arg2", "arg3", "arg4") assert "Error loading the extension." in str(excinfo.value) - @skip_if_server_version_lt('2.6.0') + @skip_if_server_version_lt("2.6.0") def test_restore(self, r): # standard restore - key = 'foo' - r.set(key, 'bar') + key = "foo" + r.set(key, "bar") dumpdata = r.dump(key) r.delete(key) assert r.restore(key, 0, dumpdata) - assert r.get(key) == b'bar' + assert r.get(key) == b"bar" # overwrite restore with pytest.raises(redis.exceptions.ResponseError): assert r.restore(key, 0, dumpdata) - r.set(key, 'a new value!') + r.set(key, "a new value!") assert r.restore(key, 0, dumpdata, replace=True) - assert r.get(key) == b'bar' + assert r.get(key) == b"bar" # ttl check - key2 = 'another' - r.set(key2, 'blee!') + key2 = "another" + r.set(key2, "blee!") dumpdata = r.dump(key2) r.delete(key2) assert r.restore(key2, 0, dumpdata) assert r.ttl(key2) == -1 - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_restore_idletime(self, r): - key = 'yayakey' - r.set(key, 'blee!') + key = "yayakey" + r.set(key, "blee!") dumpdata = r.dump(key) r.delete(key) assert r.restore(key, 0, dumpdata, idletime=5) - assert r.get(key) == b'blee!' + assert r.get(key) == b"blee!" - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") def test_restore_frequency(self, r): - key = 'yayakey' - r.set(key, 'blee!') + key = "yayakey" + r.set(key, "blee!") dumpdata = r.dump(key) r.delete(key) assert r.restore(key, 0, dumpdata, frequency=5) - assert r.get(key) == b'blee!' + assert r.get(key) == b"blee!" @pytest.mark.onlynoncluster - @skip_if_server_version_lt('5.0.0') + @skip_if_server_version_lt("5.0.0") @skip_if_redis_enterprise def test_replicaof(self, r): with pytest.raises(redis.ResponseError): @@ -3877,36 +4159,38 @@ def test_replicaof(self, r): @pytest.mark.onlynoncluster class TestBinarySave: - def test_binary_get_set(self, r): - assert r.set(' foo bar ', '123') - assert r.get(' foo bar ') == b'123' + assert r.set(" foo bar ", "123") + assert r.get(" foo bar ") == b"123" - assert r.set(' foo\r\nbar\r\n ', '456') - assert r.get(' foo\r\nbar\r\n ') == b'456' + assert r.set(" foo\r\nbar\r\n ", "456") + assert r.get(" foo\r\nbar\r\n ") == b"456" - assert r.set(' \r\n\t\x07\x13 ', '789') - assert r.get(' \r\n\t\x07\x13 ') == b'789' + assert r.set(" \r\n\t\x07\x13 ", "789") + assert r.get(" \r\n\t\x07\x13 ") == b"789" - assert sorted(r.keys('*')) == \ - [b' \r\n\t\x07\x13 ', b' foo\r\nbar\r\n ', b' foo bar '] + assert sorted(r.keys("*")) == [ + b" \r\n\t\x07\x13 ", + b" foo\r\nbar\r\n ", + b" foo bar ", + ] - assert r.delete(' foo bar ') - assert r.delete(' foo\r\nbar\r\n ') - assert r.delete(' \r\n\t\x07\x13 ') + assert r.delete(" foo bar ") + assert r.delete(" foo\r\nbar\r\n ") + assert r.delete(" \r\n\t\x07\x13 ") def test_binary_lists(self, r): mapping = { - b'foo bar': [b'1', b'2', b'3'], - b'foo\r\nbar\r\n': [b'4', b'5', b'6'], - b'foo\tbar\x07': [b'7', b'8', b'9'], + b"foo bar": [b"1", b"2", b"3"], + b"foo\r\nbar\r\n": [b"4", b"5", b"6"], + b"foo\tbar\x07": [b"7", b"8", b"9"], } # fill in lists for key, value in mapping.items(): r.rpush(key, *value) # check that KEYS returns all the keys as they are - assert sorted(r.keys('*')) == sorted(mapping.keys()) + assert sorted(r.keys("*")) == sorted(mapping.keys()) # check that it is possible to get list content by key name for key, value in mapping.items(): @@ -3917,42 +4201,44 @@ def test_22_info(self, r): Older Redis versions contained 'allocation_stats' in INFO that was the cause of a number of bugs when parsing. """ - info = "allocation_stats:6=1,7=1,8=7141,9=180,10=92,11=116,12=5330," \ - "13=123,14=3091,15=11048,16=225842,17=1784,18=814,19=12020," \ - "20=2530,21=645,22=15113,23=8695,24=142860,25=318,26=3303," \ - "27=20561,28=54042,29=37390,30=1884,31=18071,32=31367,33=160," \ - "34=169,35=201,36=10155,37=1045,38=15078,39=22985,40=12523," \ - "41=15588,42=265,43=1287,44=142,45=382,46=945,47=426,48=171," \ - "49=56,50=516,51=43,52=41,53=46,54=54,55=75,56=647,57=332," \ - "58=32,59=39,60=48,61=35,62=62,63=32,64=221,65=26,66=30," \ - "67=36,68=41,69=44,70=26,71=144,72=169,73=24,74=37,75=25," \ - "76=42,77=21,78=126,79=374,80=27,81=40,82=43,83=47,84=46," \ - "85=114,86=34,87=37,88=7240,89=34,90=38,91=18,92=99,93=20," \ - "94=18,95=17,96=15,97=22,98=18,99=69,100=17,101=22,102=15," \ - "103=29,104=39,105=30,106=70,107=22,108=21,109=26,110=52," \ - "111=45,112=33,113=67,114=41,115=44,116=48,117=53,118=54," \ - "119=51,120=75,121=44,122=57,123=44,124=66,125=56,126=52," \ - "127=81,128=108,129=70,130=50,131=51,132=53,133=45,134=62," \ - "135=12,136=13,137=7,138=15,139=21,140=11,141=20,142=6,143=7," \ - "144=11,145=6,146=16,147=19,148=1112,149=1,151=83,154=1," \ - "155=1,156=1,157=1,160=1,161=1,162=2,166=1,169=1,170=1,171=2," \ - "172=1,174=1,176=2,177=9,178=34,179=73,180=30,181=1,185=3," \ - "187=1,188=1,189=1,192=1,196=1,198=1,200=1,201=1,204=1,205=1," \ - "207=1,208=1,209=1,214=2,215=31,216=78,217=28,218=5,219=2," \ - "220=1,222=1,225=1,227=1,234=1,242=1,250=1,252=1,253=1," \ - ">=256=203" + info = ( + "allocation_stats:6=1,7=1,8=7141,9=180,10=92,11=116,12=5330," + "13=123,14=3091,15=11048,16=225842,17=1784,18=814,19=12020," + "20=2530,21=645,22=15113,23=8695,24=142860,25=318,26=3303," + "27=20561,28=54042,29=37390,30=1884,31=18071,32=31367,33=160," + "34=169,35=201,36=10155,37=1045,38=15078,39=22985,40=12523," + "41=15588,42=265,43=1287,44=142,45=382,46=945,47=426,48=171," + "49=56,50=516,51=43,52=41,53=46,54=54,55=75,56=647,57=332," + "58=32,59=39,60=48,61=35,62=62,63=32,64=221,65=26,66=30," + "67=36,68=41,69=44,70=26,71=144,72=169,73=24,74=37,75=25," + "76=42,77=21,78=126,79=374,80=27,81=40,82=43,83=47,84=46," + "85=114,86=34,87=37,88=7240,89=34,90=38,91=18,92=99,93=20," + "94=18,95=17,96=15,97=22,98=18,99=69,100=17,101=22,102=15," + "103=29,104=39,105=30,106=70,107=22,108=21,109=26,110=52," + "111=45,112=33,113=67,114=41,115=44,116=48,117=53,118=54," + "119=51,120=75,121=44,122=57,123=44,124=66,125=56,126=52," + "127=81,128=108,129=70,130=50,131=51,132=53,133=45,134=62," + "135=12,136=13,137=7,138=15,139=21,140=11,141=20,142=6,143=7," + "144=11,145=6,146=16,147=19,148=1112,149=1,151=83,154=1," + "155=1,156=1,157=1,160=1,161=1,162=2,166=1,169=1,170=1,171=2," + "172=1,174=1,176=2,177=9,178=34,179=73,180=30,181=1,185=3," + "187=1,188=1,189=1,192=1,196=1,198=1,200=1,201=1,204=1,205=1," + "207=1,208=1,209=1,214=2,215=31,216=78,217=28,218=5,219=2," + "220=1,222=1,225=1,227=1,234=1,242=1,250=1,252=1,253=1," + ">=256=203" + ) parsed = parse_info(info) - assert 'allocation_stats' in parsed - assert '6' in parsed['allocation_stats'] - assert '>=256' in parsed['allocation_stats'] + assert "allocation_stats" in parsed + assert "6" in parsed["allocation_stats"] + assert ">=256" in parsed["allocation_stats"] @skip_if_redis_enterprise def test_large_responses(self, r): "The PythonParser has some special cases for return values > 1MB" # load up 5MB of data into a key - data = ''.join([ascii_letters] * (5000000 // len(ascii_letters))) - r['a'] = data - assert r['a'] == data.encode() + data = "".join([ascii_letters] * (5000000 // len(ascii_letters))) + r["a"] = data + assert r["a"] == data.encode() def test_floating_point_encoding(self, r): """ @@ -3960,5 +4246,5 @@ def test_floating_point_encoding(self, r): precision. """ timestamp = 1349673917.939762 - r.zadd('a', {'a1': timestamp}) - assert r.zscore('a', 'a1') == timestamp + r.zadd("a", {"a1": timestamp}) + assert r.zscore("a", "a1") == timestamp diff --git a/tests/test_connection.py b/tests/test_connection.py index 0071acab5c..22f1b718de 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,37 +1,40 @@ -from unittest import mock import types +from unittest import mock + import pytest from redis.exceptions import InvalidResponse from redis.utils import HIREDIS_AVAILABLE + from .conftest import skip_if_server_version_lt -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason='PythonParser only') +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster def test_invalid_response(r): - raw = b'x' + raw = b"x" parser = r.connection._parser - with mock.patch.object(parser._buffer, 'readline', return_value=raw): + with mock.patch.object(parser._buffer, "readline", return_value=raw): with pytest.raises(InvalidResponse) as cm: parser.read_response() - assert str(cm.value) == f'Protocol Error: {raw!r}' + assert str(cm.value) == f"Protocol Error: {raw!r}" -@skip_if_server_version_lt('4.0.0') +@skip_if_server_version_lt("4.0.0") @pytest.mark.redismod def test_loading_external_modules(modclient): def inner(): pass - modclient.load_external_module('myfuncname', inner) - assert getattr(modclient, 'myfuncname') == inner - assert isinstance(getattr(modclient, 'myfuncname'), types.FunctionType) + modclient.load_external_module("myfuncname", inner) + assert getattr(modclient, "myfuncname") == inner + assert isinstance(getattr(modclient, "myfuncname"), types.FunctionType) # and call it from redis.commands import RedisModuleCommands + j = RedisModuleCommands.json - modclient.load_external_module('sometestfuncname', j) + modclient.load_external_module("sometestfuncname", j) # d = {'hello': 'world!'} # mod = j(modclient) diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 5968b2b4fe..32f5e23d53 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -1,6 +1,7 @@ -import pytest -import multiprocessing import contextlib +import multiprocessing + +import pytest import redis from redis.connection import Connection, ConnectionPool @@ -25,10 +26,7 @@ class TestMultiprocessing: # actually fork/process-safe @pytest.fixture() def r(self, request): - return _get_client( - redis.Redis, - request=request, - single_connection_client=False) + return _get_client(redis.Redis, request=request, single_connection_client=False) def test_close_connection_in_child(self, master_host): """ @@ -36,12 +34,12 @@ def test_close_connection_in_child(self, master_host): destroy the file descriptors so a parent can still use it. """ conn = Connection(host=master_host[0], port=master_host[1]) - conn.send_command('ping') - assert conn.read_response() == b'PONG' + conn.send_command("ping") + assert conn.read_response() == b"PONG" def target(conn): - conn.send_command('ping') - assert conn.read_response() == b'PONG' + conn.send_command("ping") + assert conn.read_response() == b"PONG" conn.disconnect() proc = multiprocessing.Process(target=target, args=(conn,)) @@ -53,8 +51,8 @@ def target(conn): # child. The child called socket.close() but did not call # socket.shutdown() because it wasn't the "owning" process. # Therefore the connection still works in the parent. - conn.send_command('ping') - assert conn.read_response() == b'PONG' + conn.send_command("ping") + assert conn.read_response() == b"PONG" def test_close_connection_in_parent(self, master_host): """ @@ -62,8 +60,8 @@ def test_close_connection_in_parent(self, master_host): (the owning process) closes the connection. """ conn = Connection(host=master_host[0], port=master_host[1]) - conn.send_command('ping') - assert conn.read_response() == b'PONG' + conn.send_command("ping") + assert conn.read_response() == b"PONG" def target(conn, ev): ev.wait() @@ -71,7 +69,7 @@ def target(conn, ev): # connection, the connection is shutdown and the child # cannot use it. with pytest.raises(ConnectionError): - conn.send_command('ping') + conn.send_command("ping") ev = multiprocessing.Event() proc = multiprocessing.Process(target=target, args=(conn, ev)) @@ -83,28 +81,30 @@ def target(conn, ev): proc.join(3) assert proc.exitcode == 0 - @pytest.mark.parametrize('max_connections', [1, 2, None]) + @pytest.mark.parametrize("max_connections", [1, 2, None]) def test_pool(self, max_connections, master_host): """ A child will create its own connections when using a pool created by a parent. """ - pool = ConnectionPool.from_url(f'redis://{master_host[0]}:{master_host[1]}', - max_connections=max_connections) + pool = ConnectionPool.from_url( + f"redis://{master_host[0]}:{master_host[1]}", + max_connections=max_connections, + ) - conn = pool.get_connection('ping') + conn = pool.get_connection("ping") main_conn_pid = conn.pid with exit_callback(pool.release, conn): - conn.send_command('ping') - assert conn.read_response() == b'PONG' + conn.send_command("ping") + assert conn.read_response() == b"PONG" def target(pool): with exit_callback(pool.disconnect): - conn = pool.get_connection('ping') + conn = pool.get_connection("ping") assert conn.pid != main_conn_pid with exit_callback(pool.release, conn): - assert conn.send_command('ping') is None - assert conn.read_response() == b'PONG' + assert conn.send_command("ping") is None + assert conn.read_response() == b"PONG" proc = multiprocessing.Process(target=target, args=(pool,)) proc.start() @@ -113,32 +113,34 @@ def target(pool): # Check that connection is still alive after fork process has exited # and disconnected the connections in its pool - conn = pool.get_connection('ping') + conn = pool.get_connection("ping") with exit_callback(pool.release, conn): - assert conn.send_command('ping') is None - assert conn.read_response() == b'PONG' + assert conn.send_command("ping") is None + assert conn.read_response() == b"PONG" - @pytest.mark.parametrize('max_connections', [1, 2, None]) + @pytest.mark.parametrize("max_connections", [1, 2, None]) def test_close_pool_in_main(self, max_connections, master_host): """ A child process that uses the same pool as its parent isn't affected when the parent disconnects all connections within the pool. """ - pool = ConnectionPool.from_url(f'redis://{master_host[0]}:{master_host[1]}', - max_connections=max_connections) + pool = ConnectionPool.from_url( + f"redis://{master_host[0]}:{master_host[1]}", + max_connections=max_connections, + ) - conn = pool.get_connection('ping') - assert conn.send_command('ping') is None - assert conn.read_response() == b'PONG' + conn = pool.get_connection("ping") + assert conn.send_command("ping") is None + assert conn.read_response() == b"PONG" def target(pool, disconnect_event): - conn = pool.get_connection('ping') + conn = pool.get_connection("ping") with exit_callback(pool.release, conn): - assert conn.send_command('ping') is None - assert conn.read_response() == b'PONG' + assert conn.send_command("ping") is None + assert conn.read_response() == b"PONG" disconnect_event.wait() - assert conn.send_command('ping') is None - assert conn.read_response() == b'PONG' + assert conn.send_command("ping") is None + assert conn.read_response() == b"PONG" ev = multiprocessing.Event() diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index a87ed7182d..0518893f07 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,7 +1,8 @@ import pytest import redis -from .conftest import wait_for_command, skip_if_server_version_lt + +from .conftest import skip_if_server_version_lt, wait_for_command class TestPipeline: @@ -12,31 +13,30 @@ def test_pipeline_is_true(self, r): def test_pipeline(self, r): with r.pipeline() as pipe: - (pipe.set('a', 'a1') - .get('a') - .zadd('z', {'z1': 1}) - .zadd('z', {'z2': 4}) - .zincrby('z', 1, 'z1') - .zrange('z', 0, 5, withscores=True)) - assert pipe.execute() == \ - [ - True, - b'a1', - True, - True, - 2.0, - [(b'z1', 2.0), (b'z2', 4)], - ] + ( + pipe.set("a", "a1") + .get("a") + .zadd("z", {"z1": 1}) + .zadd("z", {"z2": 4}) + .zincrby("z", 1, "z1") + .zrange("z", 0, 5, withscores=True) + ) + assert pipe.execute() == [ + True, + b"a1", + True, + True, + 2.0, + [(b"z1", 2.0), (b"z2", 4)], + ] def test_pipeline_memoryview(self, r): with r.pipeline() as pipe: - (pipe.set('a', memoryview(b'a1')) - .get('a')) - assert pipe.execute() == \ - [ - True, - b'a1', - ] + (pipe.set("a", memoryview(b"a1")).get("a")) + assert pipe.execute() == [ + True, + b"a1", + ] def test_pipeline_length(self, r): with r.pipeline() as pipe: @@ -44,7 +44,7 @@ def test_pipeline_length(self, r): assert len(pipe) == 0 # Fill 'er up! - pipe.set('a', 'a1').set('b', 'b1').set('c', 'c1') + pipe.set("a", "a1").set("b", "b1").set("c", "c1") assert len(pipe) == 3 # Execute calls reset(), so empty once again. @@ -53,83 +53,84 @@ def test_pipeline_length(self, r): def test_pipeline_no_transaction(self, r): with r.pipeline(transaction=False) as pipe: - pipe.set('a', 'a1').set('b', 'b1').set('c', 'c1') + pipe.set("a", "a1").set("b", "b1").set("c", "c1") assert pipe.execute() == [True, True, True] - assert r['a'] == b'a1' - assert r['b'] == b'b1' - assert r['c'] == b'c1' + assert r["a"] == b"a1" + assert r["b"] == b"b1" + assert r["c"] == b"c1" @pytest.mark.onlynoncluster def test_pipeline_no_transaction_watch(self, r): - r['a'] = 0 + r["a"] = 0 with r.pipeline(transaction=False) as pipe: - pipe.watch('a') - a = pipe.get('a') + pipe.watch("a") + a = pipe.get("a") pipe.multi() - pipe.set('a', int(a) + 1) + pipe.set("a", int(a) + 1) assert pipe.execute() == [True] @pytest.mark.onlynoncluster def test_pipeline_no_transaction_watch_failure(self, r): - r['a'] = 0 + r["a"] = 0 with r.pipeline(transaction=False) as pipe: - pipe.watch('a') - a = pipe.get('a') + pipe.watch("a") + a = pipe.get("a") - r['a'] = 'bad' + r["a"] = "bad" pipe.multi() - pipe.set('a', int(a) + 1) + pipe.set("a", int(a) + 1) with pytest.raises(redis.WatchError): pipe.execute() - assert r['a'] == b'bad' + assert r["a"] == b"bad" def test_exec_error_in_response(self, r): """ an invalid pipeline command at exec time adds the exception instance to the list of returned values """ - r['c'] = 'a' + r["c"] = "a" with r.pipeline() as pipe: - pipe.set('a', 1).set('b', 2).lpush('c', 3).set('d', 4) + pipe.set("a", 1).set("b", 2).lpush("c", 3).set("d", 4) result = pipe.execute(raise_on_error=False) assert result[0] - assert r['a'] == b'1' + assert r["a"] == b"1" assert result[1] - assert r['b'] == b'2' + assert r["b"] == b"2" # we can't lpush to a key that's a string value, so this should # be a ResponseError exception assert isinstance(result[2], redis.ResponseError) - assert r['c'] == b'a' + assert r["c"] == b"a" # since this isn't a transaction, the other commands after the # error are still executed assert result[3] - assert r['d'] == b'4' + assert r["d"] == b"4" # make sure the pipe was restored to a working state - assert pipe.set('z', 'zzz').execute() == [True] - assert r['z'] == b'zzz' + assert pipe.set("z", "zzz").execute() == [True] + assert r["z"] == b"zzz" def test_exec_error_raised(self, r): - r['c'] = 'a' + r["c"] = "a" with r.pipeline() as pipe: - pipe.set('a', 1).set('b', 2).lpush('c', 3).set('d', 4) + pipe.set("a", 1).set("b", 2).lpush("c", 3).set("d", 4) with pytest.raises(redis.ResponseError) as ex: pipe.execute() - assert str(ex.value).startswith('Command # 3 (LPUSH c 3) of ' - 'pipeline caused error: ') + assert str(ex.value).startswith( + "Command # 3 (LPUSH c 3) of " "pipeline caused error: " + ) # make sure the pipe was restored to a working state - assert pipe.set('z', 'zzz').execute() == [True] - assert r['z'] == b'zzz' + assert pipe.set("z", "zzz").execute() == [True] + assert r["z"] == b"zzz" @pytest.mark.onlynoncluster def test_transaction_with_empty_error_command(self, r): @@ -139,7 +140,7 @@ def test_transaction_with_empty_error_command(self, r): """ for error_switch in (True, False): with r.pipeline() as pipe: - pipe.set('a', 1).mget([]).set('c', 3) + pipe.set("a", 1).mget([]).set("c", 3) result = pipe.execute(raise_on_error=error_switch) assert result[0] @@ -154,7 +155,7 @@ def test_pipeline_with_empty_error_command(self, r): """ for error_switch in (True, False): with r.pipeline(transaction=False) as pipe: - pipe.set('a', 1).mget([]).set('c', 3) + pipe.set("a", 1).mget([]).set("c", 3) result = pipe.execute(raise_on_error=error_switch) assert result[0] @@ -164,61 +165,63 @@ def test_pipeline_with_empty_error_command(self, r): def test_parse_error_raised(self, r): with r.pipeline() as pipe: # the zrem is invalid because we don't pass any keys to it - pipe.set('a', 1).zrem('b').set('b', 2) + pipe.set("a", 1).zrem("b").set("b", 2) with pytest.raises(redis.ResponseError) as ex: pipe.execute() - assert str(ex.value).startswith('Command # 2 (ZREM b) of ' - 'pipeline caused error: ') + assert str(ex.value).startswith( + "Command # 2 (ZREM b) of " "pipeline caused error: " + ) # make sure the pipe was restored to a working state - assert pipe.set('z', 'zzz').execute() == [True] - assert r['z'] == b'zzz' + assert pipe.set("z", "zzz").execute() == [True] + assert r["z"] == b"zzz" @pytest.mark.onlynoncluster def test_parse_error_raised_transaction(self, r): with r.pipeline() as pipe: pipe.multi() # the zrem is invalid because we don't pass any keys to it - pipe.set('a', 1).zrem('b').set('b', 2) + pipe.set("a", 1).zrem("b").set("b", 2) with pytest.raises(redis.ResponseError) as ex: pipe.execute() - assert str(ex.value).startswith('Command # 2 (ZREM b) of ' - 'pipeline caused error: ') + assert str(ex.value).startswith( + "Command # 2 (ZREM b) of " "pipeline caused error: " + ) # make sure the pipe was restored to a working state - assert pipe.set('z', 'zzz').execute() == [True] - assert r['z'] == b'zzz' + assert pipe.set("z", "zzz").execute() == [True] + assert r["z"] == b"zzz" @pytest.mark.onlynoncluster def test_watch_succeed(self, r): - r['a'] = 1 - r['b'] = 2 + r["a"] = 1 + r["b"] = 2 with r.pipeline() as pipe: - pipe.watch('a', 'b') + pipe.watch("a", "b") assert pipe.watching - a_value = pipe.get('a') - b_value = pipe.get('b') - assert a_value == b'1' - assert b_value == b'2' + a_value = pipe.get("a") + b_value = pipe.get("b") + assert a_value == b"1" + assert b_value == b"2" pipe.multi() - pipe.set('c', 3) + pipe.set("c", 3) assert pipe.execute() == [True] assert not pipe.watching @pytest.mark.onlynoncluster def test_watch_failure(self, r): - r['a'] = 1 - r['b'] = 2 + r["a"] = 1 + r["b"] = 2 with r.pipeline() as pipe: - pipe.watch('a', 'b') - r['b'] = 3 + pipe.watch("a", "b") + r["b"] = 3 pipe.multi() - pipe.get('a') + pipe.get("a") with pytest.raises(redis.WatchError): pipe.execute() @@ -226,12 +229,12 @@ def test_watch_failure(self, r): @pytest.mark.onlynoncluster def test_watch_failure_in_empty_transaction(self, r): - r['a'] = 1 - r['b'] = 2 + r["a"] = 1 + r["b"] = 2 with r.pipeline() as pipe: - pipe.watch('a', 'b') - r['b'] = 3 + pipe.watch("a", "b") + r["b"] = 3 pipe.multi() with pytest.raises(redis.WatchError): pipe.execute() @@ -240,103 +243,104 @@ def test_watch_failure_in_empty_transaction(self, r): @pytest.mark.onlynoncluster def test_unwatch(self, r): - r['a'] = 1 - r['b'] = 2 + r["a"] = 1 + r["b"] = 2 with r.pipeline() as pipe: - pipe.watch('a', 'b') - r['b'] = 3 + pipe.watch("a", "b") + r["b"] = 3 pipe.unwatch() assert not pipe.watching - pipe.get('a') - assert pipe.execute() == [b'1'] + pipe.get("a") + assert pipe.execute() == [b"1"] @pytest.mark.onlynoncluster def test_watch_exec_no_unwatch(self, r): - r['a'] = 1 - r['b'] = 2 + r["a"] = 1 + r["b"] = 2 with r.monitor() as m: with r.pipeline() as pipe: - pipe.watch('a', 'b') + pipe.watch("a", "b") assert pipe.watching - a_value = pipe.get('a') - b_value = pipe.get('b') - assert a_value == b'1' - assert b_value == b'2' + a_value = pipe.get("a") + b_value = pipe.get("b") + assert a_value == b"1" + assert b_value == b"2" pipe.multi() - pipe.set('c', 3) + pipe.set("c", 3) assert pipe.execute() == [True] assert not pipe.watching - unwatch_command = wait_for_command(r, m, 'UNWATCH') + unwatch_command = wait_for_command(r, m, "UNWATCH") assert unwatch_command is None, "should not send UNWATCH" @pytest.mark.onlynoncluster def test_watch_reset_unwatch(self, r): - r['a'] = 1 + r["a"] = 1 with r.monitor() as m: with r.pipeline() as pipe: - pipe.watch('a') + pipe.watch("a") assert pipe.watching pipe.reset() assert not pipe.watching - unwatch_command = wait_for_command(r, m, 'UNWATCH') + unwatch_command = wait_for_command(r, m, "UNWATCH") assert unwatch_command is not None - assert unwatch_command['command'] == 'UNWATCH' + assert unwatch_command["command"] == "UNWATCH" @pytest.mark.onlynoncluster def test_transaction_callable(self, r): - r['a'] = 1 - r['b'] = 2 + r["a"] = 1 + r["b"] = 2 has_run = [] def my_transaction(pipe): - a_value = pipe.get('a') - assert a_value in (b'1', b'2') - b_value = pipe.get('b') - assert b_value == b'2' + a_value = pipe.get("a") + assert a_value in (b"1", b"2") + b_value = pipe.get("b") + assert b_value == b"2" # silly run-once code... incr's "a" so WatchError should be raised # forcing this all to run again. this should incr "a" once to "2" if not has_run: - r.incr('a') - has_run.append('it has') + r.incr("a") + has_run.append("it has") pipe.multi() - pipe.set('c', int(a_value) + int(b_value)) + pipe.set("c", int(a_value) + int(b_value)) - result = r.transaction(my_transaction, 'a', 'b') + result = r.transaction(my_transaction, "a", "b") assert result == [True] - assert r['c'] == b'4' + assert r["c"] == b"4" @pytest.mark.onlynoncluster def test_transaction_callable_returns_value_from_callable(self, r): def callback(pipe): # No need to do anything here since we only want the return value - return 'a' + return "a" - res = r.transaction(callback, 'my-key', value_from_callable=True) - assert res == 'a' + res = r.transaction(callback, "my-key", value_from_callable=True) + assert res == "a" def test_exec_error_in_no_transaction_pipeline(self, r): - r['a'] = 1 + r["a"] = 1 with r.pipeline(transaction=False) as pipe: - pipe.llen('a') - pipe.expire('a', 100) + pipe.llen("a") + pipe.expire("a", 100) with pytest.raises(redis.ResponseError) as ex: pipe.execute() - assert str(ex.value).startswith('Command # 1 (LLEN a) of ' - 'pipeline caused error: ') + assert str(ex.value).startswith( + "Command # 1 (LLEN a) of " "pipeline caused error: " + ) - assert r['a'] == b'1' + assert r["a"] == b"1" def test_exec_error_in_no_transaction_pipeline_unicode_command(self, r): - key = chr(3456) + 'abcd' + chr(3421) + key = chr(3456) + "abcd" + chr(3421) r[key] = 1 with r.pipeline(transaction=False) as pipe: pipe.llen(key) @@ -345,51 +349,52 @@ def test_exec_error_in_no_transaction_pipeline_unicode_command(self, r): with pytest.raises(redis.ResponseError) as ex: pipe.execute() - expected = f'Command # 1 (LLEN {key}) of pipeline caused error: ' + expected = f"Command # 1 (LLEN {key}) of pipeline caused error: " assert str(ex.value).startswith(expected) - assert r[key] == b'1' + assert r[key] == b"1" def test_pipeline_with_bitfield(self, r): with r.pipeline() as pipe: - pipe.set('a', '1') - bf = pipe.bitfield('b') - pipe2 = (bf - .set('u8', 8, 255) - .get('u8', 0) - .get('u4', 8) # 1111 - .get('u4', 12) # 1111 - .get('u4', 13) # 1110 - .execute()) - pipe.get('a') + pipe.set("a", "1") + bf = pipe.bitfield("b") + pipe2 = ( + bf.set("u8", 8, 255) + .get("u8", 0) + .get("u4", 8) # 1111 + .get("u4", 12) # 1111 + .get("u4", 13) # 1110 + .execute() + ) + pipe.get("a") response = pipe.execute() assert pipe == pipe2 - assert response == [True, [0, 0, 15, 15, 14], b'1'] + assert response == [True, [0, 0, 15, 15, 14], b"1"] @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.0.0') + @skip_if_server_version_lt("2.0.0") def test_pipeline_discard(self, r): # empty pipeline should raise an error with r.pipeline() as pipe: - pipe.set('key', 'someval') + pipe.set("key", "someval") pipe.discard() with pytest.raises(redis.exceptions.ResponseError): pipe.execute() # setting a pipeline and discarding should do the same with r.pipeline() as pipe: - pipe.set('key', 'someval') - pipe.set('someotherkey', 'val') + pipe.set("key", "someval") + pipe.set("someotherkey", "val") response = pipe.execute() - pipe.set('key', 'another value!') + pipe.set("key", "another value!") pipe.discard() - pipe.set('key', 'another vae!') + pipe.set("key", "another vae!") with pytest.raises(redis.exceptions.ResponseError): pipe.execute() - pipe.set('foo', 'bar') + pipe.set("foo", "bar") response = pipe.execute() assert response[0] - assert r.get('foo') == b'bar' + assert r.get("foo") == b"bar" diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index b019bae6e2..6df0fafd4b 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -1,17 +1,14 @@ +import platform import threading import time from unittest import mock -import platform import pytest + import redis from redis.exceptions import ConnectionError -from .conftest import ( - _get_client, - skip_if_redis_enterprise, - skip_if_server_version_lt -) +from .conftest import _get_client, skip_if_redis_enterprise, skip_if_server_version_lt def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): @@ -19,7 +16,8 @@ def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): timeout = now + timeout while now < timeout: message = pubsub.get_message( - ignore_subscribe_messages=ignore_subscribe_messages) + ignore_subscribe_messages=ignore_subscribe_messages + ) if message is not None: return message time.sleep(0.01) @@ -29,39 +27,39 @@ def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): def make_message(type, channel, data, pattern=None): return { - 'type': type, - 'pattern': pattern and pattern.encode('utf-8') or None, - 'channel': channel and channel.encode('utf-8') or None, - 'data': data.encode('utf-8') if isinstance(data, str) else data + "type": type, + "pattern": pattern and pattern.encode("utf-8") or None, + "channel": channel and channel.encode("utf-8") or None, + "data": data.encode("utf-8") if isinstance(data, str) else data, } def make_subscribe_test_data(pubsub, type): - if type == 'channel': + if type == "channel": return { - 'p': pubsub, - 'sub_type': 'subscribe', - 'unsub_type': 'unsubscribe', - 'sub_func': pubsub.subscribe, - 'unsub_func': pubsub.unsubscribe, - 'keys': ['foo', 'bar', 'uni' + chr(4456) + 'code'] + "p": pubsub, + "sub_type": "subscribe", + "unsub_type": "unsubscribe", + "sub_func": pubsub.subscribe, + "unsub_func": pubsub.unsubscribe, + "keys": ["foo", "bar", "uni" + chr(4456) + "code"], } - elif type == 'pattern': + elif type == "pattern": return { - 'p': pubsub, - 'sub_type': 'psubscribe', - 'unsub_type': 'punsubscribe', - 'sub_func': pubsub.psubscribe, - 'unsub_func': pubsub.punsubscribe, - 'keys': ['f*', 'b*', 'uni' + chr(4456) + '*'] + "p": pubsub, + "sub_type": "psubscribe", + "unsub_type": "punsubscribe", + "sub_func": pubsub.psubscribe, + "unsub_func": pubsub.punsubscribe, + "keys": ["f*", "b*", "uni" + chr(4456) + "*"], } - assert False, f'invalid subscribe type: {type}' + assert False, f"invalid subscribe type: {type}" class TestPubSubSubscribeUnsubscribe: - - def _test_subscribe_unsubscribe(self, p, sub_type, unsub_type, sub_func, - unsub_func, keys): + def _test_subscribe_unsubscribe( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): for key in keys: assert sub_func(key) is None @@ -79,15 +77,16 @@ def _test_subscribe_unsubscribe(self, p, sub_type, unsub_type, sub_func, assert wait_for_message(p) == make_message(unsub_type, key, i) def test_channel_subscribe_unsubscribe(self, r): - kwargs = make_subscribe_test_data(r.pubsub(), 'channel') + kwargs = make_subscribe_test_data(r.pubsub(), "channel") self._test_subscribe_unsubscribe(**kwargs) def test_pattern_subscribe_unsubscribe(self, r): - kwargs = make_subscribe_test_data(r.pubsub(), 'pattern') + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_subscribe_unsubscribe(**kwargs) - def _test_resubscribe_on_reconnection(self, p, sub_type, unsub_type, - sub_func, unsub_func, keys): + def _test_resubscribe_on_reconnection( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): for key in keys: assert sub_func(key) is None @@ -109,10 +108,10 @@ def _test_resubscribe_on_reconnection(self, p, sub_type, unsub_type, unique_channels = set() assert len(messages) == len(keys) for i, message in enumerate(messages): - assert message['type'] == sub_type - assert message['data'] == i + 1 - assert isinstance(message['channel'], bytes) - channel = message['channel'].decode('utf-8') + assert message["type"] == sub_type + assert message["data"] == i + 1 + assert isinstance(message["channel"], bytes) + channel = message["channel"].decode("utf-8") unique_channels.add(channel) assert len(unique_channels) == len(keys) @@ -120,16 +119,17 @@ def _test_resubscribe_on_reconnection(self, p, sub_type, unsub_type, assert channel in keys def test_resubscribe_to_channels_on_reconnection(self, r): - kwargs = make_subscribe_test_data(r.pubsub(), 'channel') + kwargs = make_subscribe_test_data(r.pubsub(), "channel") self._test_resubscribe_on_reconnection(**kwargs) @pytest.mark.onlynoncluster def test_resubscribe_to_patterns_on_reconnection(self, r): - kwargs = make_subscribe_test_data(r.pubsub(), 'pattern') + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_resubscribe_on_reconnection(**kwargs) - def _test_subscribed_property(self, p, sub_type, unsub_type, sub_func, - unsub_func, keys): + def _test_subscribed_property( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): assert p.subscribed is False sub_func(keys[0]) @@ -175,22 +175,22 @@ def _test_subscribed_property(self, p, sub_type, unsub_type, sub_func, assert p.subscribed is False def test_subscribe_property_with_channels(self, r): - kwargs = make_subscribe_test_data(r.pubsub(), 'channel') + kwargs = make_subscribe_test_data(r.pubsub(), "channel") self._test_subscribed_property(**kwargs) @pytest.mark.onlynoncluster def test_subscribe_property_with_patterns(self, r): - kwargs = make_subscribe_test_data(r.pubsub(), 'pattern') + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_subscribed_property(**kwargs) def test_ignore_all_subscribe_messages(self, r): p = r.pubsub(ignore_subscribe_messages=True) checks = ( - (p.subscribe, 'foo'), - (p.unsubscribe, 'foo'), - (p.psubscribe, 'f*'), - (p.punsubscribe, 'f*'), + (p.subscribe, "foo"), + (p.unsubscribe, "foo"), + (p.psubscribe, "f*"), + (p.punsubscribe, "f*"), ) assert p.subscribed is False @@ -204,10 +204,10 @@ def test_ignore_individual_subscribe_messages(self, r): p = r.pubsub() checks = ( - (p.subscribe, 'foo'), - (p.unsubscribe, 'foo'), - (p.psubscribe, 'f*'), - (p.punsubscribe, 'f*'), + (p.subscribe, "foo"), + (p.unsubscribe, "foo"), + (p.psubscribe, "f*"), + (p.punsubscribe, "f*"), ) assert p.subscribed is False @@ -219,16 +219,17 @@ def test_ignore_individual_subscribe_messages(self, r): assert p.subscribed is False def test_sub_unsub_resub_channels(self, r): - kwargs = make_subscribe_test_data(r.pubsub(), 'channel') + kwargs = make_subscribe_test_data(r.pubsub(), "channel") self._test_sub_unsub_resub(**kwargs) @pytest.mark.onlynoncluster def test_sub_unsub_resub_patterns(self, r): - kwargs = make_subscribe_test_data(r.pubsub(), 'pattern') + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_sub_unsub_resub(**kwargs) - def _test_sub_unsub_resub(self, p, sub_type, unsub_type, sub_func, - unsub_func, keys): + def _test_sub_unsub_resub( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): # https://github.com/andymccurdy/redis-py/issues/764 key = keys[0] sub_func(key) @@ -241,15 +242,16 @@ def _test_sub_unsub_resub(self, p, sub_type, unsub_type, sub_func, assert p.subscribed is True def test_sub_unsub_all_resub_channels(self, r): - kwargs = make_subscribe_test_data(r.pubsub(), 'channel') + kwargs = make_subscribe_test_data(r.pubsub(), "channel") self._test_sub_unsub_all_resub(**kwargs) def test_sub_unsub_all_resub_patterns(self, r): - kwargs = make_subscribe_test_data(r.pubsub(), 'pattern') + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_sub_unsub_all_resub(**kwargs) - def _test_sub_unsub_all_resub(self, p, sub_type, unsub_type, sub_func, - unsub_func, keys): + def _test_sub_unsub_all_resub( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): # https://github.com/andymccurdy/redis-py/issues/764 key = keys[0] sub_func(key) @@ -271,22 +273,22 @@ def message_handler(self, message): def test_published_message_to_channel(self, r): p = r.pubsub() - p.subscribe('foo') - assert wait_for_message(p) == make_message('subscribe', 'foo', 1) - assert r.publish('foo', 'test message') == 1 + p.subscribe("foo") + assert wait_for_message(p) == make_message("subscribe", "foo", 1) + assert r.publish("foo", "test message") == 1 message = wait_for_message(p) assert isinstance(message, dict) - assert message == make_message('message', 'foo', 'test message') + assert message == make_message("message", "foo", "test message") def test_published_message_to_pattern(self, r): p = r.pubsub() - p.subscribe('foo') - p.psubscribe('f*') - assert wait_for_message(p) == make_message('subscribe', 'foo', 1) - assert wait_for_message(p) == make_message('psubscribe', 'f*', 2) + p.subscribe("foo") + p.psubscribe("f*") + assert wait_for_message(p) == make_message("subscribe", "foo", 1) + assert wait_for_message(p) == make_message("psubscribe", "f*", 2) # 1 to pattern, 1 to channel - assert r.publish('foo', 'test message') == 2 + assert r.publish("foo", "test message") == 2 message1 = wait_for_message(p) message2 = wait_for_message(p) @@ -294,8 +296,8 @@ def test_published_message_to_pattern(self, r): assert isinstance(message2, dict) expected = [ - make_message('message', 'foo', 'test message'), - make_message('pmessage', 'foo', 'test message', pattern='f*') + make_message("message", "foo", "test message"), + make_message("pmessage", "foo", "test message", pattern="f*"), ] assert message1 in expected @@ -306,67 +308,65 @@ def test_channel_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) p.subscribe(foo=self.message_handler) assert wait_for_message(p) is None - assert r.publish('foo', 'test message') == 1 + assert r.publish("foo", "test message") == 1 assert wait_for_message(p) is None - assert self.message == make_message('message', 'foo', 'test message') + assert self.message == make_message("message", "foo", "test message") @pytest.mark.onlynoncluster def test_pattern_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) - p.psubscribe(**{'f*': self.message_handler}) + p.psubscribe(**{"f*": self.message_handler}) assert wait_for_message(p) is None - assert r.publish('foo', 'test message') == 1 + assert r.publish("foo", "test message") == 1 assert wait_for_message(p) is None - assert self.message == make_message('pmessage', 'foo', 'test message', - pattern='f*') + assert self.message == make_message( + "pmessage", "foo", "test message", pattern="f*" + ) def test_unicode_channel_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) - channel = 'uni' + chr(4456) + 'code' + channel = "uni" + chr(4456) + "code" channels = {channel: self.message_handler} p.subscribe(**channels) assert wait_for_message(p) is None - assert r.publish(channel, 'test message') == 1 + assert r.publish(channel, "test message") == 1 assert wait_for_message(p) is None - assert self.message == make_message('message', channel, 'test message') + assert self.message == make_message("message", channel, "test message") @pytest.mark.onlynoncluster # see: https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html # #known-limitations-with-pubsub def test_unicode_pattern_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) - pattern = 'uni' + chr(4456) + '*' - channel = 'uni' + chr(4456) + 'code' + pattern = "uni" + chr(4456) + "*" + channel = "uni" + chr(4456) + "code" p.psubscribe(**{pattern: self.message_handler}) assert wait_for_message(p) is None - assert r.publish(channel, 'test message') == 1 + assert r.publish(channel, "test message") == 1 assert wait_for_message(p) is None - assert self.message == make_message('pmessage', channel, - 'test message', pattern=pattern) + assert self.message == make_message( + "pmessage", channel, "test message", pattern=pattern + ) def test_get_message_without_subscribe(self, r): p = r.pubsub() with pytest.raises(RuntimeError) as info: p.get_message() - expect = ('connection not set: ' - 'did you forget to call subscribe() or psubscribe()?') + expect = ( + "connection not set: " "did you forget to call subscribe() or psubscribe()?" + ) assert expect in info.exconly() class TestPubSubAutoDecoding: "These tests only validate that we get unicode values back" - channel = 'uni' + chr(4456) + 'code' - pattern = 'uni' + chr(4456) + '*' - data = 'abc' + chr(4458) + '123' + channel = "uni" + chr(4456) + "code" + pattern = "uni" + chr(4456) + "*" + data = "abc" + chr(4458) + "123" def make_message(self, type, channel, data, pattern=None): - return { - 'type': type, - 'channel': channel, - 'pattern': pattern, - 'data': data - } + return {"type": type, "channel": channel, "pattern": pattern, "data": data} def setup_method(self, method): self.message = None @@ -381,44 +381,37 @@ def r(self, request): def test_channel_subscribe_unsubscribe(self, r): p = r.pubsub() p.subscribe(self.channel) - assert wait_for_message(p) == self.make_message('subscribe', - self.channel, 1) + assert wait_for_message(p) == self.make_message("subscribe", self.channel, 1) p.unsubscribe(self.channel) - assert wait_for_message(p) == self.make_message('unsubscribe', - self.channel, 0) + assert wait_for_message(p) == self.make_message("unsubscribe", self.channel, 0) def test_pattern_subscribe_unsubscribe(self, r): p = r.pubsub() p.psubscribe(self.pattern) - assert wait_for_message(p) == self.make_message('psubscribe', - self.pattern, 1) + assert wait_for_message(p) == self.make_message("psubscribe", self.pattern, 1) p.punsubscribe(self.pattern) - assert wait_for_message(p) == self.make_message('punsubscribe', - self.pattern, 0) + assert wait_for_message(p) == self.make_message("punsubscribe", self.pattern, 0) def test_channel_publish(self, r): p = r.pubsub() p.subscribe(self.channel) - assert wait_for_message(p) == self.make_message('subscribe', - self.channel, 1) + assert wait_for_message(p) == self.make_message("subscribe", self.channel, 1) r.publish(self.channel, self.data) - assert wait_for_message(p) == self.make_message('message', - self.channel, - self.data) + assert wait_for_message(p) == self.make_message( + "message", self.channel, self.data + ) @pytest.mark.onlynoncluster def test_pattern_publish(self, r): p = r.pubsub() p.psubscribe(self.pattern) - assert wait_for_message(p) == self.make_message('psubscribe', - self.pattern, 1) + assert wait_for_message(p) == self.make_message("psubscribe", self.pattern, 1) r.publish(self.channel, self.data) - assert wait_for_message(p) == self.make_message('pmessage', - self.channel, - self.data, - pattern=self.pattern) + assert wait_for_message(p) == self.make_message( + "pmessage", self.channel, self.data, pattern=self.pattern + ) def test_channel_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) @@ -426,18 +419,16 @@ def test_channel_message_handler(self, r): assert wait_for_message(p) is None r.publish(self.channel, self.data) assert wait_for_message(p) is None - assert self.message == self.make_message('message', self.channel, - self.data) + assert self.message == self.make_message("message", self.channel, self.data) # test that we reconnected to the correct channel self.message = None p.connection.disconnect() assert wait_for_message(p) is None # should reconnect - new_data = self.data + 'new data' + new_data = self.data + "new data" r.publish(self.channel, new_data) assert wait_for_message(p) is None - assert self.message == self.make_message('message', self.channel, - new_data) + assert self.message == self.make_message("message", self.channel, new_data) def test_pattern_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) @@ -445,24 +436,24 @@ def test_pattern_message_handler(self, r): assert wait_for_message(p) is None r.publish(self.channel, self.data) assert wait_for_message(p) is None - assert self.message == self.make_message('pmessage', self.channel, - self.data, - pattern=self.pattern) + assert self.message == self.make_message( + "pmessage", self.channel, self.data, pattern=self.pattern + ) # test that we reconnected to the correct pattern self.message = None p.connection.disconnect() assert wait_for_message(p) is None # should reconnect - new_data = self.data + 'new data' + new_data = self.data + "new data" r.publish(self.channel, new_data) assert wait_for_message(p) is None - assert self.message == self.make_message('pmessage', self.channel, - new_data, - pattern=self.pattern) + assert self.message == self.make_message( + "pmessage", self.channel, new_data, pattern=self.pattern + ) def test_context_manager(self, r): with r.pubsub() as pubsub: - pubsub.subscribe('foo') + pubsub.subscribe("foo") assert pubsub.connection is not None assert pubsub.connection is None @@ -471,86 +462,82 @@ def test_context_manager(self, r): class TestPubSubRedisDown: - def test_channel_subscribe(self, r): - r = redis.Redis(host='localhost', port=6390) + r = redis.Redis(host="localhost", port=6390) p = r.pubsub() with pytest.raises(ConnectionError): - p.subscribe('foo') + p.subscribe("foo") class TestPubSubSubcommands: - @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.8.0') + @skip_if_server_version_lt("2.8.0") def test_pubsub_channels(self, r): p = r.pubsub() - p.subscribe('foo', 'bar', 'baz', 'quux') + p.subscribe("foo", "bar", "baz", "quux") for i in range(4): - assert wait_for_message(p)['type'] == 'subscribe' - expected = [b'bar', b'baz', b'foo', b'quux'] + assert wait_for_message(p)["type"] == "subscribe" + expected = [b"bar", b"baz", b"foo", b"quux"] assert all([channel in r.pubsub_channels() for channel in expected]) @pytest.mark.onlynoncluster - @skip_if_server_version_lt('2.8.0') + @skip_if_server_version_lt("2.8.0") def test_pubsub_numsub(self, r): p1 = r.pubsub() - p1.subscribe('foo', 'bar', 'baz') + p1.subscribe("foo", "bar", "baz") for i in range(3): - assert wait_for_message(p1)['type'] == 'subscribe' + assert wait_for_message(p1)["type"] == "subscribe" p2 = r.pubsub() - p2.subscribe('bar', 'baz') + p2.subscribe("bar", "baz") for i in range(2): - assert wait_for_message(p2)['type'] == 'subscribe' + assert wait_for_message(p2)["type"] == "subscribe" p3 = r.pubsub() - p3.subscribe('baz') - assert wait_for_message(p3)['type'] == 'subscribe' + p3.subscribe("baz") + assert wait_for_message(p3)["type"] == "subscribe" - channels = [(b'foo', 1), (b'bar', 2), (b'baz', 3)] - assert r.pubsub_numsub('foo', 'bar', 'baz') == channels + channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)] + assert r.pubsub_numsub("foo", "bar", "baz") == channels - @skip_if_server_version_lt('2.8.0') + @skip_if_server_version_lt("2.8.0") def test_pubsub_numpat(self, r): p = r.pubsub() - p.psubscribe('*oo', '*ar', 'b*z') + p.psubscribe("*oo", "*ar", "b*z") for i in range(3): - assert wait_for_message(p)['type'] == 'psubscribe' + assert wait_for_message(p)["type"] == "psubscribe" assert r.pubsub_numpat() == 3 class TestPubSubPings: - - @skip_if_server_version_lt('3.0.0') + @skip_if_server_version_lt("3.0.0") def test_send_pubsub_ping(self, r): p = r.pubsub(ignore_subscribe_messages=True) - p.subscribe('foo') + p.subscribe("foo") p.ping() - assert wait_for_message(p) == make_message(type='pong', channel=None, - data='', - pattern=None) + assert wait_for_message(p) == make_message( + type="pong", channel=None, data="", pattern=None + ) - @skip_if_server_version_lt('3.0.0') + @skip_if_server_version_lt("3.0.0") def test_send_pubsub_ping_message(self, r): p = r.pubsub(ignore_subscribe_messages=True) - p.subscribe('foo') - p.ping(message='hello world') - assert wait_for_message(p) == make_message(type='pong', channel=None, - data='hello world', - pattern=None) + p.subscribe("foo") + p.ping(message="hello world") + assert wait_for_message(p) == make_message( + type="pong", channel=None, data="hello world", pattern=None + ) @pytest.mark.onlynoncluster class TestPubSubConnectionKilled: - - @skip_if_server_version_lt('3.0.0') + @skip_if_server_version_lt("3.0.0") @skip_if_redis_enterprise def test_connection_error_raised_when_connection_dies(self, r): p = r.pubsub() - p.subscribe('foo') - assert wait_for_message(p) == make_message('subscribe', 'foo', 1) + p.subscribe("foo") + assert wait_for_message(p) == make_message("subscribe", "foo", 1) for client in r.client_list(): - if client['cmd'] == 'subscribe': - r.client_kill_filter(_id=client['id']) + if client["cmd"] == "subscribe": + r.client_kill_filter(_id=client["id"]) with pytest.raises(ConnectionError): wait_for_message(p) @@ -558,15 +545,15 @@ def test_connection_error_raised_when_connection_dies(self, r): class TestPubSubTimeouts: def test_get_message_with_timeout_returns_none(self, r): p = r.pubsub() - p.subscribe('foo') - assert wait_for_message(p) == make_message('subscribe', 'foo', 1) + p.subscribe("foo") + assert wait_for_message(p) == make_message("subscribe", "foo", 1) assert p.get_message(timeout=0.01) is None class TestPubSubWorkerThread: - - @pytest.mark.skipif(platform.python_implementation() == 'PyPy', - reason="Pypy threading issue") + @pytest.mark.skipif( + platform.python_implementation() == "PyPy", reason="Pypy threading issue" + ) def test_pubsub_worker_thread_exception_handler(self, r): event = threading.Event() @@ -575,12 +562,10 @@ def exception_handler(ex, pubsub, thread): event.set() p = r.pubsub() - p.subscribe(**{'foo': lambda m: m}) - with mock.patch.object(p, 'get_message', - side_effect=Exception('error')): + p.subscribe(**{"foo": lambda m: m}) + with mock.patch.object(p, "get_message", side_effect=Exception("error")): pubsub_thread = p.run_in_thread( - daemon=True, - exception_handler=exception_handler + daemon=True, exception_handler=exception_handler ) assert event.wait(timeout=1.0) @@ -589,10 +574,9 @@ def exception_handler(ex, pubsub, thread): class TestPubSubDeadlock: - @pytest.mark.timeout(30, method='thread') + @pytest.mark.timeout(30, method="thread") def test_pubsub_deadlock(self, master_host): - pool = redis.ConnectionPool(host=master_host[0], - port=master_host[1]) + pool = redis.ConnectionPool(host=master_host[0], port=master_host[1]) r = redis.Redis(connection_pool=pool) for i in range(60): diff --git a/tox.ini b/tox.ini index 90ff8e71ea..9d78e2a028 100644 --- a/tox.ini +++ b/tox.ini @@ -172,4 +172,3 @@ ignore = W503 E203 E126 -max-line-length = 88