Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add mock llm api for unit tests that only require static response (#1939) #1943

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
336 changes: 183 additions & 153 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,13 @@ composio-langchain = "^0.5.28"
composio-core = "^0.5.34"
alembic = "^1.13.3"
pyhumps = "^3.8.0"
trustme = {version = "^1.2.0", optional = true}

[tool.poetry.extras]
#local = ["llama-index-embeddings-huggingface"]
postgres = ["pgvector", "pg8000", "psycopg2-binary"]
milvus = ["pymilvus"]
dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order", "autoflake", "isort"]
dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order", "autoflake", "isort", "trustme"]
server = ["websockets", "fastapi", "uvicorn"]
autogen = ["pyautogen"]
qdrant = ["qdrant-client"]
Expand Down
44 changes: 44 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,49 @@
import logging
import os
import threading
import time

import pytest
from fastapi.testclient import TestClient

from tests.helpers import mock_llm


def pytest_configure(config):
logging.basicConfig(level=logging.DEBUG)


def pytest_addoption(parser):
parser.addoption(
"--llm-api",
action="store",
default="openai",
help="backend options: openai or mock",
choices=("openai", "mock"),
)


@pytest.fixture(scope="module")
def llmopt(request):
return request.config.getoption("--llm-api")


@pytest.fixture(scope="module")
def mock_llm_client(llmopt):
if llmopt == "mock":
print("Starting mock llm api server thread")
print(__name__)
thread = threading.Thread(target=mock_llm.start_mock_llm_server, daemon=True)
thread.start()
time.sleep(5)

mock_llm_client = TestClient(mock_llm.app)
yield mock_llm_client
else:
yield None

# Cleanup ssl cert override
if llmopt == "mock":
del os.environ["REQUESTS_CA_BUNDLE"]
os.remove(mock_llm.DEFAULT_MOCK_LLM_SSL_CERT_PATH)
os.rmdir(mock_llm.DEFAULT_MOCK_LLM_SSL_CERT_PATH.split("/")[0])
120 changes: 120 additions & 0 deletions tests/helpers/mock_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import os
import secrets
import string
import time
from typing import Optional

import trustme
import uvicorn
from fastapi import FastAPI

from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig

DEFAULT_MOCK_LLM_API_HOST = "localhost"
DEFAULT_MOCK_LLM_API_PORT = 8000
DEFAULT_MOCK_LLM_SSL_CERT_PATH = "certs/ca_cert.pem"


app = FastAPI()


@app.post("/v1/chat/completions")
async def user_message():
response = {
"id": "chatcmpl-" + generate_mock_id(28),
"object": "chat.completion",
"created": int(time.time()),
"model": "memgpt-openai",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_" + generate_mock_id(24),
"type": "function",
"function": {
"name": "send_message",
"arguments": '{"message":"Hello! It\'s great to meet you! How are you doing today?","inner_thoughts":"User has greeted me. Time to establish a connection and gauge their mood."}',
},
}
],
"refusal": None,
},
"logprobs": None,
"finish_reason": "tool_calls",
}
],
"usage": {
"prompt_tokens": 2370,
"completion_tokens": 48,
"total_tokens": 2418,
"prompt_tokens_details": {"cached_tokens": 0},
"completion_tokens_details": {"reasoning_tokens": 0},
},
"system_fingerprint": "fp_" + generate_mock_id(10),
}
return response


@app.get("/v1/embeddings")
async def message():
pass


@app.get("/configs")
def get_config():
return {
"llm": get_llm_config(),
"embedding": get_embedding_config(),
}


def get_llm_config():
return LLMConfig(
model="memgpt-openai",
model_endpoint_type="openai",
model_endpoint="https://localhost:8000/v1",
model_wrapper=None,
context_window=8192,
)


def get_embedding_config():
return EmbeddingConfig(
embedding_model="text-embedding-ada-002",
embedding_endpoint_type="openai",
embedding_endpoint="https://localhost:8000/v1",
embedding_dim=1536,
embedding_chunk_size=300,
)


def generate_mock_id(length: int):
possible_characters = string.ascii_letters + string.digits
return "".join(secrets.choice(possible_characters) for _ in range(length))


def start_mock_llm_server(
port: Optional[int] = DEFAULT_MOCK_LLM_API_PORT,
host: Optional[str] = DEFAULT_MOCK_LLM_API_HOST,
):
ca = trustme.CA()
os.makedirs(DEFAULT_MOCK_LLM_SSL_CERT_PATH.split("/")[0], exist_ok=True)
ca.cert_pem.write_to_path(DEFAULT_MOCK_LLM_SSL_CERT_PATH)
os.environ["REQUESTS_CA_BUNDLE"] = DEFAULT_MOCK_LLM_SSL_CERT_PATH

cert = ca.issue_cert(host)
with cert.cert_chain_pems[0].tempfile() as cert_path:
with cert.private_key_pem.tempfile() as key_path:
print(f"Running: uvicorn server:mock_llm_app --host {host} --port {port}")
uvicorn.run(
app,
host=host,
port=port,
ssl_keyfile=key_path,
ssl_certfile=cert_path,
)
21 changes: 20 additions & 1 deletion tests/helpers/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Union
from typing import Optional, Union

from fastapi.testclient import TestClient

from letta import LocalClient, RESTClient
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig


def cleanup(client: Union[LocalClient, RESTClient], agent_uuid: str):
Expand All @@ -9,3 +13,18 @@ def cleanup(client: Union[LocalClient, RESTClient], agent_uuid: str):
if agent_state.name == agent_uuid:
client.delete_agent(agent_id=agent_state.id)
print(f"Deleted agent: {agent_state.name} with ID {str(agent_state.id)}")


def set_default_configs(letta_client: Union[LocalClient, RESTClient], mock_llm_client: Optional[TestClient] = None):
# Conditionally set configs based on provided llm api option
configs_response = mock_llm_client.get("/configs") if mock_llm_client else None
if configs_response and configs_response.status_code == 200:
configs = configs_response.json()
llm_config = LLMConfig.parse_obj(configs["llm"])
embedding_config = EmbeddingConfig.parse_obj(configs["embedding"])
else:
llm_config = LLMConfig.default_config("gpt-4")
embedding_config = EmbeddingConfig.default_config(provider="openai")

letta_client.set_default_llm_config(llm_config)
letta_client.set_default_embedding_config(embedding_config)
6 changes: 3 additions & 3 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from letta.schemas.message import Message
from letta.schemas.usage import LettaUsageStatistics
from tests.helpers.client_helper import upload_file_using_client
from tests.helpers.utils import set_default_configs

# from tests.utils import create_config

Expand Down Expand Up @@ -48,7 +49,7 @@ def run_server():
params=[{"server": True}], # whether to use REST API server
scope="module",
)
def client(request):
def client(request, mock_llm_client):
if request.param["server"]:
# get URL from enviornment
server_url = os.getenv("LETTA_SERVER_URL")
Expand All @@ -67,8 +68,7 @@ def client(request):
server_url = None
client = create_client()

client.set_default_llm_config(LLMConfig.default_config("gpt-4"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
set_default_configs(client, mock_llm_client)
yield client


Expand Down
Loading