diff --git a/.env.example b/.env.example index 2aa516d..b03d93a 100644 --- a/.env.example +++ b/.env.example @@ -2,6 +2,6 @@ AI_SERVER_URL="http://localhost:8000" # Discord Bot -DISCORD_BOT_TOKEN="EXAMPLETokenTTTTT.Grz0fK.IUi1IONPcLHZsnyjROTt8lR2fuYCeK4grzRoSQ" +DISCORD_BOT_TOKEN="EXAMPLETokennnnnn.Gro00t.aaaaaBBBBBccccc11111DDDDDDD2222eeeeeeee" # THIS SHOULD BE ASSIGNED TO THE BOT'S `scopes` WHEN DEPLOYED DISCORD_GUILD_ID="99999999999999999999999" diff --git a/src/discord_bot/llm.py b/src/discord_bot/llm.py index 12b3acb..9725993 100644 --- a/src/discord_bot/llm.py +++ b/src/discord_bot/llm.py @@ -7,17 +7,14 @@ import os import openai import re -from utils import pdf - -MAX_CHARACTERS = 2000 -QUESTION_CUT_OFF_LENGTH = 150 -RESERVED_SPACE = 50 # for other additional strings. E.g. number `(1/4)`, `Q: `, `A: `, etc. +from utils import llm_response, pdf async def answer_question(model: str, question: str, server_url: str, attach_question_to_message: bool = True) -> list[str]: """ Calls the LLM model with the specified question and server URL to get the answer. """ + messages: list[str] try: client = openai.AsyncOpenAI(base_url=server_url, api_key="FAKE") response = await client.chat.completions.create( @@ -25,30 +22,36 @@ async def answer_question(model: str, question: str, server_url: str, attach_que messages=[{"role": "user", "content": question}], ) content = response.choices[0].message.content or "No response from the model. Please try again" - messages = split(content) - messages = add_number(messages) + messages = llm_response.split(content) + messages = llm_response.add_number(messages) if attach_question_to_message: - messages = add_question(messages, question) + messages = llm_response.add_question(messages, question) return messages except Exception as e: - return split(f"Error in calling the LLM: {e}") + messages = llm_response.split(f"Error in calling the LLM: {e}") + return messages async def review_resume(model: str, url: str, server_url: str) -> list[str]: - # validate url structure, must have leading "http[s]?" and a domain name (e.g. "example.com"=`\w+\.\w+`) - if not re.search(r"http[s]?://\w+\.\w+", url): - return split("Invalid URL. Please provide a valid URL of your resume.") - + messages: list[str] output_path = "" + try: + # validate url structure, must have leading "http[s]?" and a domain name (e.g. "example.com"=`\w+\.\w+`) + if not re.search(r"http[s]?://\w+\.\w+", url): + raise Exception("Invalid URL. Please provide a valid URL of your resume.") + output_path = pdf.download(url) if not pdf.validate_pdf_format(output_path): - return split("Error: Invalid PDF format. Please provide a valid PDF file.") + raise Exception("The given file is not in PDF format.") + text = pdf.parse_to_text(output_path) + except Exception as e: - return split(f"Error in processing PDF: {e}") + messages = llm_response.split(f"Error: {e}") + return messages finally: if os.path.exists(output_path): os.remove(output_path) @@ -62,47 +65,3 @@ async def review_resume(model: str, url: str, server_url: str) -> list[str]: messages = await answer_question(model, question, server_url, attach_question_to_message=False) return messages - - -def split(answer: str) -> list[str]: - """ - Split the answer into a list of smaller strings so that - each element is less than MAX_CHARACTERS characters. - Full sentences are preserved. - """ - limit = MAX_CHARACTERS - RESERVED_SPACE - QUESTION_CUT_OFF_LENGTH - messages = [] - answer = answer.strip() - - while len(answer) > limit: - last_period = answer[:limit].rfind(".") - if last_period == -1: - last_period = answer[:limit].rfind(" ") - messages.append(answer[: last_period + 1]) - answer = answer[last_period + 1 :] - - messages.append(answer) - - return messages - - -def add_question(messages: list[str], questions: str) -> list[str]: - """ - Add the asked question to the beginning of each message. - """ - return [(f"Q: {questions[:QUESTION_CUT_OFF_LENGTH]}\n" + f"A: {message}") for message in messages] - - -def add_number(messages: list[str]) -> list[str]: - """ - Add the number to the beginning of each message. E.g. `(1/4)` - Do nothing if the length of `messages` is 1. - """ - if len(messages) == 1: - return messages - - for i, message in enumerate(messages): - message = message.strip() - messages[i] = f"({i+1}/{len(messages)}) {message}" - - return messages diff --git a/src/utils/llm_response.py b/src/utils/llm_response.py new file mode 100644 index 0000000..2dbb315 --- /dev/null +++ b/src/utils/llm_response.py @@ -0,0 +1,47 @@ +DISCORD_MESSAGE_MAX_CHARACTERS = 2000 +QUESTION_CUT_OFF_LENGTH = 150 +RESERVED_LENGTH = 50 # for other additional strings. E.g. number `(1/4)`, `Q: `, `A: `, etc. + + +def split(answer: str) -> list[str]: + """ + Split the answer into a list of smaller strings so that + each element is less than discord's message length limit. + Full sentences are preserved where possible. + """ + limit = DISCORD_MESSAGE_MAX_CHARACTERS - RESERVED_LENGTH - QUESTION_CUT_OFF_LENGTH + messages: list[str] = [] + answer = answer.strip() + + while len(answer) > limit: + last_period = answer[:limit].rfind(".") + if last_period == -1: + last_period = answer[:limit].rfind(" ") + messages.append(answer[: last_period + 1]) + answer = answer[last_period + 1 :] + + messages.append(answer) + return messages + + +def add_question(messages: list[str], questions: str) -> list[str]: + """ + Add the asked question to the beginning of each message. + """ + output: list[str] = [(f"Q: {questions[:QUESTION_CUT_OFF_LENGTH]}\n" + f"A: {message}") for message in messages] + return output + + +def add_number(messages: list[str]) -> list[str]: + """ + Add the number to the beginning of each message. E.g. `(1/4)` + Do nothing if the length of `messages` is 1. + """ + if len(messages) == 1: + return messages + + for i, message in enumerate(messages): + message = message.strip() + messages[i] = f"({i+1}/{len(messages)}) {message}" + + return messages diff --git a/src/utils/pdf.py b/src/utils/pdf.py index ae77b21..ce23bed 100644 --- a/src/utils/pdf.py +++ b/src/utils/pdf.py @@ -25,7 +25,9 @@ def parse_to_text(pdf_path: str) -> str: def validate_pdf_format(pdf_path: str) -> bool: - # Check PDF format from metadata + """ + Validate PDF format from metadata. + """ with open(pdf_path, "rb") as file: pdf_metadata = file.read(1024) if not pdf_metadata.startswith(b"%PDF"): diff --git a/tests/test_discord_bot_llm.py b/tests/test_llm.py similarity index 81% rename from tests/test_discord_bot_llm.py rename to tests/test_llm.py index f37e68b..25a25f5 100644 --- a/tests/test_discord_bot_llm.py +++ b/tests/test_llm.py @@ -10,42 +10,36 @@ AI_SERVER_URL = "http://localhost:8000" MODEL = "phi" dotenv.load_dotenv() +simple_prompt = "Hello world!" @pytest.mark.asyncio async def test_answer_question__LLM_should_response() -> None: - model = "phi" - prompt = "Respond shortly: hello!" - - response = await llm.answer_question(model, prompt, AI_SERVER_URL) + response = await llm.answer_question(MODEL, simple_prompt, AI_SERVER_URL) assert not response[0].startswith("Error") @pytest.mark.asyncio async def test_answer_question__invalid_model() -> None: - model = "not-a-gpt" - prompt = "Hello, world!" - - response = await llm.answer_question(model, prompt, AI_SERVER_URL) + invalid_model = "not-a-gpt" + response = await llm.answer_question(invalid_model, simple_prompt, AI_SERVER_URL) assert response[0].startswith("Error") @pytest.mark.asyncio async def test_answer_concurrent_question__should_be_at_the_same_time() -> None: - model = "tinydolphin" - prompt = "Respond shortly: hello" n_models = 2 # Get the average time for generating a character in a single run start = time.time() - out_single = await llm.answer_question(model, prompt, AI_SERVER_URL) + out_single = await llm.answer_question(MODEL, simple_prompt, AI_SERVER_URL) average_single_time = (time.time() - start) / len(out_single) # Get the average time for generating a character when running n_models concurrently start = time.time() - out_concurrent = await _concurrent_call(model, n_models, prompt, AI_SERVER_URL) + out_concurrent = await _concurrent_call(MODEL, n_models, simple_prompt, AI_SERVER_URL) average_concurrent_time = (time.time() - start) / sum([len(x) for x in out_concurrent]) assert ( @@ -73,7 +67,7 @@ async def test_review_resume__valid_url() -> None: for url in valid_urls: response = await llm.review_resume(MODEL, url, AI_SERVER_URL) - assert not response[0].startswith("Invalid URL") + assert not response[0].startswith("Error: Invalid URL") @pytest.mark.asyncio @@ -82,7 +76,7 @@ async def test_review_resume__invalid_url() -> None: for url in invalid_urls: response = await llm.review_resume(MODEL, url, AI_SERVER_URL) - assert response[0].startswith("Invalid URL") + assert response[0].startswith("Error: Invalid URL") @pytest.mark.asyncio