Skip to content

Commit

Permalink
Make channel access checks consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
MinnDevelopment committed May 12, 2024
1 parent b2010f3 commit f37cc49
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,6 @@ public boolean canTalk()
return user == null || !user.isBot();
}

@Override
public void checkCanAccessChannel() {}

@Override
public void checkCanSendMessage() {
checkBot();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ public String getStatus()
public AuditableRestAction<Void> modifyStatus(@Nonnull String status)
{
Checks.notLonger(status, MAX_STATUS_LENGTH, "Voice Status");
checkCanAccessChannel();
checkCanAccess();
if (this.equals(getGuild().getSelfMember().getVoiceState().getChannel()))
checkPermission(Permission.VOICE_SET_STATUS);
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
package net.dv8tion.jda.internal.entities.channel.mixin.middleman;

import gnu.trove.map.TLongObjectMap;
import net.dv8tion.jda.api.Permission;
import net.dv8tion.jda.api.entities.Member;
import net.dv8tion.jda.api.entities.channel.unions.AudioChannelUnion;
import net.dv8tion.jda.api.exceptions.MissingAccessException;

public interface AudioChannelMixin<T extends AudioChannelMixin<T>>
extends AudioChannelUnion, StandardGuildChannelMixin<T>
Expand All @@ -31,4 +33,14 @@ public interface AudioChannelMixin<T extends AudioChannelMixin<T>>
T setUserLimit(int userlimit);

T setRegion(String region);

// AudioChannels also require connect permission to grant access
@Override
default void checkCanAccess()
{
if (!hasPermission(Permission.VIEW_CHANNEL))
throw new MissingAccessException(this, Permission.VIEW_CHANNEL);
if (!hasPermission(Permission.VOICE_CONNECT))
throw new MissingAccessException(this, Permission.VOICE_CONNECT);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import net.dv8tion.jda.api.entities.channel.middleman.GuildChannel;
import net.dv8tion.jda.api.entities.channel.unions.GuildChannelUnion;
import net.dv8tion.jda.api.exceptions.InsufficientPermissionException;
import net.dv8tion.jda.api.exceptions.MissingAccessException;
import net.dv8tion.jda.api.requests.Route;
import net.dv8tion.jda.api.requests.restaction.AuditableRestAction;
import net.dv8tion.jda.internal.entities.channel.mixin.ChannelMixin;
Expand All @@ -40,6 +41,7 @@ public interface GuildChannelMixin<T extends GuildChannelMixin<T>> extends
@CheckReturnValue
default AuditableRestAction<Void> delete()
{
checkCanAccess();
checkCanManage();

Route.CompiledRoute route = Route.Channels.DELETE_CHANNEL.compile(getId());
Expand Down Expand Up @@ -70,4 +72,11 @@ default void checkCanManage()
{
checkPermission(Permission.MANAGE_CHANNEL);
}

// Overridden by AudioChannelMixin
default void checkCanAccess()
{
if (!hasPermission(Permission.VIEW_CHANNEL))
throw new MissingAccessException(this, Permission.VIEW_CHANNEL);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,17 @@ default RestAction<Void> clearReactionsById(@Nonnull String messageId, @Nonnull
@Override
default MessageCreateAction sendStickers(@Nonnull Collection<? extends StickerSnowflake> stickers)
{
checkCanAccessChannel();
checkCanSendMessage();
Checks.notEmpty(stickers, "Stickers");
Checks.noneNull(stickers, "Stickers");
return new MessageCreateActionImpl(this).setStickers(stickers);
}

// ---- Default implementation of parent mixins hooks ----
default void checkCanAccessChannel()
{
checkPermission(Permission.VIEW_CHANNEL);
}

default void checkCanSendMessage()
{
checkCanAccess();
if (getType().isThread())
checkPermission(Permission.MESSAGE_SEND_IN_THREADS);
else
Expand All @@ -131,32 +127,38 @@ default void checkCanSendMessage()

default void checkCanSendMessageEmbeds()
{
checkCanAccess();
checkPermission(Permission.MESSAGE_EMBED_LINKS);
}

default void checkCanSendFiles()
{
checkCanAccess();
checkPermission(Permission.MESSAGE_ATTACH_FILES);
}

default void checkCanViewHistory()
{
checkCanAccess();
checkPermission(Permission.MESSAGE_HISTORY);
}

default void checkCanAddReactions()
{
checkCanAccess();
checkPermission(Permission.MESSAGE_ADD_REACTION);
checkPermission(Permission.MESSAGE_HISTORY, "You need MESSAGE_HISTORY to add reactions to a message");
}

default void checkCanRemoveReactions()
{
checkCanAccess();
checkPermission(Permission.MESSAGE_HISTORY, "You need MESSAGE_HISTORY to remove reactions from a message");
}

default void checkCanControlMessagePins()
{
checkCanAccess();
checkPermission(Permission.MESSAGE_MANAGE, "You need MESSAGE_MANAGE to pin or unpin messages.");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ else if (!toDelete.isEmpty())
@CheckReturnValue
default MessageCreateAction sendMessage(@Nonnull CharSequence text)
{
checkCanAccessChannel();
checkCanSendMessage();
return MessageChannelUnion.super.sendMessage(text);
}
Expand All @@ -141,7 +140,6 @@ default MessageCreateAction sendMessage(@Nonnull CharSequence text)
@CheckReturnValue
default MessageCreateAction sendMessageEmbeds(@Nonnull MessageEmbed embed, @Nonnull MessageEmbed... other)
{
checkCanAccessChannel();
checkCanSendMessage();
checkCanSendMessageEmbeds();
return MessageChannelUnion.super.sendMessageEmbeds(embed, other);
Expand All @@ -151,7 +149,6 @@ default MessageCreateAction sendMessageEmbeds(@Nonnull MessageEmbed embed, @Nonn
@CheckReturnValue
default MessageCreateAction sendMessageEmbeds(@Nonnull Collection<? extends MessageEmbed> embeds)
{
checkCanAccessChannel();
checkCanSendMessage();
checkCanSendMessageEmbeds();
return MessageChannelUnion.super.sendMessageEmbeds(embeds);
Expand All @@ -161,7 +158,6 @@ default MessageCreateAction sendMessageEmbeds(@Nonnull Collection<? extends Mess
@Override
default MessageCreateAction sendMessageComponents(@NotNull LayoutComponent component, @NotNull LayoutComponent... other)
{
checkCanAccessChannel();
checkCanSendMessage();
return MessageChannelUnion.super.sendMessageComponents(component, other);
}
Expand All @@ -170,7 +166,6 @@ default MessageCreateAction sendMessageComponents(@NotNull LayoutComponent compo
@Override
default MessageCreateAction sendMessageComponents(@Nonnull Collection<? extends LayoutComponent> components)
{
checkCanAccessChannel();
checkCanSendMessage();
return MessageChannelUnion.super.sendMessageComponents(components);
}
Expand All @@ -179,7 +174,6 @@ default MessageCreateAction sendMessageComponents(@Nonnull Collection<? extends
@Override
default MessageCreateAction sendMessagePoll(@Nonnull MessagePollData poll)
{
checkCanAccessChannel();
checkCanSendMessage();
return MessageChannelUnion.super.sendMessagePoll(poll);
}
Expand All @@ -188,7 +182,6 @@ default MessageCreateAction sendMessagePoll(@Nonnull MessagePollData poll)
@CheckReturnValue
default MessageCreateAction sendMessage(@Nonnull MessageCreateData msg)
{
checkCanAccessChannel();
checkCanSendMessage();
return MessageChannelUnion.super.sendMessage(msg);
}
Expand All @@ -197,7 +190,6 @@ default MessageCreateAction sendMessage(@Nonnull MessageCreateData msg)
@CheckReturnValue
default MessageCreateAction sendFiles(@Nonnull Collection<? extends FileUpload> files)
{
checkCanAccessChannel();
checkCanSendMessage();
checkCanSendFiles();
return MessageChannelUnion.super.sendFiles(files);
Expand All @@ -207,7 +199,6 @@ default MessageCreateAction sendFiles(@Nonnull Collection<? extends FileUpload>
@CheckReturnValue
default RestAction<Message> retrieveMessageById(@Nonnull String messageId)
{
checkCanAccessChannel();
checkCanViewHistory();
return MessageChannelUnion.super.retrieveMessageById(messageId);
}
Expand All @@ -216,7 +207,6 @@ default RestAction<Message> retrieveMessageById(@Nonnull String messageId)
@CheckReturnValue
default AuditableRestAction<Void> deleteMessageById(@Nonnull String messageId)
{
checkCanAccessChannel();
//We don't know if this is a Message sent by us or another user, so we can't run checks for Permission.MESSAGE_MANAGE
return MessageChannelUnion.super.deleteMessageById(messageId);
}
Expand All @@ -225,7 +215,6 @@ default AuditableRestAction<Void> deleteMessageById(@Nonnull String messageId)
@Override
default MessageHistory getHistory()
{
checkCanAccessChannel();
checkCanViewHistory();
return MessageChannelUnion.super.getHistory();
}
Expand All @@ -234,7 +223,6 @@ default MessageHistory getHistory()
@CheckReturnValue
default MessagePaginationAction getIterableHistory()
{
checkCanAccessChannel();
checkCanViewHistory();
return MessageChannelUnion.super.getIterableHistory();
}
Expand All @@ -243,7 +231,6 @@ default MessagePaginationAction getIterableHistory()
@CheckReturnValue
default MessageHistory.MessageRetrieveAction getHistoryAround(@Nonnull String messageId, int limit)
{
checkCanAccessChannel();
checkCanViewHistory();
return MessageChannelUnion.super.getHistoryAround(messageId, limit);
}
Expand All @@ -252,7 +239,6 @@ default MessageHistory.MessageRetrieveAction getHistoryAround(@Nonnull String me
@CheckReturnValue
default MessageHistory.MessageRetrieveAction getHistoryAfter(@Nonnull String messageId, int limit)
{
checkCanAccessChannel();
checkCanViewHistory();
return MessageChannelUnion.super.getHistoryAfter(messageId, limit);
}
Expand All @@ -261,7 +247,6 @@ default MessageHistory.MessageRetrieveAction getHistoryAfter(@Nonnull String mes
@CheckReturnValue
default MessageHistory.MessageRetrieveAction getHistoryBefore(@Nonnull String messageId, int limit)
{
checkCanAccessChannel();
checkCanViewHistory();
return MessageChannelUnion.super.getHistoryBefore(messageId, limit);
}
Expand All @@ -270,7 +255,6 @@ default MessageHistory.MessageRetrieveAction getHistoryBefore(@Nonnull String me
@CheckReturnValue
default MessageHistory.MessageRetrieveAction getHistoryFromBeginning(int limit)
{
checkCanAccessChannel();
checkCanViewHistory();
return MessageHistory.getHistoryFromBeginning(this).limit(limit);
}
Expand All @@ -279,15 +263,13 @@ default MessageHistory.MessageRetrieveAction getHistoryFromBeginning(int limit)
@CheckReturnValue
default RestAction<Void> sendTyping()
{
checkCanAccessChannel();
return MessageChannelUnion.super.sendTyping();
}

@Nonnull
@CheckReturnValue
default RestAction<Void> addReactionById(@Nonnull String messageId, @Nonnull Emoji emoji)
{
checkCanAccessChannel();
checkCanAddReactions();
return MessageChannelUnion.super.addReactionById(messageId, emoji);
}
Expand All @@ -296,7 +278,6 @@ default RestAction<Void> addReactionById(@Nonnull String messageId, @Nonnull Emo
@CheckReturnValue
default RestAction<Void> removeReactionById(@Nonnull String messageId, @Nonnull Emoji emoji)
{
checkCanAccessChannel();
checkCanRemoveReactions();
return MessageChannelUnion.super.removeReactionById(messageId, emoji);
}
Expand All @@ -305,7 +286,6 @@ default RestAction<Void> removeReactionById(@Nonnull String messageId, @Nonnull
@CheckReturnValue
default ReactionPaginationAction retrieveReactionUsersById(@Nonnull String messageId, @Nonnull Emoji emoji)
{
checkCanAccessChannel();
checkCanRemoveReactions();
return MessageChannelUnion.super.retrieveReactionUsersById(messageId, emoji);
}
Expand All @@ -314,7 +294,6 @@ default ReactionPaginationAction retrieveReactionUsersById(@Nonnull String messa
@CheckReturnValue
default RestAction<Void> pinMessageById(@Nonnull String messageId)
{
checkCanAccessChannel();
checkCanControlMessagePins();
return MessageChannelUnion.super.pinMessageById(messageId);
}
Expand All @@ -323,7 +302,6 @@ default RestAction<Void> pinMessageById(@Nonnull String messageId)
@CheckReturnValue
default RestAction<Void> unpinMessageById(@Nonnull String messageId)
{
checkCanAccessChannel();
checkCanControlMessagePins();
return MessageChannelUnion.super.unpinMessageById(messageId);
}
Expand All @@ -332,15 +310,13 @@ default RestAction<Void> unpinMessageById(@Nonnull String messageId)
@CheckReturnValue
default RestAction<List<Message>> retrievePinnedMessages()
{
checkCanAccessChannel();
return MessageChannelUnion.super.retrievePinnedMessages();
}

@Nonnull
@CheckReturnValue
default MessageEditAction editMessageById(@Nonnull String messageId, @Nonnull CharSequence newContent)
{
checkCanAccessChannel();
checkCanSendMessage();
return MessageChannelUnion.super.editMessageById(messageId, newContent);
}
Expand All @@ -349,7 +325,6 @@ default MessageEditAction editMessageById(@Nonnull String messageId, @Nonnull Ch
@CheckReturnValue
default MessageEditAction editMessageById(@Nonnull String messageId, @Nonnull MessageEditData data)
{
checkCanAccessChannel();
checkCanSendMessage();
return MessageChannelUnion.super.editMessageById(messageId, data);
}
Expand All @@ -359,7 +334,6 @@ default MessageEditAction editMessageById(@Nonnull String messageId, @Nonnull Me
@CheckReturnValue
default MessageEditAction editMessageEmbedsById(@Nonnull String messageId, @Nonnull Collection<? extends MessageEmbed> newEmbeds)
{
checkCanAccessChannel();
checkCanSendMessage();
checkCanSendMessageEmbeds();
return MessageChannelUnion.super.editMessageEmbedsById(messageId, newEmbeds);
Expand All @@ -369,7 +343,6 @@ default MessageEditAction editMessageEmbedsById(@Nonnull String messageId, @Nonn
@CheckReturnValue
default MessageEditAction editMessageComponentsById(@Nonnull String messageId, @Nonnull Collection<? extends LayoutComponent> components)
{
checkCanAccessChannel();
checkCanSendMessage();
return MessageChannelUnion.super.editMessageComponentsById(messageId, components);
}
Expand All @@ -378,7 +351,6 @@ default MessageEditAction editMessageComponentsById(@Nonnull String messageId, @
@Override
default MessageEditAction editMessageAttachmentsById(@Nonnull String messageId, @Nonnull Collection<? extends AttachedFile> attachments)
{
checkCanAccessChannel();
checkCanSendMessage();
return MessageChannelUnion.super.editMessageAttachmentsById(messageId, attachments);
}
Expand All @@ -387,7 +359,6 @@ default MessageEditAction editMessageAttachmentsById(@Nonnull String messageId,
T setLatestMessageIdLong(long latestMessageId);

// ---- Mixin Hooks ----
void checkCanAccessChannel();
void checkCanSendMessage();
void checkCanSendMessageEmbeds();
void checkCanSendFiles();
Expand Down

0 comments on commit f37cc49

Please sign in to comment.