Skip to content

Commit 7d474f9

Browse files
woutdenolfdvora-h
andauthoredMar 16, 2023
introduce AbstractConnection so that UnixDomainSocketConnection can call super().__init__ (#2588)
Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>
1 parent c871723 commit 7d474f9

File tree

1 file changed

+120
-158
lines changed

1 file changed

+120
-158
lines changed
 

‎redis/connection.py

+120-158
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import sys
77
import threading
88
import weakref
9+
from abc import abstractmethod
910
from io import SEEK_END
1011
from itertools import chain
1112
from queue import Empty, Full, LifoQueue
@@ -583,20 +584,13 @@ def pack(self, *args):
583584
return output
584585

585586

586-
class Connection:
587-
"Manages TCP communication to and from a Redis server"
587+
class AbstractConnection:
588+
"Manages communication to and from a Redis server"
588589

589590
def __init__(
590591
self,
591-
host="localhost",
592-
port=6379,
593592
db=0,
594593
password=None,
595-
socket_timeout=None,
596-
socket_connect_timeout=None,
597-
socket_keepalive=False,
598-
socket_keepalive_options=None,
599-
socket_type=0,
600594
retry_on_timeout=False,
601595
retry_on_error=SENTINEL,
602596
encoding="utf-8",
@@ -627,18 +621,11 @@ def __init__(
627621
"2. 'credential_provider'"
628622
)
629623
self.pid = os.getpid()
630-
self.host = host
631-
self.port = int(port)
632624
self.db = db
633625
self.client_name = client_name
634626
self.credential_provider = credential_provider
635627
self.password = password
636628
self.username = username
637-
self.socket_timeout = socket_timeout
638-
self.socket_connect_timeout = socket_connect_timeout or socket_timeout
639-
self.socket_keepalive = socket_keepalive
640-
self.socket_keepalive_options = socket_keepalive_options or {}
641-
self.socket_type = socket_type
642629
self.retry_on_timeout = retry_on_timeout
643630
if retry_on_error is SENTINEL:
644631
retry_on_error = []
@@ -671,11 +658,9 @@ def __repr__(self):
671658
repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
672659
return f"{self.__class__.__name__}<{repr_args}>"
673660

661+
@abstractmethod
674662
def repr_pieces(self):
675-
pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
676-
if self.client_name:
677-
pieces.append(("client_name", self.client_name))
678-
return pieces
663+
pass
679664

680665
def __del__(self):
681666
try:
@@ -738,75 +723,17 @@ def connect(self):
738723
if callback:
739724
callback(self)
740725

726+
@abstractmethod
741727
def _connect(self):
742-
"Create a TCP socket connection"
743-
# we want to mimic what socket.create_connection does to support
744-
# ipv4/ipv6, but we want to set options prior to calling
745-
# socket.connect()
746-
err = None
747-
for res in socket.getaddrinfo(
748-
self.host, self.port, self.socket_type, socket.SOCK_STREAM
749-
):
750-
family, socktype, proto, canonname, socket_address = res
751-
sock = None
752-
try:
753-
sock = socket.socket(family, socktype, proto)
754-
# TCP_NODELAY
755-
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
756-
757-
# TCP_KEEPALIVE
758-
if self.socket_keepalive:
759-
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
760-
for k, v in self.socket_keepalive_options.items():
761-
sock.setsockopt(socket.IPPROTO_TCP, k, v)
762-
763-
# set the socket_connect_timeout before we connect
764-
sock.settimeout(self.socket_connect_timeout)
765-
766-
# connect
767-
sock.connect(socket_address)
768-
769-
# set the socket_timeout now that we're connected
770-
sock.settimeout(self.socket_timeout)
771-
return sock
772-
773-
except OSError as _:
774-
err = _
775-
if sock is not None:
776-
sock.close()
777-
778-
if err is not None:
779-
raise err
780-
raise OSError("socket.getaddrinfo returned an empty list")
728+
pass
781729

730+
@abstractmethod
782731
def _host_error(self):
783-
try:
784-
host_error = f"{self.host}:{self.port}"
785-
except AttributeError:
786-
host_error = "connection"
787-
788-
return host_error
732+
pass
789733

734+
@abstractmethod
790735
def _error_message(self, exception):
791-
# args for socket.error can either be (errno, "message")
792-
# or just "message"
793-
794-
host_error = self._host_error()
795-
796-
if len(exception.args) == 1:
797-
try:
798-
return f"Error connecting to {host_error}. \
799-
{exception.args[0]}."
800-
except AttributeError:
801-
return f"Connection Error: {exception.args[0]}"
802-
else:
803-
try:
804-
return (
805-
f"Error {exception.args[0]} connecting to "
806-
f"{host_error}. {exception.args[1]}."
807-
)
808-
except AttributeError:
809-
return f"Connection Error: {exception.args[0]}"
736+
pass
810737

811738
def on_connect(self):
812739
"Initialize the connection, authenticate and select a database"
@@ -990,6 +917,101 @@ def pack_commands(self, commands):
990917
return output
991918

992919

920+
class Connection(AbstractConnection):
921+
"Manages TCP communication to and from a Redis server"
922+
923+
def __init__(
924+
self,
925+
host="localhost",
926+
port=6379,
927+
socket_timeout=None,
928+
socket_connect_timeout=None,
929+
socket_keepalive=False,
930+
socket_keepalive_options=None,
931+
socket_type=0,
932+
**kwargs,
933+
):
934+
self.host = host
935+
self.port = int(port)
936+
self.socket_timeout = socket_timeout
937+
self.socket_connect_timeout = socket_connect_timeout or socket_timeout
938+
self.socket_keepalive = socket_keepalive
939+
self.socket_keepalive_options = socket_keepalive_options or {}
940+
self.socket_type = socket_type
941+
super().__init__(**kwargs)
942+
943+
def repr_pieces(self):
944+
pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
945+
if self.client_name:
946+
pieces.append(("client_name", self.client_name))
947+
return pieces
948+
949+
def _connect(self):
950+
"Create a TCP socket connection"
951+
# we want to mimic what socket.create_connection does to support
952+
# ipv4/ipv6, but we want to set options prior to calling
953+
# socket.connect()
954+
err = None
955+
for res in socket.getaddrinfo(
956+
self.host, self.port, self.socket_type, socket.SOCK_STREAM
957+
):
958+
family, socktype, proto, canonname, socket_address = res
959+
sock = None
960+
try:
961+
sock = socket.socket(family, socktype, proto)
962+
# TCP_NODELAY
963+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
964+
965+
# TCP_KEEPALIVE
966+
if self.socket_keepalive:
967+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
968+
for k, v in self.socket_keepalive_options.items():
969+
sock.setsockopt(socket.IPPROTO_TCP, k, v)
970+
971+
# set the socket_connect_timeout before we connect
972+
sock.settimeout(self.socket_connect_timeout)
973+
974+
# connect
975+
sock.connect(socket_address)
976+
977+
# set the socket_timeout now that we're connected
978+
sock.settimeout(self.socket_timeout)
979+
return sock
980+
981+
except OSError as _:
982+
err = _
983+
if sock is not None:
984+
sock.close()
985+
986+
if err is not None:
987+
raise err
988+
raise OSError("socket.getaddrinfo returned an empty list")
989+
990+
def _host_error(self):
991+
return f"{self.host}:{self.port}"
992+
993+
def _error_message(self, exception):
994+
# args for socket.error can either be (errno, "message")
995+
# or just "message"
996+
997+
host_error = self._host_error()
998+
999+
if len(exception.args) == 1:
1000+
try:
1001+
return f"Error connecting to {host_error}. \
1002+
{exception.args[0]}."
1003+
except AttributeError:
1004+
return f"Connection Error: {exception.args[0]}"
1005+
else:
1006+
try:
1007+
return (
1008+
f"Error {exception.args[0]} connecting to "
1009+
f"{host_error}. {exception.args[1]}."
1010+
)
1011+
except AttributeError:
1012+
return f"Connection Error: {exception.args[0]}"
1013+
1014+
9931015
class SSLConnection(Connection):
9941016
"""Manages SSL connections to and from the Redis server(s).
9951017
This class extends the Connection class, adding SSL functionality, and making
@@ -1035,8 +1057,6 @@ def __init__(
10351057
if not ssl_available:
10361058
raise RedisError("Python wasn't built with SSL support")
10371059

1038-
super().__init__(**kwargs)
1039-
10401060
self.keyfile = ssl_keyfile
10411061
self.certfile = ssl_certfile
10421062
if ssl_cert_reqs is None:
@@ -1062,6 +1082,7 @@ def __init__(
10621082
self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
10631083
self.ssl_ocsp_context = ssl_ocsp_context
10641084
self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
1085+
super().__init__(**kwargs)
10651086

10661087
def _connect(self):
10671088
"Wrap the socket with SSL support"
@@ -1131,77 +1152,12 @@ def _connect(self):
11311152
return sslsock
11321153

11331154

1134-
class UnixDomainSocketConnection(Connection):
1135-
def __init__(
1136-
self,
1137-
path="",
1138-
db=0,
1139-
username=None,
1140-
password=None,
1141-
socket_timeout=None,
1142-
encoding="utf-8",
1143-
encoding_errors="strict",
1144-
decode_responses=False,
1145-
retry_on_timeout=False,
1146-
retry_on_error=SENTINEL,
1147-
parser_class=DefaultParser,
1148-
socket_read_size=65536,
1149-
health_check_interval=0,
1150-
client_name=None,
1151-
retry=None,
1152-
redis_connect_func=None,
1153-
credential_provider: Optional[CredentialProvider] = None,
1154-
command_packer=None,
1155-
):
1156-
"""
1157-
Initialize a new UnixDomainSocketConnection.
1158-
To specify a retry policy for specific errors, first set
1159-
`retry_on_error` to a list of the error/s to retry on, then set
1160-
`retry` to a valid `Retry` object.
1161-
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
1162-
"""
1163-
if (username or password) and credential_provider is not None:
1164-
raise DataError(
1165-
"'username' and 'password' cannot be passed along with 'credential_"
1166-
"provider'. Please provide only one of the following arguments: \n"
1167-
"1. 'password' and (optional) 'username'\n"
1168-
"2. 'credential_provider'"
1169-
)
1170-
self.pid = os.getpid()
1155+
class UnixDomainSocketConnection(AbstractConnection):
1156+
"Manages UDS communication to and from a Redis server"
1157+
1158+
def __init__(self, path="", **kwargs):
11711159
self.path = path
1172-
self.db = db
1173-
self.client_name = client_name
1174-
self.credential_provider = credential_provider
1175-
self.password = password
1176-
self.username = username
1177-
self.socket_timeout = socket_timeout
1178-
self.retry_on_timeout = retry_on_timeout
1179-
if retry_on_error is SENTINEL:
1180-
retry_on_error = []
1181-
if retry_on_timeout:
1182-
# Add TimeoutError to the errors list to retry on
1183-
retry_on_error.append(TimeoutError)
1184-
self.retry_on_error = retry_on_error
1185-
if self.retry_on_error:
1186-
if retry is None:
1187-
self.retry = Retry(NoBackoff(), 1)
1188-
else:
1189-
# deep-copy the Retry object as it is mutable
1190-
self.retry = copy.deepcopy(retry)
1191-
# Update the retry's supported errors with the specified errors
1192-
self.retry.update_supported_errors(retry_on_error)
1193-
else:
1194-
self.retry = Retry(NoBackoff(), 0)
1195-
self.health_check_interval = health_check_interval
1196-
self.next_health_check = 0
1197-
self.redis_connect_func = redis_connect_func
1198-
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
1199-
self._sock = None
1200-
self._socket_read_size = socket_read_size
1201-
self.set_parser(parser_class)
1202-
self._connect_callbacks = []
1203-
self._buffer_cutoff = 6000
1204-
self._command_packer = self._construct_command_packer(command_packer)
1160+
super().__init__(**kwargs)
12051161

12061162
def repr_pieces(self):
12071163
pieces = [("path", self.path), ("db", self.db)]
@@ -1216,15 +1172,21 @@ def _connect(self):
12161172
sock.connect(self.path)
12171173
return sock
12181174

1175+
def _host_error(self):
1176+
return self.path
1177+
12191178
def _error_message(self, exception):
12201179
# args for socket.error can either be (errno, "message")
12211180
# or just "message"
1181+
host_error = self._host_error()
12221182
if len(exception.args) == 1:
1223-
return f"Error connecting to unix socket: {self.path}. {exception.args[0]}."
1183+
return (
1184+
f"Error connecting to unix socket: {host_error}. {exception.args[0]}."
1185+
)
12241186
else:
12251187
return (
12261188
f"Error {exception.args[0]} connecting to unix socket: "
1227-
f"{self.path}. {exception.args[1]}."
1189+
f"{host_error}. {exception.args[1]}."
12281190
)
12291191

12301192

0 commit comments

Comments
 (0)
Please sign in to comment.