Skip to content

Commit

Permalink
safety fixes for interaction responses, list removals
Browse files Browse the repository at this point in the history
  • Loading branch information
Kav-K committed Nov 11, 2023
1 parent 70bfe9c commit 1f0ab52
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 43 deletions.
6 changes: 3 additions & 3 deletions cogs/code_interpreter_service_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from services.deletion_service import Deletion
from services.environment_service import EnvService
from services.moderations_service import Moderation
from utils.safe_ctx_respond import safe_ctx_respond
from utils.safe_ctx_respond import safe_ctx_respond, safe_remove_list


class CaptureStdout:
Expand Down Expand Up @@ -240,7 +240,7 @@ async def on_message(self, message):
await message.reply(
embed=EmbedStatics.get_code_chat_failure_embed(response)
)
self.thread_awaiting_responses.remove(message.channel.id)
safe_remove_list(self.thread_awaiting_responses, message.channel.id)
return

# Parse the artifact names. After Artifacts: there should be a list in form [] where the artifact names are inside, comma separated inside stdout_output
Expand Down Expand Up @@ -292,7 +292,7 @@ async def on_message(self, message):
else None,
)

self.thread_awaiting_responses.remove(message.channel.id)
safe_remove_list(self.thread_awaiting_responses, message.channel.id)

class SessionedCodeExecutor:
def __init__(self):
Expand Down
9 changes: 5 additions & 4 deletions cogs/index_service_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from services.moderations_service import Moderation
from services.text_service import TextService
from models.index_model import Index_handler
from utils.safe_ctx_respond import safe_remove_list

USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys()
USER_KEY_DB = EnvService.get_api_db()
Expand Down Expand Up @@ -79,7 +80,7 @@ async def process_indexing(self, message, index_type, content=None, link=None):
)
failure_embed.set_thumbnail(url="https://i.imgur.com/hbdBZfG.png")
await message.reply(embed=failure_embed)
self.thread_awaiting_responses.remove(message.channel.id)
safe_remove_list(self.thread_awaiting_responses, message.channel.id)
return False

success_embed = discord.Embed(
Expand Down Expand Up @@ -146,7 +147,7 @@ async def on_message(self, message):
)

if not indexing_result:
self.thread_awaiting_responses.remove(message.channel.id)
safe_remove_list(self.thread_awaiting_responses, message.channel.id)
return

prompt += (
Expand All @@ -165,7 +166,7 @@ async def on_message(self, message):
)

if not indexing_result:
self.thread_awaiting_responses.remove(message.channel.id)
safe_remove_list(self.thread_awaiting_responses, message.channel.id)
return

prompt += (
Expand Down Expand Up @@ -204,7 +205,7 @@ async def on_message(self, message):
await message.reply(
embed=response_embed,
)
self.thread_awaiting_responses.remove(message.channel.id)
safe_remove_list(self.thread_awaiting_responses, message.channel.id)

async def index_chat_command(self, ctx, model):
await self.index_handler.start_index_chat(ctx, model)
Expand Down
6 changes: 3 additions & 3 deletions cogs/search_service_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from services.moderations_service import Moderation
from services.text_service import TextService
from models.openai_model import Models
from utils.safe_ctx_respond import safe_ctx_respond
from utils.safe_ctx_respond import safe_ctx_respond, safe_remove_list

from contextlib import redirect_stdout

Expand Down Expand Up @@ -390,7 +390,7 @@ async def on_message(self, message):
await message.reply(
embed=EmbedStatics.get_internet_chat_failure_embed(response)
)
self.thread_awaiting_responses.remove(message.channel.id)
safe_remove_list(self.thread_awaiting_responses, message.channel.id)
return

if len(response) > 2000:
Expand All @@ -416,7 +416,7 @@ async def on_message(self, message):
)
await message.reply(embed=response_embed)

self.thread_awaiting_responses.remove(message.channel.id)
safe_remove_list(self.thread_awaiting_responses, message.channel.id)

async def search_chat_command(
self, ctx: discord.ApplicationContext, model, search_scope=2
Expand Down
13 changes: 5 additions & 8 deletions cogs/text_service_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from services.pickle_service import Pickler
from services.sharegpt_service import ShareGPTService
from services.text_service import SetupModal, TextService
from utils.safe_ctx_respond import safe_ctx_respond
from utils.safe_ctx_respond import safe_ctx_respond, safe_remove_list

original_message = {}
ALLOWED_GUILDS = EnvService.get_allowed_guilds()
Expand Down Expand Up @@ -764,11 +764,9 @@ def remove_awaiting(
self, author_id, channel_id, from_ask_command, from_edit_command
):
"""Remove user from ask/edit command response wait, if not any of those then process the id to remove user from thread response wait"""
if author_id in self.awaiting_responses:
self.awaiting_responses.remove(author_id)
safe_remove_list(self.awaiting_responses, author_id)
if not from_ask_command and not from_edit_command:
if channel_id in self.awaiting_thread_responses:
self.awaiting_thread_responses.remove(channel_id)
safe_remove_list(self.awaiting_thread_responses, channel_id)

async def mention_to_username(self, ctx, message):
"""replaces discord mentions with their server nickname in text, if the user is not found keep the mention as is"""
Expand Down Expand Up @@ -1358,9 +1356,8 @@ async def converse_command(
model=self.conversation_threads[target.id].model,
custom_api_key=user_api_key,
)
self.awaiting_responses.remove(user_id_normalized)
if target.id in self.awaiting_thread_responses:
self.awaiting_thread_responses.remove(target.id)
safe_remove_list(self.awaiting_responses, user_id_normalized)
safe_remove_list(self.awaiting_thread_responses, target.id)

async def end_command(self, ctx: discord.ApplicationContext):
"""Command handler. Gets the user's thread and ends it"""
Expand Down
2 changes: 1 addition & 1 deletion gpt3discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from models.openai_model import Model


__version__ = "12.2.0"
__version__ = "12.2.1"


PID_FILE = Path("bot.pid")
Expand Down
39 changes: 15 additions & 24 deletions utils/safe_ctx_respond.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import discord

async def safe_ctx_respond(*args, **kwargs) -> None:

def safe_remove_list(remove_from, element):
try:
remove_from.remove(element)
except ValueError:
pass


async def safe_ctx_respond(ctx: discord.ApplicationContext, content: str) -> None:
"""
Safely responds to a Discord interaction.
Expand All @@ -18,35 +26,18 @@ async def safe_ctx_respond(*args, **kwargs) -> None:
await safe_ctx_respond(ctx=ctx, content="Hello World!")
```
"""
# Get the context from the kwargs
ctx: discord.ApplicationContext = kwargs.get("ctx", None)
kwargs.pop("ctx", None)

# Raise an error if context is not provided
if ctx is None:
raise ValueError("ctx is a required keyword argument")

try:
# Try to respond to the interaction
await ctx.respond(*args, **kwargs)
await ctx.respond(content)
except discord.NotFound: # NotFound is raised when the interaction is not found
try:
# If the interaction is not found, try to reply to the message
if kwargs.get("ephemeral", False):
kwargs.pop("ephemeral")
kwargs["delete_after"] = 5
await ctx.message.reply(*args, **kwargs)
await ctx.message.reply(content)
except (
discord.NotFound,
AttributeError,
discord.NotFound,
AttributeError,
): # AttributeError is raised when ctx.message is None, NotFound is raised when the message is not found
# If the message is not found, send a new message to the channel
if len(args) > 0:
content = args[0] or ""
args = args[1:]
else:
content = kwargs.get("content", "")
kwargs["content"] = f"**{ctx.author.mention}** \n{content}".strip(
content = f"**{ctx.message.author.mention}** \n{content}".strip(
"\n"
).strip()
await ctx.channel.send(*args, **kwargs)
await ctx.channel.send(content)

0 comments on commit 1f0ab52

Please sign in to comment.