Skip to content

Commit

Permalink
fix: replace async_timeout by asyncio.timeout (#2602)
Browse files Browse the repository at this point in the history
async_timeout does not support python 3.11
aio-libs/async-timeout#295

And have two years old annoying bugs:
aio-libs/async-timeout#229
#2551

Since asyncio.timeout has been shipped in python 3.11, we should start
using it.

Partially fixes 2551
  • Loading branch information
sileht authored Mar 16, 2023
1 parent 91ab12a commit 25e85e5
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Use asyncio.timeout() instead of async_timeout.timeout() for python >= 3.11 (#2602)
* Add test and fix async HiredisParser when reading during a disconnect() (#2349)
* Use hiredis-py pack_command if available.
* Support `.unlink()` in ClusterPipeline
Expand Down
21 changes: 13 additions & 8 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import socket
import ssl
import sys
import threading
import weakref
from itertools import chain
Expand All @@ -24,7 +25,11 @@
)
from urllib.parse import ParseResult, parse_qs, unquote, urlparse

import async_timeout
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
from asyncio import timeout as async_timeout
else:
from async_timeout import timeout as async_timeout


from redis.asyncio.retry import Retry
from redis.backoff import NoBackoff
Expand Down Expand Up @@ -242,7 +247,7 @@ async def can_read_destructive(self) -> bool:
if self._stream is None:
raise RedisError("Buffer is closed.")
try:
async with async_timeout.timeout(0):
async with async_timeout(0):
return await self._stream.read(1)
except asyncio.TimeoutError:
return False
Expand Down Expand Up @@ -380,7 +385,7 @@ async def can_read_destructive(self):
if self._reader.gets():
return True
try:
async with async_timeout.timeout(0):
async with async_timeout(0):
return await self.read_from_socket()
except asyncio.TimeoutError:
return False
Expand Down Expand Up @@ -635,7 +640,7 @@ async def connect(self):

async def _connect(self):
"""Create a TCP socket connection"""
async with async_timeout.timeout(self.socket_connect_timeout):
async with async_timeout(self.socket_connect_timeout):
reader, writer = await asyncio.open_connection(
host=self.host,
port=self.port,
Expand Down Expand Up @@ -722,7 +727,7 @@ async def on_connect(self) -> None:
async def disconnect(self, nowait: bool = False) -> None:
"""Disconnects from the Redis server"""
try:
async with async_timeout.timeout(self.socket_connect_timeout):
async with async_timeout(self.socket_connect_timeout):
self._parser.on_disconnect()
if not self.is_connected:
return
Expand Down Expand Up @@ -827,7 +832,7 @@ async def read_response(
read_timeout = timeout if timeout is not None else self.socket_timeout
try:
if read_timeout is not None:
async with async_timeout.timeout(read_timeout):
async with async_timeout(read_timeout):
response = await self._parser.read_response(
disable_decoding=disable_decoding
)
Expand Down Expand Up @@ -1118,7 +1123,7 @@ def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]:
return pieces

async def _connect(self):
async with async_timeout.timeout(self.socket_connect_timeout):
async with async_timeout(self.socket_connect_timeout):
reader, writer = await asyncio.open_unix_connection(path=self.path)
self._reader = reader
self._writer = writer
Expand Down Expand Up @@ -1589,7 +1594,7 @@ async def get_connection(self, command_name, *keys, **options):
# self.timeout then raise a ``ConnectionError``.
connection = None
try:
async with async_timeout.timeout(self.timeout):
async with async_timeout(self.timeout):
connection = await self.pool.get()
except (asyncio.QueueEmpty, asyncio.TimeoutError):
# Note that this is not caught by the redis client and will be
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
install_requires=[
'importlib-metadata >= 1.0; python_version < "3.8"',
'typing-extensions; python_version<"3.8"',
"async-timeout>=4.0.2",
'async-timeout>=4.0.2; python_version<"3.11"',
],
classifiers=[
"Development Status :: 5 - Production/Stable",
Expand Down
22 changes: 13 additions & 9 deletions tests/test_asyncio/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from typing import Optional
from unittest.mock import patch

import async_timeout
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
from asyncio import timeout as async_timeout
else:
from async_timeout import timeout as async_timeout

import pytest
import pytest_asyncio

Expand All @@ -21,7 +25,7 @@ def with_timeout(t):
def wrapper(corofunc):
@functools.wraps(corofunc)
async def run(*args, **kwargs):
async with async_timeout.timeout(t):
async with async_timeout(t):
return await corofunc(*args, **kwargs)

return run
Expand Down Expand Up @@ -648,7 +652,7 @@ async def test_reconnect_listen(self, r: redis.Redis, pubsub):

async def loop():
# must make sure the task exits
async with async_timeout.timeout(2):
async with async_timeout(2):
nonlocal interrupt
await pubsub.subscribe("foo")
while True:
Expand Down Expand Up @@ -677,7 +681,7 @@ async def loop_step():

task = asyncio.get_running_loop().create_task(loop())
# get the initial connect message
async with async_timeout.timeout(1):
async with async_timeout(1):
message = await messages.get()
assert message == {
"channel": b"foo",
Expand Down Expand Up @@ -776,7 +780,7 @@ def callback(message):
if n == 1:
break
await asyncio.sleep(0.1)
async with async_timeout.timeout(0.1):
async with async_timeout(0.1):
message = await messages.get()
task.cancel()
# we expect a cancelled error, not the Runtime error
Expand Down Expand Up @@ -839,7 +843,7 @@ async def test_reconnect_socket_error(self, r: redis.Redis, method):
Test that a socket error will cause reconnect
"""
try:
async with async_timeout.timeout(self.timeout):
async with async_timeout(self.timeout):
await self.mysetup(r, method)
# now, disconnect the connection, and wait for it to be re-established
async with self.cond:
Expand Down Expand Up @@ -868,7 +872,7 @@ async def test_reconnect_disconnect(self, r: redis.Redis, method):
Test that a manual disconnect() will cause reconnect
"""
try:
async with async_timeout.timeout(self.timeout):
async with async_timeout(self.timeout):
await self.mysetup(r, method)
# now, disconnect the connection, and wait for it to be re-established
async with self.cond:
Expand Down Expand Up @@ -923,7 +927,7 @@ async def loop_step_get_message(self):
async def loop_step_listen(self):
# get a single message via listen()
try:
async with async_timeout.timeout(0.1):
async with async_timeout(0.1):
async for message in self.pubsub.listen():
await self.messages.put(message)
return True
Expand All @@ -947,7 +951,7 @@ async def test_outer_timeout(self, r: redis.Redis):
assert pubsub.connection.is_connected

async def get_msg_or_timeout(timeout=0.1):
async with async_timeout.timeout(timeout):
async with async_timeout(timeout):
# blocking method to return messages
while True:
response = await pubsub.parse_response(block=True)
Expand Down

0 comments on commit 25e85e5

Please sign in to comment.