From 72c6a593f4faae8aa5c366bf6ebe982573359c85 Mon Sep 17 00:00:00 2001 From: dolfies Date: Wed, 3 Jan 2024 14:19:15 -0500 Subject: [PATCH] Add websocket send lock --- discord/gateway.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/discord/gateway.py b/discord/gateway.py index ef54117cf8ab..2d21e4bca5e0 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -350,6 +350,7 @@ def __init__(self, socket: WebSocket, *, loop: asyncio.AbstractEventLoop) -> Non self._buffer: bytearray = bytearray() self._close_code: Optional[int] = None self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter() + self._send_lock: asyncio.Lock = asyncio.Lock() @property def open(self) -> bool: @@ -666,8 +667,10 @@ async def poll_event(self) -> None: elif flags & CurlWsFlag.CLOSE: _log.debug('Received %s.', msg) err = WebSocketClosure(msg) + _log.info(f'Got close {err.code} reason {err.reason}') raise WebSocketClosure(msg) except (asyncio.TimeoutError, CurlError, WebSocketClosure) as e: + _log.info(f'Got poll exception {e}') # Ensure the keep alive handler is closed if self._keep_alive: self._keep_alive.stop() @@ -683,6 +686,8 @@ async def poll_event(self) -> None: _log.debug('Received error %s', e) reason = str(e) + _log.info(f'Got code {code} and reason {reason}') + if self._can_handle_close(code or None): _log.debug('Websocket closed with %s, attempting a reconnect.', code) raise ReconnectWebSocket from None @@ -691,7 +696,8 @@ async def poll_event(self) -> None: raise ConnectionClosed(code, reason) from None async def _sendstr(self, data: str, /) -> None: - await self.socket.asend(data.encode('utf-8')) + async with self._send_lock: + await self.socket.asend(data.encode('utf-8')) async def debug_send(self, data: str, /) -> None: await self._rate_limiter.block() @@ -893,6 +899,7 @@ async def search_recent_members( await self.send_as_json(payload) async def close(self, code: int = 4000, reason: bytes = b'') -> None: + _log.info(f'Closing websocket with code {code}') if self._keep_alive: self._keep_alive.stop() self._keep_alive = None @@ -905,6 +912,7 @@ async def close(self, code: int = 4000, reason: bytes = b'') -> None: await socket.asend(data, CurlWsFlag.CLOSE) socket.keep_running = False await self.loop.run_in_executor(None, socket.curl.close) # TODO: Do I need an executor here? + _log.info('Finished closing websocket') DVWS = TypeVar('DVWS', bound='DiscordVoiceWebSocket') @@ -974,6 +982,7 @@ def __init__( self._keep_alive: Optional[VoiceKeepAliveHandler] = None self._close_code: Optional[int] = None self.secret_key: Optional[str] = None + self._send_lock: asyncio.Lock = asyncio.Lock() if hook: self._hook = hook @@ -981,7 +990,8 @@ async def _hook(self, *args: Any) -> None: pass async def _sendstr(self, data: str, /) -> None: - await self.ws.asend(data.encode('utf-8')) + async with self._send_lock: + await self.ws.asend(data.encode('utf-8')) async def send_as_json(self, data: Any) -> None: _log.debug('Voice gateway sending: %s.', data)