diff --git a/changes/1744.bugfix.md b/changes/1744.bugfix.md new file mode 100644 index 0000000000..3596046f36 --- /dev/null +++ b/changes/1744.bugfix.md @@ -0,0 +1 @@ +Ensure shard connect and disconnect always get sent in pairs and properly waited for diff --git a/hikari/events/shard_events.py b/hikari/events/shard_events.py index e427ac532b..4c79024db8 100644 --- a/hikari/events/shard_events.py +++ b/hikari/events/shard_events.py @@ -100,7 +100,7 @@ class ShardStateEvent(ShardEvent, abc.ABC): @attrs_extensions.with_copy @attrs.define(kw_only=True, weakref_slot=False) class ShardConnectedEvent(ShardStateEvent): - """Event fired when a shard connects.""" + """Event fired when a shard successfully connects.""" app: traits.RESTAware = attrs.field(metadata={attrs_extensions.SKIP_DEEP_COPY: True}) # <>. diff --git a/hikari/impl/rest.py b/hikari/impl/rest.py index fbf8ae43f0..a24ef7bf40 100644 --- a/hikari/impl/rest.py +++ b/hikari/impl/rest.py @@ -731,7 +731,7 @@ async def _request( await aio.first_completed(request_task, self._close_event.wait()) - if not self._close_event.is_set(): + if not request_task.cancelled(): return request_task.result() raise errors.ComponentStateConflictError("The REST client was closed mid-request") diff --git a/hikari/impl/shard.py b/hikari/impl/shard.py index 5738941cf2..7bbe4241a9 100644 --- a/hikari/impl/shard.py +++ b/hikari/impl/shard.py @@ -818,7 +818,6 @@ async def _connect(self) -> typing.Tuple[asyncio.Task[None], ...]: dumps=self._dumps, url=url, ) - self._event_manager.dispatch(self._event_factory.deserialize_connected_event(self)) # Expect initial HELLO hello_payload = await self._ws.receive_json() @@ -893,6 +892,7 @@ async def _keep_alive(self) -> None: if not self._handshake_event.is_set(): continue + await self._event_manager.dispatch(self._event_factory.deserialize_connected_event(self)) await aio.first_completed(*lifetime_tasks) # Since nothing went wrong, we can reset the backoff and try again @@ -957,7 +957,9 @@ async def _keep_alive(self) -> None: else: await ws.send_close(code=_RESUME_CLOSE_CODE, message=b"shard disconnecting temporarily") - self._event_manager.dispatch(self._event_factory.deserialize_disconnected_event(self)) + if self._handshake_event.is_set(): + # We dispatched the connected event, so we can dispatch the disconnected one too + await self._event_manager.dispatch(self._event_factory.deserialize_disconnected_event(self)) def _serialize_and_store_presence_payload( self, diff --git a/tests/hikari/impl/test_shard.py b/tests/hikari/impl/test_shard.py index 5b10ab1b0a..af1bc377e8 100644 --- a/tests/hikari/impl/test_shard.py +++ b/tests/hikari/impl/test_shard.py @@ -1014,10 +1014,6 @@ async def test__connect_when_not_reconnecting(self, client, http_settings, proxy dumps=client._dumps, url="wss://somewhere.com?somewhere=true&v=400&encoding=json", ) - client._event_factory.deserialize_connected_event.assert_called_once_with(client) - client._event_manager.dispatch.assert_called_once_with( - client._event_factory.deserialize_connected_event.return_value - ) assert create_task.call_count == 2 create_task.assert_has_calls( @@ -1103,10 +1099,6 @@ async def test__connect_when_reconnecting(self, client, http_settings, proxy_set transport_compression=True, url="wss://notsomewhere.com?somewhere=true&v=400&encoding=json&compress=zlib-stream", ) - client._event_factory.deserialize_connected_event.assert_called_once_with(client) - client._event_manager.dispatch.assert_called_once_with( - client._event_factory.deserialize_connected_event.return_value - ) assert create_task.call_count == 2 create_task.assert_has_calls(