Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functional tests for discord bot #37

Merged
merged 14 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .github/actions/setup-ai/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: setup-ai-server
description: "Setup AI server for testing"

runs:
using: composite
steps:
- name: Install ollama
shell: bash
run: curl -fsSL https://ollama.com/install.sh | sh

- name: Start ollama
if: steps.cache-ollama.outputs.cache-hit != 'true'
shell: bash
run: ollama pull tinydolphin

- name: Start litellm
shell: bash
run: poetry run litellm --config src/litellm/proxy_config_for_testing.yaml &
5 changes: 4 additions & 1 deletion .github/workflows/test-and-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@ jobs:
- name: Set up Python
uses: ./.github/actions/setup-python

- name: Set up AI server
uses: ./.github/actions/setup-ai

- name: Run pytest
run: poetry run pytest
run: poetry run pytest tests/

typecheck:
runs-on: ubuntu-latest
Expand Down
85 changes: 59 additions & 26 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.1.0"
description = ""
authors = ["VAIT"]
readme = "README.md"
packages = [{include = "discord_bot", from = "src"}]

[tool.poetry.dependencies]
python = "^3.12"
Expand All @@ -17,6 +18,8 @@ pre-commit = "^3.6.1"
ruff = "^0.2.1"
pytest = "^8.0.0"
mypy = "^1.8.0"
pytest-asyncio = "^0.23.5"


[tool.ruff]
# Exclude a variety of commonly ignored directories.
Expand Down Expand Up @@ -83,6 +86,11 @@ show_error_codes = true
warn_return_any = true
warn_unused_ignores = true

# Prevent mypy from trigger the missing import error for our written packages
[[tool.mypy.overrides]]
module = ["discord_bot.*", "src.discord_bot.*"]
ignore_missing_imports = true

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
4 changes: 2 additions & 2 deletions src/discord_bot/bot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import interactions
import dotenv
import llm
from discord_bot.llm import answer_question

dotenv.load_dotenv()

Expand Down Expand Up @@ -41,7 +41,7 @@ async def ask_model(ctx: interactions.SlashContext, model: str = "", prompt: str

await ctx.defer()

response = await llm.answer_question(model, prompt, AI_SERVER_URL)
response = await answer_question(model, prompt, AI_SERVER_URL)
await ctx.send(response)


Expand Down
18 changes: 10 additions & 8 deletions src/discord_bot/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@


async def answer_question(model: str, question: str, server_url: str) -> str:
client = openai.AsyncOpenAI(base_url=server_url, api_key="FAKE")
response = await client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": question}],
)
try:
client = openai.AsyncOpenAI(base_url=server_url, api_key="FAKE")
response = await client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": question}],
)
out = response.choices[0].message.content or "No response from the model"

out = response.choices[0].message.content or "No response from the model"

return out
return out
except Exception as e:
return f"Error: {e}"
11 changes: 11 additions & 0 deletions src/litellm/proxy_config_for_testing.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
model_list:
- model_name: tinydolphin
litellm_params:
model: ollama/tinydolphin

litellm_params:
drop_params: True

router_settings:
num_retries: 2
timeout: 60 # seconds
2 changes: 0 additions & 2 deletions tests/test_demo.py

This file was deleted.

59 changes: 59 additions & 0 deletions tests/test_discord_bot_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import asyncio
import time
import dotenv
import pytest
from typing import List
from discord_bot.llm import answer_question

AI_SERVER_URL = "http://localhost:8000"
dotenv.load_dotenv()


@pytest.mark.asyncio
async def test_answer_question__LLM_should_response() -> None:
model = "tinydolphin"
prompt = "Respond shortly: hello!"

response = await answer_question(model, prompt, AI_SERVER_URL)

assert not response.startswith("Error")


@pytest.mark.asyncio
async def test_answer_question__invalid_model() -> None:
model = "not-a-gpt"
prompt = "Hello, world!"

response = await answer_question(model, prompt, AI_SERVER_URL)

assert response.startswith("Error")


@pytest.mark.asyncio
async def test_answer_concurrent_question__should_be_at_the_same_time() -> None:
model = "tinydolphin"
prompt = "Respond shortly: hello"
n_models = 2

# Get the average time for generating a character in a single run
start = time.time()
out_single = await answer_question(model, prompt, AI_SERVER_URL)
average_single_time = (time.time() - start) / len(out_single)

# Get the average time for generating a character when running n_models concurrently
start = time.time()
out_concurrent = await _concurrent_call(model, n_models, prompt, AI_SERVER_URL)
average_concurrent_time = (time.time() - start) / sum([len(x) for x in out_concurrent])

assert (
average_concurrent_time < average_single_time * n_models
), f"Running {n_models} separately should take more time than running them concurrently"


async def _concurrent_call(model: str, n_models: int, prompt: str, server_url: str) -> List[str]:
asyncMethod = []
for _ in range(n_models):
asyncMethod.append(answer_question(model, prompt, server_url))

out = await asyncio.gather(*asyncMethod)
return out