From 6c25a63e1ea9757f06c0d1c4286e37ace4f85d20 Mon Sep 17 00:00:00 2001 From: Jens Reidel Date: Tue, 4 Aug 2020 18:46:39 +0200 Subject: [PATCH] Support custom JSON encoders --- wavelink/client.py | 27 +++++++++++++++++++++++---- wavelink/node.py | 11 ++++++++--- wavelink/websocket.py | 11 ++++++++++- 3 files changed, 41 insertions(+), 8 deletions(-) diff --git a/wavelink/client.py b/wavelink/client.py index 13445498..d2420687 100644 --- a/wavelink/client.py +++ b/wavelink/client.py @@ -25,6 +25,7 @@ import logging from discord.ext import commands from functools import partial +from json import dumps from typing import Optional, Union from .errors import * @@ -69,6 +70,8 @@ def __init__(self, bot: Union[commands.Bot, commands.AutoShardedBot]): self.nodes = {} + self._dumps = dumps + bot.add_listener(self.update_handler, 'on_socket_response') @property @@ -153,7 +156,7 @@ async def get_tracks(self, query: str) -> Optional[list]: There are no :class:`wavelink.node.Node`s currently connected. """ node = self.get_best_node() - + if node is None: raise ZeroConnectedNodes @@ -386,7 +389,7 @@ async def initiate_node(self, host: str, port: int, *, rest_uri: str, password: Whether the websocket should be started with the secure wss protocol. heartbeat: Optional[float] Send ping message every heartbeat seconds and wait pong response, if pong response is not received then close connection. - + Returns --------- :class:`wavelink.node.Node` @@ -412,8 +415,9 @@ async def initiate_node(self, host: str, port: int, *, rest_uri: str, password: session=self.session, client=self, secure=secure, - heartbeat=heartbeat) - + heartbeat=heartbeat, + dumps=self._dumps) + await node.connect(bot=self.bot) node.available = True @@ -467,3 +471,18 @@ async def update_handler(self, data) -> None: pass else: await player._voice_state_update(data['d']) + + def set_serializer(self, serializer_function) -> None: + """Sets the JSON dumps function for use in the websocket. + The default one is the built-in JSON module. + + Parameters + ---------- + serializer_function: Callable[[Dict[str, Any]]], Union[str, bytes]] + The function that serializes the JSON data to a string or bytes. + """ + self._dumps = serializer_function + # Update all existing nodes + for node in self.nodes.values(): + node._dumps = serializer_function + node._websocket._dumps = serializer_function diff --git a/wavelink/node.py b/wavelink/node.py index e8fb5cdb..f37a7251 100644 --- a/wavelink/node.py +++ b/wavelink/node.py @@ -21,9 +21,10 @@ SOFTWARE. """ import inspect +import json import logging from discord.ext import commands -from typing import Optional, Union +from typing import Any, Callable, Dict, Optional, Union from urllib.parse import quote from .errors import * @@ -67,7 +68,8 @@ def __init__(self, host: str, identifier: str, shard_id: int = None, secure: bool = False, - heartbeat: float = None + heartbeat: float = None, + dumps: Callable[[Dict[str, Any]], Union[str, bytes]] = json.dumps ): self.host = host @@ -81,6 +83,8 @@ def __init__(self, host: str, self.secure = secure self.heartbeat = heartbeat + self._dumps = dumps + self.shard_id = shard_id self.players = {} @@ -125,7 +129,8 @@ async def connect(self, bot: Union[commands.Bot, commands.AutoShardedBot]) -> No password=self.password, shard_count=self.shards, user_id=self.uid, - secure=self.secure) + secure=self.secure, + dumps=self._dumps) await self._websocket._connect() __log__.info(f'NODE | {self.identifier} connected:: {self.__repr__()}') diff --git a/wavelink/websocket.py b/wavelink/websocket.py index d8f08b72..b6f2ae21 100644 --- a/wavelink/websocket.py +++ b/wavelink/websocket.py @@ -47,6 +47,7 @@ def __init__(self, **attrs): self.shard_count = attrs.get('shard_count') self.user_id = attrs.get('user_id') self.secure = attrs.get('secure') + self._dumps = attrs.get('dumps') self._websocket = None self._last_exc = None @@ -165,4 +166,12 @@ def _get_event_payload(self, name: str, data): async def _send(self, **data): if self.is_connected: __log__.debug(f'WEBSOCKET | Sending Payload:: {data}') - await self._websocket.send_json(data) + data_str = self._dumps(data) + if isinstance(data_str, bytes): + # Some JSON libraries serialize to bytes + # Yet Lavalink does not support binary websockets + # So we need to decode. In the future, maybe + # self._websocket.send_bytes could be used + # if Lavalink ever implements it + data_str = data_str.decode('utf-8') + await self._websocket.send_str(data_str)