Skip to content

Type hint improvements #2952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion redis/_parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@


class BaseParser(ABC):

EXCEPTION_CLASSES = {
"ERR": {
"max number of clients reached": ConnectionError,
Expand Down
6 changes: 2 additions & 4 deletions redis/_parsers/resp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,8 @@ async def _read_response(
]
res = self.push_handler_func(response)
if not push_request:
return await (
self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
return await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return res
Expand Down
1 change: 0 additions & 1 deletion redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,7 +1155,6 @@ def __init__(
queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated
**connection_kwargs,
):

super().__init__(
connection_class=connection_class,
max_connections=max_connections,
Expand Down
118 changes: 68 additions & 50 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import time
import warnings
from itertools import chain
from typing import Optional, Type
from typing import Any, Callable, Dict, List, Optional, Type, Union

from redis._parsers.encoders import Encoder
from redis._parsers.helpers import (
_RedisCallbacks,
_RedisCallbacksRESP2,
Expand Down Expand Up @@ -49,7 +50,7 @@
class CaseInsensitiveDict(dict):
"Case insensitive dict implementation. Assumes string keys only."

def __init__(self, data):
def __init__(self, data: Dict[str, str]) -> None:
for k, v in data.items():
self[k.upper()] = v

Expand Down Expand Up @@ -93,7 +94,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
"""

@classmethod
def from_url(cls, url, **kwargs):
def from_url(cls, url: str, **kwargs) -> None:
"""
Return a Redis client object configured from the given URL

Expand Down Expand Up @@ -202,7 +203,7 @@ def __init__(
redis_connect_func=None,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
):
) -> None:
"""
Initialize a new Redis client.
To specify a retry policy for specific errors, first set
Expand Down Expand Up @@ -309,14 +310,14 @@ def __init__(
else:
self.response_callbacks.update(_RedisCallbacksRESP2)

def __repr__(self):
def __repr__(self) -> str:
return f"{type(self).__name__}<{repr(self.connection_pool)}>"

def get_encoder(self):
def get_encoder(self) -> "Encoder":
"""Get the connection pool's encoder"""
return self.connection_pool.get_encoder()

def get_connection_kwargs(self):
def get_connection_kwargs(self) -> Dict:
"""Get the connection's key-word arguments"""
return self.connection_pool.connection_kwargs

Expand All @@ -327,11 +328,11 @@ def set_retry(self, retry: "Retry") -> None:
self.get_connection_kwargs().update({"retry": retry})
self.connection_pool.set_retry(retry)

def set_response_callback(self, command, callback):
def set_response_callback(self, command: str, callback: Callable) -> None:
"""Set a custom Response Callback"""
self.response_callbacks[command] = callback

def load_external_module(self, funcname, func):
def load_external_module(self, funcname, func) -> None:
"""
This function can be used to add externally defined redis modules,
and their namespaces to the redis client.
Expand All @@ -354,7 +355,7 @@ def load_external_module(self, funcname, func):
"""
setattr(self, funcname, func)

def pipeline(self, transaction=True, shard_hint=None):
def pipeline(self, transaction=True, shard_hint=None) -> "Pipeline":
"""
Return a new pipeline object that can queue multiple commands for
later execution. ``transaction`` indicates whether all commands
Expand All @@ -366,7 +367,9 @@ def pipeline(self, transaction=True, shard_hint=None):
self.connection_pool, self.response_callbacks, transaction, shard_hint
)

def transaction(self, func, *watches, **kwargs):
def transaction(
self, func: Callable[["Pipeline"], None], *watches, **kwargs
) -> None:
"""
Convenience method for executing the callable `func` as a transaction
while watching all keys specified in `watches`. The 'func' callable
Expand All @@ -390,13 +393,13 @@ def transaction(self, func, *watches, **kwargs):

def lock(
self,
name,
timeout=None,
sleep=0.1,
blocking=True,
blocking_timeout=None,
lock_class=None,
thread_local=True,
name: str,
timeout: Optional[float] = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: Optional[float] = None,
lock_class: Union[None, Any] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any not includes None?

thread_local: bool = True,
):
"""
Return a new Lock object using key ``name`` that mimics
Expand Down Expand Up @@ -648,9 +651,9 @@ def __init__(
self,
connection_pool,
shard_hint=None,
ignore_subscribe_messages=False,
encoder=None,
push_handler_func=None,
ignore_subscribe_messages: bool = False,
encoder: Optional["Encoder"] = None,
push_handler_func: Union[None, Callable[[str], None]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
push_handler_func: Union[None, Callable[[str], None]] = None,
push_handler_func: Optional[Callable[[str], None]] = None,

):
self.connection_pool = connection_pool
self.shard_hint = shard_hint
Expand All @@ -672,13 +675,13 @@ def __init__(
_set_info_logger()
self.reset()

def __enter__(self):
def __enter__(self) -> "PubSub":
return self

def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type, exc_value, traceback) -> None:
self.reset()

def __del__(self):
def __del__(self) -> None:
try:
# if this object went out of scope prior to shutting down
# subscriptions, close the connection manually before
Expand All @@ -687,7 +690,7 @@ def __del__(self):
except Exception:
pass

def reset(self):
def reset(self) -> None:
if self.connection:
self.connection.disconnect()
self.connection._deregister_connect_callback(self.on_connect)
Expand All @@ -702,10 +705,10 @@ def reset(self):
self.pending_unsubscribe_patterns = set()
self.subscribed_event.clear()

def close(self):
def close(self) -> None:
self.reset()

def on_connect(self, connection):
def on_connect(self, connection) -> None:
"Re-subscribe to any channels and patterns previously subscribed to"
# NOTE: for python3, we can't pass bytestrings as keyword arguments
# so we need to decode channel/pattern names back to unicode strings
Expand All @@ -731,7 +734,7 @@ def on_connect(self, connection):
self.ssubscribe(**shard_channels)

@property
def subscribed(self):
def subscribed(self) -> bool:
"""Indicates if there are subscriptions to any channels or patterns"""
return self.subscribed_event.is_set()

Expand All @@ -757,7 +760,7 @@ def execute_command(self, *args):
self.clean_health_check_responses()
self._execute(connection, connection.send_command, *args, **kwargs)

def clean_health_check_responses(self):
def clean_health_check_responses(self) -> None:
"""
If any health check responses are present, clean them
"""
Expand All @@ -775,7 +778,7 @@ def clean_health_check_responses(self):
)
ttl -= 1

def _disconnect_raise_connect(self, conn, error):
def _disconnect_raise_connect(self, conn, error) -> None:
"""
Close the connection and raise an exception
if retry_on_timeout is not set or the error
Expand Down Expand Up @@ -826,7 +829,7 @@ def try_read():
return None
return response

def is_health_check_response(self, response):
def is_health_check_response(self, response) -> bool:
"""
Check if the response is a health check response.
If there are no subscriptions redis responds to PING command with a
Expand All @@ -837,7 +840,7 @@ def is_health_check_response(self, response):
self.health_check_response_b, # If there wasn't
]

def check_health(self):
def check_health(self) -> None:
conn = self.connection
if conn is None:
raise RuntimeError(
Expand All @@ -849,7 +852,7 @@ def check_health(self):
conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False)
self.health_check_response_counter += 1

def _normalize_keys(self, data):
def _normalize_keys(self, data) -> Dict:
"""
normalize channel/pattern names to be either bytes or strings
based on whether responses are automatically decoded. this saves us
Expand Down Expand Up @@ -983,7 +986,9 @@ def listen(self):
if response is not None:
yield response

def get_message(self, ignore_subscribe_messages=False, timeout=0.0):
def get_message(
self, ignore_subscribe_messages: bool = False, timeout: float = 0.0
):
"""
Get the next message if one is available, otherwise None.

Expand Down Expand Up @@ -1012,7 +1017,7 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0):

get_sharded_message = get_message

def ping(self, message=None):
def ping(self, message: Union[str, None] = None) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def ping(self, message: Union[str, None] = None) -> bool:
def ping(self, message: Optional[str] = None) -> bool:

"""
Ping the Redis server
"""
Expand Down Expand Up @@ -1093,7 +1098,12 @@ def handle_message(self, response, ignore_subscribe_messages=False):

return message

def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None):
def run_in_thread(
self,
sleep_time: int = 0,
daemon: bool = False,
exception_handler: Optional[Callable] = None,
) -> "PubSubWorkerThread":
for channel, handler in self.channels.items():
if handler is None:
raise PubSubError(f"Channel: '{channel}' has no handler registered")
Expand All @@ -1114,15 +1124,23 @@ def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None):


class PubSubWorkerThread(threading.Thread):
def __init__(self, pubsub, sleep_time, daemon=False, exception_handler=None):
def __init__(
self,
pubsub,
sleep_time: float,
daemon: bool = False,
exception_handler: Union[
Callable[[Exception, "PubSub", "PubSubWorkerThread"], None], None
] = None,
):
super().__init__()
self.daemon = daemon
self.pubsub = pubsub
self.sleep_time = sleep_time
self.exception_handler = exception_handler
self._running = threading.Event()

def run(self):
def run(self) -> None:
if self._running.is_set():
return
self._running.set()
Expand All @@ -1137,7 +1155,7 @@ def run(self):
self.exception_handler(e, pubsub, self)
pubsub.close()

def stop(self):
def stop(self) -> None:
# trip the flag so the run loop exits. the run loop will
# close the pubsub connection, which disconnects the socket
# and returns the connection to the pool.
Expand Down Expand Up @@ -1175,7 +1193,7 @@ def __init__(self, connection_pool, response_callbacks, transaction, shard_hint)
self.watching = False
self.reset()

def __enter__(self):
def __enter__(self) -> "Pipeline":
return self

def __exit__(self, exc_type, exc_value, traceback):
Expand All @@ -1187,14 +1205,14 @@ def __del__(self):
except Exception:
pass

def __len__(self):
def __len__(self) -> int:
return len(self.command_stack)

def __bool__(self):
def __bool__(self) -> bool:
"""Pipeline instances should always evaluate to True"""
return True

def reset(self):
def reset(self) -> None:
self.command_stack = []
self.scripts = set()
# make sure to reset the connection state in the event that we were
Expand All @@ -1217,11 +1235,11 @@ def reset(self):
self.connection_pool.release(self.connection)
self.connection = None

def close(self):
def close(self) -> None:
"""Close the pipeline"""
self.reset()

def multi(self):
def multi(self) -> None:
"""
Start a transactional block of the pipeline after WATCH commands
are issued. End the transactional block with `execute`.
Expand All @@ -1239,7 +1257,7 @@ def execute_command(self, *args, **kwargs):
return self.immediate_execute_command(*args, **kwargs)
return self.pipeline_execute_command(*args, **kwargs)

def _disconnect_reset_raise(self, conn, error):
def _disconnect_reset_raise(self, conn, error) -> None:
"""
Close the connection, reset watching state and
raise an exception if we were watching,
Expand Down Expand Up @@ -1282,7 +1300,7 @@ def immediate_execute_command(self, *args, **options):
lambda error: self._disconnect_reset_raise(conn, error),
)

def pipeline_execute_command(self, *args, **options):
def pipeline_execute_command(self, *args, **options) -> "Pipeline":
"""
Stage a command to be executed when execute() is next called

Expand All @@ -1297,7 +1315,7 @@ def pipeline_execute_command(self, *args, **options):
self.command_stack.append((args, options))
return self

def _execute_transaction(self, connection, commands, raise_on_error):
def _execute_transaction(self, connection, commands, raise_on_error) -> List:
cmds = chain([(("MULTI",), {})], commands, [(("EXEC",), {})])
all_cmds = connection.pack_commands(
[args for args, options in cmds if EMPTY_RESPONSE not in options]
Expand Down Expand Up @@ -1415,7 +1433,7 @@ def load_scripts(self):
if not exist:
s.sha = immediate("SCRIPT LOAD", s.script)

def _disconnect_raise_reset(self, conn, error):
def _disconnect_raise_reset(self, conn: Redis, error: Exception) -> None:
"""
Close the connection, raise an exception if we were watching,
and raise an exception if TimeoutError is not part of retry_on_error,
Expand Down Expand Up @@ -1477,6 +1495,6 @@ def watch(self, *names):
raise RedisError("Cannot issue a WATCH after a MULTI")
return self.execute_command("WATCH", *names)

def unwatch(self):
def unwatch(self) -> bool:
"""Unwatches all previously specified keys"""
return self.watching and self.execute_command("UNWATCH") or True
Loading