diff --git a/.gitignore b/.gitignore index b13a0aa4bb..0c1180d8f4 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ docs/crowdin.py *.mp3 *.m4a *.wav +*.pcm *.png *.jpg *.flac @@ -18,6 +19,8 @@ docs/crowdin.py .DS_Store .python-version __pycache__ -.vs/slnx.sqlite +.vs/* +.vscode/* env/ build/ +test.py diff --git a/discord/__init__.py b/discord/__init__.py index c7169908c3..6451ebebfa 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -57,6 +57,7 @@ from .sticker import * from .stage_instance import * from .interactions import * +from .sink import * from .components import * from .threads import * from .bot import * diff --git a/discord/errors.py b/discord/errors.py index 6cc549c61d..7757137dff 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -59,7 +59,8 @@ 'ExtensionNotLoaded', 'NoEntryPointError', 'ExtensionFailed', - 'ExtensionNotFound' + 'ExtensionNotFound', + 'RecordingException', ) @@ -269,6 +270,19 @@ def __init__(self, shard_id: Optional[int]): ) super().__init__(msg % shard_id) +class RecordingException(ClientException): + """Exception that's thrown when there is an error while trying to record + audio from a voice channel. + + .. versionadded:: 2.0 + """ + pass + +class SinkException(ClientException): + """Raised when a Sink error occurs. + + .. versionadded:: 2.0 + """ class InteractionResponded(ClientException): """Exception that's raised when sending another interaction response using diff --git a/discord/gateway.py b/discord/gateway.py index 54c7768ca6..81ac87c1aa 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -728,6 +728,7 @@ def __init__(self, socket, loop, *, hook=None): self._keep_alive = None self._close_code = None self.secret_key = None + self.ssrc_map = {} if hook: self._hook = hook @@ -839,6 +840,15 @@ async def received_message(self, msg): self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0)) self._keep_alive.start() + elif op == self.SPEAKING: + ssrc = data['ssrc'] + user = int(data['user_id']) + speaking = data['speaking'] + if ssrc in self.ssrc_map: + self.ssrc_map[ssrc]['speaking'] = speaking + else: + self.ssrc_map.update({ssrc: {'user_id': user, 'speaking': speaking}}) + await self._hook(self, msg) async def initial_connection(self, data): diff --git a/discord/opus.py b/discord/opus.py index 515fc3db7f..a37c5c52db 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -25,7 +25,18 @@ from __future__ import annotations -from typing import List, Tuple, TypedDict, Any, TYPE_CHECKING, Callable, TypeVar, Literal, Optional, overload +from typing import ( + List, + Tuple, + TypedDict, + Any, + TYPE_CHECKING, + Callable, + TypeVar, + Literal, + Optional, + overload, +) import array import ctypes @@ -35,13 +46,19 @@ import os.path import struct import sys +import gc +import threading +import traceback +import time from .errors import DiscordException, InvalidArgument +from .sink import RawData if TYPE_CHECKING: - T = TypeVar('T') - BAND_CTL = Literal['narrow', 'medium', 'wide', 'superwide', 'full'] - SIGNAL_CTL = Literal['auto', 'voice', 'music'] + T = TypeVar("T") + BAND_CTL = Literal["narrow", "medium", "wide", "superwide", "full"] + SIGNAL_CTL = Literal["auto", "voice", "music"] + class BandCtl(TypedDict): narrow: int @@ -50,15 +67,19 @@ class BandCtl(TypedDict): superwide: int full: int + class SignalCtl(TypedDict): auto: int voice: int music: int + __all__ = ( - 'Encoder', - 'OpusError', - 'OpusNotLoaded', + "Encoder", + "Decoder", + "DecodeManager", + "OpusError", + "OpusNotLoaded", ) _log = logging.getLogger(__name__) @@ -69,62 +90,68 @@ class SignalCtl(TypedDict): _lib = None + class EncoderStruct(ctypes.Structure): pass + class DecoderStruct(ctypes.Structure): pass + EncoderStructPtr = ctypes.POINTER(EncoderStruct) DecoderStructPtr = ctypes.POINTER(DecoderStruct) ## Some constants from opus_defines.h # Error codes -OK = 0 +OK = 0 BAD_ARG = -1 # Encoder CTLs -APPLICATION_AUDIO = 2049 -APPLICATION_VOIP = 2048 -APPLICATION_LOWDELAY = 2051 +APPLICATION_AUDIO = 2049 +APPLICATION_VOIP = 2048 +APPLICATION_LOWDELAY = 2051 -CTL_SET_BITRATE = 4002 -CTL_SET_BANDWIDTH = 4008 -CTL_SET_FEC = 4012 -CTL_SET_PLP = 4014 -CTL_SET_SIGNAL = 4024 +CTL_SET_BITRATE = 4002 +CTL_SET_BANDWIDTH = 4008 +CTL_SET_FEC = 4012 +CTL_SET_PLP = 4014 +CTL_SET_SIGNAL = 4024 # Decoder CTLs -CTL_SET_GAIN = 4034 +CTL_SET_GAIN = 4034 CTL_LAST_PACKET_DURATION = 4039 band_ctl: BandCtl = { - 'narrow': 1101, - 'medium': 1102, - 'wide': 1103, - 'superwide': 1104, - 'full': 1105, + "narrow": 1101, + "medium": 1102, + "wide": 1103, + "superwide": 1104, + "full": 1105, } signal_ctl: SignalCtl = { - 'auto': -1000, - 'voice': 3001, - 'music': 3002, + "auto": -1000, + "voice": 3001, + "music": 3002, } + def _err_lt(result: int, func: Callable, args: List) -> int: if result < OK: - _log.info('error has happened in %s', func.__name__) + _log.info("error has happened in %s", func.__name__) raise OpusError(result) return result + def _err_ne(result: T, func: Callable, args: List) -> T: ret = args[-1]._obj if ret.value != OK: - _log.info('error has happened in %s', func.__name__) + _log.info("error has happened in %s", func.__name__) raise OpusError(ret.value) return result + # A list of exported functions. # The first argument is obviously the name. # The second one are the types of arguments it takes. @@ -132,54 +159,90 @@ def _err_ne(result: T, func: Callable, args: List) -> T: # The fourth is the error handler. exported_functions: List[Tuple[Any, ...]] = [ # Generic - ('opus_get_version_string', - None, ctypes.c_char_p, None), - ('opus_strerror', - [ctypes.c_int], ctypes.c_char_p, None), - + ("opus_get_version_string", None, ctypes.c_char_p, None), + ("opus_strerror", [ctypes.c_int], ctypes.c_char_p, None), # Encoder functions - ('opus_encoder_get_size', - [ctypes.c_int], ctypes.c_int, None), - ('opus_encoder_create', - [ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr], EncoderStructPtr, _err_ne), - ('opus_encode', - [EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt), - ('opus_encode_float', - [EncoderStructPtr, c_float_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt), - ('opus_encoder_ctl', - None, ctypes.c_int32, _err_lt), - ('opus_encoder_destroy', - [EncoderStructPtr], None, None), - + ("opus_encoder_get_size", [ctypes.c_int], ctypes.c_int, None), + ( + "opus_encoder_create", + [ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr], + EncoderStructPtr, + _err_ne, + ), + ( + "opus_encode", + [EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], + ctypes.c_int32, + _err_lt, + ), + ( + "opus_encode_float", + [EncoderStructPtr, c_float_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], + ctypes.c_int32, + _err_lt, + ), + ("opus_encoder_ctl", None, ctypes.c_int32, _err_lt), + ("opus_encoder_destroy", [EncoderStructPtr], None, None), # Decoder functions - ('opus_decoder_get_size', - [ctypes.c_int], ctypes.c_int, None), - ('opus_decoder_create', - [ctypes.c_int, ctypes.c_int, c_int_ptr], DecoderStructPtr, _err_ne), - ('opus_decode', - [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_int16_ptr, ctypes.c_int, ctypes.c_int], - ctypes.c_int, _err_lt), - ('opus_decode_float', - [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_float_ptr, ctypes.c_int, ctypes.c_int], - ctypes.c_int, _err_lt), - ('opus_decoder_ctl', - None, ctypes.c_int32, _err_lt), - ('opus_decoder_destroy', - [DecoderStructPtr], None, None), - ('opus_decoder_get_nb_samples', - [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int, _err_lt), - + ("opus_decoder_get_size", [ctypes.c_int], ctypes.c_int, None), + ( + "opus_decoder_create", + [ctypes.c_int, ctypes.c_int, c_int_ptr], + DecoderStructPtr, + _err_ne, + ), + ( + "opus_decode", + [ + DecoderStructPtr, + ctypes.c_char_p, + ctypes.c_int32, + c_int16_ptr, + ctypes.c_int, + ctypes.c_int, + ], + ctypes.c_int, + _err_lt, + ), + ( + "opus_decode_float", + [ + DecoderStructPtr, + ctypes.c_char_p, + ctypes.c_int32, + c_float_ptr, + ctypes.c_int, + ctypes.c_int, + ], + ctypes.c_int, + _err_lt, + ), + ("opus_decoder_ctl", None, ctypes.c_int32, _err_lt), + ("opus_decoder_destroy", [DecoderStructPtr], None, None), + ( + "opus_decoder_get_nb_samples", + [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32], + ctypes.c_int, + _err_lt, + ), # Packet functions - ('opus_packet_get_bandwidth', - [ctypes.c_char_p], ctypes.c_int, _err_lt), - ('opus_packet_get_nb_channels', - [ctypes.c_char_p], ctypes.c_int, _err_lt), - ('opus_packet_get_nb_frames', - [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt), - ('opus_packet_get_samples_per_frame', - [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt), + ("opus_packet_get_bandwidth", [ctypes.c_char_p], ctypes.c_int, _err_lt), + ("opus_packet_get_nb_channels", [ctypes.c_char_p], ctypes.c_int, _err_lt), + ( + "opus_packet_get_nb_frames", + [ctypes.c_char_p, ctypes.c_int], + ctypes.c_int, + _err_lt, + ), + ( + "opus_packet_get_samples_per_frame", + [ctypes.c_char_p, ctypes.c_int], + ctypes.c_int, + _err_lt, + ), ] + def libopus_loader(name: str) -> Any: # create the library... lib = ctypes.cdll.LoadLibrary(name) @@ -204,22 +267,24 @@ def libopus_loader(name: str) -> Any: return lib + def _load_default() -> bool: global _lib try: - if sys.platform == 'win32': + if sys.platform == "win32": _basedir = os.path.dirname(os.path.abspath(__file__)) - _bitness = struct.calcsize('P') * 8 - _target = 'x64' if _bitness > 32 else 'x86' - _filename = os.path.join(_basedir, 'bin', f'libopus-0.{_target}.dll') + _bitness = struct.calcsize("P") * 8 + _target = "x64" if _bitness > 32 else "x86" + _filename = os.path.join(_basedir, "bin", f"libopus-0.{_target}.dll") _lib = libopus_loader(_filename) else: - _lib = libopus_loader(ctypes.util.find_library('opus')) + _lib = libopus_loader(ctypes.util.find_library("opus")) except Exception: _lib = None return _lib is not None + def load_opus(name: str) -> None: """Loads the libopus shared library for use with voice. @@ -258,6 +323,7 @@ def load_opus(name: str) -> None: global _lib _lib = libopus_loader(name) + def is_loaded() -> bool: """Function to check if opus lib is successfully loaded either via the :func:`ctypes.util.find_library` call of :func:`load_opus`. @@ -272,6 +338,7 @@ def is_loaded() -> bool: global _lib return _lib is not None + class OpusError(DiscordException): """An exception that is thrown for libopus related errors. @@ -283,19 +350,22 @@ class OpusError(DiscordException): def __init__(self, code: int): self.code: int = code - msg = _lib.opus_strerror(self.code).decode('utf-8') + msg = _lib.opus_strerror(self.code).decode("utf-8") _log.info('"%s" has happened', msg) super().__init__(msg) + class OpusNotLoaded(DiscordException): """An exception that is thrown for when libopus is not loaded.""" + pass + class _OpusStruct: SAMPLING_RATE = 48000 CHANNELS = 2 FRAME_LENGTH = 20 # in milliseconds - SAMPLE_SIZE = struct.calcsize('h') * CHANNELS + SAMPLE_SIZE = struct.calcsize("h") * CHANNELS SAMPLES_PER_FRAME = int(SAMPLING_RATE / 1000 * FRAME_LENGTH) FRAME_SIZE = SAMPLES_PER_FRAME * SAMPLE_SIZE @@ -305,7 +375,8 @@ def get_opus_version() -> str: if not is_loaded() and not _load_default(): raise OpusNotLoaded() - return _lib.opus_get_version_string().decode('utf-8') + return _lib.opus_get_version_string().decode("utf-8") + class Encoder(_OpusStruct): def __init__(self, application: int = APPLICATION_AUDIO): @@ -316,18 +387,20 @@ def __init__(self, application: int = APPLICATION_AUDIO): self.set_bitrate(128) self.set_fec(True) self.set_expected_packet_loss_percent(0.15) - self.set_bandwidth('full') - self.set_signal_type('auto') + self.set_bandwidth("full") + self.set_signal_type("auto") def __del__(self) -> None: - if hasattr(self, '_state'): + if hasattr(self, "_state"): _lib.opus_encoder_destroy(self._state) # This is a destructor, so it's okay to assign None - self._state = None # type: ignore + self._state = None # type: ignore def _create_state(self) -> EncoderStruct: ret = ctypes.c_int() - return _lib.opus_encoder_create(self.SAMPLING_RATE, self.CHANNELS, self.application, ctypes.byref(ret)) + return _lib.opus_encoder_create( + self.SAMPLING_RATE, self.CHANNELS, self.application, ctypes.byref(ret) + ) def set_bitrate(self, kbps: int) -> int: kbps = min(512, max(16, int(kbps))) @@ -337,14 +410,18 @@ def set_bitrate(self, kbps: int) -> int: def set_bandwidth(self, req: BAND_CTL) -> None: if req not in band_ctl: - raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(band_ctl)}') + raise KeyError( + f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(band_ctl)}' + ) k = band_ctl[req] _lib.opus_encoder_ctl(self._state, CTL_SET_BANDWIDTH, k) def set_signal_type(self, req: SIGNAL_CTL) -> None: if req not in signal_ctl: - raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(signal_ctl)}') + raise KeyError( + f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(signal_ctl)}' + ) k = signal_ctl[req] _lib.opus_encoder_ctl(self._state, CTL_SET_SIGNAL, k) @@ -353,53 +430,54 @@ def set_fec(self, enabled: bool = True) -> None: _lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0) def set_expected_packet_loss_percent(self, percentage: float) -> None: - _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore + _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore def encode(self, pcm: bytes, frame_size: int) -> bytes: max_data_bytes = len(pcm) # bytes can be used to reference pointer - pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore + pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore data = (ctypes.c_char * max_data_bytes)() ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes) # array can be initialized with bytes but mypy doesn't know - return array.array('b', data[:ret]).tobytes() # type: ignore + return array.array("b", data[:ret]).tobytes() # type: ignore + class Decoder(_OpusStruct): def __init__(self): _OpusStruct.get_opus_version() - self._state: DecoderStruct = self._create_state() + self._state = self._create_state() - def __del__(self) -> None: - if hasattr(self, '_state'): + def __del__(self): + if hasattr(self, "_state"): _lib.opus_decoder_destroy(self._state) - # This is a destructor, so it's okay to assign None - self._state = None # type: ignore + self._state = None - def _create_state(self) -> DecoderStruct: + def _create_state(self): ret = ctypes.c_int() - return _lib.opus_decoder_create(self.SAMPLING_RATE, self.CHANNELS, ctypes.byref(ret)) + return _lib.opus_decoder_create( + self.SAMPLING_RATE, self.CHANNELS, ctypes.byref(ret) + ) @staticmethod - def packet_get_nb_frames(data: bytes) -> int: + def packet_get_nb_frames(data): """Gets the number of frames in an Opus packet""" return _lib.opus_packet_get_nb_frames(data, len(data)) @staticmethod - def packet_get_nb_channels(data: bytes) -> int: + def packet_get_nb_channels(data): """Gets the number of channels in an Opus packet""" return _lib.opus_packet_get_nb_channels(data) @classmethod - def packet_get_samples_per_frame(cls, data: bytes) -> int: + def packet_get_samples_per_frame(cls, data): """Gets the number of samples per frame from an Opus packet""" return _lib.opus_packet_get_samples_per_frame(data, cls.SAMPLING_RATE) - def _set_gain(self, adjustment: int) -> int: + def _set_gain(self, adjustment): """Configures decoder gain adjustment. - Scales the decoded output by a factor specified in Q8 dB units. This has a maximum range of -32768 to 32767 inclusive, and returns OPUS_BAD_ARG (-1) otherwise. The default is zero indicating no adjustment. @@ -409,47 +487,101 @@ def _set_gain(self, adjustment: int) -> int: """ return _lib.opus_decoder_ctl(self._state, CTL_SET_GAIN, adjustment) - def set_gain(self, dB: float) -> int: + def set_gain(self, dB): """Sets the decoder gain in dB, from -128 to 128.""" - dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8) + dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8) return self._set_gain(dB_Q8) - def set_volume(self, mult: float) -> int: + def set_volume(self, mult): """Sets the output volume as a float percent, i.e. 0.5 for 50%, 1.75 for 175%, etc.""" - return self.set_gain(20 * math.log10(mult)) # amplitude ratio + return self.set_gain(20 * math.log10(mult)) # amplitude ratio - def _get_last_packet_duration(self) -> int: + def _get_last_packet_duration(self): """Gets the duration (in samples) of the last packet successfully decoded or concealed.""" ret = ctypes.c_int32() _lib.opus_decoder_ctl(self._state, CTL_LAST_PACKET_DURATION, ctypes.byref(ret)) return ret.value - @overload - def decode(self, data: bytes, *, fec: bool) -> bytes: - ... - - @overload - def decode(self, data: Literal[None], *, fec: Literal[False]) -> bytes: - ... - - def decode(self, data: Optional[bytes], *, fec: bool = False) -> bytes: + def decode(self, data, *, fec=False): if data is None and fec: - raise InvalidArgument("Invalid arguments: FEC cannot be used with null data") + raise OpusError("Invalid arguments: FEC cannot be used with null data") if data is None: frame_size = self._get_last_packet_duration() or self.SAMPLES_PER_FRAME channel_count = self.CHANNELS else: frames = self.packet_get_nb_frames(data) - channel_count = self.packet_get_nb_channels(data) + channel_count = self.CHANNELS samples_per_frame = self.packet_get_samples_per_frame(data) frame_size = frames * samples_per_frame - pcm = (ctypes.c_int16 * (frame_size * channel_count))() + pcm = ( + ctypes.c_int16 + * (frame_size * channel_count * ctypes.sizeof(ctypes.c_int16)) + )() pcm_ptr = ctypes.cast(pcm, c_int16_ptr) - ret = _lib.opus_decode(self._state, data, len(data) if data else 0, pcm_ptr, frame_size, fec) + ret = _lib.opus_decode( + self._state, data, len(data) if data else 0, pcm_ptr, frame_size, fec + ) + + return array.array("h", pcm[: ret * channel_count]).tobytes() + + +class DecodeManager(threading.Thread, _OpusStruct): + def __init__(self, client): + super().__init__(daemon=True, name="DecodeManager") + + self.client = client + self.decode_queue = [] + + self.decoder = {} + + self._end_thread = threading.Event() + + def decode(self, opus_frame): + if not isinstance(opus_frame, RawData): + raise TypeError("opus_frame should be a RawData object.") + self.decode_queue.append(opus_frame) + + def run(self): + while not self._end_thread.is_set(): + try: + data = self.decode_queue.pop(0) + except IndexError: + continue + + try: + if data.decrypted_data is None: + continue + else: + data.decoded_data = self.get_decoder(data.ssrc).decode(data.decrypted_data) + except OpusError: + print("Error occurred while decoding opus frame.") + continue + + self.client.recv_decoded_audio(data) + + def stop(self): + while self.decoding: + time.sleep(0.1) + self.decoder = {} + gc.collect() + print("Decoder Process Killed") + self._end_thread.set() + + + def get_decoder(self, ssrc): + d = self.decoder.get(ssrc) + if d is None: + self.decoder[ssrc] = Decoder() + return self.decoder[ssrc] + else: + return d + + @property + def decoding(self): + return bool(self.decode_queue) - return array.array('h', pcm[:ret * channel_count]).tobytes() diff --git a/discord/sink.py b/discord/sink.py new file mode 100644 index 0000000000..ed53df1a09 --- /dev/null +++ b/discord/sink.py @@ -0,0 +1,268 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz & (c) 2021-present Pycord-Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +import wave +import logging +import os +import threading +import time +import subprocess +import sys +import struct +from .errors import SinkException + +_log = logging.getLogger(__name__) + +__all__ = ( + "Filters", + "Sink", + "AudioData", + "RawData", +) + + +if sys.platform != "win32": + CREATE_NO_WINDOW = 0 +else: + CREATE_NO_WINDOW = 0x08000000 + + +default_filters = { + "time": 0, + "users": [], + "max_size": 0, +} + + +class Filters: + """Filters for sink + + .. versionadded:: 2.0 + + Parameters + ---------- + filter_decorator: :meth:`Filters.filter_decorator` + + """ + def __init__(self, **kwargs): + self.filtered_users = kwargs.get("users", default_filters["users"]) + self.seconds = kwargs.get("time", default_filters["time"]) + self.max_size = kwargs.get("max_size", default_filters["max_size"]) + self.finished = False + + @staticmethod + def filter_decorator(func): # Contains all filters + def _filter(self, data, user): + if not self.filtered_users or user in self.filtered_users: + return func(self, data, user) + + return _filter + + def init(self): + if self.seconds != 0: + thread = threading.Thread(target=self.wait_and_stop) + thread.start() + + def wait_and_stop(self): + time.sleep(self.seconds) + if self.finished: + return + self.vc.stop_recording() + + +class RawData: + """Handles raw data from Discord so that it can be decrypted and decoded to be used. + + .. versionadded:: 2.0 + + """ + + def __init__(self, data, client): + self.data = bytearray(data) + self.client = client + + self.header = data[:12] + self.data = self.data[12:] + + unpacker = struct.Struct(">xxHII") + self.sequence, self.timestamp, self.ssrc = unpacker.unpack_from(self.header) + self.decrypted_data = getattr(self.client, "_decrypt_" + self.client.mode)( + self.header, self.data + ) + self.decoded_data = None + + self.user_id = None + + +class AudioData: + """Handles data that's been completely decrypted and decoded and is ready to be saved to file. + + .. versionadded:: 2.0 + + Raises + ------ + ClientException + The AudioData is already finished writing, + The AudioData is still writing + """ + + def __init__(self, file): + self.file = open(file, "ab") + self.dir_path = os.path.split(file)[0] + + self.finished = False + + def write(self, data): + if self.finished: + raise SinkException("The AudioData is already finished writing.") + try: + self.file.write(data) + except ValueError: + pass + + def cleanup(self): + if self.finished: + raise SinkException("The AudioData is already finished writing.") + self.file.close() + self.file = os.path.join(self.dir_path, self.file.name) + self.finished = True + + def on_format(self, encoding): + if not self.finished: + raise SinkException("The AudioData is still writing.") + name = os.path.split(self.file)[1] + name = name.split(".")[0] + f".{encoding}" + self.file = os.path.join(self.dir_path, name) + + +class Sink(Filters): + """A Sink "stores" all the audio data. + + .. versionadded:: 2.0 + + Parameters + ---------- + encoding: :class:`string` + The encoding to use. Valid types include wav, mp3, and pcm (even though it's not an actual encoding). + output_path: :class:`string` + A path to where the audio files should be output. + + Raises + ------ + ClientException + An invalid encoding type was specified. + Audio may only be formatted after recording is finished. + """ + + valid_encodings = [ + "wav", + "mp3", + "pcm", + ] + + def __init__(self, *, encoding="wav", output_path="", filters=None): + if filters is None: + filters = default_filters + self.filters = filters + Filters.__init__(self, **self.filters) + + encoding = encoding.lower() + + if encoding not in self.valid_encodings: + raise SinkException("An invalid encoding type was specified.") + + self.encoding = encoding + self.file_path = output_path + self.vc = None + self.audio_data = {} + + def init(self, vc): # called under start_recording + self.vc = vc + super().init() + + @Filters.filter_decorator + def write(self, data, user): + if user not in self.audio_data: + ssrc = self.vc.get_ssrc(user) + file = os.path.join(self.file_path, f"{ssrc}.pcm") + self.audio_data.update({user: AudioData(file)}) + + file = self.audio_data[user] + file.write(data) + + def cleanup(self): + self.finished = True + for file in self.audio_data.values(): + file.cleanup() + self.format_audio(file) + + def format_audio(self, audio): + if self.vc.recording: + raise SinkException( + "Audio may only be formatted after recording is finished." + ) + if self.encoding == "pcm": + return + if self.encoding == "mp3": + mp3_file = audio.file.split(".")[0] + ".mp3" + args = [ + "ffmpeg", + "-f", + "s16le", + "-ar", + "48000", + "-ac", + "2", + "-i", + audio.file, + mp3_file, + ] + process = None + if os.path.exists(mp3_file): + os.remove( + mp3_file + ) # process will get stuck asking whether or not to overwrite, if file already exists. + try: + process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW) + except FileNotFoundError: + raise SinkException("ffmpeg was not found.") from None + except subprocess.SubprocessError as exc: + raise SinkException( + "Popen failed: {0.__class__.__name__}: {0}".format(exc) + ) from exc + process.wait() + elif self.encoding == "wav": + with open(audio.file, "rb") as pcm: + data = pcm.read() + pcm.close() + + wav_file = audio.file.split(".")[0] + ".wav" + with wave.open(wav_file, "wb") as f: + f.setnchannels(self.vc.decoder.CHANNELS) + f.setsampwidth(self.vc.decoder.SAMPLE_SIZE // self.vc.decoder.CHANNELS) + f.setframerate(self.vc.decoder.SAMPLING_RATE) + f.writeframes(data) + f.close() + + os.remove(audio.file) + audio.on_format(self.encoding) diff --git a/discord/voice_client.py b/discord/voice_client.py index fab2c7e95e..734b9256d5 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -45,13 +45,17 @@ import logging import struct import threading +import select +import time from typing import Any, Callable, List, Optional, TYPE_CHECKING, Tuple from . import opus, utils from .backoff import ExponentialBackoff from .gateway import * -from .errors import ClientException, ConnectionClosed +from .errors import ClientException, ConnectionClosed, RecordingException from .player import AudioPlayer, AudioSource +from .sink import Sink, RawData + from .utils import MISSING if TYPE_CHECKING: @@ -67,26 +71,26 @@ VoiceServerUpdate as VoiceServerUpdatePayload, SupportedModes, ) - + has_nacl: bool try: import nacl.secret # type: ignore + has_nacl = True except ImportError: has_nacl = False __all__ = ( - 'VoiceProtocol', - 'VoiceClient', + "VoiceProtocol", + "VoiceClient", ) - - _log = logging.getLogger(__name__) + class VoiceProtocol: """A class that represents the Discord voice protocol. @@ -196,6 +200,7 @@ def cleanup(self) -> None: key_id, _ = self.channel._get_voice_client_key() self.client._connection._remove_voice_client(key_id) + class VoiceClient(VoiceProtocol): """Represents a Discord voice connection. @@ -222,12 +227,12 @@ class VoiceClient(VoiceProtocol): loop: :class:`asyncio.AbstractEventLoop` The event loop that the voice client is running on. """ + endpoint_ip: str voice_port: int secret_key: List[int] ssrc: int - def __init__(self, client: Client, channel: abc.Connectable): if not has_nacl: raise RuntimeError("PyNaCl library needed in order to use voice") @@ -254,20 +259,28 @@ def __init__(self, client: Client, channel: abc.Connectable): self._runner: asyncio.Task = MISSING self._player: Optional[AudioPlayer] = None self.encoder: Encoder = MISSING + self.decoder = None self._lite_nonce: int = 0 self.ws: DiscordVoiceWebSocket = MISSING + self.paused = False + self.recording = False + self.user_timestamps = {} + self.sink = None + self.starting_time = None + self.stopping_time = None + warn_nacl = not has_nacl supported_modes: Tuple[SupportedModes, ...] = ( - 'xsalsa20_poly1305_lite', - 'xsalsa20_poly1305_suffix', - 'xsalsa20_poly1305', + "xsalsa20_poly1305_lite", + "xsalsa20_poly1305_suffix", + "xsalsa20_poly1305", ) @property def guild(self) -> Optional[Guild]: """Optional[:class:`Guild`]: The guild we're connected to, if applicable.""" - return getattr(self.channel, 'guild', None) + return getattr(self.channel, "guild", None) @property def user(self) -> ClientUser: @@ -284,8 +297,8 @@ def checked_add(self, attr, value, limit): # connection related async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None: - self.session_id = data['session_id'] - channel_id = data['channel_id'] + self.session_id = data["session_id"] + channel_id = data["channel_id"] if not self._handshaking or self._potentially_reconnecting: # If we're done handshaking then we just need to update ourselves @@ -302,20 +315,22 @@ async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None: async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None: if self._voice_server_complete.is_set(): - _log.info('Ignoring extraneous voice server update.') + _log.info("Ignoring extraneous voice server update.") return - self.token = data.get('token') - self.server_id = int(data['guild_id']) - endpoint = data.get('endpoint') + self.token = data.get("token") + self.server_id = int(data["guild_id"]) + endpoint = data.get("endpoint") if endpoint is None or self.token is None: - _log.warning('Awaiting endpoint... This requires waiting. ' \ - 'If timeout occurred considering raising the timeout and reconnecting.') + _log.warning( + "Awaiting endpoint... This requires waiting. " + "If timeout occurred considering raising the timeout and reconnecting." + ) return - self.endpoint, _, _ = endpoint.rpartition(':') - if self.endpoint.startswith('wss://'): + self.endpoint, _, _ = endpoint.rpartition(":") + if self.endpoint.startswith("wss://"): # Just in case, strip it off since we're going to add it later self.endpoint = self.endpoint[6:] @@ -336,18 +351,24 @@ async def voice_connect(self) -> None: await self.channel.guild.change_voice_state(channel=self.channel) async def voice_disconnect(self) -> None: - _log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id) + _log.info( + "The voice handshake is being terminated for Channel ID %s (Guild ID %s)", + self.channel.id, + self.guild.id, + ) await self.channel.guild.change_voice_state(channel=None) def prepare_handshake(self) -> None: self._voice_state_complete.clear() self._voice_server_complete.clear() self._handshaking = True - _log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1) + _log.info( + "Starting voice handshake... (connection attempt %d)", self._connections + 1 + ) self._connections += 1 def finish_handshake(self) -> None: - _log.info('Voice handshake complete. Endpoint found %s', self.endpoint) + _log.info("Voice handshake complete. Endpoint found %s", self.endpoint) self._handshaking = False self._voice_server_complete.clear() self._voice_state_complete.clear() @@ -360,8 +381,8 @@ async def connect_websocket(self) -> DiscordVoiceWebSocket: self._connected.set() return ws - async def connect(self, *, reconnect: bool, timeout: float) ->None: - _log.info('Connecting to voice...') + async def connect(self, *, reconnect: bool, timeout: float) -> None: + _log.info("Connecting to voice...") self.timeout = timeout for i in range(5): @@ -389,7 +410,7 @@ async def connect(self, *, reconnect: bool, timeout: float) ->None: break except (ConnectionClosed, asyncio.TimeoutError): if reconnect: - _log.exception('Failed to connect to voice... Retrying...') + _log.exception("Failed to connect to voice... Retrying...") await asyncio.sleep(1 + i * 2.0) await self.voice_disconnect() continue @@ -406,7 +427,9 @@ async def potential_reconnect(self) -> bool: self._potentially_reconnecting = True try: # We only care about VOICE_SERVER_UPDATE since VOICE_STATE_UPDATE can come before we get disconnected - await asyncio.wait_for(self._voice_server_complete.wait(), timeout=self.timeout) + await asyncio.wait_for( + self._voice_server_complete.wait(), timeout=self.timeout + ) except asyncio.TimeoutError: self._potentially_reconnecting = False await self.disconnect(force=True) @@ -454,14 +477,21 @@ async def poll_voice_ws(self, reconnect: bool) -> None: # 4014 - voice channel has been deleted. # 4015 - voice server has crashed if exc.code in (1000, 4015): - _log.info('Disconnecting from voice normally, close code %d.', exc.code) + _log.info( + "Disconnecting from voice normally, close code %d.", + exc.code, + ) await self.disconnect() break if exc.code == 4014: - _log.info('Disconnected from voice by force... potentially reconnecting.') + _log.info( + "Disconnected from voice by force... potentially reconnecting." + ) successful = await self.potential_reconnect() if not successful: - _log.info('Reconnect was unsuccessful, disconnecting from voice normally...') + _log.info( + "Reconnect was unsuccessful, disconnecting from voice normally..." + ) await self.disconnect() break else: @@ -472,7 +502,9 @@ async def poll_voice_ws(self, reconnect: bool) -> None: raise retry = backoff.delay() - _log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry) + _log.exception( + "Disconnected from voice... Reconnecting in %.2fs.", retry + ) self._connected.clear() await asyncio.sleep(retry) await self.voice_disconnect() @@ -480,7 +512,7 @@ async def poll_voice_ws(self, reconnect: bool) -> None: await self.connect(reconnect=True, timeout=self.timeout) except asyncio.TimeoutError: # at this point we've retried 5 times... let's continue the loop. - _log.warning('Could not connect to voice... Retrying...') + _log.warning("Could not connect to voice... Retrying...") continue async def disconnect(self, *, force: bool = False) -> None: @@ -528,11 +560,11 @@ def _get_voice_packet(self, data): # Formulate rtp header header[0] = 0x80 header[1] = 0x78 - struct.pack_into('>H', header, 2, self.sequence) - struct.pack_into('>I', header, 4, self.timestamp) - struct.pack_into('>I', header, 8, self.ssrc) + struct.pack_into(">H", header, 2, self.sequence) + struct.pack_into(">I", header, 4, self.timestamp) + struct.pack_into(">I", header, 8, self.ssrc) - encrypt_packet = getattr(self, '_encrypt_' + self.mode) + encrypt_packet = getattr(self, "_encrypt_" + self.mode) return encrypt_packet(header, data) def _encrypt_xsalsa20_poly1305(self, header: bytes, data) -> bytes: @@ -552,12 +584,52 @@ def _encrypt_xsalsa20_poly1305_lite(self, header: bytes, data) -> bytes: box = nacl.secret.SecretBox(bytes(self.secret_key)) nonce = bytearray(24) - nonce[:4] = struct.pack('>I', self._lite_nonce) - self.checked_add('_lite_nonce', 1, 4294967295) + nonce[:4] = struct.pack(">I", self._lite_nonce) + self.checked_add("_lite_nonce", 1, 4294967295) return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4] - def play(self, source: AudioSource, *, after: Callable[[Optional[Exception]], Any]=None) -> None: + def _decrypt_xsalsa20_poly1305(self, header, data): + box = nacl.secret.SecretBox(bytes(self.secret_key)) + + nonce = bytearray(24) + nonce[:12] = header + + return self.strip_header_ext(box.decrypt(bytes(data), bytes(nonce))) + + def _decrypt_xsalsa20_poly1305_suffix(self, header, data): + box = nacl.secret.SecretBox(bytes(self.secret_key)) + + nonce_size = nacl.secret.SecretBox.NONCE_SIZE + nonce = data[-nonce_size:] + + return self.strip_header_ext(box.decrypt(bytes(data[:-nonce_size]), nonce)) + + def _decrypt_xsalsa20_poly1305_lite(self, header, data): + box = nacl.secret.SecretBox(bytes(self.secret_key)) + + nonce = bytearray(24) + nonce[:4] = data[-4:] + data = data[:-4] + + return self.strip_header_ext(box.decrypt(bytes(data), bytes(nonce))) + + @staticmethod + def strip_header_ext(data): + if data[0] == 0xBE and data[1] == 0xDE and len(data) > 4: + _, length = struct.unpack_from(">HH", data) + offset = 4 + length * 4 + data = data[offset:] + return data + + def get_ssrc(self, user_id): + return {info["user_id"]: ssrc for ssrc, info in self.ws.ssrc_map.items()}[ + user_id + ] + + def play( + self, source: AudioSource, *, after: Callable[[Optional[Exception]], Any] = None + ) -> None: """Plays an :class:`AudioSource`. The finalizer, ``after`` is called after the source has been exhausted @@ -587,13 +659,15 @@ def play(self, source: AudioSource, *, after: Callable[[Optional[Exception]], An """ if not self.is_connected(): - raise ClientException('Not connected to voice.') + raise ClientException("Not connected to voice.") if self.is_playing(): - raise ClientException('Already playing audio.') + raise ClientException("Already playing audio.") if not isinstance(source, AudioSource): - raise TypeError(f'source must be an AudioSource not {source.__class__.__name__}') + raise TypeError( + f"source must be an AudioSource not {source.__class__.__name__}" + ) if not self.encoder and not source.is_opus(): self.encoder = opus.Encoder() @@ -601,6 +675,171 @@ def play(self, source: AudioSource, *, after: Callable[[Optional[Exception]], An self._player = AudioPlayer(source, self, after=after) self._player.start() + def unpack_audio(self, data): + """Takes an audio packet received from Discord and decodes it into pcm audio data. + If there are no users talking in the channel, `None` will be returned. + + You must be connected to receive audio. + + .. versionadded:: 2.0 + + Parameters + --------- + data: :class:`bytes` + Bytes received by Discord via the UDP connection used for sending and receiving voice data. + """ + if 200 <= data[1] <= 204: + # RTCP received. + # RTCP provides information about the connection + # as opposed to actual audio data, so it's not + # important at the moment. + return + if self.paused: + return + + data = RawData(data, self) + + if data.decrypted_data == b"\xf8\xff\xfe": # Frame of silence + return + + self.decoder.decode(data) + + def start_recording(self, sink, callback, *args): + """The bot will begin recording audio from the current voice channel it is in. + This function uses a thread so the current code line will not be stopped. + Must be in a voice channel to use. + Must not be already recording. + + .. versionadded:: 2.0 + + Parameters + ---------- + sink: :class:`Sink` + A Sink which will "store" all the audio data. + callback: :class:`asynchronous function` + A function which is called after the bot has stopped recording. + *args: + Args which will be passed to the callback function. + Raises + ------ + RecordingException + Not connected to a voice channel. + RecordingException + Already recording. + RecordingException + Must provide a Sink object. + """ + if not self.is_connected(): + raise RecordingException("Not connected to voice channel.") + if self.recording: + raise RecordingException("Already recording.") + if not isinstance(sink, Sink): + raise RecordingException("Must provide a Sink object.") + + self.empty_socket() + + self.decoder = opus.DecodeManager(self) + self.decoder.start() + self.recording = True + self.sink = sink + sink.init(self) + + t = threading.Thread( + target=self.recv_audio, + args=( + sink, + callback, + *args, + ), + ) + t.start() + + def stop_recording(self): + """Stops the recording. + Must be already recording. + Raises + + .. versionadded:: 2.0 + + ------ + RecordingException + Not currently recording. + """ + if not self.recording: + raise RecordingException("Not currently recording audio.") + self.decoder.stop() + self.recording = False + self.paused = False + + def toggle_pause(self): + """Pauses or unpauses the recording. + Must be already recording. + + .. versionadded:: 2.0 + + Raises + ------ + RecordingException + Not currently recording. + """ + if not self.recording: + raise RecordingException("Not currently recording audio.") + self.paused = not self.paused + + def empty_socket(self): + while True: + ready, _, _ = select.select([self.socket], [], [], 0.0) + if not ready: + break + for s in ready: + s.recv(4096) + + def recv_audio(self, sink, callback, *args): + # Gets data from _recv_audio and sorts + # it by user, handles pcm files and + # silence that should be added. + + self.user_timestamps = {} + self.starting_time = time.perf_counter() + while self.recording: + ready, _, err = select.select([self.socket], [], [self.socket], 0.01) + if not ready: + if err: + print(f"Socket error: {err}") + continue + + try: + data = self.socket.recv(4096) + except OSError: + self.stop_recording() + continue + + self.unpack_audio(data) + + self.stopping_time = time.perf_counter() + self.sink.cleanup() + callback = asyncio.run_coroutine_threadsafe( + callback(self.sink, *args), self.loop + ) + result = callback.result() + + if result is not None: + print(result) + + def recv_decoded_audio(self, data): + if data.ssrc not in self.user_timestamps: + self.user_timestamps.update({data.ssrc: data.timestamp}) + # Add silence when they were not being recorded. + silence = 0 + else: + silence = data.timestamp - self.user_timestamps[data.ssrc] - 960 + self.user_timestamps[data.ssrc] = data.timestamp + + data.decoded_data = struct.pack(' bool: """Indicates if we're currently playing audio.""" return self._player is not None and self._player.is_playing() @@ -636,10 +875,10 @@ def source(self) -> Optional[AudioSource]: @source.setter def source(self, value: AudioSource) -> None: if not isinstance(value, AudioSource): - raise TypeError(f'expected AudioSource not {value.__class__.__name__}.') + raise TypeError(f"expected AudioSource not {value.__class__.__name__}.") if self._player is None: - raise ValueError('Not playing anything.') + raise ValueError("Not playing anything.") self._player._set_source(value) @@ -663,7 +902,7 @@ def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None: Encoding the data failed. """ - self.checked_add('sequence', 1, 65535) + self.checked_add("sequence", 1, 65535) if encode: encoded_data = self.encoder.encode(data, self.encoder.SAMPLES_PER_FRAME) else: @@ -672,6 +911,10 @@ def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None: try: self.socket.sendto(packet, (self.endpoint_ip, self.voice_port)) except BlockingIOError: - _log.warning('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp) + _log.warning( + "A packet has been dropped (seq: %s, timestamp: %s)", + self.sequence, + self.timestamp, + ) - self.checked_add('timestamp', opus.Encoder.SAMPLES_PER_FRAME, 4294967295) + self.checked_add("timestamp", opus.Encoder.SAMPLES_PER_FRAME, 4294967295) diff --git a/docs/api.rst b/docs/api.rst index 812d35f64a..8120291796 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -4444,6 +4444,23 @@ Select .. autofunction:: discord.ui.select +Voice Recording +--------------- + +.. attributetable:: discord.sink + +.. autoclass:: discord.sink.Filters + :members: + +.. autoclass:: discord.sink.Sink + :members: + +.. autoclass:: discord.sink.AudioData + :members: + +.. autoclass:: discord.sink.RawData + :members: + Exceptions ------------ @@ -4491,6 +4508,7 @@ Exception Hierarchy - :exc:`Exception` - :exc:`DiscordException` - :exc:`ClientException` + - :exc:`RecordingException` - :exc:`InvalidData` - :exc:`InvalidArgument` - :exc:`LoginFailure` diff --git a/examples/audio_recording.py b/examples/audio_recording.py new file mode 100644 index 0000000000..e16d018cef --- /dev/null +++ b/examples/audio_recording.py @@ -0,0 +1,55 @@ +import os +import discord +from discord.commands import Option + +bot = discord.Bot(debug_guilds=[...]) +bot.connections = {} + + +@bot.command() +async def start(ctx, encoding: Option(str, choices=["mp3", "wav", "pcm"])): + """ + Record your voice! + """ + + voice = ctx.author.voice + + if not voice: + return await ctx.respond("You're not in a vc right now") + + vc = await voice.channel.connect() + bot.connections.update({ctx.guild.id: vc}) + + vc.start_recording( + discord.Sink(encoding=encoding), + finished_callback, + ctx.channel, + ) + + await ctx.respond("The recording has started!") + + +async def finished_callback(sink, channel, *args): + + recorded_users = [ + f" <@{user_id}> ({os.path.split(audio.file)[1]}) " + for user_id, audio in sink.audio_data.items() + ] + await sink.vc.disconnect() + await channel.send(f"Finished! Recorded audio for {', '.join(recorded_users)}.") + +@bot.command() +async def stop(ctx): + """ + Stop recording. + """ + if ctx.guild.id in bot.connections: + vc = bot.connections[ctx.guild.id] + vc.stop_recording() + del bot.connections[ctx.guild.id] + await ctx.delete() + else: + await ctx.respond("Not recording in this guild.") + + +bot.run("TOKEN") \ No newline at end of file