Skip to content

Commit

Permalink
Fix voice state typing
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfies committed Dec 29, 2023
1 parent 2dbb369 commit f6d8c66
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 24 deletions.
4 changes: 2 additions & 2 deletions discord/guild.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
from .types.threads import (
Thread as ThreadPayload,
)
from .types.voice import GuildVoiceState
from .types.voice import BaseVoiceState as VoiceStatePayload
from .permissions import Permissions
from .channel import VoiceChannel, StageChannel, TextChannel, ForumChannel, CategoryChannel
from .template import Template
Expand Down Expand Up @@ -572,7 +572,7 @@ def __repr__(self) -> str:
return f'<Guild {inner}>'

def _update_voice_state(
self, data: GuildVoiceState, channel_id: Optional[int]
self, data: VoiceStatePayload, channel_id: Optional[int]
) -> Tuple[Optional[Member], VoiceState, VoiceState]:
cache_flags = self._state.member_cache_flags
user_id = int(data['user_id'])
Expand Down
9 changes: 3 additions & 6 deletions discord/member.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,7 @@
from .state import ConnectionState, Presence
from .message import Message
from .role import Role
from .types.voice import (
GuildVoiceState as GuildVoiceStatePayload,
VoiceState as VoiceStatePayload,
)
from .types.voice import BaseVoiceState as VoiceStatePayload
from .user import Note
from .relationship import Relationship
from .calls import PrivateCall
Expand Down Expand Up @@ -147,12 +144,12 @@ class VoiceState:
)

def __init__(
self, *, data: Union[VoiceStatePayload, GuildVoiceStatePayload], channel: Optional[ConnectableChannel] = None
self, *, data: VoiceStatePayload, channel: Optional[ConnectableChannel] = None
):
self.session_id: Optional[str] = data.get('session_id')
self._update(data, channel)

def _update(self, data: Union[VoiceStatePayload, GuildVoiceStatePayload], channel: Optional[ConnectableChannel]):
def _update(self, data: VoiceStatePayload, channel: Optional[ConnectableChannel]):
self.self_mute: bool = data.get('self_mute', False)
self.self_deaf: bool = data.get('self_deaf', False)
self.self_stream: bool = data.get('self_stream', False)
Expand Down
9 changes: 7 additions & 2 deletions discord/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
PartialMessage as PartialMessagePayload,
)
from .types import gateway as gw
from .types.voice import VoiceState as VoiceStatePayload
from .types.voice import BaseVoiceState as VoiceStatePayload
from .types.activity import ClientStatus as ClientStatusPayload

T = TypeVar('T')
Expand Down Expand Up @@ -2797,11 +2797,16 @@ def parse_call_delete(self, data: gw.CallDeleteEvent) -> None:
self.dispatch('call_delete', call)

def parse_voice_state_update(self, data: gw.VoiceStateUpdateEvent) -> None:
guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id'))
guild_id = utils._get_as_snowflake(data, 'guild_id')
guild = self._get_guild(guild_id)
channel_id = utils._get_as_snowflake(data, 'channel_id')
flags = self.member_cache_flags
self_id = self.self_id

if guild_id is not None and guild is None:
_log.debug('VOICE_STATE_UPDATE referencing unknown guild ID: %s. Discarding.', guild_id)
return

if int(data['user_id']) == self_id:
voice = self._get_voice_client(guild.id if guild else self_id)
if voice is not None:
Expand Down
6 changes: 3 additions & 3 deletions discord/types/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
User,
UserGuildSettings,
)
from .voice import GuildVoiceState, VoiceState
from .voice import GuildVoiceState, PrivateVoiceState, VoiceState

T = TypeVar('T')

Expand Down Expand Up @@ -428,7 +428,7 @@ class _GuildScheduledEventUsersEvent(TypedDict):

GuildScheduledEventUserAdd = GuildScheduledEventUserRemove = _GuildScheduledEventUsersEvent

VoiceStateUpdateEvent = GuildVoiceState
VoiceStateUpdateEvent = Union[GuildVoiceState, PrivateVoiceState]


class VoiceServerUpdateEvent(TypedDict):
Expand Down Expand Up @@ -559,7 +559,7 @@ class PartialUpdateChannel(TypedDict):
class PassiveUpdateEvent(TypedDict):
guild_id: Snowflake
channels: List[PartialUpdateChannel]
voice_states: NotRequired[List[GuildVoiceState]]
voice_states: NotRequired[List[VoiceState]]
members: NotRequired[List[MemberWithUser]]


Expand Down
6 changes: 3 additions & 3 deletions discord/types/guild.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .sticker import GuildSticker
from .snowflake import Snowflake
from .channel import GuildChannel, StageInstance
from .voice import GuildVoiceState
from .voice import VoiceState
from .welcome_screen import WelcomeScreen
from .activity import PartialPresenceUpdate
from .role import Role
Expand Down Expand Up @@ -116,7 +116,7 @@ class Guild(UnavailableGuild, _GuildMedia):
joined_at: NotRequired[Optional[str]]
large: NotRequired[bool]
member_count: NotRequired[int]
voice_states: NotRequired[List[GuildVoiceState]]
voice_states: NotRequired[List[VoiceState]]
members: NotRequired[List[MemberWithUser]]
channels: NotRequired[List[GuildChannel]]
presences: NotRequired[List[PartialPresenceUpdate]]
Expand Down Expand Up @@ -189,4 +189,4 @@ class CommandScopeMigration(TypedDict):

class SupplementalGuild(UnavailableGuild):
embedded_activities: list
voice_states: List[GuildVoiceState]
voice_states: List[VoiceState]
13 changes: 8 additions & 5 deletions discord/types/voice.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
SupportedModes = Literal['xsalsa20_poly1305_lite', 'xsalsa20_poly1305_suffix', 'xsalsa20_poly1305']


class _VoiceState(TypedDict):
class BaseVoiceState(TypedDict):
user_id: Snowflake
session_id: str
deaf: bool
Expand All @@ -45,13 +45,16 @@ class _VoiceState(TypedDict):
self_stream: NotRequired[bool]


class GuildVoiceState(_VoiceState):
class VoiceState(BaseVoiceState):
channel_id: Snowflake


class VoiceState(_VoiceState, total=False):
channel_id: NotRequired[Optional[Snowflake]]
guild_id: NotRequired[Optional[Snowflake]]
class PrivateVoiceState(BaseVoiceState):
channel_id: Optional[Snowflake]


class GuildVoiceState(PrivateVoiceState):
guild_id: Snowflake


class VoiceRegion(TypedDict):
Expand Down
6 changes: 3 additions & 3 deletions discord/voice_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@
from .opus import Encoder
from . import abc

from .types.gateway import VoiceStateUpdateEvent as VoiceStateUpdatePayload
from .types.voice import (
GuildVoiceState as GuildVoiceStatePayload,
VoiceServerUpdate as VoiceServerUpdatePayload,
SupportedModes,
)
Expand Down Expand Up @@ -114,7 +114,7 @@ def __init__(self, client: Client, channel: VocalChannel) -> None:
self.client: Client = client
self.channel: VocalChannel = channel

async def on_voice_state_update(self, data: GuildVoiceStatePayload, /) -> None:
async def on_voice_state_update(self, data: VoiceStateUpdatePayload, /) -> None:
"""|coro|
An abstract method that is called when the client's voice state
Expand Down Expand Up @@ -288,7 +288,7 @@ def checked_add(self, attr: str, value: int, limit: int) -> None:

# connection related

async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
async def on_voice_state_update(self, data: VoiceStateUpdatePayload) -> None:
self.session_id: str = data['session_id']
channel_id = data['channel_id']

Expand Down

0 comments on commit f6d8c66

Please sign in to comment.