From 46042cea3dd9b1d71b03c824e05b9c887521b0c8 Mon Sep 17 00:00:00 2001 From: woutdenolf Date: Sat, 11 Feb 2023 08:54:13 +0100 Subject: [PATCH] introduce AbstractConnection so that UnixDomainSocketConnection can call super().__init__ --- redis/connection.py | 278 +++++++++++++++++++------------------------- 1 file changed, 120 insertions(+), 158 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index d35980c167..c19b1c02fc 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -6,6 +6,7 @@ import sys import threading import weakref +from abc import abstractmethod from io import SEEK_END from itertools import chain from queue import Empty, Full, LifoQueue @@ -585,20 +586,13 @@ def pack(self, *args): return output -class Connection: - "Manages TCP communication to and from a Redis server" +class AbstractConnection: + "Manages communication to and from a Redis server" def __init__( self, - host="localhost", - port=6379, db=0, password=None, - socket_timeout=None, - socket_connect_timeout=None, - socket_keepalive=False, - socket_keepalive_options=None, - socket_type=0, retry_on_timeout=False, retry_on_error=SENTINEL, encoding="utf-8", @@ -629,18 +623,11 @@ def __init__( "2. 'credential_provider'" ) self.pid = os.getpid() - self.host = host - self.port = int(port) self.db = db self.client_name = client_name self.credential_provider = credential_provider self.password = password self.username = username - self.socket_timeout = socket_timeout - self.socket_connect_timeout = socket_connect_timeout or socket_timeout - self.socket_keepalive = socket_keepalive - self.socket_keepalive_options = socket_keepalive_options or {} - self.socket_type = socket_type self.retry_on_timeout = retry_on_timeout if retry_on_error is SENTINEL: retry_on_error = [] @@ -673,11 +660,9 @@ def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) return f"{self.__class__.__name__}<{repr_args}>" + @abstractmethod def repr_pieces(self): - pieces = [("host", self.host), ("port", self.port), ("db", self.db)] - if self.client_name: - pieces.append(("client_name", self.client_name)) - return pieces + pass def __del__(self): try: @@ -740,75 +725,17 @@ def connect(self): if callback: callback(self) + @abstractmethod def _connect(self): - "Create a TCP socket connection" - # we want to mimic what socket.create_connection does to support - # ipv4/ipv6, but we want to set options prior to calling - # socket.connect() - err = None - for res in socket.getaddrinfo( - self.host, self.port, self.socket_type, socket.SOCK_STREAM - ): - family, socktype, proto, canonname, socket_address = res - sock = None - try: - sock = socket.socket(family, socktype, proto) - # TCP_NODELAY - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - - # TCP_KEEPALIVE - if self.socket_keepalive: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - for k, v in self.socket_keepalive_options.items(): - sock.setsockopt(socket.IPPROTO_TCP, k, v) - - # set the socket_connect_timeout before we connect - sock.settimeout(self.socket_connect_timeout) - - # connect - sock.connect(socket_address) - - # set the socket_timeout now that we're connected - sock.settimeout(self.socket_timeout) - return sock - - except OSError as _: - err = _ - if sock is not None: - sock.close() - - if err is not None: - raise err - raise OSError("socket.getaddrinfo returned an empty list") + pass + @abstractmethod def _host_error(self): - try: - host_error = f"{self.host}:{self.port}" - except AttributeError: - host_error = "connection" - - return host_error + pass + @abstractmethod def _error_message(self, exception): - # args for socket.error can either be (errno, "message") - # or just "message" - - host_error = self._host_error() - - if len(exception.args) == 1: - try: - return f"Error connecting to {host_error}. \ - {exception.args[0]}." - except AttributeError: - return f"Connection Error: {exception.args[0]}" - else: - try: - return ( - f"Error {exception.args[0]} connecting to " - f"{host_error}. {exception.args[1]}." - ) - except AttributeError: - return f"Connection Error: {exception.args[0]}" + pass def on_connect(self): "Initialize the connection, authenticate and select a database" @@ -992,6 +919,101 @@ def pack_commands(self, commands): return output +class Connection(AbstractConnection): + "Manages TCP communication to and from a Redis server" + + def __init__( + self, + host="localhost", + port=6379, + socket_timeout=None, + socket_connect_timeout=None, + socket_keepalive=False, + socket_keepalive_options=None, + socket_type=0, + **kwargs, + ): + self.host = host + self.port = int(port) + self.socket_timeout = socket_timeout + self.socket_connect_timeout = socket_connect_timeout or socket_timeout + self.socket_keepalive = socket_keepalive + self.socket_keepalive_options = socket_keepalive_options or {} + self.socket_type = socket_type + super().__init__(**kwargs) + + def repr_pieces(self): + pieces = [("host", self.host), ("port", self.port), ("db", self.db)] + if self.client_name: + pieces.append(("client_name", self.client_name)) + return pieces + + def _connect(self): + "Create a TCP socket connection" + # we want to mimic what socket.create_connection does to support + # ipv4/ipv6, but we want to set options prior to calling + # socket.connect() + err = None + for res in socket.getaddrinfo( + self.host, self.port, self.socket_type, socket.SOCK_STREAM + ): + family, socktype, proto, canonname, socket_address = res + sock = None + try: + sock = socket.socket(family, socktype, proto) + # TCP_NODELAY + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + # TCP_KEEPALIVE + if self.socket_keepalive: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + for k, v in self.socket_keepalive_options.items(): + sock.setsockopt(socket.IPPROTO_TCP, k, v) + + # set the socket_connect_timeout before we connect + sock.settimeout(self.socket_connect_timeout) + + # connect + sock.connect(socket_address) + + # set the socket_timeout now that we're connected + sock.settimeout(self.socket_timeout) + return sock + + except OSError as _: + err = _ + if sock is not None: + sock.close() + + if err is not None: + raise err + raise OSError("socket.getaddrinfo returned an empty list") + + def _host_error(self): + return f"{self.host}:{self.port}" + + def _error_message(self, exception): + # args for socket.error can either be (errno, "message") + # or just "message" + + host_error = self._host_error() + + if len(exception.args) == 1: + try: + return f"Error connecting to {host_error}. \ + {exception.args[0]}." + except AttributeError: + return f"Connection Error: {exception.args[0]}" + else: + try: + return ( + f"Error {exception.args[0]} connecting to " + f"{host_error}. {exception.args[1]}." + ) + except AttributeError: + return f"Connection Error: {exception.args[0]}" + + class SSLConnection(Connection): """Manages SSL connections to and from the Redis server(s). This class extends the Connection class, adding SSL functionality, and making @@ -1037,8 +1059,6 @@ def __init__( if not ssl_available: raise RedisError("Python wasn't built with SSL support") - super().__init__(**kwargs) - self.keyfile = ssl_keyfile self.certfile = ssl_certfile if ssl_cert_reqs is None: @@ -1064,6 +1084,7 @@ def __init__( self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled self.ssl_ocsp_context = ssl_ocsp_context self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert + super().__init__(**kwargs) def _connect(self): "Wrap the socket with SSL support" @@ -1133,77 +1154,12 @@ def _connect(self): return sslsock -class UnixDomainSocketConnection(Connection): - def __init__( - self, - path="", - db=0, - username=None, - password=None, - socket_timeout=None, - encoding="utf-8", - encoding_errors="strict", - decode_responses=False, - retry_on_timeout=False, - retry_on_error=SENTINEL, - parser_class=DefaultParser, - socket_read_size=65536, - health_check_interval=0, - client_name=None, - retry=None, - redis_connect_func=None, - credential_provider: Optional[CredentialProvider] = None, - command_packer=None, - ): - """ - Initialize a new UnixDomainSocketConnection. - To specify a retry policy for specific errors, first set - `retry_on_error` to a list of the error/s to retry on, then set - `retry` to a valid `Retry` object. - To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. - """ - if (username or password) and credential_provider is not None: - raise DataError( - "'username' and 'password' cannot be passed along with 'credential_" - "provider'. Please provide only one of the following arguments: \n" - "1. 'password' and (optional) 'username'\n" - "2. 'credential_provider'" - ) - self.pid = os.getpid() +class UnixDomainSocketConnection(AbstractConnection): + "Manages UDS communication to and from a Redis server" + + def __init__(self, path="", **kwargs): self.path = path - self.db = db - self.client_name = client_name - self.credential_provider = credential_provider - self.password = password - self.username = username - self.socket_timeout = socket_timeout - self.retry_on_timeout = retry_on_timeout - if retry_on_error is SENTINEL: - retry_on_error = [] - if retry_on_timeout: - # Add TimeoutError to the errors list to retry on - retry_on_error.append(TimeoutError) - self.retry_on_error = retry_on_error - if self.retry_on_error: - if retry is None: - self.retry = Retry(NoBackoff(), 1) - else: - # deep-copy the Retry object as it is mutable - self.retry = copy.deepcopy(retry) - # Update the retry's supported errors with the specified errors - self.retry.update_supported_errors(retry_on_error) - else: - self.retry = Retry(NoBackoff(), 0) - self.health_check_interval = health_check_interval - self.next_health_check = 0 - self.redis_connect_func = redis_connect_func - self.encoder = Encoder(encoding, encoding_errors, decode_responses) - self._sock = None - self._socket_read_size = socket_read_size - self.set_parser(parser_class) - self._connect_callbacks = [] - self._buffer_cutoff = 6000 - self._command_packer = self._construct_command_packer(command_packer) + super().__init__(**kwargs) def repr_pieces(self): pieces = [("path", self.path), ("db", self.db)] @@ -1218,15 +1174,21 @@ def _connect(self): sock.connect(self.path) return sock + def _host_error(self): + return self.path + def _error_message(self, exception): # args for socket.error can either be (errno, "message") # or just "message" + host_error = self._host_error() if len(exception.args) == 1: - return f"Error connecting to unix socket: {self.path}. {exception.args[0]}." + return ( + f"Error connecting to unix socket: {host_error}. {exception.args[0]}." + ) else: return ( f"Error {exception.args[0]} connecting to unix socket: " - f"{self.path}. {exception.args[1]}." + f"{host_error}. {exception.args[1]}." )