Skip to content

Commit

Permalink
Attempt to reconnect on a gateway TimeoutError
Browse files Browse the repository at this point in the history
  - Improve `GatewayShard.is_alive` detection
  - Cleanup code for `GatewayTransport.send_close`
  • Loading branch information
davfsa committed Feb 8, 2022
1 parent 92f5be6 commit f993071
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 87 deletions.
1 change: 1 addition & 0 deletions changes/1014.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Attempt to reconnect on a gateway `TimeoutError`.
38 changes: 20 additions & 18 deletions hikari/impl/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,21 +135,22 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
self.logger: logging.Logger
self.log_filterer: typing.Callable[[str], str]

async def send_close(self, *, code: int = 1000, message: bytes = b"") -> bool:
async def send_close(self, *, code: int = 1000, message: bytes = b"") -> None:
# aiohttp may close the socket by invoking close() internally. By giving
# a different name, we can ensure aiohttp won't invoke this method.
# We can then guarantee any call to this method was made by us, as
# opposed to, for example, Windows injecting a spurious EOF when
# something disconnects, which makes aiohttp just shut down as if we
# did it.
if not self.sent_close:
self.sent_close = True
self.logger.debug("sending close frame with code %s and message %s", int(code), message)
try:
return await asyncio.wait_for(super().close(code=code, message=message), timeout=5)
except asyncio.TimeoutError:
self.logger.debug("failed to send close frame in time, probably connection issues")
return False
if self.sent_close:
return

self.sent_close = True
self.logger.debug("sending close frame with code %s and message %s", int(code), message)
try:
await asyncio.wait_for(super().close(code=code, message=message), timeout=5)
except asyncio.TimeoutError:
self.logger.debug("failed to send close frame in time, probably connection issues")

async def receive_json(
self,
Expand Down Expand Up @@ -296,17 +297,18 @@ async def connect(
message=b"client is shutting down",
)

except (aiohttp.ClientOSError, aiohttp.ClientConnectionError, aiohttp.WSServerHandshakeError) as ex:
except (
aiohttp.ClientOSError,
aiohttp.ClientConnectionError,
aiohttp.WSServerHandshakeError,
asyncio.TimeoutError,
) as ex:
# Windows will sometimes raise an aiohttp.ClientOSError
# If we cannot do DNS lookup, this will fail with a ClientConnectionError
# usually.
# usually, but it might also fail with asyncio.TimeoutError if its gets stuck in a weird way.
#
# aiohttp.WSServerHandshakeError has a really bad str, so we use the repr instead.
if isinstance(ex, aiohttp.WSServerHandshakeError):
reason = repr(ex)
else:
reason = str(ex)

# aiohttp.WSServerHandshakeError has a really bad str, so we use the repr instead
reason = repr(ex) if isinstance(ex, aiohttp.WSServerHandshakeError) else str(ex)
raise errors.GatewayConnectionError(reason) from None

finally:
Expand Down Expand Up @@ -495,7 +497,7 @@ def intents(self) -> intents_.Intents:

@property
def is_alive(self) -> bool:
return self._run_task is not None and not self._run_task.done()
return self._ws is not None and not self._ws.sent_close

@property
def shard_count(self) -> int:
Expand Down
94 changes: 25 additions & 69 deletions tests/hikari/impl/test_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import contextlib
import datetime
import platform
import re

import aiohttp
import mock
Expand All @@ -35,7 +36,6 @@
from hikari import presences
from hikari import undefined
from hikari.impl import shard
from hikari.internal import aio
from hikari.internal import time
from tests.hikari import client_session_stub
from tests.hikari import hikari_test_helpers
Expand Down Expand Up @@ -94,7 +94,7 @@ async def test_send_close_when_not_closed_nor_closing_logs(self, transport_impl)

with mock.patch.object(aiohttp.ClientWebSocketResponse, "close", new=mock.Mock()) as close:
with mock.patch.object(asyncio, "wait_for", return_value=mock.AsyncMock()) as wait_for:
assert await transport_impl.send_close(code=1234, message=b"some message") is wait_for.return_value
await transport_impl.send_close(code=1234, message=b"some message")

wait_for.assert_awaited_once_with(close.return_value, timeout=5)
close.assert_called_once_with(code=1234, message=b"some message")
Expand All @@ -104,7 +104,7 @@ async def test_send_close_when_TimeoutError(self, transport_impl):
transport_impl.sent_close = False

with mock.patch.object(aiohttp.ClientWebSocketResponse, "close", side_effect=asyncio.TimeoutError) as close:
assert await transport_impl.send_close(code=1234, message=b"some message") is False
await transport_impl.send_close(code=1234, message=b"some message")

close.assert_called_once_with(code=1234, message=b"some message")

Expand Down Expand Up @@ -468,54 +468,28 @@ def __init__(self):
mock_websocket.assert_used_once()

@pytest.mark.asyncio()
async def test_connect_when_error_connecting(self, http_settings, proxy_settings):
mock_client_session = hikari_test_helpers.AsyncContextManagerMock()
mock_client_session.ws_connect = mock.MagicMock(side_effect=aiohttp.ClientConnectionError("some error"))

stack = contextlib.ExitStack()
sleep = stack.enter_context(mock.patch.object(asyncio, "sleep"))
stack.enter_context(mock.patch.object(aiohttp, "ClientSession", return_value=mock_client_session))
stack.enter_context(mock.patch.object(aiohttp, "TCPConnector"))
stack.enter_context(mock.patch.object(aiohttp, "ClientTimeout"))
stack.enter_context(
pytest.raises(errors.GatewayConnectionError, match=r"Failed to connect to server: 'some error'")
)
logger = mock.Mock()
log_filterer = mock.Mock()

with stack:
async with shard._GatewayTransport.connect(
http_settings=http_settings,
proxy_settings=proxy_settings,
logger=logger,
url="https://some.url",
log_filterer=log_filterer,
):
pass

sleep.assert_awaited_once_with(0.25)
mock_client_session.assert_used_once()

@pytest.mark.asyncio()
async def test_connect_when_handshake_error_with_unknown_reason(self, http_settings, proxy_settings):
@pytest.mark.parametrize(
"error",
[
aiohttp.WSServerHandshakeError(status=123, message="some error", request_info=None, history=None),
aiohttp.ClientOSError("some error"),
aiohttp.ClientConnectionError("some error"),
asyncio.TimeoutError("some error"),
],
)
async def test_connect_when_expected_error_connecting(self, http_settings, proxy_settings, error):
mock_client_session = hikari_test_helpers.AsyncContextManagerMock()
mock_client_session.ws_connect = mock.MagicMock(
side_effect=aiohttp.WSServerHandshakeError(
status=123, message="some error", request_info=None, history=None
)
)
mock_client_session.ws_connect = mock.MagicMock(side_effect=error)

stack = contextlib.ExitStack()
sleep = stack.enter_context(mock.patch.object(asyncio, "sleep"))
stack.enter_context(mock.patch.object(aiohttp, "ClientSession", return_value=mock_client_session))
stack.enter_context(mock.patch.object(aiohttp, "TCPConnector"))
stack.enter_context(mock.patch.object(aiohttp, "ClientTimeout"))
err_string = repr(error) if isinstance(error, aiohttp.WSServerHandshakeError) else str(error)
stack.enter_context(
pytest.raises(
errors.GatewayConnectionError,
match=(
r'Failed to connect to server: "WSServerHandshakeError\(None, None, status=123, message=\'some error\'\)"'
),
errors.GatewayConnectionError, match=re.escape(f"Failed to connect to server: {err_string!r}")
)
)
logger = mock.Mock()
Expand All @@ -535,27 +509,16 @@ async def test_connect_when_handshake_error_with_unknown_reason(self, http_setti
mock_client_session.assert_used_once()

@pytest.mark.asyncio()
async def test_connect_when_handshake_error_with_known_reason(self, http_settings, proxy_settings):
async def test_connect_when_unexpected_error_connecting(self, http_settings, proxy_settings):
mock_client_session = hikari_test_helpers.AsyncContextManagerMock()
mock_client_session.ws_connect = mock.MagicMock(
side_effect=aiohttp.WSServerHandshakeError(
status=500, message="some error", request_info=None, history=None
)
)
mock_client_session.ws_connect = mock.MagicMock(side_effect=RuntimeError("in tests"))

stack = contextlib.ExitStack()
sleep = stack.enter_context(mock.patch.object(asyncio, "sleep"))
stack.enter_context(mock.patch.object(aiohttp, "ClientSession", return_value=mock_client_session))
stack.enter_context(mock.patch.object(aiohttp, "TCPConnector"))
stack.enter_context(mock.patch.object(aiohttp, "ClientTimeout"))
stack.enter_context(
pytest.raises(
errors.GatewayConnectionError,
match=(
r'Failed to connect to server: "WSServerHandshakeError\(None, None, status=500, message=\'some error\'\)"'
),
)
)
stack.enter_context(pytest.raises(RuntimeError, match="in tests"))
logger = mock.Mock()
log_filterer = mock.Mock()

Expand Down Expand Up @@ -643,19 +606,12 @@ def test_intents_property(self, client):
client._intents = mock_intents
assert client.intents is mock_intents

def test_is_alive_property(self, client):
client._run_task = None
assert client.is_alive is False

@pytest.mark.asyncio()
async def test_is_alive_property_with_active_future(self, client):
client._run_task = asyncio.get_running_loop().create_future()
assert client.is_alive is True

@pytest.mark.asyncio()
async def test_is_alive_property_with_finished_future(self, client):
client._run_task = aio.completed_future()
assert client.is_alive is False
@pytest.mark.parametrize(
("ws", "expected"), [(None, False), (mock.Mock(sent_close=True), False), (mock.Mock(sent_close=False), True)]
)
def test_is_alive_property(self, client, ws, expected):
client._ws = ws
assert client.is_alive is expected

def test_shard_count_property(self, client):
client._shard_count = 69
Expand Down

0 comments on commit f993071

Please sign in to comment.