Skip to content

Commit

Permalink
Add websocket send lock
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfies committed Jan 3, 2024
1 parent 9ad6d7d commit 72c6a59
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions discord/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def __init__(self, socket: WebSocket, *, loop: asyncio.AbstractEventLoop) -> Non
self._buffer: bytearray = bytearray()
self._close_code: Optional[int] = None
self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter()
self._send_lock: asyncio.Lock = asyncio.Lock()

@property
def open(self) -> bool:
Expand Down Expand Up @@ -666,8 +667,10 @@ async def poll_event(self) -> None:
elif flags & CurlWsFlag.CLOSE:
_log.debug('Received %s.', msg)
err = WebSocketClosure(msg)
_log.info(f'Got close {err.code} reason {err.reason}')
raise WebSocketClosure(msg)
except (asyncio.TimeoutError, CurlError, WebSocketClosure) as e:
_log.info(f'Got poll exception {e}')
# Ensure the keep alive handler is closed
if self._keep_alive:
self._keep_alive.stop()
Expand All @@ -683,6 +686,8 @@ async def poll_event(self) -> None:
_log.debug('Received error %s', e)
reason = str(e)

_log.info(f'Got code {code} and reason {reason}')

if self._can_handle_close(code or None):
_log.debug('Websocket closed with %s, attempting a reconnect.', code)
raise ReconnectWebSocket from None
Expand All @@ -691,7 +696,8 @@ async def poll_event(self) -> None:
raise ConnectionClosed(code, reason) from None

async def _sendstr(self, data: str, /) -> None:
await self.socket.asend(data.encode('utf-8'))
async with self._send_lock:
await self.socket.asend(data.encode('utf-8'))

async def debug_send(self, data: str, /) -> None:
await self._rate_limiter.block()
Expand Down Expand Up @@ -893,6 +899,7 @@ async def search_recent_members(
await self.send_as_json(payload)

async def close(self, code: int = 4000, reason: bytes = b'') -> None:
_log.info(f'Closing websocket with code {code}')
if self._keep_alive:
self._keep_alive.stop()
self._keep_alive = None
Expand All @@ -905,6 +912,7 @@ async def close(self, code: int = 4000, reason: bytes = b'') -> None:
await socket.asend(data, CurlWsFlag.CLOSE)
socket.keep_running = False
await self.loop.run_in_executor(None, socket.curl.close) # TODO: Do I need an executor here?
_log.info('Finished closing websocket')


DVWS = TypeVar('DVWS', bound='DiscordVoiceWebSocket')
Expand Down Expand Up @@ -974,14 +982,16 @@ def __init__(
self._keep_alive: Optional[VoiceKeepAliveHandler] = None
self._close_code: Optional[int] = None
self.secret_key: Optional[str] = None
self._send_lock: asyncio.Lock = asyncio.Lock()
if hook:
self._hook = hook

async def _hook(self, *args: Any) -> None:
pass

async def _sendstr(self, data: str, /) -> None:
await self.ws.asend(data.encode('utf-8'))
async with self._send_lock:
await self.ws.asend(data.encode('utf-8'))

async def send_as_json(self, data: Any) -> None:
_log.debug('Voice gateway sending: %s.', data)
Expand Down

0 comments on commit 72c6a59

Please sign in to comment.