Skip to content

Commit a96a38a

Browse files
authored
Add support for PubSub with RESP3 parser (#2721)
* add resp3 pubsub * linters * _set_info_logger func * async pubsun * docstring
1 parent 0db4eba commit a96a38a

File tree

8 files changed

+197
-30
lines changed

8 files changed

+197
-30
lines changed

redis/asyncio/client.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
WatchError,
5858
)
5959
from redis.typing import ChannelT, EncodableT, KeyT
60-
from redis.utils import safe_str, str_if_bytes
60+
from redis.utils import HIREDIS_AVAILABLE, _set_info_logger, safe_str, str_if_bytes
6161

6262
PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
6363
_KeyT = TypeVar("_KeyT", bound=KeyT)
@@ -658,6 +658,7 @@ def __init__(
658658
shard_hint: Optional[str] = None,
659659
ignore_subscribe_messages: bool = False,
660660
encoder=None,
661+
push_handler_func: Optional[Callable] = None,
661662
):
662663
self.connection_pool = connection_pool
663664
self.shard_hint = shard_hint
@@ -666,6 +667,7 @@ def __init__(
666667
# we need to know the encoding options for this connection in order
667668
# to lookup channel and pattern names for callback handlers.
668669
self.encoder = encoder
670+
self.push_handler_func = push_handler_func
669671
if self.encoder is None:
670672
self.encoder = self.connection_pool.get_encoder()
671673
if self.encoder.decode_responses:
@@ -678,6 +680,8 @@ def __init__(
678680
b"pong",
679681
self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
680682
]
683+
if self.push_handler_func is None:
684+
_set_info_logger()
681685
self.channels = {}
682686
self.pending_unsubscribe_channels = set()
683687
self.patterns = {}
@@ -757,6 +761,8 @@ async def connect(self):
757761
self.connection.register_connect_callback(self.on_connect)
758762
else:
759763
await self.connection.connect()
764+
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
765+
self.connection._parser.set_push_handler(self.push_handler_func)
760766

761767
async def _disconnect_raise_connect(self, conn, error):
762768
"""
@@ -797,7 +803,9 @@ async def parse_response(self, block: bool = True, timeout: float = 0):
797803
await conn.connect()
798804

799805
read_timeout = None if block else timeout
800-
response = await self._execute(conn, conn.read_response, timeout=read_timeout)
806+
response = await self._execute(
807+
conn, conn.read_response, timeout=read_timeout, push_request=True
808+
)
801809

802810
if conn.health_check_interval and response == self.health_check_response:
803811
# ignore the health check message as user might not expect it
@@ -927,15 +935,19 @@ def ping(self, message=None) -> Awaitable:
927935
"""
928936
Ping the Redis server
929937
"""
930-
message = "" if message is None else message
931-
return self.execute_command("PING", message)
938+
args = ["PING", message] if message is not None else ["PING"]
939+
return self.execute_command(*args)
932940

933941
async def handle_message(self, response, ignore_subscribe_messages=False):
934942
"""
935943
Parses a pub/sub message. If the channel or pattern was subscribed to
936944
with a message handler, the handler is invoked instead of a parsed
937945
message being returned.
938946
"""
947+
if response is None:
948+
return None
949+
if isinstance(response, bytes):
950+
response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
939951
message_type = str_if_bytes(response[0])
940952
if message_type == "pmessage":
941953
message = {

redis/asyncio/connection.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -485,15 +485,29 @@ async def read_response(
485485
self,
486486
disable_decoding: bool = False,
487487
timeout: Optional[float] = None,
488+
push_request: Optional[bool] = False,
488489
):
489490
"""Read the response from a previously sent command"""
490491
read_timeout = timeout if timeout is not None else self.socket_timeout
491492
try:
492-
if read_timeout is not None:
493+
if (
494+
read_timeout is not None
495+
and self.protocol == "3"
496+
and not HIREDIS_AVAILABLE
497+
):
498+
async with async_timeout(read_timeout):
499+
response = await self._parser.read_response(
500+
disable_decoding=disable_decoding, push_request=push_request
501+
)
502+
elif read_timeout is not None:
493503
async with async_timeout(read_timeout):
494504
response = await self._parser.read_response(
495505
disable_decoding=disable_decoding
496506
)
507+
elif self.protocol == "3" and not HIREDIS_AVAILABLE:
508+
response = await self._parser.read_response(
509+
disable_decoding=disable_decoding, push_request=push_request
510+
)
497511
else:
498512
response = await self._parser.read_response(
499513
disable_decoding=disable_decoding

redis/client.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
)
2828
from redis.lock import Lock
2929
from redis.retry import Retry
30-
from redis.utils import safe_str, str_if_bytes
30+
from redis.utils import HIREDIS_AVAILABLE, _set_info_logger, safe_str, str_if_bytes
3131

3232
SYM_EMPTY = b""
3333
EMPTY_RESPONSE = "EMPTY_RESPONSE"
@@ -1429,6 +1429,7 @@ def __init__(
14291429
shard_hint=None,
14301430
ignore_subscribe_messages=False,
14311431
encoder=None,
1432+
push_handler_func=None,
14321433
):
14331434
self.connection_pool = connection_pool
14341435
self.shard_hint = shard_hint
@@ -1438,13 +1439,16 @@ def __init__(
14381439
# we need to know the encoding options for this connection in order
14391440
# to lookup channel and pattern names for callback handlers.
14401441
self.encoder = encoder
1442+
self.push_handler_func = push_handler_func
14411443
if self.encoder is None:
14421444
self.encoder = self.connection_pool.get_encoder()
14431445
self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE)
14441446
if self.encoder.decode_responses:
14451447
self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE]
14461448
else:
14471449
self.health_check_response = [b"pong", self.health_check_response_b]
1450+
if self.push_handler_func is None:
1451+
_set_info_logger()
14481452
self.reset()
14491453

14501454
def __enter__(self):
@@ -1515,6 +1519,8 @@ def execute_command(self, *args):
15151519
# register a callback that re-subscribes to any channels we
15161520
# were listening to when we were disconnected
15171521
self.connection.register_connect_callback(self.on_connect)
1522+
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
1523+
self.connection._parser.set_push_handler(self.push_handler_func)
15181524
connection = self.connection
15191525
kwargs = {"check_health": not self.subscribed}
15201526
if not self.subscribed:
@@ -1580,7 +1586,7 @@ def try_read():
15801586
return None
15811587
else:
15821588
conn.connect()
1583-
return conn.read_response()
1589+
return conn.read_response(push_request=True)
15841590

15851591
response = self._execute(conn, try_read)
15861592

@@ -1739,8 +1745,8 @@ def ping(self, message=None):
17391745
"""
17401746
Ping the Redis server
17411747
"""
1742-
message = "" if message is None else message
1743-
return self.execute_command("PING", message)
1748+
args = ["PING", message] if message is not None else ["PING"]
1749+
return self.execute_command(*args)
17441750

17451751
def handle_message(self, response, ignore_subscribe_messages=False):
17461752
"""
@@ -1750,6 +1756,8 @@ def handle_message(self, response, ignore_subscribe_messages=False):
17501756
"""
17511757
if response is None:
17521758
return None
1759+
if isinstance(response, bytes):
1760+
response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
17531761
message_type = str_if_bytes(response[0])
17541762
if message_type == "pmessage":
17551763
message = {

redis/connection.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -406,13 +406,18 @@ def can_read(self, timeout=0):
406406
self.disconnect()
407407
raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
408408

409-
def read_response(self, disable_decoding=False):
409+
def read_response(self, disable_decoding=False, push_request=False):
410410
"""Read the response from a previously sent command"""
411411

412412
host_error = self._host_error()
413413

414414
try:
415-
response = self._parser.read_response(disable_decoding=disable_decoding)
415+
if self.protocol == "3" and not HIREDIS_AVAILABLE:
416+
response = self._parser.read_response(
417+
disable_decoding=disable_decoding, push_request=push_request
418+
)
419+
else:
420+
response = self._parser.read_response(disable_decoding=disable_decoding)
416421
except socket.timeout:
417422
self.disconnect()
418423
raise TimeoutError(f"Timeout reading from {host_error}")
@@ -705,8 +710,9 @@ def _connect(self):
705710
class UnixDomainSocketConnection(AbstractConnection):
706711
"Manages UDS communication to and from a Redis server"
707712

708-
def __init__(self, path="", **kwargs):
713+
def __init__(self, path="", socket_timeout=None, **kwargs):
709714
self.path = path
715+
self.socket_timeout = socket_timeout
710716
super().__init__(**kwargs)
711717

712718
def repr_pieces(self):

redis/parsers/resp3.py

+74-7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from logging import getLogger
12
from typing import Any, Union
23

34
from ..exceptions import ConnectionError, InvalidResponse, ResponseError
@@ -9,18 +10,29 @@
910
class _RESP3Parser(_RESPBase):
1011
"""RESP3 protocol implementation"""
1112

12-
def read_response(self, disable_decoding=False):
13+
def __init__(self, socket_read_size):
14+
super().__init__(socket_read_size)
15+
self.push_handler_func = self.handle_push_response
16+
17+
def handle_push_response(self, response):
18+
logger = getLogger("push_response")
19+
logger.info("Push response: " + str(response))
20+
return response
21+
22+
def read_response(self, disable_decoding=False, push_request=False):
1323
pos = self._buffer.get_pos()
1424
try:
15-
result = self._read_response(disable_decoding=disable_decoding)
25+
result = self._read_response(
26+
disable_decoding=disable_decoding, push_request=push_request
27+
)
1628
except BaseException:
1729
self._buffer.rewind(pos)
1830
raise
1931
else:
2032
self._buffer.purge()
2133
return result
2234

23-
def _read_response(self, disable_decoding=False):
35+
def _read_response(self, disable_decoding=False, push_request=False):
2436
raw = self._buffer.readline()
2537
if not raw:
2638
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
@@ -77,31 +89,64 @@ def _read_response(self, disable_decoding=False):
7789
response = {
7890
self._read_response(
7991
disable_decoding=disable_decoding
80-
): self._read_response(disable_decoding=disable_decoding)
92+
): self._read_response(
93+
disable_decoding=disable_decoding, push_request=push_request
94+
)
8195
for _ in range(int(response))
8296
}
97+
# push response
98+
elif byte == b">":
99+
response = [
100+
self._read_response(
101+
disable_decoding=disable_decoding, push_request=push_request
102+
)
103+
for _ in range(int(response))
104+
]
105+
res = self.push_handler_func(response)
106+
if not push_request:
107+
return self._read_response(
108+
disable_decoding=disable_decoding, push_request=push_request
109+
)
110+
else:
111+
return res
83112
else:
84113
raise InvalidResponse(f"Protocol Error: {raw!r}")
85114

86115
if isinstance(response, bytes) and disable_decoding is False:
87116
response = self.encoder.decode(response)
88117
return response
89118

119+
def set_push_handler(self, push_handler_func):
120+
self.push_handler_func = push_handler_func
121+
90122

91123
class _AsyncRESP3Parser(_AsyncRESPBase):
92-
async def read_response(self, disable_decoding: bool = False):
124+
def __init__(self, socket_read_size):
125+
super().__init__(socket_read_size)
126+
self.push_handler_func = self.handle_push_response
127+
128+
def handle_push_response(self, response):
129+
logger = getLogger("push_response")
130+
logger.info("Push response: " + str(response))
131+
return response
132+
133+
async def read_response(
134+
self, disable_decoding: bool = False, push_request: bool = False
135+
):
93136
if self._chunks:
94137
# augment parsing buffer with previously read data
95138
self._buffer += b"".join(self._chunks)
96139
self._chunks.clear()
97140
self._pos = 0
98-
response = await self._read_response(disable_decoding=disable_decoding)
141+
response = await self._read_response(
142+
disable_decoding=disable_decoding, push_request=push_request
143+
)
99144
# Successfully parsing a response allows us to clear our parsing buffer
100145
self._clear()
101146
return response
102147

103148
async def _read_response(
104-
self, disable_decoding: bool = False
149+
self, disable_decoding: bool = False, push_request: bool = False
105150
) -> Union[EncodableT, ResponseError, None]:
106151
if not self._stream or not self.encoder:
107152
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
@@ -166,9 +211,31 @@ async def _read_response(
166211
)
167212
for _ in range(int(response))
168213
}
214+
# push response
215+
elif byte == b">":
216+
response = [
217+
(
218+
await self._read_response(
219+
disable_decoding=disable_decoding, push_request=push_request
220+
)
221+
)
222+
for _ in range(int(response))
223+
]
224+
res = self.push_handler_func(response)
225+
if not push_request:
226+
return await (
227+
self._read_response(
228+
disable_decoding=disable_decoding, push_request=push_request
229+
)
230+
)
231+
else:
232+
return res
169233
else:
170234
raise InvalidResponse(f"Protocol Error: {raw!r}")
171235

172236
if isinstance(response, bytes) and disable_decoding is False:
173237
response = self.encoder.decode(response)
174238
return response
239+
240+
def set_push_handler(self, push_handler_func):
241+
self.push_handler_func = push_handler_func

redis/utils.py

+14
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from contextlib import contextmanager
23
from functools import wraps
34
from typing import Any, Dict, Mapping, Union
@@ -117,3 +118,16 @@ def wrapper(*args, **kwargs):
117118
return wrapper
118119

119120
return decorator
121+
122+
123+
def _set_info_logger():
124+
"""
125+
Set up a logger that log info logs to stdout.
126+
(This is used by the default push response handler)
127+
"""
128+
if "push_response" not in logging.root.manager.loggerDict.keys():
129+
logger = logging.getLogger("push_response")
130+
logger.setLevel(logging.INFO)
131+
handler = logging.StreamHandler()
132+
handler.setLevel(logging.INFO)
133+
logger.addHandler(handler)

0 commit comments

Comments
 (0)