diff --git a/src/discord_bot/bot.py b/src/discord_bot/bot.py index efdf70d..95e6b86 100644 --- a/src/discord_bot/bot.py +++ b/src/discord_bot/bot.py @@ -21,7 +21,7 @@ async def on_ready() -> None: @interactions.listen() async def on_message_create(event: interactions.api.events.MessageCreate) -> None: - print(f"message received: {event.message.content}") + print(f"Message received: {event.message.content}") @interactions.slash_command(name="ask", description="Ask an LLM", scopes=[DISCORD_GUILD_ID]) @@ -43,7 +43,8 @@ async def ask_model(ctx: interactions.SlashContext, model: str = "", prompt: str await ctx.defer() response = await answer_question(model, prompt, AI_SERVER_URL) - await ctx.send(response) + for r in response: + await ctx.send(r) @ask_model.autocomplete("model") diff --git a/src/discord_bot/llm.py b/src/discord_bot/llm.py index ca2d3ab..e20273b 100644 --- a/src/discord_bot/llm.py +++ b/src/discord_bot/llm.py @@ -1,21 +1,72 @@ """ Although this module uses `openai` package but we are routing it -through our LiteLLM proxy to interact with Ollama and OpenAI models +through our LiteLLM proxy to interact with Ollama and OpenAI models by modifying the `base_url`. """ import openai +MAX_CHARACTERS = 2000 +QUESTION_CUT_OFF_LENGTH = 150 +RESERVED_SPACE = 50 # for other additional strings. E.g. number `(1/4)`, `Q: `, `A: `, etc. -async def answer_question(model: str, question: str, server_url: str) -> str: + +async def answer_question(model: str, question: str, server_url: str) -> list[str]: try: client = openai.AsyncOpenAI(base_url=server_url, api_key="FAKE") response = await client.chat.completions.create( model=model, messages=[{"role": "user", "content": question}], ) - out = response.choices[0].message.content or "No response from the model. Please try again" - return out + content = response.choices[0].message.content or "No response from the model. Please try again" + messages = split(content) + messages = add_number(messages) + messages = add_question(messages, question) + return messages except Exception as e: - return f"Error: {e}" + return split(f"Error: {e}") + + +def split(answer: str) -> list[str]: + """ + Split the answer into a list of smaller strings so that + each element is less than MAX_CHARACTERS characters. + Full sentences are preserved. + """ + limit = MAX_CHARACTERS - RESERVED_SPACE - QUESTION_CUT_OFF_LENGTH + messages = [] + answer = answer.strip() + + while len(answer) > limit: + last_period = answer[:limit].rfind(".") + if last_period == -1: + last_period = answer[:limit].rfind(" ") + messages.append(answer[: last_period + 1]) + answer = answer[last_period + 1 :] + + messages.append(answer) + + return messages + + +def add_question(messages: list[str], questions: str) -> list[str]: + """ + Add the asked question to the beginning of each message. + """ + return [(f"Q: {questions[:QUESTION_CUT_OFF_LENGTH]}\n" + f"A: {message}") for message in messages] + + +def add_number(messages: list[str]) -> list[str]: + """ + Add the number to the beginning of each message. E.g. `(1/4)` + Do nothing if the length of `messages` is 1. + """ + if len(messages) == 1: + return messages + + for i, message in enumerate(messages): + message = message.strip() + messages[i] = f"({i+1}/{len(messages)}) {message}" + + return messages