diff --git a/discord/errors.py b/discord/errors.py index 25738b45a88b..306f6b8d33c9 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -29,8 +29,8 @@ from .utils import _get_as_snowflake if TYPE_CHECKING: - from aiohttp import ClientResponse, ClientWebSocketResponse - from curl_cffi.requests import Response as CurlResponse + from aiohttp import ClientResponse + from curl_cffi.requests import Response as CurlResponse, WebSocket from requests import Response from typing_extensions import TypeGuard @@ -301,10 +301,10 @@ class ConnectionClosed(ClientException): __slots__ = ('code', 'reason') - def __init__(self, socket: ClientWebSocketResponse, *, code: Optional[int] = None): + def __init__(self, code: Optional[int] = None, reason: Optional[str] = None): # This exception is just the same exception except # reconfigured to subclass ClientException for users - self.code: int = code or socket.close_code or -1 + self.code: int = code or -1 # aiohttp doesn't seem to consistently provide close reason - self.reason: str = '' - super().__init__(f'WebSocket closed with {self.code}') + self.reason: str = reason or '' + super().__init__(f'WebSocket closed with {self.code} (reason: {self.reason!r})') diff --git a/discord/gateway.py b/discord/gateway.py index cc3872d7d877..ef54117cf8ab 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -34,12 +34,14 @@ from typing import Any, Callable, Coroutine, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar -import aiohttp +from curl_cffi import CurlError +from curl_cffi.requests import WebSocket +from curl_cffi.const import CurlWsFlag import yarl from . import utils from .activity import BaseActivity, Spotify -from .enums import SpeakingState +from .enums import SpeakingState, Status from .errors import ConnectionClosed from .flags import Capabilities @@ -58,7 +60,6 @@ from .activity import ActivityTypes from .client import Client - from .enums import Status from .state import ConnectionState from .types.snowflake import Snowflake from .voice_client import VoiceClient @@ -73,9 +74,25 @@ def __init__(self, *, resume: bool = True): class WebSocketClosure(Exception): - """An exception to make up for the fact that aiohttp doesn't signal closure.""" + """An exception to make up for the fact that curl doesn't signal closure. - pass + Attributes + ----------- + code: :class:`int` + The close code of the websocket. + reason: :class:`str` + The reason provided for the closure. + """ + + __slots__ = ('code', 'reason') + + CLOSE_CODE = struct.Struct("!H") + + def __init__(self, msg: bytes): + # HACK: Unpack code and reason from raw message + self.code: int = self.CLOSE_CODE.unpack(msg[:2])[0] + self.reason: str = msg[2:].decode('utf-8') + super().__init__(f'WebSocket closed with {self.code} (reason: {self.reason!r})') class EventListener(NamedTuple): @@ -256,7 +273,7 @@ class DiscordWebSocket: RECONNECT Receive only. Tells the client to reconnect to a new gateway. REQUEST_MEMBERS - Send only. Asks for the guild members. + Send only. Asks for the guild members. Responds with GUILD_MEMBERS_CHUNK. INVALIDATE_SESSION Receive only. Tells the client to optionally invalidate the session and IDENTIFY again. @@ -265,14 +282,14 @@ class DiscordWebSocket: HEARTBEAT_ACK Receive only. Confirms receiving of a heartbeat. Not having it implies a connection issue. - GUILD_SYNC - Send only. Requests a guild sync. This is unfortunately no longer functional. CALL_CONNECT - Send only. Maybe used for calling? Probably just tracking. + Send only. Requests an existing call on a channel. Might respond with CALL_CREATE. GUILD_SUBSCRIBE Send only. Subscribes you to guilds/guild members. Might respond with GUILD_MEMBER_LIST_UPDATE. REQUEST_COMMANDS Send only. Requests application commands from a guild. Responds with GUILD_APPLICATION_COMMANDS_UPDATE. + SEARCH_RECENT_MEMBERS + Send only. Searches for recent members in a guild. Responds with GUILD_MEMBERS_CHUNK. gateway The gateway we are currently connected to. token @@ -314,8 +331,8 @@ class DiscordWebSocket: SEARCH_RECENT_MEMBERS = 35 # fmt: on - def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop) -> None: - self.socket: aiohttp.ClientWebSocketResponse = socket + def __init__(self, socket: WebSocket, *, loop: asyncio.AbstractEventLoop) -> None: + self.socket: WebSocket = socket self.loop: asyncio.AbstractEventLoop = loop # An empty dispatcher to prevent crashes @@ -336,7 +353,7 @@ def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.Abs @property def open(self) -> bool: - return not self.socket.closed + return self.socket.curl._curl is not None @property def capabilities(self) -> Capabilities: @@ -631,8 +648,7 @@ def latency(self) -> float: heartbeat = self._keep_alive return float('inf') if heartbeat is None else heartbeat.latency - def _can_handle_close(self) -> bool: - code = self._close_code or self.socket.close_code + def _can_handle_close(self, code: Optional[int]) -> bool: return code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014) async def poll_event(self) -> None: @@ -644,58 +660,62 @@ async def poll_event(self) -> None: The websocket connection was terminated for unhandled reasons. """ try: - msg = await self.socket.receive(timeout=self._max_heartbeat_timeout) - if msg.type is aiohttp.WSMsgType.TEXT: - await self.received_message(msg.data) - elif msg.type is aiohttp.WSMsgType.BINARY: - await self.received_message(msg.data) - elif msg.type is aiohttp.WSMsgType.ERROR: - _log.debug('Received %s.', msg) - raise msg.data - elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSE): + msg, flags = await asyncio.wait_for(self.socket.arecv(), timeout=self._max_heartbeat_timeout) + if (flags & CurlWsFlag.TEXT) or (flags & CurlWsFlag.BINARY): + await self.received_message(msg) + elif flags & CurlWsFlag.CLOSE: _log.debug('Received %s.', msg) - raise WebSocketClosure - except (asyncio.TimeoutError, WebSocketClosure) as e: + err = WebSocketClosure(msg) + raise WebSocketClosure(msg) + except (asyncio.TimeoutError, CurlError, WebSocketClosure) as e: # Ensure the keep alive handler is closed if self._keep_alive: self._keep_alive.stop() self._keep_alive = None - if isinstance(e, asyncio.TimeoutError): + if isinstance(e, asyncio.TimeoutError): # is this also CancelledError?? _log.debug('Timed out receiving packet. Attempting a reconnect.') raise ReconnectWebSocket from None - code = self._close_code or self.socket.close_code - if self._can_handle_close(): + code = self._close_code or getattr(e, 'code', None) + reason = getattr(e, 'reason', None) + if isinstance(e, CurlError): + _log.debug('Received error %s', e) + reason = str(e) + + if self._can_handle_close(code or None): _log.debug('Websocket closed with %s, attempting a reconnect.', code) raise ReconnectWebSocket from None else: _log.debug('Websocket closed with %s, cannot reconnect.', code) - raise ConnectionClosed(self.socket, code=code) from None + raise ConnectionClosed(code, reason) from None + + async def _sendstr(self, data: str, /) -> None: + await self.socket.asend(data.encode('utf-8')) async def debug_send(self, data: str, /) -> None: await self._rate_limiter.block() self._dispatch('socket_raw_send', data) - await self.socket.send_str(data) + await self._sendstr(data) async def send(self, data: str, /) -> None: await self._rate_limiter.block() - await self.socket.send_str(data) + await self._sendstr(data) async def send_as_json(self, data: Any) -> None: try: await self.send(utils._to_json(data)) except RuntimeError as exc: - if not self._can_handle_close(): - raise ConnectionClosed(self.socket) from exc + if not self._can_handle_close(self._close_code): + raise ConnectionClosed(self._close_code) from exc async def send_heartbeat(self, data: Any) -> None: # This bypasses the rate limit handling code since it has a higher priority try: - await self.socket.send_str(utils._to_json(data)) + await self._sendstr(utils._to_json(data)) except RuntimeError as exc: - if not self._can_handle_close(): - raise ConnectionClosed(self.socket) from exc + if not self._can_handle_close(self._close_code): + raise ConnectionClosed(self._close_code) from exc async def change_presence( self, @@ -872,13 +892,19 @@ async def search_recent_members( await self.send_as_json(payload) - async def close(self, code: int = 4000) -> None: + async def close(self, code: int = 4000, reason: bytes = b'') -> None: if self._keep_alive: self._keep_alive.stop() self._keep_alive = None self._close_code = code - await self.socket.close(code=code) + socket = self.socket + + # HACK: The close implementation in curl-cffi is currently broken so we do it ourselves + data = struct.pack('!H', code) + reason + await socket.asend(data, CurlWsFlag.CLOSE) + socket.keep_running = False + await self.loop.run_in_executor(None, socket.curl.close) # TODO: Do I need an executor here? DVWS = TypeVar('DVWS', bound='DiscordVoiceWebSocket') @@ -938,12 +964,12 @@ class DiscordVoiceWebSocket: def __init__( self, - socket: aiohttp.ClientWebSocketResponse, + socket: WebSocket, loop: asyncio.AbstractEventLoop, *, hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None, ) -> None: - self.ws: aiohttp.ClientWebSocketResponse = socket + self.ws: WebSocket = socket self.loop: asyncio.AbstractEventLoop = loop self._keep_alive: Optional[VoiceKeepAliveHandler] = None self._close_code: Optional[int] = None @@ -954,9 +980,12 @@ def __init__( async def _hook(self, *args: Any) -> None: pass + async def _sendstr(self, data: str, /) -> None: + await self.ws.asend(data.encode('utf-8')) + async def send_as_json(self, data: Any) -> None: _log.debug('Voice gateway sending: %s.', data) - await self.ws.send_str(utils._to_json(data)) + await self._sendstr(utils._to_json(data)) send_heartbeat = send_as_json @@ -992,7 +1021,8 @@ async def from_client( """Creates a voice websocket for the :class:`VoiceClient`.""" gateway = 'wss://' + client.endpoint + '/?v=4' http = client._state.http - socket = await http.ws_connect(gateway, compress=15) + # TODO: is not supported by curl + socket = await http.ws_connect(gateway) ws = cls(socket, loop=client.loop, hook=hook) ws.gateway = gateway ws._connection = client @@ -1122,19 +1152,24 @@ async def load_secret_key(self, data: Dict[str, Any]) -> None: async def poll_event(self) -> None: # This exception is handled up the chain - msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0) - if msg.type is aiohttp.WSMsgType.TEXT: - await self.received_message(utils._from_json(msg.data)) - elif msg.type is aiohttp.WSMsgType.ERROR: + msg, flags = await asyncio.wait_for(self.ws.arecv(), timeout=self._max_heartbeat_timeout) + if flags & CurlWsFlag.TEXT: + await self.received_message(utils._from_json(msg)) + elif flags & CurlWsFlag.CLOSE: _log.debug('Voice received %s.', msg) - raise ConnectionClosed(self.ws) from msg.data - elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING): - _log.debug('Voice received %s.', msg) - raise ConnectionClosed(self.ws, code=self._close_code) + # TODO: hack + data = WebSocketClosure(msg) + raise ConnectionClosed(data.code, data.reason) - async def close(self, code: int = 1000) -> None: - if self._keep_alive is not None: + async def close(self, code: int = 1000, reason: bytes = b'') -> None: + if self._keep_alive: self._keep_alive.stop() self._close_code = code - await self.ws.close(code=code) + socket = self.ws + + # HACK: The close implementation in curl-cffi is currently broken so we do it ourselves + data = struct.pack('!H', code) + reason + await socket.asend(data, CurlWsFlag.CLOSE) + socket.keep_running = False + await self.loop.run_in_executor(None, socket.curl.close) # TODO: Do I need an executor here? diff --git a/discord/http.py b/discord/http.py index 23f983d032bc..7b855bae1dd1 100644 --- a/discord/http.py +++ b/discord/http.py @@ -645,25 +645,24 @@ async def startup(self) -> None: self._started = True - async def ws_connect(self, url: str, *, compress: int = 0) -> aiohttp.ClientWebSocketResponse: - kwargs: Dict[str, Any] = { - 'proxy_auth': self.proxy_auth, - 'proxy': self.proxy, - 'max_msg_size': 0, - 'timeout': 30.0, - 'autoclose': False, - 'headers': { - 'Accept-Language': 'en-US,en;q=0.9', - 'Cache-Control': 'no-cache', - 'Origin': 'https://discord.com', - 'Pragma': 'no-cache', - 'Sec-WebSocket-Extensions': 'permessage-deflate; client_max_window_bits', - 'User-Agent': self.user_agent, - }, - 'compress': compress, + async def ws_connect(self, url: str, **kwargs) -> requests.WebSocket: + await self.startup() + + headers: Dict[str, Any] = { + 'Accept-Language': 'en-US,en;q=0.9', + 'Cache-Control': 'no-cache', + 'Origin': 'https://discord.com', + 'Pragma': 'no-cache', + 'Sec-WebSocket-Extensions': 'permessage-deflate; client_max_window_bits', + 'User-Agent': self.user_agent, } + if self.proxy is not None: + kwargs['proxies'] = {'http': self.proxy, 'https': self.proxy} + if self.proxy_auth is not None: + headers['Proxy-Authorization'] = self.proxy_auth.encode() - return await self.__asession.ws_connect(url, **kwargs) + session = self.__session + return await session.ws_connect(url, headers=headers, impersonate=session.impersonate, timeout=30.0, **kwargs) @property def browser_version(self) -> str: diff --git a/discord/utils.py b/discord/utils.py index 5696546bfd42..84fcf3cbaa35 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -1444,9 +1444,7 @@ def destroy(self) -> None: async def _get_info(session: ClientSession) -> Tuple[Dict[str, Any], str]: try: - async with session.post('https://cordapi.dolfi.es/api/v2/properties/web', timeout=5) as resp: - json = await resp.json() - return json['properties'], json['encoded'] + return await asyncio.wait_for(_get_api_properties(session, 'info'), timeout=3) except Exception: _log.info('Info API temporarily down. Falling back to manual retrieval...') @@ -1482,6 +1480,12 @@ async def _get_info(session: ClientSession) -> Tuple[Dict[str, Any], str]: return properties, b64encode(_to_json(properties).encode()).decode('utf-8') +async def _get_api_properties(session: ClientSession, type: str) -> Tuple[Dict[str, Any], str]: + async with session.get(f'https://cordapi.dolfi.es/api/v2/properties/{type}') as resp: + json = await resp.json() + return json['properties'], json['encoded'] + + async def _get_build_number(session: ClientSession) -> int: """Fetches client build number""" async with session.get('https://discord.com/login') as resp: