Skip to content

Commit

Permalink
Add nowait flag to asyncio.Connection.disconnect() (#2356)
Browse files Browse the repository at this point in the history
* Don't wait for disconnect() when handling errors.
This can result in other errors such as timeouts.

* add CHANGES

* Update redis/asyncio/connection.py

Co-authored-by: Aarni Koskela <akx@iki.fi>

* await a task to try to diagnose unittest failures in CI

Co-authored-by: Aarni Koskela <akx@iki.fi>
  • Loading branch information
kristjanvalur and akx authored Sep 29, 2022
1 parent 9fe8366 commit 652ca79
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 40 deletions.
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* add `nowait` flag to `asyncio.Connection.disconnect()`
* Update README.md links
* Fix timezone handling for datetime to unixtime conversions
* Fix start_id type for XAUTOCLAIM
Expand Down
21 changes: 11 additions & 10 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ async def on_connect(self) -> None:
if str_if_bytes(await self.read_response()) != "OK":
raise ConnectionError("Invalid Database")

async def disconnect(self) -> None:
async def disconnect(self, nowait: bool = False) -> None:
"""Disconnects from the Redis server"""
try:
async with async_timeout.timeout(self.socket_connect_timeout):
Expand All @@ -846,8 +846,9 @@ async def disconnect(self) -> None:
try:
if os.getpid() == self.pid:
self._writer.close() # type: ignore[union-attr]
# py3.6 doesn't have this method
if hasattr(self._writer, "wait_closed"):
# wait for close to finish, except when handling errors and
# forcefully disconnecting.
if not nowait:
await self._writer.wait_closed() # type: ignore[union-attr]
except OSError:
pass
Expand Down Expand Up @@ -902,10 +903,10 @@ async def send_packed_command(
self._writer.writelines(command)
await self._writer.drain()
except asyncio.TimeoutError:
await self.disconnect()
await self.disconnect(nowait=True)
raise TimeoutError("Timeout writing to socket") from None
except OSError as e:
await self.disconnect()
await self.disconnect(nowait=True)
if len(e.args) == 1:
err_no, errmsg = "UNKNOWN", e.args[0]
else:
Expand All @@ -915,7 +916,7 @@ async def send_packed_command(
f"Error {err_no} while writing to socket. {errmsg}."
) from e
except Exception:
await self.disconnect()
await self.disconnect(nowait=True)
raise

async def send_command(self, *args: Any, **kwargs: Any) -> None:
Expand All @@ -931,7 +932,7 @@ async def can_read(self, timeout: float = 0):
try:
return await self._parser.can_read(timeout)
except OSError as e:
await self.disconnect()
await self.disconnect(nowait=True)
raise ConnectionError(
f"Error while reading from {self.host}:{self.port}: {e.args}"
)
Expand All @@ -949,15 +950,15 @@ async def read_response(self, disable_decoding: bool = False):
disable_decoding=disable_decoding
)
except asyncio.TimeoutError:
await self.disconnect()
await self.disconnect(nowait=True)
raise TimeoutError(f"Timeout reading from {self.host}:{self.port}")
except OSError as e:
await self.disconnect()
await self.disconnect(nowait=True)
raise ConnectionError(
f"Error while reading from {self.host}:{self.port} : {e.args}"
)
except Exception:
await self.disconnect()
await self.disconnect(nowait=True)
raise

if self.health_check_interval:
Expand Down
73 changes: 43 additions & 30 deletions tests/test_asyncio/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,14 +819,16 @@ async def mysetup(self, r, method):
"type": "subscribe",
}

async def mycleanup(self):
async def myfinish(self):
message = await self.messages.get()
assert message == {
"channel": b"foo",
"data": 1,
"pattern": None,
"type": "subscribe",
}

async def mykill(self):
# kill thread
async with self.cond:
self.state = 4 # quit
Expand All @@ -836,41 +838,52 @@ async def test_reconnect_socket_error(self, r: redis.Redis, method):
"""
Test that a socket error will cause reconnect
"""
async with async_timeout.timeout(self.timeout):
await self.mysetup(r, method)
# now, disconnect the connection, and wait for it to be re-established
async with self.cond:
assert self.state == 0
self.state = 1
with mock.patch.object(self.pubsub.connection, "_parser") as mockobj:
mockobj.read_response.side_effect = socket.error
mockobj.can_read.side_effect = socket.error
# wait until task noticies the disconnect until we undo the patch
await self.cond.wait_for(lambda: self.state >= 2)
assert not self.pubsub.connection.is_connected
# it is in a disconnecte state
# wait for reconnect
await self.cond.wait_for(lambda: self.pubsub.connection.is_connected)
assert self.state == 3
try:
async with async_timeout.timeout(self.timeout):
await self.mysetup(r, method)
# now, disconnect the connection, and wait for it to be re-established
async with self.cond:
assert self.state == 0
self.state = 1
with mock.patch.object(self.pubsub.connection, "_parser") as m:
m.read_response.side_effect = socket.error
m.can_read.side_effect = socket.error
# wait until task noticies the disconnect until we
# undo the patch
await self.cond.wait_for(lambda: self.state >= 2)
assert not self.pubsub.connection.is_connected
# it is in a disconnecte state
# wait for reconnect
await self.cond.wait_for(
lambda: self.pubsub.connection.is_connected
)
assert self.state == 3

await self.mycleanup()
await self.myfinish()
finally:
await self.mykill()

async def test_reconnect_disconnect(self, r: redis.Redis, method):
"""
Test that a manual disconnect() will cause reconnect
"""
async with async_timeout.timeout(self.timeout):
await self.mysetup(r, method)
# now, disconnect the connection, and wait for it to be re-established
async with self.cond:
self.state = 1
await self.pubsub.connection.disconnect()
assert not self.pubsub.connection.is_connected
# wait for reconnect
await self.cond.wait_for(lambda: self.pubsub.connection.is_connected)
assert self.state == 3

await self.mycleanup()
try:
async with async_timeout.timeout(self.timeout):
await self.mysetup(r, method)
# now, disconnect the connection, and wait for it to be re-established
async with self.cond:
self.state = 1
await self.pubsub.connection.disconnect()
assert not self.pubsub.connection.is_connected
# wait for reconnect
await self.cond.wait_for(
lambda: self.pubsub.connection.is_connected
)
assert self.state == 3

await self.myfinish()
finally:
await self.mykill()

async def loop(self):
# reader loop, performing state transitions as it
Expand Down

0 comments on commit 652ca79

Please sign in to comment.