Skip to content

Commit

Permalink
Simplify async timeouts and allowing timeout=None in `PubSub.get_me…
Browse files Browse the repository at this point in the history
…ssage()` to wait forever (#2295)

* Avoid an extra "can_read" call and use timeout directly.

* Remove low-level read timeouts from the Parser, now handled in the Connection

* Allow pubsub.get_message(time=None) to block.

* update Changes

* increase test timeout for robustness

* expand with statement to avoid invoking null context managers.

remove nullcontext

* Remove unused import
  • Loading branch information
kristjanvalur authored Sep 29, 2022
1 parent cdbc662 commit b0883b7
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 112 deletions.
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Allow `timeout=None` in `PubSub.get_message()` to wait forever
* add `nowait` flag to `asyncio.Connection.disconnect()`
* Update README.md links
* Fix timezone handling for datetime to unixtime conversions
Expand Down
22 changes: 5 additions & 17 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
cast,
)

import async_timeout

from redis.asyncio.connection import (
Connection,
ConnectionPool,
Expand Down Expand Up @@ -759,18 +757,8 @@ async def parse_response(self, block: bool = True, timeout: float = 0):
if not conn.is_connected:
await conn.connect()

if not block:

async def read_with_timeout():
try:
async with async_timeout.timeout(timeout):
return await conn.read_response()
except asyncio.TimeoutError:
return None

response = await self._execute(conn, read_with_timeout)
else:
response = await self._execute(conn, conn.read_response)
read_timeout = None if block else timeout
response = await self._execute(conn, conn.read_response, timeout=read_timeout)

if conn.health_check_interval and response == self.health_check_response:
# ignore the health check message as user might not expect it
Expand Down Expand Up @@ -882,16 +870,16 @@ async def listen(self) -> AsyncIterator:
yield response

async def get_message(
self, ignore_subscribe_messages: bool = False, timeout: float = 0.0
self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
):
"""
Get the next message if one is available, otherwise None.
If timeout is specified, the system will wait for `timeout` seconds
before returning. Timeout should be specified as a floating point
number.
number or None to wait indefinitely.
"""
response = await self.parse_response(block=False, timeout=timeout)
response = await self.parse_response(block=(timeout is None), timeout=timeout)
if response:
return await self.handle_message(response, ignore_subscribe_messages)
return None
Expand Down
138 changes: 47 additions & 91 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import copy
import enum
import errno
import inspect
import io
import os
Expand Down Expand Up @@ -55,16 +54,6 @@
if HIREDIS_AVAILABLE:
import hiredis

NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {
BlockingIOError: errno.EWOULDBLOCK,
ssl.SSLWantReadError: 2,
ssl.SSLWantWriteError: 2,
ssl.SSLError: 2,
}

NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys())


SYM_STAR = b"*"
SYM_DOLLAR = b"$"
SYM_CRLF = b"\r\n"
Expand Down Expand Up @@ -229,11 +218,9 @@ def __init__(
self,
stream_reader: asyncio.StreamReader,
socket_read_size: int,
socket_timeout: Optional[float],
):
self._stream: Optional[asyncio.StreamReader] = stream_reader
self.socket_read_size = socket_read_size
self.socket_timeout = socket_timeout
self._buffer: Optional[io.BytesIO] = io.BytesIO()
# number of bytes written to the buffer from the socket
self.bytes_written = 0
Expand All @@ -244,52 +231,35 @@ def __init__(
def length(self):
return self.bytes_written - self.bytes_read

async def _read_from_socket(
self,
length: Optional[int] = None,
timeout: Union[float, None, _Sentinel] = SENTINEL,
raise_on_timeout: bool = True,
) -> bool:
async def _read_from_socket(self, length: Optional[int] = None) -> bool:
buf = self._buffer
if buf is None or self._stream is None:
raise RedisError("Buffer is closed.")
buf.seek(self.bytes_written)
marker = 0
timeout = timeout if timeout is not SENTINEL else self.socket_timeout

try:
while True:
async with async_timeout.timeout(timeout):
data = await self._stream.read(self.socket_read_size)
# an empty string indicates the server shutdown the socket
if isinstance(data, bytes) and len(data) == 0:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
buf.write(data)
data_length = len(data)
self.bytes_written += data_length
marker += data_length

if length is not None and length > marker:
continue
return True
except (socket.timeout, asyncio.TimeoutError):
if raise_on_timeout:
raise TimeoutError("Timeout reading from socket")
return False
except NONBLOCKING_EXCEPTIONS as ex:
# if we're in nonblocking mode and the recv raises a
# blocking error, simply return False indicating that
# there's no data to be read. otherwise raise the
# original exception.
allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
if not raise_on_timeout and ex.errno == allowed:
return False
raise ConnectionError(f"Error while reading from socket: {ex.args}")
while True:
data = await self._stream.read(self.socket_read_size)
# an empty string indicates the server shutdown the socket
if isinstance(data, bytes) and len(data) == 0:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
buf.write(data)
data_length = len(data)
self.bytes_written += data_length
marker += data_length

if length is not None and length > marker:
continue
return True

async def can_read_destructive(self) -> bool:
return bool(self.length) or await self._read_from_socket(
timeout=0, raise_on_timeout=False
)
if self.length:
return True
try:
async with async_timeout.timeout(0):
return await self._read_from_socket()
except asyncio.TimeoutError:
return False

async def read(self, length: int) -> bytes:
length = length + 2 # make sure to read the \r\n terminator
Expand Down Expand Up @@ -372,9 +342,7 @@ def on_connect(self, connection: "Connection"):
if self._stream is None:
raise RedisError("Buffer is closed.")

self._buffer = SocketBuffer(
self._stream, self._read_size, connection.socket_timeout
)
self._buffer = SocketBuffer(self._stream, self._read_size)
self.encoder = connection.encoder

def on_disconnect(self):
Expand Down Expand Up @@ -444,14 +412,13 @@ async def read_response(
class HiredisParser(BaseParser):
"""Parser class for connections using Hiredis"""

__slots__ = BaseParser.__slots__ + ("_reader", "_socket_timeout")
__slots__ = BaseParser.__slots__ + ("_reader",)

def __init__(self, socket_read_size: int):
if not HIREDIS_AVAILABLE:
raise RedisError("Hiredis is not available.")
super().__init__(socket_read_size=socket_read_size)
self._reader: Optional[hiredis.Reader] = None
self._socket_timeout: Optional[float] = None

def on_connect(self, connection: "Connection"):
self._stream = connection._reader
Expand All @@ -464,7 +431,6 @@ def on_connect(self, connection: "Connection"):
kwargs["errors"] = connection.encoder.encoding_errors

self._reader = hiredis.Reader(**kwargs)
self._socket_timeout = connection.socket_timeout

def on_disconnect(self):
self._stream = None
Expand All @@ -475,39 +441,20 @@ async def can_read_destructive(self):
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
if self._reader.gets():
return True
return await self.read_from_socket(timeout=0, raise_on_timeout=False)

async def read_from_socket(
self,
timeout: Union[float, None, _Sentinel] = SENTINEL,
raise_on_timeout: bool = True,
):
timeout = self._socket_timeout if timeout is SENTINEL else timeout
try:
if timeout is None:
buffer = await self._stream.read(self._read_size)
else:
async with async_timeout.timeout(timeout):
buffer = await self._stream.read(self._read_size)
if not buffer or not isinstance(buffer, bytes):
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
self._reader.feed(buffer)
# data was read from the socket and added to the buffer.
# return True to indicate that data was read.
return True
except (socket.timeout, asyncio.TimeoutError):
if raise_on_timeout:
raise TimeoutError("Timeout reading from socket") from None
async with async_timeout.timeout(0):
return await self.read_from_socket()
except asyncio.TimeoutError:
return False
except NONBLOCKING_EXCEPTIONS as ex:
# if we're in nonblocking mode and the recv raises a
# blocking error, simply return False indicating that
# there's no data to be read. otherwise raise the
# original exception.
allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
if not raise_on_timeout and ex.errno == allowed:
return False
raise ConnectionError(f"Error while reading from socket: {ex.args}")

async def read_from_socket(self):
buffer = await self._stream.read(self._read_size)
if not buffer or not isinstance(buffer, bytes):
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
self._reader.feed(buffer)
# data was read from the socket and added to the buffer.
# return True to indicate that data was read.
return True

async def read_response(
self, disable_decoding: bool = False
Expand Down Expand Up @@ -922,11 +869,16 @@ async def can_read_destructive(self):
f"Error while reading from {self.host}:{self.port}: {e.args}"
)

async def read_response(self, disable_decoding: bool = False):
async def read_response(
self,
disable_decoding: bool = False,
timeout: Optional[float] = None,
):
"""Read the response from a previously sent command"""
read_timeout = timeout if timeout is not None else self.socket_timeout
try:
if self.socket_timeout:
async with async_timeout.timeout(self.socket_timeout):
if read_timeout is not None:
async with async_timeout.timeout(read_timeout):
response = await self._parser.read_response(
disable_decoding=disable_decoding
)
Expand All @@ -935,6 +887,10 @@ async def read_response(self, disable_decoding: bool = False):
disable_decoding=disable_decoding
)
except asyncio.TimeoutError:
if timeout is not None:
# user requested timeout, return None
return None
# it was a self.socket_timeout error.
await self.disconnect(nowait=True)
raise TimeoutError(f"Timeout reading from {self.host}:{self.port}")
except OSError as e:
Expand Down
6 changes: 3 additions & 3 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,13 +1637,13 @@ def listen(self):
if response is not None:
yield response

def get_message(self, ignore_subscribe_messages=False, timeout=0):
def get_message(self, ignore_subscribe_messages=False, timeout=0.0):
"""
Get the next message if one is available, otherwise None.
If timeout is specified, the system will wait for `timeout` seconds
before returning. Timeout should be specified as a floating point
number.
number, or None, to wait indefinitely.
"""
if not self.subscribed:
# Wait for subscription
Expand All @@ -1659,7 +1659,7 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0):
# so no messages are available
return None

response = self.parse_response(block=False, timeout=timeout)
response = self.parse_response(block=(timeout is None), timeout=timeout)
if response:
return self.handle_message(response, ignore_subscribe_messages)
return None
Expand Down
2 changes: 1 addition & 1 deletion tests/test_asyncio/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def run(*args, **kwargs):
return wrapper


async def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False):
async def wait_for_message(pubsub, timeout=0.2, ignore_subscribe_messages=False):
now = asyncio.get_event_loop().time()
timeout = now + timeout
while now < timeout:
Expand Down

0 comments on commit b0883b7

Please sign in to comment.