Skip to content

Commit

Permalink
Initial curl websocket migration
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfies committed Jan 1, 2024
1 parent 352ef23 commit 77bc225
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 79 deletions.
12 changes: 6 additions & 6 deletions discord/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from .utils import _get_as_snowflake

if TYPE_CHECKING:
from aiohttp import ClientResponse, ClientWebSocketResponse
from curl_cffi.requests import Response as CurlResponse
from aiohttp import ClientResponse
from curl_cffi.requests import Response as CurlResponse, WebSocket

Check failure on line 33 in discord/errors.py

View workflow job for this annotation

GitHub Actions / check

Import "WebSocket" is not accessed (reportUnusedImport)
from requests import Response
from typing_extensions import TypeGuard

Expand Down Expand Up @@ -301,10 +301,10 @@ class ConnectionClosed(ClientException):

__slots__ = ('code', 'reason')

def __init__(self, socket: ClientWebSocketResponse, *, code: Optional[int] = None):
def __init__(self, code: Optional[int] = None, reason: Optional[str] = None):
# This exception is just the same exception except
# reconfigured to subclass ClientException for users
self.code: int = code or socket.close_code or -1
self.code: int = code or -1
# aiohttp doesn't seem to consistently provide close reason
self.reason: str = ''
super().__init__(f'WebSocket closed with {self.code}')
self.reason: str = reason or ''
super().__init__(f'WebSocket closed with {self.code} (reason: {self.reason!r})')
141 changes: 88 additions & 53 deletions discord/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@

from typing import Any, Callable, Coroutine, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar

import aiohttp
from curl_cffi import CurlError
from curl_cffi.requests import WebSocket
from curl_cffi.const import CurlWsFlag
import yarl

from . import utils
from .activity import BaseActivity, Spotify
from .enums import SpeakingState
from .enums import SpeakingState, Status
from .errors import ConnectionClosed
from .flags import Capabilities

Expand All @@ -58,7 +60,6 @@

from .activity import ActivityTypes
from .client import Client
from .enums import Status
from .state import ConnectionState
from .types.snowflake import Snowflake
from .voice_client import VoiceClient
Expand All @@ -73,9 +74,25 @@ def __init__(self, *, resume: bool = True):


class WebSocketClosure(Exception):
"""An exception to make up for the fact that aiohttp doesn't signal closure."""
"""An exception to make up for the fact that curl doesn't signal closure.
pass
Attributes
-----------
code: :class:`int`
The close code of the websocket.
reason: :class:`str`
The reason provided for the closure.
"""

__slots__ = ('code', 'reason')

CLOSE_CODE = struct.Struct("!H")

def __init__(self, msg: bytes):
# HACK: Unpack code and reason from raw message
self.code: int = self.CLOSE_CODE.unpack(msg[:2])[0]
self.reason: str = msg[2:].decode('utf-8')
super().__init__(f'WebSocket closed with {self.code} (reason: {self.reason!r})')


class EventListener(NamedTuple):
Expand Down Expand Up @@ -256,7 +273,7 @@ class DiscordWebSocket:
RECONNECT
Receive only. Tells the client to reconnect to a new gateway.
REQUEST_MEMBERS
Send only. Asks for the guild members.
Send only. Asks for the guild members. Responds with GUILD_MEMBERS_CHUNK.
INVALIDATE_SESSION
Receive only. Tells the client to optionally invalidate the session
and IDENTIFY again.
Expand All @@ -265,14 +282,14 @@ class DiscordWebSocket:
HEARTBEAT_ACK
Receive only. Confirms receiving of a heartbeat. Not having it implies
a connection issue.
GUILD_SYNC
Send only. Requests a guild sync. This is unfortunately no longer functional.
CALL_CONNECT
Send only. Maybe used for calling? Probably just tracking.
Send only. Requests an existing call on a channel. Might respond with CALL_CREATE.
GUILD_SUBSCRIBE
Send only. Subscribes you to guilds/guild members. Might respond with GUILD_MEMBER_LIST_UPDATE.
REQUEST_COMMANDS
Send only. Requests application commands from a guild. Responds with GUILD_APPLICATION_COMMANDS_UPDATE.
SEARCH_RECENT_MEMBERS
Send only. Searches for recent members in a guild. Responds with GUILD_MEMBERS_CHUNK.
gateway
The gateway we are currently connected to.
token
Expand Down Expand Up @@ -314,8 +331,8 @@ class DiscordWebSocket:
SEARCH_RECENT_MEMBERS = 35
# fmt: on

def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop) -> None:
self.socket: aiohttp.ClientWebSocketResponse = socket
def __init__(self, socket: WebSocket, *, loop: asyncio.AbstractEventLoop) -> None:
self.socket: WebSocket = socket
self.loop: asyncio.AbstractEventLoop = loop

# An empty dispatcher to prevent crashes
Expand All @@ -336,7 +353,7 @@ def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.Abs

@property
def open(self) -> bool:
return not self.socket.closed
return self.socket.curl._curl is not None

@property
def capabilities(self) -> Capabilities:
Expand Down Expand Up @@ -631,8 +648,7 @@ def latency(self) -> float:
heartbeat = self._keep_alive
return float('inf') if heartbeat is None else heartbeat.latency

def _can_handle_close(self) -> bool:
code = self._close_code or self.socket.close_code
def _can_handle_close(self, code: Optional[int]) -> bool:
return code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014)

async def poll_event(self) -> None:
Expand All @@ -644,58 +660,62 @@ async def poll_event(self) -> None:
The websocket connection was terminated for unhandled reasons.
"""
try:
msg = await self.socket.receive(timeout=self._max_heartbeat_timeout)
if msg.type is aiohttp.WSMsgType.TEXT:
await self.received_message(msg.data)
elif msg.type is aiohttp.WSMsgType.BINARY:
await self.received_message(msg.data)
elif msg.type is aiohttp.WSMsgType.ERROR:
_log.debug('Received %s.', msg)
raise msg.data
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSE):
msg, flags = await asyncio.wait_for(self.socket.arecv(), timeout=self._max_heartbeat_timeout)
if (flags & CurlWsFlag.TEXT) or (flags & CurlWsFlag.BINARY):
await self.received_message(msg)
elif flags & CurlWsFlag.CLOSE:
_log.debug('Received %s.', msg)
raise WebSocketClosure
except (asyncio.TimeoutError, WebSocketClosure) as e:
err = WebSocketClosure(msg)
raise WebSocketClosure(msg)
except (asyncio.TimeoutError, CurlError, WebSocketClosure) as e:
# Ensure the keep alive handler is closed
if self._keep_alive:
self._keep_alive.stop()
self._keep_alive = None

if isinstance(e, asyncio.TimeoutError):
if isinstance(e, asyncio.TimeoutError): # is this also CancelledError??
_log.debug('Timed out receiving packet. Attempting a reconnect.')
raise ReconnectWebSocket from None

code = self._close_code or self.socket.close_code
if self._can_handle_close():
code = self._close_code or getattr(e, 'code', None)
reason = getattr(e, 'reason', None)
if isinstance(e, CurlError):
_log.debug('Received error %s', e)
reason = str(e)

if self._can_handle_close(code or None):
_log.debug('Websocket closed with %s, attempting a reconnect.', code)
raise ReconnectWebSocket from None
else:
_log.debug('Websocket closed with %s, cannot reconnect.', code)
raise ConnectionClosed(self.socket, code=code) from None
raise ConnectionClosed(code, reason) from None

async def _sendstr(self, data: str, /) -> None:
await self.socket.asend(data.encode('utf-8'))

async def debug_send(self, data: str, /) -> None:
await self._rate_limiter.block()
self._dispatch('socket_raw_send', data)
await self.socket.send_str(data)
await self._sendstr(data)

async def send(self, data: str, /) -> None:
await self._rate_limiter.block()
await self.socket.send_str(data)
await self._sendstr(data)

async def send_as_json(self, data: Any) -> None:
try:
await self.send(utils._to_json(data))
except RuntimeError as exc:
if not self._can_handle_close():
raise ConnectionClosed(self.socket) from exc
if not self._can_handle_close(self._close_code):
raise ConnectionClosed(self._close_code) from exc

async def send_heartbeat(self, data: Any) -> None:
# This bypasses the rate limit handling code since it has a higher priority
try:
await self.socket.send_str(utils._to_json(data))
await self._sendstr(utils._to_json(data))
except RuntimeError as exc:
if not self._can_handle_close():
raise ConnectionClosed(self.socket) from exc
if not self._can_handle_close(self._close_code):
raise ConnectionClosed(self._close_code) from exc

async def change_presence(
self,
Expand Down Expand Up @@ -872,13 +892,19 @@ async def search_recent_members(

await self.send_as_json(payload)

async def close(self, code: int = 4000) -> None:
async def close(self, code: int = 4000, reason: bytes = b'') -> None:
if self._keep_alive:
self._keep_alive.stop()
self._keep_alive = None

self._close_code = code
await self.socket.close(code=code)
socket = self.socket

# HACK: The close implementation in curl-cffi is currently broken so we do it ourselves
data = struct.pack('!H', code) + reason
await socket.asend(data, CurlWsFlag.CLOSE)
socket.keep_running = False
await self.loop.run_in_executor(None, socket.curl.close) # TODO: Do I need an executor here?


DVWS = TypeVar('DVWS', bound='DiscordVoiceWebSocket')
Expand Down Expand Up @@ -938,12 +964,12 @@ class DiscordVoiceWebSocket:

def __init__(
self,
socket: aiohttp.ClientWebSocketResponse,
socket: WebSocket,
loop: asyncio.AbstractEventLoop,
*,
hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None,
) -> None:
self.ws: aiohttp.ClientWebSocketResponse = socket
self.ws: WebSocket = socket
self.loop: asyncio.AbstractEventLoop = loop
self._keep_alive: Optional[VoiceKeepAliveHandler] = None
self._close_code: Optional[int] = None
Expand All @@ -954,9 +980,12 @@ def __init__(
async def _hook(self, *args: Any) -> None:
pass

async def _sendstr(self, data: str, /) -> None:
await self.ws.asend(data.encode('utf-8'))

async def send_as_json(self, data: Any) -> None:
_log.debug('Voice gateway sending: %s.', data)
await self.ws.send_str(utils._to_json(data))
await self._sendstr(utils._to_json(data))

send_heartbeat = send_as_json

Expand Down Expand Up @@ -992,7 +1021,8 @@ async def from_client(
"""Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint + '/?v=4'
http = client._state.http
socket = await http.ws_connect(gateway, compress=15)
# TODO: <compress=15> is not supported by curl
socket = await http.ws_connect(gateway)
ws = cls(socket, loop=client.loop, hook=hook)
ws.gateway = gateway
ws._connection = client
Expand Down Expand Up @@ -1122,19 +1152,24 @@ async def load_secret_key(self, data: Dict[str, Any]) -> None:

async def poll_event(self) -> None:
# This exception is handled up the chain
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
if msg.type is aiohttp.WSMsgType.TEXT:
await self.received_message(utils._from_json(msg.data))
elif msg.type is aiohttp.WSMsgType.ERROR:
msg, flags = await asyncio.wait_for(self.ws.arecv(), timeout=self._max_heartbeat_timeout)
if flags & CurlWsFlag.TEXT:
await self.received_message(utils._from_json(msg))
elif flags & CurlWsFlag.CLOSE:
_log.debug('Voice received %s.', msg)
raise ConnectionClosed(self.ws) from msg.data
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
_log.debug('Voice received %s.', msg)
raise ConnectionClosed(self.ws, code=self._close_code)
# TODO: hack
data = WebSocketClosure(msg)
raise ConnectionClosed(data.code, data.reason)

async def close(self, code: int = 1000) -> None:
if self._keep_alive is not None:
async def close(self, code: int = 1000, reason: bytes = b'') -> None:
if self._keep_alive:
self._keep_alive.stop()

self._close_code = code
await self.ws.close(code=code)
socket = self.ws

# HACK: The close implementation in curl-cffi is currently broken so we do it ourselves
data = struct.pack('!H', code) + reason
await socket.asend(data, CurlWsFlag.CLOSE)
socket.keep_running = False
await self.loop.run_in_executor(None, socket.curl.close) # TODO: Do I need an executor here?
33 changes: 16 additions & 17 deletions discord/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,25 +645,24 @@ async def startup(self) -> None:

self._started = True

async def ws_connect(self, url: str, *, compress: int = 0) -> aiohttp.ClientWebSocketResponse:
kwargs: Dict[str, Any] = {
'proxy_auth': self.proxy_auth,
'proxy': self.proxy,
'max_msg_size': 0,
'timeout': 30.0,
'autoclose': False,
'headers': {
'Accept-Language': 'en-US,en;q=0.9',
'Cache-Control': 'no-cache',
'Origin': 'https://discord.com',
'Pragma': 'no-cache',
'Sec-WebSocket-Extensions': 'permessage-deflate; client_max_window_bits',
'User-Agent': self.user_agent,
},
'compress': compress,
async def ws_connect(self, url: str, **kwargs) -> requests.WebSocket:
await self.startup()

headers: Dict[str, Any] = {
'Accept-Language': 'en-US,en;q=0.9',
'Cache-Control': 'no-cache',
'Origin': 'https://discord.com',
'Pragma': 'no-cache',
'Sec-WebSocket-Extensions': 'permessage-deflate; client_max_window_bits',
'User-Agent': self.user_agent,
}
if self.proxy is not None:
kwargs['proxies'] = {'http': self.proxy, 'https': self.proxy}
if self.proxy_auth is not None:
headers['Proxy-Authorization'] = self.proxy_auth.encode()

return await self.__asession.ws_connect(url, **kwargs)
session = self.__session
return await session.ws_connect(url, headers=headers, impersonate=session.impersonate, timeout=30.0, **kwargs)

@property
def browser_version(self) -> str:
Expand Down
10 changes: 7 additions & 3 deletions discord/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1444,9 +1444,7 @@ def destroy(self) -> None:

async def _get_info(session: ClientSession) -> Tuple[Dict[str, Any], str]:
try:
async with session.post('https://cordapi.dolfi.es/api/v2/properties/web', timeout=5) as resp:
json = await resp.json()
return json['properties'], json['encoded']
return await asyncio.wait_for(_get_api_properties(session, 'info'), timeout=3)
except Exception:
_log.info('Info API temporarily down. Falling back to manual retrieval...')

Expand Down Expand Up @@ -1482,6 +1480,12 @@ async def _get_info(session: ClientSession) -> Tuple[Dict[str, Any], str]:
return properties, b64encode(_to_json(properties).encode()).decode('utf-8')


async def _get_api_properties(session: ClientSession, type: str) -> Tuple[Dict[str, Any], str]:
async with session.get(f'https://cordapi.dolfi.es/api/v2/properties/{type}') as resp:
json = await resp.json()
return json['properties'], json['encoded']


async def _get_build_number(session: ClientSession) -> int:
"""Fetches client build number"""
async with session.get('https://discord.com/login') as resp:
Expand Down

0 comments on commit 77bc225

Please sign in to comment.